Skip to content

Base Environment

This is the base wrapper around TorchRL's EnvBase, with additional functionality.

RL4COEnvBase

RL4COEnvBase(
    *,
    data_dir: str = "data/",
    train_file: str = None,
    val_file: str = None,
    test_file: str = None,
    val_dataloader_names: list = None,
    test_dataloader_names: list = None,
    check_solution: bool = True,
    dataset_cls: callable = TensorDictDataset,
    seed: int = None,
    device: str = "cpu",
    batch_size: Size = None,
    run_type_checks: bool = False,
    allow_done_after_reset: bool = False,
    _torchrl_mode: bool = False,
    **kwargs
)

Bases: EnvBase

Base class for RL4CO environments based on TorchRL EnvBase. The environment has the usual methods for stepping, resetting, and getting the specifications of the environment that shoud be implemented by the subclasses of this class. It also has methods for getting the reward, action mask, and checking the validity of the solution, and for generating and loading the datasets (supporting multiple dataloaders as well for validation and testing).

Parameters:

  • data_dir (str, default: 'data/' ) –

    Root directory for the dataset

  • train_file (str, default: None ) –

    Name of the training file

  • val_file (str, default: None ) –

    Name of the validation file

  • test_file (str, default: None ) –

    Name of the test file

  • val_dataloader_names (list, default: None ) –

    Names of the dataloaders to use for validation

  • test_dataloader_names (list, default: None ) –

    Names of the dataloaders to use for testing

  • check_solution (bool, default: True ) –

    Whether to check the validity of the solution at the end of the episode

  • dataset_cls (callable, default: TensorDictDataset ) –

    Dataset class to use for the environment (which can influence performance)

  • seed (int, default: None ) –

    Seed for the environment

  • device (str, default: 'cpu' ) –

    Device to use. Generally, no need to set as tensors are updated on the fly

  • batch_size (Size, default: None ) –

    Batch size to use for the environment. Generally, no need to set as tensors are updated on the fly

  • run_type_checks (bool, default: False ) –

    If True, run type checks on the TensorDicts at each step

  • allow_done_after_reset (bool, default: False ) –

    If True, an environment can be done after a reset

  • _torchrl_mode (bool, default: False ) –

    Whether to use the TorchRL mode (see :meth:step for more details)

Source code in rl4co/envs/common/base.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    *,
    data_dir: str = "data/",
    train_file: str = None,
    val_file: str = None,
    test_file: str = None,
    val_dataloader_names: list = None,
    test_dataloader_names: list = None,
    check_solution: bool = True,
    dataset_cls: callable = TensorDictDataset,
    seed: int = None,
    device: str = "cpu",
    batch_size: torch.Size = None,
    run_type_checks: bool = False,
    allow_done_after_reset: bool = False,
    _torchrl_mode: bool = False,
    **kwargs,
):
    super().__init__(
        device=device,
        batch_size=batch_size,
        run_type_checks=run_type_checks,
        allow_done_after_reset=allow_done_after_reset,
    )
    # if any kwargs are left, we want to warn the user
    kwargs.pop("name", None)  # we remove the name for checking
    if kwargs:
        log.error(
            f"Unused keyword arguments: {', '.join(kwargs.keys())}. "
            "Please check the base class documentation at https://rl4co.readthedocs.io/en/latest/_content/api/envs/base.html. "
            "In case you would like to pass data generation arguments, please pass a `generator` method instead "
            "or for example: `generator_kwargs=dict(num_loc=50)` to the constructor."
        )
    self.data_dir = data_dir
    self.train_file = pjoin(data_dir, train_file) if train_file is not None else None
    self._torchrl_mode = _torchrl_mode
    self.dataset_cls = dataset_cls

    def get_files(f):
        if f is not None:
            if isinstance(f, Iterable) and not isinstance(f, str):
                return [pjoin(data_dir, _f) for _f in f]
            else:
                return pjoin(data_dir, f)
        return None

    def get_multiple_dataloader_names(f, names):
        if f is not None:
            if isinstance(f, Iterable) and not isinstance(f, str):
                if names is None:
                    names = [f"{i}" for i in range(len(f))]
                else:
                    assert len(names) == len(
                        f
                    ), "Number of dataloader names must match number of files"
            else:
                if names is not None:
                    log.warning(
                        "Ignoring dataloader names since only one dataloader is provided"
                    )
        return names

    self.val_file = get_files(val_file)
    self.test_file = get_files(test_file)
    self.val_dataloader_names = get_multiple_dataloader_names(
        self.val_file, val_dataloader_names
    )
    self.test_dataloader_names = get_multiple_dataloader_names(
        self.test_file, test_dataloader_names
    )
    self.check_solution = check_solution
    if seed is None:
        seed = torch.empty((), dtype=torch.int64).random_().item()
    self.set_seed(seed)

step

step(td: TensorDict) -> TensorDict

Step function to call at each step of the episode containing an action. If _torchrl_mode is True, we call _torchrl_step instead which set the next key of the TensorDict to the next state - this is the usual way to do it in TorchRL, but inefficient in our case

Source code in rl4co/envs/common/base.py
121
122
123
124
125
126
127
128
129
130
131
132
133
def step(self, td: TensorDict) -> TensorDict:
    """Step function to call at each step of the episode containing an action.
    If `_torchrl_mode` is True, we call `_torchrl_step` instead which set the
    `next` key of the TensorDict to the next state - this is the usual way to do it in TorchRL,
    but inefficient in our case
    """
    if not self._torchrl_mode:
        # Default: just return the TensorDict without farther checks etc is faster
        td = self._step(td)
        return {"next": td}
    else:
        # Since we simplify the syntax
        return self._torchrl_step(td)

reset

reset(
    td: Optional[TensorDict] = None, batch_size=None
) -> TensorDict

Reset function to call at the beginning of each episode

Source code in rl4co/envs/common/base.py
135
136
137
138
139
140
141
142
143
def reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
    """Reset function to call at the beginning of each episode"""
    if batch_size is None:
        batch_size = self.batch_size if td is None else td.batch_size
    if td is None or td.is_empty():
        td = self.generator(batch_size=batch_size)
    batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
    self.to(td.device)
    return super().reset(td, batch_size=batch_size)

get_reward

get_reward(td: TensorDict, actions: Tensor) -> Tensor

Function to compute the reward. Can be called by the agent to compute the reward of the current state This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks

Source code in rl4co/envs/common/base.py
182
183
184
185
186
187
188
def get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor:
    """Function to compute the reward. Can be called by the agent to compute the reward of the current state
    This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks
    """
    if self.check_solution:
        self.check_solution_validity(td, actions)
    return self._get_reward(td, actions)

get_action_mask

get_action_mask(td: TensorDict) -> Tensor

Function to compute the action mask (feasible actions) for the current state Action mask is 1 if the action is feasible, 0 otherwise

Source code in rl4co/envs/common/base.py
197
198
199
200
201
def get_action_mask(self, td: TensorDict) -> torch.Tensor:
    """Function to compute the action mask (feasible actions) for the current state
    Action mask is 1 if the action is feasible, 0 otherwise
    """
    raise NotImplementedError

check_solution_validity

check_solution_validity(
    td: TensorDict, actions: Tensor
) -> None

Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state This is called with the full solution (i.e. all actions) at the end of the episode

Source code in rl4co/envs/common/base.py
209
210
211
212
213
def check_solution_validity(self, td: TensorDict, actions: torch.Tensor) -> None:
    """Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state
    This is called with the full solution (i.e. all actions) at the end of the episode
    """
    raise NotImplementedError

replace_selected_actions

replace_selected_actions(
    cur_actions: Tensor,
    new_actions: Tensor,
    selection_mask: Tensor,
) -> Tensor

Replace selected current actions with updated actions based on selection_mask.

Source code in rl4co/envs/common/base.py
215
216
217
218
219
220
221
222
223
224
def replace_selected_actions(
    self,
    cur_actions: torch.Tensor,
    new_actions: torch.Tensor,
    selection_mask: torch.Tensor,
) -> torch.Tensor:
    """
    Replace selected current actions with updated actions based on `selection_mask`.
    """
    raise NotImplementedError
local_search(
    td: TensorDict, actions: Tensor, **kwargs
) -> Tensor

Function to improve the solution. Can be called by the agent to improve the current state This is called with the full solution (i.e. all actions) at the end of the episode

Source code in rl4co/envs/common/base.py
226
227
228
229
230
231
232
233
234
def local_search(
    self, td: TensorDict, actions: torch.Tensor, **kwargs
) -> torch.Tensor:
    """Function to improve the solution. Can be called by the agent to improve the current state
    This is called with the full solution (i.e. all actions) at the end of the episode
    """
    raise NotImplementedError(
        f"Local is not implemented yet for {self.name} environment"
    )

dataset

dataset(batch_size=[], phase='train', filename=None)

Return a dataset of observations Generates the dataset if it does not exist, otherwise loads it from file

Source code in rl4co/envs/common/base.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def dataset(self, batch_size=[], phase="train", filename=None):
    """Return a dataset of observations
    Generates the dataset if it does not exist, otherwise loads it from file
    """
    if filename is not None:
        log.info(f"Overriding dataset filename from {filename}")
    f = getattr(self, f"{phase}_file") if filename is None else filename
    if f is None:
        if phase != "train":
            log.warning(f"{phase}_file not set. Generating dataset instead")
        td = self.generator(batch_size)
    else:
        log.info(f"Loading {phase} dataset from {f}")
        if phase == "train":
            log.warning(
                "Loading training dataset from file. This may not be desired in RL since "
                "the dataset is fixed and the agent will not be able to explore new states"
            )
        try:
            if isinstance(f, Iterable) and not isinstance(f, str):
                names = getattr(self, f"{phase}_dataloader_names")
                return {
                    name: self.dataset_cls(self.load_data(_f, batch_size))
                    for name, _f in zip(names, f)
                }
            else:
                td = self.load_data(f, batch_size)
        except FileNotFoundError:
            log.error(
                f"Provided file name {f} not found. Make sure to provide a file in the right path first or "
                f"unset {phase}_file to generate data automatically instead"
            )
            td = self.generator(batch_size)

    return self.dataset_cls(td)

transform

transform()

Used for converting TensorDict variables (such as with torch.cat) efficiently https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html By default, we do not need to transform the environment since we use specific embeddings

Source code in rl4co/envs/common/base.py
272
273
274
275
276
277
def transform(self):
    """Used for converting TensorDict variables (such as with torch.cat) efficiently
    https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html
    By default, we do not need to transform the environment since we use specific embeddings
    """
    return self

render

render(*args, **kwargs)

Render the environment

Source code in rl4co/envs/common/base.py
279
280
281
def render(self, *args, **kwargs):
    """Render the environment"""
    raise NotImplementedError

load_data staticmethod

load_data(fpath, batch_size=[])

Dataset loading from file

Source code in rl4co/envs/common/base.py
283
284
285
286
@staticmethod
def load_data(fpath, batch_size=[]):
    """Dataset loading from file"""
    return load_npz_to_tensordict(fpath)

to

to(device)

Override to device method for safety against None device (may be found in TensorDict)

Source code in rl4co/envs/common/base.py
293
294
295
296
297
298
def to(self, device):
    """Override `to` device method for safety against `None` device (may be found in `TensorDict`)"""
    if device is None:
        return self
    else:
        return super().to(device)

solve staticmethod

solve(
    instances: TensorDict,
    max_runtime: float,
    num_procs: int = 1,
    **kwargs
) -> tuple[Tensor, Tensor]

Classical solver for the environment. This is a wrapper for the baselines solver.

Parameters:

  • instances (TensorDict) –

    The instances to solve

  • max_runtime (float) –

    The maximum runtime for the solver

  • num_procs (int, default: 1 ) –

    The number of processes to use

Returns:

Source code in rl4co/envs/common/base.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
@staticmethod
def solve(
    instances: TensorDict,
    max_runtime: float,
    num_procs: int = 1,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Classical solver for the environment. This is a wrapper for the baselines solver.

    Args:
        instances: The instances to solve
        max_runtime: The maximum runtime for the solver
        num_procs: The number of processes to use

    Returns:
        A tuple containing the action and the cost, respectively
    """
    raise NotImplementedError

ImprovementEnvBase

ImprovementEnvBase(**kwargs)

Bases: RL4COEnvBase

Base class for Improvement environments based on RL4CO EnvBase. Note that this class assumes that the solution is stored in a linked list format. Here, if rec[i] = j, it means the node i is connected to node j, i.e., edge i-j is in the solution. For example, if edge 0-1, edge 1-5, edge 2-10 are in the solution, so we have rec[0]=1, rec[1]=5 and rec[2]=10. Kindly see https://github.com/yining043/VRP-DACT/blob/new_version/Play_with_DACT.ipynb for an example at the end for TSP.

Source code in rl4co/envs/common/base.py
344
345
346
347
348
def __init__(
    self,
    **kwargs,
):
    super().__init__(**kwargs)

Utilities

These contain utilities such as the base Generator class and get_sampler.

Generator

Generator(**kwargs)

Base data generator class, to be called with env.generator(batch_size)

Source code in rl4co/envs/common/utils.py
22
23
def __init__(self, **kwargs):
    self.kwargs = kwargs

get_sampler

get_sampler(
    val_name: str,
    distribution: Union[int, float, str, type, Callable],
    low: float = 0,
    high: float = 1.0,
    **kwargs
)

Get the sampler for the variable with the given distribution. If kwargs are passed, they will be parsed e.g. with val_name + _dist_arg (e.g. loc_std for Normal distribution).

Parameters:

  • val_name (str) –

    Name of the variable

  • distribution (Union[int, float, str, type, Callable]) –

    int/float value (as constant distribution), or string with the distribution name (supporting uniform, normal, exponential, and poisson) or PyTorch Distribution type or a callable function that returns a PyTorch Distribution

  • low (float, default: 0 ) –

    Minimum value for the variable, used for Uniform distribution

  • high (float, default: 1.0 ) –

    Maximum value for the variable, used for Uniform distribution

  • kwargs

    Additional arguments for the distribution

Example
sampler_uniform = get_sampler("loc", "uniform", 0, 1)
sampler_normal = get_sampler("loc", "normal", loc_mean=0.5, loc_std=.2)
Source code in rl4co/envs/common/utils.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def get_sampler(
    val_name: str,
    distribution: Union[int, float, str, type, Callable],
    low: float = 0,
    high: float = 1.0,
    **kwargs,
):
    """Get the sampler for the variable with the given distribution.
    If kwargs are passed, they will be parsed e.g. with `val_name` + `_dist_arg` (e.g. `loc_std` for Normal distribution).

    Args:
        val_name: Name of the variable
        distribution: int/float value (as constant distribution), or string with the distribution name (supporting
            uniform, normal, exponential, and poisson) or PyTorch Distribution type or a callable function that
            returns a PyTorch Distribution
        low: Minimum value for the variable, used for Uniform distribution
        high: Maximum value for the variable, used for Uniform distribution
        kwargs: Additional arguments for the distribution

    Example:
        ```python
        sampler_uniform = get_sampler("loc", "uniform", 0, 1)
        sampler_normal = get_sampler("loc", "normal", loc_mean=0.5, loc_std=.2)
        ```
    """
    if isinstance(distribution, (int, float)):
        return Uniform(low=distribution, high=distribution)
    elif distribution == Uniform or distribution == "uniform":
        return Uniform(low=low, high=high)
    elif distribution == Normal or distribution == "normal" or distribution == "gaussian":
        assert (
            kwargs.get(val_name + "_mean", None) is not None
        ), "mean is required for Normal distribution"
        assert (
            kwargs.get(val_name + "_std", None) is not None
        ), "std is required for Normal distribution"
        return Normal(loc=kwargs[val_name + "_mean"], scale=kwargs[val_name + "_std"])
    elif distribution == Exponential or distribution == "exponential":
        assert (
            kwargs.get(val_name + "_rate", None) is not None
        ), "rate is required for Exponential/Poisson distribution"
        return Exponential(rate=kwargs[val_name + "_rate"])
    elif distribution == Poisson or distribution == "poisson":
        assert (
            kwargs.get(val_name + "_rate", None) is not None
        ), "rate is required for Exponential/Poisson distribution"
        return Poisson(rate=kwargs[val_name + "_rate"])
    elif distribution == "center":
        return Uniform(low=(high - low) / 2, high=(high - low) / 2)
    elif distribution == "corner":
        return Uniform(
            low=low, high=low
        )  # todo: should be also `low, high` and any other corner
    elif isinstance(distribution, Callable):
        return distribution(**kwargs)
    elif distribution == "gaussian_mixture":
        return Gaussian_Mixture(num_modes=kwargs["num_modes"], cdist=kwargs["cdist"])
    elif distribution == "cluster":
        return Cluster(kwargs["n_cluster"])
    elif distribution == "mixed":
        return Mixed(kwargs["n_cluster_mix"])
    elif distribution == "mix_distribution":
        return Mix_Distribution(kwargs["n_cluster"], kwargs["n_cluster_mix"])
    elif distribution == "mix_multi_distributions":
        return Mix_Multi_Distributions()
    else:
        raise ValueError(f"Invalid distribution type of {distribution}")

batch_to_scalar

batch_to_scalar(param)

Return first element if in batch. Used for batched parameters that are the same for all elements in the batch.

Source code in rl4co/envs/common/utils.py
103
104
105
106
107
108
109
def batch_to_scalar(param):
    """Return first element if in batch. Used for batched parameters that are the same for all elements in the batch."""
    if len(param.shape) > 0:
        return param[0].item()
    if isinstance(param, torch.Tensor):
        return param.item()
    return param