Skip to content

Constructive Autoregressive Methods

Attention Model (AM)

Classes:

AttentionModel

AttentionModel(
    env: RL4COEnvBase,
    policy: AttentionModelPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs
)

Bases: REINFORCE

Attention Model based on REINFORCE: https://arxiv.org/abs/1803.08475. Check :class:REINFORCE and :class:rl4co.models.RL4COLitModule for more details such as additional parameters including batch size.

Parameters:

  • env (RL4COEnvBase) –

    Environment to use for the algorithm

  • policy (AttentionModelPolicy, default: None ) –

    Policy to use for the algorithm

  • baseline (REINFORCEBaseline | str, default: 'rollout' ) –

    REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs

    Keyword arguments for policy

  • baseline_kwargs

    Keyword arguments for baseline

  • **kwargs

    Keyword arguments passed to the superclass

Source code in rl4co/models/zoo/am/model.py
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    env: RL4COEnvBase,
    policy: AttentionModelPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs,
):
    if policy is None:
        policy = AttentionModelPolicy(env_name=env.name, **policy_kwargs)

    super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)

Classes:

AttentionModelPolicy

AttentionModelPolicy(
    encoder: Module = None,
    decoder: Module = None,
    embed_dim: int = 128,
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    feedforward_hidden: int = 512,
    env_name: str = "tsp",
    encoder_network: Module = None,
    init_embedding: Module = None,
    context_embedding: Module = None,
    dynamic_embedding: Module = None,
    use_graph_context: bool = True,
    linear_bias_decoder: bool = False,
    sdpa_fn: Callable = None,
    sdpa_fn_encoder: Callable = None,
    sdpa_fn_decoder: Callable = None,
    mask_inner: bool = True,
    out_bias_pointer_attn: bool = False,
    check_nan: bool = True,
    temperature: float = 1.0,
    tanh_clipping: float = 10.0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    moe_kwargs: dict = {"encoder": None, "decoder": None},
    **unused_kwargs
)

Bases: AutoregressivePolicy

Attention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475. This model first encodes the input graph using a Graph Attention Network (GAT) (:class:AttentionModelEncoder) and then decodes the solution using a pointer network (:class:AttentionModelDecoder). Cache is used to store the embeddings of the nodes to be used by the decoder to save computation. See :class:rl4co.models.common.constructive.autoregressive.policy.AutoregressivePolicy for more details on the inference process.

Parameters:

  • encoder (Module, default: None ) –

    Encoder module, defaults to :class:AttentionModelEncoder

  • decoder (Module, default: None ) –

    Decoder module, defaults to :class:AttentionModelDecoder

  • embed_dim (int, default: 128 ) –

    Dimension of the node embeddings

  • num_encoder_layers (int, default: 3 ) –

    Number of layers in the encoder

  • num_heads (int, default: 8 ) –

    Number of heads in the attention layers

  • normalization (str, default: 'batch' ) –

    Normalization type in the attention layers

  • feedforward_hidden (int, default: 512 ) –

    Dimension of the hidden layer in the feedforward network

  • env_name (str, default: 'tsp' ) –

    Name of the environment used to initialize embeddings

  • encoder_network (Module, default: None ) –

    Network to use for the encoder

  • init_embedding (Module, default: None ) –

    Module to use for the initialization of the embeddings

  • context_embedding (Module, default: None ) –

    Module to use for the context embedding

  • dynamic_embedding (Module, default: None ) –

    Module to use for the dynamic embedding

  • use_graph_context (bool, default: True ) –

    Whether to use the graph context

  • linear_bias_decoder (bool, default: False ) –

    Whether to use a bias in the linear layer of the decoder

  • sdpa_fn_encoder (Callable, default: None ) –

    Function to use for the scaled dot product attention in the encoder

  • sdpa_fn_decoder (Callable, default: None ) –

    Function to use for the scaled dot product attention in the decoder

  • sdpa_fn (Callable, default: None ) –

    (deprecated) Function to use for the scaled dot product attention

  • mask_inner (bool, default: True ) –

    Whether to mask the inner product

  • out_bias_pointer_attn (bool, default: False ) –

    Whether to use a bias in the pointer attention

  • check_nan (bool, default: True ) –

    Whether to check for nan values during decoding

  • temperature (float, default: 1.0 ) –

    Temperature for the softmax

  • tanh_clipping (float, default: 10.0 ) –

    Tanh clipping value (see Bello et al., 2016)

  • mask_logits (bool, default: True ) –

    Whether to mask the logits during decoding

  • train_decode_type (str, default: 'sampling' ) –

    Type of decoding to use during training

  • val_decode_type (str, default: 'greedy' ) –

    Type of decoding to use during validation

  • test_decode_type (str, default: 'greedy' ) –

    Type of decoding to use during testing

  • moe_kwargs (dict, default: {'encoder': None, 'decoder': None} ) –

    Keyword arguments for MoE, e.g., {"encoder": {"hidden_act": "ReLU", "num_experts": 4, "k": 2, "noisy_gating": True}, "decoder": {"light_version": True, ...}}

Source code in rl4co/models/zoo/am/policy.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    encoder: nn.Module = None,
    decoder: nn.Module = None,
    embed_dim: int = 128,
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    feedforward_hidden: int = 512,
    env_name: str = "tsp",
    encoder_network: nn.Module = None,
    init_embedding: nn.Module = None,
    context_embedding: nn.Module = None,
    dynamic_embedding: nn.Module = None,
    use_graph_context: bool = True,
    linear_bias_decoder: bool = False,
    sdpa_fn: Callable = None,
    sdpa_fn_encoder: Callable = None,
    sdpa_fn_decoder: Callable = None,
    mask_inner: bool = True,
    out_bias_pointer_attn: bool = False,
    check_nan: bool = True,
    temperature: float = 1.0,
    tanh_clipping: float = 10.0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    moe_kwargs: dict = {"encoder": None, "decoder": None},
    **unused_kwargs,
):
    if encoder is None:
        encoder = AttentionModelEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_encoder_layers,
            env_name=env_name,
            normalization=normalization,
            feedforward_hidden=feedforward_hidden,
            net=encoder_network,
            init_embedding=init_embedding,
            sdpa_fn=sdpa_fn if sdpa_fn_encoder is None else sdpa_fn_encoder,
            moe_kwargs=moe_kwargs["encoder"],
        )

    if decoder is None:
        decoder = AttentionModelDecoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            env_name=env_name,
            context_embedding=context_embedding,
            dynamic_embedding=dynamic_embedding,
            sdpa_fn=sdpa_fn if sdpa_fn_decoder is None else sdpa_fn_decoder,
            mask_inner=mask_inner,
            out_bias_pointer_attn=out_bias_pointer_attn,
            linear_bias=linear_bias_decoder,
            use_graph_context=use_graph_context,
            check_nan=check_nan,
            moe_kwargs=moe_kwargs["decoder"],
        )

    super(AttentionModelPolicy, self).__init__(
        encoder=encoder,
        decoder=decoder,
        env_name=env_name,
        temperature=temperature,
        tanh_clipping=tanh_clipping,
        mask_logits=mask_logits,
        train_decode_type=train_decode_type,
        val_decode_type=val_decode_type,
        test_decode_type=test_decode_type,
        **unused_kwargs,
    )

Attention Model - PPO (AM-PPO)

Classes:

  • AMPPO

    PPO Model based on Proximal Policy Optimization (PPO) with an attention model policy.

AMPPO

AMPPO(
    env: RL4COEnvBase,
    policy: Module = None,
    critic: CriticNetwork = None,
    policy_kwargs: dict = {},
    critic_kwargs: dict = {},
    **kwargs
)

Bases: PPO

PPO Model based on Proximal Policy Optimization (PPO) with an attention model policy. We default to the attention model policy and the Attention Critic Network.

Parameters:

  • env (RL4COEnvBase) –

    Environment to use for the algorithm

  • policy (Module, default: None ) –

    Policy to use for the algorithm

  • critic (CriticNetwork, default: None ) –

    Critic to use for the algorithm

  • policy_kwargs (dict, default: {} ) –

    Keyword arguments for policy

  • critic_kwargs (dict, default: {} ) –

    Keyword arguments for critic

Source code in rl4co/models/zoo/amppo/model.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
    self,
    env: RL4COEnvBase,
    policy: nn.Module = None,
    critic: CriticNetwork = None,
    policy_kwargs: dict = {},
    critic_kwargs: dict = {},
    **kwargs,
):
    if policy is None:
        policy = AttentionModelPolicy(env_name=env.name, **policy_kwargs)

    if critic is None:
        log.info("Creating critic network for {}".format(env.name))
        # we reuse the parameters of the model
        encoder = getattr(policy, "encoder", None)
        if encoder is None:
            raise ValueError("Critic network requires an encoder")
        critic = CriticNetwork(
            copy.deepcopy(encoder).to(next(encoder.parameters()).device),
            **critic_kwargs,
        )

    super().__init__(env, policy, critic, **kwargs)

Heterogeneous Attention Model (HAM)

Classes:

HeterogeneousAttentionModel

HeterogeneousAttentionModel(
    env: RL4COEnvBase,
    policy: HeterogeneousAttentionModelPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs
)

Bases: REINFORCE

Heterogenous Attention Model for solving the Pickup and Delivery Problem based on REINFORCE: https://arxiv.org/abs/2110.02634.

Parameters:

  • env (RL4COEnvBase) –

    Environment to use for the algorithm

  • policy (HeterogeneousAttentionModelPolicy, default: None ) –

    Policy to use for the algorithm

  • baseline (REINFORCEBaseline | str, default: 'rollout' ) –

    REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs

    Keyword arguments for policy

  • baseline_kwargs

    Keyword arguments for baseline

  • **kwargs

    Keyword arguments passed to the superclass

Source code in rl4co/models/zoo/ham/model.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(
    self,
    env: RL4COEnvBase,
    policy: HeterogeneousAttentionModelPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs,
):
    assert (
        env.name == "pdp"
    ), "HeterogeneousAttentionModel only works for PDP (Pickup and Delivery Problem)"
    if policy is None:
        policy = HeterogeneousAttentionModelPolicy(env_name=env.name, **policy_kwargs)

    super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)

Classes:

HeterogeneousAttentionModelPolicy

HeterogeneousAttentionModelPolicy(
    encoder: Module = None,
    env_name: str = "pdp",
    init_embedding: Module = None,
    embed_dim: int = 128,
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    feedforward_hidden: int = 512,
    sdpa_fn: Optional[Callable] = None,
    **kwargs
)

Bases: AttentionModelPolicy

Heterogeneous Attention Model Policy based on https://ieeexplore.ieee.org/document/9352489. We re-declare the most important arguments here for convenience as in the paper. See :class:rl4co.models.zoo.am.AttentionModelPolicy for more details.

Parameters:

  • encoder (Module, default: None ) –

    Encoder module. Can be passed by sub-classes

  • env_name (str, default: 'pdp' ) –

    Name of the environment used to initialize embeddings

  • init_embedding (Module, default: None ) –

    Model to use for the initial embedding. If None, use the default embedding for the environment

  • embed_dim (int, default: 128 ) –

    Dimension of the embeddings

  • num_encoder_layers (int, default: 3 ) –

    Number of layers in the encoder

  • num_heads (int, default: 8 ) –

    Number of heads for the attention in encoder

  • normalization (str, default: 'batch' ) –

    Normalization to use for the attention layers

  • feedforward_hidden (int, default: 512 ) –

    Dimension of the hidden layer in the feedforward network

  • sdpa_fn (Optional[Callable], default: None ) –

    Function to use for the scaled dot product attention

  • **kwargs

    keyword arguments passed to the :class:rl4co.models.zoo.am.AttentionModelPolicy

Source code in rl4co/models/zoo/ham/policy.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    encoder: nn.Module = None,
    env_name: str = "pdp",
    init_embedding: nn.Module = None,
    embed_dim: int = 128,
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    feedforward_hidden: int = 512,
    sdpa_fn: Optional[Callable] = None,
    **kwargs,
):
    if encoder is None:
        encoder = GraphHeterogeneousAttentionEncoder(
            init_embedding=init_embedding,
            num_heads=num_heads,
            embed_dim=embed_dim,
            num_encoder_layers=num_encoder_layers,
            env_name=env_name,
            normalization=normalization,
            feedforward_hidden=feedforward_hidden,
            sdpa_fn=sdpa_fn,
        )
    else:
        encoder = encoder

    super(HeterogeneousAttentionModelPolicy, self).__init__(
        env_name=env_name,
        encoder=encoder,
        embed_dim=embed_dim,
        num_encoder_layers=num_encoder_layers,
        num_heads=num_heads,
        normalization=normalization,
        **kwargs,
    )

Classes:

HeterogenousMHA

HeterogenousMHA(
    num_heads,
    input_dim,
    embed_dim=None,
    val_dim=None,
    key_dim=None,
)

Bases: Module

Methods:

Source code in rl4co/models/zoo/ham/attention.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(self, num_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
    """
    Heterogenous Multi-Head Attention for Pickup and Delivery problems
    https://arxiv.org/abs/2110.02634
    """
    super(HeterogenousMHA, self).__init__()

    if val_dim is None:
        assert embed_dim is not None, "Provide either embed_dim or val_dim"
        val_dim = embed_dim // num_heads
    if key_dim is None:
        key_dim = val_dim

    self.num_heads = num_heads
    self.input_dim = input_dim
    self.embed_dim = embed_dim
    self.val_dim = val_dim
    self.key_dim = key_dim

    self.norm_factor = 1 / math.sqrt(key_dim)  # See Attention is all you need

    self.W_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))
    self.W_key = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))
    self.W_val = nn.Parameter(torch.Tensor(num_heads, input_dim, val_dim))

    # Pickup weights
    self.W1_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))
    self.W2_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))
    self.W3_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))

    # Delivery weights
    self.W4_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))
    self.W5_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))
    self.W6_query = nn.Parameter(torch.Tensor(num_heads, input_dim, key_dim))

    if embed_dim is not None:
        self.W_out = nn.Parameter(torch.Tensor(num_heads, key_dim, embed_dim))

    self.init_parameters()

forward

forward(q, h=None, mask=None)

Parameters:

  • q

    queries (batch_size, n_query, input_dim)

  • h

    data (batch_size, graph_size, input_dim)

  • mask

    mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)

Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)

Source code in rl4co/models/zoo/ham/attention.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
def forward(self, q, h=None, mask=None):
    """
    Args:
        q: queries (batch_size, n_query, input_dim)
        h: data (batch_size, graph_size, input_dim)
        mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)

    Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
    """
    if h is None:
        h = q  # compute self-attention

    # h should be (batch_size, graph_size, input_dim)
    batch_size, graph_size, input_dim = h.size()

    # Check if graph size is odd number
    assert (
        graph_size % 2 == 1
    ), "Graph size should have odd number of nodes due to pickup-delivery problem  \
                                 (n/2 pickup, n/2 delivery, 1 depot)"

    n_query = q.size(1)
    assert q.size(0) == batch_size
    assert q.size(2) == input_dim
    assert input_dim == self.input_dim, "Wrong embedding dimension of input"

    hflat = h.contiguous().view(-1, input_dim)  # [batch_size * graph_size, embed_dim]
    qflat = q.contiguous().view(-1, input_dim)  # [batch_size * n_query, embed_dim]

    # last dimension can be different for keys and values
    shp = (self.num_heads, batch_size, graph_size, -1)
    shp_q = (self.num_heads, batch_size, n_query, -1)

    # pickup -> its delivery attention
    n_pick = (graph_size - 1) // 2
    shp_delivery = (self.num_heads, batch_size, n_pick, -1)
    shp_q_pick = (self.num_heads, batch_size, n_pick, -1)

    # pickup -> all pickups attention
    shp_allpick = (self.num_heads, batch_size, n_pick, -1)
    shp_q_allpick = (self.num_heads, batch_size, n_pick, -1)

    # pickup -> all pickups attention
    shp_alldelivery = (self.num_heads, batch_size, n_pick, -1)
    shp_q_alldelivery = (self.num_heads, batch_size, n_pick, -1)

    # Calculate queries, (num_heads, n_query, graph_size, key/val_size)
    Q = torch.matmul(qflat, self.W_query).view(shp_q)
    # Calculate keys and values (num_heads, batch_size, graph_size, key/val_size)
    K = torch.matmul(hflat, self.W_key).view(shp)
    V = torch.matmul(hflat, self.W_val).view(shp)

    # pickup -> its delivery
    pick_flat = (
        h[:, 1 : n_pick + 1, :].contiguous().view(-1, input_dim)
    )  # [batch_size * n_pick, embed_dim]
    delivery_flat = (
        h[:, n_pick + 1 :, :].contiguous().view(-1, input_dim)
    )  # [batch_size * n_pick, embed_dim]

    # pickup -> its delivery attention
    Q_pick = torch.matmul(pick_flat, self.W1_query).view(
        shp_q_pick
    )  # (self.num_heads, batch_size, n_pick, key_size)
    K_delivery = torch.matmul(delivery_flat, self.W_key).view(
        shp_delivery
    )  # (self.num_heads, batch_size, n_pick, -1)
    V_delivery = torch.matmul(delivery_flat, self.W_val).view(
        shp_delivery
    )  # (num_heads, batch_size, n_pick, key/val_size)

    # pickup -> all pickups attention
    Q_pick_allpick = torch.matmul(pick_flat, self.W2_query).view(
        shp_q_allpick
    )  # (self.num_heads, batch_size, n_pick, -1)
    K_allpick = torch.matmul(pick_flat, self.W_key).view(
        shp_allpick
    )  # [self.num_heads, batch_size, n_pick, key_size]
    V_allpick = torch.matmul(pick_flat, self.W_val).view(
        shp_allpick
    )  # [self.num_heads, batch_size, n_pick, key_size]

    # pickup -> all delivery
    Q_pick_alldelivery = torch.matmul(pick_flat, self.W3_query).view(
        shp_q_alldelivery
    )  # (self.num_heads, batch_size, n_pick, key_size)
    K_alldelivery = torch.matmul(delivery_flat, self.W_key).view(
        shp_alldelivery
    )  # (self.num_heads, batch_size, n_pick, -1)
    V_alldelivery = torch.matmul(delivery_flat, self.W_val).view(
        shp_alldelivery
    )  # (num_heads, batch_size, n_pick, key/val_size)

    # pickup -> its delivery
    V_additional_delivery = torch.cat(
        [  # [num_heads, batch_size, graph_size, key_size]
            torch.zeros(
                self.num_heads,
                batch_size,
                1,
                self.input_dim // self.num_heads,
                dtype=V.dtype,
                device=V.device,
            ),
            V_delivery,  # [num_heads, batch_size, n_pick, key/val_size]
            torch.zeros(
                self.num_heads,
                batch_size,
                n_pick,
                self.input_dim // self.num_heads,
                dtype=V.dtype,
                device=V.device,
            ),
        ],
        2,
    )

    # delivery -> its pickup attention
    Q_delivery = torch.matmul(delivery_flat, self.W4_query).view(
        shp_delivery
    )  # (self.num_heads, batch_size, n_pick, key_size)
    K_pick = torch.matmul(pick_flat, self.W_key).view(
        shp_q_pick
    )  # (self.num_heads, batch_size, n_pick, -1)
    V_pick = torch.matmul(pick_flat, self.W_val).view(
        shp_q_pick
    )  # (num_heads, batch_size, n_pick, key/val_size)

    # delivery -> all delivery attention
    Q_delivery_alldelivery = torch.matmul(delivery_flat, self.W5_query).view(
        shp_alldelivery
    )  # (self.num_heads, batch_size, n_pick, -1)
    K_alldelivery2 = torch.matmul(delivery_flat, self.W_key).view(
        shp_alldelivery
    )  # [self.num_heads, batch_size, n_pick, key_size]
    V_alldelivery2 = torch.matmul(delivery_flat, self.W_val).view(
        shp_alldelivery
    )  # [self.num_heads, batch_size, n_pick, key_size]

    # delivery -> all pickup
    Q_delivery_allpickup = torch.matmul(delivery_flat, self.W6_query).view(
        shp_alldelivery
    )  # (self.num_heads, batch_size, n_pick, key_size)
    K_allpickup2 = torch.matmul(pick_flat, self.W_key).view(
        shp_q_alldelivery
    )  # (self.num_heads, batch_size, n_pick, -1)
    V_allpickup2 = torch.matmul(pick_flat, self.W_val).view(
        shp_q_alldelivery
    )  # (num_heads, batch_size, n_pick, key/val_size)

    # delivery -> its pick up
    V_additional_pick = torch.cat(
        [  # [num_heads, batch_size, graph_size, key_size]
            torch.zeros(
                self.num_heads,
                batch_size,
                1,
                self.input_dim // self.num_heads,
                dtype=V.dtype,
                device=V.device,
            ),
            torch.zeros(
                self.num_heads,
                batch_size,
                n_pick,
                self.input_dim // self.num_heads,
                dtype=V.dtype,
                device=V.device,
            ),
            V_pick,  # [num_heads, batch_size, n_pick, key/val_size]
        ],
        2,
    )

    # Calculate compatibility (num_heads, batch_size, n_query, graph_size)
    compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))

    ##Pick up pair attention
    compatibility_pick_delivery = self.norm_factor * torch.sum(
        Q_pick * K_delivery, -1
    )  # element_wise, [num_heads, batch_size, n_pick]
    # [num_heads, batch_size, n_pick, n_pick]
    compatibility_pick_allpick = self.norm_factor * torch.matmul(
        Q_pick_allpick, K_allpick.transpose(2, 3)
    )  # [num_heads, batch_size, n_pick, n_pick]
    compatibility_pick_alldelivery = self.norm_factor * torch.matmul(
        Q_pick_alldelivery, K_alldelivery.transpose(2, 3)
    )  # [num_heads, batch_size, n_pick, n_pick]

    ##Delivery
    compatibility_delivery_pick = self.norm_factor * torch.sum(
        Q_delivery * K_pick, -1
    )  # element_wise, [num_heads, batch_size, n_pick]
    compatibility_delivery_alldelivery = self.norm_factor * torch.matmul(
        Q_delivery_alldelivery, K_alldelivery2.transpose(2, 3)
    )  # [num_heads, batch_size, n_pick, n_pick]
    compatibility_delivery_allpick = self.norm_factor * torch.matmul(
        Q_delivery_allpickup, K_allpickup2.transpose(2, 3)
    )  # [num_heads, batch_size, n_pick, n_pick]

    ##Pick up->
    # compatibility_additional?pickup????delivery????attention(size 1),1:n_pick+1??attention,depot?delivery??
    compatibility_additional_delivery = torch.cat(
        [  # [num_heads, batch_size, graph_size, 1]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                1,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            compatibility_pick_delivery,  # [num_heads, batch_size, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
        ],
        -1,
    ).view(self.num_heads, batch_size, graph_size, 1)

    compatibility_additional_allpick = torch.cat(
        [  # [num_heads, batch_size, graph_size, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                1,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            compatibility_pick_allpick,  # [num_heads, batch_size, n_pick, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                n_pick,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
        ],
        2,
    ).view(self.num_heads, batch_size, graph_size, n_pick)

    compatibility_additional_alldelivery = torch.cat(
        [  # [num_heads, batch_size, graph_size, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                1,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            compatibility_pick_alldelivery,  # [num_heads, batch_size, n_pick, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                n_pick,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
        ],
        2,
    ).view(self.num_heads, batch_size, graph_size, n_pick)
    # [num_heads, batch_size, n_query, graph_size+1+n_pick+n_pick]

    # Delivery
    compatibility_additional_pick = torch.cat(
        [  # [num_heads, batch_size, graph_size, 1]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                1,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            compatibility_delivery_pick,  # [num_heads, batch_size, n_pick]
        ],
        -1,
    ).view(self.num_heads, batch_size, graph_size, 1)

    compatibility_additional_alldelivery2 = torch.cat(
        [  # [num_heads, batch_size, graph_size, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                1,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                n_pick,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            compatibility_delivery_alldelivery,  # [num_heads, batch_size, n_pick, n_pick]
        ],
        2,
    ).view(self.num_heads, batch_size, graph_size, n_pick)

    compatibility_additional_allpick2 = torch.cat(
        [  # [num_heads, batch_size, graph_size, n_pick]
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                1,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            float("-inf")
            * torch.ones(
                self.num_heads,
                batch_size,
                n_pick,
                n_pick,
                dtype=compatibility.dtype,
                device=compatibility.device,
            ),
            compatibility_delivery_allpick,  # [num_heads, batch_size, n_pick, n_pick]
        ],
        2,
    ).view(self.num_heads, batch_size, graph_size, n_pick)

    compatibility = torch.cat(
        [
            compatibility,
            compatibility_additional_delivery,
            compatibility_additional_allpick,
            compatibility_additional_alldelivery,
            compatibility_additional_pick,
            compatibility_additional_alldelivery2,
            compatibility_additional_allpick2,
        ],
        dim=-1,
    )

    # Optionally apply mask to prevent attention
    if mask is not None:
        mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
        compatibility[mask] = float("-inf")

    attn = torch.softmax(
        compatibility, dim=-1
    )  # [num_heads, batch_size, n_query, graph_size+1+n_pick*2] (graph_size include depot)

    # If there are nodes with no neighbours then softmax returns nan so we fix them to 0
    if mask is not None:
        attnc = attn.clone()
        attnc[mask] = 0
        attn = attnc

    # heads: [num_heads, batrch_size, n_query, val_size] pick -> its delivery
    heads = torch.matmul(
        attn[:, :, :, :graph_size], V
    )  # V: (self.num_heads, batch_size, graph_size, val_size)
    heads = (
        heads
        + attn[:, :, :, graph_size].view(self.num_heads, batch_size, graph_size, 1)
        * V_additional_delivery
    )  # V_addi:[num_heads, batch_size, graph_size, key_size]

    # Heads pick -> otherpick, V_allpick: # [num_heads, batch_size, n_pick, key_size]
    heads = heads + torch.matmul(
        attn[:, :, :, graph_size + 1 : graph_size + 1 + n_pick].view(
            self.num_heads, batch_size, graph_size, n_pick
        ),
        V_allpick,
    )

    # V_alldelivery: # (num_heads, batch_size, n_pick, key/val_size)
    heads = heads + torch.matmul(
        attn[:, :, :, graph_size + 1 + n_pick : graph_size + 1 + 2 * n_pick].view(
            self.num_heads, batch_size, graph_size, n_pick
        ),
        V_alldelivery,
    )

    # Delivery
    heads = (
        heads
        + attn[:, :, :, graph_size + 1 + 2 * n_pick].view(
            self.num_heads, batch_size, graph_size, 1
        )
        * V_additional_pick
    )
    heads = heads + torch.matmul(
        attn[
            :,
            :,
            :,
            graph_size + 1 + 2 * n_pick + 1 : graph_size + 1 + 3 * n_pick + 1,
        ].view(self.num_heads, batch_size, graph_size, n_pick),
        V_alldelivery2,
    )
    heads = heads + torch.matmul(
        attn[:, :, :, graph_size + 1 + 3 * n_pick + 1 :].view(
            self.num_heads, batch_size, graph_size, n_pick
        ),
        V_allpickup2,
    )

    out = torch.mm(
        heads.permute(1, 2, 0, 3)
        .contiguous()
        .view(-1, self.num_heads * self.val_dim),
        self.W_out.view(-1, self.embed_dim),
    ).view(batch_size, n_query, self.embed_dim)

    return out

Matrix Encoding Network (MatNet)

Classes:

MatNetPolicy

MatNetPolicy(
    env_name: str = "atsp",
    embed_dim: int = 256,
    num_encoder_layers: int = 5,
    num_heads: int = 16,
    normalization: str = "instance",
    init_embedding_kwargs: dict = {"mode": "RandomOneHot"},
    use_graph_context: bool = False,
    bias: bool = False,
    **kwargs
)

Bases: AutoregressivePolicy

MatNet Policy from Kwon et al., 2021. Reference: https://arxiv.org/abs/2106.11113

Warning

This implementation is under development and subject to change.

Parameters:

  • env_name (str, default: 'atsp' ) –

    Name of the environment used to initialize embeddings

  • embed_dim (int, default: 256 ) –

    Dimension of the node embeddings

  • num_encoder_layers (int, default: 5 ) –

    Number of layers in the encoder

  • num_heads (int, default: 16 ) –

    Number of heads in the attention layers

  • normalization (str, default: 'instance' ) –

    Normalization type in the attention layers

  • **kwargs

    keyword arguments passed to the AutoregressivePolicy

Default paarameters are adopted from the original implementation.

Source code in rl4co/models/zoo/matnet/policy.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    env_name: str = "atsp",
    embed_dim: int = 256,
    num_encoder_layers: int = 5,
    num_heads: int = 16,
    normalization: str = "instance",
    init_embedding_kwargs: dict = {"mode": "RandomOneHot"},
    use_graph_context: bool = False,
    bias: bool = False,
    **kwargs,
):
    if env_name not in ["atsp", "ffsp"]:
        log.error(f"env_name {env_name} is not originally implemented in MatNet")

    if env_name == "ffsp":
        decoder = MatNetFFSPDecoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            use_graph_context=use_graph_context,
            out_bias=True,
        )

    else:
        decoder = MatNetDecoder(
            env_name=env_name,
            embed_dim=embed_dim,
            num_heads=num_heads,
            use_graph_context=use_graph_context,
        )

    super(MatNetPolicy, self).__init__(
        env_name=env_name,
        encoder=MatNetEncoder(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_layers=num_encoder_layers,
            normalization=normalization,
            init_embedding_kwargs=init_embedding_kwargs,
            bias=bias,
        ),
        decoder=decoder,
        embed_dim=embed_dim,
        num_encoder_layers=num_encoder_layers,
        num_heads=num_heads,
        normalization=normalization,
        **kwargs,
    )

MultiStageFFSPPolicy

MultiStageFFSPPolicy(
    stage_cnt: int,
    embed_dim: int = 512,
    num_heads: int = 16,
    num_encoder_layers: int = 5,
    use_graph_context: bool = False,
    normalization: str = "instance",
    feedforward_hidden: int = 512,
    bias: bool = False,
    train_decode_type: str = "sampling",
    val_decode_type: str = "sampling",
    test_decode_type: str = "sampling",
)

Bases: Module

Policy for solving the FFSP using a seperate encoder and decoder for each stage. This requires the 'while not td["done"].all()'-loop to be on policy level (instead of decoder level).

Source code in rl4co/models/zoo/matnet/policy.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self,
    stage_cnt: int,
    embed_dim: int = 512,
    num_heads: int = 16,
    num_encoder_layers: int = 5,
    use_graph_context: bool = False,
    normalization: str = "instance",
    feedforward_hidden: int = 512,
    bias: bool = False,
    train_decode_type: str = "sampling",
    val_decode_type: str = "sampling",
    test_decode_type: str = "sampling",
):
    super().__init__()
    self.stage_cnt = stage_cnt

    self.encoders: list[MatNetEncoder] = nn.ModuleList(
        [
            MatNetEncoder(
                embed_dim=embed_dim,
                num_heads=num_heads,
                num_layers=num_encoder_layers,
                normalization=normalization,
                feedforward_hidden=feedforward_hidden,
                bias=bias,
                init_embedding_kwargs={"mode": "RandomOneHot"},
            )
            for _ in range(self.stage_cnt)
        ]
    )
    self.decoders: list[MultiStageFFSPDecoder] = nn.ModuleList(
        [
            MultiStageFFSPDecoder(embed_dim, num_heads, use_graph_context)
            for _ in range(self.stage_cnt)
        ]
    )

    self.train_decode_type = train_decode_type
    self.val_decode_type = val_decode_type
    self.test_decode_type = test_decode_type

Classes:

MixedScoresSDPA

MixedScoresSDPA(
    num_heads: int,
    num_scores: int = 1,
    mixer_hidden_dim: int = 16,
    mix1_init: float = 1 / 2**1 / 2,
    mix2_init: float = 1 / 16**1 / 2,
)

Bases: Module

Methods:

  • forward

    Scaled Dot-Product Attention with MatNet Scores Mixer

Source code in rl4co/models/zoo/matnet/encoder.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    num_heads: int,
    num_scores: int = 1,
    mixer_hidden_dim: int = 16,
    mix1_init: float = (1 / 2) ** (1 / 2),
    mix2_init: float = (1 / 16) ** (1 / 2),
):
    super().__init__()
    self.num_heads = num_heads
    self.num_scores = num_scores
    mix_W1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample(
        (num_heads, self.num_scores + 1, mixer_hidden_dim)
    )
    mix_b1 = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample(
        (num_heads, mixer_hidden_dim)
    )
    self.mix_W1 = nn.Parameter(mix_W1)
    self.mix_b1 = nn.Parameter(mix_b1)

    mix_W2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample(
        (num_heads, mixer_hidden_dim, 1)
    )
    mix_b2 = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample(
        (num_heads, 1)
    )
    self.mix_W2 = nn.Parameter(mix_W2)
    self.mix_b2 = nn.Parameter(mix_b2)

forward

forward(q, k, v, attn_mask=None, dmat=None, dropout_p=0.0)

Scaled Dot-Product Attention with MatNet Scores Mixer

Source code in rl4co/models/zoo/matnet/encoder.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def forward(self, q, k, v, attn_mask=None, dmat=None, dropout_p=0.0):
    """Scaled Dot-Product Attention with MatNet Scores Mixer"""
    assert dmat is not None
    b, m, n = dmat.shape[:3]
    dmat = dmat.reshape(b, m, n, self.num_scores)

    # Calculate scaled dot product
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)
    # [b, h, m, n, num_scores+1]
    mix_attn_scores = torch.cat(
        [
            attn_scores.unsqueeze(-1),
            dmat[:, None, ...].expand(b, self.num_heads, m, n, self.num_scores),
        ],
        dim=-1,
    )
    # [b, h, m, n]
    attn_scores = (
        (
            torch.matmul(
                F.relu(
                    torch.matmul(mix_attn_scores.transpose(1, 2), self.mix_W1)
                    + self.mix_b1[None, None, :, None, :]
                ),
                self.mix_W2,
            )
            + self.mix_b2[None, None, :, None, :]
        )
        .transpose(1, 2)
        .squeeze(-1)
    )

    # Apply the provided attention mask
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask[~attn_mask.any(-1)] = True
            attn_scores.masked_fill_(~attn_mask, float("-inf"))
        else:
            attn_scores += attn_mask

    # Softmax to get attention weights
    attn_weights = F.softmax(attn_scores, dim=-1)

    # Apply dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # Compute the weighted sum of values
    return torch.matmul(attn_weights, v)

MatNetMHA

MatNetMHA(
    embed_dim: int, num_heads: int, bias: bool = False
)

Bases: Module

Methods:

Source code in rl4co/models/zoo/matnet/encoder.py
116
117
118
119
def __init__(self, embed_dim: int, num_heads: int, bias: bool = False):
    super().__init__()
    self.row_encoding_block = MatNetCrossMHA(embed_dim, num_heads, bias)
    self.col_encoding_block = MatNetCrossMHA(embed_dim, num_heads, bias)

forward

forward(row_emb, col_emb, dmat, attn_mask=None)

Parameters:

  • row_emb (Tensor) –

    [b, m, d]

  • col_emb (Tensor) –

    [b, n, d]

  • dmat (Tensor) –

    [b, m, n]

Returns:

  • Updated row_emb (Tensor): [b, m, d]

  • Updated col_emb (Tensor): [b, n, d]

Source code in rl4co/models/zoo/matnet/encoder.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def forward(self, row_emb, col_emb, dmat, attn_mask=None):
    """
    Args:
        row_emb (Tensor): [b, m, d]
        col_emb (Tensor): [b, n, d]
        dmat (Tensor): [b, m, n]

    Returns:
        Updated row_emb (Tensor): [b, m, d]
        Updated col_emb (Tensor): [b, n, d]
    """
    updated_row_emb = self.row_encoding_block(
        row_emb, col_emb, dmat=dmat, cross_attn_mask=attn_mask
    )
    attn_mask_t = attn_mask.transpose(-2, -1) if attn_mask is not None else None
    updated_col_emb = self.col_encoding_block(
        col_emb,
        row_emb,
        dmat=dmat.transpose(-2, -1),
        cross_attn_mask=attn_mask_t,
    )
    return updated_row_emb, updated_col_emb

MatNetLayer

MatNetLayer(
    embed_dim: int,
    num_heads: int,
    bias: bool = False,
    feedforward_hidden: int = 512,
    normalization: Optional[str] = "instance",
)

Bases: Module

Methods:

Source code in rl4co/models/zoo/matnet/encoder.py
146
147
148
149
150
151
152
153
154
155
156
157
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    bias: bool = False,
    feedforward_hidden: int = 512,
    normalization: Optional[str] = "instance",
):
    super().__init__()
    self.MHA = MatNetMHA(embed_dim, num_heads, bias)
    self.F_a = TransformerFFN(embed_dim, feedforward_hidden, normalization)
    self.F_b = TransformerFFN(embed_dim, feedforward_hidden, normalization)

forward

forward(row_emb, col_emb, dmat, attn_mask=None)

Parameters:

  • row_emb (Tensor) –

    [b, m, d]

  • col_emb (Tensor) –

    [b, n, d]

  • dmat (Tensor) –

    [b, m, n]

Returns:

  • Updated row_emb (Tensor): [b, m, d]

  • Updated col_emb (Tensor): [b, n, d]

Source code in rl4co/models/zoo/matnet/encoder.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def forward(self, row_emb, col_emb, dmat, attn_mask=None):
    """
    Args:
        row_emb (Tensor): [b, m, d]
        col_emb (Tensor): [b, n, d]
        dmat (Tensor): [b, m, n]

    Returns:
        Updated row_emb (Tensor): [b, m, d]
        Updated col_emb (Tensor): [b, n, d]
    """

    row_emb_out, col_emb_out = self.MHA(row_emb, col_emb, dmat, attn_mask)
    row_emb_out = self.F_a(row_emb_out, row_emb)
    col_emb_out = self.F_b(col_emb_out, col_emb)
    return row_emb_out, col_emb_out

Classes:

  • MultiStageFFSPDecoder

    Decoder class for the solving the FFSP using a seperate MatNet decoder for each stage

MultiStageFFSPDecoder

MultiStageFFSPDecoder(
    embed_dim: int,
    num_heads: int,
    use_graph_context: bool = True,
    tanh_clipping: float = 10,
    **kwargs
)

Bases: MatNetFFSPDecoder

Decoder class for the solving the FFSP using a seperate MatNet decoder for each stage as originally implemented by Kwon et al. (2021)

Source code in rl4co/models/zoo/matnet/decoder.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    use_graph_context: bool = True,
    tanh_clipping: float = 10,
    **kwargs,
):
    super().__init__(
        embed_dim=embed_dim,
        num_heads=num_heads,
        use_graph_context=use_graph_context,
        **kwargs,
    )
    self.cached_embs: PrecomputedCache = None
    self.tanh_clipping = tanh_clipping

Multi-Decoder Attention Model (MDAM)

Classes:

  • MDAM

    Multi-Decoder Attention Model (MDAM) is a model

Functions:

  • rollout

    In this case the reward from the model is [batch, num_paths]

MDAM

MDAM(
    env: RL4COEnvBase,
    policy: MDAMPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs
)

Bases: REINFORCE

Multi-Decoder Attention Model (MDAM) is a model to train multiple diverse policies, which effectively increases the chance of finding good solutions compared with existing methods that train only one policy. Reference link: https://arxiv.org/abs/2012.10638; Implementation reference: https://github.com/liangxinedu/MDAM.

Parameters:

  • env (RL4COEnvBase) –

    Environment to use for the algorithm

  • policy (MDAMPolicy, default: None ) –

    Policy to use for the algorithm

  • baseline (REINFORCEBaseline | str, default: 'rollout' ) –

    REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs

    Keyword arguments for policy

  • baseline_kwargs

    Keyword arguments for baseline

  • **kwargs

    Keyword arguments passed to the superclass

Methods:

Source code in rl4co/models/zoo/mdam/model.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(
    self,
    env: RL4COEnvBase,
    policy: MDAMPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs,
):
    if policy is None:
        policy = MDAMPolicy(env_name=env.name, **policy_kwargs)

    super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)

    # Change rollout of baseline to the rollout function
    if isinstance(self.baseline, WarmupBaseline):
        if isinstance(self.baseline.baseline, RolloutBaseline):
            self.baseline.baseline.rollout = partial(rollout, self.baseline.baseline)
    elif isinstance(self.baseline, RolloutBaseline):
        self.baseline.rollout = partial(rollout, self.baseline)

calculate_loss

calculate_loss(
    td, batch, policy_out, reward=None, log_likelihood=None
)

Calculate loss for REINFORCE algorithm. Same as in :class:REINFORCE, but the bl_val is calculated is simply unsqueezed to match the reward shape (i.e., [batch, num_paths])

Parameters:

  • td

    TensorDict containing the current state of the environment

  • batch

    Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baseline

  • policy_out

    Output of the policy network

  • reward

    Reward tensor. If None, it is taken from policy_out

  • log_likelihood

    Log-likelihood tensor. If None, it is taken from policy_out

Source code in rl4co/models/zoo/mdam/model.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def calculate_loss(
    self,
    td,
    batch,
    policy_out,
    reward=None,
    log_likelihood=None,
):
    """Calculate loss for REINFORCE algorithm.
    Same as in :class:`REINFORCE`, but the bl_val is calculated is simply unsqueezed to match
    the reward shape (i.e., [batch, num_paths])

    Args:
        td: TensorDict containing the current state of the environment
        batch: Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baseline
        policy_out: Output of the policy network
        reward: Reward tensor. If None, it is taken from `policy_out`
        log_likelihood: Log-likelihood tensor. If None, it is taken from `policy_out`
    """
    # Extra: this is used for additional loss terms, e.g., REINFORCE baseline
    extra = batch.get("extra", None)
    reward = reward if reward is not None else policy_out["reward"]
    log_likelihood = (
        log_likelihood if log_likelihood is not None else policy_out["log_likelihood"]
    )

    # REINFORCE baseline
    bl_val, bl_loss = (
        self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0)
    )

    # Main loss function
    # reward: [batch, num_paths]. Note that the baseline value is the max reward
    # if bl_val is a tensor, unsqueeze it to match the reward shape
    if isinstance(bl_val, torch.Tensor):
        if len(bl_val.shape) > 0:
            bl_val = bl_val.unsqueeze(1)
    advantage = reward - bl_val  # advantage = reward - baseline
    reinforce_loss = -(advantage * log_likelihood).mean()
    loss = reinforce_loss + bl_loss
    policy_out.update(
        {
            "loss": loss,
            "reinforce_loss": reinforce_loss,
            "bl_loss": bl_loss,
            "bl_val": bl_val,
        }
    )
    return policy_out

rollout

rollout(
    self,
    model,
    env,
    batch_size=64,
    device="cpu",
    dataset=None,
)

In this case the reward from the model is [batch, num_paths] and the baseline takes the maximum reward from the model as the baseline. https://github.com/liangxinedu/MDAM/blob/19b0bf813fb2dbec2fcde9e22eb50e04675400cd/train.py#L38C29-L38C33

Source code in rl4co/models/zoo/mdam/model.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def rollout(self, model, env, batch_size=64, device="cpu", dataset=None):
    """In this case the reward from the model is [batch, num_paths]
    and the baseline takes the maximum reward from the model as the baseline.
    https://github.com/liangxinedu/MDAM/blob/19b0bf813fb2dbec2fcde9e22eb50e04675400cd/train.py#L38C29-L38C33
    """
    # if dataset is None, use the dataset of the baseline
    dataset = self.dataset if dataset is None else dataset

    model.eval()
    model = model.to(device)

    def eval_model(batch):
        with torch.inference_mode():
            batch = env.reset(batch.to(device))
            return model(batch, env, decode_type="greedy")["reward"].max(1).values

    dl = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)

    rewards = torch.cat([eval_model(batch) for batch in dl], 0)
    return rewards

Classes:

  • MDAMPolicy

    Multi-Decoder Attention Model (MDAM) policy.

MDAMPolicy

MDAMPolicy(
    encoder: MDAMGraphAttentionEncoder = None,
    decoder: MDAMDecoder = None,
    embed_dim: int = 128,
    env_name: str = "tsp",
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    **decoder_kwargs
)

Bases: AutoregressivePolicy

Multi-Decoder Attention Model (MDAM) policy. Args:

Source code in rl4co/models/zoo/mdam/policy.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self,
    encoder: MDAMGraphAttentionEncoder = None,
    decoder: MDAMDecoder = None,
    embed_dim: int = 128,
    env_name: str = "tsp",
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    **decoder_kwargs,
):
    encoder = (
        MDAMGraphAttentionEncoder(
            num_heads=num_heads,
            embed_dim=embed_dim,
            num_layers=num_encoder_layers,
            normalization=normalization,
        )
        if encoder is None
        else encoder
    )

    decoder = (
        MDAMDecoder(
            env_name=env_name,
            embed_dim=embed_dim,
            num_heads=num_heads,
            **decoder_kwargs,
        )
        if decoder is None
        else decoder
    )

    super(MDAMPolicy, self).__init__(
        env_name=env_name, encoder=encoder, decoder=decoder
    )

    self.init_embedding = env_init_embedding(env_name, {"embed_dim": embed_dim})

Classes:

MDAMGraphAttentionEncoder

MDAMGraphAttentionEncoder(
    num_heads,
    embed_dim,
    num_layers,
    node_dim=None,
    normalization="batch",
    feedforward_hidden=512,
    sdpa_fn: Optional[Callable] = None,
)

Bases: Module

Methods:

Source code in rl4co/models/zoo/mdam/encoder.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    num_heads,
    embed_dim,
    num_layers,
    node_dim=None,
    normalization="batch",
    feedforward_hidden=512,
    sdpa_fn: Optional[Callable] = None,
):
    super(MDAMGraphAttentionEncoder, self).__init__()

    # To map input to embedding space
    self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None

    self.layers = nn.Sequential(
        *(
            MultiHeadAttentionLayer(
                embed_dim,
                num_heads,
                feedforward_hidden,
                normalization,
                sdpa_fn=sdpa_fn,
            )
            for _ in range(num_layers - 1)  # because last layer is different
        )
    )
    self.attention_layer = MultiHeadAttentionMDAM(
        embed_dim, num_heads, sdpa_fn=sdpa_fn, last_one=True
    )
    self.BN1 = Normalization(embed_dim, normalization)
    self.projection = SkipConnection(
        nn.Sequential(
            nn.Linear(embed_dim, feedforward_hidden),
            nn.ReLU(),
            nn.Linear(feedforward_hidden, embed_dim),
        )
        if feedforward_hidden > 0
        else nn.Linear(embed_dim, embed_dim)
    )
    self.BN2 = Normalization(embed_dim, normalization)

forward

forward(x, mask=None, return_transform_loss=False)

Returns:

    • h [batch_size, graph_size, embed_dim]
    • attn [num_head, batch_size, graph_size, graph_size]
    • V [num_head, batch_size, graph_size, key_dim]
    • h_old [batch_size, graph_size, embed_dim]
Source code in rl4co/models/zoo/mdam/encoder.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def forward(self, x, mask=None, return_transform_loss=False):
    """
    Returns:
        - h [batch_size, graph_size, embed_dim]
        - attn [num_head, batch_size, graph_size, graph_size]
        - V [num_head, batch_size, graph_size, key_dim]
        - h_old [batch_size, graph_size, embed_dim]
    """
    assert mask is None, "TODO mask not yet supported!"

    h_embeded = x
    h_old = self.layers(h_embeded)
    h_new, attn, V = self.attention_layer(h_old)
    h = h_new + h_old
    h = self.BN1(h)
    h = self.projection(h)
    h = self.BN2(h)

    return (h, h.mean(dim=1), attn, V, h_old)

POMO

Classes:

  • POMO

    POMO Model for neural combinatorial optimization based on REINFORCE

POMO

POMO(
    env: RL4COEnvBase,
    policy: Module = None,
    policy_kwargs={},
    baseline: str = "shared",
    num_augment: int = 8,
    augment_fn: str | Callable = "dihedral8",
    first_aug_identity: bool = True,
    feats: list = None,
    num_starts: int = None,
    **kwargs
)

Bases: REINFORCE

POMO Model for neural combinatorial optimization based on REINFORCE Based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011.

Note

If no policy kwargs is passed, we use the Attention Model policy with the following arguments: Differently to the base class:

  • num_encoder_layers=6 (instead of 3)
  • normalization="instance" (instead of "batch")
  • use_graph_context=False (instead of True) The latter is due to the fact that the paper does not use the graph context in the policy, which seems to be helpful in overfitting to the training graph size.

Parameters:

  • env (RL4COEnvBase) –

    TorchRL Environment

  • policy (Module, default: None ) –

    Policy to use for the algorithm

  • policy_kwargs

    Keyword arguments for policy

  • baseline (str, default: 'shared' ) –

    Baseline to use for the algorithm. Note that POMO only supports shared baseline, so we will throw an error if anything else is passed.

  • num_augment (int, default: 8 ) –

    Number of augmentations (used only for validation and test)

  • augment_fn (str | Callable, default: 'dihedral8' ) –

    Function to use for augmentation, defaulting to dihedral8

  • first_aug_identity (bool, default: True ) –

    Whether to include the identity augmentation in the first position

  • feats (list, default: None ) –

    List of features to augment

  • num_starts (int, default: None ) –

    Number of starts for multi-start. If None, use the number of available actions

  • **kwargs

    Keyword arguments passed to the superclass

Source code in rl4co/models/zoo/pomo/model.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    env: RL4COEnvBase,
    policy: nn.Module = None,
    policy_kwargs={},
    baseline: str = "shared",
    num_augment: int = 8,
    augment_fn: str | Callable = "dihedral8",
    first_aug_identity: bool = True,
    feats: list = None,
    num_starts: int = None,
    **kwargs,
):
    self.save_hyperparameters(logger=False)

    if policy is None:
        policy_kwargs_with_defaults = {
            "num_encoder_layers": 6,
            "normalization": "instance",
            "use_graph_context": False,
        }
        policy_kwargs_with_defaults.update(policy_kwargs)
        policy = AttentionModelPolicy(
            env_name=env.name, **policy_kwargs_with_defaults
        )

    assert baseline == "shared", "POMO only supports shared baseline"

    # Initialize with the shared baseline
    super(POMO, self).__init__(env, policy, baseline, **kwargs)

    self.num_starts = num_starts
    self.num_augment = num_augment
    if self.num_augment > 1:
        self.augment = StateAugmentation(
            num_augment=self.num_augment,
            augment_fn=augment_fn,
            first_aug_identity=first_aug_identity,
            feats=feats,
        )
    else:
        self.augment = None

    # Add `_multistart` to decode type for train, val and test in policy
    for phase in ["train", "val", "test"]:
        self.set_decode_type_multistart(phase)

Pointer Network (PtrNet)

Classes:

  • PointerNetwork

    Pointer Network for neural combinatorial optimization based on REINFORCE

PointerNetwork

PointerNetwork(
    env: RL4COEnvBase,
    policy: PointerNetworkPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs
)

Bases: REINFORCE

Pointer Network for neural combinatorial optimization based on REINFORCE Based on Vinyals et al. (2015) https://arxiv.org/abs/1506.03134 Refactored from reference implementation: https://github.com/wouterkool/attention-learn-to-route

Parameters:

  • env (RL4COEnvBase) –

    Environment to use for the algorithm

  • policy (PointerNetworkPolicy, default: None ) –

    Policy to use for the algorithm

  • baseline (REINFORCEBaseline | str, default: 'rollout' ) –

    REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs

    Keyword arguments for policy

  • baseline_kwargs

    Keyword arguments for baseline

  • **kwargs

    Keyword arguments passed to the superclass

Source code in rl4co/models/zoo/ptrnet/model.py
21
22
23
24
25
26
27
28
29
30
31
32
33
def __init__(
    self,
    env: RL4COEnvBase,
    policy: PointerNetworkPolicy = None,
    baseline: REINFORCEBaseline | str = "rollout",
    policy_kwargs={},
    baseline_kwargs={},
    **kwargs,
):
    policy = (
        PointerNetworkPolicy(env=env, **policy_kwargs) if policy is None else policy
    )
    super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)

Classes:

  • Encoder

    Maps a graph represented as an input sequence

Encoder

Encoder(input_dim, hidden_dim)

Bases: Module

Maps a graph represented as an input sequence to a hidden vector

Methods:

Source code in rl4co/models/zoo/ptrnet/encoder.py
11
12
13
14
15
def __init__(self, input_dim, hidden_dim):
    super(Encoder, self).__init__()
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(input_dim, hidden_dim)
    self.init_hx, self.init_cx = self.init_hidden(hidden_dim)

init_hidden

init_hidden(hidden_dim)

Trainable initial hidden state

Source code in rl4co/models/zoo/ptrnet/encoder.py
21
22
23
24
25
26
27
28
29
def init_hidden(self, hidden_dim):
    """Trainable initial hidden state"""
    std = 1.0 / math.sqrt(hidden_dim)
    enc_init_hx = nn.Parameter(torch.FloatTensor(hidden_dim))
    enc_init_hx.data.uniform_(-std, std)

    enc_init_cx = nn.Parameter(torch.FloatTensor(hidden_dim))
    enc_init_cx.data.uniform_(-std, std)
    return enc_init_hx, enc_init_cx

Classes:

SimpleAttention

SimpleAttention(dim, use_tanh=False, C=10)

Bases: Module

A generic attention module for a decoder in seq2seq

Methods:

Source code in rl4co/models/zoo/ptrnet/decoder.py
14
15
16
17
18
19
20
21
22
def __init__(self, dim, use_tanh=False, C=10):
    super(SimpleAttention, self).__init__()
    self.use_tanh = use_tanh
    self.project_query = nn.Linear(dim, dim)
    self.project_ref = nn.Conv1d(dim, dim, 1, 1)
    self.C = C  # tanh exploration

    self.v = nn.Parameter(torch.FloatTensor(dim))
    self.v.data.uniform_(-(1.0 / math.sqrt(dim)), 1.0 / math.sqrt(dim))

forward

forward(query, ref)

Parameters:

  • query

    is the hidden state of the decoder at the current time step. batch x dim

  • ref

    the set of hidden states from the encoder. sourceL x batch x hidden_dim

Source code in rl4co/models/zoo/ptrnet/decoder.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def forward(self, query, ref):
    """
    Args:
        query: is the hidden state of the decoder at the current
            time step. batch x dim
        ref: the set of hidden states from the encoder.
            sourceL x batch x hidden_dim
    """
    # ref is now [batch_size x hidden_dim x sourceL]
    ref = ref.permute(1, 2, 0)
    q = self.project_query(query).unsqueeze(2)  # batch x dim x 1
    e = self.project_ref(ref)  # batch_size x hidden_dim x sourceL
    # expand the query by sourceL
    # batch x dim x sourceL
    expanded_q = q.repeat(1, 1, e.size(2))
    # batch x 1 x hidden_dim
    v_view = self.v.unsqueeze(0).expand(expanded_q.size(0), len(self.v)).unsqueeze(1)
    # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL]
    u = torch.bmm(v_view, F.tanh(expanded_q + e)).squeeze(1)
    if self.use_tanh:
        logits = self.C * F.tanh(u)
    else:
        logits = u
    return e, logits

Decoder

Decoder(
    embed_dim: int = 128,
    hidden_dim: int = 128,
    tanh_exploration: float = 10.0,
    use_tanh: bool = True,
    num_glimpses=1,
    mask_glimpses=True,
    mask_logits=True,
)

Bases: Module

Methods:

Source code in rl4co/models/zoo/ptrnet/decoder.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    embed_dim: int = 128,
    hidden_dim: int = 128,
    tanh_exploration: float = 10.0,
    use_tanh: bool = True,
    num_glimpses=1,
    mask_glimpses=True,
    mask_logits=True,
):
    super(Decoder, self).__init__()

    self.embed_dim = embed_dim
    self.hidden_dim = hidden_dim
    self.num_glimpses = num_glimpses
    self.mask_glimpses = mask_glimpses
    self.mask_logits = mask_logits
    self.use_tanh = use_tanh
    self.tanh_exploration = tanh_exploration

    self.lstm = nn.LSTMCell(embed_dim, hidden_dim)
    self.pointer = SimpleAttention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration)
    self.glimpse = SimpleAttention(hidden_dim, use_tanh=False)

forward

forward(
    decoder_input,
    embedded_inputs,
    hidden,
    context,
    decode_type="sampling",
    eval_tours=None,
)

Parameters:

  • decoder_input

    The initial input to the decoder size is [batch_size x embed_dim]. Trainable parameter.

  • embedded_inputs

    [sourceL x batch_size x embed_dim]

  • hidden

    the prev hidden state, size is [batch_size x hidden_dim]. Initially this is set to (enc_h[-1], enc_c[-1])

  • context

    encoder outputs, [sourceL x batch_size x hidden_dim]

Source code in rl4co/models/zoo/ptrnet/decoder.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def forward(
    self,
    decoder_input,
    embedded_inputs,
    hidden,
    context,
    decode_type="sampling",
    eval_tours=None,
):
    """
    Args:
        decoder_input: The initial input to the decoder
            size is [batch_size x embed_dim]. Trainable parameter.
        embedded_inputs: [sourceL x batch_size x embed_dim]
        hidden: the prev hidden state, size is [batch_size x hidden_dim].
            Initially this is set to (enc_h[-1], enc_c[-1])
        context: encoder outputs, [sourceL x batch_size x hidden_dim]
    """

    batch_size = context.size(1)
    outputs = []
    selections = []
    steps = range(embedded_inputs.size(0))
    idxs = None
    mask = torch.ones(
        embedded_inputs.size(1),
        embedded_inputs.size(0),
        dtype=torch.bool,
        device=embedded_inputs.device,
    )

    for i in steps:
        hidden, log_p, mask = self.recurrence(
            decoder_input, hidden, mask, idxs, i, context
        )
        # select the next inputs for the decoder [batch_size x hidden_dim]
        idxs = (
            decode_logprobs(log_p, mask, decode_type=decode_type)
            if eval_tours is None
            else eval_tours[:, i]
        )
        # select logp of chosen action
        log_p = gather_by_index(log_p, idxs, dim=1)

        idxs = (
            idxs.detach()
        )  # Otherwise pytorch complains it want's a reward, todo implement this more properly?
        # Gather input embedding of selected
        decoder_input = torch.gather(
            embedded_inputs,
            0,
            idxs.contiguous()
            .view(1, batch_size, 1)
            .expand(1, batch_size, *embedded_inputs.size()[2:]),
        ).squeeze(0)

        # use outs to point to next object
        outputs.append(log_p)
        selections.append(idxs)
    return (torch.stack(outputs, 1), torch.stack(selections, 1)), hidden

Classes:

CriticNetworkLSTM

CriticNetworkLSTM(
    embed_dim,
    hidden_dim,
    n_process_block_iters,
    tanh_exploration,
    use_tanh,
)

Bases: Module

Useful as a baseline in REINFORCE updates

Methods:

Source code in rl4co/models/zoo/ptrnet/critic.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    embed_dim,
    hidden_dim,
    n_process_block_iters,
    tanh_exploration,
    use_tanh,
):
    super(CriticNetworkLSTM, self).__init__()

    self.hidden_dim = hidden_dim
    self.n_process_block_iters = n_process_block_iters

    self.encoder = Encoder(embed_dim, hidden_dim)

    self.process_block = SimpleAttention(
        hidden_dim, use_tanh=use_tanh, C=tanh_exploration
    )
    self.sm = nn.Softmax(dim=1)
    self.decoder = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
    )

forward

forward(inputs)

Parameters:

  • inputs

    [embed_dim x batch_size x sourceL] of embedded inputs

Source code in rl4co/models/zoo/ptrnet/critic.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def forward(self, inputs):
    """
    Args:
        inputs: [embed_dim x batch_size x sourceL] of embedded inputs
    """
    inputs = inputs.transpose(0, 1).contiguous()

    encoder_hx = (
        self.encoder.init_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)
    )
    encoder_cx = (
        self.encoder.init_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0)
    )

    # encoder forward pass
    enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx))

    # grab the hidden state and process it via the process block
    process_block_state = enc_h_t[-1]
    for i in range(self.n_process_block_iters):
        ref, logits = self.process_block(process_block_state, enc_outputs)
        process_block_state = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2)
    # produce the final scalar output
    out = self.decoder(process_block_state)
    return out

SymNCO

Classes:

  • SymNCO

    SymNCO Model based on REINFORCE with shared baselines.

SymNCO

SymNCO(
    env: RL4COEnvBase,
    policy: Module | SymNCOPolicy = None,
    policy_kwargs: dict = {},
    baseline: str = "symnco",
    num_augment: int = 4,
    augment_fn: str | Callable = "symmetric",
    feats: list = None,
    alpha: float = 0.2,
    beta: float = 1,
    num_starts: int = 0,
    **kwargs
)

Bases: REINFORCE

SymNCO Model based on REINFORCE with shared baselines. Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.

Parameters:

  • env (RL4COEnvBase) –

    TorchRL environment to use for the algorithm

  • policy (Module | SymNCOPolicy, default: None ) –

    Policy to use for the algorithm

  • policy_kwargs (dict, default: {} ) –

    Keyword arguments for policy

  • num_augment (int, default: 4 ) –

    Number of augmentations

  • augment_fn (str | Callable, default: 'symmetric' ) –

    Function to use for augmentation, defaulting to dihedral_8_augmentation

  • feats (list, default: None ) –

    List of features to augment

  • alpha (float, default: 0.2 ) –

    weight for invariance loss

  • beta (float, default: 1 ) –

    weight for solution symmetricity loss

  • num_starts (int, default: 0 ) –

    Number of starts for multi-start. If None, use the number of available actions

  • **kwargs

    Keyword arguments passed to the superclass

Source code in rl4co/models/zoo/symnco/model.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    env: RL4COEnvBase,
    policy: nn.Module | SymNCOPolicy = None,
    policy_kwargs: dict = {},
    baseline: str = "symnco",
    num_augment: int = 4,
    augment_fn: str | Callable = "symmetric",
    feats: list = None,
    alpha: float = 0.2,
    beta: float = 1,
    num_starts: int = 0,
    **kwargs,
):
    self.save_hyperparameters(logger=False)

    if policy is None:
        policy = SymNCOPolicy(env_name=env.name, **policy_kwargs)

    assert baseline == "symnco", "SymNCO only supports custom-symnco baseline"
    baseline = "no"  # Pass no baseline to superclass since there are multiple custom baselines

    # Pass no baseline to superclass since there are multiple custom baselines
    super().__init__(env, policy, baseline, **kwargs)

    self.num_starts = num_starts
    self.num_augment = num_augment
    self.augment = StateAugmentation(
        num_augment=self.num_augment, augment_fn=augment_fn, feats=feats
    )
    self.alpha = alpha  # weight for invariance loss
    self.beta = beta  # weight for solution symmetricity loss

    # Add `_multistart` to decode type for train, val and test in policy if num_starts > 1
    if self.num_starts > 1:
        for phase in ["train", "val", "test"]:
            self.set_decode_type_multistart(phase)

Classes:

  • SymNCOPolicy

    SymNCO Policy based on AutoregressivePolicy.

SymNCOPolicy

SymNCOPolicy(
    embed_dim: int = 128,
    env_name: str = "tsp",
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    projection_head: Module = None,
    use_projection_head: bool = True,
    **kwargs
)

Bases: AttentionModelPolicy

SymNCO Policy based on AutoregressivePolicy. This differs from the default :class:AutoregressivePolicy in that it projects the initial embeddings to a lower dimension using a projection head and returns it. This is used in the SymNCO algorithm to compute the invariance loss. Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.

Parameters:

  • embed_dim (int, default: 128 ) –

    Dimension of the embedding

  • env_name (str, default: 'tsp' ) –

    Name of the environment

  • num_encoder_layers (int, default: 3 ) –

    Number of layers in the encoder

  • num_heads (int, default: 8 ) –

    Number of heads in the encoder

  • normalization (str, default: 'batch' ) –

    Normalization to use in the encoder

  • projection_head (Module, default: None ) –

    Projection head to use

  • use_projection_head (bool, default: True ) –

    Whether to use projection head

  • **kwargs

    Keyword arguments passed to the superclass

Source code in rl4co/models/zoo/symnco/policy.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    embed_dim: int = 128,
    env_name: str = "tsp",
    num_encoder_layers: int = 3,
    num_heads: int = 8,
    normalization: str = "batch",
    projection_head: nn.Module = None,
    use_projection_head: bool = True,
    **kwargs,
):
    super(SymNCOPolicy, self).__init__(
        env_name=env_name,
        embed_dim=embed_dim,
        num_encoder_layers=num_encoder_layers,
        num_heads=num_heads,
        normalization=normalization,
        **kwargs,
    )

    self.use_projection_head = use_projection_head

    if self.use_projection_head:
        self.projection_head = (
            MLP(embed_dim, embed_dim, 1, embed_dim, nn.ReLU)
            if projection_head is None
            else projection_head
        )

Functions:

problem_symmetricity_loss

problem_symmetricity_loss(reward, log_likelihood, dim=1)

REINFORCE loss for problem symmetricity Baseline is the average reward for all augmented problems Corresponds to L_ps in the SymNCO paper

Source code in rl4co/models/zoo/symnco/losses.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
def problem_symmetricity_loss(reward, log_likelihood, dim=1):
    """REINFORCE loss for problem symmetricity
    Baseline is the average reward for all augmented problems
    Corresponds to `L_ps` in the SymNCO paper
    """
    num_augment = reward.shape[dim]
    if num_augment < 2:
        return 0
    advantage = reward - reward.mean(dim=dim, keepdim=True)
    loss = -advantage * log_likelihood
    return loss.mean()

solution_symmetricity_loss

solution_symmetricity_loss(reward, log_likelihood, dim=-1)

REINFORCE loss for solution symmetricity Baseline is the average reward for all start nodes Corresponds to L_ss in the SymNCO paper

Source code in rl4co/models/zoo/symnco/losses.py
18
19
20
21
22
23
24
25
26
27
28
def solution_symmetricity_loss(reward, log_likelihood, dim=-1):
    """REINFORCE loss for solution symmetricity
    Baseline is the average reward for all start nodes
    Corresponds to `L_ss` in the SymNCO paper
    """
    num_starts = reward.shape[dim]
    if num_starts < 2:
        return 0
    advantage = reward - reward.mean(dim=dim, keepdim=True)
    loss = -advantage * log_likelihood
    return loss.mean()

invariance_loss

invariance_loss(proj_embed, num_augment)

Loss for invariant representation on projected nodes Corresponds to L_inv in the SymNCO paper

Source code in rl4co/models/zoo/symnco/losses.py
31
32
33
34
35
36
37
38
39
def invariance_loss(proj_embed, num_augment):
    """Loss for invariant representation on projected nodes
    Corresponds to `L_inv` in the SymNCO paper
    """
    pe = rearrange(proj_embed, "(b a) ... -> b a ...", a=num_augment)
    similarity = sum(
        [cosine_similarity(pe[:, 0], pe[:, i], dim=-1) for i in range(1, num_augment)]
    )
    return similarity.mean()