Skip to content

Commit 06414cf

Browse files
[MLIR][TORCH] Add E2E support for aten.as_strided op
This commit adds the e2e support for the aten.as_strided op by decomposing it into a series of other torch operations. Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 8cee8ed commit 06414cf

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12427,6 +12427,198 @@ class DecomposeAtenRoundDecimalsOp
1242712427
};
1242812428
} // namespace
1242912429

12430+
namespace {
12431+
class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
12432+
public:
12433+
using OpRewritePattern<AtenAsStridedOp>::OpRewritePattern;
12434+
LogicalResult matchAndRewrite(AtenAsStridedOp op,
12435+
PatternRewriter &rewriter) const override {
12436+
12437+
// The `aten.as_strided` operation is decomposed into a series of
12438+
// operations that compute the indices based on the provided sizes and
12439+
// strides, and then index into the flattened input tensor as follows:
12440+
12441+
// input_flat = input.view(-1)
12442+
//
12443+
// for dim, s in enumerate(self.size):
12444+
// arange = torch.arange(s)
12445+
// view_shape = []
12446+
// for i in range(len(self.size)):
12447+
// if i == dim:
12448+
// view_shape.append(-1)
12449+
// else:
12450+
// view_shape.append(1)
12451+
// arange = arange.view(view_shape)
12452+
// if dim != 0:
12453+
// idx = idx + arange * self.stride[dim]
12454+
//
12455+
// # Flatten indices and add offset
12456+
// final_indices = idx.reshape(-1) + self.storage_offset
12457+
//
12458+
// # Index the flattened input tensor
12459+
// output = input_flat[final_indices]
12460+
//
12461+
// # Reshape to desired output size
12462+
// return output.view(self.size)
12463+
12464+
Location loc = op.getLoc();
12465+
MLIRContext *context = op->getContext();
12466+
Value input = op.getSelf();
12467+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
12468+
12469+
if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
12470+
return rewriter.notifyMatchFailure(op, "input must have known sizes");
12471+
12472+
SmallVector<int64_t> sizesInts;
12473+
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
12474+
return rewriter.notifyMatchFailure(
12475+
op, "sizes must be a list of constant ints");
12476+
12477+
SmallVector<int64_t> stridesInts;
12478+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts)))
12479+
return rewriter.notifyMatchFailure(
12480+
op, "strides must be a list of constant ints");
12481+
12482+
int64_t storageOffset = 0;
12483+
if (!isa<Torch::NoneType>(op.getStorageOffset().getType())) {
12484+
if (!matchPattern(op.getStorageOffset(),
12485+
m_TorchConstantInt(&storageOffset)))
12486+
return rewriter.notifyMatchFailure(
12487+
op, "storage_offset must be a constant integer");
12488+
}
12489+
12490+
ArrayRef<int64_t> inputSizes = inputType.getSizes();
12491+
int64_t inputRank = inputSizes.size();
12492+
int64_t resultRank = sizesInts.size();
12493+
12494+
Value cstZero =
12495+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
12496+
if (inputRank > 1) {
12497+
// If the input is not a 1-d tensor, we need to flatten it
12498+
// to a 1D tensor before applying the strided indexing.
12499+
int64_t flattenedInputSize = 1;
12500+
for (int64_t size : inputSizes)
12501+
flattenedInputSize *= size;
12502+
12503+
auto flattenedInputTy =
12504+
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
12505+
{flattenedInputSize}, inputType.getOptionalDtype()));
12506+
12507+
Value end = rewriter.create<ConstantIntOp>(
12508+
loc, rewriter.getI64IntegerAttr(inputRank - 1));
12509+
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenedInputTy,
12510+
input, cstZero, end);
12511+
}
12512+
12513+
Value cstOne =
12514+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
12515+
Value cstMinusOne =
12516+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
12517+
12518+
SmallVector<int64_t> viewShapeInts(resultRank, 1);
12519+
SmallVector<Value> viewShapeListElems(resultRank, cstOne);
12520+
12521+
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
12522+
Value finalIndices;
12523+
for (unsigned dim = 0; dim < sizesInts.size(); dim++) {
12524+
int64_t size = sizesInts[dim];
12525+
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
12526+
Value end =
12527+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(size));
12528+
12529+
auto arangeType =
12530+
ValueTensorType::get(context, llvm::ArrayRef(size), si64Type);
12531+
Value index = rewriter.create<Torch::AtenArangeOp>(
12532+
loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);
12533+
12534+
// Set the current dimension to -1 for broadcasting
12535+
viewShapeInts[dim] = -1;
12536+
viewShapeListElems[dim] = cstMinusOne;
12537+
12538+
Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
12539+
loc, Torch::ListType::get(Torch::IntType::get(context)),
12540+
viewShapeListElems);
12541+
12542+
auto viewType = ValueTensorType::get(
12543+
context, llvm::ArrayRef(viewShapeInts), si64Type);
12544+
index = rewriter.create<AtenViewOp>(loc, viewType, index, viewShapeList);
12545+
12546+
// Multiply the index with the stride for the current dimension
12547+
Value cstStride = rewriter.create<ConstantIntOp>(
12548+
loc, rewriter.getI64IntegerAttr(stridesInts[dim]));
12549+
index = rewriter.create<AtenMulScalarOp>(loc, viewType, index, cstStride);
12550+
12551+
// Reset the current dimension to 1 for the next iteration
12552+
viewShapeInts[dim] = 1;
12553+
viewShapeListElems[dim] = cstOne;
12554+
12555+
if (dim == 0) {
12556+
finalIndices = index;
12557+
continue;
12558+
}
12559+
12560+
// calculate common shape for broadcast
12561+
SmallVector<int64_t> broadcastShape;
12562+
SmallVector<Value> broadcastShapeValue;
12563+
computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
12564+
broadcastShapeValue);
12565+
Type broadcastType = ValueTensorType::get(
12566+
context, llvm::ArrayRef(broadcastShape), si64Type);
12567+
12568+
finalIndices = rewriter.create<AtenAddTensorOp>(
12569+
loc, broadcastType, finalIndices, index, cstOne);
12570+
}
12571+
12572+
int64_t flattenedResultSize = 1;
12573+
for (int64_t size : sizesInts)
12574+
flattenedResultSize *= size;
12575+
12576+
// Flattening the indices and adding the storage offset
12577+
finalIndices = rewriter.create<AtenFlattenUsingIntsOp>(
12578+
loc,
12579+
ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12580+
si64Type),
12581+
finalIndices, cstZero, cstMinusOne); // -1 means flatten all
12582+
12583+
if (storageOffset != 0) {
12584+
Value cstStorageOffset = rewriter.create<ConstantIntOp>(
12585+
loc, rewriter.getI64IntegerAttr(storageOffset));
12586+
finalIndices = rewriter.create<AtenAddScalarOp>(
12587+
loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne);
12588+
}
12589+
12590+
// Index the flattened input tensor
12591+
Type listElemType =
12592+
inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
12593+
/*optionalDtype=*/nullptr);
12594+
Value indicesList = rewriter.create<Torch::PrimListConstructOp>(
12595+
loc, Torch::ListType::get(listElemType),
12596+
SmallVector<Value>{finalIndices});
12597+
12598+
auto flattenedResultTy =
12599+
ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12600+
inputType.getOptionalDtype());
12601+
Value result = rewriter.create<AtenIndexTensorOp>(loc, flattenedResultTy,
12602+
input, indicesList);
12603+
12604+
// Reshape the result to the desired output size
12605+
SmallVector<Value> sizesIntsValues;
12606+
for (int64_t size : sizesInts) {
12607+
sizesIntsValues.push_back(rewriter.create<ConstantIntOp>(
12608+
loc, rewriter.getI64IntegerAttr(size)));
12609+
}
12610+
Value resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
12611+
loc, Torch::ListType::get(Torch::IntType::get(context)),
12612+
sizesIntsValues);
12613+
result =
12614+
rewriter.create<AtenViewOp>(loc, op.getType(), result, resultSizeList);
12615+
12616+
rewriter.replaceOp(op, result);
12617+
return success();
12618+
}
12619+
};
12620+
} // namespace
12621+
1243012622
namespace {
1243112623
class DecomposeComplexOpsPass
1243212624
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12750,6 +12942,7 @@ class DecomposeComplexOpsPass
1275012942
patterns);
1275112943
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
1275212944
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
12945+
addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);
1275312946

1275412947
GreedyRewriteConfig config;
1275512948
config.setUseTopDownTraversal(true);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
589589
target.addIllegalOp<AtenLogaddexpOp>();
590590
target.addIllegalOp<AtenLogaddexp2Op>();
591591
target.addIllegalOp<AtenKlDivOp>();
592+
target.addIllegalOp<AtenAsStridedOp>();
592593

593594
for (auto &opName : backendLegalOpsSet) {
594595
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,8 @@
982982
"NativeGroupNormModule_basic",
983983
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
984984
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
985+
"AtenAsStridedModule_basic",
986+
"AtenAsStridedNoStorageOffsetModule_basic",
985987
}
986988

987989
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3949,6 +3951,8 @@
39493951
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
39503952
"ReplicationPad1dModule_2DInput_basic",
39513953
"ReplicationPad1dModule_3DInput_basic",
3954+
"AtenAsStridedModule_basic",
3955+
"AtenAsStridedNoStorageOffsetModule_basic",
39523956
}
39533957

39543958
ONNX_TOSA_CRASHING_SET = {

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6730,3 +6730,48 @@ def forward(self, x):
67306730
@register_test_case(module_factory=lambda: Aten_AssertScalar())
67316731
def Aten_AssertScalar_basic(module, tu: TestUtils):
67326732
module.forward(torch.tensor(4))
6733+
6734+
6735+
# ==============================================================================
6736+
6737+
6738+
class AtenAsStridedModule(torch.nn.Module):
6739+
def __init__(self):
6740+
super().__init__()
6741+
6742+
@export
6743+
@annotate_args(
6744+
[
6745+
None,
6746+
([4, 5, 6], torch.float32, True),
6747+
]
6748+
)
6749+
def forward(self, x):
6750+
return torch.ops.aten.as_strided(
6751+
x, size=(2, 2), stride=(3, 3), storage_offset=1
6752+
)
6753+
6754+
6755+
@register_test_case(module_factory=lambda: AtenAsStridedModule())
6756+
def AtenAsStridedModule_basic(module, tu: TestUtils):
6757+
module.forward(torch.randn(4, 5, 6))
6758+
6759+
6760+
class AtenAsStridedNoStorageOffsetModule(torch.nn.Module):
6761+
def __init__(self):
6762+
super().__init__()
6763+
6764+
@export
6765+
@annotate_args(
6766+
[
6767+
None,
6768+
([12, 13], torch.float32, True),
6769+
]
6770+
)
6771+
def forward(self, x):
6772+
return torch.ops.aten.as_strided(x, size=(3, 4), stride=(2, 5))
6773+
6774+
6775+
@register_test_case(module_factory=lambda: AtenAsStridedNoStorageOffsetModule())
6776+
def AtenAsStridedNoStorageOffsetModule_basic(module, tu: TestUtils):
6777+
module.forward(torch.randn(12, 13))

0 commit comments

Comments
 (0)