Skip to content
Open
78 changes: 63 additions & 15 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,61 +120,109 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:

class SplitGateUpWeightsTransform(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model
Split fused Gate+Up weights and copy into the model.
Handles both standard MoE models and GptOss models.

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]

Handles both interleaved weights (GptOss) and concatenated weights (standard MoE).
Also handles bias terms when present.
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model

num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()

for layer_idx in range(num_layers):
# Determine if this is a GptOss model or standard MoE model
is_gpt_oss = hasattr(model_tmp.model.layers[layer_idx], "mlp")

# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
if is_gpt_oss:
prefix = f"model.layers.{layer_idx}.mlp.experts."
experts = model_tmp.model.layers[layer_idx].mlp.experts
else:
prefix = f"model.layers.{layer_idx}.feed_forward.experts."
experts = model_tmp.model.layers[layer_idx].feed_forward.experts

fused_key = prefix + "gate_up_proj"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I] (no .weight here)
# Check if we have bias terms (GptOss case)
has_bias = fused_key + "_bias" in sd
if has_bias:
fused_bias_key = fused_key + "_bias"
gate_bias_key = gate_key + "_bias"
up_bias_key = up_key + "_bias"

# ---- split weights based on model type ----------------------
fused = sd[fused_key] # [E, H, 2I]
E, H, two_I = fused.shape
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

experts = model_tmp.model.layers[layer_idx].feed_forward.experts
if is_gpt_oss:
# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
gate = fused[..., ::2] # [E, H, I] - even indices
up = fused[..., 1::2] # [E, H, I] - odd indices
else:
# For standard MoE, gate/up are concatenated: [gate, up]
ffn_dim = two_I // 2
gate, up = fused.split(ffn_dim, dim=-1) # views – no copy

# Copy weights to model
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)

# Handle bias if present
if has_bias:
fused_bias = sd[fused_bias_key] # [E, 2I]

if is_gpt_oss:
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices
else:
ffn_dim = fused_bias.shape[-1] // 2
gate_bias, up_bias = fused_bias.split(ffn_dim, dim=-1)

experts.gate_proj_bias.data.copy_(gate_bias)
experts.up_proj_bias.data.copy_(up_bias)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up

if has_bias:
sd[gate_bias_key] = gate_bias
sd[up_bias_key] = up_bias

# Delete fused keys
if delete_fused_key:
del sd[fused_key]
if has_bias:
del sd[fused_bias_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp

return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
# Keep the existing list of supported models
VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}
Loading
Loading