Skip to content

Commit 0772cda

Browse files
committed
[LoadStoreOpToLLVM] Transposed 2d load.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 458d741 commit 0772cda

File tree

2 files changed

+70
-655
lines changed

2 files changed

+70
-655
lines changed

python/test/unit/intel/test_block_io.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,9 @@ def warps_per_cta(layout):
120120
@pytest.mark.parametrize("layout", layouts)
121121
@pytest.mark.parametrize("load_block_ptr, store_block_ptr", [(True, True), (False, False), (True, False),
122122
(False, True)])
123+
@pytest.mark.parametrize("transpose", [True, False])
123124
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
124-
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, device, tmp_path: pathlib.Path):
125+
def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, transpose, device, tmp_path: pathlib.Path):
125126

126127
warps = warps_per_cta(layout)
127128
num_warps = int(np.prod(warps))
@@ -132,16 +133,18 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
132133

133134
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
134135

136+
block_io = "\"column_major\"" if transpose else "\"row_major\""
137+
135138
if load_block_ptr:
136139
load_ops = f"""
137-
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
138-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
140+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {"[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
141+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
139142
"""
140143
else:
141144
load_ops = f"""
142145
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
143-
%src_ptr = tt.addptr %src_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
144-
%store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
146+
%src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
147+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
145148
"""
146149
if store_block_ptr:
147150
store_ops = f"""
@@ -175,6 +178,12 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
175178
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
176179
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
177180
181+
%stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout>
182+
%col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout>
183+
%8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
184+
%9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
185+
%col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout>
186+
178187
{load_ops}
179188
{store_ops}
180189
@@ -195,6 +204,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, devi
195204
temp_file.write_text(ir)
196205
kernel = triton.compile(str(temp_file))
197206

207+
a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a
208+
198209
kernel[(1, 1, 1)](a, x)
199210
assert torch.equal(a, x)
200211

0 commit comments

Comments
 (0)