Skip to content

Environment Embeddings

In autoregressive policies, environment embeddings transfer data from feature space to hidden space:

  • Initial Embeddings: encode global problem features
  • Context Embeddings: modify current node embedding during decoding
  • Dynamic Embeddings: modify all nodes embeddings during decoding

policy

Context Embeddings

The context embedding is used to modify the query embedding of the problem node of the current partial solution. Usually consists of a projection of gathered node embeddings and features to the embedding space.

EnvContext

EnvContext(
    embed_dim, step_context_dim=None, linear_bias=False
)

Bases: Module

Base class for environment context embeddings. The context embedding is used to modify the query embedding of the problem node of the current partial solution. Consists of a linear layer that projects the node features to the embedding space.

Source code in rl4co/models/nn/env_embeddings/context.py
51
52
53
54
55
def __init__(self, embed_dim, step_context_dim=None, linear_bias=False):
    super(EnvContext, self).__init__()
    self.embed_dim = embed_dim
    step_context_dim = step_context_dim if step_context_dim is not None else embed_dim
    self.project_context = nn.Linear(step_context_dim, embed_dim, bias=linear_bias)

FFSPContext

FFSPContext(embed_dim, stage_cnt=None)

Bases: EnvContext

Source code in rl4co/models/nn/env_embeddings/context.py
74
75
76
77
78
79
def __init__(self, embed_dim, stage_cnt=None):
    self.has_stage_emb = stage_cnt is not None
    step_context_dim = (1 + int(self.has_stage_emb)) * embed_dim
    super().__init__(embed_dim=embed_dim, step_context_dim=step_context_dim)
    if self.has_stage_emb:
        self.stage_emb = nn.Parameter(torch.rand(stage_cnt, embed_dim))

TSPContext

TSPContext(embed_dim)

Bases: EnvContext

Context embedding for the Traveling Salesman Problem (TSP). Project the following to the embedding space:

- first node embedding
- current node embedding
Source code in rl4co/models/nn/env_embeddings/context.py
108
109
110
111
112
def __init__(self, embed_dim):
    super(TSPContext, self).__init__(embed_dim, 2 * embed_dim)
    self.W_placeholder = nn.Parameter(
        torch.Tensor(2 * self.embed_dim).uniform_(-1, 1)
    )

VRPContext

VRPContext(embed_dim)

Bases: EnvContext

Context embedding for the Capacitated Vehicle Routing Problem (CVRP). Project the following to the embedding space:

- current node embedding
- remaining capacity (vehicle_capacity - used_capacity)
Source code in rl4co/models/nn/env_embeddings/context.py
146
147
148
149
def __init__(self, embed_dim):
    super(VRPContext, self).__init__(
        embed_dim=embed_dim, step_context_dim=embed_dim + 1
    )

VRPTWContext

VRPTWContext(embed_dim)

Bases: VRPContext

Context embedding for the Capacitated Vehicle Routing Problem (CVRP). Project the following to the embedding space:

- current node embedding
- remaining capacity (vehicle_capacity - used_capacity)
- current time
Source code in rl4co/models/nn/env_embeddings/context.py
164
165
166
167
def __init__(self, embed_dim):
    super(VRPContext, self).__init__(
        embed_dim=embed_dim, step_context_dim=embed_dim + 2
    )

SVRPContext

SVRPContext(embed_dim)

Bases: EnvContext

Context embedding for the Skill Vehicle Routing Problem (SVRP). Project the following to the embedding space:

- current node embedding
- current technician
Source code in rl4co/models/nn/env_embeddings/context.py
182
183
def __init__(self, embed_dim):
    super(SVRPContext, self).__init__(embed_dim=embed_dim, step_context_dim=embed_dim)

PCTSPContext

PCTSPContext(embed_dim)

Bases: EnvContext

Context embedding for the Prize Collecting TSP (PCTSP). Project the following to the embedding space:

- current node embedding
- remaining prize (prize_required - cur_total_prize)
Source code in rl4co/models/nn/env_embeddings/context.py
197
198
def __init__(self, embed_dim):
    super(PCTSPContext, self).__init__(embed_dim, embed_dim + 1)

OPContext

OPContext(embed_dim)

Bases: EnvContext

Context embedding for the Orienteering Problem (OP). Project the following to the embedding space:

- current node embedding
- remaining distance (max_length - tour_length)
Source code in rl4co/models/nn/env_embeddings/context.py
214
215
def __init__(self, embed_dim):
    super(OPContext, self).__init__(embed_dim, embed_dim + 1)

DPPContext

DPPContext(embed_dim)

Bases: EnvContext

Context embedding for the Decap Placement Problem (DPP), EDA (electronic design automation). Project the following to the embedding space:

- current cell embedding
Source code in rl4co/models/nn/env_embeddings/context.py
228
229
def __init__(self, embed_dim):
    super(DPPContext, self).__init__(embed_dim)

forward

forward(embeddings, td)

Context cannot be defined by a single node embedding for DPP, hence 0. We modify the dynamic embedding instead to capture placed items

Source code in rl4co/models/nn/env_embeddings/context.py
231
232
233
234
235
def forward(self, embeddings, td):
    """Context cannot be defined by a single node embedding for DPP, hence 0.
    We modify the dynamic embedding instead to capture placed items
    """
    return embeddings.new_zeros(embeddings.size(0), self.embed_dim)

PDPContext

PDPContext(embed_dim)

Bases: EnvContext

Context embedding for the Pickup and Delivery Problem (PDP). Project the following to the embedding space:

- current node embedding
Source code in rl4co/models/nn/env_embeddings/context.py
244
245
def __init__(self, embed_dim):
    super(PDPContext, self).__init__(embed_dim, embed_dim)

MTSPContext

MTSPContext(embed_dim, linear_bias=False)

Bases: EnvContext

Context embedding for the Multiple Traveling Salesman Problem (mTSP). Project the following to the embedding space:

- current node embedding
- remaining_agents
- current_length
- max_subtour_length
- distance_from_depot
Source code in rl4co/models/nn/env_embeddings/context.py
262
263
264
265
266
267
def __init__(self, embed_dim, linear_bias=False):
    super(MTSPContext, self).__init__(embed_dim, 2 * embed_dim)
    proj_in_dim = (
        4  # remaining_agents, current_length, max_subtour_length, distance_from_depot
    )
    self.proj_dynamic_feats = nn.Linear(proj_in_dim, embed_dim, bias=linear_bias)

SMTWTPContext

SMTWTPContext(embed_dim)

Bases: EnvContext

Context embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP). Project the following to the embedding space:

- current node embedding
- current time
Source code in rl4co/models/nn/env_embeddings/context.py
298
299
def __init__(self, embed_dim):
    super(SMTWTPContext, self).__init__(embed_dim, embed_dim + 1)

MDCPDPContext

MDCPDPContext(embed_dim)

Bases: EnvContext

Context embedding for the MDCPDP. Project the following to the embedding space:

- current node embedding
Source code in rl4co/models/nn/env_embeddings/context.py
316
317
def __init__(self, embed_dim):
    super(MDCPDPContext, self).__init__(embed_dim, embed_dim)

MTVRPContext

MTVRPContext(embed_dim)

Bases: VRPContext

Context embedding for Multi-Task VRPEnv. Project the following to the embedding space:

- current node embedding
- remaining_linehaul_capacity (vehicle_capacity - used_capacity_linehaul)
- remaining_backhaul_capacity (vehicle_capacity - used_capacity_backhaul)
- current time
- current_route_length
- open route indicator
Source code in rl4co/models/nn/env_embeddings/context.py
348
349
350
351
def __init__(self, embed_dim):
    super(VRPContext, self).__init__(
        embed_dim=embed_dim, step_context_dim=embed_dim + 5
    )

env_context_embedding

env_context_embedding(
    env_name: str, config: dict
) -> Module

Get environment context embedding. The context embedding is used to modify the query embedding of the problem node of the current partial solution. Usually consists of a projection of gathered node embeddings and features to the embedding space.

Parameters:

  • env

    Environment or its name.

  • config (dict) –

    A dictionary of configuration options for the environment.

Source code in rl4co/models/nn/env_embeddings/context.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def env_context_embedding(env_name: str, config: dict) -> nn.Module:
    """Get environment context embedding. The context embedding is used to modify the
    query embedding of the problem node of the current partial solution.
    Usually consists of a projection of gathered node embeddings and features to the embedding space.

    Args:
        env: Environment or its name.
        config: A dictionary of configuration options for the environment.
    """
    embedding_registry = {
        "tsp": TSPContext,
        "atsp": TSPContext,
        "cvrp": VRPContext,
        "cvrptw": VRPTWContext,
        "ffsp": FFSPContext,
        "svrp": SVRPContext,
        "sdvrp": VRPContext,
        "pctsp": PCTSPContext,
        "spctsp": PCTSPContext,
        "op": OPContext,
        "dpp": DPPContext,
        "mdpp": DPPContext,
        "pdp": PDPContext,
        "mtsp": MTSPContext,
        "smtwtp": SMTWTPContext,
        "mdcpdp": MDCPDPContext,
        "mtvrp": MTVRPContext,
    }

    if env_name not in embedding_registry:
        raise ValueError(
            f"Unknown environment name '{env_name}'. Available context embeddings: {embedding_registry.keys()}"
        )

    return embedding_registry[env_name](**config)

Dynamic Embeddings

The dynamic embedding is used to modify query, key and value vectors of the attention mechanism based on the current state of the environment (which is changing during the rollout). Generally consists of a linear layer that projects the node features to the embedding space.

StaticEmbedding

StaticEmbedding(*args, **kwargs)

Bases: Module

Static embedding for general problems. This is used for problems that do not have any dynamic information, except for the information regarding the current action (e.g. the current node in TSP). See context embedding for more details.

Source code in rl4co/models/nn/env_embeddings/dynamic.py
53
54
def __init__(self, *args, **kwargs):
    super(StaticEmbedding, self).__init__()

SDVRPDynamicEmbedding

SDVRPDynamicEmbedding(embed_dim, linear_bias=False)

Bases: Module

Dynamic embedding for the Split Delivery Vehicle Routing Problem (SDVRP). Embed the following node features to the embedding space:

- demand_with_depot: demand of the customers and the depot

The demand with depot is used to modify the query, key and value vectors of the attention mechanism based on the current state of the environment (which is changing during the rollout).

Source code in rl4co/models/nn/env_embeddings/dynamic.py
68
69
70
def __init__(self, embed_dim, linear_bias=False):
    super(SDVRPDynamicEmbedding, self).__init__()
    self.projection = nn.Linear(1, 3 * embed_dim, bias=linear_bias)

env_dynamic_embedding

env_dynamic_embedding(
    env_name: str, config: dict
) -> Module

Get environment dynamic embedding. The dynamic embedding is used to modify query, key and value vectors of the attention mechanism based on the current state of the environment (which is changing during the rollout). Consists of a linear layer that projects the node features to the embedding space.

Parameters:

  • env

    Environment or its name.

  • config (dict) –

    A dictionary of configuration options for the environment.

Source code in rl4co/models/nn/env_embeddings/dynamic.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def env_dynamic_embedding(env_name: str, config: dict) -> nn.Module:
    """Get environment dynamic embedding. The dynamic embedding is used to modify query, key and value vectors of the attention mechanism
    based on the current state of the environment (which is changing during the rollout).
    Consists of a linear layer that projects the node features to the embedding space.

    Args:
        env: Environment or its name.
        config: A dictionary of configuration options for the environment.
    """
    embedding_registry = {
        "tsp": StaticEmbedding,
        "atsp": StaticEmbedding,
        "cvrp": StaticEmbedding,
        "cvrptw": StaticEmbedding,
        "ffsp": StaticEmbedding,
        "svrp": StaticEmbedding,
        "sdvrp": SDVRPDynamicEmbedding,
        "pctsp": StaticEmbedding,
        "spctsp": StaticEmbedding,
        "op": StaticEmbedding,
        "dpp": StaticEmbedding,
        "mdpp": StaticEmbedding,
        "pdp": StaticEmbedding,
        "mtsp": StaticEmbedding,
        "smtwtp": StaticEmbedding,
        "jssp": JSSPDynamicEmbedding,
        "fjsp": JSSPDynamicEmbedding,
        "mtvrp": StaticEmbedding,
    }

    if env_name not in embedding_registry:
        log.warning(
            f"Unknown environment name '{env_name}'. Available dynamic embeddings: {embedding_registry.keys()}. Defaulting to StaticEmbedding."
        )
    return embedding_registry.get(env_name, StaticEmbedding)(**config)

Init Embeddings

The init embedding is used to initialize the general embedding of the problem nodes without any solution information. Generally consists of a linear layer that projects the node features to the embedding space.

TSPInitEmbedding

TSPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Traveling Salesman Problems (TSP). Embed the following node features to the embedding space:

- locs: x, y coordinates of the cities
Source code in rl4co/models/nn/env_embeddings/init.py
56
57
58
59
def __init__(self, embed_dim, linear_bias=True):
    super(TSPInitEmbedding, self).__init__()
    node_dim = 2  # x, y
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)

MatNetInitEmbedding

MatNetInitEmbedding(
    embed_dim: int, mode: str = "RandomOneHot"
)

Bases: Module

Preparing the initial row and column embeddings for MatNet.

Reference: https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51

Source code in rl4co/models/nn/env_embeddings/init.py
76
77
78
79
80
81
82
83
84
def __init__(self, embed_dim: int, mode: str = "RandomOneHot") -> None:
    super().__init__()

    self.embed_dim = embed_dim
    assert mode in {
        "RandomOneHot",
        "Random",
    }, "mode must be one of ['RandomOneHot', 'Random']"
    self.mode = mode

VRPInitEmbedding

VRPInitEmbedding(
    embed_dim, linear_bias=True, node_dim: int = 3
)

Bases: Module

Initial embedding for the Vehicle Routing Problems (VRP). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (depot and customers separately)
- demand: demand of the customers
Source code in rl4co/models/nn/env_embeddings/init.py
117
118
119
120
121
def __init__(self, embed_dim, linear_bias=True, node_dim: int = 3):
    super(VRPInitEmbedding, self).__init__()
    node_dim = node_dim  # 3: x, y, demand
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)
    self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias)  # depot embedding

PCTSPInitEmbedding

PCTSPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Prize Collecting Traveling Salesman Problems (PCTSP). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (depot and customers separately)
- expected_prize: expected prize for visiting the customers.
    In PCTSP, this is the actual prize. In SPCTSP, this is the expected prize.
- penalty: penalty for not visiting the customers
Source code in rl4co/models/nn/env_embeddings/init.py
182
183
184
185
186
def __init__(self, embed_dim, linear_bias=True):
    super(PCTSPInitEmbedding, self).__init__()
    node_dim = 4  # x, y, prize, penalty
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)
    self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias)

OPInitEmbedding

OPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Orienteering Problems (OP). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (depot and customers separately)
- prize: prize for visiting the customers
Source code in rl4co/models/nn/env_embeddings/init.py
213
214
215
216
217
def __init__(self, embed_dim, linear_bias=True):
    super(OPInitEmbedding, self).__init__()
    node_dim = 3  # x, y, prize
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)
    self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias)  # depot embedding

DPPInitEmbedding

DPPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Decap Placement Problem (DPP), EDA (electronic design automation). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (cells)
- probe: index of the (single) probe cell. We embed the euclidean distance from the probe to all cells.
Source code in rl4co/models/nn/env_embeddings/init.py
242
243
244
245
246
def __init__(self, embed_dim, linear_bias=True):
    super(DPPInitEmbedding, self).__init__()
    node_dim = 2  # x, y
    self.init_embed = nn.Linear(node_dim, embed_dim // 2, linear_bias)  # locs
    self.init_embed_probe = nn.Linear(1, embed_dim // 2, linear_bias)  # probe

MDPPInitEmbedding

MDPPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Multi-port Placement Problem (MDPP), EDA (electronic design automation). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (cells)
- probe: indexes of the probe cells (multiple). We embed the euclidean distance of each cell to the closest probe.
Source code in rl4co/models/nn/env_embeddings/init.py
268
269
270
271
272
273
274
275
def __init__(self, embed_dim, linear_bias=True):
    super(MDPPInitEmbedding, self).__init__()
    node_dim = 2  # x, y
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)  # locs
    self.init_embed_probe_distance = nn.Linear(
        1, embed_dim, linear_bias
    )  # probe_distance
    self.project_out = nn.Linear(embed_dim * 2, embed_dim, linear_bias)

PDPInitEmbedding

PDPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Pickup and Delivery Problem (PDP). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (depot, pickups and deliveries separately)
   Note that pickups and deliveries are interleaved in the input.
Source code in rl4co/models/nn/env_embeddings/init.py
300
301
302
303
304
305
def __init__(self, embed_dim, linear_bias=True):
    super(PDPInitEmbedding, self).__init__()
    node_dim = 2  # x, y
    self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias)
    self.init_embed_pick = nn.Linear(node_dim * 2, embed_dim, linear_bias)
    self.init_embed_delivery = nn.Linear(node_dim, embed_dim, linear_bias)

MTSPInitEmbedding

MTSPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Multiple Traveling Salesman Problem (mTSP). Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (depot, cities)
Source code in rl4co/models/nn/env_embeddings/init.py
327
328
329
330
331
332
def __init__(self, embed_dim, linear_bias=True):
    """NOTE: new made by Fede. May need to be checked"""
    super(MTSPInitEmbedding, self).__init__()
    node_dim = 2  # x, y
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)
    self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias)  # depot embedding

SMTWTPInitEmbedding

SMTWTPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP). Embed the following node features to the embedding space:

- job_due_time: due time of the jobs
- job_weight: weights of the jobs
- job_process_time: the processing time of jobs
Source code in rl4co/models/nn/env_embeddings/init.py
348
349
350
351
def __init__(self, embed_dim, linear_bias=True):
    super(SMTWTPInitEmbedding, self).__init__()
    node_dim = 3  # job_due_time, job_weight, job_process_time
    self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias)

MDCPDPInitEmbedding

MDCPDPInitEmbedding(embed_dim, linear_bias=True)

Bases: Module

Initial embedding for the MDCPDP environment Embed the following node features to the embedding space:

- locs: x, y coordinates of the nodes (depot, pickups and deliveries separately)
   Note that pickups and deliveries are interleaved in the input.
Source code in rl4co/models/nn/env_embeddings/init.py
369
370
371
372
373
374
def __init__(self, embed_dim, linear_bias=True):
    super(MDCPDPInitEmbedding, self).__init__()
    node_dim = 2  # x, y
    self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias)
    self.init_embed_pick = nn.Linear(node_dim * 2, embed_dim, linear_bias)
    self.init_embed_delivery = nn.Linear(node_dim, embed_dim, linear_bias)

env_init_embedding

env_init_embedding(env_name: str, config: dict) -> Module

Get environment initial embedding. The init embedding is used to initialize the general embedding of the problem nodes without any solution information. Consists of a linear layer that projects the node features to the embedding space.

Parameters:

  • env

    Environment or its name.

  • config (dict) –

    A dictionary of configuration options for the environment.

Source code in rl4co/models/nn/env_embeddings/init.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def env_init_embedding(env_name: str, config: dict) -> nn.Module:
    """Get environment initial embedding. The init embedding is used to initialize the
    general embedding of the problem nodes without any solution information.
    Consists of a linear layer that projects the node features to the embedding space.

    Args:
        env: Environment or its name.
        config: A dictionary of configuration options for the environment.
    """
    embedding_registry = {
        "tsp": TSPInitEmbedding,
        "atsp": TSPInitEmbedding,
        "matnet": MatNetInitEmbedding,
        "cvrp": VRPInitEmbedding,
        "cvrptw": VRPTWInitEmbedding,
        "svrp": SVRPInitEmbedding,
        "sdvrp": VRPInitEmbedding,
        "pctsp": PCTSPInitEmbedding,
        "spctsp": PCTSPInitEmbedding,
        "op": OPInitEmbedding,
        "dpp": DPPInitEmbedding,
        "mdpp": MDPPInitEmbedding,
        "pdp": PDPInitEmbedding,
        "pdp_ruin_repair": TSPInitEmbedding,
        "tsp_kopt": TSPInitEmbedding,
        "mtsp": MTSPInitEmbedding,
        "smtwtp": SMTWTPInitEmbedding,
        "mdcpdp": MDCPDPInitEmbedding,
        "fjsp": FJSPInitEmbedding,
        "jssp": FJSPInitEmbedding,
        "mtvrp": MTVRPInitEmbedding,
    }

    if env_name not in embedding_registry:
        raise ValueError(
            f"Unknown environment name '{env_name}'. Available init embeddings: {embedding_registry.keys()}"
        )

    return embedding_registry[env_name](**config)