@@ -754,10 +754,21 @@ def infer_multimodal_multitask(args):
754754 next (generate_test ) # this prints end of epoch info
755755 logging .info (f"Inference on { stats ['count' ]} tensors finished. Inference TSV file at: { inference_tsv } " )
756756 break
757- prediction = model .predict (input_data , verbose = 0 )
758- if len (no_fail_tmaps_out ) == 1 :
759- prediction = [prediction ]
760- predictions_dict = {name : pred for name , pred in zip (model .output_names , prediction )}
757+ predictions = model .predict (input_data , verbose = 0 )
758+ if isinstance (predictions , dict ):
759+ predictions_dict = predictions
760+
761+ elif isinstance (predictions , (list , tuple )):
762+ # Map outputs by model.output_names (strings)
763+ predictions_dict = {
764+ str (name ): pred for name , pred in zip (model .output_names , predictions )
765+ }
766+ else :
767+ # Single tensor output
768+ predictions_dict = {
769+ model .output_names [0 ]: predictions
770+ }
771+
761772 sample_id = os .path .basename (tensor_paths [0 ]).replace (TENSOR_EXT , '' )
762773 csv_row = [sample_id ]
763774 if tsv_style_is_genetics :
@@ -779,7 +790,7 @@ def infer_multimodal_multitask(args):
779790 actual = output_data [otm .output_name ()][0 ][i ]
780791 csv_row .append ("NA" if np .isnan (actual ) else str (actual ))
781792 except (IndexError , KeyError ):
782- logging .warning (f'Error in infer at { otm .name } item { i } key { k } with cm: { otm .channel_map } y is { y .shape } y is { y } ' )
793+ logging .warning (f'index error at { otm .name } item { i } key { k } with cm: { otm .channel_map } y is { y .shape } y is { y } ' )
783794 elif otm .is_survival_curve ():
784795 intervals = otm .shape [- 1 ] // 2
785796 days_per_bin = 1 + otm .days_window // intervals
0 commit comments