diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index c47eaabf7364..39ee4cfc1c59 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -355,19 +355,19 @@ def TMTensor_TopkOp : TMTensor_Op<"topk", A Top-K operation for N-D tensors. Reduces the target dimension from the input size N down to K elements based on the supplied binary region. - Accepts an N-D tensor input consisting of values and an optioanl N-D tensor + Accepts an N-D tensor input consisting of values and an optional N-D tensor for indices of those values (i32 type). If input indices aren't provided, the index mapping is inferred based on the k dim. Both input values/indices - tensors and output values/indicies tensors must have the same shape. Top-K is + tensors and output values/indices tensors must have the same shape. Top-K is computed along the target dimension (from dimension()). Returns two output - tensors of values and the indicies of Top-K results. The output dimensions - must match the input save for the dimension that is reduced to K results. + tensors of values and the indices of Top-K results. The output dimensions + must match the input except for the dimension that is reduced to K results. - Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an + Region accepts lhs=[next N input] and rhs=[exiting K output] and yields an i1. If true, the two values are swapped: - - For Top-K compoarision: > - - For Min-K comparision: < - Note: when the two values are equal, the first occurence is always selected. + - For Top-K comparison: > + - For Min-K comparison: < + Note: when the two values are equal, the first occurrence is always selected. }]; let arguments = (ins Variadic:$inputs, diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 2c091b7e58b9..7940a82dd046 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -2008,6 +2008,116 @@ class ConvertAtenScaledDotProductAttentionOp }; } // namespace +namespace { +class ConvertAtenTopkOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenTopkOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto inputType = cast(input.getType()); + unsigned inputRank = inputType.getRank(); + Type inputElementType = inputType.getElementType(); + + auto indicesType = cast( + getTypeConverter()->convertType(op.getIndices().getType())); + + // get dim, check it is constant int + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + + // turn dim into positive if negative, and check it is in the valid range + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + } + + bool largest; + if (!matchPattern(op.getLargest(), m_TorchConstantBool(&largest))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant largest value is supported"); + + bool sorted; + if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant sorted value is supported"); + if (sorted) + return rewriter.notifyMatchFailure( + op, "unimplemented: only unsorted topk is supported."); + + SmallVector tmTensorTopkInputs({input}); + + SmallVector outputDimSizes = + tensor::getMixedSizes(rewriter, loc, adaptor.getSelf()); + int64_t k; + // TODO: why k does not fold if const? We should not need to deal with this + // here. + if (matchPattern(op.getK(), m_TorchConstantInt(&k))) { + outputDimSizes[dim] = rewriter.getI64IntegerAttr(k); + } else { + Value kVal = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(op.getK().getType()), + op.getK()); + Value kIdx = rewriter.create( + loc, rewriter.getIndexType(), kVal); + outputDimSizes[dim] = kIdx; + } + + Value emptyTensorOutputValues = rewriter.create( + loc, outputDimSizes, inputElementType); + // Fill the initial output values tensor based on largest. + // Ascending or descending. + TypedAttr infAttr; + if (auto intType = dyn_cast(inputElementType)) { + APInt fillVal; + if (largest) { + fillVal = APInt::getSignedMinValue(intType.getWidth()); + } else { + fillVal = APInt::getSignedMaxValue(intType.getWidth()); + } + infAttr = rewriter.getIntegerAttr(intType, fillVal); + } else { + auto fillVal = APFloat::getInf( + cast(inputElementType).getFloatSemantics(), + /*Negative=*/largest); + infAttr = rewriter.getFloatAttr(inputElementType, fillVal); + } + Value inf = rewriter.create(loc, infAttr); + Value infTensor = + rewriter.create(loc, inf, emptyTensorOutputValues) + .result(); + + Value emptyTensorOutputIndices = rewriter.create( + loc, outputDimSizes, indicesType.getElementType()); + + SmallVector tmTensorTopkOutputs( + {infTensor, emptyTensorOutputIndices}); + + SmallVector tmTensorTopkElementTypes( + {inputElementType, inputElementType}); + + FailureOr> tmTensorTopkResults; + { + OpBuilder::InsertionGuard guard(rewriter); + tmTensorTopkResults = createTMTensorTopkOp( + rewriter, loc, tmTensorTopkInputs, tmTensorTopkOutputs, + tmTensorTopkElementTypes, dim, /*isMinK=*/!largest); + } + if (failed(tmTensorTopkResults)) + return tmTensorTopkResults; + + rewriter.replaceOp(op, tmTensorTopkResults.value()); + + return success(); + } +}; + +} // namespace + namespace { class ConvertAtenKthvalueOp : public OpConversionPattern { public: @@ -2513,6 +2623,8 @@ class ConvertTorchToTMTensor target.addIllegalOp(); patterns.add>(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/test/Conversion/TorchToTMTensor/topk.mlir b/test/Conversion/TorchToTMTensor/topk.mlir new file mode 100644 index 000000000000..319650c40c16 --- /dev/null +++ b/test/Conversion/TorchToTMTensor/topk.mlir @@ -0,0 +1,65 @@ +// RUN: torch-mlir-opt %s \ +// RUN: --convert-torch-to-tmtensor --canonicalize --split-input-file \ +// RUN: | FileCheck %s + +// CHECK-LABEL: func.func @test_topk_static_dims +// CHECK-SAME: %[[t:.*]]: !torch.vtensor<[2,5],f32>) +// CHECK-DAG: %[[positive_inf:.*]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: %[[t_as_tensor:.*]] = torch_c.to_builtin_tensor %[[t]] : !torch.vtensor<[2,5],f32> -> tensor<2x5xf32> +// CHECK: %[[empty_as_t:.*]] = tensor.empty() : tensor<2x3xf32> +// CHECK: %[[pos_inf_like_t:.*]] = linalg.fill ins(%[[positive_inf]] : f32) +// CHECK-SAME: outs(%[[empty_as_t]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[empty_indices:.*]] = tensor.empty() : tensor<2x3xi64> +// CHECK: %[[topk_result:.*]]:2 = tm_tensor.topk dimension(1) ins(%[[t_as_tensor]] : tensor<2x5xf32>) +// CHECK-SAME: outs(%[[pos_inf_like_t]], %[[empty_indices]] : tensor<2x3xf32>, tensor<2x3xi64>) { +// CHECK: ^bb0(%[[lhs:.*]]: f32, %[[rhs:.*]]: f32): +// CHECK: %[[cmpf_res:.*]] = arith.cmpf olt, %[[lhs]], %[[rhs]] : f32 +// CHECK: tm_tensor.yield %[[cmpf_res]] : i1 +// CHECK: } -> tensor<2x3xf32>, tensor<2x3xi64> +// CHECK-DAG: %[[indices:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#1 : tensor<2x3xi64> -> !torch.vtensor<[2,3],si64> +// CHECK-DAG: %[[values:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[values]], %[[indices]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64> +func.func @test_topk_static_dims(%t: !torch.vtensor<[2,5],f32>) -> + (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64>) { + %k = torch.constant.int 3 + %dim = torch.constant.int 1 + %largest = torch.constant.bool false + %sorted = torch.constant.bool false + %values, %indices = torch.aten.topk %t, %k, %dim, %largest, %sorted : + !torch.vtensor<[2,5],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> + !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64> + return %values, %indices : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_topk_with_dynamic_dim_and_k +// CHECK-SAME: %[[t:.*]]: !torch.vtensor<[?,8,32],f32>, +// CHECK-SAME: %[[k:.*]]: !torch.int +// CHECK-DAG: %[[negative_inf:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[t_as_tensor:.*]] = torch_c.to_builtin_tensor %[[t]] : !torch.vtensor<[?,8,32],f32> -> tensor +// CHECK: %[[dim_0_size:.*]] = tensor.dim %[[t_as_tensor]], %[[c0]] : tensor +// CHECK: %[[k_as_i64:.*]] = torch_c.to_i64 %[[k]] +// CHECK: %[[k_as_index:.*]] = arith.index_cast %[[k_as_i64]] : i64 to index +// CHECK: %[[empty_like_t:.*]] = tensor.empty(%[[dim_0_size]], %[[k_as_index]]) : tensor +// CHECK: %[[neg_inf_like_t:.*]] = linalg.fill ins(%[[negative_inf]] : f32) outs(%[[empty_like_t]] : tensor) -> tensor +// CHECK: %[[empty_indices:.]] = tensor.empty(%[[dim_0_size]], %[[k_as_index]]) : tensor +// CHECK: %[[topk_result:.*]]:2 = tm_tensor.topk dimension(2) ins(%[[t_as_tensor]] : tensor) outs(%[[neg_inf_like_t]], %[[empty_indices]] : tensor, tensor) { +// CHECK: ^bb0(%[[lhs:.*]]: f32, %[[rhs:.*]]: f32): +// CHECK: %[[cmp_res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32 +// CHECK: tm_tensor.yield %[[cmp_res]] : i1 +// CHECK: } -> tensor, tensor +// CHECK-DAG: %[[indices:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#1 : tensor -> !torch.vtensor<[?,8,?],si64> +// CHECK-DAG: %[[values:.*]] = torch_c.from_builtin_tensor %[[topk_result]]#0 : tensor -> !torch.vtensor<[?,8,?],f32> +// CHECK: return %[[values]], %[[indices]] : !torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64> +func.func @test_topk_with_dynamic_dim_and_k(%t: !torch.vtensor<[?,8,32],f32>, %k: !torch.int) -> + (!torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64>) { + %dim = torch.constant.int -1 + %largest = torch.constant.bool true + %sorted = torch.constant.bool false + %values, %indices = torch.aten.topk %t, %k, %dim, %largest, %sorted : + !torch.vtensor<[?,8,32],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> + !torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64> + return %values, %indices : !torch.vtensor<[?,8,?],f32>, !torch.vtensor<[?,8,?],si64> +}