Skip to content

Data

Datasets

FastTdDataset

FastTdDataset(td: TensorDict)

Bases: Dataset

Note

Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374.

Source code in rl4co/data/dataset.py
24
25
26
def __init__(self, td: TensorDict):
    self.data_len = td.batch_size[0]
    self.data = td

collate_fn staticmethod

collate_fn(batch: Union[dict, TensorDict])

Collate function compatible with TensorDicts that reassembles a list of dicts.

Source code in rl4co/data/dataset.py
37
38
39
40
@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
    """Collate function compatible with TensorDicts that reassembles a list of dicts."""
    return batch

TensorDictDataset

TensorDictDataset(td: TensorDict)

Bases: Dataset

Dataset compatible with TensorDicts with low CPU usage. Fast loading but somewhat slow instantiation due to list comprehension since we "disassemble" the TensorDict into a list of dicts.

Note

Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374.

Source code in rl4co/data/dataset.py
53
54
55
56
57
def __init__(self, td: TensorDict):
    self.data_len = td.batch_size[0]
    self.data = [
        {key: value[i] for key, value in td.items()} for i in range(self.data_len)
    ]

collate_fn staticmethod

collate_fn(batch: Union[dict, TensorDict])

Collate function compatible with TensorDicts that reassembles a list of dicts.

Source code in rl4co/data/dataset.py
68
69
70
71
72
73
74
75
@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
    """Collate function compatible with TensorDicts that reassembles a list of dicts."""
    return TensorDict(
        {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
        batch_size=torch.Size([len(batch)]),
        **td_kwargs,
    )

ExtraKeyDataset

ExtraKeyDataset(
    dataset: TensorDictDataset,
    extra: Tensor,
    key_name="extra",
)

Bases: TensorDictDataset

Dataset that includes an extra key to add to the data dict. This is useful for adding a REINFORCE baseline reward to the data dict. Note that this is faster to instantiate than using list comprehension.

Source code in rl4co/data/dataset.py
84
85
86
87
88
89
def __init__(self, dataset: TensorDictDataset, extra: torch.Tensor, key_name="extra"):
    self.data_len = len(dataset)
    assert self.data_len == len(extra), "Data and extra must be same length"
    self.data = dataset.data
    self.extra = extra
    self.key_name = key_name

TensorDictDatasetFastGeneration

TensorDictDatasetFastGeneration(td: TensorDict)

Bases: Dataset

Dataset compatible with TensorDicts. Similar performance in loading to list comprehension, but is faster in instantiation than :class:TensorDictDatasetList (more than 10x faster).

Warning

Note that directly indexing TensorDicts may be faster in creating the dataset but uses > 3x more CPU. We may generally recommend using the :class:TensorDictDatasetList

Note

Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374.

Source code in rl4co/data/dataset.py
111
112
def __init__(self, td: TensorDict):
    self.data = td

collate_fn staticmethod

collate_fn(batch: Union[dict, TensorDict])

Equivalent to collating with lambda x: x

Source code in rl4co/data/dataset.py
131
132
133
134
@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
    """Equivalent to collating with `lambda x: x`"""
    return batch

Data Generation

generate_env_data

generate_env_data(env_type, *args, **kwargs)

Generate data for a given environment type in the form of a dictionary

Source code in rl4co/data/generate_data.py
26
27
28
29
30
31
32
33
34
35
36
37
def generate_env_data(env_type, *args, **kwargs):
    """Generate data for a given environment type in the form of a dictionary"""
    try:
        # breakpoint()
        # remove all None values from args
        args = [arg for arg in args if arg is not None]

        return getattr(sys.modules[__name__], f"generate_{env_type}_data")(
            *args, **kwargs
        )
    except AttributeError:
        raise NotImplementedError(f"Environment type {env_type} not implemented")

generate_mdpp_data

generate_mdpp_data(
    dataset_size,
    size=10,
    num_probes_min=2,
    num_probes_max=5,
    num_keepout_min=1,
    num_keepout_max=50,
    lock_size=True,
)

Generate data for the nDPP problem. If lock_size is True, then the size if fixed and we skip the size argument if it is not 10. This is because the RL environment is based on a real-world PCB (parametrized with data)

Source code in rl4co/data/generate_data.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def generate_mdpp_data(
    dataset_size,
    size=10,
    num_probes_min=2,
    num_probes_max=5,
    num_keepout_min=1,
    num_keepout_max=50,
    lock_size=True,
):
    """Generate data for the nDPP problem.
    If `lock_size` is True, then the size if fixed and we skip the `size` argument if it is not 10.
    This is because the RL environment is based on a real-world PCB (parametrized with data)
    """
    if lock_size and size != 10:
        # log.info("Locking size to 10, skipping generate_mdpp_data with size {}".format(size))
        return None

    bs = dataset_size  # bs = batch_size to generate data in batch
    m = n = size
    if isinstance(bs, int):
        bs = [bs]

    locs = np.stack(np.meshgrid(np.arange(m), np.arange(n)), axis=-1).reshape(-1, 2)
    locs = locs / np.array([m, n], dtype=np.float32)
    locs = np.expand_dims(locs, axis=0)
    locs = np.repeat(locs, bs[0], axis=0)

    available = np.ones((bs[0], m * n), dtype=bool)

    probe = np.random.randint(0, high=m * n, size=(bs[0], 1))
    np.put_along_axis(available, probe, False, axis=1)

    num_probe = np.random.randint(num_probes_min, num_probes_max + 1, size=(bs[0], 1))
    probes = np.zeros((bs[0], m * n), dtype=bool)
    for i in range(bs[0]):
        p = np.random.choice(m * n, num_probe[i], replace=False)
        np.put_along_axis(available[i], p, False, axis=0)
        np.put_along_axis(probes[i], p, True, axis=0)

    num_keepout = np.random.randint(num_keepout_min, num_keepout_max + 1, size=(bs[0], 1))
    for i in range(bs[0]):
        k = np.random.choice(m * n, num_keepout[i], replace=False)
        np.put_along_axis(available[i], k, False, axis=0)

    return {
        "locs": locs.astype(np.float32),
        "probe": probes.astype(bool),
        "action_mask": available.astype(bool),
    }

generate_dataset

generate_dataset(
    filename: Union[str, List[str]] = None,
    data_dir: str = "data",
    name: str = None,
    problem: Union[str, List[str]] = "all",
    data_distribution: str = "all",
    dataset_size: int = 10000,
    graph_sizes: Union[int, List[int]] = [20, 50, 100],
    overwrite: bool = False,
    seed: int = 1234,
    disable_warning: bool = True,
    distributions_per_problem: Union[int, dict] = None,
)

We keep a similar structure as in Kool et al. 2019 but save and load the data as npz This is way faster and more memory efficient than pickle and also allows for easy transfer to TensorDict

Parameters:

  • filename (Union[str, List[str]], default: None ) –

    Filename to save the data to. If None, the data is saved to data_dir/problem/problem_graph_size_seed.npz. Defaults to None.

  • data_dir (str, default: 'data' ) –

    Directory to save the data to. Defaults to "data".

  • name (str, default: None ) –

    Name of the dataset. Defaults to None.

  • problem (Union[str, List[str]], default: 'all' ) –

    Problem to generate data for. Defaults to "all".

  • data_distribution (str, default: 'all' ) –

    Data distribution to generate data for. Defaults to "all".

  • dataset_size (int, default: 10000 ) –

    Number of datasets to generate. Defaults to 10000.

  • graph_sizes (Union[int, List[int]], default: [20, 50, 100] ) –

    Graph size to generate data for. Defaults to [20, 50, 100].

  • overwrite (bool, default: False ) –

    Whether to overwrite existing files. Defaults to False.

  • seed (int, default: 1234 ) –

    Random seed. Defaults to 1234.

  • disable_warning (bool, default: True ) –

    Whether to disable warnings. Defaults to True.

  • distributions_per_problem (Union[int, dict], default: None ) –

    Number of distributions to generate per problem. Defaults to None.

Source code in rl4co/data/generate_data.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def generate_dataset(
    filename: Union[str, List[str]] = None,
    data_dir: str = "data",
    name: str = None,
    problem: Union[str, List[str]] = "all",
    data_distribution: str = "all",
    dataset_size: int = 10000,
    graph_sizes: Union[int, List[int]] = [20, 50, 100],
    overwrite: bool = False,
    seed: int = 1234,
    disable_warning: bool = True,
    distributions_per_problem: Union[int, dict] = None,
):
    """We keep a similar structure as in Kool et al. 2019 but save and load the data as npz
    This is way faster and more memory efficient than pickle and also allows for easy transfer to TensorDict

    Args:
        filename: Filename to save the data to. If None, the data is saved to data_dir/problem/problem_graph_size_seed.npz. Defaults to None.
        data_dir: Directory to save the data to. Defaults to "data".
        name: Name of the dataset. Defaults to None.
        problem: Problem to generate data for. Defaults to "all".
        data_distribution: Data distribution to generate data for. Defaults to "all".
        dataset_size: Number of datasets to generate. Defaults to 10000.
        graph_sizes: Graph size to generate data for. Defaults to [20, 50, 100].
        overwrite: Whether to overwrite existing files. Defaults to False.
        seed: Random seed. Defaults to 1234.
        disable_warning: Whether to disable warnings. Defaults to True.
        distributions_per_problem: Number of distributions to generate per problem. Defaults to None.
    """

    if isinstance(problem, list) and len(problem) == 1:
        problem = problem[0]

    graph_sizes = [graph_sizes] if isinstance(graph_sizes, int) else graph_sizes

    if distributions_per_problem is None:
        distributions_per_problem = DISTRIBUTIONS_PER_PROBLEM

    if problem == "all":
        problems = distributions_per_problem
    else:
        problems = {
            problem: distributions_per_problem[problem]
            if data_distribution == "all"
            else [data_distribution]
        }

    # Support multiple filenames if necessary
    filenames = [filename] if isinstance(filename, str) else filename
    iter = 0

    # Main loop for data generation. We loop over all problems, distributions and sizes
    for problem, distributions in problems.items():
        for distribution in distributions or [None]:
            for graph_size in graph_sizes:
                if filename is None:
                    datadir = os.path.join(data_dir, problem)
                    os.makedirs(datadir, exist_ok=True)
                    fname = os.path.join(
                        datadir,
                        "{}{}{}_{}_seed{}.npz".format(
                            problem,
                            "_{}".format(distribution)
                            if distribution is not None
                            else "",
                            graph_size,
                            name,
                            seed,
                        ),
                    )
                else:
                    try:
                        fname = filenames[iter]
                        # make directory if necessary
                        os.makedirs(os.path.dirname(fname), exist_ok=True)
                        iter += 1
                    except Exception:
                        raise ValueError(
                            "Number of filenames does not match number of problems"
                        )
                    fname = check_extension(filename, extension=".npz")

                if not overwrite and os.path.isfile(
                    check_extension(fname, extension=".npz")
                ):
                    if not disable_warning:
                        log.info(
                            "File {} already exists! Run with -f option to overwrite. Skipping...".format(
                                fname
                            )
                        )
                    continue

                # Set seed
                np.random.seed(seed)

                # Automatically generate dataset
                dataset = generate_env_data(
                    problem, dataset_size, graph_size, distribution
                )

                # A function can return None in case of an error or a skip
                if dataset is not None:
                    # Save to disk as dict
                    log.info("Saving {} dataset to {}".format(problem, fname))
                    np.savez(fname, **dataset)

generate_default_datasets

generate_default_datasets(data_dir, generate_eda=False)

Generate the default datasets used in the paper and save them to data_dir/problem

Source code in rl4co/data/generate_data.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def generate_default_datasets(data_dir, generate_eda=False):
    """Generate the default datasets used in the paper and save them to data_dir/problem"""
    generate_dataset(data_dir=data_dir, name="val", problem="all", seed=4321)
    generate_dataset(data_dir=data_dir, name="test", problem="all", seed=1234)

    # By default, we skip the EDA datasets since they can easily be generated on the fly when needed
    if generate_eda:
        generate_dataset(
            data_dir=data_dir,
            name="test",
            problem="mdpp",
            seed=1234,
            graph_sizes=[10],
            dataset_size=100,
        )  # EDA (mDPP)

Transforms

StateAugmentation

StateAugmentation(
    num_augment: int = 8,
    augment_fn: Union[str, callable] = "symmetric",
    first_aug_identity: bool = True,
    normalize: bool = False,
    feats: list = None,
)

Bases: object

Augment state by N times via symmetric rotation/reflection transform

Parameters:

  • num_augment (int, default: 8 ) –

    number of augmentations

  • augment_fn (Union[str, callable], default: 'symmetric' ) –

    augmentation function to use, e.g. 'symmetric' (default) or 'dihedral8', if callable, then use the function directly. If 'dihedral8', then num_augment must be 8

  • first_aug_identity (bool, default: True ) –

    whether to augment the first data point too

  • normalize (bool, default: False ) –

    whether to normalize the augmented data

  • feats (list, default: None ) –

    list of features to augment

Source code in rl4co/data/transforms.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    num_augment: int = 8,
    augment_fn: Union[str, callable] = 'symmetric', 
    first_aug_identity: bool = True,
    normalize: bool = False,
    feats: list = None,
):
    self.augmentation = get_augment_function(augment_fn)
    assert not (
        self.augmentation == dihedral_8_augmentation_wrapper and num_augment != 8
    ), "When using the `dihedral8` augmentation function, then num_augment must be 8"

    if feats is None:
        log.info("Features not passed, defaulting to 'locs'")
        self.feats = ["locs"]
    else:
        self.feats = feats
    self.num_augment = num_augment
    self.normalize = normalize
    self.first_aug_identity = first_aug_identity

dihedral_8_augmentation

dihedral_8_augmentation(xy: Tensor) -> Tensor

Augmentation (x8) for grid-based data (x, y) as done in POMO. This is a Dihedral group of order 8 (rotations and reflections) https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8

Parameters:

  • xy (Tensor) –

    [batch, graph, 2] tensor of x and y coordinates

Source code in rl4co/data/transforms.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def dihedral_8_augmentation(xy: Tensor) -> Tensor:
    """
    Augmentation (x8) for grid-based data (x, y) as done in POMO.
    This is a Dihedral group of order 8 (rotations and reflections)
    https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8

    Args:
        xy: [batch, graph, 2] tensor of x and y coordinates
    """
    # [batch, graph, 2]
    x, y = xy.split(1, dim=2)
    # augmnetations [batch, graph, 2]
    z0 = torch.cat((x, y), dim=2)
    z1 = torch.cat((1 - x, y), dim=2)
    z2 = torch.cat((x, 1 - y), dim=2)
    z3 = torch.cat((1 - x, 1 - y), dim=2)
    z4 = torch.cat((y, x), dim=2)
    z5 = torch.cat((1 - y, x), dim=2)
    z6 = torch.cat((y, 1 - x), dim=2)
    z7 = torch.cat((1 - y, 1 - x), dim=2)
    # [batch*8, graph, 2]
    aug_xy = torch.cat((z0, z1, z2, z3, z4, z5, z6, z7), dim=0)
    return aug_xy

dihedral_8_augmentation_wrapper

dihedral_8_augmentation_wrapper(
    xy: Tensor, reduce: bool = True, *args, **kw
) -> Tensor

Wrapper for dihedral_8_augmentation. If reduce, only return the first 1/8 of the augmented data since the augmentation augments the data 8 times.

Source code in rl4co/data/transforms.py
40
41
42
43
44
45
46
47
def dihedral_8_augmentation_wrapper(
    xy: Tensor, reduce: bool = True, *args, **kw
) -> Tensor:
    """Wrapper for dihedral_8_augmentation. If reduce, only return the first 1/8 of the augmented data
    since the augmentation augments the data 8 times.
    """
    xy = xy[: xy.shape[0] // 8, ...] if reduce else xy
    return dihedral_8_augmentation(xy)

symmetric_transform

symmetric_transform(
    x: Tensor, y: Tensor, phi: Tensor, offset: float = 0.5
)

SR group transform with rotation and reflection Like the one in SymNCO, but a vectorized version

Parameters:

  • x (Tensor) –

    [batch, graph, 1] tensor of x coordinates

  • y (Tensor) –

    [batch, graph, 1] tensor of y coordinates

  • phi (Tensor) –

    [batch, 1] tensor of random rotation angles

  • offset (float, default: 0.5 ) –

    offset for x and y coordinates

Source code in rl4co/data/transforms.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def symmetric_transform(x: Tensor, y: Tensor, phi: Tensor, offset: float = 0.5):
    """SR group transform with rotation and reflection
    Like the one in SymNCO, but a vectorized version

    Args:
        x: [batch, graph, 1] tensor of x coordinates
        y: [batch, graph, 1] tensor of y coordinates
        phi: [batch, 1] tensor of random rotation angles
        offset: offset for x and y coordinates
    """
    x, y = x - offset, y - offset
    # random rotation
    x_prime = torch.cos(phi) * x - torch.sin(phi) * y
    y_prime = torch.sin(phi) * x + torch.cos(phi) * y
    # make random reflection if phi > 2*pi (i.e. 50% of the time)
    mask = phi > 2 * math.pi
    # vectorized random reflection: swap axes x and y if mask
    xy = torch.cat((x_prime, y_prime), dim=-1)
    xy = torch.where(mask, xy.flip(-1), xy)
    return xy + offset

symmetric_augmentation

symmetric_augmentation(
    xy: Tensor,
    num_augment: int = 8,
    first_augment: bool = False,
)

Augment xy data by num_augment times via symmetric rotation transform and concatenate to original data

Parameters:

  • xy (Tensor) –

    [batch, graph, 2] tensor of x and y coordinates

  • num_augment (int, default: 8 ) –

    number of augmentations

  • first_augment (bool, default: False ) –

    whether to augment the first data point

Source code in rl4co/data/transforms.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def symmetric_augmentation(xy: Tensor, num_augment: int = 8, first_augment: bool = False):
    """Augment xy data by `num_augment` times via symmetric rotation transform and concatenate to original data

    Args:
        xy: [batch, graph, 2] tensor of x and y coordinates
        num_augment: number of augmentations
        first_augment: whether to augment the first data point
    """
    # create random rotation angles (4*pi for reflection, 2*pi for rotation)
    phi = torch.rand(xy.shape[0], device=xy.device) * 4 * math.pi

    # set phi to 0 for first , i.e. no augmentation as in SymNCO
    if not first_augment:
        phi[: xy.shape[0] // num_augment] = 0.0
    x, y = xy[..., [0]], xy[..., [1]]
    return symmetric_transform(x, y, phi[:, None, None])

Utils

load_npz_to_tensordict

load_npz_to_tensordict(filename)

Load a npz file directly into a TensorDict We assume that the npz file contains a dictionary of numpy arrays This is at least an order of magnitude faster than pickle

Source code in rl4co/data/utils.py
11
12
13
14
15
16
17
18
19
def load_npz_to_tensordict(filename):
    """Load a npz file directly into a TensorDict
    We assume that the npz file contains a dictionary of numpy arrays
    This is at least an order of magnitude faster than pickle
    """
    x = np.load(filename)
    x_dict = dict(x)
    batch_size = x_dict[list(x_dict.keys())[0]].shape[0]
    return TensorDict(x_dict, batch_size=batch_size)

save_tensordict_to_npz

save_tensordict_to_npz(
    tensordict, filename, compress: bool = False
)

Save a TensorDict to a npz file We assume that the TensorDict contains a dictionary of tensors

Source code in rl4co/data/utils.py
22
23
24
25
26
27
28
29
30
def save_tensordict_to_npz(tensordict, filename, compress: bool = False):
    """Save a TensorDict to a npz file
    We assume that the TensorDict contains a dictionary of tensors
    """
    x_dict = {k: v.numpy() for k, v in tensordict.items()}
    if compress:
        np.savez_compressed(filename, **x_dict)
    else:
        np.savez(filename, **x_dict)

check_extension

check_extension(filename, extension='.npz')

Check that filename has extension, otherwise add it

Source code in rl4co/data/utils.py
33
34
35
36
37
def check_extension(filename, extension=".npz"):
    """Check that filename has extension, otherwise add it"""
    if os.path.splitext(filename)[1] != extension:
        return filename + extension
    return filename

load_solomon_instance

load_solomon_instance(name, path=None, edge_weights=False)

Load solomon instance from a file

Source code in rl4co/data/utils.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def load_solomon_instance(name, path=None, edge_weights=False):
    """Load solomon instance from a file"""
    import vrplib

    if not path:
        path = "data/solomon/instances/"
        path = os.path.join(ROOT_PATH, path)
    if not os.path.isdir(path):
        os.makedirs(path)
    file_path = f"{path}{name}.txt"
    if not os.path.isfile(file_path):
        vrplib.download_instance(name=name, path=path)
    return vrplib.read_instance(
        path=file_path,
        instance_format="solomon",
        compute_edge_weights=edge_weights,
    )

load_solomon_solution

load_solomon_solution(name, path=None)

Load solomon solution from a file

Source code in rl4co/data/utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
def load_solomon_solution(name, path=None):
    """Load solomon solution from a file"""
    import vrplib

    if not path:
        path = "data/solomon/solutions/"
        path = os.path.join(ROOT_PATH, path)
    if not os.path.isdir(path):
        os.makedirs(path)
    file_path = f"{path}{name}.sol"
    if not os.path.isfile(file_path):
        vrplib.download_solution(name=name, path=path)
    return vrplib.read_solution(path=file_path)