-
Notifications
You must be signed in to change notification settings - Fork 487
[Cleanup] Miscellaneous Refactors #1607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchtitan/components/validate.py
Outdated
@@ -82,8 +82,9 @@ def validate( | |||
step: int, | |||
) -> None: | |||
# Set model to eval mode | |||
for model in model_parts: | |||
model.eval() | |||
model = model_parts[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why keeping this? I think we only need this in non-PP case.
torchtitan/components/validate.py
Outdated
@@ -174,7 +175,8 @@ def validate( | |||
module.reshard() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After you switch the order of checkpoint and validate, I think this reshard part is not necessary and creates extra overhead -- we are doing validate (reshard) -> next train (unshard), which we could've avoided?
Could you check if it works well when checkpoint and validate happens on the same step? If so we can remove this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchtitan/train.py
Outdated
model_args, job_config.model.hf_assets_path | ||
model_args, | ||
job_config.model.hf_assets_path | ||
if job_config.checkpoint.enable_checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need this change here? hf_assets_path can be built during loading HF weights as well, even if it's not used, and we could leave the checkpointing related logic in checkpoint.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry you're right, I was trying to do some clean up in order to suppress this warning if the user is not intending to save in HF since in that case we wouldn't need a model.safetensor.index.json
, but this would interfere with if we want to load from HF from the hf_assets_path
. One other idea to suppress this would be to move this error to the checkpointer, but then it would make this warning difficult to overload such as we want to may want to do in Flux. https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/model/state_dict_adapter.py#L53-L57
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok to show the warning to user. When do we want to suppress the warning? I could only think if a model has only one safetensor file, which out the model.safetensor.index.json.
If a model checkpoint has multiple safetensor file, then both saving and loading needs to check the model.safetensor.index.json
file exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do we want to suppress the warning?
E.g. when checkpointing is not enabled at all.
If a model checkpoint has multiple safetensor file, then both saving and loading needs to check the
model.safetensor.index.json
file exists.
I believe only saving requires it. For loading it should be optional, but I'm not sure if it helps load faster. cc @ankitageorge to confirm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya only saving requires it
a4962f3
to
19bd2df
Compare
torchtitan/train.py
Outdated
job_config.model.hf_assets_path | ||
if job_config.checkpoint.enable | ||
and job_config.checkpoint.last_save_in_hf | ||
else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm even for load you need the path? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L543
I think there are two ways:
- always leave the warning there as is
- only pass the
hf_assets_path
in whencheckpoint.enable=True
I feel since 2 is not clean enough (as in we don't differentiate save vs. load), so I'm OK with 1.
cc @wwwjn for your opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK with 1 as well. We can leave the warning and let user aware of this
a4dca2d
to
ff95c80
Compare
@@ -398,13 +398,13 @@ class Parallelism: | |||
|
|||
@dataclass | |||
class Checkpoint: | |||
enable_checkpoint: bool = False | |||
enable: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, I was thinking to do this as well.
torchtitan/train.py
Outdated
model_args, | ||
( | ||
job_config.model.hf_assets_path | ||
if job_config.checkpoint.enable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please move using checkpoint information out of the trainer, if possible? We deliberately hide job_config.checkpoint
inside Checkpointer and let the trainer always call checkpointer API. This usage looks okay but would be good to hide if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice refactor!
This reverts commit cd337db.
This PR makes several miscellaneous refactors to clean up
torchtitan
before release.Changes:
model_parts
to eval mode inValidator
class to support PP (Bug fix)checkpoint.enable_checkpoint -> checkpoint.enable
(Refactor)validation.enabled -> validation.enable
(Refactor)