Skip to content

Commit ab7f75c

Browse files
committed
convert alpha to log value
1 parent da25b3e commit ab7f75c

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

brax/training/agents/sac/train.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ def _init_training_state(
8080
alpha_optimizer: optax.GradientTransformation,
8181
policy_optimizer: optax.GradientTransformation,
8282
q_optimizer: optax.GradientTransformation,
83-
initial_alpha: float = 0.0,
83+
initial_alpha: float = 1.0,
8484
) -> TrainingState:
8585
"""Inits the training state and replicates it over devices."""
8686
key_policy, key_q = jax.random.split(key)
87-
log_alpha = jnp.asarray(initial_alpha, dtype=jnp.float32)
87+
log_alpha = jnp.asarray(jnp.log(initial_alpha), dtype=jnp.float32)
8888
alpha_optimizer_state = alpha_optimizer.init(log_alpha)
8989

9090
policy_params = sac_network.policy_network.init(key_policy)
@@ -162,10 +162,6 @@ def train(
162162
num_envs: the number of parallel environments to use for rollouts
163163
NOTE: `num_envs` must be divisible by the total number of chips since each
164164
chip gets `num_envs // total_number_of_chips` environments to roll out
165-
NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
166-
data generated by `num_envs` parallel envs gets used for gradient
167-
updates over `num_minibatches` of data, where each minibatch has a
168-
leading dimension of `batch_size`
169165
num_eval_envs: the number of envs to use for evluation. Each env will run 1
170166
episode, and all envs run in parallel during eval.
171167
learning_rate: learning rate for SAC loss
@@ -178,7 +174,7 @@ def train(
178174
max_devices_per_host: maximum number of chips to use per host process
179175
reward_scaling: float scaling for reward
180176
tau: interpolation factor in polyak averaging for target networks
181-
intial_alpha: initial value for the temperature parameter alpha
177+
initial_alpha: initial value for the temperature parameter α
182178
min_replay_size: the minimum number of samples in the replay buffer before
183179
starting training. This is used to prefill the replay buffer with random
184180
samples before training starts

0 commit comments

Comments
 (0)