Skip to content

Commit 670abfb

Browse files
committed
feat(attention): implement lightweight MultiheadAttention using PyTorch SDPA
- Uses PyTorch scaled_dot_product_attention with Flash Attention dispatch - More efficient than naive attention; allows for longer sequences (larger T) or bigger batches before OOM
1 parent f84b06e commit 670abfb

3 files changed

Lines changed: 130 additions & 4 deletions

File tree

nanotabpfn/attention.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# attention.py
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from typing import Tuple
6+
7+
8+
class MultiheadAttention(nn.Module):
9+
"""
10+
Minimal Multi-Head Attention using PyTorch's scaled_dot_product_attention (SDPA).
11+
12+
This implementation benefits from PyTorch's automatic dispatch:
13+
- On CUDA with supported dtypes (fp16, bf16, fp32) and head_dim <= 128,
14+
it uses **Flash Attention** kernels for maximum efficiency.
15+
- Otherwise, it falls back to the memory-efficient or math kernel.
16+
17+
Tensor shape notation:
18+
B = Batch size
19+
T = Sequence length
20+
E = Embedding dimension
21+
H = Number of attention heads
22+
D = Per-head dimension (D = E / H)
23+
24+
Parameters
25+
----------
26+
embed_dim : int
27+
Input/output embedding size (E).
28+
num_heads : int
29+
Number of attention heads (H). Must divide embed_dim.
30+
batch_first : bool, default True
31+
If True, input/output is (B, T, E). If False, (T, B, E).
32+
bias : bool, default True
33+
Include bias terms in the q/k/v/out projections.
34+
device, dtype : Optional
35+
Device and dtype.
36+
"""
37+
38+
def __init__(
39+
self,
40+
embed_dim: int,
41+
num_heads: int,
42+
batch_first: bool = True,
43+
bias: bool = True,
44+
device: torch.device = None,
45+
dtype: torch.dtype = None,
46+
):
47+
super().__init__()
48+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
49+
self.embed_dim = embed_dim
50+
self.num_heads = num_heads
51+
self.head_dim = embed_dim // num_heads
52+
self.batch_first = batch_first
53+
54+
fw = {"device": device, "dtype": dtype}
55+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **fw)
56+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **fw)
57+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **fw)
58+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **fw)
59+
60+
def forward(
61+
self,
62+
query: torch.Tensor,
63+
key: torch.Tensor,
64+
value: torch.Tensor,
65+
) -> Tuple[torch.Tensor, None]:
66+
"""
67+
Compute multi-head attention.
68+
69+
Uses PyTorch's scaled_dot_product_attention (SDPA), which
70+
automatically dispatches to the **Flash Attention kernel** when available.
71+
72+
Args
73+
----
74+
query : Tensor
75+
(B, Tq, E) if batch_first else (Tq, B, E)
76+
key : Tensor
77+
(B, Tk, E) if batch_first else (Tk, B, E)
78+
value : Tensor
79+
(B, Tk, E) if batch_first else (Tk, B, E)
80+
81+
Returns
82+
-------
83+
attn_output : Tensor
84+
Same layout as input (batch_first preserved).
85+
None :
86+
Placeholder for attention weights (not computed).
87+
"""
88+
if not self.batch_first:
89+
# convert (T, B, E) -> (B, T, E)
90+
query = query.transpose(0, 1)
91+
key = key.transpose(0, 1)
92+
value = value.transpose(0, 1)
93+
94+
# Allow for different sequence lengths in query and key/value
95+
B, Tq, _ = query.shape
96+
Tk = key.shape[1]
97+
98+
# Linear projections
99+
q = self.q_proj(query) # (B, Tq, E)
100+
k = self.k_proj(key) # (B, Tk, E)
101+
v = self.v_proj(value) # (B, Tk, E)
102+
103+
# (B, T, E) -> (B, H, T, D), where D = E / H
104+
H, D = self.num_heads, self.head_dim
105+
q = q.view(B, Tq, H, D).transpose(1, 2) # (B, H, Tq, D)
106+
k = k.view(B, Tk, H, D).transpose(1, 2) # (B, H, Tk, D)
107+
v = v.view(B, Tk, H, D).transpose(1, 2) # (B, H, Tk, D)
108+
109+
# SDPA: Flash Attention efficiency when available
110+
attn = F.scaled_dot_product_attention(
111+
q, k, v,
112+
attn_mask=None,
113+
dropout_p=0.0,
114+
is_causal=False,
115+
) # (B, H, Tq, D)
116+
117+
# (B, H, Tq, D) -> (B, Tq, E)
118+
attn = attn.transpose(1, 2).contiguous().view(B, Tq, H * D)
119+
out = self.out_proj(attn) # (B, Tq, E)
120+
121+
if not self.batch_first:
122+
# convert back (B, T, E) -> (T, B, E)
123+
out = out.transpose(0, 1)
124+
# None placeholder for attention weights (not computed)
125+
return out, None

nanotabpfn/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def init_model_from_state_dict_file(file_path):
2424
embedding_size = state_dict['feature_encoder.linear_layer.weight'].shape[0]
2525
mlp_hidden_size = state_dict['decoder.linear1.weight'].shape[0]
2626
num_outputs = state_dict['decoder.linear2.weight'].shape[0]
27-
num_layers = sum('self_attn_between_datapoints.in_proj_weight' in k for k in state_dict)
28-
num_heads = state_dict['transformer_encoder.transformer_blocks.0.self_attn_between_datapoints.in_proj_weight'].shape[1]//64
27+
num_layers = sum('self_attn_between_datapoints.q_proj.weight' in k for k in state_dict)
28+
num_heads = state_dict['transformer_encoder.transformer_blocks.0.self_attn_between_datapoints.q_proj.weight'].shape[1]//64
2929
model = NanoTabPFNModel(
3030
num_attention_heads=num_heads,
3131
embedding_size=embedding_size,

nanotabpfn/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from torch import nn
33
import torch.nn.functional as F
4-
from torch.nn.modules.transformer import MultiheadAttention, Linear, LayerNorm
5-
from typing import Tuple, Union
4+
from torch.nn.modules.transformer import Linear, LayerNorm
5+
from .attention import MultiheadAttention
6+
from typing import Tuple
67

78

89
class NanoTabPFNModel(nn.Module):

0 commit comments

Comments
 (0)