Skip to content

Commit af126cd

Browse files
committed
add computation of matrix norm in native
1 parent 3940ffa commit af126cd

File tree

28 files changed

+189
-135
lines changed

28 files changed

+189
-135
lines changed

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/linalg/LinAlg.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ public interface LinAlg {
2121
*/
2222
public fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2>
2323

24-
/**
25-
* Matrix ov vector norm. The default is Frobenius norm.
26-
*/
27-
public fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int = 2): Double
24+
// /**
25+
// * Matrix ov vector norm. The default is Frobenius norm.
26+
// */
27+
// public fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int = 2): Double
2828
}

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/linalg/LinAlgEx.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public interface LinAlgEx {
2121
public fun <D : Dim2> solveF(a: MultiArray<Float, D2>, b: MultiArray<Float, D>): NDArray<Float, D>
2222
public fun <T : Complex, D : Dim2> solveC(a: MultiArray<T, D2>, b: MultiArray<T, D>): NDArray<T, D>
2323

24+
public fun normF(mat: MultiArray<Float, D2>, norm: Norm = Norm.Fro): Float
25+
26+
public fun norm(mat: MultiArray<Double, D2>, norm: Norm = Norm.Fro): Double
27+
2428
public fun <T : Number> qr(mat: MultiArray<T, D2>): Pair<D2Array<Double>, D2Array<Double>>
2529
public fun qrF(mat: MultiArray<Float, D2>): Pair<D2Array<Float>, D2Array<Float>>
2630
public fun <T : Complex> qrC(mat: MultiArray<T, D2>): Pair<D2Array<T>, D2Array<T>>
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.jetbrains.kotlinx.multik.api.linalg
2+
3+
public enum class Norm(public val lapackCode: Char) {
4+
Max('M'),
5+
N1('1'),
6+
Inf('I'),
7+
Fro('F')
8+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package org.jetbrains.kotlinx.multik.api.linalg
2+
3+
import org.jetbrains.kotlinx.multik.ndarray.data.D2
4+
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
5+
import kotlin.jvm.JvmName
6+
7+
/**
8+
* Returns norm of float matrix
9+
*/
10+
@JvmName("normF")
11+
public fun LinAlg.norm(mat: MultiArray<Float, D2>, norm: Norm = Norm.Fro): Float = this.linAlgEx.normF(mat, norm)
12+
13+
/**
14+
* Returns norm of float matrix
15+
*/
16+
@JvmName("normD")
17+
public fun LinAlg.norm(mat: MultiArray<Double, D2>, norm: Norm = Norm.Fro): Double = this.linAlgEx.norm(mat, norm)

multik-default/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/default/linalg/DefaultLinAlg.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,4 @@ public expect object DefaultLinAlg : LinAlg {
1515
override val linAlgEx: LinAlgEx
1616

1717
override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2>
18-
19-
override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double
2018
}

multik-default/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/default/linalg/DefaultLinAlgEx.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.jetbrains.kotlinx.multik.default.linalg
66

77
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
8+
import org.jetbrains.kotlinx.multik.api.linalg.Norm
89
import org.jetbrains.kotlinx.multik.ndarray.complex.Complex
910
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexDouble
1011
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexFloat
@@ -21,6 +22,8 @@ public expect object DefaultLinAlgEx: LinAlgEx {
2122

2223
override fun <D : Dim2> solveF(a: MultiArray<Float, D2>, b: MultiArray<Float, D>): NDArray<Float, D>
2324
override fun <T : Complex, D : Dim2> solveC(a: MultiArray<T, D2>, b: MultiArray<T, D>): NDArray<T, D>
25+
override fun normF(mat: MultiArray<Float, D2>, norm: Norm): Float
26+
override fun norm(mat: MultiArray<Double, D2>, norm: Norm): Double
2427
override fun <T : Number> qr(mat: MultiArray<T, D2>): Pair<D2Array<Double>, D2Array<Double>>
2528
override fun qrF(mat: MultiArray<Float, D2>): Pair<D2Array<Float>, D2Array<Float>>
2629
override fun <T : Complex> qrC(mat: MultiArray<T, D2>): Pair<D2Array<T>, D2Array<T>>

multik-default/src/iosMain/kotlin/org.jetbrains.kotlinx.multik.default/linalg/DefaultLinAlg.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,4 @@ public actual object DefaultLinAlg : LinAlg {
1717
get() = DefaultLinAlgEx
1818

1919
actual override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> = KELinAlg.pow(mat, n)
20-
21-
actual override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat, p)
2220
}

multik-default/src/iosMain/kotlin/org.jetbrains.kotlinx.multik.default/linalg/DefaultLinAlgEx.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.jetbrains.kotlinx.multik.default.linalg
66

77
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
8+
import org.jetbrains.kotlinx.multik.api.linalg.Norm
89
import org.jetbrains.kotlinx.multik.api.linalg.dot
910
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
1011
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx
@@ -30,6 +31,10 @@ public actual object DefaultLinAlgEx : LinAlgEx {
3031
actual override fun <T : Complex, D : Dim2> solveC(a: MultiArray<T, D2>, b: MultiArray<T, D>): NDArray<T, D> =
3132
KELinAlgEx.solveC(a, b)
3233

34+
actual override fun normF(mat: MultiArray<Float, D2>, norm: Norm): Float = KELinAlgEx.normF(mat, norm)
35+
36+
actual override fun norm(mat: MultiArray<Double, D2>, norm: Norm): Double = KELinAlgEx.norm(mat, norm)
37+
3338
actual override fun <T : Number> qr(mat: MultiArray<T, D2>): Pair<D2Array<Double>, D2Array<Double>> =
3439
KELinAlgEx.qr(mat)
3540

multik-default/src/jsMain/kotlin/org.jetbrains.kotlinx.multik.default/linalg/DefaultLinAlg.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,4 @@ public actual object DefaultLinAlg : LinAlg {
1717
get() = DefaultLinAlgEx
1818

1919
actual override fun <T : Number> pow(mat: MultiArray<T, D2>, n: Int): NDArray<T, D2> = KELinAlg.pow(mat, n)
20-
21-
actual override fun <T : Number> norm(mat: MultiArray<T, D2>, p: Int): Double = KELinAlg.norm(mat, p)
2220
}

multik-default/src/jsMain/kotlin/org.jetbrains.kotlinx.multik.default/linalg/DefaultLinAlgEx.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package org.jetbrains.kotlinx.multik.default.linalg
66

77
import org.jetbrains.kotlinx.multik.api.linalg.LinAlgEx
8+
import org.jetbrains.kotlinx.multik.api.linalg.Norm
89
import org.jetbrains.kotlinx.multik.api.linalg.dot
910
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
1011
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx
@@ -29,6 +30,10 @@ public actual object DefaultLinAlgEx : LinAlgEx {
2930
actual override fun <T : Complex, D : Dim2> solveC(a: MultiArray<T, D2>, b: MultiArray<T, D>): NDArray<T, D> =
3031
KELinAlgEx.solveC(a, b)
3132

33+
actual override fun normF(mat: MultiArray<Float, D2>, norm: Norm): Float = KELinAlgEx.normF(mat, norm)
34+
35+
actual override fun norm(mat: MultiArray<Double, D2>, norm: Norm): Double = KELinAlgEx.norm(mat, norm)
36+
3237
actual override fun <T : Number> qr(mat: MultiArray<T, D2>): Pair<D2Array<Double>, D2Array<Double>> =
3338
KELinAlgEx.qr(mat)
3439

0 commit comments

Comments
 (0)