-
-
Notifications
You must be signed in to change notification settings - Fork 614
Extending Diagonal
#1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extending Diagonal
#1881
Conversation
PS: a weird error occurs when I write the forward pass as ERROR: MethodError: no method matching length(::Flux.Zeros)
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at ~/julia/usr/share/julia/base/abstractdict.jl:58
length(::Union{LinearAlgebra.Adjoint{T, <:Union{StaticArrays.StaticVector{<:Any, T}, StaticArrays.StaticMatrix{<:Any, <:Any, T}}}, LinearAlgebra.Diagonal{T, <:StaticArrays.StaticVector{<:Any, T}}, LinearAlgebra.Hermitian{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.LowerTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.Symmetric{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.Transpose{T, <:Union{StaticArrays.StaticVector{<:Any, T}, StaticArrays.StaticMatrix{<:Any, <:Any, T}}}, LinearAlgebra.UnitLowerTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.UnitUpperTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, LinearAlgebra.UpperTriangular{T, <:StaticArrays.StaticMatrix{<:Any, <:Any, T}}, StaticArrays.StaticVector{<:Any, T}, StaticArrays.StaticMatrix{<:Any, <:Any, T}, StaticArrays.StaticArray{<:Tuple, T}} where T) at ~/.julia/packages/StaticArrays/h2TVC/src/abstractarray.jl:1
length(::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S}) at ~/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:172
...
Stacktrace:
[1] _similar_shape(itr::Flux.Zeros, #unused#::Base.HasLength)
@ Base ./array.jl:660
[2] _collect(cont::UnitRange{Int64}, itr::Flux.Zeros, #unused#::Base.HasEltype, isz::Base.HasLength)
@ Base ./array.jl:715
[3] collect(itr::Flux.Zeros)
@ Base ./array.jl:709
[4] broadcastable(x::Flux.Zeros)
@ Base.Broadcast ./broadcast.jl:704
[5] broadcasted
@ ./broadcast.jl:1302 [inlined]
[6] (::Flux.Diagonal{Vector{Float32}, Flux.Zeros})(x::Vector{Float64})
@ Flux ~/Code/Flux.jl/src/layers/basic.jl:207
[7] top-level scope
@ REPL[2]:1 |
1. Allows for more flexibility to implement custom layers like LayerScale 2. Leaves current working unchanged (what worked before still will without any modifications)
Codecov Report
@@ Coverage Diff @@
## master #1881 +/- ##
==========================================
+ Coverage 86.03% 86.23% +0.19%
==========================================
Files 19 18 -1
Lines 1411 1438 +27
==========================================
+ Hits 1214 1240 +26
- Misses 197 198 +1
Continue to review full report at Codecov.
|
Probably need to define this in Base.broadcastable(::Zeros) = Ref(Zeros()) |
This didn't help, the error message changed to: ERROR: MethodError: no method matching +(::Float64, ::Flux.Zeros)
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/julia/usr/share/julia/base/operators.jl:591
+(::T, ::T) where T<:Union{Float16, Float32, Float64} at ~/julia/usr/share/julia/base/float.jl:377
+(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/julia/usr/share/julia/base/mpfr.jl:407
...
Stacktrace:
[1] _broadcast_getindex_evalf
@ ./broadcast.jl:670 [inlined]
[2] _broadcast_getindex
@ ./broadcast.jl:643 [inlined]
[3] getindex
@ ./broadcast.jl:597 [inlined]
[4] copy
@ ./broadcast.jl:899 [inlined]
[5] materialize
@ ./broadcast.jl:860 [inlined]
[6] (::Flux.Diagonal{Array{Float32, 3}, Flux.Zeros})(x::Array{Float64, 3})
@ Flux ~/Code/Flux.jl/src/layers/basic.jl:207
[7] top-level scope
@ REPL[2]:1 There's also broadcasting rules defined already, and astonishingly an extremely similar definition works for the Dense layer 🤔 |
Yes, that's cause Alternatively, we can consider replacing |
We can also just delete Zeros completely. |
Are you up for reviving your efforts in that direction? I'm all for it. |
I'm a little uncertain - I'm happy to incorporate the changes but given that |
1. Allows for more flexibility to implement custom layers like LayerScale 2. Leaves current working unchanged (what worked before still will without any modifications)
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2)) | ||
|
||
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] | ||
@test Flux.Diagonal(2)([1,2]) == [1,2] | ||
@test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] | ||
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does these tests need bias=false
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They don't really, I kept those just to check that bias=false
doesn't trip anything
Co-authored-by: Michael Abbott <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine to me.
There's a style question about whether we want spaces around keywords. Mostly Flux don't have such in docstrings. I am not sure whether Kyle's comment #1881 (comment) is suggesting to remove them? Which is what my suggestions below would do.
Per #1765 I don't think that style question is settled, and outside of installing a autoformatter to enforce it don't personally have the energy to try to litigate it. |
Yes, I don't mean to litigate it. This may not have been what Kyle meant above. Approved by me with or without. |
Yeah no spaces is what I meant, but I would be very happy with a decision to litigating this stuff! Approved by me then. |
Ok, done. Can you approve though, as your red card still blocks my button? Or just merge, in fact. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops my bad
Also, I meant above to say I would be happy not to litigate 😅 |
This PR tries to extend some of the functionality for the built-in layer
Diagonal
following the discussions in FluxML/Metalhead.jl#119. This makes very minimal changes to Diagonal to allow it to function without a bias and also take in a weight initialisation strategy for the weight matrix. This has the advantage that:PR Checklist