diff --git a/realesrgan/utils.py b/realesrgan/utils.py index 67e5232d61..dbb9a88078 100644 --- a/realesrgan/utils.py +++ b/realesrgan/utils.py @@ -1,9 +1,11 @@ -import cv2 import math -import numpy as np import os import queue import threading +from typing import Union + +import cv2 +import numpy as np import torch from basicsr.utils.download_util import load_file_from_url from torch.nn import functional as F @@ -88,14 +90,14 @@ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'): def pre_process(self, img): """Pre-process, such as pre-pad and mod pad, so that the images can be divisible """ - img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() - self.img = img.unsqueeze(0).to(self.device) + img = [torch.from_numpy(np.transpose(i, (2, 0, 1))).float() for i in img] + self.img = [i.unsqueeze(0).to(self.device) for i in img] if self.half: - self.img = self.img.half() + self.img = [i.half() for i in self.img] # pre_pad if self.pre_pad != 0: - self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + self.img = [F.pad(i, (0, self.pre_pad, 0, self.pre_pad), 'reflect') for i in self.img] # mod pad for divisible borders if self.scale == 2: self.mod_scale = 2 @@ -103,15 +105,20 @@ def pre_process(self, img): self.mod_scale = 4 if self.mod_scale is not None: self.mod_pad_h, self.mod_pad_w = 0, 0 - _, _, h, w = self.img.size() - if (h % self.mod_scale != 0): - self.mod_pad_h = (self.mod_scale - h % self.mod_scale) - if (w % self.mod_scale != 0): - self.mod_pad_w = (self.mod_scale - w % self.mod_scale) - self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + padded_imgs = [] + for im in self.img: + _, _, h, w = im.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + im = F.pad(im, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + padded_imgs.append(im) + self.img = padded_imgs def process(self): # model inference + self.img = torch.cat(self.img, dim=0) self.output = self.model(self.img) def tile_process(self): @@ -176,8 +183,8 @@ def tile_process(self): # put tile into output image self.output[:, :, output_start_y:output_end_y, - output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, - output_start_x_tile:output_end_x_tile] + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] def post_process(self): # remove extra pad @@ -191,29 +198,35 @@ def post_process(self): return self.output @torch.no_grad() - def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): - h_input, w_input = img.shape[0:2] - # img: numpy - img = img.astype(np.float32) - if np.max(img) > 256: # 16-bit image - max_range = 65535 - print('\tInput is a 16-bit image') - else: - max_range = 255 - img = img / max_range - if len(img.shape) == 2: # gray image - img_mode = 'L' - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) - elif img.shape[2] == 4: # RGBA image with alpha channel - img_mode = 'RGBA' - alpha = img[:, :, 3] - img = img[:, :, 0:3] - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - if alpha_upsampler == 'realesrgan': - alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) - else: - img_mode = 'RGB' - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + def enhance(self, img: Union[np.ndarray, list[np.ndarray]], outscale=None, alpha_upsampler='realesrgan'): + if isinstance(img, np.ndarray): # bs=1 + img = [img] + h_input = [i.shape[0] for i in img] + w_input = [i.shape[1] for i in img] + img = [i.astype(np.float32) for i in img] + max_range = [65535 if np.max(i) > 256 else 255 for i in img] + if any(i > 256 for i in max_range): + print('\tInput contains 16-bit images') + + img = [i / m_range for i, m_range in zip(img, max_range)] + + img_modes = [] + for idx, im in enumerate(img): + if len(im.shape) == 2: # gray image + img_mode = 'L' + img[idx] = cv2.cvtColor(im, cv2.COLOR_GRAY2RGB) + elif im.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = im[:, :, 3] + img = im[:, :, 0:3] + img[idx] = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img[idx] = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + img_modes.append(img_mode) # ------------------- process image (without the alpha channel) ------------------- # self.pre_process(img) @@ -221,46 +234,51 @@ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): self.tile_process() else: self.process() - output_img = self.post_process() - output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) - if img_mode == 'L': - output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) - - # ------------------- process the alpha channel if necessary ------------------- # - if img_mode == 'RGBA': - if alpha_upsampler == 'realesrgan': - self.pre_process(alpha) - if self.tile_size > 0: - self.tile_process() - else: - self.process() - output_alpha = self.post_process() - output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() - output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) - output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) - else: # use the cv2 resize for alpha channel - h, w = alpha.shape[0:2] - output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) - - # merge the alpha channel - output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) - output_img[:, :, 3] = output_alpha - - # ------------------------------ return ------------------------------ # - if max_range == 65535: # 16-bit image - output = (output_img * 65535.0).round().astype(np.uint16) - else: - output = (output_img * 255.0).round().astype(np.uint8) - - if outscale is not None and outscale != float(self.scale): - output = cv2.resize( - output, ( - int(w_input * outscale), - int(h_input * outscale), - ), interpolation=cv2.INTER_LANCZOS4) - - return output, img_mode + output_imgs = self.post_process() + output_imgs = output_imgs.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_imgs = [o for o in output_imgs] + + final_results = [] + for output_img, img_mode, max_r, h_i, w_i in zip(output_imgs, img_modes, max_range, h_input, w_input): + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_r == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_i * outscale), + int(h_i * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + final_results.append((output, img_mode)) + return zip(*final_results) class PrefetchReader(threading.Thread):