|
25 | 25 | from jax.experimental import pallas as pl
|
26 | 26 | from jax.experimental.pallas import tpu as pltpu
|
27 | 27 | import jax.numpy as jnp
|
| 28 | +import logging |
28 | 29 |
|
29 | 30 | DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
30 | 31 | # The page size is too small. We only have 32 SREGs in TC. If the pages
|
@@ -1421,12 +1422,12 @@ def simplify_key(key):
|
1421 | 1422 | return (
|
1422 | 1423 | jnp.dtype(q_dtype).name,
|
1423 | 1424 | jnp.dtype(kv_dtype).name,
|
1424 |
| - next_power_of_2(num_q_heads_per_blk), |
1425 |
| - next_power_of_2(num_kv_heads_per_blk), |
| 1425 | + num_q_heads_per_blk, |
| 1426 | + num_kv_heads_per_blk, |
1426 | 1427 | (head_dim + 127) // 128 * 128,
|
1427 | 1428 | next_power_of_2(page_size),
|
1428 | 1429 | next_power_of_2(max_num_batched_tokens),
|
1429 |
| - next_power_of_2(page_size * pages_per_seq), |
| 1430 | + page_size * pages_per_seq, |
1430 | 1431 | )
|
1431 | 1432 |
|
1432 | 1433 |
|
@@ -1472,7 +1473,7 @@ def get_tuned_block_sizes(
|
1472 | 1473 | max_num_batched_tokens,
|
1473 | 1474 | pages_per_seq,
|
1474 | 1475 | )
|
1475 |
| - key = simplify_key(key) |
| 1476 | + simplified_key = simplify_key(key) |
1476 | 1477 | device_name = get_device_name()
|
1477 | 1478 |
|
1478 | 1479 | # Default block sizes.
|
@@ -1500,8 +1501,12 @@ def compute_actual_vmem_bytes(num_kv_pages_per_blk):
|
1500 | 1501 | # OOM in vmem
|
1501 | 1502 | bkv, bq = (32, 32)
|
1502 | 1503 | elif device_name in TUNED_BLOCK_SIZES:
|
1503 |
| - if key in TUNED_BLOCK_SIZES[device_name]: |
1504 |
| - bkv, bq = TUNED_BLOCK_SIZES[device_name][key] |
| 1504 | + if simplified_key in TUNED_BLOCK_SIZES[device_name]: |
| 1505 | + bkv, bq = TUNED_BLOCK_SIZES[device_name][simplified_key] |
| 1506 | + else: |
| 1507 | + logging.warning( |
| 1508 | + f"simplified_key({simplified_key}) is not in ragged attention kernel's tuning table!, the key before simpilification is {key}" |
| 1509 | + ) |
1505 | 1510 | return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq))
|
1506 | 1511 |
|
1507 | 1512 |
|
|
0 commit comments