We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 65a7588 commit e24fd96Copy full SHA for e24fd96
src/transformers_neuronx/hlo.py
@@ -3817,11 +3817,12 @@ def speculative_mask(
3817
# Compare ratio of probabilities at locations to a random sample
3818
ratio = divide(target_probs, draft_probs) # shape: (k, batch_size)
3819
ratio = clamp(ratio, maximum=1.0)
3820
+ ratio = cast(ratio, ratio.scribe.f32)
3821
if deterministic_threshold:
3822
random = full_like(ratio, deterministic_threshold)
3823
else:
3824
random = random_uniform(ratio.dtype, ratio.sizes)
- accepted_mask = less_equal(random, ratio) # shape: (k, batch_size)
3825
+ accepted_mask = less(random, ratio) # shape: (k, batch_size)
3826
3827
# Mask out all tokens past the accepted token
3828
accepted_mask = cast(accepted_mask, s32)
0 commit comments