Skip to content

Commit 4c3e6db

Browse files
committed
updated cifar10 example, mixer model typing
1 parent b0ca239 commit 4c3e6db

File tree

4 files changed

+176
-42
lines changed

4 files changed

+176
-42
lines changed

data/cifar10.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ def cifar10(path: str, key: Key, *, in_memory: bool = True) -> ScalerDataset:
5757
target_transform=transforms.Lambda(lambda x: x.float())
5858
)
5959

60-
6160
if in_memory:
6261
Xt, At = convert_torch_to_in_memory(train_dataset)
6362
Xv, Av = convert_torch_to_in_memory(valid_dataset)
6463

6564
At = At.astype(jnp.float32)
6665
Av = Av.astype(jnp.float32)
6766

68-
process_fn = Scaler(x_min=Xt.min(), x_max=Xt.max())
67+
# process_fn = Scaler(x_min=Xt.min(), x_max=Xt.max())
68+
process_fn = Normer(x_mean=Xt.mean(), x_std=Xt.std())
6969

7070
train_dataloader = InMemoryDataLoader(
7171
X=Xt, A=At, process_fn=process_fn, key=key_train)

examples/cifar10.ipynb

Lines changed: 152 additions & 28 deletions
Large diffs are not rendered by default.

sbgm/_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def train_from_config(
466466
# Plot losses etc
467467
plot_metrics(train_losses, valid_losses, step, exp_dir)
468468

469-
return model
469+
return ema_model if config.use_ema else model
470470

471471

472472
def train(
@@ -726,4 +726,4 @@ def train(
726726
# Plot losses etc
727727
plot_metrics(train_losses, valid_losses, step, exp_dir)
728728

729-
return model
729+
return ema_model if use_ema else model

sbgm/models/_mixer.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Sequence, Optional, Union
1+
from typing import Sequence, Optional, Callable
22
import jax
33
import jax.numpy as jnp
44
import jax.random as jr
55
import einops
66
import equinox as eqx
7-
from jaxtyping import Key, Array
7+
from jaxtyping import Key, Array, Float, jaxtyped
8+
from beartype import beartype as typechecker
89

910

1011
class AdaLayerNorm(eqx.Module):
@@ -134,7 +135,11 @@ class Mixer2d(eqx.Module):
134135
t1: float
135136
embedding_dim: int
136137
final_activation: callable
138+
img_size: Sequence[int]
139+
q_dim: int
140+
a_dim: int
137141

142+
@jaxtyped(typechecker=typechecker)
138143
def __init__(
139144
self,
140145
img_size: Sequence[int],
@@ -145,11 +150,11 @@ def __init__(
145150
num_blocks: int,
146151
t1: float,
147152
embedding_dim: int = 8,
148-
final_activation: Optional[Union[callable, str]] = None,
153+
final_activation: Optional[Callable | str] = None,
149154
q_dim: Optional[int] = None,
150155
a_dim: Optional[int] = None,
151156
*,
152-
key: Key
157+
key: Key[jnp.ndarray, "..."]
153158
):
154159
"""
155160
A 2D MLP Mixer model.
@@ -207,6 +212,10 @@ def __init__(
207212
_input_size = input_size + q_dim if q_dim is not None else input_size
208213
_context_dim = embedding_dim + a_dim if a_dim is not None else embedding_dim
209214

215+
self.img_size = img_size
216+
self.q_dim = q_dim
217+
self.a_dim = a_dim
218+
210219
self.conv_in = eqx.nn.Conv2d(
211220
_input_size,
212221
hidden_size,
@@ -237,15 +246,16 @@ def __init__(
237246
self.embedding_dim = embedding_dim
238247
self.final_activation = get_activation_fn(final_activation)
239248

249+
@jaxtyped(typechecker=typechecker)
240250
def __call__(
241251
self,
242-
t: Union[float, Array],
243-
y: Array,
244-
q: Optional[Array] = None,
245-
a: Optional[Array] = None,
252+
t: float | Float[Array, ""],
253+
y: Float[Array, "..."],
254+
q: Optional[Float[Array, "{self.q_dim} ..."]] = None,
255+
a: Optional[Float[Array, "{self.a_dim}"]] = None,
246256
*,
247-
key: Optional[Key] = None
248-
) -> Array:
257+
key: Optional[Key[jnp.ndarray, "..."]] = None
258+
) -> Float[Array, "..."]:
249259
_, height, width = y.shape
250260
t = jnp.atleast_1d(t / self.t1)
251261
t = get_timestep_embedding(t, embedding_dim=self.embedding_dim)

0 commit comments

Comments
 (0)