Skip to content

Commit 4000c95

Browse files
committed
SHAC: add target network
1 parent 44830aa commit 4000c95

File tree

2 files changed

+69
-95
lines changed

2 files changed

+69
-95
lines changed

brax/training/agents/shac/losses.py

Lines changed: 59 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,32 @@ def compute_shac_policy_loss(
8383
truncation = data.extras['state_extras']['truncation']
8484
termination = (1 - data.discount) * (1 - truncation)
8585

86-
horizon = rewards.shape[0]
86+
# Append terminal values to get [v1, ..., v_t+1]
87+
values_t_plus_1 = jnp.concatenate(
88+
[values[1:], jnp.expand_dims(terminal_values, 0)], axis=0)
8789

90+
# jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227
8891
def sum_step(carry, target_t):
89-
gam, acc = carry
90-
reward, v, truncation, termination = target_t
91-
acc = acc + jnp.where(truncation + termination, gam * v, gam * reward)
92+
gam, rew_acc = carry
93+
reward, v, termination = target_t
94+
95+
# clean up gamma and rew_acc for done envs, otherwise update
96+
rew_acc = jnp.where(termination, 0, rew_acc + gam * reward)
9297
gam = jnp.where(termination, 1.0, gam * discounting)
93-
return (gam, acc), (acc)
9498

95-
acc = terminal_values * (discounting ** horizon) * (1-termination[-1]) * (1-truncation[-1])
96-
jax.debug.print('acc shape: {x}', x=acc.shape)
97-
gam = jnp.ones_like(terminal_values)
98-
(_, acc), (temp) = jax.lax.scan(sum_step, (gam, acc),
99-
(rewards, values, truncation, termination))
99+
return (gam, rew_acc), (gam, rew_acc)
100100

101-
policy_loss = -jnp.mean(acc) / horizon
101+
rew_acc = jnp.zeros_like(terminal_values)
102+
gam = jnp.ones_like(terminal_values)
103+
(gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc),
104+
(rewards, values, termination))
102105

103-
# inspect the data for one of the rollouts
104-
jax.debug.print('obs={o}, obs_next={n}, values={v}, reward={r}, truncation={t}, terminal={s}',
105-
v=values[:, 0], o=data.observation[:,0], r=data.reward[:,0],
106-
t=truncation[:, 0], s=termination[:,0], n=data.next_observation[:, 0])
106+
policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values)
107+
# for trials that are truncated (i.e. hit the episode length) include reward for
108+
# terminal state. otherwise, the trial was aborted and should receive zero additional
109+
policy_loss = policy_loss + jnp.sum((-rew_acc - gam_acc * jnp.where(truncation, values_t_plus_1, 0)) * termination)
110+
policy_loss = policy_loss / values.shape[0] / values.shape[1]
107111

108-
jax.debug.print('loss={l}, r={r}', l=policy_loss, r=temp[:,0])
109112

110113
# Entropy reward
111114
policy_logits = policy_apply(normalizer_params, policy_params,
@@ -122,68 +125,6 @@ def sum_step(carry, target_t):
122125
}
123126

124127

125-
126-
def compute_target_values(truncation: jnp.ndarray,
127-
termination: jnp.ndarray,
128-
rewards: jnp.ndarray,
129-
values: jnp.ndarray,
130-
bootstrap_value: jnp.ndarray,
131-
discount: float = 0.99,
132-
lambda_: float = 0.95,
133-
td_lambda=True):
134-
"""Calculates the target values.
135-
136-
This implements Eq. 7 of 2204.07137
137-
https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349
138-
139-
Args:
140-
truncation: A float32 tensor of shape [T, B] with truncation signal.
141-
termination: A float32 tensor of shape [T, B] with termination signal.
142-
rewards: A float32 tensor of shape [T, B] containing rewards generated by
143-
following the behaviour policy.
144-
values: A float32 tensor of shape [T, B] with the value function estimates
145-
wrt. the target policy.
146-
bootstrap_value: A float32 of shape [B] with the value function estimate at
147-
time T.
148-
discount: TD discount.
149-
150-
Returns:
151-
A float32 tensor of shape [T, B].
152-
"""
153-
truncation_mask = 1 - truncation
154-
# Append bootstrapped value to get [v1, ..., v_t+1]
155-
values_t_plus_1 = jnp.concatenate(
156-
[values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0)
157-
158-
if td_lambda:
159-
160-
def compute_v_st(carry, target_t):
161-
Ai, Bi, lam = carry
162-
reward, truncation_mask, vtp1, termination = target_t
163-
# TODO: should figure out how to handle termination
164-
165-
lam = lam * lambda_ * (1 - termination) + termination
166-
Ai = (1 - termination) * (lam * discount * Ai + discount * vtp1 + (1. - lam) / (1. - lambda_) * reward)
167-
Bi = discount * (vtp1 * termination + Bi * (1.0 - termination)) + reward
168-
vs = (1.0 - lambda_) * Ai + lam * Bi
169-
170-
return (Ai, Bi, lam), (vs)
171-
172-
Ai = jnp.ones_like(bootstrap_value)
173-
Bi = jnp.zeros_like(bootstrap_value)
174-
lam = jnp.ones_like(bootstrap_value)
175-
176-
(_, _, _), (vs) = jax.lax.scan(compute_v_st, (Ai, Bi, lam),
177-
(rewards, truncation_mask, values_t_plus_1, termination),
178-
length=int(truncation_mask.shape[0]),
179-
reverse=True)
180-
181-
else:
182-
vs = rewards + discount * values_t_plus_1
183-
184-
return jax.lax.stop_gradient(vs)
185-
186-
187128
def compute_shac_critic_loss(
188129
params: Params,
189130
normalizer_params: Any,
@@ -192,9 +133,13 @@ def compute_shac_critic_loss(
192133
shac_network: shac_networks.SHACNetworks,
193134
discounting: float = 0.9,
194135
reward_scaling: float = 1.0,
195-
lambda_: float = 0.95) -> Tuple[jnp.ndarray, types.Metrics]:
136+
lambda_: float = 0.95,
137+
td_lambda: bool = True) -> Tuple[jnp.ndarray, types.Metrics]:
196138
"""Computes SHAC critic loss.
197139
140+
This implements Eq. 7 of 2204.07137
141+
https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349
142+
198143
Args:
199144
params: Value network parameters,
200145
normalizer_params: Parameters of the normalizer.
@@ -207,8 +152,7 @@ def compute_shac_critic_loss(
207152
discounting: discounting,
208153
reward_scaling: reward multiplier.
209154
lambda_: Lambda for TD value updates
210-
clipping_epsilon: Policy loss clipping epsilon
211-
normalize_advantage: whether to normalize advantage estimate
155+
td_lambda: whether to use a TD-Lambda value target
212156
213157
Returns:
214158
A tuple (loss, metrics)
@@ -218,25 +162,47 @@ def compute_shac_critic_loss(
218162

219163
data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data)
220164

221-
baseline = value_apply(normalizer_params, params, data.observation)
222-
bootstrap_value = value_apply(normalizer_params, params, data.next_observation[-1])
165+
values = value_apply(normalizer_params, params, data.observation)
166+
terminal_value = value_apply(normalizer_params, params, data.next_observation[-1])
223167

224168
rewards = data.reward * reward_scaling
225169
truncation = data.extras['state_extras']['truncation']
226170
termination = (1 - data.discount) * (1 - truncation)
227171

228-
vs = compute_target_values(
229-
truncation=truncation,
230-
termination=termination,
231-
rewards=rewards,
232-
values=baseline,
233-
bootstrap_value=bootstrap_value,
234-
discount=discounting,
235-
lambda_=lambda_)
172+
# Append terminal values to get [v1, ..., v_t+1]
173+
values_t_plus_1 = jnp.concatenate(
174+
[values[1:], jnp.expand_dims(terminal_value, 0)], axis=0)
175+
176+
# compute target values
177+
if td_lambda:
178+
179+
def compute_v_st(carry, target_t):
180+
Ai, Bi, lam = carry
181+
reward, vtp1, termination = target_t
182+
183+
reward = reward * termination
184+
185+
lam = lam * lambda_ * (1 - termination) + termination
186+
Ai = (1 - termination) * (lam * discounting * Ai + discounting * vtp1 + (1. - lam) / (1. - lambda_) * reward)
187+
Bi = discounting * (vtp1 * termination + Bi * (1.0 - termination)) + reward
188+
vs = (1.0 - lambda_) * Ai + lam * Bi
189+
190+
return (Ai, Bi, lam), (vs)
191+
192+
Ai = jnp.ones_like(terminal_value)
193+
Bi = jnp.zeros_like(terminal_value)
194+
lam = jnp.ones_like(terminal_value)
195+
(_, _, _), (vs) = jax.lax.scan(compute_v_st, (Ai, Bi, lam),
196+
(rewards, values_t_plus_1, termination),
197+
length=int(termination.shape[0]),
198+
reverse=True)
199+
200+
else:
201+
vs = rewards + discounting * values_t_plus_1
236202

237-
v_error = vs - baseline
238-
v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5
203+
target_values = jax.lax.stop_gradient(vs)
239204

205+
v_loss = jnp.mean((target_values - values) ** 2)
240206

241207
total_loss = v_loss
242208
return total_loss, {

brax/training/agents/shac/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class TrainingState:
5353
policy_params: Params
5454
value_optimizer_state: optax.OptState
5555
value_params: Params
56+
target_value_params: Params
5657
normalizer_params: running_statistics.RunningStatisticsState
5758
env_steps: jnp.ndarray
5859

@@ -80,6 +81,7 @@ def train(environment: envs.Env,
8081
num_evals: int = 1,
8182
normalize_observations: bool = False,
8283
reward_scaling: float = 1.,
84+
tau: float = 0.005, # this is 1-alpha from the original paper
8385
lambda_: float = .95,
8486
deterministic_eval: bool = False,
8587
network_factory: types.NetworkFactory[
@@ -222,7 +224,7 @@ def training_step(
222224
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)
223225

224226
(policy_loss, (state, data, policy_metrics)), policy_params, policy_optimizer_state = policy_gradient_update_fn(
225-
training_state.policy_params, training_state.value_params,
227+
training_state.policy_params, training_state.target_value_params,
226228
training_state.normalizer_params, state, key_generate_unroll,
227229
optimizer_state=training_state.policy_optimizer_state)
228230

@@ -238,13 +240,18 @@ def training_step(
238240
(training_state.value_optimizer_state, training_state.value_params, key_sgd), (),
239241
length=num_updates_per_batch)
240242

243+
target_value_params = jax.tree_util.tree_map(
244+
lambda x, y: x * (1 - tau) + y * tau, training_state.target_value_params,
245+
value_params)
246+
241247
metrics.update(policy_metrics)
242248

243249
new_training_state = TrainingState(
244250
policy_optimizer_state=policy_optimizer_state,
245251
policy_params=policy_params,
246252
value_optimizer_state=value_optimizer_state,
247253
value_params=value_params,
254+
target_value_params=target_value_params,
248255
normalizer_params=training_state.normalizer_params,
249256
env_steps=training_state.env_steps + env_step_per_training_step)
250257
return (new_training_state, state, new_key), metrics
@@ -298,6 +305,7 @@ def training_epoch_with_timing(
298305
policy_params=policy_init_params,
299306
value_optimizer_state=value_optimizer.init(value_init_params),
300307
value_params=value_init_params,
308+
target_value_params=value_init_params,
301309
normalizer_params=running_statistics.init_state(
302310
specs.Array((env.observation_size,), jnp.float32)),
303311
env_steps=0)
@@ -329,7 +337,7 @@ def training_epoch_with_timing(
329337
if process_id == 0 and num_evals > 1:
330338
metrics = evaluator.run_evaluation(
331339
_unpmap(
332-
(training_state.normalizer_params, training_state.params.policy)),
340+
(training_state.normalizer_params, training_state.policy_params)),
333341
training_metrics={})
334342
logging.info(metrics)
335343
progress_fn(0, metrics)

0 commit comments

Comments
 (0)