Local Search¶
In this notebook, we will show how to improve the solution at hand using local search and other techniques. Here we solve TSP and use 2-opt to improve the solution. You can check how the improvement works for other problems in each Env's local_search
method.
Note that this notebook is based on 1-quickstart
and we use the checkpoint file from it. If you haven't checked it yet, we recommend you to check it first.
Installation¶
We use LocalSearch operator provided by PyVRP. See https://github.com/PyVRP/PyVRP for more details.
Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!
Note: You may need to restart the runtime in Colab after this
In [1]:
Copied!
# !pip install rl4co[routing] # include pyvrp
# !pip install rl4co[routing] # include pyvrp
Imports¶
In [2]:
Copied!
import torch
from rl4co.envs import TSPEnv
from rl4co.models.zoo import AttentionModel
import torch
from rl4co.envs import TSPEnv
from rl4co.models.zoo import AttentionModel
Environment, Policy, and Model from saved checkpoint¶
In [3]:
Copied!
# RL4CO env based on TorchRL
env = TSPEnv(num_loc=50)
checkpoint_path = "../tsp-quickstart.ckpt" # checkpoint from the ../1-quickstart.ipynb
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel.load_from_checkpoint(checkpoint_path, load_baseline=False)
# RL4CO env based on TorchRL
env = TSPEnv(num_loc=50)
checkpoint_path = "../tsp-quickstart.ckpt" # checkpoint from the ../1-quickstart.ipynb
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel.load_from_checkpoint(checkpoint_path, load_baseline=False)
/home/sanghyeok/NCO/rl4co/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`. /home/sanghyeok/NCO/rl4co/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`. /home/sanghyeok/NCO/rl4co/.venv/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:188: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.model.encoder.init_embedding.init_embed.weight', 'baseline.baseline.model.encoder.init_embedding.init_embed.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.num_batches_tracked', 'baseline.baseline.model.decoder.context_embedding.W_placeholder', 'baseline.baseline.model.decoder.context_embedding.project_context.weight', 'baseline.baseline.model.decoder.project_node_embeddings.weight', 'baseline.baseline.model.decoder.project_fixed_context.weight', 'baseline.baseline.model.decoder.pointer.project_out.weight']
Testing with Solution Improvement¶
In [4]:
Copied!
# Greedy rollouts over trained model (same states as previous plot)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
model = model.to(device)
out = model(td_init.clone(), phase="test", decode_type="greedy")
actions = out['actions']
# Improve solutions using LocalSearch
improved_actions = env.local_search(td_init, actions, rng=0)
improved_rewards = env.get_reward(td_init, improved_actions)
# Plotting
import matplotlib.pyplot as plt
for i, td in enumerate(td_init):
fig, axs = plt.subplots(1,2, figsize=(11,5))
env.render(td, actions[i], ax=axs[0])
env.render(td, improved_actions[i], ax=axs[1])
axs[0].set_title(f"Before improvement | Cost = {-out['reward'][i].item():.3f}")
axs[1].set_title(f"After improvement | Cost = {-improved_rewards[i].item():.3f}")
# Greedy rollouts over trained model (same states as previous plot)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
model = model.to(device)
out = model(td_init.clone(), phase="test", decode_type="greedy")
actions = out['actions']
# Improve solutions using LocalSearch
improved_actions = env.local_search(td_init, actions, rng=0)
improved_rewards = env.get_reward(td_init, improved_actions)
# Plotting
import matplotlib.pyplot as plt
for i, td in enumerate(td_init):
fig, axs = plt.subplots(1,2, figsize=(11,5))
env.render(td, actions[i], ax=axs[0])
env.render(td, improved_actions[i], ax=axs[1])
axs[0].set_title(f"Before improvement | Cost = {-out['reward'][i].item():.3f}")
axs[1].set_title(f"After improvement | Cost = {-improved_rewards[i].item():.3f}")
We can see that the solution has improved after using 2-opt.