File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -199,7 +199,7 @@ def apply_mlp(hidden_states: torch.Tensor,
199
199
group_list = torch .cat (
200
200
[group_list [:1 ], torch .diff (group_list , dim = 0 )])
201
201
group_list_type = 1
202
- bias1 = w1_scale_bias
202
+ bias1 = [ w1_scale_bias ] if is_torchair else w1_scale_bias
203
203
bias2 = [w2_scale_bias ]
204
204
# TODO w4a8 scene: dynamic acquisition of dtype in the future
205
205
_output_dtype = torch .bfloat16
@@ -219,7 +219,7 @@ def apply_mlp(hidden_states: torch.Tensor,
219
219
x = [hidden_states ],
220
220
weight = [w1 ],
221
221
scale = [w1_scale .to (w2_scale .dtype )],
222
- bias = [ bias1 ] if isinstance ( bias1 , torch . Tensor ) else bias1 ,
222
+ bias = bias1 ,
223
223
per_token_scale = [pertoken_scale ],
224
224
split_item = 2 ,
225
225
group_list_type = group_list_type ,
You can’t perform that action at this time.
0 commit comments