Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 235 additions & 34 deletions dia/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
second_part = second_half * cos + first_half * sin
return torch.cat((first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), dim=-1)

def apply_rope(self, inputs: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor):
first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
return torch.cat((first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), dim=-1)


def custom_scaled_dot_product_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -182,7 +188,164 @@ def custom_scaled_dot_product_attention(
return output


class Attention(nn.Module):
class CrossAttention(nn.Module):
"""Cross-Attention using DenseGeneral."""

def __init__(
self,
config: DiaConfig,
q_embed_dim: int,
kv_embed_dim: int,
num_query_heads: int,
num_kv_heads: int,
head_dim: int,
compute_dtype: torch.dtype,
out_embed_dim: int | None = None,
):
super().__init__()
self.num_query_heads = num_query_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
self.projected_query_dim = num_query_heads * head_dim
if num_query_heads % num_kv_heads != 0:
raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
self.num_gqa_groups = num_query_heads // num_kv_heads

# --- Projection Layers using DenseGeneral ---
self.q_proj = DenseGeneral(
in_shapes=(q_embed_dim,),
out_features=(num_query_heads, head_dim),
axis=(-1,),
weight_dtype=compute_dtype,
)
self.k_proj = DenseGeneral(
in_shapes=(kv_embed_dim,),
out_features=(num_kv_heads, head_dim),
axis=(-1,),
weight_dtype=compute_dtype,
)
self.v_proj = DenseGeneral(
in_shapes=(kv_embed_dim,),
out_features=(num_kv_heads, head_dim),
axis=(-1,),
weight_dtype=compute_dtype,
)
self.o_proj = DenseGeneral(
in_shapes=(num_query_heads, head_dim),
out_features=(self.output_dim,),
axis=(-2, -1),
weight_dtype=compute_dtype,
)

# --- Rotary Embedding ---
self.rotary_emb = RotaryEmbedding(
embedding_dims=self.head_dim,
min_timescale=config.model.rope_min_timescale,
max_timescale=config.model.rope_max_timescale,
dtype=compute_dtype,
)

def forward(
self,
Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
q_positions: torch.Tensor, # (B, T)
kv_positions: torch.Tensor | None = None, # (B, S)
attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
is_causal: bool = False,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
"""
Performs attention calculation with optional KV caching.

Args:
Xq: Query tensor (B, T, D). T=1 during single-step decoding.
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
q_positions: Positions for queries (B, T).
kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
attn_mask: Attention mask.
cache: KVCache.

Returns:
A tuple containing:
- output: The attention output tensor (B, T, output_dim).
- present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
"""
if kv_positions is None:
kv_positions = q_positions
original_dtype = Xq.dtype

Xq_BxTxNxH = self.q_proj(Xq)
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)

attn_k: torch.Tensor | None = None
attn_v: torch.Tensor | None = None

attn_k, attn_v = cache.k, cache.v

# Use custom attention for MPS backend, otherwise use optimized PyTorch function
is_mps = Xq.device.type == "mps" and torch.backends.mps.is_available()
if is_mps:
attn_output = custom_scaled_dot_product_attention(
query=Xq_BxNxTxH,
key=attn_k,
value=attn_v,
attn_mask=attn_mask if not is_causal else None,
scale=1.0,
is_causal=is_causal,
num_gqa_groups=self.num_gqa_groups,
)
else:
attn_output = F.scaled_dot_product_attention(
Xq_BxNxTxH,
attn_k,
attn_v,
attn_mask=attn_mask if not is_causal else None,
scale=1.0,
enable_gqa=self.num_gqa_groups > 1,
is_causal=is_causal,
)

attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
output = self.o_proj(attn_output)

return output.to(original_dtype)


class FusedQKV(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
num_q_heads: int = 1,
q_head_dim: int = 1,
num_kv_heads: int = 1,
kv_head_dim: int = 1,
):
super().__init__()
self.num_q_heads = num_q_heads
self.q_head_dim = q_head_dim
self.num_kv_heads = num_kv_heads
self.kv_head_dim = kv_head_dim
self.q_output_dim = num_q_heads * q_head_dim
self.kv_output_dim = num_kv_heads * kv_head_dim
self.linear = nn.Linear(in_features, out_features, bias=bias)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = self.linear(inputs)

q, k, v = x.split([self.q_output_dim, self.kv_output_dim, self.kv_output_dim], dim=-1)

q = q.reshape(q.shape[:-1] + (self.num_q_heads, self.q_head_dim))
k = k.reshape(k.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
v = v.reshape(v.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))

return q, k, v


class SelfAttention(nn.Module):
"""Attention using DenseGeneral."""

def __init__(
Expand All @@ -207,6 +370,8 @@ def __init__(
if num_query_heads % num_kv_heads != 0:
raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
self.num_gqa_groups = num_query_heads // num_kv_heads
self.kv_embed_dim = kv_embed_dim
self.q_embed_dim = q_embed_dim

# --- Projection Layers using DenseGeneral ---
self.q_proj = DenseGeneral(
Expand Down Expand Up @@ -242,10 +407,44 @@ def __init__(
dtype=compute_dtype,
)

self.is_fused_qkv = False

def get_linear_weight(self, dense: DenseGeneral):
W_dg = dense.weight.data

out_features = 1
input_features = 1
for dim in dense.out_features:
out_features *= dim
for dim in dense.in_shapes:
input_features *= dim

W_dg_reshaped_for_linear_T = W_dg.reshape(input_features, out_features)
linear_weight = W_dg_reshaped_for_linear_T.transpose(0, 1).contiguous()
return linear_weight

def patch_fused_qkv(self):
q_proj_weight = self.get_linear_weight(self.q_proj)
k_proj_weight = self.get_linear_weight(self.k_proj)
v_proj_weight = self.get_linear_weight(self.v_proj)

self.qkv = FusedQKV(
self.kv_embed_dim,
(self.num_query_heads * self.head_dim + 2 * (self.num_kv_heads * self.head_dim)),
bias=False,
num_q_heads=self.num_query_heads,
q_head_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
kv_head_dim=self.head_dim,
)
self.qkv.linear.weight.data = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)

# print(f"qkv.weight.shape: {self.qkv.linear.weight.shape}")
self.is_fused_qkv = True

def forward(
self,
Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
X: torch.Tensor, # (B, T, D) T = 1 in AR generation
q_positions: torch.Tensor, # (B, T)
kv_positions: torch.Tensor | None = None, # (B, S)
attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
Expand Down Expand Up @@ -273,36 +472,43 @@ def forward(
"""
if kv_positions is None:
kv_positions = q_positions
original_dtype = Xq.dtype

Xq_BxTxNxH = self.q_proj(Xq)
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
original_dtype = X.dtype

if self.is_fused_qkv:
Xq_BxTxNxH, Xk_BxSxKxH, Xv_BxSxKxH = self.qkv(X)
else:
Xq_BxTxNxH = self.q_proj(X)
Xk_BxSxKxH = self.k_proj(X)
Xv_BxSxKxH = self.v_proj(X)

position = q_positions.unsqueeze(-1).unsqueeze(-1)
sinusoid_inp = position / self.rotary_emb.timescale
sin = torch.sin(sinusoid_inp)
cos = torch.cos(sinusoid_inp)

Xq_BxTxNxH = self.rotary_emb.apply_rope(Xq_BxTxNxH, sin, cos)
Xk_BxSxKxH = self.rotary_emb.apply_rope(Xk_BxSxKxH, sin, cos)

Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)

attn_k: torch.Tensor | None = None
attn_v: torch.Tensor | None = None

if self.is_cross_attn:
attn_k, attn_v = cache.k, cache.v
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)

if cache is None:
attn_k = Xk_BxKxSxH
attn_v = Xv_BxKxSxH
elif prefill:
attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
cache.prefill(attn_k, attn_v)
else:
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)

Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)

if cache is None:
attn_k = Xk_BxKxSxH
attn_v = Xv_BxKxSxH
elif prefill:
attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
cache.prefill(attn_k, attn_v)
else:
attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH, current_idx)
attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH, current_idx)

# Use custom attention for MPS backend, otherwise use optimized PyTorch function
is_mps = Xq.device.type == "mps" and torch.backends.mps.is_available()
is_mps = Xv_BxSxKxH.device.type == "mps" and torch.backends.mps.is_available()
if is_mps:
attn_output = custom_scaled_dot_product_attention(
query=Xq_BxNxTxH,
Expand Down Expand Up @@ -346,7 +552,7 @@ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
eps=model_config.normalization_layer_epsilon,
dtype=torch.float32,
)
self.self_attention = Attention(
self.self_attention = SelfAttention(
config,
q_embed_dim=embed_dim,
kv_embed_dim=embed_dim,
Expand All @@ -373,8 +579,7 @@ def forward(
x_norm = self.pre_sa_norm(x).to(self.compute_dtype)

sa_out = self.self_attention(
Xq=x_norm,
Xkv=x_norm,
X=x_norm,
q_positions=state.positions,
kv_positions=state.positions,
attn_mask=state.attn_mask,
Expand Down Expand Up @@ -456,7 +661,7 @@ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
)

# Self-Attention (GQA) with Causal Masking
self.self_attention = Attention(
self.self_attention = SelfAttention(
config,
q_embed_dim=dec_embed_dim,
kv_embed_dim=dec_embed_dim,
Expand All @@ -468,15 +673,14 @@ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
out_embed_dim=dec_embed_dim,
)
# Cross-Attention (MHA)
self.cross_attention = Attention(
self.cross_attention = CrossAttention(
config=config,
q_embed_dim=dec_embed_dim,
kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
num_query_heads=dec_config.cross_query_heads,
num_kv_heads=dec_config.cross_query_heads,
head_dim=dec_config.cross_head_dim,
compute_dtype=compute_dtype,
is_cross_attn=True,
out_embed_dim=dec_embed_dim,
)
# MLP
Expand All @@ -501,8 +705,7 @@ def forward(
self_attn_mask = state.casual_attn_mask[None, None, current_idx]

sa_out = self.self_attention(
Xq=x_norm, # (2, 1, D)
Xkv=x_norm, # (2, 1, D)
X=x_norm, # (2, 1, D)
q_positions=state.dec_positions, # (2, 1)
kv_positions=state.dec_positions, # (2, 1)
attn_mask=self_attn_mask,
Expand All @@ -518,11 +721,9 @@ def forward(
x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
ca_out = self.cross_attention(
Xq=x_norm,
Xkv=state.enc_out,
q_positions=state.dec_positions,
kv_positions=state.enc_positions,
cache=cross_attn_cache,
current_idx=current_idx,
)
x = residual + ca_out

Expand Down
10 changes: 8 additions & 2 deletions dia/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def _prepare_generation(
self,
text: torch.Tensor,
audio_prompts: list[torch.Tensor | None],
max_tokens: int | None = None,
):
"""Initializes the model state for generation.

Expand Down Expand Up @@ -371,7 +372,12 @@ def _prepare_generation(
encoder_out, enc_state.positions, enc_state.padding_mask
)
dec_state = DecoderInferenceState.new(
self.config, enc_state, encoder_out, dec_cross_attn_cache, self.compute_dtype
self.config,
enc_state,
encoder_out,
dec_cross_attn_cache,
self.compute_dtype,
max_generation_length=max_tokens,
)
prefill, prefill_steps = self._prepare_audio_prompt(audio_prompts)

Expand Down Expand Up @@ -663,7 +669,7 @@ def generate(
text = [self._encode_text(text)]
text = self._pad_text_input(text)

dec_state, dec_output = self._prepare_generation(text, audio_prompt)
dec_state, dec_output = self._prepare_generation(text, audio_prompt, max_tokens=max_tokens)
dec_step = min(dec_output.prefill_steps) - 1
current_idx = torch.tensor([dec_step], device=self.device)

Expand Down
Loading