Skip to content

Commit 1edb832

Browse files
committed
Updated RMSprop
1 parent f279e0d commit 1edb832

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

src/xmipp/applications/scripts/rms_prop_reconstruction/rms_reconstruction.prop.py renamed to src/xmipp/applications/scripts/rms_prop_reconstruction/rms_prop_reconstruction.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,50 @@
2929

3030
import xmippPyModules.torch.image as image
3131

32-
33-
34-
def run(prev_volume_path: str,
32+
def run(map_volume_path: str,
3533
rec_volume_path: str,
36-
output_path: str,
34+
sigma2_volume_path: Optional[str],
35+
output_map_volume_path: str,
36+
output_sigma2_volume_path: str,
3737
gamma: float,
3838
nu: float,
3939
epsilon: float ):
4040

4141
# Read input volumes
42-
prev = image.read(prev_volume_path)
42+
map = image.read(map_volume_path)
4343
rec = image.read(rec_volume_path)
44+
if sigma2_volume_path is not None:
45+
sigma2 = image.read(sigma2_volume_path)
46+
else:
47+
sigma2 = torch.zeros_like(rec)
4448

4549
# Compute the gradient
46-
grad = rec - prev
50+
grad = rec - map
4751

4852
# Compute the magnitude
49-
sigma2 = gamma*torch.var(prev) + (1.0 - gamma)*torch.sum(grad**2)
53+
sigma2 *= gamma
54+
sigma2 += (1.0 - gamma)*torch.square(grad)
5055

5156
# Compute the gradient gain
5257
gain = nu / (torch.sqrt(sigma2) + epsilon)
5358

5459
# Compute the next volume
55-
next = prev + gain*grad
60+
map += gain*grad
5661

5762
# Write
58-
image.write(next, output_path)
59-
63+
image.write(map, output_map_volume_path)
64+
image.write(sigma2, output_sigma2_volume_path)
6065

6166

6267
if __name__ == '__main__':
6368
# Define the input
6469
parser = argparse.ArgumentParser(
6570
prog = 'RMS Prop reconstruction' )
66-
parser.add_argument('--prev', type=str, required=True)
71+
parser.add_argument('--map', type=str, required=True)
6772
parser.add_argument('--rec', type=str, required=True)
68-
parser.add_argument('-o', type=str, required=True)
73+
parser.add_argument('--sigma2', type=str)
74+
parser.add_argument('--omap', type=str, required=True)
75+
parser.add_argument('--osigma2', type=str, required=True)
6976
parser.add_argument('--gamma', type=float, default=0.9)
7077
parser.add_argument('--nu', type=float, default=0.001)
7178
parser.add_argument('--eps', type=float, default=1e-8)
@@ -75,10 +82,12 @@ def run(prev_volume_path: str,
7582

7683
# Run the program
7784
run(
78-
prev_volume_path = args.prev,
85+
map_volume_path = args.map,
7986
rec_volume_path = args.rec,
80-
output_path = args.o,
87+
sigma2_volume_path = args.sigma2,
88+
output_map_volume_path = args.omap,
89+
output_sigma2_volume_path = args.osigma2,
8190
gamma = args.gamma,
8291
nu = args.nu,
83-
epsilon = args.eps
92+
epsilon = args.eps
8493
)

0 commit comments

Comments
 (0)