22from keras import layers
33from typing import List , Tuple
44
5+
56def ResidualBlock (width : int ) -> layers .Layer :
67 def apply (x : layers .Layer ) -> layers .Layer :
7-
88 input_width = x .shape [3 ]
99 residual = x if input_width == width else layers .Conv2D (width , kernel_size = 1 )(x )
10-
10+
1111 x = layers .LayerNormalization (axis = - 1 , center = True , scale = True )(x )
1212 x = layers .Conv2D (
1313 width , kernel_size = 3 , padding = "same" , activation = keras .activations .swish
@@ -18,9 +18,11 @@ def apply(x: layers.Layer) -> layers.Layer:
1818
1919 return apply
2020
21+
2122def DownBlock (width : int , block_depth : int ) -> layers .Layer :
22- def apply (x : Tuple [layers .Layer , List [layers .Layer ]]) -> Tuple [layers .Layer , List [layers .Layer ]]:
23-
23+ def apply (
24+ x : Tuple [layers .Layer , List [layers .Layer ]],
25+ ) -> Tuple [layers .Layer , List [layers .Layer ]]:
2426 x , skips = x
2527 for _ in range (block_depth ):
2628 x = ResidualBlock (width )(x )
@@ -30,26 +32,30 @@ def apply(x: Tuple[layers.Layer, List[layers.Layer]]) -> Tuple[layers.Layer, Lis
3032
3133 return apply
3234
35+
3336def UpBlock (width : int , block_depth : int ) -> layers .Layer :
3437 def apply (x : Tuple [layers .Layer , List [layers .Layer ]]) -> layers .Layer :
35-
3638 x , skips = x
3739 x = layers .UpSampling2D (size = 2 , interpolation = "bilinear" )(x )
3840 for _ in range (block_depth ):
3941 x = layers .Concatenate ()([x , skips .pop ()])
4042 x = ResidualBlock (width )(x )
4143 return x
44+
4245 return apply
4346
44- def get_model (image_height : int ,
45- image_width : int ,
46- input_frames : int ,
47- output_frames : int ,
48- down_widths : List [int ] = [64 , 128 , 256 ],
49- up_widths : List [int ] = [256 , 128 , 64 ],
50- block_depth : int = 2 ) -> keras .Model :
47+
48+ def get_model (
49+ image_height : int ,
50+ image_width : int ,
51+ input_frames : int ,
52+ output_frames : int ,
53+ down_widths : List [int ] = [64 , 128 , 256 ],
54+ up_widths : List [int ] = [256 , 128 , 64 ],
55+ block_depth : int = 2 ,
56+ ) -> keras .Model :
5157 """Builds the U-Net like model with residual blocks and skip connections."""
52-
58+
5359 inputs = keras .Input (shape = (image_height , image_width , input_frames ))
5460 x = layers .Conv2D (down_widths [0 ], kernel_size = 1 )(inputs )
5561
@@ -64,5 +70,5 @@ def get_model(image_height: int,
6470 x = UpBlock (width , block_depth )([x , skips ])
6571
6672 outputs = layers .Conv2D (output_frames , kernel_size = 1 , kernel_initializer = "zeros" )(x )
67-
73+
6874 return keras .Model (inputs , outputs , name = "residual_unet" )
0 commit comments