Skip to content

Commit a0aca63

Browse files
committed
[CIR] Add special type for vtables
This change introduces a new type, cir.vtable, to be used with operations that return a pointer to a class' vtable.
1 parent ee785a0 commit a0aca63

File tree

15 files changed

+107
-75
lines changed

15 files changed

+107
-75
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,13 +2596,13 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
25962596
(as specified by Itanium ABI), and `address_point.offset` (address point index) the actual address
25972597
point within that vtable.
25982598

2599-
The return type is always a `!cir.ptr<!cir.ptr<() -> i32>>`.
2599+
The return type is always a `!cir.ptr<!cir.vtable>`.
26002600

26012601
Example:
26022602
```mlir
26032603
cir.global linkonce_odr @_ZTV1B = ...
26042604
...
2605-
%3 = cir.vtable.address_point(@_ZTV1B, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.ptr<() -> i32>>
2605+
%3 = cir.vtable.address_point(@_ZTV1B, address_point = <index = 0, offset = 2>) : !cir.vtable_ptr
26062606
```
26072607
}];
26082608

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,21 @@ def CIR_PtrToExceptionInfoType
263263
def CIR_AnyDataMemberType : CIR_TypeBase<"::cir::DataMemberType",
264264
"data member type">;
265265

266+
//===----------------------------------------------------------------------===//
267+
// VTable type predicates
268+
//===----------------------------------------------------------------------===//
269+
270+
def CIR_AnyVTableType : CIR_TypeBase<"::cir::VTableType",
271+
"vtable type">;
272+
273+
266274
//===----------------------------------------------------------------------===//
267275
// Scalar Type predicates
268276
//===----------------------------------------------------------------------===//
269277

270278
defvar CIR_ScalarTypes = [
271279
CIR_AnyBoolType, CIR_AnyIntType, CIR_AnyFloatType, CIR_AnyPtrType,
272-
CIR_AnyDataMemberType
280+
CIR_AnyDataMemberType, CIR_AnyVTableType
273281
];
274282

275283
def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,22 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
343343
}];
344344
}
345345

346+
//===----------------------------------------------------------------------===//
347+
// CIR_VTableType
348+
//===----------------------------------------------------------------------===//
349+
350+
def CIR_VTableType : CIR_Type<"VTable", "vtable",
351+
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
352+
353+
let summary = "CIR type that is used for pointers that point to a C++ vtable";
354+
let description = [{
355+
`cir.vtable` is a special type used as the pointee type for pointers to
356+
vtables. This avoids using arbitrary pointer types to declare vtable
357+
pointer values.
358+
}];
359+
}
360+
361+
346362
//===----------------------------------------------------------------------===//
347363
// BoolType
348364
//===----------------------------------------------------------------------===//
@@ -751,7 +767,8 @@ def CIRRecordType : Type<
751767
def CIR_AnyType : AnyTypeOf<[
752768
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_MethodType,
753769
CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_FuncType, CIR_VoidType,
754-
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType
770+
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType,
771+
CIR_VTableType
755772
]>;
756773

757774
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,12 +424,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
424424
llvm_unreachable("unsupported long double format");
425425
}
426426

427-
mlir::Type getVirtualFnPtrType(bool isVarArg = false) {
428-
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
429-
// type so it's a bit more clear and C++ idiomatic.
430-
auto fnTy = cir::FuncType::get({}, getUInt32Ty(), isVarArg);
431-
assert(!cir::MissingFeatures::isVarArg());
432-
return getPointerTo(getPointerTo(fnTy));
427+
mlir::Type getVirtualFnPtrType() {
428+
return cir::PointerType::get(cir::VTableType::get(getContext()));
433429
}
434430

435431
cir::FuncType getFuncType(llvm::ArrayRef<mlir::Type> params, mlir::Type retTy,

clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
10071007
.getAddressPoint(Base);
10081008

10091009
auto &builder = CGM.getBuilder();
1010-
auto vtablePtrTy = builder.getVirtualFnPtrType(/*isVarArg=*/false);
1010+
auto vtablePtrTy = builder.getVirtualFnPtrType();
10111011

10121012
return builder.create<cir::VTableAddrPointOp>(
10131013
CGM.getLoc(VTableClass->getSourceRange()), vtablePtrTy,

clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,6 @@ void CIRRecordLowering::accumulateVPtrs() {
488488
}
489489

490490
mlir::Type CIRRecordLowering::getVFPtrType() {
491-
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
492-
// type so it's a bit more clear and C++ idiomatic.
493491
return builder.getVirtualFnPtrType();
494492
}
495493

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,10 +2423,7 @@ LogicalResult cir::VTableAddrPointOp::verify() {
24232423
return success();
24242424

24252425
auto resultType = getAddr().getType();
2426-
auto intTy = cir::IntType::get(getContext(), 32, /*isSigned=*/false);
2427-
auto fnTy = cir::FuncType::get({}, intTy);
2428-
2429-
auto resTy = cir::PointerType::get(cir::PointerType::get(fnTy));
2426+
auto resTy = cir::PointerType::get(cir::VTableType::get(getContext()));
24302427

24312428
if (resultType != resTy)
24322429
return emitOpError("result type must be '")

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,20 @@ DataMemberType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
407407
return 8;
408408
}
409409

410+
llvm::TypeSize
411+
VTableType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
412+
::mlir::DataLayoutEntryListRef params) const {
413+
// FIXME: consider size differences under different ABIs
414+
return llvm::TypeSize::getFixed(64);
415+
}
416+
417+
uint64_t
418+
VTableType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
419+
::mlir::DataLayoutEntryListRef params) const {
420+
// FIXME: consider alignment differences under different ABIs
421+
return 8;
422+
}
423+
410424
llvm::TypeSize
411425
ArrayType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
412426
::mlir::DataLayoutEntryListRef params) const {

clang/test/CIR/CodeGen/dtors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class B : public A
3636
};
3737

3838
// Class A
39-
// CHECK: ![[ClassA:rec_.*]] = !cir.record<class "A" {!cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>} #cir.record.decl.ast>
39+
// CHECK: ![[ClassA:rec_.*]] = !cir.record<class "A" {!cir.ptr<!cir.vtable>} #cir.record.decl.ast>
4040

4141
// Class B
4242
// CHECK: ![[ClassB:rec_.*]] = !cir.record<class "B" {![[ClassA]]}>

clang/test/CIR/CodeGen/dynamic-cast-exact.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ struct Derived final : Base1 {};
1616
Derived *ptr_cast(Base1 *ptr) {
1717
return dynamic_cast<Derived *>(ptr);
1818
// CHECK: %[[#SRC:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
19-
// CHECK-NEXT: %[[#EXPECTED_VPTR:]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>
20-
// CHECK-NEXT: %[[#SRC_VPTR_PTR:]] = cir.cast(bitcast, %[[#SRC]] : !cir.ptr<!rec_Base1>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>>
21-
// CHECK-NEXT: %[[#SRC_VPTR:]] = cir.load{{.*}} %[[#SRC_VPTR_PTR]] : !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>>, !cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>
22-
// CHECK-NEXT: %[[#SUCCESS:]] = cir.cmp(eq, %[[#SRC_VPTR]], %[[#EXPECTED_VPTR]]) : !cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>, !cir.bool
19+
// CHECK-NEXT: %[[#EXPECTED_VPTR:]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.vtable>
20+
// CHECK-NEXT: %[[#SRC_VPTR_PTR:]] = cir.cast(bitcast, %[[#SRC]] : !cir.ptr<!rec_Base1>), !cir.ptr<!cir.ptr<!cir.vtable>>
21+
// CHECK-NEXT: %[[#SRC_VPTR:]] = cir.load{{.*}} %[[#SRC_VPTR_PTR]] : !cir.ptr<!cir.ptr<!cir.vtable>>, !cir.ptr<!cir.vtable>
22+
// CHECK-NEXT: %[[#SUCCESS:]] = cir.cmp(eq, %[[#SRC_VPTR]], %[[#EXPECTED_VPTR]]) : !cir.ptr<!cir.vtable>, !cir.bool
2323
// CHECK-NEXT: %{{.+}} = cir.ternary(%[[#SUCCESS]], true {
2424
// CHECK-NEXT: %[[#RES:]] = cir.cast(bitcast, %[[#SRC]] : !cir.ptr<!rec_Base1>), !cir.ptr<!rec_Derived>
2525
// CHECK-NEXT: cir.yield %[[#RES]] : !cir.ptr<!rec_Derived>
@@ -39,10 +39,10 @@ Derived *ptr_cast(Base1 *ptr) {
3939
Derived &ref_cast(Base1 &ref) {
4040
return dynamic_cast<Derived &>(ref);
4141
// CHECK: %[[#SRC:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
42-
// CHECK-NEXT: %[[#EXPECTED_VPTR:]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>
43-
// CHECK-NEXT: %[[#SRC_VPTR_PTR:]] = cir.cast(bitcast, %[[#SRC]] : !cir.ptr<!rec_Base1>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>>
44-
// CHECK-NEXT: %[[#SRC_VPTR:]] = cir.load{{.*}} %[[#SRC_VPTR_PTR]] : !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>>, !cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>
45-
// CHECK-NEXT: %[[#SUCCESS:]] = cir.cmp(eq, %[[#SRC_VPTR]], %[[#EXPECTED_VPTR]]) : !cir.ptr<!cir.ptr<!cir.func<() -> !u32i>>>, !cir.bool
42+
// CHECK-NEXT: %[[#EXPECTED_VPTR:]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.vtable>
43+
// CHECK-NEXT: %[[#SRC_VPTR_PTR:]] = cir.cast(bitcast, %[[#SRC]] : !cir.ptr<!rec_Base1>), !cir.ptr<!cir.ptr<!cir.vtable>>
44+
// CHECK-NEXT: %[[#SRC_VPTR:]] = cir.load{{.*}} %[[#SRC_VPTR_PTR]] : !cir.ptr<!cir.ptr<!cir.vtable>>, !cir.ptr<!cir.vtable>
45+
// CHECK-NEXT: %[[#SUCCESS:]] = cir.cmp(eq, %[[#SRC_VPTR]], %[[#EXPECTED_VPTR]]) : !cir.ptr<!cir.vtable>, !cir.bool
4646
// CHECK-NEXT: %[[#FAILED:]] = cir.unary(not, %[[#SUCCESS]]) : !cir.bool, !cir.bool
4747
// CHECK-NEXT: cir.if %[[#FAILED]] {
4848
// CHECK-NEXT: cir.call @__cxa_bad_cast() : () -> ()

0 commit comments

Comments
 (0)