Skip to content

Commit e0420a2

Browse files
authored
Merge pull request #119 from Kotlin/v0.2.0
Fix reshape bug and add stat for native
2 parents 85860ba + 03cb303 commit e0420a2

File tree

33 files changed

+652
-61
lines changed

33 files changed

+652
-61
lines changed

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
kotlin.code.style=official
2-
multik_version=0.2.0-dev-3
2+
multik_version=0.2.0-dev-4
33

44
# Kotlin
55
systemProp.kotlin_version=1.7.10

multik-core/build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,13 @@ kotlin {
7070
implementation("org.apache.commons:commons-csv:$common_csv_version")
7171
}
7272
}
73-
val nativeCommonMain by creating {
73+
val nativeMain by creating {
7474
dependsOn(commonMain)
7575
}
7676
names.forEach { n ->
7777
if (n.contains("X64Main") || n.contains("Arm64Main")){
7878
this@sourceSets.getByName(n).apply{
79-
dependsOn(nativeCommonMain)
79+
dependsOn(nativeMain)
8080
}
8181
}
8282
}

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/Engine.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package org.jetbrains.kotlinx.multik.api
66

77
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
88
import org.jetbrains.kotlinx.multik.api.math.Math
9+
import org.jetbrains.kotlinx.multik.api.stat.Statistics
910

1011
public sealed class EngineType(public val name: String)
1112

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/Multik.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.multik.api.Multik.math
1111
import org.jetbrains.kotlinx.multik.api.Multik.stat
1212
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
1313
import org.jetbrains.kotlinx.multik.api.math.Math
14+
import org.jetbrains.kotlinx.multik.api.stat.Statistics
1415

1516
/**
1617
* Abbreviated name for [Multik].
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package org.jetbrains.kotlinx.multik.api
6+
7+
import org.jetbrains.kotlinx.multik.ndarray.data.*
8+
9+
/**
10+
* Returns an D2Array from Array<ByteArray>.
11+
*/
12+
public fun Multik.ndarray(args: Array<ByteArray>): D2Array<Byte> {
13+
val dim0 = args.size
14+
val dim1 = args[0].size
15+
require(args.all { dim1 == it.size }) { "Arrays must be the same size." }
16+
17+
val array = ByteArray(dim0 * dim1)
18+
var index = 0
19+
for (i in 0 until dim0) {
20+
for (j in 0 until dim1) {
21+
array[index++] = args[i][j]
22+
}
23+
}
24+
val data = MemoryViewByteArray(array)
25+
return D2Array(data, shape = intArrayOf(dim0, dim1), dim = D2)
26+
}
27+
28+
/**
29+
* Returns an D2Array from Array<ShortArray>.
30+
*/
31+
public fun Multik.ndarray(args: Array<ShortArray>): D2Array<Short> {
32+
val dim0 = args.size
33+
val dim1 = args[0].size
34+
require(args.all { dim1 == it.size }) { "Arrays must be the same size." }
35+
36+
val array = ShortArray(dim0 * dim1)
37+
var index = 0
38+
for (i in 0 until dim0) {
39+
for (j in 0 until dim1) {
40+
array[index++] = args[i][j]
41+
}
42+
}
43+
val data = MemoryViewShortArray(array)
44+
return D2Array(data, shape = intArrayOf(dim0, dim1), dim = D2)
45+
}
46+
47+
/**
48+
* Returns an D2Array from Array<IntArray>.
49+
*/
50+
public fun Multik.ndarray(args: Array<IntArray>): D2Array<Int> {
51+
val dim0 = args.size
52+
val dim1 = args[0].size
53+
require(args.all { dim1 == it.size }) { "Arrays must be the same size." }
54+
55+
val array = IntArray(dim0 * dim1)
56+
var index = 0
57+
for (i in 0 until dim0) {
58+
for (j in 0 until dim1) {
59+
array[index++] = args[i][j]
60+
}
61+
}
62+
val data = MemoryViewIntArray(array)
63+
return D2Array(data, shape = intArrayOf(dim0, dim1), dim = D2)
64+
}
65+
66+
/**
67+
* Returns an D2Array from Array<LongArray>.
68+
*/
69+
public fun Multik.ndarray(args: Array<LongArray>): D2Array<Long> {
70+
val dim0 = args.size
71+
val dim1 = args[0].size
72+
require(args.all { dim1 == it.size }) { "Arrays must be the same size." }
73+
74+
val array = LongArray(dim0 * dim1)
75+
var index = 0
76+
for (i in 0 until dim0) {
77+
for (j in 0 until dim1) {
78+
array[index++] = args[i][j]
79+
}
80+
}
81+
val data = MemoryViewLongArray(array)
82+
return D2Array(data, shape = intArrayOf(dim0, dim1), dim = D2)
83+
}
84+
85+
/**
86+
* Returns an D2Array from Array<FloatArray>.
87+
*/
88+
public fun Multik.ndarray(args: Array<FloatArray>): D2Array<Float> {
89+
val dim0 = args.size
90+
val dim1 = args[0].size
91+
require(args.all { dim1 == it.size }) { "Arrays must be the same size." }
92+
93+
val array = FloatArray(dim0 * dim1)
94+
var index = 0
95+
for (i in 0 until dim0) {
96+
for (j in 0 until dim1) {
97+
array[index++] = args[i][j]
98+
}
99+
}
100+
val data = MemoryViewFloatArray(array)
101+
return D2Array(data, shape = intArrayOf(dim0, dim1), dim = D2)
102+
}
103+
104+
/**
105+
* Returns an D2Array from Array<DoubleArray>.
106+
*/
107+
public fun Multik.ndarray(args: Array<DoubleArray>): D2Array<Double> {
108+
val dim0 = args.size
109+
val dim1 = args[0].size
110+
require(args.all { dim1 == it.size }) { "Arrays must be the same size." }
111+
112+
val array = DoubleArray(dim0 * dim1)
113+
var index = 0
114+
for (i in 0 until dim0) {
115+
for (j in 0 until dim1) {
116+
array[index++] = args[i][j]
117+
}
118+
}
119+
val data = MemoryViewDoubleArray(array)
120+
return D2Array(data, shape = intArrayOf(dim0, dim1), dim = D2)
121+
}
122+
123+
/**
124+
* Returns an D2Array.
125+
*/
126+
public fun Array<ByteArray>.toNDArray(): D2Array<Byte> = Multik.ndarray(this)
127+
128+
/**
129+
* Returns an D2Array.
130+
*/
131+
public fun Array<ShortArray>.toNDArray(): D2Array<Short> = Multik.ndarray(this)
132+
133+
/**
134+
* Returns an D2Array.
135+
*/
136+
public fun Array<IntArray>.toNDArray(): D2Array<Int> = Multik.ndarray(this)
137+
138+
/**
139+
* Returns an D2Array.
140+
*/
141+
public fun Array<LongArray>.toNDArray(): D2Array<Long> = Multik.ndarray(this)
142+
143+
/**
144+
* Returns an D2Array.
145+
*/
146+
public fun Array<FloatArray>.toNDArray(): D2Array<Float> = Multik.ndarray(this)
147+
148+
/**
149+
* Returns an D2Array.
150+
*/
151+
public fun Array<DoubleArray>.toNDArray(): D2Array<Double> = Multik.ndarray(this)
152+
153+

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/Statistics.kt renamed to multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/stat/Statistics.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/*
2-
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

5-
package org.jetbrains.kotlinx.multik.api
5+
package org.jetbrains.kotlinx.multik.api.stat
66

77
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexDouble
88
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexFloat

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/data/MultiArrays.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
package org.jetbrains.kotlinx.multik.ndarray.data
@@ -460,7 +460,7 @@ public operator fun <T> MultiArray<T, D4>.get(ind1: Int, ind2: Int, ind3: Closed
460460
slice(mapOf(0 to ind1.r, 1 to ind2.r, 2 to ind3.toSlice(), 3 to ind4.toSlice()))
461461

462462
@JvmName("get44")
463-
public operator fun <T> MultiArray<T, D4>.get(ind1: Int, ind2: ClosedRange<Int>, ind3: Slice, ind4: Int): MultiArray<T, D2> =
463+
public operator fun <T> MultiArray<T, D4>.get(ind1: Int, ind2: ClosedRange<Int>, ind3: ClosedRange<Int>, ind4: Int): MultiArray<T, D2> =
464464
slice(mapOf(0 to ind1.r, 1 to ind2.toSlice(), 2 to ind3.toSlice(), 3 to ind4.r))
465465

466466
@JvmName("get45")

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/data/NDArray.kt

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,14 @@ public class NDArray<T, D : Dimension> constructor(
107107
requirePositiveShape(dim1)
108108
require(dim1 == size) { "Cannot reshape array of size $size into a new shape ($dim1)" }
109109

110+
// TODO(get rid of copying)
111+
val newData = if (consistent) this.data else this.deepCopy().data
112+
val newBase = if (consistent) base ?: this else null
113+
110114
return if (this.dim.d == 1 && this.shape.first() == dim1) {
111115
this as D1Array<T>
112116
} else {
113-
D1Array(this.data, this.offset, intArrayOf(dim1), dim = D1, base = base ?: this)
117+
D1Array(newData, this.offset, intArrayOf(dim1), dim = D1, base = newBase)
114118
}
115119
}
116120

@@ -119,10 +123,14 @@ public class NDArray<T, D : Dimension> constructor(
119123
newShape.forEach { requirePositiveShape(it) }
120124
require(dim1 * dim2 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2)" }
121125

126+
// TODO(get rid of copying)
127+
val newData = if (consistent) this.data else this.deepCopy().data
128+
val newBase = if (consistent) base ?: this else null
129+
122130
return if (this.shape.contentEquals(newShape)) {
123131
this as D2Array<T>
124132
} else {
125-
D2Array(this.data, this.offset, newShape, dim = D2, base = base ?: this)
133+
D2Array(newData, this.offset, newShape, dim = D2, base = newBase)
126134
}
127135
}
128136

@@ -131,10 +139,14 @@ public class NDArray<T, D : Dimension> constructor(
131139
newShape.forEach { requirePositiveShape(it) }
132140
require(dim1 * dim2 * dim3 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2, $dim3)" }
133141

142+
// TODO(get rid of copying)
143+
val newData = if (consistent) this.data else this.deepCopy().data
144+
val newBase = if (consistent) base ?: this else null
145+
134146
return if (this.shape.contentEquals(newShape)) {
135147
this as D3Array<T>
136148
} else {
137-
D3Array(this.data, this.offset, newShape, dim = D3, base = base ?: this)
149+
D3Array(newData, this.offset, newShape, dim = D3, base = newBase)
138150
}
139151
}
140152

@@ -143,10 +155,14 @@ public class NDArray<T, D : Dimension> constructor(
143155
newShape.forEach { requirePositiveShape(it) }
144156
require(dim1 * dim2 * dim3 * dim4 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2, $dim3, $dim4)" }
145157

158+
// TODO(get rid of copying)
159+
val newData = if (consistent) this.data else this.deepCopy().data
160+
val newBase = if (consistent) base ?: this else null
161+
146162
return if (this.shape.contentEquals(newShape)) {
147163
this as D4Array<T>
148164
} else {
149-
D4Array(this.data, this.offset, newShape, dim = D4, base = base ?: this)
165+
D4Array(newData, this.offset, newShape, dim = D4, base = newBase)
150166
}
151167
}
152168

@@ -157,10 +173,14 @@ public class NDArray<T, D : Dimension> constructor(
157173
"Cannot reshape array of size $size into a new shape ${newShape.joinToString(prefix = "(", postfix = ")")}"
158174
}
159175

176+
// TODO(get rid of copying)
177+
val newData = if (consistent) this.data else this.deepCopy().data
178+
val newBase = if (consistent) base ?: this else null
179+
160180
return if (this.shape.contentEquals(newShape)) {
161181
this as NDArray<T, DN>
162182
} else {
163-
NDArray(this.data, this.offset, newShape, dim = DN(newShape.size), base = base ?: this)
183+
NDArray(newData, this.offset, newShape, dim = DN(newShape.size), base = newBase)
164184
}
165185
}
166186

@@ -201,13 +221,11 @@ public class NDArray<T, D : Dimension> constructor(
201221
for (axis in axes.sorted()) {
202222
newShape.add(axis, 1)
203223
}
204-
return NDArray(
205-
this.data,
206-
this.offset,
207-
newShape.toIntArray(),
208-
dim = DN(newShape.size),
209-
base = base ?: this
210-
)
224+
// TODO(get rid of copying)
225+
val newData = if (consistent) this.data else this.deepCopy().data
226+
val newBase = if (consistent) base ?: this else null
227+
228+
return NDArray(newData, this.offset, newShape.toIntArray(), dim = DN(newShape.size), base = newBase)
211229
}
212230

213231
override infix fun cat(other: MultiArray<T, D>): NDArray<T, D> =

0 commit comments

Comments
 (0)