diff --git a/include/dxc/Support/HLSLOptions.h b/include/dxc/Support/HLSLOptions.h index 31ca3d1c14..ac1b311e1f 100644 --- a/include/dxc/Support/HLSLOptions.h +++ b/include/dxc/Support/HLSLOptions.h @@ -103,6 +103,7 @@ class DxcDefines { struct RewriterOpts { bool Unchanged = false; // OPT_rw_unchanged + bool ConsistentBindings = false; // OPT_rw_consistent_bindings bool SkipFunctionBody = false; // OPT_rw_skip_function_body bool SkipStatic = false; // OPT_rw_skip_static bool GlobalExternByDefault = false; // OPT_rw_global_extern_by_default diff --git a/include/dxc/Support/HLSLOptions.td b/include/dxc/Support/HLSLOptions.td index 58f6bdfbf3..fb597b8460 100644 --- a/include/dxc/Support/HLSLOptions.td +++ b/include/dxc/Support/HLSLOptions.td @@ -563,6 +563,8 @@ def nologo : Flag<["-", "/"], "nologo">, Group, Flags<[DriverOpt def rw_unchanged : Flag<["-", "/"], "unchanged">, Group, Flags<[RewriteOption]>, HelpText<"Rewrite HLSL, without changes.">; +def rw_consistent_bindings : Flag<["-", "/"], "consistent-bindings">, Group, Flags<[RewriteOption]>, + HelpText<"Generate bindings for registers that aren't fully qualified (to have consistent bindings).">; def rw_skip_function_body : Flag<["-", "/"], "skip-fn-body">, Group, Flags<[RewriteOption]>, HelpText<"Translate function definitions to declarations">; def rw_skip_static : Flag<["-", "/"], "skip-static">, Group, Flags<[RewriteOption]>, diff --git a/lib/DxcSupport/HLSLOptions.cpp b/lib/DxcSupport/HLSLOptions.cpp index eb071eb0a6..c733b2a633 100644 --- a/lib/DxcSupport/HLSLOptions.cpp +++ b/lib/DxcSupport/HLSLOptions.cpp @@ -1349,6 +1349,8 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude, // Rewriter Options if (flagsToInclude & hlsl::options::RewriteOption) { opts.RWOpt.Unchanged = Args.hasFlag(OPT_rw_unchanged, OPT_INVALID, false); + opts.RWOpt.ConsistentBindings = + Args.hasFlag(OPT_rw_consistent_bindings, OPT_INVALID, false); opts.RWOpt.SkipFunctionBody = Args.hasFlag(OPT_rw_skip_function_body, OPT_INVALID, false); opts.RWOpt.SkipStatic = diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 43c1effdb8..2caa408b08 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -497,6 +497,7 @@ bool IsHLSLNumericOrAggregateOfNumericType(clang::QualType type); bool IsHLSLCopyableAnnotatableRecord(clang::QualType QT); bool IsHLSLBuiltinRayAttributeStruct(clang::QualType QT); bool IsHLSLAggregateType(clang::QualType type); +hlsl::DXIL::ResourceClass GetHLSLResourceClass(clang::QualType type); clang::QualType GetHLSLResourceResultType(clang::QualType type); clang::QualType GetHLSLNodeIOResultType(clang::ASTContext &astContext, clang::QualType type); diff --git a/tools/clang/lib/AST/HlslTypes.cpp b/tools/clang/lib/AST/HlslTypes.cpp index 00c18a81a9..9748e13069 100644 --- a/tools/clang/lib/AST/HlslTypes.cpp +++ b/tools/clang/lib/AST/HlslTypes.cpp @@ -524,6 +524,14 @@ bool IsHLSLResourceType(clang::QualType type) { return false; } +hlsl::DXIL::ResourceClass GetHLSLResourceClass(clang::QualType type) { + + if (HLSLResourceAttr *attr = getAttr(type)) + return attr->getResClass(); + + return hlsl::DXIL::ResourceClass::Invalid; +} + bool IsHLSLHitObjectType(QualType type) { return nullptr != getAttr(type); } diff --git a/tools/clang/test/HLSL/rewriter/consistent_bindings.hlsl b/tools/clang/test/HLSL/rewriter/consistent_bindings.hlsl new file mode 100644 index 0000000000..e12eed0789 --- /dev/null +++ b/tools/clang/test/HLSL/rewriter/consistent_bindings.hlsl @@ -0,0 +1,48 @@ +//UAV + +RWByteAddressBuffer output1; +RWByteAddressBuffer output2; +RWByteAddressBuffer output3 : register(u0); +RWByteAddressBuffer output4 : register(space1); +RWByteAddressBuffer output5 : SEMA; +RWByteAddressBuffer output6; +RWByteAddressBuffer output7 : register(u1); +RWByteAddressBuffer output8[12] : register(u3); +RWByteAddressBuffer output9[12]; +RWByteAddressBuffer output10[33] : register(space1); +RWByteAddressBuffer output11[33] : register(space2); +RWByteAddressBuffer output12[33] : register(u0, space2); + +//SRV + +StructuredBuffer test; +ByteAddressBuffer input13 : SEMA; +ByteAddressBuffer input14; +ByteAddressBuffer input15 : register(t0); +ByteAddressBuffer input16[12] : register(t3); +ByteAddressBuffer input17[2] : register(space1); +ByteAddressBuffer input18[12] : register(t1, space1); +ByteAddressBuffer input19[3] : register(space1); +ByteAddressBuffer input20 : register(space1); + +//Sampler + +SamplerState sampler0; +SamplerState sampler1; +SamplerState sampler2 : register(s0); +SamplerState sampler3 : register(space1); +SamplerState sampler4 : register(s0, space1); + +//CBV + +cbuffer test : register(b0) { float a; }; +cbuffer test2 { float b; }; +cbuffer test3 : register(space1) { float c; }; +cbuffer test4 : register(space1) { float d; }; + +float e; + +[numthreads(16, 16, 1)] +void main(uint id : SV_DispatchThreadID) { + output2.Store(id * 4, 1); //Only use 1 output, but this won't result into output2 receiving wrong bindings +} diff --git a/tools/clang/test/HLSL/rewriter/correct_rewrites/consistent_bindings_gold.hlsl b/tools/clang/test/HLSL/rewriter/correct_rewrites/consistent_bindings_gold.hlsl new file mode 100644 index 0000000000..4c2810e9f4 --- /dev/null +++ b/tools/clang/test/HLSL/rewriter/correct_rewrites/consistent_bindings_gold.hlsl @@ -0,0 +1,49 @@ +RWByteAddressBuffer output1 : register(u2); +RWByteAddressBuffer output2 : register(u15); +RWByteAddressBuffer output3 : register(u0); +RWByteAddressBuffer output4 : register(u0, space1); +RWByteAddressBuffer output5 : SEMA : register(u16); +RWByteAddressBuffer output6 : register(u17); +RWByteAddressBuffer output7 : register(u1); +RWByteAddressBuffer output8[12] : register(u3); +RWByteAddressBuffer output9[12] : register(u18); +RWByteAddressBuffer output10[33] : register(u1, space1); +RWByteAddressBuffer output11[33] : register(u33, space2); +RWByteAddressBuffer output12[33] : register(u0, space2); +StructuredBuffer test : register(t1); +ByteAddressBuffer input13 : SEMA : register(t2); +ByteAddressBuffer input14 : register(t15); +ByteAddressBuffer input15 : register(t0); +ByteAddressBuffer input16[12] : register(t3); +ByteAddressBuffer input17[2] : register(t13, space1); +ByteAddressBuffer input18[12] : register(t1, space1); +ByteAddressBuffer input19[3] : register(t15, space1); +ByteAddressBuffer input20 : register(t0, space1); +SamplerState sampler0 : register(s1); +SamplerState sampler1 : register(s2); +SamplerState sampler2 : register(s0); +SamplerState sampler3 : register(s1, space1); +SamplerState sampler4 : register(s0, space1); +cbuffer test : register(b0) { + const float a; +} +; +cbuffer test2 : register(b1) { + const float b; +} +; +cbuffer test3 : register(b0, space1) { + const float c; +} +; +cbuffer test4 : register(b1, space1) { + const float d; +} +; +const float e; +[numthreads(16, 16, 1)] +void main(uint id : SV_DispatchThreadID) { + output2.Store(id * 4, 1); +} + + diff --git a/tools/clang/tools/libclang/dxcrewriteunused.cpp b/tools/clang/tools/libclang/dxcrewriteunused.cpp index c29854077b..1c1699196c 100644 --- a/tools/clang/tools/libclang/dxcrewriteunused.cpp +++ b/tools/clang/tools/libclang/dxcrewriteunused.cpp @@ -11,6 +11,7 @@ #include "clang/AST/ASTConsumer.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/HlslTypes.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Diagnostic.h" #include "clang/Basic/FileManager.h" @@ -1032,6 +1033,247 @@ static void RemoveStaticDecls(DeclContext &Ctx) { } } +struct ResourceKey { + + uint32_t space; + DXIL::ResourceClass resourceClass; + + bool operator==(const ResourceKey &other) const { + return space == other.space && resourceClass == other.resourceClass; + } +}; + +namespace llvm { +template <> struct DenseMapInfo { + static inline ResourceKey getEmptyKey() { + return {~0u, DXIL::ResourceClass::Invalid}; + } + static inline ResourceKey getTombstoneKey() { + return {~0u - 1, DXIL::ResourceClass::Invalid}; + } + static unsigned getHashValue(const ResourceKey &K) { + return llvm::hash_combine(K.space, uint32_t(K.resourceClass)); + } + static bool isEqual(const ResourceKey &LHS, const ResourceKey &RHS) { + return LHS.space == RHS.space && LHS.resourceClass == RHS.resourceClass; + } +}; +} // namespace llvm + +using RegisterRange = std::pair; //(startReg, count) +using RegisterMap = + llvm::DenseMap>; + +struct UnresolvedRegister { + hlsl::DXIL::ResourceClass cls; + uint32_t arraySize; + RegisterAssignment *reg; + NamedDecl *ND; +}; + +using UnresolvedRegisters = llvm::SmallVector; + +// Find gap in register list and fill it + +uint32_t FillNextRegister(llvm::SmallVector &ranges, + uint32_t arraySize) { + + if (ranges.empty()) { + ranges.push_back({0, arraySize}); + return 0; + } + + size_t i = 0, j = ranges.size(); + size_t curr = 0; + + for (; i < j; ++i) { + + const RegisterRange &range = ranges[i]; + + if (range.first - curr >= arraySize) { + ranges.insert(ranges.begin() + i, RegisterRange{curr, arraySize}); + return curr; + } + + curr = range.first + range.second; + } + + ranges.emplace_back(RegisterRange{curr, arraySize}); + return curr; +} + +// Insert in the right place (keep sorted) + +void FillRegisterAt(llvm::SmallVector &ranges, + uint32_t registerNr, uint32_t arraySize, + clang::DiagnosticsEngine &diags, + const SourceLocation &location) { + + size_t i = 0, j = ranges.size(); + + for (; i < j; ++i) { + + const RegisterRange &range = ranges[i]; + + if (range.first > registerNr) { + + if (registerNr + arraySize > range.first) { + diags.Report(location, diag::err_hlsl_register_semantics_conflicting); + return; + } + + ranges.insert(ranges.begin() + i, RegisterRange{registerNr, arraySize}); + break; + } + + if (range.first + range.second > registerNr) { + diags.Report(location, diag::err_hlsl_register_semantics_conflicting); + return; + } + } + + if (i == j) + ranges.emplace_back(RegisterRange{registerNr, arraySize}); +} + +static void RegisterBinding(NamedDecl *ND, + UnresolvedRegisters &unresolvedRegisters, + RegisterMap &map, hlsl::DXIL::ResourceClass cls, + uint32_t arraySize, clang::DiagnosticsEngine &Diags, + uint32_t autoBindingSpace) { + + const ArrayRef &UA = ND->getUnusualAnnotations(); + + bool qualified = false; + RegisterAssignment *reg = nullptr; + + for (auto It = UA.begin(), E = UA.end(); It != E; ++It) { + + if ((*It)->getKind() != hlsl::UnusualAnnotation::UA_RegisterAssignment) + continue; + + reg = cast(*It); + + if (!reg->RegisterType) // Unqualified register assignment + break; + + uint32_t space = reg->RegisterSpace.hasValue() + ? reg->RegisterSpace.getValue() + : autoBindingSpace; + + qualified = true; + FillRegisterAt(map[ResourceKey{space, cls}], reg->RegisterNumber, arraySize, + Diags, ND->getLocation()); + break; + } + + if (!qualified) + unresolvedRegisters.emplace_back( + UnresolvedRegister{cls, arraySize, reg, ND}); +} + +static void GenerateConsistentBindings(DeclContext &Ctx, + uint32_t autoBindingSpace) { + + clang::DiagnosticsEngine &Diags = Ctx.getParentASTContext().getDiagnostics(); + + RegisterMap map; + UnresolvedRegisters unresolvedRegisters; + + // Fill up map with fully qualified registers to avoid colliding with them + // later + + for (auto it = Ctx.decls_begin(); it != Ctx.decls_end(); ++it) { + + // CBuffer has special logic, since it's not technically + + if (HLSLBufferDecl *CBuffer = dyn_cast(*it)) { + RegisterBinding(CBuffer, unresolvedRegisters, map, + hlsl::DXIL::ResourceClass::CBuffer, 1, Diags, + autoBindingSpace); + continue; + } + + ValueDecl *VD = dyn_cast(*it); + + if (!VD) + continue; + + std::string test = VD->getName(); + + uint32_t arraySize = 1; + QualType type = VD->getType(); + + while (const ConstantArrayType *arr = dyn_cast(type)) { + arraySize *= arr->getSize().getZExtValue(); + type = arr->getElementType(); + } + + if (!IsHLSLResourceType(type)) + continue; + + RegisterBinding(VD, unresolvedRegisters, map, GetHLSLResourceClass(type), + arraySize, Diags, autoBindingSpace); + } + + // Resolve unresolved registers (while avoiding collisions) + + for (const UnresolvedRegister ® : unresolvedRegisters) { + + uint32_t arraySize = reg.arraySize; + hlsl::DXIL::ResourceClass resClass = reg.cls; + + char prefix = 't'; + + switch (resClass) { + + case DXIL::ResourceClass::Sampler: + prefix = 's'; + break; + + case DXIL::ResourceClass::CBuffer: + prefix = 'b'; + break; + + case DXIL::ResourceClass::UAV: + prefix = 'u'; + break; + } + + uint32_t space = + reg.reg ? reg.reg->RegisterSpace.getValue() : autoBindingSpace; + + uint32_t registerNr = + FillNextRegister(map[ResourceKey{space, resClass}], arraySize); + + if (reg.reg) { + reg.reg->RegisterType = prefix; + reg.reg->RegisterNumber = registerNr; + reg.reg->setIsValid(true); + } else { + hlsl::RegisterAssignment + r; // Keep space empty to ensure space overrides still work fine + r.RegisterNumber = registerNr; + r.RegisterType = prefix; + r.setIsValid(true); + + llvm::SmallVector annotations; + + const ArrayRef &UA = + reg.ND->getUnusualAnnotations(); + + for (auto It = UA.begin(), E = UA.end(); It != E; ++It) + annotations.emplace_back(*It); + + annotations.push_back(::new (Ctx.getParentASTContext()) + hlsl::RegisterAssignment(r)); + + reg.ND->setUnusualAnnotations(UnusualAnnotation::CopyToASTContextArray( + Ctx.getParentASTContext(), annotations.data(), annotations.size())); + } + } +} + static void GlobalVariableAsExternByDefault(DeclContext &Ctx) { for (auto it = Ctx.decls_begin(); it != Ctx.decls_end();) { auto cur = it++; @@ -1065,6 +1307,10 @@ static HRESULT DoSimpleReWrite(DxcLangExtensionsHelper *pHelper, TranslationUnitDecl *tu = astHelper.tu; + if (opts.RWOpt.ConsistentBindings) { + GenerateConsistentBindings(*tu, opts.AutoBindingSpace); + } + if (opts.RWOpt.SkipStatic && opts.RWOpt.SkipFunctionBody) { // Remove static functions and globals. RemoveStaticDecls(*tu); @@ -1083,7 +1329,7 @@ static HRESULT DoSimpleReWrite(DxcLangExtensionsHelper *pHelper, opts.RWOpt.RemoveUnusedFunctions, w); if (FAILED(hr)) return hr; - } else { + } else if (!opts.RWOpt.ConsistentBindings) { o << "// Rewrite unchanged result:\n"; } diff --git a/tools/clang/unittests/HLSL/RewriterTest.cpp b/tools/clang/unittests/HLSL/RewriterTest.cpp index 613c8561a3..7adfd3f88e 100644 --- a/tools/clang/unittests/HLSL/RewriterTest.cpp +++ b/tools/clang/unittests/HLSL/RewriterTest.cpp @@ -102,6 +102,7 @@ class RewriterTest : public ::testing::Test { TEST_METHOD(RunExtractUniforms) TEST_METHOD(RunGlobalsUsedInMethod) TEST_METHOD(RunRewriterFails) + TEST_METHOD(GenerateConsistentBindings) dxc::DxcDllSupport m_dllSupport; CComPtr m_pIncludeHandler; @@ -126,16 +127,19 @@ class RewriterTest : public ::testing::Test { ppBlob)); } - VerifyResult CheckVerifies(LPCWSTR path, LPCWSTR goldPath) { + VerifyResult CheckVerifies(LPCWSTR path, LPCWSTR goldPath, + const llvm::SmallVector &args = { + L"-HV", L"2016"}) { CComPtr pRewriter; VERIFY_SUCCEEDED(CreateRewriter(&pRewriter)); - return CheckVerifies(pRewriter, path, goldPath); + return CheckVerifies(pRewriter, path, goldPath, args); } - VerifyResult CheckVerifies(IDxcRewriter *pRewriter, LPCWSTR path, - LPCWSTR goldPath) { + VerifyResult + CheckVerifies(IDxcRewriter *pRewriter, LPCWSTR path, LPCWSTR goldPath, + const llvm::SmallVector &args = {L"-HV", L"2016"}) { CComPtr pRewriteResult; - RewriteCompareGold(path, goldPath, &pRewriteResult, pRewriter); + RewriteCompareGold(path, goldPath, &pRewriteResult, pRewriter, args); VerifyResult toReturn; @@ -165,9 +169,11 @@ class RewriterTest : public ::testing::Test { return S_OK; } - VerifyResult CheckVerifiesHLSL(LPCWSTR name, LPCWSTR goldName) { + VerifyResult CheckVerifiesHLSL(LPCWSTR name, LPCWSTR goldName, + const llvm::SmallVector &args = { + L"-HV", L"2016"}) { return CheckVerifies(GetPathToHlslDataFile(name).c_str(), - GetPathToHlslDataFile(goldName).c_str()); + GetPathToHlslDataFile(goldName).c_str(), args); } struct FileWithBlob { @@ -210,7 +216,8 @@ class RewriterTest : public ::testing::Test { void RewriteCompareGold(LPCWSTR path, LPCWSTR goldPath, IDxcOperationResult **ppResult, - IDxcRewriter *rewriter) { + IDxcRewriter *rewriter, + const llvm::SmallVector &args = {}) { // Get the source text from a file FileWithBlob source(m_dllSupport, path); @@ -218,14 +225,12 @@ class RewriterTest : public ::testing::Test { DxcDefine myDefines[myDefinesCount] = { {L"myDefine", L"2"}, {L"myDefine3", L"1994"}, {L"myDefine4", nullptr}}; - LPCWSTR args[] = {L"-HV", L"2016"}; - CComPtr rewriter2; VERIFY_SUCCEEDED(rewriter->QueryInterface(&rewriter2)); // Run rewrite unchanged on the source code VERIFY_SUCCEEDED(rewriter2->RewriteWithOptions( - source.BlobEncoding, path, args, _countof(args), myDefines, - myDefinesCount, nullptr, ppResult)); + source.BlobEncoding, path, (LPCWSTR *)args.data(), + (uint32_t)args.size(), myDefines, myDefinesCount, nullptr, ppResult)); // check for compilation errors HRESULT hrStatus; @@ -462,6 +467,13 @@ TEST_F(RewriterTest, RunSpirv) { VERIFY_IS_TRUE(strResult.find("namespace vk") == std::string::npos); } +TEST_F(RewriterTest, GenerateConsistentBindings) { + CheckVerifiesHLSL( + L"rewriter\\consistent_bindings.hlsl", + L"rewriter\\correct_rewrites\\consistent_bindings_gold.hlsl", + {L"-HV", L"2016", L"-consistent-bindings"}); +} + TEST_F(RewriterTest, RunStructMethods) { CheckVerifiesHLSL(L"rewriter\\struct-methods.hlsl", L"rewriter\\correct_rewrites\\struct-methods_gold.hlsl");