Skip to content

[Pytorch] Pytorch only schedulers #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
479e0d2
pytorch only schedulers
kashif Sep 16, 2022
3cc8796
fix style
kashif Sep 16, 2022
0b1fb7c
Merge branch 'main' into remove-numpy
kashif Sep 18, 2022
c1bb957
remove match_shape
kashif Sep 18, 2022
6818e9a
pytorch only ddpm
kashif Sep 18, 2022
66d48b8
remove SchedulerMixin
kashif Sep 18, 2022
f656c69
remove numpy from karras_ve
kashif Sep 18, 2022
66febf1
fix types
kashif Sep 18, 2022
d6cce91
remove numpy from lms_discrete
kashif Sep 18, 2022
1bb1716
remove numpy from pndm
kashif Sep 18, 2022
bde8899
fix typo
kashif Sep 18, 2022
ceb0b8a
remove mixin and numpy from sde_vp and ve
kashif Sep 18, 2022
7aa375f
remove remaining tensor_format
kashif Sep 18, 2022
5eb3763
fix style
kashif Sep 18, 2022
a56ed1f
sigmas has to be torch tensor
kashif Sep 18, 2022
7a86f0c
removed set_format in readme
kashif Sep 18, 2022
ee0185c
remove set format from docs
kashif Sep 18, 2022
7e214e9
remove set_format from pipelines
kashif Sep 18, 2022
5a608dc
update tests
kashif Sep 18, 2022
0cf700a
Merge branch 'main' into remove-numpy
kashif Sep 18, 2022
9014aee
fix typo
kashif Sep 18, 2022
03dd287
continue to use mixin
kashif Sep 19, 2022
7822977
fix imports
kashif Sep 19, 2022
4c91e36
removed unsed imports
kashif Sep 19, 2022
32c3d57
match shape instead of assuming image shapes
kashif Sep 19, 2022
3f484f3
remove import typo
kashif Sep 19, 2022
1bf4227
Merge branch 'main' into remove-numpy
kashif Sep 19, 2022
5de5c5c
Merge branch 'main' into remove-numpy
kashif Sep 20, 2022
d94787b
Merge branch 'main' into remove-numpy
kashif Sep 23, 2022
1b5fbc0
update call to add_noise
kashif Sep 23, 2022
0c8e83f
use math instead of numpy
kashif Sep 23, 2022
a4e80d8
fix t_index
kashif Sep 23, 2022
6485645
removed commented out numpy tests
kashif Sep 23, 2022
f0b7aba
timesteps needs to be discrete
kashif Sep 24, 2022
54a4aea
Merge branch 'main' into remove-numpy
kashif Sep 25, 2022
2cb7a33
cast timesteps to int in flax scheduler too
kashif Sep 26, 2022
9cacd3f
fix device mismatch issue
kashif Sep 26, 2022
98458cb
Merge branch 'main' into remove-numpy
kashif Sep 27, 2022
9bda654
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Sep 27, 2022
4e70c10
small fix
patrickvonplaten Sep 27, 2022
b2dbd79
Update src/diffusers/schedulers/scheduling_pndm.py
patrickvonplaten Sep 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
- Schedulers should be framework-specific.

The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.

Expand Down
2 changes: 1 addition & 1 deletion examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __call__(
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# # predict the noise residual
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

# perform classifier free guidance
Expand Down
5 changes: 4 additions & 1 deletion examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,10 @@ def main():

# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)

train_dataset = TextualInversionDataset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main(args):
"UpBlock2D",
),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):

def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):

def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)

@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):

def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)

@torch.no_grad()
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/pipelines/pndm/pipeline_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):

def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
Expand Down Expand Up @@ -320,11 +319,11 @@ def __call__(
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))

latents = (init_latents_proper * mask) + (latents * (1 - mask))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.register_modules(
vae_decoder=vae_decoder,
text_encoder=text_encoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):

def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/schedulers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
the forward pass.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
- Schedulers should be framework specific.

## Examples

Expand Down
59 changes: 28 additions & 31 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Expand All @@ -72,7 +72,7 @@ def alpha_bar(time_step):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
return torch.tensor(betas)


class DDIMScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.

"""

Expand All @@ -121,36 +120,34 @@ def __init__(
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1]

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
Expand Down Expand Up @@ -186,15 +183,14 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
Expand All @@ -205,9 +201,9 @@ def step(
process from the learned model outputs (most often the predicted noise).

Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
Expand Down Expand Up @@ -251,7 +247,7 @@ def step(

# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)

# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
Expand All @@ -273,9 +269,6 @@ def step(
noise = torch.randn(model_output.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise

if not torch.is_tensor(model_output):
variance = variance.numpy()

prev_sample = prev_sample + variance

if not return_dict:
Expand All @@ -285,16 +278,20 @@ def step(

def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Expand Down
Loading