Skip to content

Commit b53aad3

Browse files
authored
Merge pull request #1939 from ANTsX/aat_vector_output
ENH: antsApplyTransforms vector output compatibility
2 parents e99379c + 9dcdd68 commit b53aad3

File tree

2 files changed

+182
-11
lines changed

2 files changed

+182
-11
lines changed

Examples/antsApplyTransforms.cxx

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ antsApplyTransforms(itk::ants::CommandLineParser::Pointer & parser, unsigned int
180180
using OutputMultiChannelImageType = itk::VectorImage<OutputPixelType, Dimension>;
181181
using OutputFiveDimensionalImageType = itk::Image<OutputPixelType, 5>;
182182
using OutputVectorType = itk::Vector<OutputPixelType, Dimension>;
183-
using OutputDisplacementFieldType = itk::Image<OutputVectorType, Dimension>;
183+
using OutputDisplacementFieldType = itk::Image<OutputVectorType, Dimension>; // used to cast displacement field pixel type
184+
// define a type for vector output, can't write OutputDisplacementFieldType as an ANTsPy pointer
185+
using OutputVectorImageType = itk::VectorImage<OutputPixelType, Dimension>;
184186

185187
using RegistrationHelperType = typename ants::RegistrationHelper<T, Dimension>;
186188
using AffineTransformType = typename RegistrationHelperType::AffineTransformType;
@@ -262,7 +264,9 @@ antsApplyTransforms(itk::ants::CommandLineParser::Pointer & parser, unsigned int
262264
{
263265
std::cout << "Input multichannel image: " << inputOption->GetFunction(0)->GetName() << std::endl;
264266
}
265-
ReadImage<MultiChannelImageType>(multiChannelImage, (inputOption->GetFunction(0)->GetName()).c_str());
267+
// Call read vector image to read in as itk::VectorImage, for compatibility with ANTsPy. Then handle internally
268+
// as a time series, then output as a vector later
269+
ReadVectorImage<MultiChannelImageType>(multiChannelImage, (inputOption->GetFunction(0)->GetName()).c_str());
266270
timeSeriesImage =
267271
ConvertMultiChannelImageToTimeSeriesImage<MultiChannelImageType, TimeSeriesImageType>(multiChannelImage);
268272
}
@@ -322,6 +326,23 @@ antsApplyTransforms(itk::ants::CommandLineParser::Pointer & parser, unsigned int
322326
}
323327
ReadTensorImage<TensorImageType>(tensorImage, (inputOption->GetFunction(0)->GetName()).c_str(), true, defaultValue);
324328
}
329+
else if ((inputImageType == 1 || inputImageType == 6) && inputOption && inputOption->GetNumberOfFunctions())
330+
{
331+
if (verbose)
332+
{
333+
std::cout << "Input vector image: " << inputOption->GetFunction(0)->GetName() << std::endl;
334+
}
335+
// use MultiChannelImageType to read in vector image as itk::VectorImage. Internally, the displacement field type
336+
// is an itk::Image with itk::Vector pixels, which is different. We read as VectorImage and then convert for compatibility
337+
// with antspy.
338+
typename MultiChannelImageType::Pointer tmpVectorImage = nullptr;
339+
340+
ReadVectorImage<MultiChannelImageType>(tmpVectorImage, (inputOption->GetFunction(0)->GetName()).c_str());
341+
342+
// this is an image of DisplacementFieldType, which may be an actual displacement field or generic spatial vectors in
343+
// physical space (inputImageType == 6) or spatial vectors in index space (inputImageType == 1)
344+
vectorImage = ConvertVectorImageToDisplacementField<MultiChannelImageType, DisplacementFieldType>(tmpVectorImage);
345+
}
325346
else if (inputImageType == 0 && inputOption && inputOption->GetNumberOfFunctions())
326347
{
327348
const std::string inputFN = inputOption->GetFunction(0)->GetName();
@@ -380,14 +401,6 @@ antsApplyTransforms(itk::ants::CommandLineParser::Pointer & parser, unsigned int
380401
ReadImage<ImageType>(image, inputFN.c_str());
381402
inputImages.push_back(image);
382403
}
383-
else if ((inputImageType == 1 || inputImageType == 6) && inputOption && inputOption->GetNumberOfFunctions())
384-
{
385-
if (verbose)
386-
{
387-
std::cout << "Input vector image: " << inputOption->GetFunction(0)->GetName() << std::endl;
388-
}
389-
ReadImage<DisplacementFieldType>(vectorImage, (inputOption->GetFunction(0)->GetName()).c_str());
390-
}
391404
else if (outputOption && outputOption->GetNumberOfFunctions())
392405
{
393406
if (outputOption->GetFunction(0)->GetNumberOfParameters() > 1 &&
@@ -408,6 +421,16 @@ antsApplyTransforms(itk::ants::CommandLineParser::Pointer & parser, unsigned int
408421
}
409422
}
410423
}
424+
else
425+
{
426+
if (verbose)
427+
{
428+
std::cerr << "No input or output specified, run without args for usage" << std::endl;
429+
}
430+
return EXIT_FAILURE;
431+
}
432+
433+
411434
/**
412435
* Reference image option
413436
*/
@@ -819,7 +842,14 @@ antsApplyTransforms(itk::ants::CommandLineParser::Pointer & parser, unsigned int
819842
caster->SetInput(reorienter->GetOutput());
820843
caster->Update();
821844

822-
ANTs::WriteImage<OutputDisplacementFieldType>(caster->GetOutput(), (outputFileName).c_str());
845+
typename OutputDisplacementFieldType::Pointer outField = caster->GetOutput();
846+
847+
// Convert OutputDisplacementFieldType to OutputVectorImageType for compatibility with ANTsPy
848+
// the file on disk is the same for both, but the pointer casting for antspy needs an itk::VectorImage
849+
typename OutputVectorImageType::Pointer outVec =
850+
ConvertDisplacementFieldToVectorImage<OutputDisplacementFieldType, OutputVectorImageType>(outField);
851+
852+
ANTs::WriteImage<OutputVectorImageType>(outVec, (outputFileName).c_str());
823853
}
824854
else if (inputImageType == 2)
825855
{

Utilities/ReadWriteData.h

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,66 @@ ReadImage(char * fn)
289289
return target;
290290
}
291291

292+
293+
template <typename TImageType>
294+
void
295+
ReadVectorImage(itk::SmartPointer<TImageType> & target, const char * file)
296+
{
297+
typedef TImageType ImageType;
298+
typedef itk::ImageFileReader<ImageType> FileSourceType;
299+
300+
typename FileSourceType::Pointer reffilter = nullptr;
301+
302+
if (FileIsPointer(file))
303+
{
304+
void * ptr;
305+
sscanf(file, "%p", (void **)&ptr);
306+
using Scalar = typename TImageType::PixelType::ComponentType;
307+
using VecImageType = itk::VectorImage<Scalar, TImageType::ImageDimension>;
308+
auto vecImagePtr = *(static_cast<typename VecImageType::Pointer *>(ptr));
309+
using CastFilterType = itk::CastImageFilter<VecImageType, TImageType>;
310+
typename CastFilterType::Pointer caster = CastFilterType::New();
311+
caster->SetInput(vecImagePtr);
312+
caster->Update();
313+
target = caster->GetOutput();
314+
target->DisconnectPipeline();
315+
}
316+
else
317+
{
318+
// Read the image files begin
319+
if (!ANTSFileExists(std::string(file)))
320+
{
321+
std::cerr << " file " << std::string(file) << " does not exist . " << std::endl;
322+
target = nullptr;
323+
return;
324+
}
325+
if (!ANTSFileIsImage(file))
326+
{
327+
std::cerr << " file " << std::string(file) << " is not recognized as a supported image format . " << std::endl;
328+
target = nullptr;
329+
return;
330+
}
331+
332+
reffilter = FileSourceType::New();
333+
reffilter->SetFileName(file);
334+
try
335+
{
336+
reffilter->Update();
337+
}
338+
catch (const itk::ExceptionObject & e)
339+
{
340+
std::cerr << "Exception caught during reference file reading " << std::endl;
341+
std::cerr << e << " file " << file << std::endl;
342+
target = nullptr;
343+
return;
344+
}
345+
346+
target = reffilter->GetOutput();
347+
}
348+
349+
}
350+
351+
292352
template <typename ImageType>
293353
typename ImageType::Pointer
294354
ReadTensorImage(char * fn, bool takelog = true, double backgroundMD = 0.0)
@@ -752,6 +812,87 @@ WriteDisplacementField2(TField * field, std::string filename, std::string app)
752812
return;
753813
}
754814

815+
/** Convert a displacement field to another vector type */
816+
template <typename DisplacementFieldType, typename VectorImageType>
817+
typename VectorImageType::Pointer
818+
ConvertDisplacementFieldToVectorImage(DisplacementFieldType * displacementField)
819+
{
820+
enum
821+
{
822+
ImageDimension = DisplacementFieldType::ImageDimension
823+
};
824+
825+
typename VectorImageType::Pointer vectorImage = VectorImageType::New();
826+
827+
vectorImage->SetRegions(displacementField->GetLargestPossibleRegion());
828+
vectorImage->SetSpacing(displacementField->GetSpacing());
829+
vectorImage->SetOrigin(displacementField->GetOrigin());
830+
vectorImage->SetDirection(displacementField->GetDirection());
831+
vectorImage->SetNumberOfComponentsPerPixel(ImageDimension);
832+
vectorImage->AllocateInitialized();
833+
834+
itk::ImageRegionIteratorWithIndex<VectorImageType> It(vectorImage,
835+
vectorImage->GetLargestPossibleRegion());
836+
837+
for (It.GoToBegin(); !It.IsAtEnd(); ++It)
838+
{
839+
typename DisplacementFieldType::IndexType index = It.GetIndex();
840+
841+
typename DisplacementFieldType::PixelType dispVoxel = displacementField->GetPixel(index);
842+
typename VectorImageType::PixelType vectorVoxel;
843+
vectorVoxel.SetSize(ImageDimension);
844+
845+
for (itk::SizeValueType n = 0; n < ImageDimension; n++)
846+
{
847+
vectorVoxel[n] = dispVoxel[n];
848+
}
849+
It.Set(vectorVoxel);
850+
}
851+
852+
return vectorImage;
853+
854+
}
855+
856+
/** Convert a vector image to a displacement field */
857+
template <typename VectorImageType, typename DisplacementFieldType>
858+
typename DisplacementFieldType::Pointer
859+
ConvertVectorImageToDisplacementField(VectorImageType * vectorImage)
860+
{
861+
enum
862+
{
863+
ImageDimension = VectorImageType::ImageDimension
864+
};
865+
866+
typename DisplacementFieldType::Pointer displacementField = DisplacementFieldType::New();
867+
868+
displacementField->SetRegions(vectorImage->GetLargestPossibleRegion());
869+
displacementField->SetSpacing(vectorImage->GetSpacing());
870+
displacementField->SetOrigin(vectorImage->GetOrigin());
871+
displacementField->SetDirection(vectorImage->GetDirection());
872+
displacementField->AllocateInitialized();
873+
874+
itk::ImageRegionIteratorWithIndex<DisplacementFieldType> It(displacementField,
875+
displacementField->GetLargestPossibleRegion());
876+
877+
for (It.GoToBegin(); !It.IsAtEnd(); ++It)
878+
{
879+
typename DisplacementFieldType::IndexType index = It.GetIndex();
880+
typename DisplacementFieldType::PixelType dispVoxel = It.Get();
881+
typename VectorImageType::PixelType vectorVoxel = vectorImage->GetPixel(index);
882+
883+
for (itk::SizeValueType n = 0; n < ImageDimension; n++)
884+
{
885+
dispVoxel[n] = vectorVoxel[n];
886+
}
887+
It.Set(dispVoxel);
888+
}
889+
890+
return displacementField;
891+
892+
}
893+
894+
895+
755896
template <typename TTimeSeriesImageType, typename MultiChannelImageType>
756897
typename MultiChannelImageType::Pointer
757898
ConvertTimeSeriesImageToMultiChannelImage(TTimeSeriesImageType * timeSeriesImage)

0 commit comments

Comments
 (0)