Skip to content

Commit db8b53e

Browse files
committed
fix slice test and offset for reshape
1 parent e0420a2 commit db8b53e

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/data/NDArray.kt

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,13 @@ public class NDArray<T, D : Dimension> constructor(
109109

110110
// TODO(get rid of copying)
111111
val newData = if (consistent) this.data else this.deepCopy().data
112-
val newBase = if (consistent) base ?: this else null
112+
val newBase = if (consistent) this.base ?: this else null
113+
val newOffset = if (consistent) this.offset else 0
113114

114115
return if (this.dim.d == 1 && this.shape.first() == dim1) {
115116
this as D1Array<T>
116117
} else {
117-
D1Array(newData, this.offset, intArrayOf(dim1), dim = D1, base = newBase)
118+
D1Array(newData, newOffset, intArrayOf(dim1), dim = D1, base = newBase)
118119
}
119120
}
120121

@@ -125,12 +126,13 @@ public class NDArray<T, D : Dimension> constructor(
125126

126127
// TODO(get rid of copying)
127128
val newData = if (consistent) this.data else this.deepCopy().data
128-
val newBase = if (consistent) base ?: this else null
129+
val newBase = if (consistent) this.base ?: this else null
130+
val newOffset = if (consistent) this.offset else 0
129131

130132
return if (this.shape.contentEquals(newShape)) {
131133
this as D2Array<T>
132134
} else {
133-
D2Array(newData, this.offset, newShape, dim = D2, base = newBase)
135+
D2Array(newData, newOffset, newShape, dim = D2, base = newBase)
134136
}
135137
}
136138

@@ -142,11 +144,12 @@ public class NDArray<T, D : Dimension> constructor(
142144
// TODO(get rid of copying)
143145
val newData = if (consistent) this.data else this.deepCopy().data
144146
val newBase = if (consistent) base ?: this else null
147+
val newOffset = if (consistent) this.offset else 0
145148

146149
return if (this.shape.contentEquals(newShape)) {
147150
this as D3Array<T>
148151
} else {
149-
D3Array(newData, this.offset, newShape, dim = D3, base = newBase)
152+
D3Array(newData, newOffset, newShape, dim = D3, base = newBase)
150153
}
151154
}
152155

@@ -157,12 +160,13 @@ public class NDArray<T, D : Dimension> constructor(
157160

158161
// TODO(get rid of copying)
159162
val newData = if (consistent) this.data else this.deepCopy().data
160-
val newBase = if (consistent) base ?: this else null
163+
val newBase = if (consistent) this.base ?: this else null
164+
val newOffset = if (consistent) this.offset else 0
161165

162166
return if (this.shape.contentEquals(newShape)) {
163167
this as D4Array<T>
164168
} else {
165-
D4Array(newData, this.offset, newShape, dim = D4, base = newBase)
169+
D4Array(newData, newOffset, newShape, dim = D4, base = newBase)
166170
}
167171
}
168172

@@ -175,12 +179,13 @@ public class NDArray<T, D : Dimension> constructor(
175179

176180
// TODO(get rid of copying)
177181
val newData = if (consistent) this.data else this.deepCopy().data
178-
val newBase = if (consistent) base ?: this else null
182+
val newBase = if (consistent) this.base ?: this else null
183+
val newOffset = if (consistent) this.offset else 0
179184

180185
return if (this.shape.contentEquals(newShape)) {
181186
this as NDArray<T, DN>
182187
} else {
183-
NDArray(newData, this.offset, newShape, dim = DN(newShape.size), base = newBase)
188+
NDArray(newData, newOffset, newShape, dim = DN(newShape.size), base = newBase)
184189
}
185190
}
186191

@@ -223,9 +228,10 @@ public class NDArray<T, D : Dimension> constructor(
223228
}
224229
// TODO(get rid of copying)
225230
val newData = if (consistent) this.data else this.deepCopy().data
226-
val newBase = if (consistent) base ?: this else null
231+
val newBase = if (consistent) this.base ?: this else null
232+
val newOffset = if (consistent) this.offset else 0
227233

228-
return NDArray(newData, this.offset, newShape.toIntArray(), dim = DN(newShape.size), base = newBase)
234+
return NDArray(newData, newOffset, newShape.toIntArray(), dim = DN(newShape.size), base = newBase)
229235
}
230236

231237
override infix fun cat(other: MultiArray<T, D>): NDArray<T, D> =

multik-core/src/commonMain/kotlin/org/jetbrains/kotlinx/multik/ndarray/operations/Transformation.kt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ public fun <T> MultiArray<T, D1>.expandDims(axis: Int): MultiArray<T, D2> {
310310
val newShape = shape.toMutableList().apply { add(axis, 1) }.toIntArray()
311311
// TODO(get rid of copying)
312312
val newData = if (consistent) this.data else this.deepCopy().data
313-
val newBase = if (consistent) base ?: this else null
314-
return D2Array(newData, this.offset, newShape, dim = D2, base = newBase)
313+
val newBase = if (consistent) this.base ?: this else null
314+
val newOffset = if (consistent) this.offset else 0
315+
return D2Array(newData, newOffset, newShape, dim = D2, base = newBase)
315316
}
316317

317318
/**
@@ -322,8 +323,9 @@ public fun <T> MultiArray<T, D2>.expandDims(axis: Int): MultiArray<T, D3> {
322323
val newShape = shape.toMutableList().apply { add(axis, 1) }.toIntArray()
323324
// TODO(get rid of copying)
324325
val newData = if (consistent) this.data else this.deepCopy().data
325-
val newBase = if (consistent) base ?: this else null
326-
return D3Array(newData, this.offset, newShape, dim = D3, base = newBase)
326+
val newBase = if (consistent) this.base ?: this else null
327+
val newOffset = if (consistent) this.offset else 0
328+
return D3Array(newData, newOffset, newShape, dim = D3, base = newBase)
327329
}
328330

329331
/**
@@ -334,8 +336,9 @@ public fun <T> MultiArray<T, D3>.expandDims(axis: Int): MultiArray<T, D4> {
334336
val newShape = shape.toMutableList().apply { add(axis, 1) }.toIntArray()
335337
// TODO(get rid of copying)
336338
val newData = if (consistent) this.data else this.deepCopy().data
337-
val newBase = if (consistent) base ?: this else null
338-
return D4Array(newData, this.offset, newShape, dim = D4, base = newBase)
339+
val newBase = if (consistent) this.base ?: this else null
340+
val newOffset = if (consistent) this.offset else 0
341+
return D4Array(newData, newOffset, newShape, dim = D4, base = newBase)
339342
}
340343

341344
/**

multik-core/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/ndarray/data/SliceTest.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2020-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
package org.jetbrains.kotlinx.multik.ndarray.data
@@ -71,15 +71,15 @@ class SliceTest {
7171
val a2 = a.reshape(3, 2)
7272
val b2 = b.reshape(4, 1)
7373
assertSame(a, a2.base)
74-
assertSame(a, b2.base)
74+
assertSame(null, b2.base)
7575

7676
val d1 = b2.squeeze()
7777
val d2 = d1.unsqueeze()
78-
assertSame(a, d1.base)
79-
assertSame(a, d2.base)
78+
assertSame(b2, d1.base)
79+
assertSame(b2, d2.base)
8080

8181
val e = b2.transpose()
82-
assertSame(a, e.base)
82+
assertSame(b2, e.base)
8383

8484
val f = a2[1]
8585
assertSame(a, f.base)

0 commit comments

Comments
 (0)