Skip to content

Commit 6747139

Browse files
authored
[CIR] Use zero-initializer for partial array fills (#154161)
If an array initializer list leaves eight or more elements that require zero fill, we had been generating an individual zero element for every one of them. This change instead follows the behavior of classic codegen, which creates a constant structure with the specified elements followed by a zero-initializer for the trailing zeros.
1 parent 0542355 commit 6747139

File tree

10 files changed

+241
-19
lines changed

10 files changed

+241
-19
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,44 @@ def CIR_ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector", [
341341
let genVerifyDecl = 1;
342342
}
343343

344+
//===----------------------------------------------------------------------===//
345+
// ConstRecordAttr
346+
//===----------------------------------------------------------------------===//
347+
348+
def CIR_ConstRecordAttr : CIR_Attr<"ConstRecord", "const_record", [
349+
TypedAttrInterface
350+
]> {
351+
let summary = "Represents a constant record";
352+
let description = [{
353+
Effectively supports "struct-like" constants. It's must be built from
354+
an `mlir::ArrayAttr` instance where each element is a typed attribute
355+
(`mlir::TypedAttribute`).
356+
357+
Example:
358+
```
359+
cir.global external @rgb2 = #cir.const_record<{0 : i8,
360+
5 : i64, #cir.null : !cir.ptr<i8>
361+
}> : !cir.record<"", i8, i64, !cir.ptr<i8>>
362+
```
363+
}];
364+
365+
let parameters = (ins AttributeSelfTypeParameter<"">:$type,
366+
"mlir::ArrayAttr":$members);
367+
368+
let builders = [
369+
AttrBuilderWithInferredContext<(ins "cir::RecordType":$type,
370+
"mlir::ArrayAttr":$members), [{
371+
return $_get(type.getContext(), type, members);
372+
}]>
373+
];
374+
375+
let assemblyFormat = [{
376+
`<` custom<RecordMembers>($members) `>`
377+
}];
378+
379+
let genVerifyDecl = 1;
380+
}
381+
344382
//===----------------------------------------------------------------------===//
345383
// ConstPtrAttr
346384
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
6060
trailingZerosNum);
6161
}
6262

63+
cir::ConstRecordAttr getAnonConstRecord(mlir::ArrayAttr arrayAttr,
64+
bool packed = false,
65+
bool padded = false,
66+
mlir::Type ty = {}) {
67+
llvm::SmallVector<mlir::Type, 4> members;
68+
for (auto &f : arrayAttr) {
69+
auto ta = mlir::cast<mlir::TypedAttr>(f);
70+
members.push_back(ta.getType());
71+
}
72+
73+
if (!ty)
74+
ty = getAnonRecordTy(members, packed, padded);
75+
76+
auto sTy = mlir::cast<cir::RecordType>(ty);
77+
return cir::ConstRecordAttr::get(sTy, arrayAttr);
78+
}
79+
6380
std::string getUniqueAnonRecordName() { return getUniqueRecordName("anon"); }
6481

6582
std::string getUniqueRecordName(const std::string &baseName) {

clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
285285
mlir::Type commonElementType, unsigned arrayBound,
286286
SmallVectorImpl<mlir::TypedAttr> &elements,
287287
mlir::TypedAttr filler) {
288-
const CIRGenBuilderTy &builder = cgm.getBuilder();
288+
CIRGenBuilderTy &builder = cgm.getBuilder();
289289

290290
unsigned nonzeroLength = arrayBound;
291291
if (elements.size() < nonzeroLength && builder.isNullValue(filler))
@@ -306,6 +306,33 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
306306
if (trailingZeroes >= 8) {
307307
assert(elements.size() >= nonzeroLength &&
308308
"missing initializer for non-zero element");
309+
310+
if (commonElementType && nonzeroLength >= 8) {
311+
// If all the elements had the same type up to the trailing zeroes and
312+
// there are eight or more nonzero elements, emit a struct of two arrays
313+
// (the nonzero data and the zeroinitializer).
314+
SmallVector<mlir::Attribute, 4> eles;
315+
eles.reserve(nonzeroLength);
316+
for (const auto &element : elements)
317+
eles.push_back(element);
318+
auto initial = cir::ConstArrayAttr::get(
319+
cir::ArrayType::get(commonElementType, nonzeroLength),
320+
mlir::ArrayAttr::get(builder.getContext(), eles));
321+
elements.resize(2);
322+
elements[0] = initial;
323+
} else {
324+
// Otherwise, emit a struct with individual elements for each nonzero
325+
// initializer, followed by a zeroinitializer array filler.
326+
elements.resize(nonzeroLength + 1);
327+
}
328+
329+
mlir::Type fillerType =
330+
commonElementType
331+
? commonElementType
332+
: mlir::cast<cir::ArrayType>(desiredType).getElementType();
333+
fillerType = cir::ArrayType::get(fillerType, trailingZeroes);
334+
elements.back() = cir::ZeroAttr::get(fillerType);
335+
commonElementType = nullptr;
309336
} else if (elements.size() != arrayBound) {
310337
elements.resize(arrayBound, filler);
311338

@@ -325,8 +352,13 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
325352
mlir::ArrayAttr::get(builder.getContext(), eles));
326353
}
327354

328-
cgm.errorNYI("array with different type elements");
329-
return {};
355+
SmallVector<mlir::Attribute, 4> eles;
356+
eles.reserve(elements.size());
357+
for (auto const &element : elements)
358+
eles.push_back(element);
359+
360+
auto arrAttr = mlir::ArrayAttr::get(builder.getContext(), eles);
361+
return builder.getAnonConstRecord(arrAttr, /*isPacked=*/true);
330362
}
331363

332364
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
#include "mlir/IR/DialectImplementation.h"
1616
#include "llvm/ADT/TypeSwitch.h"
1717

18+
//===-----------------------------------------------------------------===//
19+
// RecordMembers
20+
//===-----------------------------------------------------------------===//
21+
22+
static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
23+
static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser,
24+
mlir::ArrayAttr &members);
25+
1826
//===-----------------------------------------------------------------===//
1927
// IntLiteral
2028
//===-----------------------------------------------------------------===//
@@ -68,6 +76,61 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
6876
llvm_unreachable("unexpected CIR type kind");
6977
}
7078

79+
static void printRecordMembers(mlir::AsmPrinter &printer,
80+
mlir::ArrayAttr members) {
81+
printer << '{';
82+
llvm::interleaveComma(members, printer);
83+
printer << '}';
84+
}
85+
86+
static ParseResult parseRecordMembers(mlir::AsmParser &parser,
87+
mlir::ArrayAttr &members) {
88+
llvm::SmallVector<mlir::Attribute, 4> elts;
89+
90+
auto delimiter = AsmParser::Delimiter::Braces;
91+
auto result = parser.parseCommaSeparatedList(delimiter, [&]() {
92+
mlir::TypedAttr attr;
93+
if (parser.parseAttribute(attr).failed())
94+
return mlir::failure();
95+
elts.push_back(attr);
96+
return mlir::success();
97+
});
98+
99+
if (result.failed())
100+
return mlir::failure();
101+
102+
members = mlir::ArrayAttr::get(parser.getContext(), elts);
103+
return mlir::success();
104+
}
105+
106+
//===----------------------------------------------------------------------===//
107+
// ConstRecordAttr definitions
108+
//===----------------------------------------------------------------------===//
109+
110+
LogicalResult
111+
ConstRecordAttr::verify(function_ref<InFlightDiagnostic()> emitError,
112+
mlir::Type type, ArrayAttr members) {
113+
auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
114+
if (!sTy)
115+
return emitError() << "expected !cir.record type";
116+
117+
if (sTy.getMembers().size() != members.size())
118+
return emitError() << "number of elements must match";
119+
120+
unsigned attrIdx = 0;
121+
for (auto &member : sTy.getMembers()) {
122+
auto m = mlir::cast<mlir::TypedAttr>(members[attrIdx]);
123+
if (member != m.getType())
124+
return emitError() << "element at index " << attrIdx << " has type "
125+
<< m.getType()
126+
<< " but the expected type for this element is "
127+
<< member;
128+
attrIdx++;
129+
}
130+
131+
return success();
132+
}
133+
71134
//===----------------------------------------------------------------------===//
72135
// OptInfoAttr definitions
73136
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
341341
}
342342

343343
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
344-
cir::ConstComplexAttr, cir::GlobalViewAttr, cir::PoisonAttr>(
345-
attrType))
344+
cir::ConstComplexAttr, cir::ConstRecordAttr,
345+
cir::GlobalViewAttr, cir::PoisonAttr>(attrType))
346346
return success();
347347

348348
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ class CIRAttrToValue {
201201
mlir::Value visit(mlir::Attribute attr) {
202202
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
203203
.Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
204-
cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
205-
cir::GlobalViewAttr, cir::ZeroAttr>(
204+
cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
205+
cir::ConstPtrAttr, cir::GlobalViewAttr, cir::ZeroAttr>(
206206
[&](auto attrT) { return visitCirAttr(attrT); })
207207
.Default([&](auto attrT) { return mlir::Value(); });
208208
}
@@ -212,6 +212,7 @@ class CIRAttrToValue {
212212
mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
213213
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
214214
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
215+
mlir::Value visitCirAttr(cir::ConstRecordAttr attr);
215216
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
216217
mlir::Value visitCirAttr(cir::GlobalViewAttr attr);
217218
mlir::Value visitCirAttr(cir::ZeroAttr attr);
@@ -386,6 +387,21 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
386387
return result;
387388
}
388389

390+
/// ConstRecord visitor.
391+
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstRecordAttr constRecord) {
392+
const mlir::Type llvmTy = converter->convertType(constRecord.getType());
393+
const mlir::Location loc = parentOp->getLoc();
394+
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
395+
396+
// Iteratively lower each constant element of the record.
397+
for (auto [idx, elt] : llvm::enumerate(constRecord.getMembers())) {
398+
mlir::Value init = visit(elt);
399+
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
400+
}
401+
402+
return result;
403+
}
404+
389405
/// ConstVectorAttr visitor.
390406
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
391407
const mlir::Type llvmTy = converter->convertType(attr.getType());
@@ -1286,6 +1302,11 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
12861302
rewriter.eraseOp(op);
12871303
return mlir::success();
12881304
}
1305+
} else if (const auto recordAttr =
1306+
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
1307+
auto initVal = lowerCirAttrAsValue(op, recordAttr, rewriter, typeConverter);
1308+
rewriter.replaceOp(op, initVal);
1309+
return mlir::success();
12891310
} else if (const auto vecTy = mlir::dyn_cast<cir::VectorType>(op.getType())) {
12901311
rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
12911312
getTypeConverter()));
@@ -1527,9 +1548,9 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
15271548
cir::GlobalOp op, mlir::Attribute init,
15281549
mlir::ConversionPatternRewriter &rewriter) const {
15291550
// TODO: Generalize this handling when more types are needed here.
1530-
assert(
1531-
(isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
1532-
cir::ConstComplexAttr, cir::GlobalViewAttr, cir::ZeroAttr>(init)));
1551+
assert((isa<cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
1552+
cir::ConstPtrAttr, cir::ConstComplexAttr, cir::GlobalViewAttr,
1553+
cir::ZeroAttr>(init)));
15331554

15341555
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
15351556
// should be updated. For now, we use a custom op to initialize globals
@@ -1582,8 +1603,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
15821603
return mlir::failure();
15831604
}
15841605
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
1585-
cir::ConstPtrAttr, cir::ConstComplexAttr,
1586-
cir::GlobalViewAttr, cir::ZeroAttr>(init.value())) {
1606+
cir::ConstRecordAttr, cir::ConstPtrAttr,
1607+
cir::ConstComplexAttr, cir::GlobalViewAttr,
1608+
cir::ZeroAttr>(init.value())) {
15871609
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
15881610
// should be updated. For now, we use a custom op to initialize globals
15891611
// to the appropriate value.

clang/test/CIR/CodeGen/array.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ int dd[3][2] = {{1, 2}, {3, 4}, {5, 6}};
4545
// OGCG: [i32 3, i32 4], [2 x i32] [i32 5, i32 6]]
4646

4747
int e[10] = {1, 2};
48-
// CIR: cir.global external @e = #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>
48+
// CIR: cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
4949

50-
// LLVM: @e = global [10 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0]
50+
// LLVM: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>
5151

5252
// OGCG: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>
5353

@@ -58,6 +58,28 @@ int f[5] = {1, 2};
5858

5959
// OGCG: @f = global [5 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0]
6060

61+
int g[16] = {1, 2, 3, 4, 5, 6, 7, 8};
62+
// CIR: cir.global external @g = #cir.const_record<{
63+
// CIR-SAME: #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i,
64+
// CIR-SAME: #cir.int<3> : !s32i, #cir.int<4> : !s32i,
65+
// CIR-SAME: #cir.int<5> : !s32i, #cir.int<6> : !s32i,
66+
// CIR-SAME: #cir.int<7> : !s32i, #cir.int<8> : !s32i]>
67+
// CIR-SAME: : !cir.array<!s32i x 8>,
68+
// CIR-SAME: #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct1
69+
70+
// LLVM: @g = global <{ [8 x i32], [8 x i32] }>
71+
// LLVM-SAME: <{ [8 x i32]
72+
// LLVM-SAME: [i32 1, i32 2, i32 3, i32 4,
73+
// LLVM-SAME: i32 5, i32 6, i32 7, i32 8],
74+
// LLVM-SAME: [8 x i32] zeroinitializer }>
75+
76+
// OGCG: @g = global <{ [8 x i32], [8 x i32] }>
77+
// OGCG-SAME: <{ [8 x i32]
78+
// OGCG-SAME: [i32 1, i32 2, i32 3, i32 4,
79+
// OGCG-SAME: i32 5, i32 6, i32 7, i32 8],
80+
// OGCG-SAME: [8 x i32] zeroinitializer }>
81+
82+
6183
extern int b[10];
6284
// CIR: cir.global "private" external @b : !cir.array<!s32i x 10>
6385
// LLVM: @b = external global [10 x i32]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: cir-opt %s -verify-diagnostics -split-input-file
2+
3+
!s32i = !cir.int<s, 32>
4+
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
5+
6+
// expected-error @below {{expected !cir.record type}}
7+
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !cir.ptr<!rec_anon_struct>
8+
9+
// -----
10+
11+
!s32i = !cir.int<s, 32>
12+
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
13+
14+
// expected-error @below {{number of elements must match}}
15+
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
16+
17+
// -----
18+
19+
!s32i = !cir.int<s, 32>
20+
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
21+
22+
// expected-error @below {{element at index 1 has type '!cir.float' but the expected type for this element is '!cir.int<s, 32>'}}
23+
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.fp<2.000000e+00> : !cir.float, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct

clang/test/CIR/IR/struct.cir

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
// CHECK-DAG: !rec_S = !cir.record<struct "S" incomplete>
1414
// CHECK-DAG: !rec_U = !cir.record<union "U" incomplete>
1515

16-
!rec_anon_struct = !cir.record<struct {!cir.array<!cir.ptr<!u8i> x 5>}>
17-
!rec_anon_struct1 = !cir.record<struct {!cir.ptr<!u8i>, !cir.ptr<!u8i>, !cir.ptr<!u8i>}>
16+
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
17+
!rec_anon_struct1 = !cir.record<struct {!cir.array<!cir.ptr<!u8i> x 5>}>
18+
!rec_anon_struct2 = !cir.record<struct {!cir.ptr<!u8i>, !cir.ptr<!u8i>, !cir.ptr<!u8i>}>
1819
!rec_S1 = !cir.record<struct "S1" {!s32i, !s32i}>
1920
!rec_Sc = !cir.record<struct "Sc" {!u8i, !u16i, !u32i}>
2021

@@ -42,18 +43,22 @@
4243
!rec_Node = !cir.record<struct "Node" {!cir.ptr<!cir.record<struct "Node">>}>
4344
// CHECK-DAG: !cir.record<struct "Node" {!cir.ptr<!cir.record<struct "Node">>}>
4445

46+
47+
4548
module {
4649
cir.global external @p1 = #cir.ptr<null> : !cir.ptr<!rec_S>
4750
cir.global external @p2 = #cir.ptr<null> : !cir.ptr<!rec_U>
4851
cir.global external @p3 = #cir.ptr<null> : !cir.ptr<!rec_C>
52+
cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
4953
// CHECK: cir.global external @p1 = #cir.ptr<null> : !cir.ptr<!rec_S>
5054
// CHECK: cir.global external @p2 = #cir.ptr<null> : !cir.ptr<!rec_U>
5155
// CHECK: cir.global external @p3 = #cir.ptr<null> : !cir.ptr<!rec_C>
56+
// CHECK: cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
5257

5358
// Dummy function to use types and force them to be printed.
5459
cir.func @useTypes(%arg0: !rec_Node,
55-
%arg1: !rec_anon_struct1,
56-
%arg2: !rec_anon_struct,
60+
%arg1: !rec_anon_struct2,
61+
%arg2: !rec_anon_struct1,
5762
%arg3: !rec_S1,
5863
%arg4: !rec_Ac,
5964
%arg5: !rec_P1,

0 commit comments

Comments
 (0)