Skip to content

Commit b11fbea

Browse files
Added generic MClayer with MCdense, MCconv layers with tests and an example of MC lenet5 on MNIST (#2)
* Added generic MClayer with MCdense, MCconv layers with tests and a single example of MC lenet5 on MNIST * fixed module import errors and added removed inner constructor for mclayer * Added statistics dependency and removed MCLayer inner constructor * Update src/layers/mclayers.jl reorder arguments to create conv layer Co-authored-by: Dhairya Gandhi <[email protected]> * Update src/layers/mclayers.jl Co-authored-by: Dhairya Gandhi <[email protected]> * Refactor MCLayer to allow arbitrary dropout functions by decoupling dropout and MCLayer forward pass * Added MClayer to the export list of the module Co-authored-by: Dhairya Gandhi <[email protected]>
1 parent 593321a commit b11fbea

File tree

9 files changed

+434
-20
lines changed

9 files changed

+434
-20
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.0'
1413
- '1.6'
15-
- 'nightly'
1614
os:
1715
- ubuntu-latest
1816
arch:

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ uuid = "f38f59f8-88e0-4e11-81d3-0c37501e3a95"
33
authors = ["DwaraknathT <[email protected]> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
CalibrationErrors = "33913031-fe46-5864-950f-100836f47845"
9+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
ReliabilityDiagrams = "e5f51471-6270-49e4-a15a-f1cfbff4f856"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
13+
614
[compat]
715
julia = "1"
816

docs/make.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
using DeepUncertainty
22
using Documenter
33

4-
DocMeta.setdocmeta!(DeepUncertainty, :DocTestSetup, :(using DeepUncertainty); recursive=true)
4+
DocMeta.setdocmeta!(
5+
DeepUncertainty,
6+
:DocTestSetup,
7+
:(using DeepUncertainty);
8+
recursive = true,
9+
)
510

611
makedocs(;
7-
modules=[DeepUncertainty],
8-
authors="DwaraknathT <[email protected]> and contributors",
9-
repo="https://github.com/DwaraknathT/DeepUncertainty.jl/blob/{commit}{path}#{line}",
10-
sitename="DeepUncertainty.jl",
11-
format=Documenter.HTML(;
12-
prettyurls=get(ENV, "CI", "false") == "true",
13-
canonical="https://DwaraknathT.github.io/DeepUncertainty.jl",
14-
assets=String[],
12+
modules = [DeepUncertainty],
13+
authors = "DwaraknathT <[email protected]> and contributors",
14+
repo = "https://github.com/aced-differentiate/DeepUncertainty.jl/blob/{commit}{path}#{line}",
15+
sitename = "DeepUncertainty.jl",
16+
format = Documenter.HTML(;
17+
prettyurls = get(ENV, "CI", "false") == "true",
18+
canonical = "https://DwaraknathT.github.io/DeepUncertainty.jl",
19+
assets = String[],
1520
),
16-
pages=[
17-
"Home" => "index.md",
18-
],
21+
pages = ["Home" => "index.md"],
1922
)
2023

21-
deploydocs(;
22-
repo="github.com/DwaraknathT/DeepUncertainty.jl",
23-
)
24+
deploydocs(; repo = "github.com/aced-differentiate/DeepUncertainty.jl")

examples/mcdropout.jl

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
using Base: AbstractFloat
2+
## Classification of MNIST dataset
3+
## with the convolutional neural network known as LeNet5.
4+
## This script also combines various
5+
## packages from the Julia ecosystem with Flux.
6+
using Flux
7+
using Flux.Data: DataLoader
8+
using Flux.Optimise: Optimiser, WeightDecay
9+
using Flux: onehotbatch, onecold, glorot_normal, label_smoothing
10+
using Flux.Losses: logitcrossentropy
11+
using Statistics, Random
12+
using Logging: with_logger
13+
using ProgressMeter: @showprogress
14+
import MLDatasets
15+
using CUDA
16+
using Formatting
17+
18+
using DeepUncertainty
19+
20+
# LeNet5 "constructor".
21+
# The model can be adapted to any image size
22+
# and any number of output classes.
23+
function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10)
24+
out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16)
25+
26+
return Chain(
27+
MCConv((5, 5), imgsize[end] => 6, args.dropout, relu),
28+
MaxPool((2, 2)),
29+
MCConv((5, 5), 6 => 16, args.dropout, relu),
30+
MaxPool((2, 2)),
31+
flatten,
32+
MCDense(prod(out_conv_size), 120, args.dropout, relu),
33+
MCDense(120, 84, args.dropout, relu),
34+
MCDense(84, nclasses, args.dropout),
35+
)
36+
end
37+
38+
function get_data(args)
39+
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
40+
xtest, ytest = MLDatasets.MNIST.testdata(Float32)
41+
42+
xtrain = reshape(xtrain, 28, 28, 1, :)
43+
xtest = reshape(xtest, 28, 28, 1, :)
44+
45+
ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)
46+
47+
train_loader = DataLoader(
48+
(xtrain, ytrain),
49+
batchsize = args.batchsize,
50+
shuffle = true,
51+
partial = false,
52+
)
53+
test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false)
54+
55+
return train_loader, test_loader
56+
end
57+
58+
loss(ŷ, y) = logitcrossentropy(ŷ, y)
59+
60+
function accuracy(preds, labels)
61+
acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu))
62+
return acc
63+
end
64+
65+
function eval_loss_accuracy(args, loader, model, device)
66+
l = [0.0f0 for x = 1:args.sample_size]
67+
acc = [0 for x = 1:args.sample_size]
68+
ece_list = [0.0f0 for x = 1:args.sample_size]
69+
ntot = 0
70+
mean_l = 0
71+
mean_acc = 0
72+
mean_ece = 0
73+
for (x, y) in loader
74+
predictions = []
75+
x, y = x |> device, y |> device
76+
77+
# Loop through each model's predictions
78+
for ensemble = 1:args.sample_size
79+
model_predictions = model(x)
80+
model_predictions = softmax(model_predictions, dims = 1)
81+
push!(predictions, model_predictions)
82+
# Calculate individual loss
83+
l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end]
84+
acc[ensemble] += accuracy(model_predictions, y)
85+
ece_list[ensemble] +=
86+
ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) *
87+
args.batchsize
88+
end
89+
# Get the mean predictions
90+
predictions = Flux.batch(predictions)
91+
mean_predictions = mean(predictions, dims = ndims(predictions))
92+
mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions))
93+
mean_l += loss(mean_predictions, y) * size(mean_predictions)[end]
94+
mean_acc += accuracy(mean_predictions, y)
95+
mean_ece +=
96+
ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) *
97+
args.batchsize
98+
ntot += size(mean_predictions)[end]
99+
end
100+
# Normalize the loss
101+
losses = [loss / ntot |> round4 for loss in l]
102+
acc = [a / ntot * 100 |> round4 for a in acc]
103+
ece_list = [x / ntot |> round4 for x in ece_list]
104+
# Calculate mean loss
105+
mean_l = mean_l / ntot |> round4
106+
mean_acc = mean_acc / ntot * 100 |> round4
107+
mean_ece = mean_ece / ntot |> round4
108+
109+
# Print the per ensemble mode loss and accuracy
110+
for ensemble = 1:args.sample_size
111+
@info (format(
112+
"Sample {} Loss: {} Accuracy: {} ECE: {}",
113+
ensemble,
114+
losses[ensemble],
115+
acc[ensemble],
116+
ece_list[ensemble],
117+
))
118+
end
119+
@info (format(
120+
"Mean Loss: {} Mean Accuracy: {} Mean ECE: {}",
121+
mean_l,
122+
mean_acc,
123+
mean_ece,
124+
))
125+
@info "==========================================================="
126+
return nothing
127+
end
128+
129+
## utility functions
130+
num_params(model) = sum(length, Flux.params(model))
131+
round4(x) = round(x, digits = 4)
132+
133+
# arguments for the `train` function
134+
Base.@kwdef mutable struct Args
135+
η = 3e-4 # learning rate
136+
λ = 0 # L2 regularizer param, implemented as weight decay
137+
batchsize = 32 # batch size
138+
epochs = 10 # number of epochs
139+
seed = 0 # set seed > 0 for reproducibility
140+
use_cuda = true # if true use cuda (if available)
141+
infotime = 1 # report every `infotime` epochs
142+
checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
143+
dropout = 0.1
144+
sample_size = 10
145+
end
146+
147+
function train(; kws...)
148+
args = Args(; kws...)
149+
args.seed > 0 && Random.seed!(args.seed)
150+
use_cuda = args.use_cuda && CUDA.functional()
151+
152+
if use_cuda
153+
device = gpu
154+
@info "Training on GPU"
155+
else
156+
device = cpu
157+
@info "Training on CPU"
158+
end
159+
160+
## DATA
161+
train_loader, test_loader = get_data(args)
162+
@info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"
163+
164+
## MODEL AND OPTIMIZER
165+
model = LeNet5(args) |> device
166+
@info "LeNet5 model: $(num_params(model)) trainable params"
167+
168+
ps = Flux.params(model)
169+
170+
opt = ADAM(args.η)
171+
if args.λ > 0 # add weight decay, equivalent to L2 regularization
172+
opt = Optimiser(WeightDecay(args.λ), opt)
173+
end
174+
175+
function report(epoch)
176+
@info "Test metrics"
177+
eval_loss_accuracy(args, test_loader, model, device)
178+
end
179+
180+
## TRAINING
181+
@info "Start Training"
182+
report(0)
183+
for epoch = 1:args.epochs
184+
@showprogress for (x, y) in train_loader
185+
# Make copies of batches for ensembles
186+
x = repeat(x, 1, 1, 1, args.sample_size)
187+
y = repeat(y, 1, args.sample_size)
188+
x, y = x |> device, y |> device
189+
gs = Flux.gradient(ps) do
190+
= model(x)
191+
loss(ŷ, y)
192+
end
193+
194+
Flux.Optimise.update!(opt, ps, gs)
195+
end
196+
197+
## Printing and logging
198+
epoch % args.infotime == 0 && report(epoch)
199+
end
200+
end
201+
202+
train()

src/DeepUncertainty.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
module DeepUncertainty
22

3-
# Write your package code here.
3+
# Export layers
4+
export MCLayer, MCDense, MCConv
5+
export mean_loglikelihood, brier_score, ExpectedCalibrationError, prediction_metrics
6+
7+
include("metrics.jl")
8+
include("layers/mclayers.jl")
49

510
end

src/layers/mclayers.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
using Flux
2+
using Random
3+
using Test
4+
using Flux: @functor
5+
6+
"""
7+
MCLayer(layer, dropout)
8+
A generic Monte Carlo dropout layer. Takes in any "traditional" flux
9+
layer and a function that implements dropout. Performs the usual layer
10+
forward pass and then passes the acitvations through the given dropout function.
11+
"""
12+
struct MCLayer{L,F}
13+
layer::L
14+
dropout::F
15+
end
16+
17+
@functor MCLayer
18+
19+
"""
20+
MCDense(in, out, dropout_rate, σ=identity; bias=true, init=glorot_uniform)
21+
MCDense(layer, dropout_rate)
22+
23+
Creates a traditional dense layer with MC dropout functionality.
24+
MC Dropout simply means that dropout is activated in both train and test times
25+
26+
Reference - Dropout as a bayesian approximation - https://arxiv.org/abs/1506.02142
27+
28+
The traditional dense layer is a field in the struct MCDense, so all the
29+
arguments required for the dense layer can be provided, or the layer can
30+
be provided too. The forward pass is the affine transformation of the dense
31+
layer followed by dropout applied on the resulting activations.
32+
33+
y = dropout(σ.(W * x .+ bias), dropout_rate)
34+
35+
# Fields
36+
- `layer`: A traditional dense layer
37+
- `dropout`: A function that implements dropout
38+
39+
# Arguments
40+
- `in::Integer`: Input dimension of features
41+
- `out::Integer`: Output dimension of features
42+
- `dropout_rate::AbstractFloat`: Dropout rate
43+
- `σ::F=identity`: Activation function, defaults to identity
44+
- `init=glorot_normal`: Initialization function, defaults to glorot_normal
45+
"""
46+
function MCDense(in::Integer, out::Integer, dropout_rate, σ = identity, kwargs...)
47+
layer = Flux.Dense(in, out, σ; kwargs...)
48+
dropout = (x; k...) -> Flux.dropout(x, dropout_rate; k...)
49+
return MCLayer(layer, dropout)
50+
end
51+
52+
"""
53+
MCConv(filter, in => out, σ = identity;
54+
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init])
55+
MCConv(layer, dropout_rate)
56+
57+
Creates a traditional Conv layer with MC dropout functionality.
58+
MC Dropout simply means that dropout is activated in both train and test times
59+
60+
Reference - Dropout as a bayesian approximation - https://arxiv.org/abs/1506.02142
61+
62+
The traditional conv layer is a field in the struct MCConv, so all the
63+
arguments required for the conv layer can be provided, or the layer can
64+
be provided too. The forward pass is the conv operation of the conv
65+
layer followed by dropout applied on the resulting activations.
66+
67+
y = dropout(Conv(x), dropout_rate)
68+
69+
# Fields
70+
- `layer`: A traditional conv layer
71+
- `dropout_rate::AbstractFloat`: Dropout rate
72+
73+
# Arguments
74+
- `filter::NTuple{N,Integer}`: Kernel dimensions, eg, (5, 5)
75+
- `ch::Pair{<:Integer,<:Integer}`: Input channels => output channels
76+
- `dropout_rate::AbstractFloat`: Dropout rate
77+
- `σ::F=identity`: Activation function, defaults to identity
78+
- `init=glorot_normal`: Initialization function, defaults to glorot_normal
79+
"""
80+
function MCConv(
81+
k::NTuple{N,Integer},
82+
ch::Pair{<:Integer,<:Integer},
83+
dropout_rate,
84+
σ = identity;
85+
kwargs...,
86+
) where {N}
87+
layer = Flux.Conv(k, ch, σ; kwargs...)
88+
dropout = (x; k...) -> Flux.dropout(x, dropout_rate; k...)
89+
return MCLayer(layer, dropout)
90+
end
91+
92+
function MCConv(
93+
w::AbstractArray{T,N},
94+
bias,
95+
dropout_rate,
96+
σ = identity,
97+
kwargs...,
98+
) where {T,N}
99+
layer = Flux.Conv(w, bias, σ, kwargs...)
100+
dropout = (x; k...) -> Flux.dropout(x, dropout_rate; k...)
101+
return MCLayer(layer, dropout)
102+
end
103+
104+
"""
105+
The forward pass of a MC layer: Passes the input through the
106+
usual layer first and then through a dropout layer.
107+
108+
# Arguments
109+
- `x`: Input tensors
110+
- `dropout=true`: Toggle to control dropout, it's preferred to keep
111+
dropout always on, but just in case if it's needed.
112+
"""
113+
function (mc::MCLayer)(x; dropout = true)
114+
# Layer forward pass
115+
# Dropout on activations
116+
output = mc.dropout(mc.layer(x); active = dropout)
117+
return output
118+
end

0 commit comments

Comments
 (0)