Skip to content

Commit 7c7373c

Browse files
committed
Use flashinfer softmax in top_k_top_p_sampling_from_logits
1 parent 1b9ba25 commit 7c7373c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flashinfer/sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def top_k_top_p_sampling_from_logits(
981981
"""
982982
if filter_apply_order == "top_k_first":
983983
masked_logits = top_k_mask_logits(logits, top_k)
984-
probs = torch.softmax(masked_logits, dim=-1)
984+
probs = softmax(masked_logits)
985985
return top_p_sampling_from_probs(
986986
probs,
987987
top_p,
@@ -991,7 +991,7 @@ def top_k_top_p_sampling_from_logits(
991991
generator=generator,
992992
)
993993
elif filter_apply_order == "joint":
994-
probs = torch.softmax(logits, dim=-1)
994+
probs = softmax(logits)
995995
if check_nan:
996996
if torch.any(torch.isnan(probs)):
997997
raise ValueError("Input probs contains NaN.")

0 commit comments

Comments
 (0)