Skip to content

PPO

PPO

PPO(
    env: RL4COEnvBase,
    policy: Module,
    critic: CriticNetwork = None,
    critic_kwargs: dict = {},
    clip_range: float = 0.2,
    ppo_epochs: int = 2,
    mini_batch_size: int | float = 0.25,
    vf_lambda: float = 0.5,
    entropy_lambda: float = 0.0,
    normalize_adv: bool = False,
    max_grad_norm: float = 0.5,
    metrics: dict = {
        "train": [
            "reward",
            "loss",
            "surrogate_loss",
            "value_loss",
            "entropy",
        ]
    },
    **kwargs
)

Bases: RL4COLitModule

An implementation of the Proximal Policy Optimization (PPO) algorithm (https://arxiv.org/abs/1707.06347) is presented with modifications for autoregressive decoding schemes.

In contrast to the original PPO algorithm, this implementation does not consider autoregressive decoding steps as part of the MDP transition. While many Neural Combinatorial Optimization (NCO) studies model decoding steps as transitions in a solution-construction MDP, we treat autoregressive solution construction as an algorithmic choice for tractable CO solution generation. This choice aligns with the Attention Model (AM) (https://openreview.net/forum?id=ByxBFsRqYm), which treats decoding steps as a single-step MDP in Equation 9.

Modeling autoregressive decoding steps as a single-step MDP introduces significant changes to the PPO implementation, including:

  • Generalized Advantage Estimation (GAE) (https://arxiv.org/abs/1506.02438) is not applicable since we are dealing with a single-step MDP.
  • The definition of policy entropy can differ from the commonly implemented manner.

The commonly implemented definition of policy entropy is the entropy of the policy distribution, given by:

\[H(\pi(x_t)) = - \sum_{a_t \in A_t} \pi(a_t|x_t) \log \pi(a_t|x_t)\]

where \(x_t\) represents the given state at step \(t\), \(A_t\) is the set of all (admisible) actions at step \(t\), and \(a_t\) is the action taken at step \(t\).

If we interpret autoregressive decoding steps as transition steps of an MDP, the entropy for the entire decoding process can be defined as the sum of entropies for each decoding step:

\[H(\pi) = \sum_t H(\pi(x_t))\]

However, if we consider autoregressive decoding steps as an algorithmic choice, the entropy for the entire decoding process is defined as:

\[H(\pi) = - \sum_{a \in A} \pi(a|x) \log \pi(a|x)\]

where \(x\) represents the given CO problem instance, and \(A\) is the set of all feasible solutions.

Due to the intractability of computing the entropy of the policy distribution over all feasible solutions, we approximate it by computing the entropy over solutions generated by the policy itself. This approximation serves as a proxy for the second definition of entropy, utilizing Monte Carlo sampling.

It is worth noting that our modeling of decoding steps and the implementation of the PPO algorithm align with recent work in the Natural Language Processing (NLP) community, specifically RL with Human Feedback (RLHF) (e.g., https://github.com/lucidrains/PaLM-rlhf-pytorch).

Methods:

Source code in rl4co/models/rl/ppo/ppo.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    env: RL4COEnvBase,
    policy: nn.Module,
    critic: CriticNetwork = None,
    critic_kwargs: dict = {},
    clip_range: float = 0.2,  # epsilon of PPO
    ppo_epochs: int = 2,  # inner epoch, K
    mini_batch_size: int | float = 0.25,  # 0.25,
    vf_lambda: float = 0.5,  # lambda of Value function fitting
    entropy_lambda: float = 0.0,  # lambda of entropy bonus
    normalize_adv: bool = False,  # whether to normalize advantage
    max_grad_norm: float = 0.5,  # max gradient norm
    metrics: dict = {
        "train": ["reward", "loss", "surrogate_loss", "value_loss", "entropy"],
    },
    **kwargs,
):
    super().__init__(env, policy, metrics=metrics, **kwargs)
    self.automatic_optimization = False  # PPO uses custom optimization routine

    if critic is None:
        log.info("Creating critic network for {}".format(env.name))
        critic = create_critic_from_actor(policy, **critic_kwargs)
    self.critic = critic

    if isinstance(mini_batch_size, float) and (
        mini_batch_size <= 0 or mini_batch_size > 1
    ):
        default_mini_batch_fraction = 0.25
        log.warning(
            f"mini_batch_size must be an integer or a float in the range (0, 1], got {mini_batch_size}. Setting mini_batch_size to {default_mini_batch_fraction}."
        )
        mini_batch_size = default_mini_batch_fraction

    if isinstance(mini_batch_size, int) and (mini_batch_size <= 0):
        default_mini_batch_size = 128
        log.warning(
            f"mini_batch_size must be an integer or a float in the range (0, 1], got {mini_batch_size}. Setting mini_batch_size to {default_mini_batch_size}."
        )
        mini_batch_size = default_mini_batch_size

    self.ppo_cfg = {
        "clip_range": clip_range,
        "ppo_epochs": ppo_epochs,
        "mini_batch_size": mini_batch_size,
        "vf_lambda": vf_lambda,
        "entropy_lambda": entropy_lambda,
        "normalize_adv": normalize_adv,
        "max_grad_norm": max_grad_norm,
    }

on_train_epoch_end

on_train_epoch_end()

ToDo: Add support for other schedulers.

Source code in rl4co/models/rl/ppo/ppo.py
117
118
119
120
121
122
123
124
125
126
def on_train_epoch_end(self):
    """
    ToDo: Add support for other schedulers.
    """

    sch = self.lr_schedulers()

    # If the selected scheduler is a MultiStepLR scheduler.
    if isinstance(sch, torch.optim.lr_scheduler.MultiStepLR):
        sch.step()