Skip to content

Commit 9e6888d

Browse files
committed
Merge branch 'refactor/quality_cst' into 'develop'
Refactor/quality_cst See merge request e040/e0404/pyRadPlan!84
2 parents 08b4ee6 + a867414 commit 9e6888d

File tree

3 files changed

+162
-124
lines changed

3 files changed

+162
-124
lines changed

pyRadPlan/cst/_cst.py

Lines changed: 116 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from typing import Any, Union
44
from typing_extensions import Self
5-
from pydantic import Field, model_validator
5+
from pydantic import (
6+
Field,
7+
model_validator,
8+
ValidationInfo,
9+
)
10+
611

712
import numpy as np
813
from scipy import ndimage
@@ -19,6 +24,108 @@ class StructureSet(PyRadPlanBaseModel):
1924
vois: list[VOI] = Field(init=False, description="List of VOIs in the Structure Set")
2025
ct_image: CT = Field(init=False, description="Reference to the CT Image")
2126

27+
@classmethod
28+
def _process_matrad_data(cls, data: list, info: ValidationInfo) -> dict:
29+
"""Handle data coming from matRad."""
30+
# If the data is from matRad, we need to handle the beam quantities differently.
31+
# The keys are usually named with a "_beam" suffix.
32+
voi_list = []
33+
cst_data = data["vois"]
34+
ct = data["ct_image"]
35+
36+
def get_idx_list(vdata_item):
37+
# Wrap a single array into a list if needed
38+
arr = vdata_item[3]
39+
# return only one scenario (3D) else: Multi-Scenario (4D)
40+
return (
41+
[arr.astype(int).tolist()]
42+
if not isinstance(arr, list)
43+
else [a.astype(int).tolist() for a in arr]
44+
)
45+
46+
def create_mask(idx):
47+
# Create the ITK mask for a given index list.
48+
tmp_mask = np.zeros((ct.size[2], ct.size[0], ct.size[1]), dtype=np.uint8)
49+
tmp_mask.flat[np.asarray(idx) - 1] = 1
50+
tmp_mask = np.swapaxes(tmp_mask, 1, 2)
51+
mask_image = sitk.GetImageFromArray(tmp_mask)
52+
mask_image.CopyInformation(ct.cube_hu)
53+
return mask_image
54+
55+
for vdata in cst_data:
56+
idx_list = get_idx_list(vdata)
57+
masks = [create_mask(idx) for idx in idx_list]
58+
59+
# For 4D, we need to join the masks. We also check here if the number of masks we could
60+
# extract matches the number of dimensions in the CT image
61+
dim = ct.cube_hu.GetDimension()
62+
if dim == 4:
63+
# First check if the mask is the same for all 4D scenarios
64+
if len(masks) == 1:
65+
masks = [masks[0]] * dim
66+
# Now do a sanity check that we don't have an incompatible number of masks
67+
if len(masks) != ct.cube_hu.GetSize()[3]:
68+
raise ValueError("Incompatible number of masks for 4D CT")
69+
masks = sitk.JoinSeries(*masks)
70+
# If it is a 3D CT, we just drop the list
71+
elif dim == 3:
72+
masks = masks[0]
73+
else:
74+
raise ValueError("Sanity Check failed -- unsupported CT dimensionality")
75+
76+
# Check Objectives
77+
objectives = vdata[5] if len(vdata) > 5 else []
78+
if not isinstance(objectives, list):
79+
objectives = [objectives]
80+
81+
voi = validate_voi(
82+
name=str(vdata[1]),
83+
voi_type=str(vdata[2]),
84+
mask=masks,
85+
ct_image=ct,
86+
objectives=objectives,
87+
)
88+
voi_list.append(voi)
89+
90+
return {"vois": voi_list, "ct_image": ct}
91+
92+
@model_validator(mode="before")
93+
@classmethod
94+
def aggregate_dynamic_quantities(cls, data: Any, info: ValidationInfo) -> Any:
95+
# Validate required keys.
96+
# Not needed since pydantic validation will take care of this.
97+
# But error prompt is in more detail.
98+
if data.get("vois") is None:
99+
raise ValueError("No cst provided. Please provide a cst.")
100+
if data.get("ct_image") is None:
101+
raise ValueError("No reference CT provided. Please provide a CT.")
102+
103+
data["ct_image"] = validate_ct(data["ct_image"])
104+
105+
# Handle the case where vois is supplied as a dict.
106+
if isinstance(data["vois"], dict):
107+
ct_from_vois = data["vois"].pop("ct_image", data["vois"].pop("ctImage", None))
108+
if ct_from_vois is not None and ct_from_vois != data["ct_image"]:
109+
raise ValueError("CT image mismatch between StructureSet and provided CT")
110+
return data
111+
112+
# Convert ndarray to list if needed.
113+
# needed for matRad cell array
114+
if isinstance(data["vois"], np.ndarray):
115+
data["vois"] = data["vois"].tolist()
116+
117+
# If vois is a list and not already a list of VOI dicts, process it
118+
# -> assume matrad data
119+
if (
120+
isinstance(data["vois"], list)
121+
and data["vois"]
122+
and not isinstance(data["vois"][0], dict)
123+
and not all(isinstance(i, VOI) for i in data["vois"])
124+
):
125+
data = cls._process_matrad_data(data, info)
126+
127+
return data
128+
22129
@model_validator(mode="after")
23130
def check_cst(self) -> Self:
24131
"""Check if the VOIs are valid and reference the same CT."""
@@ -106,9 +213,7 @@ def patient_voxels(self, order="sitk") -> np.ndarray:
106213
return np.unique(np.concatenate(patient_indices))
107214

108215
def patient_mask(self) -> sitk.Image:
109-
"""Return the union mask of all patient contours (or the EXTERNAL
110-
contour if provided).
111-
"""
216+
"""Return the union mask of all patient contours (or the EXTERNAL contour if provided)."""
112217

113218
patient_indices = self.patient_voxels(order="numpy")
114219

@@ -226,109 +331,16 @@ def create_cst(
226331
A StructureSet object created from the input data or keyword arguments.
227332
"""
228333

229-
# Check if already a valid model
230334
if isinstance(cst_data, StructureSet):
231-
if ct is not None and cst_data.ct_image != ct:
335+
if ct and cst_data.ct_image != validate_ct(ct):
232336
raise ValueError("CT image mismatch between StructureSet and provided CT")
233337
return cst_data
234-
235-
# validate ct if present
236-
if ct is not None:
237-
ct = validate_ct(ct)
238-
239-
# If already a model dictionary check the ct setup
240-
if isinstance(cst_data, dict):
241-
ct_image = cst_data.pop("ct_image", cst_data.pop("ctImage", None))
242-
if ct_image is not None:
243-
cst_data["ct_image"] = ct_image
244-
if ct is not None and ct != ct_image:
245-
raise ValueError("CT image mismatch between StructureSet and provided CT")
246-
elif ct is not None:
247-
cst_data["ct_image"] = ct
248-
else:
249-
raise ValueError("No CT image reference provided!")
250-
return StructureSet.model_validate(cst_data)
251-
252-
if cst_data is None and ct is not None: # If data is None
253-
return StructureSet(ct_image=ct, **kwargs)
254-
if cst_data is None and ct is None:
255-
return StructureSet(**kwargs)
256-
257-
# Other methods need the CT
258-
if ct is None:
259-
raise ValueError("No CT image reference provided!")
260-
261-
# Creation from an nd array (cell array matRad format)
262-
if isinstance(cst_data, np.ndarray) and cst_data.dtype == object:
263-
cst_data = cst_data.to_list()
264-
265-
# a list of volume information (e.g. imported from pymatreader from matRad mat file)
266-
if isinstance(cst_data, list):
267-
voi_list = []
268-
for vdata in cst_data:
269-
# First try to read the index lists
270-
idx_list = []
271-
# Only one scenario (3D CT)
272-
if not isinstance(vdata[3], list):
273-
idx_list.append(vdata[3].astype(int).tolist())
274-
# Multiple scenarios (4D CT)
275-
else:
276-
for vdata_scen in vdata[3]:
277-
idx_list.append(vdata_scen.astype(int).tolist())
278-
279-
# Now we create isimple ITK masks
280-
masks = []
281-
for idx in idx_list:
282-
# TODO: Check index ordering
283-
tmp_mask = np.zeros((ct.size[2], ct.size[0], ct.size[1]), dtype=np.uint8)
284-
tmp_mask.flat[np.asarray(idx) - 1] = 1
285-
tmp_mask = np.swapaxes(tmp_mask, 1, 2)
286-
mask_image = sitk.GetImageFromArray(tmp_mask)
287-
mask_image.CopyInformation(ct.cube_hu)
288-
289-
masks.append(mask_image)
290-
291-
# For 4D, we need to join the masks. We also check here if the number of masks we could
292-
# extract matches the number of dimensions in the CT image
293-
if ct.cube_hu.GetDimension() == 4:
294-
# First check if the mask is the same for all 4D scenarios
295-
if len(masks) == 1:
296-
masks = [masks[0] for _ in range(ct.cube_hu.GetDimension())]
297-
298-
# Now do a sanity check that we don't have an incompatible number of masks
299-
if len(masks) != ct.cube_hu.GetSize()[3]:
300-
raise ValueError("Incompatible number of masks for 4D CT")
301-
302-
masks = sitk.JoinSeries(*masks)
303-
304-
# If it is a 3D CT, we just drop the list
305-
elif ct.cube_hu.GetDimension() == 3:
306-
masks = masks[0]
307-
else:
308-
raise ValueError("Sanity Check failed -- unsupported CT dimensionality")
309-
310-
# Check Objectives
311-
if len(vdata) > 5:
312-
objectives = vdata[5]
313-
if not isinstance(objectives, list):
314-
objectives = [objectives]
315-
else:
316-
objectives = []
317-
318-
voi = validate_voi(
319-
name=str(vdata[1]),
320-
voi_type=str(vdata[2]),
321-
mask=masks,
322-
ct_image=ct,
323-
objectives=objectives,
324-
)
325-
326-
voi_list.append(voi)
327-
328-
cst_dict = {"vois": voi_list, "ct_image": ct}
329-
return StructureSet.model_validate(cst_dict)
330-
331-
raise ValueError("Invalid input data for creating a StructureSet.")
338+
elif isinstance(cst_data, dict):
339+
cst_data.update(kwargs)
340+
cst_data["ct_image"] = ct
341+
else:
342+
cst_data = {"vois": cst_data, "ct_image": ct, **kwargs}
343+
return StructureSet(**cst_data)
332344

333345

334346
def validate_cst(

pyRadPlan/cst/_voi.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class VOI(PyRadPlanBaseModel, ABC):
6161
@classmethod
6262
def validate_mask_type(cls, v: Any) -> Any:
6363
"""
64-
Validates the mask type.
64+
Validate the mask type.
6565
6666
Parameters
6767
----------
@@ -141,8 +141,7 @@ def validate_mask(self):
141141
@property
142142
def indices(self) -> np.ndarray:
143143
"""
144-
Returns the indices of the voxels in the mask using Fortran/SITK
145-
convention.
144+
Return the indices of the voxels in the mask using Fortran/SITK convention.
146145
147146
Returns
148147
-------
@@ -155,8 +154,7 @@ def indices(self) -> np.ndarray:
155154
@property
156155
def indices_numpy(self) -> np.ndarray:
157156
"""
158-
Returns the indices of the voxels in the mask using C/numpy
159-
convention.
157+
Return the indices of the voxels in the mask using C/numpy convention.
160158
161159
Returns
162160
-------
@@ -197,7 +195,7 @@ def num_of_scenarios(self) -> int:
197195

198196
def get_indices(self, order="sitk") -> np.ndarray:
199197
"""
200-
Returns the indices of the voxels in the mask.
198+
Return the indices of the voxels in the mask.
201199
202200
Parameters
203201
----------
@@ -217,7 +215,7 @@ def get_indices(self, order="sitk") -> np.ndarray:
217215

218216
def scenario_indices(self, order_type="numpy") -> Union[np.ndarray, list[np.ndarray]]:
219217
"""
220-
Returns the flattened indices of the individual scenarios.
218+
Return the flattened indices of the individual scenarios.
221219
222220
Parameters
223221
----------
@@ -249,8 +247,7 @@ def scenario_indices(self, order_type="numpy") -> Union[np.ndarray, list[np.ndar
249247

250248
def masked_ct(self, order_type="numpy") -> Union[sitk.Image, np.ndarray]:
251249
"""
252-
Returns the masked CT image, either as a numpy array or a SimpleITK
253-
image.
250+
Return the masked CT image, either as a numpy array or a SimpleITK image.
254251
255252
Parameters
256253
----------
@@ -309,8 +306,7 @@ def scenario_ct_data(self) -> Union[list[np.ndarray], np.ndarray]:
309306

310307
def to_matrad(self, context: str = "mat-file") -> Any:
311308
"""
312-
Creates an object that can be interpreted by matRad in the given
313-
context.
309+
Create an object that can be interpreted by matRad in the given context.
314310
315311
Returns
316312
-------
@@ -407,7 +403,7 @@ class OAR(VOI):
407403
@classmethod
408404
def validate_voi_type(cls, v: str) -> str:
409405
"""
410-
Validates the voi type for an OAR.
406+
Validate the voi type for an OAR.
411407
412408
Parameters
413409
----------
@@ -450,7 +446,7 @@ class Target(VOI):
450446
@classmethod
451447
def validate_voi_type(cls, v: str) -> str:
452448
"""
453-
Validates the voi type for a Target.
449+
Validate the voi type for a Target.
454450
455451
Parameters
456452
----------
@@ -493,7 +489,7 @@ class HelperVOI(VOI):
493489
@classmethod
494490
def validate_voi_type(cls, v: str) -> str:
495491
"""
496-
Validates the voi type for a HelperVOI.
492+
Validate the voi type for a HelperVOI.
497493
498494
Parameters
499495
----------
@@ -517,8 +513,7 @@ def validate_voi_type(cls, v: str) -> str:
517513

518514
class ExternalVOI(VOI):
519515
"""
520-
Represents an external contour limiting voxels to be considered for
521-
planning (EXTERNAL).
516+
Represents an external contour limiting voxels to be considered for planning (EXTERNAL).
522517
523518
Attributes
524519
----------
@@ -536,7 +531,7 @@ class ExternalVOI(VOI):
536531
@classmethod
537532
def validate_voi_type(cls, v: str) -> str:
538533
"""
539-
Validates the voi type for an EXTERNAL contour.
534+
Validate the voi type for an EXTERNAL contour.
540535
541536
Parameters
542537
----------
@@ -564,7 +559,7 @@ def validate_voi_type(cls, v: str) -> str:
564559

565560
def create_voi(data: Union[dict[str, Any], VOI, None] = None, **kwargs) -> VOI:
566561
"""
567-
Factory function to create a VOI object.
562+
Create a VOI object.
568563
569564
Parameters
570565
----------
@@ -602,7 +597,8 @@ def create_voi(data: Union[dict[str, Any], VOI, None] = None, **kwargs) -> VOI:
602597

603598
def validate_voi(data: Union[dict[str, Any], VOI, None] = None, **kwargs) -> VOI:
604599
"""
605-
Validates and creates a VOI object.
600+
Validate and create a VOI object.
601+
606602
Synonym to create_voi but should be used in validation context.
607603
608604
Parameters

0 commit comments

Comments
 (0)