{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Model on TSPLib\n", "\n", "In this notebook, we will test the trained model's performance on the TSPLib benchmark. We will use the trained model from the previous notebook.\n", "\n", "[TSPLib](http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/) is a library of sample instances for the TSP (and related problems) from various sources and of various types. In the TSPLib, there are several problems, including *TSP, HCP, ATSP*, etc. In this notebook, we will focus on testing the model on the TSP problem.\n", "\n", "\n", "\n", "\"Open\n", "\n", "\n", "## Before we start\n", "\n", "Before we test the model on TSPLib dataset, we need to prepare the dataset first by the following steps:\n", "\n", "**Step 1**. You may come to [here](http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/) to download the *symmetric traveling salesman problem* data in TSPLib dataset and unzip to a folder;\n", "\n", "Note that the downloaded data is `gzip` file with the following file tree:\n", "```\n", ".\n", "└── ALL_tsp/\n", " ├── a280.opt.tour.gz\n", " ├── a280.tsp.gz\n", " ├── ali535.tsp.gz\n", " └── ... (other problems)\n", "```\n", "We need to unzip the `gzip` file to get the `.tsp` and `.opt.tour` files. We can use the following command to unzip them to the same folder:\n", "```bash\n", "find . -name \"*.gz\" -exec gunzip {} +\n", "```\n", "\n", "After doing this, we will get the following file tree:\n", "```\n", ".\n", "└── ALL_tsp/\n", " ├── a280.opt.tour\n", " ├── a280.tsp\n", " ├── ali535.tsp\n", " └── ... (other problems)\n", "```\n", "\n", "**Step 2**. To read the TSPLib problem and optimal solution, we choose to use the `tsplib95` package. Use `pip install tsplib95` to install the package." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation\n", "\n", "Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!\n", "\n", "> Note: You may need to restart the runtime in Colab after this\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !pip install rl4co[graph] # include torch-geometric\n", "\n", "## NOTE: to install latest version from Github (may be unstable) install from source instead:\n", "# !pip install git+https://github.com/ai4co/rl4co.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install the `tsplib95` package\n", "# !pip install tsplib95" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/cbhua/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import os\n", "import re\n", "import torch\n", "\n", "from rl4co.envs import TSPEnv, CVRPEnv\n", "from rl4co.models.zoo.am import AttentionModel\n", "from rl4co.utils.trainer import RL4COTrainer\n", "from rl4co.utils.decoding import get_log_likelihood\n", "from rl4co.models.zoo import EAS, EASLay, EASEmb, ActiveSearch\n", "\n", "from math import ceil\n", "from einops import repeat\n", "from tsplib95.loaders import load_problem, load_solution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load a trained model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/cbhua/miniconda3/envs/rl4co-user/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.\n", "/home/cbhua/miniconda3/envs/rl4co-user/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.\n", "/home/cbhua/miniconda3/envs/rl4co-user/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:177: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.model.encoder.init_embedding.init_embed.weight', 'baseline.baseline.model.encoder.init_embedding.init_embed.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.num_batches_tracked', 'baseline.baseline.model.decoder.context_embedding.W_placeholder', 'baseline.baseline.model.decoder.context_embedding.project_context.weight', 'baseline.baseline.model.decoder.project_node_embeddings.weight', 'baseline.baseline.model.decoder.project_fixed_context.weight', 'baseline.baseline.model.decoder.logit_attention.project_out.weight']\n" ] } ], "source": [ "# Load from checkpoint; alternatively, simply instantiate a new model\n", "# Note the model is trained for TSP problem\n", "checkpoint_path = \"../tsp-20.ckpt\" # modify the path to your checkpoint file\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# load checkpoint\n", "# checkpoint = torch.load(checkpoint_path)\n", "\n", "lit_model = AttentionModel.load_from_checkpoint(checkpoint_path, load_baseline=False)\n", "policy, env = lit_model.policy, lit_model.env\n", "policy = policy.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load tsp problems\n", "\n", "Note that in the TSPLib, only part of the problems have optimal solutions with the same problem name but with `.opt.tour` suffix. For example, `a280.tsp` has the optimal solution `a280.opt.tour`. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load the problem from TSPLib\n", "tsplib_dir = './tsplib'# modify this to the directory of your prepared files\n", "files = os.listdir(tsplib_dir)\n", "problem_files_full = [file for file in files if file.endswith('.tsp')]\n", "\n", "# Load the optimal solution files from TSPLib\n", "solution_files = [file for file in files if file.endswith('.opt.tour')] " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Utils function\n", "def normalize_coord(coord:torch.Tensor) -> torch.Tensor:\n", " x, y = coord[:, 0], coord[:, 1]\n", " x_min, x_max = x.min(), x.max()\n", " y_min, y_max = y.min(), y.max()\n", " \n", " x_scaled = (x - x_min) / (x_max - x_min) \n", " y_scaled = (y - y_min) / (y_max - y_min)\n", " coord_scaled = torch.stack([x_scaled, y_scaled], dim=1)\n", " return coord_scaled " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test the greedy\n", "\n", "Note that run all experiments will take long time and require large VRAM. For simple testing, we can use a subset of the problems. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Customized problem cases\n", "problem_files_custom = [\n", " \"eil51.tsp\", \"berlin52.tsp\", \"st70.tsp\", \"eil76.tsp\", \n", " \"pr76.tsp\", \"rat99.tsp\", \"kroA100.tsp\", \"kroB100.tsp\", \n", " \"kroC100.tsp\", \"kroD100.tsp\", \"kroE100.tsp\", \"rd100.tsp\", \n", " \"eil101.tsp\", \"lin105.tsp\", \"pr124.tsp\", \"bier127.tsp\", \n", " \"ch130.tsp\", \"pr136.tsp\", \"pr144.tsp\", \"kroA150.tsp\", \n", " \"kroB150.tsp\", \"pr152.tsp\", \"u159.tsp\", \"rat195.tsp\", \n", " \"kroA200.tsp\", \"ts225.tsp\", \"tsp225.tsp\", \"pr226.tsp\"\n", "]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_3883036/2632546508.py:5: DeprecationWarning: Call to deprecated function (or staticmethod) load_problem. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", "/tmp/ipykernel_3883036/2632546508.py:43: DeprecationWarning: Call to deprecated function (or staticmethod) load_solution. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Problem: eil51 Cost: 493 Optimal Cost: 426 \t Gap: 15.73%\n", "problem: eil51 cost: 493 \n", "problem: berlin52 cost: 8957 \n", "Problem: st70 Cost: 806 Optimal Cost: 675 \t Gap: 19.41%\n", "problem: st70 cost: 806 \n", "Problem: eil76 Cost: 693 Optimal Cost: 538 \t Gap: 28.81%\n", "problem: eil76 cost: 693 \n", "Problem: pr76 Cost: 132292 Optimal Cost: 108159 \t Gap: 22.31%\n", "problem: pr76 cost: 132292 \n", "problem: rat99 cost: 2053 \n", "Problem: kroA100 Cost: 30791 Optimal Cost: 21282 \t Gap: 44.68%\n", "problem: kroA100 cost: 30791 \n", "problem: kroB100 cost: 30347 \n", "Problem: kroC100 Cost: 28339 Optimal Cost: 20749 \t Gap: 36.58%\n", "problem: kroC100 cost: 28339 \n", "Problem: kroD100 Cost: 27600 Optimal Cost: 21294 \t Gap: 29.61%\n", "problem: kroD100 cost: 27600 \n", "problem: kroE100 cost: 28396 \n", "Problem: rd100 Cost: 10695 Optimal Cost: 7910 \t Gap: 35.21%\n", "problem: rd100 cost: 10695 \n", "problem: eil101 cost: 919 \n", "Problem: lin105 Cost: 21796 Optimal Cost: 14379 \t Gap: 51.58%\n", "problem: lin105 cost: 21796 \n", "problem: pr124 cost: 75310 \n", "problem: bier127 cost: 177471 \n", "problem: ch130 cost: 8169 \n", "problem: pr136 cost: 135974 \n", "problem: pr144 cost: 71599 \n", "problem: kroA150 cost: 40376 \n", "problem: kroB150 cost: 37076 \n", "problem: pr152 cost: 94805 \n", "problem: u159 cost: 64768 \n", "problem: rat195 cost: 4465 \n", "problem: kroA200 cost: 44181 \n", "problem: ts225 cost: 210475 \n", "Problem: tsp225 Cost: 6212 Optimal Cost: 3919 \t Gap: 58.51%\n", "problem: tsp225 cost: 6212 \n", "problem: pr226 cost: 98849 \n" ] } ], "source": [ "# problem_files = problem_files_full # if you want to test on all the problems\n", "problem_files = problem_files_custom # if you want to test on the customized problems\n", "\n", "for problem_idx in range(len(problem_files)):\n", " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", "\n", " # NOTE: in some problem files (e.g. hk48), the node coordinates are not available\n", " # we temporarily skip these problems\n", " if not len(problem.node_coords):\n", " continue\n", "\n", " # Load the node coordinates\n", " coords = []\n", " for _, v in problem.node_coords.items():\n", " coords.append(v)\n", " coords = torch.tensor(coords).float().to(device) # [n, 2]\n", " coords_norm = normalize_coord(coords)\n", "\n", " # Prepare the tensordict\n", " batch_size = 2\n", " td = env.reset(batch_size=(batch_size,)).to(device)\n", " td['locs'] = repeat(coords_norm, 'n d -> b n d', b=batch_size, d=2)\n", " td['action_mask'] = torch.ones(batch_size, coords_norm.shape[0], dtype=torch.bool)\n", "\n", " # Get the solution from the policy\n", " with torch.no_grad():\n", " out = policy(\n", " td.clone(), \n", " decode_type=\"greedy\", \n", " return_actions=True,\n", " num_starts=0\n", " )\n", "\n", " # Calculate the cost on the original scale\n", " td['locs'] = repeat(coords, 'n d -> b n d', b=batch_size, d=2)\n", " neg_reward = env.get_reward(td, out['actions'])\n", " cost = ceil(-1 * neg_reward[0].item())\n", "\n", " # Check if there exists an optimal solution\n", " try:\n", " # Load the optimal solution\n", " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n", " matches = re.findall(r'\\((.*?)\\)', solution.comment)\n", "\n", " # NOTE: in some problem solution file (e.g. ch130), the optimal cost is not writen with a brace\n", " # we temporarily skip to calculate the gap for these problems\n", " optimal_cost = int(matches[0])\n", " gap = (cost - optimal_cost) / optimal_cost\n", " print(f'Problem: {problem_files[problem_idx][:-4]:<10} Cost: {cost:<10} Optimal Cost: {optimal_cost:<10}\\t Gap: {gap:.2%}')\n", " except:\n", " continue\n", " finally:\n", " print(f'problem: {problem_files[problem_idx][:-4]:<10} cost: {cost:<10}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test the Augmentation" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_3883036/2898406631.py:13: DeprecationWarning: Call to deprecated function (or staticmethod) load_problem. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", "/tmp/ipykernel_3883036/2898406631.py:56: DeprecationWarning: Call to deprecated function (or staticmethod) load_solution. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Problem: eil51\t Cost: 457\t Optimal Cost: 426\t Gap: 7.28%\n", "problem: eil51\t cost: 457\t\n", "problem: berlin52\t cost: 8256\t\n", "Problem: st70\t Cost: 777\t Optimal Cost: 675\t Gap: 15.11%\n", "problem: st70\t cost: 777\t\n", "Problem: eil76\t Cost: 652\t Optimal Cost: 538\t Gap: 21.19%\n", "problem: eil76\t cost: 652\t\n", "Problem: pr76\t Cost: 124939\t Optimal Cost: 108159\t Gap: 15.51%\n", "problem: pr76\t cost: 124939\t\n", "problem: rat99\t cost: 1614\t\n", "Problem: kroA100\t Cost: 27694\t Optimal Cost: 21282\t Gap: 30.13%\n", "problem: kroA100\t cost: 27694\t\n", "problem: kroB100\t cost: 28244\t\n", "Problem: kroC100\t Cost: 25032\t Optimal Cost: 20749\t Gap: 20.64%\n", "problem: kroC100\t cost: 25032\t\n", "Problem: kroD100\t Cost: 26811\t Optimal Cost: 21294\t Gap: 25.91%\n", "problem: kroD100\t cost: 26811\t\n", "problem: kroE100\t cost: 27831\t\n", "Problem: rd100\t Cost: 9657\t Optimal Cost: 7910\t Gap: 22.09%\n", "problem: rd100\t cost: 9657\t\n", "problem: eil101\t cost: 781\t\n", "Problem: lin105\t Cost: 18769\t Optimal Cost: 14379\t Gap: 30.53%\n", "problem: lin105\t cost: 18769\t\n", "problem: pr124\t cost: 72115\t\n", "problem: bier127\t cost: 154518\t\n", "problem: ch130\t cost: 7543\t\n", "problem: pr136\t cost: 128134\t\n", "problem: pr144\t cost: 69755\t\n", "problem: kroA150\t cost: 35967\t\n", "problem: kroB150\t cost: 35196\t\n", "problem: pr152\t cost: 88602\t\n", "problem: u159\t cost: 59484\t\n", "problem: rat195\t cost: 3631\t\n", "problem: kroA200\t cost: 42061\t\n", "problem: ts225\t cost: 196545\t\n", "Problem: tsp225\t Cost: 5680\t Optimal Cost: 3919\t Gap: 44.93%\n", "problem: tsp225\t cost: 5680\t\n", "problem: pr226\t cost: 94540\t\n" ] } ], "source": [ "# problem_files = problem_files_full # if you want to test on all the problems\n", "problem_files = problem_files_custom # if you want to test on the customized problems\n", "\n", "# Import augmented utils\n", "from rl4co.data.transforms import (\n", " StateAugmentation as SymmetricStateAugmentation)\n", "from rl4co.utils.ops import batchify, unbatchify\n", "\n", "num_augment = 100\n", "augmentation = SymmetricStateAugmentation(num_augment=num_augment)\n", "\n", "for problem_idx in range(len(problem_files)):\n", " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", "\n", " # NOTE: in some problem files (e.g. hk48), the node coordinates are not available\n", " # we temporarily skip these problems\n", " if not len(problem.node_coords):\n", " continue\n", "\n", " # Load the node coordinates\n", " coords = []\n", " for _, v in problem.node_coords.items():\n", " coords.append(v)\n", " coords = torch.tensor(coords).float().to(device) # [n, 2]\n", " coords_norm = normalize_coord(coords)\n", "\n", " # Prepare the tensordict\n", " batch_size = 2\n", " td = env.reset(batch_size=(batch_size,)).to(device)\n", " td['locs'] = repeat(coords_norm, 'n d -> b n d', b=batch_size, d=2)\n", " td['action_mask'] = torch.ones(batch_size, coords_norm.shape[0], dtype=torch.bool)\n", "\n", " # Augmentation\n", " td = augmentation(td)\n", "\n", " # Get the solution from the policy\n", " with torch.no_grad():\n", " out = policy(\n", " td.clone(), \n", " decode_type=\"greedy\", \n", " return_actions=True,\n", " num_starts=0\n", " )\n", "\n", " # Calculate the cost on the original scale\n", " coords_repeat = repeat(coords, 'n d -> b n d', b=batch_size, d=2)\n", " td['locs'] = batchify(coords_repeat, num_augment)\n", " reward = env.get_reward(td, out['actions'])\n", " reward = unbatchify(reward, num_augment)\n", " cost = ceil(-1 * torch.max(reward).item())\n", "\n", " # Check if there exists an optimal solution\n", " try:\n", " # Load the optimal solution\n", " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n", " matches = re.findall(r'\\((.*?)\\)', solution.comment)\n", "\n", " # NOTE: in some problem solution file (e.g. ch130), the optimal cost is not writen with a brace\n", " # we temporarily skip to calculate the gap for these problems\n", " optimal_cost = int(matches[0])\n", " gap = (cost - optimal_cost) / optimal_cost\n", " print(f'Problem: {problem_files[problem_idx][:-4]}\\t Cost: {cost}\\t Optimal Cost: {optimal_cost}\\t Gap: {gap:.2%}')\n", " except:\n", " continue\n", " finally:\n", " print(f'problem: {problem_files[problem_idx][:-4]}\\t cost: {cost}\\t')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test the Sampling" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_3883036/2154301274.py:9: DeprecationWarning: Call to deprecated function (or staticmethod) load_problem. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", "/tmp/ipykernel_3883036/2154301274.py:53: DeprecationWarning: Call to deprecated function (or staticmethod) load_solution. (Will be removed in newer versions. Use `tsplib95.load` instead.) -- Deprecated since version 7.0.0.\n", " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Problem: eil51\t Cost: 482\t Optimal Cost: 426\t Gap: 13.15%\n", "problem: eil51\t cost: 482\t\n", "problem: berlin52\t cost: 8955\t\n", "Problem: st70\t Cost: 794\t Optimal Cost: 675\t Gap: 17.63%\n", "problem: st70\t cost: 794\t\n", "Problem: eil76\t Cost: 673\t Optimal Cost: 538\t Gap: 25.09%\n", "problem: eil76\t cost: 673\t\n", "Problem: pr76\t Cost: 127046\t Optimal Cost: 108159\t Gap: 17.46%\n", "problem: pr76\t cost: 127046\t\n", "problem: rat99\t cost: 1886\t\n", "Problem: kroA100\t Cost: 29517\t Optimal Cost: 21282\t Gap: 38.69%\n", "problem: kroA100\t cost: 29517\t\n", "problem: kroB100\t cost: 28892\t\n", "Problem: kroC100\t Cost: 26697\t Optimal Cost: 20749\t Gap: 28.67%\n", "problem: kroC100\t cost: 26697\t\n", "Problem: kroD100\t Cost: 27122\t Optimal Cost: 21294\t Gap: 27.37%\n", "problem: kroD100\t cost: 27122\t\n", "problem: kroE100\t cost: 28016\t\n", "Problem: rd100\t Cost: 10424\t Optimal Cost: 7910\t Gap: 31.78%\n", "problem: rd100\t cost: 10424\t\n", "problem: eil101\t cost: 837\t\n", "Problem: lin105\t Cost: 19618\t Optimal Cost: 14379\t Gap: 36.44%\n", "problem: lin105\t cost: 19618\t\n", "problem: pr124\t cost: 74699\t\n", "problem: bier127\t cost: 170255\t\n", "problem: ch130\t cost: 7985\t\n", "problem: pr136\t cost: 129964\t\n", "problem: pr144\t cost: 70477\t\n", "problem: kroA150\t cost: 37185\t\n", "problem: kroB150\t cost: 35172\t\n", "problem: pr152\t cost: 97244\t\n", "problem: u159\t cost: 59792\t\n", "problem: rat195\t cost: 4325\t\n", "problem: kroA200\t cost: 42059\t\n", "problem: ts225\t cost: 205982\t\n", "Problem: tsp225\t Cost: 5970\t Optimal Cost: 3919\t Gap: 52.33%\n", "problem: tsp225\t cost: 5970\t\n", "problem: pr226\t cost: 103135\t\n" ] } ], "source": [ "# problem_files = problem_files_full # if you want to test on all the problems\n", "problem_files = problem_files_custom # if you want to test on the customized problems\n", "\n", "# Parameters for sampling\n", "num_samples = 100\n", "softmax_temp = 0.05\n", "\n", "for problem_idx in range(len(problem_files)):\n", " problem = load_problem(os.path.join(tsplib_dir, problem_files[problem_idx]))\n", "\n", " # NOTE: in some problem files (e.g. hk48), the node coordinates are not available\n", " # we temporarily skip these problems\n", " if not len(problem.node_coords):\n", " continue\n", "\n", " # Load the node coordinates\n", " coords = []\n", " for _, v in problem.node_coords.items():\n", " coords.append(v)\n", " coords = torch.tensor(coords).float().to(device) # [n, 2]\n", " coords_norm = normalize_coord(coords)\n", "\n", " # Prepare the tensordict\n", " batch_size = 2\n", " td = env.reset(batch_size=(batch_size,)).to(device)\n", " td['locs'] = repeat(coords_norm, 'n d -> b n d', b=batch_size, d=2)\n", " td['action_mask'] = torch.ones(batch_size, coords_norm.shape[0], dtype=torch.bool)\n", "\n", " # Sampling\n", " td = batchify(td, num_samples)\n", "\n", " # Get the solution from the policy\n", " with torch.no_grad():\n", " out = policy(\n", " td.clone(), \n", " decode_type=\"sampling\", \n", " return_actions=True,\n", " num_starts=0,\n", " softmax_temp=softmax_temp\n", " )\n", "\n", " # Calculate the cost on the original scale\n", " coords_repeat = repeat(coords, 'n d -> b n d', b=batch_size, d=2)\n", " td['locs'] = batchify(coords_repeat, num_samples)\n", " reward = env.get_reward(td, out['actions'])\n", " reward = unbatchify(reward, num_samples)\n", " cost = ceil(-1 * torch.max(reward).item())\n", "\n", " # Check if there exists an optimal solution\n", " try:\n", " # Load the optimal solution\n", " solution = load_solution(os.path.join(tsplib_dir, problem_files[problem_idx][:-4] + '.opt.tour'))\n", " matches = re.findall(r'\\((.*?)\\)', solution.comment)\n", "\n", " # NOTE: in some problem solution file (e.g. ch130), the optimal cost is not writen with a brace\n", " # we temporarily skip to calculate the gap for these problems\n", " optimal_cost = int(matches[0])\n", " gap = (cost - optimal_cost) / optimal_cost\n", " print(f'Problem: {problem_files[problem_idx][:-4]}\\t Cost: {cost}\\t Optimal Cost: {optimal_cost}\\t Gap: {gap:.2%}')\n", " except:\n", " continue\n", " finally:\n", " print(f'problem: {problem_files[problem_idx][:-4]}\\t cost: {cost}\\t')" ] } ], "metadata": { "kernelspec": { "display_name": "rl4co-user", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 2 }