-
Notifications
You must be signed in to change notification settings - Fork 140
Description
Hi there! I've been having trouble learning on the continuous action variant of SMAX (I'm exploring continuous action MARL problems), and doing a little deeper dive into the code for the environment led me to some sections which makes me think the current setup of the SMAX continuous action space has a few issues that might prevent anyone from learning an optimal policy.
In the SMAX class's _decode_continuous_actions function, there's this chunk of code (~lines 554- 567):
shoot_last_idx = self.continuous_action_dims.index("shoot_last_enemy")
action_idx = self.continuous_action_dims.index("do_shoot")
theta_idx = self.continuous_action_dims.index("coordinate_2")
r_idx = self.continuous_action_dims.index("coordinate_1")
shoot_last_enemy_logits = jnp.array(
[
jnp.log(actions[:, shoot_last_idx]),
jnp.log(1 - actions[:, shoot_last_idx]),
]
)
logits = jnp.array(
[jnp.log(actions[:, action_idx]), jnp.log(1 - actions[:, action_idx])]
)
move_or_shoot_key, shoot_last_enemy_key = jax.random.split(key)
move_or_shoot = jax.random.categorical(move_or_shoot_key, logits, axis=0)
shoot_last_enemy = jax.random.categorical(
shoot_last_enemy_key, shoot_last_enemy_logits, axis=0
)
As far as I understand, this says that even though the continuous action SMAX is defined with 4 continuous actions, the first two actions, do_shoot and shoot_last_enemy, are actually converted to discrete under the hood in the environment, where they are treating the continuous output as direct logits. This obviously has implications, e.g. when you're using MAPPO and you're calculating log probs/entropies for policy updates, because the environment action space is described as Continuous for all 4 dimensions, but in reality the first two are actually treated as Discrete actions.
Beyond the issue that one might accidentally be using a Gaussian incorrectly for all 4 actions, and thus sampling the wrong log probs/entropies, there is also the issue that the environment itself is generating its own Categorical distributions and doing the sampling internally. This means I have no way in my policy updates to know what action was selected to get correct probabilities and entropies. I think what this means is that if I use, say MAPPO, then I'm always stuck to using incorrect log probs/entropies while doing my updates. Could anyone validate this for me or double check?