Skip to content

Commit cfd9356

Browse files
committed
🚧 WIP ✨ Drop in mri_robust_template
1 parent c11c88f commit cfd9356

File tree

9 files changed

+387
-141
lines changed

9 files changed

+387
-141
lines changed

CPAC/longitudinal/robust_template.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,115 @@
1616
# You should have received a copy of the GNU Lesser General Public
1717
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1818
"""Create longitudinal template using ``mri_robust_template``."""
19+
20+
import os
21+
from typing import cast, Literal
22+
23+
from nipype.interfaces.base import (
24+
File,
25+
InputMultiPath,
26+
isdefined,
27+
OutputMultiPath,
28+
traits,
29+
)
30+
from nipype.interfaces.freesurfer import longitudinal
31+
32+
from CPAC.pipeline import nipype_pipeline_engine as pe
33+
from CPAC.utils.configuration import Configuration
34+
35+
36+
class RobustTemplateInputSpec(longitudinal.RobustTemplateInputSpec): # noqa: D101
37+
affine = traits.Bool(default_value=False, desc="compute 12 DOF registration")
38+
mapmov = traits.Either(
39+
InputMultiPath(File(exists=False)),
40+
traits.Bool,
41+
argstr="--mapmov %s",
42+
desc="output images: map and resample each input to template",
43+
)
44+
maxit = traits.Int(
45+
argstr="--maxit %d",
46+
mandatory=False,
47+
desc="iterate max # times (if #tp>2 default 6, else 5 for 2tp reg.)",
48+
)
49+
50+
51+
class RobustTemplateOutputSpec(longitudinal.RobustTemplateOutputSpec): # noqa: D101
52+
mapmov = OutputMultiPath(
53+
File(exists=True),
54+
desc="each input mapped and resampled to longitudinal template",
55+
)
56+
57+
58+
class RobustTemplate(longitudinal.RobustTemplate): # noqa: D101
59+
# STATEMENT OF CHANGES:
60+
# This class is derived from sources licensed under the Apache-2.0 terms,
61+
# and this class has been changed.
62+
63+
# CHANGES:
64+
# * Added handling for `affind`, `mapmov` and `maxit`
65+
66+
# ORIGINAL WORK'S ATTRIBUTION NOTICE:
67+
# Copyright (c) 2009-2016, Nipype developers
68+
69+
# Licensed under the Apache License, Version 2.0 (the "License");
70+
# you may not use this file except in compliance with the License.
71+
# You may obtain a copy of the License at
72+
73+
# http://www.apache.org/licenses/LICENSE-2.0
74+
75+
# Unless required by applicable law or agreed to in writing, software
76+
# distributed under the License is distributed on an "AS IS" BASIS,
77+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
78+
# See the License for the specific language governing permissions and
79+
# limitations under the License.
80+
81+
# Prior to release 0.12, Nipype was licensed under a BSD license.
82+
83+
# Modifications copyright (C) 2024 C-PAC Developers
84+
input_spec = RobustTemplateInputSpec
85+
output_spec = RobustTemplateOutputSpec
86+
87+
def _format_arg(self, name, spec, value):
88+
if name == "average_metric":
89+
# return enumeration value
90+
return spec.argstr % {"mean": 0, "median": 1}[value]
91+
if name in ("mapmov", "transform_outputs", "scaled_intensity_outputs"):
92+
value = self._list_outputs()[name]
93+
return super()._format_arg(name, spec, value)
94+
95+
def _list_outputs(self):
96+
""":py:meth:`~nipype.interfaces.freesurfer.RobustTemplate._list_outputs` + `mapmov`."""
97+
outputs = self.output_spec().get()
98+
outputs["out_file"] = os.path.abspath(self.inputs.out_file)
99+
n_files = len(self.inputs.in_files)
100+
fmt = "{}{:02d}.{}" if n_files > 9 else "{}{:d}.{}" # noqa: PLR2004
101+
for key, prefix, ext in [
102+
("transform_outputs", "tp", "lta"),
103+
("scaled_intensity_outputs", "is", "txt"),
104+
("mapmov", "space-longitudinal", "nii.gz"),
105+
]:
106+
if isdefined(getattr(self.inputs, key)):
107+
fnames = getattr(self.inputs, key)
108+
if fnames is True:
109+
fnames = [fmt.format(prefix, i + 1, ext) for i in range(n_files)]
110+
outputs[key] = [os.path.abspath(x) for x in fnames]
111+
return outputs
112+
113+
114+
def mri_robust_template(name: str, cfg: Configuration) -> pe.Node:
115+
"""Return a Node to run `mri_robust_template` with common options."""
116+
node = pe.Node(RobustTemplate(), name=name)
117+
node.set_input("mapmov", True)
118+
node.set_input("transform_outputs", True)
119+
node.set_input(
120+
"average_metric", cfg["longitudinal_template_generation", "average_method"]
121+
)
122+
node.set_input("affine", cfg["longitudinal_template_generation", "dof"] == 12) # noqa: PLR2004
123+
max_iter = cast(
124+
int | Literal["default"], cfg["longitudinal_template_generation", "max_iter"]
125+
)
126+
if isinstance(max_iter, int):
127+
node.set_input("maxit", max_iter)
128+
node.set_input("auto_detect_sensitivity", True)
129+
130+
return node

0 commit comments

Comments
 (0)