5
5
6
6
Return `x` reshaped into an array one dimensionality higher than `x`,
7
7
where `dims` indicates in which dimension `x` is extended.
8
+ `dims` can be an integer between 1 and `ndims(x)+1`.
8
9
9
10
See also [`flatten`](@ref), [`stack`](@ref).
10
11
@@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
33
34
[1, 2] [3, 4] [5, 6]
34
35
```
35
36
"""
36
- function unsqueeze (x:: AbstractArray ; dims:: Int )
37
- sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), ndims (x) + 1 )
37
+ function unsqueeze (x:: AbstractArray{T,N} ; dims:: Int ) where {T, N}
38
+ @assert 1 <= dims <= N + 1
39
+ sz = ntuple (i -> i < dims ? size (x, i) : i == dims ? 1 : size (x, i - 1 ), N + 1 )
38
40
return reshape (x, sz)
39
41
end
40
42
@@ -55,51 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)
55
57
56
58
Base. show_function (io:: IO , u:: Base.Fix2{typeof(_unsqueeze)} , :: Bool ) = print (io, " unsqueeze(dims=" , u. x, " )" )
57
59
58
- """
59
- stack(xs; dims)
60
-
61
- Concatenate the given array of arrays `xs` into a single array along the
62
- given dimension `dims`.
63
-
64
- See also [`stack`](@ref) and [`batch`](@ref).
65
-
66
- # Examples
67
-
68
- ```jldoctest
69
- julia> xs = [[1, 2], [3, 4], [5, 6]]
70
- 3-element Vector{Vector{Int64}}:
71
- [1, 2]
72
- [3, 4]
73
- [5, 6]
74
-
75
- julia> stack(xs, dims=1)
76
- 3×2 Matrix{Int64}:
77
- 1 2
78
- 3 4
79
- 5 6
80
-
81
- julia> stack(xs, dims=2)
82
- 2×3 Matrix{Int64}:
83
- 1 3 5
84
- 2 4 6
85
-
86
- julia> stack(xs, dims=3)
87
- 2×1×3 Array{Int64, 3}:
88
- [:, :, 1] =
89
- 1
90
- 2
91
-
92
- [:, :, 2] =
93
- 3
94
- 4
95
-
96
- [:, :, 3] =
97
- 5
98
- 6
99
- ```
100
- """
101
- stack (xs; dims:: Int ) = cat (unsqueeze .(xs; dims)... ; dims)
102
-
103
60
"""
104
61
unstack(xs; dims)
105
62
329
286
330
287
batchindex (xs, i) = (reverse (Base. tail (reverse (axes (xs))))... , i)
331
288
332
- function batch (xs:: AbstractArray{<:AbstractArray} )
333
- # Don't use stack(xs, dims=N+1), it is much slower.
334
- # Here we do reduce(vcat, xs) along with some reshapes.
335
- szxs = size (xs)
336
- @assert length (xs) > 0 " Minimum batch size is 1."
337
- szx = size (xs[1 ])
338
- @assert all (x -> size (x) == szx, xs) " All arrays must be of the same size."
339
- vxs = vec (vec .(xs))
340
- y = reduce (vcat, vxs)
341
- return reshape (y, szx... , szxs... )
342
- end
289
+ batch (xs:: AbstractArray{<:AbstractArray} ) = stack (xs)
343
290
344
291
function batch (xs:: Vector{<:Tuple} )
345
292
@assert length (xs) > 0 " Input should be non-empty"
0 commit comments