Skip to content

Commit 1d5bd28

Browse files
authored
[Matrix][Clang][HLSL] Move MaxMatrixDimension to a LangOpt (#163307)
fixes #160190 fixes #116710 This change just makes MaxMatrixDimension configurable by language mode. It was previously introduced in 94b4311 when there was not a need to make dimensions configurable. Current testing to this effect exists in: - clang/test/Sema/matrix-type-builtins.c - clang/test/SemaCXX/matrix-type-builtins.cpp - clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl New Tests to confirm configurability by language mode: - clang/unittests/Frontend/CompilerInvocationTest.cpp I considered adding a driver flag to `clang/include/clang/Driver/Options.td` but HLSL matrix max dim is always 4 so we don't need this configurable beyond that size for our use case.
1 parent 22a2a82 commit 1d5bd28

File tree

10 files changed

+65
-21
lines changed

10 files changed

+65
-21
lines changed

clang/include/clang/AST/TypeBase.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4378,8 +4378,6 @@ class ConstantMatrixType final : public MatrixType {
43784378
unsigned NumRows;
43794379
unsigned NumColumns;
43804380

4381-
static constexpr unsigned MaxElementsPerDimension = (1 << 20) - 1;
4382-
43834381
ConstantMatrixType(QualType MatrixElementType, unsigned NRows,
43844382
unsigned NColumns, QualType CanonElementType);
43854383

@@ -4398,16 +4396,6 @@ class ConstantMatrixType final : public MatrixType {
43984396
return getNumRows() * getNumColumns();
43994397
}
44004398

4401-
/// Returns true if \p NumElements is a valid matrix dimension.
4402-
static constexpr bool isDimensionValid(size_t NumElements) {
4403-
return NumElements > 0 && NumElements <= MaxElementsPerDimension;
4404-
}
4405-
4406-
/// Returns the maximum number of elements per dimension.
4407-
static constexpr unsigned getMaxElementsPerDimension() {
4408-
return MaxElementsPerDimension;
4409-
}
4410-
44114399
void Profile(llvm::FoldingSetNodeID &ID) {
44124400
Profile(ID, getElementType(), getNumRows(), getNumColumns(),
44134401
getTypeClass());

clang/include/clang/Basic/LangOptions.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ ENUM_LANGOPT(RegisterStaticDestructors, RegisterStaticDestructorsKind, 2,
433433
LANGOPT(RegCall4, 1, 0, NotCompatible, "Set __regcall4 as a default calling convention to respect __regcall ABI v.4")
434434

435435
LANGOPT(MatrixTypes, 1, 0, NotCompatible, "Enable or disable the builtin matrix type")
436+
VALUE_LANGOPT(MaxMatrixDimension, 32, (1 << 20) - 1, NotCompatible, "maximum allowed matrix dimension")
436437

437438
LANGOPT(CXXAssumptions, 1, 1, NotCompatible, "Enable or disable codegen and compile-time checks for C++23's [[assume]] attribute")
438439

clang/lib/AST/ASTContext.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4712,8 +4712,8 @@ QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
47124712

47134713
assert(MatrixType::isValidElementType(ElementTy) &&
47144714
"need a valid element type");
4715-
assert(ConstantMatrixType::isDimensionValid(NumRows) &&
4716-
ConstantMatrixType::isDimensionValid(NumColumns) &&
4715+
assert(NumRows > 0 && NumRows <= LangOpts.MaxMatrixDimension &&
4716+
NumColumns > 0 && NumColumns <= LangOpts.MaxMatrixDimension &&
47174717
"need valid matrix dimensions");
47184718
void *InsertPos = nullptr;
47194719
if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos))

clang/lib/Basic/LangOptions.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,12 @@ void LangOptions::setLangDefaults(LangOptions &Opts, Language Lang,
132132
Opts.NamedLoops = Std.isC2y();
133133

134134
Opts.HLSL = Lang == Language::HLSL;
135-
if (Opts.HLSL && Opts.IncludeDefaultHeader)
136-
Includes.push_back("hlsl.h");
135+
if (Opts.HLSL) {
136+
if (Opts.IncludeDefaultHeader)
137+
Includes.push_back("hlsl.h");
138+
// Set maximum matrix dimension to 4 for HLSL
139+
Opts.MaxMatrixDimension = 4;
140+
}
137141

138142
// Set OpenCL Version.
139143
Opts.OpenCL = Std.isOpenCL();

clang/lib/Sema/HLSLExternalSemaSource.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
159159
SourceLocation(), ColsParam));
160160
TemplateParams.emplace_back(ColsParam);
161161

162-
const unsigned MaxMatDim = 4;
162+
const unsigned MaxMatDim = SemaPtr->getLangOpts().MaxMatrixDimension;
163+
163164
auto *MaxRow = IntegerLiteral::Create(
164165
AST, llvm::APInt(AST.getIntWidth(AST.IntTy), MaxMatDim), AST.IntTy,
165166
SourceLocation());

clang/lib/Sema/SemaChecking.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16239,9 +16239,9 @@ getAndVerifyMatrixDimension(Expr *Expr, StringRef Name, Sema &S) {
1623916239
return {};
1624016240
}
1624116241
uint64_t Dim = Value->getZExtValue();
16242-
if (!ConstantMatrixType::isDimensionValid(Dim)) {
16242+
if (Dim == 0 || Dim > S.Context.getLangOpts().MaxMatrixDimension) {
1624316243
S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_invalid_dimension)
16244-
<< Name << ConstantMatrixType::getMaxElementsPerDimension();
16244+
<< Name << S.Context.getLangOpts().MaxMatrixDimension;
1624516245
return {};
1624616246
}
1624716247
return Dim;

clang/lib/Sema/SemaType.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,12 +2517,18 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
25172517
Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
25182518
return QualType();
25192519
}
2520-
if (!ConstantMatrixType::isDimensionValid(MatrixRows)) {
2520+
if (MatrixRows > Context.getLangOpts().MaxMatrixDimension &&
2521+
MatrixColumns > Context.getLangOpts().MaxMatrixDimension) {
2522+
Diag(AttrLoc, diag::err_attribute_size_too_large)
2523+
<< RowRange << ColRange << "matrix row and column";
2524+
return QualType();
2525+
}
2526+
if (MatrixRows > Context.getLangOpts().MaxMatrixDimension) {
25212527
Diag(AttrLoc, diag::err_attribute_size_too_large)
25222528
<< RowRange << "matrix row";
25232529
return QualType();
25242530
}
2525-
if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) {
2531+
if (MatrixColumns > Context.getLangOpts().MaxMatrixDimension) {
25262532
Diag(AttrLoc, diag::err_attribute_size_too_large)
25272533
<< ColRange << "matrix column";
25282534
return QualType();

clang/test/SemaCXX/matrix-type.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
1414
using matrix7_t = int __attribute__((matrix_type(1, 0))); // expected-error{{zero matrix size}}
1515
using matrix7_t = int __attribute__((matrix_type(char, 0))); // expected-error{{expected '(' for function-style cast or type construction}}
1616
using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large}}
17+
using matrix8_t = int __attribute__((matrix_type(1048576, 1048576))); // expected-error{{matrix row and column size too large}}
1718
}
1819

1920
struct S1 {};

clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,26 @@ uint16_t4x4 mat2;
1010
matrix<int, 5, 5> mat3;
1111
// expected-error@-1 {{constraints not satisfied for alias template 'matrix' [with element = int, rows_count = 5, cols_count = 5]}}
1212
// expected-note@* {{because '5 <= 4' (5 <= 4) evaluated to false}}
13+
14+
using float8x4 = __attribute__((matrix_type(8,4))) float;
15+
// expected-error@-1 {{matrix row size too large}}
16+
17+
using float4x8 = __attribute__((matrix_type(4,8))) float;
18+
// expected-error@-1 {{matrix column size too large}}
19+
20+
using float8x8 = __attribute__((matrix_type(8,8))) float;
21+
// expected-error@-1 {{matrix row and column size too large}}
22+
23+
using floatNeg1x4 = __attribute__((matrix_type(-1,4))) float;
24+
// expected-error@-1 {{matrix row size too large}}
25+
using float4xNeg1 = __attribute__((matrix_type(4,-1))) float;
26+
// expected-error@-1 {{matrix column size too large}}
27+
using floatNeg1xNeg1 = __attribute__((matrix_type(-1,-1))) float;
28+
// expected-error@-1 {{matrix row and column size too large}}
29+
30+
using float0x4 = __attribute__((matrix_type(0,4))) float;
31+
// expected-error@-1 {{zero matrix size}}
32+
using float4x0 = __attribute__((matrix_type(4,0))) float;
33+
// expected-error@-1 {{zero matrix size}}
34+
using float0x0 = __attribute__((matrix_type(0,0))) float;
35+
// expected-error@-1 {{zero matrix size}}

clang/unittests/Frontend/CompilerInvocationTest.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,26 @@ TEST_F(CommandLineTest, ConditionalParsingIfTrueFlagPresent) {
732732
ASSERT_THAT(GeneratedArgs, Contains(StrEq("-sycl-std=2017")));
733733
}
734734

735+
TEST_F(CommandLineTest, ConditionalParsingIfHLSLFlagPresent) {
736+
const char *Args[] = {"-xhlsl"};
737+
738+
CompilerInvocation::CreateFromArgs(Invocation, Args, *Diags);
739+
740+
ASSERT_EQ(Invocation.getLangOpts().MaxMatrixDimension, 4u);
741+
742+
Invocation.generateCC1CommandLine(GeneratedArgs, *this);
743+
}
744+
745+
TEST_F(CommandLineTest, ConditionalParsingIfHLSLFlagNotPresent) {
746+
const char *Args[] = {""};
747+
748+
CompilerInvocation::CreateFromArgs(Invocation, Args, *Diags);
749+
750+
ASSERT_EQ(Invocation.getLangOpts().MaxMatrixDimension, 1048575u);
751+
752+
Invocation.generateCC1CommandLine(GeneratedArgs, *this);
753+
}
754+
735755
// Wide integer option.
736756

737757
TEST_F(CommandLineTest, WideIntegerHighValue) {

0 commit comments

Comments
 (0)