|
60 | 60 | "outputs": [], |
61 | 61 | "source": [ |
62 | 62 | "import glasbey\n", |
| 63 | + "import h5py\n", |
63 | 64 | "import numpy as np\n", |
64 | 65 | "from plotly import graph_objects as go\n", |
65 | 66 | "from plotly.subplots import make_subplots\n", |
|
80 | 81 | "source": [ |
81 | 82 | "def n_chains_from_hdf5(sampler_name: str, filename: str) -> int:\n", |
82 | 83 | " if sampler_name == \"numpyro\":\n", |
83 | | - " import h5py\n", |
84 | | - "\n", |
85 | 84 | " with h5py.File(filename, \"r\") as f:\n", |
86 | 85 | " return len(f[\"/chains\"].keys())\n", |
87 | 86 | " return int(read_attrs_from_hdf5(filename, \"sampler_cfg\")[\"n_chains\"])\n", |
|
289 | 288 | "outputs": [], |
290 | 289 | "source": [ |
291 | 290 | "if SAMPLER_NAME == \"flowMC\":\n", |
292 | | - " global_acc_train = read_from_hdf5(inference_data_file, \"/acceptances/global/train\")\n", |
293 | | - " global_acc_prod = read_from_hdf5(inference_data_file, \"/acceptances/global/prod\")\n", |
294 | | - " local_acc_train = read_from_hdf5(inference_data_file, \"/acceptances/local/train\")\n", |
295 | | - " local_acc_prod = read_from_hdf5(inference_data_file, \"/acceptances/local/prod\")\n", |
| 291 | + " with h5py.File(inference_data_file, \"r\") as f:\n", |
| 292 | + " global_acc_train = read_from_hdf5(f, \"/acceptances/global/train\")\n", |
| 293 | + " global_acc_prod = read_from_hdf5(f, \"/acceptances/global/prod\")\n", |
| 294 | + " local_acc_train = read_from_hdf5(f, \"/acceptances/local/train\")\n", |
| 295 | + " local_acc_prod = read_from_hdf5(f, \"/acceptances/local/prod\")\n", |
296 | 296 | "\n", |
297 | 297 | " color_global, color_local = glasbey.create_palette(\n", |
298 | 298 | " palette_size=2, colorblind_safe=True\n", |
|
425 | 425 | " fig.show()" |
426 | 426 | ] |
427 | 427 | }, |
428 | | - { |
429 | | - "cell_type": "markdown", |
430 | | - "id": "cf7393f2", |
431 | | - "metadata": {}, |
432 | | - "source": [ |
433 | | - "# Training Chains\n", |
434 | | - "\n", |
435 | | - "<div style=\"border: 1px solid #00f; background-color: #eef; padding: 10px;\">\n", |
436 | | - " <strong>Note:</strong> If a sampler other than flowMC is used, no plots will be generated in this section.\n", |
437 | | - "</div>" |
438 | | - ] |
439 | | - }, |
440 | 428 | { |
441 | 429 | "cell_type": "code", |
442 | 430 | "execution_count": null, |
443 | | - "id": "0a6c9dff", |
| 431 | + "id": "dc096bb9", |
444 | 432 | "metadata": {}, |
445 | 433 | "outputs": [], |
446 | 434 | "source": [ |
447 | | - "if SAMPLER_NAME == \"flowMC\":\n", |
448 | | - " TRAINING_CHAINS = np.stack(\n", |
449 | | - " [\n", |
450 | | - " read_from_hdf5(inference_data_file, f\"/chains/train/chain_{i}/positions\")\n", |
451 | | - " for i in range(N_CHAINS)\n", |
452 | | - " ],\n", |
453 | | - " axis=1,\n", |
454 | | - " )\n", |
| 435 | + "def auxiliary_chains_plot(datapath: str, output_filename: str) -> None:\n", |
| 436 | + " try:\n", |
| 437 | + " with h5py.File(inference_data_file, \"r\") as f:\n", |
| 438 | + " chains = np.stack(\n", |
| 439 | + " [read_from_hdf5(f, datapath.format(i=i)) for i in range(N_CHAINS)],\n", |
| 440 | + " axis=1,\n", |
| 441 | + " )\n", |
| 442 | + " except Exception:\n", |
| 443 | + " return\n", |
| 444 | + "\n", |
455 | 445 | " fig = make_subplots(\n", |
456 | 446 | " rows=n_dims,\n", |
457 | 447 | " cols=1,\n", |
458 | 448 | " shared_xaxes=True,\n", |
459 | 449 | " vertical_spacing=vertical_spacing,\n", |
460 | 450 | " )\n", |
461 | 451 | "\n", |
| 452 | + " _n_samples = chains.shape[0]\n", |
| 453 | + "\n", |
462 | 454 | " for i in range(n_dims):\n", |
463 | 455 | " row = i + 1\n", |
464 | | - " data = TRAINING_CHAINS[..., i]\n", |
| 456 | + " data = chains[..., i]\n", |
465 | 457 | "\n", |
466 | 458 | " for c in range(N_CHAINS):\n", |
467 | 459 | " show_legend = i == 0\n", |
|
492 | 484 | " row = i + 1\n", |
493 | 485 | "\n", |
494 | 486 | " fig.update_yaxes(title_text=LABELS[i], **grid_style, row=row, col=1)\n", |
495 | | - " fig.update_xaxes(range=[0, n_samples_per_chain], **grid_style, row=row, col=1)\n", |
| 487 | + " fig.update_xaxes(range=[0, _n_samples], **grid_style, row=row, col=1)\n", |
496 | 488 | "\n", |
497 | 489 | " fig.update_xaxes(title_text=\"Iteration\", row=n_dims, col=1)\n", |
498 | 490 | "\n", |
499 | | - " fig.write_html(\n", |
500 | | - " \"figs/training_trace_plots.html\", include_plotlyjs=\"cdn\", full_html=True\n", |
501 | | - " )\n", |
| 491 | + " fig.write_html(output_filename, include_plotlyjs=\"cdn\", full_html=True)\n", |
502 | 492 | "\n", |
503 | 493 | " fig.show()" |
504 | 494 | ] |
505 | 495 | }, |
| 496 | + { |
| 497 | + "cell_type": "markdown", |
| 498 | + "id": "cf7393f2", |
| 499 | + "metadata": {}, |
| 500 | + "source": [ |
| 501 | + "# Training Chains\n", |
| 502 | + "\n", |
| 503 | + "<div style=\"border: 1px solid #00f; background-color: #eef; padding: 10px;\">\n", |
| 504 | + " <strong>Note:</strong> If a sampler other than flowMC is used, no plots will be generated in this section.\n", |
| 505 | + "</div>" |
| 506 | + ] |
| 507 | + }, |
| 508 | + { |
| 509 | + "cell_type": "code", |
| 510 | + "execution_count": null, |
| 511 | + "id": "0a6c9dff", |
| 512 | + "metadata": {}, |
| 513 | + "outputs": [], |
| 514 | + "source": [ |
| 515 | + "if SAMPLER_NAME == \"flowMC\":\n", |
| 516 | + " auxiliary_chains_plot(\n", |
| 517 | + " datapath=\"/chains/train/chain_{i}/positions\",\n", |
| 518 | + " output_filename=\"figs/training_trace_plots.html\",\n", |
| 519 | + " )" |
| 520 | + ] |
| 521 | + }, |
506 | 522 | { |
507 | 523 | "cell_type": "markdown", |
508 | 524 | "id": "f9367fe1", |
|
523 | 539 | "outputs": [], |
524 | 540 | "source": [ |
525 | 541 | "if SAMPLER_NAME == \"flowMC\":\n", |
526 | | - " PRODUCTION_CHAINS = np.stack(\n", |
527 | | - " [\n", |
528 | | - " read_from_hdf5(inference_data_file, f\"/chains/prod/chain_{i}/positions\")\n", |
529 | | - " for i in range(N_CHAINS)\n", |
530 | | - " ],\n", |
531 | | - " axis=1,\n", |
532 | | - " )\n", |
533 | | - " fig = make_subplots(\n", |
534 | | - " rows=n_dims,\n", |
535 | | - " cols=1,\n", |
536 | | - " shared_xaxes=True,\n", |
537 | | - " vertical_spacing=vertical_spacing,\n", |
538 | | - " )\n", |
539 | | - "\n", |
540 | | - " for i in range(n_dims):\n", |
541 | | - " row = i + 1\n", |
542 | | - " data = PRODUCTION_CHAINS[..., i]\n", |
543 | | - "\n", |
544 | | - " for c in range(N_CHAINS):\n", |
545 | | - " show_legend = i == 0\n", |
546 | | - "\n", |
547 | | - " fig.add_trace(\n", |
548 | | - " go.Scatter(\n", |
549 | | - " y=data[:, c],\n", |
550 | | - " mode=\"lines\",\n", |
551 | | - " line=dict(color=colors_n_chains[c], width=1.5),\n", |
552 | | - " name=f\"Chain {c}\",\n", |
553 | | - " legendgroup=f\"chain_{c}\",\n", |
554 | | - " showlegend=show_legend,\n", |
555 | | - " ),\n", |
556 | | - " row=row,\n", |
557 | | - " col=1,\n", |
558 | | - " )\n", |
559 | | - "\n", |
560 | | - " height = max(250, n_dims * 180)\n", |
561 | | - "\n", |
562 | | - " fig.update_layout(\n", |
563 | | - " height=height,\n", |
564 | | - " plot_bgcolor=\"white\",\n", |
565 | | - " margin=dict(l=80, r=60, t=40, b=60),\n", |
566 | | - " legend=dict(orientation=\"h\", yanchor=\"bottom\", y=1.02, xanchor=\"left\", x=0),\n", |
567 | | - " )\n", |
568 | | - "\n", |
569 | | - " for i in range(n_dims):\n", |
570 | | - " row = i + 1\n", |
571 | | - "\n", |
572 | | - " fig.update_yaxes(title_text=LABELS[i], **grid_style, row=row, col=1)\n", |
573 | | - " fig.update_xaxes(range=[0, n_samples_per_chain], **grid_style, row=row, col=1)\n", |
574 | | - "\n", |
575 | | - " fig.update_xaxes(title_text=\"Iteration\", row=n_dims, col=1)\n", |
576 | | - "\n", |
577 | | - " fig.write_html(\n", |
578 | | - " \"figs/production_trace_plots.html\", include_plotlyjs=\"cdn\", full_html=True\n", |
579 | | - " )\n", |
580 | | - "\n", |
581 | | - " fig.show()" |
| 542 | + " auxiliary_chains_plot(\n", |
| 543 | + " datapath=\"/chains/prod/chain_{i}/positions\",\n", |
| 544 | + " output_filename=\"figs/production_trace_plots.html\",\n", |
| 545 | + " )" |
582 | 546 | ] |
583 | 547 | }, |
584 | 548 | { |
|
0 commit comments