Skip to content

REINFORCE

REINFORCE

REINFORCE(
    env: RL4COEnvBase,
    policy: Module,
    baseline: REINFORCEBaseline | str = "rollout",
    baseline_kwargs: dict = {},
    reward_scale: str = None,
    **kwargs
)

Bases: RL4COLitModule

REINFORCE algorithm, also known as policy gradients. See superclass RL4COLitModule for more details.

Parameters:

  • env (RL4COEnvBase) –

    Environment to use for the algorithm

  • policy (Module) –

    Policy to use for the algorithm

  • baseline (REINFORCEBaseline | str, default: 'rollout' ) –

    REINFORCE baseline

  • baseline_kwargs (dict, default: {} ) –

    Keyword arguments for baseline. Ignored if baseline is not a string

  • **kwargs

    Keyword arguments passed to the superclass

Methods:

Source code in rl4co/models/rl/reinforce/reinforce.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(
    self,
    env: RL4COEnvBase,
    policy: nn.Module,
    baseline: REINFORCEBaseline | str = "rollout",
    baseline_kwargs: dict = {},
    reward_scale: str = None,
    **kwargs,
):
    super().__init__(env, policy, **kwargs)

    self.save_hyperparameters(logger=False)

    if baseline == "critic":
        log.warning(
            "Using critic as baseline. If you want more granular support, use the A2C module instead."
        )

    if isinstance(baseline, str):
        baseline = get_reinforce_baseline(baseline, **baseline_kwargs)
    else:
        if baseline_kwargs != {}:
            log.warning("baseline_kwargs is ignored when baseline is not a string")
    self.baseline = baseline
    self.advantage_scaler = RewardScaler(reward_scale)

calculate_loss

calculate_loss(
    td: TensorDict,
    batch: TensorDict,
    policy_out: dict,
    reward: Optional[Tensor] = None,
    log_likelihood: Optional[Tensor] = None,
)

Calculate loss for REINFORCE algorithm.

Parameters:

  • td (TensorDict) –

    TensorDict containing the current state of the environment

  • batch (TensorDict) –

    Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baseline

  • policy_out (dict) –

    Output of the policy network

  • reward (Optional[Tensor], default: None ) –

    Reward tensor. If None, it is taken from policy_out

  • log_likelihood (Optional[Tensor], default: None ) –

    Log-likelihood tensor. If None, it is taken from policy_out

Source code in rl4co/models/rl/reinforce/reinforce.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def calculate_loss(
    self,
    td: TensorDict,
    batch: TensorDict,
    policy_out: dict,
    reward: Optional[torch.Tensor] = None,
    log_likelihood: Optional[torch.Tensor] = None,
):
    """Calculate loss for REINFORCE algorithm.

    Args:
        td: TensorDict containing the current state of the environment
        batch: Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baseline
        policy_out: Output of the policy network
        reward: Reward tensor. If None, it is taken from `policy_out`
        log_likelihood: Log-likelihood tensor. If None, it is taken from `policy_out`
    """
    # Extra: this is used for additional loss terms, e.g., REINFORCE baseline
    extra = batch.get("extra", None)
    reward = reward if reward is not None else policy_out["reward"]
    log_likelihood = (
        log_likelihood if log_likelihood is not None else policy_out["log_likelihood"]
    )

    # REINFORCE baseline
    bl_val, bl_loss = (
        self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0)
    )

    # Main loss function
    advantage = reward - bl_val  # advantage = reward - baseline
    advantage = self.advantage_scaler(advantage)
    reinforce_loss = -(advantage * log_likelihood).mean()
    loss = reinforce_loss + bl_loss
    policy_out.update(
        {
            "loss": loss,
            "reinforce_loss": reinforce_loss,
            "bl_loss": bl_loss,
            "bl_val": bl_val,
        }
    )
    return policy_out

on_train_epoch_end

on_train_epoch_end()

Callback for end of training epoch: we evaluate the baseline

Source code in rl4co/models/rl/reinforce/reinforce.py
127
128
129
130
131
132
133
134
135
136
137
138
def on_train_epoch_end(self):
    """Callback for end of training epoch: we evaluate the baseline"""
    self.baseline.epoch_callback(
        self.policy,
        env=self.env,
        batch_size=self.val_batch_size,
        device=get_lightning_device(self),
        epoch=self.current_epoch,
        dataset_size=self.data_cfg["val_data_size"],
    )
    # Need to call super() for the dataset to be reset
    super().on_train_epoch_end()

wrap_dataset

wrap_dataset(dataset)

Wrap dataset from baseline evaluation. Used in greedy rollout baseline

Source code in rl4co/models/rl/reinforce/reinforce.py
140
141
142
143
144
145
146
147
def wrap_dataset(self, dataset):
    """Wrap dataset from baseline evaluation. Used in greedy rollout baseline"""
    return self.baseline.wrap_dataset(
        dataset,
        self.env,
        batch_size=self.val_batch_size,
        device=get_lightning_device(self),
    )

set_decode_type_multistart

set_decode_type_multistart(phase: str)

Set decode type to multistart for train, val and test in policy. For example, if the decode type is greedy, it will be set to multistart_greedy.

Parameters:

  • phase (str) –

    Phase to set decode type for. Must be one of train, val or test.

Source code in rl4co/models/rl/reinforce/reinforce.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def set_decode_type_multistart(self, phase: str):
    """Set decode type to `multistart` for train, val and test in policy.
    For example, if the decode type is `greedy`, it will be set to `multistart_greedy`.

    Args:
        phase: Phase to set decode type for. Must be one of `train`, `val` or `test`.
    """
    attribute = f"{phase}_decode_type"
    attr_get = getattr(self.policy, attribute)
    # If does not exist, log error
    if attr_get is None:
        log.error(f"Decode type for {phase} is None. Cannot prepend `multistart_`.")
        return
    elif "multistart" in attr_get:
        return
    else:
        setattr(self.policy, attribute, f"multistart_{attr_get}")

load_from_checkpoint classmethod

load_from_checkpoint(
    checkpoint_path: _PATH | IO,
    map_location: _MAP_LOCATION_TYPE = None,
    hparams_file: Optional[_PATH] = None,
    strict: bool = False,
    load_baseline: bool = True,
    **kwargs: Any
) -> Self

Load model from checkpoint/

Note

This is a modified version of load_from_checkpoint from pytorch_lightning.core.saving. It deals with matching keys for the baseline by first running setup

Source code in rl4co/models/rl/reinforce/reinforce.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@classmethod
def load_from_checkpoint(
    cls,
    checkpoint_path: _PATH | IO,
    map_location: _MAP_LOCATION_TYPE = None,
    hparams_file: Optional[_PATH] = None,
    strict: bool = False,
    load_baseline: bool = True,
    **kwargs: Any,
) -> Self:
    """Load model from checkpoint/

    Note:
        This is a modified version of `load_from_checkpoint` from `pytorch_lightning.core.saving`.
        It deals with matching keys for the baseline by first running setup
    """

    if strict:
        log.warning("Setting strict=False for loading model from checkpoint.")
        strict = False

    # Do not use strict
    loaded = _load_from_checkpoint(
        cls,
        checkpoint_path,
        map_location,
        hparams_file,
        strict,
        **kwargs,
    )

    # Load baseline state dict
    if load_baseline:
        # setup baseline first
        loaded.setup()
        loaded.post_setup_hook()
        # load baseline state dict
        state_dict = torch.load(checkpoint_path, map_location=map_location)[
            "state_dict"
        ]
        # get only baseline parameters
        state_dict = {k: v for k, v in state_dict.items() if "baseline" in k}
        state_dict = {k.replace("baseline.", "", 1): v for k, v in state_dict.items()}
        loaded.baseline.load_state_dict(state_dict)

    return cast(Self, loaded)

REINFORCEBaseline

REINFORCEBaseline(*args, **kw)

Bases: Module

Base class for REINFORCE baselines

Methods:

  • wrap_dataset

    Wrap dataset with baseline-specific functionality

  • eval

    Evaluate baseline

  • epoch_callback

    Callback at the end of each epoch

  • setup

    To be called before training during setup phase

Source code in rl4co/models/rl/reinforce/baselines.py
22
23
24
def __init__(self, *args, **kw):
    super().__init__()
    pass

wrap_dataset

wrap_dataset(dataset: Dataset, *args, **kw)

Wrap dataset with baseline-specific functionality

Source code in rl4co/models/rl/reinforce/baselines.py
26
27
28
def wrap_dataset(self, dataset: Dataset, *args, **kw):
    """Wrap dataset with baseline-specific functionality"""
    return dataset

eval abstractmethod

eval(
    td: TensorDict,
    reward: Tensor,
    env: RL4COEnvBase = None,
    **kwargs
)

Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
30
31
32
33
34
35
@abc.abstractmethod
def eval(
    self, td: TensorDict, reward: torch.Tensor, env: RL4COEnvBase = None, **kwargs
):
    """Evaluate baseline"""
    raise NotImplementedError

epoch_callback

epoch_callback(*args, **kw)

Callback at the end of each epoch For example, update baseline parameters and obtain baseline values

Source code in rl4co/models/rl/reinforce/baselines.py
37
38
39
40
41
def epoch_callback(self, *args, **kw):
    """Callback at the end of each epoch
    For example, update baseline parameters and obtain baseline values
    """
    pass

setup

setup(*args, **kw)

To be called before training during setup phase This follow PyTorch Lightning's setup() convention

Source code in rl4co/models/rl/reinforce/baselines.py
43
44
45
46
47
def setup(self, *args, **kw):
    """To be called before training during setup phase
    This follow PyTorch Lightning's setup() convention
    """
    pass

NoBaseline

NoBaseline(*args, **kw)

Bases: REINFORCEBaseline

No baseline: return 0 for baseline and neg_los

Methods:

  • eval

    Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
22
23
24
def __init__(self, *args, **kw):
    super().__init__()
    pass

eval

eval(td, reward, env=None)

Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
53
54
def eval(self, td, reward, env=None):
    return 0, 0  # No baseline, no neg_los

SharedBaseline

SharedBaseline(*args, **kw)

Bases: REINFORCEBaseline

Shared baseline: return mean of reward as baseline

Methods:

  • eval

    Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
22
23
24
def __init__(self, *args, **kw):
    super().__init__()
    pass

eval

eval(td, reward, env=None, on_dim=1)

Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
60
61
def eval(self, td, reward, env=None, on_dim=1):  # e.g. [batch, pomo, ...]
    return reward.mean(dim=on_dim, keepdims=True), 0

ExponentialBaseline

ExponentialBaseline(beta=0.8, **kw)

Bases: REINFORCEBaseline

Exponential baseline: return exponential moving average of reward as baseline

Parameters:

  • beta

    Beta value for the exponential moving average

Methods:

  • eval

    Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
71
72
73
74
75
def __init__(self, beta=0.8, **kw):
    super(REINFORCEBaseline, self).__init__()

    self.beta = beta
    self.v = None

eval

eval(td, reward, env=None)

Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
77
78
79
80
81
82
83
def eval(self, td, reward, env=None):
    if self.v is None:
        v = reward.mean()
    else:
        v = self.beta * self.v + (1.0 - self.beta) * reward.mean()
    self.v = v.detach()  # Detach since we never want to backprop
    return self.v, 0  # No loss

MeanBaseline

MeanBaseline(*args, **kw)

Bases: REINFORCEBaseline

Mean baseline: return mean of reward as baseline

Source code in rl4co/models/rl/reinforce/baselines.py
22
23
24
def __init__(self, *args, **kw):
    super().__init__()
    pass

WarmupBaseline

WarmupBaseline(
    baseline, n_epochs=1, warmup_exp_beta=0.8, **kw
)

Bases: REINFORCEBaseline

Warmup baseline: return convex combination of baseline and exponential baseline

Parameters:

  • baseline

    Baseline to use after warmup

  • n_epochs

    Number of epochs to warmup

  • warmup_exp_beta

    Beta value for the exponential baseline during warmup

Methods:

  • wrap_dataset

    Wrap dataset with baseline-specific functionality

  • setup

    To be called before training during setup phase

  • eval

    Evaluate baseline

  • epoch_callback

    Callback at the end of each epoch

Source code in rl4co/models/rl/reinforce/baselines.py
102
103
104
105
106
107
108
109
def __init__(self, baseline, n_epochs=1, warmup_exp_beta=0.8, **kw):
    super(REINFORCEBaseline, self).__init__()

    self.baseline = baseline
    assert n_epochs > 0, "n_epochs to warmup must be positive"
    self.warmup_baseline = ExponentialBaseline(warmup_exp_beta)
    self.alpha = 0
    self.n_epochs = n_epochs

wrap_dataset

wrap_dataset(dataset, *args, **kw)

Wrap dataset with baseline-specific functionality

Source code in rl4co/models/rl/reinforce/baselines.py
111
112
113
114
def wrap_dataset(self, dataset, *args, **kw):
    if self.alpha > 0:
        return self.baseline.wrap_dataset(dataset, *args, **kw)
    return self.warmup_baseline.wrap_dataset(dataset, *args, **kw)

setup

setup(*args, **kw)

To be called before training during setup phase This follow PyTorch Lightning's setup() convention

Source code in rl4co/models/rl/reinforce/baselines.py
116
117
def setup(self, *args, **kw):
    self.baseline.setup(*args, **kw)

eval

eval(td, reward, env=None)

Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
119
120
121
122
123
124
125
126
127
128
129
130
def eval(self, td, reward, env=None):
    if self.alpha == 1:
        return self.baseline.eval(td, reward, env)
    if self.alpha == 0:
        return self.warmup_baseline.eval(td, reward, env)
    v_b, l_b = self.baseline.eval(td, reward, env)
    v_wb, l_wb = self.warmup_baseline.eval(td, reward, env)
    # Return convex combination of baseline and of loss
    return (
        self.alpha * v_b + (1 - self.alpha) * v_wb,
        self.alpha * l_b + (1 - self.alpha) * l_wb,
    )

epoch_callback

epoch_callback(*args, **kw)

Callback at the end of each epoch For example, update baseline parameters and obtain baseline values

Source code in rl4co/models/rl/reinforce/baselines.py
132
133
134
135
136
137
def epoch_callback(self, *args, **kw):
    # Need to call epoch callback of inner policy (also after first epoch if we have not used it)
    self.baseline.epoch_callback(*args, **kw)
    if kw["epoch"] < self.n_epochs:
        self.alpha = (kw["epoch"] + 1) / float(self.n_epochs)
        log.info("Set warmup alpha = {}".format(self.alpha))

CriticBaseline

CriticBaseline(critic: CriticNetwork = None, **unused_kw)

Bases: REINFORCEBaseline

Critic baseline: use critic network as baseline

Parameters:

  • critic (CriticNetwork, default: None ) –

    Critic network to use as baseline. If None, create a new critic network based on the environment

Methods:

  • setup

    To be called before training during setup phase

  • eval

    Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
147
148
149
def __init__(self, critic: CriticNetwork = None, **unused_kw):
    super(CriticBaseline, self).__init__()
    self.critic = critic

setup

setup(policy, env, **kwargs)

To be called before training during setup phase This follow PyTorch Lightning's setup() convention

Source code in rl4co/models/rl/reinforce/baselines.py
151
152
153
154
def setup(self, policy, env, **kwargs):
    if self.critic is None:
        log.info("Critic not found. Creating critic network for {}".format(env.name))
        self.critic = create_critic_from_actor(policy)

eval

eval(x, c, env=None)

Evaluate baseline

Source code in rl4co/models/rl/reinforce/baselines.py
156
157
158
159
def eval(self, x, c, env=None):
    v = self.critic(x).squeeze(-1)
    # detach v since actor should not backprop through baseline, only for loss
    return v.detach(), F.mse_loss(v, c.detach())

RolloutBaseline

RolloutBaseline(bl_alpha=0.05, **kw)

Bases: REINFORCEBaseline

Rollout baseline: use greedy rollout as baseline

Parameters:

  • bl_alpha

    Alpha value for the baseline T-test

Methods:

  • setup

    To be called before training during setup phase

  • eval

    Evaluate rollout baseline

  • epoch_callback

    Challenges the current baseline with the policy and replaces the baseline policy if it is improved

  • rollout

    Rollout the policy on the given dataset

  • wrap_dataset

    Wrap the dataset in a baseline dataset

Source code in rl4co/models/rl/reinforce/baselines.py
169
170
171
def __init__(self, bl_alpha=0.05, **kw):
    super(RolloutBaseline, self).__init__()
    self.bl_alpha = bl_alpha

setup

setup(*args, **kw)

To be called before training during setup phase This follow PyTorch Lightning's setup() convention

Source code in rl4co/models/rl/reinforce/baselines.py
173
174
def setup(self, *args, **kw):
    self._update_policy(*args, **kw)

eval

eval(td, reward, env)

Evaluate rollout baseline

Warning

This is not differentiable and should only be used for evaluation. Also, it is recommended to use the rollout method directly instead of this method.

Source code in rl4co/models/rl/reinforce/baselines.py
191
192
193
194
195
196
197
198
199
200
def eval(self, td, reward, env):
    """Evaluate rollout baseline

    Warning:
        This is not differentiable and should only be used for evaluation.
        Also, it is recommended to use the `rollout` method directly instead of this method.
    """
    with torch.inference_mode():
        reward = self.policy(td, env)["reward"]
    return reward, 0

epoch_callback

epoch_callback(
    policy,
    env,
    batch_size=64,
    device="cpu",
    epoch=None,
    dataset_size=None,
)

Challenges the current baseline with the policy and replaces the baseline policy if it is improved

Source code in rl4co/models/rl/reinforce/baselines.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def epoch_callback(
    self, policy, env, batch_size=64, device="cpu", epoch=None, dataset_size=None
):
    """Challenges the current baseline with the policy and replaces the baseline policy if it is improved"""
    log.info("Evaluating candidate policy on evaluation dataset")
    candidate_vals = self.rollout(policy, env, batch_size, device).cpu().numpy()
    candidate_mean = candidate_vals.mean()

    log.info(
        "Candidate mean: {:.3f}, Baseline mean: {:.3f}".format(
            candidate_mean, self.mean
        )
    )
    if candidate_mean - self.mean > 0:
        # Calc p value with inverse logic (costs)
        t, p = ttest_rel(-candidate_vals, -self.bl_vals)

        p_val = p / 2  # one-sided
        assert t < 0, "T-statistic should be negative"
        log.info("p-value: {:.3f}".format(p_val))
        if p_val < self.bl_alpha:
            log.info("Updating baseline")
            self._update_policy(policy, env, batch_size, device, dataset_size)

rollout

rollout(
    policy, env, batch_size=64, device="cpu", dataset=None
)

Rollout the policy on the given dataset

Source code in rl4co/models/rl/reinforce/baselines.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def rollout(self, policy, env, batch_size=64, device="cpu", dataset=None):
    """Rollout the policy on the given dataset"""

    # if dataset is None, use the dataset of the baseline
    dataset = self.dataset if dataset is None else dataset

    policy.eval()
    policy = policy.to(device)

    def eval_policy(batch):
        with torch.inference_mode():
            batch = env.reset(batch.to(device))
            return policy(batch, env, decode_type="greedy")["reward"]

    dl = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)

    rewards = torch.cat([eval_policy(batch) for batch in dl], 0)
    return rewards

wrap_dataset

wrap_dataset(
    dataset, env, batch_size=64, device="cpu", **kw
)

Wrap the dataset in a baseline dataset

Note

This is an alternative to eval that does not require the policy to be passed at every call but just once. Values are added to the dataset. This also allows for larger batch sizes since we evauate the policy without gradients.

Source code in rl4co/models/rl/reinforce/baselines.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def wrap_dataset(self, dataset, env, batch_size=64, device="cpu", **kw):
    """Wrap the dataset in a baseline dataset

    Note:
        This is an alternative to `eval` that does not require the policy to be passed
        at every call but just once. Values are added to the dataset. This also allows for
        larger batch sizes since we evauate the policy without gradients.
    """
    rewards = (
        self.rollout(self.policy, env, batch_size, device, dataset=dataset)
        .detach()
        .cpu()
    )
    return dataset.add_key("extra", rewards)

get_reinforce_baseline

get_reinforce_baseline(name, **kw)

Get a REINFORCE baseline by name The rollout baseline default to warmup baseline with one epoch of exponential baseline and the greedy rollout

Source code in rl4co/models/rl/reinforce/baselines.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def get_reinforce_baseline(name, **kw):
    """Get a REINFORCE baseline by name
    The rollout baseline default to warmup baseline with one epoch of
    exponential baseline and the greedy rollout
    """
    if name == "warmup":
        inner_baseline = kw.get("baseline", "rollout")
        if not isinstance(inner_baseline, REINFORCEBaseline):
            inner_baseline = get_reinforce_baseline(inner_baseline, **kw)
        return WarmupBaseline(inner_baseline, **kw)
    elif name == "rollout":
        warmup_epochs = kw.get("n_epochs", 1)
        warmup_exp_beta = kw.get("exp_beta", 0.8)
        bl_alpha = kw.get("bl_alpha", 0.05)
        return WarmupBaseline(
            RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta
        )

    if name is None:
        name = "no"  # default to no baseline
    baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None)
    if baseline_cls is None:
        raise ValueError(
            f"Unknown baseline {baseline_cls}. Available baselines: {REINFORCE_BASELINES_REGISTRY.keys()}"
        )
    return baseline_cls(**kw)