Skip to content

Commit 118992b

Browse files
committed
correctly plumb quantization_rule to kernel
1 parent 5d600de commit 118992b

File tree

5 files changed

+68
-37
lines changed

5 files changed

+68
-37
lines changed

MaxText/configs/base.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ save_quantized_params_path: ""
120120
model_call_mode: ""
121121
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix.
122122
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
123-
quantization_calibration_method: "absmax"
123+
fwd_weight_calibration_method: "absmax"
124+
fwd_act_calibration_method: "absmax"
125+
dlhs_lhs_calibration_method: "absmax"
126+
dlhs_rhs_calibration_method: "absmax"
127+
drhs_lhs_calibration_method: "absmax"
128+
drhs_rhs_calibration_method: "absmax"
124129
# Shard the range finding operation for quantization. By default this is set to number of slices.
125130
quantization_local_shard_count: -1
126131

MaxText/kernels/megablox/gmm.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from qwix.pallas import QArray
3535
import qwix.pallas as qpl
36+
import qwix
3637

3738
from MaxText.kernels.megablox import common
3839

@@ -305,6 +306,9 @@ def _zero_uninitialized_memory(
305306
"interpret",
306307
"lhs_quantize_dtype",
307308
"rhs_quantize_dtype",
309+
"lhs_calibration_method",
310+
"rhs_calibration_method",
311+
"quantization_rule",
308312
"use_qwix_quantization",
309313
],
310314
)
@@ -320,6 +324,9 @@ def gmm(
320324
interpret: bool = False,
321325
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
322326
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
327+
lhs_calibration_method: str = "absmax",
328+
rhs_calibration_method: str = "absmax",
329+
quantization_rule: qwix.QtRule | None = None,
323330
use_qwix_quantization: bool = False,
324331
) -> jnp.ndarray:
325332
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
@@ -604,7 +611,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
604611

605612
if lhs_quantize_dtype is not None:
606613
if use_qwix_quantization:
607-
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, channelwise_axes=[0], scale_dtype=jnp.float32)
614+
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, channelwise_axes=[0], scale_dtype=jnp.float32, calibration_method=lhs_calibration_method)
608615
else:
609616
lhs_quantize_bits = 4 if lhs_quantize_dtype == jnp.int4 else 8
610617
lhs = aqt_pl.quant(lhs, lhs_quantize_bits, lhs_contracting_axis)
@@ -613,7 +620,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
613620
# Use per-channel scales for non-contracting axes, i.e., num_groups, m, n but not k.
614621
if use_qwix_quantization:
615622
rhs = qpl.quantize(
616-
rhs, qtype=rhs_quantize_dtype, channelwise_axes=[0, 1 if transpose_rhs else 2], scale_dtype=jnp.float32
623+
rhs, qtype=rhs_quantize_dtype, channelwise_axes=[0, 1 if transpose_rhs else 2], scale_dtype=jnp.float32, calibration_method=rhs_calibration_method
617624
)
618625
else:
619626
rhs_quantize_bits = 4 if rhs_quantize_dtype == jnp.int4 else 8
@@ -645,6 +652,8 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
645652
"interpret",
646653
"lhs_quantize_dtype",
647654
"rhs_quantize_dtype",
655+
"lhs_calibration_method",
656+
"rhs_calibration_method",
648657
"use_qwix_quantization",
649658
],
650659
)
@@ -660,6 +669,8 @@ def tgmm(
660669
interpret: bool = False,
661670
lhs_quantize_dtype=None,
662671
rhs_quantize_dtype=None,
672+
lhs_calibration_method: str = "absmax",
673+
rhs_calibration_method: str = "absmax",
663674
use_qwix_quantization: bool = False,
664675
) -> jnp.ndarray:
665676
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
@@ -780,7 +791,7 @@ def _do():
780791
qvalue = lax.select(
781792
rhs_mask[...],
782793
rhs.qvalue[...],
783-
jnp.zeros_like(rhs.qvalue, lhs.qvalue.dtype),
794+
jnp.zeros_like(rhs.qvalue, rhs.qvalue.dtype),
784795
)
785796
loaded_rhs = dataclasses.replace(loaded_rhs, qvalue=qvalue)
786797
else:
@@ -878,11 +889,11 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
878889
)
879890

880891
if use_qwix_quantization and lhs_quantize_dtype is not None:
881-
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, scale_dtype=jnp.float32)
892+
lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, scale_dtype=jnp.float32, calibration_method=lhs_calibration_method)
882893

883894
if use_qwix_quantization and rhs_quantize_dtype is not None:
884895
# Use per-channel scales for non-contracting axes, i.e., num_groups, m, n but not k.
885-
rhs = qpl.quantize(rhs, qtype=rhs_quantize_dtype, scale_dtype=jnp.float32)
896+
rhs = qpl.quantize(rhs, qtype=rhs_quantize_dtype, scale_dtype=jnp.float32, calibration_method=rhs_calibration_method)
886897

887898
out = call_gmm(
888899
group_metadata,

MaxText/kernels/megablox/ops.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,9 @@
3030

3131
gmm = jax.custom_vjp(
3232
backend.gmm,
33-
nondiff_argnums=(3, 4, 7, 8, 9, 10, 11),
33+
nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12, 13, 14),
3434
)
3535

36-
def _get_current_rule(op_name: str):
37-
rule = qpl.get_current_rule(op_name)
38-
if rule is not None and not isinstance(rule, qwix.QtRule):
39-
rule = qwix.QtRule(**dataclasses.asdict(rule))
40-
return rule
41-
4236
def _gmm_fwd(
4337
lhs: jnp.ndarray,
4438
rhs: jnp.ndarray | aqt_tensor.QTensor,
@@ -51,6 +45,9 @@ def _gmm_fwd(
5145
interpret: bool = False,
5246
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
5347
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
48+
lhs_calibration_method: str = "absmax",
49+
rhs_calibration_method: str = "absmax",
50+
quantization_rule: qwix.QtRule | None = None,
5451
use_qwix_quantization: bool = False,
5552
) -> tuple[
5653
jnp.ndarray,
@@ -65,10 +62,11 @@ def _gmm_fwd(
6562
"""Forward function for GMM VJP."""
6663
if use_qwix_quantization:
6764
lhs_quantize_dtype, rhs_quantize_dtype = None, None
68-
rule = _get_current_rule("dot_general")
69-
if rule is not None:
70-
lhs_quantize_dtype = rule.act_qtype
71-
rhs_quantize_dtype = rule.weight_qtype
65+
if quantization_rule is not None:
66+
lhs_quantize_dtype = quantization_rule.act_qtype
67+
rhs_quantize_dtype = quantization_rule.weight_qtype
68+
lhs_calibration_method = quantization_rule.act_calibration_method
69+
rhs_calibration_method = quantization_rule.weight_calibration_method
7270
out = backend.gmm(
7371
lhs,
7472
rhs,
@@ -81,6 +79,8 @@ def _gmm_fwd(
8179
interpret=interpret,
8280
lhs_quantize_dtype=lhs_quantize_dtype,
8381
rhs_quantize_dtype=rhs_quantize_dtype,
82+
lhs_calibration_method=lhs_calibration_method,
83+
rhs_calibration_method=rhs_calibration_method,
8484
use_qwix_quantization=use_qwix_quantization,
8585
)
8686
return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0])
@@ -93,6 +93,9 @@ def _gmm_bwd(
9393
interpret: bool,
9494
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None,
9595
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None,
96+
lhs_calibration_method: str,
97+
rhs_calibration_method: str,
98+
quantization_rule: qwix.QtRule | None,
9699
use_qwix_quantization: bool,
97100
residual: tuple[
98101
jnp.ndarray,
@@ -106,14 +109,15 @@ def _gmm_bwd(
106109
"""Backward function for throughput GMM VJP."""
107110
if use_qwix_quantization:
108111
lhs_quantize_dtype, rhs_quantize_dtype = None, None
109-
rule = _get_current_rule("dot_general")
110-
if rule is not None:
111-
if rule.additional_qt_config is not None:
112-
lhs_quantize_dtype = rule.additional_qt_config["dlhs_lhs_qtype"]
113-
rhs_quantize_dtype = rule.additional_qt_config["dlhs_rhs_qtype"]
112+
if quantization_rule is not None:
113+
if quantization_rule.additional_qt_config is not None:
114+
lhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_lhs_qtype"]
115+
rhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_rhs_qtype"]
114116
else:
115-
lhs_quantize_dtype = rule.act_qtype
116-
rhs_quantize_dtype = rule.bwd_qtype
117+
lhs_quantize_dtype = quantization_rule.act_qtype
118+
rhs_quantize_dtype = quantization_rule.bwd_qtype
119+
lhs_calibration_method = quantization_rule.additional_qt_config["dlhs_lhs_calibration_method"]
120+
rhs_calibration_method = quantization_rule.additional_qt_config["dlhs_rhs_calibration_method"]
117121
del preferred_element_type
118122
lhs, rhs, group_sizes, group_offset, num_actual_groups = residual
119123
grad_lhs = backend.gmm(
@@ -127,18 +131,21 @@ def _gmm_bwd(
127131
interpret=interpret,
128132
lhs_quantize_dtype=lhs_quantize_dtype,
129133
rhs_quantize_dtype=rhs_quantize_dtype,
130-
use_qwix_quantization=use_qwix_quantization,
134+
lhs_calibration_method=lhs_calibration_method,
135+
rhs_calibration_method=rhs_calibration_method,
136+
use_qwix_quantization=use_qwix_quantization,
131137
)
132138
if use_qwix_quantization:
133139
lhs_quantize_dtype, rhs_quantize_dtype = None, None
134-
rule = _get_current_rule("dot_general")
135-
if rule is not None:
136-
if rule.additional_qt_config is not None:
137-
lhs_quantize_dtype = rule.additional_qt_config["drhs_lhs_qtype"]
138-
rhs_quantize_dtype = rule.additional_qt_config["drhs_rhs_qtype"]
140+
if quantization_rule is not None:
141+
if quantization_rule.additional_qt_config is not None:
142+
lhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_lhs_qtype"]
143+
rhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_rhs_qtype"]
139144
else:
140-
lhs_quantize_dtype = rule.bwd_qtype
141-
rhs_quantize_dtype = rule.act_qtype
145+
lhs_quantize_dtype = quantization_rule.bwd_qtype
146+
rhs_quantize_dtype = quantization_rule.act_qtype
147+
lhs_calibration_method = quantization_rule.additional_qt_config["drhs_lhs_calibration_method"]
148+
rhs_calibration_method = quantization_rule.additional_qt_config["drhs_rhs_calibration_method"]
142149
grad_rhs = backend.tgmm(
143150
lhs.swapaxes(0, 1),
144151
grad,
@@ -150,6 +157,8 @@ def _gmm_bwd(
150157
interpret=interpret,
151158
lhs_quantize_dtype=lhs_quantize_dtype,
152159
rhs_quantize_dtype=rhs_quantize_dtype,
160+
lhs_calibration_method=lhs_calibration_method,
161+
rhs_calibration_method=rhs_calibration_method,
153162
use_qwix_quantization=use_qwix_quantization,
154163
)
155164

MaxText/layers/moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
736736
quant_dg = self.quant.quant_dg
737737
lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype()
738738
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
739+
quantization_rule=None
739740
if self.config.use_qwix_quantization:
740741
quantization_rule = qpl.get_current_rule("dot_general")
741742
if quantization_rule is not None:
@@ -756,6 +757,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
756757
tiling=tiling,
757758
lhs_quantize_dtype=lhs_quantize_dtype,
758759
rhs_quantize_dtype=rhs_quantize_dtype,
760+
quantization_rule=quantization_rule,
759761
use_qwix_quantization=self.config.use_qwix_quantization,
760762
)
761763
else:

MaxText/layers/quantizations.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from flax.linen import fp8_ops
3636
from flax.linen import initializers as flax_initializers
3737
import flax.linen as nn
38+
from flax.core import FrozenDict
3839

3940
from MaxText.common_types import DType, Config
4041
from MaxText.inference.kvcache import KVQuant
@@ -667,16 +668,19 @@ def get_quantization_rule(config: Config):
667668
bwd_qtype=jnp.float8_e5m2,
668669
bwd_use_original_residuals=True,
669670
disable_channelwise_axes=True, # per_tensor calibration
670-
weight_calibration_method=config.quantization_calibration_method,
671-
act_calibration_method=config.quantization_calibration_method,
672-
bwd_calibration_method=config.quantization_calibration_method,
671+
weight_calibration_method=config.fwd_weight_calibration_method,
672+
act_calibration_method=config.fwd_act_calibration_method,
673673
op_names=("dot_general",),
674-
additional_qt_config={
674+
additional_qt_config=FrozenDict({
675675
"dlhs_lhs_qtype": jnp.float8_e5m2,
676676
"dlhs_rhs_qtype": jnp.float8_e4m3fn,
677677
"drhs_lhs_qtype": jnp.float8_e4m3fn,
678678
"drhs_rhs_qtype": jnp.float8_e5m2,
679-
},
679+
"dlhs_lhs_calibration_method": config.dlhs_lhs_calibration_method,
680+
"dlhs_rhs_calibration_method": config.dlhs_rhs_calibration_method,
681+
"drhs_lhs_calibration_method": config.drhs_lhs_calibration_method,
682+
"drhs_rhs_calibration_method": config.drhs_rhs_calibration_method,
683+
}),
680684
)
681685
case "fp8_gpu":
682686
return qwix.QtRule(

0 commit comments

Comments
 (0)