Skip to content

Commit 7b64b30

Browse files
committed
Added dependents and dependent map functionality in tensormap
1 parent 2550929 commit 7b64b30

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

ml4h/ml4ht_integration/tensor_map.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,23 @@ def __call__(self, path: str) -> Batch:
7171
with h5py.File(path, 'r') as hd5:
7272
in_batch = {}
7373
for tm in self.tensor_maps_in:
74+
if tm.tensor_from_file and tm.dependent_map:
75+
if isinstance(tm.dependent_map, list):
76+
dependents = {dep.name: dep for dep in tm.dependent_map}
77+
else:
78+
dependents = {tm.dependent_map.name: tm.dependent_map}
79+
7480
in_batch[tm.input_name()] = tm.postprocess_tensor(
7581
tm.tensor_from_file(tm, hd5, dependents),
7682
augment=self.augment, hd5=hd5,
7783
)
7884
out_batch = {}
7985
for tm in self.tensor_maps_out:
80-
# TODO: Check for dependents here
86+
if tm.tensor_from_file and tm.dependent_map:
87+
if isinstance(tm.dependent_map, list):
88+
dependents = {dep.name: dep for dep in tm.dependent_map}
89+
else:
90+
dependents = {tm.dependent_map.name: tm.dependent_map}
8191
out_batch[tm.output_name()] = tm.postprocess_tensor(
8292
tm.tensor_from_file(tm, hd5, dependents),
8393
augment=self.augment, hd5=hd5,

0 commit comments

Comments
 (0)