1010import os
1111from datetime import datetime
1212from pathlib import Path
13- from typing import TYPE_CHECKING , Dict , List , Optional , Union
13+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Union
1414
1515from loguru import logger
1616from torch .utils .data import DataLoader
@@ -259,8 +259,15 @@ def oneshot(
259259 preprocessing_num_workers : Optional [int ] = None ,
260260 min_tokens_per_module : Optional [float ] = None ,
261261 moe_calibrate_all_experts : bool = True ,
262+ pipeline : str = "independent" ,
263+ tracing_ignore : Optional [List [str ]] = None ,
264+ raw_kwargs : Optional [Dict [str , Any ]] = None ,
265+ preprocessing_func : Optional [Callable ] = None ,
266+ max_train_samples : Optional [int ] = None ,
267+ remove_columns : Optional [List [str ]] = None ,
268+ dvc_data_repository : Optional [str ] = None ,
262269 quantization_aware_calibration : bool = True ,
263- # Miscellaneous arguments
270+ sequential_targets : Optional [ List [ str ]] = None ,
264271 output_dir : Optional [str ] = None ,
265272 log_dir : Optional [str ] = None ,
266273 ** kwargs ,
@@ -331,6 +338,16 @@ def oneshot(
331338 during forward pass in calibration. When False, quantization is disabled
332339 during forward pass in calibration. Default is set to True.
333340
341+ :param pipeline: The pipeline configuration to use for calibration. Options include
342+ 'independent', 'sequential', or 'layer_sequential'.
343+ :param tracing_ignore: List of module names to ignore during tracing.
344+ :param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
345+ :param preprocessing_func: Optional callable for preprocessing the dataset.
346+ :param max_train_samples: Maximum number of training samples to use.
347+ :param remove_columns: List of column names to remove from the dataset.
348+ :param dvc_data_repository: Path to the DVC data repository, if applicable.
349+ :param sequential_targets: List of sequential targets for calibration.
350+
334351 # Miscellaneous arguments
335352 :param output_dir: Path to save the output model after calibration.
336353 Nothing is saved if None.
@@ -340,10 +357,18 @@ def oneshot(
340357 :return: The calibrated PreTrainedModel
341358 """
342359
343- # pass all args directly into Oneshot
360+ if sequential_targets and pipeline == "independent" :
361+ raise ValueError (
362+ "Invalid configuration: "
363+ "sequential_targets' cannot be used with 'independent' pipeline. "
364+ "Please use 'sequential' or 'layer_sequential' pipeline when specifying "
365+ "sequential_targets."
366+ )
367+
344368 local_args = {
345369 k : v for k , v in locals ().items () if k not in ("local_args" , "kwargs" )
346370 }
371+
347372 one_shot = Oneshot (** local_args , ** kwargs )
348373 one_shot ()
349374
0 commit comments