Skip to content

Commit 1bd1702

Browse files
committed
Do not add max_blocks as a bucket in linear bucketing with cpa
1 parent e3dd6a6 commit 1bd1702

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

vllm_gaudi/extension/bucketing/linear.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from vllm_gaudi.extension.runtime import get_config
1010

1111

12-
1312
class LinearBucketingStrategy:
14-
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
13+
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
1514
max_num_batched_tokens, max_model_len):
1615
use_merged_prefill = get_config().merged_prefill
1716
prefix_caching = get_config().prefix_caching
@@ -33,9 +32,9 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
3332
new_prompt_bs_bucket_cfg = prompt_bs_bucket_cfg
3433
new_prompt_seq_bucket_cfg = prompt_seq_bucket_cfg
3534
msg = ('Merged prefill is enabled!\n'
36-
'Overriding prompt bucketing settings!\n'
37-
f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {new_prompt_bs_bucket_cfg}\n'
38-
f'prompt seq cfg: {prev_prompt_seq_bucket_cfg} -> {new_prompt_seq_bucket_cfg}\n')
35+
'Overriding prompt bucketing settings!\n'
36+
f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {new_prompt_bs_bucket_cfg}\n'
37+
f'prompt seq cfg: {prev_prompt_seq_bucket_cfg} -> {new_prompt_seq_bucket_cfg}\n')
3938
logger().info(msg)
4039

4140
msg = ("Prompt bucket config (min, step, max_warmup) "
@@ -45,16 +44,16 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
4544

4645
prompt_buckets, prompt_omitted_buckets = \
4746
generate_prompt_buckets(
48-
prompt_bs_bucket_cfg,
49-
prompt_seq_bucket_cfg,
50-
block_size,
51-
prefix_caching,
52-
max_num_batched_tokens)
47+
prompt_bs_bucket_cfg,
48+
prompt_seq_bucket_cfg,
49+
block_size,
50+
prefix_caching,
51+
max_num_batched_tokens)
5352

5453
return sorted(prompt_buckets)
5554

56-
def get_decode_buckets(self, max_num_seqs, block_size,
57-
max_num_batched_tokens, max_model_len,
55+
def get_decode_buckets(self, max_num_seqs, block_size,
56+
max_num_batched_tokens, max_model_len,
5857
max_blocks):
5958
prefix_caching = get_config().prefix_caching
6059

@@ -115,8 +114,8 @@ def warmup_range(config: Tuple[int, int, int]):
115114
"set VLLM_SKIP_WARMUP=true")
116115
base = itertools.repeat(2)
117116
ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
118-
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \
119-
ramp_up_acc)
117+
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax,
118+
ramp_up_acc)
120119
stable = range(bstep, bmax + 1, bstep)
121120
buckets = list(ramp_up_tw) + list(stable)
122121
return list(filter(lambda bucket: bucket >= bmin, buckets))
@@ -141,8 +140,8 @@ def generate_prompt_buckets(bs_bucket_config,
141140
buckets = buckets_3d
142141
else:
143142
buckets = list(
144-
itertools.product(batch_size_buckets,
145-
seq_bucket_config, [0]))
143+
itertools.product(batch_size_buckets,
144+
seq_bucket_config, [0]))
146145

147146
if len(buckets) == 0:
148147
msg = ("No buckets could be captured with following config "
@@ -156,13 +155,13 @@ def generate_prompt_buckets(bs_bucket_config,
156155
# Remove buckets exceeding batch token budget
157156
filtered_buckets = list(
158157
filter(
159-
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
158+
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
160159
buckets))
161160

162161
if len(filtered_buckets) == 0:
163162
# we can handle this if we ignore max_num_batched_tokens
164163
min_bucket_bs, min_bucket_seq, min_bucket_ctx = min(buckets,
165-
key=lambda b: (b[0] * b[1]))
164+
key=lambda b: (b[0] * b[1]))
166165
min_reqd_budget = min_bucket_bs * (min_bucket_seq + min_bucket_ctx * block_size)
167166
msg = (
168167
"The current bucketing configuration "
@@ -192,11 +191,9 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
192191
bs_buckets = warmup_range(bs_bucket_config)
193192
use_contiguous_pa = get_config().use_contiguous_pa
194193
block_buckets = warmup_range(blocks_bucket_config)
195-
if max_blocks not in block_buckets and use_contiguous_pa:
196-
block_buckets.append(max_blocks)
197194
last_bucket = max_blocks
198195
for bs in bs_buckets:
199-
max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size)
196+
max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size)
200197
for blocks in block_buckets:
201198
if bs > blocks:
202199
# Skip a dummy case when bs > blocks, which cannot occur in real execution

0 commit comments

Comments
 (0)