diff --git a/vllm_hpu_extension/bucketing/common.py b/vllm_hpu_extension/bucketing/common.py index cde63fef3..6e4751cb1 100644 --- a/vllm_hpu_extension/bucketing/common.py +++ b/vllm_hpu_extension/bucketing/common.py @@ -141,8 +141,12 @@ def generate_fallback_bucket(self, batch_size, seq_len, ctx): return (new_batch_size, new_seq_len, new_ctx) def find_prompt_bucket(self, batch_size, seq_len, ctx=0, use_fallback=True): + target_shape = (batch_size, seq_len, ctx) if self.initialized: - found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, (batch_size, seq_len, ctx)) + if get_config().prefix_caching: + found_bucket = find_bucket_with_prefix_caching(self.prompt_buckets, target_shape, self.block_size) + else: + found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, target_shape) if found_bucket is None: if use_fallback: new_bucket = self.generate_fallback_bucket(batch_size, seq_len, ctx) @@ -154,7 +158,7 @@ def find_prompt_bucket(self, batch_size, seq_len, ctx=0, use_fallback=True): else: return (None, None, None) return found_bucket - return (batch_size, seq_len, ctx) + return target_shape def find_decode_bucket(self, batch_size, num_blocks): if self.initialized: @@ -198,3 +202,33 @@ def find_equal_or_closest_greater_config(sorted_list, target_tuple): return sorted_list[i] return None +def find_bucket_with_prefix_caching(prompt_buckets, target_shape, block_size): + batch_size, seq_len, ctx = target_shape + import bisect + def find_ge(a, x): + 'Find leftmost item greater than or equal to x' + i = bisect.bisect_left(a, x) + if i != len(a): + return a[i] + raise ValueError + + def find_le(a, x): + 'Find rightmost value less than or equal to x' + i = bisect.bisect_right(a, x) + if i: + return a[i-1] + raise ValueError + + bs_buckets = list(sorted(set([b[0] for b in prompt_buckets]))) + seq_buckets = list(sorted(set([b[1] for b in prompt_buckets]))) + ctx_buckets = list(sorted(set([b[2] for b in prompt_buckets]))) + + try: + found_bs = find_ge(bs_buckets, batch_size) + found_ctx = find_le(ctx_buckets, ctx) + pad_seq_len = seq_len + (ctx - found_ctx) * block_size + found_seq = find_ge(seq_buckets, pad_seq_len) + found_bucket = (found_bs, found_seq, found_ctx) + return found_bucket + except ValueError: + return target_shape \ No newline at end of file diff --git a/vllm_hpu_extension/bucketing/linear.py b/vllm_hpu_extension/bucketing/linear.py index b0f58cda1..34d4a85df 100644 --- a/vllm_hpu_extension/bucketing/linear.py +++ b/vllm_hpu_extension/bucketing/linear.py @@ -144,22 +144,27 @@ def generate_prompt_buckets(bs_bucket_config, block_size, prefix_caching, max_num_batched_tokens=None): - _, _, bmax, _ = seq_bucket_config - batch_size_buckets = warmup_range_with_limit(bs_bucket_config) - seq_bucket_config = warmup_range_with_limit(seq_bucket_config) + _, seq_step, seq_max, limit = seq_bucket_config + bs_buckets = warmup_range_with_limit(bs_bucket_config) + seq_buckets = warmup_range_with_limit(seq_bucket_config) + context_bucket_step = max(seq_step // block_size, 1) if prefix_caching: buckets_3d = [] - for bs in batch_size_buckets: - for b in seq_bucket_config: - max_blocks_range = (bmax - b) // block_size - for i in range(0, max_blocks_range + 2): - buckets_3d.append((bs, b, i)) + context_bucket_config = (context_bucket_step, context_bucket_step, seq_max * 2 // block_size + 2, limit) + context_buckets = [0] + warmup_range_with_limit(context_bucket_config) + for bs in bs_buckets: + for seq in seq_buckets: + for i in range(len(context_buckets)): + ctx = context_buckets[i] + if ctx * block_size + seq > seq_max: + break + buckets_3d.append((bs, seq, ctx)) buckets = buckets_3d else: buckets = list( - itertools.product(batch_size_buckets, - seq_bucket_config, [0])) + itertools.product(bs_buckets, + seq_buckets, [0])) if len(buckets) == 0: msg = ("No buckets could be captured with following config " @@ -171,10 +176,17 @@ def generate_prompt_buckets(bs_bucket_config, filtered_buckets = buckets if max_num_batched_tokens is not None: # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, - buckets)) + if prefix_caching: + max_tokens = max_num_batched_tokens + context_bucket_step * block_size + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_tokens, + buckets)) + else: + filtered_buckets = list( + filter( + lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, + buckets)) if len(filtered_buckets) == 0: # we can handle this if we ignore max_num_batched_tokens