@@ -191,9 +191,10 @@ class PersistentDataset(Dataset):
191
191
Note:
192
192
The input data must be a list of file paths and will hash them as cache keys.
193
193
194
- When loading persistent cache content, it can't guarantee the cached data matches current
195
- transform chain, so please make sure to use exactly the same non-random transforms and the
196
- args as the cache content, otherwise, it may cause unexpected errors.
194
+ The filenames of the cached files also try to contain the hash of the transforms. In this
195
+ fashion, `PersistentDataset` should be robust to changes in transforms. This, however, is
196
+ not guaranteed, so caution should be used when modifying transforms to avoid unexpected
197
+ errors. If in doubt, it is advisable to clear the cache directory.
197
198
198
199
"""
199
200
@@ -205,6 +206,7 @@ def __init__(
205
206
hash_func : Callable [..., bytes ] = pickle_hashing ,
206
207
pickle_module : str = "pickle" ,
207
208
pickle_protocol : int = DEFAULT_PROTOCOL ,
209
+ hash_transform : Optional [Callable [..., bytes ]] = None ,
208
210
) -> None :
209
211
"""
210
212
Args:
@@ -232,6 +234,9 @@ def __init__(
232
234
pickle_protocol: can be specified to override the default protocol, default to `2`.
233
235
this arg is used by `torch.save`, for more details, please check:
234
236
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
237
+ hash_transform: a callable to compute hash from the transform information when caching.
238
+ This may reduce errors due to transforms changing during experiments. Default to None (no hash).
239
+ Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
235
240
236
241
"""
237
242
if not isinstance (transform , Compose ):
@@ -246,6 +251,29 @@ def __init__(
246
251
self .cache_dir .mkdir (parents = True , exist_ok = True )
247
252
if not self .cache_dir .is_dir ():
248
253
raise ValueError ("cache_dir must be a directory." )
254
+ self .transform_hash = ""
255
+ if hash_transform is not None :
256
+ self .set_transform_hash (hash_transform )
257
+
258
+ def set_transform_hash (self , hash_xform_func ):
259
+ """Get hashable transforms, and then hash them. Hashable transforms
260
+ are deterministic transforms that inherit from `Transform`. We stop
261
+ at the first non-deterministic transform, or first that does not
262
+ inherit from MONAI's `Transform` class."""
263
+ hashable_transforms = []
264
+ for _tr in self .transform .flatten ().transforms :
265
+ if isinstance (_tr , Randomizable ) or not isinstance (_tr , Transform ):
266
+ break
267
+ hashable_transforms .append (_tr )
268
+ # Try to hash. Fall back to a hash of their names
269
+ try :
270
+ self .transform_hash = hash_xform_func (hashable_transforms )
271
+ except TypeError as te :
272
+ if "is not JSON serializable" not in str (te ):
273
+ raise te
274
+ names = "" .join (tr .__class__ .__name__ for tr in hashable_transforms )
275
+ self .transform_hash = hash_xform_func (names )
276
+ self .transform_hash = self .transform_hash .decode ("utf-8" )
249
277
250
278
def set_data (self , data : Sequence ):
251
279
"""
@@ -325,6 +353,7 @@ def _cachecheck(self, item_transformed):
325
353
hashfile = None
326
354
if self .cache_dir is not None :
327
355
data_item_md5 = self .hash_func (item_transformed ).decode ("utf-8" )
356
+ data_item_md5 += self .transform_hash
328
357
hashfile = self .cache_dir / f"{ data_item_md5 } .pt"
329
358
330
359
if hashfile is not None and hashfile .is_file (): # cache hit
0 commit comments