Skip to content

Commit aa534d8

Browse files
committed
add naive impl of pow
1 parent e6f2e08 commit aa534d8

File tree

7 files changed

+100
-9
lines changed

7 files changed

+100
-9
lines changed

multik-default/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/default/linalg/DefaultLinAlg.kt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,15 @@ package org.jetbrains.kotlinx.multik.default.linalg
66

77
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
88
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
9-
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
109
import org.jetbrains.kotlinx.multik.ndarray.data.D2
1110
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
1211
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
1312

14-
public object DefaultLinAlg : LinAlg {
13+
public expect object DefaultLinAlg : LinAlg {
1514

1615
override val linAlgEx: LinAlgEx
17-
get() = DefaultLinAlgEx
1816

19-
override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> {
20-
return KELinAlg.pow(mat, n)
21-
}
17+
override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2>
2218

23-
override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat, p)
19+
override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double
2420
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package org.jetbrains.kotlinx.multik.default.linalg
6+
7+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
8+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
9+
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
10+
import org.jetbrains.kotlinx.multik.ndarray.data.D2
11+
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
12+
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
13+
14+
public actual object DefaultLinAlg : LinAlg {
15+
16+
actual override val linAlgEx: LinAlgEx
17+
get() = DefaultLinAlgEx
18+
19+
actual override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> = KELinAlg.pow(mat, n)
20+
21+
actual override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat, p)
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package org.jetbrains.kotlinx.multik.default.linalg
6+
7+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
8+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
9+
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
10+
import org.jetbrains.kotlinx.multik.ndarray.data.D2
11+
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
12+
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
13+
14+
public actual object DefaultLinAlg : LinAlg {
15+
16+
actual override val linAlgEx: LinAlgEx
17+
get() = DefaultLinAlgEx
18+
19+
actual override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> = KELinAlg.pow(mat, n)
20+
21+
actual override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat, p)
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package org.jetbrains.kotlinx.multik.default.linalg
6+
7+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
8+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
9+
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
10+
import org.jetbrains.kotlinx.multik.ndarray.data.D2
11+
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
12+
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
13+
import org.jetbrains.kotlinx.multik.openblas.linalg.NativeLinAlg
14+
15+
public actual object DefaultLinAlg : LinAlg {
16+
actual override val linAlgEx: LinAlgEx
17+
get() = DefaultLinAlgEx
18+
19+
actual override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> = NativeLinAlg.pow(mat, n)
20+
21+
actual override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat)
22+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package org.jetbrains.kotlinx.multik.default.linalg
6+
7+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
8+
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
9+
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
10+
import org.jetbrains.kotlinx.multik.ndarray.data.D2
11+
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
12+
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
13+
import org.jetbrains.kotlinx.multik.openblas.linalg.NativeLinAlg
14+
15+
public actual object DefaultLinAlg : LinAlg {
16+
actual override val linAlgEx: LinAlgEx
17+
get() = DefaultLinAlgEx
18+
19+
actual override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> = NativeLinAlg.pow(mat, n)
20+
21+
actual override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat)
22+
}

multik-openblas/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlg.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ public object NativeLinAlg : LinAlg {
1818
get() = NativeLinAlgEx
1919

2020
override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> {
21-
TODO("Not yet implemented")
21+
requireSquare(mat.shape)
22+
if (n == 0) return mk.identity(mat.shape[0], mat.dtype)
23+
return if (n % 2 == 0) {
24+
val tmp = pow(mat, n / 2)
25+
NativeLinAlgEx.dotMM(tmp, tmp)
26+
} else {
27+
NativeLinAlgEx.dotMM(mat, pow(mat, n - 1))
28+
}
2229
}
2330

2431
override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double {

multik-openblas/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlgEx.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,6 @@ private fun requireDotShape(aShape: IntArray, bShape: IntArray) = require(aShape
357357
}
358358

359359

360-
private fun requireSquare(shape: IntArray) = require(shape[0] == shape[1]) {
360+
internal fun requireSquare(shape: IntArray) = require(shape[0] == shape[1]) {
361361
"Ndarray must be square: shape = ${shape.joinToString(",", "(", ")")}"
362362
}

0 commit comments

Comments
 (0)