25
25
from typing_extensions import Literal
26
26
27
27
import edsnlp .data
28
- from edsnlp .utils .batching import BatchBy , BatchFn , BatchSizeArg , batchify_fns
28
+ from edsnlp .utils .batching import BatchBy , BatchFn , BatchSizeArg , batchify , batchify_fns
29
29
from edsnlp .utils .collections import flatten , flatten_once , shuffle
30
30
from edsnlp .utils .stream_sentinels import StreamSentinel
31
31
@@ -47,25 +47,6 @@ def deep_isgeneratorfunction(x):
47
47
raise ValueError (f"{ x } does not have a __call__ or batch_process method." )
48
48
49
49
50
- class _InferType :
51
- # Singleton is important since the INFER object may be passed to
52
- # other processes, i.e. pickled, depickled, while it should
53
- # always be the same object.
54
- instance = None
55
-
56
- def __repr__ (self ):
57
- return "INFER"
58
-
59
- def __new__ (cls , * args , ** kwargs ):
60
- if cls .instance is None :
61
- cls .instance = super ().__new__ (cls )
62
- return cls .instance
63
-
64
- def __bool__ (self ):
65
- return False
66
-
67
-
68
- INFER = _InferType ()
69
50
CONTEXT = [{}]
70
51
71
52
T = TypeVar ("T" )
@@ -125,8 +106,8 @@ def __init__(
125
106
):
126
107
if batch_fn is None :
127
108
if size is None :
128
- size = INFER
129
- batch_fn = INFER
109
+ size = None
110
+ batch_fn = None
130
111
else :
131
112
batch_fn = batchify_fns ["docs" ]
132
113
self .size = size
@@ -287,12 +268,12 @@ def __init__(
287
268
reader : Optional [BaseReader ] = None ,
288
269
writer : Optional [Union [BaseWriter , BatchWriter ]] = None ,
289
270
ops : List [Any ] = [],
290
- config : Dict = {} ,
271
+ config : Optional [ Dict ] = None ,
291
272
):
292
273
self .reader = reader
293
274
self .writer = writer
294
275
self .ops : List [Op ] = ops
295
- self .config = config
276
+ self .config = config or {}
296
277
297
278
@classmethod
298
279
def validate_batching (cls , batch_size , batch_by ):
@@ -302,17 +283,12 @@ def validate_batching(cls, batch_size, batch_by):
302
283
"Cannot use both a batch_size expression and a batch_by function"
303
284
)
304
285
batch_size , batch_by = BatchSizeArg .validate (batch_size )
305
- if (
306
- batch_size is not None
307
- and batch_size is not INFER
308
- and not isinstance (batch_size , int )
309
- ):
286
+ if batch_size is not None and not isinstance (batch_size , int ):
310
287
raise ValueError (
311
288
f"Invalid batch_size (must be an integer or None): { batch_size } "
312
289
)
313
290
if (
314
291
batch_by is not None
315
- and batch_by is not INFER
316
292
and batch_by not in batchify_fns
317
293
and not callable (batch_by )
318
294
):
@@ -321,11 +297,11 @@ def validate_batching(cls, batch_size, batch_by):
321
297
322
298
@property
323
299
def batch_size (self ):
324
- return self .config .get ("batch_size" , 1 )
300
+ return self .config .get ("batch_size" , None )
325
301
326
302
@property
327
303
def batch_by (self ):
328
- return self .config .get ("batch_by" , "docs" )
304
+ return self .config .get ("batch_by" , None )
329
305
330
306
@property
331
307
def disable_implicit_parallelism (self ):
@@ -372,39 +348,36 @@ def deterministic(self):
372
348
@with_non_default_args
373
349
def set_processing (
374
350
self ,
375
- batch_size : int = INFER ,
376
- batch_by : BatchBy = "docs" ,
377
- split_into_batches_after : str = INFER ,
378
- num_cpu_workers : Optional [int ] = INFER ,
379
- num_gpu_workers : Optional [int ] = INFER ,
351
+ batch_size : Optional [ Union [ int , str ]] = None ,
352
+ batch_by : BatchBy = None ,
353
+ split_into_batches_after : str = None ,
354
+ num_cpu_workers : Optional [int ] = None ,
355
+ num_gpu_workers : Optional [int ] = None ,
380
356
disable_implicit_parallelism : bool = True ,
381
- backend : Optional [Literal ["simple" , "multiprocessing" , "mp" , "spark" ]] = INFER ,
382
- autocast : Union [bool , Any ] = INFER ,
357
+ backend : Optional [Literal ["simple" , "multiprocessing" , "mp" , "spark" ]] = None ,
358
+ autocast : Union [bool , Any ] = None ,
383
359
show_progress : bool = False ,
384
- gpu_pipe_names : Optional [List [str ]] = INFER ,
385
- process_start_method : Optional [Literal ["fork" , "spawn" ]] = INFER ,
386
- gpu_worker_devices : Optional [List [str ]] = INFER ,
387
- cpu_worker_devices : Optional [List [str ]] = INFER ,
360
+ gpu_pipe_names : Optional [List [str ]] = None ,
361
+ process_start_method : Optional [Literal ["fork" , "spawn" ]] = None ,
362
+ gpu_worker_devices : Optional [List [str ]] = None ,
363
+ cpu_worker_devices : Optional [List [str ]] = None ,
388
364
deterministic : bool = True ,
389
365
work_unit : Literal ["record" , "fragment" ] = "record" ,
390
- chunk_size : int = INFER ,
366
+ chunk_size : int = None ,
391
367
sort_chunks : bool = False ,
392
368
_non_default_args : Iterable [str ] = (),
393
369
) -> "Stream" :
394
370
"""
395
371
Parameters
396
372
----------
397
- batch_size: int
398
- Number of documents to process at a time in a GPU worker (or in the
399
- main process if no workers are used). This is the global batch size
400
- that is used for batching methods that do not provide their own
401
- batching arguments.
373
+ batch_size: Optional[Union[int, str]]
374
+ The batch size. Can also be a batching expression like
375
+ "32 docs", "1024 words", "dataset", "fragment", etc.
402
376
batch_by: BatchBy
403
377
Function to compute the batches. If set, it should take an iterable of
404
378
documents and return an iterable of batches. You can also set it to
405
379
"docs", "words" or "padded_words" to use predefined batching functions.
406
- Defaults to "docs". Only used for operations that do not provide their
407
- own batching arguments.
380
+ Defaults to "docs".
408
381
num_cpu_workers: int
409
382
Number of CPU workers. A CPU worker handles the non deep-learning components
410
383
and the preprocessing, collating and postprocessing of deep-learning
@@ -468,15 +441,15 @@ def set_processing(
468
441
"""
469
442
kwargs = {k : v for k , v in locals ().items () if k in _non_default_args }
470
443
if (
471
- kwargs .pop ("chunk_size" , INFER ) is not INFER
472
- or kwargs .pop ("sort_chunks" , INFER ) is not INFER
444
+ kwargs .pop ("chunk_size" , None ) is not None
445
+ or kwargs .pop ("sort_chunks" , None ) is not None
473
446
):
474
447
warnings .warn (
475
448
"chunk_size and sort_chunks are deprecated, use "
476
449
"map_batched(sort_fn, batch_size=chunk_size) instead." ,
477
450
VisibleDeprecationWarning ,
478
451
)
479
- if kwargs .pop ("split_into_batches_after" , INFER ) is not INFER :
452
+ if kwargs .pop ("split_into_batches_after" , None ) is not None :
480
453
warnings .warn (
481
454
"split_into_batches_after is deprecated." , VisibleDeprecationWarning
482
455
)
@@ -486,7 +459,7 @@ def set_processing(
486
459
ops = self .ops ,
487
460
config = {
488
461
** self .config ,
489
- ** {k : v for k , v in kwargs .items () if v is not INFER },
462
+ ** {k : v for k , v in kwargs .items () if v is not None },
490
463
},
491
464
)
492
465
@@ -690,8 +663,8 @@ def map_gpu(
690
663
def map_pipeline (
691
664
self ,
692
665
model : Pipeline ,
693
- batch_size : Optional [int ] = INFER ,
694
- batch_by : BatchBy = INFER ,
666
+ batch_size : Optional [Union [ int , str ]] = None ,
667
+ batch_by : BatchBy = None ,
695
668
) -> "Stream" :
696
669
"""
697
670
Maps a pipeline to the documents, i.e. adds each component of the pipeline to
@@ -974,16 +947,10 @@ def __getattr__(self, item):
974
947
def _make_stages (self , split_torch_pipes : bool ) -> List [Stage ]:
975
948
current_ops = []
976
949
stages = []
977
- self_batch_fn = batchify_fns .get (self .batch_by , self .batch_by )
978
- self_batch_size = self .batch_size
979
- assert self_batch_size is not None
980
950
981
951
ops = [copy (op ) for op in self .ops ]
982
952
983
953
for op in ops :
984
- if isinstance (op , BatchifyOp ):
985
- op .batch_fn = self_batch_fn if op .batch_fn is INFER else op .batch_fn
986
- op .size = self_batch_size if op .size is INFER else op .size
987
954
if (
988
955
isinstance (op , MapBatchesOp )
989
956
and hasattr (op .pipe , "forward" )
@@ -1005,23 +972,39 @@ def validate_ops(self, ops, update: bool = False):
1005
972
# Check batchify requirements
1006
973
requires_sentinels = set ()
1007
974
975
+ self_batch_size , self_batch_by = self .validate_batching (
976
+ self .batch_size , self .batch_by
977
+ )
978
+ if self_batch_by is None :
979
+ self_batch_by = "docs"
980
+ if self_batch_size is None :
981
+ self_batch_size = 1
982
+ self_batch_fn = batchify_fns .get (self_batch_by , self_batch_by )
983
+
1008
984
if hasattr (self .writer , "batch_fn" ) and hasattr (
1009
985
self .writer .batch_fn , "requires_sentinel"
1010
986
):
1011
987
requires_sentinels .add (self .writer .batch_fn .requires_sentinel )
1012
988
1013
- self_batch_fn = batchify_fns .get (self .batch_by , self .batch_by )
1014
989
for op in reversed (ops ):
1015
990
if isinstance (op , BatchifyOp ):
1016
- batch_fn = op .batch_fn or self_batch_fn
991
+ if op .batch_fn is None and op .size is None :
992
+ batch_size = self_batch_size
993
+ batch_fn = self_batch_fn
994
+ elif op .batch_fn is None :
995
+ batch_size = op .size
996
+ batch_fn = batchify
997
+ else :
998
+ batch_size = op .size
999
+ batch_fn = op .batch_fn
1017
1000
sentinel_mode = op .sentinel_mode or (
1018
1001
"auto"
1019
1002
if "sentinel_mode" in signature (batch_fn ).parameters
1020
1003
else None
1021
1004
)
1022
1005
if sentinel_mode == "auto" :
1023
1006
sentinel_mode = "split" if requires_sentinels else "drop"
1024
- if requires_sentinels and op . sentinel_mode == "drop" :
1007
+ if requires_sentinels and sentinel_mode == "drop" :
1025
1008
raise ValueError (
1026
1009
f"Operation { op } drops the stream sentinel values "
1027
1010
f"(markers for the end of a dataset or a dataset "
@@ -1031,10 +1014,12 @@ def validate_ops(self, ops, update: bool = False):
1031
1014
f"any upstream batching operation."
1032
1015
)
1033
1016
if update :
1017
+ op .size = batch_size
1018
+ op .batch_fn = batch_fn
1034
1019
op .sentinel_mode = sentinel_mode
1035
1020
1036
- if hasattr (batch_fn , "requires_sentinel" ):
1037
- requires_sentinels .add (batch_fn .requires_sentinel )
1021
+ if hasattr (op . batch_fn , "requires_sentinel" ):
1022
+ requires_sentinels .add (op . batch_fn .requires_sentinel )
1038
1023
1039
1024
sentinel_str = ", " .join (requires_sentinels )
1040
1025
if requires_sentinels and self .backend == "spark" :
0 commit comments