Skip to content

Decoding Strategies

Classes:

Functions:

DecodingStrategy

DecodingStrategy(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Base class for decoding strategies. Subclasses should implement the :meth:_step method. Includes hooks for pre and post main decoding operations.

Parameters:

  • temperature (float, default: 1.0 ) –

    Temperature scaling. Higher values make the distribution more uniform (exploration), lower values make it more peaky (exploitation). Defaults to 1.0.

  • top_p (float, default: 0.0 ) –

    Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Defaults to 0.0.

  • top_k (int, default: 0 ) –

    Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Defaults to 0.

  • mask_logits (bool, default: True ) –

    Whether to mask logits of infeasible actions. Defaults to True.

  • tanh_clipping (float, default: 0 ) –

    Tanh clipping (https://arxiv.org/abs/1611.09940). Defaults to 0.

  • multisample (bool, default: False ) –

    Whether to use sampling decoding. Defaults to False.

  • num_samples (Optional[int], default: None ) –

    Number of samples to evaluate during decoding. Defaults to None.

  • num_starts (Optional[int], default: None ) –

    Number of starts for multistart decoding. Defaults to None.

  • multistart (bool, default: False ) –

    Whether to use multistart decoding. Defaults to False.

  • select_start_nodes_fn (Optional[callable], default: None ) –

    Function to select start nodes for multistart decoding. Defaults to None.

  • improvement_method_mode (bool, default: False ) –

    Whether to use improvement method mode. Defaults to False.

  • select_best (bool, default: False ) –

    Whether to select the best action or return all. Defaults to False.

  • store_all_logp (bool, default: False ) –

    Whether to store all log probabilities. Defaults to False. If True, logprobs will be stored for all actions. Note that this will increase memory usage.

Methods:

  • pre_decoder_hook

    Pre decoding hook. This method is called before the main decoding operation.

  • step

    Main decoding operation. This method should be called in a loop until all sequences are done.

  • greedy

    Select the action with the highest probability.

  • sampling

    Sample an action with a multinomial distribution given by the log probabilities.

Source code in rl4co/utils/decoding.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    # check if multistart (POMO) and multisample flags
    assert not (
        multistart and multisample
    ), "Using both multistart and multisample is not supported"
    if num_samples and num_starts:
        assert not (
            num_samples > 1 and num_starts > 1
        ), f"num_samples={num_samples} and num_starts={num_starts} are both > 1"
    if num_samples is not None:
        multisample = True if num_samples > 1 else False
    if num_starts is not None:
        multistart = True if num_starts > 1 else False
    self.multistart = multistart
    self.multisample = multisample
    # num_starts is used for both multistart and multisample
    # the function is to use start multiple rollouts for the same instance in parallel
    self.num_starts = num_starts if multistart else num_samples

    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

pre_decoder_hook

pre_decoder_hook(
    td: TensorDict, env: RL4COEnvBase, action: Tensor = None
)

Pre decoding hook. This method is called before the main decoding operation.

Source code in rl4co/utils/decoding.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def pre_decoder_hook(
    self, td: TensorDict, env: RL4COEnvBase, action: torch.Tensor = None
):
    """Pre decoding hook. This method is called before the main decoding operation."""

    # Multi-start decoding. If num_starts is None, we use the number of actions in the action mask
    if self.multistart or self.multisample:
        if self.num_starts is None:
            self.num_starts = env.get_num_starts(td)
            if self.multisample:
                log.warn(
                    f"num_starts is not provided for sampling, using num_starts={self.num_starts}"
                )
    else:
        if self.num_starts is not None:
            if self.num_starts >= 1:
                log.warn(
                    f"num_starts={self.num_starts} is ignored for decode_type={self.name}"
                )

        self.num_starts = 0

    # Multi-start decoding: first action is chosen by ad-hoc node selection
    if self.num_starts >= 1:
        if self.multistart:
            if action is None:  # if action is provided, we use it as the first action
                if self.select_start_nodes_fn is not None:
                    action = self.select_start_nodes_fn(td, env, self.num_starts)
                else:
                    action = env.select_start_nodes(td, num_starts=self.num_starts)

            # Expand td to batch_size * num_starts
            td = batchify(td, self.num_starts)

            td.set("action", action)
            td = env.step(td)["next"]
            # first logprobs is 0, so p = logprobs.exp() = 1
            if self.store_all_logp:
                logprobs = torch.zeros_like(td["action_mask"])  # [B, N]
            else:
                logprobs = torch.zeros_like(action, device=td.device)  # [B]

            self.logprobs.append(logprobs)
            self.actions.append(action)
        else:
            # Expand td to batch_size * num_samplestarts
            td = batchify(td, self.num_starts)

    return td, env, self.num_starts

step

step(
    logits: Tensor,
    mask: Tensor,
    td: TensorDict = None,
    action: Tensor = None,
    **kwargs
) -> TensorDict

Main decoding operation. This method should be called in a loop until all sequences are done.

Parameters:

  • logits (Tensor) –

    Logits from the model.

  • mask (Tensor) –

    Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).

  • td (TensorDict, default: None ) –

    TensorDict containing the current state of the environment.

  • action (Tensor, default: None ) –

    Optional action to use, e.g. for evaluating log probabilities.

Source code in rl4co/utils/decoding.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def step(
    self,
    logits: torch.Tensor,
    mask: torch.Tensor,
    td: TensorDict = None,
    action: torch.Tensor = None,
    **kwargs,
) -> TensorDict:
    """Main decoding operation. This method should be called in a loop until all sequences are done.

    Args:
        logits: Logits from the model.
        mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).
        td: TensorDict containing the current state of the environment.
        action: Optional action to use, e.g. for evaluating log probabilities.
    """
    if not self.mask_logits:  # set mask_logit to None if mask_logits is False
        mask = None

    logprobs = process_logits(
        logits,
        mask,
        temperature=self.temperature,
        top_p=self.top_p,
        top_k=self.top_k,
        tanh_clipping=self.tanh_clipping,
        mask_logits=self.mask_logits,
    )
    logprobs, selected_action, td = self._step(
        logprobs, mask, td, action=action, **kwargs
    )

    # directly return for improvement methods, since the action for improvement methods is finalized in its own policy
    if self.improvement_method_mode:
        return logprobs, selected_action
    # for others
    if not self.store_all_logp:
        logprobs = gather_by_index(logprobs, selected_action, dim=1)
    td.set("action", selected_action)
    self.actions.append(selected_action)
    self.logprobs.append(logprobs)
    return td

greedy staticmethod

greedy(logprobs, mask=None)

Select the action with the highest probability.

Source code in rl4co/utils/decoding.py
389
390
391
392
393
394
395
396
397
398
399
@staticmethod
def greedy(logprobs, mask=None):
    """Select the action with the highest probability."""
    # [BS], [BS]
    selected = logprobs.argmax(dim=-1)
    if mask is not None:
        assert (
            not (~mask).gather(1, selected.unsqueeze(-1)).data.any()
        ), "infeasible action selected"

    return selected

sampling staticmethod

sampling(logprobs, mask=None)

Sample an action with a multinomial distribution given by the log probabilities.

Source code in rl4co/utils/decoding.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@staticmethod
def sampling(logprobs, mask=None):
    """Sample an action with a multinomial distribution given by the log probabilities."""
    probs = logprobs.exp()
    selected = torch.multinomial(probs, 1).squeeze(1)

    if mask is not None:
        while (~mask).gather(1, selected.unsqueeze(-1)).data.any():
            log.info("Sampled bad values, resampling!")
            selected = probs.multinomial(1).squeeze(1)
        assert (
            not (~mask).gather(1, selected.unsqueeze(-1)).data.any()
        ), "infeasible action selected"

    return selected

Greedy

Greedy(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    # check if multistart (POMO) and multisample flags
    assert not (
        multistart and multisample
    ), "Using both multistart and multisample is not supported"
    if num_samples and num_starts:
        assert not (
            num_samples > 1 and num_starts > 1
        ), f"num_samples={num_samples} and num_starts={num_starts} are both > 1"
    if num_samples is not None:
        multisample = True if num_samples > 1 else False
    if num_starts is not None:
        multistart = True if num_starts > 1 else False
    self.multistart = multistart
    self.multisample = multisample
    # num_starts is used for both multistart and multisample
    # the function is to use start multiple rollouts for the same instance in parallel
    self.num_starts = num_starts if multistart else num_samples

    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

Sampling

Sampling(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    # check if multistart (POMO) and multisample flags
    assert not (
        multistart and multisample
    ), "Using both multistart and multisample is not supported"
    if num_samples and num_starts:
        assert not (
            num_samples > 1 and num_starts > 1
        ), f"num_samples={num_samples} and num_starts={num_starts} are both > 1"
    if num_samples is not None:
        multisample = True if num_samples > 1 else False
    if num_starts is not None:
        multistart = True if num_starts > 1 else False
    self.multistart = multistart
    self.multisample = multisample
    # num_starts is used for both multistart and multisample
    # the function is to use start multiple rollouts for the same instance in parallel
    self.num_starts = num_starts if multistart else num_samples

    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

Evaluate

Evaluate(
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs
)

Bases: DecodingStrategy

Source code in rl4co/utils/decoding.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def __init__(
    self,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    mask_logits: bool = True,
    tanh_clipping: float = 0,
    num_samples: Optional[int] = None,
    multisample: bool = False,
    num_starts: Optional[int] = None,
    multistart: bool = False,
    select_start_nodes_fn: Optional[callable] = None,
    improvement_method_mode: bool = False,
    select_best: bool = False,
    store_all_logp: bool = False,
    **kwargs,
) -> None:
    self.temperature = temperature
    self.top_p = top_p
    self.top_k = top_k
    self.mask_logits = mask_logits
    self.tanh_clipping = tanh_clipping
    # check if multistart (POMO) and multisample flags
    assert not (
        multistart and multisample
    ), "Using both multistart and multisample is not supported"
    if num_samples and num_starts:
        assert not (
            num_samples > 1 and num_starts > 1
        ), f"num_samples={num_samples} and num_starts={num_starts} are both > 1"
    if num_samples is not None:
        multisample = True if num_samples > 1 else False
    if num_starts is not None:
        multistart = True if num_starts > 1 else False
    self.multistart = multistart
    self.multisample = multisample
    # num_starts is used for both multistart and multisample
    # the function is to use start multiple rollouts for the same instance in parallel
    self.num_starts = num_starts if multistart else num_samples

    self.select_start_nodes_fn = select_start_nodes_fn
    self.improvement_method_mode = improvement_method_mode
    self.select_best = select_best
    self.store_all_logp = store_all_logp
    # initialize buffers
    self.actions = []
    self.logprobs = []

BeamSearch

BeamSearch(beam_width=None, select_best=True, **kwargs)

Bases: DecodingStrategy

Methods:

  • pre_decoder_hook

    Pre decoding hook. This method is called before the main decoding operation.

Source code in rl4co/utils/decoding.py
469
470
471
472
473
474
475
476
def __init__(self, beam_width=None, select_best=True, **kwargs) -> None:
    # TODO do we really need all logp in beam search?
    kwargs["store_all_logp"] = True
    super().__init__(**kwargs)
    self.beam_width = beam_width
    self.select_best = select_best
    self.parent_beam_logprobs = None
    self.beam_path = []

pre_decoder_hook

pre_decoder_hook(
    td: TensorDict, env: RL4COEnvBase, **kwargs
)

Pre decoding hook. This method is called before the main decoding operation.

Source code in rl4co/utils/decoding.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase, **kwargs):
    if self.beam_width is None:
        self.beam_width = env.get_num_starts(td)
    assert self.beam_width > 1, "beam width must be larger than 1"

    # select start nodes. TODO: include first step in beam search as well
    if self.select_start_nodes_fn is not None:
        action = self.select_start_nodes_fn(td, env, self.beam_width)
    else:
        action = env.select_start_nodes(td, num_starts=self.beam_width)

    # Expand td to batch_size * beam_width
    td = batchify(td, self.beam_width)

    td.set("action", action)
    td = env.step(td)["next"]

    logprobs = torch.zeros_like(td["action_mask"], device=td.device)
    beam_parent = torch.zeros(logprobs.size(0), device=td.device, dtype=torch.int32)

    self.logprobs.append(logprobs)
    self.actions.append(action)
    self.parent_beam_logprobs = logprobs.gather(1, action[..., None])
    self.beam_path.append(beam_parent)

    return td, env, self.beam_width

get_log_likelihood

get_log_likelihood(
    logprobs,
    actions=None,
    mask=None,
    return_sum: bool = True,
)

Get log likelihood of selected actions. Note that mask is a boolean tensor where True means the value should be kept.

Parameters:

  • logprobs

    Log probabilities of actions from the model (batch_size, seq_len, action_dim).

  • actions

    Selected actions (batch_size, seq_len).

  • mask

    Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).

  • return_sum (bool, default: True ) –

    Whether to return the sum of log probabilities or not. Defaults to True.

Source code in rl4co/utils/decoding.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_log_likelihood(logprobs, actions=None, mask=None, return_sum: bool = True):
    """Get log likelihood of selected actions.
    Note that mask is a boolean tensor where True means the value should be kept.

    Args:
        logprobs: Log probabilities of actions from the model (batch_size, seq_len, action_dim).
        actions: Selected actions (batch_size, seq_len).
        mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).
        return_sum: Whether to return the sum of log probabilities or not. Defaults to True.
    """
    # Optional: select logp when logp.shape = (bs, dec_steps, N)
    if actions is not None and logprobs.dim() == 3:
        logprobs = logprobs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)

    # Optional: mask out actions irrelevant to objective so they do not get reinforced
    if mask is not None:
        logprobs[~mask] = 0

    assert (
        logprobs > -1000
    ).data.all(), "Logprobs should not be -inf, check sampling procedure!"

    # Calculate log_likelihood
    if return_sum:
        return logprobs.sum(1)  # [batch]
    else:
        return logprobs  # [batch, decode_len]

decode_logprobs

decode_logprobs(logprobs, mask, decode_type='sampling')

Decode log probabilities to select actions with mask. Note that mask is a boolean tensor where True means the value should be kept.

Source code in rl4co/utils/decoding.py
67
68
69
70
71
72
73
74
75
76
77
def decode_logprobs(logprobs, mask, decode_type="sampling"):
    """Decode log probabilities to select actions with mask.
    Note that mask is a boolean tensor where True means the value should be kept.
    """
    if "greedy" in decode_type:
        selected = DecodingStrategy.greedy(logprobs, mask)
    elif "sampling" in decode_type:
        selected = DecodingStrategy.sampling(logprobs, mask)
    else:
        assert False, "Unknown decode type: {}".format(decode_type)
    return selected

random_policy

random_policy(td)

Helper function to select a random action from available actions

Source code in rl4co/utils/decoding.py
80
81
82
83
84
def random_policy(td):
    """Helper function to select a random action from available actions"""
    action = torch.multinomial(td["action_mask"].float(), 1).squeeze(-1)
    td.set("action", action)
    return td

rollout

rollout(env, td, policy, max_steps: int = None)

Helper function to rollout a policy. Currently, TorchRL does not allow to step over envs when done with env.rollout(). We need this because for environments that complete at different steps.

Source code in rl4co/utils/decoding.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def rollout(env, td, policy, max_steps: int = None):
    """Helper function to rollout a policy. Currently, TorchRL does not allow to step
    over envs when done with `env.rollout()`. We need this because for environments that complete at different steps.
    """

    max_steps = float("inf") if max_steps is None else max_steps
    actions = []
    steps = 0

    while not td["done"].all():
        td = policy(td)
        actions.append(td["action"])
        td = env.step(td)["next"]
        steps += 1
        if steps > max_steps:
            log.info("Max steps reached")
            break
    return (
        env.get_reward(td, torch.stack(actions, dim=1)),
        td,
        torch.stack(actions, dim=1),
    )

modify_logits_for_top_k_filtering

modify_logits_for_top_k_filtering(logits, top_k)

Set the logits for none top-k values to -inf. Done out-of-place. Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L6

Source code in rl4co/utils/decoding.py
111
112
113
114
115
116
def modify_logits_for_top_k_filtering(logits, top_k):
    """Set the logits for none top-k values to -inf. Done out-of-place.
    Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L6
    """
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    return logits.masked_fill(indices_to_remove, float("-inf"))

modify_logits_for_top_p_filtering

modify_logits_for_top_p_filtering(logits, top_p)

Set the logits for none top-p values to -inf. Done out-of-place. Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L14

Source code in rl4co/utils/decoding.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def modify_logits_for_top_p_filtering(logits, top_p):
    """Set the logits for none top-p values to -inf. Done out-of-place.
    Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L14
    """
    if top_p <= 0.0 or top_p >= 1.0:
        return logits

    # First sort and calculate cumulative sum of probabilities.
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)

    # Scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(
        -1, sorted_indices, sorted_indices_to_remove
    )
    return logits.masked_fill(indices_to_remove, float("-inf"))

process_logits

process_logits(
    logits: Tensor,
    mask: Tensor = None,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
)

Convert logits to log probabilities with additional features like temperature scaling, top-k and top-p sampling.

Note

We convert to log probabilities instead of probabilities to avoid numerical instability. This is because, roughly, softmax = exp(logits) / sum(exp(logits)) and log(softmax) = logits - log(sum(exp(logits))), and avoiding the division by the sum of exponentials can help with numerical stability. You may check the official PyTorch documentation.

Parameters:

  • logits (Tensor) –

    Logits from the model (batch_size, num_actions).

  • mask (Tensor, default: None ) –

    Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).

  • temperature (float, default: 1.0 ) –

    Temperature scaling. Higher values make the distribution more uniform (exploration), lower values make it more peaky (exploitation).

  • top_p (float, default: 0.0 ) –

    Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Remove tokens that have a cumulative probability less than the threshold 1 - top_p (lower tail of the distribution). If 0, do not perform.

  • top_k (int, default: 0 ) –

    Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Note that we only do filtering and do not return all the top-k logits here.

  • tanh_clipping (float, default: 0 ) –
  • mask_logits (bool, default: True ) –

    Whether to mask logits of infeasible actions.

Source code in rl4co/utils/decoding.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def process_logits(
    logits: torch.Tensor,
    mask: torch.Tensor = None,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    tanh_clipping: float = 0,
    mask_logits: bool = True,
):
    """Convert logits to log probabilities with additional features like temperature scaling, top-k and top-p sampling.

    Note:
        We convert to log probabilities instead of probabilities to avoid numerical instability.
        This is because, roughly, softmax = exp(logits) / sum(exp(logits)) and log(softmax) = logits - log(sum(exp(logits))),
        and avoiding the division by the sum of exponentials can help with numerical stability.
        You may check the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html).

    Args:
        logits: Logits from the model (batch_size, num_actions).
        mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch).
        temperature: Temperature scaling. Higher values make the distribution more uniform (exploration),
            lower values make it more peaky (exploitation).
        top_p: Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Remove tokens that have a cumulative probability
            less than the threshold 1 - top_p (lower tail of the distribution). If 0, do not perform.
        top_k: Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Note that we only do filtering and
            do not return all the top-k logits here.
        tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940).
        mask_logits: Whether to mask logits of infeasible actions.
    """

    # Tanh clipping from Bello et al. 2016
    if tanh_clipping > 0:
        logits = torch.tanh(logits) * tanh_clipping

    # In RL, we want to mask the logits to prevent the agent from selecting infeasible actions
    if mask_logits:
        assert mask is not None, "mask must be provided if mask_logits is True"
        logits[~mask] = float("-inf")

    logits = logits / temperature  # temperature scaling

    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # safety check
        logits = modify_logits_for_top_k_filtering(logits, top_k)

    if top_p > 0:
        assert top_p <= 1.0, "top-p should be in (0, 1]."
        logits = modify_logits_for_top_p_filtering(logits, top_p)

    # Compute log probabilities
    return F.log_softmax(logits, dim=-1)