Skip to content

Neural Network Modules

Critic Network

CriticNetwork

CriticNetwork(
    encoder: Module,
    value_head: Optional[Module] = None,
    embed_dim: int = 128,
    hidden_dim: int = 512,
    customized: bool = False,
)

Bases: Module

Create a critic network given an encoder (e.g. as the one in the policy network) with a value head to transform the embeddings to a scalar value.

Parameters:

  • encoder (Module) –

    Encoder module to encode the input

  • value_head (Optional[Module], default: None ) –

    Value head to transform the embeddings to a scalar value

  • embed_dim (int, default: 128 ) –

    Dimension of the embeddings of the value head

  • hidden_dim (int, default: 512 ) –

    Dimension of the hidden layer of the value head

Source code in rl4co/models/rl/common/critic.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    encoder: nn.Module,
    value_head: Optional[nn.Module] = None,
    embed_dim: int = 128,
    hidden_dim: int = 512,
    customized: bool = False,
):
    super(CriticNetwork, self).__init__()

    self.encoder = encoder
    if value_head is None:
        # check if embed dim of encoder is different, if so, use it
        if getattr(encoder, "embed_dim", embed_dim) != embed_dim:
            log.warning(
                f"Found encoder with different embed_dim {encoder.embed_dim} than the value head {embed_dim}. Using encoder embed_dim for value head."
            )
            embed_dim = getattr(encoder, "embed_dim", embed_dim)
        value_head = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1)
        )
    self.value_head = value_head
    self.customized = customized

forward

forward(
    x: Union[Tensor, TensorDict], hidden=None
) -> Tensor

Forward pass of the critic network: encode the imput in embedding space and return the value

Parameters:

  • x (Union[Tensor, TensorDict]) –

    Input containing the environment state. Can be a Tensor or a TensorDict

Returns:

  • Tensor

    Value of the input state

Source code in rl4co/models/rl/common/critic.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def forward(self, x: Union[Tensor, TensorDict], hidden=None) -> Tensor:
    """Forward pass of the critic network: encode the imput in embedding space and return the value

    Args:
        x: Input containing the environment state. Can be a Tensor or a TensorDict

    Returns:
        Value of the input state
    """
    if not self.customized:  # fir for most of costructive tasks
        h, _ = self.encoder(x)  # [batch_size, N, embed_dim] -> [batch_size, N]
        return self.value_head(h).mean(1)  # [batch_size, N] -> [batch_size]
    else:  # custimized encoder and value head with hidden input
        h = self.encoder(x)  # [batch_size, N, embed_dim] -> [batch_size, N]
        return self.value_head(h, hidden)

Graph Neural Networks

MultiHeadAttentionLayer

MultiHeadAttentionLayer(
    embed_dim: int,
    num_heads: int = 8,
    feedforward_hidden: int = 512,
    normalization: Optional[str] = "batch",
    bias: bool = True,
    sdpa_fn: Optional[Callable] = None,
    moe_kwargs: Optional[dict] = None,
)

Bases: Sequential

Multi-Head Attention Layer with normalization and feed-forward layer

Parameters:

  • embed_dim (int) –

    dimension of the embeddings

  • num_heads (int, default: 8 ) –

    number of heads in the MHA

  • feedforward_hidden (int, default: 512 ) –

    dimension of the hidden layer in the feed-forward layer

  • normalization (Optional[str], default: 'batch' ) –

    type of normalization to use (batch, layer, none)

  • sdpa_fn (Optional[Callable], default: None ) –

    scaled dot product attention function (SDPA)

  • moe_kwargs (Optional[dict], default: None ) –

    Keyword arguments for MoE

Source code in rl4co/models/nn/graph/attnnet.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    embed_dim: int,
    num_heads: int = 8,
    feedforward_hidden: int = 512,
    normalization: Optional[str] = "batch",
    bias: bool = True,
    sdpa_fn: Optional[Callable] = None,
    moe_kwargs: Optional[dict] = None,
):
    num_neurons = [feedforward_hidden] if feedforward_hidden > 0 else []
    if moe_kwargs is not None:
        ffn = MoE(embed_dim, embed_dim, num_neurons=num_neurons, **moe_kwargs)
    else:
        ffn = MLP(input_dim=embed_dim, output_dim=embed_dim, num_neurons=num_neurons, hidden_act="ReLU")

    super(MultiHeadAttentionLayer, self).__init__(
        SkipConnection(
            MultiHeadAttention(embed_dim, num_heads, bias=bias, sdpa_fn=sdpa_fn)
        ),
        Normalization(embed_dim, normalization),
        SkipConnection(ffn),
        Normalization(embed_dim, normalization),
    )

GraphAttentionNetwork

GraphAttentionNetwork(
    num_heads: int,
    embed_dim: int,
    num_layers: int,
    normalization: str = "batch",
    feedforward_hidden: int = 512,
    sdpa_fn: Optional[Callable] = None,
    moe_kwargs: Optional[dict] = None,
)

Bases: Module

Graph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).

Parameters:

  • num_heads (int) –

    number of heads in the MHA

  • embed_dim (int) –

    dimension of the embeddings

  • num_layers (int) –

    number of MHA layers

  • normalization (str, default: 'batch' ) –

    type of normalization to use (batch, layer, none)

  • feedforward_hidden (int, default: 512 ) –

    dimension of the hidden layer in the feed-forward layer

  • sdpa_fn (Optional[Callable], default: None ) –

    scaled dot product attention function (SDPA)

  • moe_kwargs (Optional[dict], default: None ) –

    Keyword arguments for MoE

Source code in rl4co/models/nn/graph/attnnet.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    num_heads: int,
    embed_dim: int,
    num_layers: int,
    normalization: str = "batch",
    feedforward_hidden: int = 512,
    sdpa_fn: Optional[Callable] = None,
    moe_kwargs: Optional[dict] = None,
):
    super(GraphAttentionNetwork, self).__init__()

    self.layers = nn.Sequential(
        *(
            MultiHeadAttentionLayer(
                embed_dim,
                num_heads,
                feedforward_hidden=feedforward_hidden,
                normalization=normalization,
                sdpa_fn=sdpa_fn,
                moe_kwargs=moe_kwargs,
            )
            for _ in range(num_layers)
        )
    )

forward

forward(x: Tensor, mask: Optional[Tensor] = None) -> Tensor

Forward pass of the encoder

Parameters:

  • x (Tensor) –

    [batch_size, graph_size, embed_dim] initial embeddings to process

  • mask (Optional[Tensor], default: None ) –

    [batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now.

Source code in rl4co/models/nn/graph/attnnet.py
 94
 95
 96
 97
 98
 99
100
101
102
103
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
    """Forward pass of the encoder

    Args:
        x: [batch_size, graph_size, embed_dim] initial embeddings to process
        mask: [batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now.
    """
    assert mask is None, "Mask not yet supported!"
    h = self.layers(x)
    return h

GCNEncoder

GCNEncoder(
    env_name: str,
    embed_dim: int,
    num_layers: int,
    init_embedding: Module = None,
    residual: bool = True,
    edge_idx_fn: EdgeIndexFnSignature = None,
    dropout: float = 0.5,
    bias: bool = True,
)

Bases: Module

Graph Convolutional Network to encode embeddings with a series of GCN layers from the pytorch geometric package

Parameters:

  • embed_dim (int) –

    dimension of the embeddings

  • num_nodes

    number of nodes in the graph

  • num_gcn_layer

    number of GCN layers

  • self_loop

    whether to add self loop in the graph

  • residual (bool, default: True ) –

    whether to use residual connection

Source code in rl4co/models/nn/graph/gcn.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    env_name: str,
    embed_dim: int,
    num_layers: int,
    init_embedding: nn.Module = None,
    residual: bool = True,
    edge_idx_fn: EdgeIndexFnSignature = None,
    dropout: float = 0.5,
    bias: bool = True,
):
    super().__init__()

    self.env_name = env_name
    self.embed_dim = embed_dim
    self.residual = residual
    self.dropout = dropout

    self.init_embedding = (
        env_init_embedding(self.env_name, {"embed_dim": embed_dim})
        if init_embedding is None
        else init_embedding
    )

    if edge_idx_fn is None:
        log.warning("No edge indices passed. Assume a fully connected graph")
        edge_idx_fn = edge_idx_fn_wrapper

    self.edge_idx_fn = edge_idx_fn

    # Define the GCN layers
    self.gcn_layers = nn.ModuleList(
        [GCNConv(embed_dim, embed_dim, bias=bias) for _ in range(num_layers)]
    )

forward

forward(
    td: TensorDict, mask: Union[Tensor, None] = None
) -> Tuple[Tensor, Tensor]

Forward pass of the encoder. Transform the input TensorDict into a latent representation.

Parameters:

  • td (TensorDict) –

    Input TensorDict containing the environment state

  • mask (Union[Tensor, None], default: None ) –

    Mask to apply to the attention

Returns:

  • h ( Tensor ) –

    Latent representation of the input

  • init_h ( Tensor ) –

    Initial embedding of the input

Source code in rl4co/models/nn/graph/gcn.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def forward(
    self, td: TensorDict, mask: Union[Tensor, None] = None
) -> Tuple[Tensor, Tensor]:
    """Forward pass of the encoder.
    Transform the input TensorDict into a latent representation.

    Args:
        td: Input TensorDict containing the environment state
        mask: Mask to apply to the attention

    Returns:
        h: Latent representation of the input
        init_h: Initial embedding of the input
    """
    # Transfer to embedding space
    init_h = self.init_embedding(td)
    bs, num_nodes, emb_dim = init_h.shape
    # (bs*num_nodes, emb_dim)
    update_node_feature = init_h.reshape(-1, emb_dim)
    # shape=(2, num_edges)
    edge_index = self.edge_idx_fn(td, num_nodes)

    for layer in self.gcn_layers[:-1]:
        update_node_feature = layer(update_node_feature, edge_index)
        update_node_feature = F.relu(update_node_feature)
        update_node_feature = F.dropout(
            update_node_feature, training=self.training, p=self.dropout
        )

    # last layer without relu activation and dropout
    update_node_feature = self.gcn_layers[-1](update_node_feature, edge_index)

    # De-batch the graph
    update_node_feature = update_node_feature.view(bs, num_nodes, emb_dim)

    # Residual
    if self.residual:
        update_node_feature = update_node_feature + init_h

    return update_node_feature, init_h

MessagePassingEncoder

MessagePassingEncoder(
    env_name: str,
    embed_dim: int,
    num_nodes: int,
    num_layers: int,
    init_embedding: Module = None,
    aggregation: str = "add",
    self_loop: bool = False,
    residual: bool = True,
)

Bases: Module

Source code in rl4co/models/nn/graph/mpnn.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
    self,
    env_name: str,
    embed_dim: int,
    num_nodes: int,
    num_layers: int,
    init_embedding: nn.Module = None,
    aggregation: str = "add",
    self_loop: bool = False,
    residual: bool = True,
):
    """
    Note:
        - Support fully connected graph for now.
    """
    super(MessagePassingEncoder, self).__init__()

    self.env_name = env_name

    self.init_embedding = (
        env_init_embedding(self.env_name, {"embed_dim": embed_dim})
        if init_embedding is None
        else init_embedding
    )

    # Generate edge index for a fully connected graph
    adj_matrix = torch.ones(num_nodes, num_nodes)
    if self_loop:
        adj_matrix.fill_diagonal_(0)  # No self-loops
    self.edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0))

    # Init message passing models
    self.mpnn_layers = nn.ModuleList(
        [
            MessagePassingLayer(
                node_indim=embed_dim,
                node_outdim=embed_dim,
                edge_indim=1,
                edge_outdim=1,
                aggregation=aggregation,
                residual=residual,
            )
            for _ in range(num_layers)
        ]
    )

    # Record parameters
    self.self_loop = self_loop

Attention Mechanisms

MultiHeadAttention

MultiHeadAttention(
    embed_dim: int,
    num_heads: int,
    bias: bool = True,
    attention_dropout: float = 0.0,
    causal: bool = False,
    device: str = None,
    dtype: dtype = None,
    sdpa_fn: Optional[Callable] = None,
)

Bases: Module

PyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support. Uses PyTorch's native scaled_dot_product_attention implementation, available from 2.0

Note

If scaled_dot_product_attention is not available, use custom implementation of scaled_dot_product_attention without Flash Attention.

Parameters:

  • embed_dim (int) –

    total dimension of the model

  • num_heads (int) –

    number of heads

  • bias (bool, default: True ) –

    whether to use bias

  • attention_dropout (float, default: 0.0 ) –

    dropout rate for attention weights

  • causal (bool, default: False ) –

    whether to apply causal mask to attention scores

  • device (str, default: None ) –

    torch device

  • dtype (dtype, default: None ) –

    torch dtype

  • sdpa_fn (Optional[Callable], default: None ) –

    scaled dot product attention function (SDPA) implementation

Source code in rl4co/models/nn/attention.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    bias: bool = True,
    attention_dropout: float = 0.0,
    causal: bool = False,
    device: str = None,
    dtype: torch.dtype = None,
    sdpa_fn: Optional[Callable] = None,
) -> None:
    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    self.embed_dim = embed_dim
    self.causal = causal
    self.attention_dropout = attention_dropout
    self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention

    self.num_heads = num_heads
    assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
    self.head_dim = self.embed_dim // num_heads
    assert (
        self.head_dim % 8 == 0 and self.head_dim <= 128
    ), "Only support head_dim <= 128 and divisible by 8"

    self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

forward

forward(x, attn_mask=None)

x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) attn_mask: bool tensor of shape (batch, seqlen)

Source code in rl4co/models/nn/attention.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def forward(self, x, attn_mask=None):
    """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
    attn_mask: bool tensor of shape (batch, seqlen)
    """
    # Project query, key, value
    q, k, v = rearrange(
        self.Wqkv(x), "b s (three h d) -> three b h s d", three=3, h=self.num_heads
    ).unbind(dim=0)

    if attn_mask is not None:
        attn_mask = (
            attn_mask.unsqueeze(1)
            if attn_mask.ndim == 3
            else attn_mask.unsqueeze(1).unsqueeze(2)
        )

    # Scaled dot product attention
    out = self.sdpa_fn(
        q,
        k,
        v,
        attn_mask=attn_mask,
        dropout_p=self.attention_dropout,
    )
    return self.out_proj(rearrange(out, "b h s d -> b s (h d)"))

MultiHeadCrossAttention

MultiHeadCrossAttention(
    embed_dim: int,
    num_heads: int,
    bias: bool = False,
    attention_dropout: float = 0.0,
    device: str = None,
    dtype: dtype = None,
    sdpa_fn: Optional[Union[Callable, Module]] = None,
)

Bases: Module

PyTorch native implementation of Flash Multi-Head Cross Attention with automatic mixed precision support. Uses PyTorch's native scaled_dot_product_attention implementation, available from 2.0

Note

If scaled_dot_product_attention is not available, use custom implementation of scaled_dot_product_attention without Flash Attention.

Parameters:

  • embed_dim (int) –

    total dimension of the model

  • num_heads (int) –

    number of heads

  • bias (bool, default: False ) –

    whether to use bias

  • attention_dropout (float, default: 0.0 ) –

    dropout rate for attention weights

  • device (str, default: None ) –

    torch device

  • dtype (dtype, default: None ) –

    torch dtype

  • sdpa_fn (Optional[Union[Callable, Module]], default: None ) –

    scaled dot product attention function (SDPA)

Source code in rl4co/models/nn/attention.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    bias: bool = False,
    attention_dropout: float = 0.0,
    device: str = None,
    dtype: torch.dtype = None,
    sdpa_fn: Optional[Union[Callable, nn.Module]] = None,
) -> None:
    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    self.embed_dim = embed_dim
    self.attention_dropout = attention_dropout

    # Default to `scaled_dot_product_attention` if `sdpa_fn` is not provided
    if sdpa_fn is None:
        sdpa_fn = sdpa_fn_wrapper
    self.sdpa_fn = sdpa_fn

    self.num_heads = num_heads
    assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
    self.head_dim = self.embed_dim // num_heads
    assert (
        self.head_dim % 8 == 0 and self.head_dim <= 128
    ), "Only support head_dim <= 128 and divisible by 8"

    self.Wq = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
    self.Wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

PointerAttention

PointerAttention(
    embed_dim: int,
    num_heads: int,
    mask_inner: bool = True,
    out_bias: bool = False,
    check_nan: bool = True,
    sdpa_fn: Optional[Callable] = None,
    **kwargs
)

Bases: Module

Calculate logits given query, key and value and logit key. This follows the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134).

Note

With Flash Attention, masking is not supported

Performs the following
  1. Apply cross attention to get the heads
  2. Project heads to get glimpse
  3. Compute attention score between glimpse and logit key

Parameters:

  • embed_dim (int) –

    total dimension of the model

  • num_heads (int) –

    number of heads

  • mask_inner (bool, default: True ) –

    whether to mask inner attention

  • linear_bias

    whether to use bias in linear projection

  • check_nan (bool, default: True ) –

    whether to check for NaNs in logits

  • sdpa_fn (Optional[Callable], default: None ) –

    scaled dot product attention function (SDPA) implementation

Source code in rl4co/models/nn/attention.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    mask_inner: bool = True,
    out_bias: bool = False,
    check_nan: bool = True,
    sdpa_fn: Optional[Callable] = None,
    **kwargs,
):
    super(PointerAttention, self).__init__()
    self.num_heads = num_heads
    self.mask_inner = mask_inner

    # Projection - query, key, value already include projections
    self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias)
    self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention
    self.check_nan = check_nan

forward

forward(query, key, value, logit_key, attn_mask=None)

Compute attention logits given query, key, value, logit key and attention mask.

Parameters:

  • query

    query tensor of shape [B, ..., L, E]

  • key

    key tensor of shape [B, ..., S, E]

  • value

    value tensor of shape [B, ..., S, E]

  • logit_key

    logit key tensor of shape [B, ..., S, E]

  • attn_mask

    attention mask tensor of shape [B, ..., S]. Note that True means that the value should take part in attention as described in the PyTorch Documentation

Source code in rl4co/models/nn/attention.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def forward(self, query, key, value, logit_key, attn_mask=None):
    """Compute attention logits given query, key, value, logit key and attention mask.

    Args:
        query: query tensor of shape [B, ..., L, E]
        key: key tensor of shape [B, ..., S, E]
        value: value tensor of shape [B, ..., S, E]
        logit_key: logit key tensor of shape [B, ..., S, E]
        attn_mask: attention mask tensor of shape [B, ..., S]. Note that `True` means that the value _should_ take part in attention
            as described in the [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
    """
    # Compute inner multi-head attention with no projections.
    heads = self._inner_mha(query, key, value, attn_mask)
    glimpse = self._project_out(heads, attn_mask)

    # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
    # bmm is slightly faster than einsum and matmul
    logits = (torch.bmm(glimpse, logit_key.squeeze(-2).transpose(-2, -1))).squeeze(
        -2
    ) / math.sqrt(glimpse.size(-1))

    if self.check_nan:
        assert not torch.isnan(logits).any(), "Logits contain NaNs"

    return logits

PointerAttnMoE

PointerAttnMoE(
    embed_dim: int,
    num_heads: int,
    mask_inner: bool = True,
    out_bias: bool = False,
    check_nan: bool = True,
    sdpa_fn: Optional[Callable] = None,
    moe_kwargs: Optional[dict] = None,
)

Bases: PointerAttention

Calculate logits given query, key and value and logit key. This follows the pointer mechanism of Vinyals et al. (2015) https://arxiv.org/abs/1506.03134, and the MoE gating mechanism of Zhou et al. (2024) https://arxiv.org/abs/2405.01029.

Note

With Flash Attention, masking is not supported

Performs the following
  1. Apply cross attention to get the heads
  2. Project heads to get glimpse
  3. Compute attention score between glimpse and logit key

Parameters:

  • embed_dim (int) –

    total dimension of the model

  • num_heads (int) –

    number of heads

  • mask_inner (bool, default: True ) –

    whether to mask inner attention

  • linear_bias

    whether to use bias in linear projection

  • check_nan (bool, default: True ) –

    whether to check for NaNs in logits

  • sdpa_fn (Optional[Callable], default: None ) –

    scaled dot product attention function (SDPA) implementation

  • moe_kwargs (Optional[dict], default: None ) –

    Keyword arguments for MoE

Source code in rl4co/models/nn/attention.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def __init__(
    self,
    embed_dim: int,
    num_heads: int,
    mask_inner: bool = True,
    out_bias: bool = False,
    check_nan: bool = True,
    sdpa_fn: Optional[Callable] = None,
    moe_kwargs: Optional[dict] = None,
):
    super(PointerAttnMoE, self).__init__(
        embed_dim, num_heads, mask_inner, out_bias, check_nan, sdpa_fn
    )
    self.moe_kwargs = moe_kwargs

    self.project_out = None
    self.project_out_moe = MoE(
        embed_dim, embed_dim, num_neurons=[], out_bias=out_bias, **moe_kwargs
    )
    if self.moe_kwargs["light_version"]:
        self.dense_or_moe = nn.Linear(embed_dim, 2, bias=False)
        self.project_out = nn.Linear(embed_dim, embed_dim, bias=out_bias)

MultiHeadCompat

MultiHeadCompat(
    n_heads,
    input_dim,
    embed_dim=None,
    val_dim=None,
    key_dim=None,
)

Bases: Module

Source code in rl4co/models/nn/attention.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None):
    super(MultiHeadCompat, self).__init__()

    if val_dim is None:
        # assert embed_dim is not None, "Provide either embed_dim or val_dim"
        val_dim = embed_dim // n_heads
    if key_dim is None:
        key_dim = val_dim

    self.n_heads = n_heads
    self.input_dim = input_dim
    self.embed_dim = embed_dim
    self.val_dim = val_dim
    self.key_dim = key_dim

    self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
    self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))

    self.init_parameters()

forward

forward(q, h=None, mask=None)

:param q: queries (batch_size, n_query, input_dim) :param h: data (batch_size, graph_size, input_dim) :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) :return:

Source code in rl4co/models/nn/attention.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def forward(self, q, h=None, mask=None):
    """

    :param q: queries (batch_size, n_query, input_dim)
    :param h: data (batch_size, graph_size, input_dim)
    :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
    Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
    :return:
    """

    if h is None:
        h = q  # compute self-attention

    # h should be (batch_size, graph_size, input_dim)
    batch_size, graph_size, input_dim = h.size()
    n_query = q.size(1)

    hflat = h.contiguous().view(-1, input_dim)  #################   reshape
    qflat = q.contiguous().view(-1, input_dim)

    # last dimension can be different for keys and values
    shp = (self.n_heads, batch_size, graph_size, -1)
    shp_q = (self.n_heads, batch_size, n_query, -1)

    # Calculate queries, (n_heads, n_query, graph_size, key/val_size)
    Q = torch.matmul(qflat, self.W_query).view(shp_q)
    K = torch.matmul(hflat, self.W_key).view(shp)

    # Calculate compatibility (n_heads, batch_size, n_query, graph_size)
    compatibility_s2n = torch.matmul(Q, K.transpose(2, 3))

    return compatibility_s2n

PolyNetAttention

PolyNetAttention(
    k: int,
    embed_dim: int,
    poly_layer_dim: int,
    num_heads: int,
    **kwargs
)

Bases: PointerAttention

Calculate logits given query, key and value and logit key. This implements a modified version the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134) as described in Hottung et al. (2024) (https://arxiv.org/abs/2402.14048) PolyNetAttention conditions the attention logits on a set of k different binary vectors allowing to learn k different solution strategies.

Note

With Flash Attention, masking is not supported

Performs the following
  1. Apply cross attention to get the heads
  2. Project heads to get glimpse
  3. Apply PolyNet layers
  4. Compute attention score between glimpse and logit key

Parameters:

  • k (int) –

    Number unique bit vectors used to compute attention score

  • embed_dim (int) –

    total dimension of the model

  • poly_layer_dim (int) –

    Dimension of the PolyNet layers

  • num_heads (int) –

    number of heads

  • mask_inner

    whether to mask inner attention

  • linear_bias

    whether to use bias in linear projection

  • check_nan

    whether to check for NaNs in logits

  • sdpa_fn

    scaled dot product attention function (SDPA) implementation

Source code in rl4co/models/nn/attention.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def __init__(
    self, k: int, embed_dim: int, poly_layer_dim: int, num_heads: int, **kwargs
):
    super(PolyNetAttention, self).__init__(embed_dim, num_heads, **kwargs)

    self.k = k
    self.binary_vector_dim = math.ceil(math.log2(k))
    self.binary_vectors = torch.nn.Parameter(
        torch.Tensor(
            list(itertools.product([0, 1], repeat=self.binary_vector_dim))[:k]
        ),
        requires_grad=False,
    )

    self.poly_layer_1 = nn.Linear(embed_dim + self.binary_vector_dim, poly_layer_dim)
    self.poly_layer_2 = nn.Linear(poly_layer_dim, embed_dim)

forward

forward(query, key, value, logit_key, attn_mask=None)

Compute attention logits given query, key, value, logit key and attention mask.

Parameters:

  • query

    query tensor of shape [B, ..., L, E]

  • key

    key tensor of shape [B, ..., S, E]

  • value

    value tensor of shape [B, ..., S, E]

  • logit_key

    logit key tensor of shape [B, ..., S, E]

  • attn_mask

    attention mask tensor of shape [B, ..., S]. Note that True means that the value should take part in attention as described in the PyTorch Documentation

Source code in rl4co/models/nn/attention.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def forward(self, query, key, value, logit_key, attn_mask=None):
    """Compute attention logits given query, key, value, logit key and attention mask.

    Args:
        query: query tensor of shape [B, ..., L, E]
        key: key tensor of shape [B, ..., S, E]
        value: value tensor of shape [B, ..., S, E]
        logit_key: logit key tensor of shape [B, ..., S, E]
        attn_mask: attention mask tensor of shape [B, ..., S]. Note that `True` means that the value _should_ take part in attention
            as described in the [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
    """
    # Compute inner multi-head attention with no projections.
    heads = self._inner_mha(query, key, value, attn_mask)
    glimpse = self.project_out(heads)

    num_solutions = glimpse.shape[1]
    z = self.binary_vectors.repeat(math.ceil(num_solutions / self.k), 1)[
        :num_solutions
    ]
    z = z[None].expand(glimpse.shape[0], num_solutions, self.binary_vector_dim)

    # PolyNet layers
    poly_out = self.poly_layer_1(torch.cat((glimpse, z), dim=2))
    poly_out = F.relu(poly_out)
    poly_out = self.poly_layer_2(poly_out)

    glimpse += poly_out

    # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
    # bmm is slightly faster than einsum and matmul
    logits = (torch.bmm(glimpse, logit_key.squeeze(-2).transpose(-2, -1))).squeeze(
        -2
    ) / math.sqrt(glimpse.size(-1))

    if self.check_nan:
        assert not torch.isnan(logits).any(), "Logits contain NaNs"

    return logits

scaled_dot_product_attention_simple

scaled_dot_product_attention_simple(
    q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)

Simple Scaled Dot-Product Attention in PyTorch without Flash Attention

Source code in rl4co/models/nn/attention.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def scaled_dot_product_attention_simple(
    q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
):
    """Simple Scaled Dot-Product Attention in PyTorch without Flash Attention"""
    # Check for causal and attn_mask conflict
    if is_causal and attn_mask is not None:
        raise ValueError("Cannot set both is_causal and attn_mask")

    # Calculate scaled dot product
    scores = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)

    # Apply the provided attention mask
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            scores.masked_fill_(~attn_mask, float("-inf"))
        else:
            scores += attn_mask

    # Apply causal mask
    if is_causal:
        s, l_ = scores.size(-2), scores.size(-1)
        mask = torch.triu(torch.ones((s, l_), device=scores.device), diagonal=1)
        scores.masked_fill_(mask.bool(), float("-inf"))

    # Softmax to get attention weights
    attn_weights = F.softmax(scores, dim=-1)

    # Apply dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # Compute the weighted sum of values
    return torch.matmul(attn_weights, v)

Multi-Layer Perceptron

MLP

MLP(
    input_dim: int,
    output_dim: int,
    num_neurons: List[int] = [64, 32],
    dropout_probs: Union[None, List[float]] = None,
    hidden_act: str = "ReLU",
    out_act: str = "Identity",
    input_norm: str = "None",
    output_norm: str = "None",
)

Bases: Module

Source code in rl4co/models/nn/mlp.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    num_neurons: List[int] = [64, 32],
    dropout_probs: Union[None, List[float]] = None,
    hidden_act: str = "ReLU",
    out_act: str = "Identity",
    input_norm: str = "None",
    output_norm: str = "None",
):
    super(MLP, self).__init__()

    assert input_norm in ["Batch", "Layer", "None"]
    assert output_norm in ["Batch", "Layer", "None"]

    if dropout_probs is None:
        dropout_probs = [0.0] * len(num_neurons)
    elif len(dropout_probs) != len(num_neurons):
        log.info(
            "dropout_probs List length should match the num_neurons List length for MLP, dropouts set to False instead"
        )
        dropout_probs = [0.0] * len(num_neurons)

    self.input_dim = input_dim
    self.output_dim = output_dim
    self.num_neurons = num_neurons
    self.hidden_act = getattr(nn, hidden_act)()
    self.out_act = getattr(nn, out_act)()
    self.dropouts = []
    for i in range(len(dropout_probs)):
        self.dropouts.append(nn.Dropout(p=dropout_probs[i]))

    input_dims = [input_dim] + num_neurons
    output_dims = num_neurons + [output_dim]

    self.lins = nn.ModuleList()
    for i, (in_dim, out_dim) in enumerate(zip(input_dims, output_dims)):
        self.lins.append(nn.Linear(in_dim, out_dim))

    self.input_norm = self._get_norm_layer(input_norm, input_dim)
    self.output_norm = self._get_norm_layer(output_norm, output_dim)

Operations

PositionalEncoding

PositionalEncoding(
    embed_dim: int,
    dropout: float = 0.1,
    max_len: int = 1000,
)

Bases: Module

Source code in rl4co/models/nn/ops.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(self, embed_dim: int, dropout: float = 0.1, max_len: int = 1000):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.d_model = embed_dim
    max_len = max_len
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model)
    )
    pe = torch.zeros(max_len, 1, self.d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    pe = pe.transpose(0, 1)  # [1, max_len, d_model]
    self.register_buffer("pe", pe)

forward

forward(hidden: Tensor, seq_pos) -> Tensor

Parameters:

  • x

    Tensor, shape [batch_size, seq_len, embedding_dim]

  • seq_pos

    Tensor, shape [batch_size, seq_len]

Source code in rl4co/models/nn/ops.py
75
76
77
78
79
80
81
82
83
84
85
def forward(self, hidden: torch.Tensor, seq_pos) -> torch.Tensor:
    """
    Arguments:
        x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        seq_pos: Tensor, shape ``[batch_size, seq_len]``
    """
    pes = self.pe.expand(hidden.size(0), -1, -1).gather(
        1, seq_pos.unsqueeze(-1).expand(-1, -1, self.d_model)
    )
    hidden = hidden + pes
    return self.dropout(hidden)

RandomEncoding

RandomEncoding(embed_dim: int, max_classes: int = 100)

Bases: Module

This is like torch.nn.Embedding but with rows of embeddings are randomly permuted in each forward pass before lookup operation. This might be useful in cases where classes have no fixed meaning but rather indicate a connection between different elements in a sequence. Reference is the MatNet model.

Source code in rl4co/models/nn/ops.py
118
119
120
121
122
123
def __init__(self, embed_dim: int, max_classes: int = 100):
    super().__init__()
    self.embed_dim = embed_dim
    self.max_classes = max_classes
    rand_emb = torch.rand(max_classes, self.embed_dim)
    self.register_buffer("emb", rand_emb)