11################################################################################
2- #
2+ #
33# This script is intended to compute the quanities of interest at a single visit
44# given the state data from the forward propagation.
5- #
5+ #
66# An example call to this script is:
77# mpirun -np <num_procs> python3 compute_qoi.py
88# --mesh /path/to/mesh/
99# --datafile /path/to/datafile/
1010# --refnii /path/to/reference.nii
1111# --roinii /path/to/roi.nii
1212# --outdir /path/to/store/results/
13- #
13+ #
1414# NOTE: Rasterization must be done in serial.
15- #
15+ #
1616################################################################################
1717
1818import os
3636from dt4co .utils .data_utils import nifti2Function , computeCarryingCapacity , rasterizeFunction , niftiPointwiseObservationOp
3737from dt4co .qoi import computeDice , computeTTV , computeVoxelCCC , computeTTC , computeVoxelDice
3838
39- def main (args )-> None :
39+
40+ def main (args ) -> None :
4041 ############################################################
4142 # 0. General setup.
4243 ############################################################
43- SEP = "\n " + "#" * 80 + "\n "
44-
44+ SEP = "\n " + "#" * 80 + "\n "
45+
4546 # Logging.
4647 dl .set_log_level (dl .LogLevel .WARNING )
47-
48+
4849 # MPI setup.
4950 COMM = MPI .COMM_WORLD
5051 rank = COMM .rank
5152 nproc = COMM .size
52-
53+
5354 # Paths for data.
5455 MESH_FPATH = args .mesh
5556 OUT_DIR = args .outdir
56-
57+
5758 SAMPLE_TYPE = args .sample_type
5859 THRESHOLD = args .threshold
59-
60+
6061 SAMPLES_FILE = args .samples
6162 NSAMPLES = args .nsamples
6263 VIDX = args .visit # what visit to compute QoIs for.
6364 VIDX = - 1 if VIDX is None else VIDX # if None, use the last visit.
64-
65+
6566 # make directory for QoI data (rasterized NIfTIs)
6667 QOI_DIR = os .path .join (OUT_DIR , f"qoi_{ SAMPLE_TYPE } " )
6768 os .makedirs (QOI_DIR , exist_ok = True )
68-
69+
6970 # Load in the patient data, get the information for the visit.
7071 pinfo = PatientData (args .pinfo , args .pdir )
7172 REF_NII = pinfo .get_visit (VIDX ).tumor
7273 ROI_NII = pinfo .get_visit (VIDX ).roi
73-
74+
7475 # Experiment setup.
7576 factory = ExperimentFactory (pinfo )
7677 exp = factory .get_experiment (args .experiment_type )
7778 root_print (COMM , f"Using experiment: { args .experiment_type } " )
7879 root_print (COMM , f"Experiment instance: { type (exp )} " )
79-
80+
8081 # which quantities of interest to compute.
8182 DO_DICE = args .dice
8283 DO_TTV = args .ttv
8384 DO_VOXQOI = args .vox
8485 DO_TTC = args .ttc
85-
86- # ------------------------------------------------------------
86+
87+ # ------------------------------------------------------------
8788 # Set up the experiment.
8889 # ------------------------------------------------------------
8990 root_print (COMM , SEP )
9091 root_print (COMM , f"Loading in the mesh..." )
9192 root_print (COMM , SEP )
92-
93+
9394 mesh , _ = exp .setupMesh (COMM , MESH_FPATH , zoff = False )
9495 # Set up variational spaces for state and parameter.
9596 Vh = exp .setupFunctionSpaces (mesh )
96-
97+
9798 # ------------------------------------------------------------
9899 # Read back states from file.
99100 # ------------------------------------------------------------
100-
101+
101102 root_print (COMM , f"Reading back data..." )
102103 root_print (COMM , SEP )
103-
104+
104105 helpfun = dl .Function (Vh [hp .STATE ])
105106 umv = hp .MultiVector (helpfun .vector (), NSAMPLES )
106107 start = time .perf_counter ()
107108 read_mv_from_h5 (COMM , umv , Vh [hp .STATE ], SAMPLES_FILE , name = "state" )
108109 ttime = time .perf_counter () - start
109110 root_print (COMM , f"Time to read data: { ttime :.2f} seconds." )
110-
111+
111112 # ------------------------------------------------------------
112113 # Compute QoIs.
113114 # ------------------------------------------------------------
114-
115+
115116 reffun = dl .Function (Vh [hp .STATE ])
116117 nifti2Function (REF_NII , reffun , Vh [hp .STATE ])
117-
118+
118119 if DO_DICE :
119120 root_print (COMM , "Computing Dice coefficient..." )
120121 start = time .perf_counter ()
121122 DICE = np .zeros ((NSAMPLES , 1 ))
122123 for i in range (NSAMPLES ):
123124 helpfun .vector ().zero ()
124- helpfun .vector ().axpy (1. , umv [i ])
125+ helpfun .vector ().axpy (1.0 , umv [i ])
125126 DICE [i ] = computeDice (helpfun , reffun , threshold = THRESHOLD )
126127 end = time .perf_counter ()
127128 root_print (COMM , f"Time to compute Dice coefficient for all data: { end - start :.2f} seconds." )
128-
129+
129130 if rank == 0 :
130131 np .save (os .path .join (OUT_DIR , f"{ SAMPLE_TYPE } _dice.npy" ), DICE )
131-
132+
132133 if DO_VOXQOI :
133134 start = time .perf_counter ()
134135 root_print (COMM , "Computing voxel-wise CCC and DICE..." )
135136 CCC = np .zeros ((NSAMPLES , 1 ))
136137 voxDICE = np .zeros ((NSAMPLES , 1 ))
137-
138+
138139 obsOp = niftiPointwiseObservationOp (REF_NII , Vh [hp .STATE ])
139-
140+
140141 for i in range (NSAMPLES ):
141142 RASTER_FILE = os .path .join (QOI_DIR , f"raster_{ i :06d} .nii" )
142143 helpfun .vector ().zero ()
143- helpfun .vector ().axpy (1. , umv [i ])
144-
144+ helpfun .vector ().axpy (1.0 , umv [i ])
145+
145146 # only rasterize if necessary.
146147 if not os .path .isfile (RASTER_FILE ):
147148 rasterizeFunction (helpfun , Vh [hp .STATE ], REF_NII , RASTER_FILE , obsOp = obsOp )
148-
149+
149150 if rank == 0 :
150151 # need to run these in serial so that the NIfTI read is not parallelized.
151152 # need to add to array when in parallel mode.
152153 CCC [i ] = computeVoxelCCC (RASTER_FILE , REF_NII , ROI_NII )
153154 voxDICE [i ] = computeVoxelDice (RASTER_FILE , REF_NII , threshold = THRESHOLD )
154-
155+
155156 end = time .perf_counter ()
156157 root_print (COMM , f"Time to compute CCC for all data: { end - start :.2f} seconds." )
157-
158+
158159 if rank == 0 :
159160 np .save (os .path .join (OUT_DIR , f"{ SAMPLE_TYPE } _ccc.npy" ), CCC )
160161 np .save (os .path .join (OUT_DIR , f"{ SAMPLE_TYPE } _voxdice.npy" ), voxDICE )
161-
162+
162163 if DO_TTC :
163164 root_print (COMM , "Computing TTC..." )
164165 start = time .perf_counter ()
@@ -167,62 +168,62 @@ def main(args)->None:
167168 TTC = np .zeros ((NSAMPLES , 1 ))
168169 for i in range (NSAMPLES ):
169170 helpfun .vector ().zero ()
170- helpfun .vector ().axpy (1. , umv [i ])
171+ helpfun .vector ().axpy (1.0 , umv [i ])
171172 TTC [i ] = computeTTC (helpfun , carry_cap , threshold = THRESHOLD )
172173 end = time .perf_counter ()
173174 root_print (COMM , f"Time to compute TTC for all data: { end - start :.2f} seconds." )
174-
175+
175176 if rank == 0 :
176177 np .save (os .path .join (OUT_DIR , f"{ SAMPLE_TYPE } _ttc.npy" ), TTC )
177178 np .save (os .path .join (OUT_DIR , f"ttc_true.npy" ), TTC_true )
178-
179+
179180 if DO_TTV :
180181 start = time .perf_counter ()
181182 root_print (COMM , "Computing normalized tumor volume..." )
182183 TTV = np .zeros ((NSAMPLES , 1 ))
183184 TTV_true = computeTTV (reffun , threshold = THRESHOLD ) # true value
184185 for i in range (NSAMPLES ):
185186 helpfun .vector ().zero ()
186- helpfun .vector ().axpy (1. , umv [i ])
187+ helpfun .vector ().axpy (1.0 , umv [i ])
187188 TTV [i ] = computeTTV (helpfun , threshold = THRESHOLD )
188189 end = time .perf_counter ()
189190 root_print (COMM , f"Time to compute TTV for all data: { end - start :.2f} seconds." )
190-
191+
191192 if rank == 0 :
192193 np .save (os .path .join (OUT_DIR , f"{ SAMPLE_TYPE } _ttv.npy" ), TTV )
193194 np .save (os .path .join (OUT_DIR , f"ttv_true.npy" ), TTV_true )
194195
195196
196197if __name__ == "__main__" :
197198 parser = argparse .ArgumentParser (description = "Compute quantities of interest for HGG data." )
198-
199+
199200 # data inputs.
200201 parser .add_argument ("--mesh" , type = str , required = True , help = "Path to the mesh file." )
201202 parser .add_argument ("--pdir" , type = str , help = "Path to the patient data directory." )
202203 parser .add_argument ("--pinfo" , type = str , help = "Path to the patient information file." )
203204 parser .add_argument ("--visit" , type = int , default = None , help = "Visit to compute QoIs for (use zero indexing)." )
204205 parser .add_argument ("--experiment_type" , type = str , required = True , choices = ["rd" , "rdtx" , "pwrdtx" ], help = "Type of experiment to run." )
205-
206+
206207 # propagation inputs.
207208 parser .add_argument ("--samples" , type = str , required = True , help = "Full path to the samples data file." )
208209 parser .add_argument ("--nsamples" , type = int , required = True , help = "Number of samples to use." )
209210 parser .add_argument ("--sample_type" , type = str , required = True , choices = ["prior" , "la_post" , "mcmc" , "map" , "prior_mean" ], help = "Type of samples file." )
210211 parser .add_argument ("--threshold" , type = float , required = True , help = "Threshold for observable tumor." )
211-
212+
212213 # Output options.
213214 parser .add_argument ("--outdir" , type = str , required = True , help = "Output directory." )
214-
215+
215216 # Input options.
216217 parser .add_argument ("--thresh" , type = float , default = 0.2 , help = "Threshold value for the state." )
217218 parser .add_argument ("--zoff" , type = float , default = None , help = "Z-offset for 2D meshes." )
218-
219+
219220 # Which quantities of interest to compute.
220221 parser .add_argument ("--dice" , action = argparse .BooleanOptionalAction , default = True , help = "Compute the Dice coefficient." )
221222 parser .add_argument ("--ttv" , action = argparse .BooleanOptionalAction , default = True , help = "Compute the total tumor volume." )
222223 parser .add_argument ("--vox" , action = argparse .BooleanOptionalAction , default = True , help = "Compute the voxel-wise QoIs." )
223224 parser .add_argument ("--ttc" , action = argparse .BooleanOptionalAction , default = True , help = "Compute the time to threshold." )
224-
225+
225226 # Parse the arguments, strip CLI args for PETSc.
226227 args , other = parser .parse_known_args ()
227-
228+
228229 main (args )
0 commit comments