Skip to content

Commit 6bb252b

Browse files
committed
add license
1 parent d3d341b commit 6bb252b

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

ml4h/recipes.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -754,10 +754,21 @@ def infer_multimodal_multitask(args):
754754
next(generate_test) # this prints end of epoch info
755755
logging.info(f"Inference on {stats['count']} tensors finished. Inference TSV file at: {inference_tsv}")
756756
break
757-
prediction = model.predict(input_data, verbose=0)
758-
if len(no_fail_tmaps_out) == 1:
759-
prediction = [prediction]
760-
predictions_dict = {name: pred for name, pred in zip(model.output_names, prediction)}
757+
predictions = model.predict(input_data, verbose=0)
758+
if isinstance(predictions, dict):
759+
predictions_dict = predictions
760+
761+
elif isinstance(predictions, (list, tuple)):
762+
# Map outputs by model.output_names (strings)
763+
predictions_dict = {
764+
str(name): pred for name, pred in zip(model.output_names, predictions)
765+
}
766+
else:
767+
# Single tensor output
768+
predictions_dict = {
769+
model.output_names[0]: predictions
770+
}
771+
761772
sample_id = os.path.basename(tensor_paths[0]).replace(TENSOR_EXT, '')
762773
csv_row = [sample_id]
763774
if tsv_style_is_genetics:
@@ -779,7 +790,7 @@ def infer_multimodal_multitask(args):
779790
actual = output_data[otm.output_name()][0][i]
780791
csv_row.append("NA" if np.isnan(actual) else str(actual))
781792
except (IndexError, KeyError):
782-
logging.warning(f'Error in infer at {otm.name} item {i} key {k} with cm: {otm.channel_map} y is {y.shape} y is {y}')
793+
logging.warning(f'index error at {otm.name} item {i} key {k} with cm: {otm.channel_map} y is {y.shape} y is {y}')
783794
elif otm.is_survival_curve():
784795
intervals = otm.shape[-1] // 2
785796
days_per_bin = 1 + otm.days_window // intervals

0 commit comments

Comments
 (0)