diff --git a/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp b/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp index 10dfc63b37..d2c3cac0b2 100644 --- a/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp +++ b/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp @@ -751,7 +751,8 @@ void compileShader(dxc::SpecificDllLoader &DxcSupport, const char *Source, if (VerboseLogging) { hlsl_test::LogCommentFmt(L"Shader Source:"); - hlsl_test::LogCommentFmt(L"%c", Source); + hlsl_test::LogCommentFmt( + std::wstring(Source, Source + strlen(Source)).c_str()); } hlsl_test::LogCommentFmt(LogFlags.str().c_str()); diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index da32f553c4..b276d823dd 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -199,35 +199,45 @@ static bool verifyComponentBuffer(ComponentType CompType, const void *Actual, } static bool fillInputBuffer(LPCSTR Name, std::vector &Data, - ComponentType CompType, size_t NumElements) { + ComponentType CompType, size_t NumElements, + size_t StartingVal = 1, bool Increment = true) { if (_stricmp(Name, "Input") != 0) return true; switch (CompType) { - case ComponentType::F32: { - float *Ptr = reinterpret_cast(Data.data()); - for (size_t I = 0; I < NumElements; I++) - Ptr[I] = static_cast(I + 1); - return true; - } - case ComponentType::I32: { - int32_t *Ptr = reinterpret_cast(Data.data()); - for (size_t I = 0; I < NumElements; I++) - Ptr[I] = static_cast(I + 1); - return true; - } - case ComponentType::F16: { - HLSLHalf_t *Ptr = reinterpret_cast(Data.data()); - for (size_t I = 0; I < NumElements; I++) - Ptr[I] = HLSLHalf_t(static_cast(I + 1)); - return true; + case ComponentType::F32: + case ComponentType::I32: + case ComponentType::F16: + break; + default: + return false; } + + for (size_t I = 0; I < NumElements; ++I) { + size_t Value = StartingVal + (Increment ? I : 0); + switch (CompType) { + case ComponentType::F32: { + float *Ptr = reinterpret_cast(Data.data()); + Ptr[I] = static_cast(Value); + break; + } + case ComponentType::I32: { + int32_t *Ptr = reinterpret_cast(Data.data()); + Ptr[I] = static_cast(Value); + break; + } + case ComponentType::F16: { + HLSLHalf_t *Ptr = reinterpret_cast(Data.data()); + Ptr[I] = HLSLHalf_t(static_cast(Value)); + break; + } + } } - return false; + return true; } -static VariantCompType makeExpected(ComponentType CompType, MatrixDim M, +static VariantCompType makeExpectedMat(ComponentType CompType, MatrixDim M, MatrixDim N, float StartingVal, bool Increment = true, bool Transpose = false) { @@ -281,6 +291,12 @@ static VariantCompType makeExpected(ComponentType CompType, MatrixDim M, } } +static VariantCompType makeExpectedVec(ComponentType CompType, MatrixDim NumElements, + float StartingVal, + bool Increment = true) { + return makeExpectedMat(CompType, 1, NumElements, StartingVal, Increment, false); +} + class DxilConf_SM610_LinAlg { public: BEGIN_TEST_CLASS(DxilConf_SM610_LinAlg) @@ -299,14 +315,32 @@ class DxilConf_SM610_LinAlg { TEST_CLASS_SETUP(setupClass); TEST_METHOD_SETUP(setupMethod); - // Load/Store - TEST_METHOD(LoadStoreRoundtrip_Wave_16x16_F16); - - // Splat Store + // Load/Store/Accumulate Descriptor + TEST_METHOD(LoadStoreDescriptor_Wave_16x16_F16); TEST_METHOD(SplatStore_Wave_16x16_F16); + TEST_METHOD(AccumulateDescriptor_Wave_16x16_F16); + TEST_METHOD(AccumulateDescriptor_Thread_16x16_F16); // Element access TEST_METHOD(ElementAccess_Wave_16x16_F16); + TEST_METHOD(ElementSet_Wave_16x16_F16); + + // Cast/Convert + TEST_METHOD(CopyConvert_Wave_16x16_F16); + TEST_METHOD(CopyConvert_Wave_16x16_F16_Transpose); + + // Matrix Matrix Arithmetic + TEST_METHOD(MatMatMul_Wave_16x16x16_F16); + TEST_METHOD(MatMatMulAccum_Wave_16x16x16_F16); + TEST_METHOD(MatAccum_Wave_16x16_F16); + + // Matrix Vector Arithmetic + TEST_METHOD(MatVecMul_Thread_16x16_F16); + TEST_METHOD(MatVecMulAdd_Thread_16x16_F16); + TEST_METHOD(OuterProduct_Thread_16x16_F16); + + // Query Accumulator Layout + TEST_METHOD(QueryAccumLayout); private: CComPtr D3DDevice; @@ -358,7 +392,7 @@ bool DxilConf_SM610_LinAlg::setupMethod() { return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false); } -static const char LoadStoreShader[] = R"( +static const char LoadStoreDescriptorShader[] = R"( RWByteAddressBuffer Input : register(u0); RWByteAddressBuffer Output : register(u1); @@ -378,9 +412,9 @@ static const char LoadStoreShader[] = R"( } )"; -static void runLoadStoreRoundtrip(ID3D12Device *Device, - dxc::SpecificDllLoader &DxcSupport, - const MatrixParams &Params, bool Verbose) { +static void runLoadStoreDescriptor(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t BufferSize = Params.totalBytes(); @@ -390,13 +424,14 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device, std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); - compileShader(DxcSupport, LoadStoreShader, "cs_6_10", Args, Verbose); + compileShader(DxcSupport, LoadStoreDescriptorShader, "cs_6_10", Args, + Verbose); - auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1); // Construct the ShaderOp: two UAV buffers, load from one, store to other. - auto Op = createComputeOp(LoadStoreShader, "cs_6_10", "UAV(u0), UAV(u1)", - Args.c_str()); + auto Op = createComputeOp(LoadStoreDescriptorShader, "cs_6_10", + "UAV(u0), UAV(u1)", Args.c_str()); addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); addUAVBuffer(Op.get(), "Output", BufferSize, true); addRootUAV(Op.get(), 0, "Input"); @@ -418,7 +453,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device, Expected, NumElements, Verbose)); } -void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() { +void DxilConf_SM610_LinAlg::LoadStoreDescriptor_Wave_16x16_F16() { MatrixParams Params = {}; Params.CompType = ComponentType::F16; Params.M = 16; @@ -428,7 +463,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() { Params.Layout = LinalgMatrixLayout::RowMajor; Params.NumThreads = 64; Params.Enable16Bit = true; - runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging); + runLoadStoreDescriptor(D3DDevice, DxcSupport, Params, VerboseLogging); } static const char SplatStoreShader[] = R"( @@ -464,7 +499,7 @@ static void runSplatStore(ID3D12Device *Device, compileShader(DxcSupport, SplatStoreShader, "cs_6_10", Args, Verbose); auto Expected = - makeExpected(Params.CompType, Params.M, Params.N, FillValue, false); + makeExpectedMat(Params.CompType, Params.M, Params.N, FillValue, false); auto Op = createComputeOp(SplatStoreShader, "cs_6_10", "UAV(u0)", Args.c_str()); @@ -493,6 +528,82 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() { runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging); } +static const char AccumulateDescriptorShader[] = R"( + #define USE_ACC 2 + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + Mat; + __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE); + __builtin_LinAlg_MatrixAccumulateToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runAccumulateDescriptor(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, float FillValue, + bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << "-DFILL_VALUE=" << FillValue; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, AccumulateDescriptorShader, "cs_6_10", Args, Verbose); + + auto Expected = + makeExpectedMat(Params.CompType, Params.M, Params.N, FillValue, false); + + auto Op = + createComputeOp(AccumulateDescriptorShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging); +} + +void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging); +} + static const char ElementAccessShader[] = R"( RWByteAddressBuffer Input : register(u0); RWByteAddressBuffer Output : register(u1); @@ -537,25 +648,20 @@ static void runElementAccess(ID3D12Device *Device, const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t NumThreads = Params.NumThreads; - const size_t InputBufSize = Params.totalBytes(); - const size_t ElementSize = elementSize(Params.CompType); - - // Output: ElementSize bytes per element - // 1 element for each mat idx - // 1 uint for each thread's length - const size_t OutputBufSize = - NumElements * ElementSize + NumThreads * sizeof(uint32_t); + const size_t MatrixSize = Params.totalBytes(); + // OutputBuf needs to fit the Matrix plus one uint per thread + const size_t OutputBufSize = MatrixSize + NumThreads * sizeof(uint32_t); std::stringstream ExtraDefs; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose); - auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1); auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)", Args.c_str()); - addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname"); + addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname"); addUAVBuffer(Op.get(), "Output", OutputBufSize, true); addRootUAV(Op.get(), 0, "Input"); addRootUAV(Op.get(), 1, "Output"); @@ -579,9 +685,8 @@ static void runElementAccess(ID3D12Device *Device, // Verify the end of the buffer is NumThreads number of lengths, whose // sum is greater than or equal to NumElements const BYTE *Out = static_cast(OutData.data()); - size_t MatrixEndOffset = NumElements * ElementSize; const uint32_t *Lengths = - reinterpret_cast(Out + MatrixEndOffset); + reinterpret_cast(Out + MatrixSize); uint32_t TotalLength = 0; for (size_t I = 0; I < NumThreads; ++I) TotalLength += Lengths[I]; @@ -602,4 +707,712 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() { runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging); } +static const char ElementSetShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 128); + + // Increment every element by 5 + for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) { + ELEM_TYPE Elem; + __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); + Elem = Elem + 5; + __builtin_LinAlg_MatrixSetElement(Mat, Mat, I, Elem); + } + + __builtin_LinAlg_MatrixStoreToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runElementSet(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t MatrixSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, ElementSetShader, "cs_6_10", Args, Verbose); + + // Start counting from 6 since each element was increased by 5 + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 6); + + auto Op = createComputeOp(ElementSetShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", MatrixSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + // Verify the front of the buffer is a list of elements of the expected type + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::ElementSet_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runElementSet(D3DDevice, DxcSupport, Params, VerboseLogging); +} + +static const char CopyConvertShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Src; + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, N_DIM, M_DIM, USE, SCOPE)]] + Dst; + + __builtin_LinAlg_MatrixLoadFromDescriptor( + Src, Input, 0, STRIDE, LAYOUT, 128); + __builtin_LinAlg_CopyConvertMatrix(Dst, Src, TRANSPOSE); + __builtin_LinAlg_MatrixStoreToDescriptor( + Dst, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runCopyConvert(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + bool Transpose) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DTRANSPOSE=" << Transpose; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, CopyConvertShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1, + /*Increment=*/true, Transpose); + + // Construct the ShaderOp: two UAV buffers, load from one, store to other. + auto Op = createComputeOp(CopyConvertShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, + /*Transpose=*/false); +} + +void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16_Transpose() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, + /*Transpose=*/true); +} + +static const char MatMatMulShader[] = R"( + #define USE_A 0 + #define USE_B 1 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]] + MatA; + __builtin_LinAlg_FillMatrix(MatA, A_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]] + MatB; + __builtin_LinAlg_FillMatrix(MatB, B_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatC; + __builtin_LinAlg_MatrixMatrixMultiply(MatC, MatA, MatB); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatC, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatMatMul(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, MatrixDim K, + float AFill, float BFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DK_DIM=" << K; + ExtraDefs << " -DA_FILL=" << AFill; + ExtraDefs << " -DB_FILL=" << BFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatMatMulShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + AFill * BFill * K, /*Increment=*/false); + + auto Op = + createComputeOp(MatMatMulShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16, + /*AFill=*/2.0f, /*BFill=*/3.0f); +} + +static const char MatMatMulAccumShader[] = R"( + #define USE_A 0 + #define USE_B 1 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]] + MatA; + __builtin_LinAlg_FillMatrix(MatA, A_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]] + MatB; + __builtin_LinAlg_FillMatrix(MatB, B_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatC; + __builtin_LinAlg_FillMatrix(MatC, C_FILL); + + __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(MatC, MatA, MatB, MatC); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatC, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatMatMulAccum(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + MatrixDim K, float AFill, float BFill, + float CFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DK_DIM=" << K; + ExtraDefs << " -DA_FILL=" << AFill; + ExtraDefs << " -DB_FILL=" << BFill; + ExtraDefs << " -DC_FILL=" << CFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatMatMulAccumShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + AFill * BFill * K + CFill, /*Increment=*/false); + + auto Op = + createComputeOp(MatMatMulAccumShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatMatMulAccum_Wave_16x16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatMatMulAccum(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16, + /*AFill=*/2.0f, /*BFill=*/3.0f, /*CFill=*/4.0f); +} + +static const char MatAccumShader[] = R"( + #define USE_A 0 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + if (WaveReadLaneFirst(threadID) != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatLHS; + __builtin_LinAlg_FillMatrix(MatLHS, LHS_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE)]] + MatRHS; + __builtin_LinAlg_FillMatrix(MatRHS, RHS_FILL); + + __builtin_LinAlg_MatrixAccumulate(MatLHS, MatLHS, MatRHS); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatLHS, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatAccum(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, float LHSFill, + float RHSFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DLHS_FILL=" << LHSFill; + ExtraDefs << " -DRHS_FILL=" << RHSFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatAccumShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + LHSFill + RHSFill, /*Increment=*/false); + + auto Op = createComputeOp(MatAccumShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatAccum_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatAccum(D3DDevice, DxcSupport, Params, VerboseLogging, + /*LHSFill=*/2.0f, /*RHSFill=*/3.0f); +} + +static const char MatVecMulShader[] = R"( + #define USE_A 0 + #define SCOPE_THREAD 0 + + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [numthreads(NUMTHREADS, 1, 1)] + void main() { + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]] + Mat; + __builtin_LinAlg_FillMatrix(Mat, MAT_FILL); + + vector InVec; + for (uint I = 0; I < M_DIM; ++I) { + InVec[I] = Input.Load(I * ELEM_SIZE); + } + + vector OutVec; + __builtin_LinAlg_MatrixVectorMultiply( + OutVec, Mat, OUTPUT_SIGNED, InVec, IN_INTERP); + + for (uint I = 0; I < M_DIM; ++I) { + Output.Store(I * ELEM_SIZE, OutVec[I]); + } + } +)"; + +static void runMatVecMul(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + float MatFill, bool OutputSigned, + ComponentType InputInterp) { + const size_t NumElements = Params.M; + const size_t BufferSize = elementSize(Params.CompType) * NumElements; + + std::stringstream ExtraDefs; + ExtraDefs << " -DMAT_FILL=" << MatFill; + ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned; + ExtraDefs << " -DIN_INTERP=" << static_cast(InputInterp); + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatVecMulShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedVec(Params.CompType, Params.M, MatFill * Params.N, + /*Increment=*/false); + + auto Op = createComputeOp(MatVecMulShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements, + /*StartingVal=*/1, /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runMatVecMul(D3DDevice, DxcSupport, Params, VerboseLogging, + /*MatFill=*/2.0f, /*OutputSigned=*/true, ComponentType::F16); +} + +static const char MatVecMulAddShader[] = R"( + #define USE_A 0 + #define SCOPE_THREAD 0 + + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [numthreads(NUMTHREADS, 1, 1)] + void main() { + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]] + Mat; + __builtin_LinAlg_FillMatrix(Mat, MAT_FILL); + + vector InVec; + for (uint I = 0; I < M_DIM; ++I) { + InVec[I] = Input.Load(I * ELEM_SIZE); + } + + // TODO: this is just copying InVec but it should be a unique value + vector BiasVec; + for (uint I = 0; I < M_DIM; ++I) { + BiasVec[I] = Input.Load(I * ELEM_SIZE); + } + + vector OutVec; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + OutVec, Mat, OUTPUT_SIGNED, InVec, IN_INTERP, BiasVec, BIAS_INTERP); + + for (uint I = 0; I < M_DIM; ++I) { + Output.Store(I * ELEM_SIZE, OutVec[I]); + } + } +)"; + +static void runMatVecMulAdd(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + float MatFill, bool OutputSigned, + ComponentType InputInterp, + ComponentType BiasInterp) { + const size_t NumElements = Params.M; + const size_t BufferSize = elementSize(Params.CompType) * NumElements; + + std::stringstream ExtraDefs; + ExtraDefs << " -DMAT_FILL=" << MatFill; + ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned; + ExtraDefs << " -DIN_INTERP=" << static_cast(InputInterp); + ExtraDefs << " -DBIAS_INTERP=" << static_cast(BiasInterp); + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatVecMulAddShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedVec(Params.CompType, Params.M, + MatFill * Params.N + 1, /*Increment=*/false); + + auto Op = createComputeOp(MatVecMulAddShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements, + /*StartingVal=*/1, /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runMatVecMulAdd(D3DDevice, DxcSupport, Params, VerboseLogging, + /*MatFill=*/2.0f, /*OutputSigned=*/true, ComponentType::F16, + ComponentType::F16); +} + +static const char OuterProductShader[] = R"( + #define USE_A 0 + #define SCOPE_THREAD 0 + + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [numthreads(NUMTHREADS, 1, 1)] + void main() { + vector VecA; + for (uint I = 0; I < M_DIM; ++I) { + VecA[I] = Input.Load(I * ELEM_SIZE); + } + + uint EndVecA = M_DIM * ELEM_SIZE; + + vector VecB; + for (uint I = 0; I < N_DIM; ++I) { + VecB[I] = Input.Load(EndVecA + I * ELEM_SIZE); + } + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]] + Mat; + __builtin_LinAlg_MatrixOuterProduct(Mat, VecA, VecB); + + __builtin_LinAlg_MatrixAccumulateToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runOuterProduct(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { + const size_t NumVecElements = Params.M + Params.N; + const size_t InBuffSize = NumVecElements * elementSize(Params.CompType); + const size_t NumMatElements = Params.totalElements(); + const size_t OutBufferSize = Params.totalBytes(); + + std::string Args = buildCompilerArgs(Params); + + compileShader(DxcSupport, OuterProductShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + 4, /*Increment=*/false); + + auto Op = createComputeOp(OuterProductShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", InBuffSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", OutBufferSize, true); + addRootUAV(Op.get(), 0, "Input"); + addRootUAV(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumVecElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumVecElements, + /*StartingVal=*/2, /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumMatElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runOuterProduct(D3DDevice, DxcSupport, Params, VerboseLogging); +} + +static const char QueryAccumLayoutShader[] = R"( + RWByteAddressBuffer Output : register(u0); + + [numthreads(1, 1, 1)] + void main() { + uint Layout = __builtin_LinAlg_MatrixQueryAccumulatorLayout(); + Output.Store(0, Layout); + } +)"; + +static void runQueryAccumLayout(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + bool Verbose) { + std::string Args = "-HV 202x"; + size_t BufferSize = elementSize(ComponentType::I32); + + compileShader(DxcSupport, QueryAccumLayoutShader, "cs_6_10", Args, Verbose); + + auto Op = + createComputeOp(QueryAccumLayoutShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootUAV(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + const uint32_t *Out = static_cast(OutData.data()); + + // Accum Layout must be A or B + VERIFY_IS_TRUE(Out[0] == static_cast(MatrixUse::A) || Out[0] == static_cast(MatrixUse::B)); + if (Verbose) + hlsl_test::LogCommentFmt(L"AccumulatorLayout = %u", Out[0]); +} + +void DxilConf_SM610_LinAlg::QueryAccumLayout() { + runQueryAccumLayout(D3DDevice, DxcSupport, VerboseLogging); +} + } // namespace LinAlg