{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Encoder Customization\n", "\n", "In this notebook we will cover a tutorial for the flexible encoders!\n", "\n", "\"Open" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation\n", "\n", "Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!\n", "\n", "> Note: You may need to restart the runtime in Colab after this\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# !pip install rl4co[graph] # include torch-geometric\n", "\n", "## NOTE: to install latest version from Github (may be unstable) install from source instead:\n", "# !pip install git+https://github.com/ai4co/rl4co.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from rl4co.envs import CVRPEnv\n", "\n", "from rl4co.models.zoo import AttentionModel\n", "from rl4co.utils.trainer import RL4COTrainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A default minimal training script\n", "\n", "Here we use the CVRP environment and AM model as a minimal example of training script. By default, the AM is initialized with a Graph Attention Encoder, but we can change it to anything we want." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n", "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n", "Using 16bit Automatic Mixed Precision (AMP)\n", "GPU available: True (cuda), used: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Encoder: GraphAttentionEncoder\n" ] } ], "source": [ "# Init env, model, trainer\n", "env = CVRPEnv(generator_params=dict(num_loc=20))\n", "\n", "model = AttentionModel(\n", " env, \n", " baseline='rollout',\n", " train_data_size=100_000, # really small size for demo\n", " val_data_size=10_000\n", ")\n", " \n", "trainer = RL4COTrainer(\n", " max_epochs=3, # few epochs for demo\n", " accelerator='gpu',\n", " devices=1,\n", " logger=False,\n", ")\n", "\n", "# By default the AM uses the Graph Attention Encoder\n", "print(f'Encoder: {model.policy.encoder._get_name()}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /datasets/home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.\n", "val_file not set. Generating dataset instead\n", "test_file not set. Generating dataset instead\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", " | Name | Type | Params\n", "--------------------------------------------------\n", "0 | env | CVRPEnv | 0 \n", "1 | policy | AttentionModelPolicy | 694 K \n", "2 | baseline | WarmupBaseline | 694 K \n", "--------------------------------------------------\n", "1.4 M Trainable params\n", "0 Non-trainable params\n", "1.4 M Total params\n", "5.553 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3db02ec8f6dc4913a26462bbc1a851e8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00 Note: while we provide these examples, you can also implement your own encoder and use it in RL4CO! For instance, you may use different encoders (and decoders) to solve problems that require e.g. distance matrices as input" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Before we init, we need to install the graph neural network dependencies\n", "# !pip install rl4co[graph]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n", "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n", "Using 16bit Automatic Mixed Precision (AMP)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "# Init the model with different encoder\n", "from rl4co.models.nn.graph.gcn import GCNEncoder\n", "from rl4co.models.nn.graph.mpnn import MessagePassingEncoder\n", "\n", "gcn_encoder = GCNEncoder(\n", " env_name='cvrp', \n", " embed_dim=128,\n", " num_nodes=20, \n", " num_layers=3,\n", ")\n", "\n", "mpnn_encoder = MessagePassingEncoder(\n", " env_name='cvrp', \n", " embed_dim=128,\n", " num_nodes=20, \n", " num_layers=3,\n", ")\n", "\n", "model = AttentionModel(\n", " env, \n", " baseline='rollout',\n", " train_data_size=100_000, # really small size for demo\n", " val_data_size=10_000, \n", " policy_kwargs={\n", " 'encoder': gcn_encoder # gcn_encoder or mpnn_encoder\n", " }\n", ")\n", " \n", "trainer = RL4COTrainer(\n", " max_epochs=3, # few epochs for demo\n", " accelerator='gpu',\n", " devices=1,\n", " logger=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /datasets/home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.\n", "val_file not set. Generating dataset instead\n", "test_file not set. Generating dataset instead\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", " | Name | Type | Params\n", "--------------------------------------------------\n", "0 | env | CVRPEnv | 0 \n", "1 | policy | AttentionModelPolicy | 148 K \n", "2 | baseline | WarmupBaseline | 148 K \n", "--------------------------------------------------\n", "297 K Trainable params\n", "0 Non-trainable params\n", "297 K Total params\n", "1.191 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "81c30fe25912497bb53cfb492810c655", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00 Tuple[Tensor, Tensor]:\n", " \"\"\"\n", " Args:\n", " td: Input TensorDict containing the environment state\n", " mask: Mask to apply to the attention\n", "\n", " Returns:\n", " h: Latent representation of the input\n", " init_h: Initial embedding of the input\n", " \"\"\"\n", " init_h = self.init_embedding(td)\n", " h = None\n", " return h, init_h" ] } ], "metadata": { "kernelspec": { "display_name": "rl4co", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }