Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 83 additions & 27 deletions src/xmipp/applications/scripts/volume_consensus/volume_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
*
**************************************************************************
"""
from os.path import splitext
import math
import os
import itertools
import numpy as np
import pywt
import pywt.data

from scipy.ndimage import zoom
from xmipp_base import XmippScript
import xmippLib
Expand All @@ -54,51 +56,105 @@ def run(self):
outVolFn = self.getParam('-o')
self.computeVolumeConsensus(inputFile, outVolFn)

def resize(self, image, dim):
imageFt = np.fft.rfftn(image)
resultFt = np.zeros(dim[:-1] + (dim[-1]//2+1,), dtype=imageFt.dtype)

copyExtent = np.minimum(image.shape, dim) // 2
srcCornerStart = image.shape-copyExtent
dstCornerStart = dim-copyExtent
for corners in itertools.product(range(2), repeat=len(dim)-1):
corners = np.array(corners + (0, ))
srcStart = np.where(corners, srcCornerStart, 0)
srcEnd = srcStart + copyExtent
dstStart = np.where(corners, dstCornerStart, 0)
dstEnd = dstStart + copyExtent
srcSlices = [slice(s, e) for s, e in zip(srcStart, srcEnd)]
dstSlices = [slice(s, e) for s, e in zip(dstStart, dstEnd)]
resultFt[tuple(dstSlices)] = imageFt[tuple(srcSlices)]

return np.fft.irfftn(resultFt)

def computeVolumeConsensus(self, inputFile, outVolFn, wavelet='sym11'):
outputWt = None
outputMin = None
xdim2 = None
xdimOrig = None
image = xmippLib.Image()
with open(inputFile) as f:
for line in f:
fileName = line.split()[0]
if fileName.endswith('.mrc'):
fileName += ':mrc'
V = xmippLib.Image(fileName)
vol = V.getData()

image.read(fileName)
volume = image.getData()

if xdimOrig is None:
xdimOrig = vol.shape[0]
xdim2 = 2**(math.ceil(math.log(xdimOrig, 2))) # Next power of 2
ydimOrig = vol.shape[1]
ydim2 = 2 ** (math.ceil(math.log(ydimOrig, 2))) # Next power of 2
zdimOrig = vol.shape[2]
zdim2 = 2 ** (math.ceil(math.log(zdimOrig, 2))) # Next power of 2
xdimOrig = volume.shape[0]
xdim2 = 2**(math.ceil(math.log2(xdimOrig))) # Next power of 2
ydimOrig = volume.shape[1]
ydim2 = 2**(math.ceil(math.log2(ydimOrig))) # Next power of 2
zdimOrig = volume.shape[2]
zdim2 = 2**(math.ceil(math.log2(zdimOrig))) # Next power of 2

if xdimOrig!=xdim2 or ydimOrig!=ydim2 or zdimOrig!=zdim2:
vol = zoom(vol, (xdim2/xdimOrig,ydim2/ydimOrig,zdim2/zdimOrig))
nlevel = pywt.swt_max_level(len(vol))
wt = pywt.swtn(vol, wavelet, nlevel, 0)
#volume = zoom(volume, (xdim2/xdimOrig,ydim2/ydimOrig,zdim2/zdimOrig))
volume = self.resize(volume, (zdim2, ydim2, xdim2))

nlevel = pywt.dwtn_max_level(volume.shape, wavelet=wavelet)
wt = pywt.wavedecn(
data=volume,
wavelet=wavelet,
level=nlevel
)

if outputWt == None:
outputWt = wt
outputMin = wt[0]['aaa']*0
outputMin = np.zeros_like(volume)
else:
for level in range(0, nlevel):
diff = np.abs(np.abs(wt[0]) - np.abs(outputWt[0]))
diff = self.resize(diff, outputMin.shape)
np.maximum(
diff, outputMin,
out=outputMin
)
outputWt[0] = np.where(
np.abs(wt[0]) > np.abs(outputWt[0]),
wt[0], outputWt[0]
)

for level in range(1, nlevel+1):
wtLevel = wt[level]
outputWtLevel = outputWt[level]
for key in wtLevel:
outputWtLevel[key] = np.where(np.abs(outputWtLevel[key]) > np.abs(wtLevel[key]),
outputWtLevel[key], wtLevel[key])
diff = np.abs(np.abs(outputWtLevel[key]) - np.abs(wtLevel[key]))
outputMin = np.where(outputMin > diff, outputMin, diff)
for detail in wtLevel:
wtLevelDetail = wtLevel[detail]
outputWtLevelDetail = outputWtLevel[detail]

diff = np.abs(np.abs(wtLevelDetail) - np.abs(outputWtLevelDetail))
diff = self.resize(diff, outputMin.shape)
np.maximum(
diff, outputMin,
out=outputMin
)

outputWtLevelDetail[...] = np.where(
np.abs(wtLevelDetail) > np.abs(outputWtLevelDetail),
wtLevelDetail, outputWtLevelDetail
)


f.close()
consensus = pywt.iswtn(outputWt, wavelet)
consensus = pywt.waverecn(outputWt, wavelet)
if xdimOrig!=xdim2 or ydimOrig!=ydim2 or zdimOrig!=zdim2:
consensus = self.resize(consensus, (zdimOrig, ydimOrig, xdimOrig))
image.setData(consensus)
image.write(outVolFn)
if xdimOrig!=xdim2 or ydimOrig!=ydim2 or zdimOrig!=zdim2:
consensus = zoom(consensus, (xdimOrig/xdim2,ydimOrig/ydim2,zdimOrig/zdim2))
V = xmippLib.Image()
V.setData(consensus)
V.write(outVolFn)
V.setData(outputMin)
outVolFn2 = splitext(outVolFn)[0] + '_diff.mrc'
V.write(outVolFn2)
outputMin = self.resize(outputMin, (zdimOrig, ydimOrig, xdimOrig))
image.setData(outputMin)
outVolFn2 = os.path.splitext(outVolFn)[0] + '_diff.mrc'
image.write(outVolFn2)


if __name__=="__main__":
Expand Down