|
42 | 42 | "\n", |
43 | 43 | "pd.options.mode.chained_assignment = None\n", |
44 | 44 | "\n", |
45 | | - "modern_cmap = LinearSegmentedColormap.from_list(\n", |
46 | | - " \"modern_cmap\", \n", |
47 | | - " [\"#ffffff\", \"#003f5c\"], \n", |
48 | | - " N=256\n", |
49 | | - ")\n", |
| 45 | + "modern_cmap = LinearSegmentedColormap.from_list(\"modern_cmap\", [\"#ffffff\", \"#003f5c\"], N=256)\n", |
50 | 46 | "\n", |
51 | 47 | "schneider_class_names = [\n", |
52 | 48 | " \"Alductive amination\",\n", |
|
124 | 120 | } |
125 | 121 | ], |
126 | 122 | "source": [ |
127 | | - "df = pd.read_csv(\"yield_prediction_results.csv\", names=[\n", |
128 | | - " \"data_set\", \"split\", \"filename\", \"ground_truth\", \"prediction\"\n", |
129 | | - "])\n", |
| 123 | + "df = pd.read_csv(\"yield_prediction_results.csv\", names=[\"data_set\", \"split\", \"filename\", \"ground_truth\", \"prediction\"])\n", |
130 | 124 | "\n", |
131 | 125 | "df[\"error\"] = df.prediction - df.ground_truth\n", |
132 | 126 | "\n", |
|
181 | 175 | " fontfamily=font_family,\n", |
182 | 176 | " )\n", |
183 | 177 | "\n", |
| 178 | + "\n", |
184 | 179 | "def calc_r2(df, verbose=True):\n", |
185 | 180 | " result = []\n", |
186 | 181 | " result_raw = []\n", |
|
196 | 191 | " r2 = r2_score(df_tmp.ground_truth, df_tmp.prediction)\n", |
197 | 192 | " result_raw.append({\"r2\": r2, \"split\": split})\n", |
198 | 193 | " r2s.append(r2)\n", |
199 | | - " \n", |
| 194 | + "\n", |
200 | 195 | " if verbose:\n", |
201 | 196 | " print(f\"r2 mean={round(sum(r2s) / len(r2s), 5)}, r2 std={round(stdev(r2s), 5)}\")\n", |
202 | 197 | " result.append((sum(r2s) / len(r2s), stdev(r2s)))\n", |
203 | 198 | "\n", |
204 | 199 | " return (result, pd.DataFrame(result_raw))\n", |
205 | 200 | "\n", |
| 201 | + "\n", |
206 | 202 | "def scatter(df, ax, title):\n", |
207 | 203 | " sns.kdeplot(\n", |
208 | | - " data=df, \n", |
209 | | - " x=\"ground_truth\", y=\"prediction\", clip=((0, 100), (None, None)),\n", |
210 | | - " color=\"#003f5c\", levels=6, zorder=2, ax=ax\n", |
| 204 | + " data=df, x=\"ground_truth\", y=\"prediction\", clip=((0, 100), (None, None)), color=\"#003f5c\", levels=6, zorder=2, ax=ax\n", |
211 | 205 | " )\n", |
212 | 206 | "\n", |
213 | 207 | " ax.plot(\n", |
214 | | - " [0, 100], [0, 100], linewidth=2, \n", |
215 | | - " color=\"#bc5090\", linestyle=\"dashed\",\n", |
| 208 | + " [0, 100],\n", |
| 209 | + " [0, 100],\n", |
| 210 | + " linewidth=2,\n", |
| 211 | + " color=\"#bc5090\",\n", |
| 212 | + " linestyle=\"dashed\",\n", |
216 | 213 | " zorder=1,\n", |
217 | 214 | " )\n", |
218 | 215 | "\n", |
219 | | - " sns.scatterplot(\n", |
220 | | - " data=df, \n", |
221 | | - " x=\"ground_truth\", y=\"prediction\",\n", |
222 | | - " color=\"#cccccc\", linewidth=0,\n", |
223 | | - " alpha=0.125, zorder=0, ax=ax\n", |
224 | | - " )\n", |
| 216 | + " sns.scatterplot(data=df, x=\"ground_truth\", y=\"prediction\", color=\"#cccccc\", linewidth=0, alpha=0.125, zorder=0, ax=ax)\n", |
225 | 217 | "\n", |
226 | | - " ax.set(xlabel=\"Ground Truth\", ylabel='Prediction')\n", |
| 218 | + " ax.set(xlabel=\"Ground Truth\", ylabel=\"Prediction\")\n", |
227 | 219 | " ax.set_title(title)" |
228 | 220 | ] |
229 | 221 | }, |
|
329 | 321 | "splits = [98, 197, 395, 791, 1186, 1977, 2766]\n", |
330 | 322 | "\n", |
331 | 323 | "for i in range(7):\n", |
332 | | - " scatter(\n", |
333 | | - " df_buchwald_hartwig_cv[df_buchwald_hartwig_cv.split == splits[i]],\n", |
334 | | - " axs.flat[i], titles[i]\n", |
335 | | - " )\n", |
| 324 | + " scatter(df_buchwald_hartwig_cv[df_buchwald_hartwig_cv.split == splits[i]], axs.flat[i], titles[i])\n", |
336 | 325 | "\n", |
337 | 326 | "_, df_results = calc_r2(df_buchwald_hartwig_cv, verbose=False)\n", |
338 | 327 | "\n", |
339 | 328 | "sns.stripplot(\n", |
340 | | - " x=\"split\", y=\"r2\", data=df_results, linewidth=1, ax=axs.flat[7],\n", |
341 | | - " palette=[\"#003f5c\", \"#374c80\", \"#7a5195\", \"#bc5090\", \"#ef5675\", \"#ff764a\", \"#ffa600\"]\n", |
| 329 | + " x=\"split\",\n", |
| 330 | + " y=\"r2\",\n", |
| 331 | + " data=df_results,\n", |
| 332 | + " linewidth=1,\n", |
| 333 | + " ax=axs.flat[7],\n", |
| 334 | + " palette=[\"#003f5c\", \"#374c80\", \"#7a5195\", \"#bc5090\", \"#ef5675\", \"#ff764a\", \"#ffa600\"],\n", |
342 | 335 | ")\n", |
343 | 336 | "axs.flat[7].set_xticklabels([\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\"])\n", |
344 | 337 | "axs.flat[7].set(xlabel=\"Split\", ylabel=\"Accuracy\")\n", |
345 | 338 | "\n", |
346 | | - "titles = [\n", |
347 | | - " \"Out-of-sample Split 1\", \"Out-of-sample Split 2\", \n", |
348 | | - " \"Out-of-sample Split 3\", \"Out-of-sample Split 4\"\n", |
349 | | - "]\n", |
| 339 | + "titles = [\"Out-of-sample Split 1\", \"Out-of-sample Split 2\", \"Out-of-sample Split 3\", \"Out-of-sample Split 4\"]\n", |
350 | 340 | "\n", |
351 | | - "splits = [\n", |
352 | | - " \"Test1-2048-3-true.pkl\", \"Test2-2048-3-true.pkl\", \n", |
353 | | - " \"Test3-2048-3-true.pkl\", \"Test4-2048-3-true.pkl\"\n", |
354 | | - "]\n", |
| 341 | + "splits = [\"Test1-2048-3-true.pkl\", \"Test2-2048-3-true.pkl\", \"Test3-2048-3-true.pkl\", \"Test4-2048-3-true.pkl\"]\n", |
355 | 342 | "\n", |
356 | 343 | "j = 0\n", |
357 | 344 | "for i in range(8, 12):\n", |
358 | | - " scatter(\n", |
359 | | - " df_buchwald_hartwig_tests[df_buchwald_hartwig_tests.split == splits[j]],\n", |
360 | | - " axs.flat[i], titles[j]\n", |
361 | | - " )\n", |
| 345 | + " scatter(df_buchwald_hartwig_tests[df_buchwald_hartwig_tests.split == splits[j]], axs.flat[i], titles[j])\n", |
362 | 346 | " j += 1\n", |
363 | 347 | "\n", |
364 | 348 | "index_subplots(axs.flat, font_size=14, y=1.17)\n", |
|
401 | 385 | "\n", |
402 | 386 | "plt_cm = []\n", |
403 | 387 | "for i in cm.classes:\n", |
404 | | - " row=[]\n", |
| 388 | + " row = []\n", |
405 | 389 | " for j in cm.classes:\n", |
406 | 390 | " row.append(cm.table[i][j])\n", |
407 | 391 | " plt_cm.append(row)\n", |
|
414 | 398 | "\n", |
415 | 399 | "\n", |
416 | 400 | "sns.heatmap(\n", |
417 | | - " plt_cm, cmap=\"RdPu\", linewidths=.1, linecolor=\"#eeeeee\", square=True, \n", |
418 | | - " cbar_kws={\"shrink\": 0.5}, norm=LogNorm(),\n", |
419 | | - " ax=ax\n", |
| 401 | + " plt_cm, cmap=\"RdPu\", linewidths=0.1, linecolor=\"#eeeeee\", square=True, cbar_kws={\"shrink\": 0.5}, norm=LogNorm(), ax=ax\n", |
420 | 402 | ")\n", |
421 | 403 | "\n", |
422 | 404 | "cax = plt.gcf().axes[-1]\n", |
|
491 | 473 | "y.extend(y_train)\n", |
492 | 474 | "y.extend(y_test)\n", |
493 | 475 | "\n", |
494 | | - "labels = {\"1\": \"Heteroatom alkylation and arylation\", \"2\": \"Acylation and related processes\", \"3\": \"C-C bond formation\", \"5\": \"Protections\", \"6\": \"Deprotections\",\n", |
495 | | - " \"7\": \"Reductions\", \"8\": \"Oxidations\", \"9\": \"Functional group interconversion (FGI)\", \"10\": \"Functional group addition (FGA)\"}\n", |
| 476 | + "labels = {\n", |
| 477 | + " \"1\": \"Heteroatom alkylation and arylation\",\n", |
| 478 | + " \"2\": \"Acylation and related processes\",\n", |
| 479 | + " \"3\": \"C-C bond formation\",\n", |
| 480 | + " \"5\": \"Protections\",\n", |
| 481 | + " \"6\": \"Deprotections\",\n", |
| 482 | + " \"7\": \"Reductions\",\n", |
| 483 | + " \"8\": \"Oxidations\",\n", |
| 484 | + " \"9\": \"Functional group interconversion (FGI)\",\n", |
| 485 | + " \"10\": \"Functional group addition (FGA)\",\n", |
| 486 | + "}\n", |
496 | 487 | "\n", |
497 | 488 | "y_values = [labels[ytem.split(\".\")[0]] for ytem in y]\n", |
498 | 489 | "\n", |
|
535 | 526 | " \"#595959\",\n", |
536 | 527 | " \"#5f9ed1\",\n", |
537 | 528 | " \"#c85300\",\n", |
538 | | - " #\"#898989\",\n", |
| 529 | + " # \"#898989\",\n", |
539 | 530 | " \"#a2c8ec\",\n", |
540 | 531 | " \"#ffbc79\",\n", |
541 | | - " \"#cfcfcf\"\n", |
| 532 | + " \"#cfcfcf\",\n", |
542 | 533 | "]\n", |
543 | 534 | "\n", |
544 | 535 | "df_tmap = pd.DataFrame({\"x\": x, \"y\": y, \"c\": y_values})\n", |
545 | 536 | "sns.scatterplot(x=\"x\", y=\"y\", hue=\"c\", data=df_tmap, s=5.0, palette=palette, ax=ax, zorder=2)\n", |
546 | 537 | "\n", |
547 | | - "legend = ax.legend(\n", |
548 | | - " loc=\"center left\", \n", |
549 | | - " bbox_to_anchor=(1, 0.5),\n", |
550 | | - " fancybox=False, \n", |
551 | | - " shadow=False, \n", |
552 | | - " frameon=False,\n", |
553 | | - " ncol=1,\n", |
554 | | - " fontsize=7\n", |
555 | | - ")\n", |
| 538 | + "legend = ax.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5), fancybox=False, shadow=False, frameon=False, ncol=1, fontsize=7)\n", |
556 | 539 | "\n", |
557 | 540 | "for handle in legend.legendHandles:\n", |
558 | 541 | " handle.set_sizes([12.0])\n", |
|
0 commit comments