Skip to content

Commit d9b9169

Browse files
committed
guidance, uv
1 parent efaed8b commit d9b9169

File tree

6 files changed

+2598
-72
lines changed

6 files changed

+2598
-72
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ imgs/
99
data/
1010
__pycache__/
1111
tests/
12-
grfs.py
12+
grfs.py
13+
guidance.py

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.11

attention.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools as ft
22
import math
33
import warnings
4-
from typing import Callable, Literal, Optional, Tuple, Union
4+
from typing import Callable, Literal, Optional, Tuple, Union, Dict
55

66
import equinox as eqx
77
import jax
@@ -153,11 +153,15 @@ class MultiheadAttention(eqx.Module):
153153
value_proj: Linear
154154
output_proj: Linear
155155
dropout: Dropout
156+
156157
autoregressive_index: StateIndex[
157-
Tuple[
158-
Float[Array, "S H QK"] | Float[Array, "S QK"],
159-
Float[Array, "S H VO"] | Float[Array, "S VO"],
160-
Int[Array, ""],
158+
Dict[
159+
str,
160+
Tuple[
161+
Float[Array, "S H QK"] | Float[Array, "S QK"],
162+
Float[Array, "S H VO"] | Float[Array, "S VO"],
163+
Int[Array, ""],
164+
]
161165
]
162166
]
163167

@@ -241,9 +245,9 @@ def _make_autoregressive_cache(**_):
241245
else:
242246
_int = jnp.int32
243247

244-
return jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int)
245-
# initial_cache = (jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int))
246-
# return dict(uncond=initial_cache, cond=initial_cache)
248+
# return jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int)
249+
initial_cache = (jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int))
250+
return dict(uncond=initial_cache, cond=initial_cache)
247251

248252
query_proj_out_size = qk_size
249253
key_proj_out_size = qk_size
@@ -312,6 +316,8 @@ def __call__(
312316
state: Optional[State] = None,
313317
*,
314318
key: Optional[PRNGKeyArray] = None,
319+
temperature: Optional[float] = 1.,
320+
which_cache: Literal["cond", "uncond"],
315321
inference: Optional[bool] = None,
316322
deterministic: Optional[bool] = None,
317323
process_heads: Optional[
@@ -372,7 +378,7 @@ def __call__(
372378
if state is None:
373379
causal_mask_offset = 0
374380
else:
375-
key_state, value_state, index = state.get(self.autoregressive_index)
381+
key_state, value_state, index = state.get(self.autoregressive_index)[which_cache]
376382

377383
# If the index is larger than state length, it will wrap around and start from zero
378384
key_state = lax.dynamic_update_slice_in_dim(
@@ -385,8 +391,13 @@ def __call__(
385391
causal_mask_offset = index # Offset shifts attention lower-tril
386392
index = index + kv_seq_length # i -> i + 1, nudging autoregression
387393

394+
other_cache = "cond" if which_cache == "uncond" else "uncond"
395+
empty_cache = jax.tree.map(
396+
lambda x: jnp.zeros_like(x), (key_state, value_state, index)
397+
)
388398
state = state.set(
389-
self.autoregressive_index, (key_state, value_state, index)
399+
self.autoregressive_index,
400+
{which_cache : (key_state, value_state, index), other_cache : empty_cache}
390401
)
391402

392403
# if sample:
@@ -429,6 +440,7 @@ def __call__(
429440
self.dropout,
430441
inference,
431442
attn_bias=self.attn_bias,
443+
scale_factor=self.scale_factor if temperature is None else temperature,
432444
keys=keys,
433445
)
434446

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,18 @@ dependencies = [
3535
"tqdm",
3636
"datasets",
3737
"optax",
38+
"ipykernel>=6.29.5",
39+
"pip>=25.1.1",
3840
]
3941

4042
[build-system]
4143
requires = ["hatchling"]
4244
build-backend = "hatchling.build"
4345

4446
[tool.hatch.build]
45-
include = ["*"]
47+
include = ["*"]
48+
49+
[tool.uv.workspace]
50+
members = [
51+
"transformer_flows",
52+
]

0 commit comments

Comments
 (0)