Skip to content

Commit 06d0869

Browse files
feat(plot): add functions to plot distributions of particle data
Add `ozzy.plot.hist` and `ozzy.plot.hist_proj` to easily plot density distributions (histograms) of particle data, taking advantage of the seaborn functions [`seaborn.histplot`](https://seaborn.pydata.org/generated/seaborn.histplot.html) and [`seaborn.jointplot`](https://seaborn.pydata.org/generated/seaborn.jointplot.html). Previously it would have been necessary to bin the data first, and then plot, e.g.: ```python import ozzy as oz import ozzy.plot as oplt # A particle data Dataset ds = oz.Dataset(..., pic_data_type="part") ds_ps = ds.ozzy.get_phase_space(["p2", "x2"]) ds_ps["rho"].plot() ``` While now the following code is enough: ```python import ozzy as oz import ozzy.plot as oplt ds = oz.Dataset(..., pic_data_type='part') oplt.hist(ds, x="x2", y="p2") ```
2 parents a156af7 + 7c39207 commit 06d0869

File tree

2 files changed

+217
-2
lines changed

2 files changed

+217
-2
lines changed

docs/mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ plugins:
148148
- https://pandas.pydata.org/docs/objects.inv
149149
- https://docs.h5py.org/en/stable/objects.inv
150150
- https://matplotlib.org/stable/objects.inv
151+
- https://seaborn.pydata.org/objects.inv
151152

152153
- git-revision-date-localized:
153154
enabled: true #!ENV [CI, false]

src/ozzy/plot.py

+216-2
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def _cmap_exists(name):
193193
"font.serif": ["Noto Serif", "Source Serif 4", "serif"],
194194
"font.sans-serif": ["Arial", "Helvetica", "sans"],
195195
"text.usetex": False,
196-
"axes.grid": False,
197196
"axes.prop_cycle": plt.cycler("color", color_wheel),
198197
"grid.color": ".9",
199198
"axes.linewidth": "0.75",
@@ -210,10 +209,13 @@ def _cmap_exists(name):
210209
"savefig.transparent": True,
211210
"savefig.dpi": "300",
212211
"savefig.bbox": "tight",
212+
"xtick.bottom": True, # draw ticks on the bottom side
213+
"ytick.left": True, # draw ticks on the left side
214+
"axes.edgecolor": "black",
213215
}
214216

215217
sns.set_theme(
216-
style="ticks",
218+
style="whitegrid",
217219
font="serif",
218220
rc=ozparams,
219221
)
@@ -1093,3 +1095,215 @@ def imovie(
10931095
)
10941096

10951097
return hvobj
1098+
1099+
1100+
def hist(
1101+
do: xr.Dataset | xr.DataArray,
1102+
x: str | None = None,
1103+
y: str | None = None,
1104+
weight_var: str | None = "q",
1105+
bins: str | int | Iterable = "auto",
1106+
cmap: str | None = "cmc.bamako",
1107+
cbar: bool = False,
1108+
**kwargs,
1109+
) -> mpl.axes.Axes:
1110+
"""Create a weighted histogram plot using [`seaborn.histplot`][seaborn.histplot].
1111+
1112+
Parameters
1113+
----------
1114+
do : xarray.Dataset | xarray.DataArray
1115+
Input Dataset or DataArray to plot
1116+
x : str | None
1117+
Variable name for x-axis
1118+
y : str | None
1119+
Variable name for y-axis
1120+
weight_var : str | None
1121+
Variable name to use as weights
1122+
bins : str | int | Iterable
1123+
Generic bin parameter passed to [`seaborn.histplot`][seaborn.histplot]. It can be `'auto'`, the number of bins, or the breaks of the bins. Defaults to `200` for weighted data or to an automatically calculated number for unweighted data.
1124+
cmap : str | None
1125+
Colormap name. Uses `'cmc.bamako'` or the `ozzy.plot` sequential default
1126+
cbar : bool
1127+
Whether to display colorbar
1128+
**kwargs
1129+
Additional keyword arguments passed to [`seaborn.histplot()`][seaborn.histplot]
1130+
1131+
Returns
1132+
-------
1133+
matplotlib.axes.Axes
1134+
The plot axes object
1135+
1136+
Examples
1137+
--------
1138+
???+ example "Basic histogram"
1139+
```python
1140+
import ozzy as oz
1141+
import ozzy.plot as oplt
1142+
ds = oz.Dataset(...)
1143+
ax = oplt.hist(ds, x='p2')
1144+
```
1145+
1146+
???+ example "2D histogram with colorbar"
1147+
```python
1148+
import ozzy as oz
1149+
import ozzy.plot as oplt
1150+
ds = oz.Dataset(...)
1151+
ax = oplt.hist(ds, x='x2', y='p2', cbar=True)
1152+
```
1153+
"""
1154+
if cmap is None:
1155+
cmap = xr.get_options()["cmap_sequential"]
1156+
1157+
cmap_opts = {}
1158+
if (x is not None) and (y is not None):
1159+
cmap_opts["cmap"] = cmap
1160+
1161+
if (weight_var is not None) and (bins == "auto"):
1162+
bins = 200
1163+
1164+
ax = sns.histplot(
1165+
do.to_dataframe(),
1166+
x=x,
1167+
y=y,
1168+
weights=weight_var,
1169+
bins=bins,
1170+
cbar=cbar,
1171+
**cmap_opts,
1172+
**kwargs,
1173+
)
1174+
1175+
if x is not None:
1176+
if "long_name" in do[x].attrs:
1177+
xlab = do[x].attrs["long_name"]
1178+
else:
1179+
xlab = x
1180+
1181+
if "units" in do[x].attrs:
1182+
xun = f" [{do[x].attrs['units']}]"
1183+
else:
1184+
xun = ""
1185+
1186+
ax.set_xlabel(xlab + xun)
1187+
1188+
if y is not None:
1189+
if "long_name" in do[y].attrs:
1190+
ylab = do[y].attrs["long_name"]
1191+
else:
1192+
ylab = y
1193+
1194+
if "units" in do[y].attrs:
1195+
yun = f" [{do[y].attrs['units']}]"
1196+
else:
1197+
yun = ""
1198+
1199+
ax.set_ylabel(ylab + yun)
1200+
1201+
return ax
1202+
1203+
1204+
def hist_proj(
1205+
do: xr.Dataset | xr.DataArray,
1206+
x: str,
1207+
y: str,
1208+
weight_var: str | None = "q",
1209+
bins: str | int | Iterable = "auto",
1210+
cmap: str | None = "cmc.bamako",
1211+
space: float = 0,
1212+
refline: bool = False,
1213+
refline_kwargs: dict = {"x": 0, "y": 0, "linewidth": 1.0, "alpha": 0.5},
1214+
**kwargs,
1215+
) -> sns.JointGrid:
1216+
"""Create a 2D histogram plot with projected distributions using [`seaborn.jointplot(kind="hist")`][seaborn.jointplot].
1217+
1218+
Parameters
1219+
----------
1220+
do : xarray.Dataset | xarray.DataArray
1221+
Input Dataset or DataArray to plot
1222+
x : str
1223+
Variable name for x-axis
1224+
y : str
1225+
Variable name for y-axis
1226+
weight_var : str | None
1227+
Variable name to use as weights
1228+
bins : str | int | Iterable
1229+
Generic bin parameter passed to [`seaborn.histplot`][seaborn.histplot]. It can be `'auto'`, the number of bins, or the breaks of the bins. Defaults to `200` for weighted data or to an automatically calculated number for unweighted data.
1230+
cmap : str | None
1231+
Colormap name. Uses `'cmc.bamako'` or the `ozzy.plot` sequential default
1232+
space : float
1233+
Space between 2D plot and marginal projection plots
1234+
refline : bool
1235+
Whether to add reference lines (see [`seaborn.JointGrid.refline`][seaborn.JointGrid.refline])
1236+
refline_kwargs : dict
1237+
Keyword arguments for reference lines (see [`seaborn.JointGrid.refline`][seaborn.JointGrid.refline])
1238+
**kwargs
1239+
Additional keyword arguments passed to [`seaborn.jointplot()`][seaborn.jointplot]
1240+
1241+
Returns
1242+
-------
1243+
seaborn.JointGrid
1244+
The joint grid plot object
1245+
1246+
Examples
1247+
--------
1248+
???+ example "2D histogram with projected distributions"
1249+
```python
1250+
import ozzy as oz
1251+
import ozzy.plot as oplt
1252+
ds = oz.Dataset(...)
1253+
jg = oplt.hist_proj(ds, x='x2', y='p2')
1254+
```
1255+
1256+
???+ example "2D histogram with projected distributions and reference lines"
1257+
```python
1258+
import ozzy as oz
1259+
import ozzy.plot as oplt
1260+
ds = oz.Dataset(...)
1261+
jg = oplt.hist_proj(ds, x='x2', y='p2',
1262+
refline=True,
1263+
refline_kwargs={'x': 0, 'y': 0})
1264+
```
1265+
"""
1266+
if cmap is None:
1267+
cmap = xr.get_options()["cmap_sequential"]
1268+
1269+
if (weight_var is not None) and (bins == "auto"):
1270+
bins = 200
1271+
1272+
jg = sns.jointplot(
1273+
do.to_dataframe(),
1274+
x=x,
1275+
y=y,
1276+
weights=weight_var,
1277+
bins=bins,
1278+
space=space,
1279+
cmap=cmap,
1280+
kind="hist",
1281+
color=mpl.colormaps[cmap](
1282+
0.0
1283+
), # choose the lower bound of the color scale as the color for the projected bins
1284+
**kwargs,
1285+
)
1286+
1287+
if refline:
1288+
jg.refline(**refline_kwargs)
1289+
1290+
lab = {}
1291+
un = {}
1292+
for var in [x, y]:
1293+
if "long_name" in do[var].attrs:
1294+
lab[var] = do[var].attrs["long_name"]
1295+
else:
1296+
lab[var] = var
1297+
if "units" in do[var].attrs:
1298+
un[var] = f" [{do[var].attrs['units']}]"
1299+
else:
1300+
un[var] = ""
1301+
1302+
jg.set_axis_labels(
1303+
xlabel=lab[x] + un[x],
1304+
ylabel=lab[y] + un[y],
1305+
)
1306+
jg.ax_marg_x.grid(False)
1307+
jg.ax_marg_y.grid(False)
1308+
1309+
return jg

0 commit comments

Comments
 (0)