-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathencoding.py
More file actions
43 lines (37 loc) · 1.18 KB
/
encoding.py
File metadata and controls
43 lines (37 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int):
super().__init__()
positions = torch.arange(max_len, dtype=torch.float32)
pe = (
get_1d_sincos_pos_embed_from_grid(
embed_dim=d_model, pos=positions, output_type='pt'
)
.unsqueeze(0)
.float()
)
self.register_buffer('pe', pe)
def forward(self, seq_len: int):
return self.pe[:, :seq_len]
class TimestepEncoding(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
# t: [B] or [B,1]
if t.dim() > 1 and t.size(-1) == 1:
t = t.squeeze(-1)
emb = get_timestep_embedding(
timesteps=t,
embedding_dim=self.frequency_embedding_size,
scale=1,
)
return self.mlp(emb)