Conversation
quentinll
left a comment
There was a problem hiding this comment.
Thank you for this nice contribution! Do you already have results using TDMPC2 in some environments?
stable_worldmodel/wm/tdmpc2.py
Outdated
| return G + discount * (1 - termination) * conservative_q | ||
|
|
||
| @torch.no_grad() | ||
| def _plan(self, obs_dict, goal_dict, step_idxs, eval_mode=False): |
There was a problem hiding this comment.
It would be nice to use the MPPI algorithm already implemented in stable-worldmodel. I believe, one way would be to use the policy API defined here https://github.com/galilai-group/stable-worldmodel/blob/main/stable_worldmodel/policy.py
There was a problem hiding this comment.
I tested it on pushT and managed to get about 85% success rate predicting 50 steps ahead of a random state in a trajectory 100 times.
About the solver: td-mpc2 uses some trajectories generated directly by the Actor network, rather than by pure Gaussian noise. From their paper:
To accelerate convergence of planning, a fraction of action sequences originate from the policy prior pi , and we warm-start planning by initializing ( μ , σ ) as the solution to the previous decision step shifted by 1.
From what I understand, in your standard mppi you don't update the var of the distribution, just the mean, and all of the trajectories are randomly generated. I tested using the implemenetd MPPI, and to me at least it hallucinated much more with OOD actions.
I used your FeedForwardPolicy to work with it.
I can change it if you want anyway, I just tried to do it as close as possible to their impl.
There was a problem hiding this comment.
Our MPPI implementation already supports starting from an initial guess
so it is possible to warm start it using the output from an actor network. For nowFeedForwardPolicy does not support hybrid strategies like TDMPC. I believe we should add a new policy class that would allow to use MPC by starting from the guess of an actor.
There was a problem hiding this comment.
Ok, I created a tdmpc policy that uses the warm start for the actor. I tested for pushT again and it's working for me.
There was a problem hiding this comment.
Do you need something else for this model? @quentinll
There was a problem hiding this comment.
Hey Luiz, thanks again for this very nice contribution. This seems close to what I had in mind in terms of implementation, I will have a closer look ASAP. Which script are you using for evaluation? In the meantime, it would be nice if you could do additionnal evaluations on other environments such as tworoom and OGBench cube. What do you think?
There was a problem hiding this comment.
Sure, I will downolad the data and test on tworoom and cube. I will let you know when I have the results.
There was a problem hiding this comment.
I did quite a few changes on how the policy uses the actor to warm start the solver. I removed the TDMPCPolicy and extended the WorldModelPolicy to support this. Could you let me know if you are still getting the same results? In theory the behavior should be identical.
| Assumptions: | ||
| - Continuous Control: The algorithm assumes continuous action spaces. | ||
| - Action Bounds: Actions are strictly assumed to be normalized to the range [-1.0, 1.0]. | ||
| The actor network and MPPI planner enforce this bound via Tanh and clamping. |
There was a problem hiding this comment.
Is this still accurate?
There was a problem hiding this comment.
Yes, tdmpc2 uses a squashed Gaussian (tanh at the output), so actions are bounded to (-1, 1) this way. Normalizing to this range is required for the two-hot value distribution and symlog reward scaling to be correct.
| if self.use_pixels: | ||
| self.cnn = nn.Sequential( | ||
| nn.Conv2d(6, 32, 7, stride=2), | ||
| nn.Mish(), | ||
| nn.Conv2d(32, 32, 5, stride=2), | ||
| nn.Mish(), | ||
| nn.Conv2d(32, 32, 3, stride=2), | ||
| nn.Mish(), | ||
| nn.Conv2d(32, 32, 3, stride=1), | ||
| nn.Mish(), | ||
| nn.Flatten(), | ||
| ) |
There was a problem hiding this comment.
We usually use 224x224 image would that be an issue?
There was a problem hiding this comment.
It's fine, the cnn works regardless of image size, and output is configured accordingly. I tested different dimensions too
stable_worldmodel/wm/tdmpc2.py
Outdated
| continue # Handled by primary backbone | ||
|
|
||
| in_dim = cfg.extra_dims[key] * 2 | ||
| self.extra_encoders[key] = nn.Sequential( |
There was a problem hiding this comment.
why extra encoders should necessarily instantiate these networks?
There was a problem hiding this comment.
You're right, I changed it.
| self.reward = mlp( | ||
| self.latent_dim + cfg.action_dim, cfg.wm.mlp_dim, cfg.wm.num_bins | ||
| ) | ||
| self.pi = mlp(self.latent_dim, cfg.wm.mlp_dim, 2 * cfg.action_dim) |
There was a problem hiding this comment.
can you add so comment to help people understand what is pi etc? (I know it's the policy but it might be confusing for new-comers)
stable_worldmodel/wm/tdmpc2.py
Outdated
| for p in self.target_qs.parameters(): | ||
| p.requires_grad = False | ||
|
|
||
| def encode(self, obs_dict, goal_dict): |
There was a problem hiding this comment.
why do encoder have the goal_dict?
cf. DINO-WM you can simple have encode(dict)
stable_worldmodel/wm/tdmpc2.py
Outdated
|
|
||
| def forward(self, z, action): | ||
| """ | ||
| Predicts the next latent state and expected reward given the current latent state and action. |
There was a problem hiding this comment.
wouldn't it be more clear if it was called predict?
There was a problem hiding this comment.
I used forward because it's pytorch's convention, to use model()
stable_worldmodel/wm/tdmpc2.py
Outdated
| if key != 'pixels': | ||
| if obs.ndim >= 3: | ||
| obs = obs[..., -1, :] | ||
| if goal.ndim >= 3: | ||
| goal = goal[..., -1, :] |
There was a problem hiding this comment.
redundant code, have a look at DINO-WM implem
There was a problem hiding this comment.
Ok, I think I improved it
…y; add new protocols in protocols.py
…Policy to incorporate a warm_start from an actionable model
…port warm-start actions
…rt prefix actions
…olvers and update related logic
…sts to ensure correct tensor shapes and values
This pull request introduces a new training pipeline for the TD-MPC2 world model and policy. The main changes include adding a comprehensive configuration file for TD-MPC2, implementing a new training script with data preprocessing and model management, and registering the TD-MPC2 model in the world model package. These updates enable flexible, modular training of TD-MPC2 using Hydra and PyTorch Lightning, and support multiple observation modalities.
TD-MPC2 Training Pipeline Integration:
tdmpc2.yamlspecifying all hyperparameters for data loading, model architecture, planning, optimization, and logging for the TD-MPC2 algorithm.tdmpc2.py, a new training script that:Model Registration:
stable_worldmodel.wmpackage by updating the__init__.pyfile, making it available for import and use throughout the codebase.