Skip to content

Add torch.aten.topk to tm_tensor.topk conversion #4241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -355,19 +355,19 @@ def TMTensor_TopkOp : TMTensor_Op<"topk",
A Top-K operation for N-D tensors. Reduces the target dimension from the input
size N down to K elements based on the supplied binary region.

Accepts an N-D tensor input consisting of values and an optioanl N-D tensor
Accepts an N-D tensor input consisting of values and an optional N-D tensor
for indices of those values (i32 type). If input indices aren't provided, the
index mapping is inferred based on the k dim. Both input values/indices
tensors and output values/indicies tensors must have the same shape. Top-K is
tensors and output values/indices tensors must have the same shape. Top-K is
computed along the target dimension (from dimension()). Returns two output
tensors of values and the indicies of Top-K results. The output dimensions
must match the input save for the dimension that is reduced to K results.
tensors of values and the indices of Top-K results. The output dimensions
must match the input except for the dimension that is reduced to K results.

Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an
Region accepts lhs=[next N input] and rhs=[exiting K output] and yields an
i1. If true, the two values are swapped:
- For Top-K compoarision: >
- For Min-K comparision: <
Note: when the two values are equal, the first occurence is always selected.
- For Top-K comparison: >
- For Min-K comparison: <
Note: when the two values are equal, the first occurrence is always selected.
}];

let arguments = (ins Variadic<AnyShaped>:$inputs,
Expand Down
112 changes: 112 additions & 0 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,116 @@ class ConvertAtenScaledDotProductAttentionOp
};
} // namespace

namespace {
class ConvertAtenTopkOp : public OpConversionPattern<AtenTopkOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenTopkOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto inputType = cast<RankedTensorType>(input.getType());
unsigned inputRank = inputType.getRank();
Type inputElementType = inputType.getElementType();

auto indicesType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getIndices().getType()));

// get dim, check it is constant int
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant dim value is supported");

// turn dim into positive if negative, and check it is in the valid range
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
}

bool largest;
if (!matchPattern(op.getLargest(), m_TorchConstantBool(&largest)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant largest value is supported");

bool sorted;
if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only constant sorted value is supported");
if (sorted)
return rewriter.notifyMatchFailure(
op, "unimplemented: only unsorted topk is supported.");

SmallVector<Value> tmTensorTopkInputs({input});

SmallVector<OpFoldResult> outputDimSizes =
tensor::getMixedSizes(rewriter, loc, adaptor.getSelf());
int64_t k;
// TODO: why k does not fold if const? We should not need to deal with this
// here.
if (matchPattern(op.getK(), m_TorchConstantInt(&k))) {
outputDimSizes[dim] = rewriter.getI64IntegerAttr(k);
} else {
Value kVal = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(op.getK().getType()),
op.getK());
Value kIdx = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), kVal);
outputDimSizes[dim] = kIdx;
}

Value emptyTensorOutputValues = rewriter.create<mlir::tensor::EmptyOp>(
loc, outputDimSizes, inputElementType);
// Fill the initial output values tensor based on largest.
// Ascending or descending.
TypedAttr infAttr;
if (auto intType = dyn_cast<IntegerType>(inputElementType)) {
APInt fillVal;
if (largest) {
fillVal = APInt::getSignedMinValue(intType.getWidth());
} else {
fillVal = APInt::getSignedMaxValue(intType.getWidth());
}
infAttr = rewriter.getIntegerAttr(intType, fillVal);
} else {
auto fillVal = APFloat::getInf(
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
/*Negative=*/largest);
infAttr = rewriter.getFloatAttr(inputElementType, fillVal);
}
Value inf = rewriter.create<arith::ConstantOp>(loc, infAttr);
Value infTensor =
rewriter.create<linalg::FillOp>(loc, inf, emptyTensorOutputValues)
.result();

Value emptyTensorOutputIndices = rewriter.create<mlir::tensor::EmptyOp>(
loc, outputDimSizes, indicesType.getElementType());

SmallVector<Value> tmTensorTopkOutputs(
{infTensor, emptyTensorOutputIndices});

SmallVector<Type> tmTensorTopkElementTypes(
{inputElementType, inputElementType});

FailureOr<SmallVector<Value>> tmTensorTopkResults;
{
OpBuilder::InsertionGuard guard(rewriter);
tmTensorTopkResults = createTMTensorTopkOp(
rewriter, loc, tmTensorTopkInputs, tmTensorTopkOutputs,
tmTensorTopkElementTypes, dim, /*isMinK=*/!largest);
}
if (failed(tmTensorTopkResults))
return tmTensorTopkResults;

rewriter.replaceOp(op, tmTensorTopkResults.value());

return success();
}
};

} // namespace

namespace {
class ConvertAtenKthvalueOp : public OpConversionPattern<AtenKthvalueOp> {
public:
Expand Down Expand Up @@ -2513,6 +2623,8 @@ class ConvertTorchToTMTensor
target.addIllegalOp<AtenScatterAddOp>();
patterns.add<ConvertAtenScatterOp<AtenScatterAddOp>>(typeConverter,
context);
target.addIllegalOp<AtenTopkOp>();
patterns.add<ConvertAtenTopkOp>(typeConverter, context);
target.addIllegalOp<AtenKthvalueOp>();
patterns.add<ConvertAtenKthvalueOp>(typeConverter, context);

Expand Down
65 changes: 65 additions & 0 deletions test/Conversion/TorchToTMTensor/topk.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN: torch-mlir-opt %s \
// RUN: --convert-torch-to-tmtensor --canonicalize --split-input-file \
// RUN: | FileCheck %s

// CHECK-LABEL: func.func @test_topk_static_dims
// CHECK-SAME: %[[t:.*]]: !torch.vtensor<[2,5],f32>)
// CHECK-DAG: %[[positive_inf:.*]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: %[[t_as_tensor:.*]] = torch_c.to_builtin_tensor %[[t]] : !torch.vtensor<[2,5],f32> -> tensor<2x5xf32>
// CHECK: %[[empty_as_t:.*]] = tensor.empty() : tensor<2x3xf32>
// CHECK: %[[pos_inf_like_t:.*]] = linalg.fill ins(%[[positive_inf]] : f32)
// CHECK-SAME: outs(%[[empty_as_t]] : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: %[[empty_indices:.*]] = tensor.empty() : tensor<2x3xi64>
// CHECK: %[[topk_result:.*]]:2 = tm_tensor.topk dimension(1) ins(%[[t_as_tensor]] : tensor<2x5xf32>)
// CHECK-SAME: outs(%[[pos_inf_like_t]], %[[empty_indices]] : tensor<2x3xf32>, tensor<2x3xi64>) {
// CHECK: ^bb0(%[[lhs:.*]]: f32, %[[rhs:.*]]: f32):
// CHECK: %[[cmpf_res:.*]] = arith.cmpf olt, %[[lhs]], %[[rhs]] : f32
// CHECK: tm_tensor.yield %[[cmpf_res]] : i1
// CHECK: } -> tensor<2x3xf32>, tensor<2x3xi64>
// CHECK-DAG: %[[indices:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#1 : tensor<2x3xi64> -> !torch.vtensor<[2,3],si64>
// CHECK-DAG: %[[values:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[values]], %[[indices]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64>
func.func @test_topk_static_dims(%t: !torch.vtensor<[2,5],f32>) ->
(!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64>) {
%k = torch.constant.int 3
%dim = torch.constant.int 1
%largest = torch.constant.bool false
%sorted = torch.constant.bool false
%values, %indices = torch.aten.topk %t, %k, %dim, %largest, %sorted :
!torch.vtensor<[2,5],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool ->
!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64>
return %values, %indices : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64>
}

// -----

// CHECK-LABEL: func.func @test_topk_with_dynamic_dim_and_k
// CHECK-SAME: %[[t:.*]]: !torch.vtensor<[?,8,32],f32>,
// CHECK-SAME: %[[k:.*]]: !torch.int
// CHECK-DAG: %[[negative_inf:.*]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[t_as_tensor:.*]] = torch_c.to_builtin_tensor %[[t]] : !torch.vtensor<[?,8,32],f32> -> tensor<?x8x32xf32>
// CHECK: %[[dim_0_size:.*]] = tensor.dim %[[t_as_tensor]], %[[c0]] : tensor<?x8x32xf32>
// CHECK: %[[k_as_i64:.*]] = torch_c.to_i64 %[[k]]
// CHECK: %[[k_as_index:.*]] = arith.index_cast %[[k_as_i64]] : i64 to index
// CHECK: %[[empty_like_t:.*]] = tensor.empty(%[[dim_0_size]], %[[k_as_index]]) : tensor<?x8x?xf32>
// CHECK: %[[neg_inf_like_t:.*]] = linalg.fill ins(%[[negative_inf]] : f32) outs(%[[empty_like_t]] : tensor<?x8x?xf32>) -> tensor<?x8x?xf32>
// CHECK: %[[empty_indices:.]] = tensor.empty(%[[dim_0_size]], %[[k_as_index]]) : tensor<?x8x?xi64>
// CHECK: %[[topk_result:.*]]:2 = tm_tensor.topk dimension(2) ins(%[[t_as_tensor]] : tensor<?x8x32xf32>) outs(%[[neg_inf_like_t]], %[[empty_indices]] : tensor<?x8x?xf32>, tensor<?x8x?xi64>) {
// CHECK: ^bb0(%[[lhs:.*]]: f32, %[[rhs:.*]]: f32):
// CHECK: %[[cmp_res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
// CHECK: tm_tensor.yield %[[cmp_res]] : i1
// CHECK: } -> tensor<?x8x?xf32>, tensor<?x8x?xi64>
// CHECK-DAG: %[[indices:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#1 : tensor<?x8x?xi64> -> !torch.vtensor<[?,8,?],si64>
// CHECK-DAG: %[[values:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#0 : tensor<?x8x?xf32> -> !torch.vtensor<[?,8,?],f32>
// CHECK: return %[[values]], %[[indices]] : !torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64>
func.func @test_topk_with_dynamic_dim_and_k(%t: !torch.vtensor<[?,8,32],f32>, %k: !torch.int) ->
(!torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64>) {
%dim = torch.constant.int -1
%largest = torch.constant.bool true
%sorted = torch.constant.bool false
%values, %indices = torch.aten.topk %t, %k, %dim, %largest, %sorted :
!torch.vtensor<[?,8,32],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool ->
!torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64>
return %values, %indices : !torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64>
}
Loading