|
| 1 | +import numpy as np |
| 2 | +from flax import nnx |
| 3 | +from jax import numpy as jnp |
| 4 | +from jax import random as jr |
| 5 | + |
| 6 | +from blaxbird._src.experimental import samplers |
| 7 | +from blaxbird._src.experimental.parameterizations import RFMConfig |
| 8 | + |
| 9 | + |
| 10 | +def _forward_process(inputs, times, noise): |
| 11 | + new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist()) |
| 12 | + times = times.reshape(new_shape) |
| 13 | + inputs_t = times * inputs + (1.0 - times) * noise |
| 14 | + return inputs_t |
| 15 | + |
| 16 | + |
| 17 | +def rfm(config: RFMConfig = RFMConfig()): |
| 18 | + """Construct rectified flow matching functions. |
| 19 | +
|
| 20 | + Args: |
| 21 | + config: a FlowMatchingConfig object |
| 22 | +
|
| 23 | + Returns: |
| 24 | + returns a tuple consisting of train_step, val_step and sampling functions |
| 25 | + """ |
| 26 | + parameterization = config.parameterization |
| 27 | + |
| 28 | + def _loss_fn(model, rng_key, batch): |
| 29 | + inputs = batch["inputs"] |
| 30 | + time_key, rng_key = jr.split(rng_key) |
| 31 | + times = jr.uniform(time_key, shape=(inputs.shape[0],)) |
| 32 | + times = ( |
| 33 | + times * (parameterization.t_max - parameterization.t_eps) |
| 34 | + + parameterization.t_eps |
| 35 | + ) |
| 36 | + noise_key, rng_key = jr.split(rng_key) |
| 37 | + noise = jr.normal(noise_key, inputs.shape) |
| 38 | + inputs_t = _forward_process(inputs, times, noise) |
| 39 | + vt = model(inputs=inputs_t, times=times, context=batch.get("context")) |
| 40 | + ut = inputs - noise |
| 41 | + loss = jnp.mean(jnp.square(ut - vt)) |
| 42 | + return loss |
| 43 | + |
| 44 | + def train_step(model, rng_key, batch, **kwargs): |
| 45 | + return nnx.value_and_grad(_loss_fn)(model, rng_key, batch) |
| 46 | + |
| 47 | + def val_step(model, rng_key, batch, **kwargs): |
| 48 | + return _loss_fn(model, rng_key, batch) |
| 49 | + |
| 50 | + sampler = getattr(samplers, config.sampler + "_sample_fn")(config) |
| 51 | + return train_step, val_step, sampler |
0 commit comments