We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
softmax
top_k_top_p_sampling_from_logits
1 parent 1b9ba25 commit 7c7373cCopy full SHA for 7c7373c
flashinfer/sampling.py
@@ -981,7 +981,7 @@ def top_k_top_p_sampling_from_logits(
981
"""
982
if filter_apply_order == "top_k_first":
983
masked_logits = top_k_mask_logits(logits, top_k)
984
- probs = torch.softmax(masked_logits, dim=-1)
+ probs = softmax(masked_logits)
985
return top_p_sampling_from_probs(
986
probs,
987
top_p,
@@ -991,7 +991,7 @@ def top_k_top_p_sampling_from_logits(
991
generator=generator,
992
)
993
elif filter_apply_order == "joint":
994
- probs = torch.softmax(logits, dim=-1)
+ probs = softmax(logits)
995
if check_nan:
996
if torch.any(torch.isnan(probs)):
997
raise ValueError("Input probs contains NaN.")
0 commit comments