Skip to content

Commit 1a0704e

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 871476435
1 parent e05a18d commit 1a0704e

File tree

2 files changed

+186
-47
lines changed

2 files changed

+186
-47
lines changed

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 170 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections.abc import Iterator, Sequence
2020
import copy
2121
import functools
22+
import math
2223
from multiprocessing import queues
2324
import queue
2425
import threading
@@ -69,6 +70,10 @@ def _initialize_prefetch_stats(
6970
)
7071

7172

73+
def _is_batch_iter_pushdown_experiment_enabled() -> bool:
74+
return False
75+
76+
7277
@dataset_stats.trace_input_pipeline_prefetch
7378
def _getitem(
7479
stats: dataset_stats.Stats, parent: dataset.MapDataset[T], index: int
@@ -77,6 +82,16 @@ def _getitem(
7782
return stats.record_bytes_consumed(parent[index])
7883

7984

85+
@dataset_stats.trace_input_pipeline_prefetch
86+
def _getitems(
87+
stats: dataset_stats.Stats,
88+
parent: dataset.MapDataset[T],
89+
indices: list[int],
90+
) -> list[T]:
91+
"""Helper to record the memory usage of the elements before prefetching."""
92+
return [stats.record_bytes_consumed(x) for x in parent._getitems(indices)] # pylint: disable=protected-access
93+
94+
8095
@typing.runtime_checkable
8196
class SupportsInPlaceSlicing(Protocol):
8297
"""Datasets that support mutation by setting the processed data slice."""
@@ -114,6 +129,10 @@ def __str__(self) -> str:
114129
)
115130

116131
def __iter__(self) -> dataset.DatasetIterator[T]:
132+
if _is_batch_iter_pushdown_experiment_enabled():
133+
return _BatchedPrefetchDatasetIterator(
134+
self._parent, self._read_options, self._allow_nones
135+
)
117136
return PrefetchDatasetIterator(
118137
self._parent, self._read_options, self._allow_nones
119138
)
@@ -141,14 +160,15 @@ def __init__(
141160
self._read_options = read_options
142161
self._next_returned_index = 0
143162
self._next_buffered_index = 0
163+
# Buffer of (future, batch_size) tuples for prefetched elements.
144164
self._buffer = collections.deque()
145165
self._lock = threading.Lock()
146-
self._prefetch_buffer_size = (
166+
self._target_buffer_size = (
147167
read_options.prefetch_buffer_size if read_options.num_threads > 0 else 0
148168
)
149169
self._num_threads = read_options.num_threads
150170
self._allow_nones = allow_nones
151-
if self._prefetch_buffer_size > 0:
171+
if self._target_buffer_size > 0:
152172
self._executor = futures.ThreadPoolExecutor(
153173
self._num_threads, thread_name_prefix="grain-prefetch"
154174
)
@@ -194,25 +214,27 @@ def __next__(self) -> T:
194214
if self._next_returned_index == self._dataset_length:
195215
break
196216
with self._lock, timer:
197-
if self._prefetch_buffer_size > 0:
217+
if self._target_buffer_size > 0:
198218
if not self._buffer:
199219
# Fill the buffer on the first iteration.
200220
self._fill_buffer()
201-
element = self._buffer.popleft()
221+
future, _ = self._buffer.popleft()
202222
# Prefetch elements until the buffer is full again.
203223
self._fill_buffer()
204-
element = element.result()
224+
element = future.result()
205225
else:
206226
# In case prefetch buffer size was decreased, we still want to consume
207227
# the already prefetched elements.
208228
if self._buffer:
209-
element = self._buffer.popleft().result()
229+
future, _ = self._buffer.popleft()
230+
element = future.result()
210231
else:
211232
element = self._stats.record_bytes_consumed(
212233
self._map_parent[self._next_returned_index]
213234
)
214235
self._next_buffered_index += 1
215236
self._next_returned_index += 1
237+
216238
return_element = self._allow_nones or element is not None
217239
self._threshold_checker.check(return_element)
218240
if return_element:
@@ -224,23 +246,26 @@ def __next__(self) -> T:
224246
def get_state(self):
225247
return {"next_index": self._next_returned_index}
226248

249+
def _set_state_helper(self, state):
250+
self._next_returned_index = state["next_index"]
251+
self._next_buffered_index = self._next_returned_index
252+
if (
253+
self._next_returned_index < 0
254+
or self._next_returned_index > self._dataset_length
255+
):
256+
raise IndexError(
257+
f"Checkpoint `next_index` {self._next_returned_index} is out of"
258+
f" range for dataset of length {self._dataset_length}."
259+
)
260+
if self._target_buffer_size > 0:
261+
# Cancel all pending futures in the buffer.
262+
while self._buffer:
263+
future, _ = self._buffer.popleft()
264+
future.cancel()
265+
227266
def set_state(self, state):
228267
with self._lock:
229-
self._next_returned_index = state["next_index"]
230-
self._next_buffered_index = self._next_returned_index
231-
if (
232-
self._next_returned_index < 0
233-
or self._next_returned_index > self._dataset_length
234-
):
235-
raise IndexError(
236-
f"Checkpoint `next_index` {self._next_returned_index} is out of"
237-
f" range for dataset of length {self._dataset_length}."
238-
)
239-
if self._prefetch_buffer_size > 0:
240-
# Cancel all pending futures in the buffer.
241-
while self._buffer:
242-
future = self._buffer.popleft()
243-
future.cancel()
268+
self._set_state_helper(state)
244269

245270
def _get_next_index(self) -> int:
246271
return self._next_returned_index
@@ -254,12 +279,12 @@ def __str__(self) -> str:
254279
f" allow_nones={self._allow_nones})"
255280
)
256281

257-
def set_prefetch_buffer_size(self, buffer_size: int):
258-
self._prefetch_buffer_size = buffer_size
282+
def set_target_buffer_size(self, buffer_size: int):
283+
self._target_buffer_size = buffer_size
259284
# The executor is created in the constructor only if the prefetch buffer
260285
# size is greater than 0. If the user changes the prefetch buffer size, we
261286
# need to create or destroy the executor accordingly.
262-
if self._prefetch_buffer_size > 0 and not hasattr(self, "_executor"):
287+
if self._target_buffer_size > 0 and not hasattr(self, "_executor"):
263288
if self._num_threads == 0:
264289
raise ValueError(
265290
"num_threads must be greater than 0 when prefetch buffer size is"
@@ -268,7 +293,7 @@ def set_prefetch_buffer_size(self, buffer_size: int):
268293
self._executor = futures.ThreadPoolExecutor(
269294
self._num_threads, thread_name_prefix="grain-prefetch"
270295
)
271-
elif self._prefetch_buffer_size == 0 and hasattr(self, "_executor"):
296+
elif self._target_buffer_size == 0 and hasattr(self, "_executor"):
272297
self._executor.shutdown()
273298
delattr(self, "_executor")
274299

@@ -292,21 +317,22 @@ def set_num_threads(self, num_threads: int) -> None:
292317

293318
def _fill_buffer(self):
294319
while (
295-
len(self._buffer) < self._prefetch_buffer_size
320+
len(self._buffer) < self._target_buffer_size
296321
and self._next_buffered_index < self._dataset_length
297322
):
298323
# Note that we trigger creation of `_stats` in this (single) thread, it is
299324
# important because the stats initialization is not thread-safe.
300-
self._buffer.append(
325+
self._buffer.append((
301326
self._executor.submit(
302327
functools.partial(_getitem, self._stats, self._map_parent),
303328
self._next_buffered_index,
304-
)
305-
)
329+
),
330+
1, # batch_size = 1 when batch pushdown is not used.
331+
))
306332
self._next_buffered_index += 1
307333

308334
def start_prefetch(self):
309-
if self._prefetch_buffer_size > 0:
335+
if self._target_buffer_size > 0:
310336
self._fill_buffer()
311337

312338
def close(self) -> None:
@@ -319,10 +345,122 @@ def close(self) -> None:
319345
self._executor.shutdown(wait=False)
320346
# Cancel all pending futures in the buffer.
321347
while self._buffer:
322-
future = self._buffer.popleft()
348+
future, _ = self._buffer.popleft()
323349
future.cancel()
324350

325351

352+
class _BatchedPrefetchDatasetIterator(PrefetchDatasetIterator[T]):
353+
"""Iterator that performs prefetching in batches using a thread pool."""
354+
355+
def __init__(
356+
self,
357+
parent: dataset.MapDataset[T],
358+
read_options: grain_options.ReadOptions,
359+
allow_nones: bool,
360+
):
361+
super().__init__(parent, read_options, allow_nones)
362+
# The number of elements to prefetch in each batch.
363+
self._batch_pushdown_size = (
364+
int(math.ceil(self._target_buffer_size / self._num_threads))
365+
if self._target_buffer_size > 0
366+
else 1
367+
)
368+
# Queue of elements from the most recently completed prefetch batch.
369+
self._current_batch = collections.deque()
370+
# Total count of elements across all pending futures in _buffer.
371+
self._total_buffered_count = 0
372+
373+
@dataset_stats.record_next_duration_if_output
374+
@dataset_stats.trace_input_pipeline_next(
375+
stage_category=dataset_stats.IPL_CAT_PREFETCH
376+
)
377+
def __next__(self) -> T:
378+
self._assert_not_closed()
379+
# The time recorded here is the time spent in prefetch node to return an
380+
# element, including the time spent in parent node.
381+
timer = dataset_stats.Timer()
382+
# We loop here to skip all None elements (in case the underlying dataset
383+
# is sparse), if self._allow_nones = False, else we return Nones too.
384+
while True:
385+
with self._lock, timer:
386+
if not self._current_batch:
387+
if self._next_returned_index == self._dataset_length:
388+
break
389+
if self._target_buffer_size > 0:
390+
if not self._buffer:
391+
# Fill the buffer on the first iteration.
392+
self._fill_buffer()
393+
future, batch_size = self._buffer.popleft()
394+
self._total_buffered_count -= batch_size
395+
# Prefetch elements until the buffer is full again.
396+
self._fill_buffer()
397+
batch = future.result()
398+
self._current_batch.extend(batch)
399+
else:
400+
# In case prefetch buffer size was decreased, we still want to
401+
# consume the already prefetched elements.
402+
if self._buffer:
403+
future, batch_size = self._buffer.popleft()
404+
batch = future.result()
405+
self._total_buffered_count -= batch_size
406+
self._current_batch.extend(batch)
407+
else:
408+
element = self._stats.record_bytes_consumed(
409+
self._map_parent[self._next_returned_index]
410+
)
411+
self._next_buffered_index += 1
412+
self._current_batch.append(element)
413+
414+
element = self._current_batch.popleft()
415+
self._next_returned_index += 1
416+
417+
return_element = self._allow_nones or element is not None
418+
self._threshold_checker.check(return_element)
419+
if return_element:
420+
with self._stats.record_self_time(offset_ns=timer.value()):
421+
element = self._stats.record_bytes_produced(element)
422+
return self._stats.record_output_spec(element)
423+
raise StopIteration
424+
425+
def _set_state_helper(self, state):
426+
super()._set_state_helper(state)
427+
self._current_batch.clear()
428+
self._total_buffered_count = 0
429+
430+
def _fill_buffer(self):
431+
while (
432+
self._total_buffered_count < self._target_buffer_size
433+
and self._next_buffered_index < self._dataset_length
434+
):
435+
batch_size = min(
436+
self._batch_pushdown_size,
437+
self._dataset_length - self._next_buffered_index,
438+
self._target_buffer_size - self._total_buffered_count,
439+
)
440+
indices = list(
441+
range(
442+
self._next_buffered_index, self._next_buffered_index + batch_size
443+
)
444+
)
445+
# Note that we trigger creation of `_stats` in this (single) thread, it is
446+
# important because the stats initialization is not thread-safe.
447+
self._buffer.append((
448+
self._executor.submit(
449+
functools.partial(_getitems, self._stats, self._map_parent),
450+
indices,
451+
),
452+
batch_size,
453+
))
454+
self._next_buffered_index += batch_size
455+
self._total_buffered_count += batch_size
456+
457+
def __str__(self) -> str:
458+
return (
459+
f"_BatchedPrefetchDatasetIterator(read_options={self._read_options},"
460+
f" allow_nones={self._allow_nones})"
461+
)
462+
463+
326464
def _set_slice_iter_dataset(
327465
ds: dataset.IterDataset,
328466
sl: slice,
@@ -729,6 +867,7 @@ def is_prefetch_iterator(it: dataset.DatasetIterator) -> bool:
729867
it,
730868
(
731869
PrefetchDatasetIterator,
870+
_BatchedPrefetchDatasetIterator,
732871
ThreadPrefetchDatasetIterator,
733872
interleave.InterleaveDatasetIterator,
734873
),

0 commit comments

Comments
 (0)