Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dia/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class DiaConfig(BaseModel, frozen=True):
version: str = Field(default="1.0")
model: ModelConfig
# TODO: remove training. this is just for backward compatibility
training: TrainingConfig
training: TrainingConfig | None = Field(default=None)
data: DataConfig

def save(self, path: str) -> None:
Expand Down
48 changes: 21 additions & 27 deletions dia/layers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from torch.nn import RMSNorm

from huggingface_hub import PyTorchModelHubMixin

from .config import DiaConfig
from .state import DecoderInferenceState, EncoderInferenceState, KVCache

Expand Down Expand Up @@ -48,7 +47,6 @@ def __init__(

factory_kwargs = {"device": device, "dtype": weight_dtype}
self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
self.register_parameter("bias", None)

def forward(self, inputs: Tensor) -> Tensor:
norm_axis = _normalize_axes(self.axis, inputs.ndim)
Expand Down Expand Up @@ -112,31 +110,23 @@ def __init__(
self.embedding_dims = embedding_dims
self.min_timescale = min_timescale
self.max_timescale = max_timescale
self.dtype = dtype
self.compute_dtype = dtype

half_embedding_dim = embedding_dims // 2
fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
self.register_buffer(
"timescale",
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
persistent=False,
)

def extra_repr(self) -> str:
s = f"{self.timescale.shape}"
return s
timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(torch.float32)
self.register_buffer("timescale", timescale, persistent=False)

def forward(self, inputs: torch.Tensor, position: torch.Tensor):
"""Applies RoPE."""
position = position.unsqueeze(-1).unsqueeze(-1)
timescale = self.timescale.to(inputs.device)
sinusoid_inp = position / timescale
sin = torch.sin(sinusoid_inp).to(inputs.dtype)
cos = torch.cos(sinusoid_inp).to(inputs.dtype)
first_half, second_half = torch.chunk(inputs, 2, dim=-1)
sinusoid_inp = position / self.timescale
sin = torch.sin(sinusoid_inp)
cos = torch.cos(sinusoid_inp)
first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
return torch.cat((first_part, second_part), dim=-1)
return torch.cat((first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), dim=-1)


class Attention(nn.Module):
Expand Down Expand Up @@ -283,6 +273,7 @@ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
model_config = config.model
enc_config = config.model.encoder
embed_dim = enc_config.n_embd
self.compute_dtype = compute_dtype

self.pre_sa_norm = RMSNorm(
embed_dim,
Expand Down Expand Up @@ -313,7 +304,8 @@ def forward(
state: EncoderInferenceState,
) -> torch.Tensor:
residual = x
x_norm = self.pre_sa_norm(x)
x_norm = self.pre_sa_norm(x).to(self.compute_dtype)

sa_out = self.self_attention(
Xq=x_norm,
Xkv=x_norm,
Expand All @@ -324,7 +316,7 @@ def forward(
x = residual + sa_out

residual = x
x_norm = self.post_sa_norm(x)
x_norm = self.post_sa_norm(x).to(self.compute_dtype)
mlp_out = self.mlp(x_norm)
x = residual + mlp_out

Expand All @@ -339,6 +331,7 @@ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
self.config = config
model_config = config.model
enc_config = config.model.encoder
self.compute_dtype = compute_dtype

self.embedding = nn.Embedding(
model_config.src_vocab_size,
Expand All @@ -362,7 +355,7 @@ def forward(
for layer in self.layers:
x = layer(x, state)

x = self.norm(x)
x = self.norm(x).to(self.compute_dtype)
return x


Expand All @@ -377,6 +370,7 @@ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
enc_config = config.model.encoder
dec_embed_dim = dec_config.n_embd
enc_embed_dim = enc_config.n_embd
self.compute_dtype = compute_dtype

# Norms
self.pre_sa_norm = RMSNorm(
Expand Down Expand Up @@ -435,7 +429,7 @@ def forward(
prefill: bool = False,
) -> torch.Tensor:
residual = x
x_norm = self.pre_sa_norm(x)
x_norm = self.pre_sa_norm(x).to(self.compute_dtype)

sa_out = self.self_attention(
Xq=x_norm, # (2, 1, D)
Expand All @@ -451,7 +445,7 @@ def forward(
x = residual + sa_out

residual = x
x_norm = self.pre_ca_norm(x)
x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
ca_out = self.cross_attention(
Xq=x_norm,
Xkv=state.enc_out,
Expand All @@ -463,7 +457,7 @@ def forward(
x = residual + ca_out

residual = x
x_norm = self.pre_mlp_norm(x)
x_norm = self.pre_mlp_norm(x).to(self.compute_dtype)
mlp_out = self.mlp(x_norm)
x = residual + mlp_out

Expand Down Expand Up @@ -616,8 +610,8 @@ class DiaModel(
license="apache-2.0",
coders={
DiaConfig: (
lambda x: x.dict(),
lambda data: DiaConfig.model_validate(**data),
lambda x: x.model_dump(),
lambda data: DiaConfig.model_validate(data),
),
},
):
Expand Down
4 changes: 3 additions & 1 deletion dia/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def from_pretrained(
FileNotFoundError: If config or checkpoint download/loading fails.
RuntimeError: If there is an error loading the checkpoint.
"""
loaded_model = DiaModel.from_pretrained(model_name)
if isinstance(compute_dtype, str):
compute_dtype = ComputeDtype(compute_dtype)
loaded_model = DiaModel.from_pretrained(model_name, compute_dtype=compute_dtype.to_dtype())
config = loaded_model.config
dia = cls(config, compute_dtype, device)

Expand Down
8 changes: 6 additions & 2 deletions dia/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceStat
"""Creates EtorchrInferenceParams from DiaConfig and a device."""
device = cond_src.device

positions = torch.arange(config.data.text_length, device=device).to(torch.long).unsqueeze(0).expand(2, -1)
positions = (
torch.arange(config.data.text_length, dtype=torch.float32, device=device).unsqueeze(0).expand(2, -1)
)
padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
attn_mask = create_attn_mask(padding_mask, padding_mask, device, is_causal=False)

Expand Down Expand Up @@ -162,7 +164,9 @@ def new(
def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
if step_to is None:
step_to = step_from + 1
self.dec_positions = torch.arange(step_from, step_to, device=self.device).unsqueeze(0).expand(2, -1)
self.dec_positions = (
torch.arange(step_from, step_to, dtype=torch.float32, device=self.device).unsqueeze(0).expand(2, -1)
)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion example/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."

output = model.generate(text, use_torch_compile=True, verbose=True)
output = model.generate(text, use_torch_compile=False, verbose=True)

model.save_audio("simple.mp3", output)
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ dependencies = [
"huggingface-hub>=0.30.2",
"numpy>=2.2.4",
"pydantic>=2.11.3",
"safetensors>=0.5.3",
"soundfile>=0.13.1",
"torch>=2.6.0",
"torch==2.6.0",
"torchaudio>=2.6.0",
"triton>=3.2.0 ; sys_platform == 'linux'",
"triton-windows>=3.2.0.post18 ; sys_platform == 'win32'",
Expand Down Expand Up @@ -57,3 +58,9 @@ torchaudio = [
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
explicit = true

[dependency-groups]
dev = [
"ninja>=1.11.1.4",
"packaging>=25.0",
]
Loading