diff --git a/test.py b/test.py index ab914c8..9ba09cc 100644 --- a/test.py +++ b/test.py @@ -8,6 +8,7 @@ """ import argparse +import csv import numpy as np import nibabel as nib import os @@ -188,6 +189,11 @@ def main(): else: # Fetch existing models in the current directory. The models are assumed to be named "best_metric_model*.pth" path_models = [f for f in os.listdir('.') if os.path.isfile(f) and f.startswith('best_metric_model')] + + # Check that if -u is specified, there is more than one model state file + if args.uncertainty and len(path_models) == 1: + raise ValueError("The -u flag is only valid if there is more than one model state file.") + # Load the trained 2D U-Net models models = [] for path_model in path_models: @@ -218,7 +224,9 @@ def main(): with torch.no_grad(): segmented_slices = [] segmented_slices_std = [] - for data in tqdm(dataloader, desc=f"Segment image", unit="image"): + prediction_std_list = [] + + for index, data in enumerate(tqdm(dataloader, desc="Segment image", unit="image")): image = data["image"].to(device) # TODO: parametrize values below roi_size = (192, 192) @@ -248,6 +256,11 @@ def main(): # segmented_slice_ensemble = np.mean(segmented_slice_ensemble_all, axis=0) if args.uncertainty: segmented_slice_ensemble_std = np.std(segmented_slice_ensemble_all, axis=0) + # Calculate the average across non-null values + non_null_values = np.where(np.isnan(segmented_slice_ensemble_all), 0, segmented_slice_ensemble_all) + prediction_std = np.mean(non_null_values) + # print(f"Slice #{index}, Prediction STD: {prediction_std}") + prediction_std_list.append([index, prediction_std]) # Add the segmented slice to the list of segmented slices segmented_slices.append(segmented_slice_ensemble_all_aggregated) if args.uncertainty: @@ -276,6 +289,13 @@ def main(): print(f"Done! Output file: {fname_out}") + # Write the data to the CSV file + output_file = 'prediction_std.csv' + with open(output_file, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Index', 'Prediction STD']) # Write the header + writer.writerows(prediction_std_list) # Write the data rows + if __name__ == "__main__": main()