|
93 | 93 | "metadata": {}, |
94 | 94 | "outputs": [], |
95 | 95 | "source": [ |
96 | | - "import logging\n", |
97 | | - "import os\n", |
98 | | - "import shutil\n", |
99 | | - "from pathlib import Path\n", |
100 | | - "\n", |
101 | | - "import kagglehub\n", |
102 | | - "import lightning as pl\n", |
103 | | - "import matplotlib.pyplot as plt\n", |
104 | | - "import numpy as np\n", |
105 | | - "import torch\n", |
106 | | - "import torch.nn as nn\n", |
107 | | - "import torch.optim as optim\n", |
108 | | - "import torchvision.transforms as transforms\n", |
109 | | - "from lightning import Trainer\n", |
110 | | - "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", |
111 | | - "from sklearn.metrics import classification_report, confusion_matrix, f1_score\n", |
112 | | - "from torch.utils.data import DataLoader, random_split\n", |
113 | | - "from torchvision.datasets import ImageFolder" |
| 96 | +"import logging\n", |
| 97 | +"import os\n", |
| 98 | +"import shutil\n", |
| 99 | +"from pathlib import Path\n", |
| 100 | +"\n", |
| 101 | +"import kagglehub\n", |
| 102 | +"import lightning as pl\n", |
| 103 | +"import matplotlib.pyplot as plt\n", |
| 104 | +"import numpy as np\n", |
| 105 | +"import torch\n", |
| 106 | +"import torch.nn as nn\n", |
| 107 | +"import torch.optim as optim\n", |
| 108 | +"import torchvision.transforms as transforms\n", |
| 109 | +"from lightning import Trainer\n", |
| 110 | +"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", |
| 111 | +"from sklearn.metrics import classification_report, confusion_matrix, f1_score\n", |
| 112 | +"from torch.utils.data import DataLoader, random_split\n", |
| 113 | +"from torchvision.datasets import ImageFolder\n", |
| 114 | +"\n", |
| 115 | +"logging.basicConfig(level=logging.INFO)\n", |
| 116 | +"logger = logging.getLogger(__name__)" |
114 | 117 | ] |
115 | 118 | }, |
116 | 119 | { |
|
127 | 130 | "id": "a0005", |
128 | 131 | "metadata": {}, |
129 | 132 | "outputs": [], |
130 | | - "source": [ |
| 133 | + "source": [ |
131 | 134 | "class PistachioDataModule(pl.LightningDataModule):\n", |
132 | 135 | " def __init__(\n", |
133 | 136 | " self,\n", |
134 | 137 | " batch_size=64,\n", |
135 | | - " num_workers=8,\n", |
136 | | - " data_dir=\"./pistachio_data\",\n", |
| 138 | + " num_workers=4,\n", |
| 139 | + " data_dir=\"./data\",\n", |
137 | 140 | " data_fraction=1.0,\n", |
138 | 141 | " val_split=0.15,\n", |
139 | 142 | " test_split=0.15,\n", |
140 | 143 | " image_size=28,\n", |
| 144 | + " seed=42,\n", |
141 | 145 | " ):\n", |
142 | 146 | " super().__init__()\n", |
143 | 147 | " self.batch_size = batch_size\n", |
|
147 | 151 | " self.val_split = val_split\n", |
148 | 152 | " self.test_split = test_split\n", |
149 | 153 | " self.image_size = image_size\n", |
| 154 | + " self.seed = seed\n", |
150 | 155 | " self.dataset_path = None\n", |
| 156 | + " self.class_names = None\n", |
151 | 157 | "\n", |
152 | 158 | " def prepare_data(self):\n", |
153 | 159 | " if not os.path.exists(self.data_dir) or not any(Path(self.data_dir).iterdir()):\n", |
154 | | - " print(\"Downloading pistachio dataset from Kaggle...\")\n", |
| 160 | + " logger.info(\"Descargando dataset de pistachos desde Kaggle...\")\n", |
155 | 161 | " raw_path = kagglehub.dataset_download(\n", |
156 | 162 | " \"muratkokludataset/pistachio-image-dataset\"\n", |
157 | 163 | " )\n", |
158 | 164 | " self.dataset_path = os.path.join(\n", |
159 | 165 | " raw_path, \"Pistachio_Image_Dataset\", \"Pistachio_Image_Dataset\"\n", |
160 | 166 | " )\n", |
161 | | - " print(f\"Dataset downloaded to: {raw_path}\")\n", |
| 167 | + " logger.info(\"Dataset descargado en: %s\", raw_path)\n", |
162 | 168 | "\n", |
163 | 169 | " os.makedirs(self.data_dir, exist_ok=True)\n", |
164 | 170 | " src = Path(self.dataset_path)\n", |
|
168 | 174 | " if item.is_dir() and not dest.exists():\n", |
169 | 175 | " shutil.copytree(item, dest)\n", |
170 | 176 | "\n", |
171 | | - " print(f\"Dataset prepared in: {self.data_dir}\")\n", |
| 177 | + " logger.info(\"Dataset copiado a: %s\", self.data_dir)\n", |
172 | 178 | "\n", |
173 | 179 | " def setup(self, stage=None):\n", |
174 | 180 | " transform = transforms.Compose(\n", |
|
180 | 186 | " )\n", |
181 | 187 | "\n", |
182 | 188 | " full_dataset = ImageFolder(self.data_dir, transform=transform)\n", |
| 189 | + " self.class_names = full_dataset.classes\n", |
183 | 190 | " dataset_size = len(full_dataset)\n", |
184 | 191 | "\n", |
| 192 | + " if self.data_fraction < 1.0:\n", |
| 193 | + " subset_size = int(dataset_size * self.data_fraction)\n", |
| 194 | + " full_dataset, _ = random_split(\n", |
| 195 | + " full_dataset,\n", |
| 196 | + " [subset_size, dataset_size - subset_size],\n", |
| 197 | + " generator=torch.Generator().manual_seed(self.seed),\n", |
| 198 | + " )\n", |
| 199 | + " dataset_size = subset_size\n", |
| 200 | + "\n", |
185 | 201 | " test_size = int(dataset_size * self.test_split)\n", |
186 | 202 | " val_size = int(dataset_size * self.val_split)\n", |
187 | 203 | " train_size = dataset_size - val_size - test_size\n", |
188 | 204 | "\n", |
189 | 205 | " self.train_dataset, self.val_dataset, self.test_dataset = random_split(\n", |
190 | 206 | " full_dataset,\n", |
191 | 207 | " [train_size, val_size, test_size],\n", |
192 | | - " generator=torch.Generator().manual_seed(42),\n", |
| 208 | + " generator=torch.Generator().manual_seed(self.seed),\n", |
193 | 209 | " )\n", |
194 | 210 | "\n", |
195 | | - " print(f\"Classes: {full_dataset.classes}\")\n", |
196 | | - " print(f\"Train size: {len(self.train_dataset)}\")\n", |
197 | | - " print(f\"Val size: {len(self.val_dataset)}\")\n", |
198 | | - " print(f\"Test size: {len(self.test_dataset)}\")\n", |
| 211 | + " logger.info(\"Clases: %s\", self.class_names)\n", |
| 212 | + " logger.info(\"Train: %d, Val: %d, Test: %d\",\n", |
| 213 | + " train_size, val_size, test_size)\n", |
199 | 214 | "\n", |
200 | 215 | " def train_dataloader(self):\n", |
201 | 216 | " return DataLoader(\n", |
|
265 | 280 | "dm.prepare_data()\n", |
266 | 281 | "dm.setup()\n", |
267 | 282 | "\n", |
268 | | - "class_names = dm.train_dataset.dataset.classes\n", |
| 283 | + "class_names = dm.class_names\n", |
269 | 284 | "\n", |
270 | 285 | "fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n", |
271 | 286 | "fig.suptitle(\"Samples\", fontsize=14)\n", |
|
382 | 397 | " self.log(\"test_acc\", acc, on_step=False, on_epoch=True, prog_bar=True)\n", |
383 | 398 | " self.test_predictions.extend(predicted.cpu().numpy())\n", |
384 | 399 | " self.test_labels.extend(y_batch.cpu().numpy())\n", |
385 | | - " return acc\n", |
| 400 | + " return {\"loss\": loss, \"acc\": acc}\n", |
386 | 401 | "\n", |
387 | 402 | " def on_test_epoch_start(self):\n", |
388 | 403 | " self.test_predictions = []\n", |
|
415 | 430 | "metadata": {}, |
416 | 431 | "outputs": [], |
417 | 432 | "source": [ |
418 | | - "class CNNBatchNormModel(BaselineModel):\n", |
419 | | - " def __init__(self, output_size=2, lr=0.001):\n", |
420 | | - " super().__init__(output_size=output_size, lr=lr)\n", |
421 | | - " self.save_hyperparameters()\n", |
| 433 | +"class CNNBatchNormModel(BaselineModel):\n", |
| 434 | +" def __init__(self, output_size=2, lr=0.001, dropout_rate=0.3):\n", |
| 435 | +" super().__init__(output_size=output_size, lr=lr)\n", |
| 436 | +" self.save_hyperparameters()\n", |
422 | 437 | " self.features = nn.Sequential(\n", |
423 | 438 | " nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),\n", |
424 | 439 | " nn.BatchNorm2d(32),\n", |
|
440 | 455 | " nn.Linear(64 * 7 * 7, 128),\n", |
441 | 456 | " nn.BatchNorm1d(128),\n", |
442 | 457 | " nn.ReLU(),\n", |
443 | | - " nn.Dropout(0.3),\n", |
444 | | - " nn.Linear(128, output_size),\n", |
445 | | - " )\n", |
446 | | - "\n", |
447 | | - " def forward(self, x):\n", |
448 | | - " x = self.features(x)\n", |
449 | | - " return self.classifier(x)\n", |
450 | | - "\n", |
451 | | - " def configure_optimizers(self):\n", |
452 | | - " return optim.Adam(self.parameters(), lr=self.hparams.lr)" |
| 458 | +" nn.Dropout(dropout_rate),\n", |
| 459 | +" nn.Linear(128, output_size),\n", |
| 460 | +" )\n", |
| 461 | +"\n", |
| 462 | +" def forward(self, x):\n", |
| 463 | +" x = self.features(x)\n", |
| 464 | +" return self.classifier(x)\n", |
| 465 | +"\n", |
| 466 | +" def configure_optimizers(self):\n", |
| 467 | +" return optim.Adam(self.parameters(), lr=self.hparams.lr)\n", |
| 468 | +"\n", |
| 469 | +"\n", |
| 470 | +"class CNNDropoutModel(BaselineModel):\n", |
| 471 | +" \"\"\"CNN con Dropout en lugar de BatchNorm para comparacion.\"\"\"\n", |
| 472 | +" def __init__(self, output_size=2, lr=0.001, dropout_rate=0.3):\n", |
| 473 | +" super().__init__(output_size=output_size, lr=lr)\n", |
| 474 | +" self.save_hyperparameters()\n", |
| 475 | +" self.features = nn.Sequential(\n", |
| 476 | +" nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),\n", |
| 477 | +" nn.ReLU(),\n", |
| 478 | +" nn.Dropout2d(dropout_rate),\n", |
| 479 | +" nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),\n", |
| 480 | +" nn.ReLU(),\n", |
| 481 | +" nn.Dropout2d(dropout_rate),\n", |
| 482 | +" nn.MaxPool2d(2),\n", |
| 483 | +" nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),\n", |
| 484 | +" nn.ReLU(),\n", |
| 485 | +" nn.Dropout2d(dropout_rate),\n", |
| 486 | +" nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),\n", |
| 487 | +" nn.ReLU(),\n", |
| 488 | +" nn.Dropout2d(dropout_rate),\n", |
| 489 | +" nn.MaxPool2d(2),\n", |
| 490 | +" )\n", |
| 491 | +" self.classifier = nn.Sequential(\n", |
| 492 | +" nn.Flatten(),\n", |
| 493 | +" nn.Linear(64 * 7 * 7, 128),\n", |
| 494 | +" nn.ReLU(),\n", |
| 495 | +" nn.Dropout(dropout_rate),\n", |
| 496 | +" nn.Linear(128, output_size),\n", |
| 497 | +" )\n", |
| 498 | +"\n", |
| 499 | +" def forward(self, x):\n", |
| 500 | +" x = self.features(x)\n", |
| 501 | +" return self.classifier(x)\n", |
| 502 | +"\n", |
| 503 | +" def configure_optimizers(self):\n", |
| 504 | +" return optim.Adam(self.parameters(), lr=self.hparams.lr)" |
453 | 505 | ] |
454 | 506 | }, |
455 | 507 | { |
|
707 | 759 | ] |
708 | 760 | } |
709 | 761 | ], |
710 | | - "source": [ |
| 762 | + "source": [ |
711 | 763 | "data_module = PistachioDataModule(data_fraction=1.0)\n", |
712 | 764 | "lrs = [0.01]\n", |
713 | | - "results = train_models(CNNBatchNormModel, lrs, data_module)\n", |
714 | | - "print_model_results(results)" |
| 765 | + "\n", |
| 766 | + "logging.info(\"Entrenando CNNBatchNormModel...\")\n", |
| 767 | + "results_bn = train_models(CNNBatchNormModel, lrs, data_module)\n", |
| 768 | + "print_model_results(results_bn)\n", |
| 769 | + "\n", |
| 770 | + "logging.info(\"Entrenando CNNDropoutModel...\")\n", |
| 771 | + "results_drop = train_models(CNNDropoutModel, lrs, data_module)\n", |
| 772 | + "print_model_results(results_drop)" |
715 | 773 | ] |
716 | 774 | } |
717 | 775 | ], |
|
0 commit comments