Skip to content

Commit 1b9ba25

Browse files
xslingcnyzh119
andauthored
bugfix: softmax NaN results caused by large -inf masks (#1178)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> For inputs with many `-inf` masks (e.g. topk masked logits), our current softmax kernel produces all `NaN` results when a thread sees a slice of input made up entirely of `-inf`s. This pr adds checks to fix it. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: yzh119 <[email protected]>
1 parent f70b66d commit 1b9ba25

File tree

2 files changed

+50
-47
lines changed

2 files changed

+50
-47
lines changed

include/flashinfer/sampling.cuh

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -358,33 +358,31 @@ __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* te
358358
__syncthreads();
359359
block_max = temp_storage.shared_state.max_val;
360360

361-
float thread_sum = 0.0f;
361+
// if block_max is -inf, then this block contains all -inf values, so we can skip updating
362+
if (!isinf(block_max)) {
363+
float thread_sum = 0.0f;
362364
#pragma unroll
363-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
364-
thread_sum += __expf(logits_vec[j] - block_max); // e^(-inf) is safe to add
365-
}
365+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
366+
thread_sum += __expf(logits_vec[j] - block_max);
367+
}
366368

367-
float block_sum =
368-
cub::BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce).Sum(thread_sum);
369-
if (tx == 0) {
370-
temp_storage.shared_state.denominator = block_sum;
371-
}
372-
__syncthreads();
373-
block_sum = temp_storage.shared_state.denominator;
369+
float block_sum =
370+
cub::BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce).Sum(thread_sum);
371+
__syncthreads();
374372

375-
if (tx == 0) {
376-
float new_max = max(running_max, block_max);
377-
running_denominator = running_denominator * __expf(running_max - new_max) +
378-
block_sum * __expf(block_max - new_max);
379-
running_max = new_max;
373+
if (tx == 0) {
374+
float new_max = max(running_max, block_max);
375+
running_denominator = running_denominator * __expf(running_max - new_max) +
376+
block_sum * __expf(block_max - new_max);
377+
running_max = new_max;
380378

381-
temp_storage.shared_state.max_val = running_max;
382-
temp_storage.shared_state.denominator = running_denominator;
379+
temp_storage.shared_state.max_val = running_max;
380+
temp_storage.shared_state.denominator = running_denominator;
381+
}
382+
__syncthreads();
383+
running_max = temp_storage.shared_state.max_val;
384+
running_denominator = temp_storage.shared_state.denominator;
383385
}
384-
__syncthreads();
385-
386-
running_max = temp_storage.shared_state.max_val;
387-
running_denominator = temp_storage.shared_state.denominator;
388386
}
389387

390388
const float final_max = running_max;
@@ -476,34 +474,31 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part
476474
__syncthreads();
477475
block_max = temp_storage.shared_state.max_val;
478476

479-
float thread_sum = 0.0f;
477+
// if block_max is -inf, then this block contains all -inf values, so we can skip updating
478+
if (!isinf(block_max)) {
479+
float thread_sum = 0.0f;
480480
#pragma unroll
481-
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
482-
thread_sum += __expf(logits_vec[j] - block_max);
483-
}
484-
485-
float block_sum =
486-
cub::BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce).Sum(thread_sum);
481+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
482+
thread_sum += __expf(logits_vec[j] - block_max);
483+
}
487484

488-
if (tx == 0) {
489-
temp_storage.shared_state.denominator = block_sum;
490-
}
491-
__syncthreads();
492-
block_sum = temp_storage.shared_state.denominator;
485+
float block_sum =
486+
cub::BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce).Sum(thread_sum);
487+
__syncthreads();
493488

494-
if (tx == 0) {
495-
float new_max = max(running_max, block_max);
496-
running_denominator = running_denominator * __expf(running_max - new_max) +
497-
block_sum * __expf(block_max - new_max);
498-
running_max = new_max;
489+
if (tx == 0) {
490+
float new_max = max(running_max, block_max);
491+
running_denominator = running_denominator * __expf(running_max - new_max) +
492+
block_sum * __expf(block_max - new_max);
493+
running_max = new_max;
499494

500-
temp_storage.shared_state.max_val = running_max;
501-
temp_storage.shared_state.denominator = running_denominator;
495+
temp_storage.shared_state.max_val = running_max;
496+
temp_storage.shared_state.denominator = running_denominator;
497+
}
498+
__syncthreads();
499+
running_max = temp_storage.shared_state.max_val;
500+
running_denominator = temp_storage.shared_state.denominator;
502501
}
503-
__syncthreads();
504-
505-
running_max = temp_storage.shared_state.max_val;
506-
running_denominator = temp_storage.shared_state.denominator;
507502
}
508503

509504
if (tx == 0) {

tests/test_sampling.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,17 @@ def gumbel_noise(shape, device):
5050
)
5151
@pytest.mark.parametrize("temperature", [1.0, 0.5, 0.1])
5252
@pytest.mark.parametrize("temperature_arr", [True, False])
53-
def test_softmax(batch_size, vocab_size, distribution, temperature, temperature_arr):
53+
@pytest.mark.parametrize("neg_inf_input", [True, False])
54+
def test_softmax(
55+
batch_size, vocab_size, distribution, temperature, temperature_arr, neg_inf_input
56+
):
5457
torch.manual_seed(42)
5558
logits = distribution((batch_size, vocab_size), "cuda:0")
59+
if neg_inf_input:
60+
# assign random logits to -inf
61+
num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item()
62+
inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf]
63+
logits.view(-1).index_fill_(0, inf_idx, float("-inf"))
5664

5765
if temperature_arr:
5866
temperature_arr = torch.full((batch_size,), temperature, device="cuda:0")
@@ -64,7 +72,7 @@ def test_softmax(batch_size, vocab_size, distribution, temperature, temperature_
6472

6573
probs_ref = torch.softmax(logits_scaled, dim=-1)
6674

67-
assert torch.allclose(probs, probs_ref, atol=1e-3)
75+
assert torch.allclose(probs, probs_ref, atol=1e-5)
6876

6977

7078
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])

0 commit comments

Comments
 (0)