Skip to content

Commit f5c6d08

Browse files
Fix test_spmd test (#372)
1 parent b438b9c commit f5c6d08

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

torchprime/torch_xla_models/model/mixtral/model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,24 @@ def _eager_gmm_backward(grad_output, lhs, rhs, group_sizes):
333333
start += size
334334
return torch.cat(grad_lhs), torch.stack(grad_rhs)
335335

336+
@staticmethod
337+
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
338+
"""
339+
Compute the histogram of a int32 tensor. The bin edges are defined by the min and max values, with step = 1.
340+
"""
341+
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
342+
assert min <= max, "min must be less than or equal to max."
343+
344+
def searchsorted(
345+
sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
346+
) -> torch.Tensor:
347+
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
348+
349+
bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
350+
input.device
351+
)
352+
return searchsorted(bin_edges, input).to(torch.int32)
353+
336354
@staticmethod
337355
@xp.trace_me("gmm_forward")
338356
def forward(
@@ -352,7 +370,7 @@ def forward(
352370
w2: [num_experts, ffn_dim, hidden_size]
353371
w3: [num_experts, hidden_size, ffn_dim]
354372
"""
355-
from torch_xla.experimental.custom_kernel import _histogram, gmm
373+
from torch_xla.experimental.custom_kernel import gmm
356374

357375
device = hidden_states.device
358376
if device == torch.device("cpu"):
@@ -397,7 +415,7 @@ def forward(
397415
).repeat_interleave(k)[hidden_states_order]
398416
hidden_states_sorted = hidden_states[hidden_states_indices]
399417

400-
group_sizes = _histogram(top_flat.to(torch.int32), 0, num_experts - 1)
418+
group_sizes = Gmm._histogram(top_flat.to(torch.int32), 0, num_experts - 1)
401419
gmm1 = gmm(hidden_states_sorted, w1, group_sizes, tiling=(512, 1024, 1024))
402420
gmm3 = gmm(hidden_states_sorted, w3, group_sizes, tiling=(512, 1024, 1024))
403421
silu = F.silu(gmm1)

torchprime/torch_xla_models/tests/test_spmd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def test_mixtral_config_sharding_against_fsdp_v2(self):
341341
"router_aux_loss_coef": 0.02,
342342
"attention_bias": False,
343343
"attention_dropout": 0.0,
344+
"attention_kernel": "splash_attention",
344345
"flash_attention": True,
345346
"moe_implementation": "gmm",
346347
}

0 commit comments

Comments
 (0)