Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import argparse
import csv
import numpy as np
import nibabel as nib
import os
Expand Down Expand Up @@ -188,6 +189,11 @@ def main():
else:
# Fetch existing models in the current directory. The models are assumed to be named "best_metric_model*.pth"
path_models = [f for f in os.listdir('.') if os.path.isfile(f) and f.startswith('best_metric_model')]

# Check that if -u is specified, there is more than one model state file
if args.uncertainty and len(path_models) == 1:
raise ValueError("The -u flag is only valid if there is more than one model state file.")

# Load the trained 2D U-Net models
models = []
for path_model in path_models:
Expand Down Expand Up @@ -218,7 +224,9 @@ def main():
with torch.no_grad():
segmented_slices = []
segmented_slices_std = []
for data in tqdm(dataloader, desc=f"Segment image", unit="image"):
prediction_std_list = []

for index, data in enumerate(tqdm(dataloader, desc="Segment image", unit="image")):
image = data["image"].to(device)
# TODO: parametrize values below
roi_size = (192, 192)
Expand Down Expand Up @@ -248,6 +256,11 @@ def main():
# segmented_slice_ensemble = np.mean(segmented_slice_ensemble_all, axis=0)
if args.uncertainty:
segmented_slice_ensemble_std = np.std(segmented_slice_ensemble_all, axis=0)
# Calculate the average across non-null values
non_null_values = np.where(np.isnan(segmented_slice_ensemble_all), 0, segmented_slice_ensemble_all)
prediction_std = np.mean(non_null_values)
# print(f"Slice #{index}, Prediction STD: {prediction_std}")
prediction_std_list.append([index, prediction_std])
# Add the segmented slice to the list of segmented slices
segmented_slices.append(segmented_slice_ensemble_all_aggregated)
if args.uncertainty:
Expand Down Expand Up @@ -276,6 +289,13 @@ def main():

print(f"Done! Output file: {fname_out}")

# Write the data to the CSV file
output_file = 'prediction_std.csv'
with open(output_file, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Index', 'Prediction STD']) # Write the header
writer.writerows(prediction_std_list) # Write the data rows


if __name__ == "__main__":
main()
Expand Down