Skip to content

[TORCH] Add support for aten.hinge_embedding_loss Op #4227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

sharavak
Copy link
Contributor

  • Decomposed hinge_embedding_loss op into Aten ops.
  • Added test cases in the e2e part.

This implementation addresses and closes #4222

@sharavak sharavak force-pushed the hinge_embedding_loss branch 2 times, most recently from 9d5b896 to ddf112a Compare June 10, 2025 14:40
@sharavak sharavak marked this pull request as ready for review June 10, 2025 14:41
@sharavak
Copy link
Contributor Author

sharavak commented Jun 10, 2025

@stellaraccident , @vivekkhandelwal1, @zjgarvey I’d be grateful if any of you could take a look at this PR. Your feedback would be greatly appreciated!

@vivekkhandelwal1 vivekkhandelwal1 self-requested a review June 16, 2025 12:13
Comment on lines +10519 to +10622
// Compute mask: target != 1
auto targetNotOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
// If target != 1 use marginClamp otherwise 0.
auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotOne, marginClamp, zero);
// Compute mask: target != -1
auto targetNotMinusOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
// If target != -1 use the original input. Otherwise 0.
auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotMinusOne, input, zero);
// Add : outputMargin + outputSelf
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
outputSelf, /*alpha=*/alpha);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing all this, you can just do:

auto result = rewriter.create<AtenWhereScalarOtherOp>(
        loc, inputTy, targetNotOne, marginClamp, input);

Copy link
Contributor Author

@sharavak sharavak Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekkhandelwal1 Thanks for the suggestion, I did try the simplified version initially, but it caused numerical validation errors in some test cases. This happens because the target tensor can sometimes have values other than just -1 and 1.

To handle this properly and stay consistent with PyTorch's semantics, I decided to explicitly check for both target == 1 and target == -1. This way, the behavior stays correct even if target have values other than just -1 and 1.

Eg:

import torch
input=torch.randn(2,3)
target=torch.randn(2,3)
torch.hinge_embedding_loss(input,target)

Output:
tensor([[1.1361, 1.0000, 1.0000],
        [1.4880, 1.1624, 1.0000]])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens because the target tensor can sometimes have values other than just -1 and 1.

In what cases and how? Since the definition says that it can contain only -1 and 1.

Copy link
Contributor Author

@sharavak sharavak Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply @vivekkhandelwal1, I got reference from the Pytorch native implementation https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182. While the official definition of hinge_embedding_loss states that the target should contain only -1 and 1, the native implementation doesn’t enforce this restriction and handles arbitrary values using at::where(target != 1, ) and at::where(target != -1,). So to stay with the Pytorch behaviour, I followed the same logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. In that case, it would be better to add this link and the justification in the comment. So that, the future users/contributors may not get confused.

@sharavak sharavak force-pushed the hinge_embedding_loss branch from ddf112a to 8d8c30b Compare June 17, 2025 17:13
@sharavak
Copy link
Contributor Author

sharavak commented Jun 17, 2025

@vivekkhandelwal1 Thanks a lot for the feedback. I've updated the code.

Comment on lines 2498 to 2541
def HingeEmbeddingLossReductionMeanModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 1), tu.rand(1, 1))


class HingeEmbeddingLossReductionSumModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.hinge_embedding_loss(input, target, reduction=2)


@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionSumModule())
def HingeEmbeddingLossReductionSumModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5), tu.rand(1, 1))


class HingeEmbeddingLossReductionNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.hinge_embedding_loss(input, target, margin=1.0)


@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionNoneModule())
def HingeEmbeddingLossReductionNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 5), tu.rand(1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All 3 of these tests have only a single value as a target; as a result, not all the paths of the lowering will get tested. Ideally, for testing purposes, the target tensor should contain a mix of -1 and 1 values. Also, how do you make sure that the target tensor contains only valid values?

Copy link
Contributor Author

@sharavak sharavak Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @vivekkhandelwal1
I will update the target tensor.

Also, how do you make sure that the target tensor contains only valid values?

For this question, I’ve addressed that in my earlier comment here: #4227 (comment). @vivekkhandelwal1.

Comment on lines +10519 to +10622
// Compute mask: target != 1
auto targetNotOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
// If target != 1 use marginClamp otherwise 0.
auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotOne, marginClamp, zero);
// Compute mask: target != -1
auto targetNotMinusOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
// If target != -1 use the original input. Otherwise 0.
auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotMinusOne, input, zero);
// Add : outputMargin + outputSelf
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
outputSelf, /*alpha=*/alpha);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens because the target tensor can sometimes have values other than just -1 and 1.

In what cases and how? Since the definition says that it can contain only -1 and 1.

@sharavak sharavak requested a review from vivekkhandelwal1 July 2, 2025 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TORCH] Add support for aten.hinge_embedding_loss
2 participants