Skip to content

Commit b232cae

Browse files
authored
Merge branch 'main' into update-ui-prod-version
2 parents bd03aa9 + 9d9392b commit b232cae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1060
-2269
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ profile = "black"
160160

161161

162162
[tool.mypy]
163-
files = ["src/guidellm", "tests"]
163+
files = ["src/guidellm"]
164164
python_version = '3.10'
165165
warn_redundant_casts = true
166166
warn_unused_ignores = false

src/guidellm/__main__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
try:
3434
import uvloop
3535
except ImportError:
36-
uvloop = None # type: ignore[assignment] # Optional dependency
36+
uvloop = None # type: ignore[assignment] # Optional dependency
3737

3838
from guidellm.backends import BackendType
3939
from guidellm.benchmark import (
@@ -116,6 +116,7 @@ def benchmark():
116116
)
117117
@click.option(
118118
"--scenario",
119+
"-c",
119120
type=cli_tools.Union(
120121
click.Path(
121122
exists=True,
@@ -156,8 +157,9 @@ def benchmark():
156157
)
157158
@click.option(
158159
"--rate",
159-
type=float,
160-
multiple=True,
160+
type=str,
161+
callback=cli_tools.parse_list_floats,
162+
multiple=False,
161163
default=BenchmarkGenerativeTextArgs.get_default("rate"),
162164
help=(
163165
"Benchmark rate(s) to test. Meaning depends on profile: "
@@ -383,16 +385,18 @@ def run(**kwargs):
383385
kwargs.get("data_args"), default=[], simplify_single=False
384386
)
385387
kwargs["rate"] = cli_tools.format_list_arg(
386-
kwargs.get("rate"), default=None, simplify_single=True
388+
kwargs.get("rate"), default=None, simplify_single=False
387389
)
388390

389391
disable_console_outputs = kwargs.pop("disable_console_outputs", False)
390392
display_scheduler_stats = kwargs.pop("display_scheduler_stats", False)
391393
disable_progress = kwargs.pop("disable_progress", False)
392394

393395
try:
396+
# Only set CLI args that differ from click defaults
397+
new_kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
394398
args = BenchmarkGenerativeTextArgs.create(
395-
scenario=kwargs.pop("scenario", None), **kwargs
399+
scenario=new_kwargs.pop("scenario", None), **new_kwargs
396400
)
397401
except ValidationError as err:
398402
# Translate pydantic valdation error to click argument error

src/guidellm/benchmark/benchmarker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import uuid
1414
from abc import ABC
1515
from collections.abc import AsyncIterator, Iterable
16-
from typing import Generic
16+
from typing import Any, Generic
1717

1818
from guidellm.benchmark.profile import Profile
1919
from guidellm.benchmark.progress import BenchmarkerProgress
@@ -57,6 +57,7 @@ async def run(
5757
backend: BackendInterface[RequestT, ResponseT],
5858
profile: Profile,
5959
environment: Environment,
60+
data: list[Any],
6061
progress: BenchmarkerProgress[BenchmarkT] | None = None,
6162
sample_requests: int | None = 20,
6263
warmup: float | None = None,
@@ -149,6 +150,7 @@ async def run(
149150
environment=environment,
150151
strategy=strategy,
151152
constraints=constraints,
153+
data=data,
152154
)
153155
if progress:
154156
await progress.on_benchmark_complete(benchmark)

src/guidellm/benchmark/entrypoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ async def benchmark_generative_text(
436436
backend=backend,
437437
profile=profile,
438438
environment=NonDistributedEnvironment(),
439+
data=args.data,
439440
progress=progress,
440441
sample_requests=args.sample_requests,
441442
warmup=args.warmup,

src/guidellm/benchmark/output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,8 @@ def _get_benchmark_status_metrics_stats(
649649
status_dist_summary: StatusDistributionSummary = getattr(
650650
benchmark.metrics, metric
651651
)
652+
if not hasattr(status_dist_summary, status):
653+
return [], []
652654
dist_summary: DistributionSummary = getattr(status_dist_summary, status)
653655

654656
headers = [
@@ -688,7 +690,7 @@ def _get_benchmark_extras_headers_and_values(
688690
values: list[str] = [
689691
benchmark.benchmarker.profile.model_dump_json(),
690692
json.dumps(benchmark.benchmarker.backend),
691-
json.dumps(benchmark.benchmarker.requests["attributes"]["data"]),
693+
json.dumps(benchmark.benchmarker.requests["data"]),
692694
]
693695

694696
if len(headers) != len(values):

src/guidellm/benchmark/schemas.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,17 @@
2323
from typing import Any, ClassVar, Literal, TypeVar, cast
2424

2525
import yaml
26-
from pydantic import ConfigDict, Field, computed_field, model_serializer
26+
from pydantic import (
27+
AliasChoices,
28+
AliasGenerator,
29+
ConfigDict,
30+
Field,
31+
ValidationError,
32+
ValidatorFunctionWrapHandler,
33+
computed_field,
34+
field_validator,
35+
model_serializer,
36+
)
2737
from torch.utils.data import Sampler
2838
from transformers import PreTrainedTokenizerBase
2939

@@ -1142,7 +1152,8 @@ def update_estimate(
11421152
)
11431153
request_duration = (
11441154
(request_end_time - request_start_time)
1145-
if request_end_time and request_start_time else None
1155+
if request_end_time and request_start_time
1156+
else None
11461157
)
11471158

11481159
# Always track concurrency
@@ -1669,11 +1680,12 @@ def compile(
16691680
estimated_state: EstimatedBenchmarkState,
16701681
scheduler_state: SchedulerState,
16711682
profile: Profile,
1672-
requests: Iterable,
1683+
requests: Iterable, # noqa: ARG003
16731684
backend: BackendInterface,
16741685
environment: Environment,
16751686
strategy: SchedulingStrategy,
16761687
constraints: dict[str, dict[str, Any]],
1688+
data: list[Any],
16771689
) -> GenerativeBenchmark:
16781690
"""
16791691
Compile final generative benchmark from accumulated state.
@@ -1702,7 +1714,7 @@ def compile(
17021714
),
17031715
benchmarker=BenchmarkerDict(
17041716
profile=profile,
1705-
requests=InfoMixin.extract_from_obj(requests),
1717+
requests={"data": data},
17061718
backend=backend.info,
17071719
environment=environment.info,
17081720
),
@@ -1786,9 +1798,8 @@ def create(
17861798
scenario_data = scenario_data["args"]
17871799
constructor_kwargs.update(scenario_data)
17881800

1789-
for key, value in kwargs.items():
1790-
if value != cls.get_default(key):
1791-
constructor_kwargs[key] = value
1801+
# Apply overrides from kwargs
1802+
constructor_kwargs.update(kwargs)
17921803

17931804
return cls.model_validate(constructor_kwargs)
17941805

@@ -1817,13 +1828,19 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18171828
else:
18181829
return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above
18191830

1820-
1821-
18221831
model_config = ConfigDict(
18231832
extra="ignore",
18241833
use_enum_values=True,
18251834
from_attributes=True,
18261835
arbitrary_types_allowed=True,
1836+
validate_by_alias=True,
1837+
validate_by_name=True,
1838+
alias_generator=AliasGenerator(
1839+
# Support field names with hyphens
1840+
validation_alias=lambda field_name: AliasChoices(
1841+
field_name, field_name.replace("_", "-")
1842+
),
1843+
),
18271844
)
18281845

18291846
# Required
@@ -1837,7 +1854,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18371854
profile: StrategyType | ProfileType | Profile = Field(
18381855
default="sweep", description="Benchmark profile or scheduling strategy type"
18391856
)
1840-
rate: float | list[float] | None = Field(
1857+
rate: list[float] | None = Field(
18411858
default=None, description="Request rate(s) for rate-based scheduling"
18421859
)
18431860
# Backend configuration
@@ -1870,6 +1887,12 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18701887
data_request_formatter: DatasetPreprocessor | dict[str, str] | str = Field(
18711888
default="chat_completions",
18721889
description="Request formatting preprocessor or template name",
1890+
validation_alias=AliasChoices(
1891+
"data_request_formatter",
1892+
"data-request-formatter",
1893+
"request_type",
1894+
"request-type",
1895+
),
18731896
)
18741897
data_collator: Callable | Literal["generative"] | None = Field(
18751898
default="generative", description="Data collator for batch processing"
@@ -1930,6 +1953,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
19301953
default=None, description="Maximum global error rate (0-1) before stopping"
19311954
)
19321955

1956+
@field_validator("data", "data_args", "rate", mode="wrap")
1957+
@classmethod
1958+
def single_to_list(
1959+
cls, value: Any, handler: ValidatorFunctionWrapHandler
1960+
) -> list[Any]:
1961+
"""
1962+
Ensures field is always a list.
1963+
1964+
:param value: Input value for the 'data' field
1965+
:return: List of data sources
1966+
"""
1967+
try:
1968+
return handler(value)
1969+
except ValidationError as err:
1970+
# If validation fails, try wrapping the value in a list
1971+
if err.errors()[0]["type"] == "list_type":
1972+
return handler([value])
1973+
else:
1974+
raise
1975+
19331976
@model_serializer
19341977
def serialize_model(self):
19351978
"""
Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
import contextlib
43
from collections.abc import Callable
54
from typing import Any, Protocol, Union, runtime_checkable
65

7-
from datasets import Dataset, IterableDataset
6+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
87
from transformers import PreTrainedTokenizerBase
98

109
from guidellm.data.utils import resolve_dataset_split
@@ -29,7 +28,7 @@ def __call__(
2928
processor_factory: Callable[[], PreTrainedTokenizerBase],
3029
random_seed: int,
3130
**data_kwargs: dict[str, Any],
32-
) -> dict[str, list]: ...
31+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
3332

3433

3534
class DatasetDeserializerFactory(
@@ -47,51 +46,16 @@ def deserialize(
4746
remove_columns: list[str] | None = None,
4847
**data_kwargs: dict[str, Any],
4948
) -> Dataset | IterableDataset:
50-
dataset = None
49+
dataset: Dataset
5150

5251
if type_ is None:
53-
errors = []
54-
# Note: There is no priority order for the deserializers, so all deserializers
55-
# must be mutually exclusive to ensure deterministic behavior.
56-
for name, deserializer in cls.registry.items():
57-
deserializer_fn: DatasetDeserializer = (
58-
deserializer() if isinstance(deserializer, type) else deserializer
59-
)
60-
61-
try:
62-
with contextlib.suppress(DataNotSupportedError):
63-
dataset = deserializer_fn(
64-
data=data,
65-
processor_factory=processor_factory,
66-
random_seed=random_seed,
67-
**data_kwargs,
68-
)
69-
except Exception as e:
70-
errors.append(e)
71-
72-
if dataset is not None:
73-
break # Found one that works. Continuing could overwrite it.
74-
75-
if dataset is None and len(errors) > 0:
76-
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77-
f"attempting to deserialize data {data}: {errors}")
78-
79-
elif deserializer := cls.get_registered_object(type_) is not None:
80-
deserializer_fn: DatasetDeserializer = (
81-
deserializer() if isinstance(deserializer, type) else deserializer
52+
dataset = cls._deserialize_with_registered_deserializers(
53+
data, processor_factory, random_seed, **data_kwargs
8254
)
8355

84-
dataset = deserializer_fn(
85-
data=data,
86-
processor_factory=processor_factory,
87-
random_seed=random_seed,
88-
**data_kwargs,
89-
)
90-
91-
if dataset is None:
92-
raise DataNotSupportedError(
93-
f"No suitable deserializer found for data {data} "
94-
f"with kwargs {data_kwargs} and deserializer type {type_}."
56+
else:
57+
dataset = cls._deserialize_with_specified_deserializer(
58+
data, type_, processor_factory, random_seed, **data_kwargs
9559
)
9660

9761
if resolve_split:
@@ -107,3 +71,75 @@ def deserialize(
10771
dataset = dataset.remove_columns(remove_columns)
10872

10973
return dataset
74+
75+
@classmethod
76+
def _deserialize_with_registered_deserializers(
77+
cls,
78+
data: Any,
79+
processor_factory: Callable[[], PreTrainedTokenizerBase],
80+
random_seed: int = 42,
81+
**data_kwargs: dict[str, Any],
82+
) -> Dataset:
83+
if cls.registry is None:
84+
raise RuntimeError("registry is None; cannot deserialize dataset")
85+
dataset: Dataset | None = None
86+
87+
errors: dict[str, Exception] = {}
88+
# Note: There is no priority order for the deserializers, so all deserializers
89+
# must be mutually exclusive to ensure deterministic behavior.
90+
for _name, deserializer in cls.registry.items():
91+
deserializer_fn: DatasetDeserializer = (
92+
deserializer() if isinstance(deserializer, type) else deserializer
93+
)
94+
95+
try:
96+
dataset = deserializer_fn(
97+
data=data,
98+
processor_factory=processor_factory,
99+
random_seed=random_seed,
100+
**data_kwargs,
101+
)
102+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
103+
errors[_name] = e
104+
105+
if dataset is not None:
106+
return dataset # Success
107+
108+
if len(errors) > 0:
109+
err_msgs = ""
110+
111+
def sort_key(item):
112+
return (isinstance(item[1], DataNotSupportedError), item[0])
113+
114+
for key, err in sorted(errors.items(), key=sort_key):
115+
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
116+
raise ValueError(
117+
"Data deserialization failed, likely because the input doesn't "
118+
f"match any of the input formats. See the {len(errors)} error(s) that "
119+
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
120+
)
121+
return dataset
122+
123+
@classmethod
124+
def _deserialize_with_specified_deserializer(
125+
cls,
126+
data: Any,
127+
type_: str,
128+
processor_factory: Callable[[], PreTrainedTokenizerBase],
129+
random_seed: int = 42,
130+
**data_kwargs: dict[str, Any],
131+
) -> Dataset:
132+
deserializer_from_type = cls.get_registered_object(type_)
133+
if deserializer_from_type is None:
134+
raise ValueError(f"Deserializer type '{type_}' is not registered.")
135+
if isinstance(deserializer_from_type, type):
136+
deserializer_fn = deserializer_from_type()
137+
else:
138+
deserializer_fn = deserializer_from_type
139+
140+
return deserializer_fn(
141+
data=data,
142+
processor_factory=processor_factory,
143+
random_seed=random_seed,
144+
**data_kwargs,
145+
)

0 commit comments

Comments
 (0)