Skip to content

Commit bb362a0

Browse files
Performance improvement for processing (#146)
1 parent 62aeb58 commit bb362a0

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

src/litdata/streaming/writer.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def __init__(
9999
self._distributed_env = _DistributedEnv.detect()
100100
self._follow_tensor_dimension = follow_tensor_dimension
101101

102+
self._per_sample_num_bytes = 0
103+
self._per_sample_num_items = 0
104+
102105
@property
103106
def filled(self) -> bool:
104107
"""Returns whether the caching phase is done."""
@@ -277,8 +280,9 @@ def __setitem__(self, index: int, items: Any) -> None:
277280
self.add_item(index, items)
278281

279282
def add_item(self, index: int, items: Any) -> Optional[str]:
280-
# Track the minimum index provided to the writer
281-
# Serialize the items and store an Item object.
283+
"""Given an index and items will serialize the items and store an Item object to the growing
284+
`_serialized_items`."""
285+
282286
if index in self._serialized_items:
283287
raise ValueError(f"The provided index {index} already exists in the cache.")
284288

@@ -289,23 +293,50 @@ def add_item(self, index: int, items: Any) -> Optional[str]:
289293
bytes=len(data),
290294
dim=dim,
291295
)
292-
293-
if not self._should_write():
296+
if self._min_index is None:
297+
# When processing the first item for the current chunk
298+
indexes = list(self._serialized_items.keys())
299+
self._max_index = self._min_index = indexes[0] if len(indexes) == 1 else min(*indexes)
300+
self._per_sample_num_items = self._per_sample_num_bytes = 0
301+
if not self._should_write():
302+
return None
303+
elif index < self._min_index:
304+
# reset the "temp" chunk
305+
self._max_index = self._min_index = index
306+
self._per_sample_num_items = self._per_sample_num_bytes = 0
307+
if not self._should_write():
308+
return None
309+
elif index == self._max_index:
310+
if not self._should_write():
311+
return None
312+
else:
294313
return None
314+
295315
filepath = os.path.join(self._cache_dir, self.get_chunk_filename())
316+
296317
self.write_chunk()
318+
319+
# now to reset
297320
self._min_index = None
298321
self._max_index = None
322+
self._per_sample_num_bytes = 0
323+
self._per_sample_num_items = 0
324+
299325
return filepath
300326

301327
def _should_write(self) -> bool:
302328
# TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`!
303329
if not self._serialized_items:
304330
return False
305-
indexes = list(self._serialized_items.keys())
306-
self._min_index = index = indexes[0] if len(indexes) == 1 else min(*indexes)
307-
num_bytes = 0
308-
num_items = 0
331+
332+
if not isinstance(self._max_index, int):
333+
return False
334+
335+
# We have already validated the indexes from the interval `min_index` to `max_index`` are in `_serialized_items`
336+
# Resetting the num_bytes and num_items back the values.
337+
num_bytes = self._per_sample_num_bytes
338+
num_items = self._per_sample_num_items
339+
index = self._max_index
309340
while True:
310341
item = self._serialized_items.get(index, None)
311342
if item:
@@ -318,6 +349,9 @@ def _should_write(self) -> bool:
318349
self._max_index = index - 1
319350
return True
320351
else:
352+
self._per_sample_num_bytes = num_bytes
353+
self._per_sample_num_items = num_items
354+
self._max_index = index
321355
return False
322356

323357
def write_chunk_to_file(

tests/streaming/test_writer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,29 @@ def test_writer_human_format(tmpdir):
226226

227227
binary_writer = BinaryWriter(tmpdir, chunk_bytes="64MB")
228228
assert binary_writer._chunk_bytes == 64000000
229+
230+
231+
def test_writer_unordered_indexes(tmpdir):
232+
cache_dir = os.path.join(tmpdir, "chunks")
233+
os.makedirs(cache_dir, exist_ok=True)
234+
235+
binary_writer = BinaryWriter(cache_dir, chunk_size=5)
236+
237+
arr = [2, 3, 1, 4, 6, 5, 7, 8, 11, 9, 10, 12]
238+
239+
for i in arr:
240+
binary_writer[i] = i - 1
241+
242+
binary_writer.done()
243+
binary_writer.merge()
244+
245+
reader = BinaryReader(cache_dir)
246+
for i in range(12):
247+
assert i == reader.read(ChunkedIndex(i, chunk_index=i // 5))
248+
249+
with open(os.path.join(cache_dir, "index.json")) as f:
250+
data = json.load(f)
251+
252+
assert data["chunks"][0]["chunk_size"] == 5
253+
assert data["chunks"][1]["chunk_size"] == 5
254+
assert data["chunks"][2]["chunk_size"] == 2

0 commit comments

Comments
 (0)