Skip to content

Commit 1718fcf

Browse files
authored
Add GitHub CI workflows. (#5)
* Move notebook to notebooks folder, addtional pre-commit hooks * Run pre-commit hooks on previous changes * add github ci workflows * update ci action versions * add nbdev to requirements * update pre-commit workflow
1 parent 9085894 commit 1718fcf

File tree

11 files changed

+231
-347
lines changed

11 files changed

+231
-347
lines changed

.github/workflows/pre-commit.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: Run pre-commit hooks
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
push:
7+
branches: [ main ]
8+
9+
jobs:
10+
pre-commit:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v3
15+
- uses: actions/setup-python@v4
16+
with:
17+
python-version: 3.11
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install nbdev
22+
- uses: pre-commit/[email protected]

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,11 @@ repos:
1616
hooks:
1717
- id: nbdev_clean
1818
args: [--fname=notebooks, --clear_all]
19+
20+
- repo: https://github.com/pre-commit/pre-commit-hooks
21+
rev: v6.0.0
22+
hooks:
23+
- id: check-json
24+
- id: check-yaml
25+
- id: trailing-whitespace
26+
- id: end-of-file-fixer

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ Self-supervised learning on HEP events.
4040
└── dataset2
4141
├── raw
4242
└── processed
43-
```
43+
```

notebooks/data_preprocessing.ipynb

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "4c99e37a-fe0e-424e-968e-e79e937a4498",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import numpy as np\n",
11+
"import awkward as ak\n",
12+
"import matplotlib.pyplot as plt"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"id": "a5f4f372-d217-4629-8773-4c7809a222ce",
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"file_path = \"/pscratch/sd/r/rmastand/particlemind/data/p8_ee_tt_ecm365_parquetfiles/reco_p8_ee_tt_ecm365_60000.parquet\"\n",
23+
"\n",
24+
"data = ak.from_parquet(file_path)"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"id": "3b615189-929d-4cf6-99f9-81b9dda7e69e",
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"print(data[\"calo_hit_features\"])"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": null,
40+
"id": "91e31f81-9f2e-4746-91dd-7fab50dc65e1",
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"all_features = []\n",
45+
"\n",
46+
"for event_i in range(100):\n",
47+
" calo_hit_features = data[\"calo_hit_features\"][event_i]\n",
48+
"\n",
49+
" calo_hit_features = np.column_stack(\n",
50+
" (\n",
51+
" calo_hit_features[\"position.x\"].to_numpy() / 1e4,\n",
52+
" calo_hit_features[\"position.y\"].to_numpy() / 1e4,\n",
53+
" calo_hit_features[\"position.z\"].to_numpy() / 1e4,\n",
54+
" np.log(calo_hit_features[\"energy\"].to_numpy() * 1e2) / 10,\n",
55+
" calo_hit_features[\"type\"].to_numpy(),\n",
56+
" calo_hit_features[\"subdetector\"].to_numpy(),\n",
57+
" )\n",
58+
" )\n",
59+
" all_features.append(calo_hit_features)\n",
60+
"\n",
61+
"all_features = np.concatenate(all_features)\n",
62+
"print(all_features.shape)"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"id": "c5ee3b74-8519-4f29-9a31-70f33010e9f8",
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"labels = [\"x\", \"y\", \"z\", \"energy\", \"type\", \"subdetector\"]\n",
73+
"\n",
74+
"for i in range(6):\n",
75+
"\n",
76+
" plt.figure()\n",
77+
" plt.hist(all_features[:, i], bins=100, histtype=\"step\")\n",
78+
" plt.xlabel(labels[i])\n",
79+
" plt.show()"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"id": "d070aead-4b9f-459c-bd56-df74b360a07d",
86+
"metadata": {},
87+
"outputs": [],
88+
"source": []
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "22264792-bca3-4797-99cd-4eb2de20ff8b",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": []
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"id": "ca4c11af-b40e-4676-9ef0-6535c75afc08",
102+
"metadata": {},
103+
"outputs": [],
104+
"source": []
105+
}
106+
],
107+
"metadata": {
108+
"kernelspec": {
109+
"display_name": "python3",
110+
"language": "python",
111+
"name": "python3"
112+
}
113+
},
114+
"nbformat": 4,
115+
"nbformat_minor": 5
116+
}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ pre-commit
2020
torch==2.5.1
2121
torchvision==0.20.1
2222
torchaudio==2.5.1
23+
nbdev

src/data_preprocessing.ipynb

Lines changed: 0 additions & 218 deletions
This file was deleted.

src/datasets/CLDHits.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def inverse_standardize_calo_hit_features(calo_hit_features):
5353

5454

5555
class CLDHits(IterableDataset):
56-
def __init__(self, folder_path, split, nsamples=None, shuffle_files=False, train_fraction=0.8, nfiles=-1, by_event=True):
56+
def __init__(
57+
self, folder_path, split, nsamples=None, shuffle_files=False, train_fraction=0.8, nfiles=-1, by_event=True
58+
):
5759
"""
5860
Initialize the dataset by storing the paths to all parquet files in the specified folder.
5961
@@ -101,7 +103,7 @@ def __iter__(self):
101103
worker_info = torch.utils.data.get_worker_info()
102104
if worker_info is None:
103105
# Single-process data loading
104-
files_to_process = self.parquet_files[:self.nfiles]
106+
files_to_process = self.parquet_files[: self.nfiles]
105107
logger.info(f"Processing {len(files_to_process)} files in single-process mode.")
106108

107109
else:
@@ -157,8 +159,6 @@ def __iter__(self):
157159
self.sample_counter += 1
158160

159161
yield {
160-
"hit_labels": hit_labels[i:i+1], # Shape (1,) or (1, label_dim)
161-
"calo_hit_features": calo_hit_features[i:i+1], # Shape (1, num_features)
162-
}
163-
164-
162+
"hit_labels": hit_labels[i : i + 1], # Shape (1,) or (1, label_dim)
163+
"calo_hit_features": calo_hit_features[i : i + 1], # Shape (1, num_features)
164+
}

src/datasets/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33

4+
45
# adapted from from: https://github.com/jpata/particleflow/blob/a3a08fe1e687987c661faad00fd5526e733be014/mlpf/model/PFDataset.py#L163
56
class Collater:
67
"""
@@ -31,9 +32,9 @@ def __call__(self, inputs):
3132
)
3233

3334
# get mask
34-
axis_sum = torch.sum(torch.abs(ret["calo_hit_features"]), dim = 2)
35+
axis_sum = torch.sum(torch.abs(ret["calo_hit_features"]), dim=2)
3536
ret["calo_hit_mask"] = torch.where(axis_sum > 0, 1.0, 0.0)
36-
37+
3738
return ret
3839

3940
# per-particle quantities need to be padded across events of different size

src/models/vqvae.py

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,7 @@ def __init__(
177177
self.output_dim = output_dim
178178

179179
self.blocks = nn.ModuleList(
180-
[
181-
NormformerBlock(input_dim=hidden_dim, mlp_dim=hidden_dim, num_heads=num_heads)
182-
for _ in range(num_blocks)
183-
]
180+
[NormformerBlock(input_dim=hidden_dim, mlp_dim=hidden_dim, num_heads=num_heads) for _ in range(num_blocks)]
184181
)
185182
self.project_out = nn.Linear(hidden_dim, output_dim)
186183

@@ -336,8 +333,8 @@ class VQVAELightning(L.LightningModule):
336333

337334
def __init__(
338335
self,
339-
optimizer_kwargs = {},
340-
#scheduler_kwargs = {},
336+
optimizer_kwargs={},
337+
# scheduler_kwargs = {},
341338
model_kwargs={},
342339
model_type="Transformer",
343340
**kwargs,
@@ -373,8 +370,7 @@ def __init__(
373370
self.val_mask = []
374371

375372
def configure_optimizers(self):
376-
optimizer = torch.optim.AdamW(
377-
self.model.parameters(), **self.optimizer_kwargs)
373+
optimizer = torch.optim.AdamW(self.model.parameters(), **self.optimizer_kwargs)
378374
"""
379375
if self.lr_scheduler:
380376
return {
@@ -387,7 +383,7 @@ def configure_optimizers(self):
387383
}
388384
"""
389385
return optimizer
390-
386+
391387
def forward(self, x_particle, mask_particle):
392388
x_particle_reco, vq_out = self.model(x_particle, mask=mask_particle)
393389
return x_particle_reco, vq_out
@@ -428,24 +424,22 @@ def on_train_start(self) -> None:
428424
self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict
429425
)
430426
"""
427+
431428
def on_train_epoch_start(self):
432429
logger.info(f"Epoch {self.trainer.current_epoch} starting.")
433430
self.epoch_train_start_time = time.time() # start timing the epoch
434431

435432
def on_train_epoch_end(self):
436433
self.epoch_train_end_time = time.time()
437-
self.epoch_train_duration_minutes = (
438-
self.epoch_train_end_time - self.epoch_train_start_time
439-
) / 60
434+
self.epoch_train_duration_minutes = (self.epoch_train_end_time - self.epoch_train_start_time) / 60
440435
self.log(
441436
"epoch_train_duration_minutes",
442437
self.epoch_train_duration_minutes,
443438
on_epoch=True,
444439
prog_bar=False,
445440
)
446441
logger.info(
447-
f"Epoch {self.trainer.current_epoch} finished in"
448-
f" {self.epoch_train_duration_minutes:.1f} minutes."
442+
f"Epoch {self.trainer.current_epoch} finished in" f" {self.epoch_train_duration_minutes:.1f} minutes."
449443
)
450444

451445
def on_train_end(self):
@@ -492,9 +486,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i
492486
saveas=plot_filename,
493487
)
494488
if comet_logger is not None:
495-
comet_logger.log_image(
496-
plot_filename, name=plot_filename.split("/")[-1], step=curr_step
497-
)
489+
comet_logger.log_image(plot_filename, name=plot_filename.split("/")[-1], step=curr_step)
498490

499491
return loss
500492

@@ -569,9 +561,7 @@ def tokenize_ak_array(self, ak_arr, pp_dict, batch_size=256, pad_length=128, hid
569561
tokens = np_to_ak(codes, names=["token"], mask=mask)["token"]
570562
return tokens
571563

572-
def reconstruct_ak_tokens(
573-
self, tokens_ak, pp_dict, batch_size=256, pad_length=128, hide_pbar=False
574-
):
564+
def reconstruct_ak_tokens(self, tokens_ak, pp_dict, batch_size=256, pad_length=128, hide_pbar=False):
575565
"""Reconstruct tokenized awkward array.
576566
577567
Parameters
@@ -635,9 +625,7 @@ def reconstruct_ak_tokens(
635625
if hasattr(self.model, "latent_projection_out"):
636626
x_reco_batch = self.model.latent_projection_out(z_q) * mask_batch.unsqueeze(-1)
637627
x_reco_batch = self.model.decoder_normformer(x_reco_batch, mask=mask_batch)
638-
x_reco_batch = self.model.output_projection(
639-
x_reco_batch
640-
) * mask_batch.unsqueeze(-1)
628+
x_reco_batch = self.model.output_projection(x_reco_batch) * mask_batch.unsqueeze(-1)
641629
elif hasattr(self.model, "decoder"):
642630
x_reco_batch = self.model.decoder(z_q)
643631
else:
@@ -665,8 +653,6 @@ def on_test_epoch_end(self):
665653
self.test_labels_concat = np.concatenate(self.test_labels)
666654
self.test_code_idx_concat = np.concatenate(self.test_code_idx)
667655

668-
669-
670656

671657
def plot_model(model, samples, device="cuda", n_examples_to_plot=200, masks=None, saveas=None):
672658
"""Visualize the model.
@@ -833,22 +819,13 @@ def plot_model(model, samples, device="cuda", n_examples_to_plot=200, masks=None
833819
ax.set_yscale("log")
834820
print(idx)
835821
ax.set_title(
836-
"Codebook histogram\n(Each entry corresponds to one sample\nbeing associated with that"
837-
" codebook entry)",
822+
"Codebook histogram\n(Each entry corresponds to one sample\nbeing associated with that" " codebook entry)",
838823
fontsize=8,
839824
)
840825

841826
# make empty axes invisible
842827
def is_axes_empty(ax):
843-
return not (
844-
ax.lines
845-
or ax.patches
846-
or ax.collections
847-
or ax.images
848-
or ax.texts
849-
or ax.artists
850-
or ax.tables
851-
)
828+
return not (ax.lines or ax.patches or ax.collections or ax.images or ax.texts or ax.artists or ax.tables)
852829

853830
for ax in axarr.flatten():
854831
if is_axes_empty(ax):
@@ -882,4 +859,4 @@ def plot_loss(loss_history, lr_history, moving_average=100):
882859
ax2.set_ylabel("Learning Rate")
883860

884861
fig.tight_layout()
885-
plt.show()
862+
plt.show()

0 commit comments

Comments
 (0)