Skip to content

Commit 0b8fa67

Browse files
authored
Merge pull request #332 from a4894z/main
multislice rPIE ready for main branch after testing Daniel's reworkng of multiGPU workflow
2 parents 3def223 + 10e0f0b commit 0b8fa67

File tree

6 files changed

+304
-64
lines changed

6 files changed

+304
-64
lines changed

src/tike/operators/cupy/fresnelspectprop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def _create_fresnel_spectrum_propagator(
123123
xgrid = ( 0.5 + self.xp.linspace( ( -0.5 * N[1] ), ( 0.5 * N[1] - 1 ), num = N[1] )) / N[1]
124124
ygrid = ( 0.5 + self.xp.linspace( ( -0.5 * N[0] ), ( 0.5 * N[0] - 1 ), num = N[0] )) / N[0]
125125

126-
kx = 2 * self.xp.pi * N[0] * xgrid / probe_FOV[ 0 ]
127-
ky = 2 * self.xp.pi * N[1] * ygrid / probe_FOV[ 1 ]
126+
kx = 2 * self.xp.pi * N[1] * xgrid / probe_FOV[ 1 ]
127+
ky = 2 * self.xp.pi * N[0] * ygrid / probe_FOV[ 0 ]
128128

129129
Kx, Ky = self.xp.meshgrid(kx, ky, indexing='xy')
130130

src/tike/operators/cupy/multislice.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ def fwd(
7373
psi: npt.NDArray[np.csingle],
7474
**kwargs,
7575
) -> npt.NDArray[np.csingle]:
76+
7677
"""Please see help(Multislice) for more info."""
78+
7779
assert psi.ndim == 3
80+
7881
exitwave = self.diffraction.fwd(
7982
psi=psi[0],
8083
scan=scan,
@@ -88,6 +91,56 @@ def fwd(
8891
)
8992
return exitwave
9093

94+
95+
96+
97+
def fwd_return_intermediate_probes(
98+
self,
99+
probe: npt.NDArray[np.csingle],
100+
scan: npt.NDArray[np.single],
101+
psi: npt.NDArray[np.csingle],
102+
**kwargs,
103+
) -> npt.NDArray[np.csingle]:
104+
105+
"""Please see help(Multislice) for more info."""
106+
107+
assert psi.ndim == 3
108+
109+
110+
# exitwave = self.diffraction.fwd(
111+
# psi=psi[0],
112+
# scan=scan,
113+
# probe=probe,
114+
# )
115+
116+
# for s in range(1, len(psi)):
117+
# exitwave = self.diffraction.fwd(
118+
# psi=psi[s],
119+
# scan=scan,
120+
# probe=self.propagation.fwd(exitwave),
121+
# )
122+
123+
# return exitwave
124+
125+
multislice_probes = self.xp.zeros( ( psi.shape[0], scan.shape[-2], *probe.shape[-3:] ), dtype=probe.dtype )
126+
multislice_probes[ 0, ... ] = probe[..., 0, :, :, :]
127+
128+
for tt in range(0, len(psi)) :
129+
130+
multislice_exwv = self.diffraction.fwd(
131+
psi = psi[ tt, ... ],
132+
scan = scan,
133+
probe = multislice_probes[ tt, ... ],
134+
)
135+
136+
if tt == ( psi.shape[0] - 1 ) :
137+
break
138+
139+
multislice_probes[ tt + 1, ... ] = self.propagation.fwd( nearplane = multislice_exwv, )
140+
141+
return multislice_exwv, multislice_probes
142+
143+
91144
def adj(
92145
self,
93146
nearplane: npt.NDArray[np.csingle],

src/tike/operators/cupy/ptycho.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,31 @@ def fwd(
128128
overwrite=True,
129129
)[..., None, :, :, :]
130130

131+
def fwd_return_intermediate_probes(
132+
self,
133+
probe: npt.NDArray[np.csingle],
134+
scan: npt.NDArray[np.single],
135+
psi: npt.NDArray[np.csingle],
136+
**kwargs,
137+
) -> npt.NDArray[np.csingle]:
138+
139+
"""Please see help(Ptycho) for more info."""
140+
141+
# return self.propagation.fwd(
142+
# self.diffraction.fwd(
143+
# psi=psi,
144+
# scan=scan,
145+
# probe=probe[..., 0, :, :, :],
146+
# ),
147+
# overwrite=True,
148+
# )[..., None, :, :, :]
149+
150+
multislice_exwv, multislice_probes = self.diffraction.fwd_return_intermediate_probes( psi=psi, scan=scan, probe=probe, )
151+
152+
multislice_farfield = self.propagation.fwd( nearplane = multislice_exwv, overwrite=True, )[..., None, :, :, :]
153+
154+
return multislice_farfield, multislice_probes
155+
131156
def adj(
132157
self,
133158
farplane: npt.NDArray[np.csingle],

src/tike/ptycho/ptycho.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def __init__(
336336
"automatic psi initialization is not synchronized "
337337
"across processes.")
338338
else:
339-
mpi = tike.communicators.NoMPIComm
339+
mpi = tike.communicators.NoMPIComm # isinstance(mpi, tike.communicators.NoMPIComm) # THIS ALWAYS RETURNS FALSE?
340340

341341
self.data: typing.List[npt.ArrayLike] = [data]
342342
self.parameters: typing.List[solvers.PtychoParameters] = [
@@ -354,7 +354,8 @@ def __init__(
354354
probe_FOV_lengths=parameters.probe_options.probe_FOV_lengths,
355355
multislice_propagation_distance=parameters.object_options.multislice_propagation_distance,
356356
)
357-
self.comm = tike.communicators.Comm(num_gpu, mpi)
357+
358+
self.comm = tike.communicators.Comm(num_gpu, mpi)
358359

359360
def __enter__(self):
360361
self.device.__enter__()

src/tike/ptycho/solvers/_preconditioner.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,29 @@
1111

1212

1313
def _rolling_average_object(parameters: PtychoParameters, new):
14-
if parameters.object_options.preconditioner is None:
15-
parameters.object_options.preconditioner = new
16-
else:
17-
parameters.object_options.preconditioner = 0.5 * (
18-
new + parameters.object_options.preconditioner
19-
)
14+
15+
# if parameters.object_options.preconditioner is None:
16+
# parameters.object_options.preconditioner = new
17+
# else:
18+
# parameters.object_options.preconditioner = 0.5 * (
19+
# new + parameters.object_options.preconditioner
20+
# )
21+
22+
parameters.object_options.preconditioner = new
23+
2024
return parameters
2125

2226

2327
def _rolling_average_probe(parameters: PtychoParameters, new):
24-
if parameters.probe_options.preconditioner is None:
25-
parameters.probe_options.preconditioner = new
26-
else:
27-
parameters.probe_options.preconditioner = 0.5 * (
28-
new + parameters.probe_options.preconditioner
29-
)
28+
29+
# if parameters.probe_options.preconditioner is None:
30+
# parameters.probe_options.preconditioner = new
31+
# else:
32+
# parameters.probe_options.preconditioner = 0.5 * (
33+
# new + parameters.probe_options.preconditioner
34+
# )
35+
36+
parameters.probe_options.preconditioner = new
3037
return parameters
3138

3239

@@ -55,24 +62,31 @@ def make_certain_args_constant(
5562
lo: int,
5663
hi: int,
5764
) -> None:
65+
5866
nonlocal psi_update_denominator
5967

6068
probe_amp = _probe_amp_sum(parameters.probe)[:, 0]
69+
6170
psi_update_denominator[0] = operator.diffraction.patch.adj(
6271
patches=probe_amp,
6372
images=psi_update_denominator[0],
6473
positions=parameters.scan[lo:hi],
6574
)
6675

6776
probe1 = parameters.probe[:, 0]
77+
6878
for i in range(1, len(parameters.psi)):
79+
6980
probe1 = operator.diffraction.diffraction.fwd(
7081
probe=probe1,
7182
scan=parameters.scan[lo:hi],
7283
psi=parameters.psi[i-1],
7384
)
85+
7486
probe1 = operator.diffraction.propagation.fwd(probe1)
87+
7588
probe_amp = _probe_amp_sum(probe1)
89+
7690
psi_update_denominator[i] = operator.diffraction.patch.adj(
7791
patches=probe_amp,
7892
images=psi_update_denominator[i],
@@ -107,7 +121,7 @@ def _probe_preconditioner(
107121
) -> npt.NDArray:
108122

109123
probe_update_denominator = cp.zeros(
110-
shape=parameters.probe.shape[-2:],
124+
shape=( parameters.psi.shape[0], *parameters.probe.shape[-2:] ),
111125
dtype=parameters.probe.dtype,
112126
)
113127

@@ -116,16 +130,31 @@ def make_certain_args_constant(
116130
lo: int,
117131
hi: int,
118132
) -> None:
133+
119134
nonlocal probe_update_denominator
120135

121-
# FIXME: Only use the first slice for the probe preconditioner
122-
patches = operator.diffraction.patch.fwd(
123-
images=parameters.psi[0],
124-
positions=parameters.scan[lo:hi],
125-
patch_width=parameters.probe.shape[-1],
126-
)
127-
probe_update_denominator[...] += _patch_amp_sum(patches)
128-
assert probe_update_denominator.ndim == 2
136+
for i in range(0, len(parameters.psi)):
137+
138+
patches = operator.diffraction.patch.fwd(
139+
images=parameters.psi[ i, ... ],
140+
positions=parameters.scan[lo:hi],
141+
patch_width=parameters.probe.shape[-1],
142+
)
143+
144+
probe_update_denominator[ i, ...] += _patch_amp_sum(patches)
145+
146+
147+
assert probe_update_denominator.ndim == 3
148+
149+
# patches = operator.diffraction.patch.fwd(
150+
# images=parameters.psi[0],
151+
# positions=parameters.scan[lo:hi],
152+
# patch_width=parameters.probe.shape[-1],
153+
# )
154+
155+
# probe_update_denominator[...] += _patch_amp_sum(patches)
156+
157+
# assert probe_update_denominator.ndim == 2
129158

130159
tike.communicators.stream.stream_and_modify2(
131160
f=make_certain_args_constant,

0 commit comments

Comments
 (0)