Neural Network Modules¶
Critic Network¶
CriticNetwork
¶
CriticNetwork(
encoder: Module,
value_head: Optional[Module] = None,
embed_dim: int = 128,
hidden_dim: int = 512,
customized: bool = False,
)
Bases: Module
Create a critic network given an encoder (e.g. as the one in the policy network) with a value head to transform the embeddings to a scalar value.
Parameters:
-
encoder
(Module
) –Encoder module to encode the input
-
value_head
(Optional[Module]
, default:None
) –Value head to transform the embeddings to a scalar value
-
embed_dim
(int
, default:128
) –Dimension of the embeddings of the value head
-
hidden_dim
(int
, default:512
) –Dimension of the hidden layer of the value head
Source code in rl4co/models/rl/common/critic.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
forward
¶
forward(
x: Union[Tensor, TensorDict], hidden=None
) -> Tensor
Forward pass of the critic network: encode the imput in embedding space and return the value
Parameters:
-
x
(Union[Tensor, TensorDict]
) –Input containing the environment state. Can be a Tensor or a TensorDict
Returns:
-
Tensor
–Value of the input state
Source code in rl4co/models/rl/common/critic.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|
Graph Neural Networks¶
MultiHeadAttentionLayer
¶
MultiHeadAttentionLayer(
embed_dim: int,
num_heads: int = 8,
feedforward_hidden: int = 512,
normalization: Optional[str] = "batch",
bias: bool = True,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
)
Bases: Sequential
Multi-Head Attention Layer with normalization and feed-forward layer
Parameters:
-
embed_dim
(int
) –dimension of the embeddings
-
num_heads
(int
, default:8
) –number of heads in the MHA
-
feedforward_hidden
(int
, default:512
) –dimension of the hidden layer in the feed-forward layer
-
normalization
(Optional[str]
, default:'batch'
) –type of normalization to use (batch, layer, none)
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA)
-
moe_kwargs
(Optional[dict]
, default:None
) –Keyword arguments for MoE
Source code in rl4co/models/nn/graph/attnnet.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
|
GraphAttentionNetwork
¶
GraphAttentionNetwork(
num_heads: int,
embed_dim: int,
num_layers: int,
normalization: str = "batch",
feedforward_hidden: int = 512,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
)
Bases: Module
Graph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).
Parameters:
-
num_heads
(int
) –number of heads in the MHA
-
embed_dim
(int
) –dimension of the embeddings
-
num_layers
(int
) –number of MHA layers
-
normalization
(str
, default:'batch'
) –type of normalization to use (batch, layer, none)
-
feedforward_hidden
(int
, default:512
) –dimension of the hidden layer in the feed-forward layer
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA)
-
moe_kwargs
(Optional[dict]
, default:None
) –Keyword arguments for MoE
Source code in rl4co/models/nn/graph/attnnet.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
forward
¶
Forward pass of the encoder
Parameters:
-
x
(Tensor
) –[batch_size, graph_size, embed_dim] initial embeddings to process
-
mask
(Optional[Tensor]
, default:None
) –[batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now.
Source code in rl4co/models/nn/graph/attnnet.py
94 95 96 97 98 99 100 101 102 103 |
|
GCNEncoder
¶
GCNEncoder(
env_name: str,
embed_dim: int,
num_layers: int,
init_embedding: Module = None,
residual: bool = True,
edge_idx_fn: EdgeIndexFnSignature = None,
dropout: float = 0.5,
bias: bool = True,
)
Bases: Module
Graph Convolutional Network to encode embeddings with a series of GCN layers from the pytorch geometric package
Parameters:
-
embed_dim
(int
) –dimension of the embeddings
-
num_nodes
–number of nodes in the graph
-
num_gcn_layer
–number of GCN layers
-
self_loop
–whether to add self loop in the graph
-
residual
(bool
, default:True
) –whether to use residual connection
Source code in rl4co/models/nn/graph/gcn.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
forward
¶
Forward pass of the encoder. Transform the input TensorDict into a latent representation.
Parameters:
-
td
(TensorDict
) –Input TensorDict containing the environment state
-
mask
(Union[Tensor, None]
, default:None
) –Mask to apply to the attention
Returns:
-
h
(Tensor
) –Latent representation of the input
-
init_h
(Tensor
) –Initial embedding of the input
Source code in rl4co/models/nn/graph/gcn.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
MessagePassingEncoder
¶
MessagePassingEncoder(
env_name: str,
embed_dim: int,
num_nodes: int,
num_layers: int,
init_embedding: Module = None,
aggregation: str = "add",
self_loop: bool = False,
residual: bool = True,
)
Bases: Module
Source code in rl4co/models/nn/graph/mpnn.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
Attention Mechanisms¶
MultiHeadAttention
¶
MultiHeadAttention(
embed_dim: int,
num_heads: int,
bias: bool = True,
attention_dropout: float = 0.0,
causal: bool = False,
device: str = None,
dtype: dtype = None,
sdpa_fn: Optional[Callable] = None,
)
Bases: Module
PyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support.
Uses PyTorch's native scaled_dot_product_attention
implementation, available from 2.0
Note
If scaled_dot_product_attention
is not available, use custom implementation of scaled_dot_product_attention
without Flash Attention.
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
bias
(bool
, default:True
) –whether to use bias
-
attention_dropout
(float
, default:0.0
) –dropout rate for attention weights
-
causal
(bool
, default:False
) –whether to apply causal mask to attention scores
-
device
(str
, default:None
) –torch device
-
dtype
(dtype
, default:None
) –torch dtype
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA) implementation
Source code in rl4co/models/nn/attention.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
|
forward
¶
forward(x, attn_mask=None)
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) attn_mask: bool tensor of shape (batch, seqlen)
Source code in rl4co/models/nn/attention.py
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
|
MultiHeadCrossAttention
¶
MultiHeadCrossAttention(
embed_dim: int,
num_heads: int,
bias: bool = False,
attention_dropout: float = 0.0,
device: str = None,
dtype: dtype = None,
sdpa_fn: Optional[Union[Callable, Module]] = None,
)
Bases: Module
PyTorch native implementation of Flash Multi-Head Cross Attention with automatic mixed precision support.
Uses PyTorch's native scaled_dot_product_attention
implementation, available from 2.0
Note
If scaled_dot_product_attention
is not available, use custom implementation of scaled_dot_product_attention
without Flash Attention.
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
bias
(bool
, default:False
) –whether to use bias
-
attention_dropout
(float
, default:0.0
) –dropout rate for attention weights
-
device
(str
, default:None
) –torch device
-
dtype
(dtype
, default:None
) –torch dtype
-
sdpa_fn
(Optional[Union[Callable, Module]]
, default:None
) –scaled dot product attention function (SDPA)
Source code in rl4co/models/nn/attention.py
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
|
PointerAttention
¶
PointerAttention(
embed_dim: int,
num_heads: int,
mask_inner: bool = True,
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Optional[Callable] = None,
**kwargs
)
Bases: Module
Calculate logits given query, key and value and logit key. This follows the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134).
Note
With Flash Attention, masking is not supported
Performs the following
- Apply cross attention to get the heads
- Project heads to get glimpse
- Compute attention score between glimpse and logit key
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
mask_inner
(bool
, default:True
) –whether to mask inner attention
-
linear_bias
–whether to use bias in linear projection
-
check_nan
(bool
, default:True
) –whether to check for NaNs in logits
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA) implementation
Source code in rl4co/models/nn/attention.py
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
|
forward
¶
forward(query, key, value, logit_key, attn_mask=None)
Compute attention logits given query, key, value, logit key and attention mask.
Parameters:
-
query
–query tensor of shape [B, ..., L, E]
-
key
–key tensor of shape [B, ..., S, E]
-
value
–value tensor of shape [B, ..., S, E]
-
logit_key
–logit key tensor of shape [B, ..., S, E]
-
attn_mask
–attention mask tensor of shape [B, ..., S]. Note that
True
means that the value should take part in attention as described in the PyTorch Documentation
Source code in rl4co/models/nn/attention.py
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
|
PointerAttnMoE
¶
PointerAttnMoE(
embed_dim: int,
num_heads: int,
mask_inner: bool = True,
out_bias: bool = False,
check_nan: bool = True,
sdpa_fn: Optional[Callable] = None,
moe_kwargs: Optional[dict] = None,
)
Bases: PointerAttention
Calculate logits given query, key and value and logit key. This follows the pointer mechanism of Vinyals et al. (2015) https://arxiv.org/abs/1506.03134, and the MoE gating mechanism of Zhou et al. (2024) https://arxiv.org/abs/2405.01029.
Note
With Flash Attention, masking is not supported
Performs the following
- Apply cross attention to get the heads
- Project heads to get glimpse
- Compute attention score between glimpse and logit key
Parameters:
-
embed_dim
(int
) –total dimension of the model
-
num_heads
(int
) –number of heads
-
mask_inner
(bool
, default:True
) –whether to mask inner attention
-
linear_bias
–whether to use bias in linear projection
-
check_nan
(bool
, default:True
) –whether to check for NaNs in logits
-
sdpa_fn
(Optional[Callable]
, default:None
) –scaled dot product attention function (SDPA) implementation
-
moe_kwargs
(Optional[dict]
, default:None
) –Keyword arguments for MoE
Source code in rl4co/models/nn/attention.py
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
|
MultiHeadCompat
¶
MultiHeadCompat(
n_heads,
input_dim,
embed_dim=None,
val_dim=None,
key_dim=None,
)
Bases: Module
Source code in rl4co/models/nn/attention.py
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 |
|
forward
¶
forward(q, h=None, mask=None)
:param q: queries (batch_size, n_query, input_dim) :param h: data (batch_size, graph_size, input_dim) :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) :return:
Source code in rl4co/models/nn/attention.py
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 |
|
PolyNetAttention
¶
Bases: PointerAttention
Calculate logits given query, key and value and logit key. This implements a modified version the pointer mechanism of Vinyals et al. (2015) (https://arxiv.org/abs/1506.03134) as described in Hottung et al. (2024) (https://arxiv.org/abs/2402.14048) PolyNetAttention conditions the attention logits on a set of k different binary vectors allowing to learn k different solution strategies.
Note
With Flash Attention, masking is not supported
Performs the following
- Apply cross attention to get the heads
- Project heads to get glimpse
- Apply PolyNet layers
- Compute attention score between glimpse and logit key
Parameters:
-
k
(int
) –Number unique bit vectors used to compute attention score
-
embed_dim
(int
) –total dimension of the model
-
poly_layer_dim
(int
) –Dimension of the PolyNet layers
-
num_heads
(int
) –number of heads
-
mask_inner
–whether to mask inner attention
-
linear_bias
–whether to use bias in linear projection
-
check_nan
–whether to check for NaNs in logits
-
sdpa_fn
–scaled dot product attention function (SDPA) implementation
Source code in rl4co/models/nn/attention.py
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 |
|
forward
¶
forward(query, key, value, logit_key, attn_mask=None)
Compute attention logits given query, key, value, logit key and attention mask.
Parameters:
-
query
–query tensor of shape [B, ..., L, E]
-
key
–key tensor of shape [B, ..., S, E]
-
value
–value tensor of shape [B, ..., S, E]
-
logit_key
–logit key tensor of shape [B, ..., S, E]
-
attn_mask
–attention mask tensor of shape [B, ..., S]. Note that
True
means that the value should take part in attention as described in the PyTorch Documentation
Source code in rl4co/models/nn/attention.py
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 |
|
scaled_dot_product_attention_simple
¶
scaled_dot_product_attention_simple(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
)
Simple Scaled Dot-Product Attention in PyTorch without Flash Attention
Source code in rl4co/models/nn/attention.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
|
Multi-Layer Perceptron¶
MLP
¶
MLP(
input_dim: int,
output_dim: int,
num_neurons: List[int] = [64, 32],
dropout_probs: Union[None, List[float]] = None,
hidden_act: str = "ReLU",
out_act: str = "Identity",
input_norm: str = "None",
output_norm: str = "None",
)
Bases: Module
Source code in rl4co/models/nn/mlp.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
|
Operations¶
PositionalEncoding
¶
Bases: Module
Source code in rl4co/models/nn/ops.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
forward
¶
Parameters:
-
x
–Tensor, shape
[batch_size, seq_len, embedding_dim]
-
seq_pos
–Tensor, shape
[batch_size, seq_len]
Source code in rl4co/models/nn/ops.py
75 76 77 78 79 80 81 82 83 84 85 |
|
RandomEncoding
¶
Bases: Module
This is like torch.nn.Embedding but with rows of embeddings are randomly permuted in each forward pass before lookup operation. This might be useful in cases where classes have no fixed meaning but rather indicate a connection between different elements in a sequence. Reference is the MatNet model.
Source code in rl4co/models/nn/ops.py
118 119 120 121 122 123 |
|