diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 5ca0f2168af6..41f5b75587fe 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -94,7 +94,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerAllocateAMDGPUSharedMemory(); mlir::triton::registerConvertTritonAMDGPUToLLVM(); mlir::triton::registerConvertBuiltinFuncToLLVM(); - mlir::triton::registerOptimizeAMDLDSUsage(); mlir::ub::registerConvertUBToLLVMInterface(registry); mlir::registerConvertNVVMToLLVMInterface(registry); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 17ed9c2128e9..0f487b174f50 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6370,6 +6370,11 @@ def gather_test_kernel_1d(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim ]) def test_gather(src_shape, indices_shape, axis, device): + if is_hip() and src_shape == [128, 64] and indices_shape == [256, 64]: + # This could be solved by reducing vectorization in general swizzling algorithm. + # We will do this if any relevant workload suffers from large LDS consumption of the algorithm. + pytest.skip('Not enough LDS.') + def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) diff --git a/test/Conversion/amd/allocate_shared_memory.mlir b/test/Conversion/amd/allocate_shared_memory.mlir index 49c56c608273..0d2eae68b86e 100644 --- a/test/Conversion/amd/allocate_shared_memory.mlir +++ b/test/Conversion/amd/allocate_shared_memory.mlir @@ -1,32 +1,5 @@ // RUN: triton-opt %s -split-input-file --allocate-amdgpu-shared-memory | FileCheck %s -#blocked1 = #ttg.blocked<{sizePerThread = [8, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> - -// This test checks padding based converter. -// -// Converter allocates temporary buffer, stores and reads parts or tensor in few transactions, which are named repeats. -// Size of temporary buffer is computed using the following algorithm: -// - get CTA tile shape of blocked1 layout: [8*8*4, 4*8*1] = [256, 32] -// - get CTA tile shape of blocked2 layout: [1*8*4, 1*8*1] = [32, 8] -// - compute common tile shape is [max(256, 32), max(32, 8)] = [256, 32]. -// - pad fastest dimension(same as output layout, 1 in this case) with size of memory access to reduce bank conflicts. 16 bytes in this case. -// -// Therefore total memory consuption for scratch buffer is 256*(32 * 4(size of one element) + 16(padding)) = 36864 bytes -// -// For implementation see mlir::triton::getNumScratchElemsPaddedCvt function. - -// CHECK: ttg.shared = 36864 : i32 -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - -// CHECK-LABEL: @convert_layout_padded -tt.func @convert_layout_padded(%arg0: tensor<256x256xi32, #blocked1>) { - // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<256x256xi32, #blocked1> -> tensor<256x256xi32, #blocked2> - tt.return -} - -} // ----- diff --git a/test/Conversion/amd/convert_layout.mlir b/test/Conversion/amd/convert_layout.mlir index 2dc7b6efbd09..0da5ca0c9518 100644 --- a/test/Conversion/amd/convert_layout.mlir +++ b/test/Conversion/amd/convert_layout.mlir @@ -28,27 +28,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } - -// ----- - -#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: convert_layout_padding_swizzling - tt.func @convert_layout_padding_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr, #blocked1>) { - - // verify that following convert layout uses padded path - // see getVecAddr lambda in transferWithinBlockImpl function - - // CHECK-DAG: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-DAG: [[CST_5:%.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK-DAG: [[OFFSET_0:%.*]] = llvm.lshr {{.*}}, [[CST_5]] : i32 - // CHECK: [[OFFSET_1:%.*]] = llvm.shl [[OFFSET_0]], [[CST_0]] : i32 - // CHECK: [[OFFSET_2:%.*]] = llvm.add [[OFFSET_1]], {{.*}} : i32 - // CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_2]]{{\]}} - - %0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1> - tt.store %arg1, %0 : tensor<64x64x!tt.ptr, #blocked1> - tt.return - } -} diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir deleted file mode 100644 index 77a82acadf83..000000000000 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ /dev/null @@ -1,170 +0,0 @@ -// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a | FileCheck %s -// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a -optimize-amd-lds-usage=lds-limit=32768 | FileCheck %s --check-prefix=CHECK-32KLIMIT - -// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS -// CHECK-LABEL: alloc_convert_load -// CHECK-32KLIMIT-LABEL: alloc_convert_load -// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#mma -// CHECK: %2 = ttg.convert_layout %1 {{.*}}: {{.*}}#mma{{.*}}#mma1 -// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32, 8], isTransposed = false}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) { - %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> - %2 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - tt.return - } -} - -// ----- - -// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS -// in case of relatively small scratch buffer -// CHECK-LABEL: alloc_convert_small_load -// CHECK-32KLIMIT-LABEL: alloc_convert_small_load -// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#blocked1 -// CHECK: %2 = ttg.convert_layout %1 {{.*}}: {{.*}}#blocked1{{.*}}#mma -// CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32, 8], isTransposed = false}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<256x128xf16, #blocked>) { - %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> - %2 = ttg.convert_layout %arg1 : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - tt.return - } -} - -// ----- - -// CHECK-DAG: [[$BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> -// CHECK-DAG: [[$MMA:#.*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [2, 4, 1], instrShape = [32, 32, 8], isTransposed = false}> -// CHECK-DAG: [[$MMA1:#.*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 1, 8], instrShape = [32, 32, 8], isTransposed = false}> - -// Check that optimization works with 3d tensors -// in case of relatively small scratch buffer -// CHECK-LABEL: alloc_convert_3d_load -// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load -// CHECK: [[V0:%.*]] = ttg.local_alloc {{.*}}[[$BLOCKED1]]{{.*}} -// CHECK: [[V1:%.*]] = ttg.convert_layout {{.*}}[[$BLOCKED1]]{{.*}}[[$MMA]] -// CHECK: [[V2:%.*]] = ttg.convert_layout [[V1]] {{.*}}: {{.*}}[[$MMA]]{{.*}}[[$MMA1]] -// CHECK: [[V3:%.*]] = ttg.local_load [[V0]] : {{.*}}#ttg.dot_op<{opIdx = 0, parent = [[$MMA1]], kWidth = 4}>> -#blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> -#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 1, 8], instrShape = [32, 32, 8], isTransposed = false}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x256x128xf16, #blocked>) { - %1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #smem> - %2 = ttg.convert_layout %arg1 : tensor<1x256x128xf16, #blocked> -> tensor<1x256x128xf16, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #smem> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - tt.return - } -} - -// ----- - -// Check that optimization triggers with custom LDS limit and do not triggers with default one -// CHECK-LABEL: alloc_convert_32k_limit -// CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#mma -// CHECK: %2 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit -// CHECK-32KLIMIT: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared -// CHECK-32KLIMIT: %1 = ttg.convert_layout %arg1 {{.*}}: {{.*}}#blocked{{.*}}#blocked1 -// CHECK-32KLIMIT: %2 = ttg.convert_layout %1 {{.*}}: {{.*}}#blocked1{{.*}}#mma -// CHECK-32KLIMIT: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}#ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32, 8], isTransposed = false}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) { - %1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> - %2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> - tt.return - } -} - -// ----- - -// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion) -// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> -// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [8, 1], instrShape = [32, 32, 8], isTransposed = true}> -// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32, 8], isTransposed = false}> -// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> - -// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}}) -// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<256x128xf32, [[BLOCKED_1]]> -> tensor<256x128xf32, [[BLOCKED_2]]> -// CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<256x128xf32, [[BLOCKED_2]]> -> tensor<256x128xf32, [[MMA_2]]> -// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] {{.*}}: tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> -// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> -#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma1 = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32, 8], isTransposed = false}> -#mma2 = #ttg.amd_mfma<{version = 2, warpsPerCTA = [8, 1], instrShape = [32, 32, 8], isTransposed = true}> -#dotop1 = #ttg.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> -#dotop2 = #ttg.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<256x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) { - %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> - %convert_1 = ttg.convert_layout %arg1 : tensor<256x128xf32, #blocked> -> tensor<256x128xf32, #mma1> - %convert_2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> - %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #dotop1> - tt.return - } -} - -// ----- - -// Checks that optimization do not crash on 1d tensor -// CHECK-LABEL: convert_1d -// CHECK: ttg.local_alloc -// CHECK-NEXT: ttg.convert_layout -// CHECK-NEXT: ttg.local_load -#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -#mma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [4, 1], instrShape = [32, 32, 8], isTransposed = true}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @convert_1d(%arg0: tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) { - %alloc = ttg.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !ttg.memdesc<128x128xf32, #shared, #smem> - %1 = ttg.convert_layout %arg0 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> - %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf32, #shared, #smem> -> tensor<128x128xf32, #mma> - tt.return - } -} - -// ----- - -// Checks that optimization do not crash on linear encoding tensor -// CHECK-LABEL: convert_linear -// CHECK: ttg.local_alloc -// CHECK-NEXT: ttg.convert_layout -// CHECK-NEXT: ttg.convert_layout -// CHECK-NEXT: ttg.local_load -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [1, 0]}> -#linear = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp = [[0, 16], [32, 0]], block = []}> -#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @convert_linear(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) { - %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> - %1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #linear> - %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #blocked> - tt.return - } -} diff --git a/third_party/amd/CMakeLists.txt b/third_party/amd/CMakeLists.txt index c8e9e6f0e2d2..bf5f54b17422 100644 --- a/third_party/amd/CMakeLists.txt +++ b/third_party/amd/CMakeLists.txt @@ -9,7 +9,4 @@ if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers lldCommon lldELF) endif() -if(TRITON_BUILD_UT) - add_subdirectory(unittest) -endif() add_subdirectory(test) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 0eacfaf3c81d..8b44a89ac392 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -284,8 +284,6 @@ def make_llir(src, metadata, options): # # If custom_lds_size = 0, pass will consider all LDS is available for one threads block, # LDS size is determined by provided arch name. - custom_lds_size = 0 - amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size) passes.convert.add_scf_to_cf(pm) passes.gluon.add_inliner(pm) passes.convert.add_index_to_llvmir(pm) diff --git a/third_party/amd/include/Analysis/AMDGPUAllocation.h b/third_party/amd/include/Analysis/AMDGPUAllocation.h index 7a20ff349db8..78ba1e9ff122 100644 --- a/third_party/amd/include/Analysis/AMDGPUAllocation.h +++ b/third_party/amd/include/Analysis/AMDGPUAllocation.h @@ -6,57 +6,17 @@ namespace mlir::triton::AMD { -constexpr char AttrSharedMemPadded[] = "amdgpu.use_padded_scratch_shmem"; - unsigned getConvertLayoutScratchInBytes(RankedTensorType srcTy, - RankedTensorType dstTy, - bool usePadding); + RankedTensorType dstTy); unsigned AMDAllocationAnalysisScratchSizeFn(Operation *op); -// To convert a tensor from one layout to another, we need to allocate a -// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may -// require multiple iterations, with each iteration involving multiple -// vectorized loads/stores. The scratch buffer has a shape (`repShape`) that -// represents the maximum size accessed in each dimension during each iteration. -// It is padded (`paddedRepShape`) to avoid bank conflicts and is accessed in a -// specific `order`. -struct ScratchConfig { - SmallVector repShape; - SmallVector paddedRepShape; - SmallVector order; - unsigned inVec; - unsigned outVec; - - ScratchConfig(SmallVector repShape, - SmallVector paddedRepShape, unsigned inVec = 1, - unsigned outVec = 1) - : repShape(repShape), paddedRepShape(paddedRepShape), inVec(inVec), - outVec(outVec) {} - - void print(llvm::raw_ostream &os) const { - os << "repShape: ["; - llvm::interleaveComma(repShape, os); - os << "]"; - os << ", paddedRepShape: ["; - llvm::interleaveComma(paddedRepShape, os); - os << "]"; - os << ", order: ["; - llvm::interleaveComma(order, os); - os << "]"; - os << ", inVec: " << inVec << ", outVec: " << outVec << "\n"; - } -}; - // For a layout conversion between `srcTy` and `dstTy`, return the vector length // that can be used for the stores to and loads from shared memory, // respectively. std::pair getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy); -ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, - RankedTensorType dstTy); - } // namespace mlir::triton::AMD #endif // TRITONAMD_ANALYSIS_AMDGPU_ALLOCATION_H diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index 22d10f135e4b..4033ef24d871 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -24,14 +24,6 @@ namespace mlir::triton { } // namespace mlir::triton namespace mlir::triton::AMD { -/// @brief Creates pass that keep LDS consumption within specified limits. -/// @param arch target architecture name, for example "gfx940" -/// @param customLDSLimit defines LDS size available for one thread block -/// zero value tells pass that whole LDS is available on a device -/// @return created pass -std::unique_ptr> -createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0); - void runScalarizePackedFOpsPass(llvm::Function &F); } // namespace mlir::triton::AMD diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 68b906a4a8e4..36c18d275540 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -3,18 +3,6 @@ include "mlir/Pass/PassBase.td" -def OptimizeAMDLDSUsage : Pass<"optimize-amd-lds-usage", "mlir::ModuleOp"> { - let summary = "Minimize LDS usage"; - let constructor = "mlir::triton::AMD::createOptimizeLDSUsagePass(\"\")"; - - let options = [ - Option<"targetArch", "target-arch", "std::string", /*default*/"", - "gfx target device architecture, e.g., gfx942">, - Option<"customLDSLimit", "lds-limit", "int", /*default*/"0", - "custom limit of LDS consumption, if not provided, maximum LDS size is used">, - ]; -} - def AllocateAMDGPUSharedMemory : Pass<"allocate-amdgpu-shared-memory", "mlir::ModuleOp"> { let summary = "Add metadata for shared memory allocation"; diff --git a/third_party/amd/lib/Analysis/AMDGPUAllocation.cpp b/third_party/amd/lib/Analysis/AMDGPUAllocation.cpp index 8df56815ceb6..b194ae7cbf32 100644 --- a/third_party/amd/lib/Analysis/AMDGPUAllocation.cpp +++ b/third_party/amd/lib/Analysis/AMDGPUAllocation.cpp @@ -11,122 +11,11 @@ namespace mlir::triton::AMD { // Max shmem instruction in bits constexpr int kMaxShmemVecBitLength = 128; -unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy, - RankedTensorType dstTy) { - auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); - return getNumScratchElements(scratchConfig.paddedRepShape); -} - -SmallVector getRepShapeForCvt(RankedTensorType srcTy, - RankedTensorType dstTy) { - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (!cvtNeedsSharedMemory(srcTy, dstTy)) { - return {}; - } - - if (shouldUseDistSmem(srcLayout, dstLayout)) { - // TODO: padding to avoid bank conflicts - return convertType(gpu::getShapePerCTA(srcTy)); - } - - assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()"); - - auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); - auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); - auto srcShapePerCTATile = ::mlir::triton::AMD::getShapePerCTATile(srcTy); - auto dstShapePerCTATile = ::mlir::triton::AMD::getShapePerCTATile(dstTy); - - assert(srcTy.getRank() == dstTy.getRank() && - "src and dst must have the same rank"); - - unsigned rank = dstTy.getRank(); - SmallVector repShape(rank); - for (unsigned d = 0; d < rank; ++d) { - repShape[d] = - std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), - std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); - } - return repShape; -} - -std::pair -getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) { - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - auto srcLinAttr = gpu::toLinearEncoding(srcTy); - auto dstLinAttr = gpu::toLinearEncoding(dstTy); - auto inOrd = srcLinAttr.getOrder(); - auto outOrd = dstLinAttr.getOrder(); - - unsigned rank = srcTy.getRank(); - - unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]]; - unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]]; - unsigned innerDim = rank - 1; - unsigned inVec = outOrd[0] != innerDim ? 1 - : inOrd[0] != innerDim ? 1 - : srcContigPerThread; - unsigned outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; - - return {inVec, outVec}; -} - -ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, - RankedTensorType dstTy) { - // Initialize vector sizes and stride - auto repShape = getRepShapeForCvt(srcTy, dstTy); - if (repShape.empty()) - return ScratchConfig({}, {}); - ScratchConfig scratchConfig(repShape, repShape); - auto rank = repShape.size(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - assert(cvtNeedsSharedMemory(srcTy, dstTy)); - auto outOrd = gpu::getOrder(dstTy); - scratchConfig.order = outOrd; - - std::tie(scratchConfig.inVec, scratchConfig.outVec) = - getScratchCvtInOutVecLengths(srcTy, dstTy); - // We can't write a longer vector than the shape of shared memory. - // This shape might be smaller than the tensor shape in case we decided to - // do the conversion in multiple iterations. - unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]]; - scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim); - scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim); - // Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this - // is the max vectorisation - auto inBitWidth = getBitwidth(srcTy); - auto outBitWidth = getBitwidth(dstTy); - scratchConfig.inVec = - std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth); - scratchConfig.outVec = - std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth); - - // No padding is required if the tensor is 1-D, or if all dimensions except - // the first accessed dimension have a size of 1. - if (rank <= 1 || product(repShape) == repShape[outOrd[0]]) - return scratchConfig; - - auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec); - scratchConfig.paddedRepShape[outOrd[0]] += paddedSize; - return scratchConfig; -} - unsigned getConvertLayoutScratchInBytes(RankedTensorType srcTy, - RankedTensorType dstTy, - bool usePadding) { + RankedTensorType dstTy) { if (!cvtNeedsSharedMemory(srcTy, dstTy)) return 0; - unsigned elems = 0; - if (usePadding) { - elems = getNumScratchElemsPaddedCvt(srcTy, dstTy); - } else { - elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy); - } + unsigned elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy); return elems * getBitwidth(srcTy) / 8; } @@ -135,8 +24,7 @@ unsigned AMDAllocationAnalysisScratchSizeFn(Operation *op) { if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.getSrc().getType(); auto dstTy = cvtLayout.getType(); - return getConvertLayoutScratchInBytes(srcTy, dstTy, - op->hasAttr(AttrSharedMemPadded)); + return getConvertLayoutScratchInBytes(srcTy, dstTy); } return defaultAllocationAnalysisScratchSizeFn(op); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index edee537e6880..28f950e9e32e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -18,8 +18,6 @@ add_triton_library(TritonAMDGPUToLLVM Utility.cpp TargetInfo.cpp TargetUtils.cpp - OptimizeLDSUsage.cpp - OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp SchedInstructions.cpp UpcastMXFPToLLVM.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ad7c6bccf3ad..3baad33e8e09 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -262,303 +262,6 @@ class ConvertLayoutOpPermlaneSwap return success(); } -private: - const TargetInfoBase &targetInfo; -}; - -class ConvertLayoutForcedPadding - : public ConvertOpToLLVMPattern { -public: - ConvertLayoutForcedPadding(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { - } - - // Determine which registers are read/written in which iteration of the shmem - // transfer specified by `layout`. - SmallVector /*registers*/> - collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const { - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - StringAttr kIteration = str_attr("iteration"); - - // The choice of iteration should be determined only by the register. That - // is, it should be correct to split the register dimension into iterations. - assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); - - LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration}); - SmallVector> ret(sublayout.getOutDimSize(kIteration)); - for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) { - auto idx = sublayout.apply({{kRegister, reg}}); - ret[idx.begin()->second].push_back(reg); - } - return ret; - } - - SmallVector transferWithinBlockImpl(ArrayRef inVals, - triton::gpu::ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, - RewriterBase &rewriter) const { - MLIRContext *ctx = op.getContext(); - auto loc = op.getLoc(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - StringAttr kOffset = str_attr("offset"); - StringAttr kIteration = str_attr("iteration"); - - auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); - - auto scratchConfig = triton::AMD::getScratchConfigForCvt( - op.getSrc().getType(), op.getType()); - auto tensorShapePerCTA = - convertType(triton::gpu::getShapePerCTA( - op.getSrc().getType().getEncoding(), op.getType().getShape())); - // Input dims: [offset, iteration, block] - // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape - LinearLayout sharedLayout = - triton::gpu::chooseShemLayoutForRegToRegConversion( - ctx, tensorShapePerCTA, scratchConfig.repShape, - scratchConfig.order); - - // Layout for the store from registers to shared memory. - // - // Note: If two threads in the same warp write to the same shmem offset, the - // hardware resolves that without a stall or a bank conflict. Therefore we - // don't need to avoid duplicate writes. - // Input dims: [reg, lane, warp] - // Output dims: [offset, iteration] - LinearLayout shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout); - - const int shmemAllocatedNumElems = - getNumScratchElements(scratchConfig.paddedRepShape); - assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems); - - // Layout for the load from shmem to registers. - LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout); - - // Check that the `register` fully determines the `iteration`. That is, - // each thread does exactly the same reads and writes to shmem on each - // iteration, just with different input/output registers. - assert( - shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); - assert( - shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); - - // iteration -> registers - SmallVector> inRegsForIter = - collectRegsForIter(ctx, shmemStoreLayout); - SmallVector> outRegsForIter = - collectRegsForIter(ctx, shmemLoadLayout); - - Value smemBase = - LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); - auto sharedPtrTy = smemBase.getType(); - Type elemTy = inVals[0].getType(); - auto outSize = shmemLoadLayout.getInDimSize(kRegister); - auto iterations = sharedLayout.getInDimSize(kIteration); - assert(scratchConfig.inVec * iterations <= inVals.size()); - assert(scratchConfig.outVec * iterations <= outSize); - - // Check only one dimension has been padded. - // This means the difference between the padded shape and the original shape - // should only be in one dimension, specifically in - // `scratchConfig.order[0]`. - auto rank = scratchConfig.repShape.size(); - for (auto i = 0; i < rank; i++) { - if (i == scratchConfig.order[0]) { - continue; - } - assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]); - } - auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]]; - auto paddedSize = - scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride; - - // Linear layout function is split in two parts below: - // - // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) - // offset = regBase xor regIdx - // - // It is the same hack as what we've done in the emitIndices function to get - // around performance issues on AMD GPUs - auto getVecAddr = [&](LinearLayout &layout, Value ®Base, - int regSlice) -> Value { - auto regIdx = layout - .apply({{kRegister, regSlice}, - {kLane, 0}, - {kWarp, 0}, - {kBlock, 0}})[0] - .second; - Value offset = b.xor_(regBase, b.i32_val(regIdx)); - if (paddedSize > 0) { - assert(llvm::isPowerOf2_32(paddedStride)); - assert(llvm::isPowerOf2_32(paddedSize)); - auto rshiftVal = llvm::Log2_32(paddedStride); - auto lshiftVal = llvm::Log2_32(paddedSize); - offset = b.add( - b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)), - offset); - } - auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset, - LLVM::GEPNoWrapFlags::inbounds); - return vecAddr; - }; - - auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout, - {{kRegister, b.i32_val(0)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, b.i32_val(0)}})[0] - .second; - auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout, - {{kRegister, b.i32_val(0)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, b.i32_val(0)}})[0] - .second; - // register idx -> Value - llvm::MapVector outVals; - for (int i = 0; i < iterations; i++) { - if (i != 0) - b.barrier(); - - auto &inRegs = inRegsForIter[i]; - auto &outRegs = outRegsForIter[i]; - - // When using `stmatrix`, we can store `inVec` elements even if they are - // not contiguous - auto inVec = scratchConfig.inVec; - for (int j = 0; j < inVals.size() / iterations; j += inVec) { - auto inRegSlice = inRegs[j]; - Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice); - SmallVector inValsVec; - for (int k = 0; k < inVec; k++) - inValsVec.push_back(inVals[inRegSlice + k]); - Value valsVec = packLLVector(loc, inValsVec, rewriter); - targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec, - /*pred=*/b.true_val()); - } - - b.barrier(); - - for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { - auto outRegSlice = outRegs[j]; - auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice); - Value valsVec = - targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt, - vec_ty(elemTy, scratchConfig.outVec), - /*pred=*/b.true_val()); - for (Value v : unpackLLVector(loc, valsVec, rewriter)) - outVals[outRegSlice++] = v; - } - } - - SmallVector outValsVec; - for (size_t i = 0; i < outVals.size(); i++) - outValsVec.push_back(outVals[i]); - return outValsVec; - } - - /// Converts ConverLayoutOp to llvm using padded pattern. - /// This pattern adds unused memory locations after every rows of tensor - /// fastest changing dimension: - /// e0 e1 e2 e3 p p \ - /// e4 e5 e6 e7 p p \ - /// ... - /// e e e e p p - /// Dimension order is chosen in order to use wide output reads. - /// - /// \param op operation to convert - /// \param src llvm structure containing operation input - /// \param targetInfo - /// \param typeConverter - /// \param rewriter - /// \returns llvm structure containing converted output - Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src, - RewriterBase &rewriter) const { - MLIRContext *ctx = op.getContext(); - auto typeConverter = getTypeConverter(); - auto loc = op.getLoc(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getType(); - - // Remove the kBlock dimension from the layout as it's the identity in the - // cvt - auto srcLayout = triton::gpu::toLinearLayout(srcTy); - auto dstLayout = triton::gpu::toLinearLayout(dstTy); - - SmallVector inVals = unpackLLElements(loc, src, rewriter); - assert(!inVals.empty()); - - // We munge the input values by converting i (n<8) elements to i8 and - // pointers to i64. This is necessary because TargetInfo::loadDShared and - // storeDShared can't handle vectors of pointers or sub-byte elements. - auto elemTy = srcTy.getElementType(); - auto isSubByteInt = - elemTy.isInteger() && elemTy.getIntOrFloatBitWidth() < 8; - auto isPtr = isa(elemTy); - auto llvmElemTyOrig = typeConverter->convertType(elemTy); - if (isSubByteInt) - elemTy = IntegerType::get(elemTy.getContext(), 8); - else if (isPtr) - elemTy = IntegerType::get(elemTy.getContext(), 64); - auto llvmElemTy = typeConverter->convertType(elemTy); - - // Munge input values - for (const auto &it : llvm::enumerate(inVals)) { - if (isSubByteInt) { - inVals[it.index()] = b.zext(llvmElemTy, it.value()); - } else if (isPtr) { - inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value()); - } - } - - // Pretty sure this is the identity function ATM - // It'd be better to simply call `quotient({kBlock})` and - // remove kBlock from transferWithinBlockImpl - auto srcLayoutWithinBlock = triton::gpu::getLayoutWithinBlock(srcLayout); - auto dstLayoutWithinBlock = triton::gpu::getLayoutWithinBlock(dstLayout); - SmallVector outVals = transferWithinBlockImpl( - inVals, op, srcLayoutWithinBlock, dstLayoutWithinBlock, rewriter); - - // Unmunge output values - for (const auto &it : llvm::enumerate(outVals)) { - if (isSubByteInt) { - outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value()); - } else if (isPtr) { - outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value()); - } - } - - Value result = - packLLElements(loc, typeConverter, outVals, rewriter, op.getType()); - return result; - } - - LogicalResult - matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op->hasAttr(mlir::triton::AMD::AttrSharedMemPadded)) - return failure(); - auto srcType = op.getSrc().getType(); - auto dstType = op.getType(); - if (!cvtNeedsSharedMemory(srcType, dstType)) - return failure(); - - auto result = transferWithinBlockPadding(op, adaptor.getSrc(), rewriter); - rewriter.replaceOp(op, result); - return success(); - } - private: const TargetInfoBase &targetInfo; }; @@ -568,7 +271,6 @@ void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); // No need to convert when ForcedSwizzling as it's already the default // lowering } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp deleted file mode 100644 index 92fde8df033d..000000000000 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files - * (the "Software"), to deal in the Software without restriction, - * including without limitation the rights to use, copy, modify, merge, - * publish, distribute, sublicense, and/or sell copies of the Software, - * and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ -#include "Analysis/AMDGPUAllocation.h" -#include "OptimizeLDSUtility.h" -#include "TargetInfo.h" -#include "TritonAMDGPUToLLVM/Passes.h" -#include "mlir/Analysis/Liveness.h" -#include "mlir/Pass/Pass.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -#define DEBUG_TYPE "optimize-amd-lds-usage" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - -using namespace mlir; - -namespace mlir::triton { -#define GEN_PASS_DEF_OPTIMIZEAMDLDSUSAGE -#include "TritonAMDGPUToLLVM/Passes.h.inc" -} // namespace mlir::triton - -namespace { - -class OptimizeAMDLDSUsage - : public mlir::triton::impl::OptimizeAMDLDSUsageBase { - - int LDSLimit; - - // Try to reduce LDS usage of convert op by adding tmp layout in conversion: - // - // %1 = convert %0 (src layout -> dst layout) - // -> - // %1 = convert %0 (src layout -> tmp) - // %2 = convert %1 (tmp -> dst layout) - // - // The implicit LDS usage of convert op depends on src and dst layouts - // - // Consider mfma->blocked conversion as an example. - // - // tensor shape: [128, 128] - // mfma layout: warpsPerCTA = [1, 4], instrShape = [32, 32] - // blocked layout: sizePerThread = [1, 4], threadsPerWarp = [32, 2], - // warpsPerCTA = [4, 1] - // - // minimal mfma tile is: [1*32, 4*32] = [32, 128] - // minimal blocked tile is: [1*32*4, 4*2*1] = [128, 8] - // - // Roughly scratch buffer shape for conversion is: - // [max(32, 128), max(128, 16)] = [128, 128]. - // - // This shape could be reduces by introducing intermediate - // layout and replacing old convert operations with two new conversions: - // - // %1 = convert %0 (mfma -> blocked) - // -> - // %1 = convert %0 (mfma -> tmp) - // %2 = convert %1 (tmp -> blocked) - // - // Let's consider tmp as blocked layout: - // sizePerThread = [1, 4], threadsPerWarp = [32, 2], warpsPerCTA = [1, 4] - // Tmp layout scratch buffer has shape: [1*32*1, 4*2*4] = [32, 32] - // - // With intermediate layout we have two scratch buffers: - // - // %1 = convert %0 (mfma -> tmp): [max(32, 32), max(128, 32)] = [32, 128] - // %2 = convert %1 (tmp -> blocked): [max(32, 128), max(32, 32)] = [128, 32] - // - // Both of these buffers are 4x times smaller than original one and their live - // times do not intersect, therefore this transformation lowers LDS - // consumption. - void tryFitCvtIntoLDS(triton::gpu::ConvertLayoutOp cvtOp, int targetLDSSize) { - LDBG("Trying fit " << cvtOp << " into " << targetLDSSize << " bytes"); - OpBuilder builder(cvtOp); - - auto ctx = builder.getContext(); - auto srcType = cvtOp.getSrc().getType(); - auto dstType = cvtOp.getType(); - - auto srcEnc = - cast(srcType.getEncoding()); - auto dstEnc = - cast(dstType.getEncoding()); - - auto rank = srcType.getRank(); - - unsigned numWarps = triton::gpu::lookupNumWarps(cvtOp); - auto warpSize = triton::gpu::lookupThreadsPerWarp(builder); - - // Find all possible shapes of WarpsPerCTA by finding all possible - // factorizations of numWarps. Pick shape for which both conversions in - // decomposition use LDS less than LDSLimit and for which sum of LDS usage - // is minimal. If no such shape exists, do not decompose. - auto factorizedNumWarps = - mlir::triton::AMD::factorizePowerOf2(numWarps, rank); - // Create a list of temporary layouts - SmallVector elemsPerThread(rank, 1); - SmallVector threadsPerWarp(rank, 1); - - // Special case for rank == 1 - if (rank == 1) { - threadsPerWarp[0] = warpSize; - } else { - assert(rank > 1); - threadsPerWarp[rank - 1] = warpSize / 8; - threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; - } - - auto layoutCTA = triton::gpu::getCTALayout(srcEnc); - auto order = triton::gpu::getOrder(srcType); - SmallVector dummyWarpsPerCTA(rank, 1); - - auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get( - ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order, - layoutCTA); - SmallVector tmpLayouts; - for (int i = 0; i < factorizedNumWarps.size(); i++) { - auto warpsPerCTA = factorizedNumWarps[i]; - - auto pushNotNull = [&](Attribute enc) { - if (enc) - tmpLayouts.push_back(enc); - }; - - pushNotNull(mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA)); - pushNotNull(mlir::triton::AMD::createTmpLayout(dstEnc, warpsPerCTA)); - pushNotNull( - mlir::triton::AMD::createTmpLayout(baseFallbackLayout, warpsPerCTA)); - } - - unsigned minLDSUsage = 2 * LDSLimit; - int minIdx = -1; - bool currentBestHasPadding = true; - - for (int i = 0; i < tmpLayouts.size(); i++) { - auto resources = mlir::triton::AMD::estimateResourcesForReplacement( - builder, cvtOp, tmpLayouts[i]); - - // Select between padded and swizzled variants of the same tmpLayout - // Prioritize swizzling: use swizzled if it fits in budget or uses less - // LDS - bool useSwizzling = (resources.LDSSwizzle < targetLDSSize) || - (resources.LDSSwizzle < resources.LDSPad); - int LDS = useSwizzling ? resources.LDSSwizzle : resources.LDSPad; - - LDBG("layout " << tmpLayouts[i] << " requires " << LDS << " bytes of LDS " - << (useSwizzling ? "without" : "with") << " padding"); - // Now select the best layout among all valid candidates - if (LDS < targetLDSSize) { - bool hasBetterLDS = - (currentBestHasPadding != useSwizzling) && (LDS < minLDSUsage); - bool hasBetterLayout = currentBestHasPadding && useSwizzling; - if (hasBetterLDS || hasBetterLayout) { - minLDSUsage = LDS; - minIdx = i; - currentBestHasPadding = !useSwizzling; - } - } - } - - if (minIdx == -1 || minLDSUsage > targetLDSSize) { - return; - } - assert(minIdx >= 0 && minIdx < tmpLayouts.size()); - - bool hasAttr = cvtOp->hasAttr(triton::AMD::AttrSharedMemPadded); - if (currentBestHasPadding && !hasAttr) { - cvtOp->setAttr(triton::AMD::AttrSharedMemPadded, UnitAttr::get(ctx)); - // if padded layout drops LDS usage on itself, we are done, return - if (triton::AMD::getConvertLayoutScratchInBytes( - srcType, dstType, /*usePadding*/ true) <= targetLDSSize) { - return; - } - } else if (!currentBestHasPadding && hasAttr) { - cvtOp->removeAttr(triton::AMD::AttrSharedMemPadded); - } - - auto tmpLayout = tmpLayouts[minIdx]; - auto replacementCvts = - mlir::triton::AMD::createNewConvertOps(builder, cvtOp, tmpLayout); - - cvtOp.replaceAllUsesWith(replacementCvts.second.getResult()); - cvtOp.erase(); - } - - struct LDSBottleneckOperation { - triton::gpu::ConvertLayoutOp op; - int64_t LDSSizeTarget; - }; - - // Assuming that all buffer above scratch buffer in memory space can be - // shifted down in memory, gives an optimistic estimation of memory space - // available for scratch buffer. - int64_t - computeTargetScratchBufferSize(triton::gpu::ConvertLayoutOp op, - Allocation *allocation, - ArrayRef liveBuffers) { - int totalSize = 0; - auto scratchBufferId = allocation->getBufferId(op.getOperation()); - int64_t scratchBufferSize = allocation->getAllocatedSize(scratchBufferId); - size_t totalLDSConsumption = 0; - for (auto buf : liveBuffers) { - totalLDSConsumption = std::max( - totalLDSConsumption, allocation->getAllocatedInterval(buf).end()); - } - int64_t freeRequired = totalLDSConsumption - LDSLimit; - return std::max(static_cast(0), scratchBufferSize - freeRequired); - } - - SmallVector - findLDSBottleneckLayoutConvert(ModuleAllocation &allocAnalysis, - FunctionOpInterface func) { - SmallVector candidates; - auto funcAnalysis = allocAnalysis.getFuncData(func); - auto liveBuffers = funcAnalysis->getLiveBuffers(); - - func.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { - auto srcTy = cvtOp.getSrc().getType(); - auto dstTy = cvtOp.getResult().getType(); - if (!cvtNeedsSharedMemory(srcTy, dstTy)) - return; - auto cvtBuffer = funcAnalysis->getBufferId(cvtOp.getOperation()); - assert(cvtBuffer != Allocation::InvalidBufferId); - - auto targetScratchBufferSize = computeTargetScratchBufferSize( - cvtOp, funcAnalysis, liveBuffers[cvtOp]); - auto currentLDSConsumption = funcAnalysis->getAllocatedSize(cvtBuffer); - if (currentLDSConsumption > targetScratchBufferSize) - candidates.push_back({cvtOp, targetScratchBufferSize}); - }); - return candidates; - } - -public: - OptimizeAMDLDSUsage(StringRef targetArch, int customLDSLimit) - : OptimizeAMDLDSUsageBase() { - this->targetArch = targetArch.str(); - this->customLDSLimit = customLDSLimit; - } - - void runOnOperation() override { - ModuleOp mod = getOperation(); - - if ((this->LDSLimit = this->customLDSLimit) == 0) { - if (this->targetArch.empty()) { - mod->emitError("missing gfx* target for pass ") - << this->getName().str(); - return signalPassFailure(); - } - triton::AMD::TargetInfo targetInfo(this->targetArch.c_str()); - LDSLimit = targetInfo.getSharedMemorySize(); - } - - ModuleAllocation allocAnalysis( - mod, mlir::triton::AMD::AMDAllocationAnalysisScratchSizeFn); - if (allocAnalysis.getSharedMemorySize() <= LDSLimit) - return; - - auto rootFunctions = allocAnalysis.getRoots(); - for (auto rootFunc : rootFunctions) { - // Find operations with peak LDS consumption - auto candidates = findLDSBottleneckLayoutConvert(allocAnalysis, rootFunc); - // Try to transform candidate operations to fit them into LDS - for (auto candidate : candidates) - tryFitCvtIntoLDS(candidate.op, candidate.LDSSizeTarget); - } - } -}; - -} // namespace - -namespace mlir::triton::AMD { - -std::unique_ptr> -createOptimizeLDSUsagePass(StringRef targetArch, int customLDSLimit) { - return std::make_unique(targetArch, customLDSLimit); -} - -} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp deleted file mode 100644 index dfd4eac2c294..000000000000 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include "OptimizeLDSUtility.h" -#include "Analysis/AMDGPUAllocation.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "llvm/Support/MathExtras.h" - -namespace mlir::triton::AMD { - -static void stepFactorizationPow2(std::vector> &factors, - SmallVector &curFactor, - int restTwos, int dim) { - if (dim == curFactor.size()) { - if (restTwos == 0) - factors.push_back(curFactor); - return; - } - curFactor[dim] = 1; - for (int i = 0; i <= restTwos; ++i) { - stepFactorizationPow2(factors, curFactor, restTwos - i, dim + 1); - curFactor[dim] *= 2; - } -} - -std::vector> factorizePowerOf2(int n, int rank) { - assert(llvm::isPowerOf2_32(n)); - int x = log2(n); - std::vector> factors; - SmallVector curFactor(rank, 1); - stepFactorizationPow2(factors, curFactor, x, 0); - return factors; -} - -triton::gpu::DistributedEncodingTrait -createTmpLayout(triton::gpu::DistributedEncodingTrait layout, - ArrayRef warpsPerCTA) { - auto ctx = layout.getContext(); - if (auto src = dyn_cast(layout)) - return triton::gpu::AMDMfmaEncodingAttr::get( - ctx, src.getVersion(), warpsPerCTA, src.getInstrShape(), - src.getIsTransposed(), src.getCTALayout(), src.getTilesPerWarp(), - src.getElementBitWidth()); - if (auto src = dyn_cast(layout)) - return triton::gpu::AMDWmmaEncodingAttr::get( - ctx, src.getVersion(), src.getIsTransposed(), warpsPerCTA, - src.getCTALayout(), src.getInstrShape()); - if (auto src = dyn_cast(layout)) - return triton::gpu::BlockedEncodingAttr::get( - ctx, src.getSizePerThread(), src.getThreadsPerWarp(), warpsPerCTA, - src.getOrder(), src.getCTALayout()); - if (auto src = dyn_cast(layout)) { - auto parent = cast(src.getParent()); - parent = createTmpLayout(parent, warpsPerCTA); - if (!parent) - return {}; - return triton::gpu::DotOperandEncodingAttr::get(ctx, src.getOpIdx(), parent, - src.getKWidth()); - } - if (auto src = dyn_cast(layout)) { - auto warps = to_vector(warpsPerCTA); - warps.insert(warps.begin() + src.getDim(), 1); - auto parent = createTmpLayout(src.getParent(), warps); - if (!parent) - return {}; - return triton::gpu::SliceEncodingAttr::get(ctx, src.getDim(), parent); - } - // TODO: support linear layout if needed. - if (isa(layout)) - return {}; - assert(false && "Encountered unsupported layout"); - return {}; -} - -std::pair -createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, - Attribute tmpLayout) { - auto srcType = cvtOp.getSrc().getType(); - auto dstType = cvtOp.getType(); - - auto newDstType = RankedTensorType::get( - dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); - RankedTensorType newSrcType = RankedTensorType::get( - srcType.getShape(), srcType.getElementType(), tmpLayout); - - auto tmpCvt = builder.create( - cvtOp.getLoc(), newSrcType, cvtOp.getSrc()); - auto newEpilogueCvt = builder.create( - cvtOp.getLoc(), newDstType, tmpCvt); - tmpCvt->setAttrs(cvtOp->getAttrs()); - newEpilogueCvt->setAttrs(cvtOp->getAttrs()); - - return std::make_pair(tmpCvt, newEpilogueCvt); -} - -Resources -estimateResourcesForReplacement(OpBuilder builder, - mlir::triton::gpu::ConvertLayoutOp cvtOp, - Attribute tmpLayout) { - Resources res; - RankedTensorType srcTy = cvtOp.getSrc().getType(); - RankedTensorType dstTy = cvtOp.getType(); - RankedTensorType intermediateTy = RankedTensorType::get( - srcTy.getShape(), srcTy.getElementType(), tmpLayout); - auto *ctx = cvtOp->getContext(); - - int tmpCvtLDS = getConvertLayoutScratchInBytes(srcTy, intermediateTy, - /*usePadding*/ true); - int tmpCvtLDSNoPad = getConvertLayoutScratchInBytes(srcTy, intermediateTy, - /*usePadding*/ false); - int newCvtLDS = getConvertLayoutScratchInBytes(intermediateTy, dstTy, - /*usePadding*/ true); - int newCvtLDSNoPad = getConvertLayoutScratchInBytes(intermediateTy, dstTy, - /*usePadding*/ false); - - res.LDSPad = std::max(tmpCvtLDS, newCvtLDS); - res.LDSSwizzle = std::max(tmpCvtLDSNoPad, newCvtLDSNoPad); - - return res; -} - -} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h deleted file mode 100644 index a994f37b1011..000000000000 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ -#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ - -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -namespace mlir::triton::AMD { - -std::vector> factorizePowerOf2(int n, int rank); - -/// Copy given layout with different warpsPerCTA parameter -/// -/// \param layout original layout -/// \param warpsPerCTA new warpsPerCTA -/// \returns create layout -triton::gpu::DistributedEncodingTrait -createTmpLayout(triton::gpu::DistributedEncodingTrait layout, - ArrayRef warpsPerCTA); - -/// Creates two chained convert layout operations -/// -/// %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation -/// -> -/// %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first -/// %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second -/// -/// \param builder -/// \param cvtOp original operation -/// \param tmpLayout -/// \returns pair of created operations -std::pair -createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, - Attribute tmpLayout); - -struct Resources { - int LDSPad; - int LDSSwizzle; -}; - -Resources -estimateResourcesForReplacement(OpBuilder builder, - mlir::triton::gpu::ConvertLayoutOp cvtOp, - Attribute tmpLayout); - -} // namespace mlir::triton::AMD - -#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_OPTIMIZELDSUTILITY_H_ diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 35bf24a7fe91..4184f6156237 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -61,9 +61,6 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass( arch, numStages)); }); - ADD_PASS_WRAPPER_2("add_optimize_lds_usage", - mlir::triton::AMD::createOptimizeLDSUsagePass, - const std::string &, int32_t); ADD_PASS_WRAPPER_0("add_allocate_shared_memory", mlir::triton::createAllocateAMDGPUSharedMemory); ADD_PASS_OPTION_WRAPPER_3("add_accelerate_matmul", diff --git a/third_party/amd/unittest/CMakeLists.txt b/third_party/amd/unittest/CMakeLists.txt deleted file mode 100644 index bd3c0c6c01c5..000000000000 --- a/third_party/amd/unittest/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Conversion) diff --git a/third_party/amd/unittest/Conversion/CMakeLists.txt b/third_party/amd/unittest/Conversion/CMakeLists.txt deleted file mode 100644 index 1902c3dfd302..000000000000 --- a/third_party/amd/unittest/Conversion/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -add_triton_ut( - NAME TestOptimizeLDS - SRCS OptimizeLDSTest.cpp - LIBS - TritonAnalysis - TritonIR - TritonGPUIR - TritonAMDGPUToLLVM - MLIRUBToLLVM - TritonAMDUtils - TritonAMDAnalysis - TritonAMDGPUTransforms - TritonAMDGPUDialectToLLVM -) diff --git a/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp b/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp deleted file mode 100644 index a9f112239ff8..000000000000 --- a/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp +++ /dev/null @@ -1,42 +0,0 @@ -//===- OptimizeLDSTest.cpp - Tests for OptimizeLDSUtility -----------------===// - -#include "third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h" -#include -#include - -namespace mlir { - -template bool checkProdEq(ArrayRef a) { - unsigned prod = - std::reduce(a.begin(), a.end(), 1u, std::multiplies()); - return prod == P; -} - -TEST(OptimizeLDSUtility, factorizePowerOf2) { - int numwarps; - int rank; - // check rank=1 generation - numwarps = 4; - rank = 1; - auto output1 = triton::AMD::factorizePowerOf2(numwarps, rank); - ASSERT_EQ(output1.size(), 1); - ASSERT_EQ(output1[0][0], numwarps); - // check rank=2 generation - numwarps = 8; - rank = 2; - auto output2 = triton::AMD::factorizePowerOf2(numwarps, rank); - ASSERT_EQ(output2.size(), 4); - ASSERT_TRUE(std::all_of(output2.begin(), output2.end(), checkProdEq<8>)); - ASSERT_TRUE(std::all_of(output2.begin(), output2.end(), - [](auto a) { return a.size() == 2; })); - // check rank=3 generation - numwarps = 8; - rank = 3; - auto output3 = triton::AMD::factorizePowerOf2(numwarps, rank); - ASSERT_EQ(output3.size(), 10); - ASSERT_TRUE(std::all_of(output3.begin(), output3.end(), checkProdEq<8>)); - ASSERT_TRUE(std::all_of(output3.begin(), output3.end(), - [](auto a) { return a.size() == 3; })); -} - -} // namespace mlir