Skip to content

Commit ba51acf

Browse files
authored
Better scenario from-file support (#432)
## Summary <!-- Include a short paragraph of the changes introduced in this PR. If this PR requires additional context or rationale, explain why the changes are necessary. --> Allows for scenarios to pass yaml or json formatted synthetic config to data. Automatically handles conversion from single data to list. ```yaml --- target: "http://vllm-deployment.kserve-e2e-perf.svc.cluster.local:8000/" profile: concurrent rate: [1,2,4,8] max_seconds: 120 data: prompt_tokens_min: 10 prompt_tokens_max: 8192 prompt_tokens: 4096 prompt_tokens_stdev: 2048 output_tokens_min: 10 output_tokens_max: 2048 output_tokens: 1024 output_tokens_stdev: 512 ``` --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents 65becf0 + a888a7c commit ba51acf

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

src/guidellm/__main__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
try:
3434
import uvloop
3535
except ImportError:
36-
uvloop = None # type: ignore[assignment] # Optional dependency
36+
uvloop = None # type: ignore[assignment] # Optional dependency
3737

3838
from guidellm.backends import BackendType
3939
from guidellm.benchmark import (
@@ -116,6 +116,7 @@ def benchmark():
116116
)
117117
@click.option(
118118
"--scenario",
119+
"-c",
119120
type=cli_tools.Union(
120121
click.Path(
121122
exists=True,
@@ -392,8 +393,10 @@ def run(**kwargs):
392393
disable_progress = kwargs.pop("disable_progress", False)
393394

394395
try:
396+
# Only set CLI args that differ from click defaults
397+
new_kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
395398
args = BenchmarkGenerativeTextArgs.create(
396-
scenario=kwargs.pop("scenario", None), **kwargs
399+
scenario=new_kwargs.pop("scenario", None), **new_kwargs
397400
)
398401
except ValidationError as err:
399402
# Translate pydantic valdation error to click argument error

src/guidellm/benchmark/schemas.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
from typing import Any, ClassVar, Literal, TypeVar, cast
2424

2525
import yaml
26-
from pydantic import ConfigDict, Field, computed_field, model_serializer
26+
from pydantic import (
27+
ConfigDict,
28+
Field,
29+
ValidationError,
30+
ValidatorFunctionWrapHandler,
31+
computed_field,
32+
field_validator,
33+
model_serializer,
34+
)
2735
from torch.utils.data import Sampler
2836
from transformers import PreTrainedTokenizerBase
2937

@@ -1142,7 +1150,8 @@ def update_estimate(
11421150
)
11431151
request_duration = (
11441152
(request_end_time - request_start_time)
1145-
if request_end_time and request_start_time else None
1153+
if request_end_time and request_start_time
1154+
else None
11461155
)
11471156

11481157
# Always track concurrency
@@ -1669,7 +1678,7 @@ def compile(
16691678
estimated_state: EstimatedBenchmarkState,
16701679
scheduler_state: SchedulerState,
16711680
profile: Profile,
1672-
requests: Iterable,
1681+
requests: Iterable, # noqa: ARG003
16731682
backend: BackendInterface,
16741683
environment: Environment,
16751684
strategy: SchedulingStrategy,
@@ -1818,8 +1827,6 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18181827
else:
18191828
return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above
18201829

1821-
1822-
18231830
model_config = ConfigDict(
18241831
extra="ignore",
18251832
use_enum_values=True,
@@ -1838,7 +1845,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18381845
profile: StrategyType | ProfileType | Profile = Field(
18391846
default="sweep", description="Benchmark profile or scheduling strategy type"
18401847
)
1841-
rate: float | list[float] | None = Field(
1848+
rate: list[float] | None = Field(
18421849
default=None, description="Request rate(s) for rate-based scheduling"
18431850
)
18441851
# Backend configuration
@@ -1931,6 +1938,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
19311938
default=None, description="Maximum global error rate (0-1) before stopping"
19321939
)
19331940

1941+
@field_validator("data", "data_args", "rate", mode="wrap")
1942+
@classmethod
1943+
def single_to_list(
1944+
cls, value: Any, handler: ValidatorFunctionWrapHandler
1945+
) -> list[Any]:
1946+
"""
1947+
Ensures field is always a list.
1948+
1949+
:param value: Input value for the 'data' field
1950+
:return: List of data sources
1951+
"""
1952+
try:
1953+
return handler(value)
1954+
except ValidationError as err:
1955+
# If validation fails, try wrapping the value in a list
1956+
if err.errors()[0]["type"] == "list_type":
1957+
return handler([value])
1958+
else:
1959+
raise
1960+
19341961
@model_serializer
19351962
def serialize_model(self):
19361963
"""

src/guidellm/data/deserializers/synthetic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import yaml
1010
from datasets import Features, IterableDataset, Value
1111
from faker import Faker
12-
from pydantic import ConfigDict, Field, model_validator
12+
from pydantic import ConfigDict, Field, ValidationError, model_validator
1313
from transformers import PreTrainedTokenizerBase
1414

1515
from guidellm.data.deserializers.deserializer import (
@@ -242,6 +242,10 @@ def __call__(
242242
if (config := self._load_config_str(data)) is not None:
243243
return self(config, processor_factory, random_seed, **data_kwargs)
244244

245+
# Try to parse dict-like data directly
246+
if (config := self._load_config_dict(data)) is not None:
247+
return self(config, processor_factory, random_seed, **data_kwargs)
248+
245249
if not isinstance(data, SyntheticTextDatasetConfig):
246250
raise DataNotSupportedError(
247251
"Unsupported data for SyntheticTextDatasetDeserializer, "
@@ -266,6 +270,15 @@ def __call__(
266270
),
267271
)
268272

273+
def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None:
274+
if not isinstance(data, dict | list):
275+
return None
276+
277+
try:
278+
return SyntheticTextDatasetConfig.model_validate(data)
279+
except ValidationError:
280+
return None
281+
269282
def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None:
270283
if (not isinstance(data, str) and not isinstance(data, Path)) or (
271284
not Path(data).is_file()

0 commit comments

Comments
 (0)