forked from keras-team/keras-cv
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
197 lines (170 loc) · 6.93 KB
/
utils.py
File metadata and controls
197 lines (170 loc) · 6.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for working with bounding boxes."""
from keras_cv.src import bounding_box
from keras_cv.src.api_export import keras_cv_export
from keras_cv.src.backend import ops
from keras_cv.src.bounding_box.formats import XYWH
@keras_cv_export("keras_cv.bounding_box.is_relative")
def is_relative(bounding_box_format):
"""A util to check if a bounding box format uses relative coordinates"""
if (
bounding_box_format.lower()
not in bounding_box.converters.TO_XYXY_CONVERTERS
):
raise ValueError(
"`is_relative()` received an unsupported format for the argument "
f"`bounding_box_format`. `bounding_box_format` should be one of "
f"{bounding_box.converters.TO_XYXY_CONVERTERS.keys()}. "
f"Got bounding_box_format={bounding_box_format}"
)
return bounding_box_format.startswith("rel")
@keras_cv_export("keras_cv.bounding_box.as_relative")
def as_relative(bounding_box_format):
"""A util to get the relative equivalent of a provided bounding box format.
If the specified format is already a relative format,
it will be returned unchanged.
"""
if not is_relative(bounding_box_format):
return "rel_" + bounding_box_format
return bounding_box_format
def _relative_area(boxes, bounding_box_format):
boxes = bounding_box.convert_format(
boxes,
source=bounding_box_format,
target="rel_xywh",
)
widths = boxes[..., XYWH.WIDTH]
heights = boxes[..., XYWH.HEIGHT]
# handle corner case where shear performs a full inversion.
return ops.where(
ops.logical_and(widths > 0, heights > 0), widths * heights, 0.0
)
@keras_cv_export("keras_cv.bounding_box.clip_to_image")
def clip_to_image(
bounding_boxes, bounding_box_format, images=None, image_shape=None
):
"""clips bounding boxes to image boundaries.
`clip_to_image()` clips bounding boxes that have coordinates out of bounds
of an image down to the boundaries of the image. This is done by converting
the bounding box to relative formats, then clipping them to the `[0, 1]`
range. Additionally, bounding boxes that end up with a zero area have their
class ID set to -1, indicating that there is no object present in them.
Args:
bounding_boxes: bounding box tensor to clip.
bounding_box_format: the KerasCV bounding box format the bounding boxes
are in.
images: list of images to clip the bounding boxes to.
image_shape: the shape of the images to clip the bounding boxes to.
"""
boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"]
boxes = bounding_box.convert_format(
boxes,
source=bounding_box_format,
target="rel_xyxy",
images=images,
image_shape=image_shape,
)
boxes, classes, images, squeeze = _format_inputs(boxes, classes, images)
x1, y1, x2, y2 = ops.split(boxes, 4, axis=-1)
clipped_bounding_boxes = ops.concatenate(
[
ops.clip(x1, 0, 1),
ops.clip(y1, 0, 1),
ops.clip(x2, 0, 1),
ops.clip(y2, 0, 1),
],
axis=-1,
)
areas = _relative_area(
clipped_bounding_boxes, bounding_box_format="rel_xyxy"
)
clipped_bounding_boxes = bounding_box.convert_format(
clipped_bounding_boxes,
source="rel_xyxy",
target=bounding_box_format,
images=images,
image_shape=image_shape,
)
clipped_bounding_boxes = ops.where(
ops.expand_dims(areas > 0.0, axis=-1), clipped_bounding_boxes, -1.0
)
classes = ops.where(areas > 0.0, classes, -1)
nan_indices = ops.any(ops.isnan(clipped_bounding_boxes), axis=-1)
classes = ops.where(nan_indices, -1, classes)
# TODO update dict and return
clipped_bounding_boxes, classes = _format_outputs(
clipped_bounding_boxes, classes, squeeze
)
result = bounding_boxes.copy()
result["boxes"] = clipped_bounding_boxes
result["classes"] = classes
return result
# TODO (tanzhenyu): merge with clip_to_image
def _clip_boxes(boxes, box_format, image_shape):
"""Clip boxes to the boundaries of the image shape"""
if boxes.shape[-1] != 4:
raise ValueError(
"boxes.shape[-1] is {:d}, but must be 4.".format(boxes.shape[-1])
)
if isinstance(image_shape, list) or isinstance(image_shape, tuple):
height, width, _ = image_shape
max_length = ops.stack([height, width, height, width], axis=-1)
else:
image_shape = ops.cast(image_shape, dtype=boxes.dtype)
height = image_shape[0]
width = image_shape[1]
max_length = ops.stack([height, width, height, width], axis=-1)
clipped_boxes = ops.maximum(ops.minimum(boxes, max_length), 0.0)
return clipped_boxes
def _format_inputs(boxes, classes, images):
boxes_rank = len(boxes.shape)
if boxes_rank > 3:
raise ValueError(
"Expected len(boxes.shape)=2, or len(boxes.shape)=3, got "
f"len(boxes.shape)={boxes_rank}"
)
boxes_includes_batch = boxes_rank == 3
# Determine if images needs an expand_dims() call
if images is not None:
images_rank = len(images.shape)
if images_rank > 4:
raise ValueError(
"Expected len(images.shape)=2, or len(images.shape)=3, got "
f"len(images.shape)={images_rank}"
)
images_include_batch = images_rank == 4
if boxes_includes_batch != images_include_batch:
raise ValueError(
"clip_to_image() expects both boxes and images to be batched, "
"or both boxes and images to be unbatched. Received "
f"len(boxes.shape)={boxes_rank}, "
f"len(images.shape)={images_rank}. Expected either "
"len(boxes.shape)=2 AND len(images.shape)=3, or "
"len(boxes.shape)=3 AND len(images.shape)=4."
)
if not images_include_batch:
images = ops.expand_dims(images, axis=0)
if not boxes_includes_batch:
return (
ops.expand_dims(boxes, axis=0),
ops.expand_dims(classes, axis=0),
images,
True,
)
return boxes, classes, images, False
def _format_outputs(boxes, classes, squeeze):
if squeeze:
return ops.squeeze(boxes, axis=0), ops.squeeze(classes, axis=0)
return boxes, classes