Skip to content

Commit 6e8451c

Browse files
committed
Custom opify triton kernel until local_map functionalization is fixed
stack-info: PR: #245, branch: xmfan/stack/19
1 parent 827188d commit 6e8451c

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def _fill_indices_kernel(
7575
# ==============
7676

7777

78-
def fill_indices_wrapper(
78+
# workaround until local_map functionalization is fixed: https://github.com/pytorch/pytorch/issues/167568
79+
@torch.library.custom_op("autoparallel::fill_indices_functional", mutates_args=())
80+
def fill_indices_functional(
7981
tokens_per_expert_group: torch.Tensor,
8082
start_index_values: torch.Tensor,
8183
write_offsets: torch.Tensor,
@@ -84,7 +86,7 @@ def fill_indices_wrapper(
8486
max_len: int,
8587
block_size: int = 128,
8688
max_blocks: int = 1024, # cap on total number of blocks to launch
87-
):
89+
) -> torch.Tensor:
8890
# preallocate output
8991
permuted_indices = torch.full(
9092
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
@@ -108,6 +110,22 @@ def fill_indices_wrapper(
108110
return permuted_indices
109111

110112

113+
@fill_indices_functional.register_fake
114+
def _(
115+
tokens_per_expert_group: torch.Tensor,
116+
start_index_values: torch.Tensor,
117+
write_offsets: torch.Tensor,
118+
experts_per_rank: int,
119+
num_ranks: int,
120+
max_len: int,
121+
block_size: int = 128,
122+
max_blocks: int = 1024, # cap on total number of blocks to launch
123+
) -> torch.Tensor:
124+
return torch.full(
125+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
126+
)
127+
128+
111129
# reference
112130
def fill_indices_cpu(
113131
tokens_per_expert_group: torch.Tensor,
@@ -143,7 +161,6 @@ def fill_indices_cpu(
143161
start_index,
144162
start_index + (end_idx - write_start),
145163
dtype=torch.int32,
146-
# device=device,
147164
)
148165
write_start += length
149166
return permuted_indices
@@ -213,7 +230,7 @@ def generate_permute_indices(
213230
max_len,
214231
)
215232
else:
216-
permuted_indices = fill_indices_wrapper(
233+
permuted_indices = fill_indices_functional(
217234
tokens_per_expert_group,
218235
start_index_values,
219236
write_offsets,

examples/example_ds3_local_map.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def input_fn():
153153
# ) # maybe not correct value
154154
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
155155
if rng_seed is not None:
156-
NumericsLogger(logs_dir).log_model_weights(parallel_mod)
156+
numerics_logger = NumericsLogger(logs_dir)
157+
numerics_logger.log_model_weights(parallel_mod)
157158

158159
x = (
159160
torch.randint(
@@ -177,6 +178,9 @@ def input_fn():
177178
out.backward(torch.randn_like(out))
178179
else:
179180
out = parallel_mod(*x)
181+
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
182+
if rng_seed is not None:
183+
numerics_logger.log_forward_output(out)
180184
out.backward(torch.randn_like(out))
181185

182186
print("All good!")

0 commit comments

Comments
 (0)