Skip to content

Commit 52d9a89

Browse files
committed
precommit
1 parent 9c8a585 commit 52d9a89

File tree

4 files changed

+44
-72
lines changed

4 files changed

+44
-72
lines changed

notebooks/clustering_studies.ipynb

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"outputs": [],
3333
"source": [
3434
"import sys\n",
35+
"\n",
3536
"sys.path.append(module_path)"
3637
]
3738
},
@@ -67,7 +68,7 @@
6768
" unique_labels, contiguous_labels = np.unique(elem[\"hit_labels\"], return_inverse=True)\n",
6869
" elem[\"hit_labels_contiguous\"] = contiguous_labels\n",
6970
" elems.append(elem)\n",
70-
" if len(elems)>=100:\n",
71+
" if len(elems) >= 100:\n",
7172
" break\n",
7273
"\n",
7374
"elems = [[ak.from_iter(elem)] for elem in elems]\n",
@@ -81,7 +82,7 @@
8182
"metadata": {},
8283
"outputs": [],
8384
"source": [
84-
"plt.hist(ak.max(elems[\"hit_labels_contiguous\"], axis=1), bins=np.linspace(0,400,41));\n",
85+
"plt.hist(ak.max(elems[\"hit_labels_contiguous\"], axis=1), bins=np.linspace(0, 400, 41))\n",
8586
"plt.xlabel(\"Clusters per event\")\n",
8687
"plt.ylabel(\"Event count\")"
8788
]
@@ -104,9 +105,9 @@
104105
"metadata": {},
105106
"outputs": [],
106107
"source": [
107-
"plt.hist(calo_hit_features_f[:, 0], np.linspace(-5000,5000,100), histtype=\"step\", lw=2, label=\"x\")\n",
108-
"plt.hist(calo_hit_features_f[:, 1], np.linspace(-5000,5000,100), histtype=\"step\", lw=2, label=\"y\")\n",
109-
"plt.hist(calo_hit_features_f[:, 2], np.linspace(-5000,5000,100), histtype=\"step\", lw=2, label=\"z\");\n",
108+
"plt.hist(calo_hit_features_f[:, 0], np.linspace(-5000, 5000, 100), histtype=\"step\", lw=2, label=\"x\")\n",
109+
"plt.hist(calo_hit_features_f[:, 1], np.linspace(-5000, 5000, 100), histtype=\"step\", lw=2, label=\"y\")\n",
110+
"plt.hist(calo_hit_features_f[:, 2], np.linspace(-5000, 5000, 100), histtype=\"step\", lw=2, label=\"z\")\n",
110111
"plt.xlabel(\"Hit position (mm)\")\n",
111112
"plt.ylabel(\"Hit count\")\n",
112113
"plt.legend()"
@@ -119,7 +120,7 @@
119120
"metadata": {},
120121
"outputs": [],
121122
"source": [
122-
"plt.hist(10*calo_hit_features_f[:, 3], np.logspace(-3,1,100))\n",
123+
"plt.hist(10 * calo_hit_features_f[:, 3], np.logspace(-3, 1, 100))\n",
123124
"plt.xscale(\"log\")\n",
124125
"plt.xlabel(\"Hit energy (GeV)\")\n",
125126
"plt.ylabel(\"Hit count\")"
@@ -160,20 +161,20 @@
160161
" cluster_hit_count = []\n",
161162
" cluster_id = []\n",
162163
" for clid in cluster_ids:\n",
163-
" cl_mask = elem[\"hit_labels_contiguous\"]==clid\n",
164+
" cl_mask = elem[\"hit_labels_contiguous\"] == clid\n",
164165
" std_x = np.std(elem[\"calo_hit_features\"][:, 0][cl_mask])\n",
165166
" std_y = np.std(elem[\"calo_hit_features\"][:, 1][cl_mask])\n",
166167
" std_z = np.std(elem[\"calo_hit_features\"][:, 2][cl_mask])\n",
167168
" sum_e = np.sum(elem[\"calo_hit_features\"][:, 3][cl_mask])\n",
168169
" hit_count = np.sum(cl_mask)\n",
169-
" \n",
170+
"\n",
170171
" cluster_std_x.append(std_x)\n",
171172
" cluster_std_y.append(std_y)\n",
172173
" cluster_std_z.append(std_z)\n",
173174
" cluster_sum_e.append(sum_e)\n",
174175
" cluster_hit_count.append(hit_count)\n",
175176
" cluster_id.append(clid)\n",
176-
" \n",
177+
"\n",
177178
" all_cluster_std_x.append(cluster_std_x)\n",
178179
" all_cluster_std_y.append(cluster_std_y)\n",
179180
" all_cluster_std_z.append(cluster_std_z)\n",
@@ -198,9 +199,9 @@
198199
"outputs": [],
199200
"source": [
200201
"plt.hist2d(\n",
201-
" ak.to_numpy(ak.flatten(all_cluster_hit_count[all_cluster_hit_count>5])),\n",
202-
" ak.to_numpy(ak.flatten(all_cluster_std_x[all_cluster_hit_count>5])),\n",
203-
" bins=(np.logspace(0,3,100), np.logspace(-2,4,100))\n",
202+
" ak.to_numpy(ak.flatten(all_cluster_hit_count[all_cluster_hit_count > 5])),\n",
203+
" ak.to_numpy(ak.flatten(all_cluster_std_x[all_cluster_hit_count > 5])),\n",
204+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 4, 100)),\n",
204205
")\n",
205206
"plt.xscale(\"log\")\n",
206207
"plt.yscale(\"log\")\n",
@@ -218,7 +219,7 @@
218219
"plt.hist2d(\n",
219220
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count))),\n",
220221
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_y))),\n",
221-
" bins=(np.logspace(0,3,100), np.logspace(-2,4,100))\n",
222+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 4, 100)),\n",
222223
")\n",
223224
"plt.xscale(\"log\")\n",
224225
"plt.yscale(\"log\")\n",
@@ -234,9 +235,9 @@
234235
"outputs": [],
235236
"source": [
236237
"plt.hist2d(\n",
237-
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count[all_cluster_hit_count>5]))),\n",
238-
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_z[all_cluster_hit_count>5]))),\n",
239-
" bins=(np.logspace(0,3,100), np.logspace(-2,4,100))\n",
238+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count[all_cluster_hit_count > 5]))),\n",
239+
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_std_z[all_cluster_hit_count > 5]))),\n",
240+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 4, 100)),\n",
240241
")\n",
241242
"plt.xscale(\"log\")\n",
242243
"plt.yscale(\"log\")\n",
@@ -251,11 +252,11 @@
251252
"metadata": {},
252253
"outputs": [],
253254
"source": [
254-
"plt.figure(figsize=(5,5))\n",
255+
"plt.figure(figsize=(5, 5))\n",
255256
"plt.hist2d(\n",
256257
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_hit_count))),\n",
257258
" ak.to_numpy(ak.flatten(ak.Array(all_cluster_sum_e))),\n",
258-
" bins=(np.logspace(0,3,100), np.logspace(-2,3,100))\n",
259+
" bins=(np.logspace(0, 3, 100), np.logspace(-2, 3, 100)),\n",
259260
")\n",
260261
"plt.xscale(\"log\")\n",
261262
"plt.yscale(\"log\")\n",
@@ -270,7 +271,7 @@
270271
"metadata": {},
271272
"outputs": [],
272273
"source": [
273-
"plt.hist(ak.flatten(all_cluster_hit_count), bins=np.linspace(0,1500,100));\n",
274+
"plt.hist(ak.flatten(all_cluster_hit_count), bins=np.linspace(0, 1500, 100))\n",
274275
"plt.yscale(\"log\")\n",
275276
"plt.xlabel(\"Number of hits per cluster\")\n",
276277
"plt.ylabel(\"Cluster count\")"
@@ -283,46 +284,37 @@
283284
"metadata": {},
284285
"outputs": [],
285286
"source": [
286-
"fig, axs = plt.subplots(3,3, figsize=(10,10))\n",
287+
"fig, axs = plt.subplots(3, 3, figsize=(10, 10))\n",
287288
"axs = axs.flatten()\n",
288289
"for ielem in range(9):\n",
289290
" plt.sca(axs[ielem])\n",
290291
" elem = elems[ielem]\n",
291-
" \n",
292+
"\n",
292293
" unique_labels, contiguous_labels = np.unique(elem[\"hit_labels\"], return_inverse=True)\n",
293-
" cmap = plt.get_cmap('viridis')\n",
294+
" cmap = plt.get_cmap(\"viridis\")\n",
294295
" distinct_colors = cmap(np.linspace(0, 1, len(unique_labels)))\n",
295-
" \n",
296+
"\n",
296297
" plt.scatter(\n",
297298
" elem[\"calo_hit_features\"][:, 0],\n",
298299
" elem[\"calo_hit_features\"][:, 1],\n",
299-
" s=np.clip(100*elem[\"calo_hit_features\"][:, 3], 0.1, 10),\n",
300-
" c=distinct_colors[contiguous_labels])\n",
300+
" s=np.clip(100 * elem[\"calo_hit_features\"][:, 3], 0.1, 10),\n",
301+
" c=distinct_colors[contiguous_labels],\n",
302+
" )\n",
301303
" plt.xlim(-6000, 6000)\n",
302304
" plt.ylim(-6000, 6000)\n",
303-
" plt.title(\"$N_{{hit}}$={}, $N_{{cl}}$={}\".format(len(elem[\"calo_hit_features\"]), len(np.unique(elem[\"hit_labels\"]))))\n",
305+
" plt.title(\n",
306+
" \"$N_{{hit}}$={}, $N_{{cl}}$={}\".format(len(elem[\"calo_hit_features\"]), len(np.unique(elem[\"hit_labels\"])))\n",
307+
" )\n",
304308
" plt.xticks([])\n",
305309
" plt.yticks([])"
306310
]
307311
}
308312
],
309313
"metadata": {
310314
"kernelspec": {
311-
"display_name": "Python 3 (ipykernel)",
315+
"display_name": "python3",
312316
"language": "python",
313317
"name": "python3"
314-
},
315-
"language_info": {
316-
"codemirror_mode": {
317-
"name": "ipython",
318-
"version": 3
319-
},
320-
"file_extension": ".py",
321-
"mimetype": "text/x-python",
322-
"name": "python",
323-
"nbconvert_exporter": "python",
324-
"pygments_lexer": "ipython3",
325-
"version": "3.11.13"
326318
}
327319
},
328320
"nbformat": 4,

notebooks/data_preprocessing.ipynb

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,9 @@
106106
],
107107
"metadata": {
108108
"kernelspec": {
109-
"display_name": "Python 3 (ipykernel)",
109+
"display_name": "python3",
110110
"language": "python",
111111
"name": "python3"
112-
},
113-
"language_info": {
114-
"codemirror_mode": {
115-
"name": "ipython",
116-
"version": 3
117-
},
118-
"file_extension": ".py",
119-
"mimetype": "text/x-python",
120-
"name": "python",
121-
"nbconvert_exporter": "python",
122-
"pygments_lexer": "ipython3",
123-
"version": "3.11.13"
124112
}
125113
},
126114
"nbformat": 4,

notebooks/debug-cld-processing.ipynb

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@
12041204
" event_data1[\"cluster_to_cluster_hit_matrix\"][event_i][\"hit_idx\"],\n",
12051205
" event_data1[\"cluster_to_cluster_hit_matrix\"][event_i][\"cluster_idx\"],\n",
12061206
" event_data1[\"cluster_to_cluster_hit_matrix\"][event_i][\"weight\"],\n",
1207-
" max_hits = np.max(hit_idx)+1\n",
1207+
" max_hits=np.max(hit_idx) + 1,\n",
12081208
")\n",
12091209
"\n",
12101210
"# Extract calorimeter hit positions (x, y, z)\n",
@@ -1217,7 +1217,9 @@
12171217
")\n",
12181218
"\n",
12191219
"\n",
1220-
"def plot_calo_hits_colored_by_genparticle(hit_labels, calo_hit_positions, title=\"Calorimeter hits colored by genparticle\"):\n",
1220+
"def plot_calo_hits_colored_by_genparticle(\n",
1221+
" hit_labels, calo_hit_positions, title=\"Calorimeter hits colored by genparticle\"\n",
1222+
"):\n",
12211223
" # Assign unique colors to each genparticle ID\n",
12221224
" unique_ids = np.unique(hit_labels)\n",
12231225
" colors = plt.cm.tab10(np.linspace(0, 1, len(unique_ids)))\n",
@@ -1233,7 +1235,7 @@
12331235
"\n",
12341236
" random_color_map = {gen_id: random_color() for gen_id in unique_ids}\n",
12351237
" random_color_map[-1] = \"rgba(0,0,0)\"\n",
1236-
" \n",
1238+
"\n",
12371239
" # Create traces for each genparticle ID\n",
12381240
" traces = []\n",
12391241
" for gen_id in unique_ids:\n",
@@ -1369,12 +1371,12 @@
13691371
" hit_labels = get_hit_labels(\n",
13701372
" hit_idx, gen_idx, weights\n",
13711373
" ) # This could be moved to the pre-processing step if needed\n",
1372-
" \n",
1374+
"\n",
13731375
" hit_labels2 = get_hit_labels(\n",
13741376
" cluster_to_cluster_hit_matrix[\"hit_idx\"],\n",
13751377
" cluster_to_cluster_hit_matrix[\"cluster_idx\"],\n",
13761378
" cluster_to_cluster_hit_matrix[\"weight\"],\n",
1377-
" max_hits = np.max(hit_idx)+1\n",
1379+
" max_hits=np.max(hit_idx) + 1,\n",
13781380
" )\n",
13791381
"\n",
13801382
" yield {\n",
@@ -1610,7 +1612,9 @@
16101612
"metadata": {},
16111613
"outputs": [],
16121614
"source": [
1613-
"plot_calo_hits_colored_by_genparticle(hit_labels_pandora, calo_hit_positions, \"Calorimeter hits colored by Pandora cluster\")"
1615+
"plot_calo_hits_colored_by_genparticle(\n",
1616+
" hit_labels_pandora, calo_hit_positions, \"Calorimeter hits colored by Pandora cluster\"\n",
1617+
")"
16141618
]
16151619
},
16161620
{
@@ -1623,21 +1627,9 @@
16231627
],
16241628
"metadata": {
16251629
"kernelspec": {
1626-
"display_name": "Python 3 (ipykernel)",
1630+
"display_name": "python3",
16271631
"language": "python",
16281632
"name": "python3"
1629-
},
1630-
"language_info": {
1631-
"codemirror_mode": {
1632-
"name": "ipython",
1633-
"version": 3
1634-
},
1635-
"file_extension": ".py",
1636-
"mimetype": "text/x-python",
1637-
"name": "python",
1638-
"nbconvert_exporter": "python",
1639-
"pygments_lexer": "ipython3",
1640-
"version": "3.11.13"
16411633
}
16421634
},
16431635
"nbformat": 4,

src/datasets/CLDHits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __iter__(self):
148148
cluster_to_cluster_hit_matrix["hit_idx"],
149149
cluster_to_cluster_hit_matrix["cluster_idx"],
150150
cluster_to_cluster_hit_matrix["weight"],
151-
max_hits = np.max(hit_idx)+1
151+
max_hits=np.max(hit_idx) + 1,
152152
)
153153

154154
if self.by_event:

0 commit comments

Comments
 (0)