Skip to content

Commit 07dcc4a

Browse files
committed
SHAC: tweak layer norm
1 parent 8ff7005 commit 07dcc4a

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

brax/training/agents/ppo/networks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def make_ppo_networks(
6666
.identity_observation_preprocessor,
6767
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
6868
value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
69-
activation: networks.ActivationFn = linen.swish) -> PPONetworks:
69+
activation: networks.ActivationFn = linen.swish,
70+
layer_norm: bool = False) -> PPONetworks:
7071
"""Make PPO networks with preprocessor."""
7172
parametric_action_distribution = distribution.NormalTanhDistribution(
7273
event_size=action_size)
@@ -75,12 +76,14 @@ def make_ppo_networks(
7576
observation_size,
7677
preprocess_observations_fn=preprocess_observations_fn,
7778
hidden_layer_sizes=policy_hidden_layer_sizes,
78-
activation=activation)
79+
activation=activation,
80+
layer_norm=layer_norm)
7981
value_network = networks.make_value_network(
8082
observation_size,
8183
preprocess_observations_fn=preprocess_observations_fn,
8284
hidden_layer_sizes=value_hidden_layer_sizes,
83-
activation=activation)
85+
activation=activation,
86+
layer_norm=layer_norm)
8487

8588
return PPONetworks(
8689
policy_network=policy_network,

brax/training/agents/shac/losses.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def compute_shac_policy_loss(
8989
# jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227
9090
def sum_step(carry, target_t):
9191
gam, rew_acc = carry
92-
reward, v, termination = target_t
92+
reward, termination = target_t
9393

9494
# clean up gamma and rew_acc for done envs, otherwise update
9595
rew_acc = jnp.where(termination, 0, rew_acc + gam * reward)
@@ -100,7 +100,7 @@ def sum_step(carry, target_t):
100100
rew_acc = jnp.zeros_like(terminal_values)
101101
gam = jnp.ones_like(terminal_values)
102102
(gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc),
103-
(rewards, values, termination))
103+
(rewards, termination))
104104

105105
policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values)
106106
# for trials that are truncated (i.e. hit the episode length) include reward for
@@ -118,7 +118,6 @@ def sum_step(carry, target_t):
118118
total_loss = policy_loss + entropy_loss
119119

120120
return total_loss, {
121-
'total_loss': total_loss,
122121
'policy_loss': policy_loss,
123122
'entropy_loss': entropy_loss
124123
}

brax/training/agents/shac/networks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def make_shac_networks(
6767
.identity_observation_preprocessor,
6868
policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
6969
value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
70-
activation: networks.ActivationFn = linen.swish) -> SHACNetworks:
70+
activation: networks.ActivationFn = linen.elu,
71+
layer_norm: bool = True) -> SHACNetworks:
7172
"""Make SHAC networks with preprocessor."""
7273
parametric_action_distribution = distribution.NormalTanhDistribution(
7374
event_size=action_size)
@@ -77,13 +78,13 @@ def make_shac_networks(
7778
preprocess_observations_fn=preprocess_observations_fn,
7879
hidden_layer_sizes=policy_hidden_layer_sizes,
7980
activation=activation,
80-
layer_norm=True)
81+
layer_norm=layer_norm)
8182
value_network = networks.make_value_network(
8283
observation_size,
8384
preprocess_observations_fn=preprocess_observations_fn,
8485
hidden_layer_sizes=value_hidden_layer_sizes,
8586
activation=activation,
86-
layer_norm=True)
87+
layer_norm=layer_norm)
8788

8889
return SHACNetworks(
8990
policy_network=policy_network,

brax/training/agents/shac/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def train(environment: envs.Env,
8383
reward_scaling: float = 1.,
8484
tau: float = 0.005, # this is 1-alpha from the original paper
8585
lambda_: float = .95,
86+
td_lambda: bool = True,
8687
deterministic_eval: bool = False,
8788
network_factory: types.NetworkFactory[
8889
shac_networks.SHACNetworks] = shac_networks.make_shac_networks,
@@ -144,7 +145,8 @@ def train(environment: envs.Env,
144145
shac_network=shac_network,
145146
discounting=discounting,
146147
reward_scaling=reward_scaling,
147-
lambda_=lambda_)
148+
lambda_=lambda_,
149+
td_lambda=td_lambda)
148150

149151
value_gradient_update_fn = gradients.gradient_update_fn(
150152
value_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)

brax/training/networks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ def __call__(self, data: jnp.ndarray):
5151
hidden_size,
5252
name=f'hidden_{i}',
5353
kernel_init=self.kernel_init,
54+
dtype=get_dtype(self.half_precision),
5455
use_bias=self.bias)(
5556
hidden)
56-
if self.layer_norm:
57-
hidden = linen.LayerNorm()(hidden)
5857
if i != len(self.layer_sizes) - 1 or self.activate_final:
5958
hidden = self.activation(hidden)
59+
if self.layer_norm:
60+
hidden = linen.LayerNorm()(hidden)
6061
return hidden
6162

6263

0 commit comments

Comments
 (0)