Skip to content

Commit 5e801e2

Browse files
authored
Merge pull request #1 from FluxML/master
Rebase
2 parents 418b316 + cdb445c commit 5e801e2

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
135135
end
136136

137137
kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...)
138-
kaiming_uniform(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; kwargs...)
138+
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)
139139

140140
"""
141141
kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)
@@ -172,7 +172,7 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
172172
end
173173

174174
kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
175-
kaiming_normal(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; kwargs...)
175+
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)
176176

177177
"""
178178
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
@@ -275,7 +275,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
275275
end
276276

277277
sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
278-
sparse_init(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; kwargs...)
278+
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)
279279

280280
ones(T::Type, dims...) = Base.ones(T, dims...)
281281
zeros(T::Type, dims...) = Base.zeros(T, dims...)

test/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,22 @@ end
134134
@test eltype(v) == Float32
135135
end
136136
end
137+
138+
@testset "partial_application" begin
139+
big = 1e9
140+
141+
partial_ku = kaiming_uniform(gain=big)
142+
@test maximum(partial_ku(8, 8)) > big / 2
143+
@test maximum(partial_ku(8, 8, gain=1)) < big / 2
144+
145+
partial_kn = kaiming_normal(gain=big)
146+
@test maximum(partial_kn(8, 8)) > big / 2
147+
@test maximum(partial_kn(8, 8, gain=1)) < big / 2
148+
149+
partial_si = sparse_init(sparsity=1)
150+
@test maximum(partial_si(8, 8)) == 0
151+
@test maximum(partial_si(8, 8, sparsity=0)) > 0
152+
end
137153
end
138154

139155
@testset "Params" begin

0 commit comments

Comments
 (0)