Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.lz4 filter=lfs diff=lfs merge=lfs -text
*.mds filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
# Audio files - uncompressed
*.pcm filter=lfs diff=lfs merge=lfs -text
*.sam filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
# Audio files - compressed
*.aac filter=lfs diff=lfs merge=lfs -text
*.flac filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
# Image files - uncompressed
*.bmp filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
# Image files - compressed
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.webp filter=lfs diff=lfs merge=lfs -text
# Video files - compressed
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,12 @@ pretrained_ckpts
*.safetensors
preprocessed_dataset
wandb
logs
logs
*.pkl
*.pt
tokenize_dataset/*
inet/**
cosmos_ckpt/**
wandb/**
*.jpg
*.jpeg
Empty file added cosmos/__init__.py
Empty file.
128 changes: 128 additions & 0 deletions cosmos/image_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library for image tokenizers inference."""

import numpy as np
import torch
from typing import Any

from cosmos.utils import (
load_model,
load_encoder_model,
load_decoder_model,
numpy2tensor,
pad_image_batch,
tensor2numpy,
unpad_image_batch,
)


class ImageTokenizer(torch.nn.Module):
def __init__(
self,
checkpoint: str = None,
checkpoint_enc: str = None,
checkpoint_dec: str = None,
tokenizer_config: dict[str, Any] = None,
device: str = "cuda",
dtype: str = "bfloat16",
) -> None:
super().__init__()
self._device = device
self._dtype = getattr(torch, dtype)
self._full_model = (
load_model(checkpoint, tokenizer_config, device).to(self._dtype)
if checkpoint is not None
else None
)
self._enc_model = (
load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
if checkpoint_enc is not None
else None
)
self._dec_model = (
load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
if checkpoint_dec is not None
else None
)

@torch.no_grad()
def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Reconstrcuts a batch of image tensors after embedding into a latent.

Args:
input_tensor: The input image Bx3xHxW layout, range [-1..1].
Returns:
The reconstructed tensor, layout Bx3xHxW, range [-1..1].
"""
if self._full_model is not None:
output_tensor = self._full_model(input_tensor)
output_tensor = (
output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
)
else:
output_latent = self.encode(input_tensor)[0]
output_tensor = self.decode(output_latent)
return output_tensor

@torch.no_grad()
def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
"""Decodes an image from a provided latent embedding.

Args:
input_latent: The continuous latent Bx16xhxw for CI,
or the discrete indices Bxhxw for DI.
Returns:
The output tensor in Bx3xHxW, range [-1..1].
"""
return self._dec_model(input_latent)

@torch.no_grad()
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
"""Encodes an image into a latent embedding or code.

Args:
input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
Returns:
For continuous image (CI) tokenizer, the tuple contains:
- The latent embedding, Bx16x(h)x(w), where the compression
rate is (H/h x W/w), and channel dimension of 16.
For discrete image (DI) tokenizer, the tuple contains:
- The indices, Bx(h)x(w), from a codebook of size 64K, which
corresponds to FSQ levels of (8,8,8,5,5,5).
- The discrete code, Bx6x(h)x(w), where the compression rate is
again (H/h x W/w), and channel dimension of 6.
"""
output_latent = self._enc_model(input_tensor)
if isinstance(output_latent, torch.Tensor):
return output_latent
return output_latent[:-1]

@torch.no_grad()
def forward(self, image: np.ndarray) -> np.ndarray:
"""Reconstructs an image using a pre-trained tokenizer.

Args:
image: The input image BxHxWxC layout, range [0..255].
Returns:
The reconstructed image in range [0..255], layout BxHxWxC.
"""
padded_input_image, crop_region = pad_image_batch(image)
input_tensor = numpy2tensor(
padded_input_image, dtype=self._dtype, device=self._device
)
output_tensor = self.autoencode(input_tensor)
padded_output_image = tensor2numpy(output_tensor)
return unpad_image_batch(padded_output_image, crop_region)
63 changes: 63 additions & 0 deletions cosmos/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum

from cosmos.modules.distributions import (
GaussianDistribution,
IdentityDistribution,
)
from cosmos.modules.layers2d import Decoder, Encoder
from cosmos.modules.layers3d import (
DecoderBase,
DecoderFactorized,
EncoderBase,
EncoderFactorized,
)
from cosmos.modules.quantizers import (
FSQuantizer,
LFQuantizer,
ResidualFSQuantizer,
VectorQuantizer,
)


class EncoderType(Enum):
Default = Encoder


class DecoderType(Enum):
Default = Decoder


class Encoder3DType(Enum):
BASE = EncoderBase
FACTORIZED = EncoderFactorized


class Decoder3DType(Enum):
BASE = DecoderBase
FACTORIZED = DecoderFactorized


class ContinuousFormulation(Enum):
VAE = GaussianDistribution
AE = IdentityDistribution


class DiscreteQuantizer(Enum):
VQ = VectorQuantizer
LFQ = LFQuantizer
FSQ = FSQuantizer
RESFSQ = ResidualFSQuantizer
41 changes: 41 additions & 0 deletions cosmos/modules/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The distribution modes to use for continuous image tokenizers."""

import torch


class IdentityDistribution(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, parameters):
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))


class GaussianDistribution(torch.nn.Module):
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
super().__init__()
self.min_logvar = min_logvar
self.max_logvar = max_logvar

def sample(self, mean, logvar):
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)

def forward(self, parameters):
mean, logvar = torch.chunk(parameters, 2, dim=1)
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
return self.sample(mean, logvar), (mean, logvar)
Loading