Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ice_station_zebra/models/common/bottleneckblock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Type

import torch.nn as nn
from torch import Tensor

Expand All @@ -9,18 +11,19 @@ def __init__(
out_channels: int,
*,
filter_size: int,
activation: Type[nn.Module] = nn.ReLU,
Comment thread
marianovitasari20 marked this conversation as resolved.
Outdated
) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
activation(),
nn.Conv2d(
out_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
activation(),
nn.BatchNorm2d(num_features=out_channels),
)

Expand Down
9 changes: 6 additions & 3 deletions ice_station_zebra/models/common/convblock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Type

import torch.nn as nn
from torch import Tensor

Expand All @@ -10,18 +12,19 @@ def __init__(
*,
filter_size: int,
final: bool = False,
activation: Type[nn.Module] = nn.ReLU,
) -> None:
super().__init__()

layers = [
nn.Conv2d(
in_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
activation(),
nn.Conv2d(
out_channels, out_channels, kernel_size=filter_size, padding="same"
),
nn.ReLU(inplace=True),
activation(),
]
if final:
layers += [
Expand All @@ -31,7 +34,7 @@ def __init__(
kernel_size=filter_size,
padding="same",
),
nn.ReLU(inplace=True),
activation(),
]

else:
Expand Down
11 changes: 9 additions & 2 deletions ice_station_zebra/models/common/upconvblock.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from typing import Type

import torch.nn as nn
from torch import Tensor


class UpconvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
def __init__(
self,
in_channels: int,
out_channels: int,
activation: Type[nn.Module] = nn.ReLU,
) -> None:
super().__init__()

self.model = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_channels, out_channels, kernel_size=2, padding="same"),
nn.ReLU(inplace=True),
activation(),
)

def forward(self, x: Tensor) -> Tensor:
Expand Down
Loading