From b26c69de1bc5deac7d43644dcfa6cfa0687ea33b Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 6 Nov 2025 19:03:57 +0000 Subject: [PATCH 1/2] Add custom communicator for trtllm_mnnvl_ar Upd --- flashinfer/comm/mnnvl.py | 12 +- flashinfer/comm/trtllm_mnnvl_ar.py | 13 +- ...test_trtllm_mnnvl_allreduce_custom_comm.py | 336 ++++++++++++++++++ 3 files changed, 351 insertions(+), 10 deletions(-) create mode 100644 tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 12aec978ec..ad6fb9dfd9 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -547,7 +547,6 @@ def supports_mnnvl() -> bool: class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" - def __init__( self, buf_size: int, @@ -555,6 +554,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, + comm: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -631,7 +631,7 @@ def __init__( "[McastDeviceMemory] Device does not support fabric handle." ) - self._alloc_mn_mcast_mem(buf_size) + self._alloc_mn_mcast_mem(buf_size, comm) else: # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem raise NotImplementedError("Single-node NVLS allocation not implemented yet") @@ -753,7 +753,7 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem(self, buf_size: int): + def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context @@ -767,9 +767,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): except Exception as e: print(f"Error checking CUDA context: {e}") - # Get MPI communicator - comm = MpiComm() - # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC @@ -969,6 +966,7 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, + comm: Optional[CommBackend] = None, ): """ Constructor for McastGpuBuffer. @@ -981,7 +979,7 @@ def __init__( mn_nvlink: Flag indicating if multi-node NVLink is used """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink + buf_size, group_size, group_rank, device.index, mn_nvlink, comm ) self.buf_size = buf_size self.local_device = device diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 76aedee260..ef36000b38 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -15,7 +15,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer +from .mnnvl import (McastGPUBuffer, CommBackend) def mpi_barrier(): @@ -122,7 +122,9 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None + mapping: Mapping, dtype: torch.dtype, + buffer_size_in_bytes: Optional[int] = None, + comm: Optional[CommBackend] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -139,6 +141,7 @@ def get_allreduce_mnnvl_workspace( mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens + comm: Optional communication backend for multi-node synchronization Returns: Tuple containing: @@ -167,6 +170,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, + comm=comm, ) # Initialize the unicast buffer with -0.0 @@ -174,7 +178,10 @@ def get_allreduce_mnnvl_workspace( # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + if comm: + comm.barrier() + else: + mpi_barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] diff --git a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py new file mode 100644 index 0000000000..e664a12b6d --- /dev/null +++ b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py @@ -0,0 +1,336 @@ +# Check torch version: +from typing import Any, Tuple + +import multiprocessing as mp +import socket +import pytest +import torch +import torch.distributed as dist +from mpi4py import MPI # Added MPI import + +import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar +from flashinfer.comm.mapping import Mapping + +# Use flashinfer.norm.rmsnorm as reference implementation. +from flashinfer.norm import rmsnorm +from flashinfer.comm.mnnvl import CommBackend as CommBackend + +import pynvml + +pynvml.nvmlInit() + +class CustomCommunicator(CommBackend): + def __init__(self, group): + self._group = group + + def Get_rank(self) -> int: + return dist.get_rank(self._group) + + def Get_size(self) -> int: + return dist.get_world_size(self._group) + + def allgather(self, data: int | bytes): + device = f"cuda:{torch.cuda.current_device()}" + if isinstance(data, int): + local_tensor = torch.tensor([data], device=device, dtype=torch.int32) + world_size = self.Get_size() + gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)] + + dist.all_gather(gathered, local_tensor, group=self._group) + return [int(x.item()) for x in gathered] + + elif isinstance(data, bytes): + local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device) + world_size = self.Get_size() + gathered = [data] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + else: + raise TypeError(f"Unsupported type for allgather: {type(data)}") + + def bcast(self, data, root: int = 0): + """ + Broadcast a picklable Python object from `root` to all ranks. + Uses torch.distributed.broadcast_object_list under the hood. + + Returns the broadcasted object on every rank. + """ + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] + + def barrier(self): + """ + Synchronize all ranks in this communicator. + """ + dist.barrier(group=self._group) + + def Split(self, color: int, key: int) -> "CustomCommunicator": + return self + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + +def multi_process_parallel( + world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, dtype, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert procs[i].exitcode == 0, ( + f"Process {i} failed with exit code {procs[i].exitcode}" + ) + +@torch.inference_mode() +def row_linear_residual_norm_forward( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + hidden_size: int, + dtype: torch.dtype, + mapping: Mapping, + reference_output: tuple[torch.Tensor, ...], + multicast_ptr: int, + buffer_ptrs_dev: int, + unicast_ptr: int, + max_num_elements_mnnvl: int, + buffer_flags_mnnvl: torch.Tensor, +): + x = x.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + reference_output = tuple(t.cuda() for t in reference_output) + + tensor_parallel_size = mapping.tp_size + tensor_parallel_rank = mapping.tp_rank + + def func( + input, + residual, + norm_weight, + eps, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + ): + # For both fused and unfused cases: + shape = input.shape + + assert max_num_elements_mnnvl % hidden_size == 0 + + input = input.view(-1, shape[-1]) + + buffer_M = max_num_elements_mnnvl // hidden_size + output = torch.empty_like(input) + + trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( + input, + multicast_ptr, + buffer_ptrs_dev, + buffer_M, + buffer_flags_mnnvl, + tensor_parallel_size, + tensor_parallel_rank, + True, # wait_for_results + False, # launch_with_pdl + output, # Need to provide output tensor since we are writing them out. + ) + return (output.view(shape),) + + output = func( + x.clone(), + residual.clone(), + norm_weight, + eps, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + ) + + assert output[0].shape == reference_output[0].shape + + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + +def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidden_size): + # Set CUDA device based on rank + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + torch.cuda.set_device(rank) + comm = CustomCommunicator(group) + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + ) + + if mapping.local_rank == 0: + print( + f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" + ) + print( + f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + ) + + tensor_parallel_size = world_size + eps = 1e-5 + torch.manual_seed(42) + + # Track if this rank failed + rank_failed = False + failure_message = "" + + try: + # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list + # This workspace is sized for the maximum expected sequence length and can be reused within each list + # Each parameterized list gets its own fresh workspace allocation + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype, comm) + ) + + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) + + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, dtype={dtype}" + ) + + # Generate test data (same on all ranks due to same seed) + x_full = torch.randn( + (tensor_parallel_size, seq_len, hidden_size), + dtype=dtype, + device=torch.device("cuda"), + ) + residual = torch.randn( + (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + ) + norm_weight = torch.randn( + (hidden_size,), dtype=dtype, device=torch.device("cuda") + ) + + # Each rank gets its slice of the input + x = x_full[rank, :, :] + + # Compute reference output based on fusion mode + reference_output: Tuple[torch.Tensor, ...] = None + + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + + # Run the test with the same workspace + row_linear_residual_norm_forward( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, + ) + + # Synchronize before next test + comm.barrier() + + print( + f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}" + ) + + except Exception as e: + rank_failed = True + failure_message = f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}" + print(failure_message) + # Gather failure status from all ranks + all_failures = MPI.COMM_WORLD.allgather(rank_failed) + + # If any rank failed, fail the test + if any(all_failures): + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}") + + # Fail the test on all ranks + pytest.fail(f"Test failed on ranks {failed_ranks}") + comm.barrier() + + finally: + # Ensure cleanup happens for this list's workspace + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl + + # Final synchronization and check for failures across all ranks + comm.barrier() + +"""Main test function that runs on each MPI rank""" +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mnnvl_allreduce_custom_communicator( + monkeypatch, world_size, +): + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + seq_len = 24 + dtype = torch.bfloat16 + hidden_size = 2048 + + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, + dtype, + _run_mnnvl_ar, + target_args=(seq_len, hidden_size), + ) + print(f"custom mnnvl allreduce world_size = {world_size}: OK") From 602adfe6e934c90dad908a2d80cf9c9a25ed2feb Mon Sep 17 00:00:00 2001 From: "Shu Wang." Date: Fri, 14 Nov 2025 20:39:34 +0000 Subject: [PATCH 2/2] Upd --- flashinfer/comm/mnnvl.py | 30 ++++- flashinfer/comm/trtllm_mnnvl_ar.py | 17 +-- tests/comm/test_trtllm_mnnvl_allreduce.py | 11 +- ...test_trtllm_mnnvl_allreduce_custom_comm.py | 123 ++++-------------- 4 files changed, 66 insertions(+), 115 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index ad6fb9dfd9..2d280a68e8 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -155,6 +155,9 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def barrier(self) -> None: ... + @abstractmethod def Split(self, color: int, key: int) -> "CommBackend": ... @@ -209,6 +212,9 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def barrier(self): + self._mpicomm.Barrier() + def Split(self, color: int, key: int) -> CommBackend: self._mpicomm = self._mpicomm.Split(color, key) return MPIBackend() # Returns new adapter @@ -547,6 +553,7 @@ def supports_mnnvl() -> bool: class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" + def __init__( self, buf_size: int, @@ -554,7 +561,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, - comm: Optional[CommBackend] = None, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -631,7 +638,7 @@ def __init__( "[McastDeviceMemory] Device does not support fabric handle." ) - self._alloc_mn_mcast_mem(buf_size, comm) + self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer) else: # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem raise NotImplementedError("Single-node NVLS allocation not implemented yet") @@ -753,7 +760,9 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()): + def _alloc_mn_mcast_mem( + self, buf_size: int, comm_backend_for_handle_transfer: Any = None + ): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context @@ -766,7 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()): ) except Exception as e: print(f"Error checking CUDA context: {e}") - + if comm_backend_for_handle_transfer is None: + comm = MpiComm() + else: + comm = comm_backend_for_handle_transfer # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC @@ -966,7 +978,7 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, - comm: Optional[CommBackend] = None, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): """ Constructor for McastGpuBuffer. @@ -977,9 +989,15 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used + comm_backend_for_handle_transfer: Communication backend for handle transfer """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink, comm + buf_size, + group_size, + group_rank, + device.index, + mn_nvlink, + comm_backend_for_handle_transfer, ) self.buf_size = buf_size self.local_device = device diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index ef36000b38..84a9c150de 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -15,7 +15,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import (McastGPUBuffer, CommBackend) +from .mnnvl import McastGPUBuffer, CommBackend def mpi_barrier(): @@ -122,9 +122,10 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, + mapping: Mapping, + dtype: torch.dtype, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, buffer_size_in_bytes: Optional[int] = None, - comm: Optional[CommBackend] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -140,8 +141,8 @@ def get_allreduce_mnnvl_workspace( Args: mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced - buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens comm: Optional communication backend for multi-node synchronization + buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens Returns: Tuple containing: @@ -170,7 +171,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, - comm=comm, + comm_backend_for_handle_transfer=comm_backend_for_handle_transfer, ) # Initialize the unicast buffer with -0.0 @@ -178,10 +179,10 @@ def get_allreduce_mnnvl_workspace( # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - if comm: - comm.barrier() - else: + if comm_backend_for_handle_transfer is None: mpi_barrier() + else: + comm_backend_for_handle_transfer.barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index abb3795019..e7274c46f0 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,5 +1,5 @@ # Check torch version: -from typing import Tuple +from typing import Tuple, Optional import pytest import torch @@ -7,6 +7,7 @@ import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import CommBackend, MpiComm # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -28,6 +29,7 @@ def row_linear_residual_norm_fusion_forward( unicast_ptr: int, max_num_elements_mnnvl: int, buffer_flags_mnnvl: torch.Tensor, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): x = x.cuda() residual = residual.cuda() @@ -36,8 +38,11 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - - MPI.COMM_WORLD.barrier() + if comm_backend_for_handle_transfer is None: + comm = MpiComm() + else: + comm = comm_backend_for_handle_transfer + comm.barrier() def func( input, diff --git a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py index e664a12b6d..60933cf89b 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py @@ -6,19 +6,16 @@ import pytest import torch import torch.distributed as dist -from mpi4py import MPI # Added MPI import import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping - -# Use flashinfer.norm.rmsnorm as reference implementation. -from flashinfer.norm import rmsnorm from flashinfer.comm.mnnvl import CommBackend as CommBackend import pynvml pynvml.nvmlInit() + class CustomCommunicator(CommBackend): def __init__(self, group): self._group = group @@ -59,7 +56,7 @@ def bcast(self, data, root: int = 0): # broadcast_object_list mutates obj_list in-place dist.broadcast_object_list(obj_list, src=root, group=self._group) return obj_list[0] - + def barrier(self): """ Synchronize all ranks in this communicator. @@ -69,6 +66,7 @@ def barrier(self): def Split(self, color: int, key: int) -> "CustomCommunicator": return self + def get_open_port() -> int: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -78,7 +76,8 @@ def get_open_port() -> int: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("::1", 0)) return s.getsockname()[1] - + + def multi_process_parallel( world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () ) -> None: @@ -98,89 +97,6 @@ def multi_process_parallel( f"Process {i} failed with exit code {procs[i].exitcode}" ) -@torch.inference_mode() -def row_linear_residual_norm_forward( - x: torch.Tensor, - residual: torch.Tensor, - norm_weight: torch.Tensor, - eps: float, - hidden_size: int, - dtype: torch.dtype, - mapping: Mapping, - reference_output: tuple[torch.Tensor, ...], - multicast_ptr: int, - buffer_ptrs_dev: int, - unicast_ptr: int, - max_num_elements_mnnvl: int, - buffer_flags_mnnvl: torch.Tensor, -): - x = x.cuda() - residual = residual.cuda() - norm_weight = norm_weight.cuda() - reference_output = tuple(t.cuda() for t in reference_output) - - tensor_parallel_size = mapping.tp_size - tensor_parallel_rank = mapping.tp_rank - - def func( - input, - residual, - norm_weight, - eps, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - ): - # For both fused and unfused cases: - shape = input.shape - - assert max_num_elements_mnnvl % hidden_size == 0 - - input = input.view(-1, shape[-1]) - - buffer_M = max_num_elements_mnnvl // hidden_size - output = torch.empty_like(input) - - trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( - input, - multicast_ptr, - buffer_ptrs_dev, - buffer_M, - buffer_flags_mnnvl, - tensor_parallel_size, - tensor_parallel_rank, - True, # wait_for_results - False, # launch_with_pdl - output, # Need to provide output tensor since we are writing them out. - ) - return (output.view(shape),) - - output = func( - x.clone(), - residual.clone(), - norm_weight, - eps, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - ) - - assert output[0].shape == reference_output[0].shape - - if tensor_parallel_rank == 0: - print("output[0] (first 10 values):", output[0].flatten()[:10]) - print( - "reference_output[0] (first 10 values):", - reference_output[0].flatten()[:10], - ) - torch.testing.assert_close( - output[0], - reference_output[0], - rtol=0.05, - atol=0.15, - ) def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidden_size): # Set CUDA device based on rank @@ -223,8 +139,11 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list # This workspace is sized for the maximum expected sequence length and can be reused within each list # Each parameterized list gets its own fresh workspace allocation + explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * seq_len mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype, comm) + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, comm, explicit_workspace_bytes + ) ) multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() @@ -263,7 +182,9 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde reference_output = (allreduce_result,) # Run the test with the same workspace - row_linear_residual_norm_forward( + from .test_trtllm_mnnvl_allreduce import row_linear_residual_norm_fusion_forward + + row_linear_residual_norm_fusion_forward( x, residual, norm_weight, @@ -271,27 +192,29 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde hidden_size, dtype, mapping, + False, reference_output, multicast_ptr, buffer_ptrs_dev, unicast_ptr, max_num_elements_mnnvl, buffer_flags_mnnvl, + comm, ) # Synchronize before next test comm.barrier() - print( - f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}" - ) + print(f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}") except Exception as e: rank_failed = True - failure_message = f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}" + failure_message = ( + f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}" + ) print(failure_message) # Gather failure status from all ranks - all_failures = MPI.COMM_WORLD.allgather(rank_failed) + all_failures = comm.allgather(rank_failed) # If any rank failed, fail the test if any(all_failures): @@ -302,7 +225,7 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde # Fail the test on all ranks pytest.fail(f"Test failed on ranks {failed_ranks}") comm.barrier() - + finally: # Ensure cleanup happens for this list's workspace if "mcast_buffer_mnnvl" in locals(): @@ -311,10 +234,14 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde # Final synchronization and check for failures across all ranks comm.barrier() + """Main test function that runs on each MPI rank""" + + @pytest.mark.parametrize("world_size", [2, 4]) def test_mnnvl_allreduce_custom_communicator( - monkeypatch, world_size, + monkeypatch, + world_size, ): monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. seq_len = 24