Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions vllm_gaudi/extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from vllm_gaudi.extension.runtime import get_config



class LinearBucketingStrategy:
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
Expand All @@ -33,9 +32,9 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
new_prompt_bs_bucket_cfg = prompt_bs_bucket_cfg
new_prompt_seq_bucket_cfg = prompt_seq_bucket_cfg
msg = ('Merged prefill is enabled!\n'
'Overriding prompt bucketing settings!\n'
f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {new_prompt_bs_bucket_cfg}\n'
f'prompt seq cfg: {prev_prompt_seq_bucket_cfg} -> {new_prompt_seq_bucket_cfg}\n')
'Overriding prompt bucketing settings!\n'
f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {new_prompt_bs_bucket_cfg}\n'
f'prompt seq cfg: {prev_prompt_seq_bucket_cfg} -> {new_prompt_seq_bucket_cfg}\n')
logger().info(msg)

msg = ("Prompt bucket config (min, step, max_warmup) "
Expand All @@ -45,16 +44,16 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,

prompt_buckets, prompt_omitted_buckets = \
generate_prompt_buckets(
prompt_bs_bucket_cfg,
prompt_seq_bucket_cfg,
block_size,
prefix_caching,
max_num_batched_tokens)
prompt_bs_bucket_cfg,
prompt_seq_bucket_cfg,
block_size,
prefix_caching,
max_num_batched_tokens)

return sorted(prompt_buckets)

def get_decode_buckets(self, max_num_seqs, block_size,
max_num_batched_tokens, max_model_len,
def get_decode_buckets(self, max_num_seqs, block_size,
max_num_batched_tokens, max_model_len,
max_blocks):
prefix_caching = get_config().prefix_caching

Expand Down Expand Up @@ -115,8 +114,8 @@ def warmup_range(config: Tuple[int, int, int]):
"set VLLM_SKIP_WARMUP=true")
base = itertools.repeat(2)
ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \
ramp_up_acc)
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax,
ramp_up_acc)
stable = range(bstep, bmax + 1, bstep)
buckets = list(ramp_up_tw) + list(stable)
return list(filter(lambda bucket: bucket >= bmin, buckets))
Expand All @@ -141,8 +140,8 @@ def generate_prompt_buckets(bs_bucket_config,
buckets = buckets_3d
else:
buckets = list(
itertools.product(batch_size_buckets,
seq_bucket_config, [0]))
itertools.product(batch_size_buckets,
seq_bucket_config, [0]))

if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
Expand All @@ -156,13 +155,13 @@ def generate_prompt_buckets(bs_bucket_config,
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))

if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
min_bucket_bs, min_bucket_seq, min_bucket_ctx = min(buckets,
key=lambda b: (b[0] * b[1]))
key=lambda b: (b[0] * b[1]))
min_reqd_budget = min_bucket_bs * (min_bucket_seq + min_bucket_ctx * block_size)
msg = (
"The current bucketing configuration "
Expand Down Expand Up @@ -192,11 +191,9 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
bs_buckets = warmup_range(bs_bucket_config)
use_contiguous_pa = get_config().use_contiguous_pa
block_buckets = warmup_range(blocks_bucket_config)
if max_blocks not in block_buckets and use_contiguous_pa:
block_buckets.append(max_blocks)
last_bucket = max_blocks
for bs in bs_buckets:
max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size)
max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size)
for blocks in block_buckets:
if bs > blocks:
# Skip a dummy case when bs > blocks, which cannot occur in real execution
Expand Down