1010import os
1111from datetime import datetime
1212from pathlib import Path
13- from typing import TYPE_CHECKING , Dict , List , Optional , Union
13+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
1414
1515from loguru import logger
1616from torch .utils .data import DataLoader
@@ -258,8 +258,15 @@ def oneshot(
258258 preprocessing_num_workers : Optional [int ] = None ,
259259 min_tokens_per_module : Optional [float ] = None ,
260260 moe_calibrate_all_experts : bool = True ,
261+ pipeline : str = "independent" ,
262+ tracing_ignore : Optional [List [str ]] = None ,
263+ raw_kwargs : Optional [Dict [str , Any ]] = None ,
264+ preprocessing_func : Optional [Callable ] = None ,
265+ max_train_samples : Optional [int ] = None ,
266+ remove_columns : Optional [List [str ]] = None ,
267+ dvc_data_repository : Optional [str ] = None ,
261268 quantization_aware_calibration : bool = True ,
262- # Miscellaneous arguments
269+ sequential_targets : Optional [ List [ str ]] = None ,
263270 output_dir : Optional [str ] = None ,
264271 log_dir : Optional [str ] = None ,
265272 ** kwargs ,
@@ -328,6 +335,16 @@ def oneshot(
328335 during forward pass in calibration. When False, quantization is disabled
329336 during forward pass in calibration. Default is set to True.
330337
338+ :param pipeline: The pipeline configuration to use for calibration. Options include
339+ 'independent', 'sequential', or 'layer_sequential'.
340+ :param tracing_ignore: List of module names to ignore during tracing.
341+ :param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
342+ :param preprocessing_func: Optional callable for preprocessing the dataset.
343+ :param max_train_samples: Maximum number of training samples to use.
344+ :param remove_columns: List of column names to remove from the dataset.
345+ :param dvc_data_repository: Path to the DVC data repository, if applicable.
346+ :param sequential_targets: List of sequential targets for calibration.
347+
331348 # Miscellaneous arguments
332349 :param output_dir: Path to save the output model after calibration.
333350 Nothing is saved if None.
@@ -337,10 +354,18 @@ def oneshot(
337354 :return: The calibrated PreTrainedModel
338355 """
339356
340- # pass all args directly into Oneshot
357+ if sequential_targets and pipeline == "independent" :
358+ raise ValueError (
359+ "Invalid configuration: "
360+ "sequential_targets' cannot be used with 'independent' pipeline. "
361+ "Please use 'sequential' or 'layer_sequential' pipeline when specifying "
362+ "sequential_targets."
363+ )
364+
341365 local_args = {
342366 k : v for k , v in locals ().items () if k not in ("local_args" , "kwargs" )
343367 }
368+
344369 one_shot = Oneshot (** local_args , ** kwargs )
345370 one_shot ()
346371
0 commit comments