@@ -50,10 +50,7 @@ class CT(PyRadPlanBaseModel, ABC):
5050
5151 @model_validator (mode = "before" )
5252 @classmethod
53- def validate_cube_hu (
54- cls ,
55- data : Any ,
56- ) -> Any :
53+ def validate_cube_hu (cls , data : Any ) -> Any :
5754 """
5855 Validate and convert input data to SimpleITK image format.
5956
@@ -77,94 +74,161 @@ def validate_cube_hu(
7774 ValueError
7875 If the HU cube is not present in the input dictionary.
7976 """
80-
77+ # Return CT objects unchanged
8178 if isinstance (data , CT ):
8279 return data
8380
84- if isinstance (data , dict ): # check if data is dict
85- # Check if the cube is contained in the data (either as field name or its alias)
86- cube_hu = data .get ("cube_hu" , None )
87- if cube_hu is None :
88- cube_fieldname = cls .model_fields .get ("cube_hu" ).validation_alias
89- cube_hu = data .pop (cube_fieldname , None )
90-
91- if cube_hu is None :
92- raise ValueError ("HU cube not present in dictionary." )
93-
94- if isinstance (cube_hu , sitk .Image ) and "origin" not in data :
95- data ["origin" ] = cube_hu .GetOrigin ()
96-
97- if isinstance (cube_hu , np .ndarray ):
98- # Now we do check if it is matRad data
99- if data .get ("cubeDim" , None ) is not None :
100- permute = True
101- else :
102- permute = False
103-
104- # We have an array of multiple CT scenarios as numpy arrays
105- if cube_hu .dtype == object :
106- num_cubes = len (cube_hu )
107-
108- # If there are multiple cubes, we need to store them in a list
109- # for later use
110-
111- if permute :
112- ct_scenarios = [
113- sitk .GetImageFromArray (np .transpose (cube_hu [i ], (2 , 0 , 1 )), False )
114- for i in range (num_cubes )
115- ]
116- else :
117- ct_scenarios = [
118- sitk .GetImageFromArray (cube_hu [i ], False ) for i in range (num_cubes )
119- ]
120-
121- cube_hu = sitk .JoinSeries (ct_scenarios )
122- else :
123- if permute :
124- cube_hu = sitk .GetImageFromArray (np .transpose (cube_hu , (2 , 0 , 1 )), False )
125- else :
126- cube_hu = sitk .GetImageFromArray (cube_hu , False )
127- num_cubes = 1
128-
129- # Reassign the updated cube_hu to the data dictionary
130- data ["cube_hu" ] = cube_hu
131-
132- # From here on we want an sitk.Image
133- if not isinstance (data ["cube_hu" ], sitk .Image ):
134- raise ValueError (f"Unsupported format of HU cube: { type (data ['cube_hu' ])} " )
135-
136- is4d = cube_hu .GetDimension () == 4
137- # Now we parse the rest of the data if additional data is available in the dictionary
138-
139- if "direction" in data :
140- data ["cube_hu" ].SetDirection (data ["direction" ])
141-
142- if "resolution" in data and all (key in data ["resolution" ] for key in ("x" , "y" , "z" )):
143- resolution = data ["resolution" ]
144- if cube_hu .GetDimension () == 4 :
145- cube_hu .SetSpacing ([resolution ["x" ], resolution ["y" ], resolution ["z" ], 1.0 ])
146- else :
147- cube_hu .SetSpacing ([resolution ["x" ], resolution ["y" ], resolution ["z" ]])
148-
149- # TODO: Either set up a Grid or check for x/y/z to do the data management
150- if "origin" in data :
151- data ["cube_hu" ].SetOrigin (data ["origin" ])
152-
153- else :
154- data ["cube_hu" ].SetOrigin (
155- - np .array (data ["cube_hu" ].GetSize ())
156- / 2.0
157- * np .array (data ["cube_hu" ].GetSpacing ())
158- )
159- if is4d :
160- data ["cube_hu" ].SetOrigin (np .append (data ["cube_hu" ].GetOrigin (), [0.0 ]))
161- # elif all(key in data for key in ("x", "y", "z")):
162- # origin = np.array([data["x"][0], data["y"][0], data["z"][0]], dtype=float)
163- # if is4d:
164- # origin = np.append(origin, [0.0])
165- # data["cube_hu"].SetOrigin(origin)
81+ # Process dictionary input
82+ if not isinstance (data , dict ):
83+ return data
84+
85+ # Extract cube_hu from data using field name or alias
86+ cube_hu = cls ._extract_cube_hu (data )
87+
88+ # Auto-extract origin from SimpleITK images if not provided
89+ cls ._extract_origin_from_sitk_image (data , cube_hu )
90+
91+ # Make sure that we have a SimpleITK image
92+ if isinstance (cube_hu , np .ndarray ):
93+ cube_hu = cls ._convert_numpy_to_sitk (data , cube_hu )
94+
95+ # Update data with processed cube_hu
96+ data ["cube_hu" ] = cube_hu
97+
98+ # Validate final format (if some other type than numpy/sitk is provided)
99+ cls ._validate_sitk_format (data )
100+
101+ # Apply image properties
102+ cls ._apply_image_properties (data , cube_hu )
103+
166104 return data
167105
106+ @classmethod
107+ def _extract_cube_hu (cls , data : dict ) -> Any :
108+ """Extract cube_hu from data dictionary using field name or alias."""
109+ cube_hu = data .get ("cube_hu" , None )
110+ if cube_hu is None :
111+ cube_fieldname = cls .model_fields .get ("cube_hu" ).validation_alias
112+ cube_hu = data .pop (cube_fieldname , None )
113+
114+ if cube_hu is None :
115+ raise ValueError ("HU cube not present in dictionary." )
116+
117+ return cube_hu
118+
119+ @classmethod
120+ def _extract_origin_from_sitk_image (cls , data : dict , cube_hu : Any ) -> None :
121+ """Extract origin from SimpleITK image if not already provided."""
122+ if isinstance (cube_hu , sitk .Image ) and "origin" not in data :
123+ data ["origin" ] = cube_hu .GetOrigin ()
124+
125+ @classmethod
126+ def _convert_numpy_to_sitk (cls , data : dict , cube_hu : np .ndarray ) -> sitk .Image :
127+ """Convert numpy array(s) to SimpleITK image(s)."""
128+ # Determine if data needs axis permutation (matRad format)
129+ should_permute = data .get ("cubeDim" , None ) is not None
130+
131+ # Handle multiple CT scenarios (object array)
132+ if cube_hu .dtype == object :
133+ return cls ._convert_multiple_scenarios (cube_hu , should_permute )
134+
135+ # Handle single CT scenario
136+ return cls ._convert_single_scenario (cube_hu , should_permute )
137+
138+ @classmethod
139+ def _convert_multiple_scenarios (cls , cube_hu : np .ndarray , should_permute : bool ) -> sitk .Image :
140+ """Convert multiple CT scenarios to joined SimpleITK image."""
141+ num_cubes = len (cube_hu )
142+
143+ if should_permute :
144+ ct_scenarios = [
145+ sitk .GetImageFromArray (np .transpose (cube_hu [i ], (2 , 0 , 1 )), False )
146+ for i in range (num_cubes )
147+ ]
148+ else :
149+ ct_scenarios = [sitk .GetImageFromArray (cube_hu [i ], False ) for i in range (num_cubes )]
150+
151+ return sitk .JoinSeries (ct_scenarios )
152+
153+ @classmethod
154+ def _convert_single_scenario (cls , cube_hu : np .ndarray , should_permute : bool ) -> sitk .Image :
155+ """Convert single CT scenario to SimpleITK image."""
156+ if should_permute :
157+ return sitk .GetImageFromArray (np .transpose (cube_hu , (2 , 0 , 1 )), False )
158+ else :
159+ return sitk .GetImageFromArray (cube_hu , False )
160+
161+ @classmethod
162+ def _validate_sitk_format (cls , data : dict ) -> None :
163+ """Validate that cube_hu is a SimpleITK image."""
164+ if not isinstance (data ["cube_hu" ], sitk .Image ):
165+ raise ValueError (f"Unsupported format of HU cube: { type (data ['cube_hu' ])} " )
166+
167+ @classmethod
168+ def _apply_image_properties (cls , data : dict , cube_hu : sitk .Image ) -> None :
169+ """Apply direction, spacing, and origin properties to the SimpleITK image."""
170+ is4d = cube_hu .GetDimension () == 4
171+
172+ # Apply direction
173+ if "direction" in data :
174+ data ["cube_hu" ].SetDirection (data ["direction" ])
175+
176+ # Apply spacing from resolution
177+ cls ._apply_spacing (data , cube_hu , is4d )
178+
179+ # Apply origin (with three different strategies)
180+ cls ._apply_origin (data , cube_hu , is4d )
181+
182+ @classmethod
183+ def _apply_spacing (cls , data : dict , cube_hu : sitk .Image , is4d : bool ) -> None :
184+ """Apply spacing from resolution data."""
185+ if "resolution" not in data :
186+ return
187+
188+ resolution = data ["resolution" ]
189+ if not all (key in resolution for key in ("x" , "y" , "z" )):
190+ return
191+
192+ if is4d :
193+ spacing = [resolution ["x" ], resolution ["y" ], resolution ["z" ], 1.0 ]
194+ else :
195+ spacing = [resolution ["x" ], resolution ["y" ], resolution ["z" ]]
196+
197+ cube_hu .SetSpacing (spacing )
198+
199+ @classmethod
200+ def _apply_origin (cls , data : dict , cube_hu : sitk .Image , is4d : bool ) -> None :
201+ """Apply origin to the SimpleITK image using three different strategies."""
202+ if "origin" in data :
203+ # Strategy 1: Use explicitly provided origin
204+ data ["cube_hu" ].SetOrigin (data ["origin" ])
205+ elif all (key in data for key in ("x" , "y" , "z" )):
206+ # Strategy 2: Calculate origin from x, y, z coordinate vectors
207+ cls ._apply_origin_from_coordinate_vectors (data , is4d )
208+ else :
209+ # Strategy 3: Calculate centered origin based on image geometry
210+ cls ._apply_centered_origin (data , is4d )
211+
212+ @classmethod
213+ def _apply_origin_from_coordinate_vectors (cls , data : dict , is4d : bool ) -> None :
214+ """Calculate and apply origin from x, y, z coordinate vectors."""
215+ origin = np .array ([data ["x" ][0 ], data ["y" ][0 ], data ["z" ][0 ]], dtype = float )
216+ if is4d :
217+ origin = np .append (origin , [0.0 ])
218+ data ["cube_hu" ].SetOrigin (origin )
219+
220+ @classmethod
221+ def _apply_centered_origin (cls , data : dict , is4d : bool ) -> None :
222+ """Calculate and apply centered origin based on image geometry."""
223+ centered_origin = (
224+ - np .array (data ["cube_hu" ].GetSize ()) / 2.0 * np .array (data ["cube_hu" ].GetSpacing ())
225+ )
226+
227+ if is4d :
228+ centered_origin = np .append (centered_origin , [0.0 ])
229+
230+ data ["cube_hu" ].SetOrigin (centered_origin )
231+
168232 @computed_field
169233 @property
170234 def resolution (self ) -> dict :
0 commit comments