Skip to content

Commit b33f11b

Browse files
committed
Fix precommit formatting
Signed-off-by: Po-Han Huang <[email protected]>
1 parent 527afb2 commit b33f11b

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -874,15 +874,12 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
874874
elif shard_id == "w2":
875875
param_data[expert_id] = loaded_weight
876876

877-
def _load_w13_weight_scale(self,
878-
shard_dim: int,
879-
loaded_weight: torch.Tensor,
880-
param: torch.Tensor,
881-
tp_rank: int):
877+
def _load_w13_weight_scale(self, shard_dim: int,
878+
loaded_weight: torch.Tensor,
879+
param: torch.Tensor, tp_rank: int):
882880
shard_size = param.shape[shard_dim]
883-
loaded_weight = loaded_weight.narrow(shard_dim,
884-
shard_size * tp_rank,
885-
shard_size)
881+
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
882+
shard_size)
886883
param.copy_(loaded_weight)
887884

888885
def _load_model_weight_or_group_weight_scale(self,
@@ -1135,12 +1132,10 @@ def weight_loader(self,
11351132
"weight_scale" in weight_name) or "input_scale" in weight_name
11361133

11371134
if "w13_weight_scale" in weight_name:
1138-
self._load_w13_weight_scale(
1139-
shard_dim=shard_dim,
1140-
loaded_weight=loaded_weight,
1141-
param=param,
1142-
tp_rank=self.tp_rank
1143-
)
1135+
self._load_w13_weight_scale(shard_dim=shard_dim,
1136+
loaded_weight=loaded_weight,
1137+
param=param,
1138+
tp_rank=self.tp_rank)
11441139
elif per_tensor_conditions:
11451140
self._load_per_tensor_weight_scale(
11461141
shard_id=shard_id,

vllm/model_executor/models/llama4.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -585,18 +585,15 @@ def permute_qk_weight_for_rotary(
585585
# Helper function to permute the weight's channels
586586
def permute(w: torch.Tensor, n_heads: int):
587587
head_dim = w.shape[0] // n_heads
588-
return (
589-
w.view(n_heads, head_dim // 2, 2, w.shape[1])
590-
.transpose(1, 2)
591-
.reshape(w.shape[0], w.shape[1])
592-
)
588+
return (w.view(n_heads, head_dim // 2, 2, w.shape[1]).transpose(
589+
1, 2).reshape(w.shape[0], w.shape[1]))
593590

594591
modules = name.split(".")
595592

596593
# Permute Q/K weights and weight block scales for rotary embedding
597594
is_weight = modules[-1] == "weight"
598-
is_nvfp4_weight_scale = (modules[-1] == "weight_scale"
599-
and loaded_weight.dtype == torch.float8_e4m3fn)
595+
is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
596+
loaded_weight.dtype == torch.float8_e4m3fn)
600597

601598
if is_weight or is_nvfp4_weight_scale:
602599
if ("wk" in modules or "k_proj" in modules):

0 commit comments

Comments
 (0)