@@ -117,12 +117,12 @@ def build_vgg(self):
117
117
118
118
def build_generator (self ):
119
119
120
- def residual_block (layer_input ):
120
+ def residual_block (layer_input , filters ):
121
121
"""Residual block described in paper"""
122
- d = Conv2D (64 , kernel_size = 3 , strides = 1 , padding = 'same' )(layer_input )
122
+ d = Conv2D (filters , kernel_size = 3 , strides = 1 , padding = 'same' )(layer_input )
123
123
d = Activation ('relu' )(d )
124
124
d = BatchNormalization (momentum = 0.8 )(d )
125
- d = Conv2D (64 , kernel_size = 3 , strides = 1 , padding = 'same' )(d )
125
+ d = Conv2D (filters , kernel_size = 3 , strides = 1 , padding = 'same' )(d )
126
126
d = BatchNormalization (momentum = 0.8 )(d )
127
127
d = Add ()([d , layer_input ])
128
128
return d
@@ -144,7 +144,7 @@ def deconv2d(layer_input):
144
144
# Propogate through residual blocks
145
145
r = residual_block (c1 )
146
146
for _ in range (self .n_residual_blocks - 1 ):
147
- r = residual_block (r )
147
+ r = residual_block (r , self . gf )
148
148
149
149
# Post-residual block
150
150
c2 = Conv2D (64 , kernel_size = 3 , strides = 1 , padding = 'same' )(r )
0 commit comments