Skip to content

Commit 29ae4c7

Browse files
authored
Optimize w8a8 kernel vmem limit (#9508)
1 parent 0a1594a commit 29ae4c7

File tree

1 file changed

+71
-26
lines changed

1 file changed

+71
-26
lines changed

torch_xla/experimental/pallas_kernels/quantized_matmul_kernel.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def matmul_kernel(
4242
x_abs_max_ref: jax.Array, # (1, batch_block_size)
4343
out_ref: jax.Array, # (batch_block_size, out_block_size)
4444
acc_scratch: jax.Array, # (batch_block_size, out_block_size)
45-
q_x_scratch: jax.Array, # (batch_block_size, in_block_size)
45+
x_q_scratch: jax.Array, # (batch_block_size, in_block_size)
4646
x_scale_scratch: jax.Array, # (batch_block_size, 1)
4747
*,
4848
quantize_activation: bool,
4949
save_acc: bool,
50-
save_q_x: bool,
50+
save_x_q: bool,
5151
batch_block_size: int,
5252
out_block_size: int,
5353
in_block_size: int,
@@ -66,13 +66,13 @@ def matmul_kernel(
6666
assert out_ref.shape == (batch_block_size,
6767
out_block_size), "out_ref shape is not correct"
6868

69-
if save_q_x:
69+
if save_x_q:
7070
assert quantize_activation
71-
assert q_x_scratch is not None
71+
assert x_q_scratch is not None
7272
assert x_scale_scratch is not None
7373
quant = (out_idx == 0)
7474
else:
75-
assert q_x_scratch is None
75+
assert x_q_scratch is None
7676
assert x_scale_scratch is None
7777
quant = quantize_activation
7878

@@ -88,18 +88,18 @@ def matmul_kernel(
8888
def matmul_body(quant, is_first_step, is_last_step):
8989
if quantize_activation:
9090
if quant:
91-
q_x_tmp, x_scale_tmp = _quantize_array(x_ref[...], x_abs_max_ref[...])
92-
if save_q_x:
93-
q_x_scratch[...] = q_x_tmp
91+
x_q_tmp, x_scale_tmp = _quantize_array(x_ref[...], x_abs_max_ref[...])
92+
if save_x_q:
93+
x_q_scratch[...] = x_q_tmp
9494
x_scale_scratch[...] = x_scale_tmp
9595
else:
96-
assert save_q_x
97-
q_x_tmp = q_x_scratch[...]
96+
assert save_x_q
97+
x_q_tmp = x_q_scratch[...]
9898
if is_last_step:
9999
x_scale_tmp = x_scale_scratch[...]
100100

101101
acc = jax.lax.dot_general(
102-
q_x_tmp,
102+
x_q_tmp,
103103
w_ref[...],
104104
(((1,), (1,)), ((), ())),
105105
preferred_element_type=jnp.int32,
@@ -130,6 +130,44 @@ def _next_multiple(x, multiple):
130130
return ((x + multiple - 1) // multiple) * multiple
131131

132132

133+
def _get_vmem_limit(n_bs, n_out, n_in, batch_block_size, out_block_size,
134+
in_block_size, x_bytes, w_bytes, x_q_bytes, scale_bytes,
135+
out_bytes, acc_bytes, save_acc, save_x_q):
136+
# Calculate in/out VMEM size.
137+
x_size = batch_block_size * in_block_size * x_bytes
138+
x_abs_max_val_size = batch_block_size * scale_bytes
139+
w_size = out_block_size * in_block_size * w_bytes
140+
scalar_size = out_block_size * scale_bytes
141+
out_size = batch_block_size * out_block_size * out_bytes
142+
143+
vmem_in_out = x_size + x_abs_max_val_size + w_size + scalar_size + out_size
144+
vmem_in_out *= 2 # Account for compute and vreg spills.
145+
146+
# Account for double buffering.
147+
# Double buffering is used only if there are multiple blocks per in/out.
148+
vmem_in_out += x_size if (n_bs > 1 or n_in > 1) else 0
149+
vmem_in_out += x_abs_max_val_size if (n_bs > 1) else 0
150+
vmem_in_out += w_size if (n_out > 1 or n_in > 1) else 0
151+
vmem_in_out += scalar_size if (n_out > 1) else 0
152+
vmem_in_out += out_size if (n_bs > 1 or n_out > 1) else 0
153+
154+
# Calculate scratch VMEM size.
155+
acc_size = batch_block_size * out_block_size * acc_bytes
156+
x_q_scratch_size = batch_block_size * in_block_size * x_q_bytes
157+
x_scale_scratch_size = batch_block_size * scale_bytes
158+
159+
vmem_scratch = acc_size if save_acc else 0
160+
vmem_scratch += x_q_scratch_size + x_scale_scratch_size if save_x_q else 0
161+
vmem_scratch *= 2 # Account for compute and vreg spills.
162+
163+
# Add in/out and scratch VMEM size.
164+
vmem_used = vmem_in_out + vmem_scratch
165+
# Specify upper limit as 96MB.
166+
vmem_limit_bytes = min(vmem_used, 96 * 1024 * 1024)
167+
168+
return vmem_limit_bytes
169+
170+
133171
@functools.partial(
134172
jax.jit,
135173
static_argnames=[
@@ -196,32 +234,39 @@ def quantized_matmul_int8(
196234
assert x.shape[
197235
1] % in_block_size == 0, f"x.shape[1] ({x.shape[1]}) must be a multiple of block size ({in_block_size})"
198236

199-
acc_dtype = jnp.int32 if quantize_activation else x.dtype
200-
vmem_to_be_transferred = 2 * (
201-
batch_block_size * in_block_size * x.dtype.itemsize +
202-
out_block_size * in_block_size * w.dtype.itemsize + out_block_size *
203-
scalar.dtype.itemsize + batch_block_size * x_abs_max_val.dtype.itemsize +
204-
batch_block_size * out_block_size * x.dtype.itemsize
205-
) + batch_block_size * out_block_size * jnp.dtype(acc_dtype).itemsize
206-
# Within the kernel, it will use some extra VMEM for computation or vreg spills.
207-
vmem_used = vmem_to_be_transferred * 2
208-
vmem_limit_bytes = min(vmem_used * 2, 96 * 1024 * 1024)
209-
210237
n_bs = padded_bs // batch_block_size
211238
n_out = padded_out_features // out_block_size
212239
n_in = padded_in_features // in_block_size
213240

214241
save_acc = n_in > 1
215242
# Remove redundant input quantization logic by caching quantized input.
216243
# For best performance, only enable this behavior when single input block is used per batch.
217-
save_q_x = quantize_activation and n_in == 1 and n_out > 1
244+
save_x_q = quantize_activation and n_in == 1 and n_out > 1
245+
246+
acc_dtype = jnp.int32 if quantize_activation else jnp.float32
247+
248+
vmem_limit_bytes = _get_vmem_limit(
249+
n_bs=n_bs,
250+
n_out=n_out,
251+
n_in=n_in,
252+
batch_block_size=batch_block_size,
253+
out_block_size=out_block_size,
254+
in_block_size=in_block_size,
255+
x_bytes=x.dtype.itemsize,
256+
w_bytes=w.dtype.itemsize,
257+
x_q_bytes=jnp.dtype(jnp.int8).itemsize,
258+
scale_bytes=jnp.dtype(jnp.float32).itemsize,
259+
out_bytes=x.dtype.itemsize,
260+
acc_bytes=jnp.dtype(acc_dtype).itemsize,
261+
save_acc=save_acc,
262+
save_x_q=save_x_q)
218263

219264
kernel = pl.pallas_call(
220265
functools.partial(
221266
matmul_kernel,
222267
quantize_activation=quantize_activation,
223268
save_acc=save_acc,
224-
save_q_x=save_q_x,
269+
save_x_q=save_x_q,
225270
batch_block_size=batch_block_size,
226271
out_block_size=out_block_size,
227272
in_block_size=in_block_size),
@@ -243,9 +288,9 @@ def quantized_matmul_int8(
243288
pltpu.VMEM((batch_block_size,
244289
out_block_size), acc_dtype) if save_acc else None,
245290
pltpu.VMEM((batch_block_size,
246-
in_block_size), jnp.int8) if save_q_x else None,
291+
in_block_size), jnp.int8) if save_x_q else None,
247292
pltpu.VMEM(
248-
(batch_block_size, 1), jnp.float32) if save_q_x else None,
293+
(batch_block_size, 1), jnp.float32) if save_x_q else None,
249294
],
250295
grid=(n_bs, n_out, n_in),
251296
),

0 commit comments

Comments
 (0)