Skip to content

Commit 2038fb9

Browse files
committed
fix: support batch expression everywhere
1 parent 62a609a commit 2038fb9

File tree

2 files changed

+66
-79
lines changed

2 files changed

+66
-79
lines changed

edsnlp/core/stream.py

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing_extensions import Literal
2626

2727
import edsnlp.data
28-
from edsnlp.utils.batching import BatchBy, BatchFn, BatchSizeArg, batchify_fns
28+
from edsnlp.utils.batching import BatchBy, BatchFn, BatchSizeArg, batchify, batchify_fns
2929
from edsnlp.utils.collections import flatten, flatten_once, shuffle
3030
from edsnlp.utils.stream_sentinels import StreamSentinel
3131

@@ -47,25 +47,6 @@ def deep_isgeneratorfunction(x):
4747
raise ValueError(f"{x} does not have a __call__ or batch_process method.")
4848

4949

50-
class _InferType:
51-
# Singleton is important since the INFER object may be passed to
52-
# other processes, i.e. pickled, depickled, while it should
53-
# always be the same object.
54-
instance = None
55-
56-
def __repr__(self):
57-
return "INFER"
58-
59-
def __new__(cls, *args, **kwargs):
60-
if cls.instance is None:
61-
cls.instance = super().__new__(cls)
62-
return cls.instance
63-
64-
def __bool__(self):
65-
return False
66-
67-
68-
INFER = _InferType()
6950
CONTEXT = [{}]
7051

7152
T = TypeVar("T")
@@ -125,8 +106,8 @@ def __init__(
125106
):
126107
if batch_fn is None:
127108
if size is None:
128-
size = INFER
129-
batch_fn = INFER
109+
size = None
110+
batch_fn = None
130111
else:
131112
batch_fn = batchify_fns["docs"]
132113
self.size = size
@@ -287,12 +268,12 @@ def __init__(
287268
reader: Optional[BaseReader] = None,
288269
writer: Optional[Union[BaseWriter, BatchWriter]] = None,
289270
ops: List[Any] = [],
290-
config: Dict = {},
271+
config: Optional[Dict] = None,
291272
):
292273
self.reader = reader
293274
self.writer = writer
294275
self.ops: List[Op] = ops
295-
self.config = config
276+
self.config = config or {}
296277

297278
@classmethod
298279
def validate_batching(cls, batch_size, batch_by):
@@ -302,17 +283,12 @@ def validate_batching(cls, batch_size, batch_by):
302283
"Cannot use both a batch_size expression and a batch_by function"
303284
)
304285
batch_size, batch_by = BatchSizeArg.validate(batch_size)
305-
if (
306-
batch_size is not None
307-
and batch_size is not INFER
308-
and not isinstance(batch_size, int)
309-
):
286+
if batch_size is not None and not isinstance(batch_size, int):
310287
raise ValueError(
311288
f"Invalid batch_size (must be an integer or None): {batch_size}"
312289
)
313290
if (
314291
batch_by is not None
315-
and batch_by is not INFER
316292
and batch_by not in batchify_fns
317293
and not callable(batch_by)
318294
):
@@ -321,11 +297,11 @@ def validate_batching(cls, batch_size, batch_by):
321297

322298
@property
323299
def batch_size(self):
324-
return self.config.get("batch_size", 1)
300+
return self.config.get("batch_size", None)
325301

326302
@property
327303
def batch_by(self):
328-
return self.config.get("batch_by", "docs")
304+
return self.config.get("batch_by", None)
329305

330306
@property
331307
def disable_implicit_parallelism(self):
@@ -372,39 +348,36 @@ def deterministic(self):
372348
@with_non_default_args
373349
def set_processing(
374350
self,
375-
batch_size: int = INFER,
376-
batch_by: BatchBy = "docs",
377-
split_into_batches_after: str = INFER,
378-
num_cpu_workers: Optional[int] = INFER,
379-
num_gpu_workers: Optional[int] = INFER,
351+
batch_size: Optional[Union[int, str]] = None,
352+
batch_by: BatchBy = None,
353+
split_into_batches_after: str = None,
354+
num_cpu_workers: Optional[int] = None,
355+
num_gpu_workers: Optional[int] = None,
380356
disable_implicit_parallelism: bool = True,
381-
backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = INFER,
382-
autocast: Union[bool, Any] = INFER,
357+
backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = None,
358+
autocast: Union[bool, Any] = None,
383359
show_progress: bool = False,
384-
gpu_pipe_names: Optional[List[str]] = INFER,
385-
process_start_method: Optional[Literal["fork", "spawn"]] = INFER,
386-
gpu_worker_devices: Optional[List[str]] = INFER,
387-
cpu_worker_devices: Optional[List[str]] = INFER,
360+
gpu_pipe_names: Optional[List[str]] = None,
361+
process_start_method: Optional[Literal["fork", "spawn"]] = None,
362+
gpu_worker_devices: Optional[List[str]] = None,
363+
cpu_worker_devices: Optional[List[str]] = None,
388364
deterministic: bool = True,
389365
work_unit: Literal["record", "fragment"] = "record",
390-
chunk_size: int = INFER,
366+
chunk_size: int = None,
391367
sort_chunks: bool = False,
392368
_non_default_args: Iterable[str] = (),
393369
) -> "Stream":
394370
"""
395371
Parameters
396372
----------
397-
batch_size: int
398-
Number of documents to process at a time in a GPU worker (or in the
399-
main process if no workers are used). This is the global batch size
400-
that is used for batching methods that do not provide their own
401-
batching arguments.
373+
batch_size: Optional[Union[int, str]]
374+
The batch size. Can also be a batching expression like
375+
"32 docs", "1024 words", "dataset", "fragment", etc.
402376
batch_by: BatchBy
403377
Function to compute the batches. If set, it should take an iterable of
404378
documents and return an iterable of batches. You can also set it to
405379
"docs", "words" or "padded_words" to use predefined batching functions.
406-
Defaults to "docs". Only used for operations that do not provide their
407-
own batching arguments.
380+
Defaults to "docs".
408381
num_cpu_workers: int
409382
Number of CPU workers. A CPU worker handles the non deep-learning components
410383
and the preprocessing, collating and postprocessing of deep-learning
@@ -468,15 +441,15 @@ def set_processing(
468441
"""
469442
kwargs = {k: v for k, v in locals().items() if k in _non_default_args}
470443
if (
471-
kwargs.pop("chunk_size", INFER) is not INFER
472-
or kwargs.pop("sort_chunks", INFER) is not INFER
444+
kwargs.pop("chunk_size", None) is not None
445+
or kwargs.pop("sort_chunks", None) is not None
473446
):
474447
warnings.warn(
475448
"chunk_size and sort_chunks are deprecated, use "
476449
"map_batched(sort_fn, batch_size=chunk_size) instead.",
477450
VisibleDeprecationWarning,
478451
)
479-
if kwargs.pop("split_into_batches_after", INFER) is not INFER:
452+
if kwargs.pop("split_into_batches_after", None) is not None:
480453
warnings.warn(
481454
"split_into_batches_after is deprecated.", VisibleDeprecationWarning
482455
)
@@ -486,7 +459,7 @@ def set_processing(
486459
ops=self.ops,
487460
config={
488461
**self.config,
489-
**{k: v for k, v in kwargs.items() if v is not INFER},
462+
**{k: v for k, v in kwargs.items() if v is not None},
490463
},
491464
)
492465

@@ -690,8 +663,8 @@ def map_gpu(
690663
def map_pipeline(
691664
self,
692665
model: Pipeline,
693-
batch_size: Optional[int] = INFER,
694-
batch_by: BatchBy = INFER,
666+
batch_size: Optional[Union[int, str]] = None,
667+
batch_by: BatchBy = None,
695668
) -> "Stream":
696669
"""
697670
Maps a pipeline to the documents, i.e. adds each component of the pipeline to
@@ -974,16 +947,10 @@ def __getattr__(self, item):
974947
def _make_stages(self, split_torch_pipes: bool) -> List[Stage]:
975948
current_ops = []
976949
stages = []
977-
self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by)
978-
self_batch_size = self.batch_size
979-
assert self_batch_size is not None
980950

981951
ops = [copy(op) for op in self.ops]
982952

983953
for op in ops:
984-
if isinstance(op, BatchifyOp):
985-
op.batch_fn = self_batch_fn if op.batch_fn is INFER else op.batch_fn
986-
op.size = self_batch_size if op.size is INFER else op.size
987954
if (
988955
isinstance(op, MapBatchesOp)
989956
and hasattr(op.pipe, "forward")
@@ -1005,23 +972,39 @@ def validate_ops(self, ops, update: bool = False):
1005972
# Check batchify requirements
1006973
requires_sentinels = set()
1007974

975+
self_batch_size, self_batch_by = self.validate_batching(
976+
self.batch_size, self.batch_by
977+
)
978+
if self_batch_by is None:
979+
self_batch_by = "docs"
980+
if self_batch_size is None:
981+
self_batch_size = 1
982+
self_batch_fn = batchify_fns.get(self_batch_by, self_batch_by)
983+
1008984
if hasattr(self.writer, "batch_fn") and hasattr(
1009985
self.writer.batch_fn, "requires_sentinel"
1010986
):
1011987
requires_sentinels.add(self.writer.batch_fn.requires_sentinel)
1012988

1013-
self_batch_fn = batchify_fns.get(self.batch_by, self.batch_by)
1014989
for op in reversed(ops):
1015990
if isinstance(op, BatchifyOp):
1016-
batch_fn = op.batch_fn or self_batch_fn
991+
if op.batch_fn is None and op.size is None:
992+
batch_size = self_batch_size
993+
batch_fn = self_batch_fn
994+
elif op.batch_fn is None:
995+
batch_size = op.size
996+
batch_fn = batchify
997+
else:
998+
batch_size = op.size
999+
batch_fn = op.batch_fn
10171000
sentinel_mode = op.sentinel_mode or (
10181001
"auto"
10191002
if "sentinel_mode" in signature(batch_fn).parameters
10201003
else None
10211004
)
10221005
if sentinel_mode == "auto":
10231006
sentinel_mode = "split" if requires_sentinels else "drop"
1024-
if requires_sentinels and op.sentinel_mode == "drop":
1007+
if requires_sentinels and sentinel_mode == "drop":
10251008
raise ValueError(
10261009
f"Operation {op} drops the stream sentinel values "
10271010
f"(markers for the end of a dataset or a dataset "
@@ -1031,10 +1014,12 @@ def validate_ops(self, ops, update: bool = False):
10311014
f"any upstream batching operation."
10321015
)
10331016
if update:
1017+
op.size = batch_size
1018+
op.batch_fn = batch_fn
10341019
op.sentinel_mode = sentinel_mode
10351020

1036-
if hasattr(batch_fn, "requires_sentinel"):
1037-
requires_sentinels.add(batch_fn.requires_sentinel)
1021+
if hasattr(op.batch_fn, "requires_sentinel"):
1022+
requires_sentinels.add(op.batch_fn.requires_sentinel)
10381023

10391024
sentinel_str = ", ".join(requires_sentinels)
10401025
if requires_sentinels and self.backend == "spark":

tests/data/test_stream.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,22 @@ def forward(batch):
5858
assert set(res.tolist()) == {i * 2 for i in range(15)}
5959

6060

61+
# fmt: off
6162
@pytest.mark.parametrize(
62-
"sort,num_cpu_workers,batch_by,expected",
63+
"sort,num_cpu_workers,batch_kwargs,expected",
6364
[
64-
(False, 1, "words", [3, 1, 3, 1, 3, 1]),
65-
(False, 1, "padded_words", [2, 1, 1, 2, 1, 1, 2, 1, 1]),
66-
(False, 1, "docs", [10, 2]),
67-
(False, 2, "words", [2, 1, 2, 1, 2, 1, 1, 1, 1]),
68-
(False, 2, "padded_words", [2, 1, 2, 1, 2, 1, 1, 1, 1]),
69-
(False, 2, "docs", [6, 6]),
70-
(True, 2, "padded_words", [3, 3, 2, 1, 1, 1, 1]),
65+
(False, 1, {"batch_size": 10, "batch_by": "words"}, [3, 1, 3, 1, 3, 1]), # noqa: E501
66+
(False, 1, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 1, 2, 1, 1, 2, 1, 1]), # noqa: E501
67+
(False, 1, {"batch_size": 10, "batch_by": "docs"}, [10, 2]), # noqa: E501
68+
(False, 2, {"batch_size": 10, "batch_by": "words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501
69+
(False, 2, {"batch_size": 10, "batch_by": "padded_words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501
70+
(False, 2, {"batch_size": 10, "batch_by": "docs"}, [6, 6]), # noqa: E501
71+
(True, 2, {"batch_size": 10, "batch_by": "padded_words"}, [3, 3, 2, 1, 1, 1, 1]), # noqa: E501
72+
(False, 2, {"batch_size": "10 words"}, [2, 1, 2, 1, 2, 1, 1, 1, 1]), # noqa: E501
7173
],
7274
)
73-
def test_map_with_batching(sort, num_cpu_workers, batch_by, expected):
75+
# fmt: on
76+
def test_map_with_batching(sort, num_cpu_workers, batch_kwargs, expected):
7477
nlp = edsnlp.blank("eds")
7578
nlp.add_pipe(
7679
"eds.matcher",
@@ -94,8 +97,7 @@ def test_map_with_batching(sort, num_cpu_workers, batch_by, expected):
9497
stream = stream.map_batches(len)
9598
stream = stream.set_processing(
9699
num_cpu_workers=num_cpu_workers,
97-
batch_size=10,
98-
batch_by=batch_by,
100+
**batch_kwargs,
99101
chunk_size=1000, # deprecated
100102
split_into_batches_after="matcher",
101103
show_progress=True,

0 commit comments

Comments
 (0)