Skip to content

Commit ff969fe

Browse files
authored
Monocular depth estimation - Keras 3 Migration (Only Tensorflow Backend) (#1910)
* Monocular depth estimation - Keras 3 Migration (Only Tensorflow Backend) * trim output * Added PyDataset
1 parent 4301f03 commit ff969fe

File tree

7 files changed

+236
-165
lines changed

7 files changed

+236
-165
lines changed

examples/vision/depth_estimation.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Monocular depth estimation
33
Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
44
Date created: 2021/08/30
5-
Last modified: 2021/08/30
5+
Last modified: 2024/08/13
66
Description: Implement a depth estimation model with a convnet.
77
Accelerator: GPU
88
"""
@@ -25,17 +25,21 @@
2525
"""
2626

2727
import os
28+
29+
os.environ["KERAS_BACKEND"] = "tensorflow"
30+
2831
import sys
2932

3033
import tensorflow as tf
31-
from tensorflow.keras import layers
32-
34+
import keras
35+
from keras import layers
36+
from keras import ops
3337
import pandas as pd
3438
import numpy as np
3539
import cv2
3640
import matplotlib.pyplot as plt
3741

38-
tf.random.set_seed(123)
42+
keras.utils.set_random_seed(123)
3943

4044
"""
4145
## Downloading the dataset
@@ -52,7 +56,7 @@
5256

5357
annotation_folder = "/dataset/"
5458
if not os.path.exists(os.path.abspath(".") + annotation_folder):
55-
annotation_zip = tf.keras.utils.get_file(
59+
annotation_zip = keras.utils.get_file(
5660
"val.tar.gz",
5761
cache_subdir=os.path.abspath("."),
5862
origin="http://diode-dataset.s3.amazonaws.com/val.tar.gz",
@@ -89,7 +93,7 @@
8993

9094
HEIGHT = 256
9195
WIDTH = 256
92-
LR = 0.0002
96+
LR = 0.00001
9397
EPOCHS = 30
9498
BATCH_SIZE = 32
9599

@@ -105,8 +109,9 @@
105109
"""
106110

107111

108-
class DataGenerator(tf.keras.utils.Sequence):
112+
class DataGenerator(keras.utils.PyDataset):
109113
def __init__(self, data, batch_size=6, dim=(768, 1024), n_channels=3, shuffle=True):
114+
super().__init__()
110115
"""
111116
Initialization
112117
"""
@@ -178,7 +183,7 @@ def data_generation(self, batch):
178183
self.data["depth"][batch_id],
179184
self.data["mask"][batch_id],
180185
)
181-
186+
x, y = x.astype("float32"), y.astype("float32")
182187
return x, y
183188

184189

@@ -249,10 +254,10 @@ def __init__(
249254
super().__init__(**kwargs)
250255
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
251256
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
252-
self.reluA = layers.LeakyReLU(alpha=0.2)
253-
self.reluB = layers.LeakyReLU(alpha=0.2)
254-
self.bn2a = tf.keras.layers.BatchNormalization()
255-
self.bn2b = tf.keras.layers.BatchNormalization()
257+
self.reluA = layers.LeakyReLU(negative_slope=0.2)
258+
self.reluB = layers.LeakyReLU(negative_slope=0.2)
259+
self.bn2a = layers.BatchNormalization()
260+
self.bn2b = layers.BatchNormalization()
256261

257262
self.pool = layers.MaxPool2D((2, 2), (2, 2))
258263

@@ -278,10 +283,10 @@ def __init__(
278283
self.us = layers.UpSampling2D((2, 2))
279284
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
280285
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
281-
self.reluA = layers.LeakyReLU(alpha=0.2)
282-
self.reluB = layers.LeakyReLU(alpha=0.2)
283-
self.bn2a = tf.keras.layers.BatchNormalization()
284-
self.bn2b = tf.keras.layers.BatchNormalization()
286+
self.reluA = layers.LeakyReLU(negative_slope=0.2)
287+
self.reluB = layers.LeakyReLU(negative_slope=0.2)
288+
self.bn2a = layers.BatchNormalization()
289+
self.bn2b = layers.BatchNormalization()
285290
self.conc = layers.Concatenate()
286291

287292
def call(self, x, skip):
@@ -305,8 +310,8 @@ def __init__(
305310
super().__init__(**kwargs)
306311
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
307312
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
308-
self.reluA = layers.LeakyReLU(alpha=0.2)
309-
self.reluB = layers.LeakyReLU(alpha=0.2)
313+
self.reluA = layers.LeakyReLU(negative_slope=0.2)
314+
self.reluB = layers.LeakyReLU(negative_slope=0.2)
310315

311316
def call(self, x):
312317
x = self.convA(x)
@@ -328,13 +333,39 @@ def call(self, x):
328333
"""
329334

330335

331-
class DepthEstimationModel(tf.keras.Model):
336+
def image_gradients(image):
337+
if len(ops.shape(image)) != 4:
338+
raise ValueError(
339+
"image_gradients expects a 4D tensor "
340+
"[batch_size, h, w, d], not {}.".format(ops.shape(image))
341+
)
342+
343+
image_shape = ops.shape(image)
344+
batch_size, height, width, depth = ops.unstack(image_shape)
345+
346+
dy = image[:, 1:, :, :] - image[:, :-1, :, :]
347+
dx = image[:, :, 1:, :] - image[:, :, :-1, :]
348+
349+
# Return tensors with same size as original image by concatenating
350+
# zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y).
351+
shape = ops.stack([batch_size, 1, width, depth])
352+
dy = ops.concatenate([dy, ops.zeros(shape, dtype=image.dtype)], axis=1)
353+
dy = ops.reshape(dy, image_shape)
354+
355+
shape = ops.stack([batch_size, height, 1, depth])
356+
dx = ops.concatenate([dx, ops.zeros(shape, dtype=image.dtype)], axis=2)
357+
dx = ops.reshape(dx, image_shape)
358+
359+
return dy, dx
360+
361+
362+
class DepthEstimationModel(keras.Model):
332363
def __init__(self):
333364
super().__init__()
334365
self.ssim_loss_weight = 0.85
335366
self.l1_loss_weight = 0.1
336367
self.edge_loss_weight = 0.9
337-
self.loss_metric = tf.keras.metrics.Mean(name="loss")
368+
self.loss_metric = keras.metrics.Mean(name="loss")
338369
f = [16, 32, 64, 128, 256]
339370
self.downscale_blocks = [
340371
DownscaleBlock(f[0]),
@@ -353,28 +384,28 @@ def __init__(self):
353384

354385
def calculate_loss(self, target, pred):
355386
# Edges
356-
dy_true, dx_true = tf.image.image_gradients(target)
357-
dy_pred, dx_pred = tf.image.image_gradients(pred)
358-
weights_x = tf.exp(tf.reduce_mean(tf.abs(dx_true)))
359-
weights_y = tf.exp(tf.reduce_mean(tf.abs(dy_true)))
387+
dy_true, dx_true = image_gradients(target)
388+
dy_pred, dx_pred = image_gradients(pred)
389+
weights_x = ops.cast(ops.exp(ops.mean(ops.abs(dx_true))), "float32")
390+
weights_y = ops.cast(ops.exp(ops.mean(ops.abs(dy_true))), "float32")
360391

361392
# Depth smoothness
362393
smoothness_x = dx_pred * weights_x
363394
smoothness_y = dy_pred * weights_y
364395

365-
depth_smoothness_loss = tf.reduce_mean(abs(smoothness_x)) + tf.reduce_mean(
396+
depth_smoothness_loss = ops.mean(abs(smoothness_x)) + ops.mean(
366397
abs(smoothness_y)
367398
)
368399

369400
# Structural similarity (SSIM) index
370-
ssim_loss = tf.reduce_mean(
401+
ssim_loss = ops.mean(
371402
1
372403
- tf.image.ssim(
373404
target, pred, max_val=WIDTH, filter_size=7, k1=0.01**2, k2=0.03**2
374405
)
375406
)
376407
# Point-wise depth
377-
l1_loss = tf.reduce_mean(tf.abs(target - pred))
408+
l1_loss = ops.mean(ops.abs(target - pred))
378409

379410
loss = (
380411
(self.ssim_loss_weight * ssim_loss)
@@ -432,9 +463,9 @@ def call(self, x):
432463
## Model training
433464
"""
434465

435-
optimizer = tf.keras.optimizers.Adam(
466+
optimizer = keras.optimizers.SGD(
436467
learning_rate=LR,
437-
amsgrad=False,
468+
nesterov=False,
438469
)
439470
model = DepthEstimationModel()
440471
# Compile the model
@@ -491,9 +522,9 @@ def call(self, x):
491522
## References
492523
493524
The following papers go deeper into possible approaches for depth estimation.
494-
1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/pdf/1811.06152v1.pdf)
525+
1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/abs/1811.06152v1)
495526
2. [Digging Into Self-Supervised Monocular Depth Estimation](https://openaccess.thecvf.com/content_ICCV_2019/papers/Godard_Digging_Into_Self-Supervised_Monocular_Depth_Estimation_ICCV_2019_paper.pdf)
496-
3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/pdf/1606.00373v2.pdf)
527+
3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/abs/1606.00373v2)
497528
498529
You can also find helpful implementations in the papers with code depth estimation task.
499530
1.11 MB
Loading
84.4 KB
Loading
4.18 MB
Loading
4.52 MB
Loading

0 commit comments

Comments
 (0)