Skip to content

Conversation

jscaldwell55
Copy link

@jscaldwell55 jscaldwell55 commented Aug 2, 2025

Summary

Fixes #2856 - DTensor/torch.Tensor mixed type error in Llama4 LoRA fine-tuning

Problem

When running distributed LoRA fine-tuning with custom_sharded_layers, the training fails with:
RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

This occurs because FSDP wraps some tensors as DTensors while others remain regular tensors, causing a type mismatch in the loss computation.

Solution

Added DTensor compatibility handling in LinearCrossEntropyLoss.compute_cross_entropy() by checking tensor types before the linear projection:

  • If weight is DTensor and hidden is not: convert hidden to DTensor to match
  • If hidden is DTensor and weight is not: convert hidden to local tensor
  • No-op for matching types or non-distributed training

Testing

  • Added regression test in test_dtensor_cross_entropy.py
  • Verified Python syntax and imports work correctly
  • No impact on non-distributed training paths

Test Plan

To verify the fix:

# Run the regression test
pytest tests/torchtune/modules/loss/test_dtensor_cross_entropy.py -v

# Test distributed LoRA training
tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora

cc @krammnic

Fixes pytorch#2856

When using distributed LoRA fine-tuning with custom_sharded_layers,
some tensors become DTensors while others remain regular tensors.
This caused a RuntimeError when computing cross-entropy loss.

The fix adds compatibility handling in LinearCrossEntropyLoss.compute_cross_entropy
by checking tensor types before the linear projection. When there's a type mismatch:
- If weight is DTensor and hidden is not: convert hidden to DTensor
- If hidden is DTensor and weight is not: convert hidden to local tensor

This ensures compatibility for distributed training while maintaining
normal operation for non-distributed cases.

Changes:
- Add DTensor type checking before self.linear_projection call
- Handle tensor type conversion when mismatch detected
- Add regression test for the issue
- No impact on non-distributed training
Copy link

pytorch-bot bot commented Aug 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2898

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 4 Cancelled Jobs

As of commit 2527975 with merge base b22a3ae (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 2, 2025
@jscaldwell55
Copy link
Author

I've been working on some similar issues in Hugging Face PEFT and llama-cookbook, so thought I'd jump in and see if i could resolve this. Totally open to any changes or feedback.

cc @krammnic

@krammnic
Copy link
Collaborator

krammnic commented Aug 3, 2025

I will do few sanity checks and then we can proceed on this

# This case is less likely but handle it
hidden_chunk = hidden_chunk.to_local()
except ImportError:
# DTensor not available in this PyTorch version
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to worry about this? torchtune I believe only needs to support latest stable and prerelease, so DTensor should always be importable.

# When using FSDP with custom_sharded_layers, some tensors might be DTensors
# while others are regular tensors, causing compatibility issues
if hasattr(torch.distributed, '_tensor') and torch.distributed.is_initialized():
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally blocks like this actually sit outside compute_cross_entropy (perhaps in forward), because compute_cross_entropy gets compiled, and type branching doesn't appear to play nicely with compile.

Compile support here is already muddy, but calling this part outside compute_cross_entropy can't hurt.

from torch.distributed._tensor import DTensor

# For linear_projection modules, we need to check the weight parameter
if hasattr(self.linear_projection, 'weight'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for this to be false?

)
elif hidden_is_dtensor and not weight_is_dtensor:
# This case is less likely but handle it
hidden_chunk = hidden_chunk.to_local()
Copy link
Collaborator

@nathan-az nathan-az Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? If hidden_chunk is a DTensor, according to the forward logic, it should be sharded on the feature dimension (bs*seq_len, feature_dim / tp_dim). Then since the weight isn't a DTensor in this branch, presumably it has the original feature dimension and shape (feature_dim, vocab_size), so the matmul shapes don't match.

I may be missing something, feel free to correct me! I don't have access to a machine to test right now.

@nathan-az
Copy link
Collaborator

Apologies if I jumped on reviewing this too soon. Let me know if anything doesn't make sense. I'll take a closer look when you're happy to proceed :)

@jscaldwell55
Copy link
Author

jscaldwell55 commented Aug 4, 2025

Hi @nathan-az, thank you for reviewing this! You're totally right about the dimension mismatch; trying to convert tensors was the wrong approach. I need to find the root FSDP problem thats causing this

Next steps for revising this PR:

Go to where custom_sharded_layers is being used in the FSDP configuration, and see whether linear_projection layer is being excluded from wrapping. If so, update the config to make sure linear_projection is wrapped w/ the rest of the model.

@jscaldwell55
Copy link
Author

Closing this PR in favor of #2900, which takes a better approach by validating the configuration rather than trying to convert tensor types at runtime. The validation approach addresses the root cause without the dimension mismatch and compilation issues identified in the review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug Report: DTensor/torch.Tensor Mixed Type Error in Llama4 LoRA Fine-tuning
3 participants