Skip to content
Open
75 changes: 74 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,77 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM"}
class SplitGateUpWeightsTransformGPTOSS(PytorchTransform):
"""
split fused Gate+Up weights and copy into the model

For every transformer layer inside `model`:
• expects <PREFIX>.experts.gate_up_proj in the *source* `sd`
• copies halves into
<PREFIX>.experts.gate_proj <-- Gate [E,H,I]
<PREFIX>.experts.up_proj <-- Up [E,H,I]
"""

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__

if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS:
return model, transformed

model_tmp = model.language_model if hasattr(model, "language_model") else model
num_layers = len(model_tmp.model.layers)
delete_fused_key = True
sd = model_tmp.state_dict()

for layer_idx in range(num_layers):
# ---- build the textual prefix once per layer ----------
prefix = f"model.layers.{layer_idx}.mlp.experts."
fused_key = prefix + "gate_up_proj"
fused_bias_key = prefix + "gate_up_proj_bias"
gate_key = prefix + "gate_proj"
up_key = prefix + "up_proj"
gate_bias_key = prefix + "gate_proj_bias"
up_bias_key = prefix + "up_proj_bias"

# ---- split [E,H,2I] → two [E,H,I] tensors ----------------------
fused = sd[fused_key] # [E, H, 2I]
fused_bias = sd[fused_bias_key] # [E, 2I]
E, H, two_I = fused.shape
# ffn_dim = two_I // 2

# For GptOss, gate/up are interleaved: [gate0, up0, gate1, up1, ...]
gate = fused[..., ::2] # [E, H, I] - even indices
up = fused[..., 1::2] # [E, H, I] - odd indices
gate_bias = fused_bias[..., ::2] # [E, I] - even indices
up_bias = fused_bias[..., 1::2] # [E, I] - odd indices

experts = model_tmp.model.layers[layer_idx].mlp.experts
experts.gate_proj.data.copy_(gate)
experts.up_proj.data.copy_(up)
experts.gate_proj_bias.data.copy_(gate_bias)
experts.up_proj_bias.data.copy_(up_bias)

# ---- update the state-dict so load_state_dict sees the right keys
sd[gate_key] = gate
sd[up_key] = up
sd[gate_bias_key] = gate_bias
sd[up_bias_key] = up_bias

if delete_fused_key:
del sd[fused_key]
del sd[fused_bias_key]

logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
transformed = True

if hasattr(model, "language_model"):
model.language_model = model_tmp
else:
model = model_tmp

return model, transformed


VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration", "QEffLlama4ForCausalLM", "QEffGptOssForCausalLM"}
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
Loading
Loading