Skip to content

figure out adding new algorithms / how to smooth #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,15 @@ tmp/
*~.nib
local.properties
.settings/
python/examples/scripts/C1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tif
python/examples/scripts/C1-YeastTNA1_1516_conv_RG_26oC_003sub100.tif
python/examples/scripts/gpsf_3D_1514_a3_001_WF-sub105crop128.tif
python/examples/scripts/gpsf_3D_1514_a3_001_WF-sub105crop64.tif
python/examples/scripts/FlowDecLog.txt
python/examples/scripts/resultC1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tifgpsf_3D_1514_a3_001_WF-sub105crop64.tif100iterations.tif
python/examples/scripts/resultC1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tifgpsf_3D_1514_a3_001_WF-sub105crop64.tif50iterations.tif
python/examples/scripts/FlowDecLogGold50.txt
python/examples/scripts/resultC1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tifgpsf_3D_1514_a3_001_WF-sub105crop64.tif250iterations.tif
python/examples/scripts/resultC1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tifgpsf_3D_1514_a3_001_WF-sub105crop64.tif4iterations.tif
python/examples/scripts/resultC1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tifgpsf_3D_1514_a3_001_WF-sub105crop64.tif5iterations.tif
python/examples/scripts/resultC1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tifgpsf_3D_1514_a3_001_WF-sub105crop64.tif10iterations.tif
47 changes: 47 additions & 0 deletions python/examples/scripts/gaussKernTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import tensorflow as tf
import sys


def gaussian_kernel(size: int, mean: float, std: float):
d = tf.distributions.Normal(mean, std)
vals = d.prob(tf.range(size, dtype = tf.float32))
gauss_kernel = tf.einsum('i,j,k->ijk',
vals,
vals,
vals)
# normalise gauss kernel to sum = 1
kerngausnorm = gauss_kernel / tf.reduce_mean(gauss_kernel)
return kerngausnorm

def squareIt(tensor):
tensorpow2 = tf.multiply(tensor, tensor)
return tensorpow2

sess = tf.compat.v1.Session()
with sess.as_default():
tensor = gaussian_kernel(5, 2.0, 1.0)
tensor2 = gaussian_kernel(5, 3.0, 1.0)
tensor = tf.dtypes.cast(tensor, tf.float32)
tensor2 = tf.dtypes.cast(tensor2, tf.float32)
tensorSquared = squareIt(tensor)
tensorPlusTen = tensor + 10.0
# expand its dimensionality to fit into conv3d, input and filter have different dimension orders, filter has no batch but in and out channels as 2 last dims
tensor_expand = tf.expand_dims(tensor, 0)
tensor_expand = tf.expand_dims(tensor_expand, -1)
tensor_filter = tf.expand_dims(tensor2, -1)
tensor_filter = tf.expand_dims(tensor_filter, -1)
# why does tf.nn.conv3d output a tensor containing only 1 value???eg [[[[[]35234.325]]]] is it the sum of the real 3D image output?
tensorConvolved = tf.compat.v1.nn.conv3d(tensor_expand, filter=tensor_filter, strides=[1,1,1,1,1], padding="VALID", data_format='NDHWC')
print_op = tf.print("3D Gauss Kernel? :\n", tensor, " Sum = ", tf.reduce_sum(tensor), " Max = ", tf.reduce_max(tensor), "\n",
"3D Gauss Kernel squared? :\n", tensorSquared, " Sum = ", tf.reduce_sum(tensorSquared), " Max = ", tf.reduce_max(tensorSquared), "\n",
"3D Gauss Kernel convolved? :\n", tensorConvolved, " Sum = ",
tf.reduce_sum(tensorConvolved), " Max = ", tf.reduce_max(tensorConvolved), " DimsShape = ", tf.shape(tensorConvolved), "\n",
" Plus10 :\n", tensorPlusTen,
output_stream=sys.stdout)
with tf.control_dependencies([print_op]):
the_tensor = tensor * 1.0
the_tensor2 = tensor2 * 1.0
the_tensorpow2 = tensorSquared * 1.0
the_tensorConvolved = tensorConvolved * 1.0
sess.run(the_tensorConvolved)

146 changes: 146 additions & 0 deletions python/examples/scripts/simpleFlowDecAlgoConvergenceTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Dan White March 2019-2020
# simpleFlowDecAlgoConvergenceTest.py
# A simple script using flowdec, testing convergence of the algorithm on test data

# flowdec pip package install/update requirements for anaconda on win10 64 bit:

# install the pip package for flowdec with GPU support as per the flowdec install instructions:
# github.com/hammerlab/flowdec/blob/master/README.md but......


# 22 april 2020 - what works today :
# flowdec uses tensorflow which on nvidda GPU uses CUDA so need to install the stuff here
# https://www.tensorflow.org/install/gpu but.....those instructions seem to install incompatible stuff... so try
# things were installed in this order and the script works and used the GPU
#cuda toolkit 10.0
#cudnn-10.0 7.6.34.38 installed into C:/tools/
#pip install flowdec
# not pip install flowdec[tf_gpu
# ommitting the tf_gpu option, so it leaves tensorflow-gpu uninstalled, becasue by default
# by now it installs v2.1 of tensorflow which doesnt seem to work for flowdec?
# pip install tensorflow-gpu==1.14.0 (2.0 might work...? maybe needs higher cuda version)
# Need windows env variables pointing to cuda stuff: CUDA and CUPTI and another related library cuDNN
# SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\bin;%PATH%
# SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\extras\CUPTI\libx64;%PATH%
# SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0\include;%PATH%
# SET PATH=C:\tools\cuda\bin;%PATH%
# some of these don't seem to be sticky???

# for time benchmarking
import time
# for logging to text file
import sys
#import logging

startImports = time.process_time()
from skimage.external.tifffile import imsave, imread
from skimage.metrics import structural_similarity as ssim
#from tensorflow.image import ssim as ssim
from flowdec import data as fd_data
from flowdec import restoration as fd_restoration
importsTime = (time.process_time() - startImports)

# this seems to return 0 seconds, so maybe TF math is already imported by flowdec?
startTFMathImport = time.process_time()
from tensorflow import math as tfmath
importTFMathTime = (time.process_time() - startTFMathImport)

# Load test image from same dir as we execute in
rawBig = 'C1-YeastTNA1_1516_conv_RG_26oC_003sub100.tif'
rawSmall = 'C1-YeastTNA1_1516_conv_RG_26oC_003_256xcropSub100.tif'
rawImg = rawSmall

raw = imread(rawImg)

# Load psf kernel image from same dir
# A cropped 64x64 PSF to reduce memory use, 21 z slices, 0.125 nm spacing in z
PSFsmall = 'gpsf_3D_1514_a3_001_WF-sub105crop64.tif'
# The same PSF but not cropped so much, 128x128, probably uses much more memory.
PSFbigger = 'gpsf_3D_1514_a3_001_WF-sub105crop128.tif'
# choose the psf to use from the options above.
PSF = PSFsmall
print (PSF)
kernel = imread(PSF)

#base number of iterations - RL converges slowly so need tens of iterations or maybe hundreds.
base_iter = 15


# Create an observer function to monitor convergence,
# where the first argument is the current state of the
# image as it is being deconvolved, i is the current iteration number (1-based),
# third is the padded current guess image, and fourth is the value of covergence metric R,
# and the remaining arguments (at TOW) only include the uncropped image result which
# is generally not useful for anything other than development or debugging
# imgs = []
def observer(decon_crop, i, decon, conv1, kerngaussh, *args):
#normalise the raw data so its sum is 1
rawNorm = raw / raw.sum()
#the sum should be 1
sumRaw = rawNorm.sum()
#imgs.append(decon_crop)
if i % 5 == 0:
sumBlurredModel = decon_crop.sum()
# compute convergence residuals between raw image and blurred model image
# with current test data these get bigger not smaller with iterations... weird..
convergenceResiduals = abs(sumRaw - sumBlurredModel)
convergenceR = convergenceResiduals / sumRaw
# let's try structural similarity (as it's supposed to be better then mean square error)
# which indeed seems to track convergence of this dataset in a way that seems to match the image results.
#SSIM in skimage:
structSim = ssim(rawNorm, decon_crop, data_range=decon_crop.max() - decon_crop.min())
#SSIM in TensorFlow should be faster: imported ft.image.ssim as ssim but this doesnt work... fix if too slow otherwize.
# structSim = ssim(raw, decon_crop, max_val=1.0, filter_size=11, filter_sigma=1.0, k1=0.01, k2=0.03)
print('Iter,{},RawSum,{:.3f},DeconSum,{:.3f},DeconMax,{:.3f},DeconStDev,{:.3f},SumResiduals,{:.3f},ConvergeR,{:.16f},SSIM,{:.3f},KerGMax,{:.3f}'.format(
i, sumRaw, sumBlurredModel, decon_crop.max(), decon_crop.std(), convergenceResiduals.max(), convergenceR.max(), structSim.max(), kerngaussh.max()))

# Run the deconvolution process and note that deconvolution initialization is best kept separate from
# execution since the "initialize" operation corresponds to creating a TensorFlow graph, which is a
# relatively expensive operation and should not be repeated across multiple executions

# initialize the TF graph for the deconvolution settings in use for certain sized input and psf images
# works for doing the same input data multiple times with different iteractions
# should work for doing different input data with same sizes of image and psf,
# eg a time series split into tiff 1 file per time point????
startAlgoinit = time.process_time()
# Run algorithm with observer function to track concvergence
algo = fd_restoration.RichardsonLucyDeconvolver(raw.ndim, observer_fn=observer).initialize()
# Run algorithm without observer function - much faster obvs.
#algo = fd_restoration.RichardsonLucyDeconvolver(raw.ndim).initialize()
TFinitTime = (time.process_time() - startAlgoinit)

# run the deconvolution itself
# in a loop making different numbers of iterations, multiples of base value of n_iter
multiRunFactor = 1
timingListIter = []
timingListTime = []

#send std output to a log file
sys.stdout = open('FlowDecLogGold' + str(base_iter) + 'multi' + str(multiRunFactor) + 'GAUSS.txt', 'w')
#logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

for i in range(1, multiRunFactor+1):
niter = (base_iter*i)
# start measuring time
startDec = time.process_time()
res = algo.run(fd_data.Acquisition(data=raw, kernel=kernel), niter=(niter)).data
# measure time here includes only the deconvolution, no file saving
DecTime = (time.process_time() - startDec)
# save the result # using skimage.external.tifffile.imsave
resultFileName = ('result' + rawImg + PSF + str(niter) + 'iterations.tif')
imsave(resultFileName, res)
# measure time here includes file saving
#DecTime = (time.process_time() - startDec)
print('Saved result image TIFF file ' + resultFileName)
print(str(DecTime) + ' is how many sec ' + str(niter) + ' iterations took.')
timingListIter.append(niter)
timingListTime.append(DecTime)

#benchmarking data output
print (str(importsTime) + ' seconds to import flowdec, TF and CUDA libs etc.')
print (str(importTFMathTime) + ' seconds to import TF Math')
print (str(TFinitTime) + ' sec is the tensorFlow initialisation time')
print ('Pairs of values of iterations done vs time in seconds')
print (timingListIter)
print (timingListTime)
print('Done')
77 changes: 69 additions & 8 deletions python/flowdec/restoration.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,27 +289,88 @@ def _build_tf_graph(self):
def cond(i, decon):
return i <= niter

def conv(data, kernel_fft):
return tf.math.real(fft_rev(fft_fwd(tf.cast(data, self.fft_dtype)) * kernel_fft))

def body(i, decon):
# Richardson-Lucy Iteration - logic taken largely from a combination of
def conv(inputData, kernel_fft):
return tf.math.real(fft_rev(fft_fwd(tf.cast(inputData, self.fft_dtype)) * kernel_fft))

def gaussian_kernel(size: int,
mean: float,
std: float,
):
"""Makes 3D gaussian Kernel for convolution."""

d = tf.distributions.Normal(mean, std)

vals = d.prob(tf.range(start = -size, limit = size + 1, dtype = tf.float32))

gauss_kernel = tf.einsum('i,j,k->ijk',
vals,
vals,
vals)
# return the kernel normalised to sum =1
return gauss_kernel / tf.reduce_sum(gauss_kernel)

gaussKernel = gaussian_kernel(9, 1.0, 7.0)
# Expand dimensions of `gauss_kernel` for `tf.nn.conv3d` signature.
gaussKernel = gaussKernel[:, :, :, tf.newaxis, tf.newaxis, tf.newaxis]

def body(i, decon,):
'''# Richardson-Lucy Iteration - logic taken largely from a combination of
# the scikit-image (real domain) and DeconvolutionLab2 implementations (complex domain)
# conv1 is the current model blurred with the PSF
conv1 = conv(decon, kern_fft)

# High-pass filter to avoid division by very small numbers (see DeconvolutionLab2)
blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1')
blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1')

# conv2 is the blurred model convolved with the flipped PSF
conv2 = conv(blur1, kern_fft_conj)

# Positivity constraint on result for iteration
decon = tf.maximum(decon * conv2, 0.)
decon = tf.maximum(decon * conv2, 0.)
'''

# Gold algorithm, ratio method, simpler then RL, doesnt use flipped OTF
# conv1 is the current model blurred with the PSF
conv1 = conv(decon, kern_fft)

# High-pass filter to avoid division by very small numbers (see DeconvolutionLab2)?
# we wont do it here as we will use the delta parameter in denom and numerrator of division to get blur2
# as per Stephan Ludwig et al 2019
# should normalise blur2 and decon each time because numbers get big and we risk overflow when multiplying in next step
conv1norm = conv1 / (tf.math.reduce_sum(conv1))
datatNorm = datat / (tf.math.reduce_sum(datat))
# this value seems to work well fo rthe images that are normalised to sum of 1
deltaParam = 1e-4
ratio = (datatNorm + deltaParam) / (conv1norm + deltaParam)
#blur1 = tf.where(conv1 < self.epsilon, tf.zeros_like(datat), datat / conv1, name='blur1')
#ratioNorm = ratio / (tf.math.reduce_sum(ratio))
#deconNorm = decon / (tf.math.reduce_sum(decon))
# decon is the normalised blurred model multiplied by the model
# Positivity constraint on result for iteration
decon = tf.maximum(decon * ratio, 0.)
# Smooth the intermediate result image with Gaussian of sigma 1 every 5th iteration
# to control noise buildup that Gold method is succeptible to.
# Use tf.nn.conv3d to convolve a Gaussian kernel with an image:
# Make Gaussian Kernel with desired specs using gaussian_kernel function defined above
if i % 5 == 0:
# Convolve decon with gauss kernel.
tf.nn.conv3d(decon, filter=gaussKernel, strides=[1, 1, 1, 1, 1], padding="SAME")
# normalise the result so the sum of the data is 1
decon = decon / (tf.math.reduce_sum(decon))

# TODO - Smoothing every 5 iterations with gaussian or wiener.
# TODO rescale back to input data sum intensity - probably need to adjust deltaParam too.

# If given an "observer", pass the current image restoration and iteration counter to it
if self.observer_fn is not None:
# Remove any cropping that may have been added as this is usually not desirable in observers
decon_crop = unpad_around_center(decon, tf.shape(datah))
_, i, decon = tf_observer([decon_crop, i, decon], self.observer_fn)
# normalise the result so the sum of the data is 1
decon_crop = decon_crop / (tf.math.reduce_sum(decon_crop))
# we can use these captured observed tensors to evaluate eg convergence
# in eg. the observer function used.
_, i, decon, conv1 = tf_observer(
[decon_crop, i, decon, conv1], self.observer_fn)

return i + 1, decon

Expand Down