Skip to content

[example] add jagged_softmax example #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@
"examples.layer_norm",
"layer_norm_fwd",
),
"jagged_softmax": (
"tritonbench.operators.jagged_softmax.operator",
"examples.jagged_softmax",
"jagged_softmax_tritonbench",
),
# Multiple kernel variants:
"gemm": (
"tritonbench.operators.gemm.operator",
Expand Down
186 changes: 186 additions & 0 deletions examples/jagged_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Jagged Softmax Example
===============

This example demonstrates how to compute the softmax across each batch in a jagged tensor using Helion.
"""

# %%
# Imports
# -------
from __future__ import annotations

import itertools

import torch

import helion
from helion._testing import run_example
import helion.language as hl


# %%
# Reference Implementation
# --------------------
def reference_jagged_softmax_pytorch(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
) -> torch.Tensor:
"""
PyTorch reference implementation for jagged softmax.

Args:
x_data: 2-D tensor holding all elements
x_offsets: Offsets tensor for row indexing

Returns:
Tensor containing the per-batch softmax scores (same shape as x_data)
"""
vals = []
for i, j in itertools.pairwise(x_offsets):
y = x_data[i:j]
vals.append(torch.softmax(y, dim=0))
return torch.cat(vals, dim=0)


# %%
# Jagged Softmax Kernel
# ---------------
@helion.kernel()
def jagged_softmax_kernel(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
) -> torch.Tensor:
"""
Compute the per-batch softmax in a jagged tensor.

Args:
x_data: 2-D tensor of shape (total_elements, max_M) holding all elements
x_offsets: (num_rows + 1) tensor. Row i is the slice
x_data[x_offsets[i] : x_offsets[i+1], :]

Returns:
2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores.
"""
N = int(x_offsets[-1].item())
num_rows, M = x_offsets.size(0) - 1, x_data.size(1)
out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device)

# flatten
x_flat = x_data.view(-1)

for tile_b in hl.tile(num_rows):
starts = x_offsets[tile_b]
ends = x_offsets[tile_b.index + 1]
seqlens = ends - starts
max_seqlen = seqlens.amax()

for tile_m in hl.tile(M):
block_max = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype)
block_new_max = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype)
block_L = hl.full([tile_b, tile_m], 0.0, dtype=x_data.dtype)

for tile_k in hl.tile(max_seqlen):
base_indices = starts[:, None] + tile_k.index[None, :]
flat_indices = (
base_indices[:, :, None] * M + tile_m.index[None, None, :]
)
row_mask = tile_k.index[None, :] < seqlens[:, None]
combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :]
x_slice = hl.load(
x_flat,
[flat_indices],
extra_mask=combined_mask,
)
slice_max = torch.where(combined_mask, x_slice, float("-inf")).amax(
dim=1
)
block_new_max = torch.maximum(block_max, slice_max)
block_L *= torch.exp(block_max - block_new_max)
block_L += torch.exp(
torch.where(
combined_mask,
x_slice - block_new_max[:, None, :],
float("-inf"),
)
).sum(dim=1)
block_max = block_new_max

for tile_k in hl.tile(max_seqlen):
base_indices = starts[:, None] + tile_k.index[None, :]
flat_indices = (
base_indices[:, :, None] * M + tile_m.index[None, None, :]
)
row_mask = tile_k.index[None, :] < seqlens[:, None]
combined_mask = row_mask[:, :, None] & (tile_m.index < M)[None, None, :]
x_slice = hl.load(
x_flat,
[flat_indices],
extra_mask=combined_mask,
)
block_out = (
torch.exp(x_slice - block_max[:, None, :]) / block_L[:, None, :]
)
hl.store(
out,
[flat_indices],
block_out,
extra_mask=combined_mask,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't sure how to write this without the store, and it seems this store breaks ref eager mode testing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm that's weird - do you mind filing an issue for the ref eager mode problem? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, filed one at #496


return out.reshape(N, M)


# %%
# Benchmark Wrapper
# --------------
def jagged_softmax_tritonbench(
x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
) -> torch.Tensor:
"""
Wrapper for tritonbench that matches the expected interface.

Args:
x: Nested tensor in jagged format with shape (B, *, M)
B: Batch size (unused)
M: Number of features (unused)
seqlen: Maximum sequence length (unused)
sparsity: Sparsity factor (unused)

Returns:
Tensor of shape (N, M), where N = total number of rows in the jagged tensor
"""
return jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]


# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point for jagged softmax kernel verification.
"""
num_rows, max_cols = 512, 64
device = "cuda"

lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device)
x_offsets = torch.cat(
[torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)]
)
nnz = int(x_offsets[-1])
M = 128 # number of features
x_data = torch.randn(nnz, M, dtype=torch.float32, device=device)

out_eager = reference_jagged_softmax_pytorch(x_data, x_offsets)
out_hl = jagged_softmax_kernel(x_data, x_offsets)
assert torch.allclose(out_eager, out_hl)

run_example(
lambda x, o: jagged_softmax_kernel(x, o),
lambda x, o: reference_jagged_softmax_pytorch(x, o),
(x_data, x_offsets),
)


if __name__ == "__main__":
main()
144 changes: 144 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,150 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
_launcher(_helion_jagged_mean_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_jagged_softmax)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, out_stride_0, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < num_rows
starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
v_0 = tl.full([], 1, tl.int32)
v_1 = indices_0 + v_0
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
v_2 = ends - starts
_mask_to = tl.where(mask_0, v_2, -9223372036854775808)
max_seqlen = tl.max(_mask_to, 0)
for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < M
max_seqlen_copy = max_seqlen
starts_copy = starts
v_2_copy = v_2
max_seqlen_copy_0 = max_seqlen_copy
starts_copy_0 = starts_copy
v_2_copy_0 = v_2_copy
block_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
block_new_max = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
block_L = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < max_seqlen_copy_0
starts_copy_0_copy = starts_copy_0
v_2_copy_0_copy = v_2_copy_0
block_max_copy = block_max
block_L_copy = block_L
starts_copy_0_copy_0 = starts_copy_0_copy
v_2_copy_0_copy_0 = v_2_copy_0_copy
block_max_copy_0 = block_max_copy
block_L_copy_0 = block_L_copy
subscript = starts_copy_0_copy_0[:, None]
subscript_1 = indices_2[None, :]
v_3 = subscript_1.to(tl.int64)
v_4 = subscript + v_3
subscript_2 = v_4[:, :, None]
v_5 = subscript_2 * M
subscript_3 = indices_1[None, None, :]
v_6 = subscript_3.to(tl.int64)
v_7 = v_5 + v_6
subscript_4 = indices_2[None, :]
subscript_5 = v_2_copy_0_copy_0[:, None]
v_8 = subscript_4.to(tl.int64)
v_9 = v_8 < subscript_5
subscript_6 = v_9[:, :, None]
v_10 = M.to(tl.int32)
v_11 = indices_1 < v_10
subscript_7 = v_11[None, None, :]
v_12 = subscript_6 & subscript_7
x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & v_12, other=0)
v_13 = float('-inf')
v_14 = v_13[None, None, None]
v_15 = tl.where(v_12, x_slice, v_14)
_mask_to_1 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_15, float('-inf'))
slice_max = tl.max(_mask_to_1, 1)
block_new_max = triton_helpers.maximum(block_max_copy_0, slice_max)
v_17 = block_max_copy_0 - block_new_max
v_18 = tl_math.exp(v_17)
v_19 = block_L_copy_0 * v_18
subscript_8 = block_new_max[:, None, :]
v_20 = x_slice - subscript_8
v_21 = float('-inf')
v_22 = v_21[None, None, None]
v_23 = tl.where(v_12, v_20, v_22)
v_24 = tl_math.exp(v_23)
_mask_to_2 = tl.where(mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], v_24, 0)
sum_1 = tl.sum(_mask_to_2, 1)
block_L = v_19 + sum_1
block_max = block_new_max
for offset_3 in tl.range(0, max_seqlen_copy_0.to(tl.int32), _BLOCK_SIZE_3):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
mask_3 = indices_3 < max_seqlen_copy_0
starts_copy_0_copy_1 = starts_copy_0
v_2_copy_0_copy_1 = v_2_copy_0
block_max_copy_1 = block_max
block_L_copy_1 = block_L
starts_copy_0_copy_1_0 = starts_copy_0_copy_1
v_2_copy_0_copy_1_0 = v_2_copy_0_copy_1
block_max_copy_1_0 = block_max_copy_1
block_L_copy_1_0 = block_L_copy_1
subscript_9 = starts_copy_0_copy_1_0[:, None]
subscript_10 = indices_3[None, :]
v_26 = subscript_10.to(tl.int64)
v_27 = subscript_9 + v_26
subscript_11 = v_27[:, :, None]
v_28 = subscript_11 * M
subscript_12 = indices_1[None, None, :]
v_29 = subscript_12.to(tl.int64)
v_30 = v_28 + v_29
subscript_13 = indices_3[None, :]
subscript_14 = v_2_copy_0_copy_1_0[:, None]
v_31 = subscript_13.to(tl.int64)
v_32 = v_31 < subscript_14
subscript_15 = v_32[:, :, None]
v_33 = M.to(tl.int32)
v_34 = indices_1 < v_33
subscript_16 = v_34[None, None, :]
v_35 = subscript_15 & subscript_16
x_slice_1 = tl.load(x_flat + v_30 * x_flat_stride_0, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35, other=0)
subscript_17 = block_max_copy_1_0[:, None, :]
v_36 = x_slice_1 - subscript_17
v_37 = tl_math.exp(v_36)
subscript_18 = block_L_copy_1_0[:, None, :]
v_38 = v_37 / subscript_18
tl.store(out + v_30 * out_stride_0, v_38, mask_0[:, None, None] & mask_3[None, :, None] & mask_1[None, None, :] & v_35)

def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher):
"""
Compute the per-batch softmax in a jagged tensor.

Args:
x_data: 2-D tensor of shape (total_elements, max_M) holding all elements
x_offsets: (num_rows + 1) tensor. Row i is the slice
x_data[x_offsets[i] : x_offsets[i+1], :]

Returns:
2-D tensor of shape (total_elements, max_M), containing the per-batch softmax scores.
"""
N = int(x_offsets[-1].item())
num_rows, M = (x_offsets.size(0) - 1, x_data.size(1))
out = torch.zeros(N * M, dtype=x_data.dtype, device=x_data.device)
x_flat = x_data.view(-1)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 8
_BLOCK_SIZE_2 = 16
_BLOCK_SIZE_3 = 16
_launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return out.reshape(N, M)

--- assertExpectedJournal(TestExamples.test_layernorm)
from __future__ import annotations

Expand Down
29 changes: 29 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,35 @@ def test_layernorm(self):
)
)

@skipIfRefEager("ref eager mode hits CUDA indexing error with hl.store")
def test_jagged_softmax(self):
num_rows, max_cols = 128, 64
M = 8 # number of features
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=DEVICE),
torch.cumsum(lengths, dim=0),
]
)
nnz = int(x_offsets[-1])
x_data = torch.randn(nnz, M, dtype=torch.float32, device=DEVICE)
args = (x_data, x_offsets)

# Import and use the reference implementation
mod = import_path(EXAMPLES_DIR / "jagged_softmax.py")
expected = mod.reference_jagged_softmax_pytorch(x_data, x_offsets)

self.assertExpectedJournal(
check_example(
"jagged_softmax",
args,
expected,
fn_name="jagged_softmax_kernel",
block_sizes=[16, 8, 16, 16],
)
)


if __name__ == "__main__":
unittest.main()
Loading