Training: Checkpoints, Logging, and Callbacks¶
In this notebook we will cover a quickstart training of the Split Delivery Vehicle Routing Problem (SDVRP), with some additional comments along the way. The SDVRP is a variant of the VRP where a vehicle can deliver a part of the demand of a customer and return later to deliver the rest of the demand.
Installation¶
Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!
Note: You may need to restart the runtime in Colab after this
# !pip install rl4co
## NOTE: to install latest version from Github (may be unstable) install from source instead:
# !pip install git+https://github.com/ai4co/rl4co.git
Imports¶
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from rl4co.envs import SDVRPEnv
from rl4co.models.zoo import AttentionModel
from rl4co.utils.trainer import RL4COTrainer
Main Setup¶
Environment, Model and LitModule¶
# RL4CO env based on TorchRL
env = SDVRPEnv(generator_params=dict(num_loc=20))
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env,
baseline='rollout',
train_data_size=100_000, # really small size for demo
val_data_size=10_000)
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
Test greedy rollout with untrained model and plot¶
# Greedy rollouts over untrained policy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
policy = model.policy.to(device)
out = policy(td_init.clone(), env, phase="test", decode_type="greedy")
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['29.45', '14.26', '21.15']
Training¶
# Checkpointing callback: save models when validation reward improves
checkpoint_callback = ModelCheckpoint( dirpath="checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
monitor="val/reward", # monitor validation reward
mode="max") # maximize validation reward
# Print model summary
rich_model_summary = RichModelSummary(max_depth=3)
# Callbacks list
callbacks = [checkpoint_callback, rich_model_summary]
We make sure we're logged into W&B so that our experiments can be associated with our account. You may comment the below line if you don't want to use it.
# import wandb
# wandb.login()
## Comment following two lines if you don't want logging
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project="rl4co", name="sdvrp-am")
## Keep below if you don't want logging
# logger = None
Trainer¶
The RL4CO trainer is a wrapper around PyTorch Lightning's Trainer
class which adds some functionality and more efficient defaults
The Trainer handles the logging, checkpointing and more for you.
from rl4co.utils.trainer import RL4COTrainer
trainer = RL4COTrainer(
max_epochs=2,
accelerator="gpu",
devices=1,
logger=logger,
callbacks=callbacks,
)
Using 16bit Automatic Mixed Precision (AMP) Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback. GPU available: True (cuda), used: True Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
Fit the model¶
trainer.fit(model)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: silab-kaist. Use `wandb login --relogin` to force relogin /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/wandb/sdk/lib/ipython.py:77: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display from IPython.core.display import HTML, display # type: ignore
./wandb/run-20240428_182146-xcgdzio4
val_file not set. Generating dataset instead test_file not set. Generating dataset instead LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ 0 │ env │ SDVRPEnv │ 0 │ │ 1 │ policy │ AttentionModelPolicy │ 694 K │ │ 2 │ policy.encoder │ AttentionModelEncoder │ 595 K │ │ 3 │ policy.encoder.init_embedding │ VRPInitEmbedding │ 896 │ │ 4 │ policy.encoder.net │ GraphAttentionNetwork │ 594 K │ │ 5 │ policy.decoder │ AttentionModelDecoder │ 98.8 K │ │ 6 │ policy.decoder.context_embedding │ VRPContext │ 16.5 K │ │ 7 │ policy.decoder.dynamic_embedding │ SDVRPDynamicEmbedding │ 384 │ │ 8 │ policy.decoder.pointer │ PointerAttention │ 16.4 K │ │ 9 │ policy.decoder.project_node_embeddings │ Linear │ 49.2 K │ │ 10 │ policy.decoder.project_fixed_context │ Linear │ 16.4 K │ │ 11 │ baseline │ WarmupBaseline │ 694 K │ │ 12 │ baseline.baseline │ RolloutBaseline │ 694 K │ │ 13 │ baseline.baseline.policy │ AttentionModelPolicy │ 694 K │ │ 14 │ baseline.warmup_baseline │ ExponentialBaseline │ 0 │ └────┴────────────────────────────────────────┴───────────────────────┴────────┘
Trainable params: 1.4 M Non-trainable params: 0 Total params: 1.4 M Total estimated model params size (MB): 5
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=2` reached.
Testing¶
Plotting¶
Here we plot the solution (greedy rollout) of the trained policy to the initial problem
# Greedy rollouts over trained model (same states as previous plot)
policy = model.policy.to(device)
out = policy(td_init.clone(), env, phase="test", decode_type="greedy")
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['9.12', '7.16', '9.55']
Test function¶
By default, the dataset is generated or loaded by the environment. You may load a dataset by setting test_file
during the env config:
env = SDVRPEnv(
...
test_file="path/to/test/file"
)
In this case, we test directly on the generated test dataset
trainer.test(model)
val_file not set. Generating dataset instead test_file not set. Generating dataset instead LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1] /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
Testing: | | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test/reward │ -7.363526344299316 │ └───────────────────────────┴───────────────────────────┘
[{'test/reward': -7.363526344299316}]
Test generalization to new dataset¶
Here we can load a new dataset (with 50 nodes) and test the trained model on it
# Test generalization to 50 nodes (not going to be great due to few epochs, but hey)
env = SDVRPEnv(generator_params=dict(num_loc=50))
# Generate data (100) and set as test dataset
new_dataset = env.dataset(50)
dataloader = model._dataloader(new_dataset, batch_size=100)
Plotting generalization¶
# Greedy rollouts over trained policy (same states as previous plot, with 20 nodes)
init_states = next(iter(dataloader))[:3]
td_init_generalization = env.reset(init_states).to(device)
policy = model.policy.to(device)
out = policy(td_init_generalization.clone(), env, phase="test", decode_type="greedy")
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init_generalization, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['11.84', '12.49', '12.20']
Loading model¶
Thanks to PyTorch Lightning, we can easily save and load a model to and from a checkpoint! This is declared in the Trainer
using the model checkpoint callback. For example, we can load the last model via the last.ckpt
file located in the folder we specified in the Trainer
.
Checkpointing¶
# Environment, Model, and Lightning Module (reinstantiate from scratch)
model = AttentionModel(env,
baseline="rollout",
train_data_size=100_000,
test_data_size=10_000,
optimizer_kwargs={'lr': 1e-4}
)
# Note that by default, Lightning will call checkpoints from newer runs with "-v{version}" suffix
# unless you specify the checkpoint path explicitly
new_model_checkpoint = AttentionModel.load_from_checkpoint("checkpoints/last.ckpt", strict=False)
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/core/saving.py:188: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.policy.encoder.init_embedding.init_embed.weight', 'baseline.baseline.policy.encoder.init_embedding.init_embed.bias', 'baseline.baseline.policy.encoder.init_embedding.init_embed_depot.weight', 'baseline.baseline.policy.encoder.init_embedding.init_embed_depot.bias', 'baseline.baseline.policy.encoder.net.layers.0.0.module.Wqkv.weight', 'baseline.baseline.policy.encoder.net.layers.0.0.module.Wqkv.bias', 'baseline.baseline.policy.encoder.net.layers.0.0.module.out_proj.weight', 'baseline.baseline.policy.encoder.net.layers.0.0.module.out_proj.bias', 'baseline.baseline.policy.encoder.net.layers.0.1.normalizer.weight', 'baseline.baseline.policy.encoder.net.layers.0.1.normalizer.bias', 'baseline.baseline.policy.encoder.net.layers.0.1.normalizer.running_mean', 'baseline.baseline.policy.encoder.net.layers.0.1.normalizer.running_var', 'baseline.baseline.policy.encoder.net.layers.0.1.normalizer.num_batches_tracked', 'baseline.baseline.policy.encoder.net.layers.0.2.module.0.weight', 'baseline.baseline.policy.encoder.net.layers.0.2.module.0.bias', 'baseline.baseline.policy.encoder.net.layers.0.2.module.2.weight', 'baseline.baseline.policy.encoder.net.layers.0.2.module.2.bias', 'baseline.baseline.policy.encoder.net.layers.0.3.normalizer.weight', 'baseline.baseline.policy.encoder.net.layers.0.3.normalizer.bias', 'baseline.baseline.policy.encoder.net.layers.0.3.normalizer.running_mean', 'baseline.baseline.policy.encoder.net.layers.0.3.normalizer.running_var', 'baseline.baseline.policy.encoder.net.layers.0.3.normalizer.num_batches_tracked', 'baseline.baseline.policy.encoder.net.layers.1.0.module.Wqkv.weight', 'baseline.baseline.policy.encoder.net.layers.1.0.module.Wqkv.bias', 'baseline.baseline.policy.encoder.net.layers.1.0.module.out_proj.weight', 'baseline.baseline.policy.encoder.net.layers.1.0.module.out_proj.bias', 'baseline.baseline.policy.encoder.net.layers.1.1.normalizer.weight', 'baseline.baseline.policy.encoder.net.layers.1.1.normalizer.bias', 'baseline.baseline.policy.encoder.net.layers.1.1.normalizer.running_mean', 'baseline.baseline.policy.encoder.net.layers.1.1.normalizer.running_var', 'baseline.baseline.policy.encoder.net.layers.1.1.normalizer.num_batches_tracked', 'baseline.baseline.policy.encoder.net.layers.1.2.module.0.weight', 'baseline.baseline.policy.encoder.net.layers.1.2.module.0.bias', 'baseline.baseline.policy.encoder.net.layers.1.2.module.2.weight', 'baseline.baseline.policy.encoder.net.layers.1.2.module.2.bias', 'baseline.baseline.policy.encoder.net.layers.1.3.normalizer.weight', 'baseline.baseline.policy.encoder.net.layers.1.3.normalizer.bias', 'baseline.baseline.policy.encoder.net.layers.1.3.normalizer.running_mean', 'baseline.baseline.policy.encoder.net.layers.1.3.normalizer.running_var', 'baseline.baseline.policy.encoder.net.layers.1.3.normalizer.num_batches_tracked', 'baseline.baseline.policy.encoder.net.layers.2.0.module.Wqkv.weight', 'baseline.baseline.policy.encoder.net.layers.2.0.module.Wqkv.bias', 'baseline.baseline.policy.encoder.net.layers.2.0.module.out_proj.weight', 'baseline.baseline.policy.encoder.net.layers.2.0.module.out_proj.bias', 'baseline.baseline.policy.encoder.net.layers.2.1.normalizer.weight', 'baseline.baseline.policy.encoder.net.layers.2.1.normalizer.bias', 'baseline.baseline.policy.encoder.net.layers.2.1.normalizer.running_mean', 'baseline.baseline.policy.encoder.net.layers.2.1.normalizer.running_var', 'baseline.baseline.policy.encoder.net.layers.2.1.normalizer.num_batches_tracked', 'baseline.baseline.policy.encoder.net.layers.2.2.module.0.weight', 'baseline.baseline.policy.encoder.net.layers.2.2.module.0.bias', 'baseline.baseline.policy.encoder.net.layers.2.2.module.2.weight', 'baseline.baseline.policy.encoder.net.layers.2.2.module.2.bias', 'baseline.baseline.policy.encoder.net.layers.2.3.normalizer.weight', 'baseline.baseline.policy.encoder.net.layers.2.3.normalizer.bias', 'baseline.baseline.policy.encoder.net.layers.2.3.normalizer.running_mean', 'baseline.baseline.policy.encoder.net.layers.2.3.normalizer.running_var', 'baseline.baseline.policy.encoder.net.layers.2.3.normalizer.num_batches_tracked', 'baseline.baseline.policy.decoder.context_embedding.project_context.weight', 'baseline.baseline.policy.decoder.dynamic_embedding.projection.weight', 'baseline.baseline.policy.decoder.pointer.project_out.weight', 'baseline.baseline.policy.decoder.project_node_embeddings.weight', 'baseline.baseline.policy.decoder.project_fixed_context.weight'] val_file not set. Generating dataset instead test_file not set. Generating dataset instead
Now we can load both the model and environment from the checkpoint!
# Greedy rollouts over trained model (same states as previous plot, with 20 nodes)
policy_new = new_model_checkpoint.policy.to(device)
env = new_model_checkpoint.env.to(device)
out = policy_new(td_init.clone(), env, phase="test", decode_type="greedy")
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['9.12', '7.16', '9.55']
Additional resources¶
Documentation | Getting Started | Usage | Contributing | Paper | Citation
Have feedback about this notebook? Feel free to contribute by either opening an issue or a pull request! ;)