Skip to content

Commit 2ea8c0e

Browse files
committed
Fix saturation detection and harden load generator (Clean)
- Fix 1-second misalignment in saturation detection logic - Add robust test suite for saturation detection
1 parent 599d379 commit 2ea8c0e

File tree

7 files changed

+833
-52
lines changed

7 files changed

+833
-52
lines changed

docs/config.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ load:
114114
timeout: 60 # Length of time to run load to determine saturation
115115
num_stages: 5 # Number of stages to generate
116116
stage_duration: 180 # Duration of each generated stage
117-
saturation_percentile: 95 # Percentile of sampled rates to select as saturation point
118117
```
119118
120119
### Model Server

examples/vllm/config-sweep.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ load:
66
timeout: 60 # Length of time to run load to determine saturation
77
num_stages: 10 # Number of stages to generate
88
stage_duration: 60 # Duration of each generated stage
9-
saturation_percentile: 95 # Percentile of sampled rates to select as saturation point
109
api:
1110
type: completion
1211
server:

inference_perf/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ class SweepConfig(BaseModel):
178178
timeout: float = 60
179179
num_stages: int = 5
180180
stage_duration: int = 180
181-
saturation_percentile: float = 95
182181

183182

184183
class MultiLoRAConfig(BaseModel):

inference_perf/loadgen/load_generator.py

Lines changed: 170 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -413,79 +413,200 @@ async def preprocess(
413413
cancel_signal: SyncEvent,
414414
) -> None:
415415
"""
416-
Runs a preliminary load test to automatically determine the server's saturation point
417-
and generate a suitable series of load stages for the main benchmark.
416+
Runs a Stepped Load (Ramping) test to determine the server's saturation point.
418417
419-
An aggregator task samples the active requests and then the burn down rate is
420-
calculated from the samples. Saturation is derived from a percentile of the
421-
sampled burn down rates.
418+
Iterates through increasing request rates (steps). For each step:
419+
1. Runs for a specific duration.
420+
2. Measures the actual throughput (finished requests / time).
421+
3. Checks if the system is saturated (Actual Throughput < Target Rate * tolerance).
422+
423+
If saturation is detected, the previous successful rate is used as the saturation point.
422424
"""
423-
logger.info("Running preprocessing stage")
424-
results: List[Tuple[float, int]] = []
425+
logger.info("Running preprocessing stage (Stepped Load)")
425426

426427
if self.sweep_config is None:
427428
raise Exception("sweep_config cannot be none")
428429

429-
# Aggregator collects timestamped value of active_requests throughout the preprocessing
430-
async def aggregator() -> None:
430+
# Adaptive Saturation Detection (Exponential Search + Binary Refinement)
431+
# Phase 1: Probe at moderate rate
432+
step_duration = 10
433+
if self.sweep_config.stage_duration > 10:
434+
step_duration = self.sweep_config.stage_duration
435+
436+
throughput_tolerance = 0.90
437+
max_rate_cap = 50000.0
438+
439+
samples: list[tuple[float, int]] = []
440+
441+
# Start aggregator for this step
442+
async def aggregator(results: List[Tuple[float, int]] = samples) -> None:
431443
while True:
432-
results.append((time.perf_counter(), active_requests_counter.value))
444+
results.append((time.perf_counter(), finished_requests_counter.value))
433445
await sleep(0.5)
434446

435447
aggregator_task = create_task(aggregator())
436448

437-
stage_id = -1
438-
duration = 5
439-
rate = self.sweep_config.num_requests / duration
440-
timeout = self.sweep_config.timeout
441-
start_time = time.perf_counter()
442-
await self.run_stage(
443-
stage_id,
444-
rate,
445-
duration,
446-
request_queue,
447-
active_requests_counter,
448-
finished_requests_counter,
449-
request_phase,
450-
timeout=timeout,
451-
cancel_signal=cancel_signal,
452-
)
449+
lower_bound = 0.0
450+
upper_bound = 0.0
451+
saturation_point = 0.0
452+
found_saturation = False
453+
454+
async def measure_rate(target_rate: float) -> tuple[bool, float]:
455+
nonlocal samples
456+
samples.clear()
457+
with finished_requests_counter.get_lock():
458+
finished_requests_counter.value = 0
459+
460+
step_start_time = time.perf_counter()
461+
462+
# Ensure enough duration for low rates to get at least 2 requests
463+
current_step_duration = step_duration
464+
if target_rate > 0 and (2.0 / target_rate) > current_step_duration:
465+
current_step_duration = 2.0 / target_rate
466+
# Cap at some reasonable max to avoid waiting forever on 0.0001 QPS
467+
if current_step_duration > 60:
468+
current_step_duration = 60
469+
470+
logger.info(f"Preprocessing Step: {target_rate:.2f} QPS for {current_step_duration}s")
471+
472+
await self.run_stage(
473+
-1,
474+
target_rate,
475+
current_step_duration,
476+
request_queue,
477+
active_requests_counter,
478+
finished_requests_counter,
479+
request_phase,
480+
timeout=timeout if (timeout := self.sweep_config.timeout) > current_step_duration else current_step_duration + 5,
481+
cancel_signal=cancel_signal,
482+
)
483+
484+
# Wait for aggregator to catch up (ensures final sample is collected)
485+
await sleep(0.6)
486+
487+
# Analysis
488+
warmup_delay = 1.0
489+
valid_results = [
490+
(t, c)
491+
for t, c in samples
492+
if t >= step_start_time + warmup_delay and t < step_start_time + current_step_duration + warmup_delay
493+
]
494+
495+
if len(valid_results) < 2:
496+
logger.warning(f"Step {target_rate}: Insufficient samples. Using total/duration fallback.")
497+
# Fallback: Total finished / Duration
498+
with finished_requests_counter.get_lock():
499+
total = finished_requests_counter.value
500+
measured = total / current_step_duration
501+
else:
502+
timestamps = [r[0] for r in valid_results]
503+
counts = [r[1] for r in valid_results]
504+
try:
505+
measured = (counts[-1] - counts[0]) / (timestamps[-1] - timestamps[0])
506+
except ZeroDivisionError:
507+
measured = 0.0
508+
509+
logger.info(f"Step {target_rate}: Target {target_rate:.2f}, Measured {measured:.2f}")
510+
511+
is_saturated = measured < target_rate * throughput_tolerance
512+
if is_saturated:
513+
logger.warning(f"Saturation detected! Measured {measured:.2f} < {throughput_tolerance*100}% of {target_rate:.2f}")
514+
515+
return is_saturated, measured
516+
517+
# Execution Logic
518+
519+
# 1. Probe
520+
is_sat, measured = await measure_rate(1.0)
521+
522+
lower_bound = 0.1 # A very low safe rate
523+
upper_bound = 1.0
524+
525+
if is_sat:
526+
# Saturated at 1.0. Search DOWN.
527+
# Check 0.1
528+
is_sat_low, measured_low = await measure_rate(0.1)
529+
if is_sat_low:
530+
# Even 0.1 is saturated.
531+
logger.warning("System saturated even at 0.1 QPS. Using minimal capacity.")
532+
saturation_point = measured_low
533+
found_saturation = True
534+
lower_bound = 0.0 # effectively
535+
upper_bound = 0.1
536+
else:
537+
# 0.1 is Safe, 1.0 is Saturated.
538+
lower_bound = 0.1
539+
upper_bound = 1.0
540+
found_saturation = True
541+
else:
542+
# 1.0 is Safe. Search UP.
543+
lower_bound = 1.0
544+
current_rate = 2.0
545+
found_upper = False
546+
547+
while current_rate <= max_rate_cap:
548+
is_sat, measured = await measure_rate(current_rate)
549+
if is_sat:
550+
upper_bound = current_rate
551+
found_upper = True
552+
found_saturation = True
553+
break
554+
else:
555+
lower_bound = current_rate
556+
current_rate *= 2.0
557+
558+
if not found_upper:
559+
logger.info("Hit max rate cap without saturation.")
560+
saturation_point = lower_bound # or measured max
561+
upper_bound = lower_bound # No upper bound found implies linear/max
562+
563+
if found_saturation and lower_bound < upper_bound:
564+
# Binary Search Refinement
565+
# We have [lower, upper]. Do 3 steps.
566+
for _ in range(3):
567+
mid = (lower_bound + upper_bound) / 2
568+
is_sat, measured = await measure_rate(mid)
569+
if is_sat:
570+
upper_bound = mid
571+
else:
572+
lower_bound = mid
573+
574+
saturation_point = lower_bound # Conservative estimate
575+
576+
if not found_saturation:
577+
# If we never found saturation (e.g. max cap reached), use lower_bound (max safe)
578+
saturation_point = lower_bound
453579

454580
aggregator_task.cancel()
455581
try:
456582
await aggregator_task
457583
except CancelledError:
458584
pass
459585

460-
# Ensure that we don't calculate saturation based on the post-timeout drain
461-
results = [(timestamp, requests) for timestamp, requests in results if timestamp < start_time + timeout]
462-
# Calculate the sampled QPS by interval between the samples
463-
rates = [
464-
abs((current_requests - previous_requests) / (current_timestamp - previous_timestamp))
465-
for (current_timestamp, current_requests), (previous_timestamp, previous_requests) in zip(
466-
results[1:], results[:-1], strict=True
467-
)
468-
if current_requests - previous_requests < 0
469-
]
586+
logger.info(f"Saturation point estimated at {saturation_point:0.2f} QPS.")
470587

471-
if len(rates) <= 1:
472-
raise Exception(
473-
"Loadgen preprocessing failed to gather enough samples to determine saturation, try increasing the num_requests or timeout"
474-
)
588+
def generateRates(target_request_rate: float, size: int, gen_type: StageGenType) -> List[float]:
475589

476-
# Generate new stages
477-
logger.debug(f"Determining saturation from rates: {[f'{rate:0.2f}' for rate in sorted(rates)]}")
478-
saturation_point = float(np.percentile(rates, self.sweep_config.saturation_percentile))
479-
logger.info(f"Saturation point estimated at {saturation_point:0.2f} concurrent requests.")
590+
# Calculate start_rate based on target_request_rate and size to ensure proper scaling
591+
# for both low and high target rates.
592+
start_rate = target_request_rate / size
480593

481-
def generateRates(target_request_rate: float, size: int, gen_type: StageGenType) -> List[float]:
482594
if gen_type == StageGenType.GEOM:
483-
return [float(round(1 + target_request_rate - rr, 2)) for rr in np.geomspace(target_request_rate, 1, num=size)]
595+
# Avoid log(0) or similar issues if target is low, but usually target > 1
596+
return [
597+
float(round(start_rate + target_request_rate - rr, 2))
598+
for rr in np.geomspace(target_request_rate, start_rate, num=size)
599+
]
484600
elif gen_type == StageGenType.LINEAR:
485-
return [float(round(r, 2)) for r in np.linspace(1, target_request_rate, size)]
601+
return [float(round(r, 2)) for r in np.linspace(start_rate, target_request_rate, size)]
602+
603+
# Regenerate stages based on found saturation
604+
# If we found saturation, we typically want stages leading up to it
605+
if saturation_point <= 0:
606+
raise Exception("Loadgen preprocessing failed to determine a valid saturation point.")
486607

487-
rates = generateRates(saturation_point, self.sweep_config.num_stages, self.sweep_config.type)
488-
self.stages = [StandardLoadStage(rate=r, duration=self.sweep_config.stage_duration) for r in rates]
608+
gen_rates = generateRates(saturation_point * 1.8, self.sweep_config.num_stages, self.sweep_config.type)
609+
self.stages = [StandardLoadStage(rate=r, duration=self.sweep_config.stage_duration) for r in gen_rates]
489610
logger.info(f"Generated load stages: {[s.rate for s in self.stages]}")
490611

491612
async def mp_run(self, client: ModelServerClient) -> None:

0 commit comments

Comments
 (0)