Skip to content

Commit 3dbe9e4

Browse files
mjanuszcopybara-github
authored andcommitted
Add more unit tests.
PiperOrigin-RevId: 874047196
1 parent a6d118f commit 3dbe9e4

File tree

5 files changed

+148
-4
lines changed

5 files changed

+148
-4
lines changed

ffn/input/load_data_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from absl.testing import absltest
17+
18+
19+
if __name__ == '__main__':
20+
absltest.main()

ffn/input/volume.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import copy
1919
import dataclasses
2020
import functools as ft
21+
import hashlib
2122
from typing import Any, Callable, Sequence, TypeVar
2223

2324
from absl import logging
2425
import array_record
2526
from connectomics.common import array
2627
from connectomics.common import bounding_box
28+
from connectomics.common import box_generator
2729
from connectomics.common import io_utils
2830
from ffn.input import segmentation
2931
from ffn.training import augmentation
@@ -193,11 +195,14 @@ def _load_data(
193195
ret = dict(ex)
194196
dtype_remap = {tf.uint64: tf.int64}
195197

198+
ret['coord'] = tf.reshape(ex['coord'], [1, 3])
199+
ret['volname'] = tf.reshape(ex['volname'], [1])
200+
196201
for name, vol in config.volumes.items():
197202
if vol.oob_mask:
198203
ret[name] = inputs.make_oob_mask(
199-
ex['coord'],
200-
ex['volname'],
204+
ret['coord'],
205+
ret['volname'],
201206
shape=vol.load_shape,
202207
volinfo_map_string=get_path_str(vol.paths),
203208
)
@@ -312,7 +317,7 @@ def _filter_coordinates_by_bbox(
312317
) -> tf.Tensor:
313318
ret = tf.numpy_function(
314319
lambda c, v: _coord_in_bboxes_np(c, v, bboxes),
315-
[item['coord'][0], item['volname'][0]],
320+
[item['coord'][0], tf.reshape(item['volname'], [-1])[0]],
316321
tf.bool,
317322
)
318323
ret.set_shape([])

ffn/input/volume_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from absl.testing import absltest
17+
18+
19+
if __name__ == '__main__':
20+
absltest.main()

ffn/training/inputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,10 @@ def filter_oob(
785785
use_bboxes: bool = True
786786
) -> tf.Tensor:
787787
radius = np.floor_divide(patch_size, 2)
788+
coord = tf.reshape(item['coord'], [1, 3])
789+
volname = tf.reshape(item['volname'], [1])
788790
return coordinates_in_bounds(
789-
item['coord'], item['volname'], radius, volinfo_map_string,
791+
coord, volname, radius, volinfo_map_string,
790792
use_bboxes
791793
)
792794

ffn/training/inputs_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2026 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import functools as ft
17+
import os
18+
19+
20+
from absl.testing import absltest
21+
from connectomics.common import bounding_box
22+
from connectomics.common import tuples
23+
from connectomics.volume import metadata
24+
from ffn.training import inputs
25+
import numpy as np
26+
import tensorflow.compat.v1 as tf
27+
28+
tf.disable_eager_execution()
29+
30+
31+
class _FilterOobTestBase:
32+
33+
_volinfo_map_string: str
34+
35+
def test_filter_oob_in_bounds(self):
36+
coord = tf.constant([[50, 50, 50]], dtype=tf.int64)
37+
volname = tf.constant(['testvol'], dtype=tf.string)
38+
item = {'coord': coord, 'volname': volname}
39+
result = inputs.filter_oob(
40+
item, self._volinfo_map_string, patch_size=[10, 10, 10]
41+
)
42+
with tf.Session() as sess:
43+
self.assertTrue(sess.run(result))
44+
45+
def test_filter_oob_out_of_bounds(self):
46+
coord = tf.constant([[0, 0, 0]], dtype=tf.int64)
47+
volname = tf.constant(['testvol'], dtype=tf.string)
48+
item = {'coord': coord, 'volname': volname}
49+
result = inputs.filter_oob(
50+
item, self._volinfo_map_string, patch_size=[10, 10, 10]
51+
)
52+
with tf.Session() as sess:
53+
self.assertFalse(sess.run(result))
54+
55+
def test_filter_oob_in_dataset_filter(self):
56+
ds = tf.data.Dataset.from_tensors({
57+
'coord': tf.constant([[50, 50, 50]], dtype=tf.int64),
58+
'volname': tf.constant(['testvol'], dtype=tf.string),
59+
})
60+
ds = ds.filter(
61+
ft.partial(
62+
inputs.filter_oob,
63+
volinfo_map_string=self._volinfo_map_string,
64+
patch_size=[10, 10, 10],
65+
)
66+
)
67+
iterator = tf.data.make_one_shot_iterator(ds)
68+
item = iterator.get_next()
69+
with tf.Session() as sess:
70+
result = sess.run(item)
71+
np.testing.assert_array_equal(result['coord'], [[50, 50, 50]])
72+
73+
74+
class FilterOobMetadataJsonTest(_FilterOobTestBase, absltest.TestCase):
75+
76+
def setUp(self):
77+
super().setUp()
78+
tf.reset_default_graph()
79+
80+
self._tmpdir = self.create_tempdir().full_path
81+
meta = metadata.VolumeMetadata(
82+
path='none',
83+
volume_size=tuples.XYZ(100, 100, 100),
84+
pixel_size=tuples.XYZ(8, 8, 30),
85+
bounding_boxes=[
86+
bounding_box.BoundingBox(start=(0, 0, 0), size=(100, 100, 100))
87+
],
88+
)
89+
self._metadata_path = os.path.join(self._tmpdir, 'metadata.json')
90+
with open(self._metadata_path, 'w') as f:
91+
f.write(meta.to_json())
92+
93+
self._volinfo_map_string = f'testvol:{self._metadata_path}'
94+
95+
96+
if __name__ == '__main__':
97+
absltest.main()

0 commit comments

Comments
 (0)