Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,14 @@ def supports_mnnvl() -> bool:

class McastDeviceMemory:
"""Python port of McastDeviceMemory from TensorRT-LLM"""

def __init__(
self,
buf_size: int,
group_size: int,
group_rank: int,
device_idx: int,
is_multi_node: bool = True,
comm: Optional[CommBackend] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about calling it comm_backend or communicator?
cc @nvmbreughe in case you have any preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comm_backend sounds good. Or even more explicit comm_backend_for_handle_transfer Besides the name I would also list some options: MpiComm()

):
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))

Expand Down Expand Up @@ -631,7 +631,7 @@ def __init__(
"[McastDeviceMemory] Device does not support fabric handle."
)

self._alloc_mn_mcast_mem(buf_size)
self._alloc_mn_mcast_mem(buf_size, comm)
else:
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
Expand Down Expand Up @@ -753,7 +753,7 @@ def get_world_size(self) -> int:
"""Get the total number of devices in the group"""
return self.group_size

def _alloc_mn_mcast_mem(self, buf_size: int):
def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
"""Allocate multi-node multicast memory using MNNVL"""

# Verify CUDA context
Expand All @@ -767,9 +767,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
except Exception as e:
print(f"Error checking CUDA context: {e}")

# Get MPI communicator
comm = MpiComm()

# Set up allocation properties
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC

Expand Down Expand Up @@ -969,6 +966,7 @@ def __init__(
group_rank: int,
device: torch.device,
mn_nvlink: bool = True,
comm: Optional[CommBackend] = None,
):
"""
Constructor for McastGpuBuffer.
Expand All @@ -981,7 +979,7 @@ def __init__(
mn_nvlink: Flag indicating if multi-node NVLink is used
"""
self.mcast_device_memory = McastDeviceMemory(
buf_size, group_size, group_rank, device.index, mn_nvlink
buf_size, group_size, group_rank, device.index, mn_nvlink, comm
)
self.buf_size = buf_size
self.local_device = device
Expand Down
11 changes: 8 additions & 3 deletions flashinfer/comm/trtllm_mnnvl_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..jit import gen_trtllm_mnnvl_comm_module
from ..utils import register_custom_op
from .mnnvl import McastGPUBuffer
from .mnnvl import (McastGPUBuffer, CommBackend)


def mpi_barrier():
Expand Down Expand Up @@ -122,7 +122,8 @@ def trtllm_mnnvl_rmsnorm(


def get_allreduce_mnnvl_workspace(
mapping: Mapping, dtype: torch.dtype
mapping: Mapping, dtype: torch.dtype,
comm: Optional[CommBackend] = None,
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.

Expand Down Expand Up @@ -164,14 +165,18 @@ def get_allreduce_mnnvl_workspace(
mapping.tp_rank,
torch.device("cuda", mapping.local_rank),
mapping.is_multi_node() or force_mn,
comm=comm,
)

# Initialize the unicast buffer with -0.0
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)

# CPU barrier since we assume this should not be called in cuda graph
torch.cuda.synchronize()
mpi_barrier()
if comm:
comm.barrier()
else:
mpi_barrier()

# This is a buffer to maintain the state of this allreduce Op
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]
Expand Down
Loading
Loading