Skip to content
38 changes: 36 additions & 2 deletions vllm_hpu_extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ def generate_fallback_bucket(self, batch_size, seq_len, ctx):
return (new_batch_size, new_seq_len, new_ctx)

def find_prompt_bucket(self, batch_size, seq_len, ctx=0, use_fallback=True):
target_shape = (batch_size, seq_len, ctx)
if self.initialized:
found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, (batch_size, seq_len, ctx))
if get_config().prefix_caching:
found_bucket = find_bucket_with_prefix_caching(self.prompt_buckets, target_shape, self.block_size)
else:
found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, target_shape)
if found_bucket is None:
if use_fallback:
new_bucket = self.generate_fallback_bucket(batch_size, seq_len, ctx)
Expand All @@ -154,7 +158,7 @@ def find_prompt_bucket(self, batch_size, seq_len, ctx=0, use_fallback=True):
else:
return (None, None, None)
return found_bucket
return (batch_size, seq_len, ctx)
return target_shape

def find_decode_bucket(self, batch_size, num_blocks):
if self.initialized:
Expand Down Expand Up @@ -198,3 +202,33 @@ def find_equal_or_closest_greater_config(sorted_list, target_tuple):
return sorted_list[i]
return None

def find_bucket_with_prefix_caching(prompt_buckets, target_shape, block_size):
batch_size, seq_len, ctx = target_shape
import bisect
def find_ge(a, x):
'Find leftmost item greater than or equal to x'
i = bisect.bisect_left(a, x)
if i != len(a):
return a[i]
raise ValueError

def find_le(a, x):
'Find rightmost value less than or equal to x'
i = bisect.bisect_right(a, x)
if i:
return a[i-1]
raise ValueError

bs_buckets = list(sorted(set([b[0] for b in prompt_buckets])))
seq_buckets = list(sorted(set([b[1] for b in prompt_buckets])))
ctx_buckets = list(sorted(set([b[2] for b in prompt_buckets])))

try:
found_bs = find_ge(bs_buckets, batch_size)
found_ctx = find_le(ctx_buckets, ctx)
pad_seq_len = seq_len + (ctx - found_ctx) * block_size
found_seq = find_ge(seq_buckets, pad_seq_len)
found_bucket = (found_bs, found_seq, found_ctx)
return found_bucket
except ValueError:
return target_shape
40 changes: 26 additions & 14 deletions vllm_hpu_extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,27 @@ def generate_prompt_buckets(bs_bucket_config,
block_size,
prefix_caching,
max_num_batched_tokens=None):
_, _, bmax, _ = seq_bucket_config
batch_size_buckets = warmup_range_with_limit(bs_bucket_config)
seq_bucket_config = warmup_range_with_limit(seq_bucket_config)
_, seq_step, seq_max, limit = seq_bucket_config
bs_buckets = warmup_range_with_limit(bs_bucket_config)
seq_buckets = warmup_range_with_limit(seq_bucket_config)
context_bucket_step = max(seq_step // block_size, 1)

if prefix_caching:
buckets_3d = []
for bs in batch_size_buckets:
for b in seq_bucket_config:
max_blocks_range = (bmax - b) // block_size
for i in range(0, max_blocks_range + 2):
buckets_3d.append((bs, b, i))
context_bucket_config = (context_bucket_step, context_bucket_step, seq_max * 2 // block_size + 2, limit)
context_buckets = [0] + warmup_range_with_limit(context_bucket_config)
for bs in bs_buckets:
for seq in seq_buckets:
for i in range(len(context_buckets)):
ctx = context_buckets[i]
if ctx * block_size + seq > seq_max:
break
buckets_3d.append((bs, seq, ctx))
buckets = buckets_3d
else:
buckets = list(
itertools.product(batch_size_buckets,
seq_bucket_config, [0]))
itertools.product(bs_buckets,
seq_buckets, [0]))

if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
Expand All @@ -171,10 +176,17 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))
if prefix_caching:
max_tokens = max_num_batched_tokens + context_bucket_step * block_size
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_tokens,
buckets))
else:
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))

if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
Expand Down