Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def generate_prompt_buckets(self):
max_num_prefill_seqs = self.max_num_prefill_seqs,
block_size = self.block_size,
max_num_batched_tokens = self.max_num_batched_tokens,
max_model_len = self.max_model_len)
max_model_len = self.max_model_len,
max_num_blocks = self.num_hpu_blocks)
self.log_generate_info(True)
else:
logger().info("Bucketing is off - skipping prompt buckets generation")
Expand Down
59 changes: 47 additions & 12 deletions vllm_gaudi/extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@ def check_for_user_flags(self, phase):


def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
max_num_batched_tokens, max_model_len, max_num_blocks):
self.check_for_user_flags('prompt')
use_merged_prefill = get_config().merged_prefill
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
max_prompt_seq = max_model_len
# NOTE(kzawora): v1 requires chunked prefill,
# and we assume it is not going to be supported in v0 hpu code
enable_chunked_prefill = get_config().engine_version == 'v1'
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len
max_prompt_seq = max_model_len if not enable_chunked_prefill else max_num_batched_tokens

# cfgs shape: [min, step, max, limit]
prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1
prompt_bs_bucket_cfg = [1, 2, max_num_prefill_seqs, prompt_bs_limit]
max_prompt_seq_limit = math.ceil(math.log2(max_prompt_seq)) + 1

prompt_seq_bucket_cfg = [block_size, block_size, max_prompt_seq, max_prompt_seq_limit]

if use_merged_prefill:
Expand All @@ -54,8 +59,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
prompt_seq_bucket_cfg,
block_size,
prefix_caching,
enable_chunked_prefill,
max_num_batched_tokens,
max_model_len)
max_model_len,
max_num_blocks)

return sorted(prompt_buckets)

Expand Down Expand Up @@ -89,8 +96,10 @@ def generate_prompt_buckets(bs_bucket_config,
seq_bucket_config,
block_size,
prefix_caching,
enable_chunked_prefill,
max_num_batched_tokens=None,
max_model_len=None):
max_model_len=None,
max_num_blocks=None):
_, _, bmax, _ = seq_bucket_config
batch_size_buckets = warmup_range_with_limit(bs_bucket_config)
long_context = False
Expand All @@ -103,15 +112,15 @@ def generate_prompt_buckets(bs_bucket_config,
for bs in batch_size_buckets:
for b in seq_bucket_config:
buckets_3d.append((bs, b, 0))
max_blocks_range = (bmax - b) // block_size
max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks
if max_blocks_range == 0:
continue
else:
num_buckets_3d = math.ceil(math.log2(max_blocks_range)) + 1

for i in range(num_buckets_3d):
for i in range(1, num_buckets_3d + 1):
power_unpadded = 1 * np.float_power(
max_blocks_range, (1 / float(num_buckets_3d - 1)) * i)
max_blocks_range, (1 / float(num_buckets_3d)) * i)
new_bucket = math.ceil(power_unpadded)
buckets_3d.append((bs, b, new_bucket))

Expand All @@ -131,10 +140,36 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None and max_model_len is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \
and bucket[1] <= max_model_len, buckets))
if not enable_chunked_prefill:
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \
and bucket[1] <= max_model_len, buckets))
else:
def filter_fn(bucket):
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len
_, seq, block = bucket
is_seq_in_bounds = seq <= max_num_batched_tokens
is_block_in_bounds = block <= max_num_blocks
# New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest
return is_seq_in_bounds and is_block_in_bounds
# Find the first bucket that exceeds max_model_len
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does
from collections import defaultdict
first_exceed_seen = defaultdict(bool)
def keep_bucket(idx_bucket):
_, bucket = idx_bucket
bs, seq, block = bucket
exceeds = (seq + block * block_size) > max_model_len
key = (bs, seq)
if not exceeds:
return filter_fn(bucket)
elif not first_exceed_seen[key] and filter_fn(bucket):
first_exceed_seen[key] = True
return True
else:
return False
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets))))

if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
Expand Down
48 changes: 41 additions & 7 deletions vllm_gaudi/extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

class LinearBucketingStrategy:
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
max_num_batched_tokens, max_model_len, max_num_blocks):
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
chunked_prefill = get_config().engine_version == 'v1'

max_prompt_seq = max_model_len

Expand Down Expand Up @@ -50,7 +51,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
prompt_seq_bucket_cfg,
block_size,
prefix_caching,
max_num_batched_tokens)
chunked_prefill,
max_num_batched_tokens,
max_model_len,
max_num_blocks)

return sorted(prompt_buckets)

Expand Down Expand Up @@ -129,7 +133,10 @@ def generate_prompt_buckets(bs_bucket_config,
seq_bucket_config,
block_size,
prefix_caching,
max_num_batched_tokens=None):
enable_chunked_prefill,
max_num_batched_tokens=None,
max_model_len=None,
max_num_blocks=None):
_, _, bmax = seq_bucket_config
batch_size_buckets = warmup_range(bs_bucket_config)
seq_bucket_config = warmup_range(seq_bucket_config)
Expand Down Expand Up @@ -157,10 +164,37 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))
if not enable_chunked_prefill:
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))
else:
def filter_fn(bucket):
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len
_, seq, block = bucket
is_seq_in_bounds = seq <= max_num_batched_tokens
is_block_in_bounds = block <= max_num_blocks
# New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest
return is_seq_in_bounds and is_block_in_bounds
# Find the first bucket that exceeds max_model_len
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does
from collections import defaultdict
first_exceed_seen = defaultdict(bool)
def keep_bucket(idx_bucket):
_, bucket = idx_bucket
bs, seq, block = bucket
exceeds = (seq + block * block_size) > max_model_len
key = (bs, seq)
if not exceeds:
return filter_fn(bucket)
elif not first_exceed_seen[key] and filter_fn(bucket):
first_exceed_seen[key] = True
return True
else:
return False
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets))))


if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
Expand Down