Skip to content

Commit 31c4c2f

Browse files
authored
[Bugfix] fix ragged attention kernel auto-tuning table key (#9497)
1 parent e82631e commit 31c4c2f

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax.experimental import pallas as pl
2626
from jax.experimental.pallas import tpu as pltpu
2727
import jax.numpy as jnp
28+
import logging
2829

2930
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
3031
# The page size is too small. We only have 32 SREGs in TC. If the pages
@@ -1421,12 +1422,12 @@ def simplify_key(key):
14211422
return (
14221423
jnp.dtype(q_dtype).name,
14231424
jnp.dtype(kv_dtype).name,
1424-
next_power_of_2(num_q_heads_per_blk),
1425-
next_power_of_2(num_kv_heads_per_blk),
1425+
num_q_heads_per_blk,
1426+
num_kv_heads_per_blk,
14261427
(head_dim + 127) // 128 * 128,
14271428
next_power_of_2(page_size),
14281429
next_power_of_2(max_num_batched_tokens),
1429-
next_power_of_2(page_size * pages_per_seq),
1430+
page_size * pages_per_seq,
14301431
)
14311432

14321433

@@ -1472,7 +1473,7 @@ def get_tuned_block_sizes(
14721473
max_num_batched_tokens,
14731474
pages_per_seq,
14741475
)
1475-
key = simplify_key(key)
1476+
simplified_key = simplify_key(key)
14761477
device_name = get_device_name()
14771478

14781479
# Default block sizes.
@@ -1500,8 +1501,12 @@ def compute_actual_vmem_bytes(num_kv_pages_per_blk):
15001501
# OOM in vmem
15011502
bkv, bq = (32, 32)
15021503
elif device_name in TUNED_BLOCK_SIZES:
1503-
if key in TUNED_BLOCK_SIZES[device_name]:
1504-
bkv, bq = TUNED_BLOCK_SIZES[device_name][key]
1504+
if simplified_key in TUNED_BLOCK_SIZES[device_name]:
1505+
bkv, bq = TUNED_BLOCK_SIZES[device_name][simplified_key]
1506+
else:
1507+
logging.warning(
1508+
f"simplified_key({simplified_key}) is not in ragged attention kernel's tuning table!, the key before simpilification is {key}"
1509+
)
15051510
return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq))
15061511

15071512

torchax/torchax/interop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ def j2t_autograd(fn, call_jax=call_jax):
239239

240240
@wraps(fn)
241241
def inner(*args, **kwargs):
242-
from jax.tree_util import tree_flatten, tree_unflatten
243-
from jax.util import safe_zip
242+
from jax.tree_util import tree_flatten
244243

245244
class JaxFun(torch.autograd.Function):
246245

@@ -275,8 +274,8 @@ def backward(ctx, *grad_out):
275274
# The subsequent gradients correspond to flat_inputs.
276275
# We need to put a None for inputs that did not require gradients.
277276
final_grads = [None]
278-
for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:],
279-
input_grads_structured):
277+
for needs_grad, grad in zip(
278+
ctx.needs_input_grad[1:], input_grads_structured, strict=True):
280279
final_grads.append(grad if needs_grad else None)
281280

282281
return tuple(final_grads)

0 commit comments

Comments
 (0)