@@ -172,31 +172,40 @@ function Base.show(io::IO, l::Dense)
172172end
173173
174174"""
175- Diagonal(α, β )
176- Diagonal(size::Integer... )
175+ Diagonal(size::Integer...; bias=true, init=ones32 )
176+ Diagonal(scale::AbstractArray, [bias] )
177177
178178Create an element-wise linear layer, which performs
179179
180- y = α .* x .+ β
180+ y = scale .* x .+ bias
181181
182- The learnable arrays are initialised `α = ones(Float32, size)` and
183- `β = zeros(Float32, size)`.
182+ with no activation function.
183+
184+ The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
185+ with `init=ones32` by default. You may specify the function `init`,
186+ turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
184187
185188Used by [`LayerNorm`](@ref).
186189"""
187- struct Diagonal{T}
188- α:: T
189- β:: T
190+ struct Diagonal{A<: AbstractArray , B}
191+ scale:: A
192+ bias:: B
193+ function Diagonal (W:: M , bias = true ) where M<: AbstractArray
194+ b = create_bias (W, bias, size (W)... )
195+ new {M, typeof(b)} (W, b)
196+ end
190197end
191198
192- Diagonal (sz:: Integer... ) = Diagonal (ones32 (sz... ), zeros32 (sz ... ) )
199+ Diagonal (sz:: Integer... ; bias = true , init = ones32 ) = Diagonal (init (sz... ), bias )
193200
194201@functor Diagonal
195202
196- (a:: Diagonal )(x) = a. α .* x .+ a. β
203+ (a:: Diagonal )(x) = a. scale .* x .+ a. bias
197204
198205function Base. show (io:: IO , l:: Diagonal )
199- print (io, " Diagonal(" , join (size (l. α), " , " ), " )" )
206+ print (io, " Diagonal(" , join (size (l. scale), " , " ))
207+ l. bias == false && print (io, " ; bias=false" )
208+ print (io, " )" )
200209end
201210
202211"""
0 commit comments