Skip to content

Transductive Methods

Transductive Methods

These methods update policy parameters during online testing to improve the solutions of a specific instance.

Active Search (AS)

Classes:

  • ActiveSearch

    Active Search for Neural Combination Optimization from Bello et al. (2016).

ActiveSearch

ActiveSearch(
    env,
    policy,
    dataset: Dataset | str,
    batch_size: int = 1,
    max_iters: int = 200,
    augment_size: int = 8,
    augment_dihedral: bool = True,
    num_parallel_runs: int = 1,
    max_runtime: int = 86400,
    save_path: str = None,
    optimizer: str | Optimizer | partial = "Adam",
    optimizer_kwargs: dict = {
        "lr": 0.00026,
        "weight_decay": 1e-06,
    },
    **kwargs
)

Bases: TransductiveModel

Active Search for Neural Combination Optimization from Bello et al. (2016). Fine-tunes the whole policy network (encoder + decoder) on a batch of instances. Reference: https://arxiv.org/abs/1611.09940

Parameters:

  • env

    RL4CO environment to be solved

  • policy

    policy network

  • dataset (Dataset | str) –

    dataset to be used for training

  • batch_size (int, default: 1 ) –

    batch size for training

  • max_iters (int, default: 200 ) –

    maximum number of iterations

  • augment_size (int, default: 8 ) –

    number of augmentations per state

  • augment_dihedral (bool, default: True ) –

    whether to augment with dihedral rotations

  • parallel_runs

    number of parallel runs

  • max_runtime (int, default: 86400 ) –

    maximum runtime in seconds

  • save_path (str, default: None ) –

    path to save solution checkpoints

  • optimizer (str | Optimizer | partial, default: 'Adam' ) –

    optimizer to use for training

  • optimizer_kwargs (dict, default: {'lr': 0.00026, 'weight_decay': 1e-06} ) –

    keyword arguments for optimizer

  • **kwargs

    additional keyword arguments

Methods:

Source code in rl4co/models/zoo/active_search/search.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def __init__(
    self,
    env,
    policy,
    dataset: Dataset | str,
    batch_size: int = 1,
    max_iters: int = 200,
    augment_size: int = 8,
    augment_dihedral: bool = True,
    num_parallel_runs: int = 1,
    max_runtime: int = 86_400,
    save_path: str = None,
    optimizer: str | torch.optim.Optimizer | partial = "Adam",
    optimizer_kwargs: dict = {"lr": 2.6e-4, "weight_decay": 1e-6},
    **kwargs,
):
    self.save_hyperparameters(logger=False)

    assert batch_size == 1, "Batch size must be 1 for active search"

    super(ActiveSearch, self).__init__(
        env,
        policy=policy,
        dataset=dataset,
        batch_size=batch_size,
        max_iters=max_iters,
        max_runtime=max_runtime,
        save_path=save_path,
        optimizer=optimizer,
        optimizer_kwargs=optimizer_kwargs,
        **kwargs,
    )

setup

setup(stage='fit')

Setup base class and instantiate:

  • augmentation
  • instance solutions and rewards
  • original policy state dict
Source code in rl4co/models/zoo/active_search/search.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def setup(self, stage="fit"):
    """Setup base class and instantiate:
    - augmentation
    - instance solutions and rewards
    - original policy state dict
    """
    log.info("Setting up active search...")
    super(ActiveSearch, self).setup(stage)

    # Instantiate augmentation
    self.augmentation = StateAugmentation(
        num_augment=self.hparams.augment_size,
        augment_fn="dihedral8" if self.hparams.augment_dihedral else "symmetric",
    )

    # Store original policy state dict
    self.original_policy_state = self.policy.state_dict()

    # Get dataset size and problem size
    dataset_size = len(self.dataset)
    _batch = next(iter(self.train_dataloader()))
    self.problem_size = self.env.reset(_batch)["action_mask"].shape[-1]
    self.instance_solutions = torch.zeros(
        dataset_size, self.problem_size * 2, dtype=int
    )
    self.instance_rewards = torch.zeros(dataset_size)

on_train_batch_start

on_train_batch_start(batch: Any, batch_idx: int)

Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure the optimizer.

Source code in rl4co/models/zoo/active_search/search.py
100
101
102
103
104
105
def on_train_batch_start(self, batch: Any, batch_idx: int):
    """Called before training (i.e. search) for a new batch begins.
    We re-load the original policy state dict and configure the optimizer.
    """
    self.policy.load_state_dict(self.original_policy_state)
    self.configure_optimizers(self.policy.parameters())

training_step

training_step(batch, batch_idx)

Main search loop. We use the training step to effectively adapt to a batch of instances.

Source code in rl4co/models/zoo/active_search/search.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def training_step(self, batch, batch_idx):
    """Main search loop. We use the training step to effectively adapt to a `batch` of instances."""
    # Augment state
    batch_size = batch.shape[0]
    td_init = self.env.reset(batch)
    n_aug, n_start, n_runs = (
        self.augmentation.num_augment,
        self.env.get_num_starts(td_init),
        self.hparams.num_parallel_runs,
    )
    td_init = self.augmentation(td_init)
    td_init = batchify(td_init, n_runs)

    # Solution and reward buffer
    max_reward = torch.full((batch_size,), -float("inf"), device=batch.device)
    best_solutions = torch.zeros(
        batch_size, self.problem_size * 2, device=batch.device, dtype=int
    )

    # Init search
    t_start = time.time()
    for i in range(self.hparams.max_iters):
        # Evaluate policy with sampling multistarts (as in POMO)
        out = self.policy(
            td_init.clone(),
            env=self.env,
            decode_type="multistart_sampling",
            num_starts=n_start,
        )

        if i == 0:
            log.info(f"Initial reward: {out['reward'].max():.2f}")

        # Update best solution and reward found
        max_reward_iter = out["reward"].max()
        if max_reward_iter > max_reward:
            max_reward_idx = out["reward"].argmax()
            best_solution_iter = out["actions"][max_reward_idx]
            max_reward = max_reward_iter
            best_solutions[0, : best_solution_iter.shape[0]] = best_solution_iter

        # Compute REINFORCE loss with shared baseline
        reward = unbatchify(out["reward"], (n_runs, n_aug, n_start))
        ll = unbatchify(out["log_likelihood"], (n_runs, n_aug, n_start))
        advantage = reward - reward.mean(dim=-1, keepdim=True)
        loss = -(advantage * ll).mean()

        # Backpropagate loss
        # perform manual optimization following the Lightning routine
        # https://lightning.ai/docs/pytorch/stable/common/optimization.html
        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)

        self.log_dict(
            {
                "loss": loss,
                "max_reward": max_reward,
                "step": i,
                "time": time.time() - t_start,
            },
            on_step=self.log_on_step,
        )

        # Stop if max runtime is exceeded
        if time.time() - t_start > self.hparams.max_runtime:
            break

    return {"max_reward": max_reward, "best_solutions": best_solutions}

on_train_batch_end

on_train_batch_end(
    outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None

We store the best solution and reward found.

Source code in rl4co/models/zoo/active_search/search.py
177
178
179
180
181
182
183
184
185
186
def on_train_batch_end(
    self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
    """We store the best solution and reward found."""
    max_rewards, best_solutions = outputs["max_reward"], outputs["best_solutions"]
    self.instance_rewards[batch_idx] = max_rewards
    self.instance_solutions[batch_idx, :] = best_solutions.squeeze(
        0
    )  # only one instance
    log.info(f"Best reward: {max_rewards.mean():.2f}")

on_train_epoch_end

on_train_epoch_end() -> None

Called when the training ends. If the epoch ends, it means we have finished searching over the instances, thus the trainer should stop.

Source code in rl4co/models/zoo/active_search/search.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def on_train_epoch_end(self) -> None:
    """Called when the training ends.
    If the epoch ends, it means we have finished searching over the
    instances, thus the trainer should stop.
    """
    save_path = self.hparams.save_path
    if save_path is not None:
        log.info(f"Saving solutions and rewards to {save_path}...")
        torch.save(
            {"solutions": self.instance_solutions, "rewards": self.instance_rewards},
            save_path,
        )

    # https://github.com/Lightning-AI/lightning/issues/1406
    self.trainer.should_stop = True

Efficent Active Search (EAS)

Classes:

  • EAS

    Efficient Active Search for Neural Combination Optimization from Hottung et al. (2022).

  • EASEmb

    EAS with embedding adaptation

  • EASLay

    EAS with layer adaptation

EAS

EAS(
    env,
    policy,
    dataset: Dataset | str,
    use_eas_embedding: bool = True,
    use_eas_layer: bool = False,
    eas_emb_cache_keys: list[str] = ["logit_key"],
    eas_lambda: float = 0.013,
    batch_size: int = 2,
    max_iters: int = 200,
    augment_size: int = 8,
    augment_dihedral: bool = True,
    num_parallel_runs: int = 1,
    baseline: str = "multistart",
    max_runtime: int = 86400,
    save_path: str = None,
    optimizer: str | Optimizer | partial = "Adam",
    optimizer_kwargs: dict = {
        "lr": 0.0041,
        "weight_decay": 1e-06,
    },
    verbose: bool = True,
    **kwargs
)

Bases: TransductiveModel

Efficient Active Search for Neural Combination Optimization from Hottung et al. (2022). Fine-tunes a subset of parameters (such as node embeddings or newly added layers) thus avoiding expensive re-encoding of the problem. Reference: https://openreview.net/pdf?id=nO5caZwFwYu

Parameters:

  • env

    RL4CO environment to be solved

  • policy

    policy network

  • dataset (Dataset | str) –

    dataset to be used for training

  • use_eas_embedding (bool, default: True ) –

    whether to use EAS embedding (EASEmb)

  • use_eas_layer (bool, default: False ) –

    whether to use EAS layer (EASLay)

  • eas_emb_cache_keys (list[str], default: ['logit_key'] ) –

    keys to cache in the embedding

  • eas_lambda (float, default: 0.013 ) –

    lambda parameter for IL loss

  • batch_size (int, default: 2 ) –

    batch size for training

  • max_iters (int, default: 200 ) –

    maximum number of iterations

  • augment_size (int, default: 8 ) –

    number of augmentations per state

  • augment_dihedral (bool, default: True ) –

    whether to augment with dihedral rotations

  • parallel_runs

    number of parallel runs

  • baseline (str, default: 'multistart' ) –

    REINFORCE baseline type (multistart, symmetric, full)

  • max_runtime (int, default: 86400 ) –

    maximum runtime in seconds

  • save_path (str, default: None ) –

    path to save solution checkpoints

  • optimizer (str | Optimizer | partial, default: 'Adam' ) –

    optimizer to use for training

  • optimizer_kwargs (dict, default: {'lr': 0.0041, 'weight_decay': 1e-06} ) –

    keyword arguments for optimizer

  • verbose (bool, default: True ) –

    whether to print progress for each iteration

Methods:

Source code in rl4co/models/zoo/eas/search.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    env,
    policy,
    dataset: Dataset | str,
    use_eas_embedding: bool = True,
    use_eas_layer: bool = False,
    eas_emb_cache_keys: list[str] = ["logit_key"],
    eas_lambda: float = 0.013,
    batch_size: int = 2,
    max_iters: int = 200,
    augment_size: int = 8,
    augment_dihedral: bool = True,
    num_parallel_runs: int = 1,
    baseline: str = "multistart",
    max_runtime: int = 86_400,
    save_path: str = None,
    optimizer: str | torch.optim.Optimizer | partial = "Adam",
    optimizer_kwargs: dict = {"lr": 0.0041, "weight_decay": 1e-6},
    verbose: bool = True,
    **kwargs,
):
    self.save_hyperparameters(logger=False)

    assert (
        self.hparams.use_eas_embedding or self.hparams.use_eas_layer
    ), "At least one of `use_eas_embedding` or `use_eas_layer` must be True."

    super(EAS, self).__init__(
        env,
        policy=policy,
        dataset=dataset,
        batch_size=batch_size,
        max_iters=max_iters,
        max_runtime=max_runtime,
        save_path=save_path,
        optimizer=optimizer,
        optimizer_kwargs=optimizer_kwargs,
        **kwargs,
    )

    assert self.hparams.baseline in [
        "multistart",
        "symmetric",
        "full",
    ], f"Baseline {self.hparams.baseline} not supported."

setup

setup(stage='fit')

Setup base class and instantiate:

  • augmentation
  • instance solutions and rewards
  • original policy state dict
Source code in rl4co/models/zoo/eas/search.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def setup(self, stage="fit"):
    """Setup base class and instantiate:
    - augmentation
    - instance solutions and rewards
    - original policy state dict
    """
    log.info(
        f"Setting up Efficient Active Search (EAS) with: \n"
        f"- EAS Embedding: {self.hparams.use_eas_embedding} \n"
        f"- EAS Layer: {self.hparams.use_eas_layer} \n"
    )
    super(EAS, self).setup(stage)

    # Instantiate augmentation
    self.augmentation = StateAugmentation(
        num_augment=self.hparams.augment_size,
        augment_fn="dihedral8" if self.hparams.augment_dihedral else "symmetric",
    )

    # Store original policy state dict
    self.original_policy_state = self.policy.state_dict()

    # Get dataset size and problem size
    len(self.dataset)
    _batch = next(iter(self.train_dataloader()))
    self.problem_size = self.env.reset(_batch)["action_mask"].shape[-1]
    self.instance_solutions = []
    self.instance_rewards = []

on_train_batch_start

on_train_batch_start(batch: Any, batch_idx: int)

Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure all parameters not to require gradients. We do the rest in the training step.

Source code in rl4co/models/zoo/eas/search.py
126
127
128
129
130
131
132
133
134
135
def on_train_batch_start(self, batch: Any, batch_idx: int):
    """Called before training (i.e. search) for a new batch begins.
    We re-load the original policy state dict and configure all parameters not to require gradients.
    We do the rest in the training step.
    """
    self.policy.load_state_dict(self.original_policy_state)

    # Set all policy parameters to not require gradients
    for param in self.policy.parameters():
        param.requires_grad = False

training_step

training_step(batch, batch_idx)

Main search loop. We use the training step to effectively adapt to a batch of instances.

Source code in rl4co/models/zoo/eas/search.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def training_step(self, batch, batch_idx):
    """Main search loop. We use the training step to effectively adapt to a `batch` of instances."""
    # Augment state
    batch_size = batch.shape[0]
    td_init = self.env.reset(batch)
    n_aug, n_start, n_runs = (
        self.augmentation.num_augment,
        self.env.get_num_starts(td_init),
        self.hparams.num_parallel_runs,
    )
    td_init = self.augmentation(td_init)
    td_init = batchify(td_init, n_runs)
    num_instances = batch_size * n_aug * n_runs  # NOTE: no num_starts!
    # batch_r = n_runs * batch_size # effective batch size
    group_s = (
        n_start + 1
    )  # number of different rollouts per instance (+1 for incumbent solution construction)

    # Get encoder and decoder for simplicity
    encoder = self.policy.encoder
    decoder = self.policy.decoder

    # Precompute the cache of the embeddings (i.e. q,k,v and logit_key)
    embeddings, _ = encoder(td_init)
    cached_embeds = decoder._precompute_cache(embeddings)

    # Collect optimizer parameters
    opt_params = []
    if self.hparams.use_eas_layer:
        # EASLay: replace forward of logit attention computation. EASLayer
        eas_layer = EASLayerNet(num_instances, decoder.embed_dim).to(batch.device)
        decoder.pointer.eas_layer = partial(eas_layer, decoder.pointer)
        decoder.pointer.forward = partial(
            forward_pointer_attn_eas_lay, decoder.pointer
        )
        for param in eas_layer.parameters():
            opt_params.append(param)
    if self.hparams.use_eas_embedding:
        # EASEmb: set gradient of emb_key to True
        # for all the keys, wrap the embedding in a nn.Parameter
        for key in self.hparams.eas_emb_cache_keys:
            setattr(
                cached_embeds, key, torch.nn.Parameter(getattr(cached_embeds, key))
            )
            opt_params.append(getattr(cached_embeds, key))
    decoder.forward_eas = partial(forward_eas, decoder)

    # We pass attributes saved in policy too
    def set_attr_if_exists(attr):
        if hasattr(self.policy, attr):
            setattr(decoder, attr, getattr(self.policy, attr))

    for attr in ["temperature", "tanh_clipping", "mask_logits"]:
        set_attr_if_exists(attr)

    self.configure_optimizers(opt_params)

    # Solution and reward buffer
    max_reward = torch.full((batch_size,), -float("inf"), device=batch.device)
    best_solutions = torch.zeros(
        batch_size, self.problem_size * 2, device=batch.device, dtype=int
    )  # i.e. incumbent solutions

    # Init search
    t_start = time.time()
    for iter_count in range(self.hparams.max_iters):
        # Evaluate policy with sampling multistarts passing the cached embeddings
        best_solutions_expanded = best_solutions.repeat(n_aug, 1).repeat(n_runs, 1)
        logprobs, actions, td_out, reward = decoder.forward_eas(
            td_init.clone(),
            cached_embeds=cached_embeds,
            best_solutions=best_solutions_expanded,
            iter_count=iter_count,
            env=self.env,
            decode_type="multistart_sampling",
            num_starts=n_start,
        )

        # Unbatchify to get correct dimensions
        ll = get_log_likelihood(logprobs, actions, td_out.get("mask", None))
        ll = unbatchify(ll, (n_runs * batch_size, n_aug, group_s)).squeeze()
        reward = unbatchify(reward, (n_runs * batch_size, n_aug, group_s)).squeeze()
        actions = unbatchify(actions, (n_runs * batch_size, n_aug, group_s)).squeeze()

        # Compute REINFORCE loss with shared baselines
        # compared to original EAS, we also support symmetric and full baselines
        group_reward = reward[..., :-1]  # exclude incumbent solution
        if self.hparams.baseline == "multistart":
            bl_val = group_reward.mean(dim=-1, keepdim=True)
        elif self.hparams.baseline == "symmetric":
            bl_val = group_reward.mean(dim=-2, keepdim=True)
        elif self.hparams.baseline == "full":
            bl_val = group_reward.mean(dim=-1, keepdim=True).mean(
                dim=-2, keepdim=True
            )
        else:
            raise ValueError(f"Baseline {self.hparams.baseline} not supported.")

        # REINFORCE loss
        advantage = group_reward - bl_val
        loss_rl = -(advantage * ll[..., :-1]).mean()
        # IL loss
        loss_il = -ll[..., -1].mean()
        # Total loss
        loss = loss_rl + self.hparams.eas_lambda * loss_il

        # Manual backpropagation
        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)

        # Save best solutions and rewards
        # Get max reward for each group and instance
        max_reward = reward.max(dim=2)[0].max(dim=1)[0]

        # Reshape and rank rewards
        reward_group = reward.reshape(n_runs * batch_size, -1)
        _, top_indices = torch.topk(reward_group, k=1, dim=1)

        # Obtain best solutions found so far
        solutions = actions.reshape(n_runs * batch_size, n_aug * group_s, -1)
        best_solutions_iter = gather_by_index(solutions, top_indices, dim=1)
        best_solutions[:, : best_solutions_iter.shape[1]] = best_solutions_iter

        self.log_dict(
            {
                "loss": loss,
                "max_reward": max_reward.mean(),
                "step": iter_count,
                "time": time.time() - t_start,
            },
            on_step=self.log_on_step,
        )

        log.info(
            f"{iter_count}/{self.hparams.max_iters} | "
            f" Reward: {max_reward.mean().item():.2f} "
        )

        # Stop if max runtime is exceeded
        if time.time() - t_start > self.hparams.max_runtime:
            log.info(f"Max runtime of {self.hparams.max_runtime} seconds exceeded.")
            break

    return {"max_reward": max_reward, "best_solutions": best_solutions}

on_train_batch_end

on_train_batch_end(
    outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None

We store the best solution and reward found.

Source code in rl4co/models/zoo/eas/search.py
283
284
285
286
287
288
289
290
def on_train_batch_end(
    self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
    """We store the best solution and reward found."""
    max_rewards, best_solutions = outputs["max_reward"], outputs["best_solutions"]
    self.instance_solutions.append(best_solutions)
    self.instance_rewards.append(max_rewards)
    log.info(f"Best reward: {max_rewards.mean():.2f}")

on_train_epoch_end

on_train_epoch_end() -> None

Called when the train ends.

Source code in rl4co/models/zoo/eas/search.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def on_train_epoch_end(self) -> None:
    """Called when the train ends."""
    save_path = self.hparams.save_path
    # concatenate solutions and rewards
    self.instance_solutions = pad_sequence(
        self.instance_solutions, batch_first=True, padding_value=0
    ).squeeze()
    self.instance_rewards = torch.cat(self.instance_rewards, dim=0).squeeze()
    if save_path is not None:
        log.info(f"Saving solutions and rewards to {save_path}...")
        torch.save(
            {"solutions": self.instance_solutions, "rewards": self.instance_rewards},
            save_path,
        )

    # https://github.com/Lightning-AI/lightning/issues/1406
    self.trainer.should_stop = True

EASEmb

EASEmb(*args, **kwargs)

Bases: EAS

EAS with embedding adaptation

Source code in rl4co/models/zoo/eas/search.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def __init__(
    self,
    *args,
    **kwargs,
):
    if not kwargs.get("use_eas_embedding", False) or kwargs.get(
        "use_eas_layer", True
    ):
        log.warning(
            "Setting `use_eas_embedding` to True and `use_eas_layer` to False. Use EAS base class to override."
        )
    kwargs["use_eas_embedding"] = True
    kwargs["use_eas_layer"] = False
    super(EASEmb, self).__init__(*args, **kwargs)

EASLay

EASLay(*args, **kwargs)

Bases: EAS

EAS with layer adaptation

Source code in rl4co/models/zoo/eas/search.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def __init__(
    self,
    *args,
    **kwargs,
):
    if kwargs.get("use_eas_embedding", False) or not kwargs.get(
        "use_eas_layer", True
    ):
        log.warning(
            "Setting `use_eas_embedding` to True and `use_eas_layer` to False. Use EAS base class to override."
        )
    kwargs["use_eas_embedding"] = False
    kwargs["use_eas_layer"] = True
    super(EASLay, self).__init__(*args, **kwargs)

Functions:

forward_pointer_attn_eas_lay

forward_pointer_attn_eas_lay(
    self, query, key, value, logit_key, mask
)

Add layer to the forward pass of logit attention, i.e. Single-head attention.

Source code in rl4co/models/zoo/eas/decoder.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def forward_pointer_attn_eas_lay(self, query, key, value, logit_key, mask):
    """Add layer to the forward pass of logit attention, i.e.
    Single-head attention.
    """
    # Compute inner multi-head attention with no projections.
    heads = self._inner_mha(query, key, value, mask)

    # Add residual for EAS layer if is set
    if getattr(self, "eas_layer", None) is not None:
        heads = heads + self.eas_layer(heads)

    glimpse = self.project_out(heads)

    # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
    # bmm is slightly faster than einsum and matmul
    logits = (
        torch.bmm(glimpse, logit_key.squeeze(1).transpose(-2, -1))
        / math.sqrt(glimpse.size(-1))
    ).squeeze(1)

    return logits

forward_eas

forward_eas(
    self,
    td: TensorDict,
    cached_embeds,
    best_solutions,
    iter_count: int = 0,
    env: str | RL4COEnvBase = None,
    decode_type: str = "multistart_sampling",
    num_starts: int = None,
    mask_logits: bool = True,
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    **decode_kwargs
)

Forward pass of the decoder Given the environment state and the pre-computed embeddings, compute the logits and sample actions

Parameters:

  • td (TensorDict) –

    Input TensorDict containing the environment state

  • embeddings

    Precomputed embeddings for the nodes. Can be already precomputed cached in form of q, k, v and

  • env (str | RL4COEnvBase, default: None ) –

    Environment to use for decoding. If None, the environment is instantiated from env_name. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control

  • decode_type (str, default: 'multistart_sampling' ) –

    Type of decoding to use. Can be one of:

    • "sampling": sample from the logits
    • "greedy": take the argmax of the logits
    • "multistart_sampling": sample as sampling, but with multi-start decoding
    • "multistart_greedy": sample as greedy, but with multi-start decoding
  • num_starts (int, default: None ) –

    Number of multi-starts to use. If None, will be calculated from the action mask

  • calc_reward

    Whether to calculate the reward for the decoded sequence

Source code in rl4co/models/zoo/eas/decoder.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def forward_eas(
    self,
    td: TensorDict,
    cached_embeds,
    best_solutions,
    iter_count: int = 0,
    env: str | RL4COEnvBase = None,
    decode_type: str = "multistart_sampling",
    num_starts: int = None,
    mask_logits: bool = True,
    temperature: float = 1.0,
    tanh_clipping: float = 0,
    **decode_kwargs,
):
    """Forward pass of the decoder
    Given the environment state and the pre-computed embeddings, compute the logits and sample actions

    Args:
        td: Input TensorDict containing the environment state
        embeddings: Precomputed embeddings for the nodes. Can be already precomputed cached in form of q, k, v and
        env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that
            it is more efficient to pass an already instantiated environment each time for fine-grained control
        decode_type: Type of decoding to use. Can be one of:
            - "sampling": sample from the logits
            - "greedy": take the argmax of the logits
            - "multistart_sampling": sample as sampling, but with multi-start decoding
            - "multistart_greedy": sample as greedy, but with multi-start decoding
        num_starts: Number of multi-starts to use. If None, will be calculated from the action mask
        calc_reward: Whether to calculate the reward for the decoded sequence
    """
    # TODO: this could be refactored by decoding strategies

    # Collect logprobs
    logprobs = []
    actions = []

    decode_step = 0
    # Multi-start decoding: first action is chosen by ad-hoc node selection
    if num_starts > 1 or "multistart" in decode_type:
        action = env.select_start_nodes(td, num_starts + 1) % num_starts
        # Append incumbent solutions
        if iter_count > 0:
            action = unbatchify(action, num_starts + 1)
            action[:, -1] = best_solutions[:, decode_step]
            action = action.permute(1, 0).reshape(-1)

        # Expand td to batch_size * (num_starts + 1)
        td = batchify(td, num_starts + 1)

        td.set("action", action)
        td = env.step(td)["next"]
        logp = torch.zeros_like(
            td["action_mask"], device=td.device
        )  # first logprobs is 0, so p = logprobs.exp() = 1

        logprobs.append(logp)
        actions.append(action)

    # Main decoding: loop until all sequences are done
    while not td["done"].all():
        decode_step += 1
        logits, mask = self.forward(td, cached_embeds, num_starts + 1)

        logp = process_logits(
            logits,
            mask,
            temperature=self.temperature if self.temperature is not None else temperature,
            tanh_clipping=(
                self.tanh_clipping if self.tanh_clipping is not None else tanh_clipping
            ),
            mask_logits=self.mask_logits if self.mask_logits is not None else mask_logits,
        )

        # Select the indices of the next nodes in the sequences, result (batch_size) long
        action = decode_logprobs(logp, mask, decode_type=decode_type)

        if iter_count > 0:  # append incumbent solutions
            init_shp = action.shape
            action = unbatchify(action, num_starts + 1)
            action[:, -1] = best_solutions[:, decode_step]
            action = action.permute(1, 0).reshape(init_shp)

        td.set("action", action)
        td = env.step(td)["next"]

        # Collect output of step
        logprobs.append(logp)
        actions.append(action)

    logprobs, actions = torch.stack(logprobs, 1), torch.stack(actions, 1)
    rewards = env.get_reward(td, actions)
    return logprobs, actions, td, rewards

Classes:

  • EASLayerNet

    Instantiate weights and biases for the added layer.

EASLayerNet

EASLayerNet(num_instances: int, emb_dim: int)

Bases: Module

Instantiate weights and biases for the added layer. The layer is defined as: h = relu(emb * W1 + b1); out = h * W2 + b2. Wrapping in nn.Parameter makes the parameters trainable and sets gradient to True.

Parameters:

  • num_instances (int) –

    Number of instances in the dataset

  • emb_dim (int) –

    Dimension of the embedding

Methods:

  • forward

    emb: [num_instances, group_num, emb_dim]

Source code in rl4co/models/zoo/eas/nn.py
15
16
17
18
19
20
21
22
23
def __init__(self, num_instances: int, emb_dim: int):
    super().__init__()
    # W2 and b2 are initialized to zero so in the first iteration the layer is identity
    self.W1 = nn.Parameter(torch.randn(num_instances, emb_dim, emb_dim))
    self.b1 = nn.Parameter(torch.randn(num_instances, 1, emb_dim))
    self.W2 = nn.Parameter(torch.zeros(num_instances, emb_dim, emb_dim))
    self.b2 = nn.Parameter(torch.zeros(num_instances, 1, emb_dim))
    torch.nn.init.xavier_uniform_(self.W1)
    torch.nn.init.xavier_uniform_(self.b1)

forward

forward(*args)

emb: [num_instances, group_num, emb_dim]

Source code in rl4co/models/zoo/eas/nn.py
25
26
27
28
29
30
def forward(self, *args):
    """emb: [num_instances, group_num, emb_dim]"""
    # get tensor arg (from partial instantiation)
    emb = [arg for arg in args if isinstance(arg, torch.Tensor)][0]
    h = torch.relu(torch.matmul(emb, self.W1) + self.b1.expand_as(emb))
    return torch.matmul(h, self.W2) + self.b2.expand_as(h)