22Title: Monocular depth estimation
33Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
44Date created: 2021/08/30
5- Last modified: 2021 /08/30
5+ Last modified: 2024 /08/13
66Description: Implement a depth estimation model with a convnet.
77Accelerator: GPU
88"""
2525"""
2626
2727import os
28+
29+ os .environ ["KERAS_BACKEND" ] = "tensorflow"
30+
2831import sys
2932
3033import tensorflow as tf
31- from tensorflow .keras import layers
32-
34+ import keras
35+ from keras import layers
36+ from keras import ops
3337import pandas as pd
3438import numpy as np
3539import cv2
3640import matplotlib .pyplot as plt
3741
38- tf . random . set_seed (123 )
42+ keras . utils . set_random_seed (123 )
3943
4044"""
4145## Downloading the dataset
5256
5357annotation_folder = "/dataset/"
5458if not os .path .exists (os .path .abspath ("." ) + annotation_folder ):
55- annotation_zip = tf . keras .utils .get_file (
59+ annotation_zip = keras .utils .get_file (
5660 "val.tar.gz" ,
5761 cache_subdir = os .path .abspath ("." ),
5862 origin = "http://diode-dataset.s3.amazonaws.com/val.tar.gz" ,
8993
9094HEIGHT = 256
9195WIDTH = 256
92- LR = 0.0002
96+ LR = 0.00001
9397EPOCHS = 30
9498BATCH_SIZE = 32
9599
105109"""
106110
107111
108- class DataGenerator (tf . keras .utils .Sequence ):
112+ class DataGenerator (keras .utils .PyDataset ):
109113 def __init__ (self , data , batch_size = 6 , dim = (768 , 1024 ), n_channels = 3 , shuffle = True ):
114+ super ().__init__ ()
110115 """
111116 Initialization
112117 """
@@ -178,7 +183,7 @@ def data_generation(self, batch):
178183 self .data ["depth" ][batch_id ],
179184 self .data ["mask" ][batch_id ],
180185 )
181-
186+ x , y = x . astype ( "float32" ), y . astype ( "float32" )
182187 return x , y
183188
184189
@@ -249,10 +254,10 @@ def __init__(
249254 super ().__init__ (** kwargs )
250255 self .convA = layers .Conv2D (filters , kernel_size , strides , padding )
251256 self .convB = layers .Conv2D (filters , kernel_size , strides , padding )
252- self .reluA = layers .LeakyReLU (alpha = 0.2 )
253- self .reluB = layers .LeakyReLU (alpha = 0.2 )
254- self .bn2a = tf . keras . layers .BatchNormalization ()
255- self .bn2b = tf . keras . layers .BatchNormalization ()
257+ self .reluA = layers .LeakyReLU (negative_slope = 0.2 )
258+ self .reluB = layers .LeakyReLU (negative_slope = 0.2 )
259+ self .bn2a = layers .BatchNormalization ()
260+ self .bn2b = layers .BatchNormalization ()
256261
257262 self .pool = layers .MaxPool2D ((2 , 2 ), (2 , 2 ))
258263
@@ -278,10 +283,10 @@ def __init__(
278283 self .us = layers .UpSampling2D ((2 , 2 ))
279284 self .convA = layers .Conv2D (filters , kernel_size , strides , padding )
280285 self .convB = layers .Conv2D (filters , kernel_size , strides , padding )
281- self .reluA = layers .LeakyReLU (alpha = 0.2 )
282- self .reluB = layers .LeakyReLU (alpha = 0.2 )
283- self .bn2a = tf . keras . layers .BatchNormalization ()
284- self .bn2b = tf . keras . layers .BatchNormalization ()
286+ self .reluA = layers .LeakyReLU (negative_slope = 0.2 )
287+ self .reluB = layers .LeakyReLU (negative_slope = 0.2 )
288+ self .bn2a = layers .BatchNormalization ()
289+ self .bn2b = layers .BatchNormalization ()
285290 self .conc = layers .Concatenate ()
286291
287292 def call (self , x , skip ):
@@ -305,8 +310,8 @@ def __init__(
305310 super ().__init__ (** kwargs )
306311 self .convA = layers .Conv2D (filters , kernel_size , strides , padding )
307312 self .convB = layers .Conv2D (filters , kernel_size , strides , padding )
308- self .reluA = layers .LeakyReLU (alpha = 0.2 )
309- self .reluB = layers .LeakyReLU (alpha = 0.2 )
313+ self .reluA = layers .LeakyReLU (negative_slope = 0.2 )
314+ self .reluB = layers .LeakyReLU (negative_slope = 0.2 )
310315
311316 def call (self , x ):
312317 x = self .convA (x )
@@ -328,13 +333,39 @@ def call(self, x):
328333"""
329334
330335
331- class DepthEstimationModel (tf .keras .Model ):
336+ def image_gradients (image ):
337+ if len (ops .shape (image )) != 4 :
338+ raise ValueError (
339+ "image_gradients expects a 4D tensor "
340+ "[batch_size, h, w, d], not {}." .format (ops .shape (image ))
341+ )
342+
343+ image_shape = ops .shape (image )
344+ batch_size , height , width , depth = ops .unstack (image_shape )
345+
346+ dy = image [:, 1 :, :, :] - image [:, :- 1 , :, :]
347+ dx = image [:, :, 1 :, :] - image [:, :, :- 1 , :]
348+
349+ # Return tensors with same size as original image by concatenating
350+ # zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y).
351+ shape = ops .stack ([batch_size , 1 , width , depth ])
352+ dy = ops .concatenate ([dy , ops .zeros (shape , dtype = image .dtype )], axis = 1 )
353+ dy = ops .reshape (dy , image_shape )
354+
355+ shape = ops .stack ([batch_size , height , 1 , depth ])
356+ dx = ops .concatenate ([dx , ops .zeros (shape , dtype = image .dtype )], axis = 2 )
357+ dx = ops .reshape (dx , image_shape )
358+
359+ return dy , dx
360+
361+
362+ class DepthEstimationModel (keras .Model ):
332363 def __init__ (self ):
333364 super ().__init__ ()
334365 self .ssim_loss_weight = 0.85
335366 self .l1_loss_weight = 0.1
336367 self .edge_loss_weight = 0.9
337- self .loss_metric = tf . keras .metrics .Mean (name = "loss" )
368+ self .loss_metric = keras .metrics .Mean (name = "loss" )
338369 f = [16 , 32 , 64 , 128 , 256 ]
339370 self .downscale_blocks = [
340371 DownscaleBlock (f [0 ]),
@@ -353,28 +384,28 @@ def __init__(self):
353384
354385 def calculate_loss (self , target , pred ):
355386 # Edges
356- dy_true , dx_true = tf . image . image_gradients (target )
357- dy_pred , dx_pred = tf . image . image_gradients (pred )
358- weights_x = tf . exp (tf . reduce_mean ( tf .abs (dx_true )))
359- weights_y = tf . exp (tf . reduce_mean ( tf .abs (dy_true )))
387+ dy_true , dx_true = image_gradients (target )
388+ dy_pred , dx_pred = image_gradients (pred )
389+ weights_x = ops . cast ( ops . exp (ops . mean ( ops .abs (dx_true ))), "float32" )
390+ weights_y = ops . cast ( ops . exp (ops . mean ( ops .abs (dy_true ))), "float32" )
360391
361392 # Depth smoothness
362393 smoothness_x = dx_pred * weights_x
363394 smoothness_y = dy_pred * weights_y
364395
365- depth_smoothness_loss = tf . reduce_mean (abs (smoothness_x )) + tf . reduce_mean (
396+ depth_smoothness_loss = ops . mean (abs (smoothness_x )) + ops . mean (
366397 abs (smoothness_y )
367398 )
368399
369400 # Structural similarity (SSIM) index
370- ssim_loss = tf . reduce_mean (
401+ ssim_loss = ops . mean (
371402 1
372403 - tf .image .ssim (
373404 target , pred , max_val = WIDTH , filter_size = 7 , k1 = 0.01 ** 2 , k2 = 0.03 ** 2
374405 )
375406 )
376407 # Point-wise depth
377- l1_loss = tf . reduce_mean ( tf .abs (target - pred ))
408+ l1_loss = ops . mean ( ops .abs (target - pred ))
378409
379410 loss = (
380411 (self .ssim_loss_weight * ssim_loss )
@@ -432,9 +463,9 @@ def call(self, x):
432463## Model training
433464"""
434465
435- optimizer = tf . keras .optimizers .Adam (
466+ optimizer = keras .optimizers .SGD (
436467 learning_rate = LR ,
437- amsgrad = False ,
468+ nesterov = False ,
438469)
439470model = DepthEstimationModel ()
440471# Compile the model
@@ -491,9 +522,9 @@ def call(self, x):
491522## References
492523
493524The following papers go deeper into possible approaches for depth estimation.
494- 1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/pdf /1811.06152v1.pdf )
525+ 1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/abs /1811.06152v1)
4955262. [Digging Into Self-Supervised Monocular Depth Estimation](https://openaccess.thecvf.com/content_ICCV_2019/papers/Godard_Digging_Into_Self-Supervised_Monocular_Depth_Estimation_ICCV_2019_paper.pdf)
496- 3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/pdf /1606.00373v2.pdf )
527+ 3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/abs /1606.00373v2)
497528
498529You can also find helpful implementations in the papers with code depth estimation task.
499530
0 commit comments