Skip to content

Commit b5d15a1

Browse files
committed
test: add some unittests
1 parent e660cce commit b5d15a1

File tree

4 files changed

+149
-139
lines changed

4 files changed

+149
-139
lines changed

cellseg_models_pytorch/modules/conv_base.py

Lines changed: 6 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
bias: bool = False,
3232
attention: str = None,
3333
preattend: bool = False,
34-
**kwargs
34+
**kwargs,
3535
) -> None:
3636
"""Conv-block (basic) parent class.
3737
@@ -143,7 +143,7 @@ def __init__(
143143
bias: bool = False,
144144
attention: str = None,
145145
preattend: bool = False,
146-
**kwargs
146+
**kwargs,
147147
) -> None:
148148
"""Bottleneck conv block parent-class.
149149
@@ -305,7 +305,7 @@ def __init__(
305305
kernel_size: int = 3,
306306
attention: str = None,
307307
preattend: bool = False,
308-
**kwargs
308+
**kwargs,
309309
) -> None:
310310
"""Depthwise separable conv block parent class.
311311
@@ -427,7 +427,7 @@ def __init__(
427427
kernel_size: int = 3,
428428
attention: str = None,
429429
preattend: bool = False,
430-
**kwargs
430+
**kwargs,
431431
) -> None:
432432
"""Mobile inverted bottleneck conv parent-class.
433433
@@ -578,7 +578,7 @@ def __init__(
578578
activation: str = "relu",
579579
attention: str = None,
580580
preattend: bool = False,
581-
**kwargs
581+
**kwargs,
582582
) -> None:
583583
"""Fused mobile inverted conv block parent-class.
584584
@@ -714,7 +714,7 @@ def __init__(
714714
kernel_size: int = 3,
715715
attention: str = None,
716716
preattend: bool = False,
717-
**kwargs
717+
**kwargs,
718718
) -> None:
719719
"""Dense block of the HoVer-Net.
720720
@@ -830,135 +830,3 @@ def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor:
830830
x = self.conv(x)
831831

832832
return x
833-
834-
835-
class BasicConvOld(nn.Module):
836-
def __init__(
837-
self,
838-
in_channels: int,
839-
out_channels: int,
840-
same_padding: bool = True,
841-
normalization: str = "bn",
842-
activation: str = "relu",
843-
convolution: str = "conv",
844-
preactivate: bool = False,
845-
kernel_size=3,
846-
groups: int = 1,
847-
bias: bool = False,
848-
attention: str = None,
849-
preattend: bool = False,
850-
**kwargs
851-
) -> None:
852-
"""Conv-block (basic) parent class.
853-
854-
Parameters
855-
----------
856-
in_channels : int
857-
Number of input channels.
858-
out_channels : int
859-
Number of output channels.
860-
same_padding : bool, default=True
861-
if True, performs same-covolution.
862-
normalization : str, default="bn":
863-
Normalization method.
864-
One of: "bn", "bcn", "gn", "in", "ln", None
865-
activation : str, default="relu"
866-
Activation method.
867-
One of: "mish", "swish", "relu", "relu6", "rrelu", "selu",
868-
"celu", "gelu", "glu", "tanh", "sigmoid", "silu", "prelu",
869-
"leaky-relu", "elu", "hardshrink", "tanhshrink", "hardsigmoid"
870-
convolution : str, default="conv"
871-
The convolution method. One of: "conv", "wsconv", "scaled_wsconv"
872-
preactivate : bool, default=False
873-
If True, normalization will be applied before convolution.
874-
kernel_size : int, default=3
875-
The size of the convolution kernel.
876-
groups : int, default=1
877-
Number of groups the kernels are divided into. If `groups == 1`
878-
normal convolution is applied. If `groups = in_channels`
879-
depthwise convolution is applied.
880-
bias : bool, default=False,
881-
Include bias term in the convolution.
882-
attention : str, default=None
883-
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
884-
preattend : bool, default=False
885-
If True, Attention is applied at the beginning of forward pass.
886-
"""
887-
super().__init__()
888-
self.conv_choice = convolution
889-
self.out_channels = out_channels
890-
self.preattend = preattend
891-
self.preactivate = preactivate
892-
893-
# set norm channel number for preactivation or normal
894-
norm_channels = in_channels if preactivate else self.out_channels
895-
896-
# set padding. Works if dilation or stride are not adjusted
897-
padding = (kernel_size - 1) // 2 if same_padding else 0
898-
899-
self.conv = Conv(
900-
name=self.conv_choice,
901-
in_channels=in_channels,
902-
out_channels=out_channels,
903-
kernel_size=kernel_size,
904-
groups=groups,
905-
padding=padding,
906-
bias=bias,
907-
)
908-
909-
self.norm = Norm(normalization, num_features=norm_channels)
910-
self.act = Activation(activation)
911-
912-
# set attention channels
913-
att_channels = in_channels if preattend else self.out_channels
914-
self.att = Attention(attention, in_channels=att_channels)
915-
916-
self.downsample = None
917-
if in_channels != out_channels:
918-
self.downsample = nn.Sequential(
919-
Conv(
920-
self.conv_choice,
921-
in_channels=in_channels,
922-
out_channels=out_channels,
923-
bias=False,
924-
kernel_size=1,
925-
padding=0,
926-
),
927-
Norm(normalization, num_features=out_channels),
928-
)
929-
930-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
931-
"""Forward pass."""
932-
identity = x
933-
if self.downsample is not None:
934-
identity = self.downsample(x)
935-
936-
x = self.att(x)
937-
938-
# residual
939-
x = self.conv(x)
940-
x = self.norm(x)
941-
942-
x += identity
943-
x = self.act(x)
944-
945-
return x
946-
947-
def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor:
948-
"""Forward pass with pre-activation."""
949-
identity = x
950-
if self.downsample is not None:
951-
identity = self.downsample(x)
952-
953-
# pre-attention
954-
x = self.att(x)
955-
956-
# preact residual
957-
x = self.norm(x)
958-
x = self.act(x)
959-
x = self.conv(x)
960-
961-
x += identity
962-
x = self.act(x)
963-
964-
return x
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pytest
2+
import torch
3+
from torch.autograd import gradcheck
4+
from cellseg_models_pytorch.modules.act.swish import Swish
5+
from cellseg_models_pytorch.modules.act.mish import Mish
6+
from cellseg_models_pytorch.modules.act.gated_gelu import GEGLU
7+
8+
@pytest.mark.parametrize("dim_in, dim_out", [
9+
(64, 128),
10+
(128, 256),
11+
])
12+
def test_geglu(dim_in, dim_out):
13+
# Create the GEGLU layer
14+
geglu = GEGLU(dim_in=dim_in, dim_out=dim_out)
15+
16+
# Create a random input tensor with the specified shape
17+
x = torch.rand((1, 32, dim_in))
18+
19+
# Forward pass
20+
output = geglu(x)
21+
22+
# Check the output shape
23+
assert output.shape == (1, 32, dim_out)
24+
25+
# Check the output type
26+
assert isinstance(output, torch.Tensor)
27+
28+
29+
@pytest.mark.parametrize("batch_size, num_features", [
30+
(1, 10),
31+
(2, 20),
32+
])
33+
def test_mish_fwdbwd(batch_size, num_features):
34+
# Create the Mish layer
35+
mish_layer = Mish()
36+
37+
# Create a random input tensor with the specified shape
38+
x = torch.randn(batch_size, num_features, requires_grad=True)
39+
40+
# Forward pass
41+
output = mish_layer(x)
42+
43+
# Check the output shape
44+
assert output.shape == x.shape
45+
46+
# Check the output type
47+
assert isinstance(output, torch.Tensor)
48+
49+
# Backward pass
50+
output.sum().backward()
51+
assert x.grad is not None
52+
53+
# Gradient check
54+
# assert gradcheck(mish, (x,), eps=1e-6, atol=1e-4)
55+
56+
57+
@pytest.mark.parametrize("batch_size, num_features", [
58+
(1, 10),
59+
(2, 20),
60+
])
61+
def test_swish_fwdbwd(batch_size, num_features):
62+
# Create the Swish layer
63+
swish_layer = Swish()
64+
65+
# Create a random input tensor with the specified shape
66+
x = torch.randn(batch_size, num_features, requires_grad=True)
67+
68+
# Forward pass
69+
output = swish_layer(x)
70+
71+
# Check the output shape
72+
assert output.shape == x.shape
73+
74+
# Check the output type
75+
assert isinstance(output, torch.Tensor)
76+
77+
# Backward pass
78+
output.sum().backward()
79+
assert x.grad is not None
80+
81+
# Gradient check
82+
# assert gradcheck(swish, (x,))

cellseg_models_pytorch/modules/tests/test_conv_blocks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from cellseg_models_pytorch.modules import ConvBlock
5+
from cellseg_models_pytorch.modules.mlp import ConvMlp
56

67

78
@pytest.mark.parametrize(
@@ -43,3 +44,31 @@ def test_conv_block_fwdbwd(
4344

4445
assert output.shape == torch.Size([1, out_channels, 16, 16])
4546
assert output.dtype == input.dtype
47+
48+
49+
@pytest.mark.parametrize("in_channels, out_channels, input_shape", [
50+
(32, 16, (1, 32, 32, 32)),
51+
(32, None, (2, 32, 32, 32)),
52+
])
53+
def test_convmlp(in_channels, out_channels, input_shape):
54+
conv_mlp = ConvMlp(
55+
in_channels=in_channels,
56+
out_channels=out_channels,
57+
mlp_ratio=1
58+
)
59+
60+
if out_channels is None:
61+
out_channels = in_channels
62+
63+
# Create a random input tensor with the specified shape
64+
x = torch.rand(input_shape).float()
65+
66+
# Forward pass
67+
output = conv_mlp(x)
68+
69+
# Check the output shape
70+
expected_out_channels = in_channels if out_channels is None else out_channels
71+
assert output.shape == (input_shape[0], expected_out_channels, input_shape[2], input_shape[3])
72+
73+
# Check the output type
74+
assert isinstance(output, torch.Tensor)

cellseg_models_pytorch/modules/tests/test_miscmodules.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
StyleBlock,
88
StyleReshape,
99
)
10-
10+
from cellseg_models_pytorch.modules.patch_embeddings import PatchEmbed
1111

1212
@pytest.mark.parametrize("in_channels", [32, 16])
1313
@pytest.mark.parametrize("out_channels", [16, 32])
@@ -37,3 +37,34 @@ def test_stylefwdbwd(in_channels, style_channels, out_channels):
3737
out.mean().backward()
3838

3939
assert out.shape[1] == out_channels
40+
41+
42+
@pytest.mark.parametrize("in_channels, patch_size, head_dim, num_heads, input_shape", [
43+
(3, 8, 32, 4, (1, 3, 128, 128)),
44+
(3, 4, 16, 2, (1, 3, 64, 64)),
45+
])
46+
def test_patch_embed(in_channels, patch_size, head_dim, num_heads, input_shape):
47+
# Create the PatchEmbed layer
48+
patch_embed = PatchEmbed(
49+
in_channels=in_channels,
50+
patch_size=patch_size,
51+
head_dim=head_dim,
52+
num_heads=num_heads
53+
)
54+
55+
# Create a random input tensor with the specified shape
56+
x = torch.randn(input_shape)
57+
58+
# Forward pass
59+
output = patch_embed(x)
60+
61+
# Calculate expected output shape
62+
B, _, H, W = input_shape
63+
expected_seq_len = (H // patch_size) * (W // patch_size)
64+
expected_proj_dim = head_dim * num_heads
65+
66+
# Check the output shape
67+
assert output.shape == (B, expected_seq_len, expected_proj_dim)
68+
69+
# Check the output type
70+
assert isinstance(output, torch.Tensor)

0 commit comments

Comments
 (0)