2525from typing_extensions import Literal
2626
2727import edsnlp .data
28- from edsnlp .utils .batching import BatchBy , BatchFn , BatchSizeArg , batchify_fns
28+ from edsnlp .utils .batching import BatchBy , BatchFn , BatchSizeArg , batchify , batchify_fns
2929from edsnlp .utils .collections import flatten , flatten_once , shuffle
3030from edsnlp .utils .stream_sentinels import StreamSentinel
3131
@@ -47,25 +47,6 @@ def deep_isgeneratorfunction(x):
4747 raise ValueError (f"{ x } does not have a __call__ or batch_process method." )
4848
4949
50- class _InferType :
51- # Singleton is important since the INFER object may be passed to
52- # other processes, i.e. pickled, depickled, while it should
53- # always be the same object.
54- instance = None
55-
56- def __repr__ (self ):
57- return "INFER"
58-
59- def __new__ (cls , * args , ** kwargs ):
60- if cls .instance is None :
61- cls .instance = super ().__new__ (cls )
62- return cls .instance
63-
64- def __bool__ (self ):
65- return False
66-
67-
68- INFER = _InferType ()
6950CONTEXT = [{}]
7051
7152T = TypeVar ("T" )
@@ -125,8 +106,8 @@ def __init__(
125106 ):
126107 if batch_fn is None :
127108 if size is None :
128- size = INFER
129- batch_fn = INFER
109+ size = None
110+ batch_fn = None
130111 else :
131112 batch_fn = batchify_fns ["docs" ]
132113 self .size = size
@@ -302,17 +283,13 @@ def validate_batching(cls, batch_size, batch_by):
302283 "Cannot use both a batch_size expression and a batch_by function"
303284 )
304285 batch_size , batch_by = BatchSizeArg .validate (batch_size )
305- if (
306- batch_size is not None
307- and batch_size is not INFER
308- and not isinstance (batch_size , int )
309- ):
286+ if batch_size is not None and not isinstance (batch_size , int ):
310287 raise ValueError (
311288 f"Invalid batch_size (must be an integer or None): { batch_size } "
312289 )
313290 if (
314291 batch_by is not None
315- and batch_by is not INFER
292+ and batch_by is not None
316293 and batch_by not in batchify_fns
317294 and not callable (batch_by )
318295 ):
@@ -321,11 +298,11 @@ def validate_batching(cls, batch_size, batch_by):
321298
322299 @property
323300 def batch_size (self ):
324- return self .config .get ("batch_size" , 1 )
301+ return self .config .get ("batch_size" , None )
325302
326303 @property
327304 def batch_by (self ):
328- return self .config .get ("batch_by" , "docs" )
305+ return self .config .get ("batch_by" , None )
329306
330307 @property
331308 def disable_implicit_parallelism (self ):
@@ -372,39 +349,36 @@ def deterministic(self):
372349 @with_non_default_args
373350 def set_processing (
374351 self ,
375- batch_size : int = INFER ,
376- batch_by : BatchBy = "docs" ,
377- split_into_batches_after : str = INFER ,
378- num_cpu_workers : Optional [int ] = INFER ,
379- num_gpu_workers : Optional [int ] = INFER ,
352+ batch_size : Optional [ Union [ int , str ]] = None ,
353+ batch_by : BatchBy = None ,
354+ split_into_batches_after : str = None ,
355+ num_cpu_workers : Optional [int ] = None ,
356+ num_gpu_workers : Optional [int ] = None ,
380357 disable_implicit_parallelism : bool = True ,
381- backend : Optional [Literal ["simple" , "multiprocessing" , "mp" , "spark" ]] = INFER ,
382- autocast : Union [bool , Any ] = INFER ,
358+ backend : Optional [Literal ["simple" , "multiprocessing" , "mp" , "spark" ]] = None ,
359+ autocast : Union [bool , Any ] = None ,
383360 show_progress : bool = False ,
384- gpu_pipe_names : Optional [List [str ]] = INFER ,
385- process_start_method : Optional [Literal ["fork" , "spawn" ]] = INFER ,
386- gpu_worker_devices : Optional [List [str ]] = INFER ,
387- cpu_worker_devices : Optional [List [str ]] = INFER ,
361+ gpu_pipe_names : Optional [List [str ]] = None ,
362+ process_start_method : Optional [Literal ["fork" , "spawn" ]] = None ,
363+ gpu_worker_devices : Optional [List [str ]] = None ,
364+ cpu_worker_devices : Optional [List [str ]] = None ,
388365 deterministic : bool = True ,
389366 work_unit : Literal ["record" , "fragment" ] = "record" ,
390- chunk_size : int = INFER ,
367+ chunk_size : int = None ,
391368 sort_chunks : bool = False ,
392369 _non_default_args : Iterable [str ] = (),
393370 ) -> "Stream" :
394371 """
395372 Parameters
396373 ----------
397- batch_size: int
398- Number of documents to process at a time in a GPU worker (or in the
399- main process if no workers are used). This is the global batch size
400- that is used for batching methods that do not provide their own
401- batching arguments.
374+ batch_size: Optional[Union[int, str]]
375+ The batch size. Can also be a batching expression like
376+ "32 docs", "1024 words", "dataset", "fragment", etc.
402377 batch_by: BatchBy
403378 Function to compute the batches. If set, it should take an iterable of
404379 documents and return an iterable of batches. You can also set it to
405380 "docs", "words" or "padded_words" to use predefined batching functions.
406- Defaults to "docs". Only used for operations that do not provide their
407- own batching arguments.
381+ Defaults to "docs".
408382 num_cpu_workers: int
409383 Number of CPU workers. A CPU worker handles the non deep-learning components
410384 and the preprocessing, collating and postprocessing of deep-learning
@@ -468,15 +442,15 @@ def set_processing(
468442 """
469443 kwargs = {k : v for k , v in locals ().items () if k in _non_default_args }
470444 if (
471- kwargs .pop ("chunk_size" , INFER ) is not INFER
472- or kwargs .pop ("sort_chunks" , INFER ) is not INFER
445+ kwargs .pop ("chunk_size" , None ) is not None
446+ or kwargs .pop ("sort_chunks" , None ) is not None
473447 ):
474448 warnings .warn (
475449 "chunk_size and sort_chunks are deprecated, use "
476450 "map_batched(sort_fn, batch_size=chunk_size) instead." ,
477451 VisibleDeprecationWarning ,
478452 )
479- if kwargs .pop ("split_into_batches_after" , INFER ) is not INFER :
453+ if kwargs .pop ("split_into_batches_after" , None ) is not None :
480454 warnings .warn (
481455 "split_into_batches_after is deprecated." , VisibleDeprecationWarning
482456 )
@@ -486,7 +460,7 @@ def set_processing(
486460 ops = self .ops ,
487461 config = {
488462 ** self .config ,
489- ** {k : v for k , v in kwargs .items () if v is not INFER },
463+ ** {k : v for k , v in kwargs .items () if v is not None },
490464 },
491465 )
492466
@@ -690,8 +664,8 @@ def map_gpu(
690664 def map_pipeline (
691665 self ,
692666 model : Pipeline ,
693- batch_size : Optional [int ] = INFER ,
694- batch_by : BatchBy = INFER ,
667+ batch_size : Optional [Union [ int , str ]] = None ,
668+ batch_by : BatchBy = None ,
695669 ) -> "Stream" :
696670 """
697671 Maps a pipeline to the documents, i.e. adds each component of the pipeline to
@@ -974,16 +948,10 @@ def __getattr__(self, item):
974948 def _make_stages (self , split_torch_pipes : bool ) -> List [Stage ]:
975949 current_ops = []
976950 stages = []
977- self_batch_fn = batchify_fns .get (self .batch_by , self .batch_by )
978- self_batch_size = self .batch_size
979- assert self_batch_size is not None
980951
981952 ops = [copy (op ) for op in self .ops ]
982953
983954 for op in ops :
984- if isinstance (op , BatchifyOp ):
985- op .batch_fn = self_batch_fn if op .batch_fn is INFER else op .batch_fn
986- op .size = self_batch_size if op .size is INFER else op .size
987955 if (
988956 isinstance (op , MapBatchesOp )
989957 and hasattr (op .pipe , "forward" )
@@ -1005,23 +973,39 @@ def validate_ops(self, ops, update: bool = False):
1005973 # Check batchify requirements
1006974 requires_sentinels = set ()
1007975
976+ self_batch_size , self_batch_by = self .validate_batching (
977+ self .batch_size , self .batch_by
978+ )
979+ if self_batch_by is None :
980+ self_batch_by = "docs"
981+ if self_batch_size is None :
982+ self_batch_size = 1
983+ self_batch_fn = batchify_fns .get (self_batch_by , self_batch_by )
984+
1008985 if hasattr (self .writer , "batch_fn" ) and hasattr (
1009986 self .writer .batch_fn , "requires_sentinel"
1010987 ):
1011988 requires_sentinels .add (self .writer .batch_fn .requires_sentinel )
1012989
1013- self_batch_fn = batchify_fns .get (self .batch_by , self .batch_by )
1014990 for op in reversed (ops ):
1015991 if isinstance (op , BatchifyOp ):
1016- batch_fn = op .batch_fn or self_batch_fn
992+ if op .batch_fn is None and op .size is None :
993+ batch_size = self_batch_size
994+ batch_fn = self_batch_fn
995+ elif op .batch_fn is None :
996+ batch_size = op .size
997+ batch_fn = batchify
998+ else :
999+ batch_size = op .size
1000+ batch_fn = op .batch_fn
10171001 sentinel_mode = op .sentinel_mode or (
10181002 "auto"
10191003 if "sentinel_mode" in signature (batch_fn ).parameters
10201004 else None
10211005 )
10221006 if sentinel_mode == "auto" :
10231007 sentinel_mode = "split" if requires_sentinels else "drop"
1024- if requires_sentinels and op . sentinel_mode == "drop" :
1008+ if requires_sentinels and sentinel_mode == "drop" :
10251009 raise ValueError (
10261010 f"Operation { op } drops the stream sentinel values "
10271011 f"(markers for the end of a dataset or a dataset "
@@ -1031,10 +1015,12 @@ def validate_ops(self, ops, update: bool = False):
10311015 f"any upstream batching operation."
10321016 )
10331017 if update :
1018+ op .size = batch_size
1019+ op .batch_fn = batch_fn
10341020 op .sentinel_mode = sentinel_mode
10351021
1036- if hasattr (batch_fn , "requires_sentinel" ):
1037- requires_sentinels .add (batch_fn .requires_sentinel )
1022+ if hasattr (op . batch_fn , "requires_sentinel" ):
1023+ requires_sentinels .add (op . batch_fn .requires_sentinel )
10381024
10391025 sentinel_str = ", " .join (requires_sentinels )
10401026 if requires_sentinels and self .backend == "spark" :
0 commit comments