Skip to content

Commit 59312d5

Browse files
committed
Add validation for empty dataset and enhance oneshot function parameters
Signed-off-by: Arka Sanka <[email protected]> Refactor oneshot function parameters to use Optional types and enhance documentation Signed-off-by: Arka Sanka <[email protected]>
1 parent c254c19 commit 59312d5

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

src/llmcompressor/datasets/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def format_calibration_data(
149149
f"the provided dataset only has {safe_calibration_samples}. "
150150
)
151151

152+
if safe_calibration_samples == 0:
153+
raise ValueError(
154+
"Dataset is empty. Cannot create a calibration dataloader with 0 samples."
155+
)
156+
152157
if do_shuffle:
153158
tokenized_dataset = tokenized_dataset.shuffle()
154159
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
from datetime import datetime
1212
from pathlib import Path
13-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
13+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
1414

1515
from loguru import logger
1616
from torch.utils.data import DataLoader
@@ -258,8 +258,15 @@ def oneshot(
258258
preprocessing_num_workers: Optional[int] = None,
259259
min_tokens_per_module: Optional[float] = None,
260260
moe_calibrate_all_experts: bool = True,
261+
pipeline: str = "independent",
262+
tracing_ignore: Optional[List[str]] = None,
263+
raw_kwargs: Optional[Dict[str, Any]] = None,
264+
preprocessing_func: Optional[Callable] = None,
265+
max_train_samples: Optional[int] = None,
266+
remove_columns: Optional[List[str]] = None,
267+
dvc_data_repository: Optional[str] = None,
261268
quantization_aware_calibration: bool = True,
262-
# Miscellaneous arguments
269+
sequential_targets: Optional[List[str]] = None,
263270
output_dir: Optional[str] = None,
264271
log_dir: Optional[str] = None,
265272
**kwargs,
@@ -328,6 +335,16 @@ def oneshot(
328335
during forward pass in calibration. When False, quantization is disabled
329336
during forward pass in calibration. Default is set to True.
330337
338+
:param pipeline: The pipeline configuration to use for calibration. Options include
339+
'independent', 'sequential', or 'layer_sequential'.
340+
:param tracing_ignore: List of module names to ignore during tracing.
341+
:param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
342+
:param preprocessing_func: Optional callable for preprocessing the dataset.
343+
:param max_train_samples: Maximum number of training samples to use.
344+
:param remove_columns: List of column names to remove from the dataset.
345+
:param dvc_data_repository: Path to the DVC data repository, if applicable.
346+
:param sequential_targets: List of sequential targets for calibration.
347+
331348
# Miscellaneous arguments
332349
:param output_dir: Path to save the output model after calibration.
333350
Nothing is saved if None.
@@ -337,10 +354,18 @@ def oneshot(
337354
:return: The calibrated PreTrainedModel
338355
"""
339356

340-
# pass all args directly into Oneshot
357+
if sequential_targets and pipeline == "independent":
358+
raise ValueError(
359+
"Invalid configuration: "
360+
"sequential_targets' cannot be used with 'independent' pipeline. "
361+
"Please use 'sequential' or 'layer_sequential' pipeline when specifying "
362+
"sequential_targets."
363+
)
364+
341365
local_args = {
342366
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
343367
}
368+
344369
one_shot = Oneshot(**local_args, **kwargs)
345370
one_shot()
346371

tests/llmcompressor/transformers/oneshot/test_api_inputs.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import logging
2+
13
import pytest
24
from transformers import AutoModelForCausalLM, AutoTokenizer
35

46
from llmcompressor import oneshot
57
from tests.llmcompressor.transformers.oneshot.dataset_processing import get_data_utils
68
from tests.testing_utils import parse_params
79

10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
813
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/oneshot/oneshot_configs"
914

1015
# TODO: Seems better to mark test type (smoke, sanity, regression) as a marker as
@@ -42,15 +47,41 @@ def wrapped_preprocess_func(sample):
4247
dataset_config_name=config.get("dataset_config_name"),
4348
)
4449

50+
args["pipeline"] = config.get("pipeline", "independent")
51+
args["sequential_targets"] = config.get("sequential_targets", None)
52+
args["tracing_ignore"] = config.get("tracing_ignore", [])
53+
args["raw_kwargs"] = config.get("raw_kwargs", {})
54+
args["preprocessing_func"] = config.get("preprocessing_func", lambda x: x)
55+
args["max_train_samples"] = config.get("max_train_samples", 50)
56+
args["remove_columns"] = config.get("remove_columns", None)
57+
args["dvc_data_repository"] = config.get("dvc_data_repository", None)
58+
args["splits"] = config.get("splits", {"calibration": "train[:50]"})
59+
args["log_dir"] = config.get("log_dir", "sparse_logs")
60+
4561
return args
4662

4763

4864
@pytest.mark.smoke
4965
@pytest.mark.integration
5066
def test_one_shot_inputs(one_shot_args, tmp_path):
51-
oneshot(
52-
**one_shot_args,
53-
output_dir=tmp_path,
54-
num_calibration_samples=10,
55-
pad_to_max_length=False,
56-
)
67+
logger.info(f"Dataset type: {type(one_shot_args.get('dataset'))}")
68+
if isinstance(one_shot_args.get("dataset"), str):
69+
logger.info(f"Dataset name: {one_shot_args.get('dataset')}")
70+
logger.info(f"Dataset config: {one_shot_args.get('dataset_config_name')}")
71+
try:
72+
# Call oneshot with all parameters as flat arguments
73+
oneshot(
74+
**one_shot_args,
75+
output_dir=tmp_path,
76+
num_calibration_samples=10,
77+
pad_to_max_length=False,
78+
)
79+
80+
except ValueError as e:
81+
if "num_samples should be a positive integer value" in str(
82+
e
83+
) or "Dataset is empty. Cannot create a calibration dataloader" in str(e):
84+
logger.warning(f"Dataset is empty: {one_shot_args.get('dataset')}")
85+
pytest.skip(f"Dataset is empty: {one_shot_args.get('dataset')}")
86+
else:
87+
raise # Re-raise other ValueError exceptions

0 commit comments

Comments
 (0)