diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index e75f5d25..1faaaafa 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -33,7 +33,7 @@ try: import uvloop except ImportError: - uvloop = None # type: ignore[assignment] # Optional dependency + uvloop = None # type: ignore[assignment] # Optional dependency from guidellm.backends import BackendType from guidellm.benchmark import ( @@ -116,6 +116,7 @@ def benchmark(): ) @click.option( "--scenario", + "-c", type=cli_tools.Union( click.Path( exists=True, @@ -392,8 +393,10 @@ def run(**kwargs): disable_progress = kwargs.pop("disable_progress", False) try: + # Only set CLI args that differ from click defaults + new_kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs) args = BenchmarkGenerativeTextArgs.create( - scenario=kwargs.pop("scenario", None), **kwargs + scenario=new_kwargs.pop("scenario", None), **new_kwargs ) except ValidationError as err: # Translate pydantic valdation error to click argument error diff --git a/src/guidellm/benchmark/schemas.py b/src/guidellm/benchmark/schemas.py index 8ddbc458..b523d75d 100644 --- a/src/guidellm/benchmark/schemas.py +++ b/src/guidellm/benchmark/schemas.py @@ -23,7 +23,15 @@ from typing import Any, ClassVar, Literal, TypeVar, cast import yaml -from pydantic import ConfigDict, Field, computed_field, model_serializer +from pydantic import ( + ConfigDict, + Field, + ValidationError, + ValidatorFunctionWrapHandler, + computed_field, + field_validator, + model_serializer, +) from torch.utils.data import Sampler from transformers import PreTrainedTokenizerBase @@ -1142,7 +1150,8 @@ def update_estimate( ) request_duration = ( (request_end_time - request_start_time) - if request_end_time and request_start_time else None + if request_end_time and request_start_time + else None ) # Always track concurrency @@ -1669,7 +1678,7 @@ def compile( estimated_state: EstimatedBenchmarkState, scheduler_state: SchedulerState, profile: Profile, - requests: Iterable, + requests: Iterable, # noqa: ARG003 backend: BackendInterface, environment: Environment, strategy: SchedulingStrategy, @@ -1818,8 +1827,6 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: else: return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above - - model_config = ConfigDict( extra="ignore", use_enum_values=True, @@ -1838,7 +1845,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: profile: StrategyType | ProfileType | Profile = Field( default="sweep", description="Benchmark profile or scheduling strategy type" ) - rate: float | list[float] | None = Field( + rate: list[float] | None = Field( default=None, description="Request rate(s) for rate-based scheduling" ) # Backend configuration @@ -1931,6 +1938,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: default=None, description="Maximum global error rate (0-1) before stopping" ) + @field_validator("data", "data_args", "rate", mode="wrap") + @classmethod + def single_to_list( + cls, value: Any, handler: ValidatorFunctionWrapHandler + ) -> list[Any]: + """ + Ensures field is always a list. + + :param value: Input value for the 'data' field + :return: List of data sources + """ + try: + return handler(value) + except ValidationError as err: + # If validation fails, try wrapping the value in a list + if err.errors()[0]["type"] == "list_type": + return handler([value]) + else: + raise + @model_serializer def serialize_model(self): """ diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index f1184e9e..e1df911a 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -9,7 +9,7 @@ import yaml from datasets import Features, IterableDataset, Value from faker import Faker -from pydantic import ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, ValidationError, model_validator from transformers import PreTrainedTokenizerBase from guidellm.data.deserializers.deserializer import ( @@ -242,6 +242,10 @@ def __call__( if (config := self._load_config_str(data)) is not None: return self(config, processor_factory, random_seed, **data_kwargs) + # Try to parse dict-like data directly + if (config := self._load_config_dict(data)) is not None: + return self(config, processor_factory, random_seed, **data_kwargs) + if not isinstance(data, SyntheticTextDatasetConfig): raise DataNotSupportedError( "Unsupported data for SyntheticTextDatasetDeserializer, " @@ -266,6 +270,15 @@ def __call__( ), ) + def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None: + if not isinstance(data, dict | list): + return None + + try: + return SyntheticTextDatasetConfig.model_validate(data) + except ValidationError: + return None + def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None: if (not isinstance(data, str) and not isinstance(data, Path)) or ( not Path(data).is_file()