Skip to content

Commit a65015f

Browse files
committed
Fixed the other branches of crop_resize function to support various image formats and improve bounding box handling
1 parent 57658b3 commit a65015f

File tree

1 file changed

+63
-27
lines changed

1 file changed

+63
-27
lines changed

models/utils/detect_face.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -307,36 +307,72 @@ def imresample(img, sz):
307307

308308

309309
def crop_resize(img, box, image_size):
310+
"""
311+
box: (x1, y1, x2, y2) in pixel coords, x2/y2 exclusive-style is fine too (we resize anyway).
312+
img: numpy HWC, torch HWC or CHW, or PIL Image
313+
"""
314+
315+
x1, y1, x2, y2 = map(int, box)
316+
w = max(1, x2 - x1)
317+
h = max(1, y2 - y1)
318+
319+
s = max(w, h)
320+
cx = x1 + w / 2.0
321+
cy = y1 + h / 2.0
322+
323+
# square window [x0, x0+s), [y0, y0+s)
324+
x0 = int(round(cx - s / 2.0))
325+
y0 = int(round(cy - s / 2.0))
326+
310327
if isinstance(img, np.ndarray):
311-
# box[1] -> y1
312-
# box[3] -> y2
313-
# box[0] -> x1
314-
# box[2] -> x2
315-
h = box[3] - box[1]
316-
w = box[2] - box[0]
317-
yc = box[1] + h // 2
318-
xc = box[0] + w // 2
319-
intermediate_size = max(w, h)
320-
y0 = max(0, yc - intermediate_size // 2)
321-
y1 = min(img.shape[0], yc + intermediate_size // 2)
322-
x0 = max(0, xc - intermediate_size // 2)
323-
x1 = min(img.shape[1], xc + intermediate_size // 2)
324-
img = img[y0:y1, x0:x1]
325-
326-
out = cv2.resize(
327-
img,
328-
(image_size, image_size),
329-
interpolation=cv2.INTER_AREA
330-
).copy()
328+
H, W = img.shape[:2]
331329
elif isinstance(img, torch.Tensor):
332-
img = img[box[1]:box[3], box[0]:box[2]]
333-
out = imresample(
334-
img.permute(2, 0, 1).unsqueeze(0).float(),
335-
(image_size, image_size)
336-
).byte().squeeze(0).permute(1, 2, 0)
330+
# accept HWC or CHW
331+
if img.ndim != 3:
332+
raise ValueError("torch img must be 3D (HWC or CHW)")
333+
if img.shape[0] in (1, 3, 4) and img.shape[2] not in (1, 3, 4):
334+
# CHW
335+
C, H, W = img.shape
336+
chw = True
337+
else:
338+
# HWC
339+
H, W, C = img.shape
340+
chw = False
337341
else:
338-
out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
339-
return out
342+
# PIL
343+
W, H = img.size
344+
345+
# shift window to stay inside image (keeps square)
346+
x0 = min(max(0, x0), max(0, W - s))
347+
y0 = min(max(0, y0), max(0, H - s))
348+
x1n, y1n = x0 + s, y0 + s
349+
350+
if isinstance(img, np.ndarray):
351+
crop = img[y0:y1n, x0:x1n]
352+
return cv2.resize(crop, (image_size, image_size), interpolation=cv2.INTER_AREA).copy()
353+
354+
if isinstance(img, torch.Tensor):
355+
if chw:
356+
crop = img[:, y0:y1n, x0:x1n]
357+
else:
358+
crop = img[y0:y1n, x0:x1n, :]
359+
360+
# simplest: use torch.nn.functional.interpolate on float
361+
import torch.nn.functional as F
362+
if chw:
363+
crop_f = crop.unsqueeze(0).float()
364+
else:
365+
crop_f = crop.permute(2, 0, 1).unsqueeze(0).float()
366+
367+
out = F.interpolate(crop_f, size=(image_size, image_size), mode="area")
368+
out = out.squeeze(0)
369+
if not chw:
370+
out = out.permute(1, 2, 0)
371+
return out.byte()
372+
373+
# PIL
374+
crop = img.crop((x0, y0, x1n, y1n))
375+
return crop.resize((image_size, image_size), Image.BILINEAR)
340376

341377

342378
def save_img(img, path):

0 commit comments

Comments
 (0)