Skip to content

Commit 799c4de

Browse files
committed
Update README, start from perlin noise for discriminator_synthesis.py.
1 parent 0bfa8e1 commit 799c4de

File tree

3 files changed

+60
-29
lines changed

3 files changed

+60
-29
lines changed

README.md

+19-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@ capabilities (but hopefully not its complexity!).
1010

1111
This repository adds the following (not yet the complete list):
1212

13+
* Dataset tool
14+
* Add `--center-crop-tall`: add vertical black bars to the sides instead, in the same vein as the horizontal bars in
15+
`--center-crop-wide`.
16+
* Grayscale images in the dataset are converted to `RGB`.
17+
* If the dataset tool encounters an error, print it along the offending image, but continue with the rest of the dataset
18+
([pull #39](https://github.com/NVlabs/stylegan3/pull/39) from [Andreas Jansson](https://github.com/andreasjansson)).
19+
* *TODO*: Add multi-crop, as used in [Earth View](https://github.com/PDillis/earthview#multi-crop---data_augmentpy).
20+
* Training
21+
* `--mirrory`: Added vertical mirroring for doubling the dataset size
22+
* `--gamma`: If no R1 regularization is provided, the heuristic formula will be used from [StyleGAN2](https://github.com/NVlabs/stylegan2).
23+
* `--augpipe`: Now available to use is [StyleGAN2-ADA's](https://github.com/NVlabs/stylegan2-ada-pytorch) full list of augpipe, e,g., `blit`, `geom`, `bgc`, `bgcfnc`, etc.
24+
* `--img-snap`: When to save snapshot images, so now it's independent of when the model is saved;
25+
* `--snap-res`: The resolution of the snapshots, depending on your screen resolution, or how many images you wish to see per tick. Available resolutions: `1080p`, `4k`, and `8k`.
26+
* `--resume-kimg`: Starting number of `kimg`, useful when continuing training a previous run
27+
* `--outdir`: Automatically set as `training-runs`
28+
* `--metrics`: Now set by default to `None`, so there's no need to worry about this one
29+
* `--resume`: All available pre-trained models from NVIDIA can be found with a simple dictionary, depending on the `--cfg` used.
30+
For example, if `--cfg=stylegan3-r`, then to transfer learn from FFHQU at 1024 resolution, set `--resume=ffhqu1024`. Full list available [here](https://github.com/PDillis/stylegan3-fun/blob/0bfa8e108487b50d6ecb73718c60497f063d8c17/train.py#L297).
1331
* Interpolation videos
1432
* [Random interpolation](https://youtu.be/DNfocO1IOUE)
1533
* Style-mixing
@@ -21,19 +39,13 @@ This repository adds the following (not yet the complete list):
2139
* Additional losses to use for better projection (e.g., using VGG16 or [CLIP](https://github.com/openai/CLIP))
2240
* [Discriminator Synthesis](https://arxiv.org/abs/2111.02175) (official code)
2341
* Generate a static image or a [video](https://youtu.be/hEJKWL2VQTE) with a feedback loop
42+
* Start from a random image (`random` or `perlin`, using [Mathieu Duchesneau's implementation](https://github.com/duchesneaumathieu/pyperlin)) or from an existing one
2443
* Expansion on GUI/`visualizer.py`
2544
* Added the rest of the affine transformations
2645
* General model and code additions
2746
* No longer necessary to specify `--outdir` when running the code, as the output directory will be automatically generated
2847
* [Better sampling?](https://arxiv.org/abs/2110.08009) (TODO)
2948
* StyleGAN3: anchor the latent space for easier to follow interpolations
30-
* Dataset tool
31-
* Add `--center-crop-tall`: add vertical black bars to the sides instead, in the same vein as the horizontal bars in
32-
`--center-crop-wide`.
33-
* Grayscale images in the dataset are converted to `RGB`.
34-
* If the dataset tool encounters an error, print it along the offending image, but continue with the rest of the dataset
35-
([pull #39](https://github.com/NVlabs/stylegan3/pull/39) from [Andreas Jansson](https://github.com/andreasjansson)).
36-
* *TODO*: Add multi-crop, as used in [Earth View](https://github.com/PDillis/earthview#multi-crop---data_augmentpy).
3749

3850
***TODO:*** Finish documentation for better user experience, add videos/images, code samples.
3951

discriminator_synthesis.py

+40-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn as nn
32
from torch.autograd import Variable
43
from torchvision import transforms
54

@@ -11,27 +10,18 @@
1110
except ImportError:
1211
raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"')
1312

14-
try:
15-
import skvideo.io
16-
except ImportError:
17-
raise ImportError('scikit-video not found! Install it via "pip install scikit-video"')
18-
1913
import scipy.ndimage as nd
2014
import numpy as np
2115

2216
import os
2317
import click
2418
from typing import Union, Tuple, Optional, List
25-
from collections import OrderedDict
2619
from tqdm import tqdm
2720

2821
from torch_utils import gen_utils
2922
import dnnlib
3023
import legacy
31-
from network_features import VGG16FeaturesNVIDIA, DiscriminatorFeatures
32-
33-
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide'
34-
import moviepy.editor
24+
from network_features import DiscriminatorFeatures
3525

3626

3727
# ----------------------------------------------------------------------------
@@ -78,6 +68,7 @@ def parse_layers(s: str) -> List[str]:
7868
# DeepDream code; modified from Erik Linder-Norén's repository: https://github.com/eriklindernoren/PyTorch-Deep-Dream
7969

8070
def get_image(seed: int = 0,
71+
image_noise: str = 'random',
8172
starting_image: Union[str, os.PathLike] = None,
8273
image_size: int = 1024) -> Tuple[PIL.Image.Image, str]:
8374
"""Set the random seed (NumPy + PyTorch), as well as get an image from a path or generate a random one with the seed"""
@@ -88,8 +79,24 @@ def get_image(seed: int = 0,
8879
if starting_image is not None:
8980
image = Image.open(starting_image).convert('RGB').resize((image_size, image_size), Image.LANCZOS)
9081
else:
91-
starting_image = f'random_image-seed_{seed}.jpg'
92-
image = Image.fromarray(rnd.randint(0, 255, (image_size, image_size, 3), dtype='uint8'))
82+
if image_noise == 'random':
83+
starting_image = f'random_image-seed_{seed}.jpg'
84+
image = Image.fromarray(rnd.randint(0, 255, (image_size, image_size, 3), dtype='uint8'))
85+
elif image_noise == 'perlin':
86+
try:
87+
# Graciously using Mathieu Duchesneau's implementation: https://github.com/duchesneaumathieu/pyperlin
88+
from pyperlin import FractalPerlin2D
89+
starting_image = f'perlin_image-seed_{seed}.jpg'
90+
shape = (3, image_size, image_size)
91+
resolutions = [(2**i, 2**i) for i in range(1, 6+1)] # for lacunarity = 2.0 # TODO: set as cli variable
92+
factors = [0.5**i for i in range(6)] # for persistence = 0.5 TODO: set as cli variables
93+
g_cuda = torch.Generator(device='cuda')
94+
rgb = FractalPerlin2D(shape, resolutions, factors, generator=g_cuda)().cpu().numpy()
95+
rgb = (255 * (rgb + 1) / 2).astype(np.uint8) # [-1.0, 1.0] => [0, 255]
96+
image = Image.fromarray(np.stack(rgb, axis=2), 'RGB')
97+
98+
except ImportError:
99+
raise ImportError('pyperlin not found! Install it via "pip install pyperlin"')
93100

94101
return image, starting_image
95102

@@ -232,6 +239,7 @@ def style_transfer_discriminator():
232239
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
233240
# Synthesis options
234241
@click.option('--seed', type=int, help='Random seed to use', default=0)
242+
@click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
235243
@click.option('--starting-image', type=str, help='Path to image to start from', default=None)
236244
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
237245
@click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=1e-2, show_default=True)
@@ -251,6 +259,7 @@ def discriminator_dream(
251259
ctx: click.Context,
252260
network_pkl: Union[str, os.PathLike],
253261
seed: int,
262+
image_noise: str,
254263
starting_image: Union[str, os.PathLike],
255264
class_idx: Optional[int], # TODO: conditional model
256265
learning_rate: float,
@@ -281,7 +290,8 @@ def discriminator_dream(
281290
available_layers = get_available_layers(max_resolution=model.get_block_resolutions()[0])
282291

283292
# Get the image and image name
284-
image, starting_image = get_image(seed=seed, starting_image=starting_image, image_size=model_resolution)
293+
image, starting_image = get_image(seed=seed, image_noise=image_noise,
294+
starting_image=starting_image, image_size=model_resolution)
285295

286296
# Make the run dir in the specified output directory
287297
desc = 'discriminator-dream-all_layers'
@@ -296,6 +306,7 @@ def discriminator_dream(
296306
'network_pkl': network_pkl,
297307
'synthesis_options': {
298308
'seed': seed,
309+
'random_image_noise': image_noise,
299310
'starting_image': starting_image,
300311
'class_idx': class_idx,
301312
'learning_rate': learning_rate,
@@ -378,12 +389,13 @@ def discriminator_dream(
378389
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
379390
# Synthesis options
380391
@click.option('--seed', type=int, help='Random seed to use', default=0, show_default=True)
392+
@click.option('--random-image-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
381393
@click.option('--starting-image', type=str, help='Path to image to start from', default=None)
382394
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
383395
@click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
384396
@click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
385397
# Layer options
386-
@click.option('--layers', type=parse_layers, help='Layers of the Discriminator to use as the features. If None, will default to the output of D.', default=['b16_conv1'], show_default=True)
398+
@click.option('--layers', type=parse_layers, help='Layers of the Discriminator to use as the features. If None, will default to the output of D.', default='b16_conv0', show_default=True)
387399
@click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
388400
@click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
389401
# Octaves options
@@ -407,16 +419,17 @@ def discriminator_dream_zoom(
407419
ctx: click.Context,
408420
network_pkl: Union[str, os.PathLike],
409421
seed: int,
410-
starting_image: Union[str, os.PathLike],
422+
image_noise: Optional[str],
423+
starting_image: Optional[Union[str, os.PathLike]],
411424
class_idx: Optional[int], # TODO: conditional model
412425
learning_rate: float,
413426
iterations: int,
414427
layers: List[str],
415-
norm_model_layers: bool,
416-
sqrt_norm_model_layers: bool,
428+
norm_model_layers: Optional[bool],
429+
sqrt_norm_model_layers: Optional[bool],
417430
num_octaves: int,
418431
octave_scale: float,
419-
unzoom_octave: bool,
432+
unzoom_octave: Optional[bool],
420433
pixel_zoom: int,
421434
rotation_deg: float,
422435
translate_x: int,
@@ -442,7 +455,8 @@ def discriminator_dream_zoom(
442455
model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
443456

444457
# Get the image and image name
445-
image, starting_image = get_image(seed=seed, starting_image=starting_image, image_size=model_resolution)
458+
image, starting_image = get_image(seed=seed, image_noise=image_noise,
459+
starting_image=starting_image, image_size=model_resolution)
446460

447461
# Make the run dir in the specified output directory
448462
desc = 'discriminator-dream-zoom'
@@ -454,6 +468,7 @@ def discriminator_dream_zoom(
454468
'network_pkl': network_pkl,
455469
'synthesis_options': {
456470
'seed': seed,
471+
'random_image_noise': image_noise,
457472
'starting_image': starting_image,
458473
'class_idx': class_idx,
459474
'learning_rate': learning_rate,
@@ -516,12 +531,15 @@ def discriminator_dream_zoom(
516531

517532
# Save the final video
518533
print('Saving video...')
519-
stream = ffmpeg.input(os.path.join(run_dir, 'dreamed_*.jpg'), pattern_type='glob', framerate=fps)
534+
if os.name == 'nt': # No glob pattern for Windows
535+
stream = ffmpeg.input(os.path.join(run_dir, f'dreamed_%0{n_digits}d.jpg'), framerate=fps)
536+
else:
537+
stream = ffmpeg.input(os.path.join(run_dir, 'dreamed_*.jpg'), pattern_type='glob', framerate=fps)
520538
stream = ffmpeg.output(stream, os.path.join(run_dir, 'dream-zoom.mp4'), crf=20, pix_fmt='yuv420p')
521539
ffmpeg.run(stream, capture_stdout=True, capture_stderr=True) # I dislike ffmpeg's console logs, so I turn them off
522540

541+
# Save the reversed video apart from the original one, so the user can compare both
523542
if reverse_video:
524-
# Save the reversed video apart from the original one, so the user can compare both
525543
stream = ffmpeg.input(os.path.join(run_dir, 'dream-zoom.mp4'))
526544
stream = stream.video.filter('reverse')
527545
stream = ffmpeg.output(stream, os.path.join(run_dir, 'dream-zoom_reversed.mp4'), crf=20, pix_fmt='yuv420p')

torch_utils/gen_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def interpolate(
238238
if smooth:
239239
# Smooth out the interpolation with a polynomial of order 3 (cubic function f)
240240
# Constructed f by setting f'(0) = f'(1) = 0, and f(0) = 0, f(1) = 1 => f(t) = -2t^3+3t^2 = t^2 (3-2t)
241+
# NOTE: I've merely rediscovered the Smoothstep function S_1(x): https://en.wikipedia.org/wiki/Smoothstep
241242
t_array = t_array ** 2 * (3 - 2 * t_array) # One line thanks to NumPy arrays
242243
# TODO: this might be possible to optimize by using the fact they're numpy arrays, but haven't found a nice way yet
243244
funcs_dict = {'linear': lerp, 'spherical': slerp}

0 commit comments

Comments
 (0)