@@ -125,9 +125,9 @@ def _apply_mask(
125
125
masks = []
126
126
if mask_ref is not None :
127
127
if k_in_lanes :
128
- mask = pl . load ( mask_ref , ( slice ( None ), k_slice ))
128
+ mask = mask_ref [:, k_slice ]
129
129
else :
130
- mask = pl . load ( mask_ref , ( k_slice , slice ( None )))
130
+ mask = mask_ref [ k_slice , :]
131
131
132
132
snm = jnp .where (should_not_mask , 1 , 0 )
133
133
masks .append (jnp .bitwise_or (mask , jnp .broadcast_to (snm , mask .shape )) != 0 )
@@ -156,7 +156,7 @@ def _apply_mask(
156
156
k_sequence = k_offset + jax .lax .broadcasted_iota (
157
157
jnp .int32 , (k_slice .size , bq ), 0
158
158
)
159
- q_sequence = pl . load ( q_sequence_ref , ( pl . ds ( 1 ), slice ( None ))) # [1, bq]
159
+ q_sequence = q_sequence_ref [: 1 , :] # [1, bq]
160
160
q_sequence = jnp .broadcast_to (q_sequence , (k_slice .size , bq ))
161
161
162
162
assert q_sequence .shape == k_sequence .shape
@@ -170,7 +170,7 @@ def _apply_mask(
170
170
171
171
if q_segment_ids_ref is not None :
172
172
if k_in_lanes :
173
- kv_ids = pl . load ( kv_segment_ids_ref , ( pl . ds ( 1 ) , k_slice )) # [1, k_slice]
173
+ kv_ids = kv_segment_ids_ref [: 1 , k_slice ] # [1, k_slice]
174
174
repeats , rem = divmod (kv_ids .shape [1 ], NUM_LANES )
175
175
if rem :
176
176
raise NotImplementedError (f"block_kv must be a multiple of { NUM_LANES } " )
@@ -181,9 +181,9 @@ def _apply_mask(
181
181
if rem :
182
182
raise NotImplementedError (f"block_q must be a multiple of { NUM_LANES } " )
183
183
kv_ids = pltpu .repeat (
184
- pl . load ( kv_segment_ids_ref , ( k_slice , slice ( None ))) , repeats , axis = 1
184
+ kv_segment_ids_ref [ k_slice , :] , repeats , axis = 1
185
185
) # [k_slice, bq]
186
- q_ids = pl . load ( q_segment_ids_ref , ( pl . ds ( 1 ), slice ( None ))) # [1, bq]
186
+ q_ids = q_segment_ids_ref [: 1 , :] # [1, bq]
187
187
masks .append (q_ids == kv_ids )
188
188
189
189
if masks :
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228
228
slice_k = pl .ds (kv_compute_index * bkv_compute , bkv_compute )
229
229
230
230
q = q_ref [...]
231
- k = pl . load ( k_ref , ( slice_k , slice ( None )))
231
+ k = k_ref [ slice_k , :]
232
232
qk = jax .lax .dot_general (
233
233
q , k , NT_DIM_NUMBERS , preferred_element_type = jnp .float32
234
234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256
256
)
257
257
258
258
sv_dims = NN_DIM_NUMBERS
259
- v = pl . load ( v_ref , ( slice_k , slice ( None )))
259
+ v = v_ref [ slice_k , :]
260
260
261
261
to_float32 = lambda x : x .astype (jnp .float32 )
262
262
v = to_float32 (v )
0 commit comments