Skip to content

Commit 3919c0b

Browse files
sgreenburyradka-jmarjanfamilicispragueContiPaolo
committed
Add test_identity
Co-authored-by: Radka Jersakova <r.jersakova@gmail.com> Co-authored-by: Marjan Famili <marjanfamili@users.noreply.github.com> Co-authored-by: Christopher Iliffe Sprague <cisprague@users.noreply.github.com> Co-authored-by: Paolo Conti <ContiPaolo@users.noreply.github.com>
1 parent b2ef15a commit 3919c0b

3 files changed

Lines changed: 35 additions & 23 deletions

File tree

tests/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@
99
from auto_cast.types import Batch, EncodedBatch
1010

1111

12+
def _make_batch(
13+
batch_size: int = 1,
14+
t: int = 1,
15+
w: int = 2,
16+
h: int = 3,
17+
c: int = 2,
18+
const_c: int = 1,
19+
scalar_c: int = 1,
20+
) -> Batch:
21+
input_fields = torch.arange(batch_size * t * w * h * c, dtype=torch.float32)
22+
input_fields = input_fields.view(batch_size, t, w, h, c)
23+
output_fields = torch.zeros(batch_size, t, w, h, c)
24+
constant_fields = torch.ones(batch_size, w, h, const_c)
25+
constant_scalars = torch.full((batch_size, scalar_c), 5.0)
26+
return Batch(
27+
input_fields=input_fields,
28+
output_fields=output_fields,
29+
constant_scalars=constant_scalars,
30+
constant_fields=constant_fields,
31+
)
32+
33+
1234
def assert_output_valid(output: Tensor, expected_shape: tuple, name: str = "Output"):
1335
"""Assert output has expected shape and contains no NaN values."""
1436
assert output.shape == expected_shape, (

tests/encoders/test_identity.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
from conftest import _make_batch
3+
4+
from auto_cast.encoders.identity import IdentityEncoder
5+
6+
7+
def test_identity():
8+
batch = _make_batch()
9+
encoder = IdentityEncoder()
10+
encoded_batch = encoder.encode_batch(batch)
11+
assert torch.allclose(encoded_batch.encoded_inputs, batch.input_fields)
12+
assert torch.allclose(encoded_batch.encoded_output_fields, batch.output_fields)

tests/encoders/test_permute_concat.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,8 @@
11
import torch
2+
from conftest import _make_batch
23
from einops import rearrange
34

45
from auto_cast.encoders.permute_concat import PermuteConcat
5-
from auto_cast.types import Batch
6-
7-
8-
def _make_batch(
9-
batch_size: int = 1,
10-
t: int = 1,
11-
w: int = 2,
12-
h: int = 3,
13-
c: int = 2,
14-
const_c: int = 1,
15-
scalar_c: int = 1,
16-
) -> Batch:
17-
input_fields = torch.arange(batch_size * t * w * h * c, dtype=torch.float32)
18-
input_fields = input_fields.view(batch_size, t, w, h, c)
19-
output_fields = torch.zeros(batch_size, t, w, h, c)
20-
constant_fields = torch.ones(batch_size, w, h, const_c)
21-
constant_scalars = torch.full((batch_size, scalar_c), 5.0)
22-
return Batch(
23-
input_fields=input_fields,
24-
output_fields=output_fields,
25-
constant_scalars=constant_scalars,
26-
constant_fields=constant_fields,
27-
)
286

297

308
def test_permute_concat_with_constants():

0 commit comments

Comments
 (0)