Skip to content

Commit 85860ba

Browse files
committed
wrap kotlin random
1 parent aa534d8 commit 85860ba

File tree

5 files changed

+243
-5
lines changed

5 files changed

+243
-5
lines changed

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/api/CreateNDArray.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ public inline fun <reified T : Any> Multik.ones(
9797
): NDArray<T, DN> =
9898
ones(intArrayOf(dim1, dim2, dim3, dim4) + dims, DataType.ofKClass(T::class))
9999

100+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
100101
public fun <T, D : Dimension> Multik.ones(dims: IntArray, dtype: DataType): NDArray<T, D> {
101102
val dim = dimensionOf<D>(dims.size)
102103
requireDimension(dim, dims.size)
@@ -1561,6 +1562,7 @@ public inline fun <reified T : Number> List<List<List<List<T>>>>.toNDArray(): D4
15611562
public inline fun <reified T : Complex> List<List<List<List<T>>>>.toNDArray(): D4Array<T> = Multik.ndarray(this)
15621563

15631564
@PublishedApi
1565+
@Suppress("UNCHECKED_CAST")
15641566
internal fun <T> Iterable<T>.toCommonNDArray(dtype: DataType): D1Array<T> {
15651567
if (this is Collection<T>)
15661568
return ndarrayCommon(this, intArrayOf(this.size), D1, dtype) as D1Array<T>
@@ -1573,6 +1575,7 @@ internal fun <T> Iterable<T>.toCommonNDArray(dtype: DataType): D1Array<T> {
15731575
}
15741576

15751577
@PublishedApi
1578+
@Suppress("UNCHECKED_CAST")
15761579
internal inline fun <reified T: Any> Iterable<T>.toCommonNDArray(): D1Array<T> {
15771580
val dtype: DataType = DataType.ofKClass(T::class)
15781581
if (this is Collection<T>)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/*
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package org.jetbrains.kotlinx.multik.api
6+
7+
import org.jetbrains.kotlinx.multik.ndarray.data.*
8+
import kotlin.jvm.JvmName
9+
import kotlin.random.Random
10+
11+
/**
12+
* Returns a vector of the specified size filled with random numbers uniformly distributed for:
13+
* Int - [Int.MIN_VALUE, Int.MAX_VALUE)
14+
* Long - [Long.MIN_VALUE, Long.MAX_VALUE)
15+
* Float - [0f, 1f)
16+
* Double - [0.0, 1.0)
17+
*/
18+
public inline fun <reified T : Number> Multik.rand(dim0: Int): D1Array<T> {
19+
require(dim0 > 0) { "Dimension must be positive." }
20+
val dtype = DataType.ofKClass(T::class)
21+
val frand: () -> T = fRand(dtype)
22+
val data = initMemoryView(dim0, dtype) { frand() }
23+
return D1Array(data, shape = intArrayOf(dim0), dim = D1)
24+
}
25+
26+
/**
27+
* Returns a matrix of the specified shape filled with random numbers uniformly distributed for:
28+
* Int - [Int.MIN_VALUE, Int.MAX_VALUE)
29+
* Long - [Long.MIN_VALUE, Long.MAX_VALUE)
30+
* Float - [0f, 1f)
31+
* Double - [0.0, 1.0)
32+
*/
33+
public inline fun <reified T : Number> Multik.rand(dim0: Int, dim1: Int): D2Array<T> {
34+
val dtype = DataType.ofKClass(T::class)
35+
val shape = intArrayOf(dim0, dim1)
36+
for (i in shape.indices) {
37+
require(shape[i] > 0) { "Dimension $i must be positive." }
38+
}
39+
val frand: () -> T = fRand(dtype)
40+
val data = initMemoryView(dim0 * dim1, dtype) { frand() }
41+
return D2Array(data, shape = shape, dim = D2)
42+
}
43+
44+
/**
45+
* Returns an NDArray of the specified shape filled with random numbers uniformly distributed for:
46+
* Int - [Int.MIN_VALUE, Int.MAX_VALUE)
47+
* Long - [Long.MIN_VALUE, Long.MAX_VALUE)
48+
* Float - [0f, 1f)
49+
* Double - [0.0, 1.0)
50+
*/
51+
public inline fun <reified T : Number> Multik.rand(dim0: Int, dim1: Int, dim2: Int): D3Array<T> {
52+
val dtype = DataType.ofKClass(T::class)
53+
val shape = intArrayOf(dim0, dim1, dim2)
54+
for (i in shape.indices) {
55+
require(shape[i] > 0) { "Dimension $i must be positive." }
56+
}
57+
val frand: () -> T = fRand(dtype)
58+
val data = initMemoryView(dim0 * dim1 * dim2, dtype) { frand() }
59+
return D3Array(data, shape = shape, dim = D3)
60+
}
61+
62+
/**
63+
* Returns an NDArray of the specified shape filled with random numbers uniformly distributed for:
64+
* Int - [Int.MIN_VALUE, Int.MAX_VALUE)
65+
* Long - [Long.MIN_VALUE, Long.MAX_VALUE)
66+
* Float - [0f, 1f)
67+
* Double - [0.0, 1.0)
68+
*/
69+
public inline fun <reified T : Number> Multik.rand(dim0: Int, dim1: Int, dim2: Int, dim3: Int): D4Array<T> {
70+
val dtype = DataType.ofKClass(T::class)
71+
val shape = intArrayOf(dim0, dim1, dim2, dim3)
72+
for (i in shape.indices) {
73+
require(shape[i] > 0) { "Dimension $i must be positive." }
74+
}
75+
val frand: () -> T = fRand(dtype)
76+
val data = initMemoryView(dim0 * dim1 * dim2 * dim3, dtype) { frand() }
77+
return D4Array(data, shape = shape, dim = D4)
78+
}
79+
80+
/**
81+
* Returns an NDArray of the specified shape filled with random numbers uniformly distributed for:
82+
* Int - [Int.MIN_VALUE, Int.MAX_VALUE)
83+
* Long - [Long.MIN_VALUE, Long.MAX_VALUE)
84+
* Float - [0f, 1f)
85+
* Double - [0.0, 1.0)
86+
*/
87+
public inline fun <reified T : Number> Multik.rand(
88+
dim0: Int, dim1: Int, dim2: Int, dim3: Int, vararg dims: Int
89+
): NDArray<T, DN> {
90+
return rand(intArrayOf(dim0, dim1, dim2, dim3, *dims))
91+
}
92+
93+
/**
94+
* Returns an NDArray of the specified shape filled with random numbers uniformly distributed for:
95+
* Int - [Int.MIN_VALUE, Int.MAX_VALUE)
96+
* Long - [Long.MIN_VALUE, Long.MAX_VALUE)
97+
* Float - [0f, 1f)
98+
* Double - [0.0, 1.0)
99+
*/
100+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(shape: IntArray): NDArray<T, D> {
101+
val dtype = DataType.ofKClass(T::class)
102+
val dim = dimensionClassOf<D>(shape.size)
103+
requireDimension(dim, shape.size)
104+
for (i in shape.indices) {
105+
require(shape[i] > 0) { "Dimension $i must be positive." }
106+
}
107+
val size = shape.fold(1, Int::times)
108+
val frand: () -> T = fRand(dtype)
109+
val data = initMemoryView(size, dtype) { frand() }
110+
return NDArray(data, shape = shape, dim = dim)
111+
}
112+
113+
114+
/**
115+
* Returns an NDArray of the specified shape filled with number uniformly distributed between [[from], [until])
116+
*/
117+
@JvmName("randWithVarArg")
118+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(
119+
from: T, until: T, vararg dims: Int
120+
): NDArray<T, D> =
121+
Multik.rand(from, until, dims)
122+
123+
/**
124+
* Returns an NDArray of the specified shape filled with number uniformly distributed between [[from], [until])
125+
*
126+
* Note: Float generation is inefficient.
127+
*/
128+
@JvmName("randWithShape")
129+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(
130+
from: T, until: T, dims: IntArray
131+
): NDArray<T, D> {
132+
val dtype = DataType.ofKClass(T::class)
133+
val dim = dimensionClassOf<D>(dims.size)
134+
requireDimension(dim, dims.size)
135+
for (i in dims.indices) {
136+
require(dims[i] > 0) { "Dimension $i must be positive." }
137+
}
138+
val size = dims.fold(1, Int::times)
139+
val data = randData(from, until, size, dtype)
140+
return NDArray(data, shape = dims, dim = dim)
141+
}
142+
143+
/**
144+
* Returns an NDArray of the specified shape filled with number uniformly distributed between [[from], [until])
145+
* with the specified [seed].
146+
*
147+
* Note: Float generation is inefficient.
148+
*/
149+
@JvmName("randSeedVarArg")
150+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(
151+
seed: Int, from: T, until: T, vararg dims: Int
152+
): NDArray<T, D> = Multik.rand(Random(seed), from, until, dims)
153+
154+
/**
155+
* Returns an NDArray of the specified shape filled with number uniformly distributed between [[from], [until])
156+
* with the specified [seed].
157+
*
158+
* Note: Float generation is inefficient.
159+
*/
160+
@JvmName("randSeedShape")
161+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(
162+
seed: Int, from: T, until: T, dims: IntArray
163+
): NDArray<T, D> = Multik.rand(Random(seed), from, until, dims)
164+
165+
/**
166+
* Returns an NDArray of the specified shape filled with number uniformly distributed between [[from], [until])
167+
* with the specified [gen].
168+
*
169+
* Note: Float generation is inefficient.
170+
*/
171+
@JvmName("randGenVarArg")
172+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(
173+
gen: Random, from: T, until: T, vararg dims: Int
174+
): NDArray<T, D> = Multik.rand(gen, from, until, dims)
175+
176+
/**
177+
* Returns an NDArray of the specified shape filled with number uniformly distributed between [[from], [until])
178+
* with the specified [gen].
179+
*
180+
* Note: Float generation is inefficient.
181+
*/
182+
@JvmName("randGenShape")
183+
public inline fun <reified T : Number, reified D : Dimension> Multik.rand(
184+
gen: Random, from: T, until: T, dims: IntArray
185+
): NDArray<T, D> {
186+
val dtype = DataType.ofKClass(T::class)
187+
val dim = dimensionClassOf<D>(dims.size)
188+
requireDimension(dim, dims.size)
189+
for (i in dims.indices) {
190+
require(dims[i] > 0) { "Dimension $i must be positive." }
191+
}
192+
val size = dims.fold(1, Int::times)
193+
val data = randData(from, until, size, dtype, gen)
194+
return NDArray(data, shape = dims, dim = dim)
195+
}
196+
197+
@PublishedApi
198+
@Suppress("UNCHECKED_CAST")
199+
internal inline fun <T : Number> fRand(dtype: DataType): () -> T {
200+
return when (dtype) {
201+
DataType.IntDataType -> { { Random.nextInt() } }
202+
DataType.LongDataType -> { { Random.nextLong() } }
203+
DataType.FloatDataType -> { { Random.nextFloat() } }
204+
DataType.DoubleDataType -> { { Random.nextDouble() } }
205+
else -> throw UnsupportedOperationException("Other types are not currently supported")
206+
} as () -> T
207+
}
208+
209+
@PublishedApi
210+
@Suppress("UNCHECKED_CAST")
211+
internal inline fun <T : Number> randData(
212+
from: T, until: T, size: Int, dtype: DataType, gen: Random? = null
213+
): MemoryView<T> {
214+
var f = 0.0
215+
var u = 0.0
216+
val random = gen ?: Random.Default
217+
if (from is Float && until is Float) {
218+
f = from.toDouble()
219+
u = until.toDouble()
220+
}
221+
return when {
222+
from is Int && until is Int -> initMemoryView(size, dtype) { random.nextInt(from, until) }
223+
from is Long && until is Long -> initMemoryView(size, dtype) { random.nextLong(from, until) }
224+
from is Float && until is Float -> initMemoryView(size, dtype) { random.nextDouble(f, u).toFloat() }
225+
from is Double && until is Double -> initMemoryView(size, dtype) { random.nextDouble(from, until) }
226+
else -> throw UnsupportedOperationException()
227+
} as MemoryView<T>
228+
}

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/operations/Inplace.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ public class Abs<T : Number> : Exp<T>() {
328328
}
329329
}
330330

331+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
331332
private inline operator fun <T> Number.plus(other: T): T {
332333
return when {
333334
this is Double && other is Double -> this + other
@@ -340,6 +341,7 @@ private inline operator fun <T> Number.plus(other: T): T {
340341
} as T
341342
}
342343

344+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
343345
private inline operator fun <T> Complex.plus(other: T): T {
344346
return when {
345347
this is ComplexFloat && other is ComplexFloat -> this + other
@@ -348,6 +350,7 @@ private inline operator fun <T> Complex.plus(other: T): T {
348350
} as T
349351
}
350352

353+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
351354
private inline operator fun <T> Number.minus(other: T): T {
352355
return when {
353356
this is Double && other is Double -> this - other
@@ -360,6 +363,7 @@ private inline operator fun <T> Number.minus(other: T): T {
360363
} as T
361364
}
362365

366+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
363367
private inline operator fun <T> Complex.minus(other: T): T {
364368
return when {
365369
this is ComplexFloat && other is ComplexFloat -> this - other
@@ -368,6 +372,7 @@ private inline operator fun <T> Complex.minus(other: T): T {
368372
} as T
369373
}
370374

375+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
371376
private inline operator fun <T> Number.times(other: T): T {
372377
return when {
373378
this is Double && other is Double -> this * other
@@ -380,6 +385,7 @@ private inline operator fun <T> Number.times(other: T): T {
380385
} as T
381386
}
382387

388+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
383389
private inline operator fun <T> Complex.times(other: T): T {
384390
return when {
385391
this is ComplexFloat && other is ComplexFloat -> this * other
@@ -388,6 +394,7 @@ private inline operator fun <T> Complex.times(other: T): T {
388394
} as T
389395
}
390396

397+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
391398
private inline operator fun <T> Number.div(other: T): T {
392399
return when {
393400
this is Double && other is Double -> this / other
@@ -400,6 +407,7 @@ private inline operator fun <T> Number.div(other: T): T {
400407
} as T
401408
}
402409

410+
@Suppress("IMPLICIT_CAST_TO_ANY", "UNCHECKED_CAST")
403411
private inline operator fun <T> Complex.div(other: T): T {
404412
return when {
405413
this is ComplexFloat && other is ComplexFloat -> this / other

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/operations/IteratingNDArray.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1748,7 +1748,6 @@ internal fun mapCapacity(size: Int): Int {
17481748
}
17491749
}
17501750

1751-
@Suppress("NOTHING_TO_INLINE")
17521751
@PublishedApi
17531752
internal inline fun checkIndexOverflow(index: Int): Int {
17541753
if (index < 0) throw ArithmeticException("Index overflow has happened.")

multik-core/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/ndarray/complex/ComplexMultiArrayTest.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class ComplexMultiArrayTest {
1414
val complex = mk.d2arrayIndices(3, 3) { i, j -> ComplexFloat(i, j) }
1515
val real = complex.re
1616
val im = complex.im
17-
val expectedReal = mk.d2arrayIndices(3, 3) { i, j -> i.toFloat() }
18-
val expectedIm = mk.d2arrayIndices(3, 3) { i, j -> j.toFloat() }
17+
val expectedReal = mk.d2arrayIndices(3, 3) { i, _ -> i.toFloat() }
18+
val expectedIm = mk.d2arrayIndices(3, 3) { _, j -> j.toFloat() }
1919

2020
assertEquals(complex.shape, real.shape)
2121
assertEquals(complex.shape, im.shape)
@@ -28,8 +28,8 @@ class ComplexMultiArrayTest {
2828
val complex = mk.d2arrayIndices(3, 3) { i, j -> ComplexDouble(i, j) }
2929
val real = complex.re
3030
val im = complex.im
31-
val expectedReal = mk.d2arrayIndices(3, 3) { i, j -> i.toDouble() }
32-
val expectedIm = mk.d2arrayIndices(3, 3) { i, j -> j.toDouble() }
31+
val expectedReal = mk.d2arrayIndices(3, 3) { i, _ -> i.toDouble() }
32+
val expectedIm = mk.d2arrayIndices(3, 3) { _, j -> j.toDouble() }
3333

3434
assertEquals(complex.shape, real.shape)
3535
assertEquals(complex.shape, im.shape)

0 commit comments

Comments
 (0)