Neural Network Modules¶
Critic Network¶
CriticNetwork
¶
CriticNetwork(
encoder: Module,
value_head: Optional[Module] = None,
embed_dim: int = 128,
hidden_dim: int = 512,
customized: bool = False,
)
Bases: Module
Create a critic network given an encoder (e.g. as the one in the policy network) with a value head to transform the embeddings to a scalar value.
Parameters:
-
encoder
(Module
) –Encoder module to encode the input
-
value_head
(Optional[Module]
, default:None
) –Value head to transform the embeddings to a scalar value
-
embed_dim
(int
, default:128
) –Dimension of the embeddings of the value head
-
hidden_dim
(int
, default:512
) –Dimension of the hidden layer of the value head
Methods:
-
forward
–Forward pass of the critic network: encode the imput in embedding space and return the value
Source code in rl4co/models/rl/common/critic.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
forward
¶
forward(x: Tensor | TensorDict, hidden=None) -> Tensor
Forward pass of the critic network: encode the imput in embedding space and return the value
Parameters:
-
x
(Tensor | TensorDict
) –Input containing the environment state. Can be a Tensor or a TensorDict
Returns:
-
Tensor
–Value of the input state
Source code in rl4co/models/rl/common/critic.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|
Graph Neural Networks¶
MultiHeadAttentionLayer
¶
MultiHeadAttentionLayer(
embed_dim: int,
num_heads: int = 8,
feedforward_hidden: int = 512,
normalization: Optional[str] = "batch",
bias: bool = True,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
)
Bases: Sequential
Multi-Head Attention Layer with normalization and feed-forward layer
Parameters:
-
embed_dim
(int
) –dimension of the embeddings
-
num_heads
(int
, default:8
) –number of heads in the MHA
-
feedforward_hidden
(int
, default:512
) –dimension of the hidden layer in the feed-forward layer
-
normalization
(Optional[str]
, default:'batch'
) –type of normalization to use (batch, layer, none)
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA)
-
moe_kwargs
(Optional[dict]
, default:None
) –Keyword arguments for MoE
Source code in rl4co/models/nn/graph/attnnet.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
|
GraphAttentionNetwork
¶
GraphAttentionNetwork(
num_heads: int,
embed_dim: int,
num_layers: int,
normalization: str = "batch",
feedforward_hidden: int = 512,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
)
Bases: Module
Graph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).
Parameters:
-
num_heads
(int
) –number of heads in the MHA
-
embed_dim
(int
) –dimension of the embeddings
-
num_layers
(int
) –number of MHA layers
-
normalization
(str
, default:'batch'
) –type of normalization to use (batch, layer, none)
-
feedforward_hidden
(int
, default:512
) –dimension of the hidden layer in the feed-forward layer
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA)
-
moe_kwargs
(Optional[dict]
, default:None
) –Keyword arguments for MoE
Methods:
-
forward
–Forward pass of the encoder
Source code in rl4co/models/nn/graph/attnnet.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
forward
¶
Forward pass of the encoder
Parameters:
-
x
(Tensor
) –[batch_size, graph_size, embed_dim] initial embeddings to process
-
mask
(Optional[Tensor]
, default:None
) –[batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now.
Source code in rl4co/models/nn/graph/attnnet.py
94 95 96 97 98 99 100 101 102 103 |
|
GCNEncoder
¶
GCNEncoder(
env_name: str,
embed_dim: int,
num_layers: int,
init_embedding: Module = None,
residual: bool = True,
edge_idx_fn: EdgeIndexFnSignature = None,
dropout: float = 0.5,
bias: bool = True,
)
Bases: Module
Graph Convolutional Network to encode embeddings with a series of GCN layers from the pytorch geometric package
Parameters:
-
embed_dim
(int
) –dimension of the embeddings
-
num_nodes
–number of nodes in the graph
-
num_gcn_layer
–number of GCN layers
-
self_loop
–whether to add self loop in the graph
-
residual
(bool
, default:True
) –whether to use residual connection
Methods:
-
forward
–Forward pass of the encoder.
Source code in rl4co/models/nn/graph/gcn.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
forward
¶
forward(
td: TensorDict, mask: Tensor | None = None
) -> Tuple[Tensor, Tensor]
Forward pass of the encoder. Transform the input TensorDict into a latent representation.
Parameters:
-
td
(TensorDict
) –Input TensorDict containing the environment state
-
mask
(Tensor | None
, default:None
) –Mask to apply to the attention
Returns:
-
h
(Tensor
) –Latent representation of the input
-
init_h
(Tensor
) –Initial embedding of the input
Source code in rl4co/models/nn/graph/gcn.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
MessagePassingEncoder
¶
MessagePassingEncoder(
env_name: str,
embed_dim: int,
num_nodes: int,
num_layers: int,
init_embedding: Module = None,
aggregation: str = "add",
self_loop: bool = False,
residual: bool = True,
)
Bases: Module
Source code in rl4co/models/nn/graph/mpnn.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
Attention Mechanisms¶
Classes:
-
MultiHeadAttention
–PyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support.
-
MultiHeadCrossAttention
–PyTorch native implementation of Flash Multi-Head Cross Attention with automatic mixed precision support.
-
PointerAttention
–Calculate logits given query, key and value and logit key.
-
PointerAttnMoE
–Calculate logits given query, key and value and logit key.
-
MultiHeadCompat
– -
PolyNetAttention
–Calculate logits given query, key and value and logit key.
Functions:
-
scaled_dot_product_attention_simple
–Simple (exact) Scaled Dot-Product Attention in RL4CO without customized kernels (i.e. no Flash Attention).
MultiHeadAttention
¶
MultiHeadAttention(
embed_dim: int,
num_heads: int,
bias: bool = True,
attention_dropout: float = 0.0,
causal: bool = False,
device: str = None,
dtype: dtype = None,
sdpa_fn: Optional[Callable] = None,
)
Bases: Module
PyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support.
Uses PyTorch's native scaled_dot_product_attention
implementation, available from 2.0
Note
If scaled_dot_product_attention
is not available, use custom implementation of scaled_dot_product_attention
without Flash Attention.
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
bias
(bool
, default:True
) –whether to use bias
-
attention_dropout
(float
, default:0.0
) –dropout rate for attention weights
-
causal
(bool
, default:False
) –whether to apply causal mask to attention scores
-
device
(str
, default:None
) –torch device
-
dtype
(dtype
, default:None
) –torch dtype
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA) implementation
Methods:
-
forward
–x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
Source code in rl4co/models/nn/attention.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
forward
¶
forward(x, attn_mask=None)
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) attn_mask: bool tensor of shape (batch, seqlen)
Source code in rl4co/models/nn/attention.py
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
|
MultiHeadCrossAttention
¶
MultiHeadCrossAttention(
embed_dim: int,
num_heads: int,
bias: bool = False,
attention_dropout: float = 0.0,
device: str = None,
dtype: dtype = None,
sdpa_fn: Optional[Callable | Module] = None,
)
Bases: Module
PyTorch native implementation of Flash Multi-Head Cross Attention with automatic mixed precision support.
Uses PyTorch's native scaled_dot_product_attention
implementation, available from 2.0
Note
If scaled_dot_product_attention
is not available, use custom implementation of scaled_dot_product_attention
without Flash Attention.
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
bias
(bool
, default:False
) –whether to use bias
-
attention_dropout
(float
, default:0.0
) –dropout rate for attention weights
-
device
(str
, default:None
) –torch device
-
dtype
(dtype
, default:None
) –torch dtype
-
sdpa_fn
(Optional[Callable | Module]
, default:None
) –scaled dot product attention function (SDPA)
Source code in rl4co/models/nn/attention.py
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
|
PointerAttention
¶
PointerAttention(
embed_dim: int,
num_heads: int,
mask_inner: bool = True,
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Callable | str = "default",
**kwargs
)
Bases: Module
Calculate logits given query, key and value and logit key. This follows the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134).
Note
With Flash Attention, masking is not supported
Performs the following
- Apply cross attention to get the heads
- Project heads to get glimpse
- Compute attention score between glimpse and logit key
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
mask_inner
(bool
, default:True
) –whether to mask inner attention
-
linear_bias
–whether to use bias in linear projection
-
check_nan
(bool
, default:True
) –whether to check for NaNs in logits
-
sdpa_fn
(Callable | str
, default:'default'
) –scaled dot product attention function (SDPA) implementation
Methods:
-
forward
–Compute attention logits given query, key, value, logit key and attention mask.
Source code in rl4co/models/nn/attention.py
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
|
forward
¶
forward(query, key, value, logit_key, attn_mask=None)
Compute attention logits given query, key, value, logit key and attention mask.
Parameters:
-
query
–query tensor of shape [B, ..., L, E]
-
key
–key tensor of shape [B, ..., S, E]
-
value
–value tensor of shape [B, ..., S, E]
-
logit_key
–logit key tensor of shape [B, ..., S, E]
-
attn_mask
–attention mask tensor of shape [B, ..., S]. Note that
True
means that the value should take part in attention as described in the PyTorch Documentation
Source code in rl4co/models/nn/attention.py
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
|
PointerAttnMoE
¶
PointerAttnMoE(
embed_dim: int,
num_heads: int,
mask_inner: bool = True,
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
)
Bases: PointerAttention
Calculate logits given query, key and value and logit key. This follows the pointer mechanism of Vinyals et al. (2015) https://arxiv.org/abs/1506.03134, and the MoE gating mechanism of Zhou et al. (2024) https://arxiv.org/abs/2405.01029.
Note
With Flash Attention, masking is not supported
Performs the following
- Apply cross attention to get the heads
- Project heads to get glimpse
- Compute attention score between glimpse and logit key
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
mask_inner
(bool
, default:True
) –whether to mask inner attention
-
linear_bias
–whether to use bias in linear projection
-
check_nan
(bool
, default:True
) –whether to check for NaNs in logits
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA) implementation
-
moe_kwargs
(Optional[dict]
, default:None
) –Keyword arguments for MoE
Source code in rl4co/models/nn/attention.py
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 |
|
MultiHeadCompat
¶
MultiHeadCompat(
n_heads,
input_dim,
embed_dim=None,
val_dim=None,
key_dim=None,
)
Bases: Module
Methods:
-
forward
–:param q: queries (batch_size, n_query, input_dim)
Source code in rl4co/models/nn/attention.py
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
|
forward
¶
forward(q, h=None, mask=None)
:param q: queries (batch_size, n_query, input_dim) :param h: data (batch_size, graph_size, input_dim) :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) :return:
Source code in rl4co/models/nn/attention.py
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 |
|
PolyNetAttention
¶
Bases: PointerAttention
Calculate logits given query, key and value and logit key. This implements a modified version the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134) as described in Hottung et al. (2024) (https://arxiv.org/abs/2402.14048) PolyNetAttention conditions the attention logits on a set of k different binary vectors allowing to learn k different solution strategies.
Note
With Flash Attention, masking is not supported
Performs the following
- Apply cross attention to get the heads
- Project heads to get glimpse
- Apply PolyNet layers
- Compute attention score between glimpse and logit key
Parameters:
-
k
(int
) –Number unique bit vectors used to compute attention score
-
embed_dim
(int
) –total dimension of the model
-
poly_layer_dim
(int
) –Dimension of the PolyNet layers
-
num_heads
(int
) –number of heads
-
mask_inner
–whether to mask inner attention
-
linear_bias
–whether to use bias in linear projection
-
check_nan
–whether to check for NaNs in logits
-
sdpa_fn
–scaled dot product attention function (SDPA) implementation
Methods:
-
forward
–Compute attention logits given query, key, value, logit key and attention mask.
Source code in rl4co/models/nn/attention.py
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 |
|
forward
¶
forward(query, key, value, logit_key, attn_mask=None)
Compute attention logits given query, key, value, logit key and attention mask.
Parameters:
-
query
–query tensor of shape [B, ..., L, E]
-
key
–key tensor of shape [B, ..., S, E]
-
value
–value tensor of shape [B, ..., S, E]
-
logit_key
–logit key tensor of shape [B, ..., S, E]
-
attn_mask
–attention mask tensor of shape [B, ..., S]. Note that
True
means that the value should take part in attention as described in the PyTorch Documentation
Source code in rl4co/models/nn/attention.py
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 |
|
scaled_dot_product_attention_simple
¶
scaled_dot_product_attention_simple(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
Simple (exact) Scaled Dot-Product Attention in RL4CO without customized kernels (i.e. no Flash Attention).
Source code in rl4co/models/nn/attention.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
|
Multi-Layer Perceptron¶
MLP
¶
MLP(
input_dim: int,
output_dim: int,
num_neurons: list[int] = [64, 32],
dropout_probs: None | list[float] = None,
hidden_act: str = "ReLU",
out_act: str = "Identity",
input_norm: str = "None",
output_norm: str = "None",
)
Bases: Module
Source code in rl4co/models/nn/mlp.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
|
Operations¶
PositionalEncoding
¶
Bases: Module
Methods:
-
forward
–Arguments:
Source code in rl4co/models/nn/ops.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
forward
¶
Parameters:
-
x
–Tensor, shape
[batch_size, seq_len, embedding_dim]
-
seq_pos
–Tensor, shape
[batch_size, seq_len]
Source code in rl4co/models/nn/ops.py
75 76 77 78 79 80 81 82 83 84 85 |
|
RandomEncoding
¶
Bases: Module
This is like torch.nn.Embedding but with rows of embeddings are randomly permuted in each forward pass before lookup operation. This might be useful in cases where classes have no fixed meaning but rather indicate a connection between different elements in a sequence. Reference is the MatNet model.
Source code in rl4co/models/nn/ops.py
118 119 120 121 122 123 |
|