Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 85 additions & 29 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Literal
from typing import NoReturn
from typing import cast
from unittest.mock import patch
Expand Down Expand Up @@ -377,7 +378,14 @@ def start_precompile_and_check_for_hangs(

def parallel_benchmark(
self, configs: list[Config], *, desc: str = "Benchmarking"
) -> list[tuple[Config, Callable[..., object], float]]:
) -> list[
tuple[
Config,
Callable[..., object],
float,
Literal["ok", "error", "timeout"],
]
]:
"""
Benchmark multiple configurations in parallel.

Expand All @@ -389,35 +397,55 @@ def parallel_benchmark(
A list of tuples containing configurations and their performance.
"""
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
precompile_status: list[Literal["ok", "error", "timeout"]]
if self.settings.autotune_precompile:
futures = [
*starmap(
self.start_precompile_and_check_for_hangs,
zip(configs, fns, strict=True),
)
]
is_workings = PrecompileFuture.wait_for_all(
[
*starmap(
self.start_precompile_and_check_for_hangs,
zip(configs, fns, strict=True),
)
],
futures,
desc=f"{desc} precompiling"
if self.settings.autotune_progress_bar
else None,
)
precompile_status = []
for future, ok in zip(futures, is_workings, strict=True):
reason = future.failure_reason
if ok:
precompile_status.append("ok")
elif reason == "timeout":
precompile_status.append("timeout")
else:
precompile_status.append("error")
else:
is_workings = [True] * len(configs)
results = []
precompile_status = ["ok"] * len(configs)
results: list[
tuple[
Config, Callable[..., object], float, Literal["ok", "error", "timeout"]
]
] = []

# Render a progress bar only when the user requested it.
iterator = iter_with_progress(
zip(configs, fns, is_workings, strict=True),
zip(configs, fns, is_workings, precompile_status, strict=True),
total=len(configs),
description=f"{desc} exploring neighbors",
enabled=self.settings.autotune_progress_bar,
)
for config, fn, is_working in iterator:
for config, fn, is_working, reason in iterator:
status: Literal["ok", "error", "timeout"]
if is_working:
# benchmark one-by-one to avoid noisy results
results.append((config, fn, self.benchmark_function(config, fn)))
perf = self.benchmark_function(config, fn)
status = "ok" if math.isfinite(perf) else "error"
results.append((config, fn, perf, status))
else:
results.append((config, fn, inf))
status = "timeout" if reason == "timeout" else "error"
results.append((config, fn, inf, status))
return results

def autotune(self, *, skip_cache: bool = False) -> Config:
Expand Down Expand Up @@ -486,6 +514,7 @@ class PopulationMember:
perfs: list[float]
flat_values: FlatConfig
config: Config
status: Literal["ok", "error", "timeout", "unknown"] = "unknown"

@property
def perf(self) -> float:
Expand Down Expand Up @@ -581,7 +610,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
"""
config = self.config_gen.unflatten(flat_values)
fn, perf = self.benchmark(config)
return PopulationMember(fn, [perf], flat_values, config)
status: Literal["ok", "error"] = "ok" if math.isfinite(perf) else "error"
return PopulationMember(fn, [perf], flat_values, config, status=status)

def parallel_benchmark_flat(
self, to_check: list[FlatConfig]
Expand Down Expand Up @@ -622,14 +652,15 @@ def parallel_benchmark_population(
members: The list of population members to benchmark.
desc: Description for the progress bar.
"""
for member, (config_out, fn, perf) in zip(
for member, (config_out, fn, perf, status) in zip(
members,
self.parallel_benchmark([m.config for m in members], desc=desc),
strict=True,
):
assert config_out is member.config
member.perfs.append(perf)
member.fn = fn
member.status = status
return members

def compare(self, a: PopulationMember, b: PopulationMember) -> int:
Expand Down Expand Up @@ -730,23 +761,39 @@ def population_statistics(population: list[PopulationMember]) -> str:
A string summarizing the performance of the population.
"""
population = sorted(population, key=performance)
if math.isinf(population[-1].perf):
working = [x for x in population if not math.isinf(x.perf)]
if len(working) == 0:
raise exc.NoConfigFound
return (
f"failed={len(population) - len(working)} "
f"min={working[0].perf:.4f} "
f"mid={working[len(working) // 2].perf:.4f} "
f"max={working[-1].perf:.4f} "
f"best={population[0].config!s}"
status_counts: collections.Counter[str] = collections.Counter()
working: list[PopulationMember] = []
for member in population:
status = member.status
if math.isfinite(member.perf):
working.append(member)
if status not in {"ok", "error", "timeout"}:
status = "ok"
else:
if status not in {"error", "timeout"}:
status = "error"
if status == "timeout":
status_counts["timeout"] += 1
elif status == "error":
status_counts["error"] += 1
else:
status_counts["ok"] += 1
if len(working) == 0:
raise exc.NoConfigFound
parts: list[str] = []
for label in ("error", "timeout", "ok"):
count = status_counts.get(label, 0)
if count:
parts.append(f"{label}={count}")
parts.extend(
(
f"min={working[0].perf:.4f}",
f"mid={working[len(working) // 2].perf:.4f}",
f"max={working[-1].perf:.4f}",
f"best={population[0].config!s}",
)
return (
f"min={population[0].perf:.4f} "
f"mid={population[len(population) // 2].perf:.4f} "
f"max={population[-1].perf:.4f} "
f"best={population[0].config!s}"
)
return " ".join(parts)


@dataclasses.dataclass
Expand Down Expand Up @@ -777,6 +824,7 @@ class PrecompileFuture:
_result_received: bool = False
remote_error: RemoteError | None = None
_remote_error_handled: bool = False
failure_reason: Literal["ok", "error", "timeout"] | None = None

@property
def elapsed(self) -> float:
Expand Down Expand Up @@ -834,6 +882,7 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture:
_result_received=True,
remote_error=None,
_remote_error_handled=True,
failure_reason="ok" if ok else "error",
)

def __call__(self) -> bool:
Expand Down Expand Up @@ -892,6 +941,8 @@ def wait_for_all(
result = []
for f in futures:
assert f.ok is not None
if f.failure_reason is None:
f.failure_reason = "ok" if f.ok else "error"
result.append(f.ok)
return result

Expand Down Expand Up @@ -945,6 +996,10 @@ def _mark_complete(self) -> bool:
self.ok = process.exitcode == 0
self._recv_result(block=True)
self._handle_remote_error(raise_on_raise=False)
if self.ok:
self.failure_reason = "ok"
elif self.failure_reason is None:
self.failure_reason = "error"
return self.ok
process.terminate()
process.join(10)
Expand All @@ -960,6 +1015,7 @@ def _mark_complete(self) -> bool:
self.search.log.warning(msg)

self.ok = False
self.failure_reason = "timeout"
self._recv_result(block=False)
self._handle_remote_error(raise_on_raise=False)
return False
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/finite_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
def _autotune(self) -> Config:
best_config = None
best_time = float("inf")
for config, _fn, time in self.parallel_benchmark(
for config, _fn, time, _status in self.parallel_benchmark(
self.configs, desc="Benchmarking"
):
if time < best_time:
Expand Down
Loading