Skip to content

Commit b9003bf

Browse files
authored
Added a histogram option for plotting. (cctbx#1154)
Authored-by: David Mittan-Moreau <dwmoreau@lbl.gov>
1 parent 691c594 commit b9003bf

1 file changed

Lines changed: 89 additions & 19 deletions

File tree

xfel/small_cell/command_line/cake_plot.py

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,46 @@
1919
from dials.util.options import ArgumentParser
2020
from dxtbx.model.experiment_list import ExperimentList
2121
import matplotlib.pyplot as plt
22+
from matplotlib.colors import LogNorm
2223
from matplotlib.ticker import FuncFormatter
2324
import numpy as np
2425
import sys
2526

2627
phil_str = """
2728
mp {
28-
method = *serial mpi multiprocessing
29+
method = *multiprocessing mpi
2930
.type = choice
30-
.help = Parallelization method
31+
.help = Parallelization method. When method=multiprocessing and nproc=1, runs serially.
3132
nproc = 1
3233
.type = int
33-
.help = Number of processes for multiprocessing method
34+
.help = Number of processes. nproc=1 (default) runs serially with no pool overhead.
35+
}
36+
plot {
37+
method = *histogram scatter
38+
.type = choice
39+
.help = Plotting method: 2D histogram (default) or scatter plot
40+
d_min = None
41+
.type = float
42+
.help = Low-resolution cutoff in Angstroms (largest d-spacing shown). Controls left x-axis limit.
43+
d_max = None
44+
.type = float
45+
.help = High-resolution cutoff in Angstroms (smallest d-spacing shown). Controls right x-axis limit.
46+
scatter {
47+
spotsize = 0.5
48+
.type = float
49+
.help = Marker size for scatter plot points (matplotlib s parameter)
50+
alpha = 0.5
51+
.type = float
52+
.help = Transparency of scatter plot points (0=transparent, 1=opaque)
53+
}
54+
histogram {
55+
n_bins_radial = 100
56+
.type = int
57+
.help = Number of bins along the x-axis (1/d, radial direction)
58+
n_bins_azimuthal = 100
59+
.type = int
60+
.help = Number of bins along the y-axis (azimuthal angle direction)
61+
}
3462
}
3563
"""
3664

@@ -117,7 +145,7 @@ def extract_panel_data(experiments, reflections, params=None):
117145
118146
Returns a dict ``panel_id -> {'d': list/array, 'azi': list/array}``.
119147
"""
120-
if params is None or params.mp.method == 'serial':
148+
if params is None or (params.mp.method == 'multiprocessing' and params.mp.nproc == 1):
121149
return _extract_panel_data_serial(experiments, reflections)
122150
elif params.mp.method == 'mpi':
123151
return _extract_panel_data_mpi(experiments, reflections)
@@ -184,31 +212,73 @@ def _extract_panel_data_mp(experiments, reflections, nproc):
184212
return merge_panel_data(all_results)
185213

186214

187-
def plot_panel_data(panel_data):
215+
def plot_panel_data(panel_data, params=None):
188216
"""Create the cake plot.
189217
190218
``panel_data`` is the dict returned by ``extract_panel_data``.
191219
The figure is displayed interactively.
192220
"""
193-
cmap = plt.get_cmap('tab20')
221+
plot_params = params.plot if params is not None else None
222+
method = plot_params.method if plot_params is not None else 'histogram'
223+
d_min = plot_params.d_min if plot_params is not None else None
224+
d_max = plot_params.d_max if plot_params is not None else None
225+
194226
fig, ax = plt.subplots(figsize=(6, 3))
195-
for i, (panel_id, data) in enumerate(sorted(panel_data.items())):
196-
d_arr = np.array(data['d'])
197-
azi_arr = np.array(data['azi'])
198-
mask = d_arr > 0
199-
if not mask.any():
200-
continue
201-
x = 1.0 / d_arr[mask] # plotted values: 1/d (1/Å)
202-
y = azi_arr[mask]
203-
ax.scatter(x, y, s=0.5, alpha=0.5, color=cmap(i % 20), label=f'Panel {panel_id}')
204-
ax.set_ylabel('Azimuthal Angle (deg)')
205-
ax.set_xlabel('Resolution (Å)')
206-
# Format x‑axis to show resolution instead of 1/d
227+
207228
def resolution_formatter(x, pos):
208229
if x == 0:
209230
return '-'
210231
return f"{1/x:.2f}"
232+
233+
if method == 'histogram':
234+
# Combine all panels into a single 2D histogram
235+
all_d = np.concatenate([np.array(data['d']) for data in panel_data.values()])
236+
all_azi = np.concatenate([np.array(data['azi']) for data in panel_data.values()])
237+
mask = all_d > 0
238+
if d_min is not None:
239+
mask &= all_d <= d_min
240+
if d_max is not None:
241+
mask &= all_d >= d_max
242+
x = 1.0 / all_d[mask]
243+
y = all_azi[mask]
244+
n_bins_r = plot_params.histogram.n_bins_radial if plot_params is not None else 100
245+
n_bins_a = plot_params.histogram.n_bins_azimuthal if plot_params is not None else 100
246+
cmap = plt.get_cmap('binary').copy()
247+
h = ax.hist2d(x, y, bins=[n_bins_r, n_bins_a], cmap=cmap, norm=LogNorm(vmin=1))
248+
fig.colorbar(h[3], ax=ax, label='Counts')
249+
else:
250+
# Scatter mode: per-panel coloring
251+
cmap = plt.get_cmap('tab20')
252+
spotsize = plot_params.scatter.spotsize if plot_params is not None else 0.5
253+
alpha = plot_params.scatter.alpha if plot_params is not None else 0.5
254+
for i, (panel_id, data) in enumerate(sorted(panel_data.items())):
255+
d_arr = np.array(data['d'])
256+
azi_arr = np.array(data['azi'])
257+
mask = d_arr > 0
258+
if d_min is not None:
259+
mask &= d_arr <= d_min
260+
if d_max is not None:
261+
mask &= d_arr >= d_max
262+
if not mask.any():
263+
continue
264+
x = 1.0 / d_arr[mask]
265+
y = azi_arr[mask]
266+
ax.scatter(x, y, s=spotsize, alpha=alpha, color=cmap(i % 20), label=f'Panel {panel_id}')
267+
268+
ax.set_ylabel('Azimuthal Angle (deg)')
269+
ax.set_xlabel('Resolution (Å)')
211270
ax.xaxis.set_major_formatter(FuncFormatter(resolution_formatter))
271+
272+
# Apply resolution limits: d_min (low-res) → left x-limit, d_max (high-res) → right x-limit
273+
xlim_left = 1.0 / d_min if d_min is not None else None
274+
xlim_right = 1.0 / d_max if d_max is not None else None
275+
if xlim_left is not None or xlim_right is not None:
276+
current = ax.get_xlim()
277+
ax.set_xlim(
278+
xlim_left if xlim_left is not None else current[0],
279+
xlim_right if xlim_right is not None else current[1],
280+
)
281+
212282
fig.tight_layout()
213283
plt.show()
214284

@@ -273,7 +343,7 @@ def run(args=None):
273343

274344
panel_data = extract_panel_data(experiments, reflections, params)
275345
if is_rank0 and panel_data is not None:
276-
plot_panel_data(panel_data)
346+
plot_panel_data(panel_data, params)
277347

278348
if __name__ == '__main__':
279349
run()

0 commit comments

Comments
 (0)