1919from collections .abc import Iterator , Sequence
2020import copy
2121import functools
22+ import math
2223from multiprocessing import queues
2324import queue
2425import threading
@@ -69,6 +70,10 @@ def _initialize_prefetch_stats(
6970 )
7071
7172
73+ def _is_batch_iter_pushdown_experiment_enabled () -> bool :
74+ return False
75+
76+
7277@dataset_stats .trace_input_pipeline_prefetch
7378def _getitem (
7479 stats : dataset_stats .Stats , parent : dataset .MapDataset [T ], index : int
@@ -77,6 +82,16 @@ def _getitem(
7782 return stats .record_bytes_consumed (parent [index ])
7883
7984
85+ @dataset_stats .trace_input_pipeline_prefetch
86+ def _getitems (
87+ stats : dataset_stats .Stats ,
88+ parent : dataset .MapDataset [T ],
89+ indices : list [int ],
90+ ) -> list [T ]:
91+ """Helper to record the memory usage of the elements before prefetching."""
92+ return [stats .record_bytes_consumed (x ) for x in parent ._getitems (indices )] # pylint: disable=protected-access
93+
94+
8095@typing .runtime_checkable
8196class SupportsInPlaceSlicing (Protocol ):
8297 """Datasets that support mutation by setting the processed data slice."""
@@ -114,6 +129,10 @@ def __str__(self) -> str:
114129 )
115130
116131 def __iter__ (self ) -> dataset .DatasetIterator [T ]:
132+ if _is_batch_iter_pushdown_experiment_enabled ():
133+ return _BatchedPrefetchDatasetIterator (
134+ self ._parent , self ._read_options , self ._allow_nones
135+ )
117136 return PrefetchDatasetIterator (
118137 self ._parent , self ._read_options , self ._allow_nones
119138 )
@@ -141,14 +160,15 @@ def __init__(
141160 self ._read_options = read_options
142161 self ._next_returned_index = 0
143162 self ._next_buffered_index = 0
163+ # Buffer of (future, batch_size) tuples for prefetched elements.
144164 self ._buffer = collections .deque ()
145165 self ._lock = threading .Lock ()
146- self ._prefetch_buffer_size = (
166+ self ._target_buffer_size = (
147167 read_options .prefetch_buffer_size if read_options .num_threads > 0 else 0
148168 )
149169 self ._num_threads = read_options .num_threads
150170 self ._allow_nones = allow_nones
151- if self ._prefetch_buffer_size > 0 :
171+ if self ._target_buffer_size > 0 :
152172 self ._executor = futures .ThreadPoolExecutor (
153173 self ._num_threads , thread_name_prefix = "grain-prefetch"
154174 )
@@ -194,25 +214,27 @@ def __next__(self) -> T:
194214 if self ._next_returned_index == self ._dataset_length :
195215 break
196216 with self ._lock , timer :
197- if self ._prefetch_buffer_size > 0 :
217+ if self ._target_buffer_size > 0 :
198218 if not self ._buffer :
199219 # Fill the buffer on the first iteration.
200220 self ._fill_buffer ()
201- element = self ._buffer .popleft ()
221+ future , _ = self ._buffer .popleft ()
202222 # Prefetch elements until the buffer is full again.
203223 self ._fill_buffer ()
204- element = element .result ()
224+ element = future .result ()
205225 else :
206226 # In case prefetch buffer size was decreased, we still want to consume
207227 # the already prefetched elements.
208228 if self ._buffer :
209- element = self ._buffer .popleft ().result ()
229+ future , _ = self ._buffer .popleft ()
230+ element = future .result ()
210231 else :
211232 element = self ._stats .record_bytes_consumed (
212233 self ._map_parent [self ._next_returned_index ]
213234 )
214235 self ._next_buffered_index += 1
215236 self ._next_returned_index += 1
237+
216238 return_element = self ._allow_nones or element is not None
217239 self ._threshold_checker .check (return_element )
218240 if return_element :
@@ -224,23 +246,26 @@ def __next__(self) -> T:
224246 def get_state (self ):
225247 return {"next_index" : self ._next_returned_index }
226248
249+ def _set_state_helper (self , state ):
250+ self ._next_returned_index = state ["next_index" ]
251+ self ._next_buffered_index = self ._next_returned_index
252+ if (
253+ self ._next_returned_index < 0
254+ or self ._next_returned_index > self ._dataset_length
255+ ):
256+ raise IndexError (
257+ f"Checkpoint `next_index` { self ._next_returned_index } is out of"
258+ f" range for dataset of length { self ._dataset_length } ."
259+ )
260+ if self ._target_buffer_size > 0 :
261+ # Cancel all pending futures in the buffer.
262+ while self ._buffer :
263+ future , _ = self ._buffer .popleft ()
264+ future .cancel ()
265+
227266 def set_state (self , state ):
228267 with self ._lock :
229- self ._next_returned_index = state ["next_index" ]
230- self ._next_buffered_index = self ._next_returned_index
231- if (
232- self ._next_returned_index < 0
233- or self ._next_returned_index > self ._dataset_length
234- ):
235- raise IndexError (
236- f"Checkpoint `next_index` { self ._next_returned_index } is out of"
237- f" range for dataset of length { self ._dataset_length } ."
238- )
239- if self ._prefetch_buffer_size > 0 :
240- # Cancel all pending futures in the buffer.
241- while self ._buffer :
242- future = self ._buffer .popleft ()
243- future .cancel ()
268+ self ._set_state_helper (state )
244269
245270 def _get_next_index (self ) -> int :
246271 return self ._next_returned_index
@@ -254,12 +279,12 @@ def __str__(self) -> str:
254279 f" allow_nones={ self ._allow_nones } )"
255280 )
256281
257- def set_prefetch_buffer_size (self , buffer_size : int ):
258- self ._prefetch_buffer_size = buffer_size
282+ def set_target_buffer_size (self , buffer_size : int ):
283+ self ._target_buffer_size = buffer_size
259284 # The executor is created in the constructor only if the prefetch buffer
260285 # size is greater than 0. If the user changes the prefetch buffer size, we
261286 # need to create or destroy the executor accordingly.
262- if self ._prefetch_buffer_size > 0 and not hasattr (self , "_executor" ):
287+ if self ._target_buffer_size > 0 and not hasattr (self , "_executor" ):
263288 if self ._num_threads == 0 :
264289 raise ValueError (
265290 "num_threads must be greater than 0 when prefetch buffer size is"
@@ -268,7 +293,7 @@ def set_prefetch_buffer_size(self, buffer_size: int):
268293 self ._executor = futures .ThreadPoolExecutor (
269294 self ._num_threads , thread_name_prefix = "grain-prefetch"
270295 )
271- elif self ._prefetch_buffer_size == 0 and hasattr (self , "_executor" ):
296+ elif self ._target_buffer_size == 0 and hasattr (self , "_executor" ):
272297 self ._executor .shutdown ()
273298 delattr (self , "_executor" )
274299
@@ -292,21 +317,22 @@ def set_num_threads(self, num_threads: int) -> None:
292317
293318 def _fill_buffer (self ):
294319 while (
295- len (self ._buffer ) < self ._prefetch_buffer_size
320+ len (self ._buffer ) < self ._target_buffer_size
296321 and self ._next_buffered_index < self ._dataset_length
297322 ):
298323 # Note that we trigger creation of `_stats` in this (single) thread, it is
299324 # important because the stats initialization is not thread-safe.
300- self ._buffer .append (
325+ self ._buffer .append ((
301326 self ._executor .submit (
302327 functools .partial (_getitem , self ._stats , self ._map_parent ),
303328 self ._next_buffered_index ,
304- )
305- )
329+ ),
330+ 1 , # batch_size = 1 when batch pushdown is not used.
331+ ))
306332 self ._next_buffered_index += 1
307333
308334 def start_prefetch (self ):
309- if self ._prefetch_buffer_size > 0 :
335+ if self ._target_buffer_size > 0 :
310336 self ._fill_buffer ()
311337
312338 def close (self ) -> None :
@@ -319,10 +345,122 @@ def close(self) -> None:
319345 self ._executor .shutdown (wait = False )
320346 # Cancel all pending futures in the buffer.
321347 while self ._buffer :
322- future = self ._buffer .popleft ()
348+ future , _ = self ._buffer .popleft ()
323349 future .cancel ()
324350
325351
352+ class _BatchedPrefetchDatasetIterator (PrefetchDatasetIterator [T ]):
353+ """Iterator that performs prefetching in batches using a thread pool."""
354+
355+ def __init__ (
356+ self ,
357+ parent : dataset .MapDataset [T ],
358+ read_options : grain_options .ReadOptions ,
359+ allow_nones : bool ,
360+ ):
361+ super ().__init__ (parent , read_options , allow_nones )
362+ # The number of elements to prefetch in each batch.
363+ self ._batch_pushdown_size = (
364+ int (math .ceil (self ._target_buffer_size / self ._num_threads ))
365+ if self ._target_buffer_size > 0
366+ else 1
367+ )
368+ # Queue of elements from the most recently completed prefetch batch.
369+ self ._current_batch = collections .deque ()
370+ # Total count of elements across all pending futures in _buffer.
371+ self ._total_buffered_count = 0
372+
373+ @dataset_stats .record_next_duration_if_output
374+ @dataset_stats .trace_input_pipeline_next (
375+ stage_category = dataset_stats .IPL_CAT_PREFETCH
376+ )
377+ def __next__ (self ) -> T :
378+ self ._assert_not_closed ()
379+ # The time recorded here is the time spent in prefetch node to return an
380+ # element, including the time spent in parent node.
381+ timer = dataset_stats .Timer ()
382+ # We loop here to skip all None elements (in case the underlying dataset
383+ # is sparse), if self._allow_nones = False, else we return Nones too.
384+ while True :
385+ with self ._lock , timer :
386+ if not self ._current_batch :
387+ if self ._next_returned_index == self ._dataset_length :
388+ break
389+ if self ._target_buffer_size > 0 :
390+ if not self ._buffer :
391+ # Fill the buffer on the first iteration.
392+ self ._fill_buffer ()
393+ future , batch_size = self ._buffer .popleft ()
394+ self ._total_buffered_count -= batch_size
395+ # Prefetch elements until the buffer is full again.
396+ self ._fill_buffer ()
397+ batch = future .result ()
398+ self ._current_batch .extend (batch )
399+ else :
400+ # In case prefetch buffer size was decreased, we still want to
401+ # consume the already prefetched elements.
402+ if self ._buffer :
403+ future , batch_size = self ._buffer .popleft ()
404+ batch = future .result ()
405+ self ._total_buffered_count -= batch_size
406+ self ._current_batch .extend (batch )
407+ else :
408+ element = self ._stats .record_bytes_consumed (
409+ self ._map_parent [self ._next_returned_index ]
410+ )
411+ self ._next_buffered_index += 1
412+ self ._current_batch .append (element )
413+
414+ element = self ._current_batch .popleft ()
415+ self ._next_returned_index += 1
416+
417+ return_element = self ._allow_nones or element is not None
418+ self ._threshold_checker .check (return_element )
419+ if return_element :
420+ with self ._stats .record_self_time (offset_ns = timer .value ()):
421+ element = self ._stats .record_bytes_produced (element )
422+ return self ._stats .record_output_spec (element )
423+ raise StopIteration
424+
425+ def _set_state_helper (self , state ):
426+ super ()._set_state_helper (state )
427+ self ._current_batch .clear ()
428+ self ._total_buffered_count = 0
429+
430+ def _fill_buffer (self ):
431+ while (
432+ self ._total_buffered_count < self ._target_buffer_size
433+ and self ._next_buffered_index < self ._dataset_length
434+ ):
435+ batch_size = min (
436+ self ._batch_pushdown_size ,
437+ self ._dataset_length - self ._next_buffered_index ,
438+ self ._target_buffer_size - self ._total_buffered_count ,
439+ )
440+ indices = list (
441+ range (
442+ self ._next_buffered_index , self ._next_buffered_index + batch_size
443+ )
444+ )
445+ # Note that we trigger creation of `_stats` in this (single) thread, it is
446+ # important because the stats initialization is not thread-safe.
447+ self ._buffer .append ((
448+ self ._executor .submit (
449+ functools .partial (_getitems , self ._stats , self ._map_parent ),
450+ indices ,
451+ ),
452+ batch_size ,
453+ ))
454+ self ._next_buffered_index += batch_size
455+ self ._total_buffered_count += batch_size
456+
457+ def __str__ (self ) -> str :
458+ return (
459+ f"_BatchedPrefetchDatasetIterator(read_options={ self ._read_options } ,"
460+ f" allow_nones={ self ._allow_nones } )"
461+ )
462+
463+
326464def _set_slice_iter_dataset (
327465 ds : dataset .IterDataset ,
328466 sl : slice ,
@@ -729,6 +867,7 @@ def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool:
729867 it ,
730868 (
731869 PrefetchDatasetIterator ,
870+ _BatchedPrefetchDatasetIterator ,
732871 ThreadPrefetchDatasetIterator ,
733872 interleave .InterleaveDatasetIterator ,
734873 ),
0 commit comments