Skip to content

Commit 6338ea4

Browse files
rijobrowyli
andauthored
Permanent dataset hashes transforms (#4633)
* permanent dataset hashes transforms Signed-off-by: Richard Brown <[email protected]> * fixes Signed-off-by: Richard Brown <[email protected]> * documentation Signed-off-by: Richard Brown <[email protected]> * default to no hash Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent 398466c commit 6338ea4

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

Diff for: monai/data/dataset.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,10 @@ class PersistentDataset(Dataset):
191191
Note:
192192
The input data must be a list of file paths and will hash them as cache keys.
193193
194-
When loading persistent cache content, it can't guarantee the cached data matches current
195-
transform chain, so please make sure to use exactly the same non-random transforms and the
196-
args as the cache content, otherwise, it may cause unexpected errors.
194+
The filenames of the cached files also try to contain the hash of the transforms. In this
195+
fashion, `PersistentDataset` should be robust to changes in transforms. This, however, is
196+
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
197+
errors. If in doubt, it is advisable to clear the cache directory.
197198
198199
"""
199200

@@ -205,6 +206,7 @@ def __init__(
205206
hash_func: Callable[..., bytes] = pickle_hashing,
206207
pickle_module: str = "pickle",
207208
pickle_protocol: int = DEFAULT_PROTOCOL,
209+
hash_transform: Optional[Callable[..., bytes]] = None,
208210
) -> None:
209211
"""
210212
Args:
@@ -232,6 +234,9 @@ def __init__(
232234
pickle_protocol: can be specified to override the default protocol, default to `2`.
233235
this arg is used by `torch.save`, for more details, please check:
234236
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
237+
hash_transform: a callable to compute hash from the transform information when caching.
238+
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
239+
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
235240
236241
"""
237242
if not isinstance(transform, Compose):
@@ -246,6 +251,29 @@ def __init__(
246251
self.cache_dir.mkdir(parents=True, exist_ok=True)
247252
if not self.cache_dir.is_dir():
248253
raise ValueError("cache_dir must be a directory.")
254+
self.transform_hash = ""
255+
if hash_transform is not None:
256+
self.set_transform_hash(hash_transform)
257+
258+
def set_transform_hash(self, hash_xform_func):
259+
"""Get hashable transforms, and then hash them. Hashable transforms
260+
are deterministic transforms that inherit from `Transform`. We stop
261+
at the first non-deterministic transform, or first that does not
262+
inherit from MONAI's `Transform` class."""
263+
hashable_transforms = []
264+
for _tr in self.transform.flatten().transforms:
265+
if isinstance(_tr, Randomizable) or not isinstance(_tr, Transform):
266+
break
267+
hashable_transforms.append(_tr)
268+
# Try to hash. Fall back to a hash of their names
269+
try:
270+
self.transform_hash = hash_xform_func(hashable_transforms)
271+
except TypeError as te:
272+
if "is not JSON serializable" not in str(te):
273+
raise te
274+
names = "".join(tr.__class__.__name__ for tr in hashable_transforms)
275+
self.transform_hash = hash_xform_func(names)
276+
self.transform_hash = self.transform_hash.decode("utf-8")
249277

250278
def set_data(self, data: Sequence):
251279
"""
@@ -325,6 +353,7 @@ def _cachecheck(self, item_transformed):
325353
hashfile = None
326354
if self.cache_dir is not None:
327355
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
356+
data_item_md5 += self.transform_hash
328357
hashfile = self.cache_dir / f"{data_item_md5}.pt"
329358

330359
if hashfile is not None and hashfile.is_file(): # cache hit

Diff for: tests/test_persistentdataset.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from parameterized import parameterized
2020

2121
from monai.data import PersistentDataset, json_hashing
22-
from monai.transforms import Compose, LoadImaged, SimulateDelayd, Transform
22+
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform
2323

2424
TEST_CASE_1 = [
2525
Compose(
@@ -77,7 +77,7 @@ def test_cache(self):
7777

7878
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
7979
def test_shape(self, transform, expected_shape):
80-
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
80+
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
8181
with tempfile.TemporaryDirectory() as tempdir:
8282
nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
8383
nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
@@ -150,6 +150,19 @@ def test_shape(self, transform, expected_shape):
150150
self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz"))
151151
self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz"))
152152

153+
def test_different_transforms(self):
154+
"""
155+
Different instances of `PersistentDataset` with the same cache_dir,
156+
same input data, but different transforms should give different results.
157+
"""
158+
shape = (1, 10, 9, 8)
159+
im = np.arange(0, np.prod(shape)).reshape(shape)
160+
with tempfile.TemporaryDirectory() as path:
161+
im1 = PersistentDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing)[0]
162+
im2 = PersistentDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing)[0]
163+
l2 = ((im1 - im2) ** 2).sum() ** 0.5
164+
self.assertTrue(l2 > 1)
165+
153166

154167
if __name__ == "__main__":
155168
unittest.main()

0 commit comments

Comments
 (0)