@@ -80,11 +80,11 @@ def _init_training_state(
8080 alpha_optimizer : optax .GradientTransformation ,
8181 policy_optimizer : optax .GradientTransformation ,
8282 q_optimizer : optax .GradientTransformation ,
83- initial_alpha : float = 0 .0 ,
83+ initial_alpha : float = 1 .0 ,
8484) -> TrainingState :
8585 """Inits the training state and replicates it over devices."""
8686 key_policy , key_q = jax .random .split (key )
87- log_alpha = jnp .asarray (initial_alpha , dtype = jnp .float32 )
87+ log_alpha = jnp .asarray (jnp . log ( initial_alpha ) , dtype = jnp .float32 )
8888 alpha_optimizer_state = alpha_optimizer .init (log_alpha )
8989
9090 policy_params = sac_network .policy_network .init (key_policy )
@@ -162,10 +162,6 @@ def train(
162162 num_envs: the number of parallel environments to use for rollouts
163163 NOTE: `num_envs` must be divisible by the total number of chips since each
164164 chip gets `num_envs // total_number_of_chips` environments to roll out
165- NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
166- data generated by `num_envs` parallel envs gets used for gradient
167- updates over `num_minibatches` of data, where each minibatch has a
168- leading dimension of `batch_size`
169165 num_eval_envs: the number of envs to use for evluation. Each env will run 1
170166 episode, and all envs run in parallel during eval.
171167 learning_rate: learning rate for SAC loss
@@ -178,7 +174,7 @@ def train(
178174 max_devices_per_host: maximum number of chips to use per host process
179175 reward_scaling: float scaling for reward
180176 tau: interpolation factor in polyak averaging for target networks
181- intial_alpha : initial value for the temperature parameter alpha
177+ initial_alpha : initial value for the temperature parameter α
182178 min_replay_size: the minimum number of samples in the replay buffer before
183179 starting training. This is used to prefill the replay buffer with random
184180 samples before training starts
0 commit comments