Skip to content

Constructive Policies Base Classes

ConstructiveEncoder

Bases: Module

Base class for the encoder of constructive models

forward abstractmethod

forward(td: TensorDict) -> Tuple[Any, Tensor]

Forward pass for the encoder

Parameters:

  • td (TensorDict) –

    TensorDict containing the input data

Returns:

  • Tuple[Any, Tensor]

    Tuple containing:

    • latent representation (any type)
    • initial embeddings (from feature space to embedding space)
Source code in rl4co/models/common/constructive/base.py
25
26
27
28
29
30
31
32
33
34
35
36
37
@abc.abstractmethod
def forward(self, td: TensorDict) -> Tuple[Any, Tensor]:
    """Forward pass for the encoder

    Args:
        td: TensorDict containing the input data

    Returns:
        Tuple containing:
          - latent representation (any type)
          - initial embeddings (from feature space to embedding space)
    """
    raise NotImplementedError("Implement me in subclass!")

ConstructiveDecoder

Bases: Module

Base decoder model for constructive models. The decoder is responsible for generating the logits for the action

forward abstractmethod

forward(
    td: TensorDict, hidden: Any = None, num_starts: int = 0
) -> Tuple[Tensor, Tensor]

Obtain logits for current action to the next ones

Parameters:

  • td (TensorDict) –

    TensorDict containing the input data

  • hidden (Any, default: None ) –

    Hidden state from the encoder. Can be any type

  • num_starts (int, default: 0 ) –

    Number of starts for multistart decoding

Returns:

Source code in rl4co/models/common/constructive/base.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@abc.abstractmethod
def forward(
    self, td: TensorDict, hidden: Any = None, num_starts: int = 0
) -> Tuple[Tensor, Tensor]:
    """Obtain logits for current action to the next ones

    Args:
        td: TensorDict containing the input data
        hidden: Hidden state from the encoder. Can be any type
        num_starts: Number of starts for multistart decoding

    Returns:
        Tuple containing the logits and the action mask
    """
    raise NotImplementedError("Implement me in subclass!")

pre_decoder_hook

pre_decoder_hook(
    td: TensorDict,
    env: RL4COEnvBase,
    hidden: Any = None,
    num_starts: int = 0,
) -> Tuple[TensorDict, Any, RL4COEnvBase]

By default, we don't need to do anything here.

Parameters:

  • td (TensorDict) –

    TensorDict containing the input data

  • hidden (Any, default: None ) –

    Hidden state from the encoder

  • env (RL4COEnvBase) –

    Environment for decoding

  • num_starts (int, default: 0 ) –

    Number of starts for multistart decoding

Returns:

  • Tuple[TensorDict, Any, RL4COEnvBase]

    Tuple containing the updated hidden state, TensorDict, and environment

Source code in rl4co/models/common/constructive/base.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def pre_decoder_hook(
    self, td: TensorDict, env: RL4COEnvBase, hidden: Any = None, num_starts: int = 0
) -> Tuple[TensorDict, Any, RL4COEnvBase]:
    """By default, we don't need to do anything here.

    Args:
        td: TensorDict containing the input data
        hidden: Hidden state from the encoder
        env: Environment for decoding
        num_starts: Number of starts for multistart decoding

    Returns:
        Tuple containing the updated hidden state, TensorDict, and environment
    """
    return td, env, hidden

NoEncoder

Bases: ConstructiveEncoder

Default encoder decoder-only models, i.e. autoregressive models that re-encode all the state at each decoding step.

forward

forward(td: TensorDict) -> Tuple[Tensor, Tensor]

Return Nones for the hidden state and initial embeddings

Source code in rl4co/models/common/constructive/base.py
79
80
81
def forward(self, td: TensorDict) -> Tuple[Tensor, Tensor]:
    """Return Nones for the hidden state and initial embeddings"""
    return None, None

ConstructivePolicy

ConstructivePolicy(
    encoder: Union[ConstructiveEncoder, Callable],
    decoder: Union[ConstructiveDecoder, Callable],
    env_name: str = "tsp",
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    **unused_kw
)

Bases: Module

Base class for constructive policies. Constructive policies take as input and instance and output a solution (sequence of actions). "Constructive" means that a solution is created from scratch by the model.

The structure follows roughly the following steps
  1. Create a hidden state from the encoder
  2. Initialize decoding strategy (such as greedy, sampling, etc.)
  3. Decode the action given the hidden state and the environment state at the current step
  4. Update the environment state with the action. Repeat 3-4 until all sequences are done
  5. Obtain log likelihood, rewards etc.

Note that an encoder is not strictly needed (see :class:NoEncoder).). A decoder however is always needed either in the form of a network or a function.

Note

There are major differences between this decoding and most RL problems. The most important one is that reward may not defined for partial solutions, hence we have to wait for the environment to reach a terminal state before we can compute the reward with env.get_reward().

Warning

We suppose environments in the done state are still available for sampling. This is because in NCO we need to wait for all the environments to reach a terminal state before we can stop the decoding process. This is in contrast with the TorchRL framework (at the moment) where the env.rollout function automatically resets. You may follow tighter integration with TorchRL here: https://github.com/ai4co/rl4co/issues/72.

Parameters:

  • encoder (Union[ConstructiveEncoder, Callable]) –

    Encoder to use

  • decoder (Union[ConstructiveDecoder, Callable]) –

    Decoder to use

  • env_name (str, default: 'tsp' ) –

    Environment name to solve (used for automatically instantiating networks)

  • temperature (float, default: 1.0 ) –

    Temperature for the softmax during decoding

  • tanh_clipping (float, default: 0 ) –

    Clipping value for the tanh activation (see Bello et al. 2016) during decoding

  • mask_logits (bool, default: True ) –

    Whether to mask the logits or not during decoding

  • train_decode_type (str, default: 'sampling' ) –

    Decoding strategy for training

  • val_decode_type (str, default: 'greedy' ) –

    Decoding strategy for validation

  • test_decode_type (str, default: 'greedy' ) –

    Decoding strategy for testing

Source code in rl4co/models/common/constructive/base.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    encoder: Union[ConstructiveEncoder, Callable],
    decoder: Union[ConstructiveDecoder, Callable],
    env_name: str = "tsp",
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    **unused_kw,
):
    super(ConstructivePolicy, self).__init__()

    if len(unused_kw) > 0:
        log.error(f"Found {len(unused_kw)} unused kwargs: {unused_kw}")

    self.env_name = env_name

    # Encoder and decoder
    if encoder is None:
        log.warning("`None` was provided as encoder. Using `NoEncoder`.")
        encoder = NoEncoder()
    self.encoder = encoder
    self.decoder = decoder

    # Decoding strategies
    self.temperature = temperature
    self.tanh_clipping = tanh_clipping
    self.mask_logits = mask_logits
    self.train_decode_type = train_decode_type
    self.val_decode_type = val_decode_type
    self.test_decode_type = test_decode_type

forward

forward(
    td: TensorDict,
    env: Optional[Union[str, RL4COEnvBase]] = None,
    phase: str = "train",
    calc_reward: bool = True,
    return_actions: bool = False,
    return_entropy: bool = False,
    return_hidden: bool = False,
    return_init_embeds: bool = False,
    return_sum_log_likelihood: bool = True,
    actions=None,
    max_steps=1000000,
    **decoding_kwargs
) -> dict

Forward pass of the policy.

Parameters:

  • td (TensorDict) –

    TensorDict containing the environment state

  • env (Optional[Union[str, RL4COEnvBase]], default: None ) –

    Environment to use for decoding. If None, the environment is instantiated from env_name. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control

  • phase (str, default: 'train' ) –

    Phase of the algorithm (train, val, test)

  • calc_reward (bool, default: True ) –

    Whether to calculate the reward

  • return_actions (bool, default: False ) –

    Whether to return the actions

  • return_entropy (bool, default: False ) –

    Whether to return the entropy

  • return_hidden (bool, default: False ) –

    Whether to return the hidden state

  • return_init_embeds (bool, default: False ) –

    Whether to return the initial embeddings

  • return_sum_log_likelihood (bool, default: True ) –

    Whether to return the sum of the log likelihood

  • actions

    Actions to use for evaluating the policy. If passed, use these actions instead of sampling from the policy to calculate log likelihood

  • max_steps

    Maximum number of decoding steps for sanity check to avoid infinite loops if envs are buggy (i.e. do not reach done)

  • decoding_kwargs

    Keyword arguments for the decoding strategy. See :class:rl4co.utils.decoding.DecodingStrategy for more information.

Returns:

  • out ( dict ) –

    Dictionary containing the reward, log likelihood, and optionally the actions and entropy

Source code in rl4co/models/common/constructive/base.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def forward(
    self,
    td: TensorDict,
    env: Optional[Union[str, RL4COEnvBase]] = None,
    phase: str = "train",
    calc_reward: bool = True,
    return_actions: bool = False,
    return_entropy: bool = False,
    return_hidden: bool = False,
    return_init_embeds: bool = False,
    return_sum_log_likelihood: bool = True,
    actions=None,
    max_steps=1_000_000,
    **decoding_kwargs,
) -> dict:
    """Forward pass of the policy.

    Args:
        td: TensorDict containing the environment state
        env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that
            it is more efficient to pass an already instantiated environment each time for fine-grained control
        phase: Phase of the algorithm (train, val, test)
        calc_reward: Whether to calculate the reward
        return_actions: Whether to return the actions
        return_entropy: Whether to return the entropy
        return_hidden: Whether to return the hidden state
        return_init_embeds: Whether to return the initial embeddings
        return_sum_log_likelihood: Whether to return the sum of the log likelihood
        actions: Actions to use for evaluating the policy.
            If passed, use these actions instead of sampling from the policy to calculate log likelihood
        max_steps: Maximum number of decoding steps for sanity check to avoid infinite loops if envs are buggy (i.e. do not reach `done`)
        decoding_kwargs: Keyword arguments for the decoding strategy. See :class:`rl4co.utils.decoding.DecodingStrategy` for more information.

    Returns:
        out: Dictionary containing the reward, log likelihood, and optionally the actions and entropy
    """

    # Encoder: get encoder output and initial embeddings from initial state
    hidden, init_embeds = self.encoder(td)

    # Instantiate environment if needed
    if isinstance(env, str) or env is None:
        env_name = self.env_name if env is None else env
        log.info(f"Instantiated environment not provided; instantiating {env_name}")
        env = get_env(env_name)

    # Get decode type depending on phase and whether actions are passed for evaluation
    decode_type = decoding_kwargs.pop("decode_type", None)
    if actions is not None:
        decode_type = "evaluate"
    elif decode_type is None:
        decode_type = getattr(self, f"{phase}_decode_type")

    # Setup decoding strategy
    # we pop arguments that are not part of the decoding strategy
    decode_strategy: DecodingStrategy = get_decoding_strategy(
        decode_type,
        temperature=decoding_kwargs.pop("temperature", self.temperature),
        tanh_clipping=decoding_kwargs.pop("tanh_clipping", self.tanh_clipping),
        mask_logits=decoding_kwargs.pop("mask_logits", self.mask_logits),
        store_all_logp=decoding_kwargs.pop("store_all_logp", return_entropy),
        **decoding_kwargs,
    )

    # Pre-decoding hook: used for the initial step(s) of the decoding strategy
    td, env, num_starts = decode_strategy.pre_decoder_hook(td, env)

    # Additionally call a decoder hook if needed before main decoding
    td, env, hidden = self.decoder.pre_decoder_hook(td, env, hidden, num_starts)

    # Main decoding: loop until all sequences are done
    step = 0
    while not td["done"].all():
        logits, mask = self.decoder(td, hidden, num_starts)
        td = decode_strategy.step(
            logits,
            mask,
            td,
            action=actions[..., step] if actions is not None else None,
        )
        td = env.step(td)["next"]
        step += 1
        if step > max_steps:
            log.error(
                f"Exceeded maximum number of steps ({max_steps}) duing decoding"
            )
            break

    # Post-decoding hook: used for the final step(s) of the decoding strategy
    logprobs, actions, td, env = decode_strategy.post_decoder_hook(td, env)

    # Output dictionary construction
    if calc_reward:
        td.set("reward", env.get_reward(td, actions))

    outdict = {
        "reward": td["reward"],
        "log_likelihood": get_log_likelihood(
            logprobs, actions, td.get("mask", None), return_sum_log_likelihood
        ),
    }

    if return_actions:
        outdict["actions"] = actions
    if return_entropy:
        outdict["entropy"] = calculate_entropy(logprobs)
    if return_hidden:
        outdict["hidden"] = hidden
    if return_init_embeds:
        outdict["init_embeds"] = init_embeds

    return outdict

Autoregressive Policies

AutoregressiveEncoder

Bases: ConstructiveEncoder

Template class for an autoregressive encoder, simple wrapper around :class:rl4co.models.common.constructive.base.ConstructiveEncoder.

Tip

This class will not work as it is and is just a template. An example for autoregressive encoder can be found as :class:rl4co.models.zoo.am.encoder.AttentionModelEncoder.

AutoregressiveDecoder

Bases: ConstructiveDecoder

Template class for an autoregressive decoder, simple wrapper around :class:rl4co.models.common.constructive.base.ConstructiveDecoder

Tip

This class will not work as it is and is just a template. An example for autoregressive encoder can be found as :class:rl4co.models.zoo.am.decoder.AttentionModelDecoder.

AutoregressivePolicy

AutoregressivePolicy(
    encoder: AutoregressiveEncoder,
    decoder: AutoregressiveDecoder,
    env_name: str = "tsp",
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    **unused_kw
)

Bases: ConstructivePolicy

Template class for an autoregressive policy, simple wrapper around :class:rl4co.models.common.constructive.base.ConstructivePolicy.

Note

While a decoder is required, an encoder is optional and will be initialized to :class:rl4co.models.common.constructive.autoregressive.encoder.NoEncoder. This can be used in decoder-only models in which at each step actions do not depend on previously encoded states.

Source code in rl4co/models/common/constructive/autoregressive/policy.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    encoder: AutoregressiveEncoder,
    decoder: AutoregressiveDecoder,
    env_name: str = "tsp",
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    **unused_kw,
):
    # We raise an error for the user if no decoder was provided
    if decoder is None:
        raise ValueError("AutoregressivePolicy requires a decoder to be provided.")

    super(AutoregressivePolicy, self).__init__(
        encoder=encoder,
        decoder=decoder,
        env_name=env_name,
        temperature=temperature,
        tanh_clipping=tanh_clipping,
        mask_logits=mask_logits,
        train_decode_type=train_decode_type,
        val_decode_type=val_decode_type,
        test_decode_type=test_decode_type,
        **unused_kw,
    )

Nonautoregressive Policies

NonAutoregressiveEncoder

Bases: ConstructiveEncoder

Template class for an autoregressive encoder, simple wrapper around :class:rl4co.models.common.constructive.base.ConstructiveEncoder.

Tip

This class will not work as it is and is just a template. An example for autoregressive encoder can be found as :class:rl4co.models.zoo.am.encoder.AttentionModelEncoder.

NonAutoregressiveDecoder

Bases: ConstructiveDecoder

The nonautoregressive decoder is a simple callable class that takes the tensor dictionary and the heatmaps logits and returns the logits for the current action logits and the action mask.

heatmap_to_logits staticmethod

heatmap_to_logits(
    td: TensorDict, heatmaps_logits: Tensor, num_starts: int
)

Obtain heatmap logits for current action to the next ones

Source code in rl4co/models/common/constructive/nonautoregressive/decoder.py
30
31
32
33
34
35
36
37
38
39
40
@staticmethod
def heatmap_to_logits(td: TensorDict, heatmaps_logits: torch.Tensor, num_starts: int):
    """Obtain heatmap logits for current action to the next ones"""
    current_action = td.get("action", None)
    if current_action is None:
        logits = heatmaps_logits.mean(-1)
    else:
        batch_size = heatmaps_logits.shape[0]
        _indexer = _multistart_batched_index(batch_size, num_starts)
        logits = heatmaps_logits[_indexer, current_action, :]
    return logits, td["action_mask"]

NonAutoregressivePolicy

NonAutoregressivePolicy(
    encoder: NonAutoregressiveEncoder,
    decoder: NonAutoregressiveDecoder = None,
    env_name: str = "tsp",
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    **unused_kw
)

Bases: ConstructivePolicy

Template class for an nonautoregressive policy, simple wrapper around :class:rl4co.models.common.constructive.base.ConstructivePolicy.

Source code in rl4co/models/common/constructive/nonautoregressive/policy.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    encoder: NonAutoregressiveEncoder,
    decoder: NonAutoregressiveDecoder = None,
    env_name: str = "tsp",
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
    train_decode_type: str = "sampling",
    val_decode_type: str = "greedy",
    test_decode_type: str = "greedy",
    **unused_kw,
):
    # If decoder is not passed, we default to the non-autoregressive decoder that decodes the heatmap
    if decoder is None:
        decoder = NonAutoregressiveDecoder()

    super(NonAutoregressivePolicy, self).__init__(
        encoder=encoder,
        decoder=decoder,
        env_name=env_name,
        temperature=temperature,
        tanh_clipping=tanh_clipping,
        mask_logits=mask_logits,
        train_decode_type=train_decode_type,
        val_decode_type=val_decode_type,
        test_decode_type=test_decode_type,
        **unused_kw,
    )

Improvement Policies (Base Classes)

ImprovementEncoder

ImprovementEncoder(
    embed_dim: int = 128,
    init_embedding: Module = None,
    pos_embedding: Module = None,
    env_name: str = "pdp_ruin_repair",
    pos_type: str = "CPE",
    num_heads: int = 4,
    num_layers: int = 3,
    normalization: str = "layer",
    feedforward_hidden: int = 128,
    linear_bias: bool = False,
)

Bases: Module

Base class for the encoder of improvement models

Source code in rl4co/models/common/improvement/base.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    embed_dim: int = 128,
    init_embedding: nn.Module = None,
    pos_embedding: nn.Module = None,
    env_name: str = "pdp_ruin_repair",
    pos_type: str = "CPE",
    num_heads: int = 4,
    num_layers: int = 3,
    normalization: str = "layer",
    feedforward_hidden: int = 128,
    linear_bias: bool = False,
):
    super(ImprovementEncoder, self).__init__()

    if isinstance(env_name, RL4COEnvBase):
        env_name = env_name.name
    self.env_name = env_name
    self.init_embedding = (
        env_init_embedding(
            self.env_name, {"embed_dim": embed_dim, "linear_bias": linear_bias}
        )
        if init_embedding is None
        else init_embedding
    )

    self.pos_type = pos_type
    self.pos_embedding = (
        pos_init_embedding(self.pos_type, {"embed_dim": embed_dim})
        if pos_embedding is None
        else pos_embedding
    )

forward

forward(td: TensorDict) -> Tuple[Tensor, Tensor]

Forward pass of the encoder. Transform the input TensorDict into a latent representation.

Parameters:

  • td (TensorDict) –

    Input TensorDict containing the environment state

Returns:

  • h ( Tensor ) –

    Latent representation of the input

  • init_h ( Tensor ) –

    Initial embedding of the input

Source code in rl4co/models/common/improvement/base.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def forward(self, td: TensorDict) -> Tuple[Tensor, Tensor]:
    """Forward pass of the encoder.
    Transform the input TensorDict into a latent representation.

    Args:
        td: Input TensorDict containing the environment state

    Returns:
        h: Latent representation of the input
        init_h: Initial embedding of the input
    """
    # Transfer to embedding space (node)
    init_h = self.init_embedding(td)

    # Transfer to embedding space (solution)
    init_p = self.pos_embedding(td)

    # Process embedding
    final_h, final_p = self._encoder_forward(init_h, init_p)

    # Return latent representation and initial embedding
    return final_h, final_p

ImprovementDecoder

Bases: Module

Base decoder model for improvement models. The decoder is responsible for generating the logits of the action

forward abstractmethod

forward(
    td: TensorDict, final_h: Tensor, final_p: Tensor
) -> Tensor

Obtain logits to perform operators that improve the current solution to the next ones

Parameters:

  • td (TensorDict) –

    TensorDict with the current environment state

  • final_h (Tensor) –

    final node embeddings

  • final_p (Tensor) –

    final positional embeddings

Returns:

  • Tensor

    Tuple containing the logits

Source code in rl4co/models/common/improvement/base.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@abc.abstractmethod
def forward(self, td: TensorDict, final_h: Tensor, final_p: Tensor) -> Tensor:
    """Obtain logits to perform operators that improve the current solution to the next ones

    Args:
        td: TensorDict with the current environment state
        final_h: final node embeddings
        final_p: final positional embeddings

    Returns:
        Tuple containing the logits
    """
    raise NotImplementedError("Implement me in subclass!")

ImprovementPolicy

Bases: Module

Base class for improvement policies. Improvement policies take an instance + a solution as input and output a specific operator that changes the current solution to a new one.

"Improvement" means that a solution is (potentially) improved to a new one by the model.

forward abstractmethod

forward(
    td: TensorDict,
    env: Union[str, RL4COEnvBase] = None,
    phase: str = "train",
    return_actions: bool = False,
    return_entropy: bool = False,
    return_init_embeds: bool = False,
    actions=None,
    **decoding_kwargs
) -> dict

Forward pass of the policy.

Parameters:

  • td (TensorDict) –

    TensorDict containing the environment state

  • env (Union[str, RL4COEnvBase], default: None ) –

    Environment to use for decoding. If None, the environment is instantiated from env_name. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control

  • phase (str, default: 'train' ) –

    Phase of the algorithm (train, val, test)

  • return_actions (bool, default: False ) –

    Whether to return the actions

  • return_entropy (bool, default: False ) –

    Whether to return the entropy

  • return_init_embeds (bool, default: False ) –

    Whether to return the initial embeddings

  • actions

    Actions to use for evaluating the policy. If passed, use these actions instead of sampling from the policy to calculate log likelihood

  • decoding_kwargs

    Keyword arguments for the decoding strategy. See :class:rl4co.utils.decoding.DecodingStrategy for more information.

Returns:

  • out ( dict ) –

    Dictionary containing the reward, log likelihood, and optionally the actions and entropy

Source code in rl4co/models/common/improvement/base.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@abc.abstractmethod
def forward(
    self,
    td: TensorDict,
    env: Union[str, RL4COEnvBase] = None,
    phase: str = "train",
    return_actions: bool = False,
    return_entropy: bool = False,
    return_init_embeds: bool = False,
    actions=None,
    **decoding_kwargs,
) -> dict:
    """Forward pass of the policy.

    Args:
        td: TensorDict containing the environment state
        env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that
            it is more efficient to pass an already instantiated environment each time for fine-grained control
        phase: Phase of the algorithm (train, val, test)
        return_actions: Whether to return the actions
        return_entropy: Whether to return the entropy
        return_init_embeds: Whether to return the initial embeddings
        actions: Actions to use for evaluating the policy.
            If passed, use these actions instead of sampling from the policy to calculate log likelihood
        decoding_kwargs: Keyword arguments for the decoding strategy. See :class:`rl4co.utils.decoding.DecodingStrategy` for more information.

    Returns:
        out: Dictionary containing the reward, log likelihood, and optionally the actions and entropy
    """
    raise NotImplementedError("Implement me in subclass!")