|
17 | 17 | " Trainer,\n", |
18 | 18 | " AutoModelForSequenceClassification,\n", |
19 | 19 | " TrainingArguments,\n", |
20 | | - " EarlyStoppingCallback\n", |
| 20 | + " EarlyStoppingCallback,\n", |
21 | 21 | ")\n", |
22 | 22 | "from datasets import Dataset as HFDataset\n", |
23 | 23 | "from evaluate import load as load_metric\n", |
|
28 | 28 | "\n", |
29 | 29 | "\n", |
30 | 30 | "from huggingface_hub.utils import disable_progress_bars\n", |
| 31 | + "\n", |
31 | 32 | "disable_progress_bars()\n", |
32 | 33 | "\n", |
33 | 34 | "import os\n", |
| 35 | + "\n", |
34 | 36 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", |
35 | 37 | "os.environ[\"TRANSFORMERS_NO_ADVISORY_WARNINGS\"] = \"true\"" |
36 | 38 | ] |
|
57 | 59 | "outputs": [], |
58 | 60 | "source": [ |
59 | 61 | "ita = pd.read_csv(\"dataset/augmented_it.csv\")\n", |
60 | | - "#For the moment, it works only on italian\n", |
| 62 | + "# For the moment, it works only on italian\n", |
61 | 63 | "dataset = pd.concat([ita])\n", |
62 | | - "dataset['bio'] = dataset['bio'].fillna('')" |
| 64 | + "dataset[\"bio\"] = dataset[\"bio\"].fillna(\"\")" |
63 | 65 | ] |
64 | 66 | }, |
65 | 67 | { |
|
243 | 245 | "metadata": {}, |
244 | 246 | "outputs": [], |
245 | 247 | "source": [ |
246 | | - "pre_train_df, pre_test_df = train_test_split(\n", |
247 | | - " dataset,\n", |
248 | | - " test_size=0.3,\n", |
249 | | - " stratify=dataset[\"lgbt\"],\n", |
250 | | - " random_state=42\n", |
251 | | - ")" |
| 248 | + "pre_train_df, pre_test_df = train_test_split(dataset, test_size=0.3, stratify=dataset[\"lgbt\"], random_state=42)" |
252 | 249 | ] |
253 | 250 | }, |
254 | 251 | { |
|
283 | 280 | " batch[\"bio\"],\n", |
284 | 281 | " truncation=True,\n", |
285 | 282 | " padding=\"max_length\",\n", |
286 | | - " max_length=128, # Lunghezza massima per i testi\n", |
| 283 | + " max_length=128, # Lunghezza massima per i testi\n", |
287 | 284 | " )" |
288 | 285 | ] |
289 | 286 | }, |
|
645 | 642 | "source": [ |
646 | 643 | "class DualEncoderForSequenceClassification(PreTrainedModel):\n", |
647 | 644 | " config_class = AutoConfig\n", |
| 645 | + "\n", |
648 | 646 | " def __init__(self, config):\n", |
649 | 647 | " super().__init__(config)\n", |
650 | 648 | " self.num_labels = config.num_labels\n", |
|
654 | 652 | "\n", |
655 | 653 | " # Gating layer: it weights the two source of informations for the final classification\n", |
656 | 654 | " self.gate_layer = nn.Sequential(\n", |
657 | | - " nn.Linear(hidden_size * 2, hidden_size),\n", |
658 | | - " nn.Tanh(),\n", |
659 | | - " nn.Linear(hidden_size, hidden_size),\n", |
660 | | - " nn.Sigmoid()\n", |
| 655 | + " nn.Linear(hidden_size * 2, hidden_size), nn.Tanh(), nn.Linear(hidden_size, hidden_size), nn.Sigmoid()\n", |
661 | 656 | " )\n", |
662 | 657 | "\n", |
663 | 658 | " # Final classifier\n", |
|
686 | 681 | " return_dict=return_dict,\n", |
687 | 682 | " )\n", |
688 | 683 | "\n", |
689 | | - " #Obtain last hidden state of the bert model\n", |
| 684 | + " # Obtain last hidden state of the bert model\n", |
690 | 685 | " h_text = outputs_text.last_hidden_state[:, 0]\n", |
691 | 686 | " h_bio = outputs_bio.last_hidden_state[:, 0]\n", |
692 | 687 | "\n", |
693 | | - " #Concat the hidden states\n", |
| 688 | + " # Concat the hidden states\n", |
694 | 689 | " combined = torch.cat((h_text, h_bio), dim=-1)\n", |
695 | | - " #Gate the informations\n", |
696 | | - " gate = self.gate_layer(combined) \n", |
| 690 | + " # Gate the informations\n", |
| 691 | + " gate = self.gate_layer(combined)\n", |
697 | 692 | " h_final = gate * h_text + (1 - gate) * h_bio\n", |
698 | 693 | "\n", |
699 | 694 | " # Classification\n", |
|
703 | 698 | " # Loss computation\n", |
704 | 699 | " loss = None\n", |
705 | 700 | " if labels is not None:\n", |
706 | | - " if hasattr(self.config, 'class_weights') and self.config.class_weights is not None:\n", |
| 701 | + " if hasattr(self.config, \"class_weights\") and self.config.class_weights is not None:\n", |
707 | 702 | " class_weights = torch.tensor(self.config.class_weights, device=self.device)\n", |
708 | 703 | " else:\n", |
709 | 704 | " class_weights = None\n", |
|
727 | 722 | "metadata": {}, |
728 | 723 | "outputs": [], |
729 | 724 | "source": [ |
730 | | - "train_df, test_df = train_test_split(\n", |
731 | | - " dataset,\n", |
732 | | - " test_size=0.3,\n", |
733 | | - " stratify=dataset[\"label\"],\n", |
734 | | - " random_state=42\n", |
735 | | - ")" |
| 725 | + "train_df, test_df = train_test_split(dataset, test_size=0.3, stratify=dataset[\"label\"], random_state=42)" |
736 | 726 | ] |
737 | 727 | }, |
738 | 728 | { |
|
864 | 854 | "source": [ |
865 | 855 | "del trainer.model, trainer\n", |
866 | 856 | "import gc\n", |
| 857 | + "\n", |
867 | 858 | "gc.collect()\n", |
868 | 859 | "torch.cuda.empty_cache()" |
869 | 860 | ] |
|
1009 | 1000 | " save_strategy=\"epoch\",\n", |
1010 | 1001 | " learning_rate=2e-5,\n", |
1011 | 1002 | " per_device_train_batch_size=16,\n", |
1012 | | - " #gradient_accumulation_steps=2, #Since higly unbalanced, this should provide also negative examples\n", |
| 1003 | + " # gradient_accumulation_steps=2, #Since higly unbalanced, this should provide also negative examples\n", |
1013 | 1004 | " per_device_eval_batch_size=4,\n", |
1014 | 1005 | " num_train_epochs=8,\n", |
1015 | 1006 | " weight_decay=0.1,\n", |
|
1018 | 1009 | " logging_dir=\"./logs_weighted\",\n", |
1019 | 1010 | " logging_steps=50,\n", |
1020 | 1011 | " save_total_limit=2,\n", |
1021 | | - " #label_smoothing_factor=0.1,\n", |
| 1012 | + " # label_smoothing_factor=0.1,\n", |
1022 | 1013 | ")" |
1023 | 1014 | ] |
1024 | 1015 | }, |
|
1053 | 1044 | " eval_dataset=test_ds,\n", |
1054 | 1045 | " tokenizer=tokenizer,\n", |
1055 | 1046 | " compute_metrics=compute_metrics,\n", |
1056 | | - " callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] \n", |
| 1047 | + " callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n", |
1057 | 1048 | ")" |
1058 | 1049 | ] |
1059 | 1050 | }, |
|
1346 | 1337 | "true_labels = predictions_output.label_ids\n", |
1347 | 1338 | "predicted_labels = np.argmax(logits, axis=-1)\n", |
1348 | 1339 | "\n", |
1349 | | - "# Compute the confidence of the model \n", |
| 1340 | + "# Compute the confidence of the model\n", |
1350 | 1341 | "probabilities = softmax(logits, axis=1)\n", |
1351 | 1342 | "confidence_scores = np.max(probabilities, axis=1)\n", |
1352 | 1343 | "\n", |
1353 | 1344 | "results_df = test_df.copy()\n", |
1354 | 1345 | "\n", |
1355 | | - "results_df['predicted_label'] = predicted_labels\n", |
1356 | | - "results_df['true_label'] = true_labels\n", |
1357 | | - "results_df['confidence'] = confidence_scores\n", |
| 1346 | + "results_df[\"predicted_label\"] = predicted_labels\n", |
| 1347 | + "results_df[\"true_label\"] = true_labels\n", |
| 1348 | + "results_df[\"confidence\"] = confidence_scores\n", |
1358 | 1349 | "\n", |
1359 | | - "results_df['is_correct'] = (results_df['true_label'] == results_df['predicted_label'])\n", |
| 1350 | + "results_df[\"is_correct\"] = results_df[\"true_label\"] == results_df[\"predicted_label\"]\n", |
1360 | 1351 | "\n", |
1361 | 1352 | "output_filename = \"error_analysis_results.csv\"\n", |
1362 | | - "results_df.to_csv(output_filename, index=False, encoding='utf-8-sig')\n", |
| 1353 | + "results_df.to_csv(output_filename, index=False, encoding=\"utf-8-sig\")\n", |
1363 | 1354 | "\n", |
1364 | 1355 | "print(f\"\\nRisultati salvati in '{output_filename}'\")\n", |
1365 | 1356 | "print(\"\\nAnteprima del DataFrame con i risultati:\")\n", |
|
1471 | 1462 | ], |
1472 | 1463 | "source": [ |
1473 | 1464 | "# Shows only errors\n", |
1474 | | - "errors_df = results_df[results_df['is_correct'] == False].copy()\n", |
| 1465 | + "errors_df = results_df[results_df[\"is_correct\"] == False].copy()\n", |
1475 | 1466 | "\n", |
1476 | 1467 | "print(f\"{len(errors_df)}/{len(results_df)} errors.\")\n", |
1477 | 1468 | "\n", |
| 1469 | + "\n", |
1478 | 1470 | "def print_error_details(dataframe):\n", |
1479 | 1471 | " if dataframe.empty:\n", |
1480 | 1472 | " return\n", |
1481 | | - " \n", |
| 1473 | + "\n", |
1482 | 1474 | " for index, row in dataframe.iterrows():\n", |
1483 | 1475 | " print(\"-\" * 50)\n", |
1484 | 1476 | " print(f\"Confidence: {row['confidence']:.2%}\")\n", |
|
1487 | 1479 | " print(f\"Text:\\n\\\"{row['text']}\\\"\")\n", |
1488 | 1480 | " print(\"-\" * 50 + \"\\n\")\n", |
1489 | 1481 | "\n", |
| 1482 | + "\n", |
1490 | 1483 | "# Errors with high confidence\n", |
1491 | 1484 | "N = 5\n", |
1492 | | - "high_confidence_errors = errors_df.sort_values(by='confidence', ascending=False).head(N)\n", |
| 1485 | + "high_confidence_errors = errors_df.sort_values(by=\"confidence\", ascending=False).head(N)\n", |
1493 | 1486 | "\n", |
1494 | | - "print(\"\\n\" + \"=\"*20 + \" TOP 5 ERRORS WITH HIGH CONFIDENCE \" + \"=\"*20)\n", |
| 1487 | + "print(\"\\n\" + \"=\" * 20 + \" TOP 5 ERRORS WITH HIGH CONFIDENCE \" + \"=\" * 20)\n", |
1495 | 1488 | "print_error_details(high_confidence_errors)\n", |
1496 | 1489 | "\n", |
1497 | 1490 | "\n", |
1498 | | - "low_confidence_errors = errors_df.sort_values(by='confidence', ascending=True).head(N)\n", |
1499 | | - "print(\"\\n\" + \"=\"*20 + \" TOP 5 ERRORS LOW CONFIDENCE \" + \"=\"*20)\n", |
1500 | | - "print_error_details(low_confidence_errors)\n" |
| 1491 | + "low_confidence_errors = errors_df.sort_values(by=\"confidence\", ascending=True).head(N)\n", |
| 1492 | + "print(\"\\n\" + \"=\" * 20 + \" TOP 5 ERRORS LOW CONFIDENCE \" + \"=\" * 20)\n", |
| 1493 | + "print_error_details(low_confidence_errors)" |
1501 | 1494 | ] |
1502 | 1495 | }, |
1503 | 1496 | { |
|
1524 | 1517 | "\n", |
1525 | 1518 | "cm = confusion_matrix(true_labels, predicted_labels)\n", |
1526 | 1519 | "\n", |
1527 | | - "class_labels = ['Non-Reclamatory', 'Reclamatory']\n", |
| 1520 | + "class_labels = [\"Non-Reclamatory\", \"Reclamatory\"]\n", |
1528 | 1521 | "\n", |
1529 | 1522 | "plt.figure(figsize=(8, 6))\n", |
1530 | | - "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n", |
1531 | | - " xticklabels=class_labels, yticklabels=class_labels)\n", |
| 1523 | + "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=class_labels, yticklabels=class_labels)\n", |
1532 | 1524 | "\n", |
1533 | | - "plt.title('Confusion matrix')\n", |
1534 | | - "plt.ylabel('True Label')\n", |
1535 | | - "plt.xlabel('Predicted Label')\n", |
| 1525 | + "plt.title(\"Confusion matrix\")\n", |
| 1526 | + "plt.ylabel(\"True Label\")\n", |
| 1527 | + "plt.xlabel(\"Predicted Label\")\n", |
1536 | 1528 | "\n", |
1537 | 1529 | "plt.show()" |
1538 | 1530 | ] |
|
0 commit comments