Skip to content

Commit 6f23e31

Browse files
authored
Allow repeat-like patterns (#51)
* allow broadcast over indices not present on RHS * partial fix for function on left, e.g. rand(3)[i] = i * a test * fix tests, docs
1 parent df7e8be commit 6f23e31

File tree

9 files changed

+83
-38
lines changed

9 files changed

+83
-38
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0"
5+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
56
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
67
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
78

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

22
using Documenter
3-
using TensorCast
3+
using TensorCast, OffsetArrays
44

55
makedocs(
66
sitename = "TensorCast",
7-
modules = [TensorCast],
7+
modules = [TensorCast, OffsetArrays],
88
pages = [
99
"Home" => "index.md",
1010
"Basics" => "basics.md",

docs/src/basics.md

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Combining this with slicing from `M[:,j]` is a convenient way to perform `mapsli
8383

8484
```jldoctest mylabel
8585
julia> @cast C[i,j] := cumsum(M[:,j])[i]
86-
3×4 stack(::Vector{Vector{Int64}}) with eltype Int64:
86+
3×4 lazystack(::Vector{Vector{Int64}}) with eltype Int64:
8787
1 4 7 10
8888
3 9 15 21
8989
6 15 24 33
@@ -205,28 +205,23 @@ as `size(A)` is known:
205205
julia> @cast A[i,j] = 10 * collect(1:12)[i⊗j];
206206
```
207207

208+
## Repeating
209+
208210
If the right hand side is independent of an index, then the same result is repeated.
209211
The range of the index must still be known:
210212

211213
```jldoctest mylabel
212214
julia> @cast R[r,(n,c)] := M[r,c]^2 (n in 1:3)
213-
ERROR: LoadError: index n appears only on the left
214-
@cast R[r, (n, c)] := M[r, c] ^ 2 n in 1:3
215-
@ Main none:1
216-
Stacktrace:
217-
[1] checkallseen(canon::Vector{Any}, store::NamedTuple{(:dict, :assert, :mustassert, :seen, :need, :top, :main), Tuple{Dict{Any, Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}}}, call::TensorCast.CallInfo)
218-
@ TensorCast ~/.julia/dev/TensorCast/src/macro.jl:1466
219-
[2] _macro(exone::Expr, extwo::Expr, exthree::Nothing; call::TensorCast.CallInfo, dict::Dict{Any, Any})
220-
@ TensorCast ~/.julia/dev/TensorCast/src/macro.jl:199
221-
[3] var"@cast"(__source__::LineNumberNode, __module__::Module, exs::Vararg{Any})
222-
@ TensorCast ~/.julia/dev/TensorCast/src/macro.jl:74
223-
in expression starting at none:1
215+
3×12 Matrix{Int64}:
216+
1 1 1 16 16 16 49 49 49 100 100 100
217+
4 4 4 25 25 25 64 64 64 121 121 121
218+
9 9 9 36 36 36 81 81 81 144 144 144
224219
225220
julia> R == repeat(M .^ 2, inner=(1,3))
226221
true
227222
228-
julia> @cast R[r,(c,n)] = M[r,c] # repeat(M, outer=(1,3)), uses size(R)
229-
3×12 Array{Int64,2}:
223+
julia> @cast similar(R)[r,(c,n)] = M[r,c] # repeat(M, outer=(1,3)), uses size(R)
224+
3×12 Matrix{Int64}:
230225
1 4 7 10 1 4 7 10 1 4 7 10
231226
2 5 8 11 2 5 8 11 2 5 8 11
232227
3 6 9 12 3 6 9 12 3 6 9 12

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ New features in 0.4:
2323
- Indices can appear ouside of indexing: `@cast A[i,j] = i+j` translates to `A .= axes(A,1) .+ axes(A,2)'`
2424
- The ternary operator `? :` can appear on the right, and will be broadcast correctly.
2525
- All operations should now support [OffsetArrays.jl](https://github.com/JuliaArrays/OffsetArrays.jl).
26+
- You can `repeat` by broadcasting over indices not appearing on the right, such as `@cast r[i,(k,j)] = m[i,j]`
2627

2728
## Pages
2829

docs/src/options.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ M = rand(1:99, 3,4)
2222
@cast S[k][i] := M[i,k] lazy=false # the same
2323
```
2424

25-
The default way of un-slicing is `reduce(hcat, ...)`, which creates a new array.
26-
But there are other options, controlled by keywords after the expression:
25+
The default way of un-slicing uses [LazyStack.jl](https://github.com/mcabbott/LazyStack.jl) to create a view.
26+
The keyword `lazy=false` after the expression will turn this off, to make a solid array using Base's code:
2727

2828
```julia
29-
@cast A[i,k] := S[k][i] lazy=false # A = reduce(hcat, B)
3029
@cast A[i,k] := S[k][i] # A = LazyStack.stack(B)
30+
@cast A[i,k] := S[k][i] lazy=false # A = reduce(hcat, B)
3131

3232
size(A) == (3, 4) # true
3333
```

src/TensorCast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ module Fast # shield non-macro code from @optlevel 1
3232
export sliceview, slicecopy, copy_glue, glue!, iscodesorted, countcolons
3333

3434
include("view.jl")
35-
export diagview, mul!, rview, star
35+
export diagview, mul!, rview, star, onlyfirst
3636

3737
include("static.jl")
3838
export static_slice, static_glue

src/macro.jl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ function _macro(exone, extwo=nothing, exthree=nothing; call::CallInfo=CallInfo()
187187

188188
# Third pass to standardise & then glue, postwalk sees A[i] before A[i][j]
189189
right3 = MacroTools.postwalk(x -> standardglue(x, canon, store, call), right2)
190+
right3 = checkallseen(right3, canon, store, call)
190191

191192
if !(:matmul in call.flags)
192193
# Then finally broadcasting if necc (or just permutedims etc. if not):
@@ -196,8 +197,6 @@ function _macro(exone, extwo=nothing, exthree=nothing; call::CallInfo=CallInfo()
196197
right4 = matmultarget(right3, canon, parsed, store, call)
197198
end
198199

199-
checkallseen(canon, store, call) # this must be run before inplaceoutput()
200-
201200
# Return to LHS, build up what it requested:
202201
if :inplace in call.flags
203202
rightlist = inplaceoutput(right4, canon, parsed, store, call)
@@ -1344,6 +1343,25 @@ end
13441343
# return sortperm(xi) == sortperm(yi)
13451344
# end
13461345

1346+
"""
1347+
checkallseen(rhs, ...)
1348+
1349+
Now not just a check, but also inserts trivial broadcasting, if needed,
1350+
over indices omitted from the rhs.
1351+
"""
1352+
function checkallseen(ex, canon, store, call)
1353+
right = setdiff(store.seen, canon)
1354+
length(right) > 0 && throw(MacroError("index $(right[1]) appears only on the right", call))
1355+
left = setdiff(canon, unique!(store.seen))
1356+
# length(left) > 0 && throw(MacroError("index $(left[1]) appears only on the left", call))
1357+
if isempty(left)
1358+
ex
1359+
else
1360+
fake = map(i -> recursemacro(i, canon, store, call), left)
1361+
:( TensorCast.onlyfirst($ex, $(fake...)) )
1362+
end
1363+
end
1364+
13471365
"""
13481366
needview!([:, 3, A]) # true, need view(A, :,3,:)
13491367
needview!([:, :_, :]) # false, can use rview(A, :,1,:)
@@ -1467,14 +1485,6 @@ function checknorepeats(flat, call=nothing, msg=nothing)
14671485
end
14681486
end
14691487

1470-
function checkallseen(canon, store, call)
1471-
left = setdiff(canon, unique!(store.seen))
1472-
length(left) > 0 && throw(MacroError("index $(left[1]) appears only on the left", call))
1473-
right = setdiff(store.seen, canon)
1474-
length(right) > 0 && throw(MacroError("index $(right[1]) appears only on the right", call))
1475-
end
1476-
1477-
# this may never be necessary with checkallseen?
14781488
function findcheck(i::Symbol, flat::Vector, call=nothing, msg=nothing)
14791489
msg == nothing && (msg = " in " * string(:( [$(flat...)] )))
14801490
res = findfirst(isequal(i), flat)
@@ -1590,11 +1600,8 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo)
15901600
newleft = standardise(parsed.left, store, call)
15911601
@capture(newleft, zed_[ijk__]) || throw(MacroError("failed to parse LHS correctly, $(parsed.left) -> $newleft"))
15921602

1593-
if !(zed isa Symbol) # then standardise did something!
1603+
if newleft != parsed.left # then standardise did something!
15941604
push!(call.flags, :showfinal)
1595-
Zsym = gensym(:reverse)
1596-
push!(out, :( local $Zsym = $zed ) )
1597-
zed = Zsym
15981605
end
15991606
end
16001607

@@ -1619,7 +1626,12 @@ function inplaceoutput(ex, canon, parsed, store::NamedTuple, call::CallInfo)
16191626
end
16201627

16211628
if :showfinal in call.flags
1622-
push!(out, parsed.name)
1629+
A = parsed.name
1630+
if A isa Symbol || @capture(A, AA_.ff_)
1631+
else
1632+
A = Symbol(A,"_val") # exact same symbol is used by standardise()
1633+
end
1634+
push!(out, A)
16231635
end
16241636

16251637
return out

src/view.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,15 @@ mul!(Z::AbstractArray{T,0}, A,B) where {T} = copyto!(Z, A * B)
2828
"""
2929
star(x,y,...)
3030
31-
Used for multiplying axes now, not sizes.
31+
Used for multiplying axes now, always producing a `OneTo` whose length
32+
is the product of the given ranges.
3233
"""
3334
star(x, y) = Base.OneTo(length(x) * length(y))
3435
star(x,y,zs...) = star(star(x,y), zs...)
36+
37+
"""
38+
onlyfirst(x, ys...) = x
39+
40+
Used with arguments which are there just to set the shape of broadcasting.
41+
"""
42+
onlyfirst(x, ys...) = x

test/four.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ end
8181
@test E sum((1:2) ./ (1:4)')
8282
end
8383

84+
@testset "repeats" begin
85+
@test [1,1,1] == @cast _[i] := 1 (i in 1:3)
86+
@test [1,1,1] == @cast rand(3)[i] = 1
87+
88+
M = reshape(1:12, 3,4)
89+
repeat(M, inner=(1,3)) == @cast _[i, (r,j)] := M[i,j] r in 1:3
90+
repeat(M, outer=(1,3)) == @cast _[i, (j,r)] := M[i,j] r in 1:3
91+
92+
@cast T[i,r,j] := M[i,j] r in 1:3
93+
@test T[:,3,:] == M
94+
95+
@test 100 .+ sum(M, dims=2) == @cast _[i,k] := sum(M[i,:])+100 k in 1:1
96+
@test 4 .+ sum(M, dims=2) == @reduce _[i,k] := sum(j) M[i,j]+1 k in 1:1
97+
@test repeat(sum(M, dims=2), inner=(1,2)) == @reduce _[i,k] := sum(j) M[i,j] k in 1:2
98+
@test repeat(M, inner=(1,2)) == @cast rand(3,8)[i,(k,j)] = M[i,j]
99+
100+
@test_throws Exception @macroexpand @cast _[i] := A[i,j] # j doesn't appear on left
101+
# @test_throws Exception @macroexpand @cast _[i] := i+j (i in 1:2, j in 1:3)
102+
end
103+
84104
@testset "offset handling" begin
85105
using OffsetArrays
86106
@cast A[i,j] := i+10j (i in 0:1, j in 7:15)
@@ -91,6 +111,13 @@ end
91111
@cast B[(i,k),j] := A[i,(j,k)] k in 1:3
92112
@test axes(B) === (Base.OneTo(6), Base.OneTo(3))
93113

114+
@cast E[(j,i),k] := A[i,j]^2 (k in 10:11) # RHS indep k
115+
@test E[:,11] == vec(A') .^ 2
116+
117+
# dropdims
118+
@cast F[i,k,j] := A[i,j] k in 5:5 # a trivial dimension not indexed from 1
119+
@test A == @cast _[i,j] := F[i,_,j]
120+
94121
# reduction
95122
@reduce C[_,j] := sum(i) A[i,j]
96123
@test axes(C) == (1:1, 7:15)
@@ -156,8 +183,9 @@ end
156183
@test y == @cast _[i,j,k] := tuple(y[i,:,k]...)[j]
157184
end
158185

186+
using TensorCast: MacroError, _macro, CallInfo
159187
@testset "bugs" begin
160188
# scatter not handled, previously ignored -- https://github.com/mcabbott/TensorCast.jl/issues/49
161-
@test_throws TensorCast.MacroError @pretty @reduce v[ixs[j]] := mean(i) vs[i][j]
162-
@test_throws TensorCast.MacroError @pretty @cast v[i][ixs[j]] := vs[j][i]
189+
@test_throws MacroError _macro(:( v[ixs[j]] := mean(i) ),:( vs[i][j] ), call=CallInfo(:reduce))
190+
@test_throws MacroError _macro(:( v[i][ixs[j]] := vs[j][i] ), call=CallInfo())
163191
end

0 commit comments

Comments
 (0)