Skip to content

Commit 29e6939

Browse files
committed
fix reshaping for slicing arrays
1 parent 3299370 commit 29e6939

File tree

3 files changed

+48
-21
lines changed

3 files changed

+48
-21
lines changed

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
kotlin.code.style=official
2-
multik_version=0.2.0-dev-3
2+
multik_version=0.2.0-dev-4
33

44
# Kotlin
55
systemProp.kotlin_version=1.7.10

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/data/NDArray.kt

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,14 @@ public class NDArray<T, D : Dimension> constructor(
107107
requirePositiveShape(dim1)
108108
require(dim1 == size) { "Cannot reshape array of size $size into a new shape ($dim1)" }
109109

110+
// TODO(get rid of copying)
111+
val newData = if (consistent) this.data else this.deepCopy().data
112+
val newBase = if (consistent) base ?: this else null
113+
110114
return if (this.dim.d == 1 && this.shape.first() == dim1) {
111115
this as D1Array<T>
112116
} else {
113-
D1Array(this.data, this.offset, intArrayOf(dim1), dim = D1, base = base ?: this)
117+
D1Array(newData, this.offset, intArrayOf(dim1), dim = D1, base = newBase)
114118
}
115119
}
116120

@@ -119,10 +123,14 @@ public class NDArray<T, D : Dimension> constructor(
119123
newShape.forEach { requirePositiveShape(it) }
120124
require(dim1 * dim2 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2)" }
121125

126+
// TODO(get rid of copying)
127+
val newData = if (consistent) this.data else this.deepCopy().data
128+
val newBase = if (consistent) base ?: this else null
129+
122130
return if (this.shape.contentEquals(newShape)) {
123131
this as D2Array<T>
124132
} else {
125-
D2Array(this.data, this.offset, newShape, dim = D2, base = base ?: this)
133+
D2Array(newData, this.offset, newShape, dim = D2, base = newBase)
126134
}
127135
}
128136

@@ -131,10 +139,14 @@ public class NDArray<T, D : Dimension> constructor(
131139
newShape.forEach { requirePositiveShape(it) }
132140
require(dim1 * dim2 * dim3 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2, $dim3)" }
133141

142+
// TODO(get rid of copying)
143+
val newData = if (consistent) this.data else this.deepCopy().data
144+
val newBase = if (consistent) base ?: this else null
145+
134146
return if (this.shape.contentEquals(newShape)) {
135147
this as D3Array<T>
136148
} else {
137-
D3Array(this.data, this.offset, newShape, dim = D3, base = base ?: this)
149+
D3Array(newData, this.offset, newShape, dim = D3, base = newBase)
138150
}
139151
}
140152

@@ -143,10 +155,14 @@ public class NDArray<T, D : Dimension> constructor(
143155
newShape.forEach { requirePositiveShape(it) }
144156
require(dim1 * dim2 * dim3 * dim4 == size) { "Cannot reshape array of size $size into a new shape ($dim1, $dim2, $dim3, $dim4)" }
145157

158+
// TODO(get rid of copying)
159+
val newData = if (consistent) this.data else this.deepCopy().data
160+
val newBase = if (consistent) base ?: this else null
161+
146162
return if (this.shape.contentEquals(newShape)) {
147163
this as D4Array<T>
148164
} else {
149-
D4Array(this.data, this.offset, newShape, dim = D4, base = base ?: this)
165+
D4Array(newData, this.offset, newShape, dim = D4, base = newBase)
150166
}
151167
}
152168

@@ -157,10 +173,14 @@ public class NDArray<T, D : Dimension> constructor(
157173
"Cannot reshape array of size $size into a new shape ${newShape.joinToString(prefix = "(", postfix = ")")}"
158174
}
159175

176+
// TODO(get rid of copying)
177+
val newData = if (consistent) this.data else this.deepCopy().data
178+
val newBase = if (consistent) base ?: this else null
179+
160180
return if (this.shape.contentEquals(newShape)) {
161181
this as NDArray<T, DN>
162182
} else {
163-
NDArray(this.data, this.offset, newShape, dim = DN(newShape.size), base = base ?: this)
183+
NDArray(newData, this.offset, newShape, dim = DN(newShape.size), base = newBase)
164184
}
165185
}
166186

@@ -201,13 +221,11 @@ public class NDArray<T, D : Dimension> constructor(
201221
for (axis in axes.sorted()) {
202222
newShape.add(axis, 1)
203223
}
204-
return NDArray(
205-
this.data,
206-
this.offset,
207-
newShape.toIntArray(),
208-
dim = DN(newShape.size),
209-
base = base ?: this
210-
)
224+
// TODO(get rid of copying)
225+
val newData = if (consistent) this.data else this.deepCopy().data
226+
val newBase = if (consistent) base ?: this else null
227+
228+
return NDArray(newData, this.offset, newShape.toIntArray(), dim = DN(newShape.size), base = newBase)
211229
}
212230

213231
override infix fun cat(other: MultiArray<T, D>): NDArray<T, D> =

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/operations/Transformation.kt

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,39 +306,48 @@ public fun <T, D : Dimension> MultiArray<T, D>.clip(min: T, max: T): NDArray<T,
306306
* Returns a ndarray with an expanded shape.
307307
*/
308308
@JvmName("expandDimsD1")
309-
public fun <T, D : Dimension> MultiArray<T, D1>.expandDims(axis: Int): MultiArray<T, D2> {
309+
public fun <T> MultiArray<T, D1>.expandDims(axis: Int): MultiArray<T, D2> {
310310
val newShape = shape.toMutableList().apply { add(axis, 1) }.toIntArray()
311-
return D2Array(this.data, this.offset, newShape, dim = D2, base = base ?: this)
311+
// TODO(get rid of copying)
312+
val newData = if (consistent) this.data else this.deepCopy().data
313+
val newBase = if (consistent) base ?: this else null
314+
return D2Array(newData, this.offset, newShape, dim = D2, base = newBase)
312315
}
313316

314317
/**
315318
* Returns a ndarray with an expanded shape.
316319
*/
317320
@JvmName("expandDimsD2")
318-
public fun <T, D : Dimension> MultiArray<T, D2>.expandDims(axis: Int): MultiArray<T, D3> {
321+
public fun <T> MultiArray<T, D2>.expandDims(axis: Int): MultiArray<T, D3> {
319322
val newShape = shape.toMutableList().apply { add(axis, 1) }.toIntArray()
320-
return D3Array(this.data, this.offset, newShape, dim = D3, base = base ?: this)
323+
// TODO(get rid of copying)
324+
val newData = if (consistent) this.data else this.deepCopy().data
325+
val newBase = if (consistent) base ?: this else null
326+
return D3Array(newData, this.offset, newShape, dim = D3, base = newBase)
321327
}
322328

323329
/**
324330
* Returns a ndarray with an expanded shape.
325331
*/
326332
@JvmName("expandDimsD3")
327-
public fun <T, D : Dimension> MultiArray<T, D3>.expandDims(axis: Int): MultiArray<T, D4> {
333+
public fun <T> MultiArray<T, D3>.expandDims(axis: Int): MultiArray<T, D4> {
328334
val newShape = shape.toMutableList().apply { add(axis, 1) }.toIntArray()
329-
return D4Array(this.data, this.offset, newShape, dim = D4, base = base ?: this)
335+
// TODO(get rid of copying)
336+
val newData = if (consistent) this.data else this.deepCopy().data
337+
val newBase = if (consistent) base ?: this else null
338+
return D4Array(newData, this.offset, newShape, dim = D4, base = newBase)
330339
}
331340

332341
/**
333342
* Returns a ndarray with an expanded shape.
334343
*/
335344
@JvmName("expandDimsD4")
336-
public fun <T, D : Dimension> MultiArray<T, D4>.expandDims(axis: Int): MultiArray<T, DN> = this.unsqueeze()
345+
public fun <T> MultiArray<T, D4>.expandDims(axis: Int): MultiArray<T, DN> = this.unsqueeze()
337346

338347
/**
339348
* Returns a ndarray with an expanded shape.
340349
*
341350
* @see MultiArray.unsqueeze
342351
*/
343352
@JvmName("expandDimsDN")
344-
public fun <T, D : Dimension> MultiArray<T, D>.expandDims(vararg axes: Int): MultiArray<T, DN> = this.unsqueeze()
353+
public fun <T, D : Dimension> MultiArray<T, D>.expandNDims(vararg axes: Int): MultiArray<T, DN> = this.unsqueeze()

0 commit comments

Comments
 (0)