Skip to content

Commit c88e5a9

Browse files
committed
correctly plumb quantization_rule to kernel
1 parent dd7b6f7 commit c88e5a9

File tree

5 files changed

+39
-32
lines changed

5 files changed

+39
-32
lines changed

MaxText/configs/base.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ save_quantized_params_path: ""
120120
model_call_mode: ""
121121
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix.
122122
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
123-
quantization_calibration_method: "absmax"
123+
fwd_weight_calibration_method: "absmax"
124+
fwd_act_calibration_method: "absmax"
125+
dlhs_lhs_calibration_method: "absmax"
126+
dlhs_rhs_calibration_method: "absmax"
127+
drhs_lhs_calibration_method: "absmax"
128+
drhs_rhs_calibration_method: "absmax"
124129
# Shard the range finding operation for quantization. By default this is set to number of slices.
125130
quantization_local_shard_count: -1
126131

MaxText/kernels/megablox/gmm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from qwix.pallas import QArray
3535
import qwix.pallas as qpl
36+
import qwix
3637

3738
from MaxText.kernels.megablox import common
3839

@@ -305,6 +306,7 @@ def _zero_uninitialized_memory(
305306
"interpret",
306307
"lhs_quantize_dtype",
307308
"rhs_quantize_dtype",
309+
"quantization_rule",
308310
"use_qwix_quantization",
309311
],
310312
)
@@ -320,6 +322,7 @@ def gmm(
320322
interpret: bool = False,
321323
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
322324
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
325+
quantization_rule: qwix.QtRule | None = None,
323326
use_qwix_quantization: bool = False,
324327
) -> jnp.ndarray:
325328
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
@@ -780,7 +783,7 @@ def _do():
780783
qvalue = lax.select(
781784
rhs_mask[...],
782785
rhs.qvalue[...],
783-
jnp.zeros_like(rhs.qvalue, lhs.qvalue.dtype),
786+
jnp.zeros_like(rhs.qvalue, rhs.qvalue.dtype),
784787
)
785788
loaded_rhs = dataclasses.replace(loaded_rhs, qvalue=qvalue)
786789
else:

MaxText/kernels/megablox/ops.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,9 @@
3030

3131
gmm = jax.custom_vjp(
3232
backend.gmm,
33-
nondiff_argnums=(3, 4, 7, 8, 9, 10, 11),
33+
nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12),
3434
)
3535

36-
def _get_current_rule(op_name: str):
37-
rule = qpl.get_current_rule(op_name)
38-
if rule is not None and not isinstance(rule, qwix.QtRule):
39-
rule = qwix.QtRule(**dataclasses.asdict(rule))
40-
return rule
41-
4236
def _gmm_fwd(
4337
lhs: jnp.ndarray,
4438
rhs: jnp.ndarray | aqt_tensor.QTensor,
@@ -51,6 +45,7 @@ def _gmm_fwd(
5145
interpret: bool = False,
5246
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
5347
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
48+
quantization_rule: qwix.QtRule | None = None,
5449
use_qwix_quantization: bool = False,
5550
) -> tuple[
5651
jnp.ndarray,
@@ -65,10 +60,9 @@ def _gmm_fwd(
6560
"""Forward function for GMM VJP."""
6661
if use_qwix_quantization:
6762
lhs_quantize_dtype, rhs_quantize_dtype = None, None
68-
rule = _get_current_rule("dot_general")
69-
if rule is not None:
70-
lhs_quantize_dtype = rule.act_qtype
71-
rhs_quantize_dtype = rule.weight_qtype
63+
if quantization_rule is not None:
64+
lhs_quantize_dtype = quantization_rule.act_qtype
65+
rhs_quantize_dtype = quantization_rule.weight_qtype
7266
out = backend.gmm(
7367
lhs,
7468
rhs,
@@ -93,6 +87,7 @@ def _gmm_bwd(
9387
interpret: bool,
9488
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None,
9589
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None,
90+
quantization_rule: qwix.QtRule | None,
9691
use_qwix_quantization: bool,
9792
residual: tuple[
9893
jnp.ndarray,
@@ -106,14 +101,13 @@ def _gmm_bwd(
106101
"""Backward function for throughput GMM VJP."""
107102
if use_qwix_quantization:
108103
lhs_quantize_dtype, rhs_quantize_dtype = None, None
109-
rule = _get_current_rule("dot_general")
110-
if rule is not None:
111-
if rule.additional_qt_config is not None:
112-
lhs_quantize_dtype = rule.additional_qt_config["dlhs_lhs_qtype"]
113-
rhs_quantize_dtype = rule.additional_qt_config["dlhs_rhs_qtype"]
104+
if quantization_rule is not None:
105+
if quantization_rule.additional_qt_config is not None:
106+
lhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_lhs_qtype"]
107+
rhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_rhs_qtype"]
114108
else:
115-
lhs_quantize_dtype = rule.act_qtype
116-
rhs_quantize_dtype = rule.bwd_qtype
109+
lhs_quantize_dtype = quantization_rule.act_qtype
110+
rhs_quantize_dtype = quantization_rule.bwd_qtype
117111
del preferred_element_type
118112
lhs, rhs, group_sizes, group_offset, num_actual_groups = residual
119113
grad_lhs = backend.gmm(
@@ -131,14 +125,13 @@ def _gmm_bwd(
131125
)
132126
if use_qwix_quantization:
133127
lhs_quantize_dtype, rhs_quantize_dtype = None, None
134-
rule = _get_current_rule("dot_general")
135-
if rule is not None:
136-
if rule.additional_qt_config is not None:
137-
lhs_quantize_dtype = rule.additional_qt_config["drhs_lhs_qtype"]
138-
rhs_quantize_dtype = rule.additional_qt_config["drhs_rhs_qtype"]
128+
if quantization_rule is not None:
129+
if quantization_rule.additional_qt_config is not None:
130+
lhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_lhs_qtype"]
131+
rhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_rhs_qtype"]
139132
else:
140-
lhs_quantize_dtype = rule.bwd_qtype
141-
rhs_quantize_dtype = rule.act_qtype
133+
lhs_quantize_dtype = quantization_rule.bwd_qtype
134+
rhs_quantize_dtype = quantization_rule.act_qtype
142135
grad_rhs = backend.tgmm(
143136
lhs.swapaxes(0, 1),
144137
grad,

MaxText/layers/moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
736736
quant_dg = self.quant.quant_dg
737737
lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype()
738738
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
739+
quantization_rule=None
739740
if self.config.use_qwix_quantization:
740741
quantization_rule = qpl.get_current_rule("dot_general")
741742
if quantization_rule is not None:
@@ -756,6 +757,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments):
756757
tiling=tiling,
757758
lhs_quantize_dtype=lhs_quantize_dtype,
758759
rhs_quantize_dtype=rhs_quantize_dtype,
760+
quantization_rule=quantization_rule,
759761
use_qwix_quantization=self.config.use_qwix_quantization,
760762
)
761763
else:

MaxText/layers/quantizations.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from flax.linen import fp8_ops
3636
from flax.linen import initializers as flax_initializers
3737
import flax.linen as nn
38+
from flax.core import FrozenDict
3839

3940
from MaxText.common_types import DType, Config
4041
from MaxText.inference.kvcache import KVQuant
@@ -667,16 +668,19 @@ def get_quantization_rule(config: Config):
667668
bwd_qtype=jnp.float8_e5m2,
668669
bwd_use_original_residuals=True,
669670
disable_channelwise_axes=True, # per_tensor calibration
670-
weight_calibration_method=config.quantization_calibration_method,
671-
act_calibration_method=config.quantization_calibration_method,
672-
bwd_calibration_method=config.quantization_calibration_method,
671+
weight_calibration_method=config.fwd_weight_calibration_method,
672+
act_calibration_method=config.fwd_act_calibration_method,
673673
op_names=("dot_general",),
674-
additional_qt_config={
674+
additional_qt_config=FrozenDict({
675675
"dlhs_lhs_qtype": jnp.float8_e5m2,
676676
"dlhs_rhs_qtype": jnp.float8_e4m3fn,
677677
"drhs_lhs_qtype": jnp.float8_e4m3fn,
678678
"drhs_rhs_qtype": jnp.float8_e5m2,
679-
},
679+
"dlhs_lhs_calibration_method": config.dlhs_lhs_calibration_method,
680+
"dlhs_rhs_calibration_method": config.dlhs_rhs_calibration_method,
681+
"drhs_lhs_calibration_method": config.drhs_lhs_calibration_method,
682+
"drhs_rhs_calibration_method": config.drhs_rhs_calibration_method,
683+
}),
680684
)
681685
case "fp8_gpu":
682686
return qwix.QtRule(

0 commit comments

Comments
 (0)