Skip to content

Commit e394ea2

Browse files
authored
Merge pull request #8 from kkovary/fix-int32-out-of-bounds
fix int32 out of bounds
2 parents 134ca06 + 2afcc2c commit e394ea2

25 files changed

+970
-367
lines changed

.github/workflows/tests.yml

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Python Package using Conda
1+
name: Python Package using uv
22

33
on: [push]
44

@@ -9,26 +9,22 @@ jobs:
99
max-parallel: 5
1010

1111
steps:
12-
- uses: actions/checkout@v2
13-
- name: Set up Python 3.7
14-
uses: actions/setup-python@v2
12+
- uses: actions/checkout@v4
13+
- name: Install uv
14+
uses: astral-sh/setup-uv@v5
1515
with:
16-
python-version: 3.7
17-
- name: Add conda to system path
18-
run: |
19-
# $CONDA is an environment variable pointing to the root of the miniconda directory
20-
echo $CONDA/bin >> $GITHUB_PATH
16+
enable-cache: true
17+
cache-dependency-glob: "uv.lock"
18+
- name: Set up Python 3.10
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: "3.10"
2122
- name: Install dependencies
2223
run: |
23-
conda env update --file environment.yml --name base
24-
- name: Lint with flake8
24+
uv sync
25+
- name: Run pre-commit
2526
run: |
26-
conda install flake8
27-
# stop the build if there are Python syntax errors or undefined names
28-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
29-
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
30-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
27+
uv run pre-commit run --all-files
3128
- name: Test with tox
3229
run: |
33-
pip install -e .[testing]
34-
tox
30+
uv run tox

.pre-commit-config.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.3.3
4+
hooks:
5+
- id: ruff
6+
args: [--fix]
7+
- id: ruff-format
8+
9+
- repo: local
10+
hooks:
11+
- id: pytest
12+
name: pytest
13+
entry: uv run pytest
14+
language: system
15+
types: [python]
16+
pass_filenames: false
17+
always_run: true

environment.yml

Lines changed: 0 additions & 9 deletions
This file was deleted.

notebooks/01_fingerprinting.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"import numpy as np\n",
3232
"from matplotlib import pyplot as plt\n",
3333
"from sklearn.decomposition import PCA\n",
34-
"from drfp import DrfpEncoder\n"
34+
"from drfp import DrfpEncoder"
3535
]
3636
},
3737
{
@@ -288,7 +288,7 @@
288288
"pca = PCA(n_components=2)\n",
289289
"X = pca.fit(fps).transform(fps)\n",
290290
"\n",
291-
"plt.scatter(X[:,0], X[:,1], alpha=0.8)\n",
291+
"plt.scatter(X[:, 0], X[:, 1], alpha=0.8)\n",
292292
"plt.title(\"PCA of 100 drfp-encoded reactions\")"
293293
]
294294
},

notebooks/02_model_explainability.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@
277277
}
278278
],
279279
"source": [
280-
"shap.force_plot(explainer.expected_value, shap_values[0,:], matplotlib=True)"
280+
"shap.force_plot(explainer.expected_value, shap_values[0, :], matplotlib=True)"
281281
]
282282
},
283283
{
@@ -321,7 +321,7 @@
321321
}
322322
],
323323
"source": [
324-
"shap.force_plot(explainer.expected_value, shap_values[42,:], matplotlib=True)"
324+
"shap.force_plot(explainer.expected_value, shap_values[42, :], matplotlib=True)"
325325
]
326326
},
327327
{

notebooks/03_more_model_explainability.ipynb

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
}
9191
],
9292
"source": [
93-
"\n",
9493
"%pip install theia-pypi xgboost matplotlib faerun-notebook --upgrade\n",
9594
"import pickle\n",
9695
"from pathlib import Path\n",
@@ -321,6 +320,7 @@
321320
"!pip uninstall ipywidgets -y\n",
322321
"!pip install ipywidgets==7.7.1\n",
323322
"import ipywidgets\n",
323+
"\n",
324324
"ipywidgets.version_info"
325325
]
326326
},
@@ -375,11 +375,9 @@
375375
"mapping = split[\"test\"][\"mapping\"]\n",
376376
"dataset = InferenceReactionDataset([rxn])\n",
377377
"\n",
378-
"expl = explain_regression(\n",
379-
" dataset, explainer, mapping\n",
380-
")\n",
378+
"expl = explain_regression(dataset, explainer, mapping)\n",
381379
"\n",
382-
"w = { \"reactants\": expl.reactant_weights, \"products\": expl.product_weights}\n",
380+
"w = {\"reactants\": expl.reactant_weights, \"products\": expl.product_weights}\n",
383381
"\n",
384382
"SmilesDrawer(value=[(\"Example\", rxn)], weights=[w], output=\"img\", theme=\"solarized\")"
385383
]

notebooks/0a_figures.ipynb

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@
4242
"\n",
4343
"pd.options.mode.chained_assignment = None\n",
4444
"\n",
45-
"modern_cmap = LinearSegmentedColormap.from_list(\n",
46-
" \"modern_cmap\", \n",
47-
" [\"#ffffff\", \"#003f5c\"], \n",
48-
" N=256\n",
49-
")\n",
45+
"modern_cmap = LinearSegmentedColormap.from_list(\"modern_cmap\", [\"#ffffff\", \"#003f5c\"], N=256)\n",
5046
"\n",
5147
"schneider_class_names = [\n",
5248
" \"Alductive amination\",\n",
@@ -124,9 +120,7 @@
124120
}
125121
],
126122
"source": [
127-
"df = pd.read_csv(\"yield_prediction_results.csv\", names=[\n",
128-
" \"data_set\", \"split\", \"filename\", \"ground_truth\", \"prediction\"\n",
129-
"])\n",
123+
"df = pd.read_csv(\"yield_prediction_results.csv\", names=[\"data_set\", \"split\", \"filename\", \"ground_truth\", \"prediction\"])\n",
130124
"\n",
131125
"df[\"error\"] = df.prediction - df.ground_truth\n",
132126
"\n",
@@ -181,6 +175,7 @@
181175
" fontfamily=font_family,\n",
182176
" )\n",
183177
"\n",
178+
"\n",
184179
"def calc_r2(df, verbose=True):\n",
185180
" result = []\n",
186181
" result_raw = []\n",
@@ -196,34 +191,31 @@
196191
" r2 = r2_score(df_tmp.ground_truth, df_tmp.prediction)\n",
197192
" result_raw.append({\"r2\": r2, \"split\": split})\n",
198193
" r2s.append(r2)\n",
199-
" \n",
194+
"\n",
200195
" if verbose:\n",
201196
" print(f\"r2 mean={round(sum(r2s) / len(r2s), 5)}, r2 std={round(stdev(r2s), 5)}\")\n",
202197
" result.append((sum(r2s) / len(r2s), stdev(r2s)))\n",
203198
"\n",
204199
" return (result, pd.DataFrame(result_raw))\n",
205200
"\n",
201+
"\n",
206202
"def scatter(df, ax, title):\n",
207203
" sns.kdeplot(\n",
208-
" data=df, \n",
209-
" x=\"ground_truth\", y=\"prediction\", clip=((0, 100), (None, None)),\n",
210-
" color=\"#003f5c\", levels=6, zorder=2, ax=ax\n",
204+
" data=df, x=\"ground_truth\", y=\"prediction\", clip=((0, 100), (None, None)), color=\"#003f5c\", levels=6, zorder=2, ax=ax\n",
211205
" )\n",
212206
"\n",
213207
" ax.plot(\n",
214-
" [0, 100], [0, 100], linewidth=2, \n",
215-
" color=\"#bc5090\", linestyle=\"dashed\",\n",
208+
" [0, 100],\n",
209+
" [0, 100],\n",
210+
" linewidth=2,\n",
211+
" color=\"#bc5090\",\n",
212+
" linestyle=\"dashed\",\n",
216213
" zorder=1,\n",
217214
" )\n",
218215
"\n",
219-
" sns.scatterplot(\n",
220-
" data=df, \n",
221-
" x=\"ground_truth\", y=\"prediction\",\n",
222-
" color=\"#cccccc\", linewidth=0,\n",
223-
" alpha=0.125, zorder=0, ax=ax\n",
224-
" )\n",
216+
" sns.scatterplot(data=df, x=\"ground_truth\", y=\"prediction\", color=\"#cccccc\", linewidth=0, alpha=0.125, zorder=0, ax=ax)\n",
225217
"\n",
226-
" ax.set(xlabel=\"Ground Truth\", ylabel='Prediction')\n",
218+
" ax.set(xlabel=\"Ground Truth\", ylabel=\"Prediction\")\n",
227219
" ax.set_title(title)"
228220
]
229221
},
@@ -329,36 +321,28 @@
329321
"splits = [98, 197, 395, 791, 1186, 1977, 2766]\n",
330322
"\n",
331323
"for i in range(7):\n",
332-
" scatter(\n",
333-
" df_buchwald_hartwig_cv[df_buchwald_hartwig_cv.split == splits[i]],\n",
334-
" axs.flat[i], titles[i]\n",
335-
" )\n",
324+
" scatter(df_buchwald_hartwig_cv[df_buchwald_hartwig_cv.split == splits[i]], axs.flat[i], titles[i])\n",
336325
"\n",
337326
"_, df_results = calc_r2(df_buchwald_hartwig_cv, verbose=False)\n",
338327
"\n",
339328
"sns.stripplot(\n",
340-
" x=\"split\", y=\"r2\", data=df_results, linewidth=1, ax=axs.flat[7],\n",
341-
" palette=[\"#003f5c\", \"#374c80\", \"#7a5195\", \"#bc5090\", \"#ef5675\", \"#ff764a\", \"#ffa600\"]\n",
329+
" x=\"split\",\n",
330+
" y=\"r2\",\n",
331+
" data=df_results,\n",
332+
" linewidth=1,\n",
333+
" ax=axs.flat[7],\n",
334+
" palette=[\"#003f5c\", \"#374c80\", \"#7a5195\", \"#bc5090\", \"#ef5675\", \"#ff764a\", \"#ffa600\"],\n",
342335
")\n",
343336
"axs.flat[7].set_xticklabels([\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\"])\n",
344337
"axs.flat[7].set(xlabel=\"Split\", ylabel=\"Accuracy\")\n",
345338
"\n",
346-
"titles = [\n",
347-
" \"Out-of-sample Split 1\", \"Out-of-sample Split 2\", \n",
348-
" \"Out-of-sample Split 3\", \"Out-of-sample Split 4\"\n",
349-
"]\n",
339+
"titles = [\"Out-of-sample Split 1\", \"Out-of-sample Split 2\", \"Out-of-sample Split 3\", \"Out-of-sample Split 4\"]\n",
350340
"\n",
351-
"splits = [\n",
352-
" \"Test1-2048-3-true.pkl\", \"Test2-2048-3-true.pkl\", \n",
353-
" \"Test3-2048-3-true.pkl\", \"Test4-2048-3-true.pkl\"\n",
354-
"]\n",
341+
"splits = [\"Test1-2048-3-true.pkl\", \"Test2-2048-3-true.pkl\", \"Test3-2048-3-true.pkl\", \"Test4-2048-3-true.pkl\"]\n",
355342
"\n",
356343
"j = 0\n",
357344
"for i in range(8, 12):\n",
358-
" scatter(\n",
359-
" df_buchwald_hartwig_tests[df_buchwald_hartwig_tests.split == splits[j]],\n",
360-
" axs.flat[i], titles[j]\n",
361-
" )\n",
345+
" scatter(df_buchwald_hartwig_tests[df_buchwald_hartwig_tests.split == splits[j]], axs.flat[i], titles[j])\n",
362346
" j += 1\n",
363347
"\n",
364348
"index_subplots(axs.flat, font_size=14, y=1.17)\n",
@@ -401,7 +385,7 @@
401385
"\n",
402386
"plt_cm = []\n",
403387
"for i in cm.classes:\n",
404-
" row=[]\n",
388+
" row = []\n",
405389
" for j in cm.classes:\n",
406390
" row.append(cm.table[i][j])\n",
407391
" plt_cm.append(row)\n",
@@ -414,9 +398,7 @@
414398
"\n",
415399
"\n",
416400
"sns.heatmap(\n",
417-
" plt_cm, cmap=\"RdPu\", linewidths=.1, linecolor=\"#eeeeee\", square=True, \n",
418-
" cbar_kws={\"shrink\": 0.5}, norm=LogNorm(),\n",
419-
" ax=ax\n",
401+
" plt_cm, cmap=\"RdPu\", linewidths=0.1, linecolor=\"#eeeeee\", square=True, cbar_kws={\"shrink\": 0.5}, norm=LogNorm(), ax=ax\n",
420402
")\n",
421403
"\n",
422404
"cax = plt.gcf().axes[-1]\n",
@@ -491,8 +473,17 @@
491473
"y.extend(y_train)\n",
492474
"y.extend(y_test)\n",
493475
"\n",
494-
"labels = {\"1\": \"Heteroatom alkylation and arylation\", \"2\": \"Acylation and related processes\", \"3\": \"C-C bond formation\", \"5\": \"Protections\", \"6\": \"Deprotections\",\n",
495-
" \"7\": \"Reductions\", \"8\": \"Oxidations\", \"9\": \"Functional group interconversion (FGI)\", \"10\": \"Functional group addition (FGA)\"}\n",
476+
"labels = {\n",
477+
" \"1\": \"Heteroatom alkylation and arylation\",\n",
478+
" \"2\": \"Acylation and related processes\",\n",
479+
" \"3\": \"C-C bond formation\",\n",
480+
" \"5\": \"Protections\",\n",
481+
" \"6\": \"Deprotections\",\n",
482+
" \"7\": \"Reductions\",\n",
483+
" \"8\": \"Oxidations\",\n",
484+
" \"9\": \"Functional group interconversion (FGI)\",\n",
485+
" \"10\": \"Functional group addition (FGA)\",\n",
486+
"}\n",
496487
"\n",
497488
"y_values = [labels[ytem.split(\".\")[0]] for ytem in y]\n",
498489
"\n",
@@ -535,24 +526,16 @@
535526
" \"#595959\",\n",
536527
" \"#5f9ed1\",\n",
537528
" \"#c85300\",\n",
538-
" #\"#898989\",\n",
529+
" # \"#898989\",\n",
539530
" \"#a2c8ec\",\n",
540531
" \"#ffbc79\",\n",
541-
" \"#cfcfcf\"\n",
532+
" \"#cfcfcf\",\n",
542533
"]\n",
543534
"\n",
544535
"df_tmap = pd.DataFrame({\"x\": x, \"y\": y, \"c\": y_values})\n",
545536
"sns.scatterplot(x=\"x\", y=\"y\", hue=\"c\", data=df_tmap, s=5.0, palette=palette, ax=ax, zorder=2)\n",
546537
"\n",
547-
"legend = ax.legend(\n",
548-
" loc=\"center left\", \n",
549-
" bbox_to_anchor=(1, 0.5),\n",
550-
" fancybox=False, \n",
551-
" shadow=False, \n",
552-
" frameon=False,\n",
553-
" ncol=1,\n",
554-
" fontsize=7\n",
555-
")\n",
538+
"legend = ax.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5), fancybox=False, shadow=False, frameon=False, ncol=1, fontsize=7)\n",
556539
"\n",
557540
"for handle in legend.legendHandles:\n",
558541
" handle.set_sizes([12.0])\n",

0 commit comments

Comments
 (0)