|
30 | 30 | TopKWeightAndReduceNoOP)
|
31 | 31 | from vllm.model_executor.layers.fused_moe.utils import (
|
32 | 32 | _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
|
| 33 | +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( |
| 34 | + calculate_tile_tokens_dim) |
33 | 35 | from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
34 | 36 | dequant_mxfp4)
|
35 | 37 | from vllm.platforms import current_platform
|
@@ -1065,22 +1067,6 @@ def inplace_fused_experts_fake(
|
1065 | 1067 | )
|
1066 | 1068 |
|
1067 | 1069 |
|
1068 |
| -def next_positive_power_of_2(x: int) -> int: |
1069 |
| - if x < 1: |
1070 |
| - return 1 |
1071 |
| - return 1 << (x - 1).bit_length() |
1072 |
| - |
1073 |
| - |
1074 |
| -def _get_tile_tokens_dim(num_tokens, top_k, num_experts): |
1075 |
| - # Guess tokens per expert assuming perfect expert distribution first. |
1076 |
| - num_tokens_per_expert = (num_tokens * top_k) // num_experts |
1077 |
| - # And pad the number to the next power of 2. |
1078 |
| - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) |
1079 |
| - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. |
1080 |
| - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) |
1081 |
| - return tile_tokens_dim |
1082 |
| - |
1083 |
| - |
1084 | 1070 | def flashinfer_fused_moe_blockscale_fp8(
|
1085 | 1071 | routing_logits: torch.Tensor,
|
1086 | 1072 | routing_bias: torch.Tensor,
|
@@ -1128,8 +1114,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
1128 | 1114 | local_expert_offset=expert_offset,
|
1129 | 1115 | local_num_experts=local_num_experts,
|
1130 | 1116 | routed_scaling_factor=routed_scaling,
|
1131 |
| - tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, |
1132 |
| - global_num_experts), |
| 1117 | + tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, |
| 1118 | + global_num_experts), |
1133 | 1119 | routing_method_type=2, # DeepSeek-styled routing method
|
1134 | 1120 | use_shuffled_weight=False,
|
1135 | 1121 | )
|
@@ -1164,6 +1150,97 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
1164 | 1150 | )
|
1165 | 1151 |
|
1166 | 1152 |
|
| 1153 | +def flashinfer_fused_moe_per_tensor_scale_fp8( |
| 1154 | + routing_logits: torch.Tensor, |
| 1155 | + routing_bias: Optional[torch.Tensor], |
| 1156 | + hidden_states: torch.Tensor, |
| 1157 | + input_scale: torch.Tensor, |
| 1158 | + gemm1_weights: torch.Tensor, |
| 1159 | + gemm1_weights_scale: torch.Tensor, |
| 1160 | + activation_scale: torch.Tensor, |
| 1161 | + gemm2_weights: torch.Tensor, |
| 1162 | + gemm2_weights_scale: torch.Tensor, |
| 1163 | + num_experts: int, |
| 1164 | + top_k: int, |
| 1165 | + num_expert_group: Optional[int], |
| 1166 | + topk_group: Optional[int], |
| 1167 | + intermediate_size: int, |
| 1168 | + local_expert_offset: int, |
| 1169 | + local_num_experts: int, |
| 1170 | + use_routing_scales_on_input: bool, |
| 1171 | + routing_method_type: int, |
| 1172 | + routed_scaling_factor: float = 1.0) -> torch.Tensor: |
| 1173 | + num_expert_group = num_expert_group if num_expert_group is not None else 0 |
| 1174 | + topk_group = topk_group if topk_group is not None else 0 |
| 1175 | + |
| 1176 | + quant_hidden_states, input_scale = moe_kernel_quantize_input( |
| 1177 | + hidden_states, |
| 1178 | + input_scale, |
| 1179 | + quant_dtype=torch.float8_e4m3fn, |
| 1180 | + per_act_token_quant=False) |
| 1181 | + |
| 1182 | + output1_scales_scalar = gemm1_weights_scale * input_scale * ( |
| 1183 | + 1.0 / activation_scale) |
| 1184 | + output1_scales_gate_scalar = gemm1_weights_scale * input_scale |
| 1185 | + output2_scales_scalar = activation_scale * gemm2_weights_scale |
| 1186 | + |
| 1187 | + from vllm.utils.flashinfer import ( |
| 1188 | + flashinfer_trtllm_fp8_per_tensor_scale_moe) |
| 1189 | + return flashinfer_trtllm_fp8_per_tensor_scale_moe( |
| 1190 | + routing_logits=routing_logits, |
| 1191 | + routing_bias=routing_bias, |
| 1192 | + hidden_states=quant_hidden_states, |
| 1193 | + gemm1_weights=gemm1_weights, |
| 1194 | + output1_scales_scalar=output1_scales_scalar, |
| 1195 | + output1_scales_gate_scalar=output1_scales_gate_scalar, |
| 1196 | + gemm2_weights=gemm2_weights, |
| 1197 | + output2_scales_scalar=output2_scales_scalar, |
| 1198 | + num_experts=num_experts, |
| 1199 | + top_k=top_k, |
| 1200 | + n_group=num_expert_group, |
| 1201 | + topk_group=topk_group, |
| 1202 | + intermediate_size=intermediate_size, |
| 1203 | + local_expert_offset=local_expert_offset, |
| 1204 | + local_num_experts=local_num_experts, |
| 1205 | + routed_scaling_factor=routed_scaling_factor, |
| 1206 | + use_routing_scales_on_input=use_routing_scales_on_input, |
| 1207 | + tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], |
| 1208 | + top_k, num_experts), |
| 1209 | + routing_method_type=routing_method_type) |
| 1210 | + |
| 1211 | + |
| 1212 | +def flashinfer_fused_moe_per_tensor_scale_fp8_fake( |
| 1213 | + routing_logits: torch.Tensor, |
| 1214 | + routing_bias: torch.Tensor, |
| 1215 | + hidden_states: torch.Tensor, |
| 1216 | + gemm1_weights: torch.Tensor, |
| 1217 | + output1_scales_scalar: torch.Tensor, |
| 1218 | + output1_scales_gate_scalar: torch.Tensor, |
| 1219 | + gemm2_weights: torch.Tensor, |
| 1220 | + output2_scales_scalar: torch.Tensor, |
| 1221 | + num_experts: int, |
| 1222 | + top_k: int, |
| 1223 | + num_expert_group: int, |
| 1224 | + topk_group: int, |
| 1225 | + intermediate_size: int, |
| 1226 | + local_expert_offset: int, |
| 1227 | + local_num_experts: int, |
| 1228 | + routed_scaling_factor: float = 1.0, |
| 1229 | + use_routing_scales_on_input: bool = False, |
| 1230 | + tile_tokens_dim: int = 8, |
| 1231 | + routing_method_type: int = 0) -> torch.Tensor: |
| 1232 | + pass |
| 1233 | + |
| 1234 | + |
| 1235 | +direct_register_custom_op( |
| 1236 | + op_name="flashinfer_fused_moe_per_tensor_scale_fp8", |
| 1237 | + op_func=flashinfer_fused_moe_per_tensor_scale_fp8, |
| 1238 | + mutates_args=["hidden_states"], |
| 1239 | + fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, |
| 1240 | + tags=(torch.Tag.needs_fixed_stride_order, ), |
| 1241 | +) |
| 1242 | + |
| 1243 | + |
1167 | 1244 | def outplace_fused_experts(
|
1168 | 1245 | hidden_states: torch.Tensor,
|
1169 | 1246 | w1: torch.Tensor,
|
|
0 commit comments