Skip to content

Commit 92b50ca

Browse files
committed
Use flashinfer softmax in unittests
1 parent 7c7373c commit 92b50ca

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size
348348
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
349349
)
350350
samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs(
351-
torch.softmax(logits, dim=-1),
351+
flashinfer.sampling.softmax(logits),
352352
k,
353353
p,
354354
filter_apply_order="top_k_first",
@@ -377,7 +377,7 @@ def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p):
377377
)
378378

379379
samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs(
380-
torch.softmax(logits, dim=-1),
380+
flashinfer.sampling.softmax(logits),
381381
k,
382382
p,
383383
filter_apply_order="joint",

0 commit comments

Comments
 (0)