diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 89c03baf..a7a1d566 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -55,7 +55,8 @@ def generate_prompt_buckets(self): max_num_prefill_seqs = self.max_num_prefill_seqs, block_size = self.block_size, max_num_batched_tokens = self.max_num_batched_tokens, - max_model_len = self.max_model_len) + max_model_len = self.max_model_len, + max_num_blocks = self.num_hpu_blocks) self.log_generate_info(True) else: logger().info("Bucketing is off - skipping prompt buckets generation") diff --git a/vllm_gaudi/extension/bucketing/exponential.py b/vllm_gaudi/extension/bucketing/exponential.py index 34dcaab4..29826707 100644 --- a/vllm_gaudi/extension/bucketing/exponential.py +++ b/vllm_gaudi/extension/bucketing/exponential.py @@ -28,16 +28,21 @@ def check_for_user_flags(self, phase): def get_prompt_buckets(self, max_num_prefill_seqs, block_size, - max_num_batched_tokens, max_model_len): + max_num_batched_tokens, max_model_len, max_num_blocks): self.check_for_user_flags('prompt') - use_merged_prefill = get_config().merged_prefill + use_merged_prefill = get_config().merged_prefill prefix_caching = get_config().prefix_caching - max_prompt_seq = max_model_len + # NOTE(kzawora): v1 requires chunked prefill, + # and we assume it is not going to be supported in v0 hpu code + enable_chunked_prefill = get_config().engine_version == 'v1' + # NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len + max_prompt_seq = max_model_len if not enable_chunked_prefill else max_num_batched_tokens # cfgs shape: [min, step, max, limit] prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1 prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit] max_prompt_seq_limit = math.ceil(math.log2(max_prompt_seq)) + 1 + prompt_seq_bucket_cfg = [block_size, block_size, max_prompt_seq, max_prompt_seq_limit] if use_merged_prefill: @@ -54,8 +59,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, prompt_seq_bucket_cfg, block_size, prefix_caching, + enable_chunked_prefill, max_num_batched_tokens, - max_model_len) + max_model_len, + max_num_blocks) return sorted(prompt_buckets) @@ -89,8 +96,10 @@ def generate_prompt_buckets(bs_bucket_config, seq_bucket_config, block_size, prefix_caching, + enable_chunked_prefill, max_num_batched_tokens=None, - max_model_len=None): + max_model_len=None, + max_num_blocks=None): _, _, bmax, _ = seq_bucket_config batch_size_buckets = warmup_range_with_limit(bs_bucket_config) long_context = False @@ -103,15 +112,15 @@ def generate_prompt_buckets(bs_bucket_config, for bs in batch_size_buckets: for b in seq_bucket_config: buckets_3d.append((bs, b, 0)) - max_blocks_range = (bmax - b) // block_size + max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks if max_blocks_range == 0: continue else: num_buckets_3d = math.ceil(math.log2(max_blocks_range)) + 1 - for i in range(num_buckets_3d): + for i in range(1, num_buckets_3d + 1): power_unpadded = 1 * np.float_power( - max_blocks_range, (1 / float(num_buckets_3d - 1)) * i) + max_blocks_range, (1 / float(num_buckets_3d)) * i) new_bucket = math.ceil(power_unpadded) buckets_3d.append((bs, b, new_bucket)) @@ -131,10 +140,36 @@ def generate_prompt_buckets(bs_bucket_config, filtered_buckets = buckets if max_num_batched_tokens is not None and max_model_len is not None: # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \ - and bucket[1] <= max_model_len, buckets)) + if not enable_chunked_prefill: + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \ + and bucket[1] <= max_model_len, buckets)) + else: + def filter_fn(bucket): + # NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len + _, seq, block = bucket + is_seq_in_bounds = seq <= max_num_batched_tokens + is_block_in_bounds = block <= max_num_blocks + # New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest + return is_seq_in_bounds and is_block_in_bounds + # Find the first bucket that exceeds max_model_len + # For each (bs, seq), keep all buckets that do not exceed model len, and the first that does + from collections import defaultdict + first_exceed_seen = defaultdict(bool) + def keep_bucket(idx_bucket): + _, bucket = idx_bucket + bs, seq, block = bucket + exceeds = (seq + block * block_size) > max_model_len + key = (bs, seq) + if not exceeds: + return filter_fn(bucket) + elif not first_exceed_seen[key] and filter_fn(bucket): + first_exceed_seen[key] = True + return True + else: + return False + filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index c70e04f3..ea9134c7 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -11,9 +11,10 @@ class LinearBucketingStrategy: def get_prompt_buckets(self, max_num_prefill_seqs, block_size, - max_num_batched_tokens, max_model_len): + max_num_batched_tokens, max_model_len, max_num_blocks): use_merged_prefill = get_config().merged_prefill prefix_caching = get_config().prefix_caching + chunked_prefill = get_config().engine_version == 'v1' max_prompt_seq = max_model_len @@ -50,7 +51,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, prompt_seq_bucket_cfg, block_size, prefix_caching, - max_num_batched_tokens) + chunked_prefill, + max_num_batched_tokens, + max_model_len, + max_num_blocks) return sorted(prompt_buckets) @@ -129,7 +133,10 @@ def generate_prompt_buckets(bs_bucket_config, seq_bucket_config, block_size, prefix_caching, - max_num_batched_tokens=None): + enable_chunked_prefill, + max_num_batched_tokens=None, + max_model_len=None, + max_num_blocks=None): _, _, bmax = seq_bucket_config batch_size_buckets = warmup_range(bs_bucket_config) seq_bucket_config = warmup_range(seq_bucket_config) @@ -157,10 +164,37 @@ def generate_prompt_buckets(bs_bucket_config, filtered_buckets = buckets if max_num_batched_tokens is not None: # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, - buckets)) + if not enable_chunked_prefill: + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, + buckets)) + else: + def filter_fn(bucket): + # NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len + _, seq, block = bucket + is_seq_in_bounds = seq <= max_num_batched_tokens + is_block_in_bounds = block <= max_num_blocks + # New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest + return is_seq_in_bounds and is_block_in_bounds + # Find the first bucket that exceeds max_model_len + # For each (bs, seq), keep all buckets that do not exceed model len, and the first that does + from collections import defaultdict + first_exceed_seen = defaultdict(bool) + def keep_bucket(idx_bucket): + _, bucket = idx_bucket + bs, seq, block = bucket + exceeds = (seq + block * block_size) > max_model_len + key = (bs, seq) + if not exceeds: + return filter_fn(bucket) + elif not first_exceed_seen[key] and filter_fn(bucket): + first_exceed_seen[key] = True + return True + else: + return False + filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) + if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens