Skip to content

Commit 9755d57

Browse files
authored
Merge pull request #1881 from theabhirath/diagonal
2 parents 5e05de7 + 0b06bcb commit 9755d57

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

src/layers/basic.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,31 +172,40 @@ function Base.show(io::IO, l::Dense)
172172
end
173173

174174
"""
175-
Diagonal(α, β)
176-
Diagonal(size::Integer...)
175+
Diagonal(size::Integer...; bias=true, init=ones32)
176+
Diagonal(scale::AbstractArray, [bias])
177177
178178
Create an element-wise linear layer, which performs
179179
180-
y = α .* x .+ β
180+
y = scale .* x .+ bias
181181
182-
The learnable arrays are initialised `α = ones(Float32, size)` and
183-
`β = zeros(Float32, size)`.
182+
with no activation function.
183+
184+
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
185+
with `init=ones32` by default. You may specify the function `init`,
186+
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
184187
185188
Used by [`LayerNorm`](@ref).
186189
"""
187-
struct Diagonal{T}
188-
α::T
189-
β::T
190+
struct Diagonal{A<:AbstractArray, B}
191+
scale::A
192+
bias::B
193+
function Diagonal(W::M, bias = true) where M<:AbstractArray
194+
b = create_bias(W, bias, size(W)...)
195+
new{M, typeof(b)}(W, b)
196+
end
190197
end
191198

192-
Diagonal(sz::Integer...) = Diagonal(ones32(sz...), zeros32(sz...))
199+
Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)
193200

194201
@functor Diagonal
195202

196-
(a::Diagonal)(x) = a.α .* x .+ a.β
203+
(a::Diagonal)(x) = a.scale .* x .+ a.bias
197204

198205
function Base.show(io::IO, l::Diagonal)
199-
print(io, "Diagonal(", join(size(l.α), ", "), ")")
206+
print(io, "Diagonal(", join(size(l.scale), ", "))
207+
l.bias == false && print(io, "; bias=false")
208+
print(io, ")")
200209
end
201210

202211
"""

test/layers/basic.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,17 @@ import Flux: activations
9191
@test length(Flux.Diagonal(10)(randn(10))) == 10
9292
@test length(Flux.Diagonal(10)(1)) == 10
9393
@test length(Flux.Diagonal(10)(randn(1))) == 10
94+
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
9495
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
9596

9697
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
9798
@test Flux.Diagonal(2)([1,2]) == [1,2]
98-
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
99+
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
99100

100101
@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
101102
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
102-
@test Flux.Diagonal(2,3,4)(rand(2,3,4)) |> size == (2, 3, 4)
103-
@test Flux.Diagonal(2,3)(rand(2,1,4)) |> size == (2, 3, 4)
103+
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
104+
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
104105
end
105106

106107
@testset "Maxout" begin

0 commit comments

Comments
 (0)