-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathnaive_latent_space_decoder.py
More file actions
67 lines (50 loc) · 2.09 KB
/
naive_latent_space_decoder.py
File metadata and controls
67 lines (50 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import math
from typing import Any
import torch.nn as nn
from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW
from .base_decoder import BaseDecoder
class NaiveLatentSpaceDecoder(BaseDecoder):
"""
Naive, linear decoder that takes data in a latent space and translates it to a larger output space
Latent space:
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
Output space:
TensorNTCHW with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
"""
def __init__(
self, *, latent_space: DataSpace, output_space: DataSpace, **kwargs: Any
) -> None:
super().__init__(**kwargs)
# List of layers
layers: list[nn.Module] = []
# Add size-increasing convolutional layers until we are larger than the output shape
n_conv_layers = math.floor(
math.log2(min(*output_space.shape) / max(*latent_space.shape))
)
n_channels = latent_space.channels
for _ in range(n_conv_layers):
layers.append(
nn.ConvTranspose2d(
n_channels, n_channels // 2, kernel_size=4, stride=2, padding=1
)
)
n_channels //= 2
# Resample to the desired output shape
layers.append(nn.Upsample(output_space.shape))
# Convolve to the desired number of output channels
layers.append(
nn.Conv2d(n_channels, output_space.channels * self.n_forecast_steps, 1)
)
# Unflatten the time and channels
layers.append(nn.Unflatten(1, [self.n_forecast_steps, output_space.channels]))
# Combine the layers sequentially
self.model = nn.Sequential(*layers)
def forward(self, x: TensorNCHW) -> TensorNTCHW:
"""
Transformation summary
Args:
x: TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
Returns:
TensorNTCHW with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
"""
return self.model(x)