Skip to content

Train and Evaluation

Train

run

run(cfg: DictConfig) -> Tuple[dict, dict]

Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

Parameters:

  • cfg (DictConfig) –

    Configuration composed by Hydra.

Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.

Source code in rl4co/tasks/train.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@utils.task_wrapper
def run(cfg: DictConfig) -> Tuple[dict, dict]:
    """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
    training.
    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    Args:
        cfg (DictConfig): Configuration composed by Hydra.
    Returns:
        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
    """

    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)

    # We instantiate the environment separately and then pass it to the model
    log.info(f"Instantiating environment <{cfg.env._target_}>")
    env = hydra.utils.instantiate(cfg.env)

    # Note that the RL environment is instantiated inside the model
    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(cfg.model, env)

    log.info("Instantiating callbacks...")
    callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))

    log.info("Instantiating loggers...")
    logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"), model)

    log.info("Instantiating trainer...")
    trainer: RL4COTrainer = hydra.utils.instantiate(
        cfg.trainer,
        callbacks=callbacks,
        logger=logger,
    )

    object_dict = {
        "cfg": cfg,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters!")
        utils.log_hyperparameters(object_dict)

    if cfg.get("compile", False):
        log.info("Compiling model!")
        model = torch.compile(model)

    if cfg.get("train"):
        log.info("Starting training!")
        trainer.fit(model=model, ckpt_path=cfg.get("ckpt_path"))

        train_metrics = trainer.callback_metrics

    if cfg.get("test"):
        log.info("Starting testing!")
        ckpt_path = trainer.checkpoint_callback.best_model_path
        if ckpt_path == "":
            log.warning("Best ckpt not found! Using current weights for testing...")
            ckpt_path = None
        trainer.test(model=model, ckpt_path=ckpt_path)
        log.info(f"Best ckpt path: {ckpt_path}")

    test_metrics = trainer.callback_metrics

    # merge train and test metrics
    metric_dict = {**train_metrics, **test_metrics}

    return metric_dict, object_dict

Evaluate

EvalBase

EvalBase(env, progress=True, **kwargs)

Base class for evaluation

Parameters:

  • env

    Environment

  • progress

    Whether to show progress bar

  • **kwargs

    Additional arguments (to be implemented in subclasses)

Source code in rl4co/tasks/eval.py
29
30
31
32
def __init__(self, env, progress=True, **kwargs):
    check_unused_kwargs(self, kwargs)
    self.env = env
    self.progress = progress

GreedyEval

GreedyEval(env, **kwargs)

Bases: EvalBase

Evaluates the policy using greedy decoding and single trajectory

Source code in rl4co/tasks/eval.py
93
94
95
def __init__(self, env, **kwargs):
    check_unused_kwargs(self, kwargs)
    super().__init__(env, kwargs.get("progress", True))

AugmentationEval

AugmentationEval(
    env,
    num_augment=8,
    force_dihedral_8=False,
    feats=None,
    **kwargs
)

Bases: EvalBase

Evaluates the policy via N state augmentations force_dihedral_8 forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8

Parameters:

  • num_augment (int, default: 8 ) –

    Number of state augmentations

  • force_dihedral_8 (bool, default: False ) –

    Whether to force the use of 8 augmentations

Source code in rl4co/tasks/eval.py
120
121
122
123
124
125
126
127
def __init__(self, env, num_augment=8, force_dihedral_8=False, feats=None, **kwargs):
    check_unused_kwargs(self, kwargs)
    super().__init__(env, kwargs.get("progress", True))
    self.augmentation = StateAugmentation(
        num_augment=num_augment,
        augment_fn="dihedral8" if force_dihedral_8 else "symmetric",
        feats=feats,
    )

SamplingEval

SamplingEval(
    env,
    samples,
    softmax_temp=None,
    select_best=True,
    temperature=1.0,
    top_p=0.0,
    top_k=0,
    **kwargs
)

Bases: EvalBase

Evaluates the policy via N samples from the policy

Parameters:

  • samples (int) –

    Number of samples to take

  • softmax_temp (float, default: None ) –

    Temperature for softmax sampling. The higher the temperature, the more random the sampling

Source code in rl4co/tasks/eval.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def __init__(
    self,
    env,
    samples,
    softmax_temp=None,
    select_best=True,
    temperature=1.0,
    top_p=0.0,
    top_k=0,
    **kwargs,
):
    check_unused_kwargs(self, kwargs)
    super().__init__(env, kwargs.get("progress", True))

    self.samples = samples
    self.softmax_temp = softmax_temp
    self.temperature = temperature
    self.select_best = select_best
    self.top_p = top_p
    self.top_k = top_k

GreedyMultiStartEval

GreedyMultiStartEval(env, num_starts=None, **kwargs)

Bases: EvalBase

Evaluates the policy via num_starts greedy multistarts samples from the policy

Parameters:

  • num_starts (int, default: None ) –

    Number of greedy multistarts to use

Source code in rl4co/tasks/eval.py
213
214
215
216
217
218
def __init__(self, env, num_starts=None, **kwargs):
    check_unused_kwargs(self, kwargs)
    super().__init__(env, kwargs.get("progress", True))

    assert num_starts is not None, "Must specify num_starts"
    self.num_starts = num_starts

GreedyMultiStartAugmentEval

GreedyMultiStartAugmentEval(
    env,
    num_starts=None,
    num_augment=8,
    force_dihedral_8=False,
    feats=None,
    **kwargs
)

Bases: EvalBase

Evaluates the policy via num_starts samples from the policy and num_augment augmentations of each sample.force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8

Parameters:

  • num_starts

    Number of greedy multistart samples

  • num_augment

    Number of augmentations per sample

  • force_dihedral_8

    If True, force the use of 8 augmentations (rotations and flips) as in POMO

Source code in rl4co/tasks/eval.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def __init__(
    self,
    env,
    num_starts=None,
    num_augment=8,
    force_dihedral_8=False,
    feats=None,
    **kwargs,
):
    check_unused_kwargs(self, kwargs)
    super().__init__(env, kwargs.get("progress", True))

    assert num_starts is not None, "Must specify num_starts"
    self.num_starts = num_starts
    assert not (
        num_augment != 8 and force_dihedral_8
    ), "Cannot force dihedral 8 when num_augment != 8"
    self.augmentation = StateAugmentation(
        num_augment=num_augment,
        augment_fn="dihedral8" if force_dihedral_8 else "symmetric",
        feats=feats,
    )

get_automatic_batch_size

get_automatic_batch_size(
    eval_fn, start_batch_size=8192, max_batch_size=4096
)

Automatically reduces the batch size based on the eval function

Parameters:

  • eval_fn

    The eval function

  • start_batch_size

    The starting batch size. This should be the theoretical maximum batch size

  • max_batch_size

    The maximum batch size. This is the practical maximum batch size

Source code in rl4co/tasks/eval.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def get_automatic_batch_size(eval_fn, start_batch_size=8192, max_batch_size=4096):
    """Automatically reduces the batch size based on the eval function

    Args:
        eval_fn: The eval function
        start_batch_size: The starting batch size. This should be the theoretical maximum batch size
        max_batch_size: The maximum batch size. This is the practical maximum batch size
    """
    batch_size = start_batch_size

    effective_ratio = 1

    if hasattr(eval_fn, "num_starts"):
        batch_size = batch_size // (eval_fn.num_starts // 10)
        effective_ratio *= eval_fn.num_starts // 10
    if hasattr(eval_fn, "num_augment"):
        batch_size = batch_size // eval_fn.num_augment
        effective_ratio *= eval_fn.num_augment
    if hasattr(eval_fn, "samples"):
        batch_size = batch_size // eval_fn.samples
        effective_ratio *= eval_fn.samples

    batch_size = min(batch_size, max_batch_size)
    # get closest integer power of 2
    batch_size = 2 ** int(np.log2(batch_size))

    print(f"Effective batch size: {batch_size} (ratio: {effective_ratio})")

    return batch_size