Skip to content

Commit e24fd96

Browse files
Yi-Hsiang (Sean) Laiaws-yishanm
authored andcommitted
Use f32 rng for SD token selection
GitOrigin-RevId: 210123e945a509d7087f478a322cbea7974b3542
1 parent 65a7588 commit e24fd96

File tree

1 file changed

+2
-1
lines changed
  • src/transformers_neuronx

1 file changed

+2
-1
lines changed

src/transformers_neuronx/hlo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3817,11 +3817,12 @@ def speculative_mask(
38173817
# Compare ratio of probabilities at locations to a random sample
38183818
ratio = divide(target_probs, draft_probs) # shape: (k, batch_size)
38193819
ratio = clamp(ratio, maximum=1.0)
3820+
ratio = cast(ratio, ratio.scribe.f32)
38203821
if deterministic_threshold:
38213822
random = full_like(ratio, deterministic_threshold)
38223823
else:
38233824
random = random_uniform(ratio.dtype, ratio.sizes)
3824-
accepted_mask = less_equal(random, ratio) # shape: (k, batch_size)
3825+
accepted_mask = less(random, ratio) # shape: (k, batch_size)
38253826

38263827
# Mask out all tokens past the accepted token
38273828
accepted_mask = cast(accepted_mask, s32)

0 commit comments

Comments
 (0)