diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 1d4cb4311..f8ce9aa8b 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -120,7 +120,12 @@ save_quantized_params_path: "" model_call_mode: "" use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix. # Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 -quantization_calibration_method: "absmax" +fwd_weight_calibration_method: "absmax" +fwd_act_calibration_method: "absmax" +dlhs_lhs_calibration_method: "absmax" +dlhs_rhs_calibration_method: "absmax" +drhs_lhs_calibration_method: "absmax" +drhs_rhs_calibration_method: "absmax" # Shard the range finding operation for quantization. By default this is set to number of slices. quantization_local_shard_count: -1 diff --git a/MaxText/kernels/megablox/gmm.py b/MaxText/kernels/megablox/gmm.py index 764752002..337c4d8d2 100644 --- a/MaxText/kernels/megablox/gmm.py +++ b/MaxText/kernels/megablox/gmm.py @@ -33,6 +33,7 @@ from qwix.pallas import QArray import qwix.pallas as qpl +import qwix from MaxText.kernels.megablox import common @@ -305,6 +306,9 @@ def _zero_uninitialized_memory( "interpret", "lhs_quantize_dtype", "rhs_quantize_dtype", + "lhs_calibration_method", + "rhs_calibration_method", + "quantization_rule", "use_qwix_quantization", ], ) @@ -320,6 +324,9 @@ def gmm( interpret: bool = False, lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, + lhs_calibration_method: str = "absmax", + rhs_calibration_method: str = "absmax", + quantization_rule: qwix.QtRule | None = None, use_qwix_quantization: bool = False, ) -> jnp.ndarray: """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. @@ -604,7 +611,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): if lhs_quantize_dtype is not None: if use_qwix_quantization: - lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, channelwise_axes=[0], scale_dtype=jnp.float32) + lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, channelwise_axes=[0], scale_dtype=jnp.float32, calibration_method=lhs_calibration_method) else: lhs_quantize_bits = 4 if lhs_quantize_dtype == jnp.int4 else 8 lhs = aqt_pl.quant(lhs, lhs_quantize_bits, lhs_contracting_axis) @@ -613,7 +620,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): # Use per-channel scales for non-contracting axes, i.e., num_groups, m, n but not k. if use_qwix_quantization: rhs = qpl.quantize( - rhs, qtype=rhs_quantize_dtype, channelwise_axes=[0, 1 if transpose_rhs else 2], scale_dtype=jnp.float32 + rhs, qtype=rhs_quantize_dtype, channelwise_axes=[0, 1 if transpose_rhs else 2], scale_dtype=jnp.float32, calibration_method=rhs_calibration_method ) else: rhs_quantize_bits = 4 if rhs_quantize_dtype == jnp.int4 else 8 @@ -645,6 +652,8 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): "interpret", "lhs_quantize_dtype", "rhs_quantize_dtype", + "lhs_calibration_method", + "rhs_calibration_method", "use_qwix_quantization", ], ) @@ -660,6 +669,8 @@ def tgmm( interpret: bool = False, lhs_quantize_dtype=None, rhs_quantize_dtype=None, + lhs_calibration_method: str = "absmax", + rhs_calibration_method: str = "absmax", use_qwix_quantization: bool = False, ) -> jnp.ndarray: """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. @@ -780,7 +791,7 @@ def _do(): qvalue = lax.select( rhs_mask[...], rhs.qvalue[...], - jnp.zeros_like(rhs.qvalue, lhs.qvalue.dtype), + jnp.zeros_like(rhs.qvalue, rhs.qvalue.dtype), ) loaded_rhs = dataclasses.replace(loaded_rhs, qvalue=qvalue) else: @@ -878,11 +889,11 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): ) if use_qwix_quantization and lhs_quantize_dtype is not None: - lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, scale_dtype=jnp.float32) + lhs = qpl.quantize(lhs, qtype=lhs_quantize_dtype, scale_dtype=jnp.float32, calibration_method=lhs_calibration_method) if use_qwix_quantization and rhs_quantize_dtype is not None: # Use per-channel scales for non-contracting axes, i.e., num_groups, m, n but not k. - rhs = qpl.quantize(rhs, qtype=rhs_quantize_dtype, scale_dtype=jnp.float32) + rhs = qpl.quantize(rhs, qtype=rhs_quantize_dtype, scale_dtype=jnp.float32, calibration_method=rhs_calibration_method) out = call_gmm( group_metadata, diff --git a/MaxText/kernels/megablox/ops.py b/MaxText/kernels/megablox/ops.py index d9d50a3dd..a7882994e 100644 --- a/MaxText/kernels/megablox/ops.py +++ b/MaxText/kernels/megablox/ops.py @@ -30,15 +30,9 @@ gmm = jax.custom_vjp( backend.gmm, - nondiff_argnums=(3, 4, 7, 8, 9, 10, 11), + nondiff_argnums=(3, 4, 7, 8, 9, 10, 11, 12, 13, 14), ) -def _get_current_rule(op_name: str): - rule = qpl.get_current_rule(op_name) - if rule is not None and not isinstance(rule, qwix.QtRule): - rule = qwix.QtRule(**dataclasses.asdict(rule)) - return rule - def _gmm_fwd( lhs: jnp.ndarray, rhs: jnp.ndarray | aqt_tensor.QTensor, @@ -51,6 +45,9 @@ def _gmm_fwd( interpret: bool = False, lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, + lhs_calibration_method: str = "absmax", + rhs_calibration_method: str = "absmax", + quantization_rule: qwix.QtRule | None = None, use_qwix_quantization: bool = False, ) -> tuple[ jnp.ndarray, @@ -65,10 +62,11 @@ def _gmm_fwd( """Forward function for GMM VJP.""" if use_qwix_quantization: lhs_quantize_dtype, rhs_quantize_dtype = None, None - rule = _get_current_rule("dot_general") - if rule is not None: - lhs_quantize_dtype = rule.act_qtype - rhs_quantize_dtype = rule.weight_qtype + if quantization_rule is not None: + lhs_quantize_dtype = quantization_rule.act_qtype + rhs_quantize_dtype = quantization_rule.weight_qtype + lhs_calibration_method = quantization_rule.act_calibration_method + rhs_calibration_method = quantization_rule.weight_calibration_method out = backend.gmm( lhs, rhs, @@ -81,6 +79,8 @@ def _gmm_fwd( interpret=interpret, lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, + lhs_calibration_method=lhs_calibration_method, + rhs_calibration_method=rhs_calibration_method, use_qwix_quantization=use_qwix_quantization, ) return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0]) @@ -93,6 +93,9 @@ def _gmm_bwd( interpret: bool, lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None, rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None, + lhs_calibration_method: str, + rhs_calibration_method: str, + quantization_rule: qwix.QtRule | None, use_qwix_quantization: bool, residual: tuple[ jnp.ndarray, @@ -106,14 +109,15 @@ def _gmm_bwd( """Backward function for throughput GMM VJP.""" if use_qwix_quantization: lhs_quantize_dtype, rhs_quantize_dtype = None, None - rule = _get_current_rule("dot_general") - if rule is not None: - if rule.additional_qt_config is not None: - lhs_quantize_dtype = rule.additional_qt_config["dlhs_lhs_qtype"] - rhs_quantize_dtype = rule.additional_qt_config["dlhs_rhs_qtype"] + if quantization_rule is not None: + if quantization_rule.additional_qt_config is not None: + lhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_lhs_qtype"] + rhs_quantize_dtype = quantization_rule.additional_qt_config["dlhs_rhs_qtype"] else: - lhs_quantize_dtype = rule.act_qtype - rhs_quantize_dtype = rule.bwd_qtype + lhs_quantize_dtype = quantization_rule.act_qtype + rhs_quantize_dtype = quantization_rule.bwd_qtype + lhs_calibration_method = quantization_rule.additional_qt_config["dlhs_lhs_calibration_method"] + rhs_calibration_method = quantization_rule.additional_qt_config["dlhs_rhs_calibration_method"] del preferred_element_type lhs, rhs, group_sizes, group_offset, num_actual_groups = residual grad_lhs = backend.gmm( @@ -127,18 +131,21 @@ def _gmm_bwd( interpret=interpret, lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, - use_qwix_quantization=use_qwix_quantization, + lhs_calibration_method=lhs_calibration_method, + rhs_calibration_method=rhs_calibration_method, + use_qwix_quantization=use_qwix_quantization, ) if use_qwix_quantization: lhs_quantize_dtype, rhs_quantize_dtype = None, None - rule = _get_current_rule("dot_general") - if rule is not None: - if rule.additional_qt_config is not None: - lhs_quantize_dtype = rule.additional_qt_config["drhs_lhs_qtype"] - rhs_quantize_dtype = rule.additional_qt_config["drhs_rhs_qtype"] + if quantization_rule is not None: + if quantization_rule.additional_qt_config is not None: + lhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_lhs_qtype"] + rhs_quantize_dtype = quantization_rule.additional_qt_config["drhs_rhs_qtype"] else: - lhs_quantize_dtype = rule.bwd_qtype - rhs_quantize_dtype = rule.act_qtype + lhs_quantize_dtype = quantization_rule.bwd_qtype + rhs_quantize_dtype = quantization_rule.act_qtype + lhs_calibration_method = quantization_rule.additional_qt_config["drhs_lhs_calibration_method"] + rhs_calibration_method = quantization_rule.additional_qt_config["drhs_rhs_calibration_method"] grad_rhs = backend.tgmm( lhs.swapaxes(0, 1), grad, @@ -150,6 +157,8 @@ def _gmm_bwd( interpret=interpret, lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, + lhs_calibration_method=lhs_calibration_method, + rhs_calibration_method=rhs_calibration_method, use_qwix_quantization=use_qwix_quantization, ) diff --git a/MaxText/layers/moe.py b/MaxText/layers/moe.py index 15cb54a4a..5d6e2dda5 100644 --- a/MaxText/layers/moe.py +++ b/MaxText/layers/moe.py @@ -736,6 +736,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): quant_dg = self.quant.quant_dg lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype() rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype() + quantization_rule=None if self.config.use_qwix_quantization: quantization_rule = qpl.get_current_rule("dot_general") if quantization_rule is not None: @@ -756,6 +757,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): tiling=tiling, lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, + quantization_rule=quantization_rule, use_qwix_quantization=self.config.use_qwix_quantization, ) else: diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index b65186e8a..cab2f6d33 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -35,6 +35,7 @@ from flax.linen import fp8_ops from flax.linen import initializers as flax_initializers import flax.linen as nn +from flax.core import FrozenDict from MaxText.common_types import DType, Config from MaxText.inference.kvcache import KVQuant @@ -667,16 +668,19 @@ def get_quantization_rule(config: Config): bwd_qtype=jnp.float8_e5m2, bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, + weight_calibration_method=config.fwd_weight_calibration_method, + act_calibration_method=config.fwd_act_calibration_method, op_names=("dot_general",), - additional_qt_config={ + additional_qt_config=FrozenDict({ "dlhs_lhs_qtype": jnp.float8_e5m2, "dlhs_rhs_qtype": jnp.float8_e4m3fn, "drhs_lhs_qtype": jnp.float8_e4m3fn, "drhs_rhs_qtype": jnp.float8_e5m2, - }, + "dlhs_lhs_calibration_method": config.dlhs_lhs_calibration_method, + "dlhs_rhs_calibration_method": config.dlhs_rhs_calibration_method, + "drhs_lhs_calibration_method": config.drhs_lhs_calibration_method, + "drhs_rhs_calibration_method": config.drhs_rhs_calibration_method, + }), ) case "fp8_gpu": return qwix.QtRule(