Skip to content

Commit 8ff7005

Browse files
committed
SHAC: layer norm and gradient clipping
Starting to see progress training the ant environment.
1 parent 4000c95 commit 8ff7005

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

brax/training/agents/shac/losses.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Proximal policy optimization training.
15+
"""Short-Horizon Actor Critic.
1616
17-
See: https://arxiv.org/pdf/1707.06347.pdf
17+
See: https://arxiv.org/pdf/2204.07137.pdf
1818
"""
1919

2020
from typing import Any, Tuple
@@ -46,8 +46,7 @@ def compute_shac_policy_loss(
4646
reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]:
4747
"""Computes SHAC critic loss.
4848
49-
This implements Eq. 5 of 2204.07137. It needs to account for any episodes where
50-
the episode terminates and include the terminal values appopriately.
49+
This implements Eq. 5 of 2204.07137.
5150
5251
Args:
5352
policy_params: Policy network parameters
@@ -129,7 +128,6 @@ def compute_shac_critic_loss(
129128
params: Params,
130129
normalizer_params: Any,
131130
data: types.Transition,
132-
rng: jnp.ndarray,
133131
shac_network: shac_networks.SHACNetworks,
134132
discounting: float = 0.9,
135133
reward_scaling: float = 1.0,

brax/training/agents/shac/networks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ def make_shac_networks(
7676
observation_size,
7777
preprocess_observations_fn=preprocess_observations_fn,
7878
hidden_layer_sizes=policy_hidden_layer_sizes,
79-
activation=activation)
79+
activation=activation,
80+
layer_norm=True)
8081
value_network = networks.make_value_network(
8182
observation_size,
8283
preprocess_observations_fn=preprocess_observations_fn,
8384
hidden_layer_sizes=value_hidden_layer_sizes,
84-
activation=activation)
85+
activation=activation,
86+
layer_norm=True)
8587

8688
return SHACNetworks(
8789
policy_network=policy_network,

brax/training/agents/shac/train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,14 @@ def train(environment: envs.Env,
130130
preprocess_observations_fn=normalize)
131131
make_policy = shac_networks.make_inference_fn(shac_network)
132132

133-
policy_optimizer = optax.adam(learning_rate=actor_learning_rate)
134-
value_optimizer = optax.adam(learning_rate=critic_learning_rate)
133+
policy_optimizer = optax.chain(
134+
optax.clip(1.0),
135+
optax.adam(learning_rate=actor_learning_rate, b1=0.7, b2=0.95)
136+
)
137+
value_optimizer = optax.chain(
138+
optax.clip(1.0),
139+
optax.adam(learning_rate=critic_learning_rate, b1=0.7, b2=0.95)
140+
)
135141

136142
value_loss_fn = functools.partial(
137143
shac_losses.compute_shac_critic_loss,
@@ -184,6 +190,7 @@ def f(carry, unused_t):
184190

185191
policy_gradient_update_fn = gradients.gradient_update_fn(
186192
rollout_loss_fn, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)
193+
policy_gradient_update_fn = jax.jit(policy_gradient_update_fn)
187194

188195
def minibatch_step(
189196
carry, data: types.Transition,
@@ -194,7 +201,6 @@ def minibatch_step(
194201
params,
195202
normalizer_params,
196203
data,
197-
key_loss,
198204
optimizer_state=optimizer_state)
199205

200206
return (optimizer_state, params, key), metrics
@@ -317,7 +323,6 @@ def training_epoch_with_timing(
317323
key_envs = jnp.reshape(key_envs,
318324
(local_devices_to_use, -1) + key_envs.shape[1:])
319325
env_state = reset_fn(key_envs)
320-
print(f'env_state: {env_state.qp.pos.shape}')
321326

322327
if not eval_env:
323328
eval_env = env

brax/training/networks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class MLP(linen.Module):
4141
kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
4242
activate_final: bool = False
4343
bias: bool = True
44+
layer_norm: bool = True
4445

4546
@linen.compact
4647
def __call__(self, data: jnp.ndarray):
@@ -52,6 +53,8 @@ def __call__(self, data: jnp.ndarray):
5253
kernel_init=self.kernel_init,
5354
use_bias=self.bias)(
5455
hidden)
56+
if self.layer_norm:
57+
hidden = linen.LayerNorm()(hidden)
5558
if i != len(self.layer_sizes) - 1 or self.activate_final:
5659
hidden = self.activation(hidden)
5760
return hidden
@@ -86,11 +89,13 @@ def make_policy_network(
8689
preprocess_observations_fn: types.PreprocessObservationFn = types
8790
.identity_observation_preprocessor,
8891
hidden_layer_sizes: Sequence[int] = (256, 256),
89-
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
92+
activation: ActivationFn = linen.relu,
93+
layer_norm: bool = False) -> FeedForwardNetwork:
9094
"""Creates a policy network."""
9195
policy_module = MLP(
9296
layer_sizes=list(hidden_layer_sizes) + [param_size],
9397
activation=activation,
98+
layer_norm=layer_norm,
9499
kernel_init=jax.nn.initializers.lecun_uniform())
95100

96101
def apply(processor_params, policy_params, obs):
@@ -107,11 +112,13 @@ def make_value_network(
107112
preprocess_observations_fn: types.PreprocessObservationFn = types
108113
.identity_observation_preprocessor,
109114
hidden_layer_sizes: Sequence[int] = (256, 256),
110-
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
115+
activation: ActivationFn = linen.relu,
116+
layer_norm: bool = False) -> FeedForwardNetwork:
111117
"""Creates a policy network."""
112118
value_module = MLP(
113119
layer_sizes=list(hidden_layer_sizes) + [1],
114120
activation=activation,
121+
layer_norm=layer_norm,
115122
kernel_init=jax.nn.initializers.lecun_uniform())
116123

117124
def apply(processor_params, policy_params, obs):

0 commit comments

Comments
 (0)