Skip to content

Commit bd396ab

Browse files
committed
[fix:] fix test on windows
1 parent 42c9188 commit bd396ab

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package org.jetbrains.kotlinx.multik.jni
2+
3+
import org.jetbrains.kotlinx.multik.ndarray.data.Dimension
4+
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
5+
import org.jetbrains.kotlinx.multik.ndarray.operations.map
6+
import java.math.BigDecimal
7+
import java.math.RoundingMode
8+
9+
fun <D: Dimension> roundDouble(ndarray: NDArray<Double, D>): NDArray<Double, D> =
10+
ndarray.map { BigDecimal(it).setScale(2, RoundingMode.HALF_UP).toDouble() }
11+
12+
fun <D: Dimension> roundFloat(ndarray: NDArray<Float, D>): NDArray<Float, D> =
13+
ndarray.map { BigDecimal(it.toDouble()).setScale(2, RoundingMode.HALF_UP).toFloat() }

multik-native/src/test/kotlin/org/jetbrains/kotlinx/multik/jni/linalg/NativeLinAlgTest.kt

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import org.jetbrains.kotlinx.multik.api.mk
44
import org.jetbrains.kotlinx.multik.api.ndarray
55
import org.jetbrains.kotlinx.multik.jni.Loader
66
import org.jetbrains.kotlinx.multik.jni.NativeLinAlg
7+
import org.jetbrains.kotlinx.multik.jni.roundDouble
8+
import org.jetbrains.kotlinx.multik.jni.roundFloat
79
import kotlin.test.BeforeTest
810
import kotlin.test.Test
911
import kotlin.test.assertEquals
@@ -18,10 +20,10 @@ class NativeLinAlgTest {
1820
@Test
1921
fun `matrix-matrix dot test D`() {
2022
val expected = mk.ndarray(
21-
mk[mk[1.0718999999999999, 0.6181, 0.46080000000000004, 0.48109999999999997],
22-
mk[0.8210999999999999, 0.7162, 0.79, 0.8199],
23-
mk[0.5287999999999999, 0.48339999999999994, 0.5342, 0.5082],
24-
mk[1.0353, 0.758, 0.7114, 0.6647]]
23+
mk[mk[1.07, 0.62, 0.46, 0.48],
24+
mk[0.82, 0.72, 0.79, 0.82],
25+
mk[0.53, 0.48, 0.53, 0.51],
26+
mk[1.04, 0.76, 0.71, 0.66]]
2527
)
2628
val matrix1 = mk.ndarray(
2729
mk[mk[0.22, 0.9, 0.27],
@@ -36,16 +38,16 @@ class NativeLinAlgTest {
3638
)
3739

3840
val actual = NativeLinAlg.dot(matrix1, matrix2)
39-
assertEquals(expected, actual)
41+
assertEquals(expected, roundDouble(actual))
4042
}
4143

4244
@Test
4345
fun `matrix-matrix dot test F`() {
4446
val expected = mk.ndarray(
45-
mk[mk[1.0719f, 0.6181f, 0.4608f, 0.4811f],
46-
mk[0.8211f, 0.7162f, 0.79f, 0.8199f],
47-
mk[0.52879995f, 0.48339996f, 0.5342f, 0.5082f],
48-
mk[1.0353f, 0.758f, 0.71140003f, 0.6647f]]
47+
mk[mk[1.07f, 0.62f, 0.46f, 0.48f],
48+
mk[0.82f, 0.72f, 0.79f, 0.82f],
49+
mk[0.53f, 0.48f, 0.53f, 0.51f],
50+
mk[1.04f, 0.76f, 0.71f, 0.66f]]
4951
)
5052
val matrix1 = mk.ndarray(
5153
mk[mk[0.22f, 0.9f, 0.27f],
@@ -60,12 +62,12 @@ class NativeLinAlgTest {
6062
)
6163

6264
val actual = NativeLinAlg.dot(matrix1, matrix2)
63-
assertEquals(expected, actual)
65+
assertEquals(expected, roundFloat(actual))
6466
}
6567

6668
@Test
6769
fun `matrix-vector dot test D`() {
68-
val expected = mk.ndarray(mk[0.8006, 0.663, 0.5771])
70+
val expected = mk.ndarray(mk[0.80, 0.66, 0.58])
6971

7072
val matrix = mk.ndarray(
7173
mk[mk[0.22, 0.9, 0.27],
@@ -75,12 +77,12 @@ class NativeLinAlgTest {
7577
val vector = mk.ndarray(mk[0.08, 0.63, 0.8])
7678

7779
val actual = NativeLinAlg.dot(matrix, vector)
78-
assertEquals(expected, actual)
80+
assertEquals(expected, roundDouble(actual))
7981
}
8082

8183
@Test
8284
fun `matrix-vector dot test F`() {
83-
val expected = mk.ndarray(mk[0.8006f, 0.663f, 0.5771f])
85+
val expected = mk.ndarray(mk[0.80f, 0.66f, 0.58f])
8486

8587
val matrix = mk.ndarray(
8688
mk[mk[0.22f, 0.9f, 0.27f],
@@ -90,6 +92,6 @@ class NativeLinAlgTest {
9092
val vector = mk.ndarray(mk[0.08f, 0.63f, 0.8f])
9193

9294
val actual = NativeLinAlg.dot(matrix, vector)
93-
assertEquals(expected, actual)
95+
assertEquals(expected, roundFloat(actual))
9496
}
9597
}

multik-native/src/test/kotlin/org/jetbrains/kotlinx/multik/jni/math/NativeMathTest.kt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ import org.jetbrains.kotlinx.multik.api.mk
44
import org.jetbrains.kotlinx.multik.api.ndarray
55
import org.jetbrains.kotlinx.multik.jni.Loader
66
import org.jetbrains.kotlinx.multik.jni.NativeMath
7+
import org.jetbrains.kotlinx.multik.jni.roundDouble
78
import org.jetbrains.kotlinx.multik.ndarray.data.D2
89
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
10+
import java.math.BigDecimal
11+
import java.math.RoundingMode
912
import kotlin.test.BeforeTest
1013
import kotlin.test.Test
1114
import kotlin.test.assertEquals
@@ -31,26 +34,26 @@ class NativeMathTest {
3134

3235
@Test
3336
fun expTest() {
34-
val expected = mk.ndarray(mk[mk[298.8673439626328, 2.9446796774374633], mk[9136.200570869447, 34.813315827573724]])
35-
assertEquals(expected, NativeMath.exp(ndarray))
37+
val expected = mk.ndarray(mk[mk[298.87, 2.94], mk[9136.2, 34.81]])
38+
assertEquals(expected, roundDouble(NativeMath.exp(ndarray)))
3639
}
3740

3841
@Test
3942
fun logTest() {
40-
val expected = mk.ndarray(mk[mk[1.7404661413782472, 0.07696108087255739], mk[2.2104697915378937, 1.2669475900552918]])
41-
assertEquals(expected, NativeMath.log(ndarray))
43+
val expected = mk.ndarray(mk[mk[1.74, 0.08], mk[2.21, 1.27]])
44+
assertEquals(expected, roundDouble(NativeMath.log(ndarray)))
4245
}
4346

4447
@Test
4548
fun sinTest() {
46-
val expected = mk.ndarray(mk[mk[-0.5506857018064566, 0.8819578271121656], mk[0.300081485531831, -0.39714812352401446]])
47-
assertEquals(expected, NativeMath.sin(ndarray))
49+
val expected = mk.ndarray(mk[mk[-0.55, 0.88], mk[0.3, -0.4]])
50+
assertEquals(expected, roundDouble(NativeMath.sin(ndarray)))
4851
}
4952

5053
@Test
5154
fun cosTest() {
52-
val expected = mk.ndarray(mk[mk[0.8347126798042128, 0.47132832632421673], mk[-0.9539135715781643, -0.9177545249037752]])
53-
assertEquals(expected, NativeMath.cos(ndarray))
55+
val expected = mk.ndarray(mk[mk[0.83, 0.47], mk[-0.95, -0.92]])
56+
assertEquals(expected, roundDouble(NativeMath.cos(ndarray)))
5457
}
5558

5659
@Test

0 commit comments

Comments
 (0)