Skip to content

Commit 9c4f1e8

Browse files
committed
Add save GIF animation feature
1 parent f016f9c commit 9c4f1e8

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ FLAGS
168168
--theta_hidden=THETA_INITIAL
169169
Default: 30.0
170170
Hyperparameter describing the frequency of the color space. Only applies to the hidden layers of the network.
171+
--save_gif=SAVE_GIF
172+
Default: False
173+
Wether or not to save a GIF animation of the generation procedure. Only works if save_progress is set to True.
171174
```
172175
173176
### Priming

deep_daze/cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def train(
3333
create_story=False,
3434
story_start_words=5,
3535
story_words_per_epoch=5,
36+
save_gif=False
3637
):
3738
"""
3839
:param text: (required) A phrase less than 77 characters which you would like to visualize.
@@ -62,6 +63,7 @@ def train(
6263
:param create_story: Creates a story by optimizing each epoch on a new sliding-window of the input words. If this is enabled, much longer texts than 77 chars can be used. Requires save_progress to visualize the transitions of the story.
6364
:param story_start_words: Only used if create_story is True. How many words to optimize on for the first epoch.
6465
:param story_words_per_epoch: Only used if create_story is True. How many words to add to the optimization goal per epoch after the first one.
66+
:param save_gif: Only used if save_progress is True. Saves a GIF animation of the generation procedure using the saved frames.
6567
"""
6668
# Don't instantiate imagine if the user just wants help.
6769
if any("--help" in arg for arg in sys.argv):
@@ -95,7 +97,8 @@ def train(
9597
saturate_bound=saturate_bound,
9698
create_story=create_story,
9799
story_start_words=story_start_words,
98-
story_words_per_epoch=story_words_per_epoch
100+
story_words_per_epoch=story_words_per_epoch,
101+
save_gif=save_gif
99102
)
100103

101104
print('Starting up...')

deep_daze/deep_daze.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
from torch_optimizer import DiffGrad, AdamP
1414

1515
from PIL import Image
16+
from imageio import imread, mimsave
1617
import torchvision.transforms as T
1718
#from torchvision.utils import save_image
1819

1920
from tqdm import trange, tqdm
2021

21-
from deep_daze.clip import load, tokenize
22+
from clip import load, tokenize
2223

2324
assert torch.cuda.is_available(), 'CUDA must be available in order to use Deep Daze'
2425

@@ -223,6 +224,7 @@ def __init__(
223224
create_story=False,
224225
story_start_words=5,
225226
story_words_per_epoch=5,
227+
save_gif=False
226228
):
227229

228230
super().__init__()
@@ -293,6 +295,8 @@ def __init__(
293295

294296
image_tensor = self.clip_img_transform(image)[None, ...].cuda()
295297
self.start_image = image_tensor
298+
299+
self.save_gif = save_gif
296300

297301
def create_clip_encoding(self, text=None, img=None, encoding=None):
298302
self.text = text
@@ -410,6 +414,15 @@ def save_image(self, epoch, iteration, img=None):
410414

411415
tqdm.write(f'image updated at "./{str(self.filename)}"')
412416

417+
def generate_gif(self):
418+
images = []
419+
for file_name in sorted(os.listdir('./')):
420+
if file_name.startswith(self.textpath) and file_name != f'{self.textpath}.jpg':
421+
images.append(imread(os.path.join('./', file_name)))
422+
423+
mimsave(f'{self.textpath}.gif', images)
424+
print(f'Generated image generation animation at ./{self.textpath}.gif')
425+
413426
def forward(self):
414427
if exists(self.start_image):
415428
tqdm.write('Preparing with initial image...')
@@ -453,3 +466,7 @@ def forward(self):
453466
self.clip_encoding = self.update_story_encoding(epoch, i)
454467

455468
self.save_image(epoch, i) # one final save at end
469+
470+
if self.save_gif and self.save_progress:
471+
self.generate_gif()
472+

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
'einops>=0.3',
3131
'fire',
3232
'ftfy',
33+
'imageio>=2.9.0',
3334
'siren-pytorch>=0.0.8',
3435
'torch>=1.7.1',
3536
'torch_optimizer',

0 commit comments

Comments
 (0)