Skip to content

Use flashinfer softmax in top_k_top_p_sampling_from_logits #1171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

lgeiger
Copy link

@lgeiger lgeiger commented Jun 24, 2025

📌 Description

This PR uses the softmax from #1153 in top_k_top_p_sampling_from_logits in favour of torch.softmax.

🔍 Related Issues

Fixes #1153 (comment)

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @lgeiger, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance of the top_k_top_p_sampling_from_logits function by integrating flashinfer's custom softmax implementation. The primary goal is to leverage flashinfer's optimized routines for probability distribution calculations, moving away from the generic torch.softmax to potentially achieve faster and more efficient sampling within the library.

Highlights

  • Performance Optimization: The pull request replaces the standard torch.softmax function with flashinfer's optimized softmax implementation within the top_k_top_p_sampling_from_logits function. This change is applied to both the top_k_first and joint filter application orders.
  • Code Consistency: By using flashinfer's native softmax, the PR aligns the sampling logic with the library's specialized kernels, potentially leading to better performance and consistency across flashinfer's operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR replaces torch.softmax with the softmax function from the flashinfer library in the top_k_top_p_sampling_from_logits function. This change aims to improve performance by utilizing the flashinfer implementation. I've suggested adding comments to clarify the reason for the change.

@lgeiger
Copy link
Author

lgeiger commented Jun 25, 2025

Hmm looks like we're getting CUDA error: an illegal memory access was encountered errors. @xslingcn any ideas what could be going on here?

@xslingcn
Copy link
Contributor

Hmm looks like we're getting CUDA error: an illegal memory access was encountered errors. @xslingcn any ideas what could be going on here?

Let me check on this

@xslingcn
Copy link
Contributor

Hmm looks like we're getting CUDA error: an illegal memory access was encountered errors. @xslingcn any ideas what could be going on here?

Hi @lgeiger, should be fixed in #1178.

@xslingcn
Copy link
Contributor

Btw could you also update the tests?

diff --git a/tests/test_sampling.py b/tests/test_sampling.py
index b3a9bbc..6b40d27 100644
--- a/tests/test_sampling.py
+++ b/tests/test_sampling.py
@@ -340,7 +340,7 @@ def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size
         logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
     )
     samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs(
-        torch.softmax(logits, dim=-1),
+        flashinfer.sampling.softmax(logits, dim=-1),
         k,
         p,
         filter_apply_order="top_k_first",
@@ -369,7 +369,7 @@ def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p):
     )
 
     samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs(
-        torch.softmax(logits, dim=-1),
+        flashinfer.sampling.softmax(logits, dim=-1),
         k,
         p,
         filter_apply_order="joint",

@lgeiger
Copy link
Author

lgeiger commented Jun 26, 2025

Btw could you also update the tests?

For the unittests you mentioned this changes the reference. Shouldn't we keep torch.softmax for those cases?

@lgeiger
Copy link
Author

lgeiger commented Jun 26, 2025

@xslingcn I applied your suggested changes to the unittest since it looks like there are some numerical differences that otherwise make the unittests fail. Not sure if this is expected.

@xslingcn
Copy link
Contributor

@xslingcn I applied your suggested changes to the unittest since it looks like there are some numerical differences that otherwise make the unittests fail. Not sure if this is expected.

These tests are meant for testing the alignment between applying filters directly on logits and on after-softmax probs, since you have changed the top_k_top_p_sampling_from_logits to use our softmax we might want to align with it. But you are right unittests shouldn't fail either way, let me take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants