@@ -89,7 +89,7 @@ def compute_shac_policy_loss(
8989 # jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227
9090 def sum_step (carry , target_t ):
9191 gam , rew_acc = carry
92- reward , v , termination = target_t
92+ reward , termination = target_t
9393
9494 # clean up gamma and rew_acc for done envs, otherwise update
9595 rew_acc = jnp .where (termination , 0 , rew_acc + gam * reward )
@@ -100,7 +100,7 @@ def sum_step(carry, target_t):
100100 rew_acc = jnp .zeros_like (terminal_values )
101101 gam = jnp .ones_like (terminal_values )
102102 (gam , last_rew_acc ), (gam_acc , rew_acc ) = jax .lax .scan (sum_step , (gam , rew_acc ),
103- (rewards , values , termination ))
103+ (rewards , termination ))
104104
105105 policy_loss = jnp .sum (- last_rew_acc - gam * terminal_values )
106106 # for trials that are truncated (i.e. hit the episode length) include reward for
@@ -118,7 +118,6 @@ def sum_step(carry, target_t):
118118 total_loss = policy_loss + entropy_loss
119119
120120 return total_loss , {
121- 'total_loss' : total_loss ,
122121 'policy_loss' : policy_loss ,
123122 'entropy_loss' : entropy_loss
124123 }
0 commit comments