Skip to content

Commit b19faf6

Browse files
authored
#2-add time embedding (#4)
1 parent 1373ffb commit b19faf6

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

model/embedding.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class TimeEmbedding(nn.Module):
6+
7+
def __init__(self, embedding_dim: int):
8+
super().__init__()
9+
self.embedding_dim = embedding_dim
10+
self.half_embedding_dim = embedding_dim // 2
11+
12+
@torch.no_grad()
13+
def forward(self, x: torch.Tensor) -> torch.Tensor:
14+
exponents = (
15+
torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.halv_embedding_dim
16+
).to(x.device)
17+
factors = 10_000 ** exponents
18+
19+
embedding_arguments = x[:, None].repeat(1, self.half_embedding_dim) / factors
20+
embeddings = torch.cat([torch.sin(embedding_arguments), torch.cos(embedding_arguments)], dim=-1)
21+
return embeddings

0 commit comments

Comments
 (0)