-
Notifications
You must be signed in to change notification settings - Fork 15.8k
Optimize Multi-head Latent Attention (MLA) with Fast Path for Short Sequences #684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,13 +85,6 @@ class ModelArgs: | |
|
||
|
||
class ParallelEmbedding(nn.Module): | ||
""" | ||
Embedding layer with parallelism support across distributed processes. | ||
|
||
Args: | ||
vocab_size (int): Vocabulary size. | ||
dim (int): Embedding dimension. | ||
""" | ||
def __init__(self, vocab_size: int, dim: int): | ||
super().__init__() | ||
self.vocab_size = vocab_size | ||
|
@@ -103,18 +96,6 @@ def __init__(self, vocab_size: int, dim: int): | |
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass for parallel embedding layer. | ||
|
||
Args: | ||
x (torch.Tensor): Input tensor containing token indices. | ||
|
||
Returns: | ||
torch.Tensor: Embedded representations. | ||
|
||
Raises: | ||
ValueError: If `world_size` is not defined. | ||
""" | ||
Comment on lines
-106
to
-117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstrings were removed. The PR description mentions enhancing documentation, so this seems contradictory. |
||
if world_size > 1: | ||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) | ||
x = x - self.vocab_start_idx | ||
|
@@ -162,15 +143,6 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = | |
|
||
|
||
class Linear(nn.Module): | ||
""" | ||
Custom linear layer with support for quantized weights and optional bias. | ||
|
||
Args: | ||
in_features (int): Number of input features. | ||
out_features (int): Number of output features. | ||
bias (bool): Whether to include a bias term. Defaults to False. | ||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. | ||
""" | ||
Comment on lines
-165
to
-173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstrings were removed. The PR description mentions enhancing documentation, so this seems contradictory. |
||
dtype = torch.bfloat16 | ||
|
||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): | ||
|
@@ -190,15 +162,6 @@ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtyp | |
self.register_parameter("bias", None) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass for the custom linear layer. | ||
|
||
Args: | ||
x (torch.Tensor): Input tensor. | ||
|
||
Returns: | ||
torch.Tensor: Transformed tensor after linear computation. | ||
""" | ||
Comment on lines
-193
to
-201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstrings were removed. The PR description mentions enhancing documentation, so this seems contradictory. |
||
return linear(x, self.weight, self.bias) | ||
|
||
|
||
|
@@ -440,7 +403,7 @@ def __init__(self, args: ModelArgs): | |
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) | ||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) | ||
|
||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: | ||
""" | ||
Forward pass for the Multi-Headed Attention Layer (MLA). | ||
|
||
|
@@ -453,45 +416,67 @@ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask | |
Returns: | ||
torch.Tensor: Output tensor with the same shape as the input. | ||
""" | ||
bsz, seqlen, _ = x.size() | ||
end_pos = start_pos + seqlen | ||
bsz, seqlen, _ = x.shape | ||
|
||
# Fast path for short sequences without masks | ||
use_fast_path = seqlen <= 256 and mask is None | ||
|
||
if self.q_lora_rank == 0: | ||
q = self.wq(x) | ||
else: | ||
q = self.wq_b(self.q_norm(self.wq_a(x))) | ||
|
||
kv_out = self.wkv_a(x) | ||
kv_pe, kv_in = kv_out[:, :, :self.qk_rope_head_dim], kv_out[:, :, self.qk_rope_head_dim:] | ||
kv_in = self.wkv_b(self.kv_norm(kv_in)) | ||
k_nope, v = kv_in[:, :, :self.n_local_heads*self.qk_nope_head_dim], kv_in[:, :, self.n_local_heads*self.qk_nope_head_dim:] | ||
|
||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) | ||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) | ||
q_pe = apply_rotary_emb(q_pe, freqs_cis) | ||
kv = self.wkv_a(x) | ||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) | ||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) | ||
k_nope = k_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim) | ||
v = v.view(bsz, seqlen, self.n_local_heads, self.v_head_dim) | ||
|
||
q_rope, q_nope = q[:, :, :, :self.qk_rope_head_dim], q[:, :, :, self.qk_rope_head_dim:] | ||
k_rope = kv_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim) | ||
|
||
if attn_impl == "naive": | ||
q = torch.cat([q_nope, q_pe], dim=-1) | ||
kv = self.wkv_b(self.kv_norm(kv)) | ||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) | ||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) | ||
self.k_cache[:bsz, start_pos:end_pos] = k | ||
self.v_cache[:bsz, start_pos:end_pos] = v | ||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale | ||
self.k_cache[: bsz, start_pos: start_pos + seqlen] = torch.cat([k_rope, k_nope], dim=-1) | ||
self.v_cache[: bsz, start_pos: start_pos + seqlen] = v | ||
k = self.k_cache[: bsz, : start_pos + seqlen] | ||
v = self.v_cache[: bsz, : start_pos + seqlen] | ||
else: | ||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) | ||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) | ||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) | ||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) | ||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) | ||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + | ||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale | ||
if mask is not None: | ||
scores += mask.unsqueeze(1) | ||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) | ||
if attn_impl == "naive": | ||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) | ||
self.kv_cache[: bsz, start_pos: start_pos + seqlen] = kv_in | ||
self.pe_cache[: bsz, start_pos: start_pos + seqlen] = kv_pe | ||
k = torch.cat([k_rope, k_nope], dim=-1) | ||
|
||
q = apply_rotary_emb(q_rope, freqs_cis) | ||
k = apply_rotary_emb(k_rope, freqs_cis) | ||
|
||
if use_fast_path: | ||
# Optimized path for short sequences | ||
q = q.transpose(1, 2) # [bsz, n_local_heads, seqlen, head_dim] | ||
k = k.transpose(1, 2) | ||
v = v.transpose(1, 2) | ||
|
||
# Single matmul for attention scores | ||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale | ||
scores = F.softmax(scores, dim=-1, dtype=torch.float32) | ||
|
||
# Single matmul for output computation | ||
output = torch.matmul(scores, v) | ||
else: | ||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) | ||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) | ||
x = self.wo(x.flatten(2)) | ||
return x | ||
# Standard path for longer sequences or when mask is needed | ||
q = q.transpose(1, 2) | ||
k = k.transpose(1, 2) | ||
v = v.transpose(1, 2) | ||
|
||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale | ||
if mask is not None: | ||
scores = scores + mask | ||
scores = F.softmax(scores, dim=-1, dtype=torch.float32) | ||
output = torch.matmul(scores, v) | ||
|
||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | ||
return self.wo(output) | ||
|
||
|
||
class MLP(nn.Module): | ||
|
@@ -757,7 +742,7 @@ def __init__(self, args: ModelArgs): | |
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 | ||
super().__init__() | ||
self.max_seq_len = args.max_seq_len | ||
self.embed = ParallelEmbedding(args.vocab_size, args.dim) | ||
self.embed = ParallelEmbedding(args.vocab_size, args.dim, memory_efficient=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change isn't mentioned in the description. |
||
self.layers = torch.nn.ModuleList() | ||
for layer_id in range(args.n_layers): | ||
self.layers.append(Block(layer_id, args)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings were removed. The PR description mentions enhancing documentation, so this seems contradictory. Could these be restored? Should be adding docstrings if you want to enhance them.