Skip to content

[Bug] An error in MaskPPO training #81

@Yangxiaojun1230

Description

@Yangxiaojun1230

System Info
Describe the characteristic of your environment:

Describe how the library was installed: pip
sb3-contrib=='1.5.1a9'
Python: 3.8.13
Stable-Baselines3: 1.5.1a9
PyTorch: 1.11.0+cu102
GPU Enabled: False
Numpy: 1.22.3
Gym: 0.21.0

My training code as below:
model = MaskablePPO("MultiInputPolicy", env, gamma=0.4, seed=32, verbose=0)
model.learn(300000)
My action space is spaces.Discrete() . It seems a problem in torch distribution init(), the input logits had invalid value. And the error happened at uncertain training step.

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:579, in MaskablePPO.learn(self, total_timesteps, callback, log_interval, eval_env, eval_freq, n_eval_episodes, tb_log_name, eval_log_path, reset_num_timesteps, use_masking)
576 self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
577 self.logger.dump(step=self.num_timesteps)
--> 579 self.train()
581 callback.on_training_end()
583 return self

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:439, in MaskablePPO.train(self)
435 if isinstance(self.action_space, spaces.Discrete):
436 # Convert discrete action from float to long
437 actions = rollout_data.actions.long().flatten()
--> 439 values, log_prob, entropy = self.policy.evaluate_actions(
440 rollout_data.observations,
441 actions,
442 action_masks=rollout_data.action_masks,
443 )
445 values = values.flatten()
446 # Normalize advantage

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/policies.py:280, in MaskableActorCriticPolicy.evaluate_actions(self, obs, actions, action_masks)
278 distribution = self._get_action_dist_from_latent(latent_pi)
279 if action_masks is not None:
--> 280 distribution.apply_masking(action_masks)
281 log_prob = distribution.log_prob(actions)
282 values = self.value_net(latent_vf)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py:152, in MaskableCategoricalDistribution.apply_masking(self, masks)
150 def apply_masking(self, masks: Optional[np.ndarray]) -> None:
151 assert self.distribution is not None, "Must set distribution parameters"
--> 152 self.distribution.apply_masking(masks)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py:62, in MaskableCategorical.apply_masking(self, masks)
59 logits = self._original_logits
61 # Reinitialize with updated logits
---> 62 super().init(logits=logits)
64 # self.probs may already be cached, so we must force an update
65 self.probs = logits_to_probs(self.logits)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/torch/distributions/categorical.py:64, in Categorical.init(self, probs, logits, validate_args)
62 self._num_events = self._param.size()[-1]
63 batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
---> 64 super(Categorical, self).init(batch_shape, validate_args=validate_args)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.init(self, batch_shape, event_shape, validate_args)
53 valid = constraint.check(value)
54 if not valid.all():
---> 55 raise ValueError(
56 f"Expected parameter {param} "
57 f"({type(value).name} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
59 f"to satisfy the constraint {repr(constraint)}, "
60 f"but found invalid values:\n{value}"
61 )
62 super(Distribution, self).init()

ValueError: Expected parameter probs (Tensor of shape (64, 400)) of distribution MaskableCategorical(probs: torch.Size([64, 400]), logits: torch.Size([64, 400])) to satisfy the constraint Simplex(), but found invalid values:

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedHelp from contributors is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions