Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/test_protocol_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,3 +1232,98 @@ def test_serialize_dataproto_with_empty_tensordict():
deserialized_data = pickle.loads(serialized_data)
assert len(deserialized_data.batch.keys()) == 0
assert deserialized_data.batch.batch_size == torch.Size([10])


# --- DataProtoFuture Tests ---

import ray # noqa: E402

from verl.protocol import DataProtoFuture # noqa: E402


@pytest.fixture(scope="module", autouse=True)
def setup_ray():
ray.init(ignore_reinit_error=True)
yield
ray.shutdown()


def test_data_proto_future_chunk_even():
# Create 2 DataProtos of size 4 each
dp1 = DataProto.from_dict({"a": torch.arange(0, 4)})
dp2 = DataProto.from_dict({"a": torch.arange(4, 8)})

ref1 = ray.put(dp1)
ref2 = ray.put(dp2)

future = DataProtoFuture.concat([ref1, ref2])

# Chunk into 4 pieces. Each piece should have size 2
chunks = future.chunk(4)
assert len(chunks) == 4

res0 = chunks[0].get()
assert torch.equal(res0.batch["a"], torch.tensor([0, 1]))

res1 = chunks[1].get()
assert torch.equal(res1.batch["a"], torch.tensor([2, 3]))

res2 = chunks[2].get()
assert torch.equal(res2.batch["a"], torch.tensor([4, 5]))

res3 = chunks[3].get()
assert torch.equal(res3.batch["a"], torch.tensor([6, 7]))


def test_data_proto_future_chunk_uneven_overlap():
# Create 3 DataProtos of size 4 each
dp1 = DataProto.from_dict({"a": torch.arange(0, 4)})
dp2 = DataProto.from_dict({"a": torch.arange(4, 8)})
dp3 = DataProto.from_dict({"a": torch.arange(8, 12)})

refs = [ray.put(dp) for dp in [dp1, dp2, dp3]]
future = DataProtoFuture.concat(refs)

chunks = future.chunk(2)
assert len(chunks) == 2

res0 = chunks[0].get()
assert len(res0) == 6
assert torch.equal(res0.batch["a"], torch.arange(0, 6))

res1 = chunks[1].get()
assert len(res1) == 6
assert torch.equal(res1.batch["a"], torch.arange(6, 12))


def test_data_proto_future_nested_chunk():
dp1 = DataProto.from_dict({"a": torch.arange(0, 4)})
dp2 = DataProto.from_dict({"a": torch.arange(4, 8)})
future = DataProtoFuture.concat([ray.put(dp1), ray.put(dp2)])

# First chunk into 2 pieces (size 4 each)
first_chunks = future.chunk(2)
# Then chunk the first piece into 2 more pieces (size 2 each)
second_chunks = first_chunks[0].chunk(2)

res0 = second_chunks[0].get()
assert len(res0) == 2
assert torch.equal(res0.batch["a"], torch.tensor([0, 1]))

res1 = second_chunks[1].get()
assert len(res1) == 2
assert torch.equal(res1.batch["a"], torch.tensor([2, 3]))


def test_data_proto_future_tensordict():
td1 = TensorDict({"a": torch.arange(0, 4)}, batch_size=[4])
td2 = TensorDict({"a": torch.arange(4, 8)}, batch_size=[4])
future = DataProtoFuture.concat([ray.put(td1), ray.put(td2)])

chunks = future.chunk(2)
res0 = chunks[0].get()
assert isinstance(res0, TensorDict)
assert torch.equal(res0["a"], torch.arange(0, 4))

res1 = chunks[1].get()
assert torch.equal(res1["a"], torch.arange(4, 8))
98 changes: 74 additions & 24 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import os
import pickle
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
from fractions import Fraction
from typing import Any, Optional

import numpy as np
import ray
Expand Down Expand Up @@ -214,7 +215,7 @@ def fold_batch_dim(data: "DataProto", new_batch_size):
tensor.auto_batch_size_(batch_dims=1)

for key, val in non_tensor.items():
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
non_tensor[key] = np.reshape(val, (new_batch_size, -1, *val.shape[1:]))

return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)

Expand All @@ -233,7 +234,7 @@ def unfold_batch_dim(data: "DataProto", batch_dims=2):
non_tensor_new = {}

for key, val in non_tensor.items():
non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))
non_tensor_new[key] = np.reshape(val, (batch_size, *val.shape[batch_dims:]))

return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)

Expand Down Expand Up @@ -1173,44 +1174,95 @@ def _get_type_info(self, value):
@dataclass
class DataProtoFuture:
"""
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
DataProtoFuture aims to eliminate actual data fetching on the driver. By doing so, the driver doesn't have to wait
for data so that asynchronous execution becomes possible.
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
- collect_fn is a Callable that reduces the list of futures to a DataProto
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size
and then select

Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
- DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
operation on the DataProtoFuture in driver.

DataProtoFuture contains a list of futures from another WorkerGroup of world_size. It only supports
directly passing from the output of a method to another as input. You cannot perform any operation
on the DataProtoFuture in the driver.

The basic assumption is that all futures have the same size, and the fractions form a contiguous range, meaning
that only the first and last future may be partial.

Attributes
----------
futures : list[ray.ObjectRef]
A list of Ray object references representing the distributed data chunks.
start_fraction : Fraction, optional
The starting fraction of the first future. Defaults to Fraction(0).
end_fraction : Fraction, optional
The ending fraction of the last future. Defaults to Fraction(1).
"""

collect_fn: Callable
futures: list[ray.ObjectRef]
dispatch_fn: Callable = None
start_fraction: Fraction = field(default_factory=lambda: Fraction(0))
end_fraction: Fraction = field(default_factory=lambda: Fraction(1))

@staticmethod
def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture":
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
output = DataProtoFuture(futures=data)
return output

def chunk(self, chunks: int) -> list["DataProtoFuture"]:
from functools import partial
import math
from fractions import Fraction

# Total number of futures including fractions (exact rational arithmetic).
total_futures = len(self.futures) - 1 + self.end_fraction - self.start_fraction
# Start fraction of the first future is considered as the global start fractional offset.
global_start_frac = self.start_fraction
# Number of futures per chunk including fractions.
num_futures_in_chunk = total_futures / chunks

arg_future_lst = []
for i in range(chunks):
# note that we can't directly pass i and chunks
def dispatch_fn(x, i, chunks):
return x.chunk(chunks=chunks)[i]
# Chunk's global start and end fractional offsets.
chunk_start_global_frac = global_start_frac + i * num_futures_in_chunk
chunk_end_global_frac = global_start_frac + (i + 1) * num_futures_in_chunk

start_future_idx = math.floor(chunk_start_global_frac)
start_future_frac = chunk_start_global_frac - start_future_idx

end_future_idx = math.floor(chunk_end_global_frac)
end_future_frac = chunk_end_global_frac - end_future_idx

if end_future_frac == 0 and chunk_end_global_frac > chunk_start_global_frac:
end_future_idx -= 1
end_future_frac = Fraction(1)

futures_in_chunk = self.futures[start_future_idx : end_future_idx + 1]
arg_future = DataProtoFuture(
collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures
futures=futures_in_chunk, start_fraction=start_future_frac, end_fraction=end_future_frac
)
arg_future_lst.append(arg_future)
return arg_future_lst

def get(self):
output = ray.get(self.futures) # dp_size.
# Fetch only the needed futures natively tracked by this object
fetched_data = ray.get(self.futures)

# Calculate the integer offsets in each future, and stitch them together.
# Using exact Fraction arithmetic avoids float rounding errors when converting
# back to integer indices (e.g. Fraction(1,3) * 3 == Fraction(1) exactly).
output = []
for i, data in enumerate(fetched_data):
data_len = len(data)

if i == 0:
start_offset = int(self.start_fraction * data_len)
else:
start_offset = 0

if i == len(fetched_data) - 1:
end_offset = int(self.end_fraction * data_len)
else:
end_offset = data_len
Comment thread
yurun00 marked this conversation as resolved.

if start_offset == 0 and end_offset == data_len:
output.append(data)
else:
output.append(data[start_offset:end_offset])

for o in output:
assert isinstance(o, DataProto | TensorDict)

Expand All @@ -1221,10 +1273,8 @@ def get(self):

output = concat_tensordict(output)
else:
raise TypeError(f"Unknown type {type(o[0])} in DataProtoFuture")
raise TypeError(f"Unknown type {type(output[0])} in DataProtoFuture")

if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp
return output


Expand Down