Training: Checkpoints, Logging, and Callbacks¶
In this notebook we will cover a quickstart training of the Split Delivery Vehicle Routing Problem (SDVRP), with some additional comments along the way. The SDVRP is a variant of the VRP where a vehicle can deliver a part of the demand of a customer and return later to deliver the rest of the demand.
Installation¶
Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!
Note: You may need to restart the runtime in Colab after this
# !pip install rl4co
## NOTE: to install latest version from Github (may be unstable) install from source instead:
# !pip install git+https://github.com/ai4co/rl4co.git
Imports¶
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from rl4co.envs import SDVRPEnv
from rl4co.models.zoo import AttentionModel
from rl4co.utils.trainer import RL4COTrainer
Main Setup¶
Environment, Model and LitModule¶
# RL4CO env based on TorchRL
env = SDVRPEnv(generator_params=dict(num_loc=20))
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env,
baseline='rollout',
train_data_size=100_000, # really small size for demo
val_data_size=10_000)
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
Test greedy rollout with untrained model and plot¶
# Greedy rollouts over untrained policy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
policy = model.policy.to(device)
out = policy(td_init.clone(), env, phase="test", decode_type="greedy")
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['29.45', '14.26', '21.15']
Training¶
# Checkpointing callback: save models when validation reward improves
checkpoint_callback = ModelCheckpoint( dirpath="checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
monitor="val/reward", # monitor validation reward
mode="max") # maximize validation reward
# Print model summary
rich_model_summary = RichModelSummary(max_depth=3)
# Callbacks list
callbacks = [checkpoint_callback, rich_model_summary]
We make sure we're logged into W&B so that our experiments can be associated with our account. You may comment the below line if you don't want to use it.
# import wandb
# wandb.login()
## Comment following two lines if you don't want logging
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project="rl4co", name="sdvrp-am")
## Keep below if you don't want logging
# logger = None
Trainer¶
The RL4CO trainer is a wrapper around PyTorch Lightning's Trainer
class which adds some functionality and more efficient defaults
The Trainer handles the logging, checkpointing and more for you.
from rl4co.utils.trainer import RL4COTrainer
trainer = RL4COTrainer(
max_epochs=2,
accelerator="gpu",
devices=1,
logger=logger,
callbacks=callbacks,
)
Using 16bit Automatic Mixed Precision (AMP) Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback. GPU available: True (cuda), used: True Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
Fit the model¶
trainer.fit(model)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: silab-kaist. Use `wandb login --relogin` to force relogin /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/wandb/sdk/lib/ipython.py:77: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display from IPython.core.display import HTML, display # type: ignore
./wandb/run-20240428_182146-xcgdzio4
val_file not set. Generating dataset instead test_file not set. Generating dataset instead LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ 0 │ env │ SDVRPEnv │ 0 │ │ 1 │ policy │ AttentionModelPolicy │ 694 K │ │ 2 │ policy.encoder │ AttentionModelEncoder │ 595 K │ │ 3 │ policy.encoder.init_embedding │ VRPInitEmbedding │ 896 │ │ 4 │ policy.encoder.net │ GraphAttentionNetwork │ 594 K │ │ 5 │ policy.decoder │ AttentionModelDecoder │ 98.8 K │ │ 6 │ policy.decoder.context_embedding │ VRPContext │ 16.5 K │ │ 7 │ policy.decoder.dynamic_embedding │ SDVRPDynamicEmbedding │ 384 │ │ 8 │ policy.decoder.pointer │ PointerAttention │ 16.4 K │ │ 9 │ policy.decoder.project_node_embeddings │ Linear │ 49.2 K │ │ 10 │ policy.decoder.project_fixed_context │ Linear │ 16.4 K │ │ 11 │ baseline │ WarmupBaseline │ 694 K │ │ 12 │ baseline.baseline │ RolloutBaseline │ 694 K │ │ 13 │ baseline.baseline.policy │ AttentionModelPolicy │ 694 K │ │ 14 │ baseline.warmup_baseline │ ExponentialBaseline │ 0 │ └────┴────────────────────────────────────────┴───────────────────────┴────────┘
Trainable params: 1.4 M Non-trainable params: 0 Total params: 1.4 M Total estimated model params size (MB): 5
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=2` reached.
Testing¶
Plotting¶
Here we plot the solution (greedy rollout) of the trained policy to the initial problem
# Greedy rollouts over trained model (same states as previous plot)
policy = model.policy.to(device)
out = policy(td_init.clone(), env, phase="test", decode_type="greedy")
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['9.12', '7.16', '9.55']