Skip to content

[Cleanup] Miscellaneous Refactors #1607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 22, 2025
Merged

[Cleanup] Miscellaneous Refactors #1607

merged 7 commits into from
Aug 22, 2025

Conversation

wesleytruong
Copy link
Contributor

@wesleytruong wesleytruong commented Aug 20, 2025

This PR makes several miscellaneous refactors to clean up torchtitan before release.

Changes:

  • Sets each of model_parts to eval mode in Validator class to support PP (Bug fix)
  • Refactor checkpoint.enable_checkpoint -> checkpoint.enable (Refactor)
  • Refacotr validation.enabled -> validation.enable (Refactor)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 20, 2025
@@ -82,8 +82,9 @@ def validate(
step: int,
) -> None:
# Set model to eval mode
for model in model_parts:
model.eval()
model = model_parts[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why keeping this? I think we only need this in non-PP case.

@@ -174,7 +175,8 @@ def validate(
module.reshard()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After you switch the order of checkpoint and validate, I think this reshard part is not necessary and creates extra overhead -- we are doing validate (reshard) -> next train (unshard), which we could've avoided?

Could you check if it works well when checkpoint and validate happens on the same step? If so we can remove this code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, in my test it looks like removing this reshard doesn't affect memory usage or loss so I think it should be fine to remove.

Reshard No Reshard
Screenshot 2025-08-20 at 3 10 20 PM Screenshot 2025-08-20 at 3 36 45 PM

@tianyu-l tianyu-l requested a review from ebsmothers August 20, 2025 20:29
model_args, job_config.model.hf_assets_path
model_args,
job_config.model.hf_assets_path
if job_config.checkpoint.enable_checkpoint
Copy link
Contributor

@wwwjn wwwjn Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this change here? hf_assets_path can be built during loading HF weights as well, even if it's not used, and we could leave the checkpointing related logic in checkpoint.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry you're right, I was trying to do some clean up in order to suppress this warning if the user is not intending to save in HF since in that case we wouldn't need a model.safetensor.index.json, but this would interfere with if we want to load from HF from the hf_assets_path. One other idea to suppress this would be to move this error to the checkpointer, but then it would make this warning difficult to overload such as we want to may want to do in Flux. https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/model/state_dict_adapter.py#L53-L57

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to show the warning to user. When do we want to suppress the warning? I could only think if a model has only one safetensor file, which out the model.safetensor.index.json.

If a model checkpoint has multiple safetensor file, then both saving and loading needs to check the model.safetensor.index.json file exists.

Copy link
Contributor

@tianyu-l tianyu-l Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do we want to suppress the warning?

E.g. when checkpointing is not enabled at all.

If a model checkpoint has multiple safetensor file, then both saving and loading needs to check the model.safetensor.index.json file exists.

I believe only saving requires it. For loading it should be optional, but I'm not sure if it helps load faster. cc @ankitageorge to confirm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya only saving requires it

@tianyu-l tianyu-l added the release blocking Issues that are blocking the milestone / release completion label Aug 21, 2025
job_config.model.hf_assets_path
if job_config.checkpoint.enable
and job_config.checkpoint.last_save_in_hf
else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm even for load you need the path? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L543

I think there are two ways:

  1. always leave the warning there as is
  2. only pass the hf_assets_path in when checkpoint.enable=True

I feel since 2 is not clean enough (as in we don't differentiate save vs. load), so I'm OK with 1.
cc @wwwjn for your opinion.

Copy link
Contributor

@wwwjn wwwjn Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with 1 as well. We can leave the warning and let user aware of this

@wesleytruong wesleytruong changed the title [Validation] fix setting all model_parts to eval mode [Cleanup] Miscellaneous Refactors Aug 22, 2025
@@ -398,13 +398,13 @@ class Parallelism:

@dataclass
class Checkpoint:
enable_checkpoint: bool = False
enable: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, I was thinking to do this as well.

model_args,
(
job_config.model.hf_assets_path
if job_config.checkpoint.enable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please move using checkpoint information out of the trainer, if possible? We deliberately hide job_config.checkpoint inside Checkpointer and let the trainer always call checkpointer API. This usage looks okay but would be good to hide if possible.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice refactor!

@wesleytruong wesleytruong merged commit cd337db into main Aug 22, 2025
10 checks passed
@tianyu-l tianyu-l deleted the validation_pp_fix branch August 22, 2025 21:37
alfuyao1986 pushed a commit to AMD-AIG-AIMA/torchtitan-amd that referenced this pull request Aug 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot. release blocking Issues that are blocking the milestone / release completion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants