@@ -42,12 +42,12 @@ def matmul_kernel(
42
42
x_abs_max_ref : jax .Array , # (1, batch_block_size)
43
43
out_ref : jax .Array , # (batch_block_size, out_block_size)
44
44
acc_scratch : jax .Array , # (batch_block_size, out_block_size)
45
- q_x_scratch : jax .Array , # (batch_block_size, in_block_size)
45
+ x_q_scratch : jax .Array , # (batch_block_size, in_block_size)
46
46
x_scale_scratch : jax .Array , # (batch_block_size, 1)
47
47
* ,
48
48
quantize_activation : bool ,
49
49
save_acc : bool ,
50
- save_q_x : bool ,
50
+ save_x_q : bool ,
51
51
batch_block_size : int ,
52
52
out_block_size : int ,
53
53
in_block_size : int ,
@@ -66,13 +66,13 @@ def matmul_kernel(
66
66
assert out_ref .shape == (batch_block_size ,
67
67
out_block_size ), "out_ref shape is not correct"
68
68
69
- if save_q_x :
69
+ if save_x_q :
70
70
assert quantize_activation
71
- assert q_x_scratch is not None
71
+ assert x_q_scratch is not None
72
72
assert x_scale_scratch is not None
73
73
quant = (out_idx == 0 )
74
74
else :
75
- assert q_x_scratch is None
75
+ assert x_q_scratch is None
76
76
assert x_scale_scratch is None
77
77
quant = quantize_activation
78
78
@@ -88,18 +88,18 @@ def matmul_kernel(
88
88
def matmul_body (quant , is_first_step , is_last_step ):
89
89
if quantize_activation :
90
90
if quant :
91
- q_x_tmp , x_scale_tmp = _quantize_array (x_ref [...], x_abs_max_ref [...])
92
- if save_q_x :
93
- q_x_scratch [...] = q_x_tmp
91
+ x_q_tmp , x_scale_tmp = _quantize_array (x_ref [...], x_abs_max_ref [...])
92
+ if save_x_q :
93
+ x_q_scratch [...] = x_q_tmp
94
94
x_scale_scratch [...] = x_scale_tmp
95
95
else :
96
- assert save_q_x
97
- q_x_tmp = q_x_scratch [...]
96
+ assert save_x_q
97
+ x_q_tmp = x_q_scratch [...]
98
98
if is_last_step :
99
99
x_scale_tmp = x_scale_scratch [...]
100
100
101
101
acc = jax .lax .dot_general (
102
- q_x_tmp ,
102
+ x_q_tmp ,
103
103
w_ref [...],
104
104
(((1 ,), (1 ,)), ((), ())),
105
105
preferred_element_type = jnp .int32 ,
@@ -130,6 +130,44 @@ def _next_multiple(x, multiple):
130
130
return ((x + multiple - 1 ) // multiple ) * multiple
131
131
132
132
133
+ def _get_vmem_limit (n_bs , n_out , n_in , batch_block_size , out_block_size ,
134
+ in_block_size , x_bytes , w_bytes , x_q_bytes , scale_bytes ,
135
+ out_bytes , acc_bytes , save_acc , save_x_q ):
136
+ # Calculate in/out VMEM size.
137
+ x_size = batch_block_size * in_block_size * x_bytes
138
+ x_abs_max_val_size = batch_block_size * scale_bytes
139
+ w_size = out_block_size * in_block_size * w_bytes
140
+ scalar_size = out_block_size * scale_bytes
141
+ out_size = batch_block_size * out_block_size * out_bytes
142
+
143
+ vmem_in_out = x_size + x_abs_max_val_size + w_size + scalar_size + out_size
144
+ vmem_in_out *= 2 # Account for compute and vreg spills.
145
+
146
+ # Account for double buffering.
147
+ # Double buffering is used only if there are multiple blocks per in/out.
148
+ vmem_in_out += x_size if (n_bs > 1 or n_in > 1 ) else 0
149
+ vmem_in_out += x_abs_max_val_size if (n_bs > 1 ) else 0
150
+ vmem_in_out += w_size if (n_out > 1 or n_in > 1 ) else 0
151
+ vmem_in_out += scalar_size if (n_out > 1 ) else 0
152
+ vmem_in_out += out_size if (n_bs > 1 or n_out > 1 ) else 0
153
+
154
+ # Calculate scratch VMEM size.
155
+ acc_size = batch_block_size * out_block_size * acc_bytes
156
+ x_q_scratch_size = batch_block_size * in_block_size * x_q_bytes
157
+ x_scale_scratch_size = batch_block_size * scale_bytes
158
+
159
+ vmem_scratch = acc_size if save_acc else 0
160
+ vmem_scratch += x_q_scratch_size + x_scale_scratch_size if save_x_q else 0
161
+ vmem_scratch *= 2 # Account for compute and vreg spills.
162
+
163
+ # Add in/out and scratch VMEM size.
164
+ vmem_used = vmem_in_out + vmem_scratch
165
+ # Specify upper limit as 96MB.
166
+ vmem_limit_bytes = min (vmem_used , 96 * 1024 * 1024 )
167
+
168
+ return vmem_limit_bytes
169
+
170
+
133
171
@functools .partial (
134
172
jax .jit ,
135
173
static_argnames = [
@@ -196,32 +234,39 @@ def quantized_matmul_int8(
196
234
assert x .shape [
197
235
1 ] % in_block_size == 0 , f"x.shape[1] ({ x .shape [1 ]} ) must be a multiple of block size ({ in_block_size } )"
198
236
199
- acc_dtype = jnp .int32 if quantize_activation else x .dtype
200
- vmem_to_be_transferred = 2 * (
201
- batch_block_size * in_block_size * x .dtype .itemsize +
202
- out_block_size * in_block_size * w .dtype .itemsize + out_block_size *
203
- scalar .dtype .itemsize + batch_block_size * x_abs_max_val .dtype .itemsize +
204
- batch_block_size * out_block_size * x .dtype .itemsize
205
- ) + batch_block_size * out_block_size * jnp .dtype (acc_dtype ).itemsize
206
- # Within the kernel, it will use some extra VMEM for computation or vreg spills.
207
- vmem_used = vmem_to_be_transferred * 2
208
- vmem_limit_bytes = min (vmem_used * 2 , 96 * 1024 * 1024 )
209
-
210
237
n_bs = padded_bs // batch_block_size
211
238
n_out = padded_out_features // out_block_size
212
239
n_in = padded_in_features // in_block_size
213
240
214
241
save_acc = n_in > 1
215
242
# Remove redundant input quantization logic by caching quantized input.
216
243
# For best performance, only enable this behavior when single input block is used per batch.
217
- save_q_x = quantize_activation and n_in == 1 and n_out > 1
244
+ save_x_q = quantize_activation and n_in == 1 and n_out > 1
245
+
246
+ acc_dtype = jnp .int32 if quantize_activation else jnp .float32
247
+
248
+ vmem_limit_bytes = _get_vmem_limit (
249
+ n_bs = n_bs ,
250
+ n_out = n_out ,
251
+ n_in = n_in ,
252
+ batch_block_size = batch_block_size ,
253
+ out_block_size = out_block_size ,
254
+ in_block_size = in_block_size ,
255
+ x_bytes = x .dtype .itemsize ,
256
+ w_bytes = w .dtype .itemsize ,
257
+ x_q_bytes = jnp .dtype (jnp .int8 ).itemsize ,
258
+ scale_bytes = jnp .dtype (jnp .float32 ).itemsize ,
259
+ out_bytes = x .dtype .itemsize ,
260
+ acc_bytes = jnp .dtype (acc_dtype ).itemsize ,
261
+ save_acc = save_acc ,
262
+ save_x_q = save_x_q )
218
263
219
264
kernel = pl .pallas_call (
220
265
functools .partial (
221
266
matmul_kernel ,
222
267
quantize_activation = quantize_activation ,
223
268
save_acc = save_acc ,
224
- save_q_x = save_q_x ,
269
+ save_x_q = save_x_q ,
225
270
batch_block_size = batch_block_size ,
226
271
out_block_size = out_block_size ,
227
272
in_block_size = in_block_size ),
@@ -243,9 +288,9 @@ def quantized_matmul_int8(
243
288
pltpu .VMEM ((batch_block_size ,
244
289
out_block_size ), acc_dtype ) if save_acc else None ,
245
290
pltpu .VMEM ((batch_block_size ,
246
- in_block_size ), jnp .int8 ) if save_q_x else None ,
291
+ in_block_size ), jnp .int8 ) if save_x_q else None ,
247
292
pltpu .VMEM (
248
- (batch_block_size , 1 ), jnp .float32 ) if save_q_x else None ,
293
+ (batch_block_size , 1 ), jnp .float32 ) if save_x_q else None ,
249
294
],
250
295
grid = (n_bs , n_out , n_in ),
251
296
),
0 commit comments