Skip to content

Commit db15804

Browse files
committed
cleanup new function
1 parent 8fe7448 commit db15804

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

dataset_tool.py

+19-25
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def center_crop_wide(width, height, img):
250250
canvas[(width - height) // 2 : (width + height) // 2, :] = img
251251
return canvas
252252

253-
def resize_pad(width, height, img):
254-
# fix dims
253+
def _to_rgb(img):
254+
# add channel dim
255255
if img.ndim == 2:
256256
img = img[:, :, None]
257257
assert img.ndim == 3, f'input image has incorrect number of dimensions, required 2 (H, W) or 3 (H, W, C), got: {img.shape}'
@@ -261,34 +261,28 @@ def resize_pad(width, height, img):
261261
elif img.shape[-1] == 4:
262262
img = img[:, :, :3]
263263
assert img.shape[-1] == 3, f'input image must have 1 or 3 channels, got: {img.shape}'
264+
return img
265+
266+
def resize_pad(width, height, img):
267+
img = _to_rgb(img)
264268
# exit early
265-
h, w, c = img.shape
266-
if width == w and height == h:
269+
img_h, img_w = img.shape[:2]
270+
if width == img_w and height == img_h:
267271
return img
268-
# get scale size
269-
rh = height / h
270-
rw = width / w
271-
sh = int(round(h * min(rh, rw), 5)) # avoid precision errors
272-
sw = int(round(w * min(rh, rw), 5)) # avoid precision errors
273-
assert sh <= height and sw <= height, f'scaled shape {nw}x{nh} is not smaller than or equal to the required shape: {width}x{height} this is a bug:'
272+
# get scale size, avoiding precision errors
273+
scale_ratio = max(img_h / height, img_w / width)
274+
scale_h = int(round(img_h / scale_ratio, 5))
275+
scale_w = int(round(img_w / scale_ratio, 5))
276+
assert scale_h <= height and scale_w <= width, f'scaled shape {scale_w}x{scale_h} is not smaller than or equal to the required shape: {width}x{height} this is a bug:'
274277
# scale image
275-
img = scale(sw, sh, img)
276-
nh, nw, nc = img.shape
277-
assert nh == sh and nw == sw, f'scaled shape {nw}x{nh} does not match required scaled shape: {sw}x{sh} this is a bug!'
278+
img = scale(scale_w, scale_h, img)
278279
# pad the image if needed
279-
ph = height - sh
280-
pw = width - sw
281-
if ph != 0 or pw != 0:
282-
assert ph >= 0 and pw >= 0, f'target width={repr(width)} height={repr(width)}, pad amount: pw={repr(pw)} ph={repr(ph)}'
283-
img = np.pad(img, (
284-
(int(np.floor(ph/2)), int(np.ceil(ph/2))),
285-
(int(np.floor(pw/2)), int(np.ceil(pw/2))),
286-
(0, 0),
287-
))
280+
pad_h, pad_w = height - scale_h, width - scale_w
281+
if pad_h != 0 or pad_w != 0:
282+
pad_dims = [[np.floor(pad_h/2), np.ceil(pad_h/2)], [np.floor(pad_w/2), np.ceil(pad_w/2)], [0, 0]] # (H,W,C)
283+
img = np.pad(img, np.array(pad_dims).astype('int'))
288284
# check the shape
289-
oh, ow, oc = img.shape
290-
# print((h, w, c), h/w, '->', (nh, nw, nc), nh/nw, '->', (oh, ow, oc), oh/ow)
291-
assert oh == height and ow == width, f'output shape {ow}x{oh} does not match required shape: {width}x{height} this is a bug!'
285+
assert img.shape[0] == height and img.shape[1] == width, f'output shape {img.shape[1]}x{img.shape[0]} does not match required shape: {width}x{height} this is a bug!'
292286
return img
293287

294288
if transform is None:

0 commit comments

Comments
 (0)