Skip to content

Commit 2e2f87a

Browse files
committed
allow for config driven creation of imagen off the bat, also researcher must always give list of image resolutions this time around
1 parent fa3fde7 commit 2e2f87a

File tree

3 files changed

+60
-16
lines changed

3 files changed

+60
-16
lines changed

imagen_pytorch/configs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
from pydantic import BaseModel, validator, root_validator
3+
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
4+
from imagen_pytorch.imagen_pytorch import Imagen, Unet
5+
6+
# helper functions
7+
8+
def exists(val):
9+
return val is not None
10+
11+
def default(val, d):
12+
return val if exists(val) else d
13+
14+
def ListOrTuple(inner_type):
15+
return Union[List[inner_type], Tuple[inner_type]]
16+
17+
# imagen pydantic classes
18+
19+
class UnetConfig(BaseModel):
20+
dim: int
21+
dim_mults: ListOrTuple(int)
22+
text_embed_dim: int = 1024
23+
cond_dim: int = None
24+
channels: int = 3
25+
attn_dim_head: int = 32
26+
attn_heads: int = 16
27+
28+
class Config:
29+
extra = "allow"
30+
31+
class ImagenConfig(BaseModel):
32+
unets: ListOrTuple(UnetConfig)
33+
image_sizes: ListOrTuple(int)
34+
channels: int = 3
35+
timesteps: int = 1000
36+
loss_type: str = 'l2'
37+
beta_schedule: str = 'cosine'
38+
learned_variance: bool = True
39+
cond_drop_prob: float = 0.5
40+
41+
@validator('image_sizes')
42+
def check_image_sizes(cls, image_sizes, values):
43+
unets = values.get('unets')
44+
if len(image_sizes) != len(unets):
45+
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
46+
return image_sizes
47+
48+
def create(self):
49+
decoder_kwargs = self.dict()
50+
unet_configs = decoder_kwargs.pop('unets')
51+
unets = [Unet(**config) for config in unet_configs]
52+
return Imagen(unets, **decoder_kwargs)
53+
54+
class Config:
55+
extra = "allow"

imagen_pytorch/imagen_pytorch.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def __init__(
620620
self,
621621
dim,
622622
*,
623-
image_embed_dim = None,
623+
image_embed_dim = 1024,
624624
text_embed_dim = 512,
625625
cond_dim = None,
626626
num_image_tokens = 4,
@@ -940,14 +940,13 @@ def __init__(
940940
self,
941941
unets,
942942
*,
943+
image_sizes, # for cascading ddpm, image size at each stage
943944
text_encoder_name = 't5-small',
944-
image_size = None,
945945
channels = 3,
946946
timesteps = 1000,
947947
cond_drop_prob = 0.1,
948948
loss_type = 'l2',
949949
beta_schedule = 'cosine',
950-
image_sizes = None, # for cascading ddpm, image size at each stage
951950
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
952951
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
953952
condition_on_text = True,
@@ -966,14 +965,6 @@ def __init__(
966965
self.condition_on_text = condition_on_text
967966
self.unconditional = not condition_on_text
968967

969-
# determine image size, with image_size and image_sizes taking precedence
970-
971-
if exists(image_size) or exists(image_sizes):
972-
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
973-
image_size = default(image_size, lambda: image_sizes[-1])
974-
else:
975-
raise Error('either image_size or image sizes must be given to imagen')
976-
977968
# channels
978969

979970
self.channels = channels
@@ -1016,11 +1007,8 @@ def __init__(
10161007

10171008
# unet image sizes
10181009

1019-
image_sizes = default(image_sizes, (image_size,))
1020-
image_sizes = tuple(sorted(set(image_sizes)))
1021-
10221010
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
1023-
self.image_sizes = image_sizes
1011+
self.image_sizes = cast_tuple(image_sizes)
10241012
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
10251013

10261014
# random crop sizes (for super-resoluting unets at the end of cascade?)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'imagen-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.12',
6+
version = '0.0.15',
77
license='MIT',
88
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
99
author = 'Phil Wang',
@@ -22,6 +22,7 @@
2222
'einops-exts',
2323
'kornia',
2424
'numpy',
25+
'pydantic',
2526
'resize-right',
2627
'torch>=1.6',
2728
'torchvision',

0 commit comments

Comments
 (0)