Solving the Flexible Job-Shop Scheduling Problem (FJSP)¶
The following notebook explains the FJSP and explains the solution construction process using an encoder-decoder architecture based on a Heterogeneous Graph Neural Network (HetGNN)
! pip install torch_geometric
Requirement already satisfied: torch_geometric in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (2.5.0) Requirement already satisfied: tqdm in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (4.66.1) Requirement already satisfied: numpy in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (1.26.3) Requirement already satisfied: scipy in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (1.11.4) Requirement already satisfied: fsspec in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (2023.12.2) Requirement already satisfied: jinja2 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (3.1.3) Requirement already satisfied: aiohttp in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (3.9.1) Requirement already satisfied: requests in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (2.31.0) Requirement already satisfied: pyparsing in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (3.1.1) Requirement already satisfied: scikit-learn in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (1.4.1.post1) Requirement already satisfied: psutil>=5.8.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from torch_geometric) (5.9.7) Requirement already satisfied: attrs>=17.3.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from aiohttp->torch_geometric) (23.2.0) Requirement already satisfied: multidict<7.0,>=4.5 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from aiohttp->torch_geometric) (6.0.4) Requirement already satisfied: yarl<2.0,>=1.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from aiohttp->torch_geometric) (1.9.4) Requirement already satisfied: frozenlist>=1.1.1 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from aiohttp->torch_geometric) (1.4.1) Requirement already satisfied: aiosignal>=1.1.2 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from aiohttp->torch_geometric) (1.3.1) Requirement already satisfied: async-timeout<5.0,>=4.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from aiohttp->torch_geometric) (4.0.3) Requirement already satisfied: MarkupSafe>=2.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from jinja2->torch_geometric) (2.1.3) Requirement already satisfied: charset-normalizer<4,>=2 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from requests->torch_geometric) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from requests->torch_geometric) (3.6) Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from requests->torch_geometric) (1.26.18) Requirement already satisfied: certifi>=2017.4.17 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from requests->torch_geometric) (2023.11.17) Requirement already satisfied: joblib>=1.2.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from scikit-learn->torch_geometric) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages (from scikit-learn->torch_geometric) (3.3.0)
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output
import time
import networkx as nx
import matplotlib.pyplot as plt
from rl4co.envs import FJSPEnv
from rl4co.models.zoo.l2d import L2DModel
from rl4co.models.zoo.l2d.policy import L2DPolicy
from rl4co.models.zoo.l2d.decoder import L2DDecoder
from rl4co.models.nn.graph.hgnn import HetGNNEncoder
from rl4co.utils.trainer import RL4COTrainer
generator_params = {
"num_jobs": 5, # the total number of jobs
"num_machines": 5, # the total number of machines that can process operations
"min_ops_per_job": 1, # minimum number of operatios per job
"max_ops_per_job": 2, # maximum number of operations per job
"min_processing_time": 1, # the minimum time required for a machine to process an operation
"max_processing_time": 20, # the maximum time required for a machine to process an operation
"min_eligible_ma_per_op": 1, # the minimum number of machines capable to process an operation
"max_eligible_ma_per_op": 2, # the maximum number of machines capable to process an operation
}
env = FJSPEnv(generator_params=generator_params)
td = env.reset(batch_size=[1])
Visualize the Problem¶
Below we visualize the generated instance of the FJSP. Blue nodes correspond to machines, red nodes to operations and yellow nodes to jobs. A machine may process an operation if there exists an edge between the two.
The thickness of the connection between a machine and an operation node specifies the processing time the respective machine needs to process the operation (thicker line := longer processing).
Each operation belongs to exactly one job, where an edge between a job and an operation node indicates that the respective operation belongs to the job. The number above an operation-job edge specifies the precedence-order in which the operations of a job need to be processed. A job is done when all operations belonging to it are scheduled. The instance is solved when all jobs are fully scheduled.
Also note that some operation nodes are not connected. These operation nodes are padded, so that all instances in a batch have the same number of operations (where we determine the maximum number of operations as num_jobs * max_ops_per_job).
# Create a bipartite graph from the adjacency matrix
G = nx.Graph()
proc_times = td["proc_times"].squeeze(0)
job_ops_adj = td["job_ops_adj"].squeeze(0)
order = td["ops_sequence_order"].squeeze(0) + 1
num_machines, num_operations = proc_times.shape
num_jobs = job_ops_adj.size(0)
jobs = [f"j{i+1}" for i in range(num_jobs)]
machines = [f"m{i+1}" for i in range(num_machines)]
operations = [f"o{i+1}" for i in range(num_operations)]
# Add nodes from each set
G.add_nodes_from(machines, bipartite=0)
G.add_nodes_from(operations, bipartite=1)
G.add_nodes_from(jobs, bipartite=2)
# Add edges based on the adjacency matrix
for i in range(num_machines):
for j in range(num_operations):
edge_weigth = proc_times[i][j]
if edge_weigth != 0:
G.add_edge(f"m{i+1}", f"o{j+1}", weight=edge_weigth)
# Add edges based on the adjacency matrix
for i in range(num_jobs):
for j in range(num_operations):
edge_weigth = job_ops_adj[i][j]
if edge_weigth != 0:
G.add_edge(f"j{i+1}", f"o{j+1}", weight=3, label=order[j])
widths = [x / 3 for x in nx.get_edge_attributes(G, 'weight').values()]
plt.figure(figsize=(10,6))
# Plot the graph
machines = [n for n, d in G.nodes(data=True) if d['bipartite'] == 0]
operations = [n for n, d in G.nodes(data=True) if d['bipartite'] == 1]
jobs = [n for n, d in G.nodes(data=True) if d['bipartite'] == 2]
pos = {}
pos.update((node, (1, index)) for index, node in enumerate(machines))
pos.update((node, (2, index)) for index, node in enumerate(operations))
pos.update((node, (3, index)) for index, node in enumerate(jobs))
edge_labels = {(u, v): d['label'].item() for u, v, d in G.edges(data=True) if d.get("label") is not None}
nx.draw_networkx_edge_labels(G, {k: (v[0]+.12, v[1]) for k,v in pos.items()}, edge_labels=edge_labels, rotate=False)
nx.draw_networkx_nodes(G, pos, nodelist=machines, node_color='b', label="Machine")
nx.draw_networkx_nodes(G, pos, nodelist=operations, node_color='r', label="Operation")
nx.draw_networkx_nodes(G, pos, nodelist=jobs, node_color='y', label="jobs")
nx.draw_networkx_edges(G, pos, width=widths, alpha=0.6)
plt.title('Visualization of the FJSP')
plt.legend(bbox_to_anchor=(.95, 1.05))
plt.axis('off')
plt.show()
Build a Model to Solve the FJSP¶
In the FJSP we typically encode Operations and Machines separately, since they pose different node types in a k-partite Graph. Therefore, the encoder for the FJSP returns two hidden representations, the first containing machine embeddings and the second containing operation embeddings:
# Lets generate a more complex instance
generator_params = {
"num_jobs": 10, # the total number of jobs
"num_machines": 5, # the total number of machines that can process operations
"min_ops_per_job": 4, # minimum number of operatios per job
"max_ops_per_job": 6, # maximum number of operations per job
"min_processing_time": 1, # the minimum time required for a machine to process an operation
"max_processing_time": 20, # the maximum time required for a machine to process an operation
"min_eligible_ma_per_op": 1, # the minimum number of machines capable to process an operation
"max_eligible_ma_per_op": 5, # the maximum number of machines capable to process an operation
}
env = FJSPEnv(generator_params=generator_params)
td = env.reset(batch_size=[1])
encoder = HetGNNEncoder(embed_dim=32, num_layers=2)
(ma_emb, op_emb), init = encoder(td)
print(ma_emb.shape)
print(op_emb.shape)
torch.Size([1, 60, 32]) torch.Size([1, 5, 32])
The decoder return logits over a composite action-space of size (1 + num_jobs * num_machines), where each entry corresponds to a machine-job combination plus one waiting-operation. The selected action specifies, which job is processed next by which machine. To be more precise, the next operation of the selected job is processed. This operation can be retrieved from td["next_op"]
# next operation per job
td["next_op"]
tensor([[ 0, 4, 10, 15, 21, 27, 33, 39, 45, 49]])
decoder = L2DDecoder(env_name=env.name, embed_dim=32)
logits, mask = decoder(td, (ma_emb, op_emb), num_starts=0)
# (1 + num_jobs * num_machines)
print(logits.shape)
torch.Size([1, 51])
def make_step(td):
logits, mask = decoder(td, (ma_emb, op_emb), num_starts=0)
action = logits.masked_fill(~mask, -torch.inf).argmax(1)
td["action"] = action
td = env.step(td)["next"]
return td
Visualize solution construction¶
Starting at $t=0$, the decoder uses the machine-operation embeddings of the encoder to decide which machine-job-combination to schedule next. Note, that due to the precedence relationship, the operations to be scheduled next are fixed per job. Therefore, it is sufficient to determine the next job to be scheduled, which significantly reduces the action space.
After some operations have been scheduled, either all the machines are busy or all the jobs have been scheduled with their currently active operation. In this case, the environment transitions to a new time step $t$. The new $t$ will be equal to the first time step where a machine finishes an operation in the partial schedule. When an operation is finished, the machine that has processed it is immediately ready to process the next operation. Also, the next operation of the respective job can then be scheduled.
The start time of an operation is always equal to the time step in which it is scheduled. The finish time of an operation is equal to its start time plus the processing time required by the machine on which it is being processed.
The figure below visualises this process.
env.render(td, 0)
# Update plot within a for loop
while not td["done"].all():
# Clear the previous output for the next iteration
clear_output(wait=True)
td = make_step(td)
env.render(td, 0)
# Display updated plot
display(plt.gcf())
# Pause for a moment to see the changes
time.sleep(.4)
<Figure size 640x480 with 0 Axes>
<Figure size 640x480 with 0 Axes>
<Figure size 640x480 with 0 Axes>
if torch.cuda.is_available():
accelerator = "gpu"
batch_size = 256
train_data_size = 2_000
embed_dim = 128
num_encoder_layers = 4
else:
accelerator = "cpu"
batch_size = 32
train_data_size = 1_000
embed_dim = 64
num_encoder_layers = 2
# Policy: neural network, in this case with encoder-decoder architecture
policy = L2DPolicy(embed_dim=embed_dim, num_encoder_layers=num_encoder_layers, env_name="fjsp")
# Model: default is AM with REINFORCE and greedy rollout baseline
model = L2DModel(env,
policy=policy,
baseline="rollout",
batch_size=batch_size,
train_data_size=train_data_size,
val_data_size=1_000,
optimizer_kwargs={"lr": 1e-4})
trainer = RL4COTrainer(
max_epochs=3,
accelerator=accelerator,
devices=1,
logger=None,
)
trainer.fit(model)
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`. /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`. /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:551: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead. Using bfloat16 Automatic Mixed Precision (AMP) GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default val_file not set. Generating dataset instead test_file not set. Generating dataset instead | Name | Type | Params -------------------------------------------- 0 | env | FJSPEnv | 0 1 | policy | L2DPolicy | 81.2 K 2 | baseline | WarmupBaseline | 81.2 K -------------------------------------------- 162 K Trainable params 0 Non-trainable params 162 K Total params 0.649 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance. /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance. /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Training: | | 0/? [00:00<?, ?it/s]
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
Solving the Job-Shop Scheduling Problem (JSSP)¶
import gc
from rl4co.envs import JSSPEnv
from rl4co.models.zoo.l2d.model import L2DPPOModel
from rl4co.models.zoo.l2d.policy import L2DPolicy4PPO
from torch.utils.data import DataLoader
# Lets generate a more complex instance
generator_params = {
"num_jobs": 15, # the total number of jobs
"num_machines": 15, # the total number of machines that can process operations
"min_processing_time": 1, # the minimum time required for a machine to process an operation
"max_processing_time": 99, # the maximum time required for a machine to process an operation
}
env = JSSPEnv(
generator_params=generator_params,
_torchrl_mode=True,
stepwise_reward=True
)
Train on synthetic data and test on Taillard benchmark¶
# Policy: neural network, in this case with encoder-decoder architecture
policy = L2DPolicy4PPO(
embed_dim=embed_dim,
num_encoder_layers=num_encoder_layers,
env_name="jssp",
het_emb=False
)
model = L2DPPOModel(
env=env,
policy=policy,
batch_size=batch_size,
train_data_size=train_data_size,
val_data_size=1_000,
optimizer_kwargs={"lr": 1e-4}
)
CHECKPOINT_PATH = "last.ckpt"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
model = L2DPPOModel.load_from_checkpoint(CHECKPOINT_PATH)
except FileNotFoundError:
trainer = RL4COTrainer(
max_epochs=1,
accelerator=accelerator,
devices=1,
logger=None,
)
trainer.fit(model)
finally:
model = model.to(device)
Using bfloat16 Automatic Mixed Precision (AMP) GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Overriding gradient_clip_val to None for 'automatic_optimization=False' models /Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:43: attribute 'policy' removed from hparams because it cannot be pickled val_file not set. Generating dataset instead test_file not set. Generating dataset instead | Name | Type | Params --------------------------------------------- 0 | env | JSSPEnv | 0 1 | policy | L2DPolicy4PPO | 25.5 K 2 | policy_old | L2DPolicy4PPO | 25.5 K --------------------------------------------- 51.1 K Trainable params 0 Non-trainable params 51.1 K Total params 0.204 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
%%bash
# Define the folder path
DATA_PATH="./taillard"
# Check if the folder exists
if [ -d "$DATA_PATH" ]; then
echo "Folder already exists."
else
echo "Folder does not exist. Creating folder and downloading taillard instances..."
mkdir -p "$DATA_PATH"
! wget http://mistic.heig-vd.ch/taillard/problemes.dir/ordonnancement.dir/jobshop.dir/tai20_15.txt -O taillard/tai20_15.txt
! wget http://mistic.heig-vd.ch/taillard/problemes.dir/ordonnancement.dir/jobshop.dir/tai15_15.txt -O taillard/tai15_15.txt
! wget http://mistic.heig-vd.ch/taillard/problemes.dir/ordonnancement.dir/jobshop.dir/tai20_20.txt -O taillard/tai20_20.txt
! wget http://mistic.heig-vd.ch/taillard/problemes.dir/ordonnancement.dir/jobshop.dir/tai30_15.txt -O taillard/tai30_15.txt
! wget http://mistic.heig-vd.ch/taillard/problemes.dir/ordonnancement.dir/jobshop.dir/tai30_20.txt -O taillard/tai30_20.txt
fi
exit 0
Folder does not exist. Creating folder and downloading taillard instances...
bash: line 11: wget: command not found bash: line 12: wget: command not found bash: line 13: wget: command not found bash: line 14: wget: command not found bash: line 15: wget: command not found
# path to taillard instances
FILE_PATH = "./taillard/tai{instance_type}.txt"
results = {}
instance_types = ["15_15", "20_15", "20_20", "30_15", "30_20"]
for instance_type in instance_types:
dataset = env.dataset(batch_size=[10], phase="test", filename=FILE_PATH.format(instance_type=instance_type))
dl = DataLoader(dataset, batch_size=5, collate_fn=dataset.collate_fn)
rewards = []
for batch in dl:
td = env.reset(batch).to(device)
# use policy.generate to avoid grad calculations which can lead to oom
out = model.policy.generate(td, env=env, phase="test", decode_type="multistart_sampling", num_starts=100, select_best=True)
rewards.append(out["reward"])
reward = torch.cat(rewards, dim=0).mean().item()
results[instance_type] = reward
print("Done evaluating instance type %s with reward %s" % (instance_type, reward))
# avoid ooms due to cache not being cleared
model.rb.empty()
gc.collect()
torch.cuda.empty_cache()
Provided file name ../../ai4co/rl4co/data/jssp/taillard/15j_15m not found. Make sure to provide a file in the right path first or unset test_file to generate data automatically instead
Done evaluating instance type 15j_15m with reward -1408.0999755859375
Provided file name ../../ai4co/rl4co/data/jssp/taillard/20j_15m not found. Make sure to provide a file in the right path first or unset test_file to generate data automatically instead Provided file name ../../ai4co/rl4co/data/jssp/taillard/20j_20m not found. Make sure to provide a file in the right path first or unset test_file to generate data automatically instead
Done evaluating instance type 20j_15m with reward -1380.699951171875
Provided file name ../../ai4co/rl4co/data/jssp/taillard/30j_15m not found. Make sure to provide a file in the right path first or unset test_file to generate data automatically instead
Done evaluating instance type 20j_20m with reward -1349.9000244140625
Provided file name ../../ai4co/rl4co/data/jssp/taillard/30j_20m not found. Make sure to provide a file in the right path first or unset test_file to generate data automatically instead
Done evaluating instance type 30j_15m with reward -1374.0999755859375 Done evaluating instance type 30j_20m with reward -1371.699951171875