2727*
2828**************************************************************************
2929"""
30- from os .path import splitext
3130import math
31+ import os
32+ import itertools
3233import numpy as np
3334import pywt
3435import pywt .data
36+
3537from scipy .ndimage import zoom
3638from xmipp_base import XmippScript
3739import xmippLib
@@ -54,51 +56,105 @@ def run(self):
5456 outVolFn = self .getParam ('-o' )
5557 self .computeVolumeConsensus (inputFile , outVolFn )
5658
59+ def resize (self , image , dim ):
60+ imageFt = np .fft .rfftn (image )
61+ resultFt = np .zeros (dim [:- 1 ] + (dim [- 1 ]// 2 + 1 ,), dtype = imageFt .dtype )
62+
63+ copyExtent = np .minimum (image .shape , dim ) // 2
64+ srcCornerStart = image .shape - copyExtent
65+ dstCornerStart = dim - copyExtent
66+ for corners in itertools .product (range (2 ), repeat = len (dim )- 1 ):
67+ corners = np .array (corners + (0 , ))
68+ srcStart = np .where (corners , srcCornerStart , 0 )
69+ srcEnd = srcStart + copyExtent
70+ dstStart = np .where (corners , dstCornerStart , 0 )
71+ dstEnd = dstStart + copyExtent
72+ srcSlices = [slice (s , e ) for s , e in zip (srcStart , srcEnd )]
73+ dstSlices = [slice (s , e ) for s , e in zip (dstStart , dstEnd )]
74+ resultFt [tuple (dstSlices )] = imageFt [tuple (srcSlices )]
75+
76+ return np .fft .irfftn (resultFt )
77+
5778 def computeVolumeConsensus (self , inputFile , outVolFn , wavelet = 'sym11' ):
5879 outputWt = None
5980 outputMin = None
6081 xdim2 = None
6182 xdimOrig = None
83+ image = xmippLib .Image ()
6284 with open (inputFile ) as f :
6385 for line in f :
6486 fileName = line .split ()[0 ]
6587 if fileName .endswith ('.mrc' ):
6688 fileName += ':mrc'
67- V = xmippLib .Image (fileName )
68- vol = V .getData ()
89+
90+ image .read (fileName )
91+ volume = image .getData ()
92+
6993 if xdimOrig is None :
70- xdimOrig = vol .shape [0 ]
71- xdim2 = 2 ** (math .ceil (math .log (xdimOrig , 2 ))) # Next power of 2
72- ydimOrig = vol .shape [1 ]
73- ydim2 = 2 ** (math .ceil (math .log (ydimOrig , 2 ))) # Next power of 2
74- zdimOrig = vol .shape [2 ]
75- zdim2 = 2 ** (math .ceil (math .log (zdimOrig , 2 ))) # Next power of 2
94+ xdimOrig = volume .shape [0 ]
95+ xdim2 = 2 ** (math .ceil (math .log2 (xdimOrig ))) # Next power of 2
96+ ydimOrig = volume .shape [1 ]
97+ ydim2 = 2 ** (math .ceil (math .log2 (ydimOrig ))) # Next power of 2
98+ zdimOrig = volume .shape [2 ]
99+ zdim2 = 2 ** (math .ceil (math .log2 (zdimOrig ))) # Next power of 2
100+
76101 if xdimOrig != xdim2 or ydimOrig != ydim2 or zdimOrig != zdim2 :
77- vol = zoom (vol , (xdim2 / xdimOrig ,ydim2 / ydimOrig ,zdim2 / zdimOrig ))
78- nlevel = pywt .swt_max_level (len (vol ))
79- wt = pywt .swtn (vol , wavelet , nlevel , 0 )
102+ #volume = zoom(volume, (xdim2/xdimOrig,ydim2/ydimOrig,zdim2/zdimOrig))
103+ volume = self .resize (volume , (zdim2 , ydim2 , xdim2 ))
104+
105+ nlevel = pywt .dwtn_max_level (volume .shape , wavelet = wavelet )
106+ wt = pywt .wavedecn (
107+ data = volume ,
108+ wavelet = wavelet ,
109+ level = nlevel
110+ )
111+
80112 if outputWt == None :
81113 outputWt = wt
82- outputMin = wt [ 0 ][ 'aaa' ] * 0
114+ outputMin = np . zeros_like ( volume )
83115 else :
84- for level in range (0 , nlevel ):
116+ diff = np .abs (np .abs (wt [0 ]) - np .abs (outputWt [0 ]))
117+ diff = self .resize (diff , outputMin .shape )
118+ np .maximum (
119+ diff , outputMin ,
120+ out = outputMin
121+ )
122+ outputWt [0 ] = np .where (
123+ np .abs (wt [0 ]) > np .abs (outputWt [0 ]),
124+ wt [0 ], outputWt [0 ]
125+ )
126+
127+ for level in range (1 , nlevel + 1 ):
85128 wtLevel = wt [level ]
86129 outputWtLevel = outputWt [level ]
87- for key in wtLevel :
88- outputWtLevel [key ] = np .where (np .abs (outputWtLevel [key ]) > np .abs (wtLevel [key ]),
89- outputWtLevel [key ], wtLevel [key ])
90- diff = np .abs (np .abs (outputWtLevel [key ]) - np .abs (wtLevel [key ]))
91- outputMin = np .where (outputMin > diff , outputMin , diff )
130+ for detail in wtLevel :
131+ wtLevelDetail = wtLevel [detail ]
132+ outputWtLevelDetail = outputWtLevel [detail ]
133+
134+ diff = np .abs (np .abs (wtLevelDetail ) - np .abs (outputWtLevelDetail ))
135+ diff = self .resize (diff , outputMin .shape )
136+ np .maximum (
137+ diff , outputMin ,
138+ out = outputMin
139+ )
140+
141+ outputWtLevelDetail [...] = np .where (
142+ np .abs (wtLevelDetail ) > np .abs (outputWtLevelDetail ),
143+ wtLevelDetail , outputWtLevelDetail
144+ )
145+
146+
92147 f .close ()
93- consensus = pywt .iswtn (outputWt , wavelet )
148+ consensus = pywt .waverecn (outputWt , wavelet )
149+ if xdimOrig != xdim2 or ydimOrig != ydim2 or zdimOrig != zdim2 :
150+ consensus = self .resize (consensus , (zdimOrig , ydimOrig , xdimOrig ))
151+ image .setData (consensus )
152+ image .write (outVolFn )
94153 if xdimOrig != xdim2 or ydimOrig != ydim2 or zdimOrig != zdim2 :
95- consensus = zoom (consensus , (xdimOrig / xdim2 ,ydimOrig / ydim2 ,zdimOrig / zdim2 ))
96- V = xmippLib .Image ()
97- V .setData (consensus )
98- V .write (outVolFn )
99- V .setData (outputMin )
100- outVolFn2 = splitext (outVolFn )[0 ] + '_diff.mrc'
101- V .write (outVolFn2 )
154+ outputMin = self .resize (outputMin , (zdimOrig , ydimOrig , xdimOrig ))
155+ image .setData (outputMin )
156+ outVolFn2 = os .path .splitext (outVolFn )[0 ] + '_diff.mrc'
157+ image .write (outVolFn2 )
102158
103159
104160if __name__ == "__main__" :
0 commit comments