@@ -1196,6 +1196,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
1196
1196
1197
1197
return inputs
1198
1198
1199
+ def _postprocess_inputs (self , inputs : Dict [str , Any ]):
1200
+ """
1201
+ Postprocess to make sure model forward doesn't change the inputs.
1202
+ It is only used in cuda graph capture, because other cases will prepare
1203
+ new inputs before the model forward.
1204
+ """
1205
+ if self .enable_spec_decode and not self ._disable_overlap_scheduler :
1206
+ if inputs ['attn_metadata' ].kv_cache_manager is not None :
1207
+ num_seqs = inputs ['attn_metadata' ].num_seqs
1208
+ num_ctx_requests = inputs ['attn_metadata' ].num_contexts
1209
+ num_gen_requests = inputs ['attn_metadata' ].num_generations
1210
+ num_ctx_tokens = inputs ['attn_metadata' ].num_ctx_tokens
1211
+ previous_batch_tokens = inputs ['input_ids' ].shape [
1212
+ 0 ] - num_ctx_tokens
1213
+ inputs ['position_ids' ][0 , num_ctx_tokens :] -= (
1214
+ self .previous_pos_id_offsets_cuda [:previous_batch_tokens ])
1215
+ inputs ['attn_metadata' ].kv_lens_cuda [
1216
+ num_ctx_requests :num_seqs ] -= (
1217
+ self .previous_kv_lens_offsets_cuda [:num_gen_requests ])
1218
+
1199
1219
def _get_all_rank_num_tokens (self , attn_metadata : AttentionMetadata ):
1200
1220
if self .enable_attention_dp :
1201
1221
return list (self .dist .tp_allgather (attn_metadata .num_tokens ))
@@ -2298,8 +2318,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
2298
2318
gather_ids = gather_ids ,
2299
2319
gather_context_logits = gather_context_logits )
2300
2320
2321
+ def capture_postprocess_fn (inputs : Dict [str , Any ]):
2322
+ self ._postprocess_inputs (inputs )
2323
+
2301
2324
self .cuda_graph_runner .capture (batch_size ,
2302
- capture_forward_fn , inputs )
2325
+ capture_forward_fn , inputs ,
2326
+ capture_postprocess_fn )
2303
2327
2304
2328
# here we don't need to use context since cuda graph capture didn't run kernel.
2305
2329
# maybe we need a cleaner way to do this.
0 commit comments