Skip to content

RL4COLitModule

The RL4COLitModule is a wrapper around PyTorch Lightning's LightningModule that provides additional functionality for RL algorithms. It is the parent class for all RL algorithms in the library.

RL4COLitModule

RL4COLitModule(
    env: RL4COEnvBase,
    policy: Module,
    batch_size: int = 512,
    val_batch_size: list[int] | int = None,
    test_batch_size: list[int] | int = None,
    train_data_size: int = 100000,
    val_data_size: int = 10000,
    test_data_size: int = 10000,
    optimizer: str | Optimizer | partial = "Adam",
    optimizer_kwargs: dict = {"lr": 0.0001},
    lr_scheduler: str | LRScheduler | partial = None,
    lr_scheduler_kwargs: dict = {
        "milestones": [80, 95],
        "gamma": 0.1,
    },
    lr_scheduler_interval: str = "epoch",
    lr_scheduler_monitor: str = "val/reward",
    generate_default_data: bool = False,
    shuffle_train_dataloader: bool = False,
    dataloader_num_workers: int = 0,
    data_dir: str = "data/",
    log_on_step: bool = True,
    metrics: dict = {},
    **litmodule_kwargs
)

Bases: LightningModule

Base class for Lightning modules for RL4CO. This defines the general training loop in terms of RL algorithms. Subclasses should implement mainly the shared_step to define the specific loss functions and optimization routines.

Parameters:

  • env (RL4COEnvBase) –

    RL4CO environment

  • policy (Module) –

    policy network (actor)

  • batch_size (int, default: 512 ) –

    batch size (general one, default used for training)

  • val_batch_size (list[int] | int, default: None ) –

    specific batch size for validation. If None, will use batch_size. If list, will use one for each dataset

  • test_batch_size (list[int] | int, default: None ) –

    specific batch size for testing. If None, will use val_batch_size. If list, will use one for each dataset

  • train_data_size (int, default: 100000 ) –

    size of training dataset for one epoch

  • val_data_size (int, default: 10000 ) –

    size of validation dataset for one epoch

  • test_data_size (int, default: 10000 ) –

    size of testing dataset for one epoch

  • optimizer (str | Optimizer | partial, default: 'Adam' ) –

    optimizer or optimizer name

  • optimizer_kwargs (dict, default: {'lr': 0.0001} ) –

    optimizer kwargs

  • lr_scheduler (str | LRScheduler | partial, default: None ) –

    learning rate scheduler or learning rate scheduler name

  • lr_scheduler_kwargs (dict, default: {'milestones': [80, 95], 'gamma': 0.1} ) –

    learning rate scheduler kwargs

  • lr_scheduler_interval (str, default: 'epoch' ) –

    learning rate scheduler interval

  • lr_scheduler_monitor (str, default: 'val/reward' ) –

    learning rate scheduler monitor

  • generate_default_data (bool, default: False ) –

    whether to generate default datasets, filling up the data directory

  • shuffle_train_dataloader (bool, default: False ) –

    whether to shuffle training dataloader. Default is False since we recreate dataset every epoch

  • dataloader_num_workers (int, default: 0 ) –

    number of workers for dataloader

  • data_dir (str, default: 'data/' ) –

    data directory

  • metrics (dict, default: {} ) –

    metrics

  • litmodule_kwargs

    kwargs for LightningModule

Methods:

  • instantiate_metrics

    Dictionary of metrics to be logged at each phase

  • setup

    Base LightningModule setup method. This will setup the datasets and dataloaders

  • setup_loggers

    Log all hyperparameters except those in nn.Module

  • post_setup_hook

    Hook to be called after setup. Can be used to set up subclasses without overriding setup

  • configure_optimizers

    Args:

  • log_metrics

    Log metrics to logger and progress bar

  • forward

    Forward pass for the model. Simple wrapper around policy. Uses env from the module if not provided.

  • shared_step

    Shared step between train/val/test. To be implemented in subclass

  • on_train_epoch_end

    Called at the end of the training epoch. This can be used for instance to update the train dataset

  • wrap_dataset

    Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to

Source code in rl4co/models/rl/common/base.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    env: RL4COEnvBase,
    policy: nn.Module,
    batch_size: int = 512,
    val_batch_size: list[int] | int = None,
    test_batch_size: list[int] | int = None,
    train_data_size: int = 100_000,
    val_data_size: int = 10_000,
    test_data_size: int = 10_000,
    optimizer: str | torch.optim.Optimizer | partial = "Adam",
    optimizer_kwargs: dict = {"lr": 1e-4},
    lr_scheduler: str | torch.optim.lr_scheduler.LRScheduler | partial = None,
    lr_scheduler_kwargs: dict = {
        "milestones": [80, 95],
        "gamma": 0.1,
    },
    lr_scheduler_interval: str = "epoch",
    lr_scheduler_monitor: str = "val/reward",
    generate_default_data: bool = False,
    shuffle_train_dataloader: bool = False,
    dataloader_num_workers: int = 0,
    data_dir: str = "data/",
    log_on_step: bool = True,
    metrics: dict = {},
    **litmodule_kwargs,
):
    super().__init__(**litmodule_kwargs)

    # This line ensures params passed to LightningModule will be saved to ckpt
    # it also allows to access params with 'self.hparams' attribute
    # Note: we will send to logger with `self.logger.save_hyperparams` in `setup`
    self.save_hyperparameters(logger=False)

    self.env = env
    self.policy = policy

    self.instantiate_metrics(metrics)
    self.log_on_step = log_on_step

    self.data_cfg = {
        "batch_size": batch_size,
        "val_batch_size": val_batch_size,
        "test_batch_size": test_batch_size,
        "generate_default_data": generate_default_data,
        "data_dir": data_dir,
        "train_data_size": train_data_size,
        "val_data_size": val_data_size,
        "test_data_size": test_data_size,
    }

    self._optimizer_name_or_cls: str | torch.optim.Optimizer = optimizer
    self.optimizer_kwargs: dict = optimizer_kwargs
    self._lr_scheduler_name_or_cls: str | torch.optim.lr_scheduler.LRScheduler = (
        lr_scheduler
    )
    self.lr_scheduler_kwargs: dict = lr_scheduler_kwargs
    self.lr_scheduler_interval: str = lr_scheduler_interval
    self.lr_scheduler_monitor: str = lr_scheduler_monitor

    self.shuffle_train_dataloader = shuffle_train_dataloader
    self.dataloader_num_workers = dataloader_num_workers

instantiate_metrics

instantiate_metrics(metrics: dict)

Dictionary of metrics to be logged at each phase

Source code in rl4co/models/rl/common/base.py
111
112
113
114
115
116
117
118
119
def instantiate_metrics(self, metrics: dict):
    """Dictionary of metrics to be logged at each phase"""

    if not metrics:
        log.info("No metrics specified, using default")
    self.train_metrics = metrics.get("train", ["loss", "reward"])
    self.val_metrics = metrics.get("val", ["reward"])
    self.test_metrics = metrics.get("test", ["reward"])
    self.log_on_step = metrics.get("log_on_step", True)

setup

setup(stage='fit')

Base LightningModule setup method. This will setup the datasets and dataloaders

Note

We also send to the loggers all hyperparams that are not nn.Module (i.e. the policy). Apparently PyTorch Lightning does not do this by default.

Source code in rl4co/models/rl/common/base.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def setup(self, stage="fit"):
    """Base LightningModule setup method. This will setup the datasets and dataloaders

    Note:
        We also send to the loggers all hyperparams that are not `nn.Module` (i.e. the policy).
        Apparently PyTorch Lightning does not do this by default.
    """

    log.info("Setting up batch sizes for train/val/test")
    train_bs, val_bs, test_bs = (
        self.data_cfg["batch_size"],
        self.data_cfg["val_batch_size"],
        self.data_cfg["test_batch_size"],
    )
    self.train_batch_size = train_bs
    self.val_batch_size = train_bs if val_bs is None else val_bs
    self.test_batch_size = self.val_batch_size if test_bs is None else test_bs

    if self.data_cfg["generate_default_data"]:
        log.info(
            "Generating default datasets. If found, they will not be overwritten"
        )
        generate_default_datasets(data_dir=self.data_cfg["data_dir"])

    log.info("Setting up datasets")
    self.train_dataset = self.wrap_dataset(
        self.env.dataset(self.data_cfg["train_data_size"], phase="train")
    )
    self.val_dataset = self.env.dataset(self.data_cfg["val_data_size"], phase="val")
    self.test_dataset = self.env.dataset(
        self.data_cfg["test_data_size"], phase="test"
    )
    self.dataloader_names = None
    self.setup_loggers()
    self.post_setup_hook()

setup_loggers

setup_loggers()

Log all hyperparameters except those in nn.Module

Source code in rl4co/models/rl/common/base.py
157
158
159
160
161
162
163
164
165
166
def setup_loggers(self):
    """Log all hyperparameters except those in `nn.Module`"""
    if self.loggers is not None:
        hparams_save = {
            k: v for k, v in self.hparams.items() if not isinstance(v, nn.Module)
        }
        for logger in self.loggers:
            logger.log_hyperparams(hparams_save)
            logger.log_graph(self)
            logger.save()

post_setup_hook

post_setup_hook()

Hook to be called after setup. Can be used to set up subclasses without overriding setup

Source code in rl4co/models/rl/common/base.py
168
169
170
def post_setup_hook(self):
    """Hook to be called after setup. Can be used to set up subclasses without overriding `setup`"""
    pass

configure_optimizers

configure_optimizers(parameters=None)

Parameters:

  • parameters

    parameters to be optimized. If None, will use self.parameters(), i.e. all parameters

Source code in rl4co/models/rl/common/base.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def configure_optimizers(self, parameters=None):
    """
    Args:
        parameters: parameters to be optimized. If None, will use `self.parameters()`, i.e. all parameters
    """

    if parameters is None:
        parameters = self.parameters()

    log.info(f"Instantiating optimizer <{self._optimizer_name_or_cls}>")
    if isinstance(self._optimizer_name_or_cls, str):
        optimizer = create_optimizer(
            parameters, self._optimizer_name_or_cls, **self.optimizer_kwargs
        )
    elif isinstance(self._optimizer_name_or_cls, partial):
        optimizer = self._optimizer_name_or_cls(parameters, **self.optimizer_kwargs)
    else:  # User-defined optimizer
        opt_cls = self._optimizer_name_or_cls
        optimizer = opt_cls(parameters, **self.optimizer_kwargs)
        assert isinstance(optimizer, torch.optim.Optimizer)

    # instantiate lr scheduler
    if self._lr_scheduler_name_or_cls is None:
        return optimizer
    else:
        log.info(f"Instantiating LR scheduler <{self._lr_scheduler_name_or_cls}>")
        if isinstance(self._lr_scheduler_name_or_cls, str):
            scheduler = create_scheduler(
                optimizer, self._lr_scheduler_name_or_cls, **self.lr_scheduler_kwargs
            )
        elif isinstance(self._lr_scheduler_name_or_cls, partial):
            scheduler = self._lr_scheduler_name_or_cls(
                optimizer, **self.lr_scheduler_kwargs
            )
        else:  # User-defined scheduler
            scheduler_cls = self._lr_scheduler_name_or_cls
            scheduler = scheduler_cls(optimizer, **self.lr_scheduler_kwargs)
            assert isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler)
        return [optimizer], {
            "scheduler": scheduler,
            "interval": self.lr_scheduler_interval,
            "monitor": self.lr_scheduler_monitor,
        }

log_metrics

log_metrics(
    metric_dict: dict,
    phase: str,
    dataloader_idx: int | None = None,
)

Log metrics to logger and progress bar

Source code in rl4co/models/rl/common/base.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def log_metrics(
    self, metric_dict: dict, phase: str, dataloader_idx: int | None = None
):
    """Log metrics to logger and progress bar"""
    metrics = getattr(self, f"{phase}_metrics")
    dataloader_name = ""
    if dataloader_idx is not None and self.dataloader_names is not None:
        dataloader_name = "/" + self.dataloader_names[dataloader_idx]
    metrics = {
        f"{phase}/{k}{dataloader_name}": (
            v.mean() if isinstance(v, torch.Tensor) else v
        )
        for k, v in metric_dict.items()
        if k in metrics
    }
    log_on_step = self.log_on_step if phase == "train" else False
    on_epoch = False if phase == "train" else True
    self.log_dict(
        metrics,
        on_step=log_on_step,
        on_epoch=on_epoch,
        prog_bar=True,
        sync_dist=True,
        add_dataloader_idx=False,  # we add manually above
    )
    return metrics

forward

forward(td, **kwargs)

Forward pass for the model. Simple wrapper around policy. Uses env from the module if not provided.

Source code in rl4co/models/rl/common/base.py
243
244
245
246
247
248
249
250
def forward(self, td, **kwargs):
    """Forward pass for the model. Simple wrapper around `policy`. Uses `env` from the module if not provided."""
    if kwargs.get("env", None) is None:
        env = self.env
    else:
        log.info("Using env from kwargs")
        env = kwargs.pop("env")
    return self.policy(td, env, **kwargs)

shared_step

shared_step(
    batch: Any, batch_idx: int, phase: str, **kwargs
)

Shared step between train/val/test. To be implemented in subclass

Source code in rl4co/models/rl/common/base.py
252
253
254
def shared_step(self, batch: Any, batch_idx: int, phase: str, **kwargs):
    """Shared step between train/val/test. To be implemented in subclass"""
    raise NotImplementedError("Shared step is required to implemented in subclass")

on_train_epoch_end

on_train_epoch_end()

Called at the end of the training epoch. This can be used for instance to update the train dataset with new data (which is the case in RL).

Source code in rl4co/models/rl/common/base.py
281
282
283
284
285
286
287
288
289
290
def on_train_epoch_end(self):
    """Called at the end of the training epoch. This can be used for instance to update the train dataset
    with new data (which is the case in RL).
    """
    # Only update if not in the first epoch
    # If last epoch, we don't need to update since we will not use the dataset anymore
    if self.current_epoch < self.trainer.max_epochs - 1:
        log.info("Generating training dataset for next epoch...")
        train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train")
        self.train_dataset = self.wrap_dataset(train_dataset)

wrap_dataset

wrap_dataset(dataset)

Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to collect the greedy rollout baseline outputs.

Source code in rl4co/models/rl/common/base.py
292
293
294
295
296
def wrap_dataset(self, dataset):
    """Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to
    collect the greedy rollout baseline outputs.
    """
    return dataset

Transductive Learning

Transductive models are learning algorithms that optimize on a specific instance. They improve solutions by updating policy parameters \(\theta\), which means that we are running optimization (backprop) at test time. Transductive learning can be performed with different policies: for example EAS updates (a part of) AR policies parameters to obtain better solutions, but I guess there are ways (or papers out there I don't know of) that optimize at test time.

Tip

You may refer to the definition of inductive vs transductive RL . In inductive RL, we train to generalize to new instances. In transductive RL we train (or finetune) to solve only specific ones.

Classes:

  • TransductiveModel

    Base class for transductive algorithms (i.e. that optimize policy parameters for

TransductiveModel

TransductiveModel(
    env,
    policy,
    dataset: Dataset | str,
    batch_size: int = 1,
    max_iters: int = 100,
    max_runtime: Optional[int] = 86400,
    save_path: Optional[str] = None,
    **kwargs
)

Bases: RL4COLitModule

Base class for transductive algorithms (i.e. that optimize policy parameters for specific instances, see https://en.wikipedia.org/wiki/Transduction_(machine_learning)). Transductive algorithms are used online to find better solutions for a given dataset, i.e. given a policy, improve (a part of) its parameters such that the policy performs better on the given dataset.

Note

By default, we use manual optimization to handle the search.

Parameters:

  • env

    RL4CO environment

  • policy

    policy network

  • dataset (Dataset | str) –

    dataset to use for training

  • batch_size (int, default: 1 ) –

    batch size

  • max_iters (int, default: 100 ) –

    maximum number of iterations

  • max_runtime (Optional[int], default: 86400 ) –

    maximum runtime in seconds

  • save_path (Optional[str], default: None ) –

    path to save the model

  • **kwargs

    additional arguments

Methods:

Source code in rl4co/models/common/transductive/base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    env,
    policy,
    dataset: Dataset | str,
    batch_size: int = 1,
    max_iters: int = 100,
    max_runtime: Optional[int] = 86_400,
    save_path: Optional[str] = None,
    **kwargs,
):
    self.save_hyperparameters(logger=False)
    super().__init__(env, policy, **kwargs)
    self.dataset = dataset
    self.automatic_optimization = False  # we optimize manually

setup

setup(stage='fit')

Setup the dataset and attributes. The RL4COLitModulebase class automatically loads the data.

Source code in rl4co/models/common/transductive/base.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def setup(self, stage="fit"):
    """Setup the dataset and attributes.
    The RL4COLitModulebase class automatically loads the data.
    """
    if isinstance(self.dataset, str):
        # load from file
        self.dataset = self.env.dataset(filename=self.dataset)

    # Set all datasets and batch size as the same
    for split in ["train", "val", "test"]:
        setattr(self, f"{split}_dataset", self.dataset)
        setattr(self, f"{split}_batch_size", self.hparams.batch_size)

    # Setup loggers
    self.setup_loggers()

on_train_batch_start

on_train_batch_start(batch: Any, batch_idx: int)

Called before training (i.e. search) for a new batch begins. This can be used to perform changes to the model or optimizer at the start of each batch.

Source code in rl4co/models/common/transductive/base.py
64
65
66
67
68
def on_train_batch_start(self, batch: Any, batch_idx: int):
    """Called before training (i.e. search) for a new batch begins.
    This can be used to perform changes to the model or optimizer at the start of each batch.
    """
    pass  # Implement in subclass

training_step abstractmethod

training_step(batch, batch_idx)

Main search loop. We use the training step to effectively adapt to a batch of instances.

Source code in rl4co/models/common/transductive/base.py
70
71
72
73
@abc.abstractmethod
def training_step(self, batch, batch_idx):
    """Main search loop. We use the training step to effectively adapt to a `batch` of instances."""
    raise NotImplementedError("Implement in subclass")

on_train_batch_end

on_train_batch_end(
    outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None

Called when the train batch ends. This can be used for instance for logging or clearing cache.

Source code in rl4co/models/common/transductive/base.py
75
76
77
78
79
80
81
def on_train_batch_end(
    self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
    """Called when the train batch ends. This can be used for
    instance for logging or clearing cache.
    """
    pass  # Implement in subclass

on_train_epoch_end

on_train_epoch_end() -> None

Called when the train ends.

Source code in rl4co/models/common/transductive/base.py
83
84
85
def on_train_epoch_end(self) -> None:
    """Called when the train ends."""
    pass  # Implement in subclass

validation_step

validation_step(batch: Any, batch_idx: int)

Not used during search

Source code in rl4co/models/common/transductive/base.py
87
88
89
def validation_step(self, batch: Any, batch_idx: int):
    """Not used during search"""
    pass

test_step

test_step(batch: Any, batch_idx: int)

Not used during search

Source code in rl4co/models/common/transductive/base.py
91
92
93
def test_step(self, batch: Any, batch_idx: int):
    """Not used during search"""
    pass