Skip to content

Commit 81264c2

Browse files
committed
version, guidance cache fix
1 parent 5746ee2 commit 81264c2

File tree

5 files changed

+44
-33
lines changed

5 files changed

+44
-33
lines changed

examples/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def get_config(dataset_name: str) -> ConfigDict:
186186
# Model
187187
config.model = model = ConfigDict()
188188
model.img_size = data.img_size
189-
model.in_channels = data.n_channels
189+
model.n_channels = data.n_channels
190190
model.patch_size = 4
191191
model.channels = {"CIFAR10" : 512, "MNIST" : 128, "FLOWERS" : 512}[dataset_name]
192192
model.y_dim = {"CIFAR10" : 1, "MNIST" : 1, "FLOWERS" : 1}[dataset_name]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transformer-flows"
3-
version = "0.0.11"
3+
version = "0.0.12"
44
description = "Implementation of Transformer Flows (Apple ML) in JAX and Equinox."
55
readme = "README.md"
66
authors = [

src/transformer_flows/attention.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
typecheck = jaxtyped(typechecker=typechecker)
1616

17+
KQVCacheType = Literal["conditional", "unconditional"] # Guidance caches
18+
1719

1820
@typecheck
1921
def standard_attention(
@@ -245,9 +247,9 @@ def _make_autoregressive_cache(**_):
245247
else:
246248
_int = jnp.int32
247249

248-
# return jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int)
249250
initial_cache = (jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int))
250-
return dict(uncond=initial_cache, cond=initial_cache)
251+
252+
return dict(unconditional=initial_cache, conditional=initial_cache)
251253

252254
query_proj_out_size = qk_size
253255
key_proj_out_size = qk_size
@@ -317,7 +319,7 @@ def __call__(
317319
*,
318320
key: Optional[PRNGKeyArray] = None,
319321
temperature: Optional[float] = 1.,
320-
which_cache: Literal["cond", "uncond"],
322+
which_cache: KQVCacheType,
321323
inference: Optional[bool] = None,
322324
deterministic: Optional[bool] = None,
323325
process_heads: Optional[
@@ -391,20 +393,23 @@ def __call__(
391393
causal_mask_offset = index # Offset shifts attention lower-tril
392394
index = index + kv_seq_length # i -> i + 1, nudging autoregression
393395

394-
other_cache = "cond" if which_cache == "uncond" else "uncond"
395-
empty_cache = jax.tree.map(
396-
lambda x: jnp.zeros_like(x), (key_state, value_state, index)
397-
)
396+
if which_cache == "unconditional":
397+
other_cache = "conditional"
398+
else:
399+
other_cache = "unconditional"
400+
401+
# empty_cache = jax.tree.map(
402+
# lambda x: jnp.zeros_like(x), (key_state, value_state, index)
403+
# )
404+
398405
state = state.set(
399406
self.autoregressive_index,
400-
{which_cache : (key_state, value_state, index), other_cache : empty_cache}
407+
{
408+
which_cache : (key_state, value_state, index),
409+
other_cache : state.get(self.autoregressive_index)[other_cache] # empty_cache
410+
}
401411
)
402412

403-
# if sample:
404-
# state = state.set(
405-
# self.autoregressive_index, (key_state, value_state, index)
406-
# )
407-
408413
# The keys and values stack the preceeding keys and values,
409414
# key-value sequence length updated; masking adopts this
410415
key_heads = key_state

src/transformer_flows/transformer_flow.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import matplotlib.pyplot as plt
2323
from tqdm.auto import trange
2424

25-
from .attention import MultiheadAttention, self_attention
25+
from .attention import MultiheadAttention, self_attention, KQVCacheType
2626

2727

2828
if os.getenv("TYPECHECK", "").lower() in ["1", "true"]:
@@ -42,8 +42,6 @@
4242

4343
NoiseType = Union[Literal["gaussian", "uniform"], None]
4444

45-
CacheType = Literal["conditional", "unconditional"] # Guidance caches
46-
4745
MaskArray = Union[
4846
Float[Array, "s s"], Int[Array, "s s"], Bool[Array, "s s"]
4947
]
@@ -123,23 +121,27 @@ def clear_and_get_results_dir(
123121
run_dir: Optional[Path] = None,
124122
clear_old: bool = False
125123
) -> Path:
124+
126125
if not exists(run_dir):
127126
run_dir = Path.cwd()
128127

129128
# Image save directories
130129
imgs_dir = run_dir / "imgs" / dataset_name.lower()
131130

131+
# Clear old ones
132132
if clear_old:
133-
rmtree(str(imgs_dir), ignore_errors=True) # Clear old ones
133+
rmtree(str(imgs_dir), ignore_errors=True)
134134

135135
if not imgs_dir.exists():
136-
137136
imgs_dir.mkdir(exist_ok=True, parents=True)
138137

139-
for _dir in ["samples", "warps", "latents"]:
140-
(imgs_dir / _dir).mkdir(exist_ok=True)
138+
# Image type directories
139+
for _dir in ["samples", "warps", "latents"]:
140+
(imgs_dir / _dir).mkdir(exist_ok=True, parents=True)
141141

142-
return imgs_dir
142+
print("Saving samples in:\n\t", imgs_dir)
143+
144+
return imgs_dir
143145

144146

145147
def count_parameters(model: eqx.Module) -> int:
@@ -289,6 +291,7 @@ def __init__(
289291
key: PRNGKeyArray
290292
):
291293
key_weight, key_bias = jr.split(key)
294+
292295
l = math.sqrt(1. / in_size)
293296
dtype = default(dtype, jnp.float32)
294297

@@ -432,7 +435,7 @@ def __call__(
432435
mask: Optional[Union[MaskArray, Literal["causal"]]],
433436
state: Optional[eqx.nn.State],
434437
*,
435-
which_cache: CacheType,
438+
which_cache: KQVCacheType,
436439
attention_temperature: Optional[float] = 1.
437440
) -> Tuple[
438441
Float[Array, "#s q"], Optional[eqx.nn.State] # Autoregression
@@ -480,6 +483,7 @@ def __init__(
480483
key: PRNGKeyArray
481484
):
482485
keys = jr.split(key, 3)
486+
483487
self.y_dim = y_dim
484488
self.conditioning_type = conditioning_type
485489

@@ -575,7 +579,7 @@ def __call__(
575579
] = None,
576580
state: Optional[eqx.nn.State] = None, # No state during forward pass
577581
*,
578-
which_cache: CacheType = "conditional",
582+
which_cache: KQVCacheType = "conditional",
579583
attention_temperature: Optional[float] = 1.
580584
) -> Union[
581585
Float[Array, "#{self.n_patches} {self.sequence_dim}"],
@@ -613,9 +617,10 @@ def __init__(
613617
sequence_length: int
614618
):
615619
self.permute = permute # Flip if true else pass
616-
assert jnp.isscalar(self.permute)
617620
self.sequence_length = sequence_length
618621

622+
assert jnp.isscalar(self.permute)
623+
619624
@property
620625
def permute_idx(self):
621626
permute = maybe_stop_grad(self.permute, stop=True)
@@ -786,7 +791,7 @@ def reverse_step(
786791
s: Int[Array, ""],
787792
state: eqx.nn.State,
788793
*,
789-
which_cache: CacheType = "conditional",
794+
which_cache: KQVCacheType = "conditional",
790795
attention_temperature: Optional[float] = 1.
791796
) -> Tuple[
792797
Float[Array, "1 {self.sequence_dim}"],
@@ -847,7 +852,7 @@ def reverse(
847852
],
848853
state: eqx.nn.State,
849854
*,
850-
which_cache: CacheType = "conditional",
855+
which_cache: KQVCacheType = "conditional",
851856
guidance: float = 0.,
852857
attention_temperature: Optional[float] = 1.0,
853858
guide_what: Optional[Literal["ab", "a", "b"]] = "ab",
@@ -942,7 +947,7 @@ class TransformerFlow(eqx.Module):
942947
@typecheck
943948
def __init__(
944949
self,
945-
in_channels: int,
950+
n_channels: int,
946951
img_size: int,
947952
patch_size: int,
948953
channels: int,
@@ -958,11 +963,11 @@ def __init__(
958963
key: PRNGKeyArray
959964
):
960965
self.img_size = img_size
961-
self.n_channels = in_channels
966+
self.n_channels = n_channels
962967

963968
self.patch_size = patch_size
964969
self.n_patches = int(img_size / patch_size) ** 2
965-
self.sequence_dim = in_channels * patch_size ** 2
970+
self.sequence_dim = n_channels * patch_size ** 2
966971
self.n_blocks = n_blocks
967972

968973
self.y_dim = y_dim
@@ -1785,6 +1790,7 @@ def filter_spikes(l: list, loss_max: float = 10.0) -> list[float]:
17851790
plt.savefig(imgs_dir / "losses.png", bbox_inches="tight")
17861791
plt.close()
17871792

1788-
save_fn(model=ema_model if use_ema else model)
1793+
if exists(save_fn):
1794+
save_fn(model=ema_model if use_ema else model)
17891795

17901796
return model

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)