Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Run pre-commit hooks

on:
pull_request:
branches: [ main ]
push:
branches: [ main ]

jobs:
pre-commit:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install nbdev
- uses: pre-commit/[email protected]
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ repos:
hooks:
- id: nbdev_clean
args: [--fname=notebooks, --clear_all]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-json
- id: check-yaml
- id: trailing-whitespace
- id: end-of-file-fixer
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ Self-supervised learning on HEP events.
└── dataset2
├── raw
└── processed
```
```
116 changes: 116 additions & 0 deletions notebooks/data_preprocessing.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4c99e37a-fe0e-424e-968e-e79e937a4498",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import awkward as ak\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5f4f372-d217-4629-8773-4c7809a222ce",
"metadata": {},
"outputs": [],
"source": [
"file_path = \"/pscratch/sd/r/rmastand/particlemind/data/p8_ee_tt_ecm365_parquetfiles/reco_p8_ee_tt_ecm365_60000.parquet\"\n",
"\n",
"data = ak.from_parquet(file_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b615189-929d-4cf6-99f9-81b9dda7e69e",
"metadata": {},
"outputs": [],
"source": [
"print(data[\"calo_hit_features\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91e31f81-9f2e-4746-91dd-7fab50dc65e1",
"metadata": {},
"outputs": [],
"source": [
"all_features = []\n",
"\n",
"for event_i in range(100):\n",
" calo_hit_features = data[\"calo_hit_features\"][event_i]\n",
"\n",
" calo_hit_features = np.column_stack(\n",
" (\n",
" calo_hit_features[\"position.x\"].to_numpy() / 1e4,\n",
" calo_hit_features[\"position.y\"].to_numpy() / 1e4,\n",
" calo_hit_features[\"position.z\"].to_numpy() / 1e4,\n",
" np.log(calo_hit_features[\"energy\"].to_numpy() * 1e2) / 10,\n",
" calo_hit_features[\"type\"].to_numpy(),\n",
" calo_hit_features[\"subdetector\"].to_numpy(),\n",
" )\n",
" )\n",
" all_features.append(calo_hit_features)\n",
"\n",
"all_features = np.concatenate(all_features)\n",
"print(all_features.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5ee3b74-8519-4f29-9a31-70f33010e9f8",
"metadata": {},
"outputs": [],
"source": [
"labels = [\"x\", \"y\", \"z\", \"energy\", \"type\", \"subdetector\"]\n",
"\n",
"for i in range(6):\n",
"\n",
" plt.figure()\n",
" plt.hist(all_features[:, i], bins=100, histtype=\"step\")\n",
" plt.xlabel(labels[i])\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d070aead-4b9f-459c-bd56-df74b360a07d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "22264792-bca3-4797-99cd-4eb2de20ff8b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca4c11af-b40e-4676-9ef0-6535c75afc08",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ pre-commit
torch==2.5.1
torchvision==0.20.1
torchaudio==2.5.1
nbdev
218 changes: 0 additions & 218 deletions src/data_preprocessing.ipynb

This file was deleted.

14 changes: 7 additions & 7 deletions src/datasets/CLDHits.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def inverse_standardize_calo_hit_features(calo_hit_features):


class CLDHits(IterableDataset):
def __init__(self, folder_path, split, nsamples=None, shuffle_files=False, train_fraction=0.8, nfiles=-1, by_event=True):
def __init__(
self, folder_path, split, nsamples=None, shuffle_files=False, train_fraction=0.8, nfiles=-1, by_event=True
):
"""
Initialize the dataset by storing the paths to all parquet files in the specified folder.

Expand Down Expand Up @@ -101,7 +103,7 @@ def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single-process data loading
files_to_process = self.parquet_files[:self.nfiles]
files_to_process = self.parquet_files[: self.nfiles]
logger.info(f"Processing {len(files_to_process)} files in single-process mode.")

else:
Expand Down Expand Up @@ -157,8 +159,6 @@ def __iter__(self):
self.sample_counter += 1

yield {
"hit_labels": hit_labels[i:i+1], # Shape (1,) or (1, label_dim)
"calo_hit_features": calo_hit_features[i:i+1], # Shape (1, num_features)
}


"hit_labels": hit_labels[i : i + 1], # Shape (1,) or (1, label_dim)
"calo_hit_features": calo_hit_features[i : i + 1], # Shape (1, num_features)
}
5 changes: 3 additions & 2 deletions src/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch


# adapted from from: https://github.com/jpata/particleflow/blob/a3a08fe1e687987c661faad00fd5526e733be014/mlpf/model/PFDataset.py#L163
class Collater:
"""
Expand Down Expand Up @@ -31,9 +32,9 @@ def __call__(self, inputs):
)

# get mask
axis_sum = torch.sum(torch.abs(ret["calo_hit_features"]), dim = 2)
axis_sum = torch.sum(torch.abs(ret["calo_hit_features"]), dim=2)
ret["calo_hit_mask"] = torch.where(axis_sum > 0, 1.0, 0.0)

return ret

# per-particle quantities need to be padded across events of different size
Expand Down
51 changes: 14 additions & 37 deletions src/models/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def __init__(
self.output_dim = output_dim

self.blocks = nn.ModuleList(
[
NormformerBlock(input_dim=hidden_dim, mlp_dim=hidden_dim, num_heads=num_heads)
for _ in range(num_blocks)
]
[NormformerBlock(input_dim=hidden_dim, mlp_dim=hidden_dim, num_heads=num_heads) for _ in range(num_blocks)]
)
self.project_out = nn.Linear(hidden_dim, output_dim)

Expand Down Expand Up @@ -336,8 +333,8 @@ class VQVAELightning(L.LightningModule):

def __init__(
self,
optimizer_kwargs = {},
#scheduler_kwargs = {},
optimizer_kwargs={},
# scheduler_kwargs = {},
model_kwargs={},
model_type="Transformer",
**kwargs,
Expand Down Expand Up @@ -373,8 +370,7 @@ def __init__(
self.val_mask = []

def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.model.parameters(), **self.optimizer_kwargs)
optimizer = torch.optim.AdamW(self.model.parameters(), **self.optimizer_kwargs)
"""
if self.lr_scheduler:
return {
Expand All @@ -387,7 +383,7 @@ def configure_optimizers(self):
}
"""
return optimizer

def forward(self, x_particle, mask_particle):
x_particle_reco, vq_out = self.model(x_particle, mask=mask_particle)
return x_particle_reco, vq_out
Expand Down Expand Up @@ -428,24 +424,22 @@ def on_train_start(self) -> None:
self.trainer.datamodule.hparams.dataset_kwargs_common.feature_dict
)
"""

def on_train_epoch_start(self):
logger.info(f"Epoch {self.trainer.current_epoch} starting.")
self.epoch_train_start_time = time.time() # start timing the epoch

def on_train_epoch_end(self):
self.epoch_train_end_time = time.time()
self.epoch_train_duration_minutes = (
self.epoch_train_end_time - self.epoch_train_start_time
) / 60
self.epoch_train_duration_minutes = (self.epoch_train_end_time - self.epoch_train_start_time) / 60
self.log(
"epoch_train_duration_minutes",
self.epoch_train_duration_minutes,
on_epoch=True,
prog_bar=False,
)
logger.info(
f"Epoch {self.trainer.current_epoch} finished in"
f" {self.epoch_train_duration_minutes:.1f} minutes."
f"Epoch {self.trainer.current_epoch} finished in" f" {self.epoch_train_duration_minutes:.1f} minutes."
)

def on_train_end(self):
Expand Down Expand Up @@ -492,9 +486,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i
saveas=plot_filename,
)
if comet_logger is not None:
comet_logger.log_image(
plot_filename, name=plot_filename.split("/")[-1], step=curr_step
)
comet_logger.log_image(plot_filename, name=plot_filename.split("/")[-1], step=curr_step)

return loss

Expand Down Expand Up @@ -569,9 +561,7 @@ def tokenize_ak_array(self, ak_arr, pp_dict, batch_size=256, pad_length=128, hid
tokens = np_to_ak(codes, names=["token"], mask=mask)["token"]
return tokens

def reconstruct_ak_tokens(
self, tokens_ak, pp_dict, batch_size=256, pad_length=128, hide_pbar=False
):
def reconstruct_ak_tokens(self, tokens_ak, pp_dict, batch_size=256, pad_length=128, hide_pbar=False):
"""Reconstruct tokenized awkward array.

Parameters
Expand Down Expand Up @@ -635,9 +625,7 @@ def reconstruct_ak_tokens(
if hasattr(self.model, "latent_projection_out"):
x_reco_batch = self.model.latent_projection_out(z_q) * mask_batch.unsqueeze(-1)
x_reco_batch = self.model.decoder_normformer(x_reco_batch, mask=mask_batch)
x_reco_batch = self.model.output_projection(
x_reco_batch
) * mask_batch.unsqueeze(-1)
x_reco_batch = self.model.output_projection(x_reco_batch) * mask_batch.unsqueeze(-1)
elif hasattr(self.model, "decoder"):
x_reco_batch = self.model.decoder(z_q)
else:
Expand Down Expand Up @@ -665,8 +653,6 @@ def on_test_epoch_end(self):
self.test_labels_concat = np.concatenate(self.test_labels)
self.test_code_idx_concat = np.concatenate(self.test_code_idx)




def plot_model(model, samples, device="cuda", n_examples_to_plot=200, masks=None, saveas=None):
"""Visualize the model.
Expand Down Expand Up @@ -833,22 +819,13 @@ def plot_model(model, samples, device="cuda", n_examples_to_plot=200, masks=None
ax.set_yscale("log")
print(idx)
ax.set_title(
"Codebook histogram\n(Each entry corresponds to one sample\nbeing associated with that"
" codebook entry)",
"Codebook histogram\n(Each entry corresponds to one sample\nbeing associated with that" " codebook entry)",
fontsize=8,
)

# make empty axes invisible
def is_axes_empty(ax):
return not (
ax.lines
or ax.patches
or ax.collections
or ax.images
or ax.texts
or ax.artists
or ax.tables
)
return not (ax.lines or ax.patches or ax.collections or ax.images or ax.texts or ax.artists or ax.tables)

for ax in axarr.flatten():
if is_axes_empty(ax):
Expand Down Expand Up @@ -882,4 +859,4 @@ def plot_loss(loss_history, lr_history, moving_average=100):
ax2.set_ylabel("Learning Rate")

fig.tight_layout()
plt.show()
plt.show()
Loading