diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 3d5dd6c49743..c041d2fd0ba4 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -6,21 +6,16 @@ from collections.abc import Callable from typing import Any, Optional +import numpy as np import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import ( - FusedMoE, - FusedMoEConfig, - FusedMoEMethodBase, -) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, - int4_w4a16_moe_quant_config, - int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -31,6 +26,12 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_rtn_marlin_linear, + marlin_make_workspace_new, +) +from vllm.scalar_type import scalar_types logger = init_logger(__name__) """By default, use 8 bit as target precision, but it can be @@ -41,6 +42,9 @@ overridden by setting the RTN_GROUP_SIZE envvar """ GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") +"""Global Marlin workspace shared by all modules +""" +workspace = None class RTNConfig(QuantizationConfig): @@ -60,6 +64,10 @@ def __init__( f"supported for RTN, but got {self.weight_bits} bits." ) + self.quant_type = ( + scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8 + ) + def __repr__(self) -> str: return ( f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})" @@ -221,7 +229,15 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - fix_weights(layer, "weight") + """Repack weights and scales for Marlin kernels.""" + weight_bits = self.quant_config.weight_bits + + weight, scale = repack_weights(layer.weight, layer.scale, weight_bits) + + replace_parameter(layer, "weight", weight) + replace_parameter(layer, "scale", scale) + + init_workspace(layer.weight.device) def apply( self, @@ -229,16 +245,16 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - qweight = layer.weight - scale = layer.scale - - weight = rtn_dequantize(qweight, scale) - out = F.linear(x, weight) - del weight - if bias is not None: - out.add_(bias) - - return out + return apply_rtn_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.scale, + workspace=workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias, + ) class RTNMoEMethod(FusedMoEMethodBase): @@ -315,28 +331,27 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Repack weights and scales for Marlin kernels.""" weight_bits = self.quant_config.weight_bits - fix_weights(layer, "w13_weight", weight_bits == 4) - fix_weights(layer, "w2_weight", weight_bits == 4) + + w13_weight, w13_scale = repack_weights( + layer.w13_weight, layer.w13_scale, weight_bits + ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w13_scale", w13_scale) + + w2_weight, w2_scale = repack_weights( + layer.w2_weight, layer.w2_scale, weight_bits + ) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w2_scale", w2_scale) + + init_workspace(layer.w13_weight.device) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - weight_bits = self.quant_config.weight_bits - group_size = self.quant_config.group_size - assert weight_bits == 4 or weight_bits == 8 - config_builder = ( - int4_w4a16_moe_quant_config - if weight_bits == 4 - else int8_w8a16_moe_quant_config - ) - return config_builder( - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, group_size], - ) + return None def apply( self, @@ -366,8 +381,6 @@ def apply( if enable_eplb: raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -383,18 +396,22 @@ def apply( indices_type=self.topk_indices_dtype, ) - return fused_experts( + return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), + layer.w13_scale, + layer.w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_config.quant_type.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - quant_config=self.moe_quant_config, + workspace=workspace, ) @@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return input_deq -def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False): - """torch.compile does not know how to deal with a Parameter subclass - (aka RTNParameter). As we don't really need RTNParameters for the - forward pass, we replace them with equivalent instances of Parameters. +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm_arr = np.array(perm) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel() + perm_tensor = torch.from_numpy(perm_arr) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm_tensor, scale_perm, scale_perm_single + + +_perm, _scale_perm, _scale_perm_single = _get_perms() + + +def pack_for_marlin(weight, scale, qbits): + batch = weight.shape[0] + + n = weight.size(1) + k = weight.size(2) + groupsize = k // scale.size(2) + + tile = 16 + s = scale.permute(0, 2, 1) # transpose + w = weight.permute(0, 2, 1) # transpose + if groupsize != k: + w = w.reshape((batch, -1, groupsize, n)) + w = w.permute(0, 2, 1, 3) + w = w.reshape((batch, groupsize, -1)) + s = s.reshape((batch, 1, -1)) + + if groupsize != k: + w = w.reshape((batch, groupsize, -1, n)) + w = w.permute(0, 2, 1, 3) + w = w.reshape((batch, k, n)).contiguous() + s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm] + else: + s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single] + s = s.reshape((batch, -1, n)).contiguous() + w = w.reshape((batch, k // tile, tile, n // tile, tile)) + w = w.permute((0, 1, 3, 2, 4)) + w = w.reshape((batch, k // tile, n * tile)) + res = w + res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape) + if qbits == 4: + q = torch.zeros( + (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device + ) + for i in range(2): + q |= res[:, :, i::2] << 4 * i + q = q.reshape(batch, -1, n).contiguous() + else: + q = res.clone() + q[:, :, 2::8] = res[:, :, 4::8] + q[:, :, 3::8] = res[:, :, 5::8] + q[:, :, 4::8] = res[:, :, 2::8] + q[:, :, 5::8] = res[:, :, 3::8] + q = q.reshape(batch, -1, n).to(torch.int8).contiguous() + + return q, s + + +def repack_8bit_into_32bit(input): + output = torch.zeros( + (input.shape[0], input.shape[1], input.shape[2] // 4), + dtype=torch.int32, + device=input.device, + ) + for i in range(4): + output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i + + return output + + +def repack_weights(qweight, scale, weight_bits): + batch_present = len(qweight.shape) == 3 + if not batch_present: + qweight = qweight.unsqueeze(0) + scale = scale.unsqueeze(0) + + if weight_bits == 4: + """Unpack two 4-bit values from each byte. + """ + qweight_unpacked = torch.empty( + (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]), + dtype=torch.uint8, + device=qweight.device, + ) + for i in range(2): + qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape( + qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2 + ) + else: + qweight_unpacked = qweight + + qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits) + """Marlin kernels expect tensors in int32 format in a certain shape """ - old_weight = getattr(layer, param_name) - assert isinstance(old_weight, RTNParameter) - data = old_weight.data.data + qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8)) + qweight_reshaped = qweight_repacked.reshape( + qweight.shape[0], qweight.shape[2] // 16, -1 + ) + if not batch_present: + qweight_reshaped = qweight_reshaped.squeeze(0) + scale_packed = scale_packed.squeeze(0) + + return qweight_reshaped, scale_packed - delattr(layer, param_name) - if reshape: - data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) - new_weight = Parameter(data=data, requires_grad=False) - layer.register_parameter(param_name, new_weight) +def init_workspace(device): + global workspace + if workspace is None: + workspace = marlin_make_workspace_new(device, 4) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index fd6b581d2b90..071fb4ba1686 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -528,3 +528,48 @@ def apply_awq_marlin_linear( ) return output.reshape(out_shape) + + +def apply_rtn_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: torch.Tensor | None = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + None, + None, + None, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + return output.reshape(out_shape)