Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,8 +1548,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto dty = dataTy.getDtype();
Value scalar;
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
auto inf =
APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true);
auto inf = APFloat::getLargest(fpTy.getFloatSemantics(),
/*Negative=*/true);
scalar = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(),
Expand Down
14 changes: 7 additions & 7 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
Value self = adaptor.getSelf();
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
elementType, APFloat::getLargest(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
Value initValue =
rewriter.create<arith::ConstantOp>(op->getLoc(), smallestFPValueAttr);

Expand Down Expand Up @@ -693,7 +693,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
if (auto fpty = dyn_cast<mlir::FloatType>(elementType)) {
smallestValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true));
APFloat::getLargest(fpty.getFloatSemantics(), /*Negative=*/true));
} else if (auto intTy = dyn_cast<mlir::IntegerType>(elementType)) {
int64_t bw = intTy.getIntOrFloatBitWidth();
smallestValueAttr = rewriter.getIntegerAttr(
Expand Down Expand Up @@ -1379,9 +1379,9 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper {
typeConverter->convertType(op.getResult1().getType()));
Type auxTensorElementType = auxTensorType.getElementType();
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
elementType, APFloat::getLargest(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
buffVal = rewriter.create<arith::ConstantOp>(loc, elementType,
smallestFPValueAttr);
auxTensor = rewriter.create<tensor::EmptyOp>(
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
inElementType,
APFloat::getInf(
APFloat::getLargest(
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
Comment on lines +120 to 121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be good to add testpoints for torch.aten.min.dim and torch.aten.max.dim test to basic.mlir to lock this down.

Also, does TorchToTosa not support this and pooling with padding?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By testpoints do you mean e2e tests or lit tests?

/*Negative=*/isMax)));
} else if (!isUnsigned) {
Expand Down Expand Up @@ -302,7 +302,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
elementType,
APFloat::getInf(
APFloat::getLargest(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true)));
else if (isa<mlir::IntegerType>(elementType) &&
Expand All @@ -318,7 +318,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
elementType,
APFloat::getInf(
APFloat::getLargest(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/false)));
else if (isa<mlir::IntegerType>(elementType) &&
Expand Down
12 changes: 6 additions & 6 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
constAttr = DenseElementsAttr::get(
constType,
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)});
constType, {APFloat::getLargest(
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/true)});
} else if (isa<mlir::IntegerType>(elementTy)) {
constAttr = DenseElementsAttr::get(
constType,
Expand All @@ -75,9 +75,9 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
constAttr = DenseElementsAttr::get(
constType,
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
constType, {APFloat::getLargest(
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
/*negative=*/false)});
} else if (isa<mlir::IntegerType>(elementTy)) {
constAttr = DenseElementsAttr::get(
constType,
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2072,15 +2072,15 @@ class ConvertAtenKthvalueOp : public OpConversionPattern<AtenKthvalueOp> {
loc,
rewriter.getFloatAttr(
inputElementType,
APFloat::getInf(
APFloat::getLargest(
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
/*Negative=*/false)));
// min float for linalg generic op tensor
fillValLinalgFindMax = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inputElementType,
APFloat::getInf(
APFloat::getLargest(
cast<mlir::FloatType>(inputElementType).getFloatSemantics(),
/*Negative=*/true)));
} else if (!isUnsigned) {
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -805,13 +805,13 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,

// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000
// CHECK-DAG: %[[NEGMAX:.+]] = torch.constant.float -3.4028234663852886E+38
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[NEGMAX]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK: return %[[FULL]]
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
return %0 : !torch.vtensor<[2,1,4],f32>
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%false = torch.constant.bool false
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32>
Expand All @@ -33,7 +33,7 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%false = torch.constant.bool false
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32>
Expand Down Expand Up @@ -106,7 +106,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.

%4 = torch.aten.max_pool3d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32>

// CHECK: %[[MIN_VALUE:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[MIN_VALUE:.*]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[PADDED_INPUT_TENSOR:.*]] = tensor.pad %{{.*}} low[0, 0, 4, 4, 4] high[0, 0, 4, 4, 4] {
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK-NEXT: tensor.yield %[[MIN_VALUE:.*]] : f32
Expand Down
Loading