From 78ae90aae4fcdacdad13c6c1f37fe717801a9d8c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 6 Jun 2025 18:04:38 -0400 Subject: [PATCH 1/5] upd --- csrc/nvshmem_binding.cu | 126 ++++++++++++++++++++++++++++++++++++++++ flashinfer/comm.py | 16 +++++ flashinfer/jit/env.py | 14 +++++ setup.py | 2 +- tests/test_nvshmem.py | 9 +++ 5 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 csrc/nvshmem_binding.cu create mode 100644 tests/test_nvshmem.py diff --git a/csrc/nvshmem_binding.cu b/csrc/nvshmem_binding.cu new file mode 100644 index 0000000000..c50e5bbc65 --- /dev/null +++ b/csrc/nvshmem_binding.cu @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2025 Perplexity AI + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +template +T* mallocZeroBuffer(size_t size) { + T* ptr; + CUDACHECK(cudaMalloc(&ptr, size * sizeof(T))); + cudaMemset(ptr, 0, size * sizeof(T)); + return ptr; +} + +inline int get_sm_count() { + int device; + CUDACHECK(cudaGetDevice(&device)); + int numSMs; + CUDACHECK(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, device)); + + return numSMs; +} + +#define NVSHMEMCHECK(stmt) \ + do { \ + int result = (stmt); \ + if (NVSHMEMX_SUCCESS != result) { \ + fprintf(stderr, "[%s:%d] nvshmem failed with error %d \n", __FILE__, __LINE__, result); \ + exit(-1); \ + } \ + } while (0) + +namespace { + +at::Tensor get_unique_id() { + nvshmemx_uniqueid_t uid = NVSHMEMX_UNIQUEID_INITIALIZER; + nvshmemx_get_uniqueid(&uid); + return at::from_blob(&uid, sizeof(uid), at::kByte).clone(); +} + +int64_t unique_id_size() { return sizeof(nvshmemx_uniqueid_t); } + +int64_t init(at::Tensor uid, int64_t rank, int64_t world_size) { + TORCH_CHECK(uid.device().is_cpu(), "uid must be a CPU tensor"); + TORCH_CHECK(uid.scalar_type() == at::kByte, "uid must be a byte tensor"); + TORCH_CHECK(uid.numel() == sizeof(nvshmemx_uniqueid_t), "Invalid unique id size (expected ", + sizeof(nvshmemx_uniqueid_t), ", got ", uid.numel(), ")"); + nvshmemx_uniqueid_t id; + std::memcpy(&id, uid.data_ptr(), sizeof(id)); + nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; + nvshmemx_set_attr_uniqueid_args(rank, world_size, &id, &attr); + return nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); +} + +void finalize() { nvshmem_finalize(); } + +int64_t my_pe() { return nvshmem_my_pe(); } + +int64_t n_pes() { return nvshmem_n_pes(); } + +at::Tensor malloc_tensor(const std::vector& shape, c10::ScalarType dtype, + const c10::Device& device) { + size_t size = c10::elementSize(dtype) * c10::multiply_integers(shape); + void* ptr = nvshmem_malloc(size); + if (ptr == nullptr) { + AT_ERROR("nvshmem_malloc failed. size: ", size); + } + return at::from_blob( + ptr, shape, [](void* ptr) { nvshmem_free(ptr); }, + at::TensorOptions().dtype(dtype).device(device)); +} + +void barrier_all() { nvshmem_barrier_all(); } + +void barrier_all_on_current_stream() { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + nvshmemx_barrier_all_on_stream(stream); +} + +void alltoall(at::Tensor dest, at::Tensor source) { + TORCH_CHECK(dest.is_contiguous(), "dest must be contiguous"); + TORCH_CHECK(source.is_contiguous(), "source must be contiguous"); + + size_t nbytes = dest.numel() * dest.itemsize() / dest.size(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + NVSHMEMCHECK(nvshmemx_alltoallmem_on_stream(NVSHMEM_TEAM_WORLD, (uint8_t*)dest.data_ptr(), + (uint8_t*)source.data_ptr(), nbytes, stream)); +} + +void fake_alltoall(at::Tensor dest, at::Tensor source) {} + +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("nvshmem_get_unique_id", &get_unique_id); + m.def("nvshmem_unique_id_size", &unique_id_size); + m.def("nvshmem_init", &init); + m.def("nvshmem_finalize", &finalize); + m.def("nvshmem_my_pe", &my_pe); + m.def("nvshmem_n_pes", &n_pes); + m.def("nvshmem_malloc", &malloc_tensor); + m.def("nvshmem_barrier_all", &barrier_all); + m.def("nvshmem_barrier_all_on_current_stream", &barrier_all_on_current_stream); + m.def("nvshmem_alltoall(Tensor! dest, Tensor src) -> ()"); + m.impl("nvshmem_alltoall", c10::kCUDA, &alltoall); + m.impl("nvshmem_alltoall", c10::kMeta, &fake_alltoall); +}; + +} // namespace diff --git a/flashinfer/comm.py b/flashinfer/comm.py index 37642d4dda..1a15831fc4 100644 --- a/flashinfer/comm.py +++ b/flashinfer/comm.py @@ -410,6 +410,22 @@ def trtllm_custom_all_reduce( ) +def gen_nvshmem_module() -> JitSpec: + return gen_jit_spec( + "nvshmem", + [jit_env.FLASHINFER_CSRC_DIR / "nvshmem_binding.cu"], + extra_include_paths=[jit_env.get_nvshmem_include_dir()], + extra_ldflags=[f"-L{jit_env.get_nvshmem_lib_dir()}", "-lnvshmem"], + ) + + +@functools.cache +def get_nvshmem_module(): + module = gen_nvshmem_module().build_and_load() + + return module + + def init_custom_ar( ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool ) -> int: diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index efb0c1679c..9d9158875c 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -63,3 +63,17 @@ def _get_workspace_dir_name() -> pathlib.Path: _package_root / "data" / "cutlass" / "tools" / "util" / "include", ] SPDLOG_INCLUDE_DIR = _package_root / "data" / "spdlog" / "include" + + +def get_nvshmem_include_dir(): + import nvidia.nvshmem + + path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "include" + return path + + +def get_nvshmem_lib_dir(): + import nvidia.nvshmem + + path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "lib" + return path diff --git a/setup.py b/setup.py index 2fd1f24d94..87389c40e2 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def generate_build_meta(aot_build_meta: dict) -> None: ext_modules = [] cmdclass = {} -install_requires = ["numpy", "torch", "ninja", "requests"] +install_requires = ["numpy", "torch", "ninja", "requests", "nvidia-nvshmem-cu12"] generate_build_meta({}) if enable_aot: diff --git a/tests/test_nvshmem.py b/tests/test_nvshmem.py new file mode 100644 index 0000000000..0185670b5e --- /dev/null +++ b/tests/test_nvshmem.py @@ -0,0 +1,9 @@ +import flashinfer.comm as comm + + +def test_nvshmem(): + comm.get_nvshmem_module() + + +if __name__ == "__main__": + test_nvshmem() From 1d73960751179a6b32a33b59d93f6174c68f5fbf Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 19 Jun 2025 20:41:48 +0000 Subject: [PATCH 2/5] upd --- flashinfer/comm.py | 6 +++++- flashinfer/jit/core.py | 4 ++++ flashinfer/jit/cpp_ext.py | 42 +++++++++++++++++++++++++++++++++++---- flashinfer/jit/env.py | 14 +++++++------ 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/flashinfer/comm.py b/flashinfer/comm.py index 1a15831fc4..cf89c6860b 100644 --- a/flashinfer/comm.py +++ b/flashinfer/comm.py @@ -415,7 +415,11 @@ def gen_nvshmem_module() -> JitSpec: "nvshmem", [jit_env.FLASHINFER_CSRC_DIR / "nvshmem_binding.cu"], extra_include_paths=[jit_env.get_nvshmem_include_dir()], - extra_ldflags=[f"-L{jit_env.get_nvshmem_lib_dir()}", "-lnvshmem"], + extra_ldflags=[ + f"-L{jit_env.get_nvshmem_lib_dir()}", + "-lnvshmem", + ], + needs_device_linking=True, ) diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 83931ed4bb..55667f2f84 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -72,6 +72,7 @@ class JitSpec: extra_ldflags: Optional[List[str]] extra_include_dirs: Optional[List[Path]] is_class: bool = False + needs_device_linking: bool = False @property def ninja_path(self) -> Path: @@ -100,6 +101,7 @@ def write_ninja(self) -> None: extra_cuda_cflags=self.extra_cuda_cflags, extra_ldflags=self.extra_ldflags, extra_include_dirs=self.extra_include_dirs, + needs_device_linking=self.needs_device_linking, ) write_if_different(ninja_path, content) @@ -131,6 +133,7 @@ def gen_jit_spec( extra_cuda_cflags: Optional[List[str]] = None, extra_ldflags: Optional[List[str]] = None, extra_include_paths: Optional[List[Union[str, Path]]] = None, + needs_device_linking: bool = False, ) -> JitSpec: check_cuda_arch() verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1" @@ -173,6 +176,7 @@ def gen_jit_spec( extra_cuda_cflags=cuda_cflags, extra_ldflags=extra_ldflags, extra_include_dirs=extra_include_paths, + needs_device_linking=needs_device_linking, ) spec.write_ninja() return spec diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index ff1d697f3f..e9d9a6eeb7 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -37,6 +37,7 @@ def generate_ninja_build_for_op( extra_cuda_cflags: Optional[List[str]], extra_ldflags: Optional[List[str]], extra_include_dirs: Optional[List[Path]], + needs_device_linking: bool = False, ) -> str: system_includes = [ sysconfig.get_path("include"), @@ -93,6 +94,20 @@ def generate_ninja_build_for_op( "-L$cuda_home/lib64", "-lcudart", ] + + env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS") + if env_extra_ldflags: + try: + import shlex + + ldflags += shlex.split(env_extra_ldflags) + except ValueError as e: + print( + f"Warning: Could not parse FLASHINFER_EXTRA_LDFLAGS with shlex: {e}. Falling back to simple split.", + file=sys.stderr, + ) + ldflags += env_extra_ldflags.split() + if extra_ldflags is not None: ldflags += extra_ldflags @@ -125,12 +140,28 @@ def generate_ninja_build_for_op( " depfile = $out.d", " deps = gcc", "", - "rule link", - " command = $cxx $in $ldflags -o $out", - "", ] + # Add nvcc linking rule for device code + if needs_device_linking: + lines.extend( + [ + "rule nvcc_link", + " command = $nvcc -shared $in $ldflags -o $out", + "", + ] + ) + else: + lines.extend( + [ + "rule link", + " command = $cxx $in $ldflags -o $out", + "", + ] + ) + objects = [] + cuda_objects = [] for source in sources: is_cuda = source.suffix == ".cu" object_suffix = ".cuda.o" if is_cuda else ".o" @@ -138,10 +169,13 @@ def generate_ninja_build_for_op( obj_name = source.with_suffix(object_suffix).name obj = f"$name/{obj_name}" objects.append(obj) + if is_cuda: + cuda_objects.append(obj) lines.append(f"build {obj}: {cmd} {source.resolve()}") lines.append("") - lines.append("build $name/$name.so: link " + " ".join(objects)) + link_rule = "nvcc_link" if needs_device_linking else "link" + lines.append(f"build $name/$name.so: {link_rule} " + " ".join(objects)) lines.append("default $name/$name.so") lines.append("") diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 9d9158875c..e1307e67d0 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -66,14 +66,16 @@ def _get_workspace_dir_name() -> pathlib.Path: def get_nvshmem_include_dir(): - import nvidia.nvshmem + # import nvidia.nvshmem - path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "include" - return path + # path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "include" + # return path + return "/opt/nvshmem-3.2.5/include" def get_nvshmem_lib_dir(): - import nvidia.nvshmem + # import nvidia.nvshmem - path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "lib" - return path + # path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "lib" + # return path + return "/opt/nvshmem-3.2.5/lib" From b96115e6696338814f694cafcce011a33bf9c143 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 19 Jun 2025 20:53:49 +0000 Subject: [PATCH 3/5] upd --- flashinfer/comm.py | 10 +++++++++- flashinfer/jit/env.py | 14 ++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/flashinfer/comm.py b/flashinfer/comm.py index cf89c6860b..24248207ab 100644 --- a/flashinfer/comm.py +++ b/flashinfer/comm.py @@ -417,7 +417,7 @@ def gen_nvshmem_module() -> JitSpec: extra_include_paths=[jit_env.get_nvshmem_include_dir()], extra_ldflags=[ f"-L{jit_env.get_nvshmem_lib_dir()}", - "-lnvshmem", + "-lnvshmem_device", ], needs_device_linking=True, ) @@ -425,6 +425,14 @@ def gen_nvshmem_module() -> JitSpec: @functools.cache def get_nvshmem_module(): + from pathlib import Path + + import nvidia.nvshmem + + ctypes.CDLL( + Path(nvidia.nvshmem.__path__[0]) / "lib" / "libnvshmem_host.so.3", + mode=ctypes.RTLD_GLOBAL, + ) module = gen_nvshmem_module().build_and_load() return module diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index e1307e67d0..9d9158875c 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -66,16 +66,14 @@ def _get_workspace_dir_name() -> pathlib.Path: def get_nvshmem_include_dir(): - # import nvidia.nvshmem + import nvidia.nvshmem - # path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "include" - # return path - return "/opt/nvshmem-3.2.5/include" + path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "include" + return path def get_nvshmem_lib_dir(): - # import nvidia.nvshmem + import nvidia.nvshmem - # path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "lib" - # return path - return "/opt/nvshmem-3.2.5/lib" + path = pathlib.Path(nvidia.nvshmem.__path__[0]) / "lib" + return path From 4835f36b4c4758343e4ac36ed65239526dd8e4f2 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 19 Jun 2025 20:55:34 +0000 Subject: [PATCH 4/5] upd --- flashinfer/jit/cpp_ext.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index e9d9a6eeb7..d650ef3fba 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -161,7 +161,6 @@ def generate_ninja_build_for_op( ) objects = [] - cuda_objects = [] for source in sources: is_cuda = source.suffix == ".cu" object_suffix = ".cuda.o" if is_cuda else ".o" @@ -169,8 +168,6 @@ def generate_ninja_build_for_op( obj_name = source.with_suffix(object_suffix).name obj = f"$name/{obj_name}" objects.append(obj) - if is_cuda: - cuda_objects.append(obj) lines.append(f"build {obj}: {cmd} {source.resolve()}") lines.append("") From 0cd00bba1ff1550c0d189d1c2e8b63f99578aec5 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Mon, 23 Jun 2025 16:37:09 -0700 Subject: [PATCH 5/5] add cute multimem allreduce kernel --- csrc/nvshmem_binding.cu | 9 + flashinfer/comm.py | 281 ++++++++++++++++++++++++++++++- flashinfer/cute_utils.py | 152 +++++++++++++++++ setup.py | 2 +- tests/test_multimem_allreduce.py | 112 ++++++++++++ 5 files changed, 554 insertions(+), 2 deletions(-) create mode 100644 flashinfer/cute_utils.py create mode 100644 tests/test_multimem_allreduce.py diff --git a/csrc/nvshmem_binding.cu b/csrc/nvshmem_binding.cu index c50e5bbc65..e455c4d498 100644 --- a/csrc/nvshmem_binding.cu +++ b/csrc/nvshmem_binding.cu @@ -89,6 +89,14 @@ at::Tensor malloc_tensor(const std::vector& shape, c10::ScalarType dtyp at::TensorOptions().dtype(dtype).device(device)); } +int64_t multicast_ptr(at::Tensor tensor) { + void *mc_ptr = nvshmemx_mc_ptr(NVSHMEM_TEAM_WORLD, (void *) tensor.data_ptr()); + if (mc_ptr == nullptr) { + AT_ERROR("nvshmemx_mc_ptr failed."); + } + return reinterpret_cast(mc_ptr); +} + void barrier_all() { nvshmem_barrier_all(); } void barrier_all_on_current_stream() { @@ -121,6 +129,7 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("nvshmem_alltoall(Tensor! dest, Tensor src) -> ()"); m.impl("nvshmem_alltoall", c10::kCUDA, &alltoall); m.impl("nvshmem_alltoall", c10::kMeta, &fake_alltoall); + m.def("nvshmem_multicast_ptr", &multicast_ptr); }; } // namespace diff --git a/flashinfer/comm.py b/flashinfer/comm.py index fb0e695e06..c30caa090c 100644 --- a/flashinfer/comm.py +++ b/flashinfer/comm.py @@ -18,12 +18,26 @@ import functools from dataclasses import dataclass from types import SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Sequence import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import cutlass +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +import cutlass._mlir.ir as ir +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Pointer, Int32 +from cutlass.cutlass_dsl import dsl_user_op +from cutlass.base_dsl.dsl import extract_mlir_values +from cutlass._mlir.dialects import scf +from cutlass._mlir.dialects import llvm + +from .cute_utils import as_tensor, signal_multimem, wait_loop, multimem_ld_reduce, multimem_st from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec, sm100a_nvcc_flags @@ -1098,3 +1112,268 @@ def trtllm_moe_allreduce_fusion( quant_out=quant_out, scale_out=scale_out, ) + + +class MultimemAllReduce: + + def __init__( + self, + local_rank: int, + world_size: int, + max_buffer_elements: int, + dtype: torch.dtype, + device: torch.device, + group: Optional[ProcessGroup] = None, + should_init: bool = True, + ): + self.local_rank = local_rank + self.world_size = world_size + self.dtype = dtype + self.device = device + self.max_buffer_elements = max_buffer_elements + self.group = group + self.nvshmem_module = get_nvshmem_module() + + # TODO(asamani): make this configurable + tensor_dtype = cutlass.Float16 + num_warps_per_block = 8 # for fp16 and bf16 + num_ctas = 128 # max is the number of sms + tensor_size = max_buffer_elements + numel_per_thread = 8 + + max_work_per_iter = numel_per_thread * num_warps_per_block * 32 * num_ctas * world_size + num_iter_per_thread = tensor_size//max_work_per_iter + + # TODO(asamani): check if this configuratio works with impl + self.should_init = should_init + if self.should_init: + self.init_nvshmem() + + # assert PE and world size match + my_pe = self.nvshmem_module.nvshmem_my_pe() + n_pes = self.nvshmem_module.nvshmem_n_pes() + if my_pe != local_rank: + print(f"WARNING: Rank {local_rank}: PE mismatch! Expected PE {local_rank}, got PE {my_pe}", flush=True) + if n_pes != world_size: + print(f"WARNING: Rank {local_rank}: World size mismatch! Expected {world_size}, got {n_pes}", flush=True) + + # allocate memory in nvshmem symm heap + # self.symm_buffer_input = self.nvshmem_module.nvshmem_malloc( + # [max_buffer_elements], + # self.dtype, + # self.device, + # ) + # self.symm_buffer_output = self.nvshmem_module.nvshmem_malloc( + # [max_buffer_elements], + # self.dtype, + # self.device, + # ) + self.cute_tensor_barrier_start, self.nvshmem_tensor_barrier_start, self.tensor_mc_memerf_barrier_start = self.create_barrier_flags([num_ctas], cutlass.Int32, device) + + self.cute_tensor_barrier_end, self.nvshmem_tensor_barrier_end, self.tensor_mc_memerf_barrier_end = self.create_barrier_flags([num_ctas], cutlass.Int32, device) + + self.input_tensor, self.input_torch, self.input_mc_memref = self.create_and_permute_tensor( + [tensor_size], + dtype=tensor_dtype, + device=device, + is_mc=True, + fill_value=None, + ) + self.output_tensor, self.output_torch, self.output_mc_memref = self.create_and_permute_tensor( + [tensor_size], + dtype=tensor_dtype, + device=device, + is_mc=True, + fill_value=0, + ) + + torch.distributed.barrier(self.group) + + def init_nvshmem(self): + torch.zeros(self.nvshmem_module.nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu") + if self.local_rank == 0: + uid = self.nvshmem_module.nvshmem_get_unique_id() + else: + uid = torch.zeros(self.nvshmem_module.nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu") + torch.distributed.broadcast(uid, src=0) + torch.distributed.barrier(self.group) + init_status = self.nvshmem_module.nvshmem_init(uid, self.local_rank, self.world_size) + torch.cuda.synchronize() + if init_status != 0: + raise RuntimeError("Failed to initialize nvshmem") + + @cute.kernel + def allreduce_kernel( + self, + gInput: cute.Tensor, + gOutput: cute.Tensor, + gInputMC: cute.Tensor, + gOutputMC: cute.Tensor, + total_elements: cutlass.Constexpr[int], + rank_id: cutlass.Constexpr[int], + red_elements: cutlass.Constexpr[int], + num_iter_per_thread: cutlass.Constexpr[int], + cute_tensor_barrier_start: cute.Tensor, + cute_tensor_barrier_end: cute.Tensor, + tensor_mc_memerf_barrier_start: cute.Tensor, + tensor_mc_memerf_barrier_end: cute.Tensor, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + gdim, _, _ = cute.arch.grid_dim() + + thread_idx = bidx * bdim + tidx + # sync before start + #cute.arch.sync_warp() + if tidx == 0: # run once per block + flag_start_mc = tensor_mc_memerf_barrier_start.iterator + bidx + flag_start_uc = cute_tensor_barrier_start.iterator + bidx + signal_multimem(flag_start_mc, is_relaxed=True) + wait_loop(flag_start_uc, Int32(self.world_size), is_relaxed=True) + # reduction loop + cute.arch.sync_threads() + #cute.arch.sync_warp() + for i in cutlass.range_dynamic(num_iter_per_thread): + global_offset = rank_id * gdim * bdim * red_elements * num_iter_per_thread + device_offset = thread_idx * red_elements + iter_offset = i * red_elements * bdim * gdim + offset = global_offset + device_offset + iter_offset + elem_coords = (offset,) + idx = cute.crd2idx(elem_coords, gOutputMC.layout) + mc_ptr_inp = gInputMC.iterator + idx + mc_ptr = gOutputMC.iterator + idx + x, y, z, w = multimem_ld_reduce(mc_ptr_inp) + multimem_st(mc_ptr, x, y, z, w) + + # sync before exiting the kernel + cute.arch.sync_threads() + #cute.arch.sync_warp() + if tidx == 0: # run once per block + flag_end_mc = tensor_mc_memerf_barrier_end.iterator + bidx + flag_end_uc = cute_tensor_barrier_end.iterator + bidx + signal_multimem(flag_end_mc, is_relaxed=False) + wait_loop(flag_end_uc, Int32(self.world_size), is_relaxed=False) + + @cute.jit + def all_reduce_jitted( + self, + gInput: cute.Tensor, + gOutput: cute.Tensor, + gInputMC: cute.Tensor, + gOutputMC: cute.Tensor, + num_warps_per_block: cutlass.Constexpr[int], + num_ctas: cutlass.Constexpr[int], + num_iter_per_thread: cutlass.Constexpr[int], + rank_id: cutlass.Constexpr[int], + num_ranks: cutlass.Constexpr[int], + cute_tensor_barrier_start: cute.Tensor, + cute_tensor_barrier_end: cute.Tensor, + tensor_mc_memerf_barrier_start: cute.Tensor, + tensor_mc_memerf_barrier_end: cute.Tensor, + ): + total_elements = cute.size(gInput) + numel_per_thread = 8 # for fp16 and bf16 + kernel = self.allreduce_kernel( + gInput, + gOutput, + gInputMC, + gOutputMC, + total_elements, + rank_id, + numel_per_thread, + num_iter_per_thread, + cute_tensor_barrier_start, + cute_tensor_barrier_end, + tensor_mc_memerf_barrier_start, + tensor_mc_memerf_barrier_end, + ) + kernel.launch(grid=(num_ctas, 1, 1), + block=(num_warps_per_block * 32, 1, 1)) + + def all_reduce(self, inp: torch.Tensor, out: torch.Tensor) -> None: + numel = inp.numel() + input_buffer = self.input_torch.narrow(0, 0, numel) + output_buffer = self.output_torch.narrow(0, 0, numel) + input_buffer.copy_(inp) + # TODO(asamani): optimize and only perform all reduce for the current data + self.all_reduce_jitted( + self.input_tensor, + self.output_tensor, + self.input_mc_memref, + self.output_mc_memref, + self.num_warps_per_block, + self.num_ctas, + self.num_iter_per_thread, + self.rank_id, + self.num_ranks, + self.cute_tensor_barrier_start, + self.cute_tensor_barrier_end, + self.tensor_mc_memerf_barrier_start, + self.tensor_mc_memerf_barrier_end) + out.copy_(output_buffer) + + def create_barrier_flags( + self, + shape: Sequence[int], + dtype: torch.dtype, + device: torch.device + ): + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + nvshmem_tensor = self.nvshmem_modulenvshmem_malloc(shape, torch_dtype, device) + nvshmem_tensor.fill_(0) + tensor_mc = self.nvshmem_modulenvshmem_multicast_ptr(nvshmem_tensor) + cute_tensor = from_dlpack(nvshmem_tensor) + cute_tensor.element_type = dtype + cute_tensor.mark_layout_dynamic() + tensor_mc_memerf = from_dlpack( + as_tensor(tensor_mc, nvshmem_tensor.shape, nvshmem_tensor.dtype), + ) + #NOTE(asamani): check if should align to 16 bytes? + tensor_mc_memerf.mark_layout_dynamic() + return cute_tensor, nvshmem_tensor, tensor_mc_memerf + + + def create_and_permute_tensor( + self, + shape: Sequence[int], + dtype: torch.dtype, + device: torch.device, + is_mc: bool = False, + fill_value: Optional[float] = None, + ): + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + torch_tensor = self.nvshmem_module.nvshmem_malloc(shape, torch_dtype, device) + if fill_value is not None: + torch_tensor.fill_(fill_value) + else: + torch_tensor.copy_(torch.randint(1,16, shape, dtype=torch_dtype, device=device)) + tensor_mc = None + if is_mc: + mc_ptr = self.nvshmem_module.nvshmem_multicast_ptr(torch_tensor) + tensor_mc = from_dlpack(as_tensor(mc_ptr, + torch_tensor.shape, + torch_tensor.dtype), + assumed_align=16) + tensor_mc = tensor_mc.mark_layout_dynamic() + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + cute_tensor.mark_layout_dynamic() + # TODO(asamani): if fp8 handle differently + return cute_tensor, torch_tensor, tensor_mc + + def shutdown(self): + del self.symm_buffer_input + del self.symm_buffer_output + torch.distributed.barrier(self.group) + torch.cuda.synchronize() + if self.should_init: + self.nvshmem_module.nvshmem_finalize() \ No newline at end of file diff --git a/flashinfer/cute_utils.py b/flashinfer/cute_utils.py new file mode 100644 index 0000000000..3017068185 --- /dev/null +++ b/flashinfer/cute_utils.py @@ -0,0 +1,152 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +from functools import partial +import subprocess +import ctypes +from math import prod + +import torch +import torch.distributed as dist + +import cutlass +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +import cutlass._mlir.ir as ir +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Pointer, Int32 +from cutlass.cutlass_dsl import dsl_user_op +from cutlass.base_dsl.dsl import extract_mlir_values +from cutlass._mlir.dialects import scf +from cutlass._mlir.dialects import llvm +from typing import Sequence + +def as_tensor(pointer, shape, torch_type): + if torch_type.itemsize == 1: + cytype = ctypes.c_uint8 + elif torch_type.itemsize == 2: + cytype = ctypes.c_uint16 + elif torch_type.itemsize == 4: + cytype = ctypes.c_uint32 + elif torch_type.itemsize == 8: + cytype = ctypes.c_uint64 + else: + raise ValueError(f'Unsupported torch dtype: {torch_type}') + cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype)) + arr = (cpointer._type_ * prod(shape)).from_address( + ctypes.addressof(cpointer.contents)) + return torch.frombuffer(arr, dtype=torch_type).view(*shape) + +@dsl_user_op +def multimem_ld_reduce( + mc_ptr: Pointer, + *, + loc=None, + ip=None, +): + # ld reduce 8x f16 elts + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + i32 = ir.IntegerType.get_signless(32) + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), + [mc_ptr_int], + "multimem.ld_reduce.relaxed.sys.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];", + "=r,=r,=r,=r,l", + has_side_effects=True, + asm_dialect=0, + ) + return_regs = [ + llvm.extractvalue(i32, return_struct, [i]) for i in range(4) + ] + return return_regs[0], return_regs[1], return_regs[2], return_regs[3] + + +@dsl_user_op +def multimem_st( + mc_ptr: Pointer, + x: Int32, + y: Int32, + z: Int32, + w: Int32, + *, + loc=None, + ip=None, +): + # st 8x f16 elts + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + i32 = ir.IntegerType.get_signless(32) + llvm.inline_asm( + i32, + [mc_ptr_int, x, y, z, w], + "multimem.st.relaxed.sys.global.v4.f32 [$1], {$2, $3, $4, $5};", + "=r,l,r,r,r,r", + has_side_effects=True, + asm_dialect=0, + ) + +@dsl_user_op +def signal_multimem( + flag_mc, + is_relaxed=False, + *, + loc=None, + ip=None, +): + mode = "relaxed" if is_relaxed else "release" + flag_ptr_int = flag_mc.toint().ir_value() + llvm.inline_asm( + None, + [flag_ptr_int], + f""" + {{ + multimem.red.{mode}.sys.global.add.u32 [$0], 1; + fence.proxy.alias; + }}""", + "l", + has_side_effects=True, + asm_dialect=0, + ) + +@dsl_user_op +def wait_loop( + flag, + num_ranks, + is_relaxed=False, + *, + loc=None, + ip=None, +): + mode = "relaxed" if is_relaxed else "acquire" + flag_ptr_int = flag.toint().ir_value() + llvm.inline_asm( + None, + [flag_ptr_int, num_ranks.ir_value()], + f""" + {{ + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + wait_signal: + atom.global.sys.{mode}.cas.b32 %tmp32_0, [$0], $1, 0; + setp.eq.u32 %p0, %tmp32_0, $1; + @!%p0 bra wait_signal; + }}""", + "l,r", + has_side_effects=True, + asm_dialect=0, + ) diff --git a/setup.py b/setup.py index 87389c40e2..4634e49920 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def generate_build_meta(aot_build_meta: dict) -> None: ext_modules = [] cmdclass = {} -install_requires = ["numpy", "torch", "ninja", "requests", "nvidia-nvshmem-cu12"] +install_requires = ["numpy", "torch", "ninja", "requests", "nvidia-nvshmem-cu12", "nvidia-cutlass-dsl"] generate_build_meta({}) if enable_aot: diff --git a/tests/test_multimem_allreduce.py b/tests/test_multimem_allreduce.py new file mode 100644 index 0000000000..f6aec898e0 --- /dev/null +++ b/tests/test_multimem_allreduce.py @@ -0,0 +1,112 @@ +import logging +import multiprocessing as mp +import os +import socket +from typing import Any + +import pytest +import torch +import torch.distributed as dist + +import flashinfer.comm as comm + +logger = logging.getLogger(__name__) + + +def _run_correctness_worker(world_size, rank, distributed_init_port): + assert rank >= 0 + torch.cuda.set_device(rank) + device = torch.device("cuda", rank) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, + world_size=world_size, + device_id=device, + init_method=distributed_init_method, + ) + group = dist.group.WORLD + num_ranks = torch.distributed.get_world_size() + rank_id = torch.distributed.get_rank() + + batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + max_batch_size = 4096 + hidden_dim = 8192 + test_loop = 10 + tensor_dtype = torch.bfloat16 + #tensor_dtype = torch.float16 + mnnvl_allreduce = comm.MultimemAllReduce( + rank_id, + num_ranks, + max_batch_size*hidden_dim, + tensor_dtype, + device, + group, + ) + + try: + for batch_size in batch_sizes: + for _ in range(test_loop): + tensor_size = batch_size * hidden_dim + inp1 = torch.full([tensor_size], rank_id, dtype=tensor_dtype, device=device) + inp1_ref = inp1.clone() + out1 = torch.empty_like(inp1) + mnnvl_allreduce.all_reduce(inp1, out1) + torch.distributed.all_reduce(inp1_ref, group=group) + torch.cuda.synchronize() + torch.testing.assert_close(out1, inp1_ref) + torch.distributed.barrier(group) + except Exception as e: + print(f"Rank {rank_id}: Exception during test: {e}") + raise + finally: + torch.distributed.barrier(group) + mnnvl_allreduce.shutdown() + torch.distributed.destroy_process_group(group) + + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert ( + procs[i].exitcode == 0 + ), f"Process {i} failed with exit code {procs[i].exitcode}" + + +@pytest.mark.parametrize("world_size", [8]) +def test_multimem_allreduce(world_size): + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, + _run_correctness_worker, + target_args=(), + ) + print(f"MNNVL allreduce tp = {world_size}: OK") \ No newline at end of file