@@ -275,18 +275,18 @@ function resnet_stages(get_layers, block_repeats::AbstractVector{<:Integer}, con
275
275
# Construct the blocks for each stage
276
276
blocks = map (1 : nblocks) do block_idx
277
277
branches = get_layers (stage_idx, block_idx)
278
- return ( length (branches) == 1 ) ? only (branches) :
278
+ return length (branches) == 1 ? only (branches) :
279
279
Parallel (connection, branches... )
280
280
end
281
281
push! (stages, Chain (blocks... ))
282
282
end
283
283
return Chain (stages... )
284
284
end
285
285
286
- function resnet (img_dims, stem, builders , block_repeats:: AbstractVector{<:Integer} ,
286
+ function resnet (img_dims, stem, get_layers , block_repeats:: AbstractVector{<:Integer} ,
287
287
connection, classifier_fn)
288
288
# Build stages of the ResNet
289
- stage_blocks = resnet_stages (builders , block_repeats, connection)
289
+ stage_blocks = resnet_stages (get_layers , block_repeats, connection)
290
290
backbone = Chain (stem, stage_blocks)
291
291
# Add classifier to the backbone
292
292
nfeaturemaps = Flux. outputsize (backbone, img_dims; padbatch = true )[3 ]
@@ -308,17 +308,19 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
308
308
if block_type == basicblock
309
309
@assert cardinality== 1 " Cardinality must be 1 for `basicblock`"
310
310
@assert base_width== 64 " Base width must be 64 for `basicblock`"
311
- builder = basicblock_builder (block_repeats; inplanes, reduction_factor,
312
- activation, norm_layer, revnorm, attn_fn,
313
- drop_block_rate, drop_path_rate,
314
- stride_fn = resnet_stride, planes_fn = resnet_planes,
315
- downsample_tuple = downsample_opt, kwargs... )
311
+ get_layers = basicblock_builder (block_repeats; inplanes, reduction_factor,
312
+ activation, norm_layer, revnorm, attn_fn,
313
+ drop_block_rate, drop_path_rate,
314
+ stride_fn = resnet_stride,
315
+ planes_fn = resnet_planes,
316
+ downsample_tuple = downsample_opt, kwargs... )
316
317
elseif block_type == bottleneck
317
- builder = bottleneck_builder (block_repeats; inplanes, cardinality,
318
- base_width, reduction_factor, activation, norm_layer,
319
- revnorm, attn_fn, drop_block_rate, drop_path_rate,
320
- stride_fn = resnet_stride, planes_fn = resnet_planes,
321
- downsample_tuple = downsample_opt, kwargs... )
318
+ get_layers = bottleneck_builder (block_repeats; inplanes, cardinality, base_width,
319
+ reduction_factor, activation, norm_layer, revnorm,
320
+ attn_fn, drop_block_rate, drop_path_rate,
321
+ stride_fn = resnet_stride,
322
+ planes_fn = resnet_planes,
323
+ downsample_tuple = downsample_opt, kwargs... )
322
324
elseif block_type == bottle2neck
323
325
@assert isnothing (drop_block_rate)
324
326
" DropBlock not supported for `bottle2neck`. Set `drop_block_rate` to nothing"
@@ -337,8 +339,8 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer},
337
339
end
338
340
classifier_fn = nfeatures -> create_classifier (nfeatures, nclasses; dropout_rate,
339
341
pool_layer, use_conv)
340
- return resnet ((imsize... , inchannels), stem, fill (builder, length ( block_repeats)) ,
341
- block_repeats, connection$ activation, classifier_fn)
342
+ return resnet ((imsize... , inchannels), stem, get_layers, block_repeats,
343
+ connection$ activation, classifier_fn)
342
344
end
343
345
function resnet (block_fn, block_repeats, downsample_opt:: Symbol = :B ; kwargs... )
344
346
return resnet (block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs... )
0 commit comments