@@ -307,36 +307,72 @@ def imresample(img, sz):
307307
308308
309309def crop_resize (img , box , image_size ):
310+ """
311+ box: (x1, y1, x2, y2) in pixel coords, x2/y2 exclusive-style is fine too (we resize anyway).
312+ img: numpy HWC, torch HWC or CHW, or PIL Image
313+ """
314+
315+ x1 , y1 , x2 , y2 = map (int , box )
316+ w = max (1 , x2 - x1 )
317+ h = max (1 , y2 - y1 )
318+
319+ s = max (w , h )
320+ cx = x1 + w / 2.0
321+ cy = y1 + h / 2.0
322+
323+ # square window [x0, x0+s), [y0, y0+s)
324+ x0 = int (round (cx - s / 2.0 ))
325+ y0 = int (round (cy - s / 2.0 ))
326+
310327 if isinstance (img , np .ndarray ):
311- # box[1] -> y1
312- # box[3] -> y2
313- # box[0] -> x1
314- # box[2] -> x2
315- h = box [3 ] - box [1 ]
316- w = box [2 ] - box [0 ]
317- yc = box [1 ] + h // 2
318- xc = box [0 ] + w // 2
319- intermediate_size = max (w , h )
320- y0 = max (0 , yc - intermediate_size // 2 )
321- y1 = min (img .shape [0 ], yc + intermediate_size // 2 )
322- x0 = max (0 , xc - intermediate_size // 2 )
323- x1 = min (img .shape [1 ], xc + intermediate_size // 2 )
324- img = img [y0 :y1 , x0 :x1 ]
325-
326- out = cv2 .resize (
327- img ,
328- (image_size , image_size ),
329- interpolation = cv2 .INTER_AREA
330- ).copy ()
328+ H , W = img .shape [:2 ]
331329 elif isinstance (img , torch .Tensor ):
332- img = img [box [1 ]:box [3 ], box [0 ]:box [2 ]]
333- out = imresample (
334- img .permute (2 , 0 , 1 ).unsqueeze (0 ).float (),
335- (image_size , image_size )
336- ).byte ().squeeze (0 ).permute (1 , 2 , 0 )
330+ # accept HWC or CHW
331+ if img .ndim != 3 :
332+ raise ValueError ("torch img must be 3D (HWC or CHW)" )
333+ if img .shape [0 ] in (1 , 3 , 4 ) and img .shape [2 ] not in (1 , 3 , 4 ):
334+ # CHW
335+ C , H , W = img .shape
336+ chw = True
337+ else :
338+ # HWC
339+ H , W , C = img .shape
340+ chw = False
337341 else :
338- out = img .crop (box ).copy ().resize ((image_size , image_size ), Image .BILINEAR )
339- return out
342+ # PIL
343+ W , H = img .size
344+
345+ # shift window to stay inside image (keeps square)
346+ x0 = min (max (0 , x0 ), max (0 , W - s ))
347+ y0 = min (max (0 , y0 ), max (0 , H - s ))
348+ x1n , y1n = x0 + s , y0 + s
349+
350+ if isinstance (img , np .ndarray ):
351+ crop = img [y0 :y1n , x0 :x1n ]
352+ return cv2 .resize (crop , (image_size , image_size ), interpolation = cv2 .INTER_AREA ).copy ()
353+
354+ if isinstance (img , torch .Tensor ):
355+ if chw :
356+ crop = img [:, y0 :y1n , x0 :x1n ]
357+ else :
358+ crop = img [y0 :y1n , x0 :x1n , :]
359+
360+ # simplest: use torch.nn.functional.interpolate on float
361+ import torch .nn .functional as F
362+ if chw :
363+ crop_f = crop .unsqueeze (0 ).float ()
364+ else :
365+ crop_f = crop .permute (2 , 0 , 1 ).unsqueeze (0 ).float ()
366+
367+ out = F .interpolate (crop_f , size = (image_size , image_size ), mode = "area" )
368+ out = out .squeeze (0 )
369+ if not chw :
370+ out = out .permute (1 , 2 , 0 )
371+ return out .byte ()
372+
373+ # PIL
374+ crop = img .crop ((x0 , y0 , x1n , y1n ))
375+ return crop .resize ((image_size , image_size ), Image .BILINEAR )
340376
341377
342378def save_img (img , path ):
0 commit comments