RL4COLitModule¶
The RL4COLitModule
is a wrapper around PyTorch Lightning's LightningModule
that provides additional functionality for RL algorithms. It is the parent class for all RL algorithms in the library.
RL4COLitModule
¶
RL4COLitModule(
env: RL4COEnvBase,
policy: Module,
batch_size: int = 512,
val_batch_size: list[int] | int = None,
test_batch_size: list[int] | int = None,
train_data_size: int = 100000,
val_data_size: int = 10000,
test_data_size: int = 10000,
optimizer: str | Optimizer | partial = "Adam",
optimizer_kwargs: dict = {"lr": 0.0001},
lr_scheduler: str | LRScheduler | partial = None,
lr_scheduler_kwargs: dict = {
"milestones": [80, 95],
"gamma": 0.1,
},
lr_scheduler_interval: str = "epoch",
lr_scheduler_monitor: str = "val/reward",
generate_default_data: bool = False,
shuffle_train_dataloader: bool = False,
dataloader_num_workers: int = 0,
data_dir: str = "data/",
log_on_step: bool = True,
metrics: dict = {},
**litmodule_kwargs
)
Bases: LightningModule
Base class for Lightning modules for RL4CO. This defines the general training loop in terms of
RL algorithms. Subclasses should implement mainly the shared_step
to define the specific
loss functions and optimization routines.
Parameters:
-
env
(RL4COEnvBase
) –RL4CO environment
-
policy
(Module
) –policy network (actor)
-
batch_size
(int
, default:512
) –batch size (general one, default used for training)
-
val_batch_size
(list[int] | int
, default:None
) –specific batch size for validation. If None, will use
batch_size
. If list, will use one for each dataset -
test_batch_size
(list[int] | int
, default:None
) –specific batch size for testing. If None, will use
val_batch_size
. If list, will use one for each dataset -
train_data_size
(int
, default:100000
) –size of training dataset for one epoch
-
val_data_size
(int
, default:10000
) –size of validation dataset for one epoch
-
test_data_size
(int
, default:10000
) –size of testing dataset for one epoch
-
optimizer
(str | Optimizer | partial
, default:'Adam'
) –optimizer or optimizer name
-
optimizer_kwargs
(dict
, default:{'lr': 0.0001}
) –optimizer kwargs
-
lr_scheduler
(str | LRScheduler | partial
, default:None
) –learning rate scheduler or learning rate scheduler name
-
lr_scheduler_kwargs
(dict
, default:{'milestones': [80, 95], 'gamma': 0.1}
) –learning rate scheduler kwargs
-
lr_scheduler_interval
(str
, default:'epoch'
) –learning rate scheduler interval
-
lr_scheduler_monitor
(str
, default:'val/reward'
) –learning rate scheduler monitor
-
generate_default_data
(bool
, default:False
) –whether to generate default datasets, filling up the data directory
-
shuffle_train_dataloader
(bool
, default:False
) –whether to shuffle training dataloader. Default is False since we recreate dataset every epoch
-
dataloader_num_workers
(int
, default:0
) –number of workers for dataloader
-
data_dir
(str
, default:'data/'
) –data directory
-
metrics
(dict
, default:{}
) –metrics
-
litmodule_kwargs
–kwargs for
LightningModule
Methods:
-
instantiate_metrics
–Dictionary of metrics to be logged at each phase
-
setup
–Base LightningModule setup method. This will setup the datasets and dataloaders
-
setup_loggers
–Log all hyperparameters except those in
nn.Module
-
post_setup_hook
–Hook to be called after setup. Can be used to set up subclasses without overriding
setup
-
configure_optimizers
–Args:
-
log_metrics
–Log metrics to logger and progress bar
-
forward
–Forward pass for the model. Simple wrapper around
policy
. Usesenv
from the module if not provided. -
shared_step
–Shared step between train/val/test. To be implemented in subclass
-
on_train_epoch_end
–Called at the end of the training epoch. This can be used for instance to update the train dataset
-
wrap_dataset
–Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to
Source code in rl4co/models/rl/common/base.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
|
instantiate_metrics
¶
instantiate_metrics(metrics: dict)
Dictionary of metrics to be logged at each phase
Source code in rl4co/models/rl/common/base.py
111 112 113 114 115 116 117 118 119 |
|
setup
¶
setup(stage='fit')
Base LightningModule setup method. This will setup the datasets and dataloaders
Note
We also send to the loggers all hyperparams that are not nn.Module
(i.e. the policy).
Apparently PyTorch Lightning does not do this by default.
Source code in rl4co/models/rl/common/base.py
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
|
setup_loggers
¶
setup_loggers()
Log all hyperparameters except those in nn.Module
Source code in rl4co/models/rl/common/base.py
157 158 159 160 161 162 163 164 165 166 |
|
post_setup_hook
¶
post_setup_hook()
Hook to be called after setup. Can be used to set up subclasses without overriding setup
Source code in rl4co/models/rl/common/base.py
168 169 170 |
|
configure_optimizers
¶
configure_optimizers(parameters=None)
Parameters:
-
parameters
–parameters to be optimized. If None, will use
self.parameters()
, i.e. all parameters
Source code in rl4co/models/rl/common/base.py
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
|
log_metrics
¶
Log metrics to logger and progress bar
Source code in rl4co/models/rl/common/base.py
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
|
forward
¶
forward(td, **kwargs)
Forward pass for the model. Simple wrapper around policy
. Uses env
from the module if not provided.
Source code in rl4co/models/rl/common/base.py
243 244 245 246 247 248 249 250 |
|
shared_step
¶
Shared step between train/val/test. To be implemented in subclass
Source code in rl4co/models/rl/common/base.py
252 253 254 |
|
on_train_epoch_end
¶
on_train_epoch_end()
Called at the end of the training epoch. This can be used for instance to update the train dataset with new data (which is the case in RL).
Source code in rl4co/models/rl/common/base.py
281 282 283 284 285 286 287 288 289 290 |
|
wrap_dataset
¶
wrap_dataset(dataset)
Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to collect the greedy rollout baseline outputs.
Source code in rl4co/models/rl/common/base.py
292 293 294 295 296 |
|
Transductive Learning¶
Transductive models are learning algorithms that optimize on a specific instance. They improve solutions by updating policy parameters \(\theta\), which means that we are running optimization (backprop) at test time. Transductive learning can be performed with different policies: for example EAS updates (a part of) AR policies parameters to obtain better solutions, but I guess there are ways (or papers out there I don't know of) that optimize at test time.
Tip
You may refer to the definition of inductive vs transductive RL . In inductive RL, we train to generalize to new instances. In transductive RL we train (or finetune) to solve only specific ones.
Classes:
-
TransductiveModel
–Base class for transductive algorithms (i.e. that optimize policy parameters for
TransductiveModel
¶
TransductiveModel(
env,
policy,
dataset: Dataset | str,
batch_size: int = 1,
max_iters: int = 100,
max_runtime: Optional[int] = 86400,
save_path: Optional[str] = None,
**kwargs
)
Bases: RL4COLitModule
Base class for transductive algorithms (i.e. that optimize policy parameters for specific instances, see https://en.wikipedia.org/wiki/Transduction_(machine_learning)). Transductive algorithms are used online to find better solutions for a given dataset, i.e. given a policy, improve (a part of) its parameters such that the policy performs better on the given dataset.
Note
By default, we use manual optimization to handle the search.
Parameters:
-
env
–RL4CO environment
-
policy
–policy network
-
dataset
(Dataset | str
) –dataset to use for training
-
batch_size
(int
, default:1
) –batch size
-
max_iters
(int
, default:100
) –maximum number of iterations
-
max_runtime
(Optional[int]
, default:86400
) –maximum runtime in seconds
-
save_path
(Optional[str]
, default:None
) –path to save the model
-
**kwargs
–additional arguments
Methods:
-
setup
–Setup the dataset and attributes.
-
on_train_batch_start
–Called before training (i.e. search) for a new batch begins.
-
training_step
–Main search loop. We use the training step to effectively adapt to a
batch
of instances. -
on_train_batch_end
–Called when the train batch ends. This can be used for
-
on_train_epoch_end
–Called when the train ends.
-
validation_step
–Not used during search
-
test_step
–Not used during search
Source code in rl4co/models/common/transductive/base.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
setup
¶
setup(stage='fit')
Setup the dataset and attributes. The RL4COLitModulebase class automatically loads the data.
Source code in rl4co/models/common/transductive/base.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|
on_train_batch_start
¶
Called before training (i.e. search) for a new batch begins. This can be used to perform changes to the model or optimizer at the start of each batch.
Source code in rl4co/models/common/transductive/base.py
64 65 66 67 68 |
|
training_step
abstractmethod
¶
training_step(batch, batch_idx)
Main search loop. We use the training step to effectively adapt to a batch
of instances.
Source code in rl4co/models/common/transductive/base.py
70 71 72 73 |
|
on_train_batch_end
¶
Called when the train batch ends. This can be used for instance for logging or clearing cache.
Source code in rl4co/models/common/transductive/base.py
75 76 77 78 79 80 81 |
|
on_train_epoch_end
¶
on_train_epoch_end() -> None
Called when the train ends.
Source code in rl4co/models/common/transductive/base.py
83 84 85 |
|
validation_step
¶
Not used during search
Source code in rl4co/models/common/transductive/base.py
87 88 89 |
|