-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathnaive_latent_space_encoder.py
More file actions
65 lines (48 loc) · 2.04 KB
/
naive_latent_space_encoder.py
File metadata and controls
65 lines (48 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
from typing import Any
import torch.nn as nn
from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW
from .base_encoder import BaseEncoder
class NaiveLatentSpaceEncoder(BaseEncoder):
"""
Naive, linear encoder that takes data in an input space and translates it to a smaller latent space
Input space:
TensorNTCHW with (batch_size, n_history_steps, input_channels, input_height, input_width)
Latent space:
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
"""
def __init__(
self, *, input_space: DataSpace, latent_space: DataSpace, **kwargs: Any
) -> None:
super().__init__(name=input_space.name, **kwargs)
# Construct list of layers
layers: list[nn.Module] = []
# Start by flattening the time and channels
layers.append(nn.Flatten(1, 2))
n_channels = input_space.channels * self.n_history_steps
# Add size-reducing convolutional layers while we are larger than the latent shape
n_conv_layers = math.floor(
math.log2(min(*input_space.shape) / max(*latent_space.shape))
)
for _ in range(n_conv_layers):
layers.append(
nn.Conv2d(
n_channels, 2 * n_channels, kernel_size=4, stride=2, padding=1
)
)
n_channels *= 2
# Resample to the desired latent shape
layers.append(nn.Upsample(latent_space.shape))
# Convolve to the desired number of latent channels
layers.append(nn.Conv2d(n_channels, latent_space.channels, 1))
# Combine the layers sequentially
self.model = nn.Sequential(*layers)
def forward(self, x: TensorNTCHW) -> TensorNCHW:
"""
Transformation summary
Args:
x: TensorNTCHW with (batch_size, n_history_steps, input_channels, input_height, input_width)
Returns:
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
"""
return self.model(x)