diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index b5a1c39b..66664874 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -1,24 +1,35 @@ import math +from functools import partial from scipy import integrate import torch +from torch import Tensor from torchdiffeq import odeint from tqdm.auto import trange, tqdm +from typing import Optional, Callable, TypeAlias from . import utils +TensorOperator: TypeAlias = Callable[[Tensor], Tensor] + +def make_quantizer(quanta: Tensor) -> TensorOperator: + """Returns an monotype operator which accepts a single-element 1-dimensional Tensor, and rounds its element to the nearest element in `quanta`""" + return partial(utils.quantize, quanta) + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu', concat_zero=True): """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) + ramp = torch.linspace(0, 1, n if concat_zero else n+1) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) + sigmas = sigmas.to(device) + return append_zero(sigmas) if concat_zero else sigmas def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): @@ -48,7 +59,7 @@ def get_ancestral_step(sigma_from, sigma_to): @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -56,6 +67,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -87,7 +99,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -95,6 +107,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -116,7 +129,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -124,6 +137,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 8d700c2e..1cb6f950 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -7,8 +7,9 @@ import warnings import torch -from torch import optim +from torch import optim, Tensor from torchvision.transforms import functional as TF +from typing import Union def from_pil_image(x): @@ -249,3 +250,8 @@ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.floa min_value = math.log(min_value) max_value = math.log(max_value) return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() + + +def quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor: + """Rounds `candidate` to the nearest element in `quanta`""" + return quanta[torch.argmin((quanta-candidate).abs(), dim=0)] \ No newline at end of file