Skip to content

Dataloader #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
9280efd
complete logging
Oct 8, 2021
0f33625
change default checkpoint
Oct 8, 2021
a9c372e
fix missing key error
Oct 8, 2021
46c5caf
add path argument
Oct 8, 2021
2f6a8f7
add path argument
Oct 8, 2021
3aff2fe
add path argument
Oct 8, 2021
f7181ed
add path argument
Oct 8, 2021
8232e53
add comment
Oct 10, 2021
08fafaf
skimage version conflict solved
Oct 10, 2021
8b17ffd
add pre-download backbone weight
Oct 11, 2021
ccdc8c8
init dataloader and cleaner code
Oct 12, 2021
2b065bb
add option to turn tqdm off
Oct 12, 2021
bc2d5a5
clean argument
Oct 12, 2021
db77740
add pretty print args
Oct 12, 2021
6af8c7c
adding input type restriction
Oct 12, 2021
9044413
fix typo
Oct 12, 2021
f5153ee
fix typo
Oct 12, 2021
428f49a
create all directories beforehand
Oct 12, 2021
9133282
create all directories beforehand
Oct 12, 2021
9cc06f6
Merge branch 'dataloader' of https://github.com/nessessence/stylegan2…
Oct 12, 2021
447ca4a
ascii tqdm
Oct 12, 2021
71cc336
continue projection
Oct 12, 2021
2d48ee0
minor change
Oct 12, 2021
877e709
minor change
Oct 12, 2021
67c1167
index range
Oct 12, 2021
8726e92
index range
Oct 13, 2021
5e25dc0
index range
Oct 13, 2021
a56b567
index range
Oct 13, 2021
416a8b5
data parallel
Oct 13, 2021
e769e25
data parallel
Oct 13, 2021
c6b9903
data parallel
Oct 13, 2021
22db08f
data parallel
Oct 13, 2021
6a37345
data parallel
Oct 13, 2021
84b1325
change percept to another cuda device
Oct 13, 2021
32dd62d
minor change
Oct 13, 2021
f3d9398
minor change
Oct 13, 2021
7bb6a20
use dataparallel from original lpips
Oct 13, 2021
daecc4d
fix continue project
Oct 13, 2021
ef5c2d3
add default_device_idx for lpips
Oct 13, 2021
924d29c
minor change
Oct 13, 2021
6c3bf26
set the first index of Dataparallel
Oct 13, 2021
35d20db
set the first index of Dataparallel
Oct 13, 2021
12a668a
print default cuda
Oct 13, 2021
e209c2f
print default cuda
Oct 13, 2021
63d6b70
print default cuda
Oct 13, 2021
9d2c1c3
print device img_gen, imgs
Oct 13, 2021
3aef4ab
change to manually DP
Oct 13, 2021
9dbe8be
manaully set DP and lpips_cuda works!
Oct 13, 2021
af9e815
active both DP
Oct 13, 2021
4180f66
active both DP
Oct 13, 2021
9e23c7a
active both DP
Oct 13, 2021
792a158
back to normal manual
Oct 13, 2021
e97b090
sort index_range
Oct 13, 2021
c98fc8d
tqdm off
Oct 13, 2021
6177fa0
drop_last = false
Oct 13, 2021
079dc70
minor change
Oct 14, 2021
0b00ff4
interpolate
Oct 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,5 @@ dmypy.json
wandb/
*.lmdb/
*.pkl
550000.pt
sample/
Binary file added 48_projected.pt
Binary file not shown.
7 changes: 7 additions & 0 deletions convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,14 @@ def fill_statedict(state_dict, vars, size, n_mlp):
img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)

print(img_diff.abs().max())



utils.save_image(
img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
)

print(f"save sample images {name}.png")
print("converting weight complete!!!")


32 changes: 32 additions & 0 deletions datasets/custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch.utils.data import Dataset
import os
from PIL import Image
from natsort import natsorted

class CustomDataSet(Dataset):
def __init__(self, main_dir, transform,completed_images=[],index_range=None):
self.main_dir = main_dir
self.transform = transform
all_imgs = os.listdir(main_dir)
all_imgs = natsorted(all_imgs)
if index_range != None:
print(f"applying index range {index_range}")
start_index = index_range[0]; end_index = index_range[1]
all_imgs = all_imgs[start_index: end_index ]
if len(completed_images) > 0:
print(f"Continue projection from previous process that have finished projecting {len(completed_images)} images")
if index_range != None: print(f"total in-range completion: {len(set(all_imgs).intersection(set(all_imgs)))} images")
all_imgs = [img for img in all_imgs if img not in completed_images]
if index_range != None: print(f"total images to be process: {len(all_imgs)} images, within range [{index_range}]")
else: print(f"total images to be process: {len(all_imgs)} images")
self.total_imgs = natsorted(all_imgs)

def __len__(self):
return len(self.total_imgs)

def __getitem__(self, idx):
fname = self.total_imgs[idx]
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = Image.open(img_loc).convert("RGB")
img = self.transform(image)
return fname,img
23 changes: 18 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from torchvision import utils
from model import Generator
from tqdm import tqdm
import os.path as osp
import os


def generate(args, g_ema, device, mean_latent):
def generate(args, g_ema, device, mean_latent,output_path):

with torch.no_grad():
g_ema.eval()
Expand All @@ -19,7 +21,7 @@ def generate(args, g_ema, device, mean_latent):

utils.save_image(
sample,
f"sample/{str(i).zfill(6)}.png",
osp.join(output_path,f"{str(i).zfill(6)}.png"),
nrow=1,
normalize=True,
range=(-1, 1),
Expand Down Expand Up @@ -53,7 +55,7 @@ def generate(args, g_ema, device, mean_latent):
parser.add_argument(
"--ckpt",
type=str,
default="stylegan2-ffhq-config-f.pt",
default="checkpoint/stylegan2-ffhq-config-f.pt",
help="path to the model checkpoint",
)
parser.add_argument(
Expand All @@ -63,8 +65,19 @@ def generate(args, g_ema, device, mean_latent):
help="channel multiplier of the generator. config-f = 2, else = 1",
)

parser.add_argument(
'-o',
"--output_path",
type=str,
default="sample",
help="root path save sample",
)
args = parser.parse_args()

args.output_path = osp.join( args.output_path,f"{osp.basename(args.ckpt).split('.')[0]}_{args.size}")
os.makedirs(args.output_path,exist_ok=True)


args.latent = 512
args.n_mlp = 8

Expand All @@ -73,12 +86,12 @@ def generate(args, g_ema, device, mean_latent):
).to(device)
checkpoint = torch.load(args.ckpt)

g_ema.load_state_dict(checkpoint["g_ema"])
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)

if args.truncation < 1:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
else:
mean_latent = None

generate(args, g_ema, device, mean_latent)
generate(args, g_ema, device, mean_latent,args.output_path)
5,785 changes: 5,785 additions & 0 deletions interpolate.ipynb

Large diffs are not rendered by default.

146 changes: 146 additions & 0 deletions interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import argparse
import os
import os.path as osp
import torch
from torchvision import transforms
from torch.backends import cudnn
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # new add by gu
torch.cuda.set_device(torch.device('cuda',0))
cudnn.benchmark = True
from datasets.custom_dataset import CustomDataSet
from model import Generator
import torch.nn.functional as F
from collections import defaultdict


import math
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image




def init_styleGAN(img_size=256,ckpt='./checkpoint/550000.pt'):
g_ema = Generator(img_size, 512, 8)
g_ema.load_state_dict(torch.load(ckpt)["g_ema"], strict=False)
g_ema.eval()
g_ema.to('cuda')
return g_ema



def cos(a, b):
a = a.view(-1)
b = b.view(-1)
a = F.normalize(a, dim=0)
b = F.normalize(b, dim=0)
return (a * b).sum()

def spherical_interpolation(x0, x1, alpha):
theta = torch.acos(cos(x0, x1)) #torch.arccos(cos(x0, x1))
a = torch.sin((1-alpha)*theta) / torch.sin(theta) * x0
b = torch.sin(alpha*theta) / torch.sin(theta) * x1
return a + b

def sqrt_interpolation(x0, x1, alpha):
return ((1-alpha) * x0 + (alpha) * x1) / math.sqrt(alpha ** 2 + (1-alpha) ** 2)

def linear_interpolation(x0, x1, alpha):
return ((1-alpha) * x0 + (alpha) * x1)






def make_image(tensor):
return (
tensor.detach()
.clamp_(min=-1, max=1)
.add(1)
.div_(2)
.mul(255)
.type(torch.uint8)
.permute(0, 2, 3, 1)
.to("cpu")
.numpy()
)
def post_processing(img_gen):
channel, height, width = img_gen.shape
if height > 256:
factor = height // 256
img_gen = img_gen.reshape(channel, height // factor, factor, width // factor, factor)
img_gen = img_gen.mean([3, 5])

return img_gen
def pair_interpolate(noises1,noises2,latent1,latent2,alphas,w_plus=False):
global g_ema
interpolated_imgs = []
for alpha in alphas:
interpolated_noises = []
for noise1, noise2 in zip(noises1,noises2):
interpolated_noises += [spherical_interpolation(noise1, noise2, alpha)]
if w_plus:
interpolated_latent = torch.stack([linear_interpolation(latent1_a,latent2_b,alpha) for latent1_a,latent2_b in zip(latent1,latent2) ])
else: interpolated_latent = linear_interpolation(latent1,latent2,alpha)
# print(interpolated_latent.shape,interpolated_latent[None, :].shape )
interpolated_img, _ = g_ema([interpolated_latent[None, :]], input_is_latent=True, noise=interpolated_noises)
interpolated_img = make_image(interpolated_img)
interpolated_imgs += [interpolated_img[0]]
return interpolated_imgs
def make_pair_interpolate(latent_pairs,alphas,w_plus=False):
interpolated_pair_imgs = defaultdict(lambda: list())
for i in latent_pairs:
interpolated_pair_imgs[i] = pair_interpolate(latent_pairs[i]['a']['noise'],latent_pairs[i]['b']['noise'],latent_pairs[i]['a']['latent'],latent_pairs[i]['b']['latent'],alphas=alphas,w_plus=w_plus)
return interpolated_pair_imgs


def save_images(interpolated_pair_imgs,alphas,output_path='./output/',w_plus=False):
w_flag = "w_plus" if w_plus else "w"
output_path = osp.join(output_path,w_flag,f"{len(alphas)}alpha")
os.makedirs(output_path,exist_ok=True)
for i in interpolated_pair_imgs:
pair_path = osp.join(output_path,f"pair{i}")
os.makedirs(pair_path,exist_ok=True)
for j,(img,alpha) in enumerate(zip(interpolated_pair_imgs[i],alphas)):
img_path = osp.join(pair_path,f"{j}_alpha{alpha}_.png")
pil_img = Image.fromarray(img)
pil_img.save(img_path)
print(f"saved images to {output_path}")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image projector to the generator latent spaces")
parser.add_argument("-lp","--latent_path", default="./projected_output/interpolate_pairs/projected_W_1000step_256_550000/projected_latent_dict")
parser.add_argument("--device", default='cuda',choices=['cuda','cpu'])
parser.add_argument("-ni", "--num_interpolate",type=int,default=7,help='alpha, frequency of interpolation between values 0-1')
parser.add_argument("--w_plus",action="store_true",help="allow to use distinct latent codes to each layers",)
parser.add_argument("-o","--output_path",default='output/celeb_pairs/')
parser.add_argument("-s","--img_size",type=int,default=256)
parser.add_argument("-ckpt","--checkpoint",default='./checkpoint/550000.pt')



args = parser.parse_args()


g_ema = init_styleGAN(args.img_size,args.checkpoint)

alphas = np.linspace(0, 1, args.num_interpolate)

latent_path = args.latent_path
latent_files = os.listdir(latent_path)
latent_pairs = defaultdict(lambda: defaultdict(str))
for latent_file in latent_files:
img_id, img_tag = latent_file.split("_")[:2]
with open(osp.join(latent_path,latent_file),'rb') as f:
latent_pairs[img_id][img_tag] = torch.load(f)
print(f"there are {len(latent_files)} latent files")

print("start interpolation")
interpolated_pair_imgs = make_pair_interpolate(latent_pairs,alphas=alphas,w_plus=args.w_plus)
print("finish interpolation")
save_images(interpolated_pair_imgs,alphas=alphas,w_plus=args.w_plus,output_path=args.output_path)



7 changes: 4 additions & 3 deletions lpips/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from __future__ import print_function

import numpy as np
from skimage.measure import compare_ssim
# from skimage.measure import compare_ssim
from skimage.metrics import structural_similarity as compare_ssim
import torch
from torch.autograd import Variable

from lpips import dist_model

class PerceptualLoss(torch.nn.Module):
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0],default_device_idx=0): # VGG using our perceptually-learned weights (LPIPS metric)
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids,default_device_idx=default_device_idx)
print('...[%s] initialized'%self.model.name())
print('...Done')

Expand Down
3 changes: 2 additions & 1 deletion lpips/dist_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def name(self):

def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
use_gpu=True, printNet=False, spatial=False,
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0], default_device_idx=0):
'''
INPUTS
model - ['net-lin'] for linearly calibrated network
Expand Down Expand Up @@ -96,6 +96,7 @@ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=Fa
self.net.eval()

if(use_gpu):
# default_device_id = gpu_ids[default_device_idx] if len(gpu_ids) > 1 else 0
self.net.to(gpu_ids[0])
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
if(self.is_train):
Expand Down
10 changes: 6 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2

# weight of conv2d
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
Expand Down Expand Up @@ -255,11 +256,11 @@ def forward(self, input, style):

return out

style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
style = self.modulation(style).view(batch, 1, in_channel, 1, 1) # I'm not sure, but I guess this is "A" in the paper, s = A(latent_w) : laten_w is stlye
weight = self.scale * self.weight * style # modulated_w = s * w

if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) # torch.rsqrt(x) = 1/sqr()
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

weight = weight.view(
Expand Down Expand Up @@ -507,8 +508,9 @@ def forward(
noise=None,
randomize_noise=True,
):
# style or latent here is "w" in the paper (the one that processed from noise through a network )
if not input_is_latent:
styles = [self.style(s) for s in styles]
styles = [self.style(s) for s in styles]

if noise is None:
if randomize_noise:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output/celeb_pairs/w/5alpha/pair1/2_alpha0.5_.png
Binary file added output/celeb_pairs/w/5alpha/pair1/4_alpha1.0_.png
Binary file added output/celeb_pairs/w/5alpha/pair6/0_alpha0.0_.png
Binary file added output/celeb_pairs/w/7alpha/pair1/3_alpha0.5_.png
Binary file added output/celeb_pairs/w/7alpha/pair1/6_alpha1.0_.png
Binary file added output/celeb_pairs/w/7alpha/pair6/0_alpha0.0_.png
Binary file added output/celeb_pairs/w/9alpha/pair1/4_alpha0.5_.png
Binary file added output/celeb_pairs/w/9alpha/pair1/8_alpha1.0_.png
Binary file added output/celeb_pairs/w/9alpha/pair6/0_alpha0.0_.png
Loading