Skip to content

RGBA with proper hue/saturation and upscaling #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 79 additions & 24 deletions scripts/postprocessing_pixelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from pixelization.models.networks import define_G
import pixelization.models.c2pGen
import gdown

import colorsys

pixelize_code = [
233356.8125, -27387.5918, -32866.8008, 126575.0312, -181590.0156,
Expand Down Expand Up @@ -107,20 +108,19 @@ def load(self):

missing = False

models = (
(path_pixelart_vgg19, "https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM"),
(path_160_net_G_A, "https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az"),
(path_alias_net, "https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_"),
)
if not os.path.exists(path_pixelart_vgg19):
print(f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM")
missing = True

for path, url in models:
if not os.path.exists(path):
gdown.download(url, path)
if not os.path.exists(path_160_net_G_A):
print(f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az")
missing = True

if not os.path.exists(path):
missing = True
if not os.path.exists(path_alias_net):
print(f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_")
missing = True

assert not missing, f'Missing checkpoints for pixelization - see console for download links. Download checkpoints manually and place them in {path_checkpoints}.'
assert not missing, 'Missing checkpoints for pixelization - see console for doqwnload links.'

with torch.no_grad():
self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
Expand All @@ -136,7 +136,6 @@ def load(self):
alias_state["module." + str(p)] = alias_state.pop(p)
self.alias_net.load_state_dict(alias_state)


def process(img):
ow, oh = img.size

Expand All @@ -150,22 +149,71 @@ def process(img):

img = img.crop((left, top, right, bottom))

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Split the RGBA image into RGB and alpha channels
img_rgba = img.convert('RGBA')
r, g, b, a = img_rgba.split()

return trans(img)[None, :, :, :]
# Convert RGB to tensor and normalize
rgb_img = Image.merge('RGB', (r, g, b))
trans_rgb = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
rgb_tensor = trans_rgb(rgb_img)

# Convert alpha channel to tensor (scale from 0-255 to 0-1)
alpha_tensor = transforms.ToTensor()(a)[None, :, :] # Add an extra dimension for batch size

def to_image(tensor, pixel_size, upscale_after):
return rgb_tensor[None, :, :, :], alpha_tensor

def to_image(tensor, alpha_tensor, pixel_size, upscale_after, original_img, copy_hue, copy_sat):
img = tensor.data[0].cpu().float().numpy()
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
img = img.astype(np.uint8)
img = Image.fromarray(img)
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
width = img.size[0] // 4
height = img.size[1] // 4
img = img.resize((width, height), resample=Image.Resampling.NEAREST)

# Resize the alpha channel to match the new dimensions
alpha_img = alpha_tensor.data[0].cpu().numpy()
alpha_img = (alpha_img * 255).astype(np.uint8)
alpha_img = Image.fromarray(alpha_img.squeeze(), mode='L')
alpha_img = alpha_img.resize((width, height), resample=Image.Resampling.NEAREST)

if copy_hue or copy_sat:
original_img = original_img.resize((width, height), resample=Image.Resampling.NEAREST)
img = color_image(img, original_img, copy_hue, copy_sat)


# Merge the processed RGB image with the alpha channel
img.putalpha(alpha_img)

if upscale_after:
img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST)

return img

def color_image(img, original_img, copy_hue, copy_sat):
img = img.convert("RGB")
original_img = original_img.convert("RGB")

colored_img = Image.new("RGB", img.size)

for x in range(img.width):
for y in range(img.height):
pixel = original_img.getpixel((x, y))
r, g, b = pixel
original_h, original_s, original_v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)

pixel = img.getpixel((x, y))
r, g, b = pixel
h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)

r, g, b = colorsys.hsv_to_rgb(original_h if copy_hue else h, original_s if copy_sat else s, v)
colored_img.putpixel((x, y), (int(r * 255), int(g * 255), int(b * 255)))

return colored_img

class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
name = "Pixelization"
Expand All @@ -175,16 +223,21 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def ui(self):
with ui_components.InputAccordion(False, label="Pixelize") as enable:
with gr.Row():
upscale_after = gr.Checkbox(False, label="Keep resolution")
upscale_after = gr.Checkbox(False, label="Keep resolution")
copy_hue = gr.Checkbox(False, label="Restore hue")
copy_sat = gr.Checkbox(False, label="Restore saturation")
with gr.Column():
pixel_size = gr.Slider(minimum=1, maximum=16, step=1, label="Pixel size", value=4, elem_id="pixelization_pixel_size")

return {
"enable": enable,
"upscale_after": upscale_after,
"pixel_size": pixel_size,
"copy_hue": copy_hue,
"copy_sat": copy_sat,
}

def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size):
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size, copy_hue, copy_sat):
if not enable:
return

Expand All @@ -196,20 +249,22 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale

self.model.to(devices.device)

pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size)).convert('RGB')
pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size))
original_img = pp.image.copy()

with torch.no_grad():
in_t = process(pp.image).to(devices.device)
in_t, alpha_t = process(pp.image)
in_t = in_t.to(devices.device)
alpha_t = alpha_t.to(devices.device)

feature = self.model.G_A_net.module.RGBEnc(in_t)
code = torch.asarray(pixelize_code, device=devices.device).reshape((1, 256, 1, 1))
code = torch.tensor(pixelize_code, device=devices.device).reshape((1, 256, 1, 1))
adain_params = self.model.G_A_net.module.MLP(code)
images = self.model.G_A_net.module.RGBDec(feature, adain_params)
out_t = self.model.alias_net(images)

pp.image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after)
pp.image = to_image(out_t, alpha_t, pixel_size=pixel_size, upscale_after=upscale_after, original_img=original_img, copy_hue=copy_hue, copy_sat=copy_sat)

self.model.to(devices.cpu)

pp.info["Pixelization pixel size"] = pixel_size