Skip to content

Conversation

@sahas3
Copy link
Member

@sahas3 sahas3 commented Nov 3, 2025

This change replaces usage of non-finite value inf with finite value realmax for init value of various max/min operations -- no change in semantics of the ops.

@sahas3 sahas3 requested a review from zjgarvey November 3, 2025 02:00
@sahas3 sahas3 changed the title [linalg] : Use -realmax instead of -inf for MaxPool init. [linalg] : Use (-)realmax instead of (-)inf to avoid usage of non-finites. Nov 3, 2025
Comment on lines +120 to 121
APFloat::getLargest(
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be good to add testpoints for torch.aten.min.dim and torch.aten.max.dim test to basic.mlir to lock this down.

Also, does TorchToTosa not support this and pooling with padding?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By testpoints do you mean e2e tests or lit tests?

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how much this practically matters, but min(x, inf) is only the same as min(x, realmax) if x is finite, so this change will possibly result in incorrect results when the input tensors have non-finites. E.g., if you had an e2e test for torch.min and the input was a splat torch.inf tensor, then pytorch would return inf and not realmax.

Comment on lines +120 to 121
APFloat::getLargest(
cast<mlir::FloatType>(inElementType).getFloatSemantics(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By testpoints do you mean e2e tests or lit tests?

@sahas3
Copy link
Member Author

sahas3 commented Nov 6, 2025

I'm not sure how much this practically matters, but min(x, inf) is only the same as min(x, realmax) if x is finite, so this change will possibly result in incorrect results when the input tensors have non-finites. E.g., if you had an e2e test for torch.min and the input was a splat torch.inf tensor, then pytorch would return inf and not realmax.

Great point, I hadn't thought of that.

It appears that without this change, if I pass a tensor of all -inf to torch.MaxPool2d the result in the linalg-on-tensors path will be a tensor of all -inf, but in the tosa path it is -realmax. This is due to https://github.com/llvm/llvm-project/blob/d2f75f2fe3261264bb4692368aab64aaafb30f08/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp#L750
After this change linalg-on-tensors also produces tensor of all -realmax as you've mentioned.
cc-ing @sjarus just to notify this behavior for the tosa path.

It seems wrong to me to knowingly introduce this discrepancy with torch even though the scenario doesn't make sense in practice. One alternative is to introduce a flag, such as SupportNonFinites, through the pass-pipeline (and probably fx.export_and_import) which by default will be True and preserve the current behavior of linalg-on-tensors path. If set to false it'll replace inf with realmax and deviate from torch's behavior. The motivation for us is to honor this behavior https://www.mathworks.com/help/coder/ug/run-time-error-checks-1.html for our product.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants