Skip to content

Conversation

@caiohamamura
Copy link
Contributor

A very common operation we need to perform when using tensors is the matrix multiplication. This is just a generic cast of %*% operator to torch_matmul.

#' @export
`%*%.torch_tensor` <- function(e1, e2) {
  if (!is_torch_tensor(e2)) {
    e2 <- torch_tensor(e2, device = e1$device)
  }
  torch_matmul(e1, e2)
}

@dfalbel
Copy link
Member

dfalbel commented Nov 11, 2025

Muito obrigado @caiohamamura !

We probably want something like @rawNamespace if (getRversion() >= "4.3.0") S3method("%*%", torch_tensor) instead of #' @export to support older R versions.

@caiohamamura
Copy link
Contributor Author

caiohamamura commented Nov 12, 2025

Do you want me to make that change? Then maybe that should also be done for every method then? Because all other operators are using only #' @export too.

@dfalbel
Copy link
Member

dfalbel commented Nov 12, 2025

I can do it! I don't think we need for all methods. %*% is a special case because it has been recenttly introduced as an S3 generic in R 4.3.

https://cran.r-project.org/doc/manuals/r-release/NEWS.html

The matrix multiply operator %*% is now an S3 generic, belonging to new group generic matrixOps. From Tomasz Kalinowski's contribution in PR#18483.

@dfalbel dfalbel merged commit 37d5114 into mlverse:main Nov 12, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants