diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..1ef325f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,59 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.lz4 filter=lfs diff=lfs merge=lfs -text +*.mds filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +# Audio files - uncompressed +*.pcm filter=lfs diff=lfs merge=lfs -text +*.sam filter=lfs diff=lfs merge=lfs -text +*.raw filter=lfs diff=lfs merge=lfs -text +# Audio files - compressed +*.aac filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +# Image files - uncompressed +*.bmp filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.tiff filter=lfs diff=lfs merge=lfs -text +# Image files - compressed +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +# Video files - compressed +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index f62bdab..0c73160 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,12 @@ pretrained_ckpts *.safetensors preprocessed_dataset wandb -logs \ No newline at end of file +logs +*.pkl +*.pt +tokenize_dataset/* +inet/** +cosmos_ckpt/** +wandb/** +*.jpg +*.jpeg diff --git a/cosmos/__init__.py b/cosmos/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cosmos/image_lib.py b/cosmos/image_lib.py new file mode 100644 index 0000000..95c11db --- /dev/null +++ b/cosmos/image_lib.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A library for image tokenizers inference.""" + +import numpy as np +import torch +from typing import Any + +from cosmos.utils import ( + load_model, + load_encoder_model, + load_decoder_model, + numpy2tensor, + pad_image_batch, + tensor2numpy, + unpad_image_batch, +) + + +class ImageTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) + if checkpoint is not None + else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of image tensors after embedding into a latent. + + Args: + input_tensor: The input image Bx3xHxW layout, range [-1..1]. + Returns: + The reconstructed tensor, layout Bx3xHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = ( + output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + ) + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Decodes an image from a provided latent embedding. + + Args: + input_latent: The continuous latent Bx16xhxw for CI, + or the discrete indices Bxhxw for DI. + Returns: + The output tensor in Bx3xHxW, range [-1..1]. + """ + return self._dec_model(input_latent) + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes an image into a latent embedding or code. + + Args: + input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. + Returns: + For continuous image (CI) tokenizer, the tuple contains: + - The latent embedding, Bx16x(h)x(w), where the compression + rate is (H/h x W/w), and channel dimension of 16. + For discrete image (DI) tokenizer, the tuple contains: + - The indices, Bx(h)x(w), from a codebook of size 64K, which + corresponds to FSQ levels of (8,8,8,5,5,5). + - The discrete code, Bx6x(h)x(w), where the compression rate is + again (H/h x W/w), and channel dimension of 6. + """ + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def forward(self, image: np.ndarray) -> np.ndarray: + """Reconstructs an image using a pre-trained tokenizer. + + Args: + image: The input image BxHxWxC layout, range [0..255]. + Returns: + The reconstructed image in range [0..255], layout BxHxWxC. + """ + padded_input_image, crop_region = pad_image_batch(image) + input_tensor = numpy2tensor( + padded_input_image, dtype=self._dtype, device=self._device + ) + output_tensor = self.autoencode(input_tensor) + padded_output_image = tensor2numpy(output_tensor) + return unpad_image_batch(padded_output_image, crop_region) \ No newline at end of file diff --git a/cosmos/modules/__init__.py b/cosmos/modules/__init__.py new file mode 100644 index 0000000..4ad4af1 --- /dev/null +++ b/cosmos/modules/__init__.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + +from cosmos.modules.distributions import ( + GaussianDistribution, + IdentityDistribution, +) +from cosmos.modules.layers2d import Decoder, Encoder +from cosmos.modules.layers3d import ( + DecoderBase, + DecoderFactorized, + EncoderBase, + EncoderFactorized, +) +from cosmos.modules.quantizers import ( + FSQuantizer, + LFQuantizer, + ResidualFSQuantizer, + VectorQuantizer, +) + + +class EncoderType(Enum): + Default = Encoder + + +class DecoderType(Enum): + Default = Decoder + + +class Encoder3DType(Enum): + BASE = EncoderBase + FACTORIZED = EncoderFactorized + + +class Decoder3DType(Enum): + BASE = DecoderBase + FACTORIZED = DecoderFactorized + + +class ContinuousFormulation(Enum): + VAE = GaussianDistribution + AE = IdentityDistribution + + +class DiscreteQuantizer(Enum): + VQ = VectorQuantizer + LFQ = LFQuantizer + FSQ = FSQuantizer + RESFSQ = ResidualFSQuantizer \ No newline at end of file diff --git a/cosmos/modules/distributions.py b/cosmos/modules/distributions.py new file mode 100644 index 0000000..bf900a6 --- /dev/null +++ b/cosmos/modules/distributions.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The distribution modes to use for continuous image tokenizers.""" + +import torch + + +class IdentityDistribution(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, parameters): + return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) + + +class GaussianDistribution(torch.nn.Module): + def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): + super().__init__() + self.min_logvar = min_logvar + self.max_logvar = max_logvar + + def sample(self, mean, logvar): + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + def forward(self, parameters): + mean, logvar = torch.chunk(parameters, 2, dim=1) + logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) + return self.sample(mean, logvar), (mean, logvar) \ No newline at end of file diff --git a/cosmos/modules/layers2d.py b/cosmos/modules/layers2d.py new file mode 100644 index 0000000..fdb36b3 --- /dev/null +++ b/cosmos/modules/layers2d.py @@ -0,0 +1,368 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The model definition for Continuous 2D layers + +Adapted from: https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py + +[Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors] +https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE +""" + +import math + +import numpy as np + +# pytorch_diffusion + derived encoder decoder +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger as logging + +from cosmos.modules.patching import Patcher, UnPatcher +from cosmos.modules.utils import Normalize, nonlinearity + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.nin_shortcut = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # TODO (freda): Consider reusing implementations in Attn `imaginaire`, + # since than one is gonna be based on TransformerEngine's attn op, + # w/c could ease CP implementations. + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + in_channels = in_channels * patch_size * patch_size + + # calculate the number of downsample operations + self.num_downsamples = int(math.log2(spatial_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_downsamples <= self.num_resolutions + ), f"we can only downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, channels, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level < self.num_downsamples: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level < self.num_downsamples: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: int, + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + out_ch = out_channels * patch_size * patch_size + + # calculate the number of upsample operations + self.num_upsamples = int(math.log2(spatial_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_upsamples <= self.num_resolutions + ), f"we can only upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level >= (self.num_resolutions - self.num_upsamples): + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level >= (self.num_resolutions - self.num_upsamples): + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher(h) + return h \ No newline at end of file diff --git a/cosmos/modules/layers3d.py b/cosmos/modules/layers3d.py new file mode 100644 index 0000000..36c807e --- /dev/null +++ b/cosmos/modules/layers3d.py @@ -0,0 +1,1040 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger as logging + +from cosmos.modules.patching import ( + Patcher, + Patcher3D, + UnPatcher, + UnPatcher3D, +) +from cosmos.modules.utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) + +_LEGACY_NUM_GROUPS = 32 + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs, + ) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalUpsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + # TODO(freda): Check if this causes temporal inconsistency. + # Shoule reverse the order of the following two ops, + # better perf and better temporal smoothness. + x = self.conv(x) + return x[..., int(time_factor - 1) :, :, :] + + +class CausalDownsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + time_stride=2, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x = replication_pad(x) + x = self.conv(x) + return x + + +class CausalHybridUpsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_up: bool = True, + temporal_up: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(3, 1, 1), + stride=1, + time_stride=1, + padding=0, + ) + self.conv2 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=1, + time_stride=1, + padding=1, + ) + self.conv3 = CausalConv3d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + time_stride=1, + padding=0, + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_down: bool = True, + temporal_down: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=2, + time_stride=1, + padding=0, + ) + self.conv2 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(3, 1, 1), + stride=1, + time_stride=2, + padding=0, + ) + self.conv3 = CausalConv3d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + time_stride=1, + padding=0, + ) + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlock3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=num_groups) + self.conv1 = CausalConv3d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderBase(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + in_channels = in_channels * patch_size * patch_size + + # downsampling + self.conv_in = CausalConv3d( + in_channels, channels, kernel_size=3, stride=1, padding=1 + ) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = CausalDownsample3d(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def patcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.patcher(x) + x = batch2time(x, batch_size) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + else: + # temporal downsample (last level) + time_factor = 1 + 1 * (hs[-1].shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + hs[-1] = replication_pad(hs[-1]) + hs.append( + F.avg_pool3d( + hs[-1], + kernel_size=[time_factor, 1, 1], + stride=[2, 1, 1], + ) + ) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderBase(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + out_ch = out_channels * patch_size * patch_size + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = CausalConv3d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = CausalUpsample3d(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.unpatcher(x) + x = batch2time(x, batch_size) + + return x + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + else: + # temporal upsample (last level) + time_factor = 1.0 + 1.0 * (h.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + h = h.repeat_interleave(int(time_factor), dim=2) + h = h[..., int(time_factor - 1) :, :, :] + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d( + in_channels, + channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0 + ), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, + spatial_down=spatial_down, + temporal_down=temporal_down, + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d( + block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1 + ), + CausalConv3d( + z_channels, + z_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_spatial_ups <= self.num_resolutions + ), f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d( + z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1 + ), + CausalConv3d( + block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0 + ), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed + # in the encoder should correspond to the layer index in + # reverse order where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups + and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d( + block_in, spatial_up=spatial_up, temporal_up=temporal_up + ) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h \ No newline at end of file diff --git a/cosmos/modules/patching.py b/cosmos/modules/patching.py new file mode 100644 index 0000000..97a7333 --- /dev/null +++ b/cosmos/modules/patching.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The patcher and unpatcher implementation for 2D and 3D data. + +The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions. +One on the rows and one on the columns. +For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2. +We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component. +For H component, we can use a 1D convolution with kernel [1, -1] and stride 2. +Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all + as we need to support downsampling for more than 2x. +For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be. + [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64] +""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = False + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer( + "wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT + ) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange( + x, + "b c (h p1) (w p2) -> b (c p1 p2) h w", + p1=self.patch_size, + p2=self.patch_size, + ).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", + patch_size * torch.ones([1], dtype=torch.int32), + persistent=_PERSISTENT, + ) + + def _dwt(self, x, wavelet, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad( + x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode + ).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, "haar", rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer( + "wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT + ) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d( + xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + yl += torch.nn.functional.conv_transpose2d( + xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + yh = torch.nn.functional.conv_transpose2d( + xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + yh += torch.nn.functional.conv_transpose2d( + xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + y = torch.nn.functional.conv_transpose2d( + yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) + ) + y += torch.nn.functional.conv_transpose2d( + yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) + ) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d( + xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xll += F.conv_transpose3d( + xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + xlh = F.conv_transpose3d( + xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xlh += F.conv_transpose3d( + xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + xhl = F.conv_transpose3d( + xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xhl += F.conv_transpose3d( + xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + xhh = F.conv_transpose3d( + xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xhh += F.conv_transpose3d( + xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d( + xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + xl += F.conv_transpose3d( + xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + xh = F.conv_transpose3d( + xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + xh += F.conv_transpose3d( + xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d( + xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) + ) + x += F.conv_transpose3d( + xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) + ) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x \ No newline at end of file diff --git a/cosmos/modules/quantizers.py b/cosmos/modules/quantizers.py new file mode 100644 index 0000000..145ce88 --- /dev/null +++ b/cosmos/modules/quantizers.py @@ -0,0 +1,546 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import reduce +from loguru import logger as logging + +from cosmos.modules.utils import ( + default, + entropy, + pack_one, + rearrange, + round_ste, + unpack_one, +) + + +class ResidualFSQuantizer(nn.Module): + """Residual Finite Scalar Quantization + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + self.layers = nn.ModuleList( + [FSQuantizer(levels=levels) for _ in range(num_quantizers)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + indices_stack = [] + residual = x + quantized_out = 0 + loss_out = 0 + for i, layer in enumerate(self.layers): + quant_indices, z, loss = layer(residual) + indices_stack.append(quant_indices) + residual = residual - z.detach() + quantized_out = quantized_out + z + loss_out = loss_out + loss + self.residual = residual + indices = torch.stack(indices_stack, dim=1) + return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype) + + def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor: + quantized_out = 0 + for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)): + quantized_out += layer.indices_to_codes(indices) + return quantized_out + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Code adapted from Jax version in Appendix A.1. + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod( + torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32 + ) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = ( + nn.Linear(self.dim, effective_codebook_dim) + if has_projections + else nn.Identity() + ) + self.project_out = ( + nn.Linear(effective_codebook_dim, self.dim) + if has_projections + else nn.Identity() + ) + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes( + torch.arange(self.codebook_size), project_out=False + ) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert ( + z.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze( + 1 + ) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) + + +class VectorQuantizer(nn.Module): + """Improved version over VectorQuantizer. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + + Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/ + taming/modules/vqvae/quantize.py + + [Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer] + https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + beta: float = 0.25, + remap: str = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + use_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.n_e = num_embeddings + self.e_dim = embedding_dim + self.beta = beta + self.legacy = legacy + self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = num_embeddings + + self.sane_index_shape = sane_index_shape + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", + z_flattened, + rearrange(self.embedding.weight, "n d -> d n"), + ) + ) + + encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device) + encodings.scatter_(1, encoding_indices, 1) + z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) + min_encodings = None + + z_q, z = self.norm(z_q), self.norm(z) + + # compute loss for embedding + commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True) + emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True) + if not self.legacy: + loss = self.beta * emb_loss + commit_loss + else: + loss = emb_loss + self.beta * commit_loss + + # preserve gradients + z_q = z + (z_q - z).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = encoding_indices.squeeze(1).reshape( + z.shape[0], -1 + ) # add batch axis + min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1)) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + + # TODO: return (indices, z_q, loss) + return ( + z_q, + loss, + ( + encoding_indices.squeeze(1), + min_encodings, + commit_loss.mean().detach(), + self.beta * emb_loss.mean().detach(), + perplexity.mean().detach(), + ), + ) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class LFQuantizer(nn.Module): + """Lookup-Free Quantization + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/lookup_free_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + *, + codebook_size: int, + codebook_dim: int, + embed_dim: Optional[int] = None, # if None, use codebook_dim + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + default_temp: float = 0.01, + entropy_loss: bool = False, + **ignore_kwargs, + ): + """Lookup-Free Quantization + + Args: + codebook_size (int): The number of entries in the codebook. + codebook_dim (int): The number of bits in each code. + embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None. + entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1. + commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25. + default_temp (float, optional): The temprature to use. Defaults to 0.01. + entropy_loss (bool, optional): Flag for entropy loss. Defaults to False. + """ + super().__init__() + self.entropy_loss = entropy_loss + self.codebook_dim = codebook_dim + self.default_temp = default_temp + self.entrop_loss_weight = entropy_loss_weight + self.commitment_loss_weight = commitment_loss_weight + embed_dim = embed_dim or codebook_dim + + has_projections = embed_dim != codebook_dim + self.project_in = ( + nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity() + ) + logging.info( + f"LFQ: has_projections={has_projections}, dim_in={embed_dim}, codebook_dim={codebook_dim}" + ) + + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + if entropy_loss: + assert ( + 2**codebook_dim == codebook_size + ), "codebook size must be 2 ** codebook_dim" + self.codebook_size = codebook_size + + self.register_buffer( + "mask", + 2 ** torch.arange(codebook_dim - 1, -1, -1), + persistent=False, + ) + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = 2 * bits - 1.0 + + self.register_buffer( + "codebook", codebook, persistent=False + ) # [codebook_size, codebook_dim] + + def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor: + temp = temp or self.default_temp + + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + z = self.project_in(z) + + # split out number of codebooks + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + # quantization + original_input = z + + codebook_value = torch.ones_like(z) + z_q = torch.where(z > 0, codebook_value, -codebook_value) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # commit loss + commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3]) + + z_q = rearrange(z_q, "b n c d -> b n (c d)") + z_q = self.project_out(z_q) + + # reshape + z_q = unpack_one(z_q, ps, "b * d") + z_q = rearrange(z_q, "b ... d -> b d ...") + + loss = self.commitment_loss_weight * commit_loss + + # entropy loss (eq-5) + if self.entropy_loss: + # indices + indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") + indices = unpack_one(indices, ps, "b * c") + indices = rearrange(indices, "... 1 -> ...") + + distance = -2 * torch.einsum( + "... i d, j d -> ... i j", + original_input, + self.codebook.to(original_input.dtype), + ) + prob = (-distance / temp).softmax(dim=-1) + per_sample_entropy = entropy(prob).mean(dim=[1, 2]) + avg_prob = reduce(prob, "... c d -> c d", "mean") + codebook_entropy = entropy(avg_prob).mean() + entropy_aux_loss = per_sample_entropy - codebook_entropy + + loss += self.entrop_loss_weight * entropy_aux_loss + + # TODO: return (indices, z_q, loss) + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + ( + indices, + self.commitment_loss_weight * commit_loss.mean().detach(), + self.entrop_loss_weight * entropy_aux_loss.mean().detach(), + self.entrop_loss_weight * per_sample_entropy.mean().detach(), + self.entrop_loss_weight * codebook_entropy.mean().detach(), + ), + ) + else: + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + self.commitment_loss_weight * commit_loss.mean().detach(), + ) + + +class InvQuantizerJit(nn.Module): + """Use for decoder_jit to trace quantizer in discrete tokenizer""" + + def __init__(self, quantizer): + super().__init__() + self.quantizer = quantizer + + def forward(self, indices: torch.Tensor): + codes = self.quantizer.indices_to_codes(indices) + return codes.to(self.quantizer.dtype) \ No newline at end of file diff --git a/cosmos/modules/utils.py b/cosmos/modules/utils.py new file mode 100644 index 0000000..add8d5f --- /dev/null +++ b/cosmos/modules/utils.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utilities for the networks module.""" + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) \ No newline at end of file diff --git a/cosmos/networks/__init__.py b/cosmos/networks/__init__.py new file mode 100644 index 0000000..87818bd --- /dev/null +++ b/cosmos/networks/__init__.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from cosmos.networks.configs import ( + continuous_image as continuous_image_dict, +) +from cosmos.networks.configs import ( + discrete_image as discrete_image_dict, +) +from cosmos.networks.configs import ( + continuous_video as continuous_video_dict, +) +from cosmos.networks.configs import ( + discrete_video as discrete_video_dict, +) + +from cosmos.networks.continuous_image import ContinuousImageTokenizer +from cosmos.networks.discrete_image import DiscreteImageTokenizer + + + +class TokenizerConfigs(Enum): + CI = continuous_image_dict + DI = discrete_image_dict + CV = continuous_video_dict + DV = discrete_video_dict + + +class TokenizerModels(Enum): + CI = ContinuousImageTokenizer + DI = DiscreteImageTokenizer \ No newline at end of file diff --git a/cosmos/networks/configs.py b/cosmos/networks/configs.py new file mode 100644 index 0000000..a59cf72 --- /dev/null +++ b/cosmos/networks/configs.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The default image and video tokenizer configs.""" + +from cosmos.modules import ( + ContinuousFormulation, + DiscreteQuantizer, + EncoderType, + DecoderType, + Encoder3DType, + Decoder3DType, +) + +continuous_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The output latent dimension (channels). + latent_channels=16, + # The encoder output channels just before sampling. + # Which is also the decoder's input channels. + z_channels=16, + # A factor over the z_channels, to get the total channels the encoder should output. + # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. + z_factor=1, + name="CI", + # What formulation to use, either "AE" or "VAE". + # Chose VAE here, since the pre-trained ckpt were of a VAE formulation. + formulation=ContinuousFormulation.AE.name, + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +discrete_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before sampling. + z_channels=256, + # A factor over the z_channels, to get the total channels the encoder should output. + # for discrete tokenization, often we directly use the vector, so z_factor=1. + z_factor=1, + # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. + quantizer=DiscreteQuantizer.FSQ.name, + # The embedding dimension post-quantization, which is also the input channels of the decoder. + # Which is also the output + embedding_dim=6, + # The number of levels to use for fine-scalar quantization. + levels=[8, 8, 8, 5, 5, 5], + # The number of quantizers to use for residual fine-scalar quantization. + num_quantizers=4, + name="DI", + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +continuous_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + latent_channels=16, + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=8, + temporal_compression=8, + formulation=ContinuousFormulation.AE.name, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CV", +) + +discrete_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + quantizer=DiscreteQuantizer.FSQ.name, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="DV", +) \ No newline at end of file diff --git a/cosmos/networks/continuous_image.py b/cosmos/networks/continuous_image.py new file mode 100644 index 0000000..5cc3490 --- /dev/null +++ b/cosmos/networks/continuous_image.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The continuous image tokenizer with VAE or AE formulation for 2D data.""" + +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos.modules import ( + ContinuousFormulation, + DecoderType, + EncoderType, +) + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class ContinuousImageTokenizer(nn.Module): + def __init__( + self, z_channels: int, z_factor: int, latent_channels: int, **kwargs + ) -> None: + super().__init__() + self.name = kwargs.get("name", "ContinuousImageTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value( + z_channels=z_factor * z_channels, **kwargs + ) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = torch.nn.Conv2d( + z_factor * z_channels, z_factor * latent_channels, 1 + ) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + logging.info( + f"{self.name} based on {formulation_name} formulation, with {kwargs}." + ) + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info( + f"z_channels={z_channels}, latent_channels={self.latent_channels}." + ) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval: + latent, posteriors = self.encode(input) + dec = self.decode(latent) + if self.training: + return dict(reconstructions=dec, posteriors=posteriors, latent=latent) + return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent) \ No newline at end of file diff --git a/cosmos/networks/discrete_image.py b/cosmos/networks/discrete_image.py new file mode 100644 index 0000000..3072576 --- /dev/null +++ b/cosmos/networks/discrete_image.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ.""" +from collections import OrderedDict, namedtuple + +import torch +from loguru import logger as logging +from torch import nn + +from cosmos.modules import DecoderType, DiscreteQuantizer, EncoderType +from cosmos.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class DiscreteImageTokenizer(nn.Module): + def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "DiscreteImageTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1) + self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert ( + "num_embeddings" in kwargs + ), f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert ( + "codebook_size" in kwargs + ), f"`codebook_size` must be provided for {quantizer_name}." + assert ( + "codebook_dim" in kwargs + ), f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert ( + "levels" in kwargs + ), f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert ( + "levels" in kwargs + ), f"`levels` must be provided for {quantizer_name}.name." + assert ( + "num_quantizers" in kwargs + ), f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(DiscreteImageTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) \ No newline at end of file diff --git a/cosmos/utils.py b/cosmos/utils.py new file mode 100644 index 0000000..523c051 --- /dev/null +++ b/cosmos/utils.py @@ -0,0 +1,408 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for the inference libraries.""" + +import os +from glob import glob +from typing import Any + +import mediapy as media +import numpy as np +import torch +from PIL import Image + +from cosmos.networks import TokenizerModels + +_DTYPE, _DEVICE = torch.bfloat16, "cuda" +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) +_SPATIAL_ALIGN = 16 +_TEMPORAL_ALIGN = 8 + + +def load_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + full_model.load_state_dict(ckpts.state_dict(), strict=False) + return full_model.eval().to(device) + + +def load_encoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + encoder_model = full_model.encoder_jit() + encoder_model.load_state_dict(ckpts.state_dict(), strict=False) + return encoder_model.eval().to(device) + + +def load_decoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + decoder_model = full_model.decoder_jit() + decoder_model.load_state_dict(ckpts.state_dict(), strict=False) + return decoder_model.eval().to(device) + + +def _load_pytorch_model( + jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" +) -> torch.nn.Module: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + tokenizer_name = tokenizer_config["name"] + model = TokenizerModels[tokenizer_name].value(**tokenizer_config) + ckpts = torch.jit.load(jit_filepath) + return model, ckpts + + +def load_jit_model( + jit_filepath: str = None, device: str = "cuda" +) -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + model = torch.jit.load(jit_filepath) + return model.eval().to(device) + + +def save_jit_model( + model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, + jit_filepath: str = None, +) -> None: + """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. + + Args: + model: JIT compiled model loaded onto `config.checkpoint.jit.device`. + jit_filepath: The filepath to the JIT-compiled model. + """ + torch.jit.save(model, jit_filepath) + + +def get_filepaths(input_pattern) -> list[str]: + """Returns a list of filepaths from a pattern.""" + filepaths = sorted(glob(str(input_pattern))) + return list(set(filepaths)) + + +def get_output_filepath(filepath: str, output_dir: str = None) -> str: + """Returns the output filepath for the given input filepath.""" + output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" + output_filepath = f"{output_dir}/{os.path.basename(filepath)}" + os.makedirs(output_dir, exist_ok=True) + return output_filepath + + +def read_image(filepath: str) -> np.ndarray: + """Reads an image from a filepath. + + Args: + filepath: The filepath to the image. + + Returns: + The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. + """ + image = media.read_image(filepath) + # convert the grey scale image to RGB + # since our tokenizers always assume 3-channel RGB image + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + # convert RGBA to RGB + if image.shape[-1] == 4: + image = image[..., :3] + return image + + +def read_video(filepath: str) -> np.ndarray: + """Reads a video from a filepath. + + Args: + filepath: The filepath to the video. + Returns: + The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. + """ + video = media.read_video(filepath) + # convert the grey scale frame to RGB + # since our tokenizers always assume 3-channel video + if video.ndim == 3: + video = np.stack([video] * 3, axis=-1) + # convert RGBA to RGB + if video.shape[-1] == 4: + video = video[..., :3] + return video + + +def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes an image to have the short side of `short_size`. + + Args: + image: The image to resize, layout HxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized image. + """ + if short_size is None: + return image + height, width = image.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_image(image, shape=(height_new, width_new)) + + +def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes a video to have the short side of `short_size`. + + Args: + video: The video to resize, layout TxHxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized video. + """ + if short_size is None: + return video + height, width = video.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_video(video, shape=(height_new, width_new)) + + +def write_image(filepath: str, image: np.ndarray): + """Writes an image to a filepath.""" + return media.write_image(filepath, image) + + +def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: + """Writes a video to a filepath.""" + return media.write_video(filepath, video, fps=fps) + + +def numpy2tensor( + input_image: np.ndarray, + dtype: torch.dtype = _DTYPE, + device: str = _DEVICE, + range_min: int = -1, +) -> torch.Tensor: + """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. + + Args: + input_image: A batch of images in range [0..255], BxHxWx3 layout. + Returns: + A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. + """ + ndim = input_image.ndim + indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] + image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F + if range_min == -1: + image = 2.0 * image - 1.0 + return torch.from_numpy(image).to(dtype).to(device) + + +def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: + """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. + + Args: + input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. + Returns: + A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. + """ + if range_min == -1: + input_tensor = (input_tensor.float() + 1.0) / 2.0 + ndim = input_tensor.ndim + output_image = input_tensor.clamp(0, 1).cpu().numpy() + output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) + return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) + + +def pad_image_batch( + batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN +) -> tuple[np.ndarray, list[int]]: + """Pads a batch of images to be divisible by `spatial_align`. + + Args: + batch: The batch of images to pad, layout BxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + height, width = batch.shape[1:3] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + crop_region = [ + height_to_pad >> 1, + width_to_pad >> 1, + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + return batch, crop_region + + +def pad_video_batch( + batch: np.ndarray, + temporal_align: int = _TEMPORAL_ALIGN, + spatial_align: int = _SPATIAL_ALIGN, +) -> tuple[np.ndarray, list[int]]: + """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. + + Zero pad spatially. Reflection pad temporally to handle causality better. + Args: + batch: The batch of videos to pad., layout BxFxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + num_frames, height, width = batch.shape[-4:-1] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + align = temporal_align + frames_to_pad = ( + (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 + ) + + crop_region = [ + frames_to_pad >> 1, + height_to_pad >> 1, + width_to_pad >> 1, + num_frames + (frames_to_pad >> 1), + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + batch = np.pad( + batch, + ( + (0, 0), + (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), + (0, 0), + (0, 0), + (0, 0), + ), + mode="edge", + ) + return batch, crop_region + + +def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads video with `crop_region`. + + Args: + batch: A batch of numpy videos, layout BxFxHxWxC. + crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy video, layout BxFxHxWxC. + """ + assert len(crop_region) == 6, "crop_region should be len of 6." + f1, y1, x1, f2, y2, x2 = crop_region + return batch[..., f1:f2, y1:y2, x1:x2, :] + + +def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads image with `crop_region`. + + Args: + batch: A batch of numpy images, layout BxHxWxC. + crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy image, layout BxHxWxC. + """ + assert len(crop_region) == 4, "crop_region should be len of 4." + y1, x1, y2, x2 = crop_region + return batch[..., y1:y2, x1:x2, :] \ No newline at end of file diff --git a/dit.py b/dit.py new file mode 100644 index 0000000..894d168 --- /dev/null +++ b/dit.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn +import numpy as np +import math +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp +import click + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class SiTBlock(nn.Module): + """ + A single Transformer block that optionally accepts a skip tensor. + If skip=True, we learn a linear projection over the concatenation of the current features x and skip. + """ + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + skip=False, + use_checkpoint=False, + ): + super().__init__() + self.norm1 = norm_layer(hidden_size, eps=1e-6) + self.attn = Attention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + self.norm2 = norm_layer(hidden_size, eps=1e-6) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=0.0, + ) + + # For injecting time or label embeddings (AdaLayerNorm style) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + # Skip connection logic + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, c, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, c, skip) + else: + return self._forward(x, c, skip) + + def _forward(self, x, c, skip=None): + # If skip_linear exists, we do "concat + linear" just like the paper + if self.skip_linear is not None and skip is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + + # AdaLayerNorm modulations from c + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + + # --- Attention path --- + x_attn_normed = modulate(self.norm1(x), shift_msa, scale_msa) + x_attn = self.attn(x_attn_normed) + x = x + gate_msa.unsqueeze(1) * x_attn + + # --- MLP path --- + x_mlp_normed = modulate(self.norm2(x), shift_mlp, scale_mlp) + x_mlp = self.mlp(x_mlp_normed) + x = x + gate_mlp.unsqueeze(1) * x_mlp + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class SiT(nn.Module): + """ + A UViT-like refactor of your SiT model: + - Split 'depth' into in-blocks, a single mid-block, and out-blocks + - Skip-connections from in-block outputs to out-blocks + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=False + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + + # Number of 'in blocks' and 'out blocks' (like U-Net's encoder/decoder) + # We'll reserve 1 block for the 'mid_block' in the center. + in_depth = depth // 2 + out_depth = depth - in_depth - 1 + + # Patch embedding + self.x_embedder = PatchEmbed( + img_size=input_size, + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size + ) + # Timestep + label + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + + # Precompute positional embeddings + self.num_patches = self.x_embedder.num_patches + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, hidden_size), + requires_grad=False + ) + + # In-blocks (encoder) + self.in_blocks = nn.ModuleList([ + SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, skip=False) + for _ in range(in_depth) + ]) + # Mid-block + self.mid_block = SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, skip=False) + # Out-blocks (decoder), each with skip=True + self.out_blocks = nn.ModuleList([ + SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, skip=True) + for _ in range(out_depth) + ]) + + # Final prediction layer + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + + # Initialize + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.num_patches**0.5) + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize label embedding table + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize TimestepEmbedder MLP + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in all blocks + for block in list(self.in_blocks) + [self.mid_block] + list(self.out_blocks): + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out final layer + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.view(x.shape[0], h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4) # N, c, h, p, w, p + imgs = x.reshape(x.shape[0], c, h * p, w * p) + return imgs + + def forward(self, x, t, y): + # 1. Patch Embedding + Pos Embedding + x = self.x_embedder(x) + self.pos_embed # (N, T, D) + # 2. Timestep Embedding + t = self.t_embedder(t) # (N, D) + # 3. Label Embedding (with optional dropout for classifier-free guidance) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # Combined condition embedding + + # ============ Encoder (in_blocks) ============ + skips = [] + for blk in self.in_blocks: + x = blk(x, c) # no skip yet + skips.append(x) + + # ============ Mid-block ============ + x = self.mid_block(x, c) + + # ============ Decoder (out_blocks) ============ + # pop from the 'skips' list to feed into out-blocks + for blk in self.out_blocks: + skip_x = skips.pop() # last in-block output + x = blk(x, c, skip=skip_x) + + # ============ Final Prediction + Unpatchify ============ + x = self.final_layer(x, c) + x = self.unpatchify(x) + if self.learn_sigma: + x, _ = x.chunk(2, dim=1) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Classifier-free guidance pass. Similar to your original logic. + """ + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + emb = np.concatenate([get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]), get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) ], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + assert embed_dim % 2 == 0 + omega = 1. / (10000 ** ((np.arange(embed_dim // 2, dtype=np.float64)) / (embed_dim / 2.))) + out = np.einsum('m,d->md', pos.reshape(-1), omega) + emb = np.concatenate([np.sin(out), np.cos(out) ], axis=1) + return emb + +################################################################################# +# SiT Configs # +################################################################################# + +SiT_models = { + 'SiT-XL/2': lambda **kwargs:SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) , 'SiT-XL/4': lambda **kwargs:SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) , 'SiT-XL/8': lambda **kwargs:SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) , + 'SiT-L/2': lambda **kwargs:SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) , 'SiT-L/4': lambda **kwargs:SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) , 'SiT-L/8': lambda **kwargs:SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) , + 'SiT-B/2': lambda **kwargs:SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) , 'SiT-B/4': lambda **kwargs:SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) , 'SiT-B/8': lambda **kwargs:SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) , + 'SiT-S/2': lambda **kwargs:SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) , 'SiT-S/4': lambda **kwargs:SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) , 'SiT-S/8': lambda **kwargs:SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) , + 'SiT-T/8': lambda **kwargs:SiT(depth=6, hidden_size=192, patch_size=8, num_heads=6, **kwargs) +} diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..96ac372 --- /dev/null +++ b/eval.py @@ -0,0 +1,95 @@ +import torch +from tqdm import tqdm +from sklearn.metrics.pairwise import polynomial_kernel +import numpy as np +import torch.distributed as dist +def compute_mmd(feat_real, feat_gen, n_subsets=100, subset_size=1000, **kernel_args): + m = min(feat_real.shape[0], feat_gen.shape[0]) + subset_size = min(subset_size, m) + mmds = np.zeros(n_subsets) + choice = np.random.choice + + # with range(n_subsets) as bar: + for i in range(n_subsets): + g = feat_real[choice(len(feat_real), subset_size, replace=False)] + r = feat_gen[choice(len(feat_gen), subset_size, replace=False)] + o = compute_polynomial_mmd(g, r, **kernel_args) + mmds[i] = o + # bar.set_postfix({'mean': mmds[:i+1].mean()}) + return mmds + + +def compute_polynomial_mmd(feat_r, feat_gen, degree=3, gamma=None, coef0=1): + # use k(x, y) = (gamma + coef0)^degree + # default gamma is 1 / dim + X = feat_r + Y = feat_gen + + K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) + K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) + K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) + + return _mmd2_and_variance(K_XX, K_XY, K_YY) + + +def _mmd2_and_variance(K_XX, K_XY, K_YY): + # based on https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py + # but changed to not compute the full kernel matrix at once + m = K_XX.shape[0] + assert K_XX.shape == (m, m) + assert K_XY.shape == (m, m) + assert K_YY.shape == (m, m) + + diag_X = np.diagonal(K_XX) + diag_Y = np.diagonal(K_YY) + + Kt_XX_sums = K_XX.sum(axis=1) - diag_X + Kt_YY_sums = K_YY.sum(axis=1) - diag_Y + K_XY_sums_0 = K_XY.sum(axis=0) + + Kt_XX_sum = Kt_XX_sums.sum() + Kt_YY_sum = Kt_YY_sums.sum() + K_XY_sum = K_XY_sums_0.sum() + + mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) + mmd2 -= 2 * K_XY_sum / (m * m) + return mmd2 + +class Eval: + def __init__(self): + # initialize dinov2 model + from transformers import AutoModel, AutoImageProcessor + local_rank = torch.distributed.get_rank() + if local_rank == 0: + self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large') + self.model = AutoModel.from_pretrained('facebook/dinov2-large').bfloat16() + dist.barrier() + if local_rank != 0: + self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large') + self.model = AutoModel.from_pretrained('facebook/dinov2-large').bfloat16() + dist.barrier() + self.model.to(f"cuda:{local_rank}") + self.load_precomputed_features() + + def load_precomputed_features(self, dataset_path="./inet/dinov2_inet_feats.pt"): + self.precomputed_features = torch.load(dataset_path).float().cpu().numpy() + + @torch.no_grad() + def eval(self, val_images): + """ + val_images: Tensor of shape (N, 3, 256, 256) , from 0 to 255, in uint8 + """ + # preprocess the images + inputs = self.processor(images=val_images, return_tensors="pt").to(self.model.device) + # forward pass + outputs = self.model(**inputs) + # get the embeddings + embeddings = outputs.pooler_output + # normalize the embeddings + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) + embeddings = embeddings.float().cpu().numpy() + + # compute the mmd + mmd = compute_mmd(self.precomputed_features, embeddings) + mmd = max(0, mmd.mean()) + return mmd diff --git a/generate_gpt.py b/generate_gpt.py index 178eb15..27bd8ec 100644 --- a/generate_gpt.py +++ b/generate_gpt.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image import os -from cosmos_tokenizer.image_lib import ImageTokenizer +from cosmos.image_lib import ImageTokenizer @click.command() diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000..cb675fd --- /dev/null +++ b/setup.sh @@ -0,0 +1,6 @@ +mkdir inet +mkdir cosmos_ckpt +wget https://huggingface.co/nvidia/Cosmos-Tokenizer-CI8x8/resolve/main/decoder.jit -O cosmos_ckpt/decoder.jit +wget https://huggingface.co/datasets/fal/cosmos-imagenet/resolve/main/imagenet_ci8x8.safetensors -O inet/imagenet_ci8x8.safetensors +wget https://huggingface.co/datasets/fal/cosmos-imagenet/resolve/main/imagenet_ci8x8_val.safetensors -O inet/imagenet_ci8x8_val.safetensors +wget https://huggingface.co/datasets/ramimmo/dinov2.inet/resolve/main/dinov2_inet_feats.pt -O inet/dinov2_inet_feats.pt \ No newline at end of file diff --git a/sweep.sh b/sweep.sh deleted file mode 100644 index 4ac8f64..0000000 --- a/sweep.sh +++ /dev/null @@ -1,8 +0,0 @@ -log_lrs=(-13 -12 -11 -10 -9) -vres_list=(True False) -for vres in "${vres_list[@]}"; do - for log_lr in "${log_lrs[@]}"; do - lr=$(python -c "print(2 ** $log_lr)") - torchrun --nnode=1 --nproc_per_node=8 train_gpt.py --run_name "layer48_sweep_gpt_${lr}_${vres}" --learning_rate $lr --vres $vres - done -done diff --git a/train_classifier.py b/train_classifier.py new file mode 100644 index 0000000..74f4ebc --- /dev/null +++ b/train_classifier.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +from safetensors.torch import safe_open + +with safe_open("tokenize_dataset/imagenet_ci8x8.safetensors", framework="pt") as f: + data = f.get_tensor("latents") + labels = f.get_tensor("labels").long() + +print(data.shape) # 1281167, 16, 32, 32 +print(labels.shape) # 1281167 + +data = data.reshape(data.shape[0], -1) # Flatten to (N, 16*32*32) +gen = torch.Generator() +gen.manual_seed(42) +perm = torch.randperm(data.size(0), generator=gen) +data = data[perm] +labels = labels[perm] + +train_size = int(0.95 * len(data)) + +train_data = data[:train_size] +train_labels = labels[:train_size] +val_data = data[train_size:] +val_labels = labels[train_size:] + + +class LinearRegression(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.linear = nn.Sequential( + nn.Linear(input_dim, 1024), + nn.ReLU(), + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, output_dim), + ) + + def forward(self, x): + return self.linear(x) + + +input_dim = 16 * 32 * 32 # Flattened input dimension +output_dim = 1000 # ImageNet has 1000 classes +model = LinearRegression(input_dim, output_dim) +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + +num_epochs = 10 +batch_size = 4096 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device) + +for epoch in range(num_epochs): + model.train() + total_loss = 0 + num_batches = 0 + + for i in range(0, len(train_data), batch_size): + batch_data = ( + train_data[i : i + batch_size].to(torch.float32).to(device) * 16.0 / 255.0 + ) + + batch_labels = train_labels[i : i + batch_size].to(device) + + optimizer.zero_grad() + outputs = model(batch_data) + loss = criterion(outputs, batch_labels) + loss.backward() + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches + + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for i in range(0, len(val_data), batch_size): + batch_data = ( + val_data[i : i + batch_size].to(torch.float32).to(device) * 16.0 / 255.0 + ) + batch_labels = val_labels[i : i + batch_size].to(device) + + outputs = model(batch_data) + _, predicted = torch.max(outputs.data, 1) + total += batch_labels.size(0) + correct += (predicted == batch_labels).sum().item() + + accuracy = 100 * correct / total + print( + f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Validation Accuracy: {accuracy:.2f}%" + ) + +print("Training finished!") diff --git a/train_diffusion.py b/train_diffusion.py new file mode 100644 index 0000000..81a0458 --- /dev/null +++ b/train_diffusion.py @@ -0,0 +1,434 @@ +import os +from torch.nn.utils import clip_grad_norm_ + +from transport import create_transport +from transport.transport import Sampler +from dit import SiT_models +from eval import Eval +import click +import torch +import torch.nn.functional as F +import numpy as np +import random +import torch.distributed as dist +from torch.backends.cuda import ( + enable_cudnn_sdp, + enable_flash_sdp, + enable_math_sdp, + enable_mem_efficient_sdp, +) +from torchvision.utils import make_grid +import copy +from torch.nn.attention import SDPBackend, sdpa_kernel +from datetime import datetime +from PIL import Image +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +import wandb +from collections import OrderedDict +from tqdm import tqdm, trange +import cosmos.image_lib +import cosmos +from safetensors.torch import safe_open + + +SCALING_FACTOR = 0.4 +NUM_CHANNELS = 16 + + +class IMAGENET(torch.utils.data.Dataset): + def __init__(self, is_train=True): + dpath = "./inet/imagenet_ci8x8.safetensors" if is_train else "./inet/imagenet_ci8x8_val.safetensors" + with safe_open(dpath, framework="pt") as f: + self.labels, self.latents = f.get_tensor("labels"), (f.get_tensor("latents") ) + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + image = self.latents[idx] + label = self.labels[idx] + image = image.type(torch.float32) * 16.0 / 255.0 + return image, int(label) + + + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + +def cleanup(): + dist.destroy_process_group() + + +def cosmos_vae(device="cuda:0"): + + vae = cosmos.image_lib.ImageTokenizer( + checkpoint_dec= "./cosmos_ckpt/decoder.jit", + ).to(device) + + + def decode(z): + latent = z + latent = latent.type(torch.bfloat16) + with torch.no_grad(): + return vae.decode(latent) + + return None, decode + + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +@click.command() +@click.option("--run_name", default="run_1", help="Name of the run") +@click.option("--global_batch_size", default=256, help="Global batch size across all GPUs") +@click.option("--global_seed", default=4, help="Global seed") +@click.option("--per_gpu_batch_size", default=32, help="Per GPU batch size") +@click.option("--num_iterations", default=500_000, help="Number of training iterations") +@click.option("--learning_rate", default=1e-4, help="Learning rate") +@click.option("--sample_every", default=10_000, help="Sample frequency") +@click.option("--val_every", default=2_000, help="Validation frequency") +@click.option("--kdd_every", default=2_000, help="KDD evaluation frequency") +@click.option("--save_every", default=2_000, help="Checkpoint save frequency") +@click.option("--init_ckpt", default=None, help="Path to initial checkpoint") +@click.option("--cfg_scale", default=1.5, help="CFG scale during KDD evaluation") +@click.option("--uncond_prob", default=0.1, help="Probability of dropping label for unconditional training") +def main(run_name, global_batch_size, global_seed, per_gpu_batch_size, num_iterations, + learning_rate, sample_every, val_every, kdd_every, save_every, init_ckpt, cfg_scale, uncond_prob): + + ddp_rank = int(os.environ["RANK"]) + ddp_local_rank = int(os.environ["LOCAL_RANK"]) + ddp_world_size = int(os.environ["WORLD_SIZE"]) + torch.manual_seed(global_seed + ddp_rank) + np.random.seed(global_seed + ddp_rank) + random.seed(global_seed + ddp_rank) + + ########################################################################## + # DDP Initialization and Basic Setup # + ########################################################################## + val_per_gpu_batch_size = per_gpu_batch_size * 2 + dist.init_process_group(backend="nccl") + + device = f"cuda:{ddp_local_rank}" + torch.cuda.set_device(device) + master_process = (ddp_rank == 0) + + + grad_accum_steps = int(global_batch_size // (per_gpu_batch_size * ddp_world_size)) + date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_id = f"{date_time}_{run_name}" + + if master_process: + print(f"Global batch size: {global_batch_size}") + print(f"Per GPU batch size: {per_gpu_batch_size}") + print(f"Gradient accumulation steps: {grad_accum_steps}") + print(f"Effective batch size per step: {per_gpu_batch_size * ddp_world_size}") + + wandb.init( + project="imagegpt", + name=run_name, + config={ + "global_batch_size": global_batch_size, + "per_gpu_batch_size": per_gpu_batch_size, + "grad_accum_steps": grad_accum_steps, + "num_iterations": num_iterations, + "learning_rate": learning_rate, + "sample_every": sample_every, + "val_every": val_every, + "kdd_every": kdd_every, + "save_every": save_every, + "cfg_scale": cfg_scale, + "uncond_prob": uncond_prob + }, + ) + wandb.run.log_code(".") + + # Allow tf32 for speed + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + model = SiT_models['SiT-XL/2'](in_channels=NUM_CHANNELS).to(memory_format=torch.channels_last) # From your code + model = torch.compile(model) + # ema model + ema = copy.deepcopy(model) + model = model.to(device) + ema = ema.to(device) + requires_grad(ema, False) + # ema + if init_ckpt is not None and master_process: + print(f"Loading checkpoint from {init_ckpt}") + if init_ckpt is not None: + checkpoint = torch.load(init_ckpt, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + ema.load_state_dict(checkpoint["ema"]) + + random_tensor = torch.ones(1000, 1000, device=device) * ddp_rank + dist.all_reduce(random_tensor, op=dist.ReduceOp.SUM) + if master_process: + print(f"Rank {ddp_rank} has value {random_tensor[0, 0].item()}\n") + + # Wrap in DDP + model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=False) + transport = create_transport( + "Linear", + "velocity", + "velocity", + 1e-3, + 1e-3 + ) # default: velocity; + transport_sampler = Sampler(transport) + + # make x_embedder & final_layer lr / 4 , rest lr, optimizer + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=0, + ) + + + + enable_cudnn_sdp(True) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + + train_dataset = IMAGENET(is_train=True) + val_dataset = IMAGENET(is_train=False) + + train_sampler = torch.utils.data.DistributedSampler( + train_dataset, num_replicas=ddp_world_size, rank=ddp_rank, shuffle=True, seed=global_seed + ) + val_sampler = torch.utils.data.DistributedSampler( + val_dataset, num_replicas=ddp_world_size, rank=ddp_rank, shuffle=False + ) + + train_loader = DataLoader( + train_dataset, + batch_size=per_gpu_batch_size, + sampler=train_sampler, + num_workers=4, + pin_memory=True, + drop_last=False, + persistent_workers=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=val_per_gpu_batch_size, + sampler=val_sampler, + num_workers=4, + pin_memory=True, + drop_last=False, + ) + + + + evaluator = Eval() # Uses DINOv2 for MMD + _, decode_fn = cosmos_vae(device=device) # for latents decode + + + seed_for_rank = global_seed + ddp_rank + # find appropriate # of samples that's more than 2000 and divisible by world size + num_kdd_samples = ((2000 // ddp_world_size + 1)) * ddp_world_size + num_kdd_samples_per_rank = num_kdd_samples // ddp_world_size + fixed_class_ids = torch.randint(0, 1000, (num_kdd_samples_per_rank,), generator=torch.Generator().manual_seed(seed_for_rank)).to(device) + ########################################################################## + # Helper Functions # + ########################################################################## + + ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + + + @torch.no_grad() + def do_validation(): + """ + Compute a simple validation loss across the entire val_loader. + """ + model.eval() + val_losses = [] + + for val_latents, val_labels in val_loader: + val_latents = val_latents.to(device, non_blocking=True).to(memory_format=torch.channels_last) + val_labels = val_labels.to(device, non_blocking=True) + + # Scale latents by SCALING_FACTOR for stable diffusion training + data_val = val_latents * SCALING_FACTOR + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + model_kwargs = dict(y=val_labels) + loss_dict = transport.training_losses(model, data_val, model_kwargs) + loss = loss_dict["loss"].mean() + val_losses.append(loss.item()) + + val_loss = np.mean(val_losses) + model.train() + return val_loss + + @torch.no_grad() + def do_ema_sample(num_samples): + """ + Sample images using the EMA model with CFG. + Returns decoded images in uint8 format [0, 255]. + """ + ema.eval() + z = torch.randn(num_samples, 16, 32, 32, device=device, + generator=torch.Generator(device=device).manual_seed(seed_for_rank)) + y = fixed_class_ids[:num_samples] + + all_imgs = [] + + range_fn_imgs = (lambda *args, **kwargs: trange(*args, **kwargs, position=1)) if master_process else range + sample_fn = transport_sampler.sample_ode() + + + for i in range_fn_imgs(0, num_samples, val_per_gpu_batch_size): + z_i = z[i:i+val_per_gpu_batch_size] + y_i = y[i:i+val_per_gpu_batch_size] + b_i = z_i.size(0) + + z_i = torch.cat([z_i, z_i], dim=0).to(memory_format=torch.channels_last) + ynull = torch.zeros_like(y_i) + 1000 + y_i = torch.cat([y_i, ynull], dim=0) + model_kwargs = dict(y=y_i, cfg_scale=cfg_scale) + model_fn = ema.forward_with_cfg + # with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"): + + + + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, + SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + z_i = sample_fn(z_i, model_fn, **model_kwargs)[-1] + z_i, _ = z_i.chunk(2, dim=0) # Remove null class samples + z_i = z_i / SCALING_FACTOR + imgs = decode_fn(z_i) + imgs = (imgs.clamp(-1,1) + 1) * 127.5 + imgs = imgs.type(torch.uint8) + all_imgs.append(imgs) + + return torch.cat(all_imgs, dim=0) + + @torch.no_grad() + def do_sample_grid(step): + samples = do_ema_sample(per_gpu_batch_size) # (PGPU, 3, 256, 256) + all_samples = torch.zeros((global_batch_size, 3, 256, 256), device=device, dtype=samples.dtype) + dist.all_gather_into_tensor(all_samples, samples) + # all_samples = all_samples.permute(0, 2, 3, 1) + x = make_grid(all_samples, nrow=int(np.sqrt(global_batch_size))) + x = x.permute(1, 2, 0) + if master_process: + sample = Image.fromarray(x.cpu().numpy()) + sample.save("sample.jpg", quality=50) + sample.save("sample_hq.jpg", quality=95) + wandb.log({f"samples": wandb.Image("./sample.jpg"),f"samples_hq": wandb.Image("./sample_hq.jpg")}, step=step) + dist.barrier() + + @torch.no_grad() + def do_kdd_evaluation(): + imgs = do_ema_sample(num_kdd_samples_per_rank) + mmd = evaluator.eval(imgs) + return mmd + + ########################################################################## + # Training Loop # + ########################################################################## + + update_ema(ema, model.module, 0.0) + model.train() + ema.eval() + train_iter = iter(train_loader) + optimizer.zero_grad(set_to_none=True) + + pbar = tqdm(range(num_iterations), desc="Training", position=0) if master_process else range(num_iterations) + running_loss = [] + for step in pbar: + epoch = step // len(train_loader) + train_sampler.set_epoch(epoch) + + # Gradient Accumulation + for micro_step in range(grad_accum_steps): + try: + latents, labels = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + latents, labels = next(train_iter) + + latents = latents.to(device, non_blocking=True).to(memory_format=torch.channels_last) + labels = labels.to(device, non_blocking=True) + with torch.no_grad(): + latents = latents * SCALING_FACTOR + + with ctx: + model_kwargs = dict(y=labels) + loss_dict = transport.training_losses(model, latents, model_kwargs) + loss = loss_dict["loss"].mean() + running_loss.append(loss.item()) + + + loss.backward() + + clip_grad_norm_(model.parameters(), max_norm=1.0) # or any suitable value + optimizer.step() + # ema beta should be 0 for first 10k steps and then annealed to 0.999 within 100k steps + ema_beta = 0.0 if (step < 10000) else 0.999 + update_ema(ema, model.module, ema_beta) + optimizer.zero_grad(set_to_none=True) + + # Logging + if master_process and step % 10 == 0: + wandb.log({ + "train/loss": np.mean(running_loss), + }, step=step) + running_loss = [] + # Validation + if step % val_every == 0: + val_loss = do_validation() + if master_process: + wandb.log({"val/loss": val_loss}, step=step) + + + + # KDD Evaluation + if step % kdd_every == 0: + kdd = do_kdd_evaluation() + if master_process: + # print(f"step: {step}, kdd: {kdd:.4f}") + wandb.log({"kdd/mmd": kdd}, step=step) + + # Sample + if step % sample_every == 0: + do_sample_grid(step) + + # Save Checkpoints + if master_process and step > 0 and step % save_every == 0: + checkpoint = { + "model": model.module.state_dict(), + "optimizer": optimizer.state_dict(), + "step": step, + "ema": ema.state_dict(), + } + os.makedirs(f"logs/ckpts_{run_id}", exist_ok=True) + ckpt_path = f"logs/ckpts_{run_id}/step_{step}.pt" + print(f"Saving checkpoint to {ckpt_path}") + torch.save(checkpoint, ckpt_path) + + + + if master_process: + wandb.finish() + dist.destroy_process_group() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/transport/__init__.py b/transport/__init__.py new file mode 100644 index 0000000..db68edd --- /dev/null +++ b/transport/__init__.py @@ -0,0 +1,63 @@ +from .transport import Transport, ModelType, WeightType, PathType, Sampler + +def create_transport( + path_type='Linear', + prediction="velocity", + loss_weight=None, + train_eps=None, + sample_eps=None, +): + """function for creating Transport object + **Note**: model prediction defaults to velocity + Args: + - path_type: type of path to use; default to linear + - learn_score: set model prediction to score + - learn_noise: set model prediction to noise + - velocity_weighted: weight loss by velocity weight + - likelihood_weighted: weight loss by likelihood weight + - train_eps: small epsilon for avoiding instability during training + - sample_eps: small epsilon for avoiding instability during sampling + """ + + if prediction == "noise": + model_type = ModelType.NOISE + elif prediction == "score": + model_type = ModelType.SCORE + else: + model_type = ModelType.VELOCITY + + if loss_weight == "velocity": + loss_type = WeightType.VELOCITY + elif loss_weight == "likelihood": + loss_type = WeightType.LIKELIHOOD + else: + loss_type = WeightType.NONE + + path_choice = { + "Linear": PathType.LINEAR, + "GVP": PathType.GVP, + "VP": PathType.VP, + } + + path_type = path_choice[path_type] + + if (path_type in [PathType.VP]): + train_eps = 1e-5 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): + train_eps = 1e-3 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + else: # velocity & [GVP, LINEAR] is stable everywhere + train_eps = 0 + sample_eps = 0 + + # create flow state + state = Transport( + model_type=model_type, + path_type=path_type, + loss_type=loss_type, + train_eps=train_eps, + sample_eps=sample_eps, + ) + + return state \ No newline at end of file diff --git a/transport/integrators.py b/transport/integrators.py new file mode 100644 index 0000000..adf7c7b --- /dev/null +++ b/transport/integrators.py @@ -0,0 +1,117 @@ +import numpy as np +import torch as th +import torch.nn as nn +from torchdiffeq import odeint +from functools import partial +from tqdm import tqdm + +class sde: + """SDE solver class""" + def __init__( + self, + drift, + diffusion, + *, + t0, + t1, + num_steps, + sampler_type, + ): + assert t0 < t1, "SDE sampler has to be in forward time" + + self.num_timesteps = num_steps + self.t = th.linspace(t0, t1, num_steps) + self.dt = self.t[1] - self.t[0] + self.drift = drift + self.diffusion = diffusion + self.sampler_type = sampler_type + + def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + t = th.ones(x.size(0)).to(x) * t + dw = w_cur * th.sqrt(self.dt) + drift = self.drift(x, t, model, **model_kwargs) + diffusion = self.diffusion(x, t) + mean_x = x + drift * self.dt + x = mean_x + th.sqrt(2 * diffusion) * dw + return x, mean_x + + def __Heun_step(self, x, _, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + dw = w_cur * th.sqrt(self.dt) + t_cur = th.ones(x.size(0)).to(x) * t + diffusion = self.diffusion(x, t_cur) + xhat = x + th.sqrt(2 * diffusion) * dw + K1 = self.drift(xhat, t_cur, model, **model_kwargs) + xp = xhat + self.dt * K1 + K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) + return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step + + def __forward_fn(self): + """TODO: generalize here by adding all private functions ending with steps to it""" + sampler_dict = { + "Euler": self.__Euler_Maruyama_step, + "Heun": self.__Heun_step, + } + + try: + sampler = sampler_dict[self.sampler_type] + except: + raise NotImplementedError("Smapler type not implemented.") + + return sampler + + def sample(self, init, model, **model_kwargs): + """forward loop of sde""" + x = init + mean_x = init + samples = [] + sampler = self.__forward_fn() + for ti in self.t[:-1]: + with th.no_grad(): + x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) + samples.append(x) + + return samples + +class ode: + """ODE solver class""" + def __init__( + self, + drift, + *, + t0, + t1, + sampler_type, + num_steps, + atol, + rtol, + ): + assert t0 < t1, "ODE sampler has to be in forward time" + + self.drift = drift + self.t = th.linspace(t0, t1, num_steps) + self.atol = atol + self.rtol = rtol + self.sampler_type = sampler_type + + def sample(self, x, model, **model_kwargs): + + device = x[0].device if isinstance(x, tuple) else x.device + def _fn(t, x): + t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t + model_output = self.drift(x, t, model, **model_kwargs) + return model_output + + t = self.t.to(device) + atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] + rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] + samples = odeint( + _fn, + x, + t, + method=self.sampler_type, + atol=atol, + rtol=rtol + ) + return samples \ No newline at end of file diff --git a/transport/path.py b/transport/path.py new file mode 100644 index 0000000..156a7b0 --- /dev/null +++ b/transport/path.py @@ -0,0 +1,192 @@ +import torch as th +import numpy as np +from functools import partial + +def expand_t_like_x(t, x): + """Function to reshape time t to broadcastable dimension of x + Args: + t: [batch_dim,], time vector + x: [batch_dim,...], data point + """ + dims = [1] * (len(x.size()) - 1) + t = t.view(t.size(0), *dims) + return t + + +#################### Coupling Plans #################### + +class ICPlan: + """Linear Coupling Plan""" + def __init__(self, sigma=0.0): + self.sigma = sigma + + def compute_alpha_t(self, t): + """Compute the data coefficient along the path""" + return t, 1 + + def compute_sigma_t(self, t): + """Compute the noise coefficient along the path""" + return 1 - t, -1 + + def compute_d_alpha_alpha_ratio_t(self, t): + """Compute the ratio between d_alpha and alpha""" + return 1 / t + + def compute_drift(self, x, t): + """We always output sde according to score parametrization; """ + t = expand_t_like_x(t, x) + alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + drift = alpha_ratio * x + diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t + + return -drift, diffusion + + def compute_diffusion(self, x, t, form="constant", norm=1.0): + """Compute the diffusion term of the SDE + Args: + x: [batch_dim, ...], data point + t: [batch_dim,], time vector + form: str, form of the diffusion term + norm: float, norm of the diffusion term + """ + t = expand_t_like_x(t, x) + choices = { + "constant": norm, + "SBDM": norm * self.compute_drift(x, t)[1], + "sigma": norm * self.compute_sigma_t(t)[0], + "linear": norm * (1 - t), + "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, + "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, + } + + try: + diffusion = choices[form] + except KeyError: + raise NotImplementedError(f"Diffusion form {form} not implemented") + + return diffusion + + def get_score_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to score + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_noise_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to denoiser + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = reverse_alpha_ratio * d_sigma_t - sigma_t + noise = (reverse_alpha_ratio * velocity - mean) / var + return noise + + def get_velocity_from_score(self, score, x, t): + """Wrapper function: transfrom score prediction model to velocity + Args: + score: [batch_dim, ...] shaped tensor; score model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + drift, var = self.compute_drift(x, t) + velocity = var * score - drift + return velocity + + def compute_mu_t(self, t, x0, x1): + """Compute the mean of time-dependent density p_t""" + t = expand_t_like_x(t, x1) + alpha_t, _ = self.compute_alpha_t(t) + sigma_t, _ = self.compute_sigma_t(t) + return alpha_t * x1 + sigma_t * x0 + + def compute_xt(self, t, x0, x1): + """Sample xt from time-dependent density p_t; rng is required""" + xt = self.compute_mu_t(t, x0, x1) + return xt + + def compute_ut(self, t, x0, x1, xt): + """Compute the vector field corresponding to p_t""" + t = expand_t_like_x(t, x1) + _, d_alpha_t = self.compute_alpha_t(t) + _, d_sigma_t = self.compute_sigma_t(t) + return d_alpha_t * x1 + d_sigma_t * x0 + + def plan(self, t, x0, x1): + xt = self.compute_xt(t, x0, x1) + ut = self.compute_ut(t, x0, x1, xt) + return t, xt, ut + + +class VPCPlan(ICPlan): + """class for VP path flow matching""" + + def __init__(self, sigma_min=0.1, sigma_max=20.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min + self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min + + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = self.log_mean_coeff(t) + alpha_t = th.exp(alpha_t) + d_alpha_t = alpha_t * self.d_log_mean_coeff(t) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + p_sigma_t = 2 * self.log_mean_coeff(t) + sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) + d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return self.d_log_mean_coeff(t) + + def compute_drift(self, x, t): + """Compute the drift term of the SDE""" + t = expand_t_like_x(t, x) + beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) + return -0.5 * beta_t * x, beta_t / 2 + + +class GVPCPlan(ICPlan): + def __init__(self, sigma=0.0): + super().__init__(sigma) + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = th.sin(t * np.pi / 2) + d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + sigma_t = th.cos(t * np.pi / 2) + d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return np.pi / (2 * th.tan(t * np.pi / 2)) \ No newline at end of file diff --git a/transport/transport.py b/transport/transport.py new file mode 100644 index 0000000..3dc85ae --- /dev/null +++ b/transport/transport.py @@ -0,0 +1,445 @@ +import torch as th +import numpy as np +import logging + +import enum + +from . import path +from .utils import EasyDict, log_state, mean_flat +from .integrators import ode, sde +import torch +class ModelType(enum.Enum): + """ + Which type of output the model predicts. + """ + + NOISE = enum.auto() # the model predicts epsilon + SCORE = enum.auto() # the model predicts \nabla \log p(x) + VELOCITY = enum.auto() # the model predicts v(x) + +class PathType(enum.Enum): + """ + Which type of path to use. + """ + + LINEAR = enum.auto() + GVP = enum.auto() + VP = enum.auto() + +class WeightType(enum.Enum): + """ + Which type of weighting to use. + """ + + NONE = enum.auto() + VELOCITY = enum.auto() + LIKELIHOOD = enum.auto() + + +class Transport: + + def __init__( + self, + *, + model_type, + path_type, + loss_type, + train_eps, + sample_eps, + ): + path_options = { + PathType.LINEAR: path.ICPlan, + PathType.GVP: path.GVPCPlan, + PathType.VP: path.VPCPlan, + } + + self.loss_type = loss_type + self.model_type = model_type + self.path_sampler = path_options[path_type]() + self.train_eps = train_eps + self.sample_eps = sample_eps + + def prior_logp(self, z): + ''' + Standard multivariate normal prior + Assume z is batched + ''' + shape = th.tensor(z.size()) + N = th.prod(shape[1:]) + _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. + return th.vmap(_fn)(z) + + + def check_interval( + self, + train_eps, + sample_eps, + *, + diffusion_form="SBDM", + sde=False, + reverse=False, + eval=False, + last_step_size=0.0, + ): + t0 = 0 + t1 = 1 + eps = train_eps if not eval else sample_eps + if (type(self.path_sampler) in [path.VPCPlan]): + + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ + and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step + + t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + if reverse: + t0, t1 = 1 - t0, 1 - t1 + + return t0, t1 + + + def sample(self, x1): + """Sampling x0 & t based on shape of x1 (if needed) + Args: + x1 - data point; [batch, *dim] + """ + + x0 = th.randn_like(x1) + t0, t1 = self.check_interval(self.train_eps, self.sample_eps) + t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 + t = t.to(x1) + return t, x0, x1 + + + def training_losses( + self, + model, + x1, + model_kwargs=None + ): + """Loss for training the score model + Args: + - model: backbone model; could be score, noise, or velocity + - x1: datapoint + - model_kwargs: additional arguments for the model + """ + if model_kwargs == None: + model_kwargs = {} + + t, x0, x1 = self.sample(x1) + t, xt, ut = self.path_sampler.plan(t, x0, x1) + model_output = model(xt, t, **model_kwargs) + B, *_, C = xt.shape + assert model_output.size() == (B, *xt.size()[1:-1], C) + + terms = {} + terms['pred'] = model_output + if self.model_type == ModelType.VELOCITY: + # mse_loss = mean_flat(((model_output - ut) ** 2)) + # directional_loss = torch.nn.functional.cosine_similarity(model_output.reshape(B, -1), ut.reshape(B, -1), dim=1).mean() + terms['loss'] = mean_flat(((model_output - ut) ** 2)) + else: + _, drift_var = self.path_sampler.compute_drift(xt, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) + if self.loss_type in [WeightType.VELOCITY]: + weight = (drift_var / sigma_t) ** 2 + elif self.loss_type in [WeightType.LIKELIHOOD]: + weight = drift_var / (sigma_t ** 2) + elif self.loss_type in [WeightType.NONE]: + weight = 1 + else: + raise NotImplementedError() + + if self.model_type == ModelType.NOISE: + terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) + else: + terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) + + return terms + + + def get_drift( + self + ): + """member function for obtaining the drift of the probability flow ODE""" + def score_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + model_output = model(x, t, **model_kwargs) + return (-drift_mean + drift_var * model_output) # by change of variable + + def noise_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) + model_output = model(x, t, **model_kwargs) + score = model_output / -sigma_t + return (-drift_mean + drift_var * score) + + def velocity_ode(x, t, model, **model_kwargs): + model_output = model(x, t, **model_kwargs) + return model_output + + if self.model_type == ModelType.NOISE: + drift_fn = noise_ode + elif self.model_type == ModelType.SCORE: + drift_fn = score_ode + else: + drift_fn = velocity_ode + + def body_fn(x, t, model, **model_kwargs): + model_output = drift_fn(x, t, model, **model_kwargs) + assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" + return model_output + + return body_fn + + + def get_score( + self, + ): + """member function for obtaining score of + x_t = alpha_t * x + sigma_t * eps""" + if self.model_type == ModelType.NOISE: + score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] + elif self.model_type == ModelType.SCORE: + score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) + elif self.model_type == ModelType.VELOCITY: + score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) + else: + raise NotImplementedError() + + return score_fn + + +class Sampler: + """Sampler class for the transport model""" + def __init__( + self, + transport, + ): + """Constructor for a general sampler; supporting different sampling methods + Args: + - transport: an tranport object specify model prediction & interpolant type + """ + + self.transport = transport + self.drift = self.transport.get_drift() + self.score = self.transport.get_score() + + def __get_sde_diffusion_and_drift( + self, + *, + diffusion_form="SBDM", + diffusion_norm=1.0, + ): + + def diffusion_fn(x, t): + diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) + return diffusion + + sde_drift = \ + lambda x, t, model, **kwargs: \ + self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) + + sde_diffusion = diffusion_fn + + return sde_drift, sde_diffusion + + def __get_last_step( + self, + sde_drift, + *, + last_step, + last_step_size, + ): + """Get the last step function of the SDE solver""" + + if last_step is None: + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + elif last_step == "Mean": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + sde_drift(x, t, model, **model_kwargs) * last_step_size + elif last_step == "Tweedie": + alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long + sigma = self.transport.path_sampler.compute_sigma_t + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) + elif last_step == "Euler": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + self.drift(x, t, model, **model_kwargs) * last_step_size + else: + raise NotImplementedError() + + return last_step_fn + + def sample_sde( + self, + *, + sampling_method="Euler", + diffusion_form="SBDM", + diffusion_norm=1.0, + last_step="Mean", + last_step_size=0.04, + num_steps=250, + ): + """returns a sampling function with given SDE settings + Args: + - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama + - diffusion_form: function form of diffusion coefficient; default to be matching SBDM + - diffusion_norm: function magnitude of diffusion coefficient; default to 1 + - last_step: type of the last step; default to identity + - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] + - num_steps: total integration step of SDE + """ + + if last_step is None: + last_step_size = 0.0 + + sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( + diffusion_form=diffusion_form, + diffusion_norm=diffusion_norm, + ) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + diffusion_form=diffusion_form, + sde=True, + eval=True, + reverse=False, + last_step_size=last_step_size, + ) + + _sde = sde( + sde_drift, + sde_diffusion, + t0=t0, + t1=t1, + num_steps=num_steps, + sampler_type=sampling_method + ) + + last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) + + + def _sample(init, model, **model_kwargs): + xs = _sde.sample(init, model, **model_kwargs) + ts = th.ones(init.size(0), device=init.device) * t1 + x = last_step_fn(xs[-1], ts, model, **model_kwargs) + xs.append(x) + + assert len(xs) == num_steps, "Samples does not match the number of steps" + + return xs + + return _sample + + def sample_ode( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + reverse=False, + ): + """returns a sampling function with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + - reverse: whether solving the ODE in reverse (data to noise); default to False + """ + if reverse: + drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) + else: + drift = self.drift + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=reverse, + last_step_size=0.0, + ) + + _ode = ode( + drift=drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + return _ode.sample + + def sample_ode_likelihood( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + ): + + """returns a sampling function for calculating likelihood with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + """ + def _likelihood_drift(x, t, model, **model_kwargs): + x, _ = x + eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 + t = th.ones_like(t) * (1 - t) + with th.enable_grad(): + x.requires_grad = True + grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] + logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) + drift = self.drift(x, t, model, **model_kwargs) + return (-drift, logp_grad) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=False, + last_step_size=0.0, + ) + + _ode = ode( + drift=_likelihood_drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + def _sample_fn(x, model, **model_kwargs): + init_logp = th.zeros(x.size(0)).to(x) + input = (x, init_logp) + drift, delta_logp = _ode.sample(input, model, **model_kwargs) + drift, delta_logp = drift[-1], delta_logp[-1] + prior_logp = self.transport.prior_logp(drift) + logp = prior_logp - delta_logp + return logp, drift + + return _sample_fn \ No newline at end of file diff --git a/transport/utils.py b/transport/utils.py new file mode 100644 index 0000000..4464603 --- /dev/null +++ b/transport/utils.py @@ -0,0 +1,29 @@ +import torch as th + +class EasyDict: + + def __init__(self, sub_dict): + for k, v in sub_dict.items(): + setattr(self, k, v) + + def __getitem__(self, key): + return getattr(self, key) + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return th.mean(x, dim=list(range(1, len(x.size())))) + +def log_state(state): + result = [] + + sorted_state = dict(sorted(state.items())) + for key, value in sorted_state.items(): + # Check if the value is an instance of a class + if "