Skip to content

Commit 7c7f42f

Browse files
committed
Optimize code for composition
1 parent 468f927 commit 7c7f42f

File tree

5 files changed

+47
-7
lines changed

5 files changed

+47
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
1010
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431"
13+
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
1314

1415
[compat]
1516
ConstraintDomains = "0.2"
1617
Dictionaries = "0.3"
1718
Evolutionary = "0.9"
1819
OrderedCollections = "1"
1920
ThreadPools = "2"
21+
Unrolled = "0.1"
2022
julia = "1.6"
2123

2224
[extras]

src/CompositionalNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Evolutionary
77
using OrderedCollections
88
using Random
99
using ThreadPools
10+
using Unrolled
1011

1112
# Exports utilities
1213
export hamming

src/icn.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,19 @@ function _compose(icn::ICN)
123123
end
124124

125125
l = length(funcs[1])
126-
composition = (x; param=nothing, dom_size) -> fill(x, l) .|> map(f -> (y -> f(y; param=param)), funcs[1]) |> funcs[2][1] |> funcs[3][1] |> (y -> funcs[4][1](y; param=param, dom_size=dom_size, nvars=length(x)))
126+
127+
composition = (x; X=zeros(length(x), l), param=nothing, dom_size) -> if l == 1
128+
x |> (y -> funcs[1][1](y; param)) |> funcs[3][1] |>
129+
(y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
130+
else
131+
fill!(@view(X[1:length(x), :]), 0.0)
132+
tr_in(Tuple(funcs[1]), X, x, param)
133+
for i in 1:length(x)
134+
X[i,1] = funcs[2][1](@view X[i,:])
135+
end
136+
funcs[3][1](@view X[:, 1]) |> (y -> funcs[4][1](y; param, dom_size, nvars=length(x)))
137+
end
138+
127139
return composition, symbols
128140
end
129141

src/learn.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,24 @@ function compose_to_string(symbols, name)
111111
ag = reduce_symbols(symbols[3], ", ", false; prefix=CN * "ag_")
112112
co = reduce_symbols(symbols[4], ", ", false; prefix=CN * "co_")
113113

114-
julia_string = """
115-
function $name(x; param=nothing, dom_size)
116-
fill(x, $tr_length) .|> map(f -> (y -> f(y; param=param)), $tr) |> $ar |> $ag |> (y -> $co(y; param=param, dom_size=dom_size, nvars=length(x)))
114+
return if tr_length == 1
115+
"""
116+
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
117+
x |> (y -> $tr[1](y; param)) |> $ag |> (y -> $co(y; param, dom_size, nvars=length(x)))
118+
end
119+
"""
120+
else
121+
"""
122+
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
123+
fill!(@view(X[1:length(x), :]), 0.0)
124+
$(CN)tr_in(Tuple($tr), X, x, param)
125+
for i in 1:length(x)
126+
X[i,1] = $ar(@view X[i,:])
127+
end
128+
return $ag(@view X[:, 1]) |> (y -> $co(y; param, dom_size, nvars=length(x)))
129+
end
130+
"""
117131
end
118-
"""
119-
120-
return julia_string
121132
end
122133

123134
"""

src/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,17 @@ end
6161
function incsert!(d::Dictionary, ind)
6262
set!(d, ind, isassigned(d, ind) ? d[ind] + 1 : 1)
6363
end
64+
65+
@unroll function tr_in(tr, X, x, param)
66+
@unroll for i in 1:length(tr)
67+
X[:,i] = tr[i](x; param)
68+
end
69+
end
70+
71+
# TODO: look for a length limit that make it slow or space-comsuming
72+
# TODO: handle SMatrix
73+
# @unroll function ar_in(ar, X, x)
74+
# @unroll for i in 1:length(x)
75+
# X[i, 1] = ar(@view X[i, :])
76+
# end
77+
# end

0 commit comments

Comments
 (0)