|
| 1 | +#pragma once |
| 2 | +#include "animaRicianToGaussianImageFilter.h" |
| 3 | +#include <animaChiDistribution.h> |
| 4 | +#include <animaVectorOperations.h> |
| 5 | + |
| 6 | +#include <itkConstNeighborhoodIterator.h> |
| 7 | +#include <itkImageRegionConstIterator.h> |
| 8 | +#include <itkImageRegionIterator.h> |
| 9 | +#include <itkGaussianOperator.h> |
| 10 | + |
| 11 | +namespace anima |
| 12 | +{ |
| 13 | + |
| 14 | +template <class ComponentType,unsigned int ImageDimension> |
| 15 | +void |
| 16 | +RicianToGaussianImageFilter<ComponentType,ImageDimension> |
| 17 | +::BeforeThreadedGenerateData(void) |
| 18 | +{ |
| 19 | + // Compute spatial weights of neighbors beforehand |
| 20 | + typename InputImageType::SpacingType spacing = this->GetInput()->GetSpacing(); |
| 21 | + typedef itk::GaussianOperator<double,ImageDimension> GaussianOperatorType; |
| 22 | + std::vector<GaussianOperatorType> gaussianKernels(ImageDimension); |
| 23 | + |
| 24 | + for (unsigned int i = 0;i < ImageDimension;++i) |
| 25 | + { |
| 26 | + unsigned int reverse_i = ImageDimension - i - 1; |
| 27 | + double stddev = m_Sigma / spacing[i]; |
| 28 | + gaussianKernels[reverse_i].SetDirection(i); |
| 29 | + gaussianKernels[reverse_i].SetVariance(stddev * stddev); |
| 30 | + gaussianKernels[reverse_i].SetMaximumError(1e-3); |
| 31 | + gaussianKernels[reverse_i].CreateDirectional(); |
| 32 | + gaussianKernels[reverse_i].ScaleCoefficients(1.0e4); |
| 33 | + m_Radius[i] = gaussianKernels[reverse_i].GetRadius(i); |
| 34 | + } |
| 35 | + |
| 36 | + m_NeighborWeights.clear(); |
| 37 | + for (unsigned int i = 0;i < gaussianKernels[0].Size();++i) |
| 38 | + { |
| 39 | + for (unsigned int j = 0;j < gaussianKernels[1].Size();++j) |
| 40 | + { |
| 41 | + for (unsigned int k = 0;k < gaussianKernels[2].Size();++k) |
| 42 | + { |
| 43 | + double weight = gaussianKernels[0][i] * gaussianKernels[1][j] * gaussianKernels[2][k]; |
| 44 | + m_NeighborWeights.push_back(weight); |
| 45 | + } |
| 46 | + } |
| 47 | + } |
| 48 | + |
| 49 | + // Initialize thread containers for global scale estimation |
| 50 | + unsigned int numThreads = this->GetNumberOfThreads(); |
| 51 | + m_ThreadScaleSamples.resize(numThreads); |
| 52 | + this->Superclass::BeforeThreadedGenerateData(); |
| 53 | +} |
| 54 | + |
| 55 | +template <class ComponentType,unsigned int ImageDimension> |
| 56 | +void |
| 57 | +RicianToGaussianImageFilter<ComponentType,ImageDimension> |
| 58 | +::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, |
| 59 | + itk::ThreadIdType threadId) |
| 60 | +{ |
| 61 | + typedef itk::ConstNeighborhoodIterator<InputImageType> InputIteratorType; |
| 62 | + typedef itk::ConstNeighborhoodIterator<MaskImageType> MaskIteratorType; |
| 63 | + typedef itk::ImageRegionConstIterator<InputImageType> InputSimpleIteratorType; |
| 64 | + typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; |
| 65 | + |
| 66 | + InputIteratorType inputItr(m_Radius, this->GetInput(), outputRegionForThread); |
| 67 | + unsigned int neighborhoodSize = inputItr.Size(); |
| 68 | + |
| 69 | + MaskIteratorType maskItr(m_Radius, this->GetComputationMask(), outputRegionForThread); |
| 70 | + |
| 71 | + InputSimpleIteratorType meanItr, varItr; |
| 72 | + if (m_MeanImage) |
| 73 | + meanItr = InputSimpleIteratorType(m_MeanImage, outputRegionForThread); |
| 74 | + if (m_VarianceImage) |
| 75 | + varItr = InputSimpleIteratorType(m_VarianceImage, outputRegionForThread); |
| 76 | + |
| 77 | + OutputIteratorType locationItr(this->GetOutput(0), outputRegionForThread); |
| 78 | + OutputIteratorType scaleItr(this->GetOutput(1), outputRegionForThread); |
| 79 | + OutputIteratorType signalItr(this->GetOutput(2), outputRegionForThread); |
| 80 | + |
| 81 | + std::vector<double> samples, weights; |
| 82 | + bool isInBounds; |
| 83 | + typename InputImageType::IndexType currentIndex, neighborIndex; |
| 84 | + typename InputImageType::PointType currentPoint, neighborPoint; |
| 85 | + |
| 86 | + while (!maskItr.IsAtEnd()) |
| 87 | + { |
| 88 | + // Discard voxels outside of the brain |
| 89 | + if (maskItr.GetCenterPixel() == 0) |
| 90 | + { |
| 91 | + locationItr.Set(0); |
| 92 | + scaleItr.Set(0); |
| 93 | + signalItr.Set(0); |
| 94 | + |
| 95 | + ++inputItr; |
| 96 | + ++maskItr; |
| 97 | + if (m_MeanImage) |
| 98 | + ++meanItr; |
| 99 | + if (m_VarianceImage) |
| 100 | + ++varItr; |
| 101 | + ++locationItr; |
| 102 | + ++scaleItr; |
| 103 | + ++signalItr; |
| 104 | + |
| 105 | + continue; |
| 106 | + } |
| 107 | + |
| 108 | + // Get signal at current central voxel |
| 109 | + double inputSignal = inputItr.GetCenterPixel(); |
| 110 | + |
| 111 | + // Rice-corrupted signals should all be positive |
| 112 | + if (inputSignal <= 0) |
| 113 | + { |
| 114 | + locationItr.Set(0); |
| 115 | + scaleItr.Set(0); |
| 116 | + signalItr.Set(0); |
| 117 | + |
| 118 | + ++inputItr; |
| 119 | + ++maskItr; |
| 120 | + if (m_MeanImage) |
| 121 | + ++meanItr; |
| 122 | + if (m_VarianceImage) |
| 123 | + ++varItr; |
| 124 | + ++locationItr; |
| 125 | + ++scaleItr; |
| 126 | + ++signalItr; |
| 127 | + |
| 128 | + continue; |
| 129 | + } |
| 130 | + |
| 131 | + // Estimation of location and scale |
| 132 | + double location = 0; |
| 133 | + double scale = m_Scale; |
| 134 | + |
| 135 | + if (m_MeanImage && m_VarianceImage) |
| 136 | + { |
| 137 | + // If mean and variance images are available, use them instead of neighborhood |
| 138 | + double sigmaValue = std::sqrt(varItr.Get()); |
| 139 | + double rValue = meanItr.Get() / sigmaValue; |
| 140 | + double k1Value = 0; |
| 141 | + double thetaValue = anima::FixedPointFinder(rValue, 1, k1Value); |
| 142 | + k1Value = anima::KummerFunction(-thetaValue * thetaValue / 2.0, -0.5, 1); |
| 143 | + |
| 144 | + scale = sigmaValue / std::sqrt(anima::XiFunction(thetaValue * thetaValue, 1, k1Value, M_PI / 2.0)); |
| 145 | + location = thetaValue * scale; |
| 146 | + } |
| 147 | + else |
| 148 | + { |
| 149 | + // Use neighbors to create samples |
| 150 | + samples.clear(); |
| 151 | + weights.clear(); |
| 152 | + |
| 153 | + for (unsigned int i = 0; i < neighborhoodSize; ++i) |
| 154 | + { |
| 155 | + double tmpVal = static_cast<double>(inputItr.GetPixel(i, isInBounds)); |
| 156 | + |
| 157 | + if (isInBounds && !std::isnan(tmpVal) && std::isfinite(tmpVal)) |
| 158 | + { |
| 159 | + if (maskItr.GetPixel(i) != maskItr.GetCenterPixel()) |
| 160 | + continue; |
| 161 | + |
| 162 | + double weight = m_NeighborWeights[i]; |
| 163 | + |
| 164 | + if (weight < m_Epsilon) |
| 165 | + continue; |
| 166 | + |
| 167 | + samples.push_back(tmpVal); |
| 168 | + weights.push_back(weight); |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + if (samples.size() == 1) |
| 173 | + location = inputSignal; |
| 174 | + else |
| 175 | + anima::GetRiceParameters(samples, weights, location, scale); |
| 176 | + } |
| 177 | + |
| 178 | + // Transform Rice signal in Gaussian signal |
| 179 | + double outputSignal = inputSignal; |
| 180 | + double snrValue = location / scale; |
| 181 | + |
| 182 | + if (!std::isfinite(snrValue) || std::isnan(snrValue)) |
| 183 | + { |
| 184 | + std::cout << snrValue << " " << location << " " << scale << std::endl; |
| 185 | + itkExceptionMacro("Estimated SNR is invalid"); |
| 186 | + } |
| 187 | + |
| 188 | + if (snrValue <= 0) |
| 189 | + outputSignal = 0.0; |
| 190 | + else if (snrValue <= m_Epsilon) |
| 191 | + { |
| 192 | + double unifSignal = boost::math::cdf(m_RayleighDistribution, inputSignal / scale); |
| 193 | + |
| 194 | + if (unifSignal >= 1.0 - m_Alpha || unifSignal <= m_Alpha) |
| 195 | + unifSignal = boost::math::cdf(m_RayleighDistribution, snrValue); |
| 196 | + |
| 197 | + outputSignal = scale * boost::math::quantile(m_NormalDistribution, unifSignal); |
| 198 | + } |
| 199 | + else if (snrValue <= 600) // if SNR if > 600 keep signal as is, else... |
| 200 | + { |
| 201 | + double unifSignal = anima::GetRiceCDF(inputSignal, location, scale); |
| 202 | + |
| 203 | + if (unifSignal >= 1.0 - m_Alpha || unifSignal <= m_Alpha) |
| 204 | + unifSignal = anima::GetRiceCDF(location, location, scale); |
| 205 | + |
| 206 | + outputSignal = location + scale * boost::math::quantile(m_NormalDistribution, unifSignal); |
| 207 | + } |
| 208 | + |
| 209 | + m_ThreadScaleSamples[threadId].push_back(scale); |
| 210 | + |
| 211 | + locationItr.Set(static_cast<OutputPixelType>(location)); |
| 212 | + scaleItr.Set(static_cast<OutputPixelType>(scale)); |
| 213 | + signalItr.Set(static_cast<OutputPixelType>(outputSignal)); |
| 214 | + |
| 215 | + ++inputItr; |
| 216 | + ++maskItr; |
| 217 | + if (m_MeanImage) |
| 218 | + ++meanItr; |
| 219 | + if (m_VarianceImage) |
| 220 | + ++varItr; |
| 221 | + ++locationItr; |
| 222 | + ++scaleItr; |
| 223 | + ++signalItr; |
| 224 | + } |
| 225 | +} |
| 226 | + |
| 227 | +template <class ComponentType,unsigned int ImageDimension> |
| 228 | +void |
| 229 | +RicianToGaussianImageFilter<ComponentType,ImageDimension> |
| 230 | +::AfterThreadedGenerateData(void) |
| 231 | +{ |
| 232 | + if (m_Scale == 0) |
| 233 | + { |
| 234 | + std::vector<double> scaleSamples; |
| 235 | + for (unsigned int i = 0;i < this->GetNumberOfThreads();++i) |
| 236 | + scaleSamples.insert(scaleSamples.end(), m_ThreadScaleSamples[i].begin(), m_ThreadScaleSamples[i].end()); |
| 237 | + |
| 238 | + m_Scale = anima::GetMedian(scaleSamples); |
| 239 | + } |
| 240 | + |
| 241 | + m_ThreadScaleSamples.clear(); |
| 242 | + m_NeighborWeights.clear(); |
| 243 | + |
| 244 | + this->Superclass::AfterThreadedGenerateData(); |
| 245 | +} |
| 246 | + |
| 247 | +} // end of namespace anima |
0 commit comments