-
Notifications
You must be signed in to change notification settings - Fork 671
Fix DTensor/torch.Tensor compatibility in LinearCrossEntropyLoss #2898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes pytorch#2856 When using distributed LoRA fine-tuning with custom_sharded_layers, some tensors become DTensors while others remain regular tensors. This caused a RuntimeError when computing cross-entropy loss. The fix adds compatibility handling in LinearCrossEntropyLoss.compute_cross_entropy by checking tensor types before the linear projection. When there's a type mismatch: - If weight is DTensor and hidden is not: convert hidden to DTensor - If hidden is DTensor and weight is not: convert hidden to local tensor This ensures compatibility for distributed training while maintaining normal operation for non-distributed cases. Changes: - Add DTensor type checking before self.linear_projection call - Handle tensor type conversion when mismatch detected - Add regression test for the issue - No impact on non-distributed training
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2898
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Cancelled JobsAs of commit 2527975 with merge base b22a3ae ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I've been working on some similar issues in Hugging Face PEFT and llama-cookbook, so thought I'd jump in and see if i could resolve this. Totally open to any changes or feedback. cc @krammnic |
I will do few sanity checks and then we can proceed on this |
# This case is less likely but handle it | ||
hidden_chunk = hidden_chunk.to_local() | ||
except ImportError: | ||
# DTensor not available in this PyTorch version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to worry about this? torchtune
I believe only needs to support latest stable and prerelease, so DTensor
should always be importable.
# When using FSDP with custom_sharded_layers, some tensors might be DTensors | ||
# while others are regular tensors, causing compatibility issues | ||
if hasattr(torch.distributed, '_tensor') and torch.distributed.is_initialized(): | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally blocks like this actually sit outside compute_cross_entropy
(perhaps in forward
), because compute_cross_entropy
gets compiled, and type branching doesn't appear to play nicely with compile.
Compile support here is already muddy, but calling this part outside compute_cross_entropy
can't hurt.
from torch.distributed._tensor import DTensor | ||
|
||
# For linear_projection modules, we need to check the weight parameter | ||
if hasattr(self.linear_projection, 'weight'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for this to be false
?
) | ||
elif hidden_is_dtensor and not weight_is_dtensor: | ||
# This case is less likely but handle it | ||
hidden_chunk = hidden_chunk.to_local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? If hidden_chunk
is a DTensor, according to the forward
logic, it should be sharded on the feature dimension (bs*seq_len, feature_dim / tp_dim)
. Then since the weight isn't a DTensor in this branch, presumably it has the original feature dimension and shape (feature_dim, vocab_size)
, so the matmul shapes don't match.
I may be missing something, feel free to correct me! I don't have access to a machine to test right now.
Apologies if I jumped on reviewing this too soon. Let me know if anything doesn't make sense. I'll take a closer look when you're happy to proceed :) |
Hi @nathan-az, thank you for reviewing this! You're totally right about the dimension mismatch; trying to convert tensors was the wrong approach. I need to find the root FSDP problem thats causing this Next steps for revising this PR: Go to where custom_sharded_layers is being used in the FSDP configuration, and see whether linear_projection layer is being excluded from wrapping. If so, update the config to make sure linear_projection is wrapped w/ the rest of the model. |
Closing this PR in favor of #2900, which takes a better approach by validating the configuration rather than trying to convert tensor types at runtime. The validation approach addresses the root cause without the dimension mismatch and compilation issues identified in the review. |
Summary
Fixes #2856 - DTensor/torch.Tensor mixed type error in Llama4 LoRA fine-tuning
Problem
When running distributed LoRA fine-tuning with
custom_sharded_layers
, the training fails with:RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
This occurs because FSDP wraps some tensors as DTensors while others remain regular tensors, causing a type mismatch in the loss computation.
Solution
Added DTensor compatibility handling in
LinearCrossEntropyLoss.compute_cross_entropy()
by checking tensor types before the linear projection:Testing
test_dtensor_cross_entropy.py
Test Plan
To verify the fix:
cc @krammnic