Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _fill_indices_kernel(
# ==============


def fill_indices_wrapper(
# workaround until local_map functionalization is fixed: https://github.com/pytorch/pytorch/issues/167568
@torch.library.custom_op("autoparallel::fill_indices_functional", mutates_args=())
def fill_indices_functional(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
Expand All @@ -84,7 +86,7 @@ def fill_indices_wrapper(
max_len: int,
block_size: int = 128,
max_blocks: int = 1024, # cap on total number of blocks to launch
):
) -> torch.Tensor:
# preallocate output
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
Expand All @@ -108,6 +110,22 @@ def fill_indices_wrapper(
return permuted_indices


@fill_indices_functional.register_fake
def _(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024, # cap on total number of blocks to launch
) -> torch.Tensor:
return torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)


# reference
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
Expand Down Expand Up @@ -143,7 +161,6 @@ def fill_indices_cpu(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
# device=device,
)
write_start += length
return permuted_indices
Expand Down Expand Up @@ -213,7 +230,7 @@ def generate_permute_indices(
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
permuted_indices = fill_indices_functional(
tokens_per_expert_group,
start_index_values,
write_offsets,
Expand Down
6 changes: 5 additions & 1 deletion examples/example_ds3_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def input_fn():
# ) # maybe not correct value
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
if rng_seed is not None:
NumericsLogger(logs_dir).log_model_weights(parallel_mod)
numerics_logger = NumericsLogger(logs_dir)
numerics_logger.log_model_weights(parallel_mod)

x = (
torch.randint(
Expand All @@ -177,6 +178,9 @@ def input_fn():
out.backward(torch.randn_like(out))
else:
out = parallel_mod(*x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
if rng_seed is not None:
numerics_logger.log_forward_output(out)
out.backward(torch.randn_like(out))

print("All good!")
Expand Down