Skip to content
Draft
48 changes: 35 additions & 13 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ class DeepseekV3MoECalibrate(torch.nn.Module):
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
"""

def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
def __init__(
self,
config: DeepseekV3Config,
original: OriginalDeepseekV3MoE,
calibrate_all_experts: bool,
):
super().__init__()
self.config = config
self.experts = original.experts
self.gate = original.gate
self.shared_experts = original.shared_experts
self.calibrate_all_experts = calibrate_all_experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
Expand All @@ -30,24 +36,40 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
)
expert_mask = expert_mask.permute(2, 0, 1)

for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)
for expert_idx, expert in enumerate(self.experts):
token_indices, weight_indices = torch.where(expert_mask[expert_idx])
has_tokens = token_indices.numel() > 0

expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
if self.calibrate_all_experts:
expert_input = hidden_states
expert_output = expert(expert_input)

if token_indices.numel() > 0:
final_hidden_states.index_add_(0, token_indices, weighted_output)
if has_tokens:
expert_weights = topk_weights[token_indices, weight_indices]
routed_output = expert_output[
token_indices
] * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, routed_output)
else:
# Normal MoE: only process tokens routed to this expert
if has_tokens:
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
expert_weights = topk_weights[token_indices, weight_indices]
routed_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, routed_output)
# End MoE

hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states


def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
return DeepseekV3MoECalibrate(config=config, original=module)
def replace(
config: DeepseekV3Config,
module: OriginalDeepseekV3MoE,
calibrate_all_experts: bool,
):
return DeepseekV3MoECalibrate(
config=config, original=module, calibrate_all_experts=calibrate_all_experts
)
64 changes: 42 additions & 22 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Tuple

import torch
import transformers
from packaging import version
from transformers.models.llama4.configuration_llama4 import (
Llama4Config,
Llama4TextConfig,
Expand All @@ -17,39 +15,57 @@


class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
def __init__(
self,
config: Llama4TextConfig,
original: Llama4TextMoe,
calibrate_all_experts: bool,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts

self.experts = SequentialLlama4TextExperts(config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert
self.calibrate_all_experts = calibrate_all_experts

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
router_outputs = self.router(hidden_states)

# support transformers 4.53 and greater
if isinstance(router_logits, tuple):
router_logits = router_logits[-1]
if isinstance(router_outputs, tuple):
router_scores, router_logits = router_outputs
else:
router_top_value, router_indices = torch.topk(
router_logits, self.top_k, dim=1
)
router_logits = router_outputs
router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)

router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
out = self.shared_expert(hidden_states)
for expert_index in range(self.num_experts):
top_token_mask = router_scores[:, expert_index] > 0

router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
if self.calibrate_all_experts:
# Run all tokens for calibration
expert_out = self.experts[expert_index](hidden_states)[top_token_mask]
else:
expert_out = self.experts[expert_index](hidden_states[top_token_mask])

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
# Only top-k tokens contribute to final output
if top_token_mask.any():
expert_score = router_scores[top_token_mask, expert_index].unsqueeze(-1)
out[top_token_mask] += expert_out * expert_score

if version.parse(transformers.__version__) >= version.parse("4.54.0"):
return out, router_logits
else:
return out, router_scores
return out, router_scores


class SequentialLlama4TextExperts(torch.nn.ModuleList):
Expand All @@ -72,5 +88,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
self[i].down_proj.weight.data = down.t().clone().contiguous()


def replace(config: Llama4Config, module: Llama4TextMoe):
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool):
return SequentialLlama4TextMoe(
config=config.get_text_config(),
original=module,
calibrate_all_experts=calibrate_all_experts,
)
30 changes: 23 additions & 7 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tqdm
from compressed_tensors.utils import replace_module
from transformers import PreTrainedModel

Expand All @@ -15,11 +16,18 @@
}


def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
for name, module in model.named_modules():
def replace_modules_for_calibration(
model: PreTrainedModel,
calibrate_all_experts: bool = True,
) -> PreTrainedModel:
for name, module in tqdm.tqdm(list(model.named_modules())):
cls_name = module.__class__.__name__
if cls_name in replacements:
new_module = replacements[cls_name](config=model.config, module=module)
new_module = replacements[cls_name](
config=model.config,
module=module,
calibrate_all_experts=calibrate_all_experts,
)
replace_module(model, name, new_module)

return model
Expand All @@ -28,7 +36,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
# ------------------- module replacements; during calibration --------------------


def update_qwen3_moe(model, stack):
def update_qwen3_moe(model, stack, calibrate_all_experts):
for module in model.modules():
cls_name = module.__class__.__name__
if cls_name == "Qwen3MoeDecoderLayer":
Expand All @@ -37,7 +45,11 @@ def update_qwen3_moe(model, stack):
patch_attr(
module,
"mlp",
replace_Qwen3MoE(config=model.config, module=module.mlp),
replace_Qwen3MoE(
config=model.config,
module=module.mlp,
calibrate_all_experts=calibrate_all_experts,
),
)
)

Expand All @@ -47,9 +59,13 @@ def update_qwen3_moe(model, stack):
}


def moe_calibration_context(model: PreTrainedModel, stack):
def moe_calibration_context(
model: PreTrainedModel,
stack,
calibrate_all_experts: bool = False,
):
# Temporarily updates the MoE modules within the context
# Once the context exists, parameter updates persist
cls_name = model.__class__.__name__
if cls_name in moe_context:
moe_context.get(cls_name)(model, stack)
moe_context.get(cls_name)(model, stack, calibrate_all_experts)
46 changes: 28 additions & 18 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@

class Qwen3MoeSparseMoeBlock(torch.nn.Module):
def __init__(
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
self,
config: Qwen3MoeConfig,
original: OriginalQwen3MoeSparseMoeBlock,
calibrate_all_experts: bool,
):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob

# gating
self.calibrate_all_experts = calibrate_all_experts
self.gate = original.gate
self.experts = original.experts

Expand All @@ -50,6 +53,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
Expand All @@ -62,26 +66,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)

for expert_idx in range(len(self.experts)):
expert_layer = self.experts[expert_idx]
for expert_idx, expert_layer in enumerate(self.experts):
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

if self.calibrate_all_experts:
expert_out = expert_layer(hidden_states)[top_x]
else:
expert_out = expert_layer(hidden_states[top_x])

# TODO: double check
if len(top_x) > 0:
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits


def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
return Qwen3MoeSparseMoeBlock(config=config, original=module)
def replace(
config: Qwen3MoeConfig,
module: OriginalQwen3MoeSparseMoeBlock,
calibrate_all_experts: bool,
):
return Qwen3MoeSparseMoeBlock(
config=config, original=module, calibrate_all_experts=calibrate_all_experts
)
Loading
Loading