Skip to content

Training

Jacob Marshall edited this page Jan 21, 2024 · 1 revision

Training is handled by the Trainer class, you can find its source code here: Trainer

Training is divided into epochs, where each epoch collects:

  • N Self-play episodes, sampling minibatches from experience replay memory every M episodes
  • K episodes against any number of baselines (optional)

Trainers will continue to train indefinitely until they are terminated.

Self-Play Episodes

N self-play episodes are collected each epoch. These self-play episodes are collected using the current trained model, and learning happens online, meaning that the model may change during the course of a collected episode as training steps are completed. The episodes_per_epoch parameter controls the number of episodes, while parallel_envs controls the number of envs to collect episodes from in parallel. After an episode is completed, relevant information (observation, mcts visit counts, etc.) are collected and stored in Replay Memory, which is sampled from during training steps. Replay Memory can be configured to store a maximum number of samples, as well as to require a minimum number of samples prior to allowing any training steps to execute. Once replay memory is full, this means that N selfplay episodes and N training steps will be completed per epoch. Additionally, replay memory may also be configured to sample from episodes first, and then again within steps in an episode. This helps mitigate the effects of environments where the number of steps in an episode can very greatly, as this can cause longer episodes to be overly-represented in training minibatches. All environments step at once in parallel, and then any terminated episodes are added to replay memory.

Training Step

After M episodes are collected and added to replay memory, a training step is executed. Each training step samples a minibatch from replay memory of size minibatch_size.

A training step's loss is calculated as policy loss + value loss, or:

$l = (z-v)^2 + (c * \pi^T \log{p}) + r||\theta||^2$

where value loss $l$ is given as the mean-squared error of the episode outcome $z$ and predicted outcome $v$, and policy loss is given as the cross-entropy loss of the search policies $\pi^T$ and the predicted policy vector $p$. A policy factor $c$ is also included, which can be used to help re-balance the weight of the two losses. Finally, L2-regularization is applied, controlled by $r$, to prevent overfitting.

It's worth noting that all supported trainable algorithms are two-headed evaluators -- they output a policy p and an evaluation v.

Evaluation/Test Step

After collecting N selfplay episodes, K evaluation episodes are collected against each of any number of configurable baselines. See Baselines for more information on baselines. After evaluation episodes are collected, metrics on training progress and evaluation statistics are logged (and plotted if using an interactive python notebook). Self-play episodes that were in progress when evaluation began will be resumed after evaluation is completed.

Checkpoints

After an evaluation step, a checkpoint will be saved, which contains:

  • Model/Optimizer parameters
  • History (loss, evaluation data, other metrics)
  • Training config
  • Env config

These checkpoints can be loaded such that you can start right where you left off without remembering the correct configuration. It's worth noting that checkpoints do not save replay memory, so it must be repopulated when loading from a checkpoint. A checkpoint can be loaded using the --checkpoint flag, or by providing a checkpoint in a trainable algorithm's config (please use --checkpoint for training).

Training Parameters

  • algo_config: hyperparameters/configuration for the trained algorithm, see algorithm wiki page for more details
  • episodes_per_epoch: number of selfplay episodes to collect per epoch
  • episodes_per_minibatch: how many episodes need to be collected prior to running a training step
  • minibatch_size: number of samples per minibatch
  • learning_rate: optimizer learning rate
  • momentum: SGD momentum
  • c_reg: L2-regularization factor
  • lr_decay_gamma: learning rate decay applied via torch.optim.lr_scheduler.ExponentialLR
  • parallel_envs: number of envs to collect selfplay episodes from in parallel
  • policy_factor: $c$ in $l = (z-v)^2 + (c * \pi^T \log{p})$, used to bias overall loss towards/against the policy component.
  • replay_memory_sample_games: if True, samples from episodes in memory first, then samples from observation/result pairs within the sampled episode. If false, samples from a queue of observation result pairs.
  • replay_memory_min_size: minimum number of samples/episodes that must be acquired prior to executing training steps (important: this will work differently depending on what replay_memory_sample_games is set to). If the replay memory does not meet this threshold, training steps will be skipped.
  • replay_memory_max_size: maximum size of replay memory queue/buffer
  • test_config: configuration for Evaluation step, see Evaluation & Testing

Example Configurations

Clone this wiki locally