Skip to content

Commit fc3eff4

Browse files
committed
rm Flux.Zeros, take N+1
1 parent 57ef5c0 commit fc3eff4

File tree

11 files changed

+54
-151
lines changed

11 files changed

+54
-151
lines changed

src/Flux.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ using CUDA
3636
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
3737

3838
include("utils.jl")
39-
include("zeros.jl")
4039
include("onehot.jl")
4140
include("functor.jl")
4241

src/deprecations.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32,
1919

2020
@deprecate frequencies(xs) group_counts(xs)
2121

22+
struct Zeros
23+
function Zeros()
24+
Base.depwarn("Flux.Zeros is no more, has ceased to be, is bereft of life, is an ex-boondoggle... please use bias=false instead", :Zeros)
25+
false
26+
end
27+
end
28+
Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())
29+
2230
# Channel notation: Changed to match Conv, but very softly deprecated!
2331
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
2432
Dense(in::Integer, out::Integer, σ = identity; kw...) =

src/layers/basic.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
function Base.show(io::IO, l::Dense)
168168
print(io, "Dense(", size(l.weight, 2), " => ", size(l.weight, 1))
169169
l.σ == identity || print(io, ", ", l.σ)
170-
l.bias == Zeros() && print(io, "; bias=false")
170+
l.bias == false && print(io, "; bias=false")
171171
print(io, ")")
172172
end
173173

@@ -394,7 +394,11 @@ function Base.show(io::IO, l::Bilinear)
394394
print(io, "Bilinear((", size(l.weight, 2), ", ", size(l.weight, 3), ") => ", size(l.weight, 1))
395395
end
396396
l.σ == identity || print(io, ", ", l.σ)
397+
<<<<<<< HEAD
397398
l.bias == Flux.Zeros() && print(io, "; bias=false")
399+
=======
400+
l.bias === false && print(io, ", bias=false")
401+
>>>>>>> 1ef2cd377 (rm Flux.Zeros, take N+1)
398402
print(io, ")")
399403
end
400404

src/layers/conv.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
66
expand(N, i::Tuple) = i
77
expand(N, i::Integer) = ntuple(_ -> i, N)
88

9+
conv_reshape_bias(c) = c.bias isa AbstractVector ?
10+
reshape(c.bias, map(_->1, c.stride)..., :, 1) :
11+
c.bias
12+
913
"""
1014
SamePad()
1115
@@ -61,8 +65,8 @@ Then:
6165
6266
Keywords to control initialization of the layer:
6367
* `init` - Function used to generate initial weights. Defaults to `glorot_uniform`.
64-
* `bias` - Initial bias is zero by default, this can be disabled entirely by setting it to
65-
`false`, or another vector explicitly as `bias = randn(Float32, out)`.
68+
* `bias` - The initial bias vector is all zero by default. Trainable bias can be disabled entirely
69+
by setting this to `false`, or another vector can be provided such as `bias = randn(Float32, out)`.
6670
6771
See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
6872
@@ -159,10 +163,9 @@ end
159163
@functor Conv
160164

161165
function (c::Conv)(x::AbstractArray)
162-
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
163166
σ = NNlib.fast_act(c.σ, x)
164167
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
165-
σ.(conv(x, c.weight, cdims) .+ b)
168+
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
166169
end
167170

168171
_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -183,7 +186,7 @@ function _print_conv_opt(io::IO, l)
183186
if hasproperty(l, :groups)
184187
(l.groups == 1) || print(io, ", groups=", l.groups)
185188
end
186-
(l.bias isa Zeros) && print(io, ", bias=false")
189+
(l.bias === false) && print(io, ", bias=false")
187190
end
188191

189192
"""
@@ -277,10 +280,9 @@ end
277280
@nograd conv_transpose_dims
278281

279282
function (c::ConvTranspose)(x::AbstractArray)
280-
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
281283
σ = NNlib.fast_act(c.σ, x)
282284
cdims = conv_transpose_dims(c, x)
283-
σ.(∇conv_data(x, c.weight, cdims) .+ b)
285+
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
284286
end
285287

286288
function Base.show(io::IO, l::ConvTranspose)
@@ -372,10 +374,9 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
372374
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
373375

374376
function (c::DepthwiseConv)(x)
375-
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
376377
σ = NNlib.fast_act(c.σ, x)
377378
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
378-
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
379+
σ.(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c))
379380
end
380381

381382
function Base.show(io::IO, l::DepthwiseConv)
@@ -453,10 +454,9 @@ function crosscor(x, w, ddims::DenseConvDims)
453454
end
454455

455456
function (c::CrossCor)(x::AbstractArray)
456-
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
457457
σ = NNlib.fast_act(c.σ, x)
458458
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
459-
σ.(crosscor(x, c.weight, cdims) .+ b)
459+
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
460460
end
461461

462462
function Base.show(io::IO, l::CrossCor)

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,17 +441,17 @@ rand32(dims...) = Base.rand(Float32, dims...)
441441
randn32(dims...) = Base.randn(Float32, dims...)
442442

443443
"""
444-
create_bias(weights, bias, length)
444+
create_bias(weights, bias, size...)
445445
446446
Return a bias parameter for a layer, based on the value given
447447
to the constructor's keyword `bias=bias`.
448448
449449
* `bias == true` creates a zero vector, of the same type as weights.
450-
* `bias == false` returns `Zeros()`, a special struct which exists only to encode the absence of bias.
450+
* `bias == false` returns `false` now, which is understood by AD to be non-differentiable.
451451
* `bias::AbstractArray` uses the array provided, provided it has the correct size and eltype. If the type is wrong, it will be converted.
452452
"""
453453
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
454-
bias ? fill!(similar(weights, dims...), 0) : Zeros()
454+
bias ? fill!(similar(weights, dims...), 0) : false
455455
end
456456
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
457457
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))

src/zeros.jl

Lines changed: 0 additions & 52 deletions
This file was deleted.

test/cuda/layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ end
155155
end
156156
end
157157

158-
@testset "Dense with Zeros bias" begin
159-
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
158+
@testset "Dense without bias" begin
159+
l = Dense(ones(Float32, 4, 3), false) |> gpu
160160
ip = zeros(Float32, 3, 7) |> gpu
161161

162162
@test sum(l(ip)) 0.f0

test/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ import Flux: activations
175175
@test b1.σ == identity
176176

177177
b2 = Flux.Bilinear(randn(3,4,5), false)
178-
@test b2.bias == Flux.Zeros()
178+
@test b2.bias === false
179179

180180
b3 = Flux.Bilinear(randn(Float16, 3,4,5), true, tanh)
181181
@test b3.σ == tanh

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ end
273273

274274
@testset "constructors: $fun" for fun in [Conv, CrossCor, ConvTranspose, DepthwiseConv]
275275
@test fun(rand(2,3,4)).bias isa Vector{Float64}
276-
@test fun(rand(2,3,4,5), false).bias isa Flux.Zeros
276+
@test fun(rand(2,3,4,5), false).bias === false
277277
if fun == Conv
278278
@test fun(rand(2,3,4,5,6), rand(6)).bias isa Vector{Float64}
279279
@test_skip fun(rand(2,3,4,5,6), 1:6).bias isa Vector{Float64}

test/optimise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Random
1515
Nesterov(), RMSProp(), Momentum()]
1616
Random.seed!(42)
1717
w′ = randn(10, 10)
18-
b = Flux.Zeros()
18+
b = false
1919
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
2020
for t = 1: 10^5
2121
θ = params([w′, b])

0 commit comments

Comments
 (0)