Skip to content

Commit 9b1ce70

Browse files
authored
Merge pull request NVlabs#3 from oneiroid/master
Added support for custom dlatent_avg for truncation
2 parents c6bd591 + 1d429da commit 9b1ce70

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

encode_images.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def main():
2323
parser.add_argument('--data_dir', default='data', help='Directory for storing optional models')
2424
parser.add_argument('--mask_dir', default='masks', help='Directory for storing optional masks')
2525
parser.add_argument('--load_last', default='', help='Start with embeddings from directory')
26+
parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
2627
parser.add_argument('--model_url', default='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', help='Fetch a StyleGAN model to train on from this URL') # karras2019stylegan-ffhq-1024x1024.pkl
2728
parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
2829
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
@@ -90,6 +91,8 @@ def main():
9091
generator_network, discriminator_network, Gs_network = pickle.load(f)
9192

9293
generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
94+
if (args.dlatent_avg != ''):
95+
generator.set_dlatent_avg(np.load(args.dlatent_avg))
9396

9497
perc_model = None
9598
if (args.use_lpips_loss > 0.00000001):

encoder/generator_model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def __init__(self, model, batch_size, clipping_threshold=2, tiled_dlatent=False,
4444
partial(create_stub, batch_size=batch_size)],
4545
structure='fixed')
4646

47-
self.dlatent_avg = model.get_var('dlatent_avg')
47+
self.dlatent_avg_def = model.get_var('dlatent_avg')
48+
self.reset_dlatent_avg()
4849
self.sess = tf.get_default_session()
4950
self.graph = tf.get_default_graph()
5051

@@ -93,6 +94,12 @@ def get_dlatents(self):
9394
def get_dlatent_avg(self):
9495
return self.dlatent_avg
9596

97+
def set_dlatent_avg(self, dlatent_avg):
98+
self.dlatent_avg = dlatent_avg
99+
100+
def reset_dlatent_avg(self):
101+
self.dlatent_avg = self.dlatent_avg_def
102+
96103
def generate_images(self, dlatents=None):
97104
if dlatents:
98105
self.set_dlatents(dlatents)

0 commit comments

Comments
 (0)