Skip to content

TD-MPC2 Implementation#159

Open
luizfacury wants to merge 13 commits intogalilai-group:mainfrom
luizfacury:tdmpc2
Open

TD-MPC2 Implementation#159
luizfacury wants to merge 13 commits intogalilai-group:mainfrom
luizfacury:tdmpc2

Conversation

@luizfacury
Copy link
Copy Markdown
Contributor

This pull request introduces a new training pipeline for the TD-MPC2 world model and policy. The main changes include adding a comprehensive configuration file for TD-MPC2, implementing a new training script with data preprocessing and model management, and registering the TD-MPC2 model in the world model package. These updates enable flexible, modular training of TD-MPC2 using Hydra and PyTorch Lightning, and support multiple observation modalities.

TD-MPC2 Training Pipeline Integration:

  • Added a new Hydra configuration file tdmpc2.yaml specifying all hyperparameters for data loading, model architecture, planning, optimization, and logging for the TD-MPC2 algorithm.
  • Implemented tdmpc2.py, a new training script that:
    • Dynamically builds datasets and preprocessing transforms for multiple modalities (e.g., pixels, state).
    • Defines the TD-MPC2 forward pass, loss computation, and policy update logic.
    • Sets up PyTorch Lightning training, including a callback for periodic model checkpointing.
    • Integrates with the stable world model and stable pretraining libraries for modular training and data management.

Model Registration:

  • Registered the TD-MPC2 model and module in the stable_worldmodel.wm package by updating the __init__.py file, making it available for import and use throughout the codebase.

@luizfacury luizfacury changed the title tdmpc2 implementation TD-MPC2 Implementation Mar 13, 2026
Copy link
Copy Markdown
Collaborator

@quentinll quentinll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this nice contribution! Do you already have results using TDMPC2 in some environments?

return G + discount * (1 - termination) * conservative_q

@torch.no_grad()
def _plan(self, obs_dict, goal_dict, step_idxs, eval_mode=False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to use the MPPI algorithm already implemented in stable-worldmodel. I believe, one way would be to use the policy API defined here https://github.com/galilai-group/stable-worldmodel/blob/main/stable_worldmodel/policy.py

Copy link
Copy Markdown
Contributor Author

@luizfacury luizfacury Mar 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it on pushT and managed to get about 85% success rate predicting 50 steps ahead of a random state in a trajectory 100 times.

About the solver: td-mpc2 uses some trajectories generated directly by the Actor network, rather than by pure Gaussian noise. From their paper:

To accelerate convergence of planning, a fraction of action sequences originate from the policy prior pi , and we warm-start planning by initializing ( μ , σ ) as the solution to the previous decision step shifted by 1.

From what I understand, in your standard mppi you don't update the var of the distribution, just the mean, and all of the trajectories are randomly generated. I tested using the implemenetd MPPI, and to me at least it hallucinated much more with OOD actions.
I used your FeedForwardPolicy to work with it.
I can change it if you want anyway, I just tried to do it as close as possible to their impl.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our MPPI implementation already supports starting from an initial guess

def init_action_distrib(
so it is possible to warm start it using the output from an actor network. For now FeedForwardPolicy does not support hybrid strategies like TDMPC. I believe we should add a new policy class that would allow to use MPC by starting from the guess of an actor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I created a tdmpc policy that uses the warm start for the actor. I tested for pushT again and it's working for me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need something else for this model? @quentinll

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Luiz, thanks again for this very nice contribution. This seems close to what I had in mind in terms of implementation, I will have a closer look ASAP. Which script are you using for evaluation? In the meantime, it would be nice if you could do additionnal evaluations on other environments such as tworoom and OGBench cube. What do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will downolad the data and test on tworoom and cube. I will let you know when I have the results.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did quite a few changes on how the policy uses the actor to warm start the solver. I removed the TDMPCPolicy and extended the WorldModelPolicy to support this. Could you let me know if you are still getting the same results? In theory the behavior should be identical.

Assumptions:
- Continuous Control: The algorithm assumes continuous action spaces.
- Action Bounds: Actions are strictly assumed to be normalized to the range [-1.0, 1.0].
The actor network and MPPI planner enforce this bound via Tanh and clamping.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still accurate?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, tdmpc2 uses a squashed Gaussian (tanh at the output), so actions are bounded to (-1, 1) this way. Normalizing to this range is required for the two-hot value distribution and symlog reward scaling to be correct.

Comment on lines +153 to +164
if self.use_pixels:
self.cnn = nn.Sequential(
nn.Conv2d(6, 32, 7, stride=2),
nn.Mish(),
nn.Conv2d(32, 32, 5, stride=2),
nn.Mish(),
nn.Conv2d(32, 32, 3, stride=2),
nn.Mish(),
nn.Conv2d(32, 32, 3, stride=1),
nn.Mish(),
nn.Flatten(),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually use 224x224 image would that be an issue?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine, the cnn works regardless of image size, and output is configured accordingly. I tested different dimensions too

continue # Handled by primary backbone

in_dim = cfg.extra_dims[key] * 2
self.extra_encoders[key] = nn.Sequential(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why extra encoders should necessarily instantiate these networks?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I changed it.

self.reward = mlp(
self.latent_dim + cfg.action_dim, cfg.wm.mlp_dim, cfg.wm.num_bins
)
self.pi = mlp(self.latent_dim, cfg.wm.mlp_dim, 2 * cfg.action_dim)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add so comment to help people understand what is pi etc? (I know it's the policy but it might be confusing for new-comers)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

for p in self.target_qs.parameters():
p.requires_grad = False

def encode(self, obs_dict, goal_dict):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do encoder have the goal_dict?

cf. DINO-WM you can simple have encode(dict)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


def forward(self, z, action):
"""
Predicts the next latent state and expected reward given the current latent state and action.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be more clear if it was called predict?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used forward because it's pytorch's convention, to use model()

Comment on lines +353 to +357
if key != 'pixels':
if obs.ndim >= 3:
obs = obs[..., -1, :]
if goal.ndim >= 3:
goal = goal[..., -1, :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant code, have a look at DINO-WM implem

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I improved it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants