New Environment: Creating and Modeling¶
In this notebook, we will show how to extend RL4CO to solve new problems from zero to hero! 🚀
Contents¶
Problem: TSP¶
We will build an environment and model for the Traveling Salesman Problem (TSP). The TSP is a well-known combinatorial optimization problem that consists of finding the shortest route that visits each city in a given list exactly once and returns to the origin city. The TSP is NP-hard, and it is one of the most studied problems in combinatorial optimization.
Installation¶
## Uncomment the following line to install the package from PyPI
## You may need to restart the runtime in Colab after this
## Remember to choose a GPU runtime for faster training!
# !pip install rl4co
Imports¶
from typing import Optional
import torch
import torch.nn as nn
from tensordict.tensordict import TensorDict
from torchrl.data import (
Bounded,
Composite,
Unbounded,
Unbounded,
)
from rl4co.utils.decoding import rollout, random_policy
from rl4co.envs.common import RL4COEnvBase, Generator, get_sampler
from rl4co.models.zoo import AttentionModel, AttentionModelPolicy
from rl4co.utils.ops import gather_by_index, get_tour_length
from rl4co.utils.trainer import RL4COTrainer
Environment Creation¶
We will base environment creation on the RL4COEnvBase
class, which is based on TorchRL. More information in documentation!
Reset¶
The _reset
function is used to initialize the environment to an initial state. It returns a TensorDict of the initial state.
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
# Initialize locations
init_locs = td["locs"] if td is not None else None
if batch_size is None:
batch_size = self.batch_size if init_locs is None else init_locs.shape[:-2]
device = init_locs.device if init_locs is not None else self.device
self.to(device)
if init_locs is None:
init_locs = self.generate_data(batch_size=batch_size).to(device)["locs"]
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
# We do not enforce loading from self for flexibility
num_loc = init_locs.shape[-2]
# Other variables
current_node = torch.zeros((batch_size), dtype=torch.int64, device=device)
available = torch.ones(
(*batch_size, num_loc), dtype=torch.bool, device=device
) # 1 means not visited, i.e. action is allowed
i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device)
return TensorDict(
{
"locs": init_locs,
"first_node": current_node,
"current_node": current_node,
"i": i,
"action_mask": available,
"reward": torch.zeros((*batch_size, 1), dtype=torch.float32),
},
batch_size=batch_size,
)
Step¶
Environment _step
: this defines the state update of the TSP problem gived a TensorDict (td in the code) of the current state and the action to take:
def _step(self, td: TensorDict) -> TensorDict:
current_node = td["action"]
first_node = current_node if td["i"].all() == 0 else td["first_node"]
# Set not visited to 0 (i.e., we visited the node)
# Note: we may also use a separate function for obtaining the mask for more flexibility
available = td["action_mask"].scatter(
-1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0
)
# We are done there are no unvisited locations
done = torch.sum(available, dim=-1) == 0
# The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
reward = torch.zeros_like(done)
td.update(
{
"first_node": first_node,
"current_node": current_node,
"i": td["i"] + 1,
"action_mask": available,
"reward": reward,
"done": done,
},
)
return td
[Optional] Separate Action Mask Function¶
The get_action_mask
function simply returns a mask of the valid actions for the current updated state. This can be used in _step
and _reset
for larger environments with several constraints and may be useful for modularity
def get_action_mask(self, td: TensorDict) -> TensorDict:
# Here: your logic
return td["action_mask"]
[Optional] Check Solution Validity¶
Another optional utility, this checks whether the solution is feasible and can help identify bugs
def check_solution_validity(self, td: TensorDict, actions: torch.Tensor):
"""Check that solution is valid: nodes are visited exactly once"""
assert (
torch.arange(actions.size(1), out=actions.data.new())
.view(1, -1)
.expand_as(actions)
== actions.data.sort(1)[0]
).all(), "Invalid tour"
Reward function¶
The _get_reward
function is used to evaluate the reward given the solution (actions).
def _get_reward(self, td, actions) -> TensorDict:
# Sanity check if enabled
if self.check_solution:
self.check_solution_validity(td, actions)
# Gather locations in order of tour and return distance between them (i.e., -reward)
locs_ordered = gather_by_index(td["locs"], actions)
return -get_tour_length(locs_ordered)
Environment Action Specs¶
This defines the input and output domains of the environment - similar to Gym's spaces
.
This is not strictly necessary, but it is useful to have a clear definition of the environment's action and observation spaces and if we want to sample actions using TorchRL's utils
Note: this is actually not necessary, but it is useful to have a clear definition of the environment's action and observation spaces and if we want to sample actions using TorchRL's utils
def _make_spec(self, generator):
"""Make the observation and action specs from the parameters"""
self.observation_spec = Composite(
locs=Bounded(
low=self.generator.min_loc,
high=self.generator.max_loc,
shape=(self.generator.num_loc, 2),
dtype=torch.float32,
),
first_node=Unbounded(
shape=(1),
dtype=torch.int64,
),
current_node=Unbounded(
shape=(1),
dtype=torch.int64,
),
i=Unbounded(
shape=(1),
dtype=torch.int64,
),
action_mask=Unbounded(
shape=(self.generator.num_loc),
dtype=torch.bool,
),
shape=(),
)
self.action_spec = Bounded(
shape=(1,),
dtype=torch.int64,
low=0,
high=self.generator.num_loc,
)
self.reward_spec = Unbounded(shape=(1,))
self.done_spec = Unbounded(shape=(1,), dtype=torch.bool)
Data generator¶
The generator allows to generate random instances of the problem. Note that this is a simplified example: this can include additional distributions via the rl4co.envs.common.utils.get_sampler
method!
class TSPGenerator(Generator):
def __init__(
self,
num_loc: int = 20,
min_loc: float = 0.0,
max_loc: float = 1.0,
):
self.num_loc = num_loc
self.min_loc = min_loc
self.max_loc = max_loc
self.loc_sampler = torch.distributions.Uniform(
low=min_loc, high=max_loc
)
def _generate(self, batch_size) -> TensorDict:
# Sample locations
locs = self.loc_sampler.sample((*batch_size, self.num_loc, 2))
return TensorDict({"locs": locs}, batch_size=batch_size)
# Test generator
generator = TSPGenerator(num_loc=20)
locs = generator(32)
print(locs["locs"].shape)
torch.Size([32, 20, 2])
Render function¶
The render
function is optional, but can be useful for quickly visualizing the results of your algorithm!
def render(self, td, actions=None, ax=None):
import matplotlib.pyplot as plt
import numpy as np
if ax is None:
# Create a plot of the nodes
_, ax = plt.subplots()
td = td.detach().cpu()
if actions is None:
actions = td.get("action", None)
# if batch_size greater than 0 , we need to select the first batch element
if td.batch_size != torch.Size([]):
td = td[0]
actions = actions[0]
locs = td["locs"]
# gather locs in order of action if available
if actions is None:
print("No action in TensorDict, rendering unsorted locs")
else:
actions = actions.detach().cpu()
locs = gather_by_index(locs, actions, dim=0)
# Cat the first node to the end to complete the tour
locs = torch.cat((locs, locs[0:1]))
x, y = locs[:, 0], locs[:, 1]
# Plot the visited nodes
ax.scatter(x, y, color="tab:blue")
# Add arrows between visited nodes as a quiver plot
dx, dy = np.diff(x), np.diff(y)
ax.quiver(
x[:-1], y[:-1], dx, dy, scale_units="xy", angles="xy", scale=1, color="k"
)
# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
Putting everything together¶
class TSPEnv(RL4COEnvBase):
"""Traveling Salesman Problem (TSP) environment"""
name = "tsp"
def __init__(
self,
generator = TSPGenerator,
generator_params = {},
**kwargs,
):
super().__init__(**kwargs)
self.generator = generator(**generator_params)
self._make_spec(self.generator)
_reset = _reset
_step = _step
_get_reward = _get_reward
check_solution_validity = check_solution_validity
get_action_mask = get_action_mask
_make_spec = _make_spec
render = render
batch_size = 2
env = TSPEnv(generator_params=dict(num_loc=20))
reward, td, actions = rollout(env, env.reset(batch_size=[batch_size]), random_policy)
env.render(td, actions)
Modeling¶
Now we need to model the problem by transforming input information into the latent space to be processed. Here we focus on AttentionModel
-based embeddings with an encoder-decoder structure. In RL4CO, we divide embeddings in 3 parts:
init_embedding
: (encoder) embed initial states of the problemcontext_embedding
: (decoder) embed context information of the problem for the current partial solution to modify the querydynamic_embedding
: (decoder) embed dynamic information of the problem for the current partial solution to modify the query, key, and value (i.e. if other nodes also change state)
Init Embedding¶
Embed initial problem into latent space. In our case, we can project the coordinates of the cities into a latent space.
class TSPInitEmbedding(nn.Module):
"""Initial embedding for the Traveling Salesman Problems (TSP).
Embed the following node features to the embedding space:
- locs: x, y coordinates of the cities
"""
def __init__(self, embed_dim, linear_bias=True):
super(TSPInitEmbedding, self).__init__()
node_dim = 2 # x, y
self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)
def forward(self, td):
out = self.init_embed(td["locs"])
return out
Context Embedding¶
Context embedding takes the current context and returns a vector representation of it. In TSP, we can take the embedding of the first node visited (since we need to complete the tour) as well as the embedding of current node visited (in the first step we just have a placeholder since they are the same).
class TSPContext(nn.Module):
"""Context embedding for the Traveling Salesman Problem (TSP).
Project the following to the embedding space:
- first node embedding
- current node embedding
"""
def __init__(self, embed_dim, linear_bias=True):
super(TSPContext, self).__init__()
self.W_placeholder = nn.Parameter(
torch.Tensor(2 * embed_dim).uniform_(-1, 1)
)
self.project_context = nn.Linear(
embed_dim*2, embed_dim, bias=linear_bias
)
def forward(self, embeddings, td):
batch_size = embeddings.size(0)
# By default, node_dim = -1 (we only have one node embedding per node)
node_dim = (
(-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1)
)
if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast
context_embedding = self.W_placeholder[None, :].expand(
batch_size, self.W_placeholder.size(-1)
)
else:
context_embedding = gather_by_index(
embeddings,
torch.stack([td["first_node"], td["current_node"]], -1).view(
batch_size, -1
),
).view(batch_size, *node_dim)
return self.project_context(context_embedding)
Dynamic Embedding¶
Since the states do not change except for visited nodes, we do not need to modify the keys and values. Therefore, we set this to 0
class StaticEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super(StaticEmbedding, self).__init__()
def forward(self, td):
return 0, 0, 0
Training our Model¶
# Instantiate our environment
env = TSPEnv(generator_params=dict(num_loc=20))
# Instantiate policy with the embeddings we created above
emb_dim = 128
policy = AttentionModelPolicy(env_name=env.name, # this is actually not needed since we are initializing the embeddings!
embed_dim=emb_dim,
init_embedding=TSPInitEmbedding(emb_dim),
context_embedding=TSPContext(emb_dim),
dynamic_embedding=StaticEmbedding(emb_dim)
)
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env,
policy=policy,
baseline='rollout',
train_data_size=100_000,
val_data_size=10_000)
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
Rollout untrained model¶
# Greedy rollouts over untrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
policy = model.policy.to(device)
out = policy(td_init.clone(), env, phase="test", decode_type="greedy")
actions_untrained = out['actions'].cpu().detach()
rewards_untrained = out['reward'].cpu().detach()
for i in range(3):
print(f"Problem {i+1} | Cost: {-rewards_untrained[i]:.3f}")
env.render(td_init[i], actions_untrained[i])
Problem 1 | Cost: 11.545 Problem 2 | Cost: 8.525 Problem 3 | Cost: 12.461
Training loop¶
# We use our own wrapper around Lightning's `Trainer` to make it easier to use
trainer = RL4COTrainer(max_epochs=3, devices=1)
trainer.fit(model)
Using 16bit Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default val_file not set. Generating dataset instead test_file not set. Generating dataset instead LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1] | Name | Type | Params -------------------------------------------------- 0 | env | TSPEnv | 0 1 | policy | AttentionModelPolicy | 710 K 2 | baseline | WarmupBaseline | 710 K -------------------------------------------------- 1.4 M Trainable params 0 Non-trainable params 1.4 M Total params 5.682 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance. /home/botu/mambaforge/envs/rl4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=3` reached.
Evaluation¶
# Greedy rollouts over trained policy (same states as previous plot)
policy = model.policy.to(device)
out = policy(td_init.clone(), env, phase="test", decode_type="greedy")
actions_trained = out['actions'].cpu().detach()
# Plotting
import matplotlib.pyplot as plt
for i, td in enumerate(td_init):
fig, axs = plt.subplots(1,2, figsize=(11,5))
env.render(td, actions_untrained[i], ax=axs[0])
env.render(td, actions_trained[i], ax=axs[1])
axs[0].set_title(f"Untrained | Cost = {-rewards_untrained[i].item():.3f}")
axs[1].set_title(r"Trained $\pi_\theta$" + f"| Cost = {-out['reward'][i].item():.3f}")