Skip to content

Commit 1997a92

Browse files
antalvdbclaude
andcommitted
Fix set_epoch bug in Jupyter notebook
Add hasattr check before calling model.set_epoch() to prevent AttributeError when training baseline BERT models that don't have this method. Only BP-SOM models have set_epoch(). Bug found during Colab testing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 651e2f5 commit 1997a92

1 file changed

Lines changed: 2 additions & 79 deletions

File tree

BPSOM_SST2_Experiment.ipynb

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -605,84 +605,7 @@
605605
"execution_count": null,
606606
"metadata": {},
607607
"outputs": [],
608-
"source": [
609-
"def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, max_epochs):\n",
610-
" \"\"\"Train for one epoch.\"\"\"\n",
611-
" model.train()\n",
612-
" model.set_epoch(epoch, max_epochs)\n",
613-
" \n",
614-
" total_loss = 0\n",
615-
" correct = 0\n",
616-
" total = 0\n",
617-
" \n",
618-
" progress_bar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{max_epochs} [Train]\")\n",
619-
" \n",
620-
" for batch in progress_bar:\n",
621-
" input_ids = batch['input_ids'].to(device)\n",
622-
" attention_mask = batch['attention_mask'].to(device)\n",
623-
" labels = batch['labels'].to(device)\n",
624-
" \n",
625-
" optimizer.zero_grad()\n",
626-
" \n",
627-
" outputs = model(\n",
628-
" input_ids=input_ids,\n",
629-
" attention_mask=attention_mask,\n",
630-
" labels=labels\n",
631-
" )\n",
632-
" \n",
633-
" loss = outputs.loss\n",
634-
" loss.backward()\n",
635-
" optimizer.step()\n",
636-
" scheduler.step()\n",
637-
" \n",
638-
" total_loss += loss.item()\n",
639-
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
640-
" correct += (predictions == labels).sum().item()\n",
641-
" total += labels.size(0)\n",
642-
" \n",
643-
" progress_bar.set_postfix({\n",
644-
" 'loss': f\"{loss.item():.4f}\",\n",
645-
" 'acc': f\"{correct/total*100:.2f}%\"\n",
646-
" })\n",
647-
" \n",
648-
" return {\n",
649-
" 'loss': total_loss / len(dataloader),\n",
650-
" 'accuracy': correct / total * 100\n",
651-
" }\n",
652-
"\n",
653-
"\n",
654-
"def eval_epoch(model, dataloader, device, desc=\"Eval\"):\n",
655-
" \"\"\"Evaluate model.\"\"\"\n",
656-
" model.eval()\n",
657-
" \n",
658-
" total_loss = 0\n",
659-
" correct = 0\n",
660-
" total = 0\n",
661-
" \n",
662-
" with torch.no_grad():\n",
663-
" for batch in tqdm(dataloader, desc=desc):\n",
664-
" input_ids = batch['input_ids'].to(device)\n",
665-
" attention_mask = batch['attention_mask'].to(device)\n",
666-
" labels = batch['labels'].to(device)\n",
667-
" \n",
668-
" outputs = model(\n",
669-
" input_ids=input_ids,\n",
670-
" attention_mask=attention_mask,\n",
671-
" labels=labels\n",
672-
" )\n",
673-
" \n",
674-
" total_loss += outputs.loss.item()\n",
675-
" predictions = torch.argmax(outputs.logits, dim=-1)\n",
676-
" correct += (predictions == labels).sum().item()\n",
677-
" total += labels.size(0)\n",
678-
" \n",
679-
" return {\n",
680-
" 'loss': total_loss / len(dataloader),\n",
681-
" 'accuracy': correct / total * 100\n",
682-
" }\n",
683-
"\n",
684-
"print(\"✓ Training functions defined\")"
685-
]
608+
"source": "def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, max_epochs):\n \"\"\"Train for one epoch.\"\"\"\n model.train()\n \n # Set epoch for BP-SOM models (baseline models don't have this method)\n if hasattr(model, 'set_epoch'):\n model.set_epoch(epoch, max_epochs)\n \n total_loss = 0\n correct = 0\n total = 0\n \n progress_bar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{max_epochs} [Train]\")\n \n for batch in progress_bar:\n input_ids = batch['input_ids'].to(device)\n attention_mask = batch['attention_mask'].to(device)\n labels = batch['labels'].to(device)\n \n optimizer.zero_grad()\n \n outputs = model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n labels=labels\n )\n \n loss = outputs.loss\n loss.backward()\n optimizer.step()\n scheduler.step()\n \n total_loss += loss.item()\n predictions = torch.argmax(outputs.logits, dim=-1)\n correct += (predictions == labels).sum().item()\n total += labels.size(0)\n \n progress_bar.set_postfix({\n 'loss': f\"{loss.item():.4f}\",\n 'acc': f\"{correct/total*100:.2f}%\"\n })\n \n return {\n 'loss': total_loss / len(dataloader),\n 'accuracy': correct / total * 100\n}\n\n\ndef eval_epoch(model, dataloader, device, desc=\"Eval\"):\n \"\"\"Evaluate model.\"\"\"\n model.eval()\n \n total_loss = 0\n correct = 0\n total = 0\n \n with torch.no_grad():\n for batch in tqdm(dataloader, desc=desc):\n input_ids = batch['input_ids'].to(device)\n attention_mask = batch['attention_mask'].to(device)\n labels = batch['labels'].to(device)\n \n outputs = model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n labels=labels\n )\n \n total_loss += outputs.loss.item()\n predictions = torch.argmax(outputs.logits, dim=-1)\n correct += (predictions == labels).sum().item()\n total += labels.size(0)\n \n return {\n 'loss': total_loss / len(dataloader),\n 'accuracy': correct / total * 100\n }\n\nprint(\"✓ Training functions defined\")"
686609
},
687610
{
688611
"cell_type": "markdown",
@@ -1170,4 +1093,4 @@
11701093
},
11711094
"nbformat": 4,
11721095
"nbformat_minor": 4
1173-
}
1096+
}

0 commit comments

Comments
 (0)