@@ -275,18 +275,18 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
275275 # Construct the blocks for each stage
276276 blocks = map(1 : nblocks) do block_idx
277277 branches = get_layers(stage_idx, block_idx)
278- return ( length(branches) == 1 ) ? only(branches) :
278+ return length(branches) == 1 ? only(branches) :
279279 Parallel(connection, branches... )
280280 end
281281 push!(stages, Chain(blocks... ))
282282 end
283283 return Chain(stages... )
284284end
285285
286- function resnet(img_dims, stem, builders , block_repeats:: AbstractVector{<:Integer} ,
286+ function resnet(img_dims, stem, get_layers , block_repeats:: AbstractVector{<:Integer} ,
287287 connection, classifier_fn)
288288 # Build stages of the ResNet
289- stage_blocks = resnet_stages(builders , block_repeats, connection)
289+ stage_blocks = resnet_stages(get_layers , block_repeats, connection)
290290 backbone = Chain(stem, stage_blocks)
291291 # Add classifier to the backbone
292292 nfeaturemaps = Flux. outputsize(backbone, img_dims; padbatch = true )[3 ]
@@ -308,17 +308,19 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
308308 if block_type == basicblock
309309 @assert cardinality== 1 " Cardinality must be 1 for `basicblock`"
310310 @assert base_width== 64 " Base width must be 64 for `basicblock`"
311- builder = basicblock_builder(block_repeats; inplanes, reduction_factor,
312- activation, norm_layer, revnorm, attn_fn,
313- drop_block_rate, drop_path_rate,
314- stride_fn = resnet_stride, planes_fn = resnet_planes,
315- downsample_tuple = downsample_opt, kwargs... )
311+ get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
312+ activation, norm_layer, revnorm, attn_fn,
313+ drop_block_rate, drop_path_rate,
314+ stride_fn = resnet_stride,
315+ planes_fn = resnet_planes,
316+ downsample_tuple = downsample_opt, kwargs... )
316317 elseif block_type == bottleneck
317- builder = bottleneck_builder(block_repeats; inplanes, cardinality,
318- base_width, reduction_factor, activation, norm_layer,
319- revnorm, attn_fn, drop_block_rate, drop_path_rate,
320- stride_fn = resnet_stride, planes_fn = resnet_planes,
321- downsample_tuple = downsample_opt, kwargs... )
318+ get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
319+ reduction_factor, activation, norm_layer, revnorm,
320+ attn_fn, drop_block_rate, drop_path_rate,
321+ stride_fn = resnet_stride,
322+ planes_fn = resnet_planes,
323+ downsample_tuple = downsample_opt, kwargs... )
322324 elseif block_type == bottle2neck
323325 @assert isnothing(drop_block_rate)
324326 " DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
@@ -337,8 +339,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
337339 end
338340 classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
339341 pool_layer, use_conv)
340- return resnet((imsize... , inchannels), stem, fill(builder, length( block_repeats)) ,
341- block_repeats, connection$ activation, classifier_fn)
342+ return resnet((imsize... , inchannels), stem, get_layers, block_repeats,
343+ connection$ activation, classifier_fn)
342344end
343345function resnet(block_fn, block_repeats, downsample_opt:: Symbol = :B; kwargs... )
344346 return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs... )
0 commit comments