@@ -125,9 +125,9 @@ def _apply_mask(
125125  masks  =  []
126126  if  mask_ref  is  not None :
127127    if  k_in_lanes :
128-       mask  =  pl . load ( mask_ref , ( slice ( None ),  k_slice )) 
128+       mask  =  mask_ref [:,  k_slice ] 
129129    else :
130-       mask  =  pl . load ( mask_ref , ( k_slice , slice ( None ))) 
130+       mask  =  mask_ref [ k_slice , :] 
131131
132132    snm  =  jnp .where (should_not_mask , 1 , 0 )
133133    masks .append (jnp .bitwise_or (mask , jnp .broadcast_to (snm , mask .shape )) !=  0 )
@@ -156,7 +156,7 @@ def _apply_mask(
156156      k_sequence  =  k_offset  +  jax .lax .broadcasted_iota (
157157          jnp .int32 , (k_slice .size , bq ), 0 
158158      )
159-       q_sequence  =  pl . load ( q_sequence_ref , ( pl . ds ( 1 ),  slice ( None )))   # [1, bq] 
159+       q_sequence  =  q_sequence_ref [: 1 , :]   # [1, bq] 
160160      q_sequence  =  jnp .broadcast_to (q_sequence , (k_slice .size , bq ))
161161
162162    assert  q_sequence .shape  ==  k_sequence .shape 
@@ -170,7 +170,7 @@ def _apply_mask(
170170
171171  if  q_segment_ids_ref  is  not None :
172172    if  k_in_lanes :
173-       kv_ids  =  pl . load ( kv_segment_ids_ref , ( pl . ds ( 1 ) , k_slice ))   # [1, k_slice] 
173+       kv_ids  =  kv_segment_ids_ref [: 1 , k_slice ]   # [1, k_slice] 
174174      repeats , rem  =  divmod (kv_ids .shape [1 ], NUM_LANES )
175175      if  rem :
176176        raise  NotImplementedError (f"block_kv must be a multiple of { NUM_LANES }  )
@@ -181,9 +181,9 @@ def _apply_mask(
181181      if  rem :
182182        raise  NotImplementedError (f"block_q must be a multiple of { NUM_LANES }  )
183183      kv_ids  =  pltpu .repeat (
184-           pl . load ( kv_segment_ids_ref , ( k_slice , slice ( None ))) , repeats , axis = 1 
184+           kv_segment_ids_ref [ k_slice , :] , repeats , axis = 1 
185185      )  # [k_slice, bq] 
186-       q_ids  =  pl . load ( q_segment_ids_ref , ( pl . ds ( 1 ),  slice ( None )))   # [1, bq] 
186+       q_ids  =  q_segment_ids_ref [: 1 , :]   # [1, bq] 
187187    masks .append (q_ids  ==  kv_ids )
188188
189189  if  masks :
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228    slice_k  =  pl .ds (kv_compute_index  *  bkv_compute , bkv_compute )
229229
230230    q  =  q_ref [...]
231-     k  =  pl . load ( k_ref , ( slice_k , slice ( None ))) 
231+     k  =  k_ref [ slice_k , :] 
232232    qk  =  jax .lax .dot_general (
233233        q , k , NT_DIM_NUMBERS , preferred_element_type = jnp .float32 
234234    )
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256    )
257257
258258    sv_dims  =  NN_DIM_NUMBERS 
259-     v  =  pl . load ( v_ref , ( slice_k , slice ( None ))) 
259+     v  =  v_ref [ slice_k , :] 
260260
261261    to_float32  =  lambda  x : x .astype (jnp .float32 )
262262    v  =  to_float32 (v )
0 commit comments