Skip to content

Commit 467f478

Browse files
committed
refactor: skip plot generation for unavailable datasets
1 parent 734215f commit 467f478

1 file changed

Lines changed: 52 additions & 88 deletions

File tree

src/gwkokab/analysis/report/template_report.ipynb

Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"outputs": [],
6161
"source": [
6262
"import glasbey\n",
63+
"import h5py\n",
6364
"import numpy as np\n",
6465
"from plotly import graph_objects as go\n",
6566
"from plotly.subplots import make_subplots\n",
@@ -80,8 +81,6 @@
8081
"source": [
8182
"def n_chains_from_hdf5(sampler_name: str, filename: str) -> int:\n",
8283
" if sampler_name == \"numpyro\":\n",
83-
" import h5py\n",
84-
"\n",
8584
" with h5py.File(filename, \"r\") as f:\n",
8685
" return len(f[\"/chains\"].keys())\n",
8786
" return int(read_attrs_from_hdf5(filename, \"sampler_cfg\")[\"n_chains\"])\n",
@@ -289,10 +288,11 @@
289288
"outputs": [],
290289
"source": [
291290
"if SAMPLER_NAME == \"flowMC\":\n",
292-
" global_acc_train = read_from_hdf5(inference_data_file, \"/acceptances/global/train\")\n",
293-
" global_acc_prod = read_from_hdf5(inference_data_file, \"/acceptances/global/prod\")\n",
294-
" local_acc_train = read_from_hdf5(inference_data_file, \"/acceptances/local/train\")\n",
295-
" local_acc_prod = read_from_hdf5(inference_data_file, \"/acceptances/local/prod\")\n",
291+
" with h5py.File(inference_data_file, \"r\") as f:\n",
292+
" global_acc_train = read_from_hdf5(f, \"/acceptances/global/train\")\n",
293+
" global_acc_prod = read_from_hdf5(f, \"/acceptances/global/prod\")\n",
294+
" local_acc_train = read_from_hdf5(f, \"/acceptances/local/train\")\n",
295+
" local_acc_prod = read_from_hdf5(f, \"/acceptances/local/prod\")\n",
296296
"\n",
297297
" color_global, color_local = glasbey.create_palette(\n",
298298
" palette_size=2, colorblind_safe=True\n",
@@ -425,43 +425,35 @@
425425
" fig.show()"
426426
]
427427
},
428-
{
429-
"cell_type": "markdown",
430-
"id": "cf7393f2",
431-
"metadata": {},
432-
"source": [
433-
"# Training Chains\n",
434-
"\n",
435-
"<div style=\"border: 1px solid #00f; background-color: #eef; padding: 10px;\">\n",
436-
" <strong>Note:</strong> If a sampler other than flowMC is used, no plots will be generated in this section.\n",
437-
"</div>"
438-
]
439-
},
440428
{
441429
"cell_type": "code",
442430
"execution_count": null,
443-
"id": "0a6c9dff",
431+
"id": "dc096bb9",
444432
"metadata": {},
445433
"outputs": [],
446434
"source": [
447-
"if SAMPLER_NAME == \"flowMC\":\n",
448-
" TRAINING_CHAINS = np.stack(\n",
449-
" [\n",
450-
" read_from_hdf5(inference_data_file, f\"/chains/train/chain_{i}/positions\")\n",
451-
" for i in range(N_CHAINS)\n",
452-
" ],\n",
453-
" axis=1,\n",
454-
" )\n",
435+
"def auxiliary_chains_plot(datapath: str, output_filename: str) -> None:\n",
436+
" try:\n",
437+
" with h5py.File(inference_data_file, \"r\") as f:\n",
438+
" chains = np.stack(\n",
439+
" [read_from_hdf5(f, datapath.format(i=i)) for i in range(N_CHAINS)],\n",
440+
" axis=1,\n",
441+
" )\n",
442+
" except Exception:\n",
443+
" return\n",
444+
"\n",
455445
" fig = make_subplots(\n",
456446
" rows=n_dims,\n",
457447
" cols=1,\n",
458448
" shared_xaxes=True,\n",
459449
" vertical_spacing=vertical_spacing,\n",
460450
" )\n",
461451
"\n",
452+
" _n_samples = chains.shape[0]\n",
453+
"\n",
462454
" for i in range(n_dims):\n",
463455
" row = i + 1\n",
464-
" data = TRAINING_CHAINS[..., i]\n",
456+
" data = chains[..., i]\n",
465457
"\n",
466458
" for c in range(N_CHAINS):\n",
467459
" show_legend = i == 0\n",
@@ -492,17 +484,41 @@
492484
" row = i + 1\n",
493485
"\n",
494486
" fig.update_yaxes(title_text=LABELS[i], **grid_style, row=row, col=1)\n",
495-
" fig.update_xaxes(range=[0, n_samples_per_chain], **grid_style, row=row, col=1)\n",
487+
" fig.update_xaxes(range=[0, _n_samples], **grid_style, row=row, col=1)\n",
496488
"\n",
497489
" fig.update_xaxes(title_text=\"Iteration\", row=n_dims, col=1)\n",
498490
"\n",
499-
" fig.write_html(\n",
500-
" \"figs/training_trace_plots.html\", include_plotlyjs=\"cdn\", full_html=True\n",
501-
" )\n",
491+
" fig.write_html(output_filename, include_plotlyjs=\"cdn\", full_html=True)\n",
502492
"\n",
503493
" fig.show()"
504494
]
505495
},
496+
{
497+
"cell_type": "markdown",
498+
"id": "cf7393f2",
499+
"metadata": {},
500+
"source": [
501+
"# Training Chains\n",
502+
"\n",
503+
"<div style=\"border: 1px solid #00f; background-color: #eef; padding: 10px;\">\n",
504+
" <strong>Note:</strong> If a sampler other than flowMC is used, no plots will be generated in this section.\n",
505+
"</div>"
506+
]
507+
},
508+
{
509+
"cell_type": "code",
510+
"execution_count": null,
511+
"id": "0a6c9dff",
512+
"metadata": {},
513+
"outputs": [],
514+
"source": [
515+
"if SAMPLER_NAME == \"flowMC\":\n",
516+
" auxiliary_chains_plot(\n",
517+
" datapath=\"/chains/train/chain_{i}/positions\",\n",
518+
" output_filename=\"figs/training_trace_plots.html\",\n",
519+
" )"
520+
]
521+
},
506522
{
507523
"cell_type": "markdown",
508524
"id": "f9367fe1",
@@ -523,62 +539,10 @@
523539
"outputs": [],
524540
"source": [
525541
"if SAMPLER_NAME == \"flowMC\":\n",
526-
" PRODUCTION_CHAINS = np.stack(\n",
527-
" [\n",
528-
" read_from_hdf5(inference_data_file, f\"/chains/prod/chain_{i}/positions\")\n",
529-
" for i in range(N_CHAINS)\n",
530-
" ],\n",
531-
" axis=1,\n",
532-
" )\n",
533-
" fig = make_subplots(\n",
534-
" rows=n_dims,\n",
535-
" cols=1,\n",
536-
" shared_xaxes=True,\n",
537-
" vertical_spacing=vertical_spacing,\n",
538-
" )\n",
539-
"\n",
540-
" for i in range(n_dims):\n",
541-
" row = i + 1\n",
542-
" data = PRODUCTION_CHAINS[..., i]\n",
543-
"\n",
544-
" for c in range(N_CHAINS):\n",
545-
" show_legend = i == 0\n",
546-
"\n",
547-
" fig.add_trace(\n",
548-
" go.Scatter(\n",
549-
" y=data[:, c],\n",
550-
" mode=\"lines\",\n",
551-
" line=dict(color=colors_n_chains[c], width=1.5),\n",
552-
" name=f\"Chain {c}\",\n",
553-
" legendgroup=f\"chain_{c}\",\n",
554-
" showlegend=show_legend,\n",
555-
" ),\n",
556-
" row=row,\n",
557-
" col=1,\n",
558-
" )\n",
559-
"\n",
560-
" height = max(250, n_dims * 180)\n",
561-
"\n",
562-
" fig.update_layout(\n",
563-
" height=height,\n",
564-
" plot_bgcolor=\"white\",\n",
565-
" margin=dict(l=80, r=60, t=40, b=60),\n",
566-
" legend=dict(orientation=\"h\", yanchor=\"bottom\", y=1.02, xanchor=\"left\", x=0),\n",
567-
" )\n",
568-
"\n",
569-
" for i in range(n_dims):\n",
570-
" row = i + 1\n",
571-
"\n",
572-
" fig.update_yaxes(title_text=LABELS[i], **grid_style, row=row, col=1)\n",
573-
" fig.update_xaxes(range=[0, n_samples_per_chain], **grid_style, row=row, col=1)\n",
574-
"\n",
575-
" fig.update_xaxes(title_text=\"Iteration\", row=n_dims, col=1)\n",
576-
"\n",
577-
" fig.write_html(\n",
578-
" \"figs/production_trace_plots.html\", include_plotlyjs=\"cdn\", full_html=True\n",
579-
" )\n",
580-
"\n",
581-
" fig.show()"
542+
" auxiliary_chains_plot(\n",
543+
" datapath=\"/chains/prod/chain_{i}/positions\",\n",
544+
" output_filename=\"figs/production_trace_plots.html\",\n",
545+
" )"
582546
]
583547
},
584548
{

0 commit comments

Comments
 (0)