@@ -71,13 +71,23 @@ def __call__(self, path: str) -> Batch:
7171 with h5py .File (path , 'r' ) as hd5 :
7272 in_batch = {}
7373 for tm in self .tensor_maps_in :
74+ if tm .tensor_from_file and tm .dependent_map :
75+ if isinstance (tm .dependent_map , list ):
76+ dependents = {dep .name : dep for dep in tm .dependent_map }
77+ else :
78+ dependents = {tm .dependent_map .name : tm .dependent_map }
79+
7480 in_batch [tm .input_name ()] = tm .postprocess_tensor (
7581 tm .tensor_from_file (tm , hd5 , dependents ),
7682 augment = self .augment , hd5 = hd5 ,
7783 )
7884 out_batch = {}
7985 for tm in self .tensor_maps_out :
80- # TODO: Check for dependents here
86+ if tm .tensor_from_file and tm .dependent_map :
87+ if isinstance (tm .dependent_map , list ):
88+ dependents = {dep .name : dep for dep in tm .dependent_map }
89+ else :
90+ dependents = {tm .dependent_map .name : tm .dependent_map }
8191 out_batch [tm .output_name ()] = tm .postprocess_tensor (
8292 tm .tensor_from_file (tm , hd5 , dependents ),
8393 augment = self .augment , hd5 = hd5 ,
0 commit comments