Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 93 additions & 58 deletions python/tests/rdma_load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import argparse
import asyncio
import dataclasses
import os
import random
import statistics
import time


# parse up front to extract env variables.
args = None
if __name__ == "__main__":
Expand Down Expand Up @@ -56,6 +56,12 @@
default=10,
help="Number of warmup iterations (default: 5)",
)
parser.add_argument(
"--n-concurrent-operations",
type=int,
default=1,
help="Number of concurrent operations (default: 1)",
)

args = parser.parse_args()

Expand All @@ -72,6 +78,13 @@
from monarch.rdma import RDMABuffer


@dataclasses.dataclass
class RDMATestRequest:
buffer: RDMABuffer
shape: torch.Size
dtype: torch.dtype


class RDMATest(Actor):
def __init__(
self, device: str = "cpu", operation: str = "write", size_mb: int = 64
Expand All @@ -91,76 +104,96 @@ async def set_other_actor(self, other_actor):
self.other_actor = other_actor

@endpoint
async def send(self, is_warmup=False) -> None:
shape = int(
1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3))
) # Random size with +/- 50% variation based on user size

# Use the device string directly
tensor = torch.rand(shape, dtype=torch.float32, device=self.device)
size_elem = tensor.numel() * tensor.element_size()
tensor_addr = tensor.data_ptr()

# Critical validation - this should catch the null pointer issue
assert (
tensor_addr != 0
), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}"
assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}"

byte_view = tensor.view(torch.uint8).flatten()
# Validate byte_view too
byte_view_addr = byte_view.data_ptr()
assert (
byte_view_addr != 0
), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}"
assert (
byte_view_addr == tensor_addr
), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}"

execution_start = time.time()
buffer = RDMABuffer(byte_view)
execution_end = time.time()
elapsed = execution_end - execution_start

# Store timing and size data in this actor
size_elem = torch.numel(tensor) * tensor.element_size()
if not is_warmup:
self.timing_data.append(elapsed)
self.size_data.append(size_elem)
buffer_size = buffer.size()
assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}"
async def send(self, is_warmup=False, n_concurrent_operations=1) -> None:
requests: list[RDMATestRequest] = []
for _ in range(n_concurrent_operations):
shape = int(
1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3))
) # Random size with +/- 50% variation based on user size

# Use the device string directly
tensor = torch.rand(shape, dtype=torch.float32, device=self.device)
size_elem = tensor.numel() * tensor.element_size()
tensor_addr = tensor.data_ptr()

# Critical validation - this should catch the null pointer issue
assert (
tensor_addr != 0
), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}"
assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}"

byte_view = tensor.view(torch.uint8).flatten()
# Validate byte_view too
byte_view_addr = byte_view.data_ptr()
assert (
byte_view_addr != 0
), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}"
assert (
byte_view_addr == tensor_addr
), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}"

execution_start = time.time()
buffer = RDMABuffer(byte_view)
execution_end = time.time()
elapsed = execution_end - execution_start

# Store timing and size data in this actor
size_elem = torch.numel(tensor) * tensor.element_size()
if not is_warmup:
self.timing_data.append(elapsed)
self.size_data.append(size_elem)
buffer_size = buffer.size()
assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}"

requests.append(RDMATestRequest(buffer, tensor.shape, tensor.dtype))

# Call recv - timing happens there
await self.other_actor.recv.call(buffer, tensor.shape, tensor.dtype, is_warmup)
await self.other_actor.recv.call(requests, is_warmup)

# cleanup
await buffer.drop()
for req in requests:
await req.buffer.drop()

self.i += 1

@endpoint
async def recv(self, rdma_buffer, shape, dtype, is_warmup):
async def recv(self, requests, is_warmup):
# Create receiving tensor on the same device
tensor = torch.rand(shape, dtype=dtype, device=self.device)
byte_view = tensor.view(torch.uint8).flatten()
sizes = []
byte_views = []
for req in requests:
shape = req.shape
dtype = req.dtype
tensor = torch.rand(shape, dtype=dtype, device=self.device)
sizes.append(tensor.numel() * tensor.element_size())
byte_view = tensor.view(torch.uint8).flatten()
byte_views.append(byte_view)

coros = []

for i, req in enumerate(requests):
rdma_buffer = req.buffer
byte_view = byte_views[i]

async def op_coro(rdma_buffer=rdma_buffer, byte_view=byte_view):
if self.operation == "write":
await rdma_buffer.write_from(byte_view, timeout=5)
elif self.operation == "read":
await rdma_buffer.read_into(byte_view, timeout=5)
elif self.operation == "ping-pong":
if self.i % 2 == 0:
await rdma_buffer.write_from(byte_view, timeout=5)
else:
await rdma_buffer.read_into(byte_view, timeout=5)

coros.append(op_coro(rdma_buffer=rdma_buffer, byte_view=byte_view))

execution_start = time.time()

if self.operation == "write":
await rdma_buffer.write_from(byte_view, timeout=5)
elif self.operation == "read":
await rdma_buffer.read_into(byte_view, timeout=5)
elif self.operation == "ping-pong":
if self.i % 2 == 0:
await rdma_buffer.write_from(byte_view, timeout=5)
else:
await rdma_buffer.read_into(byte_view, timeout=5)

await asyncio.gather(*coros)
execution_end = time.time()
elapsed = execution_end - execution_start

# Store timing and size data in this actor
size_elem = torch.numel(tensor) * tensor.element_size()
size_elem = sum(sizes)
if not is_warmup:
self.timing_data.append(elapsed)
self.size_data.append(size_elem)
Expand Down Expand Up @@ -227,6 +260,7 @@ async def main(
operation: str = "write",
size_mb: int = 64,
warmup_iterations: int = 10,
n_concurrent_operations: int = 1,
):
# Adjust GPU allocation based on the device types
device_0, device_1 = devices[0], devices[1]
Expand All @@ -248,7 +282,7 @@ async def main(
await actor_0.send.call(is_warmup=True)

for i in range(iterations):
await actor_0.send.call()
await actor_0.send.call(n_concurrent_operations=n_concurrent_operations)

# Have both actors print their statistics
print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===")
Expand Down Expand Up @@ -313,5 +347,6 @@ async def main(
args.operation,
args.size,
args.warmup_iterations,
args.n_concurrent_operations,
)
)
Loading