Skip to content

Commit 1bd7442

Browse files
authored
handle GQA in convert_state_dict
1 parent e6dd6ef commit 1bd7442

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/neuronx_distributed_inference/models/phi3/modeling_phi3.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,16 @@ def convert_state_dict_to_neuron(phi3_state_dict, cfg: InferenceConfig):
130130
].clone().detach()
131131

132132
# Get the fused QKV weight
133-
fused_weight = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"].clone().detach()
133+
fused_attn = phi3_state_dict[f"layers.{l}.self_attn.qkv_proj.weight"].clone().detach()
134134
fused_gate_up = phi3_state_dict[f"layers.{l}.mlp.gate_up_proj.weight"].clone().detach()
135-
135+
# Potentially handle GQA
136+
if cfg.num_attention_heads > cfg.num_key_value_heads:
137+
q_features = cfg.hidden_size
138+
q_weight = fused_attn[:q_features]
139+
k_weight, v_weight = torch.chunk(fused_attn[q_features:], 2, dim=0)
136140
# Split the fused weight into Q, K, and V using torch.chunk
137-
q_weight, k_weight, v_weight = torch.chunk(fused_weight, 3, dim=0)
141+
else:
142+
q_weight, k_weight, v_weight = torch.chunk(fused_attn, 3, dim=0)
138143
gate, up = torch.chunk(fused_gate_up, 2, dim=0)
139144

140145
# Add the split weights to the state dict

0 commit comments

Comments
 (0)