@@ -60,15 +60,16 @@ function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_p
6060end
6161
6262"""
63- EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
63+ EfficientNet(config::Symbol; pretrain::Union{ Bool,String} = false, inchannels::Integer = 3,
6464 nclasses::Integer = 1000)
6565
6666Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
6767
6868# Arguments
6969
7070 - `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`.
71- - `pretrain`: set to `true` to load the pre-trained weights for ImageNet
71+ - `pretrain`: set to `true` to load the pre-trained weights for ImageNet, or provide a local path string to load a
72+ custom weights file.
7273 - `inchannels`: number of input channels.
7374 - `nclasses`: number of output classes.
7475
@@ -83,12 +84,16 @@ struct EfficientNet
8384end
8485@functor EfficientNet
8586
86- function EfficientNet (config:: Symbol ; pretrain:: Bool = false , inchannels:: Integer = 3 ,
87+ function EfficientNet (config:: Symbol ; pretrain:: Union{ Bool,String} = false , inchannels:: Integer = 3 ,
8788 nclasses:: Integer = 1000 )
8889 layers = efficientnet (config; inchannels, nclasses)
8990 model = EfficientNet (layers)
90- if pretrain
91+ if pretrain === true
9192 loadpretrain! (model, string (" efficientnet_" , config))
93+ elseif pretrain isa String
94+ isfile (pretrain) || error (" Weights file does not exist at `pretrain`" )
95+ m = load_weights_file (pretrain)
96+ Flux. loadmodel! (model, m)
9297 end
9398 return model
9499end
0 commit comments