Skip to content

Commit 6345afe

Browse files
committed
Update with tensorboard logging and black formatting for .py and .ipynb
1 parent 6abe498 commit 6345afe

10 files changed

Lines changed: 1711 additions & 666 deletions

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.10

confusionmatrix.png

-2.05 KB
Loading

dataset_augmentation_esp.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
### Setup ###
1010
console = Console()
1111
load_dotenv()
12-
client = OpenAI(
13-
api_key=os.environ.get('DEEPSEEK_API_KEY'),
14-
base_url="https://api.deepseek.com"
15-
)
12+
client = OpenAI(api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")
1613

1714
system_prompt = """Eres un anotador para una tarea de clasificación. En la entrada recibirás la biografía de un usuario de Twitter y algunos de sus tuits.
1815
Tu tarea es decidir si el usuario en cuestión forma parte o no de la comunidad LGBT.
@@ -36,31 +33,28 @@
3633
iterable = zip(ita["text"], ita["bio"])
3734

3835
for text, bio in track(iterable, description="[cyan]Processing entries...[/cyan]", total=len(ita)):
39-
36+
4037
bio = "" if str(bio) == "nan" else bio
4138
user_message = f'"{text}" - "{bio}"'
42-
39+
4340
try:
4441
response = client.chat.completions.create(
4542
model="deepseek-chat",
46-
messages=[
47-
{"role": "system", "content": system_prompt},
48-
{"role": "user", "content": user_message}
49-
],
50-
stream=False
43+
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}],
44+
stream=False,
5145
)
5246
answer = response.choices[0].message.content.strip()
5347
except Exception as e:
5448
console.print(f"[bold red]An error occurred: {e}\nRetrying in 2 minutes...[/bold red]")
5549
answer = "error"
5650
time.sleep(120)
57-
51+
5852
if answer in ["1", "0"]:
5953
answer = int(answer)
6054
else:
6155
# Mark ambiguous/error responses
6256
answer = 0.5
63-
57+
6458
lgbt.append(answer)
6559
time.sleep(0.5)
6660

@@ -69,4 +63,4 @@
6963
console.print("\n[bold yellow]Saving augmented dataset...[/bold yellow]")
7064
ita.to_csv("augmented_es.csv", index=False)
7165
console.print("[bold green]:white_check_mark: File saved as [cyan]augmented_es.csv[/cyan][/bold green]")
72-
console.print(ita.head())
66+
console.print(ita.head())

dataset_augmentation_ita.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
### Setup ###
1010
console = Console()
1111
load_dotenv()
12-
client = OpenAI(
13-
api_key=os.environ.get('DEEPSEEK_API_KEY'),
14-
base_url="https://api.deepseek.com"
15-
)
12+
client = OpenAI(api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com")
1613

1714
system_prompt = """Sei un annotatore per un task di classificazione. In input riceverai la bio di un utente Twitter e alcuni suoi tweet.
1815
Il tuo compito è decidere se l'utente in questione fa parte o meno della comunità LGBT.
@@ -36,31 +33,28 @@
3633
iterable = zip(ita["text"], ita["bio"])
3734

3835
for text, bio in track(iterable, description="[cyan]Processing entries...[/cyan]", total=len(ita)):
39-
36+
4037
bio = "" if str(bio) == "nan" else bio
4138
user_message = f'"{text}" - "{bio}"'
42-
39+
4340
try:
4441
response = client.chat.completions.create(
4542
model="deepseek-chat",
46-
messages=[
47-
{"role": "system", "content": system_prompt},
48-
{"role": "user", "content": user_message}
49-
],
50-
stream=False
43+
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}],
44+
stream=False,
5145
)
5246
answer = response.choices[0].message.content.strip()
5347
except Exception as e:
5448
console.print(f"[bold red]An error occurred: {e}\nRetrying in 2 minutes...[/bold red]")
5549
answer = "error"
5650
time.sleep(120)
57-
51+
5852
if answer in ["1", "0"]:
5953
answer = int(answer)
6054
else:
6155
# Mark ambiguous/error responses
6256
answer = 0.5
63-
57+
6458
lgbt.append(answer)
6559
time.sleep(0.5)
6660

@@ -69,4 +63,4 @@
6963
console.print("\n[bold yellow]Saving augmented dataset...[/bold yellow]")
7064
ita.to_csv("augmented_it.csv", index=False)
7165
console.print("[bold green]:white_check_mark: File saved as [cyan]augmented_it.csv[/cyan][/bold green]")
72-
console.print(ita.head())
66+
console.print(ita.head())

double_pipeline.ipynb

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
" Trainer,\n",
1818
" AutoModelForSequenceClassification,\n",
1919
" TrainingArguments,\n",
20-
" EarlyStoppingCallback\n",
20+
" EarlyStoppingCallback,\n",
2121
")\n",
2222
"from datasets import Dataset as HFDataset\n",
2323
"from evaluate import load as load_metric\n",
@@ -28,9 +28,11 @@
2828
"\n",
2929
"\n",
3030
"from huggingface_hub.utils import disable_progress_bars\n",
31+
"\n",
3132
"disable_progress_bars()\n",
3233
"\n",
3334
"import os\n",
35+
"\n",
3436
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
3537
"os.environ[\"TRANSFORMERS_NO_ADVISORY_WARNINGS\"] = \"true\""
3638
]
@@ -57,9 +59,9 @@
5759
"outputs": [],
5860
"source": [
5961
"ita = pd.read_csv(\"dataset/augmented_it.csv\")\n",
60-
"#For the moment, it works only on italian\n",
62+
"# For the moment, it works only on italian\n",
6163
"dataset = pd.concat([ita])\n",
62-
"dataset['bio'] = dataset['bio'].fillna('')"
64+
"dataset[\"bio\"] = dataset[\"bio\"].fillna(\"\")"
6365
]
6466
},
6567
{
@@ -243,12 +245,7 @@
243245
"metadata": {},
244246
"outputs": [],
245247
"source": [
246-
"pre_train_df, pre_test_df = train_test_split(\n",
247-
" dataset,\n",
248-
" test_size=0.3,\n",
249-
" stratify=dataset[\"lgbt\"],\n",
250-
" random_state=42\n",
251-
")"
248+
"pre_train_df, pre_test_df = train_test_split(dataset, test_size=0.3, stratify=dataset[\"lgbt\"], random_state=42)"
252249
]
253250
},
254251
{
@@ -283,7 +280,7 @@
283280
" batch[\"bio\"],\n",
284281
" truncation=True,\n",
285282
" padding=\"max_length\",\n",
286-
" max_length=128, # Lunghezza massima per i testi\n",
283+
" max_length=128, # Lunghezza massima per i testi\n",
287284
" )"
288285
]
289286
},
@@ -645,6 +642,7 @@
645642
"source": [
646643
"class DualEncoderForSequenceClassification(PreTrainedModel):\n",
647644
" config_class = AutoConfig\n",
645+
"\n",
648646
" def __init__(self, config):\n",
649647
" super().__init__(config)\n",
650648
" self.num_labels = config.num_labels\n",
@@ -654,10 +652,7 @@
654652
"\n",
655653
" # Gating layer: it weights the two source of informations for the final classification\n",
656654
" self.gate_layer = nn.Sequential(\n",
657-
" nn.Linear(hidden_size * 2, hidden_size),\n",
658-
" nn.Tanh(),\n",
659-
" nn.Linear(hidden_size, hidden_size),\n",
660-
" nn.Sigmoid()\n",
655+
" nn.Linear(hidden_size * 2, hidden_size), nn.Tanh(), nn.Linear(hidden_size, hidden_size), nn.Sigmoid()\n",
661656
" )\n",
662657
"\n",
663658
" # Final classifier\n",
@@ -686,14 +681,14 @@
686681
" return_dict=return_dict,\n",
687682
" )\n",
688683
"\n",
689-
" #Obtain last hidden state of the bert model\n",
684+
" # Obtain last hidden state of the bert model\n",
690685
" h_text = outputs_text.last_hidden_state[:, 0]\n",
691686
" h_bio = outputs_bio.last_hidden_state[:, 0]\n",
692687
"\n",
693-
" #Concat the hidden states\n",
688+
" # Concat the hidden states\n",
694689
" combined = torch.cat((h_text, h_bio), dim=-1)\n",
695-
" #Gate the informations\n",
696-
" gate = self.gate_layer(combined) \n",
690+
" # Gate the informations\n",
691+
" gate = self.gate_layer(combined)\n",
697692
" h_final = gate * h_text + (1 - gate) * h_bio\n",
698693
"\n",
699694
" # Classification\n",
@@ -703,7 +698,7 @@
703698
" # Loss computation\n",
704699
" loss = None\n",
705700
" if labels is not None:\n",
706-
" if hasattr(self.config, 'class_weights') and self.config.class_weights is not None:\n",
701+
" if hasattr(self.config, \"class_weights\") and self.config.class_weights is not None:\n",
707702
" class_weights = torch.tensor(self.config.class_weights, device=self.device)\n",
708703
" else:\n",
709704
" class_weights = None\n",
@@ -727,12 +722,7 @@
727722
"metadata": {},
728723
"outputs": [],
729724
"source": [
730-
"train_df, test_df = train_test_split(\n",
731-
" dataset,\n",
732-
" test_size=0.3,\n",
733-
" stratify=dataset[\"label\"],\n",
734-
" random_state=42\n",
735-
")"
725+
"train_df, test_df = train_test_split(dataset, test_size=0.3, stratify=dataset[\"label\"], random_state=42)"
736726
]
737727
},
738728
{
@@ -864,6 +854,7 @@
864854
"source": [
865855
"del trainer.model, trainer\n",
866856
"import gc\n",
857+
"\n",
867858
"gc.collect()\n",
868859
"torch.cuda.empty_cache()"
869860
]
@@ -1009,7 +1000,7 @@
10091000
" save_strategy=\"epoch\",\n",
10101001
" learning_rate=2e-5,\n",
10111002
" per_device_train_batch_size=16,\n",
1012-
" #gradient_accumulation_steps=2, #Since higly unbalanced, this should provide also negative examples\n",
1003+
" # gradient_accumulation_steps=2, #Since higly unbalanced, this should provide also negative examples\n",
10131004
" per_device_eval_batch_size=4,\n",
10141005
" num_train_epochs=8,\n",
10151006
" weight_decay=0.1,\n",
@@ -1018,7 +1009,7 @@
10181009
" logging_dir=\"./logs_weighted\",\n",
10191010
" logging_steps=50,\n",
10201011
" save_total_limit=2,\n",
1021-
" #label_smoothing_factor=0.1,\n",
1012+
" # label_smoothing_factor=0.1,\n",
10221013
")"
10231014
]
10241015
},
@@ -1053,7 +1044,7 @@
10531044
" eval_dataset=test_ds,\n",
10541045
" tokenizer=tokenizer,\n",
10551046
" compute_metrics=compute_metrics,\n",
1056-
" callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] \n",
1047+
" callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
10571048
")"
10581049
]
10591050
},
@@ -1346,20 +1337,20 @@
13461337
"true_labels = predictions_output.label_ids\n",
13471338
"predicted_labels = np.argmax(logits, axis=-1)\n",
13481339
"\n",
1349-
"# Compute the confidence of the model \n",
1340+
"# Compute the confidence of the model\n",
13501341
"probabilities = softmax(logits, axis=1)\n",
13511342
"confidence_scores = np.max(probabilities, axis=1)\n",
13521343
"\n",
13531344
"results_df = test_df.copy()\n",
13541345
"\n",
1355-
"results_df['predicted_label'] = predicted_labels\n",
1356-
"results_df['true_label'] = true_labels\n",
1357-
"results_df['confidence'] = confidence_scores\n",
1346+
"results_df[\"predicted_label\"] = predicted_labels\n",
1347+
"results_df[\"true_label\"] = true_labels\n",
1348+
"results_df[\"confidence\"] = confidence_scores\n",
13581349
"\n",
1359-
"results_df['is_correct'] = (results_df['true_label'] == results_df['predicted_label'])\n",
1350+
"results_df[\"is_correct\"] = results_df[\"true_label\"] == results_df[\"predicted_label\"]\n",
13601351
"\n",
13611352
"output_filename = \"error_analysis_results.csv\"\n",
1362-
"results_df.to_csv(output_filename, index=False, encoding='utf-8-sig')\n",
1353+
"results_df.to_csv(output_filename, index=False, encoding=\"utf-8-sig\")\n",
13631354
"\n",
13641355
"print(f\"\\nRisultati salvati in '{output_filename}'\")\n",
13651356
"print(\"\\nAnteprima del DataFrame con i risultati:\")\n",
@@ -1471,14 +1462,15 @@
14711462
],
14721463
"source": [
14731464
"# Shows only errors\n",
1474-
"errors_df = results_df[results_df['is_correct'] == False].copy()\n",
1465+
"errors_df = results_df[results_df[\"is_correct\"] == False].copy()\n",
14751466
"\n",
14761467
"print(f\"{len(errors_df)}/{len(results_df)} errors.\")\n",
14771468
"\n",
1469+
"\n",
14781470
"def print_error_details(dataframe):\n",
14791471
" if dataframe.empty:\n",
14801472
" return\n",
1481-
" \n",
1473+
"\n",
14821474
" for index, row in dataframe.iterrows():\n",
14831475
" print(\"-\" * 50)\n",
14841476
" print(f\"Confidence: {row['confidence']:.2%}\")\n",
@@ -1487,17 +1479,18 @@
14871479
" print(f\"Text:\\n\\\"{row['text']}\\\"\")\n",
14881480
" print(\"-\" * 50 + \"\\n\")\n",
14891481
"\n",
1482+
"\n",
14901483
"# Errors with high confidence\n",
14911484
"N = 5\n",
1492-
"high_confidence_errors = errors_df.sort_values(by='confidence', ascending=False).head(N)\n",
1485+
"high_confidence_errors = errors_df.sort_values(by=\"confidence\", ascending=False).head(N)\n",
14931486
"\n",
1494-
"print(\"\\n\" + \"=\"*20 + \" TOP 5 ERRORS WITH HIGH CONFIDENCE \" + \"=\"*20)\n",
1487+
"print(\"\\n\" + \"=\" * 20 + \" TOP 5 ERRORS WITH HIGH CONFIDENCE \" + \"=\" * 20)\n",
14951488
"print_error_details(high_confidence_errors)\n",
14961489
"\n",
14971490
"\n",
1498-
"low_confidence_errors = errors_df.sort_values(by='confidence', ascending=True).head(N)\n",
1499-
"print(\"\\n\" + \"=\"*20 + \" TOP 5 ERRORS LOW CONFIDENCE \" + \"=\"*20)\n",
1500-
"print_error_details(low_confidence_errors)\n"
1491+
"low_confidence_errors = errors_df.sort_values(by=\"confidence\", ascending=True).head(N)\n",
1492+
"print(\"\\n\" + \"=\" * 20 + \" TOP 5 ERRORS LOW CONFIDENCE \" + \"=\" * 20)\n",
1493+
"print_error_details(low_confidence_errors)"
15011494
]
15021495
},
15031496
{
@@ -1524,15 +1517,14 @@
15241517
"\n",
15251518
"cm = confusion_matrix(true_labels, predicted_labels)\n",
15261519
"\n",
1527-
"class_labels = ['Non-Reclamatory', 'Reclamatory']\n",
1520+
"class_labels = [\"Non-Reclamatory\", \"Reclamatory\"]\n",
15281521
"\n",
15291522
"plt.figure(figsize=(8, 6))\n",
1530-
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', \n",
1531-
" xticklabels=class_labels, yticklabels=class_labels)\n",
1523+
"sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=class_labels, yticklabels=class_labels)\n",
15321524
"\n",
1533-
"plt.title('Confusion matrix')\n",
1534-
"plt.ylabel('True Label')\n",
1535-
"plt.xlabel('Predicted Label')\n",
1525+
"plt.title(\"Confusion matrix\")\n",
1526+
"plt.ylabel(\"True Label\")\n",
1527+
"plt.xlabel(\"Predicted Label\")\n",
15361528
"\n",
15371529
"plt.show()"
15381530
]

0 commit comments

Comments
 (0)