@@ -83,29 +83,32 @@ def compute_shac_policy_loss(
8383 truncation = data .extras ['state_extras' ]['truncation' ]
8484 termination = (1 - data .discount ) * (1 - truncation )
8585
86- horizon = rewards .shape [0 ]
86+ # Append terminal values to get [v1, ..., v_t+1]
87+ values_t_plus_1 = jnp .concatenate (
88+ [values [1 :], jnp .expand_dims (terminal_values , 0 )], axis = 0 )
8789
90+ # jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227
8891 def sum_step (carry , target_t ):
89- gam , acc = carry
90- reward , v , truncation , termination = target_t
91- acc = acc + jnp .where (truncation + termination , gam * v , gam * reward )
92+ gam , rew_acc = carry
93+ reward , v , termination = target_t
94+
95+ # clean up gamma and rew_acc for done envs, otherwise update
96+ rew_acc = jnp .where (termination , 0 , rew_acc + gam * reward )
9297 gam = jnp .where (termination , 1.0 , gam * discounting )
93- return (gam , acc ), (acc )
9498
95- acc = terminal_values * (discounting ** horizon ) * (1 - termination [- 1 ]) * (1 - truncation [- 1 ])
96- jax .debug .print ('acc shape: {x}' , x = acc .shape )
97- gam = jnp .ones_like (terminal_values )
98- (_ , acc ), (temp ) = jax .lax .scan (sum_step , (gam , acc ),
99- (rewards , values , truncation , termination ))
99+ return (gam , rew_acc ), (gam , rew_acc )
100100
101- policy_loss = - jnp .mean (acc ) / horizon
101+ rew_acc = jnp .zeros_like (terminal_values )
102+ gam = jnp .ones_like (terminal_values )
103+ (gam , last_rew_acc ), (gam_acc , rew_acc ) = jax .lax .scan (sum_step , (gam , rew_acc ),
104+ (rewards , values , termination ))
102105
103- # inspect the data for one of the rollouts
104- jax .debug .print ('obs={o}, obs_next={n}, values={v}, reward={r}, truncation={t}, terminal={s}' ,
105- v = values [:, 0 ], o = data .observation [:,0 ], r = data .reward [:,0 ],
106- t = truncation [:, 0 ], s = termination [:,0 ], n = data .next_observation [:, 0 ])
106+ policy_loss = jnp .sum (- last_rew_acc - gam * terminal_values )
107+ # for trials that are truncated (i.e. hit the episode length) include reward for
108+ # terminal state. otherwise, the trial was aborted and should receive zero additional
109+ policy_loss = policy_loss + jnp .sum ((- rew_acc - gam_acc * jnp .where (truncation , values_t_plus_1 , 0 )) * termination )
110+ policy_loss = policy_loss / values .shape [0 ] / values .shape [1 ]
107111
108- jax .debug .print ('loss={l}, r={r}' , l = policy_loss , r = temp [:,0 ])
109112
110113 # Entropy reward
111114 policy_logits = policy_apply (normalizer_params , policy_params ,
@@ -122,68 +125,6 @@ def sum_step(carry, target_t):
122125 }
123126
124127
125-
126- def compute_target_values (truncation : jnp .ndarray ,
127- termination : jnp .ndarray ,
128- rewards : jnp .ndarray ,
129- values : jnp .ndarray ,
130- bootstrap_value : jnp .ndarray ,
131- discount : float = 0.99 ,
132- lambda_ : float = 0.95 ,
133- td_lambda = True ):
134- """Calculates the target values.
135-
136- This implements Eq. 7 of 2204.07137
137- https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349
138-
139- Args:
140- truncation: A float32 tensor of shape [T, B] with truncation signal.
141- termination: A float32 tensor of shape [T, B] with termination signal.
142- rewards: A float32 tensor of shape [T, B] containing rewards generated by
143- following the behaviour policy.
144- values: A float32 tensor of shape [T, B] with the value function estimates
145- wrt. the target policy.
146- bootstrap_value: A float32 of shape [B] with the value function estimate at
147- time T.
148- discount: TD discount.
149-
150- Returns:
151- A float32 tensor of shape [T, B].
152- """
153- truncation_mask = 1 - truncation
154- # Append bootstrapped value to get [v1, ..., v_t+1]
155- values_t_plus_1 = jnp .concatenate (
156- [values [1 :], jnp .expand_dims (bootstrap_value , 0 )], axis = 0 )
157-
158- if td_lambda :
159-
160- def compute_v_st (carry , target_t ):
161- Ai , Bi , lam = carry
162- reward , truncation_mask , vtp1 , termination = target_t
163- # TODO: should figure out how to handle termination
164-
165- lam = lam * lambda_ * (1 - termination ) + termination
166- Ai = (1 - termination ) * (lam * discount * Ai + discount * vtp1 + (1. - lam ) / (1. - lambda_ ) * reward )
167- Bi = discount * (vtp1 * termination + Bi * (1.0 - termination )) + reward
168- vs = (1.0 - lambda_ ) * Ai + lam * Bi
169-
170- return (Ai , Bi , lam ), (vs )
171-
172- Ai = jnp .ones_like (bootstrap_value )
173- Bi = jnp .zeros_like (bootstrap_value )
174- lam = jnp .ones_like (bootstrap_value )
175-
176- (_ , _ , _ ), (vs ) = jax .lax .scan (compute_v_st , (Ai , Bi , lam ),
177- (rewards , truncation_mask , values_t_plus_1 , termination ),
178- length = int (truncation_mask .shape [0 ]),
179- reverse = True )
180-
181- else :
182- vs = rewards + discount * values_t_plus_1
183-
184- return jax .lax .stop_gradient (vs )
185-
186-
187128def compute_shac_critic_loss (
188129 params : Params ,
189130 normalizer_params : Any ,
@@ -192,9 +133,13 @@ def compute_shac_critic_loss(
192133 shac_network : shac_networks .SHACNetworks ,
193134 discounting : float = 0.9 ,
194135 reward_scaling : float = 1.0 ,
195- lambda_ : float = 0.95 ) -> Tuple [jnp .ndarray , types .Metrics ]:
136+ lambda_ : float = 0.95 ,
137+ td_lambda : bool = True ) -> Tuple [jnp .ndarray , types .Metrics ]:
196138 """Computes SHAC critic loss.
197139
140+ This implements Eq. 7 of 2204.07137
141+ https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349
142+
198143 Args:
199144 params: Value network parameters,
200145 normalizer_params: Parameters of the normalizer.
@@ -207,8 +152,7 @@ def compute_shac_critic_loss(
207152 discounting: discounting,
208153 reward_scaling: reward multiplier.
209154 lambda_: Lambda for TD value updates
210- clipping_epsilon: Policy loss clipping epsilon
211- normalize_advantage: whether to normalize advantage estimate
155+ td_lambda: whether to use a TD-Lambda value target
212156
213157 Returns:
214158 A tuple (loss, metrics)
@@ -218,25 +162,47 @@ def compute_shac_critic_loss(
218162
219163 data = jax .tree_util .tree_map (lambda x : jnp .swapaxes (x , 0 , 1 ), data )
220164
221- baseline = value_apply (normalizer_params , params , data .observation )
222- bootstrap_value = value_apply (normalizer_params , params , data .next_observation [- 1 ])
165+ values = value_apply (normalizer_params , params , data .observation )
166+ terminal_value = value_apply (normalizer_params , params , data .next_observation [- 1 ])
223167
224168 rewards = data .reward * reward_scaling
225169 truncation = data .extras ['state_extras' ]['truncation' ]
226170 termination = (1 - data .discount ) * (1 - truncation )
227171
228- vs = compute_target_values (
229- truncation = truncation ,
230- termination = termination ,
231- rewards = rewards ,
232- values = baseline ,
233- bootstrap_value = bootstrap_value ,
234- discount = discounting ,
235- lambda_ = lambda_ )
172+ # Append terminal values to get [v1, ..., v_t+1]
173+ values_t_plus_1 = jnp .concatenate (
174+ [values [1 :], jnp .expand_dims (terminal_value , 0 )], axis = 0 )
175+
176+ # compute target values
177+ if td_lambda :
178+
179+ def compute_v_st (carry , target_t ):
180+ Ai , Bi , lam = carry
181+ reward , vtp1 , termination = target_t
182+
183+ reward = reward * termination
184+
185+ lam = lam * lambda_ * (1 - termination ) + termination
186+ Ai = (1 - termination ) * (lam * discounting * Ai + discounting * vtp1 + (1. - lam ) / (1. - lambda_ ) * reward )
187+ Bi = discounting * (vtp1 * termination + Bi * (1.0 - termination )) + reward
188+ vs = (1.0 - lambda_ ) * Ai + lam * Bi
189+
190+ return (Ai , Bi , lam ), (vs )
191+
192+ Ai = jnp .ones_like (terminal_value )
193+ Bi = jnp .zeros_like (terminal_value )
194+ lam = jnp .ones_like (terminal_value )
195+ (_ , _ , _ ), (vs ) = jax .lax .scan (compute_v_st , (Ai , Bi , lam ),
196+ (rewards , values_t_plus_1 , termination ),
197+ length = int (termination .shape [0 ]),
198+ reverse = True )
199+
200+ else :
201+ vs = rewards + discounting * values_t_plus_1
236202
237- v_error = vs - baseline
238- v_loss = jnp .mean (v_error * v_error ) * 0.5 * 0.5
203+ target_values = jax .lax .stop_gradient (vs )
239204
205+ v_loss = jnp .mean ((target_values - values ) ** 2 )
240206
241207 total_loss = v_loss
242208 return total_loss , {
0 commit comments