Skip to content

Commit 439e467

Browse files
matthias-springermikolaj-pirog
authored andcommitted
[mlir][IR] Add implicit conversion operator to TypedValue (llvm#164621)
Allow implicit conversion from `TypedValue<B>` to `TypedValue<A>` if `B` is assignable to `A`. Example: ```c++ TypedValue<MemRefType> val; TypedValue<ShapedType> shapedVal = val; // this is now valid ```
1 parent f826fd4 commit 439e467

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

mlir/include/mlir/IR/Value.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const {
433433
template <typename Ty>
434434
struct TypedValue : Value {
435435
using Value::Value;
436+
using ValueType = Ty;
436437

437438
static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
438439

440+
/// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable
441+
/// to A.
442+
template <typename ToTy,
443+
typename = typename std::enable_if<std::is_assignable<
444+
typename ToTy::ValueType &, Ty>::value>::type>
445+
operator ToTy() const {
446+
return llvm::cast<ToTy>(*this);
447+
}
448+
439449
/// Return the known Type
440450
Ty getType() const { return llvm::cast<Ty>(Value::getType()); }
441451
void setType(Ty ty) { Value::setType(ty); }

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
17511751
}
17521752

17531753
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1754-
return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
1754+
return getSource();
17551755
}
17561756

17571757
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1758-
return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
1758+
return getDest();
17591759
}
17601760

17611761
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,

mlir/lib/Dialect/Shard/Transforms/Partition.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
6969
Sharding sourceSharding,
7070
TypedValue<ShapedType> sourceShard, GridOp grid,
7171
int64_t splitTensorAxis, GridAxis splitGridAxis) {
72-
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
72+
TypedValue<ShapedType> targetShard =
7373
AllSliceOp::create(builder, sourceShard, grid,
7474
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
75-
.getResult());
75+
.getResult();
7676
Sharding targetSharding = targetShardingInSplitLastAxis(
7777
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
7878
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
204204
APInt(64, splitTensorAxis));
205205
ShapedType targetShape =
206206
shardShapedType(sourceUnshardedShape, grid, targetSharding);
207-
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
208-
tensor::CastOp::create(builder, targetShape, allGatherResult)
209-
.getResult());
207+
TypedValue<ShapedType> targetShard =
208+
tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
210209
return {targetShard, targetSharding};
211210
}
212211

@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
336335
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
337336
ShapedType targetShape =
338337
shardShapedType(sourceUnshardedShape, grid, targetSharding);
339-
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
340-
tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
338+
TypedValue<ShapedType> targetShard =
339+
tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
341340
return {targetShard, targetSharding};
342341
}
343342

@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
510509
auto targetSharding = target.getSharding();
511510
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
512511
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
513-
cast<TypedValue<ShapedType>>(source.getSrc()),
514-
sourceShardValue);
512+
source.getSrc(), sourceShardValue);
515513
}
516514

517515
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

0 commit comments

Comments
 (0)