Skip to content

[Qwix] Correctly plumb quantization_rule to kernel #2206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ save_quantized_params_path: ""
model_call_mode: ""
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix.
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
quantization_calibration_method: "absmax"
fwd_weight_calibration_method: "absmax"
fwd_act_calibration_method: "absmax"
dlhs_lhs_calibration_method: "absmax"
dlhs_rhs_calibration_method: "absmax"
drhs_lhs_calibration_method: "absmax"
drhs_rhs_calibration_method: "absmax"
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1

Expand Down
21 changes: 16 additions & 5 deletions MaxText/kernels/megablox/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from qwix.pallas import QArray
import qwix.pallas as qpl
import qwix

from MaxText.kernels.megablox import common

Expand Down Expand Up @@ -305,6 +306,9 @@ def _zero_uninitialized_memory(
"interpret",
"lhs_quantize_dtype",
"rhs_quantize_dtype",
"lhs_calibration_method",
"rhs_calibration_method",
"quantization_rule",
"use_qwix_quantization",
],
)
Expand All @@ -320,6 +324,9 @@ def gmm(
interpret: bool = False,
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
lhs_calibration_method: str = "absmax",
rhs_calibration_method: str = "absmax",
quantization_rule: qwix.QtRule | None = None,
use_qwix_quantization: bool = False,
) -> jnp.ndarray:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Expand Down Expand Up @@ -604,7 +611,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):

if lhs_quantize_dtype is not None:
if use_qwix_quantization:
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, channelwise_axes=[0], scale_dtype=jnp.float32)
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, channelwise_axes=[0], scale_dtype=jnp.float32, calibration_method=lhs_calibration_method)
else:
lhs_quantize_bits = 4 if lhs_quantize_dtype == jnp.int4 else 8
lhs = aqt_pl.quant(lhs, lhs_quantize_bits, lhs_contracting_axis)
Expand All @@ -613,7 +620,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
# Use per-channel scales for non-contracting axes, i.e., num_groups, m, n but not k.
if use_qwix_quantization:
rhs = qpl.quantize(
rhs, qtype=rhs_quantize_dtype, channelwise_axes=[0, 1 if transpose_rhs else 2], scale_dtype=jnp.float32
rhs, qtype=rhs_quantize_dtype, channelwise_axes=[0, 1 if transpose_rhs else 2], scale_dtype=jnp.float32, calibration_method=rhs_calibration_method
)
else:
rhs_quantize_bits = 4 if rhs_quantize_dtype == jnp.int4 else 8
Expand Down Expand Up @@ -645,6 +652,8 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
"interpret",
"lhs_quantize_dtype",
"rhs_quantize_dtype",
"lhs_calibration_method",
"rhs_calibration_method",
"use_qwix_quantization",
],
)
Expand All @@ -660,6 +669,8 @@ def tgmm(
interpret: bool = False,
lhs_quantize_dtype=None,
rhs_quantize_dtype=None,
lhs_calibration_method: str = "absmax",
rhs_calibration_method: str = "absmax",
use_qwix_quantization: bool = False,
) -> jnp.ndarray:
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
Expand Down Expand Up @@ -780,7 +791,7 @@ def _do():
qvalue = lax.select(
rhs_mask[...],
rhs.qvalue[...],
jnp.zeros_like(rhs.qvalue, lhs.qvalue.dtype),
jnp.zeros_like(rhs.qvalue, rhs.qvalue.dtype),
)
loaded_rhs = dataclasses.replace(loaded_rhs, qvalue=qvalue)
else:
Expand Down Expand Up @@ -878,11 +889,11 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
)

if use_qwix_quantization and lhs_quantize_dtype is not None:
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, scale_dtype=jnp.float32)
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, scale_dtype=jnp.float32, calibration_method=lhs_calibration_method)

if use_qwix_quantization and rhs_quantize_dtype is not None:
# Use per-channel scales for non-contracting axes, i.e., num_groups, m, n but not k.
rhs = qpl.quantize(rhs, qtype=rhs_quantize_dtype, scale_dtype=jnp.float32)
rhs = qpl.quantize(rhs, qtype=rhs_quantize_dtype, scale_dtype=jnp.float32, calibration_method=rhs_calibration_method)

out = call_gmm(
group_metadata,
Expand Down
61 changes: 35 additions & 26 deletions MaxText/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,9 @@

gmm = jax.custom_vjp(
backend.gmm,
nondiff_argnums=(3, 4, 7, 8, 9, 10, 11),
nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12, 13, 14),
)

def _get_current_rule(op_name: str):
rule = qpl.get_current_rule(op_name)
if rule is not None and not isinstance(rule, qwix.QtRule):
rule = qwix.QtRule(**dataclasses.asdict(rule))
return rule

def _gmm_fwd(
lhs: jnp.ndarray,
rhs: jnp.ndarray | aqt_tensor.QTensor,
Expand All @@ -51,6 +45,9 @@ def _gmm_fwd(
interpret: bool = False,
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
lhs_calibration_method: str = "absmax",
rhs_calibration_method: str = "absmax",
quantization_rule: qwix.QtRule | None = None,
use_qwix_quantization: bool = False,
) -> tuple[
jnp.ndarray,
Expand All @@ -65,10 +62,11 @@ def _gmm_fwd(
"""Forward function for GMM VJP."""
if use_qwix_quantization:
lhs_quantize_dtype, rhs_quantize_dtype = None, None
rule = _get_current_rule("dot_general")
if rule is not None:
lhs_quantize_dtype = rule.act_qtype
rhs_quantize_dtype = rule.weight_qtype
if quantization_rule is not None:
lhs_quantize_dtype = quantization_rule.act_qtype
rhs_quantize_dtype = quantization_rule.weight_qtype
lhs_calibration_method = quantization_rule.act_calibration_method
rhs_calibration_method = quantization_rule.weight_calibration_method
out = backend.gmm(
lhs,
rhs,
Expand All @@ -81,6 +79,8 @@ def _gmm_fwd(
interpret=interpret,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
lhs_calibration_method=lhs_calibration_method,
rhs_calibration_method=rhs_calibration_method,
use_qwix_quantization=use_qwix_quantization,
)
return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0])
Expand All @@ -93,6 +93,9 @@ def _gmm_bwd(
interpret: bool,
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None,
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None,
lhs_calibration_method: str,
rhs_calibration_method: str,
quantization_rule: qwix.QtRule | None,
use_qwix_quantization: bool,
residual: tuple[
jnp.ndarray,
Expand All @@ -106,14 +109,15 @@ def _gmm_bwd(
"""Backward function for throughput GMM VJP."""
if use_qwix_quantization:
lhs_quantize_dtype, rhs_quantize_dtype = None, None
rule = _get_current_rule("dot_general")
if rule is not None:
if rule.additional_qt_config is not None:
lhs_quantize_dtype = rule.additional_qt_config["dlhs_lhs_qtype"]
rhs_quantize_dtype = rule.additional_qt_config["dlhs_rhs_qtype"]
if quantization_rule is not None:
if quantization_rule.additional_qt_config is not None:
lhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_lhs_qtype"]
rhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_rhs_qtype"]
else:
lhs_quantize_dtype = rule.act_qtype
rhs_quantize_dtype = rule.bwd_qtype
lhs_quantize_dtype = quantization_rule.act_qtype
rhs_quantize_dtype = quantization_rule.bwd_qtype
lhs_calibration_method = quantization_rule.additional_qt_config["dlhs_lhs_calibration_method"]
rhs_calibration_method = quantization_rule.additional_qt_config["dlhs_rhs_calibration_method"]
del preferred_element_type
lhs, rhs, group_sizes, group_offset, num_actual_groups = residual
grad_lhs = backend.gmm(
Expand All @@ -127,18 +131,21 @@ def _gmm_bwd(
interpret=interpret,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=use_qwix_quantization,
lhs_calibration_method=lhs_calibration_method,
rhs_calibration_method=rhs_calibration_method,
use_qwix_quantization=use_qwix_quantization,
)
if use_qwix_quantization:
lhs_quantize_dtype, rhs_quantize_dtype = None, None
rule = _get_current_rule("dot_general")
if rule is not None:
if rule.additional_qt_config is not None:
lhs_quantize_dtype = rule.additional_qt_config["drhs_lhs_qtype"]
rhs_quantize_dtype = rule.additional_qt_config["drhs_rhs_qtype"]
if quantization_rule is not None:
if quantization_rule.additional_qt_config is not None:
lhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_lhs_qtype"]
rhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_rhs_qtype"]
else:
lhs_quantize_dtype = rule.bwd_qtype
rhs_quantize_dtype = rule.act_qtype
lhs_quantize_dtype = quantization_rule.bwd_qtype
rhs_quantize_dtype = quantization_rule.act_qtype
lhs_calibration_method = quantization_rule.additional_qt_config["drhs_lhs_calibration_method"]
rhs_calibration_method = quantization_rule.additional_qt_config["drhs_rhs_calibration_method"]
grad_rhs = backend.tgmm(
lhs.swapaxes(0, 1),
grad,
Expand All @@ -150,6 +157,8 @@ def _gmm_bwd(
interpret=interpret,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
lhs_calibration_method=lhs_calibration_method,
rhs_calibration_method=rhs_calibration_method,
use_qwix_quantization=use_qwix_quantization,
)

Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
quant_dg = self.quant.quant_dg
lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype()
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
quantization_rule=None
if self.config.use_qwix_quantization:
quantization_rule = qpl.get_current_rule("dot_general")
if quantization_rule is not None:
Expand All @@ -756,6 +757,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
tiling=tiling,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
quantization_rule=quantization_rule,
use_qwix_quantization=self.config.use_qwix_quantization,
)
else:
Expand Down
14 changes: 9 additions & 5 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from flax.linen import fp8_ops
from flax.linen import initializers as flax_initializers
import flax.linen as nn
from flax.core import FrozenDict

from MaxText.common_types import DType, Config
from MaxText.inference.kvcache import KVQuant
Expand Down Expand Up @@ -667,16 +668,19 @@ def get_quantization_rule(config: Config):
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
bwd_calibration_method=config.quantization_calibration_method,
weight_calibration_method=config.fwd_weight_calibration_method,
act_calibration_method=config.fwd_act_calibration_method,
op_names=("dot_general",),
additional_qt_config={
additional_qt_config=FrozenDict({
"dlhs_lhs_qtype": jnp.float8_e5m2,
"dlhs_rhs_qtype": jnp.float8_e4m3fn,
"drhs_lhs_qtype": jnp.float8_e4m3fn,
"drhs_rhs_qtype": jnp.float8_e5m2,
},
"dlhs_lhs_calibration_method": config.dlhs_lhs_calibration_method,
"dlhs_rhs_calibration_method": config.dlhs_rhs_calibration_method,
"drhs_lhs_calibration_method": config.drhs_lhs_calibration_method,
"drhs_rhs_calibration_method": config.drhs_rhs_calibration_method,
}),
)
case "fp8_gpu":
return qwix.QtRule(
Expand Down
Loading