Skip to content

Commit b15039a

Browse files
committed
Classifies scalar-like operands and assigns them a
rank-aware scalar map (d0,…,dn) -> () during lowering.
1 parent 6738853 commit b15039a

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,39 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
105105
return rewriter.notifyMatchFailure(
106106
op, "requires elementwise op on ranked tensors");
107107

108-
auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
109-
SmallVector<AffineMap, 3> indexingMaps(
110-
op->getNumResults() + op->getNumOperands(),
111-
rewriter.getMultiDimIdentityMap(rank));
112-
SmallVector<utils::IteratorType, 6> iteratorTypes(
108+
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
109+
auto rank = resTy.getRank();
110+
111+
// Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0.
112+
AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
113+
/*results=*/{}, rewriter.getContext());
114+
AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
115+
116+
// Create indexing maps: one per operand, one per result.
117+
SmallVector<AffineMap, 6> indexingMaps;
118+
indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
119+
120+
for (Value v : op->getOperands()) {
121+
Type ty = v.getType();
122+
if (isScalarLike(ty))
123+
indexingMaps.push_back(scalarMap);
124+
else if (auto rt = dyn_cast<RankedTensorType>(ty)) {
125+
indexingMaps.push_back(idMap);
126+
} else
127+
return rewriter.notifyMatchFailure(
128+
op,
129+
"unsupported operand type (expected scalar-like or ranked tensor)");
130+
}
131+
132+
for (Value r : op->getResults()) {
133+
(void)r;
134+
indexingMaps.push_back(idMap); // results use identity map.
135+
}
136+
137+
SmallVector<utils::IteratorType, 4> iteratorTypes(
113138
rank, utils::IteratorType::parallel);
114-
auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
139+
SmallVector<Value, 2> outputs =
140+
getOrCreateOperandsMatchingResultTypes(rewriter, op);
115141
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
116142
op, /*resultTensorTypes=*/op->getResultTypes(),
117143
/*inputs=*/op->getOperands(),
@@ -120,14 +146,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
120146
/*iteratorTypes=*/iteratorTypes,
121147
/*bodyBuilder=*/
122148
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
123-
auto resultTypes = llvm::to_vector<6>(
149+
SmallVector<Type> resultEltTys = llvm::to_vector<6>(
124150
llvm::map_range(op->getResultTypes(), [](Type type) {
125151
return cast<TensorType>(type).getElementType();
126152
}));
127-
auto *scalarOp =
153+
Operation *scalarOp =
128154
builder.create(loc, op->getName().getIdentifier(),
129155
regionArgs.take_front(op->getNumOperands()),
130-
resultTypes, op->getAttrs());
156+
resultEltTys, op->getAttrs());
131157
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
132158
});
133159
return success();

0 commit comments

Comments
 (0)