Skip to content

Conversation

theabhirath
Copy link
Member

This PR tries to extend some of the functionality for the built-in layer Diagonal following the discussions in FluxML/Metalhead.jl#119. This makes very minimal changes to Diagonal to allow it to function without a bias and also take in a weight initialisation strategy for the weight matrix. This has the advantage that:

  1. It allows for more flexibility to implement custom layers like LayerScale; and
  2. It leaves the current working unchanged (what worked before still will without any modifications)

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@theabhirath
Copy link
Member Author

PS: a weird error occurs when I write the forward pass as (a::Diagonal)(x) = a.α .* x .+ a.β (the implementation in the PR has two lines for this reason):

ERROR: MethodError: no method matching length(::Flux.Zeros)
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at ~/julia/usr/share/julia/base/abstractdict.jl:58
  length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArrays.StaticVector{<:Any, T}, StaticArrays.StaticMatrix{<:Any, <:Any, T}}}, LinearAlgebra.Diagonal{T, <:StaticArrays.StaticVector{<:Any, T}}, LinearAlgebra.Hermitian{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.LowerTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.Symmetric{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.Transpose{T, <:Union{StaticArrays.StaticVector{<:Any, T}, StaticArrays.StaticMatrix{<:Any, <:Any, T}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.UpperTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, StaticArrays.StaticVector{<:Any, T}, StaticArrays.StaticMatrix{<:Any, <:Any, T}, StaticArrays.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/h2TVC/src/abstractarray.jl:1
  length(::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S}) at ~/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:172
  ...
Stacktrace:
 [1] _similar_shape(itr::Flux.Zeros, #unused#::Base.HasLength)
   @ Base ./array.jl:660
 [2] _collect(cont::UnitRange{Int64}, itr::Flux.Zeros, #unused#::Base.HasEltype, isz::Base.HasLength)
   @ Base ./array.jl:715
 [3] collect(itr::Flux.Zeros)
   @ Base ./array.jl:709
 [4] broadcastable(x::Flux.Zeros)
   @ Base.Broadcast ./broadcast.jl:704
 [5] broadcasted
   @ ./broadcast.jl:1302 [inlined]
 [6] (::Flux.Diagonal{Vector{Float32}, Flux.Zeros})(x::Vector{Float64})
   @ Flux ~/Code/Flux.jl/src/layers/basic.jl:207
 [7] top-level scope
   @ REPL[2]:1

1. Allows for more flexibility to implement custom layers like LayerScale
2. Leaves current working unchanged (what worked before still will without any modifications)
@codecov-commenter
Copy link

codecov-commenter commented Feb 19, 2022

Codecov Report

Merging #1881 (cdbc5c5) into master (13a65be) will increase coverage by 0.19%.
The diff coverage is 80.73%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1881      +/-   ##
==========================================
+ Coverage   86.03%   86.23%   +0.19%     
==========================================
  Files          19       18       -1     
  Lines        1411     1438      +27     
==========================================
+ Hits         1214     1240      +26     
- Misses        197      198       +1     
Impacted Files Coverage Δ
src/layers/show.jl 72.72% <ø> (ø)
src/layers/stateless.jl 100.00% <ø> (ø)
src/layers/basic.jl 78.29% <55.00%> (-2.86%) ⬇️
src/functor.jl 88.33% <60.00%> (+1.66%) ⬆️
src/deprecations.jl 39.13% <69.23%> (+39.13%) ⬆️
src/utils.jl 92.25% <80.00%> (-1.99%) ⬇️
src/layers/recurrent.jl 87.80% <90.90%> (+2.26%) ⬆️
src/cuda/cudnn.jl 100.00% <100.00%> (+11.11%) ⬆️
src/layers/conv.jl 80.32% <100.00%> (-0.22%) ⬇️
src/layers/normalise.jl 83.33% <100.00%> (+0.20%) ⬆️
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 13a65be...cdbc5c5. Read the comment docs.

@darsnack
Copy link
Member

darsnack commented Feb 19, 2022

Probably need to define this in zeros.jl.

Base.broadcastable(::Zeros) = Ref(Zeros())

@theabhirath
Copy link
Member Author

Probably need to define this in zeros.jl.

Base.broadcastable(::Zeros) = Ref(Zeros())

This didn't help, the error message changed to:

ERROR: MethodError: no method matching +(::Float64, ::Flux.Zeros)
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/julia/usr/share/julia/base/operators.jl:591
  +(::T, ::T) where T<:Union{Float16, Float32, Float64} at ~/julia/usr/share/julia/base/float.jl:377
  +(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/julia/usr/share/julia/base/mpfr.jl:407
  ...
Stacktrace:
 [1] _broadcast_getindex_evalf
   @ ./broadcast.jl:670 [inlined]
 [2] _broadcast_getindex
   @ ./broadcast.jl:643 [inlined]
 [3] getindex
   @ ./broadcast.jl:597 [inlined]
 [4] copy
   @ ./broadcast.jl:899 [inlined]
 [5] materialize
   @ ./broadcast.jl:860 [inlined]
 [6] (::Flux.Diagonal{Array{Float32, 3}, Flux.Zeros})(x::Array{Float64, 3})
   @ Flux ~/Code/Flux.jl/src/layers/basic.jl:207
 [7] top-level scope
   @ REPL[2]:1

There's also broadcasting rules defined already, and astonishingly an extremely similar definition works for the Dense layer 🤔

@darsnack
Copy link
Member

There's also broadcasting rules defined already, and astonishingly an extremely similar definition works for the Dense layer

Yes, that's cause Dense only has one level of broadcasting. Zeros probably never worked for any complex nested broadcasting. This basically leaves us with making Zeros a full blown AbstractArray with dimensions or defining scalar rules as well.

Alternatively, we can consider replacing Zeros with FillArrays. What do you think @mcabbott?

@mcabbott
Copy link
Member

We can also just delete Zeros completely.

@darsnack
Copy link
Member

Are you up for reviving your efforts in that direction? I'm all for it.

@theabhirath
Copy link
Member Author

I'm a little uncertain - I'm happy to incorporate the changes but given that create_bias and Zeros seem to be a point of contention, I think I'll let this stew for a bit until there's a clear direction regarding that?

@mcabbott mcabbott mentioned this pull request Feb 19, 2022
1. Allows for more flexibility to implement custom layers like LayerScale
2. Leaves current working unchanged (what worked before still will without any modifications)
@theabhirath theabhirath changed the base branch from master to ad-overhaul March 5, 2022 17:34
@theabhirath theabhirath changed the base branch from ad-overhaul to master March 5, 2022 17:35
@theabhirath theabhirath requested review from mcabbott and darsnack March 5, 2022 17:56
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))

@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
@test Flux.Diagonal(2)([1,2]) == [1,2]
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4]
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does these tests need bias=false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't really, I kept those just to check that bias=false doesn't trip anything

Co-authored-by: Michael Abbott <[email protected]>
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me.

There's a style question about whether we want spaces around keywords. Mostly Flux don't have such in docstrings. I am not sure whether Kyle's comment #1881 (comment) is suggesting to remove them? Which is what my suggestions below would do.

@ToucheSir
Copy link
Member

Per #1765 I don't think that style question is settled, and outside of installing a autoformatter to enforce it don't personally have the energy to try to litigate it.

@mcabbott
Copy link
Member

mcabbott commented Mar 5, 2022

Yes, I don't mean to litigate it. This may not have been what Kyle meant above. Approved by me with or without.

@darsnack
Copy link
Member

darsnack commented Mar 5, 2022

Yeah no spaces is what I meant, but I would be very happy with a decision to litigating this stuff! Approved by me then.

@mcabbott
Copy link
Member

mcabbott commented Mar 5, 2022

Ok, done. Can you approve though, as your red card still blocks my button? Or just merge, in fact.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops my bad

@darsnack darsnack merged commit 9755d57 into FluxML:master Mar 5, 2022
@darsnack
Copy link
Member

darsnack commented Mar 5, 2022

Also, I meant above to say I would be happy not to litigate 😅

@theabhirath theabhirath deleted the diagonal branch March 6, 2022 00:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants