Skip to content

Commit 08b4ee6

Browse files
committed
Merge branch 'refactor/quality_ct' into 'develop'
Refactor/quality_ct See merge request e040/e0404/pyRadPlan!85
2 parents 6245fb5 + 0e026b9 commit 08b4ee6

File tree

1 file changed

+151
-87
lines changed

1 file changed

+151
-87
lines changed

pyRadPlan/ct/_ct.py

Lines changed: 151 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ class CT(PyRadPlanBaseModel, ABC):
5050

5151
@model_validator(mode="before")
5252
@classmethod
53-
def validate_cube_hu(
54-
cls,
55-
data: Any,
56-
) -> Any:
53+
def validate_cube_hu(cls, data: Any) -> Any:
5754
"""
5855
Validate and convert input data to SimpleITK image format.
5956
@@ -77,94 +74,161 @@ def validate_cube_hu(
7774
ValueError
7875
If the HU cube is not present in the input dictionary.
7976
"""
80-
77+
# Return CT objects unchanged
8178
if isinstance(data, CT):
8279
return data
8380

84-
if isinstance(data, dict): # check if data is dict
85-
# Check if the cube is contained in the data (either as field name or its alias)
86-
cube_hu = data.get("cube_hu", None)
87-
if cube_hu is None:
88-
cube_fieldname = cls.model_fields.get("cube_hu").validation_alias
89-
cube_hu = data.pop(cube_fieldname, None)
90-
91-
if cube_hu is None:
92-
raise ValueError("HU cube not present in dictionary.")
93-
94-
if isinstance(cube_hu, sitk.Image) and "origin" not in data:
95-
data["origin"] = cube_hu.GetOrigin()
96-
97-
if isinstance(cube_hu, np.ndarray):
98-
# Now we do check if it is matRad data
99-
if data.get("cubeDim", None) is not None:
100-
permute = True
101-
else:
102-
permute = False
103-
104-
# We have an array of multiple CT scenarios as numpy arrays
105-
if cube_hu.dtype == object:
106-
num_cubes = len(cube_hu)
107-
108-
# If there are multiple cubes, we need to store them in a list
109-
# for later use
110-
111-
if permute:
112-
ct_scenarios = [
113-
sitk.GetImageFromArray(np.transpose(cube_hu[i], (2, 0, 1)), False)
114-
for i in range(num_cubes)
115-
]
116-
else:
117-
ct_scenarios = [
118-
sitk.GetImageFromArray(cube_hu[i], False) for i in range(num_cubes)
119-
]
120-
121-
cube_hu = sitk.JoinSeries(ct_scenarios)
122-
else:
123-
if permute:
124-
cube_hu = sitk.GetImageFromArray(np.transpose(cube_hu, (2, 0, 1)), False)
125-
else:
126-
cube_hu = sitk.GetImageFromArray(cube_hu, False)
127-
num_cubes = 1
128-
129-
# Reassign the updated cube_hu to the data dictionary
130-
data["cube_hu"] = cube_hu
131-
132-
# From here on we want an sitk.Image
133-
if not isinstance(data["cube_hu"], sitk.Image):
134-
raise ValueError(f"Unsupported format of HU cube: {type(data['cube_hu'])}")
135-
136-
is4d = cube_hu.GetDimension() == 4
137-
# Now we parse the rest of the data if additional data is available in the dictionary
138-
139-
if "direction" in data:
140-
data["cube_hu"].SetDirection(data["direction"])
141-
142-
if "resolution" in data and all(key in data["resolution"] for key in ("x", "y", "z")):
143-
resolution = data["resolution"]
144-
if cube_hu.GetDimension() == 4:
145-
cube_hu.SetSpacing([resolution["x"], resolution["y"], resolution["z"], 1.0])
146-
else:
147-
cube_hu.SetSpacing([resolution["x"], resolution["y"], resolution["z"]])
148-
149-
# TODO: Either set up a Grid or check for x/y/z to do the data management
150-
if "origin" in data:
151-
data["cube_hu"].SetOrigin(data["origin"])
152-
153-
else:
154-
data["cube_hu"].SetOrigin(
155-
-np.array(data["cube_hu"].GetSize())
156-
/ 2.0
157-
* np.array(data["cube_hu"].GetSpacing())
158-
)
159-
if is4d:
160-
data["cube_hu"].SetOrigin(np.append(data["cube_hu"].GetOrigin(), [0.0]))
161-
# elif all(key in data for key in ("x", "y", "z")):
162-
# origin = np.array([data["x"][0], data["y"][0], data["z"][0]], dtype=float)
163-
# if is4d:
164-
# origin = np.append(origin, [0.0])
165-
# data["cube_hu"].SetOrigin(origin)
81+
# Process dictionary input
82+
if not isinstance(data, dict):
83+
return data
84+
85+
# Extract cube_hu from data using field name or alias
86+
cube_hu = cls._extract_cube_hu(data)
87+
88+
# Auto-extract origin from SimpleITK images if not provided
89+
cls._extract_origin_from_sitk_image(data, cube_hu)
90+
91+
# Make sure that we have a SimpleITK image
92+
if isinstance(cube_hu, np.ndarray):
93+
cube_hu = cls._convert_numpy_to_sitk(data, cube_hu)
94+
95+
# Update data with processed cube_hu
96+
data["cube_hu"] = cube_hu
97+
98+
# Validate final format (if some other type than numpy/sitk is provided)
99+
cls._validate_sitk_format(data)
100+
101+
# Apply image properties
102+
cls._apply_image_properties(data, cube_hu)
103+
166104
return data
167105

106+
@classmethod
107+
def _extract_cube_hu(cls, data: dict) -> Any:
108+
"""Extract cube_hu from data dictionary using field name or alias."""
109+
cube_hu = data.get("cube_hu", None)
110+
if cube_hu is None:
111+
cube_fieldname = cls.model_fields.get("cube_hu").validation_alias
112+
cube_hu = data.pop(cube_fieldname, None)
113+
114+
if cube_hu is None:
115+
raise ValueError("HU cube not present in dictionary.")
116+
117+
return cube_hu
118+
119+
@classmethod
120+
def _extract_origin_from_sitk_image(cls, data: dict, cube_hu: Any) -> None:
121+
"""Extract origin from SimpleITK image if not already provided."""
122+
if isinstance(cube_hu, sitk.Image) and "origin" not in data:
123+
data["origin"] = cube_hu.GetOrigin()
124+
125+
@classmethod
126+
def _convert_numpy_to_sitk(cls, data: dict, cube_hu: np.ndarray) -> sitk.Image:
127+
"""Convert numpy array(s) to SimpleITK image(s)."""
128+
# Determine if data needs axis permutation (matRad format)
129+
should_permute = data.get("cubeDim", None) is not None
130+
131+
# Handle multiple CT scenarios (object array)
132+
if cube_hu.dtype == object:
133+
return cls._convert_multiple_scenarios(cube_hu, should_permute)
134+
135+
# Handle single CT scenario
136+
return cls._convert_single_scenario(cube_hu, should_permute)
137+
138+
@classmethod
139+
def _convert_multiple_scenarios(cls, cube_hu: np.ndarray, should_permute: bool) -> sitk.Image:
140+
"""Convert multiple CT scenarios to joined SimpleITK image."""
141+
num_cubes = len(cube_hu)
142+
143+
if should_permute:
144+
ct_scenarios = [
145+
sitk.GetImageFromArray(np.transpose(cube_hu[i], (2, 0, 1)), False)
146+
for i in range(num_cubes)
147+
]
148+
else:
149+
ct_scenarios = [sitk.GetImageFromArray(cube_hu[i], False) for i in range(num_cubes)]
150+
151+
return sitk.JoinSeries(ct_scenarios)
152+
153+
@classmethod
154+
def _convert_single_scenario(cls, cube_hu: np.ndarray, should_permute: bool) -> sitk.Image:
155+
"""Convert single CT scenario to SimpleITK image."""
156+
if should_permute:
157+
return sitk.GetImageFromArray(np.transpose(cube_hu, (2, 0, 1)), False)
158+
else:
159+
return sitk.GetImageFromArray(cube_hu, False)
160+
161+
@classmethod
162+
def _validate_sitk_format(cls, data: dict) -> None:
163+
"""Validate that cube_hu is a SimpleITK image."""
164+
if not isinstance(data["cube_hu"], sitk.Image):
165+
raise ValueError(f"Unsupported format of HU cube: {type(data['cube_hu'])}")
166+
167+
@classmethod
168+
def _apply_image_properties(cls, data: dict, cube_hu: sitk.Image) -> None:
169+
"""Apply direction, spacing, and origin properties to the SimpleITK image."""
170+
is4d = cube_hu.GetDimension() == 4
171+
172+
# Apply direction
173+
if "direction" in data:
174+
data["cube_hu"].SetDirection(data["direction"])
175+
176+
# Apply spacing from resolution
177+
cls._apply_spacing(data, cube_hu, is4d)
178+
179+
# Apply origin (with three different strategies)
180+
cls._apply_origin(data, cube_hu, is4d)
181+
182+
@classmethod
183+
def _apply_spacing(cls, data: dict, cube_hu: sitk.Image, is4d: bool) -> None:
184+
"""Apply spacing from resolution data."""
185+
if "resolution" not in data:
186+
return
187+
188+
resolution = data["resolution"]
189+
if not all(key in resolution for key in ("x", "y", "z")):
190+
return
191+
192+
if is4d:
193+
spacing = [resolution["x"], resolution["y"], resolution["z"], 1.0]
194+
else:
195+
spacing = [resolution["x"], resolution["y"], resolution["z"]]
196+
197+
cube_hu.SetSpacing(spacing)
198+
199+
@classmethod
200+
def _apply_origin(cls, data: dict, cube_hu: sitk.Image, is4d: bool) -> None:
201+
"""Apply origin to the SimpleITK image using three different strategies."""
202+
if "origin" in data:
203+
# Strategy 1: Use explicitly provided origin
204+
data["cube_hu"].SetOrigin(data["origin"])
205+
elif all(key in data for key in ("x", "y", "z")):
206+
# Strategy 2: Calculate origin from x, y, z coordinate vectors
207+
cls._apply_origin_from_coordinate_vectors(data, is4d)
208+
else:
209+
# Strategy 3: Calculate centered origin based on image geometry
210+
cls._apply_centered_origin(data, is4d)
211+
212+
@classmethod
213+
def _apply_origin_from_coordinate_vectors(cls, data: dict, is4d: bool) -> None:
214+
"""Calculate and apply origin from x, y, z coordinate vectors."""
215+
origin = np.array([data["x"][0], data["y"][0], data["z"][0]], dtype=float)
216+
if is4d:
217+
origin = np.append(origin, [0.0])
218+
data["cube_hu"].SetOrigin(origin)
219+
220+
@classmethod
221+
def _apply_centered_origin(cls, data: dict, is4d: bool) -> None:
222+
"""Calculate and apply centered origin based on image geometry."""
223+
centered_origin = (
224+
-np.array(data["cube_hu"].GetSize()) / 2.0 * np.array(data["cube_hu"].GetSpacing())
225+
)
226+
227+
if is4d:
228+
centered_origin = np.append(centered_origin, [0.0])
229+
230+
data["cube_hu"].SetOrigin(centered_origin)
231+
168232
@computed_field
169233
@property
170234
def resolution(self) -> dict:

0 commit comments

Comments
 (0)