30
30
31
31
gmm = jax .custom_vjp (
32
32
backend .gmm ,
33
- nondiff_argnums = (3 , 4 , 7 , 8 , 9 , 10 , 11 ),
33
+ nondiff_argnums = (3 , 4 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 ),
34
34
)
35
35
36
- def _get_current_rule (op_name : str ):
37
- rule = qpl .get_current_rule (op_name )
38
- if rule is not None and not isinstance (rule , qwix .QtRule ):
39
- rule = qwix .QtRule (** dataclasses .asdict (rule ))
40
- return rule
41
-
42
36
def _gmm_fwd (
43
37
lhs : jnp .ndarray ,
44
38
rhs : jnp .ndarray | aqt_tensor .QTensor ,
@@ -51,6 +45,9 @@ def _gmm_fwd(
51
45
interpret : bool = False ,
52
46
lhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None = None ,
53
47
rhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None = None ,
48
+ lhs_calibration_method : str = "absmax" ,
49
+ rhs_calibration_method : str = "absmax" ,
50
+ quantization_rule : qwix .QtRule | None = None ,
54
51
use_qwix_quantization : bool = False ,
55
52
) -> tuple [
56
53
jnp .ndarray ,
@@ -65,10 +62,11 @@ def _gmm_fwd(
65
62
"""Forward function for GMM VJP."""
66
63
if use_qwix_quantization :
67
64
lhs_quantize_dtype , rhs_quantize_dtype = None , None
68
- rule = _get_current_rule ("dot_general" )
69
- if rule is not None :
70
- lhs_quantize_dtype = rule .act_qtype
71
- rhs_quantize_dtype = rule .weight_qtype
65
+ if quantization_rule is not None :
66
+ lhs_quantize_dtype = quantization_rule .act_qtype
67
+ rhs_quantize_dtype = quantization_rule .weight_qtype
68
+ lhs_calibration_method = quantization_rule .act_calibration_method
69
+ rhs_calibration_method = quantization_rule .weight_calibration_method
72
70
out = backend .gmm (
73
71
lhs ,
74
72
rhs ,
@@ -81,6 +79,8 @@ def _gmm_fwd(
81
79
interpret = interpret ,
82
80
lhs_quantize_dtype = lhs_quantize_dtype ,
83
81
rhs_quantize_dtype = rhs_quantize_dtype ,
82
+ lhs_calibration_method = lhs_calibration_method ,
83
+ rhs_calibration_method = rhs_calibration_method ,
84
84
use_qwix_quantization = use_qwix_quantization ,
85
85
)
86
86
return out , (lhs , rhs , group_sizes , group_offset , rhs .shape [0 ])
@@ -93,6 +93,9 @@ def _gmm_bwd(
93
93
interpret : bool ,
94
94
lhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None ,
95
95
rhs_quantize_dtype : Literal [jnp .int4 , jnp .int8 ] | None ,
96
+ lhs_calibration_method : str ,
97
+ rhs_calibration_method : str ,
98
+ quantization_rule : qwix .QtRule | None ,
96
99
use_qwix_quantization : bool ,
97
100
residual : tuple [
98
101
jnp .ndarray ,
@@ -106,14 +109,15 @@ def _gmm_bwd(
106
109
"""Backward function for throughput GMM VJP."""
107
110
if use_qwix_quantization :
108
111
lhs_quantize_dtype , rhs_quantize_dtype = None , None
109
- rule = _get_current_rule ("dot_general" )
110
- if rule is not None :
111
- if rule .additional_qt_config is not None :
112
- lhs_quantize_dtype = rule .additional_qt_config ["dlhs_lhs_qtype" ]
113
- rhs_quantize_dtype = rule .additional_qt_config ["dlhs_rhs_qtype" ]
112
+ if quantization_rule is not None :
113
+ if quantization_rule .additional_qt_config is not None :
114
+ lhs_quantize_dtype = quantization_rule .additional_qt_config ["dlhs_lhs_qtype" ]
115
+ rhs_quantize_dtype = quantization_rule .additional_qt_config ["dlhs_rhs_qtype" ]
114
116
else :
115
- lhs_quantize_dtype = rule .act_qtype
116
- rhs_quantize_dtype = rule .bwd_qtype
117
+ lhs_quantize_dtype = quantization_rule .act_qtype
118
+ rhs_quantize_dtype = quantization_rule .bwd_qtype
119
+ lhs_calibration_method = quantization_rule .additional_qt_config ["dlhs_lhs_calibration_method" ]
120
+ rhs_calibration_method = quantization_rule .additional_qt_config ["dlhs_rhs_calibration_method" ]
117
121
del preferred_element_type
118
122
lhs , rhs , group_sizes , group_offset , num_actual_groups = residual
119
123
grad_lhs = backend .gmm (
@@ -127,18 +131,21 @@ def _gmm_bwd(
127
131
interpret = interpret ,
128
132
lhs_quantize_dtype = lhs_quantize_dtype ,
129
133
rhs_quantize_dtype = rhs_quantize_dtype ,
130
- use_qwix_quantization = use_qwix_quantization ,
134
+ lhs_calibration_method = lhs_calibration_method ,
135
+ rhs_calibration_method = rhs_calibration_method ,
136
+ use_qwix_quantization = use_qwix_quantization ,
131
137
)
132
138
if use_qwix_quantization :
133
139
lhs_quantize_dtype , rhs_quantize_dtype = None , None
134
- rule = _get_current_rule ("dot_general" )
135
- if rule is not None :
136
- if rule .additional_qt_config is not None :
137
- lhs_quantize_dtype = rule .additional_qt_config ["drhs_lhs_qtype" ]
138
- rhs_quantize_dtype = rule .additional_qt_config ["drhs_rhs_qtype" ]
140
+ if quantization_rule is not None :
141
+ if quantization_rule .additional_qt_config is not None :
142
+ lhs_quantize_dtype = quantization_rule .additional_qt_config ["drhs_lhs_qtype" ]
143
+ rhs_quantize_dtype = quantization_rule .additional_qt_config ["drhs_rhs_qtype" ]
139
144
else :
140
- lhs_quantize_dtype = rule .bwd_qtype
141
- rhs_quantize_dtype = rule .act_qtype
145
+ lhs_quantize_dtype = quantization_rule .bwd_qtype
146
+ rhs_quantize_dtype = quantization_rule .act_qtype
147
+ lhs_calibration_method = quantization_rule .additional_qt_config ["drhs_lhs_calibration_method" ]
148
+ rhs_calibration_method = quantization_rule .additional_qt_config ["drhs_rhs_calibration_method" ]
142
149
grad_rhs = backend .tgmm (
143
150
lhs .swapaxes (0 , 1 ),
144
151
grad ,
@@ -150,6 +157,8 @@ def _gmm_bwd(
150
157
interpret = interpret ,
151
158
lhs_quantize_dtype = lhs_quantize_dtype ,
152
159
rhs_quantize_dtype = rhs_quantize_dtype ,
160
+ lhs_calibration_method = lhs_calibration_method ,
161
+ rhs_calibration_method = rhs_calibration_method ,
153
162
use_qwix_quantization = use_qwix_quantization ,
154
163
)
155
164
0 commit comments