2626from litdata .processing .data_processor import DataChunkRecipe , DataProcessor , DataTransformRecipe
2727from litdata .processing .readers import BaseReader
2828from litdata .processing .utilities import optimize_dns_context
29+ from litdata .streaming .dataloader import StreamingDataLoader
2930from litdata .streaming .resolver import (
3031 Dir ,
3132 _assert_dir_has_index_file ,
@@ -176,6 +177,7 @@ def map(
176177 inputs: A sequence of input to be processed by the `fn` function.
177178 Each input should contain at least a valid filepath.
178179 output_dir: The folder where the processed data should be stored.
180+ weights: Provide an associated weight to each input. This is used to balance work among workers.
179181 num_workers: The number of workers to use during processing
180182 fast_dev_run: Whether to use process only a sub part of the inputs
181183 num_nodes: When doing remote execution, the number of nodes to use. Only supported on https://lightning.ai/.
@@ -188,8 +190,14 @@ def map(
188190 batch_size: Group the inputs into batches of batch_size length.
189191
190192 """
191- if not isinstance (inputs , Sequence ):
192- raise ValueError (f"The provided inputs should be non empty sequence. Found { inputs } ." )
193+ if isinstance (inputs , StreamingDataLoader ) and batch_size is not None :
194+ raise ValueError ("When providing a streaming dataloader, pass the batch_size to the dataloader directly." )
195+
196+ if isinstance (inputs , StreamingDataLoader ) and weights is not None :
197+ raise ValueError ("When providing a streaming dataloader, weights isn't supported." )
198+
199+ if not isinstance (inputs , (Sequence , StreamingDataLoader )):
200+ raise ValueError (f"The provided inputs should be non empty sequence or a streaming dataloader. Found { inputs } ." )
193201
194202 if len (inputs ) == 0 :
195203 raise ValueError (f"The provided inputs should be non empty. Found { inputs } ." )
@@ -218,10 +226,13 @@ def map(
218226 if error_when_not_empty :
219227 _assert_dir_is_empty (_output_dir )
220228
221- input_dir = _resolve_dir (_get_input_dir (inputs ))
229+ if not isinstance (inputs , StreamingDataLoader ):
230+ input_dir = _resolve_dir (_get_input_dir (inputs ))
222231
223- if isinstance (batch_size , int ) and batch_size > 1 :
224- inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
232+ if isinstance (batch_size , int ) and batch_size > 1 :
233+ inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
234+ else :
235+ input_dir = Dir ()
225236
226237 data_processor = DataProcessor (
227238 input_dir = input_dir ,
@@ -247,6 +258,7 @@ def optimize(
247258 fn : Callable [[Any ], Any ],
248259 inputs : Sequence [Any ],
249260 output_dir : str ,
261+ weights : Optional [List [int ]] = None ,
250262 chunk_size : Optional [int ] = None ,
251263 chunk_bytes : Optional [Union [int , str ]] = None ,
252264 compression : Optional [str ] = None ,
@@ -267,6 +279,7 @@ def optimize(
267279 inputs: A sequence of input to be processed by the `fn` function.
268280 Each input should contain at least a valid filepath.
269281 output_dir: The folder where the processed data should be stored.
282+ weights: Provide an associated weight to each input. This is used to balance work among workers.
270283 chunk_size: The maximum number of elements to hold within a chunk.
271284 chunk_bytes: The maximum number of bytes to hold within a chunk.
272285 compression: The compression algorithm to use over the chunks.
@@ -281,8 +294,14 @@ def optimize(
281294 batch_size: Group the inputs into batches of batch_size length.
282295
283296 """
284- if not isinstance (inputs , Sequence ):
285- raise ValueError (f"The provided inputs should be non empty sequence. Found { inputs } ." )
297+ if isinstance (inputs , StreamingDataLoader ) and batch_size is not None :
298+ raise ValueError ("When providing a streaming dataloader, pass the batch_size to the dataloader directly." )
299+
300+ if isinstance (inputs , StreamingDataLoader ) and weights is not None :
301+ raise ValueError ("When providing a streaming dataloader, weights isn't supported." )
302+
303+ if not isinstance (inputs , (Sequence , StreamingDataLoader )):
304+ raise ValueError (f"The provided inputs should be non empty sequence or a streaming dataloader. Found { inputs } ." )
286305
287306 if len (inputs ) == 0 :
288307 raise ValueError (f"The provided inputs should be non empty. Found { inputs } ." )
@@ -313,10 +332,13 @@ def optimize(
313332
314333 _assert_dir_has_index_file (_output_dir )
315334
316- input_dir = _resolve_dir (_get_input_dir (inputs ))
335+ if not isinstance (inputs , StreamingDataLoader ):
336+ input_dir = _resolve_dir (_get_input_dir (inputs ))
317337
318- if isinstance (batch_size , int ) and batch_size > 1 :
319- inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
338+ if isinstance (batch_size , int ) and batch_size > 1 :
339+ inputs = [inputs [pos : pos + batch_size ] for pos in range (0 , len (inputs ), batch_size )]
340+ else :
341+ input_dir = Dir ()
320342
321343 data_processor = DataProcessor (
322344 input_dir = input_dir ,
0 commit comments