1111
1212
1313def _rolling_average_object (parameters : PtychoParameters , new ):
14- if parameters .object_options .preconditioner is None :
15- parameters .object_options .preconditioner = new
16- else :
17- parameters .object_options .preconditioner = 0.5 * (
18- new + parameters .object_options .preconditioner
19- )
14+
15+ # if parameters.object_options.preconditioner is None:
16+ # parameters.object_options.preconditioner = new
17+ # else:
18+ # parameters.object_options.preconditioner = 0.5 * (
19+ # new + parameters.object_options.preconditioner
20+ # )
21+
22+ parameters .object_options .preconditioner = new
23+
2024 return parameters
2125
2226
2327def _rolling_average_probe (parameters : PtychoParameters , new ):
24- if parameters .probe_options .preconditioner is None :
25- parameters .probe_options .preconditioner = new
26- else :
27- parameters .probe_options .preconditioner = 0.5 * (
28- new + parameters .probe_options .preconditioner
29- )
28+
29+ # if parameters.probe_options.preconditioner is None:
30+ # parameters.probe_options.preconditioner = new
31+ # else:
32+ # parameters.probe_options.preconditioner = 0.5 * (
33+ # new + parameters.probe_options.preconditioner
34+ # )
35+
36+ parameters .probe_options .preconditioner = new
3037 return parameters
3138
3239
@@ -55,24 +62,31 @@ def make_certain_args_constant(
5562 lo : int ,
5663 hi : int ,
5764 ) -> None :
65+
5866 nonlocal psi_update_denominator
5967
6068 probe_amp = _probe_amp_sum (parameters .probe )[:, 0 ]
69+
6170 psi_update_denominator [0 ] = operator .diffraction .patch .adj (
6271 patches = probe_amp ,
6372 images = psi_update_denominator [0 ],
6473 positions = parameters .scan [lo :hi ],
6574 )
6675
6776 probe1 = parameters .probe [:, 0 ]
77+
6878 for i in range (1 , len (parameters .psi )):
79+
6980 probe1 = operator .diffraction .diffraction .fwd (
7081 probe = probe1 ,
7182 scan = parameters .scan [lo :hi ],
7283 psi = parameters .psi [i - 1 ],
7384 )
85+
7486 probe1 = operator .diffraction .propagation .fwd (probe1 )
87+
7588 probe_amp = _probe_amp_sum (probe1 )
89+
7690 psi_update_denominator [i ] = operator .diffraction .patch .adj (
7791 patches = probe_amp ,
7892 images = psi_update_denominator [i ],
@@ -107,7 +121,7 @@ def _probe_preconditioner(
107121) -> npt .NDArray :
108122
109123 probe_update_denominator = cp .zeros (
110- shape = parameters .probe .shape [- 2 :],
124+ shape = ( parameters .psi . shape [ 0 ], * parameters . probe .shape [- 2 :] ) ,
111125 dtype = parameters .probe .dtype ,
112126 )
113127
@@ -116,16 +130,31 @@ def make_certain_args_constant(
116130 lo : int ,
117131 hi : int ,
118132 ) -> None :
133+
119134 nonlocal probe_update_denominator
120135
121- # FIXME: Only use the first slice for the probe preconditioner
122- patches = operator .diffraction .patch .fwd (
123- images = parameters .psi [0 ],
124- positions = parameters .scan [lo :hi ],
125- patch_width = parameters .probe .shape [- 1 ],
126- )
127- probe_update_denominator [...] += _patch_amp_sum (patches )
128- assert probe_update_denominator .ndim == 2
136+ for i in range (0 , len (parameters .psi )):
137+
138+ patches = operator .diffraction .patch .fwd (
139+ images = parameters .psi [ i , ... ],
140+ positions = parameters .scan [lo :hi ],
141+ patch_width = parameters .probe .shape [- 1 ],
142+ )
143+
144+ probe_update_denominator [ i , ...] += _patch_amp_sum (patches )
145+
146+
147+ assert probe_update_denominator .ndim == 3
148+
149+ # patches = operator.diffraction.patch.fwd(
150+ # images=parameters.psi[0],
151+ # positions=parameters.scan[lo:hi],
152+ # patch_width=parameters.probe.shape[-1],
153+ # )
154+
155+ # probe_update_denominator[...] += _patch_amp_sum(patches)
156+
157+ # assert probe_update_denominator.ndim == 2
129158
130159 tike .communicators .stream .stream_and_modify2 (
131160 f = make_certain_args_constant ,
0 commit comments