@@ -333,6 +333,24 @@ def _eager_gmm_backward(grad_output, lhs, rhs, group_sizes):
333333 start += size
334334 return torch .cat (grad_lhs ), torch .stack (grad_rhs )
335335
336+ @staticmethod
337+ def _histogram (input : torch .Tensor , min : int , max : int ) -> torch .Tensor :
338+ """
339+ Compute the histogram of a int32 tensor. The bin edges are defined by the min and max values, with step = 1.
340+ """
341+ assert input .dtype == torch .int32 , "input must be of torch.int32 dtype."
342+ assert min <= max , "min must be less than or equal to max."
343+
344+ def searchsorted (
345+ sorted_sequence : torch .Tensor , values_to_search : torch .Tensor
346+ ) -> torch .Tensor :
347+ return (sorted_sequence .unsqueeze (1 ) == values_to_search ).sum (dim = 1 )
348+
349+ bin_edges = torch .linspace (min , max , max - min + 1 , dtype = input .dtype ).to (
350+ input .device
351+ )
352+ return searchsorted (bin_edges , input ).to (torch .int32 )
353+
336354 @staticmethod
337355 @xp .trace_me ("gmm_forward" )
338356 def forward (
@@ -352,7 +370,7 @@ def forward(
352370 w2: [num_experts, ffn_dim, hidden_size]
353371 w3: [num_experts, hidden_size, ffn_dim]
354372 """
355- from torch_xla .experimental .custom_kernel import _histogram , gmm
373+ from torch_xla .experimental .custom_kernel import gmm
356374
357375 device = hidden_states .device
358376 if device == torch .device ("cpu" ):
@@ -397,7 +415,7 @@ def forward(
397415 ).repeat_interleave (k )[hidden_states_order ]
398416 hidden_states_sorted = hidden_states [hidden_states_indices ]
399417
400- group_sizes = _histogram (top_flat .to (torch .int32 ), 0 , num_experts - 1 )
418+ group_sizes = Gmm . _histogram (top_flat .to (torch .int32 ), 0 , num_experts - 1 )
401419 gmm1 = gmm (hidden_states_sorted , w1 , group_sizes , tiling = (512 , 1024 , 1024 ))
402420 gmm3 = gmm (hidden_states_sorted , w3 , group_sizes , tiling = (512 , 1024 , 1024 ))
403421 silu = F .silu (gmm1 )
0 commit comments