@@ -99,6 +99,9 @@ def __init__(
9999 self ._distributed_env = _DistributedEnv .detect ()
100100 self ._follow_tensor_dimension = follow_tensor_dimension
101101
102+ self ._per_sample_num_bytes = 0
103+ self ._per_sample_num_items = 0
104+
102105 @property
103106 def filled (self ) -> bool :
104107 """Returns whether the caching phase is done."""
@@ -277,8 +280,9 @@ def __setitem__(self, index: int, items: Any) -> None:
277280 self .add_item (index , items )
278281
279282 def add_item (self , index : int , items : Any ) -> Optional [str ]:
280- # Track the minimum index provided to the writer
281- # Serialize the items and store an Item object.
283+ """Given an index and items will serialize the items and store an Item object to the growing
284+ `_serialized_items`."""
285+
282286 if index in self ._serialized_items :
283287 raise ValueError (f"The provided index { index } already exists in the cache." )
284288
@@ -289,23 +293,50 @@ def add_item(self, index: int, items: Any) -> Optional[str]:
289293 bytes = len (data ),
290294 dim = dim ,
291295 )
292-
293- if not self ._should_write ():
296+ if self ._min_index is None :
297+ # When processing the first item for the current chunk
298+ indexes = list (self ._serialized_items .keys ())
299+ self ._max_index = self ._min_index = indexes [0 ] if len (indexes ) == 1 else min (* indexes )
300+ self ._per_sample_num_items = self ._per_sample_num_bytes = 0
301+ if not self ._should_write ():
302+ return None
303+ elif index < self ._min_index :
304+ # reset the "temp" chunk
305+ self ._max_index = self ._min_index = index
306+ self ._per_sample_num_items = self ._per_sample_num_bytes = 0
307+ if not self ._should_write ():
308+ return None
309+ elif index == self ._max_index :
310+ if not self ._should_write ():
311+ return None
312+ else :
294313 return None
314+
295315 filepath = os .path .join (self ._cache_dir , self .get_chunk_filename ())
316+
296317 self .write_chunk ()
318+
319+ # now to reset
297320 self ._min_index = None
298321 self ._max_index = None
322+ self ._per_sample_num_bytes = 0
323+ self ._per_sample_num_items = 0
324+
299325 return filepath
300326
301327 def _should_write (self ) -> bool :
302328 # TODO: Misleading method name, it modifies `self._min_index` and `self._max_index`!
303329 if not self ._serialized_items :
304330 return False
305- indexes = list (self ._serialized_items .keys ())
306- self ._min_index = index = indexes [0 ] if len (indexes ) == 1 else min (* indexes )
307- num_bytes = 0
308- num_items = 0
331+
332+ if not isinstance (self ._max_index , int ):
333+ return False
334+
335+ # We have already validated the indexes from the interval `min_index` to `max_index`` are in `_serialized_items`
336+ # Resetting the num_bytes and num_items back the values.
337+ num_bytes = self ._per_sample_num_bytes
338+ num_items = self ._per_sample_num_items
339+ index = self ._max_index
309340 while True :
310341 item = self ._serialized_items .get (index , None )
311342 if item :
@@ -318,6 +349,9 @@ def _should_write(self) -> bool:
318349 self ._max_index = index - 1
319350 return True
320351 else :
352+ self ._per_sample_num_bytes = num_bytes
353+ self ._per_sample_num_items = num_items
354+ self ._max_index = index
321355 return False
322356
323357 def write_chunk_to_file (
0 commit comments