@@ -75,7 +75,9 @@ def _fill_indices_kernel(
7575# ==============
7676
7777
78- def fill_indices_wrapper (
78+ # workaround until local_map functionalization is fixed: https://github.com/pytorch/pytorch/issues/167568
79+ @torch .library .custom_op ("autoparallel::fill_indices_functional" , mutates_args = ())
80+ def fill_indices_functional (
7981 tokens_per_expert_group : torch .Tensor ,
8082 start_index_values : torch .Tensor ,
8183 write_offsets : torch .Tensor ,
@@ -84,7 +86,7 @@ def fill_indices_wrapper(
8486 max_len : int ,
8587 block_size : int = 128 ,
8688 max_blocks : int = 1024 , # cap on total number of blocks to launch
87- ):
89+ ) -> torch . Tensor :
8890 # preallocate output
8991 permuted_indices = torch .full (
9092 (max_len ,), - 1 , dtype = torch .int32 , device = tokens_per_expert_group .device
@@ -108,6 +110,22 @@ def fill_indices_wrapper(
108110 return permuted_indices
109111
110112
113+ @fill_indices_functional .register_fake
114+ def _ (
115+ tokens_per_expert_group : torch .Tensor ,
116+ start_index_values : torch .Tensor ,
117+ write_offsets : torch .Tensor ,
118+ experts_per_rank : int ,
119+ num_ranks : int ,
120+ max_len : int ,
121+ block_size : int = 128 ,
122+ max_blocks : int = 1024 , # cap on total number of blocks to launch
123+ ) -> torch .Tensor :
124+ return torch .full (
125+ (max_len ,), - 1 , dtype = torch .int32 , device = tokens_per_expert_group .device
126+ )
127+
128+
111129# reference
112130def fill_indices_cpu (
113131 tokens_per_expert_group : torch .Tensor ,
@@ -143,7 +161,6 @@ def fill_indices_cpu(
143161 start_index ,
144162 start_index + (end_idx - write_start ),
145163 dtype = torch .int32 ,
146- # device=device,
147164 )
148165 write_start += length
149166 return permuted_indices
@@ -213,7 +230,7 @@ def generate_permute_indices(
213230 max_len ,
214231 )
215232 else :
216- permuted_indices = fill_indices_wrapper (
233+ permuted_indices = fill_indices_functional (
217234 tokens_per_expert_group ,
218235 start_index_values ,
219236 write_offsets ,
0 commit comments