Skip to content

Commit cf5deac

Browse files
authored
Fixes (#5)
1 parent cd1c65a commit cf5deac

File tree

10 files changed

+457
-117
lines changed

10 files changed

+457
-117
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ Using `blaxbird` one can
1515
- distribute data and model weights over multiple processes or GPUs,
1616
- define hooks that are periodically called during training.
1717

18-
In addition, `blaxbird` offers high-quality implementation of common neural network modules and algorithms, such as:
18+
In addition, `blaxbird` offers high-quality implementations of common neural network modules and algorithms, such as:
1919

20-
- MLP, Diffusion Transformer,
21-
- Flow Matching and Denoising Score Matching (EDM schedules) with Euler and Heun samplers,
22-
- Consistency Distillation/Matching.
20+
- MLPs, DiTs, UNets,
21+
- Flow Matching and Denoising Score Matching (EDM schedules) models with Euler and Heun samplers,
22+
- Consistency Distillation/Matching models.
2323

2424
## Example
2525

blaxbird/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""blaxbird: A high-level API for building and training Flax NNX models."""
22

3-
__version__ = "0.1.0"
3+
__version__ = "0.1.1"
44

55
from blaxbird._src.checkpointer import get_default_checkpointer
66
from blaxbird._src.trainer import train_fn
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
from flax import nnx
3+
from jax import numpy as jnp
4+
from jax import random as jr
5+
6+
from blaxbird._src.experimental import samplers
7+
from blaxbird._src.experimental.parameterizations import RFMConfig
8+
9+
10+
def _forward_process(inputs, times, noise):
11+
new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist())
12+
times = times.reshape(new_shape)
13+
inputs_t = times * inputs + (1.0 - times) * noise
14+
return inputs_t
15+
16+
17+
def rfm(config: RFMConfig = RFMConfig()):
18+
"""Construct rectified flow matching functions.
19+
20+
Args:
21+
config: a FlowMatchingConfig object
22+
23+
Returns:
24+
returns a tuple consisting of train_step, val_step and sampling functions
25+
"""
26+
parameterization = config.parameterization
27+
28+
def _loss_fn(model, rng_key, batch):
29+
inputs = batch["inputs"]
30+
time_key, rng_key = jr.split(rng_key)
31+
times = jr.uniform(time_key, shape=(inputs.shape[0],))
32+
times = (
33+
times * (parameterization.t_max - parameterization.t_eps)
34+
+ parameterization.t_eps
35+
)
36+
noise_key, rng_key = jr.split(rng_key)
37+
noise = jr.normal(noise_key, inputs.shape)
38+
inputs_t = _forward_process(inputs, times, noise)
39+
vt = model(inputs=inputs_t, times=times, context=batch.get("context"))
40+
ut = inputs - noise
41+
loss = jnp.mean(jnp.square(ut - vt))
42+
return loss
43+
44+
def train_step(model, rng_key, batch, **kwargs):
45+
return nnx.value_and_grad(_loss_fn)(model, rng_key, batch)
46+
47+
def val_step(model, rng_key, batch, **kwargs):
48+
return _loss_fn(model, rng_key, batch)
49+
50+
sampler = getattr(samplers, config.sampler + "_sample_fn")(config)
51+
return train_step, val_step, sampler

blaxbird/_src/experimental/edm.py

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,10 @@
1-
import dataclasses
2-
31
import numpy as np
42
from flax import nnx
53
from jax import numpy as jnp
64
from jax import random as jr
75

86
from blaxbird._src.experimental import samplers
9-
10-
11-
@dataclasses.dataclass
12-
class EDMParameterization:
13-
n_sampling_steps: int = 25
14-
sigma_min: float = 0.002
15-
sigma_max: float = 80.0
16-
rho: float = 7.0
17-
sigma_data: float = 0.5
18-
P_mean: float = -1.2
19-
P_std: float = 1.2
20-
S_churn: float = 40
21-
S_min: float = 0.05
22-
S_max: float = 50
23-
S_noise: float = 1.003
24-
25-
def sigma(self, eps):
26-
return jnp.exp(eps * self.P_std + self.P_mean)
27-
28-
def loss_weight(self, sigma):
29-
return (jnp.square(sigma) + jnp.square(self.sigma_data)) / jnp.square(
30-
sigma * self.sigma_data
31-
)
32-
33-
def skip_scaling(self, sigma):
34-
return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
35-
36-
def out_scaling(self, sigma):
37-
return sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
38-
39-
def in_scaling(self, sigma):
40-
return 1 / (sigma**2 + self.sigma_data**2) ** 0.5
41-
42-
def noise_conditioning(self, sigma):
43-
return 0.25 * jnp.log(sigma)
44-
45-
def sampling_sigmas(self, num_steps):
46-
rho_inv = 1 / self.rho
47-
step_idxs = jnp.arange(num_steps, dtype=jnp.float32)
48-
sigmas = (
49-
self.sigma_max**rho_inv
50-
+ step_idxs
51-
/ (num_steps - 1)
52-
* (self.sigma_min**rho_inv - self.sigma_max**rho_inv)
53-
) ** self.rho
54-
return jnp.concatenate([sigmas, jnp.zeros_like(sigmas[:1])])
55-
56-
def sigma_hat(self, sigma, num_steps):
57-
gamma = (
58-
jnp.minimum(self.S_churn / num_steps, 2**0.5 - 1)
59-
if self.S_min <= sigma <= self.S_max
60-
else 0
61-
)
62-
return sigma + gamma * sigma
63-
64-
65-
@dataclasses.dataclass
66-
class EDMConfig:
67-
n_sampling_steps: int = 25
68-
sampler: str = "heun"
69-
parameterization: EDMParameterization = dataclasses.field(
70-
default_factory=EDMParameterization
71-
)
7+
from blaxbird._src.experimental.parameterizations import EDMConfig
728

739

7410
def edm(config: EDMConfig):

blaxbird/_src/experimental/nn/mlp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
class MLP(nnx.Module):
8-
# ruff: noqa: PLR0913, ANN204, ANN101
98
def __init__(
109
self,
1110
in_features: int,

0 commit comments

Comments
 (0)