@@ -1051,11 +1051,10 @@ def forward(
1051
1051
]
1052
1052
num_actual_toks = attn_metadata .num_actual_tokens
1053
1053
if k_pe is None and not self .running_in_graph :
1054
- if not self .torchair_graph_enabled :
1055
- kv_c , k_pe = self .kv_a_proj_with_mqa (
1056
- hidden_states_or_kv_c_normed )[0 ].split (
1057
- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1058
- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1054
+ kv_c , k_pe = self .kv_a_proj_with_mqa (
1055
+ hidden_states_or_kv_c_normed )[0 ].split (
1056
+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1057
+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
1059
1058
else :
1060
1059
kv_c_normed = hidden_states_or_kv_c_normed
1061
1060
assert attn_metadata .num_decodes is not None and \
@@ -1074,12 +1073,13 @@ def forward(
1074
1073
if not self .running_in_graph :
1075
1074
hidden_states_or_q_c = hidden_states_or_q_c [:num_actual_toks , ...]
1076
1075
prefill_hs_or_q_c = hidden_states_or_q_c [num_decode_tokens :]
1077
- if not self .torchair_graph_enabled :
1078
- decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
1079
- k_pe = k_pe [:num_actual_toks , ...]
1080
- k_pe = k_pe .unsqueeze (1 )
1081
- decode_k_pe = k_pe [:num_decode_tokens ]
1082
- prefill_k_pe = k_pe [num_decode_tokens :]
1076
+ decode_hs_or_q_c = hidden_states_or_q_c [:num_decode_tokens ]
1077
+ prefill_hs = hidden_states_or_kv_c_normed [num_decode_tokens :]
1078
+ # if not self.torchair_graph_enabled:
1079
+ k_pe = k_pe [:num_actual_toks , ...]
1080
+ k_pe = k_pe .unsqueeze (1 )
1081
+ decode_k_pe = k_pe [:num_decode_tokens ]
1082
+ prefill_k_pe = k_pe [num_decode_tokens :]
1083
1083
else :
1084
1084
decode_hs_or_q_c = hidden_states_or_q_c
1085
1085
if has_decode :
@@ -1153,11 +1153,11 @@ def forward(
1153
1153
1154
1154
prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1155
1155
prefill_k_pe , prefill_k_nope = self .exec_kv_prefill (
1156
- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
1157
- attn_metadata .slot_mapping )
1156
+ prefill_hs , cos , sin , kv_cache ,
1157
+ attn_metadata .slot_mapping [ num_decode_tokens :] )
1158
1158
1159
1159
kv_c_normed = prefill_k_nope [:num_actual_toks , ...]
1160
- prefill_k_c_normed = prefill_k_nope [ num_decode_tokens :]
1160
+ prefill_k_c_normed = prefill_k_nope
1161
1161
prefill_k_pe = prefill_k_pe .view (num_tokens , self .num_kv_heads ,
1162
1162
- 1 )
1163
1163
prefill_q = torch .cat ([prefill_q_nope , prefill_q_pe ], dim = - 1 )
0 commit comments