29
29
30
30
import xmippPyModules .torch .image as image
31
31
32
-
33
-
34
- def run (prev_volume_path : str ,
32
+ def run (map_volume_path : str ,
35
33
rec_volume_path : str ,
36
- output_path : str ,
34
+ sigma2_volume_path : Optional [str ],
35
+ output_map_volume_path : str ,
36
+ output_sigma2_volume_path : str ,
37
37
gamma : float ,
38
38
nu : float ,
39
39
epsilon : float ):
40
40
41
41
# Read input volumes
42
- prev = image .read (prev_volume_path )
42
+ map = image .read (map_volume_path )
43
43
rec = image .read (rec_volume_path )
44
+ if sigma2_volume_path is not None :
45
+ sigma2 = image .read (sigma2_volume_path )
46
+ else :
47
+ sigma2 = torch .zeros_like (rec )
44
48
45
49
# Compute the gradient
46
- grad = rec - prev
50
+ grad = rec - map
47
51
48
52
# Compute the magnitude
49
- sigma2 = gamma * torch .var (prev ) + (1.0 - gamma )* torch .sum (grad ** 2 )
53
+ sigma2 *= gamma
54
+ sigma2 += (1.0 - gamma )* torch .square (grad )
50
55
51
56
# Compute the gradient gain
52
57
gain = nu / (torch .sqrt (sigma2 ) + epsilon )
53
58
54
59
# Compute the next volume
55
- next = prev + gain * grad
60
+ map += gain * grad
56
61
57
62
# Write
58
- image .write (next , output_path )
59
-
63
+ image .write (map , output_map_volume_path )
64
+ image . write ( sigma2 , output_sigma2_volume_path )
60
65
61
66
62
67
if __name__ == '__main__' :
63
68
# Define the input
64
69
parser = argparse .ArgumentParser (
65
70
prog = 'RMS Prop reconstruction' )
66
- parser .add_argument ('--prev ' , type = str , required = True )
71
+ parser .add_argument ('--map ' , type = str , required = True )
67
72
parser .add_argument ('--rec' , type = str , required = True )
68
- parser .add_argument ('-o' , type = str , required = True )
73
+ parser .add_argument ('--sigma2' , type = str )
74
+ parser .add_argument ('--omap' , type = str , required = True )
75
+ parser .add_argument ('--osigma2' , type = str , required = True )
69
76
parser .add_argument ('--gamma' , type = float , default = 0.9 )
70
77
parser .add_argument ('--nu' , type = float , default = 0.001 )
71
78
parser .add_argument ('--eps' , type = float , default = 1e-8 )
@@ -75,10 +82,12 @@ def run(prev_volume_path: str,
75
82
76
83
# Run the program
77
84
run (
78
- prev_volume_path = args .prev ,
85
+ map_volume_path = args .map ,
79
86
rec_volume_path = args .rec ,
80
- output_path = args .o ,
87
+ sigma2_volume_path = args .sigma2 ,
88
+ output_map_volume_path = args .omap ,
89
+ output_sigma2_volume_path = args .osigma2 ,
81
90
gamma = args .gamma ,
82
91
nu = args .nu ,
83
- epsilon = args .eps
92
+ epsilon = args .eps
84
93
)
0 commit comments