Skip to content
244 changes: 188 additions & 56 deletions vllm/model_executor/layers/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@
from collections.abc import Callable
from typing import Any, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
Expand All @@ -31,6 +26,12 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_rtn_marlin_linear,
marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be
Expand All @@ -41,6 +42,9 @@
overridden by setting the RTN_GROUP_SIZE envvar
"""
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
"""Global Marlin workspace shared by all modules
"""
workspace = None


class RTNConfig(QuantizationConfig):
Expand All @@ -60,6 +64,10 @@ def __init__(
f"supported for RTN, but got {self.weight_bits} bits."
)

self.quant_type = (
scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
)

def __repr__(self) -> str:
return (
f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
Expand Down Expand Up @@ -221,24 +229,32 @@ def create_weights(
layer.output_size_per_partition = output_size_per_partition

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fix_weights(layer, "weight")
"""Repack weights and scales for Marlin kernels."""
weight_bits = self.quant_config.weight_bits

weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)

replace_parameter(layer, "weight", weight)
replace_parameter(layer, "scale", scale)

init_workspace(layer.weight.device)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
qweight = layer.weight
scale = layer.scale

weight = rtn_dequantize(qweight, scale)
out = F.linear(x, weight)
del weight
if bias is not None:
out.add_(bias)

return out
return apply_rtn_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.scale,
workspace=workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias,
)


class RTNMoEMethod(FusedMoEMethodBase):
Expand Down Expand Up @@ -315,28 +331,27 @@ def create_weights(
set_weight_attrs(w2_weight, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Repack weights and scales for Marlin kernels."""
weight_bits = self.quant_config.weight_bits
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4)

w13_weight, w13_scale = repack_weights(
layer.w13_weight, layer.w13_scale, weight_bits
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_scale", w13_scale)

w2_weight, w2_scale = repack_weights(
layer.w2_weight, layer.w2_scale, weight_bits
)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_scale", w2_scale)

init_workspace(layer.w13_weight.device)

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if weight_bits == 4
else int8_w8a16_moe_quant_config
)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
return None

def apply(
self,
Expand Down Expand Up @@ -366,8 +381,6 @@ def apply(
if enable_eplb:
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")

from vllm.model_executor.layers.fused_moe import fused_experts

topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -383,18 +396,22 @@ def apply(
indices_type=self.topk_indices_dtype,
)

return fused_experts(
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scale,
layer.w2_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
workspace=workspace,
)


Expand Down Expand Up @@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return input_deq


def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False):
"""torch.compile does not know how to deal with a Parameter subclass
(aka RTNParameter). As we don't really need RTNParameters for the
forward pass, we replace them with equivalent instances of Parameters.
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])

perm_arr = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
perm_tensor = torch.from_numpy(perm_arr)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm_tensor, scale_perm, scale_perm_single


_perm, _scale_perm, _scale_perm_single = _get_perms()


def pack_for_marlin(weight, scale, qbits):
batch = weight.shape[0]

n = weight.size(1)
k = weight.size(2)
groupsize = k // scale.size(2)

tile = 16
s = scale.permute(0, 2, 1) # transpose
w = weight.permute(0, 2, 1) # transpose
if groupsize != k:
w = w.reshape((batch, -1, groupsize, n))
w = w.permute(0, 2, 1, 3)
w = w.reshape((batch, groupsize, -1))
s = s.reshape((batch, 1, -1))

if groupsize != k:
w = w.reshape((batch, groupsize, -1, n))
w = w.permute(0, 2, 1, 3)
w = w.reshape((batch, k, n)).contiguous()
s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
else:
s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
s = s.reshape((batch, -1, n)).contiguous()
w = w.reshape((batch, k // tile, tile, n // tile, tile))
w = w.permute((0, 1, 3, 2, 4))
w = w.reshape((batch, k // tile, n * tile))
res = w
res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
if qbits == 4:
q = torch.zeros(
(batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
)
for i in range(2):
q |= res[:, :, i::2] << 4 * i
q = q.reshape(batch, -1, n).contiguous()
else:
q = res.clone()
q[:, :, 2::8] = res[:, :, 4::8]
q[:, :, 3::8] = res[:, :, 5::8]
q[:, :, 4::8] = res[:, :, 2::8]
q[:, :, 5::8] = res[:, :, 3::8]
q = q.reshape(batch, -1, n).to(torch.int8).contiguous()

return q, s


def repack_8bit_into_32bit(input):
output = torch.zeros(
(input.shape[0], input.shape[1], input.shape[2] // 4),
dtype=torch.int32,
device=input.device,
)
for i in range(4):
output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i

return output


def repack_weights(qweight, scale, weight_bits):
batch_present = len(qweight.shape) == 3
if not batch_present:
qweight = qweight.unsqueeze(0)
scale = scale.unsqueeze(0)

if weight_bits == 4:
"""Unpack two 4-bit values from each byte.
"""
qweight_unpacked = torch.empty(
(qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
dtype=torch.uint8,
device=qweight.device,
)
for i in range(2):
qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
)
else:
qweight_unpacked = qweight

qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
"""Marlin kernels expect tensors in int32 format in a certain shape
"""
old_weight = getattr(layer, param_name)
assert isinstance(old_weight, RTNParameter)
data = old_weight.data.data
qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
qweight_reshaped = qweight_repacked.reshape(
qweight.shape[0], qweight.shape[2] // 16, -1
)
if not batch_present:
qweight_reshaped = qweight_reshaped.squeeze(0)
scale_packed = scale_packed.squeeze(0)

return qweight_reshaped, scale_packed

delattr(layer, param_name)

if reshape:
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
new_weight = Parameter(data=data, requires_grad=False)
layer.register_parameter(param_name, new_weight)
def init_workspace(device):
global workspace
if workspace is None:
workspace = marlin_make_workspace_new(device, 4)
45 changes: 45 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
)

return output.reshape(out_shape)


def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)

use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)

output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)

return output.reshape(out_shape)