Train and Evaluation
Train¶
Functions:
-
run
–Trains the model. Can additionally evaluate on a testset, using best weights obtained during
run
¶
Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.
Parameters:
-
cfg
(DictConfig
) –Configuration composed by Hydra.
Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
Source code in rl4co/tasks/train.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
|
Evaluate¶
Classes:
-
EvalBase
–Base class for evaluation
-
GreedyEval
–Evaluates the policy using greedy decoding and single trajectory
-
AugmentationEval
–Evaluates the policy via N state augmentations
-
SamplingEval
–Evaluates the policy via N samples from the policy
-
GreedyMultiStartEval
–Evaluates the policy via
num_starts
greedy multistarts samples from the policy -
GreedyMultiStartAugmentEval
–Evaluates the policy via
num_starts
samples from the policy
Functions:
-
get_automatic_batch_size
–Automatically reduces the batch size based on the eval function
EvalBase
¶
EvalBase(env, progress=True, **kwargs)
Base class for evaluation
Parameters:
-
env
–Environment
-
progress
–Whether to show progress bar
-
**kwargs
–Additional arguments (to be implemented in subclasses)
Source code in rl4co/tasks/eval.py
29 30 31 32 |
|
GreedyEval
¶
GreedyEval(env, **kwargs)
Bases: EvalBase
Evaluates the policy using greedy decoding and single trajectory
Source code in rl4co/tasks/eval.py
93 94 95 |
|
AugmentationEval
¶
AugmentationEval(
env,
num_augment=8,
force_dihedral_8=False,
feats=None,
**kwargs
)
Bases: EvalBase
Evaluates the policy via N state augmentations
force_dihedral_8
forces the use of 8 augmentations (rotations and flips) as in POMO
https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
Parameters:
-
num_augment
(int
, default:8
) –Number of state augmentations
-
force_dihedral_8
(bool
, default:False
) –Whether to force the use of 8 augmentations
Source code in rl4co/tasks/eval.py
119 120 121 122 123 124 125 126 |
|
SamplingEval
¶
SamplingEval(
env,
samples,
softmax_temp=None,
select_best=True,
temperature=1.0,
top_p=0.0,
top_k=0,
**kwargs
)
Bases: EvalBase
Evaluates the policy via N samples from the policy
Parameters:
-
samples
(int
) –Number of samples to take
-
softmax_temp
(float
, default:None
) –Temperature for softmax sampling. The higher the temperature, the more random the sampling
Source code in rl4co/tasks/eval.py
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
|
GreedyMultiStartEval
¶
GreedyMultiStartEval(env, num_starts=None, **kwargs)
Bases: EvalBase
Evaluates the policy via num_starts
greedy multistarts samples from the policy
Parameters:
-
num_starts
(int
, default:None
) –Number of greedy multistarts to use
Source code in rl4co/tasks/eval.py
211 212 213 214 215 216 |
|
GreedyMultiStartAugmentEval
¶
GreedyMultiStartAugmentEval(
env,
num_starts=None,
num_augment=8,
force_dihedral_8=False,
feats=None,
**kwargs
)
Bases: EvalBase
Evaluates the policy via num_starts
samples from the policy
and num_augment
augmentations of each sample.force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO
https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
Parameters:
-
num_starts
–Number of greedy multistart samples
-
num_augment
–Number of augmentations per sample
-
force_dihedral_8
–If True, force the use of 8 augmentations (rotations and flips) as in POMO
Source code in rl4co/tasks/eval.py
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
|
get_automatic_batch_size
¶
get_automatic_batch_size(
eval_fn, start_batch_size=8192, max_batch_size=4096
)
Automatically reduces the batch size based on the eval function
Parameters:
-
eval_fn
–The eval function
-
start_batch_size
–The starting batch size. This should be the theoretical maximum batch size
-
max_batch_size
–The maximum batch size. This is the practical maximum batch size
Source code in rl4co/tasks/eval.py
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
|