diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 12aec978ec..2d280a68e8 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -155,6 +155,9 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def barrier(self) -> None: ... + @abstractmethod def Split(self, color: int, key: int) -> "CommBackend": ... @@ -209,6 +212,9 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def barrier(self): + self._mpicomm.Barrier() + def Split(self, color: int, key: int) -> CommBackend: self._mpicomm = self._mpicomm.Split(color, key) return MPIBackend() # Returns new adapter @@ -555,6 +561,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -631,7 +638,7 @@ def __init__( "[McastDeviceMemory] Device does not support fabric handle." ) - self._alloc_mn_mcast_mem(buf_size) + self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer) else: # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem raise NotImplementedError("Single-node NVLS allocation not implemented yet") @@ -753,7 +760,9 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem(self, buf_size: int): + def _alloc_mn_mcast_mem( + self, buf_size: int, comm_backend_for_handle_transfer: Any = None + ): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context @@ -766,10 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) except Exception as e: print(f"Error checking CUDA context: {e}") - - # Get MPI communicator - comm = MpiComm() - + if comm_backend_for_handle_transfer is None: + comm = MpiComm() + else: + comm = comm_backend_for_handle_transfer # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC @@ -969,6 +978,7 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): """ Constructor for McastGpuBuffer. @@ -979,9 +989,15 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used + comm_backend_for_handle_transfer: Communication backend for handle transfer """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink + buf_size, + group_size, + group_rank, + device.index, + mn_nvlink, + comm_backend_for_handle_transfer, ) self.buf_size = buf_size self.local_device = device diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 76aedee260..84a9c150de 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -15,7 +15,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer +from .mnnvl import McastGPUBuffer, CommBackend def mpi_barrier(): @@ -122,7 +122,10 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None + mapping: Mapping, + dtype: torch.dtype, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, + buffer_size_in_bytes: Optional[int] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -138,6 +141,7 @@ def get_allreduce_mnnvl_workspace( Args: mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced + comm: Optional communication backend for multi-node synchronization buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens Returns: @@ -167,6 +171,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, + comm_backend_for_handle_transfer=comm_backend_for_handle_transfer, ) # Initialize the unicast buffer with -0.0 @@ -174,7 +179,10 @@ def get_allreduce_mnnvl_workspace( # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + if comm_backend_for_handle_transfer is None: + mpi_barrier() + else: + comm_backend_for_handle_transfer.barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index abb3795019..e7274c46f0 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,5 +1,5 @@ # Check torch version: -from typing import Tuple +from typing import Tuple, Optional import pytest import torch @@ -7,6 +7,7 @@ import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import CommBackend, MpiComm # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -28,6 +29,7 @@ def row_linear_residual_norm_fusion_forward( unicast_ptr: int, max_num_elements_mnnvl: int, buffer_flags_mnnvl: torch.Tensor, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): x = x.cuda() residual = residual.cuda() @@ -36,8 +38,11 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - - MPI.COMM_WORLD.barrier() + if comm_backend_for_handle_transfer is None: + comm = MpiComm() + else: + comm = comm_backend_for_handle_transfer + comm.barrier() def func( input, diff --git a/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py new file mode 100644 index 0000000000..60933cf89b --- /dev/null +++ b/tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py @@ -0,0 +1,263 @@ +# Check torch version: +from typing import Any, Tuple + +import multiprocessing as mp +import socket +import pytest +import torch +import torch.distributed as dist + +import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar +from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import CommBackend as CommBackend + +import pynvml + +pynvml.nvmlInit() + + +class CustomCommunicator(CommBackend): + def __init__(self, group): + self._group = group + + def Get_rank(self) -> int: + return dist.get_rank(self._group) + + def Get_size(self) -> int: + return dist.get_world_size(self._group) + + def allgather(self, data: int | bytes): + device = f"cuda:{torch.cuda.current_device()}" + if isinstance(data, int): + local_tensor = torch.tensor([data], device=device, dtype=torch.int32) + world_size = self.Get_size() + gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)] + + dist.all_gather(gathered, local_tensor, group=self._group) + return [int(x.item()) for x in gathered] + + elif isinstance(data, bytes): + local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device) + world_size = self.Get_size() + gathered = [data] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + else: + raise TypeError(f"Unsupported type for allgather: {type(data)}") + + def bcast(self, data, root: int = 0): + """ + Broadcast a picklable Python object from `root` to all ranks. + Uses torch.distributed.broadcast_object_list under the hood. + + Returns the broadcasted object on every rank. + """ + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] + + def barrier(self): + """ + Synchronize all ranks in this communicator. + """ + dist.barrier(group=self._group) + + def Split(self, color: int, key: int) -> "CustomCommunicator": + return self + + +def get_open_port() -> int: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("::1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = () +) -> None: + mp.set_start_method("spawn", force=True) + + procs = [] + distributed_init_port = get_open_port() + for i in range(world_size): + proc_args = (world_size, i, dtype, distributed_init_port) + target_args + proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") + proc.start() + procs.append(proc) + + for i in range(world_size): + procs[i].join() + assert procs[i].exitcode == 0, ( + f"Process {i} failed with exit code {procs[i].exitcode}" + ) + + +def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidden_size): + # Set CUDA device based on rank + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = dist.group.WORLD + torch.cuda.set_device(rank) + comm = CustomCommunicator(group) + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + ) + + if mapping.local_rank == 0: + print( + f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" + ) + print( + f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + ) + + tensor_parallel_size = world_size + eps = 1e-5 + torch.manual_seed(42) + + # Track if this rank failed + rank_failed = False + failure_message = "" + + try: + # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list + # This workspace is sized for the maximum expected sequence length and can be reused within each list + # Each parameterized list gets its own fresh workspace allocation + explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * seq_len + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, comm, explicit_workspace_bytes + ) + ) + + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) + + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, dtype={dtype}" + ) + + # Generate test data (same on all ranks due to same seed) + x_full = torch.randn( + (tensor_parallel_size, seq_len, hidden_size), + dtype=dtype, + device=torch.device("cuda"), + ) + residual = torch.randn( + (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + ) + norm_weight = torch.randn( + (hidden_size,), dtype=dtype, device=torch.device("cuda") + ) + + # Each rank gets its slice of the input + x = x_full[rank, :, :] + + # Compute reference output based on fusion mode + reference_output: Tuple[torch.Tensor, ...] = None + + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + + # Run the test with the same workspace + from .test_trtllm_mnnvl_allreduce import row_linear_residual_norm_fusion_forward + + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + False, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, + comm, + ) + + # Synchronize before next test + comm.barrier() + + print(f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}") + + except Exception as e: + rank_failed = True + failure_message = ( + f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}" + ) + print(failure_message) + # Gather failure status from all ranks + all_failures = comm.allgather(rank_failed) + + # If any rank failed, fail the test + if any(all_failures): + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}") + + # Fail the test on all ranks + pytest.fail(f"Test failed on ranks {failed_ranks}") + comm.barrier() + + finally: + # Ensure cleanup happens for this list's workspace + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl + + # Final synchronization and check for failures across all ranks + comm.barrier() + + +"""Main test function that runs on each MPI rank""" + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mnnvl_allreduce_custom_communicator( + monkeypatch, + world_size, +): + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + seq_len = 24 + dtype = torch.bfloat16 + hidden_size = 2048 + + available_gpus = torch.cuda.device_count() + if world_size > available_gpus: + raise ValueError( + f"world_size {world_size} is greater than available_gpus {available_gpus}" + ) + print(f"Running test for world_size={world_size}") + multi_process_parallel( + world_size, + dtype, + _run_mnnvl_ar, + target_args=(seq_len, hidden_size), + ) + print(f"custom mnnvl allreduce world_size = {world_size}: OK")