Skip to content

Commit 3618144

Browse files
committed
Add tests
1 parent 48baa5c commit 3618144

3 files changed

Lines changed: 149 additions & 0 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from auto_cast.decoders.channels_last import ChannelsLast
4+
5+
6+
def test_channels_last_reorders_dimensions():
7+
decoder = ChannelsLast()
8+
x = torch.randn(2, 3, 4, 5, 6)
9+
10+
output = decoder(x)
11+
12+
assert output.shape == (2, 4, 5, 6, 3)
13+
# Spot-check a value to ensure permutation matches expectation
14+
assert torch.allclose(output[0, 0, 0, 0, 0], x[0, 0, 0, 0, 0])
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
from auto_cast.encoders.permute_concat import PermuteConcat
4+
from auto_cast.types import Batch
5+
6+
7+
def _make_batch(
8+
batch_size: int = 1,
9+
t: int = 1,
10+
w: int = 2,
11+
h: int = 3,
12+
c: int = 2,
13+
const_c: int = 1,
14+
scalar_c: int = 1,
15+
) -> Batch:
16+
input_fields = torch.arange(batch_size * t * w * h * c, dtype=torch.float32)
17+
input_fields = input_fields.view(batch_size, t, w, h, c)
18+
output_fields = torch.zeros(batch_size, t, w, h, c)
19+
constant_fields = torch.ones(batch_size, w, h, const_c)
20+
constant_scalars = torch.full((batch_size, scalar_c), 5.0)
21+
return Batch(
22+
input_fields=input_fields,
23+
output_fields=output_fields,
24+
constant_scalars=constant_scalars,
25+
constant_fields=constant_fields,
26+
)
27+
28+
29+
def test_permute_concat_with_constants():
30+
encoder = PermuteConcat(with_constants=True)
31+
batch = _make_batch()
32+
33+
encoded = encoder(batch)
34+
35+
expected = batch.input_fields.permute(0, 4, 1, 2, 3)
36+
37+
base_channels = batch.input_fields.shape[-1]
38+
assert batch.constant_fields is not None
39+
assert batch.constant_scalars is not None
40+
const_channels = batch.constant_fields.shape[-1]
41+
scalar_channels = batch.constant_scalars.shape[-1]
42+
43+
assert encoded.shape == (
44+
batch.input_fields.shape[0],
45+
base_channels + const_channels + scalar_channels,
46+
batch.input_fields.shape[1],
47+
batch.input_fields.shape[2],
48+
batch.input_fields.shape[3],
49+
)
50+
51+
assert torch.allclose(encoded[:, :base_channels, ...], expected)
52+
const_slice = encoded[:, base_channels : base_channels + const_channels, ...]
53+
assert torch.allclose(const_slice, torch.ones_like(const_slice))
54+
scalar_slice = encoded[:, -scalar_channels:, ...]
55+
assert torch.allclose(scalar_slice, torch.full_like(scalar_slice, 5.0))
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import lightning as L
2+
import torch
3+
from torch import nn
4+
from torch.utils.data import DataLoader, Dataset
5+
6+
from auto_cast.decoders.channels_last import ChannelsLast
7+
from auto_cast.encoders.permute_concat import PermuteConcat
8+
from auto_cast.models.encoder_decoder import EncoderDecoder
9+
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
10+
from auto_cast.nn.base import Module
11+
from auto_cast.types import Batch, Tensor
12+
13+
14+
class TinyProcessor(Module):
15+
def __init__(self) -> None:
16+
super().__init__()
17+
self.conv = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=1)
18+
19+
def forward(self, x: Tensor) -> Tensor:
20+
return self.conv(x)
21+
22+
23+
def _toy_batch(
24+
batch_size: int = 2,
25+
t_in: int = 2,
26+
t_out: int | None = None,
27+
w: int = 4,
28+
h: int = 4,
29+
c: int = 1,
30+
) -> Batch:
31+
t_out = t_out or t_in
32+
input_fields = torch.randn(batch_size, t_in, w, h, c)
33+
output_fields = torch.randn(batch_size, t_out, w, h, c)
34+
return Batch(
35+
input_fields=input_fields,
36+
output_fields=output_fields,
37+
constant_scalars=None,
38+
constant_fields=None,
39+
)
40+
41+
42+
class _BatchDataset(Dataset):
43+
def __len__(self) -> int:
44+
return 2 # keep small
45+
46+
def __getitem__(self, idx: int) -> Batch:
47+
return _toy_batch(batch_size=1)
48+
49+
50+
dummy_loader = DataLoader(
51+
_BatchDataset(),
52+
batch_size=1,
53+
collate_fn=lambda items: items[0],
54+
num_workers=0,
55+
)
56+
57+
58+
def test_encoder_processor_decoder_training_step_runs():
59+
encoder = PermuteConcat(with_constants=False)
60+
decoder = ChannelsLast()
61+
loss = nn.MSELoss()
62+
encoder_decoder = EncoderDecoder(encoder=encoder, decoder=decoder, loss_func=loss)
63+
64+
processor = TinyProcessor()
65+
model = EncoderProcessorDecoder.from_encoder_processor_decoder(
66+
encoder_decoder=encoder_decoder,
67+
processor=processor,
68+
loss_func=loss,
69+
)
70+
71+
batch = _toy_batch()
72+
train_loss = model.training_step(batch, 0)
73+
74+
assert train_loss.shape == ()
75+
train_loss.backward()
76+
77+
trainer = L.Trainer(
78+
max_epochs=1, logger=False, enable_checkpointing=False, limit_train_batches=1
79+
)
80+
trainer.fit(model, train_dataloaders=dummy_loader, val_dataloaders=dummy_loader)

0 commit comments

Comments
 (0)