Skip to content

Add torch.aten.as_strided to tensor.extract_slice conversion #4268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion lib/Conversion/TorchToTensor/TorchToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;

namespace {

Expand Down Expand Up @@ -138,6 +140,97 @@ class ConvertAtenTensorOpPattern : public OpConversionPattern<AtenTensorOp> {
}
};

class ConvertAtenAsStridedOp : public OpConversionPattern<AtenAsStridedOp> {
public:
using OpConversionPattern<AtenAsStridedOp>::OpConversionPattern;
using OpAdaptor = typename AtenAsStridedOp::Adaptor;
LogicalResult
matchAndRewrite(AtenAsStridedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// In some cases AtenAsStridedOp is equivalent to a Tensor ExtractSliceOp.
// We will try to match those cases here.
auto inputShape =
cast<RankedTensorType>(adaptor.getSelf().getType()).getShape();
auto outputShape =
cast<BaseTensorType>(op.getResult().getType()).getSizes();
auto resultTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));

// If the output shape is strictly larger than the input shape at any
// dimension than this AtenAsStridedOp is not equivalent to a slice.
for (uint64_t i = 0; i < outputShape.size(); ++i) {
if (outputShape[i] > inputShape[i])
return failure();
}

// Calculate what the strides attribute should be if the input tensor is
// contiguous.
SmallVector<int64_t> contiguousStrides(inputShape.size(), 1);
for (int i = inputShape.size() - 2; i >= 0; --i) {
contiguousStrides[i] = contiguousStrides[i + 1] * inputShape[i + 1];
}

SmallVector<Value> outSizeValues, opStridesValues;
if (!getListConstructElements(adaptor.getStride(), opStridesValues))
return op.emitError(
"unimplemented: the tensor list is not from list construct");

if (!getListConstructElements(adaptor.getSize(), outSizeValues))
return op.emitError(
"unimplemented: the tensor list is not from list construct");

// Get storage offset
int64_t offset;
if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset)))
offset = 0;

APInt size;
SmallVector<int64_t> outSize(inputShape.size(), 0);
for (uint64_t i = 0; i < outSizeValues.size(); ++i) {
if (!matchPattern(outSizeValues[i], m_Op<TorchConversion::FromI64Op>(
m_ConstantInt(&size))) ||
!size.isSignedIntN(64))
return failure();
outSize[i] = size.getSExtValue();
}
APInt stride;
SmallVector<int64_t> opStrides(inputShape.size(), 0);
for (uint64_t i = 0; i < opStridesValues.size(); ++i) {
if (!matchPattern(opStridesValues[i], m_Op<TorchConversion::FromI64Op>(
m_ConstantInt(&stride))) ||
!stride.isSignedIntN(64))
return failure();
opStrides[i] = stride.getSExtValue();
}

// Slice dims are the dims where the input and output shapes are not equal.
SmallVector<int64_t> sliceDims;
for (uint64_t i = 0; i < inputShape.size(); ++i) {
if (outSize[i] != inputShape[i])
sliceDims.push_back(i);
}

// If there are no slice dims, then the AtenAsStridedOp is equivalent to the
// input tensor.
if (sliceDims.empty()) {
rewriter.replaceOp(op, adaptor.getSelf());
return success();
}

SmallVector<int64_t> sliceOffsets(inputShape.size(), 0);
SmallVector<int64_t> sliceStrides(opStrides.size(), 1);
for (auto dim : sliceDims) {
sliceOffsets[dim] = offset / contiguousStrides[dim];
sliceStrides[dim] = opStrides[dim] / contiguousStrides[dim];
}

rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
op, resultTy, adaptor.getSelf(), ValueRange(), ValueRange(),
ValueRange(), sliceOffsets, outSize, sliceStrides);
return success();
}
};

class ConvertTorchToTensor
: public ConvertTorchToTensorBase<ConvertTorchToTensor> {
public:
Expand All @@ -153,14 +246,16 @@ class ConvertTorchToTensor
target.addIllegalOp<Torch::AtenItemOp>();
target.addIllegalOp<Torch::AtenTensorOp>();
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();
target.addIllegalOp<Torch::AtenAsStridedOp>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);
patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp,
ConvertAtenTensorOpPattern>(typeConverter, context);
ConvertAtenTensorOpPattern, ConvertAtenAsStridedOp>(
typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6730,3 +6730,26 @@ def forward(self, x):
@register_test_case(module_factory=lambda: Aten_AssertScalar())
def Aten_AssertScalar_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))


# ==============================================================================


class AtenAsStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1)


@register_test_case(module_factory=lambda: AtenAsStridedModule())
def AsStridedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(25, 1, 1))
27 changes: 26 additions & 1 deletion test/Conversion/TorchToTensor/torch_to_tensor.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s
// RUN: torch-mlir-opt <%s -split-input-file -convert-torch-to-tensor | FileCheck %s

// CHECK-LABEL: func.func @test_shape
func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> {
// CHECK: %[[SHAPE:.+]] = arith.constant dense<[3, 4, 5]> : tensor<3xi64>
%0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64>
return %0 : !torch.vtensor<[3],si64>
}

// -----

// CHECK-LABEL: func.func @test_as_strided
func.func @test_as_strided(%arg0: !torch.vtensor<[1,128,1024,192],f32>) -> !torch.vtensor<[1,128,1024,128],f32> {
%c0_i64 = arith.constant 0 : i64
%int0 = torch_c.from_i64 %c0_i64
%c1_i64 = arith.constant 1 : i64
%int1 = torch_c.from_i64 %c1_i64
%c128_i64 = arith.constant 128 : i64
%int128 = torch_c.from_i64 %c128_i64
%c192_i64 = arith.constant 192 : i64
%int192 = torch_c.from_i64 %c192_i64
%c1024_i64 = arith.constant 1024 : i64
%int1024 = torch_c.from_i64 %c1024_i64
%c24576_i64 = arith.constant 24576 : i64
%int24576 = torch_c.from_i64 %c24576_i64
%c25165824_i64 = arith.constant 25165824 : i64
%int25165824 = torch_c.from_i64 %c25165824_i64
%0 = torch.prim.ListConstruct %int1, %int128, %int1024, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int25165824, %int192, %int24576, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.+]] = tensor.extract_slice %0[0, 0, 0, 0] [1, 128, 1024, 128] [1, 1, 1, 1] : tensor<1x128x1024x192xf32> to tensor<1x128x1024x128xf32>
%2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[1,128,1024,192],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[1,128,1024,128],f32>
return %2 : !torch.vtensor<[1,128,1024,128],f32>
}
Loading