605605 "execution_count" : null ,
606606 "metadata" : {},
607607 "outputs" : [],
608- "source" : [
609- " def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, max_epochs):\n " ,
610- " \"\"\" Train for one epoch.\"\"\"\n " ,
611- " model.train()\n " ,
612- " model.set_epoch(epoch, max_epochs)\n " ,
613- " \n " ,
614- " total_loss = 0\n " ,
615- " correct = 0\n " ,
616- " total = 0\n " ,
617- " \n " ,
618- " progress_bar = tqdm(dataloader, desc=f\" Epoch {epoch+1}/{max_epochs} [Train]\" )\n " ,
619- " \n " ,
620- " for batch in progress_bar:\n " ,
621- " input_ids = batch['input_ids'].to(device)\n " ,
622- " attention_mask = batch['attention_mask'].to(device)\n " ,
623- " labels = batch['labels'].to(device)\n " ,
624- " \n " ,
625- " optimizer.zero_grad()\n " ,
626- " \n " ,
627- " outputs = model(\n " ,
628- " input_ids=input_ids,\n " ,
629- " attention_mask=attention_mask,\n " ,
630- " labels=labels\n " ,
631- " )\n " ,
632- " \n " ,
633- " loss = outputs.loss\n " ,
634- " loss.backward()\n " ,
635- " optimizer.step()\n " ,
636- " scheduler.step()\n " ,
637- " \n " ,
638- " total_loss += loss.item()\n " ,
639- " predictions = torch.argmax(outputs.logits, dim=-1)\n " ,
640- " correct += (predictions == labels).sum().item()\n " ,
641- " total += labels.size(0)\n " ,
642- " \n " ,
643- " progress_bar.set_postfix({\n " ,
644- " 'loss': f\" {loss.item():.4f}\" ,\n " ,
645- " 'acc': f\" {correct/total*100:.2f}%\"\n " ,
646- " })\n " ,
647- " \n " ,
648- " return {\n " ,
649- " 'loss': total_loss / len(dataloader),\n " ,
650- " 'accuracy': correct / total * 100\n " ,
651- " }\n " ,
652- " \n " ,
653- " \n " ,
654- " def eval_epoch(model, dataloader, device, desc=\" Eval\" ):\n " ,
655- " \"\"\" Evaluate model.\"\"\"\n " ,
656- " model.eval()\n " ,
657- " \n " ,
658- " total_loss = 0\n " ,
659- " correct = 0\n " ,
660- " total = 0\n " ,
661- " \n " ,
662- " with torch.no_grad():\n " ,
663- " for batch in tqdm(dataloader, desc=desc):\n " ,
664- " input_ids = batch['input_ids'].to(device)\n " ,
665- " attention_mask = batch['attention_mask'].to(device)\n " ,
666- " labels = batch['labels'].to(device)\n " ,
667- " \n " ,
668- " outputs = model(\n " ,
669- " input_ids=input_ids,\n " ,
670- " attention_mask=attention_mask,\n " ,
671- " labels=labels\n " ,
672- " )\n " ,
673- " \n " ,
674- " total_loss += outputs.loss.item()\n " ,
675- " predictions = torch.argmax(outputs.logits, dim=-1)\n " ,
676- " correct += (predictions == labels).sum().item()\n " ,
677- " total += labels.size(0)\n " ,
678- " \n " ,
679- " return {\n " ,
680- " 'loss': total_loss / len(dataloader),\n " ,
681- " 'accuracy': correct / total * 100\n " ,
682- " }\n " ,
683- " \n " ,
684- " print(\" ✓ Training functions defined\" )"
685- ]
608+ "source": "def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, max_epochs):\n \"\"\"Train for one epoch.\"\"\"\n model.train()\n \n # Set epoch for BP-SOM models (baseline models don't have this method)\n if hasattr(model, 'set_epoch'):\n model.set_epoch(epoch, max_epochs)\n \n total_loss = 0\n correct = 0\n total = 0\n \n progress_bar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{max_epochs} [Train]\")\n \n for batch in progress_bar:\n input_ids = batch['input_ids'].to(device)\n attention_mask = batch['attention_mask'].to(device)\n labels = batch['labels'].to(device)\n \n optimizer.zero_grad()\n \n outputs = model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n labels=labels\n )\n \n loss = outputs.loss\n loss.backward()\n optimizer.step()\n scheduler.step()\n \n total_loss += loss.item()\n predictions = torch.argmax(outputs.logits, dim=-1)\n correct += (predictions == labels).sum().item()\n total += labels.size(0)\n \n progress_bar.set_postfix({\n 'loss': f\"{loss.item():.4f}\",\n 'acc': f\"{correct/total*100:.2f}%\"\n })\n \n return {\n 'loss': total_loss / len(dataloader),\n 'accuracy': correct / total * 100\n}\n\n\ndef eval_epoch(model, dataloader, device, desc=\"Eval\"):\n \"\"\"Evaluate model.\"\"\"\n model.eval()\n \n total_loss = 0\n correct = 0\n total = 0\n \n with torch.no_grad():\n for batch in tqdm(dataloader, desc=desc):\n input_ids = batch['input_ids'].to(device)\n attention_mask = batch['attention_mask'].to(device)\n labels = batch['labels'].to(device)\n \n outputs = model(\n input_ids=input_ids,\n attention_mask=attention_mask,\n labels=labels\n )\n \n total_loss += outputs.loss.item()\n predictions = torch.argmax(outputs.logits, dim=-1)\n correct += (predictions == labels).sum().item()\n total += labels.size(0)\n \n return {\n 'loss': total_loss / len(dataloader),\n 'accuracy': correct / total * 100\n }\n\nprint(\"✓ Training functions defined\")"
686609 },
687610 {
688611 "cell_type" : " markdown" ,
11701093 },
11711094 "nbformat" : 4 ,
11721095 "nbformat_minor" : 4
1173- }
1096+ }
0 commit comments