Skip to content

Commit 580144b

Browse files
committed
Custom opify triton kernel until local_map functionalization is fixed
stack-info: PR: #245, branch: xmfan/stack/19
1 parent 827188d commit 580144b

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def _fill_indices_kernel(
7575
# ==============
7676

7777

78-
def fill_indices_wrapper(
78+
# workaround until local_map functionalization is fixed: https://github.com/pytorch/pytorch/issues/167568
79+
@torch.library.custom_op("autoparallel::fill_indices_functional", mutates_args=())
80+
def fill_indices_functional(
7981
tokens_per_expert_group: torch.Tensor,
8082
start_index_values: torch.Tensor,
8183
write_offsets: torch.Tensor,
@@ -84,7 +86,7 @@ def fill_indices_wrapper(
8486
max_len: int,
8587
block_size: int = 128,
8688
max_blocks: int = 1024, # cap on total number of blocks to launch
87-
):
89+
) -> torch.Tensor:
8890
# preallocate output
8991
permuted_indices = torch.full(
9092
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
@@ -108,6 +110,22 @@ def fill_indices_wrapper(
108110
return permuted_indices
109111

110112

113+
@fill_indices_functional.register_fake
114+
def _(
115+
tokens_per_expert_group: torch.Tensor,
116+
start_index_values: torch.Tensor,
117+
write_offsets: torch.Tensor,
118+
experts_per_rank: int,
119+
num_ranks: int,
120+
max_len: int,
121+
block_size: int = 128,
122+
max_blocks: int = 1024, # cap on total number of blocks to launch
123+
) -> torch.Tensor:
124+
return torch.full(
125+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
126+
)
127+
128+
111129
# reference
112130
def fill_indices_cpu(
113131
tokens_per_expert_group: torch.Tensor,
@@ -143,7 +161,6 @@ def fill_indices_cpu(
143161
start_index,
144162
start_index + (end_idx - write_start),
145163
dtype=torch.int32,
146-
# device=device,
147164
)
148165
write_start += length
149166
return permuted_indices
@@ -213,7 +230,7 @@ def generate_permute_indices(
213230
max_len,
214231
)
215232
else:
216-
permuted_indices = fill_indices_wrapper(
233+
permuted_indices = fill_indices_functional(
217234
tokens_per_expert_group,
218235
start_index_values,
219236
write_offsets,

examples/example_ds3_local_map.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def input_fn():
177177
out.backward(torch.randn_like(out))
178178
else:
179179
out = parallel_mod(*x)
180+
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
181+
if rng_seed is not None:
182+
numerics_logger.log_forward_output(out)
180183
out.backward(torch.randn_like(out))
181184

182185
print("All good!")

0 commit comments

Comments
 (0)