@@ -120,8 +120,9 @@ def warps_per_cta(layout):
120120@pytest .mark .parametrize ("layout" , layouts )
121121@pytest .mark .parametrize ("load_block_ptr, store_block_ptr" , [(True , True ), (False , False ), (True , False ),
122122 (False , True )])
123+ @pytest .mark .parametrize ("transpose" , [True , False ])
123124@pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
124- def test_block_io (M , N , dtype_str , layout , load_block_ptr , store_block_ptr , device , tmp_path : pathlib .Path ):
125+ def test_block_io (M , N , dtype_str , layout , load_block_ptr , store_block_ptr , transpose , device , tmp_path : pathlib .Path ):
125126
126127 warps = warps_per_cta (layout )
127128 num_warps = int (np .prod (warps ))
@@ -132,16 +133,18 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
132133
133134 support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
134135
136+ block_io = "\" column_major\" " if transpose else "\" row_major\" "
137+
135138 if load_block_ptr :
136139 load_ops = f"""
137- %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [% N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
138- %store_val = tt.load %src_ptr {{ttig.block_io = "row_major" , boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
140+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], { "[%c1_i64, %M_i64]" if transpose else "[% N_i64, %c1_i64]" } , [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
141+ %store_val = tt.load %src_ptr {{ttig.block_io = { block_io } , boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139142 """
140143 else :
141144 load_ops = f"""
142145 %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
143- %src_ptr = tt.addptr %src_base, % row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
144- %store_val = tt.load %src_ptr {{ttig.block_io = "row_major" }} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
146+ %src_ptr = tt.addptr %src_base, { "%col_major_off" if transpose else "% row_major_off" } : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
147+ %store_val = tt.load %src_ptr {{ttig.block_io = { block_io } }} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
145148 """
146149 if store_block_ptr :
147150 store_ops = f"""
@@ -175,6 +178,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
175178 %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
176179 %row_major_off = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
177180
181+ %stride_M = arith.constant dense<{ M } > : tensor<1x{ N } xi32, #layout>
182+ %col_stride = arith.muli %5, %stride_M : tensor<1x{ N } xi32, #layout>
183+ %8 = tt.broadcast %2 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
184+ %9 = tt.broadcast %col_stride : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
185+ %col_major_off = arith.addi %8, %9 : tensor<{ M } x{ N } xi32, #layout>
186+
178187 { load_ops }
179188 { store_ops }
180189
@@ -195,6 +204,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
195204 temp_file .write_text (ir )
196205 kernel = triton .compile (str (temp_file ))
197206
207+ a = a .permute (1 , 0 ).contiguous ().permute (1 , 0 ) if transpose else a
208+
198209 kernel [(1 , 1 , 1 )](a , x )
199210 assert torch .equal (a , x )
200211
0 commit comments