-
Notifications
You must be signed in to change notification settings - Fork 84
Add op and test for chunk_local_cumsum_scalar #199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
1StepForever
wants to merge
5
commits into
sgl-project:main
Choose a base branch
from
1StepForever:www/chunk_cumsum
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| ''' | ||
| 1. make block larger to accelerate | ||
| 2. cumsum at 0-axis by transpose, (later could be removed as currently the none 0-axis cumsum | ||
| is unrolled by for loop and low-efficient) | ||
| ''' | ||
|
|
||
| from inspect import signature | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from sgl_kernel_npu.utils.index import prepare_chunk_indices | ||
|
|
||
|
|
||
| @triton.heuristics( | ||
| { | ||
| "HAS_SCALE": lambda args: args["scale"] is not None, | ||
| "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | ||
| } | ||
| ) | ||
| @triton.jit(do_not_specialize=["T"]) | ||
| def chunk_local_cumsum_scalar_kernel( | ||
| s, | ||
| o, | ||
| scale, | ||
| cu_seqlens, | ||
| chunk_indices, | ||
| T, | ||
| B: tl.constexpr, | ||
| H: tl.constexpr, | ||
| BLOCK_T: tl.constexpr, | ||
| REVERSE: tl.constexpr, | ||
| HAS_SCALE: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| HEAD_FIRST: tl.constexpr, | ||
| CHUNK_SIZE: tl.constexpr=64, | ||
| ): | ||
| """ | ||
| Computes a chunk-wise cumulative sum (optionally reversed) over input tensor `s` and writes the result to `o`. | ||
| This kernel operates on sequences that may be either fixed-length (batched) or variable-length (packed).. | ||
|
|
||
| The layout of the input/output tensors depends on the `HEAD_FIRST` flag: | ||
| - If `HEAD_FIRST=True`: tensors are shaped `(B, H, T)` | ||
| - If `HEAD_FIRST=False`: tensors are shaped `(B, T, H)` | ||
|
|
||
| For variable-length sequences (`IS_VARLEN=True`), sequence boundaries are defined by `cu_seqlens`, | ||
| and valid computation blocks are specified via `chunk_indices`. | ||
|
|
||
| Args: | ||
| s (tl.pointer): Input tensor pointer. Shape depends on `HEAD_FIRST` and batching mode. | ||
| o (tl.pointer): Output tensor pointer. Same shape and layout as `s`. | ||
| scale (float or None): Optional scalar multiplier applied to the output if `HAS_SCALE=True`. | ||
| cu_seqlens (tl.pointer or None): Cumulative sequence lengths for variable-length batching. | ||
| Required if `IS_VARLEN=True`. | ||
| chunk_indices (tl.pointer or None): Pairs of (sequence_id, block_id) indicating which | ||
| sequence and which time-block to process. | ||
| Only used when `IS_VARLEN=True`. | ||
| T (int): Total sequence length per batch (for fixed-length) or max sequence length (for varlen). | ||
| B (tl.constexpr): Batch size (number of sequences). | ||
| H (tl.constexpr): Number of heads or feature dimension. | ||
| BLOCK_T (tl.constexpr): Number of time steps processed per kernel launch per batch item. | ||
| REVERSE (tl.constexpr): If True, computes reverse cumulative sum within each chunk. | ||
| HAS_SCALE (tl.constexpr): If True, applies `scale` to the output. | ||
| IS_VARLEN (tl.constexpr): If True, uses packed variable-length layout via `cu_seqlens`. | ||
| HEAD_FIRST (tl.constexpr): Controls tensor memory layout (head-first vs time-first). | ||
| CHUNK_SIZE (tl.constexpr, optional): Size of each local chunk for cumsum. Default: 64. | ||
|
|
||
| Notes: | ||
| - The kernel assumes `BLOCK_T` is divisible into chunks of `CHUNK_SIZE` (padding handled internally). | ||
| - Boundary checks are applied during load/store to avoid out-of-bounds access. | ||
| - All computations are performed in fp32 for numerical stability, then cast back to input dtype. | ||
|
|
||
| - reverse cumsum requires T is multiple of CHUNK_SIZE, same as orig code. | ||
| """ | ||
| tl.static_assert(BLOCK_T % CHUNK_SIZE == 0) | ||
|
|
||
| i_block, i_b = tl.program_id(0), tl.program_id(1) | ||
| N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE | ||
|
|
||
| if IS_VARLEN: | ||
| i_s, i_block = tl.load(chunk_indices + i_block * 2).to(tl.int32), tl.load( | ||
| chunk_indices + i_block * 2 + 1 | ||
| ).to(tl.int32) | ||
|
|
||
| bos, eos = tl.load(cu_seqlens + i_s).to(tl.int32), tl.load( | ||
| cu_seqlens + i_s + 1 | ||
| ).to(tl.int32) | ||
| T = eos - bos | ||
| else: | ||
| bos, eos = i_b * T, i_b * T + T | ||
|
|
||
|
|
||
| if HEAD_FIRST: | ||
| ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) | ||
| ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) | ||
| b_s = tl.load(ptr_s, boundary_check=(0, 1)).to(tl.float32) | ||
|
|
||
| b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE)) | ||
1StepForever marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| b_s = tl.trans(b_s, (2, 0, 1)) | ||
| b_o = tl.cumsum(b_s, axis=0) | ||
| if REVERSE: | ||
| b_z = tl.sum(b_s, axis=0) | ||
| b_o = -b_o + b_z[None, :, :] + b_s | ||
| if HAS_SCALE: | ||
| b_o *= scale | ||
| b_o = tl.trans(b_o, (1, 2, 0)) | ||
| b_o = tl.reshape(b_o, (H, BLOCK_T)) | ||
| tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
||
| else: | ||
| ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) | ||
| ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) | ||
| b_s = tl.load(ptr_s, boundary_check=(0, 1)).to(tl.float32) | ||
| b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H)) | ||
| b_s = tl.trans(b_s, (1, 0, 2)) | ||
| b_o = tl.cumsum(b_s, axis=0) | ||
| if REVERSE: | ||
| b_z = tl.sum(b_s, axis=0) | ||
| b_o = -b_o + b_z[None, :, :]+ b_s | ||
| if HAS_SCALE: | ||
| b_o *= scale | ||
| b_o = tl.trans(b_o, (1, 0, 2)) | ||
| b_o = tl.reshape(b_o, (BLOCK_T, H)) | ||
| tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, 1)) | ||
| return | ||
|
|
||
|
|
||
| def chunk_local_cumsum_scalar_npu( | ||
| g: torch.Tensor, | ||
| chunk_size: int, | ||
| reverse: bool = False, | ||
| scale: float = None, | ||
| cu_seqlens: Optional[torch.Tensor] = None, | ||
| head_first: bool = False, | ||
| output_dtype: Optional[torch.dtype] = torch.float, | ||
| ) -> torch.Tensor: | ||
| if head_first: | ||
| B, H, T = g.shape | ||
| else: | ||
| B, T, H = g.shape | ||
| assert chunk_size == 2 ** ( | ||
| chunk_size.bit_length() - 1 | ||
| ), "chunk_size must be a power of 2" | ||
| OPTIM_BLOCK_SIZE = triton.next_power_of_2((2 ** 18) // (H * chunk_size)) | ||
| block_indices = prepare_chunk_indices(cu_seqlens, chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None | ||
| num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(T, OPTIM_BLOCK_SIZE) | ||
| g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) | ||
| grid = (num_blocks, B) | ||
| chunk_local_cumsum_scalar_kernel[grid]( | ||
| s=g_org, | ||
| o=g, | ||
| scale=scale, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=block_indices, | ||
| T=T, | ||
| B=B, | ||
| H=H, | ||
| BLOCK_T = OPTIM_BLOCK_SIZE, | ||
| CHUNK_SIZE =chunk_size, | ||
| HEAD_FIRST=head_first, | ||
| REVERSE=reverse, | ||
| ) | ||
| return g | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| import torch | ||
| import torch.nn.functional as F | ||
| import triton | ||
|
|
||
|
|
||
| def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: | ||
| return cu_seqlens[1:] - cu_seqlens[:-1] | ||
|
|
||
| def prepare_chunk_indices( | ||
| cu_seqlens: torch.LongTensor, chunk_size: int | ||
| ) -> torch.LongTensor: | ||
| num_chunks = triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() | ||
| indices = torch.cat( | ||
| [ | ||
| torch.arange(n) | ||
| for n in num_chunks | ||
| ] | ||
| ) | ||
| return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.