1010import os
1111from datetime import datetime
1212from pathlib import Path
13- from typing import TYPE_CHECKING , Dict , List , Optional , Union
13+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
1414
1515from loguru import logger
1616from torch .utils .data import DataLoader
@@ -253,8 +253,15 @@ def oneshot(
253253 preprocessing_num_workers : Optional [int ] = None ,
254254 min_tokens_per_module : Optional [float ] = None ,
255255 calibrate_moe_context : bool = False ,
256+ pipeline : str = "independent" ,
257+ tracing_ignore : Optional [List [str ]] = None ,
258+ raw_kwargs : Optional [Dict [str , Any ]] = None ,
259+ preprocessing_func : Optional [Callable ] = None ,
260+ max_train_samples : Optional [int ] = None ,
261+ remove_columns : Optional [List [str ]] = None ,
262+ dvc_data_repository : Optional [str ] = None ,
256263 quantization_aware_calibration : bool = True ,
257- # Miscellaneous arguments
264+ sequential_targets : Optional [ List [ str ]] = None ,
258265 output_dir : Optional [str ] = None ,
259266 log_dir : Optional [str ] = None ,
260267 ** kwargs ,
@@ -324,6 +331,16 @@ def oneshot(
324331 during forward pass in calibration. When False, quantization is disabled
325332 during forward pass in calibration. Default is set to True.
326333
334+ :param pipeline: The pipeline configuration to use for calibration. Options include
335+ 'independent', 'sequential', or 'layer_sequential'.
336+ :param tracing_ignore: List of module names to ignore during tracing.
337+ :param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
338+ :param preprocessing_func: Optional callable for preprocessing the dataset.
339+ :param max_train_samples: Maximum number of training samples to use.
340+ :param remove_columns: List of column names to remove from the dataset.
341+ :param dvc_data_repository: Path to the DVC data repository, if applicable.
342+ :param sequential_targets: List of sequential targets for calibration.
343+
327344 # Miscellaneous arguments
328345 :param output_dir: Path to save the output model after calibration.
329346 Nothing is saved if None.
@@ -333,10 +350,22 @@ def oneshot(
333350 :return: The calibrated PreTrainedModel
334351 """
335352
353+ if sequential_targets and pipeline == "independent" :
354+ raise ValueError (
355+ "Invalid configuration: "
356+ "sequential_targets' cannot be used with 'independent' pipeline. "
357+ "Please use 'sequential' or 'layer_sequential' pipeline when specifying "
358+ "sequential_targets."
359+ )
360+
336361 # pass all args directly into Oneshot
362+ if raw_kwargs is None :
363+ raw_kwargs = {}
364+
337365 local_args = {
338366 k : v for k , v in locals ().items () if k not in ("local_args" , "kwargs" )
339367 }
368+
340369 one_shot = Oneshot (** local_args , ** kwargs )
341370 one_shot ()
342371
0 commit comments