Skip to content

Commit 37d5114

Browse files
Add matrix multiplication operator for torch_tensor (#1379)
* Add matrix multiplication operator for torch_tensor * Add conditional namespace * add test case --------- Co-authored-by: Daniel Falbel <[email protected]>
1 parent 31701c5 commit 37d5114

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ export(with_enable_grad)
877877
export(with_no_grad)
878878
export(with_torch_manual_seed)
879879
export(yield)
880+
if (getRversion() >= "4.3.0") S3method("%*%", torch_tensor)
880881
importFrom(Rcpp,sourceCpp)
881882
importFrom(bit64,as.integer64)
882883
importFrom(callr,r)

R/operators.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@
175175
torch_logical_not(x)
176176
}
177177

178+
#' @rawNamespace if (getRversion() >= "4.3.0") S3method("%*%", torch_tensor)
179+
`%*%.torch_tensor` <- function(e1, e2) {
180+
if (!is_torch_tensor(e2)) {
181+
e2 <- torch_tensor(e2, device = e1$device)
182+
}
183+
torch_matmul(e1, e2)
184+
}
185+
178186
#' @export
179187
dim.torch_tensor <- function(x) {
180188
cpp_tensor_dim(x$ptr)

tests/testthat/test-operators.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,22 @@ test_that("! works", {
121121
expect_equal_to_r(!y, c(TRUE, FALSE))
122122
})
123123

124+
test_that("%*% works", {
125+
if (getRversion() < "4.3.0") {
126+
skip("%*% S3 operator for torch_tensor requires R >= 4.3.0")
127+
}
128+
129+
lhs <- torch_tensor(matrix(c(1, 2, 3, 4), nrow = 2))
130+
rhs_tensor <- torch_tensor(matrix(c(0, 1, 1, 0), nrow = 2))
131+
rhs_matrix <- matrix(c(2, 0, 1, 2), nrow = 2)
132+
133+
lhs_r <- as.matrix(as_array(lhs))
134+
rhs_tensor_r <- as.matrix(as_array(rhs_tensor))
135+
136+
expect_equal_to_r(lhs %*% rhs_tensor, lhs_r %*% rhs_tensor_r)
137+
expect_equal_to_r(lhs %*% rhs_matrix, lhs_r %*% rhs_matrix)
138+
})
139+
124140
test_that("dim works", {
125141
x <- torch_randn(c(2, 2))
126142
expect_equal(dim(x), c(2, 2))
@@ -357,6 +373,24 @@ test_that("| works", {
357373
expect_equal_to_r(x | 0, c(FALSE, TRUE))
358374
})
359375

376+
test_that("%*% works on cuda", {
377+
if (getRversion() < "4.3.0") {
378+
skip("%*% S3 operator for torch_tensor requires R >= 4.3.0")
379+
}
380+
381+
skip_if_cuda_not_available()
382+
383+
lhs <- torch_tensor(matrix(c(1, 2, 3, 4), nrow = 2), device = "cuda")
384+
rhs_tensor <- torch_tensor(matrix(c(0, 1, 1, 0), nrow = 2), device = "cuda")
385+
rhs_matrix <- matrix(c(2, 0, 1, 2), nrow = 2)
386+
387+
lhs_r <- as.matrix(as_array(lhs$cpu()))
388+
rhs_tensor_r <- as.matrix(as_array(rhs_tensor$cpu()))
389+
390+
expect_equal_to_r(lhs %*% rhs_tensor, lhs_r %*% rhs_tensor_r)
391+
expect_equal_to_r(lhs %*% rhs_matrix, lhs_r %*% rhs_matrix)
392+
})
393+
360394
test_that("mean works", {
361395
x <- c(1, 2, 3, 4)
362396

0 commit comments

Comments
 (0)