Skip to content

Commit 695a0b5

Browse files
authored
Enhanced Deep Residual Networks for single-image super-resolution - Keras 3 migration (Only Tensorflow Backend) (#1920)
* Keras 3 migration * trim output
1 parent 3117146 commit 695a0b5

File tree

9 files changed

+146
-311
lines changed

9 files changed

+146
-311
lines changed

examples/vision/edsr.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Enhanced Deep Residual Networks for single-image super-resolution
33
Author: Gitesh Chawda
44
Date created: 2022/04/07
5-
Last modified: 2022/04/07
5+
Last modified: 2024/08/27
66
Description: Training an EDSR model on the DIV2K Dataset.
77
Accelerator: GPU
88
"""
@@ -40,14 +40,18 @@
4040
"""
4141
## Imports
4242
"""
43+
import os
44+
45+
os.environ["KERAS_BACKEND"] = "tensorflow"
4346

4447
import numpy as np
4548
import tensorflow as tf
4649
import tensorflow_datasets as tfds
4750
import matplotlib.pyplot as plt
4851

49-
from tensorflow import keras
50-
from tensorflow.keras import layers
52+
import keras
53+
from keras import layers
54+
from keras import ops
5155

5256
AUTOTUNE = tf.data.AUTOTUNE
5357

@@ -81,15 +85,15 @@ def flip_left_right(lowres_img, highres_img):
8185
"""Flips Images to left and right."""
8286

8387
# Outputs random values from a uniform distribution in between 0 to 1
84-
rn = tf.random.uniform(shape=(), maxval=1)
88+
rn = keras.random.uniform(shape=(), maxval=1)
8589
# If rn is less than 0.5 it returns original lowres_img and highres_img
8690
# If rn is greater than 0.5 it returns flipped image
87-
return tf.cond(
91+
return ops.cond(
8892
rn < 0.5,
8993
lambda: (lowres_img, highres_img),
9094
lambda: (
91-
tf.image.flip_left_right(lowres_img),
92-
tf.image.flip_left_right(highres_img),
95+
ops.flip(lowres_img),
96+
ops.flip(highres_img),
9397
),
9498
)
9599

@@ -98,7 +102,9 @@ def random_rotate(lowres_img, highres_img):
98102
"""Rotates Images by 90 degrees."""
99103

100104
# Outputs random values from uniform distribution in between 0 to 4
101-
rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
105+
rn = ops.cast(
106+
keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32"
107+
)
102108
# Here rn signifies number of times the image(s) are rotated by 90 degrees
103109
return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)
104110

@@ -110,13 +116,19 @@ def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
110116
high resolution images: 96x96
111117
"""
112118
lowres_crop_size = hr_crop_size // scale # 96//4=24
113-
lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)
119+
lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)
114120

115-
lowres_width = tf.random.uniform(
116-
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32
121+
lowres_width = ops.cast(
122+
keras.random.uniform(
123+
shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
124+
),
125+
dtype="int32",
117126
)
118-
lowres_height = tf.random.uniform(
119-
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32
127+
lowres_height = ops.cast(
128+
keras.random.uniform(
129+
shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
130+
),
131+
dtype="int32",
120132
)
121133

122134
highres_width = lowres_width * scale
@@ -218,7 +230,7 @@ def PSNR(super_resolution, high_resolution):
218230
"""
219231

220232

221-
class EDSRModel(tf.keras.Model):
233+
class EDSRModel(keras.Model):
222234
def train_step(self, data):
223235
# Unpack the data. Its structure depends on your model and
224236
# on what you pass to `fit()`.
@@ -242,16 +254,16 @@ def train_step(self, data):
242254

243255
def predict_step(self, x):
244256
# Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
245-
x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)
257+
x = ops.cast(tf.expand_dims(x, axis=0), dtype="float32")
246258
# Passing low resolution image to model
247259
super_resolution_img = self(x, training=False)
248260
# Clips the tensor from min(0) to max(255)
249-
super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)
261+
super_resolution_img = ops.clip(super_resolution_img, 0, 255)
250262
# Rounds the values of a tensor to the nearest integer
251-
super_resolution_img = tf.round(super_resolution_img)
263+
super_resolution_img = ops.round(super_resolution_img)
252264
# Removes dimensions of size 1 from the shape of a tensor and converting to uint8
253-
super_resolution_img = tf.squeeze(
254-
tf.cast(super_resolution_img, tf.uint8), axis=0
265+
super_resolution_img = ops.squeeze(
266+
ops.cast(super_resolution_img, dtype="uint8"), axis=0
255267
)
256268
return super_resolution_img
257269

@@ -267,9 +279,9 @@ def ResBlock(inputs):
267279
# Upsampling Block
268280
def Upsampling(inputs, factor=2, **kwargs):
269281
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
270-
x = tf.nn.depth_to_space(x, block_size=factor)
282+
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
271283
x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
272-
x = tf.nn.depth_to_space(x, block_size=factor)
284+
x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
273285
return x
274286

275287

546 KB
Loading
-268 KB
Loading
56.5 KB
Loading
116 KB
Loading
46.5 KB
Loading
230 KB
Loading

examples/vision/ipynb/edsr.ipynb

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** Gitesh Chawda<br>\n",
1212
"**Date created:** 2022/04/07<br>\n",
13-
"**Last modified:** 2022/04/07<br>\n",
13+
"**Last modified:** 2024/08/27<br>\n",
1414
"**Description:** Training an EDSR model on the DIV2K Dataset."
1515
]
1616
},
@@ -39,7 +39,7 @@
3939
"you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five\n",
4040
"best-performing super-resolution methods based on PSNR scores. However, it has more\n",
4141
"parameters and requires more computational power than other approaches.\n",
42-
"It has a PSNR value (≈34db) that is slightly higher than ESPCN (≈32db).\n",
42+
"It has a PSNR value (\u224834db) that is slightly higher than ESPCN (\u224832db).\n",
4343
"As per the survey paper, EDSR performs better than ESPCN.\n",
4444
"\n",
4545
"Paper:\n",
@@ -60,19 +60,24 @@
6060
},
6161
{
6262
"cell_type": "code",
63-
"execution_count": null,
63+
"execution_count": 0,
6464
"metadata": {
6565
"colab_type": "code"
6666
},
6767
"outputs": [],
6868
"source": [
69+
"import os\n",
70+
"\n",
71+
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
72+
"\n",
6973
"import numpy as np\n",
7074
"import tensorflow as tf\n",
7175
"import tensorflow_datasets as tfds\n",
7276
"import matplotlib.pyplot as plt\n",
7377
"\n",
74-
"from tensorflow import keras\n",
75-
"from tensorflow.keras import layers\n",
78+
"import keras\n",
79+
"from keras import layers\n",
80+
"from keras import ops\n",
7681
"\n",
7782
"AUTOTUNE = tf.data.AUTOTUNE"
7883
]
@@ -93,7 +98,7 @@
9398
},
9499
{
95100
"cell_type": "code",
96-
"execution_count": null,
101+
"execution_count": 0,
97102
"metadata": {
98103
"colab_type": "code"
99104
},
@@ -123,7 +128,7 @@
123128
},
124129
{
125130
"cell_type": "code",
126-
"execution_count": null,
131+
"execution_count": 0,
127132
"metadata": {
128133
"colab_type": "code"
129134
},
@@ -134,15 +139,15 @@
134139
" \"\"\"Flips Images to left and right.\"\"\"\n",
135140
"\n",
136141
" # Outputs random values from a uniform distribution in between 0 to 1\n",
137-
" rn = tf.random.uniform(shape=(), maxval=1)\n",
142+
" rn = keras.random.uniform(shape=(), maxval=1)\n",
138143
" # If rn is less than 0.5 it returns original lowres_img and highres_img\n",
139144
" # If rn is greater than 0.5 it returns flipped image\n",
140-
" return tf.cond(\n",
145+
" return ops.cond(\n",
141146
" rn < 0.5,\n",
142147
" lambda: (lowres_img, highres_img),\n",
143148
" lambda: (\n",
144-
" tf.image.flip_left_right(lowres_img),\n",
145-
" tf.image.flip_left_right(highres_img),\n",
149+
" ops.flip(lowres_img),\n",
150+
" ops.flip(highres_img),\n",
146151
" ),\n",
147152
" )\n",
148153
"\n",
@@ -151,7 +156,9 @@
151156
" \"\"\"Rotates Images by 90 degrees.\"\"\"\n",
152157
"\n",
153158
" # Outputs random values from uniform distribution in between 0 to 4\n",
154-
" rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)\n",
159+
" rn = ops.cast(\n",
160+
" keras.random.uniform(shape=(), maxval=4, dtype=\"float32\"), dtype=\"int32\"\n",
161+
" )\n",
155162
" # Here rn signifies number of times the image(s) are rotated by 90 degrees\n",
156163
" return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)\n",
157164
"\n",
@@ -163,13 +170,19 @@
163170
" high resolution images: 96x96\n",
164171
" \"\"\"\n",
165172
" lowres_crop_size = hr_crop_size // scale # 96//4=24\n",
166-
" lowres_img_shape = tf.shape(lowres_img)[:2] # (height,width)\n",
173+
" lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)\n",
167174
"\n",
168-
" lowres_width = tf.random.uniform(\n",
169-
" shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32\n",
175+
" lowres_width = ops.cast(\n",
176+
" keras.random.uniform(\n",
177+
" shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=\"float32\"\n",
178+
" ),\n",
179+
" dtype=\"int32\",\n",
170180
" )\n",
171-
" lowres_height = tf.random.uniform(\n",
172-
" shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32\n",
181+
" lowres_height = ops.cast(\n",
182+
" keras.random.uniform(\n",
183+
" shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=\"float32\"\n",
184+
" ),\n",
185+
" dtype=\"int32\",\n",
173186
" )\n",
174187
"\n",
175188
" highres_width = lowres_width * scale\n",
@@ -184,7 +197,8 @@
184197
" highres_width : highres_width + hr_crop_size,\n",
185198
" ] # 96x96\n",
186199
"\n",
187-
" return lowres_img_cropped, highres_img_cropped\n"
200+
" return lowres_img_cropped, highres_img_cropped\n",
201+
""
188202
]
189203
},
190204
{
@@ -202,15 +216,14 @@
202216
},
203217
{
204218
"cell_type": "code",
205-
"execution_count": null,
219+
"execution_count": 0,
206220
"metadata": {
207221
"colab_type": "code"
208222
},
209223
"outputs": [],
210224
"source": [
211225
"\n",
212226
"def dataset_object(dataset_cache, training=True):\n",
213-
"\n",
214227
" ds = dataset_cache\n",
215228
" ds = ds.map(\n",
216229
" lambda lowres, highres: random_crop(lowres, highres, scale=4),\n",
@@ -248,7 +261,7 @@
248261
},
249262
{
250263
"cell_type": "code",
251-
"execution_count": null,
264+
"execution_count": 0,
252265
"metadata": {
253266
"colab_type": "code"
254267
},
@@ -277,7 +290,8 @@
277290
" \"\"\"Compute the peak signal-to-noise ratio, measures quality of image.\"\"\"\n",
278291
" # Max value of pixel is 255\n",
279292
" psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]\n",
280-
" return psnr_value\n"
293+
" return psnr_value\n",
294+
""
281295
]
282296
},
283297
{
@@ -305,14 +319,14 @@
305319
},
306320
{
307321
"cell_type": "code",
308-
"execution_count": null,
322+
"execution_count": 0,
309323
"metadata": {
310324
"colab_type": "code"
311325
},
312326
"outputs": [],
313327
"source": [
314328
"\n",
315-
"class EDSRModel(tf.keras.Model):\n",
329+
"class EDSRModel(keras.Model):\n",
316330
" def train_step(self, data):\n",
317331
" # Unpack the data. Its structure depends on your model and\n",
318332
" # on what you pass to `fit()`.\n",
@@ -336,16 +350,16 @@
336350
"\n",
337351
" def predict_step(self, x):\n",
338352
" # Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast\n",
339-
" x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)\n",
353+
" x = ops.cast(tf.expand_dims(x, axis=0), dtype=\"float32\")\n",
340354
" # Passing low resolution image to model\n",
341355
" super_resolution_img = self(x, training=False)\n",
342356
" # Clips the tensor from min(0) to max(255)\n",
343-
" super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255)\n",
357+
" super_resolution_img = ops.clip(super_resolution_img, 0, 255)\n",
344358
" # Rounds the values of a tensor to the nearest integer\n",
345-
" super_resolution_img = tf.round(super_resolution_img)\n",
359+
" super_resolution_img = ops.round(super_resolution_img)\n",
346360
" # Removes dimensions of size 1 from the shape of a tensor and converting to uint8\n",
347-
" super_resolution_img = tf.squeeze(\n",
348-
" tf.cast(super_resolution_img, tf.uint8), axis=0\n",
361+
" super_resolution_img = ops.squeeze(\n",
362+
" ops.cast(super_resolution_img, dtype=\"uint8\"), axis=0\n",
349363
" )\n",
350364
" return super_resolution_img\n",
351365
"\n",
@@ -360,10 +374,10 @@
360374
"\n",
361375
"# Upsampling Block\n",
362376
"def Upsampling(inputs, factor=2, **kwargs):\n",
363-
" x = layers.Conv2D(64 * (factor ** 2), 3, padding=\"same\", **kwargs)(inputs)\n",
364-
" x = tf.nn.depth_to_space(x, block_size=factor)\n",
365-
" x = layers.Conv2D(64 * (factor ** 2), 3, padding=\"same\", **kwargs)(x)\n",
366-
" x = tf.nn.depth_to_space(x, block_size=factor)\n",
377+
" x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(inputs)\n",
378+
" x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n",
379+
" x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(x)\n",
380+
" x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n",
367381
" return x\n",
368382
"\n",
369383
"\n",
@@ -402,7 +416,7 @@
402416
},
403417
{
404418
"cell_type": "code",
405-
"execution_count": null,
419+
"execution_count": 0,
406420
"metadata": {
407421
"colab_type": "code"
408422
},
@@ -431,7 +445,7 @@
431445
},
432446
{
433447
"cell_type": "code",
434-
"execution_count": null,
448+
"execution_count": 0,
435449
"metadata": {
436450
"colab_type": "code"
437451
},
@@ -473,7 +487,7 @@
473487
"\n",
474488
"| Trained Model | Demo |\n",
475489
"| :--: | :--: |\n",
476-
"| [![Generic badge](https://img.shields.io/badge/🤗%20Model-EDSR-red.svg)](https://huggingface.co/keras-io/EDSR) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-EDSR-red.svg)](https://huggingface.co/spaces/keras-io/EDSR) |"
490+
"| [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Model-EDSR-red.svg)](https://huggingface.co/keras-io/EDSR) | [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Spaces-EDSR-red.svg)](https://huggingface.co/spaces/keras-io/EDSR) |"
477491
]
478492
}
479493
],
@@ -506,4 +520,4 @@
506520
},
507521
"nbformat": 4,
508522
"nbformat_minor": 0
509-
}
523+
}

0 commit comments

Comments
 (0)