@@ -1807,7 +1807,7 @@ defmodule EXLA.Defn.ExprTest do
18071807      indices  =  Nx . tensor ( [ [ 0 ] ] ) 
18081808      updates  =  Nx . tensor ( [ 1 ] ) 
18091809
1810-       assert_equal ( indexed_add ( target ,  indices ,  updates ) ,  Nx . tensor ( [ 1 ] ,  type:  { :s ,  64 } ) ) 
1810+       assert_equal ( indexed_add ( target ,  indices ,  updates ) ,  Nx . tensor ( [ 1 ] ,  type:  { :s ,  32 } ) ) 
18111811
18121812      target  =  Nx . tensor ( [ 0 ] ) 
18131813      indices  =  Nx . tensor ( [ [ 0 ] ] ) 
@@ -1879,7 +1879,7 @@ defmodule EXLA.Defn.ExprTest do
18791879      indices  =  Nx . tensor ( [ [ 0 ] ] ) 
18801880      updates  =  Nx . tensor ( [ 1 ] ) 
18811881
1882-       assert_equal ( indexed_put ( target ,  indices ,  updates ) ,  Nx . tensor ( [ 1 ] ,  type:  { :s ,  64 } ) ) 
1882+       assert_equal ( indexed_put ( target ,  indices ,  updates ) ,  Nx . tensor ( [ 1 ] ,  type:  { :s ,  32 } ) ) 
18831883
18841884      target  =  Nx . tensor ( [ 0 ] ) 
18851885      indices  =  Nx . tensor ( [ [ 0 ] ] ) 
@@ -1963,7 +1963,7 @@ defmodule EXLA.Defn.ExprTest do
19631963    test  "computes the sum across types"  do 
19641964      assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] )  |>  sum ( ) ,  Nx . tensor ( 6 ) ) 
19651965      assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ,  type:  { :s ,  8 } )  |>  sum ( ) ,  Nx . tensor ( 6 ) ) 
1966-       assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ,  type:  { :u ,  8 } )  |>  sum ( ) ,  Nx . tensor ( 6 ,  type:  { :u ,  64 } ) ) 
1966+       assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ,  type:  { :u ,  8 } )  |>  sum ( ) ,  Nx . tensor ( 6 ,  type:  { :u ,  32 } ) ) 
19671967      assert_equal ( Nx . tensor ( [ 1.0 ,  2.0 ,  3.0 ] )  |>  sum ( ) ,  Nx . tensor ( 6.0 ) ) 
19681968
19691969      assert_equal ( 
@@ -1986,9 +1986,9 @@ defmodule EXLA.Defn.ExprTest do
19861986    defn  sum_equal ( t ) ,  do:  Nx . sum ( Nx . equal ( t ,  1.0 ) ) 
19871987
19881988    test  "does not overflow"  do 
1989-       assert_equal ( sum_equal ( Nx . tensor ( 1 ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  64 } ) ) 
1990-       assert_equal ( sum_equal ( Nx . tensor ( [ 1 ,  1 ,  1 ] ) ) ,  Nx . tensor ( 3 ,  type:  { :u ,  64 } ) ) 
1991-       assert_equal ( sum_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  64 } ) ) 
1989+       assert_equal ( sum_equal ( Nx . tensor ( 1 ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  32 } ) ) 
1990+       assert_equal ( sum_equal ( Nx . tensor ( [ 1 ,  1 ,  1 ] ) ) ,  Nx . tensor ( 3 ,  type:  { :u ,  32 } ) ) 
1991+       assert_equal ( sum_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  32 } ) ) 
19921992    end 
19931993
19941994    defn  sum_keep ( t ) ,  do:  Nx . sum ( t ,  keep_axes:  true ) 
@@ -2011,7 +2011,7 @@ defmodule EXLA.Defn.ExprTest do
20112011    test  "computes the product across types"  do 
20122012      assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] )  |>  product ( ) ,  Nx . tensor ( 6 ) ) 
20132013      assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ,  type:  { :s ,  8 } )  |>  product ( ) ,  Nx . tensor ( 6 ) ) 
2014-       assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ,  type:  { :u ,  8 } )  |>  product ( ) ,  Nx . tensor ( 6 ,  type:  { :u ,  64 } ) ) 
2014+       assert_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ,  type:  { :u ,  8 } )  |>  product ( ) ,  Nx . tensor ( 6 ,  type:  { :u ,  32 } ) ) 
20152015      assert_equal ( Nx . tensor ( [ 1.0 ,  2.0 ,  3.0 ] )  |>  product ( ) ,  Nx . tensor ( 6.0 ) ) 
20162016
20172017      assert_equal ( 
@@ -2034,9 +2034,9 @@ defmodule EXLA.Defn.ExprTest do
20342034    defn  product_equal ( t ) ,  do:  Nx . product ( Nx . equal ( t ,  1.0 ) ) 
20352035
20362036    test  "does not overflow"  do 
2037-       assert_equal ( product_equal ( Nx . tensor ( 1 ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  64 } ) ) 
2038-       assert_equal ( product_equal ( Nx . tensor ( [ 1 ,  1 ,  1 ] ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  64 } ) ) 
2039-       assert_equal ( product_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ) ) ,  Nx . tensor ( 0 ,  type:  { :u ,  64 } ) ) 
2037+       assert_equal ( product_equal ( Nx . tensor ( 1 ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  32 } ) ) 
2038+       assert_equal ( product_equal ( Nx . tensor ( [ 1 ,  1 ,  1 ] ) ) ,  Nx . tensor ( 1 ,  type:  { :u ,  32 } ) ) 
2039+       assert_equal ( product_equal ( Nx . tensor ( [ 1 ,  2 ,  3 ] ) ) ,  Nx . tensor ( 0 ,  type:  { :u ,  32 } ) ) 
20402040    end 
20412041
20422042    defn  product_keep ( t ) ,  do:  Nx . product ( t ,  keep_axes:  true ) 
@@ -2416,12 +2416,12 @@ defmodule EXLA.Defn.ExprTest do
24162416        window_max2 ( Nx . tensor ( [ [ [ 1 ,  2 ,  3 ] ,  [ 4 ,  5 ,  6 ] ] ,  [ [ 1 ,  2 ,  3 ] ,  [ 4 ,  5 ,  6 ] ] ] ) ) , 
24172417        Nx . tensor ( [ 
24182418          [ 
2419-             [ - 9_223_372_036_854_775_808 ,  - 9_223_372_036_854_775_808 ] , 
2420-             [ - 9_223_372_036_854_775_808 ,  6 ] 
2419+             [ - 2_147_483_648 ,  - 2_147_483_648 ] , 
2420+             [ - 2_147_483_648 ,  6 ] 
24212421          ] , 
24222422          [ 
2423-             [ - 9_223_372_036_854_775_808 ,  - 9_223_372_036_854_775_808 ] , 
2424-             [ - 9_223_372_036_854_775_808 ,  6 ] 
2423+             [ - 2_147_483_648 ,  - 2_147_483_648 ] , 
2424+             [ - 2_147_483_648 ,  6 ] 
24252425          ] 
24262426        ] ) 
24272427      ) 
@@ -2482,12 +2482,12 @@ defmodule EXLA.Defn.ExprTest do
24822482        window_min2 ( Nx . tensor ( [ [ [ 1 ,  2 ,  3 ] ,  [ 4 ,  5 ,  6 ] ] ,  [ [ 1 ,  2 ,  3 ] ,  [ 4 ,  5 ,  6 ] ] ] ) ) , 
24832483        Nx . tensor ( [ 
24842484          [ 
2485-             [ 9_223_372_036_854_775_807 ,   9_223_372_036_854_775_807 ] , 
2486-             [ 9_223_372_036_854_775_807 ,  3 ] 
2485+             [ 2_147_483_647 ,   2_147_483_647 ] , 
2486+             [ 2_147_483_647 ,  3 ] 
24872487          ] , 
24882488          [ 
2489-             [ 9_223_372_036_854_775_807 ,   9_223_372_036_854_775_807 ] , 
2490-             [ 9_223_372_036_854_775_807 ,  3 ] 
2489+             [ 2_147_483_647 ,   2_147_483_647 ] , 
2490+             [ 2_147_483_647 ,  3 ] 
24912491          ] 
24922492        ] ) 
24932493      ) 
0 commit comments