Skip to content

Commit a101d33

Browse files
committed
feat: reflection_padding layer added with test cases
1 parent 9f64291 commit a101d33

File tree

7 files changed

+241
-69
lines changed

7 files changed

+241
-69
lines changed

.vscode/settings.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
],
1919
"[python]": {
2020
"editor.codeActionsOnSave": {
21-
"source.organizeImports": true
22-
}
21+
"source.organizeImports": "explicit"
22+
}
2323
},
2424
"python.analysis.diagnosticSeverityOverrides": {
2525
"reportMissingImports": "none"

tf_keras/layers/convolutional/BUILD

+78-67
Original file line numberDiff line numberDiff line change
@@ -215,86 +215,97 @@ py_library(
215215
)
216216

217217
py_library(
218-
name = "depthwise_conv1d",
219-
srcs = ["depthwise_conv1d.py"],
220-
srcs_version = "PY3",
221-
deps = [
222-
":base_depthwise_conv",
223-
"//:expect_tensorflow_installed",
224-
"//tf_keras/utils:engine_utils",
225-
"//tf_keras/utils:tf_utils",
226-
],
218+
name = "depthwise_conv1d",
219+
srcs = ["depthwise_conv1d.py"],
220+
srcs_version = "PY3",
221+
deps = [
222+
":base_depthwise_conv",
223+
"//:expect_tensorflow_installed",
224+
"//tf_keras/utils:engine_utils",
225+
"//tf_keras/utils:tf_utils",
226+
],
227227
)
228228

229229
py_library(
230-
name = "depthwise_conv2d",
231-
srcs = ["depthwise_conv2d.py"],
232-
srcs_version = "PY3",
233-
deps = [
234-
":base_depthwise_conv",
235-
"//tf_keras:backend",
236-
"//tf_keras/utils:engine_utils",
237-
"//tf_keras/utils:tf_utils",
238-
],
230+
name = "depthwise_conv2d",
231+
srcs = ["depthwise_conv2d.py"],
232+
srcs_version = "PY3",
233+
deps = [
234+
":base_depthwise_conv",
235+
"//tf_keras:backend",
236+
"//tf_keras/utils:engine_utils",
237+
"//tf_keras/utils:tf_utils",
238+
],
239+
)
240+
241+
py_library(
242+
name = "reflection_padding",
243+
srcs = ["reflection_padding.py"], # Adjust this to your actual file name
244+
srcs_version = "PY3",
245+
deps = [
246+
"//:expect_tensorflow_installed", # Assuming reflection_padding.py depends on TensorFlow
247+
"//tf_keras/utils:engine_utils",
248+
"//tf_keras/utils:tf_utils",
249+
],
239250
)
240251

241252
cuda_py_test(
242-
name = "conv_test",
243-
size = "medium",
244-
srcs = ["conv_test.py"],
245-
python_version = "PY3",
246-
shard_count = 8,
247-
deps = [
248-
"//:expect_absl_installed", # absl/testing:parameterized
249-
"//:expect_numpy_installed",
250-
"//:expect_tensorflow_installed",
251-
"//tf_keras",
252-
"//tf_keras/testing_infra:test_combinations",
253-
"//tf_keras/testing_infra:test_utils",
254-
],
253+
name = "conv_test",
254+
size = "medium",
255+
srcs = ["conv_test.py"],
256+
python_version = "PY3",
257+
shard_count = 8,
258+
deps = [
259+
"//:expect_absl_installed", # absl/testing:parameterized
260+
"//:expect_numpy_installed",
261+
"//:expect_tensorflow_installed",
262+
"//tf_keras",
263+
"//tf_keras/testing_infra:test_combinations",
264+
"//tf_keras/testing_infra:test_utils",
265+
],
255266
)
256267

257268
cuda_py_test(
258-
name = "conv_transpose_test",
259-
size = "medium",
260-
srcs = ["conv_transpose_test.py"],
261-
python_version = "PY3",
262-
deps = [
263-
"//:expect_absl_installed", # absl/testing:parameterized
264-
"//:expect_numpy_installed",
265-
"//:expect_tensorflow_installed",
266-
"//tf_keras",
267-
"//tf_keras/testing_infra:test_combinations",
268-
"//tf_keras/testing_infra:test_utils",
269-
],
269+
name = "conv_transpose_test",
270+
size = "medium",
271+
srcs = ["conv_transpose_test.py"],
272+
python_version = "PY3",
273+
deps = [
274+
"//:expect_absl_installed", # absl/testing:parameterized
275+
"//:expect_numpy_installed",
276+
"//:expect_tensorflow_installed",
277+
"//tf_keras",
278+
"//tf_keras/testing_infra:test_combinations",
279+
"//tf_keras/testing_infra:test_utils",
280+
],
270281
)
271282

272283
cuda_py_test(
273-
name = "depthwise_conv_test",
274-
size = "medium",
275-
srcs = ["depthwise_conv_test.py"],
276-
python_version = "PY3",
277-
shard_count = 8,
278-
deps = [
279-
"//:expect_absl_installed", # absl/testing:parameterized
280-
"//:expect_tensorflow_installed",
281-
"//tf_keras",
282-
"//tf_keras/testing_infra:test_combinations",
283-
"//tf_keras/testing_infra:test_utils",
284-
],
284+
name = "depthwise_conv_test",
285+
size = "medium",
286+
srcs = ["depthwise_conv_test.py"],
287+
python_version = "PY3",
288+
shard_count = 8,
289+
deps = [
290+
"//:expect_absl_installed", # absl/testing:parameterized
291+
"//:expect_tensorflow_installed",
292+
"//tf_keras",
293+
"//tf_keras/testing_infra:test_combinations",
294+
"//tf_keras/testing_infra:test_utils",
295+
],
285296
)
286297

287298
cuda_py_test(
288-
name = "separable_conv_test",
289-
size = "medium",
290-
srcs = ["separable_conv_test.py"],
291-
python_version = "PY3",
292-
deps = [
293-
"//:expect_absl_installed", # absl/testing:parameterized
294-
"//:expect_numpy_installed",
295-
"//:expect_tensorflow_installed",
296-
"//tf_keras",
297-
"//tf_keras/testing_infra:test_combinations",
298-
"//tf_keras/testing_infra:test_utils",
299-
],
299+
name = "separable_conv_test",
300+
size = "medium",
301+
srcs = ["separable_conv_test.py"],
302+
python_version = "PY3",
303+
deps = [
304+
"//:expect_absl_installed", # absl/testing:parameterized
305+
"//:expect_numpy_installed",
306+
"//:expect_tensorflow_installed",
307+
"//tf_keras",
308+
"//tf_keras/testing_infra:test_combinations",
309+
"//tf_keras/testing_infra:test_utils",
310+
],
300311
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import tensorflow as tf
2+
from tensorflow.keras.layers import Layer
3+
4+
5+
class BaseReflectionPadding(Layer):
6+
"""Abstract N-D reflection padding layer.
7+
8+
This layer performs reflection padding on the input tensor.
9+
10+
Args:
11+
padding: int, or tuple/list of n ints, for n > 1.
12+
If int: the same symmetric padding is applied to all spatial dimensions.
13+
If tuple/list of n ints: interpreted as n different symmetric padding values
14+
for each spatial dimension.
15+
No padding is applied to the batch or channel dimensions.
16+
17+
Raises:
18+
ValueError: If `padding` is negative or not of length 2 or more.
19+
20+
"""
21+
22+
def __init__(self, padding=(1, 1), **kwargs):
23+
super(BaseReflectionPadding, self).__init__(**kwargs)
24+
if isinstance(padding, int):
25+
self.padding = (padding, padding)
26+
elif isinstance(padding, tuple) or isinstance(padding, list):
27+
if len(padding) != self.rank:
28+
raise ValueError(
29+
f"If passing a tuple or list as padding, it must be of length {self.rank}. Received length: {len(padding)}"
30+
)
31+
self.padding = padding
32+
else:
33+
raise ValueError(
34+
f"Unsupported padding type. Expected int, tuple, or list. Received: {type(padding)}"
35+
)
36+
37+
for pad in self.padding:
38+
if pad < 0:
39+
raise ValueError("Padding cannot be negative.")
40+
41+
def compute_output_shape(self, input_shape):
42+
output_shape = list(input_shape)
43+
for i in range(1, self.rank + 1):
44+
output_shape[i] += 2 * self.padding[i - 1]
45+
return tuple(output_shape)
46+
47+
def call(self, inputs):
48+
padding_dims = [[0, 0]]
49+
for pad in self.padding:
50+
padding_dims.append([pad, pad])
51+
for _ in range(self.rank - len(self.padding)):
52+
padding_dims.append([0, 0])
53+
return tf.pad(inputs, padding_dims, mode='REFLECT')
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from base_reflection_padding import BaseReflectionPadding
2+
3+
4+
class ReflectionPadding1D(BaseReflectionPadding):
5+
"""1D reflection padding layer.
6+
7+
This layer performs reflection padding on a 1D input tensor.
8+
Inherits from BaseReflectionPadding.
9+
10+
Args:
11+
padding: int, or tuple/list of 1 int, specifying the padding width.
12+
If int: the same symmetric padding is applied to both sides.
13+
If tuple/list of 1 int: interpreted as two different symmetric padding values.
14+
"""
15+
16+
rank = 1
17+
18+
def __init__(self, padding=1, **kwargs):
19+
super(ReflectionPadding1D, self).__init__(
20+
padding=(padding,), **kwargs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from base_reflection_padding import BaseReflectionPadding
2+
3+
4+
class ReflectionPadding2D(BaseReflectionPadding):
5+
"""2D reflection padding layer.
6+
7+
This layer performs reflection padding on a 2D input tensor.
8+
Inherits from BaseReflectionPadding.
9+
10+
Args:
11+
padding: int, or tuple/list of 2 ints, for height and width respectively.
12+
If int: the same symmetric padding is applied to both dimensions.
13+
If tuple/list of 2 ints: interpreted as two different symmetric padding values.
14+
"""
15+
16+
rank = 2
17+
18+
def __init__(self, padding=(1, 1), **kwargs):
19+
if isinstance(padding, int):
20+
padding = (padding, padding)
21+
super(ReflectionPadding2D, self).__init__(padding=padding, **kwargs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from base_reflection_padding import BaseReflectionPadding
2+
3+
4+
class ReflectionPadding3D(BaseReflectionPadding):
5+
"""3D reflection padding layer.
6+
7+
This layer performs reflection padding on a 3D input tensor.
8+
Inherits from BaseReflectionPadding.
9+
10+
Args:
11+
padding: int, or tuple/list of 3 ints, for height, width, and depth respectively.
12+
If int: the same symmetric padding is applied to all dimensions.
13+
If tuple/list of 3 ints: interpreted as three different symmetric padding values.
14+
"""
15+
16+
rank = 3
17+
18+
def __init__(self, padding=(1, 1, 1), **kwargs):
19+
if isinstance(padding, int) and padding > 0:
20+
padding = (padding, padding, padding)
21+
super(ReflectionPadding3D, self).__init__(padding=padding, **kwargs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
from absl.testing import parameterized
4+
from reflection_padding1d import ReflectionPadding1D
5+
from reflection_padding2d import ReflectionPadding2D
6+
from reflection_padding3d import ReflectionPadding3D
7+
8+
from tf_keras.testing_infra import test_combinations
9+
from tf_keras.testing_infra import test_utils
10+
11+
12+
@test_combinations.run_all_keras_modes
13+
class ReflectionPaddingTest(test_combinations.TestCase):
14+
def _run_test(self, padding_layer_cls, kwargs, input_shape, expected_output_shape):
15+
with self.cached_session():
16+
test_utils.layer_test(
17+
padding_layer_cls,
18+
kwargs=kwargs,
19+
input_shape=input_shape,
20+
expected_output_shape=expected_output_shape,
21+
)
22+
23+
@parameterized.named_parameters(
24+
("ReflectionPadding1D", ReflectionPadding1D,
25+
{"padding": 2}, (None, 5, 3), (None, 9, 3)),
26+
("ReflectionPadding2D", ReflectionPadding2D, {
27+
"padding": (2, 1)}, (None, 5, 6, 3), (None, 9, 8, 3)),
28+
("ReflectionPadding3D", ReflectionPadding3D, {
29+
"padding": (1, 2, 3)}, (None, 5, 6, 7, 3), (None, 7, 10, 13, 3)),
30+
)
31+
def test_reflection_padding(self, padding_layer_cls, kwargs, input_shape, expected_output_shape):
32+
self._run_test(padding_layer_cls, kwargs,
33+
input_shape, expected_output_shape)
34+
35+
def test_reflection_padding_dynamic_shape(self):
36+
with self.cached_session():
37+
layer = ReflectionPadding2D(padding=(2, 2))
38+
input_shape = (None, None, None, 3)
39+
inputs = tf.keras.Input(shape=input_shape)
40+
x = layer(inputs)
41+
# Won't raise error here with None values in input shape.
42+
layer(x)
43+
44+
45+
if __name__ == "__main__":
46+
tf.test.main()

0 commit comments

Comments
 (0)