Skip to content

Decoding Strategies

DecodingStrategy

DecodingStrategy(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Base class for decoding strategies. Subclasses should implement the :meth:_step method. Includes hooks for pre and post main decoding operations.

Parameters:

  • temperature (float, default: 1.0 ) –

    Temperature scaling. Higher values make the distribution more uniform (exploration), lower values make it more peaky (exploitation). Defaults to 1.0.

  • top_p (float, default: 0.0 ) –

    Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Defaults to 0.0.

  • top_k (int, default: 0 ) –

    Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Defaults to 0.

  • mask_logits (bool, default: True ) –

    Whether to mask logits of infeasible actions. Defaults to True.

  • tanh_clipping (float, default: 0 ) –

    Tanh clipping (https://arxiv.org/abs/1611.09940). Defaults to 0.

  • multistart (bool, default: False ) –

    Whether to use multistart decoding. Defaults to False.

  • multisample (bool, default: False ) –

    Whether to use sampling decoding. Defaults to False.

  • num_starts (Optional[int], default: None ) –

    Number of starts for multistart decoding. Defaults to None.

Source code in rl4co/utils/decoding.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    self.multistart = multistart
    self.multisample = multisample
    self.num_starts = num_starts
    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

pre_decoder_hook

pre_decoder_hook(
    td: TensorDict, env: RL4COEnvBase, action: Tensor = None
)

Pre decoding hook. This method is called before the main decoding operation.

Source code in rl4co/utils/decoding.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def pre_decoder_hook(
    self, td: TensorDict, env: RL4COEnvBase, action: torch.Tensor = None
):
    """Pre decoding hook. This method is called before the main decoding operation."""

    # Multi-start decoding. If num_starts is None, we use the number of actions in the action mask
    if self.multistart or self.multisample:
        if self.num_starts is None:
            self.num_starts = env.get_num_starts(td)
            if self.multisample:
                log.warn(
                    f"num_starts is not provided for sampling, using num_starts={self.num_starts}"
                )
    else:
        if self.num_starts is not None:
            if self.num_starts >= 1:
                log.warn(
                    f"num_starts={self.num_starts} is ignored for decode_type={self.name}"
                )

        self.num_starts = 0

    # Multi-start decoding: first action is chosen by ad-hoc node selection
    if self.num_starts >= 1:
        if self.multistart:
            if action is None:  # if action is provided, we use it as the first action
                if self.select_start_nodes_fn is not None:
                    action = self.select_start_nodes_fn(td, env, self.num_starts)
                else:
                    action = env.select_start_nodes(td, num_starts=self.num_starts)

            # Expand td to batch_size * num_starts
            td = batchify(td, self.num_starts)

            td.set("action", action)
            td = env.step(td)["next"]
            # first logprobs is 0, so p = logprobs.exp() = 1
            if self.store_all_logp:
                logprobs = torch.zeros_like(td["action_mask"])  # [B, N]
            else:
                logprobs = torch.zeros_like(action, device=td.device)  # [B]

            self.logprobs.append(logprobs)
            self.actions.append(action)
        else:
            # Expand td to batch_size * num_samplestarts
            td = batchify(td, self.num_starts)

    return td, env, self.num_starts

step

step(
    logits: Tensor,
    mask: Tensor,
    td: TensorDict = None,
    action: Tensor = None,
    **kwargs
) -> TensorDict

Main decoding operation. This method should be called in a loop until all sequences are done.

Parameters:

  • logits (Tensor) –

    Logits from the model.

  • mask (Tensor) –

    Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).

  • td (TensorDict, default: None ) –

    TensorDict containing the current state of the environment.

  • action (Tensor, default: None ) –

    Optional action to use, e.g. for evaluating log probabilities.

Source code in rl4co/utils/decoding.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def step(
    self,
    logits: torch.Tensor,
    mask: torch.Tensor,
    td: TensorDict = None,
    action: torch.Tensor = None,
    **kwargs,
) -> TensorDict:
    """Main decoding operation. This method should be called in a loop until all sequences are done.

    Args:
        logits: Logits from the model.
        mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).
        td: TensorDict containing the current state of the environment.
        action: Optional action to use, e.g. for evaluating log probabilities.
    """
    if not self.mask_logits:  # set mask_logit to None if mask_logits is False
        mask = None

    logprobs = process_logits(
        logits,
        mask,
        temperature=self.temperature,
        top_p=self.top_p,
        top_k=self.top_k,
        tanh_clipping=self.tanh_clipping,
        mask_logits=self.mask_logits,
    )
    logprobs, selected_action, td = self._step(
        logprobs, mask, td, action=action, **kwargs
    )

    # directly return for improvement methods, since the action for improvement methods is finalized in its own policy
    if self.improvement_method_mode:
        return logprobs, selected_action
    # for others
    if not self.store_all_logp:
        logprobs = gather_by_index(logprobs, selected_action, dim=1)
    td.set("action", selected_action)
    self.actions.append(selected_action)
    self.logprobs.append(logprobs)
    return td

greedy staticmethod

greedy(logprobs, mask=None)

Select the action with the highest probability.

Source code in rl4co/utils/decoding.py
367
368
369
370
371
372
373
374
375
376
377
@staticmethod
def greedy(logprobs, mask=None):
    """Select the action with the highest probability."""
    # [BS], [BS]
    selected = logprobs.argmax(dim=-1)
    if mask is not None:
        assert (
            not (~mask).gather(1, selected.unsqueeze(-1)).data.any()
        ), "infeasible action selected"

    return selected

sampling staticmethod

sampling(logprobs, mask=None)

Sample an action with a multinomial distribution given by the log probabilities.

Source code in rl4co/utils/decoding.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
@staticmethod
def sampling(logprobs, mask=None):
    """Sample an action with a multinomial distribution given by the log probabilities."""
    probs = logprobs.exp()
    selected = torch.multinomial(probs, 1).squeeze(1)

    if mask is not None:
        while (~mask).gather(1, selected.unsqueeze(-1)).data.any():
            log.info("Sampled bad values, resampling!")
            selected = probs.multinomial(1).squeeze(1)
        assert (
            not (~mask).gather(1, selected.unsqueeze(-1)).data.any()
        ), "infeasible action selected"

    return selected

Greedy

Greedy(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    self.multistart = multistart
    self.multisample = multisample
    self.num_starts = num_starts
    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

Sampling

Sampling(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    self.multistart = multistart
    self.multisample = multisample
    self.num_starts = num_starts
    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

Evaluate

Evaluate(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    multistart: bool = False,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    self.multistart = multistart
    self.multisample = multisample
    self.num_starts = num_starts
    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

BeamSearch

BeamSearch(beam_width=None, select_best=True, **kwargs)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
447
448
449
450
451
452
453
454
def __init__(self, beam_width=None, select_best=True, **kwargs) -> None:
    # TODO do we really need all logp in beam search?
    kwargs["store_all_logp"] = True
    super().__init__(**kwargs)
    self.beam_width = beam_width
    self.select_best = select_best
    self.parent_beam_logprobs = None
    self.beam_path = []

pre_decoder_hook

pre_decoder_hook(
    td: TensorDict, env: RL4COEnvBase, **kwargs
)

Pre decoding hook. This method is called before the main decoding operation.

Source code in rl4co/utils/decoding.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase, **kwargs):
    if self.beam_width is None:
        self.beam_width = env.get_num_starts(td)
    assert self.beam_width > 1, "beam width must be larger than 1"

    # select start nodes. TODO: include first step in beam search as well
    if self.select_start_nodes_fn is not None:
        action = self.select_start_nodes_fn(td, env, self.beam_width)
    else:
        action = env.select_start_nodes(td, num_starts=self.beam_width)

    # Expand td to batch_size * beam_width
    td = batchify(td, self.beam_width)

    td.set("action", action)
    td = env.step(td)["next"]

    logprobs = torch.zeros_like(td["action_mask"], device=td.device)
    beam_parent = torch.zeros(logprobs.size(0), device=td.device, dtype=torch.int32)

    self.logprobs.append(logprobs)
    self.actions.append(action)
    self.parent_beam_logprobs = logprobs.gather(1, action[..., None])
    self.beam_path.append(beam_parent)

    return td, env, self.beam_width

get_log_likelihood

get_log_likelihood(
    logprobs,
    actions=None,
    mask=None,
    return_sum: bool = True,
)

Get log likelihood of selected actions. Note that mask is a boolean tensor where True means the value should be kept.

Parameters:

  • logprobs

    Log probabilities of actions from the model (batch_size, seq_len, action_dim).

  • actions

    Selected actions (batch_size, seq_len).

  • mask

    Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).

  • return_sum (bool, default: True ) –

    Whether to return the sum of log probabilities or not. Defaults to True.

Source code in rl4co/utils/decoding.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_log_likelihood(logprobs, actions=None, mask=None, return_sum: bool = True):
    """Get log likelihood of selected actions.
    Note that mask is a boolean tensor where True means the value should be kept.

    Args:
        logprobs: Log probabilities of actions from the model (batch_size, seq_len, action_dim).
        actions: Selected actions (batch_size, seq_len).
        mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).
        return_sum: Whether to return the sum of log probabilities or not. Defaults to True.
    """
    # Optional: select logp when logp.shape = (bs, dec_steps, N)
    if actions is not None and logprobs.dim() == 3:
        logprobs = logprobs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)

    # Optional: mask out actions irrelevant to objective so they do not get reinforced
    if mask is not None:
        logprobs[~mask] = 0

    assert (
        logprobs > -1000
    ).data.all(), "Logprobs should not be -inf, check sampling procedure!"

    # Calculate log_likelihood
    if return_sum:
        return logprobs.sum(1)  # [batch]
    else:
        return logprobs  # [batch, decode_len]

decode_logprobs

decode_logprobs(logprobs, mask, decode_type='sampling')

Decode log probabilities to select actions with mask. Note that mask is a boolean tensor where True means the value should be kept.

Source code in rl4co/utils/decoding.py
67
68
69
70
71
72
73
74
75
76
77
def decode_logprobs(logprobs, mask, decode_type="sampling"):
    """Decode log probabilities to select actions with mask.
    Note that mask is a boolean tensor where True means the value should be kept.
    """
    if "greedy" in decode_type:
        selected = DecodingStrategy.greedy(logprobs, mask)
    elif "sampling" in decode_type:
        selected = DecodingStrategy.sampling(logprobs, mask)
    else:
        assert False, "Unknown decode type: {}".format(decode_type)
    return selected

random_policy

random_policy(td)

Helper function to select a random action from available actions

Source code in rl4co/utils/decoding.py
80
81
82
83
84
def random_policy(td):
    """Helper function to select a random action from available actions"""
    action = torch.multinomial(td["action_mask"].float(), 1).squeeze(-1)
    td.set("action", action)
    return td

rollout

rollout(env, td, policy, max_steps: int = None)

Helper function to rollout a policy. Currently, TorchRL does not allow to step over envs when done with env.rollout(). We need this because for environments that complete at different steps.

Source code in rl4co/utils/decoding.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def rollout(env, td, policy, max_steps: int = None):
    """Helper function to rollout a policy. Currently, TorchRL does not allow to step
    over envs when done with `env.rollout()`. We need this because for environments that complete at different steps.
    """

    max_steps = float("inf") if max_steps is None else max_steps
    actions = []
    steps = 0

    while not td["done"].all():
        td = policy(td)
        actions.append(td["action"])
        td = env.step(td)["next"]
        steps += 1
        if steps > max_steps:
            log.info("Max steps reached")
            break
    return (
        env.get_reward(td, torch.stack(actions, dim=1)),
        td,
        torch.stack(actions, dim=1),
    )

modify_logits_for_top_k_filtering

modify_logits_for_top_k_filtering(logits, top_k)

Set the logits for none top-k values to -inf. Done out-of-place. Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L6

Source code in rl4co/utils/decoding.py
111
112
113
114
115
116
def modify_logits_for_top_k_filtering(logits, top_k):
    """Set the logits for none top-k values to -inf. Done out-of-place.
    Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L6
    """
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    return logits.masked_fill(indices_to_remove, float("-inf"))

modify_logits_for_top_p_filtering

modify_logits_for_top_p_filtering(logits, top_p)

Set the logits for none top-p values to -inf. Done out-of-place. Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L14

Source code in rl4co/utils/decoding.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf. Done out-of-place.
    Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L14
    """
    if top_p <= 0.0 or top_p >= 1.0:
        return logits

    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)

    # Scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(
        -1, sorted_indices, sorted_indices_to_remove
    )
    return logits.masked_fill(indices_to_remove, float("-inf"))

process_logits

process_logits(
    logits: Tensor,
    mask: Tensor = None,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
)

Convert logits to log probabilities with additional features like temperature scaling, top-k and top-p sampling.

Note

We convert to log probabilities instead of probabilities to avoid numerical instability. This is because, roughly, softmax = exp(logits) / sum(exp(logits)) and log(softmax) = logits - log(sum(exp(logits))), and avoiding the division by the sum of exponentials can help with numerical stability. You may check the official PyTorch documentation.

Parameters:

  • logits (Tensor) –

    Logits from the model (batch_size, num_actions).

  • mask (Tensor, default: None ) –

    Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).

  • temperature (float, default: 1.0 ) –

    Temperature scaling. Higher values make the distribution more uniform (exploration), lower values make it more peaky (exploitation).

  • top_p (float, default: 0.0 ) –

    Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Remove tokens that have a cumulative probability less than the threshold 1 - top_p (lower tail of the distribution). If 0, do not perform.

  • top_k (int, default: 0 ) –

    Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Note that we only do filtering and do not return all the top-k logits here.

  • tanh_clipping (float, default: 0 ) –
  • mask_logits (bool, default: True ) –

    Whether to mask logits of infeasible actions.

Source code in rl4co/utils/decoding.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def process_logits(
    logits: torch.Tensor,
    mask: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
):
    """Convert logits to log probabilities with additional features like temperature scaling, top-k and top-p sampling.

    Note:
        We convert to log probabilities instead of probabilities to avoid numerical instability.
        This is because, roughly, softmax = exp(logits) / sum(exp(logits)) and log(softmax) = logits - log(sum(exp(logits))),
        and avoiding the division by the sum of exponentials can help with numerical stability.
        You may check the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html).

    Args:
        logits: Logits from the model (batch_size, num_actions).
        mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).
        temperature: Temperature scaling. Higher values make the distribution more uniform (exploration),
            lower values make it more peaky (exploitation).
        top_p: Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Remove tokens that have a cumulative probability
            less than the threshold 1 - top_p (lower tail of the distribution). If 0, do not perform.
        top_k: Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Note that we only do filtering and
            do not return all the top-k logits here.
        tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940).
        mask_logits: Whether to mask logits of infeasible actions.
    """

    # Tanh clipping from Bello et al. 2016
    if tanh_clipping > 0:
        logits = torch.tanh(logits) * tanh_clipping

    # In RL, we want to mask the logits to prevent the agent from selecting infeasible actions
    if mask_logits:
        assert mask is not None, "mask must be provided if mask_logits is True"
        logits[~mask] = float("-inf")

    logits = logits / temperature  # temperature scaling

    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # safety check
        logits = modify_logits_for_top_k_filtering(logits, top_k)

    if top_p > 0:
        assert top_p <= 1.0, "top-p should be in (0, 1]."
        logits = modify_logits_for_top_p_filtering(logits, top_p)

    # Compute log probabilities
    return F.log_softmax(logits, dim=-1)