diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index c608b61854ae..5493b7120632 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -251,7 +251,7 @@ def Rock_GeneralGemmParamsAttr : Rock_Attr<"GeneralGemmParams", [RockTuningParam let extraClassDeclaration = [{ void getPerfConfigStr(::llvm::SmallVectorImpl &perfStr) { - ("v4:" + Twine(getBlockSize()) + "," + ("v3:" + Twine(getBlockSize()) + "," + Twine(getMPerBlock()) + "," + Twine(getNPerBlock()) + "," + Twine(getKPerBlock()) + "," diff --git a/mlir/test/rocmlir-gen/emit-tuning-space.mlir b/mlir/test/rocmlir-gen/emit-tuning-space.mlir index 93b4f448c655..d7f20c93d21f 100644 --- a/mlir/test/rocmlir-gen/emit-tuning-space.mlir +++ b/mlir/test/rocmlir-gen/emit-tuning-space.mlir @@ -1,5 +1,5 @@ // RUN: rocmlir-gen -p --arch gfx1100 --operation=gemm --emit-tuning-space=full | FileCheck %s --check-prefixes=CHECK-NAVI -// CHECK-NAVI: v4:64,32,32,4,2,4,1,1,2 +// CHECK-NAVI: v3:64,32,32,4,2,4,1,1,2 // RUN: rocmlir-gen --arch gfx90a --operation=gemm -t f32 -g 1 -m 64 -k 128 -n 64 --num_cu=104 --emit-tuning-space=full | FileCheck %s --check-prefixes=CHECK-MI // CHECK-MI: v4:64,64,8,32,32,16,4,4,1,2,1,1 diff --git a/mlir/unittests/Dialect/Rock/CMakeLists.txt b/mlir/unittests/Dialect/Rock/CMakeLists.txt index 7b4e2b1383cf..74ba5e09a72b 100644 --- a/mlir/unittests/Dialect/Rock/CMakeLists.txt +++ b/mlir/unittests/Dialect/Rock/CMakeLists.txt @@ -4,6 +4,7 @@ set(ROCK_UNITTEST_SOURCES loweringUtilsTests.cpp transformMapUtilsTests.cpp InitParamsAccelTests.cpp + InitParamsNonAccelTests.cpp ) if(NOT WIN32) diff --git a/mlir/unittests/Dialect/Rock/InitParamsNonAccelTests.cpp b/mlir/unittests/Dialect/Rock/InitParamsNonAccelTests.cpp new file mode 100644 index 000000000000..58ed3a745707 --- /dev/null +++ b/mlir/unittests/Dialect/Rock/InitParamsNonAccelTests.cpp @@ -0,0 +1,84 @@ +//===- InitParamsNonAccelTests.cpp - Tests for InitParamsNonAccel +//--------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::rock; + +namespace { + +//===----------------------------------------------------------------------===// +// v3 perfconfig +//===----------------------------------------------------------------------===// + +TEST(V3Config, First) { + InitParamsNonAccel validParams; + bool isValidPerfConfig = validParams.deserialize("v3:64,32,32,4,2,4,1,1,2"); + + EXPECT_EQ(isValidPerfConfig, true); + EXPECT_EQ(validParams.blockSize, static_cast(64)); + EXPECT_EQ(validParams.gemmMPerBlock, 32); + EXPECT_EQ(validParams.gemmNPerBlock, 32); + EXPECT_EQ(validParams.gemmKPerBlock, 4); + EXPECT_EQ(validParams.gemmMPerThread, 2); + EXPECT_EQ(validParams.gemmNPerThread, 4); + EXPECT_EQ(validParams.splitKFactor, 1); + EXPECT_EQ(validParams.gemmScheduleVersion, 1); + EXPECT_EQ(validParams.outputSwizzle, 2); + EXPECT_EQ(validParams.getKPack(), 1); + EXPECT_EQ(validParams.getVersion(), InitParamsNonAccel::Version::V3); +} + +TEST(V3Config, Second) { + InitParamsNonAccel validParams; + bool isValidPerfConfig = validParams.deserialize("v3:128,64,32,8,4,2,3,1,2"); + + EXPECT_EQ(isValidPerfConfig, true); + EXPECT_EQ(validParams.blockSize, static_cast(128)); + EXPECT_EQ(validParams.gemmMPerBlock, 64); + EXPECT_EQ(validParams.gemmNPerBlock, 32); + EXPECT_EQ(validParams.gemmKPerBlock, 8); + EXPECT_EQ(validParams.gemmMPerThread, 4); + EXPECT_EQ(validParams.gemmNPerThread, 2); + EXPECT_EQ(validParams.splitKFactor, 3); + EXPECT_EQ(validParams.gemmScheduleVersion, 1); + EXPECT_EQ(validParams.outputSwizzle, 2); + EXPECT_EQ(validParams.getKPack(), 1); + EXPECT_EQ(validParams.getVersion(), InitParamsNonAccel::Version::V3); +} + +//===----------------------------------------------------------------------===// +// Negative Tests +//===----------------------------------------------------------------------===// + +TEST(NegativeTests, NoVersion) { + InitParamsNonAccel validParams; + bool isValidPerfConfig = + validParams.deserialize("128,64,8,64,32,4,9,2,2,0,1"); + + EXPECT_EQ(isValidPerfConfig, false); +} + +TEST(NegativeTests, WrongNumberV3) { + InitParamsNonAccel validParams; + bool isValidPerfConfig = validParams.deserialize("v3:64,32,32,4,2,4,1,1"); + + EXPECT_EQ(isValidPerfConfig, false); +} + +TEST(NegativeTests, Empty) { + InitParamsNonAccel validParams; + bool isValidPerfConfig = validParams.deserialize(""); + + EXPECT_EQ(isValidPerfConfig, false); +} + +} // end anonymous namespace