@@ -336,3 +336,50 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
336
336
tt.return
337
337
}
338
338
}
339
+
340
+ // -----
341
+
342
+ // COM: Test coalescing on blocked pointers: loop result used by tt.reduce
343
+
344
+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 4 ], order = [1 , 0 ]}>
345
+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 1 , 32 ], warpsPerCTA = [1 , 4 , 4 ], order = [2 , 1 , 0 ]}>
346
+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 16 : i32 , " triton_gpu.threads-per-warp" = 32 : i32 } {
347
+ // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348
+ // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [1, 1, 16], order = [0, 1, 2]}>
349
+ // CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
350
+ // CHECK: @triton_red_fused_mul_sum_0
351
+ tt.func public @triton_red_fused_mul_sum_0 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
352
+ %c128_i32 = arith.constant 128 : i32
353
+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x128 xf32 , #blocked >
354
+ %c0_i32 = arith.constant 0 : i32
355
+ %c262144_i64 = arith.constant 262144 : i64
356
+ %c1_i64 = arith.constant 1 : i64
357
+ %c512_i64 = arith.constant 512 : i64
358
+ %c32_i32 = arith.constant 32 : i32
359
+ %c512_i32 = arith.constant 512 : i32
360
+ %0 = tt.get_program_id x : i32
361
+ %1 = arith.muli %0 , %c32_i32 : i32
362
+ %2 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #triton_gpu.slice <{dim = 0 , parent = #blocked }>>
363
+ %3 = tt.expand_dims %2 {axis = 0 : i32 } : tensor <128 xi32 , #triton_gpu.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x128 xi32 , #blocked >
364
+ %4 = arith.divsi %1 , %c512_i32 : i32
365
+ %5 = arith.remsi %1 , %c512_i32 : i32
366
+ // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
367
+ %6 = tt.make_tensor_ptr %arg0 , [%c512_i64 , %c512_i64 , %c512_i64 ], [%c1_i64 , %c512_i64 , %c262144_i64 ], [%4 , %5 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x128 xf32 , #blocked1 >>
368
+ // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR1]], [[ARG2:%.*]] = {{.*}}) -> (!tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>)
369
+ %8:2 = scf.for %arg5 = %c0_i32 to %c512_i32 step %c128_i32 iter_args (%arg6 = %6 , %arg8 = %cst_0 ) -> (!tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>, tensor <32 x128 xf32 , #blocked >) : i32 {
370
+ // CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] evictionPolicy = evict_last {boundaryCheck = array<i32: 2>, padding = 1 : i32} : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
371
+ // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
372
+ %17 = tt.load %arg6 evictionPolicy = evict_last {boundaryCheck = array<i32 : 2 >, padding = 1 : i32 } : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>
373
+ // CHECK: scf.yield [[ARG1]], [[ARG2]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>
374
+ scf.yield %arg6 , %arg8 : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>, tensor <32 x128 xf32 , #blocked >
375
+ }
376
+ // CHECK: = "tt.reduce"([[RES]]#1) <{axis = 1 : i32}> ({
377
+ // CHECK }) : (tensor<32x128xf32, [[BLOCKED_LAYOUT]]) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = [[BLOCKED_LAYOUT]]}>>
378
+ %9 = " tt.reduce" (%8#1 ) <{axis = 1 : i32 }> ({
379
+ ^bb0 (%arg5: f32 , %arg6: f32 ):
380
+ %14 = arith.addf %arg5 , %arg6 : f32
381
+ tt.reduce.return %14 : f32
382
+ }) : (tensor <32 x128 xf32 , #blocked >) -> tensor <32 xf32 , #triton_gpu.slice <{dim = 1 , parent = #blocked }>>
383
+ tt.return
384
+ }
385
+ }
0 commit comments