22
33from typing import Any , Union
44from typing_extensions import Self
5- from pydantic import Field , model_validator
5+ from pydantic import (
6+ Field ,
7+ model_validator ,
8+ ValidationInfo ,
9+ )
10+
611
712import numpy as np
813from scipy import ndimage
@@ -19,6 +24,108 @@ class StructureSet(PyRadPlanBaseModel):
1924 vois : list [VOI ] = Field (init = False , description = "List of VOIs in the Structure Set" )
2025 ct_image : CT = Field (init = False , description = "Reference to the CT Image" )
2126
27+ @classmethod
28+ def _process_matrad_data (cls , data : list , info : ValidationInfo ) -> dict :
29+ """Handle data coming from matRad."""
30+ # If the data is from matRad, we need to handle the beam quantities differently.
31+ # The keys are usually named with a "_beam" suffix.
32+ voi_list = []
33+ cst_data = data ["vois" ]
34+ ct = data ["ct_image" ]
35+
36+ def get_idx_list (vdata_item ):
37+ # Wrap a single array into a list if needed
38+ arr = vdata_item [3 ]
39+ # return only one scenario (3D) else: Multi-Scenario (4D)
40+ return (
41+ [arr .astype (int ).tolist ()]
42+ if not isinstance (arr , list )
43+ else [a .astype (int ).tolist () for a in arr ]
44+ )
45+
46+ def create_mask (idx ):
47+ # Create the ITK mask for a given index list.
48+ tmp_mask = np .zeros ((ct .size [2 ], ct .size [0 ], ct .size [1 ]), dtype = np .uint8 )
49+ tmp_mask .flat [np .asarray (idx ) - 1 ] = 1
50+ tmp_mask = np .swapaxes (tmp_mask , 1 , 2 )
51+ mask_image = sitk .GetImageFromArray (tmp_mask )
52+ mask_image .CopyInformation (ct .cube_hu )
53+ return mask_image
54+
55+ for vdata in cst_data :
56+ idx_list = get_idx_list (vdata )
57+ masks = [create_mask (idx ) for idx in idx_list ]
58+
59+ # For 4D, we need to join the masks. We also check here if the number of masks we could
60+ # extract matches the number of dimensions in the CT image
61+ dim = ct .cube_hu .GetDimension ()
62+ if dim == 4 :
63+ # First check if the mask is the same for all 4D scenarios
64+ if len (masks ) == 1 :
65+ masks = [masks [0 ]] * dim
66+ # Now do a sanity check that we don't have an incompatible number of masks
67+ if len (masks ) != ct .cube_hu .GetSize ()[3 ]:
68+ raise ValueError ("Incompatible number of masks for 4D CT" )
69+ masks = sitk .JoinSeries (* masks )
70+ # If it is a 3D CT, we just drop the list
71+ elif dim == 3 :
72+ masks = masks [0 ]
73+ else :
74+ raise ValueError ("Sanity Check failed -- unsupported CT dimensionality" )
75+
76+ # Check Objectives
77+ objectives = vdata [5 ] if len (vdata ) > 5 else []
78+ if not isinstance (objectives , list ):
79+ objectives = [objectives ]
80+
81+ voi = validate_voi (
82+ name = str (vdata [1 ]),
83+ voi_type = str (vdata [2 ]),
84+ mask = masks ,
85+ ct_image = ct ,
86+ objectives = objectives ,
87+ )
88+ voi_list .append (voi )
89+
90+ return {"vois" : voi_list , "ct_image" : ct }
91+
92+ @model_validator (mode = "before" )
93+ @classmethod
94+ def aggregate_dynamic_quantities (cls , data : Any , info : ValidationInfo ) -> Any :
95+ # Validate required keys.
96+ # Not needed since pydantic validation will take care of this.
97+ # But error prompt is in more detail.
98+ if data .get ("vois" ) is None :
99+ raise ValueError ("No cst provided. Please provide a cst." )
100+ if data .get ("ct_image" ) is None :
101+ raise ValueError ("No reference CT provided. Please provide a CT." )
102+
103+ data ["ct_image" ] = validate_ct (data ["ct_image" ])
104+
105+ # Handle the case where vois is supplied as a dict.
106+ if isinstance (data ["vois" ], dict ):
107+ ct_from_vois = data ["vois" ].pop ("ct_image" , data ["vois" ].pop ("ctImage" , None ))
108+ if ct_from_vois is not None and ct_from_vois != data ["ct_image" ]:
109+ raise ValueError ("CT image mismatch between StructureSet and provided CT" )
110+ return data
111+
112+ # Convert ndarray to list if needed.
113+ # needed for matRad cell array
114+ if isinstance (data ["vois" ], np .ndarray ):
115+ data ["vois" ] = data ["vois" ].tolist ()
116+
117+ # If vois is a list and not already a list of VOI dicts, process it
118+ # -> assume matrad data
119+ if (
120+ isinstance (data ["vois" ], list )
121+ and data ["vois" ]
122+ and not isinstance (data ["vois" ][0 ], dict )
123+ and not all (isinstance (i , VOI ) for i in data ["vois" ])
124+ ):
125+ data = cls ._process_matrad_data (data , info )
126+
127+ return data
128+
22129 @model_validator (mode = "after" )
23130 def check_cst (self ) -> Self :
24131 """Check if the VOIs are valid and reference the same CT."""
@@ -106,9 +213,7 @@ def patient_voxels(self, order="sitk") -> np.ndarray:
106213 return np .unique (np .concatenate (patient_indices ))
107214
108215 def patient_mask (self ) -> sitk .Image :
109- """Return the union mask of all patient contours (or the EXTERNAL
110- contour if provided).
111- """
216+ """Return the union mask of all patient contours (or the EXTERNAL contour if provided)."""
112217
113218 patient_indices = self .patient_voxels (order = "numpy" )
114219
@@ -226,109 +331,16 @@ def create_cst(
226331 A StructureSet object created from the input data or keyword arguments.
227332 """
228333
229- # Check if already a valid model
230334 if isinstance (cst_data , StructureSet ):
231- if ct is not None and cst_data .ct_image != ct :
335+ if ct and cst_data .ct_image != validate_ct ( ct ) :
232336 raise ValueError ("CT image mismatch between StructureSet and provided CT" )
233337 return cst_data
234-
235- # validate ct if present
236- if ct is not None :
237- ct = validate_ct (ct )
238-
239- # If already a model dictionary check the ct setup
240- if isinstance (cst_data , dict ):
241- ct_image = cst_data .pop ("ct_image" , cst_data .pop ("ctImage" , None ))
242- if ct_image is not None :
243- cst_data ["ct_image" ] = ct_image
244- if ct is not None and ct != ct_image :
245- raise ValueError ("CT image mismatch between StructureSet and provided CT" )
246- elif ct is not None :
247- cst_data ["ct_image" ] = ct
248- else :
249- raise ValueError ("No CT image reference provided!" )
250- return StructureSet .model_validate (cst_data )
251-
252- if cst_data is None and ct is not None : # If data is None
253- return StructureSet (ct_image = ct , ** kwargs )
254- if cst_data is None and ct is None :
255- return StructureSet (** kwargs )
256-
257- # Other methods need the CT
258- if ct is None :
259- raise ValueError ("No CT image reference provided!" )
260-
261- # Creation from an nd array (cell array matRad format)
262- if isinstance (cst_data , np .ndarray ) and cst_data .dtype == object :
263- cst_data = cst_data .to_list ()
264-
265- # a list of volume information (e.g. imported from pymatreader from matRad mat file)
266- if isinstance (cst_data , list ):
267- voi_list = []
268- for vdata in cst_data :
269- # First try to read the index lists
270- idx_list = []
271- # Only one scenario (3D CT)
272- if not isinstance (vdata [3 ], list ):
273- idx_list .append (vdata [3 ].astype (int ).tolist ())
274- # Multiple scenarios (4D CT)
275- else :
276- for vdata_scen in vdata [3 ]:
277- idx_list .append (vdata_scen .astype (int ).tolist ())
278-
279- # Now we create isimple ITK masks
280- masks = []
281- for idx in idx_list :
282- # TODO: Check index ordering
283- tmp_mask = np .zeros ((ct .size [2 ], ct .size [0 ], ct .size [1 ]), dtype = np .uint8 )
284- tmp_mask .flat [np .asarray (idx ) - 1 ] = 1
285- tmp_mask = np .swapaxes (tmp_mask , 1 , 2 )
286- mask_image = sitk .GetImageFromArray (tmp_mask )
287- mask_image .CopyInformation (ct .cube_hu )
288-
289- masks .append (mask_image )
290-
291- # For 4D, we need to join the masks. We also check here if the number of masks we could
292- # extract matches the number of dimensions in the CT image
293- if ct .cube_hu .GetDimension () == 4 :
294- # First check if the mask is the same for all 4D scenarios
295- if len (masks ) == 1 :
296- masks = [masks [0 ] for _ in range (ct .cube_hu .GetDimension ())]
297-
298- # Now do a sanity check that we don't have an incompatible number of masks
299- if len (masks ) != ct .cube_hu .GetSize ()[3 ]:
300- raise ValueError ("Incompatible number of masks for 4D CT" )
301-
302- masks = sitk .JoinSeries (* masks )
303-
304- # If it is a 3D CT, we just drop the list
305- elif ct .cube_hu .GetDimension () == 3 :
306- masks = masks [0 ]
307- else :
308- raise ValueError ("Sanity Check failed -- unsupported CT dimensionality" )
309-
310- # Check Objectives
311- if len (vdata ) > 5 :
312- objectives = vdata [5 ]
313- if not isinstance (objectives , list ):
314- objectives = [objectives ]
315- else :
316- objectives = []
317-
318- voi = validate_voi (
319- name = str (vdata [1 ]),
320- voi_type = str (vdata [2 ]),
321- mask = masks ,
322- ct_image = ct ,
323- objectives = objectives ,
324- )
325-
326- voi_list .append (voi )
327-
328- cst_dict = {"vois" : voi_list , "ct_image" : ct }
329- return StructureSet .model_validate (cst_dict )
330-
331- raise ValueError ("Invalid input data for creating a StructureSet." )
338+ elif isinstance (cst_data , dict ):
339+ cst_data .update (kwargs )
340+ cst_data ["ct_image" ] = ct
341+ else :
342+ cst_data = {"vois" : cst_data , "ct_image" : ct , ** kwargs }
343+ return StructureSet (** cst_data )
332344
333345
334346def validate_cst (
0 commit comments