Skip to content

Commit 7faee76

Browse files
authored
Merge pull request #10 from jpata/jp_20251107_clustering
Clustering studies, baseline reco clustering - add notebook to visualize cluster information - in CLDHits retrieve Pandora cluster indices as a baseline clustering to compare our ML against - use dvc to keep track of datasets
2 parents 42266d2 + 52d9a89 commit 7faee76

File tree

10 files changed

+430
-8
lines changed

10 files changed

+430
-8
lines changed

.dvc/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/config.local
2+
/tmp
3+
/cache

.dvc/config

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[core]
2+
remote = cern-jpata
3+
4+
['remote "cern-jpata"']
5+
url = https://jpata.web.cern.ch/dvc/particlemind

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
data/
21
*.ipynb_checkpoints
32
*.pyc

README_Tallinn.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
```
2+
./scripts/run_ee.sh dvc config --local cache.dir /scratch/persistent/$USER/dvc-cache
3+
./scripts/run_ee.sh dvc config --local cache.type symlink
4+
./scripts/run_ee.sh dvc fetch
5+
./scripts/run_ee.sh dvc pull
6+
./scripts/run_ee.sh jupyter notebook
7+
```

data/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/p8_ee_tt_ecm365

data/p8_ee_tt_ecm365.dvc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
outs:
2+
- md5: 1e061a23483ff95a45c620fdd23db85d.dir
3+
size: 39000735175
4+
nfiles: 1000
5+
hash: md5
6+
path: p8_ee_tt_ecm365

notebooks/clustering_studies.ipynb

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "7084fc8e-60bb-4341-af6f-92eb97ebc42c",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"root_file_path = \"../data/p8_ee_tt_ecm365/root\"\n",
11+
"parquet_file_path = \"../data/p8_ee_tt_ecm365/parquet\"\n",
12+
"module_path = \"../\""
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"id": "6f153459-a73f-4126-8029-5ebf597ba1ba",
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"import matplotlib.pyplot as plt\n",
23+
"import numpy as np\n",
24+
"import awkward as ak"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"id": "7480d773-180e-4709-9229-7e91facbdcb3",
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"import sys\n",
35+
"\n",
36+
"sys.path.append(module_path)"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"id": "7c392fc8-50fe-4429-8f73-a20d6bf98717",
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"from src.datasets.CLDHits import CLDHits"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"id": "a06ee555-0b52-49ce-97c8-3586fb8081be",
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"dataset_train = CLDHits(parquet_file_path, \"train\")"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"id": "4eb29375-8ca1-455c-bb09-e5c2d1d5c748",
63+
"metadata": {},
64+
"outputs": [],
65+
"source": [
66+
"elems = []\n",
67+
"for elem in dataset_train:\n",
68+
" unique_labels, contiguous_labels = np.unique(elem[\"hit_labels\"], return_inverse=True)\n",
69+
" elem[\"hit_labels_contiguous\"] = contiguous_labels\n",
70+
" elems.append(elem)\n",
71+
" if len(elems) >= 100:\n",
72+
" break\n",
73+
"\n",
74+
"elems = [[ak.from_iter(elem)] for elem in elems]\n",
75+
"elems = ak.concatenate(elems, axis=0)"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "44a55aa6-20a6-4aa4-b91a-44731dcd971d",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"plt.hist(ak.max(elems[\"hit_labels_contiguous\"], axis=1), bins=np.linspace(0, 400, 41))\n",
86+
"plt.xlabel(\"Clusters per event\")\n",
87+
"plt.ylabel(\"Event count\")"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "b7c2fbe8-a243-4a07-aab0-3f8b7576cc66",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"hit_labels_c_f = ak.flatten(elems[\"hit_labels_contiguous\"])\n",
98+
"calo_hit_features_f = ak.flatten(elems[\"calo_hit_features\"])"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"id": "4f52c7c6-1dd1-436a-bdeb-6a8754862512",
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"plt.hist(calo_hit_features_f[:, 0], np.linspace(-5000, 5000, 100), histtype=\"step\", lw=2, label=\"x\")\n",
109+
"plt.hist(calo_hit_features_f[:, 1], np.linspace(-5000, 5000, 100), histtype=\"step\", lw=2, label=\"y\")\n",
110+
"plt.hist(calo_hit_features_f[:, 2], np.linspace(-5000, 5000, 100), histtype=\"step\", lw=2, label=\"z\")\n",
111+
"plt.xlabel(\"Hit position (mm)\")\n",
112+
"plt.ylabel(\"Hit count\")\n",
113+
"plt.legend()"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"id": "2c3cd37c-3626-4884-8712-f5e0e8e02111",
120+
"metadata": {},
121+
"outputs": [],
122+
"source": [
123+
"plt.hist(10 * calo_hit_features_f[:, 3], np.logspace(-3, 1, 100))\n",
124+
"plt.xscale(\"log\")\n",
125+
"plt.xlabel(\"Hit energy (GeV)\")\n",
126+
"plt.ylabel(\"Hit count\")"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"id": "3cbc7257-df54-400a-a0ca-34b2c3ecb295",
133+
"metadata": {},
134+
"outputs": [],
135+
"source": [
136+
"len(elems)"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"id": "5927be66-b729-427c-9e01-5df7e83b7487",
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"all_cluster_std_x = []\n",
147+
"all_cluster_std_y = []\n",
148+
"all_cluster_std_z = []\n",
149+
"all_cluster_sum_e = []\n",
150+
"all_cluster_hit_count = []\n",
151+
"all_cluster_id = []\n",
152+
"\n",
153+
"for ielem in range(5):\n",
154+
" print(ielem)\n",
155+
" elem = elems[ielem]\n",
156+
" cluster_ids = np.unique(elem[\"hit_labels_contiguous\"])\n",
157+
" cluster_std_x = []\n",
158+
" cluster_std_y = []\n",
159+
" cluster_std_z = []\n",
160+
" cluster_sum_e = []\n",
161+
" cluster_hit_count = []\n",
162+
" cluster_id = []\n",
163+
" for clid in cluster_ids:\n",
164+
" cl_mask = elem[\"hit_labels_contiguous\"] == clid\n",
165+
" std_x = np.std(elem[\"calo_hit_features\"][:, 0][cl_mask])\n",
166+
" std_y = np.std(elem[\"calo_hit_features\"][:, 1][cl_mask])\n",
167+
" std_z = np.std(elem[\"calo_hit_features\"][:, 2][cl_mask])\n",
168+
" sum_e = np.sum(elem[\"calo_hit_features\"][:, 3][cl_mask])\n",
169+
" hit_count = np.sum(cl_mask)\n",
170+
"\n",
171+
" cluster_std_x.append(std_x)\n",
172+
" cluster_std_y.append(std_y)\n",
173+
" cluster_std_z.append(std_z)\n",
174+
" cluster_sum_e.append(sum_e)\n",
175+
" cluster_hit_count.append(hit_count)\n",
176+
" cluster_id.append(clid)\n",
177+
"\n",
178+
" all_cluster_std_x.append(cluster_std_x)\n",
179+
" all_cluster_std_y.append(cluster_std_y)\n",
180+
" all_cluster_std_z.append(cluster_std_z)\n",
181+
" all_cluster_sum_e.append(cluster_sum_e)\n",
182+
" all_cluster_hit_count.append(cluster_hit_count)\n",
183+
" all_cluster_id.append(cluster_id)\n",
184+
"\n",
185+
"\n",
186+
"all_cluster_std_x = ak.Array(all_cluster_std_x)\n",
187+
"all_cluster_std_y = ak.Array(all_cluster_std_y)\n",
188+
"all_cluster_std_z = ak.Array(all_cluster_std_z)\n",
189+
"all_cluster_sum_e = ak.Array(all_cluster_sum_e)\n",
190+
"all_cluster_hit_count = ak.Array(all_cluster_hit_count)\n",
191+
"all_cluster_id = ak.Array(all_cluster_id)"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"id": "005eadee-b7b1-4f55-a7f0-0c8965d3ab7c",
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"plt.hist2d(\n",
202+
" ak.to_numpy(ak.flatten(all_cluster_hit_count[all_cluster_hit_count > 5])),\n",
203+
" ak.to_numpy(ak.flatten(all_cluster_std_x[all_cluster_hit_count > 5])),\n",
204+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 4, 100)),\n",
205+
")\n",
206+
"plt.xscale(\"log\")\n",
207+
"plt.yscale(\"log\")\n",
208+
"plt.xlabel(\"Hits per cluster\")\n",
209+
"plt.ylabel(\"Hit pos x stddev\")"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": null,
215+
"id": "7b088f67-c4e4-454c-9244-0bbbc3c6500f",
216+
"metadata": {},
217+
"outputs": [],
218+
"source": [
219+
"plt.hist2d(\n",
220+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count))),\n",
221+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_y))),\n",
222+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 4, 100)),\n",
223+
")\n",
224+
"plt.xscale(\"log\")\n",
225+
"plt.yscale(\"log\")\n",
226+
"plt.xlabel(\"Hits per cluster\")\n",
227+
"plt.ylabel(\"Hit pos y stddev\")"
228+
]
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": null,
233+
"id": "5004224b-ef89-4183-b6d5-f3774dbef0bc",
234+
"metadata": {},
235+
"outputs": [],
236+
"source": [
237+
"plt.hist2d(\n",
238+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count[all_cluster_hit_count > 5]))),\n",
239+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_z[all_cluster_hit_count > 5]))),\n",
240+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 4, 100)),\n",
241+
")\n",
242+
"plt.xscale(\"log\")\n",
243+
"plt.yscale(\"log\")\n",
244+
"plt.xlabel(\"Hits per cluster\")\n",
245+
"plt.ylabel(\"Hit pos z stddev\")"
246+
]
247+
},
248+
{
249+
"cell_type": "code",
250+
"execution_count": null,
251+
"id": "b314c040-84e7-4970-9b3d-42144ca5ef4b",
252+
"metadata": {},
253+
"outputs": [],
254+
"source": [
255+
"plt.figure(figsize=(5, 5))\n",
256+
"plt.hist2d(\n",
257+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count))),\n",
258+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_sum_e))),\n",
259+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 3, 100)),\n",
260+
")\n",
261+
"plt.xscale(\"log\")\n",
262+
"plt.yscale(\"log\")\n",
263+
"plt.xlabel(\"Hits per cluster\")\n",
264+
"plt.ylabel(\"Sum energy per cluster\")"
265+
]
266+
},
267+
{
268+
"cell_type": "code",
269+
"execution_count": null,
270+
"id": "2410866d-9c17-4c3b-b182-c67c5692506d",
271+
"metadata": {},
272+
"outputs": [],
273+
"source": [
274+
"plt.hist(ak.flatten(all_cluster_hit_count), bins=np.linspace(0, 1500, 100))\n",
275+
"plt.yscale(\"log\")\n",
276+
"plt.xlabel(\"Number of hits per cluster\")\n",
277+
"plt.ylabel(\"Cluster count\")"
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"id": "11acb17c-255a-4d01-9371-f01b6a378520",
284+
"metadata": {},
285+
"outputs": [],
286+
"source": [
287+
"fig, axs = plt.subplots(3, 3, figsize=(10, 10))\n",
288+
"axs = axs.flatten()\n",
289+
"for ielem in range(9):\n",
290+
" plt.sca(axs[ielem])\n",
291+
" elem = elems[ielem]\n",
292+
"\n",
293+
" unique_labels, contiguous_labels = np.unique(elem[\"hit_labels\"], return_inverse=True)\n",
294+
" cmap = plt.get_cmap(\"viridis\")\n",
295+
" distinct_colors = cmap(np.linspace(0, 1, len(unique_labels)))\n",
296+
"\n",
297+
" plt.scatter(\n",
298+
" elem[\"calo_hit_features\"][:, 0],\n",
299+
" elem[\"calo_hit_features\"][:, 1],\n",
300+
" s=np.clip(100 * elem[\"calo_hit_features\"][:, 3], 0.1, 10),\n",
301+
" c=distinct_colors[contiguous_labels],\n",
302+
" )\n",
303+
" plt.xlim(-6000, 6000)\n",
304+
" plt.ylim(-6000, 6000)\n",
305+
" plt.title(\n",
306+
" \"$N_{{hit}}$={}, $N_{{cl}}$={}\".format(len(elem[\"calo_hit_features\"]), len(np.unique(elem[\"hit_labels\"])))\n",
307+
" )\n",
308+
" plt.xticks([])\n",
309+
" plt.yticks([])"
310+
]
311+
}
312+
],
313+
"metadata": {
314+
"kernelspec": {
315+
"display_name": "python3",
316+
"language": "python",
317+
"name": "python3"
318+
}
319+
},
320+
"nbformat": 4,
321+
"nbformat_minor": 5
322+
}

0 commit comments

Comments
 (0)