From 9935befad9c3d3d54205100b3d0b7f6d956db315 Mon Sep 17 00:00:00 2001 From: Arka Sanka Date: Wed, 22 Oct 2025 00:51:44 +0530 Subject: [PATCH] Add validation for empty dataset and enhance oneshot function parameters Signed-off-by: Arka Sanka --- src/llmcompressor/datasets/utils.py | 6 +++ src/llmcompressor/entrypoints/oneshot.py | 32 ++++++++++++-- .../transformers/oneshot/test_api_inputs.py | 43 ++++++++++++++++--- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 8aef67fcd4..f4df70a047 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -149,6 +149,12 @@ def format_calibration_data( f"the provided dataset only has {safe_calibration_samples}. " ) + if safe_calibration_samples == 0: + logger.error("Dataset is empty. Cannot create a calibration dataloader.") + raise ValueError( + "Dataset is empty. Cannot create a calibration dataloader with 0 samples." + ) + if do_shuffle: tokenized_dataset = tokenized_dataset.shuffle() tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 66c320d1b3..0aaf539812 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -12,7 +12,7 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable from loguru import logger from torch.utils.data import DataLoader @@ -260,8 +260,16 @@ def oneshot( preprocessing_num_workers: int | None = None, min_tokens_per_module: float | None = None, moe_calibrate_all_experts: bool = True, + pipeline: str = "independent", + tracing_ignore: list[str] | None = None, + raw_kwargs: dict[str, Any] | None = None, + preprocessing_func: Callable | None = None, + max_train_samples: int | None = None, + remove_columns: list[str] | None = None, + dvc_data_repository: str | None = None, quantization_aware_calibration: bool = True, - # Miscellaneous arguments + sequential_targets: list[str] | None = None, + # Miscellaneous arguments output_dir: str | None = None, log_dir: str | None = None, **kwargs, @@ -331,6 +339,16 @@ def oneshot( during forward pass in calibration. When False, quantization is disabled during forward pass in calibration. Default is set to True. + :param pipeline: The pipeline configuration to use for calibration. Options include + 'independent', 'sequential', or 'layer_sequential'. + :param tracing_ignore: List of module names to ignore during tracing. + :param raw_kwargs: Dictionary of raw keyword arguments passed to the function. + :param preprocessing_func: Optional callable for preprocessing the dataset. + :param max_train_samples: Maximum number of training samples to use. + :param remove_columns: List of column names to remove from the dataset. + :param dvc_data_repository: Path to the DVC data repository, if applicable. + :param sequential_targets: List of sequential targets for calibration. + # Miscellaneous arguments :param output_dir: Path to save the output model after calibration. Nothing is saved if None. @@ -340,10 +358,18 @@ def oneshot( :return: The calibrated PreTrainedModel """ - # pass all args directly into Oneshot + if sequential_targets and pipeline == "independent": + raise ValueError( + "Invalid configuration: " + "sequential_targets' cannot be used with 'independent' pipeline. " + "Please use 'sequential' or 'layer_sequential' pipeline when specifying " + "sequential_targets." + ) + local_args = { k: v for k, v in locals().items() if k not in ("local_args", "kwargs") } + one_shot = Oneshot(**local_args, **kwargs) one_shot() diff --git a/tests/llmcompressor/transformers/oneshot/test_api_inputs.py b/tests/llmcompressor/transformers/oneshot/test_api_inputs.py index 665a765689..b1fb3783f4 100644 --- a/tests/llmcompressor/transformers/oneshot/test_api_inputs.py +++ b/tests/llmcompressor/transformers/oneshot/test_api_inputs.py @@ -1,3 +1,5 @@ +import logging + import pytest from transformers import AutoModelForCausalLM, AutoTokenizer @@ -5,6 +7,9 @@ from tests.llmcompressor.transformers.oneshot.dataset_processing import get_data_utils from tests.testing_utils import parse_params +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/oneshot/oneshot_configs" # TODO: Seems better to mark test type (smoke, sanity, regression) as a marker as @@ -42,15 +47,41 @@ def wrapped_preprocess_func(sample): dataset_config_name=config.get("dataset_config_name"), ) + args["pipeline"] = config.get("pipeline", "independent") + args["sequential_targets"] = config.get("sequential_targets", None) + args["tracing_ignore"] = config.get("tracing_ignore", []) + args["raw_kwargs"] = config.get("raw_kwargs", {}) + args["preprocessing_func"] = config.get("preprocessing_func", lambda x: x) + args["max_train_samples"] = config.get("max_train_samples", 50) + args["remove_columns"] = config.get("remove_columns", None) + args["dvc_data_repository"] = config.get("dvc_data_repository", None) + args["splits"] = config.get("splits", {"calibration": "train[:50]"}) + args["log_dir"] = config.get("log_dir", "sparse_logs") + return args @pytest.mark.smoke @pytest.mark.integration def test_one_shot_inputs(one_shot_args, tmp_path): - oneshot( - **one_shot_args, - output_dir=tmp_path, - num_calibration_samples=10, - pad_to_max_length=False, - ) + logger.info(f"Dataset type: {type(one_shot_args.get('dataset'))}") + if isinstance(one_shot_args.get("dataset"), str): + logger.info(f"Dataset name: {one_shot_args.get('dataset')}") + logger.info(f"Dataset config: {one_shot_args.get('dataset_config_name')}") + try: + # Call oneshot with all parameters as flat arguments + oneshot( + **one_shot_args, + output_dir=tmp_path, + num_calibration_samples=10, + pad_to_max_length=False, + ) + + except ValueError as e: + if "num_samples should be a positive integer value" in str( + e + ) or "Dataset is empty. Cannot create a calibration dataloader" in str(e): + logger.warning(f"Dataset is empty: {one_shot_args.get('dataset')}") + pytest.skip(f"Dataset is empty: {one_shot_args.get('dataset')}") + else: + raise # Re-raise other ValueError exceptions