@@ -874,15 +874,12 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
874
874
elif shard_id == "w2" :
875
875
param_data [expert_id ] = loaded_weight
876
876
877
- def _load_w13_weight_scale (self ,
878
- shard_dim : int ,
879
- loaded_weight : torch .Tensor ,
880
- param : torch .Tensor ,
881
- tp_rank : int ):
877
+ def _load_w13_weight_scale (self , shard_dim : int ,
878
+ loaded_weight : torch .Tensor ,
879
+ param : torch .Tensor , tp_rank : int ):
882
880
shard_size = param .shape [shard_dim ]
883
- loaded_weight = loaded_weight .narrow (shard_dim ,
884
- shard_size * tp_rank ,
885
- shard_size )
881
+ loaded_weight = loaded_weight .narrow (shard_dim , shard_size * tp_rank ,
882
+ shard_size )
886
883
param .copy_ (loaded_weight )
887
884
888
885
def _load_model_weight_or_group_weight_scale (self ,
@@ -1135,12 +1132,10 @@ def weight_loader(self,
1135
1132
"weight_scale" in weight_name ) or "input_scale" in weight_name
1136
1133
1137
1134
if "w13_weight_scale" in weight_name :
1138
- self ._load_w13_weight_scale (
1139
- shard_dim = shard_dim ,
1140
- loaded_weight = loaded_weight ,
1141
- param = param ,
1142
- tp_rank = self .tp_rank
1143
- )
1135
+ self ._load_w13_weight_scale (shard_dim = shard_dim ,
1136
+ loaded_weight = loaded_weight ,
1137
+ param = param ,
1138
+ tp_rank = self .tp_rank )
1144
1139
elif per_tensor_conditions :
1145
1140
self ._load_per_tensor_weight_scale (
1146
1141
shard_id = shard_id ,
0 commit comments