diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index abd9a745..584eec44 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -9,9 +9,8 @@ from vllm_gaudi.extension.runtime import get_config - class LinearBucketingStrategy: - def get_prompt_buckets(self, max_num_prefill_seqs, block_size, + def get_prompt_buckets(self, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len): use_merged_prefill = get_config().merged_prefill prefix_caching = get_config().prefix_caching @@ -33,9 +32,9 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, new_prompt_bs_bucket_cfg = prompt_bs_bucket_cfg new_prompt_seq_bucket_cfg = prompt_seq_bucket_cfg msg = ('Merged prefill is enabled!\n' - 'Overriding prompt bucketing settings!\n' - f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {new_prompt_bs_bucket_cfg}\n' - f'prompt seq cfg: {prev_prompt_seq_bucket_cfg} -> {new_prompt_seq_bucket_cfg}\n') + 'Overriding prompt bucketing settings!\n' + f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {new_prompt_bs_bucket_cfg}\n' + f'prompt seq cfg: {prev_prompt_seq_bucket_cfg} -> {new_prompt_seq_bucket_cfg}\n') logger().info(msg) msg = ("Prompt bucket config (min, step, max_warmup) " @@ -45,16 +44,16 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, prompt_buckets, prompt_omitted_buckets = \ generate_prompt_buckets( - prompt_bs_bucket_cfg, - prompt_seq_bucket_cfg, - block_size, - prefix_caching, - max_num_batched_tokens) + prompt_bs_bucket_cfg, + prompt_seq_bucket_cfg, + block_size, + prefix_caching, + max_num_batched_tokens) return sorted(prompt_buckets) - def get_decode_buckets(self, max_num_seqs, block_size, - max_num_batched_tokens, max_model_len, + def get_decode_buckets(self, max_num_seqs, block_size, + max_num_batched_tokens, max_model_len, max_blocks): prefix_caching = get_config().prefix_caching @@ -115,8 +114,8 @@ def warmup_range(config: Tuple[int, int, int]): "set VLLM_SKIP_WARMUP=true") base = itertools.repeat(2) ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) + ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, + ramp_up_acc) stable = range(bstep, bmax + 1, bstep) buckets = list(ramp_up_tw) + list(stable) return list(filter(lambda bucket: bucket >= bmin, buckets)) @@ -141,8 +140,8 @@ def generate_prompt_buckets(bs_bucket_config, buckets = buckets_3d else: buckets = list( - itertools.product(batch_size_buckets, - seq_bucket_config, [0])) + itertools.product(batch_size_buckets, + seq_bucket_config, [0])) if len(buckets) == 0: msg = ("No buckets could be captured with following config " @@ -156,13 +155,13 @@ def generate_prompt_buckets(bs_bucket_config, # Remove buckets exceeding batch token budget filtered_buckets = list( filter( - lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, + lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, buckets)) if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens min_bucket_bs, min_bucket_seq, min_bucket_ctx = min(buckets, - key=lambda b: (b[0] * b[1])) + key=lambda b: (b[0] * b[1])) min_reqd_budget = min_bucket_bs * (min_bucket_seq + min_bucket_ctx * block_size) msg = ( "The current bucketing configuration " @@ -192,11 +191,9 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, bs_buckets = warmup_range(bs_bucket_config) use_contiguous_pa = get_config().use_contiguous_pa block_buckets = warmup_range(blocks_bucket_config) - if max_blocks not in block_buckets and use_contiguous_pa: - block_buckets.append(max_blocks) last_bucket = max_blocks for bs in bs_buckets: - max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size) + max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size) for blocks in block_buckets: if bs > blocks: # Skip a dummy case when bs > blocks, which cannot occur in real execution