diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b02d..4dd2127d25dd 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -38,6 +38,32 @@ def silu_mul_replacement_static(result: torch.Tensor, return at[1] +def silu_mul_mxfp4_gemm_pattern(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default, + result=result, + x=at1[1], + weight=weight, + weight_scale=scale, + x_scales=None) + return at2[1] + + +def silu_mul_mxfp4_gemm_replacement(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default, + result=result, + x=input, + weight=weight, + weight_scale=scale) + return at[1] + + def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -51,6 +77,10 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") +def empty_fp4(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda") + + class ActivationQuantFusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -76,6 +106,17 @@ def __init__(self, config: VllmConfig): register_replacement(silu_mul_pattern_static, silu_mul_replacement_static, inputs, fwd_only, self.patterns) + + inputs = [ + empty_bf16(32, 32), # result + empty_bf16(32, 32), # result_silu_mul + empty_bf16(32, 32), # input + empty_fp4(32, 32), # weight + empty_fp4(32, 1), # scale + ] + register_replacement(silu_mul_mxfp4_gemm_pattern, + silu_mul_mxfp4_gemm_replacement, inputs, fwd_only, + self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 94c0698eb50c..a88b8f779b75 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -18,74 +18,94 @@ from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant + from aiter.ops.triton.activation import act_mul_and_mxfp4_quant from vllm.utils import direct_register_custom_op if envs.VLLM_TRITON_FP4_GEMM_USE_ASM: from aiter import gemm_a4w4, per_1x32_f4_quant_hip def gemm_with_dynamic_quant( + result: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, x_scales: torch.Tensor = None, - out_dtype: Optional[torch.dtype] = torch.bfloat16, - ) -> torch.Tensor: - M = x.shape[0] + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: if envs.VLLM_TRITON_FP4_GEMM_USE_ASM: + M = x.shape[0] if x_scales is None: # use hip quant kernel for performance x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) else: x_q = x x_s = x_scales - # 32 alignment is enough for dim0 padding of output for # gemm_a4w4 kernel y = torch.empty((M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype) - gemm_a4w4(x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True) - return y[:M] + result.copy_(y[:M]) else: if x_scales is None: x_q, x_s = dynamic_mxfp4_quant(x) else: x_q = x x_s = x_scales - y = torch.empty(x_q.shape[0], - weight.shape[0], - device=x_q.device, - dtype=out_dtype) - - gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) - return y + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, result) def gemm_with_dynamic_quant_fake( + result: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, x_scales: torch.Tensor = None, - out_dtype: Optional[torch.dtype] = torch.bfloat16, - ) -> torch.Tensor: - return torch.empty((*x.shape[:-1], weight.shape[0]), - dtype=out_dtype, - device=x.device) + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return direct_register_custom_op( op_name="gemm_with_dynamic_quant", op_func=gemm_with_dynamic_quant, - mutates_args=[], + mutates_args=['result'], fake_impl=gemm_with_dynamic_quant_fake, dispatch_key=current_platform.dispatch_key, ) + def silu_and_mul_mxfp4_gemm( + result: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + x_fp4, blockscale_e8m0 = act_mul_and_mxfp4_quant(x, 'silu') + gemm_with_dynamic_quant(result, x_fp4, weight, weight_scale, blockscale_e8m0, out_dtype) + + def silu_and_mul_mxfp4_gemm_fake( + result: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = torch.bfloat16 + ) -> None: + return + + direct_register_custom_op( + op_name="silu_and_mul_mxfp4_gemm", + op_func=silu_and_mul_mxfp4_gemm, + mutates_args=['result'], + fake_impl=silu_and_mul_mxfp4_gemm_fake, + dispatch_key=current_platform.dispatch_key, + ) + except ImportError: dynamic_mxfp4_quant = gemm_afp4wfp4 = None @@ -225,5 +245,7 @@ def apply_weights(self, return F.linear(x, dq_w, bias) else: - return torch.ops.vllm.gemm_with_dynamic_quant( - x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype) + result = torch.empty((*x.shape[:-1], layer.weight.shape[0]), dtype=self.out_dtype, device=x.device) + torch.ops.vllm.gemm_with_dynamic_quant( + result, x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype) + return result