Skip to content

Commit 5746ee2

Browse files
committed
version
1 parent 44f27e9 commit 5746ee2

File tree

6 files changed

+1356
-22
lines changed

6 files changed

+1356
-22
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ __pycache__/
1111
tests/
1212
grfs.py
1313
guidance.py
14-
test.yml
14+
test.yml
15+
.ipynb_checkpoints/
16+
grfs.ipynb

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ pip install transformer-flows
2929
uv run --all-extras python examples/main.py
3030
```
3131

32-
```
33-
3432
#### Samples
3533

3634
I haven't optimised anything here (the authors mention varying the variance of noise used to dequantise the images), nor have I trained for very long. You can see slight artifacts due to the dequantisation noise.

examples/main.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import math
32
from dataclasses import dataclass
43
from pathlib import Path
@@ -264,9 +263,7 @@ def get_config(dataset_name: str) -> ConfigDict:
264263

265264
key_model, key_train, key_data = jr.split(key, 3)
266265

267-
model, _ = eqx.nn.make_with_state(TransformerFlow)(**config.model, key=key_model)
268-
269-
get_state_fn = partial(get_sample_state, config=config, key=key_model)
266+
model, state = eqx.nn.make_with_state(TransformerFlow)(**config.model, key=key_model)
270267

271268
sharding, replicated_sharding = get_shardings()
272269

@@ -287,6 +284,7 @@ def get_config(dataset_name: str) -> ConfigDict:
287284
dataset,
288285
# Model
289286
model,
287+
state,
290288
eps_sigma=config.train.eps_sigma,
291289
noise_type=config.train.noise_type,
292290
# Data
@@ -316,7 +314,6 @@ def get_config(dataset_name: str) -> ConfigDict:
316314
# Other
317315
cmap=config.train.cmap,
318316
policy=policy,
319-
get_state_fn=get_state_fn,
320317
sharding=sharding,
321318
replicated_sharding=replicated_sharding,
322319
imgs_dir=imgs_dir,

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transformer-flows"
3-
version = "0.0.10"
3+
version = "0.0.11"
44
description = "Implementation of Transformer Flows (Apple ML) in JAX and Equinox."
55
readme = "README.md"
66
authors = [
@@ -27,3 +27,8 @@ examples = [
2727
[build-system]
2828
requires = ["hatchling"]
2929
build-backend = "hatchling.build"
30+
31+
[dependency-groups]
32+
dev = [
33+
"jupyter>=1.1.1",
34+
]

src/transformer_flows/transformer_flow.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import math
23
import dataclasses
34
from pathlib import Path
@@ -24,8 +25,10 @@
2425
from .attention import MultiheadAttention, self_attention
2526

2627

27-
typecheck = jaxtyped(typechecker=typechecker)
28-
28+
if os.getenv("TYPECHECK", "").lower() in ["1", "true"]:
29+
typecheck = jaxtyped(typechecker=typechecker)
30+
else:
31+
typecheck = lambda _: _
2932

3033
MetricsDict = dict[
3134
str, Union[Scalar, Float[Array, "..."]]
@@ -39,6 +42,8 @@
3942

4043
NoiseType = Union[Literal["gaussian", "uniform"], None]
4144

45+
CacheType = Literal["conditional", "unconditional"] # Guidance caches
46+
4247
MaskArray = Union[
4348
Float[Array, "s s"], Int[Array, "s s"], Bool[Array, "s s"]
4449
]
@@ -47,7 +52,7 @@
4752
Union[Float[Array, "..."], Int[Array, "..."]] # Flattened regardless
4853
]
4954

50-
OptState = PyTree | optax.OptState
55+
OptState = Union[PyTree, optax.OptState]
5156

5257

5358
def exists(v):
@@ -193,8 +198,10 @@ def shard_batch(
193198
Tuple[Float[Array, "n ..."], Float[Array, "n ..."]],
194199
Float[Array, "n ..."]
195200
]:
201+
196202
if sharding:
197203
batch = eqx.filter_shard(batch, sharding)
204+
198205
return batch
199206

200207

@@ -205,13 +212,19 @@ def shard_model(
205212
) -> Union[eqx.Module, Tuple[eqx.Module, OptState]]:
206213
if sharding:
207214
model = eqx.filter_shard(model, sharding)
215+
208216
if opt_state:
217+
209218
opt_state = eqx.filter_shard(opt_state, sharding)
219+
210220
return model, opt_state
221+
211222
return model
212223
else:
213224
if opt_state:
225+
214226
return model, opt_state
227+
215228
return model
216229

217230

@@ -419,7 +432,7 @@ def __call__(
419432
mask: Optional[Union[MaskArray, Literal["causal"]]],
420433
state: Optional[eqx.nn.State],
421434
*,
422-
which_cache: Literal["cond", "uncond"],
435+
which_cache: CacheType,
423436
attention_temperature: Optional[float] = 1.
424437
) -> Tuple[
425438
Float[Array, "#s q"], Optional[eqx.nn.State] # Autoregression
@@ -562,7 +575,7 @@ def __call__(
562575
] = None,
563576
state: Optional[eqx.nn.State] = None, # No state during forward pass
564577
*,
565-
which_cache: Literal["cond", "uncond"] = "cond",
578+
which_cache: CacheType = "conditional",
566579
attention_temperature: Optional[float] = 1.
567580
) -> Union[
568581
Float[Array, "#{self.n_patches} {self.sequence_dim}"],
@@ -773,7 +786,7 @@ def reverse_step(
773786
s: Int[Array, ""],
774787
state: eqx.nn.State,
775788
*,
776-
which_cache: Literal["cond", "uncond"] = "cond",
789+
which_cache: CacheType = "conditional",
777790
attention_temperature: Optional[float] = 1.
778791
) -> Tuple[
779792
Float[Array, "1 {self.sequence_dim}"],
@@ -834,7 +847,7 @@ def reverse(
834847
],
835848
state: eqx.nn.State,
836849
*,
837-
which_cache: Literal["cond", "uncond"] = "cond",
850+
which_cache: CacheType = "conditional",
838851
guidance: float = 0.,
839852
attention_temperature: Optional[float] = 1.0,
840853
guide_what: Optional[Literal["ab", "a", "b"]] = "ab",
@@ -875,7 +888,7 @@ def _autoregression_step(
875888
pos_embed,
876889
s,
877890
state=state,
878-
which_cache="uncond",
891+
which_cache="unconditional",
879892
attention_temperature=attention_temperature,
880893
)
881894

@@ -1155,6 +1168,7 @@ def reverse(
11551168

11561169
def _block_step(z_s_sequence, params__state):
11571170
z, s, sequence = z_s_sequence
1171+
11581172
params, state = params__state
11591173
block = eqx.combine(params, struct)
11601174

@@ -1173,6 +1187,7 @@ def _block_step(z_s_sequence, params__state):
11731187
return (z, s + 1, sequence), None
11741188

11751189
sequence = jnp.zeros((self.n_blocks + 1, self.n_channels, self.img_size, self.img_size), dtype=z.dtype)
1190+
11761191
sequence = sequence.at[0].set(self.unpatchify(z))
11771192

11781193
(z, _, sequence), _ = jax.lax.scan(
@@ -1512,6 +1527,7 @@ def train(
15121527
dataset: Dataset,
15131528
# Model
15141529
model: TransformerFlow,
1530+
state: eqx.nn.State,
15151531
eps_sigma: Optional[float],
15161532
noise_type: NoiseType,
15171533
# Data
@@ -1539,7 +1555,6 @@ def train(
15391555
n_sample: Optional[int] = 4,
15401556
n_warps: Optional[int] = 1,
15411557
denoise_samples: bool = False,
1542-
get_state_fn: Callable[[None], eqx.nn.State] = None,
15431558
cmap: Optional[str] = None,
15441559
# Sharding: data and model
15451560
sharding: Optional[NamedSharding] = None,
@@ -1550,7 +1565,7 @@ def train(
15501565

15511566
print("n_params={:.3E}".format(count_parameters(model)))
15521567

1553-
key_data, valid_key, sample_key, *loader_keys = jr.split(key, 5)
1568+
valid_key, sample_key, *loader_keys = jr.split(key, 4)
15541569

15551570
# Optimiser & scheduler
15561571
n_steps_per_epoch = int((dataset.x_train.shape[0] + dataset.x_valid.shape[0]) / batch_size)
@@ -1691,7 +1706,7 @@ def train(
16911706
ema_model if use_ema else model,
16921707
z,
16931708
y,
1694-
state=get_state_fn(),
1709+
state=state,
16951710
guidance=guidance,
16961711
denoise_samples=denoise_samples,
16971712
sharding=sharding,
@@ -1722,7 +1737,7 @@ def train(
17221737
ema_model if use_ema else model,
17231738
z,
17241739
y,
1725-
state=get_state_fn(),
1740+
state=state,
17261741
return_sequence=True,
17271742
denoise_samples=denoise_samples,
17281743
sharding=sharding,
@@ -1744,7 +1759,6 @@ def train(
17441759
plt.savefig(imgs_dir / "warps/warps_{:05d}.png".format(i), bbox_inches="tight")
17451760
plt.close()
17461761

1747-
17481762
# Losses and metrics
17491763
if i > 0:
17501764

0 commit comments

Comments
 (0)