Skip to content

Commit ef5a330

Browse files
authored
Add formatting (#2)
1 parent 79a2891 commit ef5a330

55 files changed

Lines changed: 2742 additions & 2740 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/black.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Black Code Style Check
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
black:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- name: Checkout code
14+
uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: '3.12'
20+
21+
- name: Install black
22+
run: pip install black
23+
24+
- name: Run black style check
25+
run: black --check .
26+

gbm/compute_qoi.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
################################################################################
2-
#
2+
#
33
# This script is intended to compute the quanities of interest at a single visit
44
# given the state data from the forward propagation.
5-
#
5+
#
66
# An example call to this script is:
77
# mpirun -np <num_procs> python3 compute_qoi.py
88
# --mesh /path/to/mesh/
99
# --datafile /path/to/datafile/
1010
# --refnii /path/to/reference.nii
1111
# --roinii /path/to/roi.nii
1212
# --outdir /path/to/store/results/
13-
#
13+
#
1414
# NOTE: Rasterization must be done in serial.
15-
#
15+
#
1616
################################################################################
1717

1818
import os
@@ -36,129 +36,130 @@
3636
from dt4co.utils.data_utils import nifti2Function, computeCarryingCapacity, rasterizeFunction, niftiPointwiseObservationOp
3737
from dt4co.qoi import computeDice, computeTTV, computeVoxelCCC, computeTTC, computeVoxelDice
3838

39-
def main(args)->None:
39+
40+
def main(args) -> None:
4041
############################################################
4142
# 0. General setup.
4243
############################################################
43-
SEP = "\n"+"#"*80+"\n"
44-
44+
SEP = "\n" + "#" * 80 + "\n"
45+
4546
# Logging.
4647
dl.set_log_level(dl.LogLevel.WARNING)
47-
48+
4849
# MPI setup.
4950
COMM = MPI.COMM_WORLD
5051
rank = COMM.rank
5152
nproc = COMM.size
52-
53+
5354
# Paths for data.
5455
MESH_FPATH = args.mesh
5556
OUT_DIR = args.outdir
56-
57+
5758
SAMPLE_TYPE = args.sample_type
5859
THRESHOLD = args.threshold
59-
60+
6061
SAMPLES_FILE = args.samples
6162
NSAMPLES = args.nsamples
6263
VIDX = args.visit # what visit to compute QoIs for.
6364
VIDX = -1 if VIDX is None else VIDX # if None, use the last visit.
64-
65+
6566
# make directory for QoI data (rasterized NIfTIs)
6667
QOI_DIR = os.path.join(OUT_DIR, f"qoi_{SAMPLE_TYPE}")
6768
os.makedirs(QOI_DIR, exist_ok=True)
68-
69+
6970
# Load in the patient data, get the information for the visit.
7071
pinfo = PatientData(args.pinfo, args.pdir)
7172
REF_NII = pinfo.get_visit(VIDX).tumor
7273
ROI_NII = pinfo.get_visit(VIDX).roi
73-
74+
7475
# Experiment setup.
7576
factory = ExperimentFactory(pinfo)
7677
exp = factory.get_experiment(args.experiment_type)
7778
root_print(COMM, f"Using experiment: {args.experiment_type}")
7879
root_print(COMM, f"Experiment instance: {type(exp)}")
79-
80+
8081
# which quantities of interest to compute.
8182
DO_DICE = args.dice
8283
DO_TTV = args.ttv
8384
DO_VOXQOI = args.vox
8485
DO_TTC = args.ttc
85-
86-
# ------------------------------------------------------------
86+
87+
# ------------------------------------------------------------
8788
# Set up the experiment.
8889
# ------------------------------------------------------------
8990
root_print(COMM, SEP)
9091
root_print(COMM, f"Loading in the mesh...")
9192
root_print(COMM, SEP)
92-
93+
9394
mesh, _ = exp.setupMesh(COMM, MESH_FPATH, zoff=False)
9495
# Set up variational spaces for state and parameter.
9596
Vh = exp.setupFunctionSpaces(mesh)
96-
97+
9798
# ------------------------------------------------------------
9899
# Read back states from file.
99100
# ------------------------------------------------------------
100-
101+
101102
root_print(COMM, f"Reading back data...")
102103
root_print(COMM, SEP)
103-
104+
104105
helpfun = dl.Function(Vh[hp.STATE])
105106
umv = hp.MultiVector(helpfun.vector(), NSAMPLES)
106107
start = time.perf_counter()
107108
read_mv_from_h5(COMM, umv, Vh[hp.STATE], SAMPLES_FILE, name="state")
108109
ttime = time.perf_counter() - start
109110
root_print(COMM, f"Time to read data: {ttime:.2f} seconds.")
110-
111+
111112
# ------------------------------------------------------------
112113
# Compute QoIs.
113114
# ------------------------------------------------------------
114-
115+
115116
reffun = dl.Function(Vh[hp.STATE])
116117
nifti2Function(REF_NII, reffun, Vh[hp.STATE])
117-
118+
118119
if DO_DICE:
119120
root_print(COMM, "Computing Dice coefficient...")
120121
start = time.perf_counter()
121122
DICE = np.zeros((NSAMPLES, 1))
122123
for i in range(NSAMPLES):
123124
helpfun.vector().zero()
124-
helpfun.vector().axpy(1., umv[i])
125+
helpfun.vector().axpy(1.0, umv[i])
125126
DICE[i] = computeDice(helpfun, reffun, threshold=THRESHOLD)
126127
end = time.perf_counter()
127128
root_print(COMM, f"Time to compute Dice coefficient for all data: {end-start:.2f} seconds.")
128-
129+
129130
if rank == 0:
130131
np.save(os.path.join(OUT_DIR, f"{SAMPLE_TYPE}_dice.npy"), DICE)
131-
132+
132133
if DO_VOXQOI:
133134
start = time.perf_counter()
134135
root_print(COMM, "Computing voxel-wise CCC and DICE...")
135136
CCC = np.zeros((NSAMPLES, 1))
136137
voxDICE = np.zeros((NSAMPLES, 1))
137-
138+
138139
obsOp = niftiPointwiseObservationOp(REF_NII, Vh[hp.STATE])
139-
140+
140141
for i in range(NSAMPLES):
141142
RASTER_FILE = os.path.join(QOI_DIR, f"raster_{i:06d}.nii")
142143
helpfun.vector().zero()
143-
helpfun.vector().axpy(1., umv[i])
144-
144+
helpfun.vector().axpy(1.0, umv[i])
145+
145146
# only rasterize if necessary.
146147
if not os.path.isfile(RASTER_FILE):
147148
rasterizeFunction(helpfun, Vh[hp.STATE], REF_NII, RASTER_FILE, obsOp=obsOp)
148-
149+
149150
if rank == 0:
150151
# need to run these in serial so that the NIfTI read is not parallelized.
151152
# need to add to array when in parallel mode.
152153
CCC[i] = computeVoxelCCC(RASTER_FILE, REF_NII, ROI_NII)
153154
voxDICE[i] = computeVoxelDice(RASTER_FILE, REF_NII, threshold=THRESHOLD)
154-
155+
155156
end = time.perf_counter()
156157
root_print(COMM, f"Time to compute CCC for all data: {end-start:.2f} seconds.")
157-
158+
158159
if rank == 0:
159160
np.save(os.path.join(OUT_DIR, f"{SAMPLE_TYPE}_ccc.npy"), CCC)
160161
np.save(os.path.join(OUT_DIR, f"{SAMPLE_TYPE}_voxdice.npy"), voxDICE)
161-
162+
162163
if DO_TTC:
163164
root_print(COMM, "Computing TTC...")
164165
start = time.perf_counter()
@@ -167,62 +168,62 @@ def main(args)->None:
167168
TTC = np.zeros((NSAMPLES, 1))
168169
for i in range(NSAMPLES):
169170
helpfun.vector().zero()
170-
helpfun.vector().axpy(1., umv[i])
171+
helpfun.vector().axpy(1.0, umv[i])
171172
TTC[i] = computeTTC(helpfun, carry_cap, threshold=THRESHOLD)
172173
end = time.perf_counter()
173174
root_print(COMM, f"Time to compute TTC for all data: {end-start:.2f} seconds.")
174-
175+
175176
if rank == 0:
176177
np.save(os.path.join(OUT_DIR, f"{SAMPLE_TYPE}_ttc.npy"), TTC)
177178
np.save(os.path.join(OUT_DIR, f"ttc_true.npy"), TTC_true)
178-
179+
179180
if DO_TTV:
180181
start = time.perf_counter()
181182
root_print(COMM, "Computing normalized tumor volume...")
182183
TTV = np.zeros((NSAMPLES, 1))
183184
TTV_true = computeTTV(reffun, threshold=THRESHOLD) # true value
184185
for i in range(NSAMPLES):
185186
helpfun.vector().zero()
186-
helpfun.vector().axpy(1., umv[i])
187+
helpfun.vector().axpy(1.0, umv[i])
187188
TTV[i] = computeTTV(helpfun, threshold=THRESHOLD)
188189
end = time.perf_counter()
189190
root_print(COMM, f"Time to compute TTV for all data: {end-start:.2f} seconds.")
190-
191+
191192
if rank == 0:
192193
np.save(os.path.join(OUT_DIR, f"{SAMPLE_TYPE}_ttv.npy"), TTV)
193194
np.save(os.path.join(OUT_DIR, f"ttv_true.npy"), TTV_true)
194195

195196

196197
if __name__ == "__main__":
197198
parser = argparse.ArgumentParser(description="Compute quantities of interest for HGG data.")
198-
199+
199200
# data inputs.
200201
parser.add_argument("--mesh", type=str, required=True, help="Path to the mesh file.")
201202
parser.add_argument("--pdir", type=str, help="Path to the patient data directory.")
202203
parser.add_argument("--pinfo", type=str, help="Path to the patient information file.")
203204
parser.add_argument("--visit", type=int, default=None, help="Visit to compute QoIs for (use zero indexing).")
204205
parser.add_argument("--experiment_type", type=str, required=True, choices=["rd", "rdtx", "pwrdtx"], help="Type of experiment to run.")
205-
206+
206207
# propagation inputs.
207208
parser.add_argument("--samples", type=str, required=True, help="Full path to the samples data file.")
208209
parser.add_argument("--nsamples", type=int, required=True, help="Number of samples to use.")
209210
parser.add_argument("--sample_type", type=str, required=True, choices=["prior", "la_post", "mcmc", "map", "prior_mean"], help="Type of samples file.")
210211
parser.add_argument("--threshold", type=float, required=True, help="Threshold for observable tumor.")
211-
212+
212213
# Output options.
213214
parser.add_argument("--outdir", type=str, required=True, help="Output directory.")
214-
215+
215216
# Input options.
216217
parser.add_argument("--thresh", type=float, default=0.2, help="Threshold value for the state.")
217218
parser.add_argument("--zoff", type=float, default=None, help="Z-offset for 2D meshes.")
218-
219+
219220
# Which quantities of interest to compute.
220221
parser.add_argument("--dice", action=argparse.BooleanOptionalAction, default=True, help="Compute the Dice coefficient.")
221222
parser.add_argument("--ttv", action=argparse.BooleanOptionalAction, default=True, help="Compute the total tumor volume.")
222223
parser.add_argument("--vox", action=argparse.BooleanOptionalAction, default=True, help="Compute the voxel-wise QoIs.")
223224
parser.add_argument("--ttc", action=argparse.BooleanOptionalAction, default=True, help="Compute the time to threshold.")
224-
225+
225226
# Parse the arguments, strip CLI args for PETSc.
226227
args, other = parser.parse_known_args()
227-
228+
228229
main(args)

0 commit comments

Comments
 (0)