Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerAllocateAMDGPUSharedMemory();
mlir::triton::registerConvertTritonAMDGPUToLLVM();
mlir::triton::registerConvertBuiltinFuncToLLVM();
mlir::triton::registerOptimizeAMDLDSUsage();

mlir::ub::registerConvertUBToLLVMInterface(registry);
mlir::registerConvertNVVMToLLVMInterface(registry);
Expand Down
5 changes: 5 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6370,6 +6370,11 @@ def gather_test_kernel_1d(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim
])
def test_gather(src_shape, indices_shape, axis, device):

if is_hip() and src_shape == [128, 64] and indices_shape == [256, 64]:
# This could be solved by reducing vectorization in general swizzling algorithm.
# We will do this if any relevant workload suffers from large LDS consumption of the algorithm.
pytest.skip('Not enough LDS.')

def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)

Expand Down
27 changes: 0 additions & 27 deletions test/Conversion/amd/allocate_shared_memory.mlir
Original file line number Diff line number Diff line change
@@ -1,32 +1,5 @@
// RUN: triton-opt %s -split-input-file --allocate-amdgpu-shared-memory | FileCheck %s

#blocked1 = #ttg.blocked<{sizePerThread = [8, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>

// This test checks padding based converter.
//
// Converter allocates temporary buffer, stores and reads parts or tensor in few transactions, which are named repeats.
// Size of temporary buffer is computed using the following algorithm:
// - get CTA tile shape of blocked1 layout: [8*8*4, 4*8*1] = [256, 32]
// - get CTA tile shape of blocked2 layout: [1*8*4, 1*8*1] = [32, 8]
// - compute common tile shape is [max(256, 32), max(32, 8)] = [256, 32].
// - pad fastest dimension(same as output layout, 1 in this case) with size of memory access to reduce bank conflicts. 16 bytes in this case.
//
// Therefore total memory consuption for scratch buffer is 256*(32 * 4(size of one element) + 16(padding)) = 36864 bytes
//
// For implementation see mlir::triton::getNumScratchElemsPaddedCvt function.

// CHECK: ttg.shared = 36864 : i32
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {

// CHECK-LABEL: @convert_layout_padded
tt.func @convert_layout_padded(%arg0: tensor<256x256xi32, #blocked1>) {
// CHECK-NEXT: allocation.offset = 0 : i32
%0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<256x256xi32, #blocked1> -> tensor<256x256xi32, #blocked2>
tt.return
}

}

// -----

Expand Down
24 changes: 0 additions & 24 deletions test/Conversion/amd/convert_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
tt.return
}
}

// -----

#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: convert_layout_padding_swizzling
tt.func @convert_layout_padding_swizzling(%arg0: tensor<64x64xf32, #blocked0>, %arg1: tensor<64x64x!tt.ptr<f32>, #blocked1>) {

// verify that following convert layout uses padded path
// see getVecAddr lambda in transferWithinBlockImpl function

// CHECK-DAG: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: [[CST_5:%.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK-DAG: [[OFFSET_0:%.*]] = llvm.lshr {{.*}}, [[CST_5]] : i32
// CHECK: [[OFFSET_1:%.*]] = llvm.shl [[OFFSET_0]], [[CST_0]] : i32
// CHECK: [[OFFSET_2:%.*]] = llvm.add [[OFFSET_1]], {{.*}} : i32
// CHECK: llvm.getelementptr inbounds {{.*}}{{\[}}[[OFFSET_2]]{{\]}}

%0 = ttg.convert_layout %arg0 {amdgpu.use_padded_scratch_shmem} : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked1>
tt.store %arg1, %0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
tt.return
}
}
170 changes: 0 additions & 170 deletions test/TritonGPU/amd/optimize-lds-usage.mlir

This file was deleted.

3 changes: 0 additions & 3 deletions third_party/amd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,4 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM)
target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers lldCommon lldELF)
endif()
if(TRITON_BUILD_UT)
add_subdirectory(unittest)
endif()
add_subdirectory(test)
2 changes: 0 additions & 2 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@ def make_llir(src, metadata, options):
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
passes.convert.add_scf_to_cf(pm)
passes.gluon.add_inliner(pm)
passes.convert.add_index_to_llvmir(pm)
Expand Down
42 changes: 1 addition & 41 deletions third_party/amd/include/Analysis/AMDGPUAllocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,57 +6,17 @@

namespace mlir::triton::AMD {

constexpr char AttrSharedMemPadded[] = "amdgpu.use_padded_scratch_shmem";

unsigned getConvertLayoutScratchInBytes(RankedTensorType srcTy,
RankedTensorType dstTy,
bool usePadding);
RankedTensorType dstTy);

unsigned AMDAllocationAnalysisScratchSizeFn(Operation *op);

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
// require multiple iterations, with each iteration involving multiple
// vectorized loads/stores. The scratch buffer has a shape (`repShape`) that
// represents the maximum size accessed in each dimension during each iteration.
// It is padded (`paddedRepShape`) to avoid bank conflicts and is accessed in a
// specific `order`.
struct ScratchConfig {
SmallVector<unsigned> repShape;
SmallVector<unsigned> paddedRepShape;
SmallVector<unsigned> order;
unsigned inVec;
unsigned outVec;

ScratchConfig(SmallVector<unsigned> repShape,
SmallVector<unsigned> paddedRepShape, unsigned inVec = 1,
unsigned outVec = 1)
: repShape(repShape), paddedRepShape(paddedRepShape), inVec(inVec),
outVec(outVec) {}

void print(llvm::raw_ostream &os) const {
os << "repShape: [";
llvm::interleaveComma(repShape, os);
os << "]";
os << ", paddedRepShape: [";
llvm::interleaveComma(paddedRepShape, os);
os << "]";
os << ", order: [";
llvm::interleaveComma(order, os);
os << "]";
os << ", inVec: " << inVec << ", outVec: " << outVec << "\n";
}
};

// For a layout conversion between `srcTy` and `dstTy`, return the vector length
// that can be used for the stores to and loads from shared memory,
// respectively.
std::pair</*inVec*/ unsigned, /*outVec*/ unsigned>
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy);

ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
RankedTensorType dstTy);

} // namespace mlir::triton::AMD

#endif // TRITONAMD_ANALYSIS_AMDGPU_ALLOCATION_H
8 changes: 0 additions & 8 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ namespace mlir::triton {
} // namespace mlir::triton

namespace mlir::triton::AMD {
/// @brief Creates pass that keep LDS consumption within specified limits.
/// @param arch target architecture name, for example "gfx940"
/// @param customLDSLimit defines LDS size available for one thread block
/// zero value tells pass that whole LDS is available on a device
/// @return created pass
std::unique_ptr<OperationPass<ModuleOp>>
createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);

void runScalarizePackedFOpsPass(llvm::Function &F);

} // namespace mlir::triton::AMD
Expand Down
Loading
Loading