9
9
from vllm_gaudi .extension .runtime import get_config
10
10
11
11
12
-
13
12
class LinearBucketingStrategy :
14
- def get_prompt_buckets (self , max_num_prefill_seqs , block_size ,
13
+ def get_prompt_buckets (self , max_num_prefill_seqs , block_size ,
15
14
max_num_batched_tokens , max_model_len ):
16
15
use_merged_prefill = get_config ().merged_prefill
17
16
prefix_caching = get_config ().prefix_caching
@@ -33,9 +32,9 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
33
32
new_prompt_bs_bucket_cfg = prompt_bs_bucket_cfg
34
33
new_prompt_seq_bucket_cfg = prompt_seq_bucket_cfg
35
34
msg = ('Merged prefill is enabled!\n '
36
- 'Overriding prompt bucketing settings!\n '
37
- f'prompt bs cfg: { prev_prompt_bs_bucket_cfg } -> { new_prompt_bs_bucket_cfg } \n '
38
- f'prompt seq cfg: { prev_prompt_seq_bucket_cfg } -> { new_prompt_seq_bucket_cfg } \n ' )
35
+ 'Overriding prompt bucketing settings!\n '
36
+ f'prompt bs cfg: { prev_prompt_bs_bucket_cfg } -> { new_prompt_bs_bucket_cfg } \n '
37
+ f'prompt seq cfg: { prev_prompt_seq_bucket_cfg } -> { new_prompt_seq_bucket_cfg } \n ' )
39
38
logger ().info (msg )
40
39
41
40
msg = ("Prompt bucket config (min, step, max_warmup) "
@@ -45,16 +44,16 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
45
44
46
45
prompt_buckets , prompt_omitted_buckets = \
47
46
generate_prompt_buckets (
48
- prompt_bs_bucket_cfg ,
49
- prompt_seq_bucket_cfg ,
50
- block_size ,
51
- prefix_caching ,
52
- max_num_batched_tokens )
47
+ prompt_bs_bucket_cfg ,
48
+ prompt_seq_bucket_cfg ,
49
+ block_size ,
50
+ prefix_caching ,
51
+ max_num_batched_tokens )
53
52
54
53
return sorted (prompt_buckets )
55
54
56
- def get_decode_buckets (self , max_num_seqs , block_size ,
57
- max_num_batched_tokens , max_model_len ,
55
+ def get_decode_buckets (self , max_num_seqs , block_size ,
56
+ max_num_batched_tokens , max_model_len ,
58
57
max_blocks ):
59
58
prefix_caching = get_config ().prefix_caching
60
59
@@ -115,8 +114,8 @@ def warmup_range(config: Tuple[int, int, int]):
115
114
"set VLLM_SKIP_WARMUP=true" )
116
115
base = itertools .repeat (2 )
117
116
ramp_up_acc = itertools .accumulate (base , func = operator .mul , initial = bmin )
118
- ramp_up_tw = itertools .takewhile (lambda x : x < bstep and x <= bmax , \
119
- ramp_up_acc )
117
+ ramp_up_tw = itertools .takewhile (lambda x : x < bstep and x <= bmax ,
118
+ ramp_up_acc )
120
119
stable = range (bstep , bmax + 1 , bstep )
121
120
buckets = list (ramp_up_tw ) + list (stable )
122
121
return list (filter (lambda bucket : bucket >= bmin , buckets ))
@@ -141,8 +140,8 @@ def generate_prompt_buckets(bs_bucket_config,
141
140
buckets = buckets_3d
142
141
else :
143
142
buckets = list (
144
- itertools .product (batch_size_buckets ,
145
- seq_bucket_config , [0 ]))
143
+ itertools .product (batch_size_buckets ,
144
+ seq_bucket_config , [0 ]))
146
145
147
146
if len (buckets ) == 0 :
148
147
msg = ("No buckets could be captured with following config "
@@ -156,13 +155,13 @@ def generate_prompt_buckets(bs_bucket_config,
156
155
# Remove buckets exceeding batch token budget
157
156
filtered_buckets = list (
158
157
filter (
159
- lambda bucket : bucket [0 ] * (bucket [1 ] + bucket [2 ] * block_size ) <= max_num_batched_tokens ,
158
+ lambda bucket : bucket [0 ] * (bucket [1 ] + bucket [2 ] * block_size ) <= max_num_batched_tokens ,
160
159
buckets ))
161
160
162
161
if len (filtered_buckets ) == 0 :
163
162
# we can handle this if we ignore max_num_batched_tokens
164
163
min_bucket_bs , min_bucket_seq , min_bucket_ctx = min (buckets ,
165
- key = lambda b : (b [0 ] * b [1 ]))
164
+ key = lambda b : (b [0 ] * b [1 ]))
166
165
min_reqd_budget = min_bucket_bs * (min_bucket_seq + min_bucket_ctx * block_size )
167
166
msg = (
168
167
"The current bucketing configuration "
@@ -192,11 +191,9 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
192
191
bs_buckets = warmup_range (bs_bucket_config )
193
192
use_contiguous_pa = get_config ().use_contiguous_pa
194
193
block_buckets = warmup_range (blocks_bucket_config )
195
- if max_blocks not in block_buckets and use_contiguous_pa :
196
- block_buckets .append (max_blocks )
197
194
last_bucket = max_blocks
198
195
for bs in bs_buckets :
199
- max_blocks_including_max_model_len = bs * math .ceil (max_model_len / block_size )
196
+ max_blocks_including_max_model_len = bs * math .ceil (max_model_len / block_size )
200
197
for blocks in block_buckets :
201
198
if bs > blocks :
202
199
# Skip a dummy case when bs > blocks, which cannot occur in real execution
0 commit comments