Skip to content

Commit 2e5b3b0

Browse files
committed
Fix: notebook synced with src, added CNNDropoutModel, cleanup config
1 parent 6f949de commit 2e5b3b0

8 files changed

Lines changed: 148 additions & 78 deletions

File tree

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ Clasificador de pistachos (Kirmizi vs Siirt) usando Deep Learning con PyTorch Li
3030

3131
## Configuracion
3232

33-
Editar `config/configuracion.yaml`:
33+
Editar `config/configuracion.yaml` (solo hiperparametros del modelo):
3434

3535
- `model_type`: `cnn_batchnorm` o `cnn_dropout`
3636
- `learning_rate`, `batch_size`, `max_epochs`, `patience`
37-
- `image_size` (default: 28)
37+
- `dropout_rate`
3838
- `wandb_project`: nombre del proyecto en W&B
3939

40+
El resto (`output_size=2`, `image_size=28`, splits, etc.) son constantes de arquitectura en el codigo.
41+
4042
## Ejecucion local
4143

4244
```bash
@@ -48,6 +50,9 @@ python src/main.py configuracion.yaml
4850
# Sin W&B (offline)
4951
python src/main.py configuracion.yaml --no-wandb
5052

53+
# Smoke test (10% datos, 3 epocas)
54+
python src/main.py configuracion.yaml --no-wandb --data-fraction 0.1
55+
5156
# W&B Sweep (grid search)
5257
python src/main.py configuracion.yaml --sweep
5358

config/configuracion.yaml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
1-
# Parametros de entrenamiento
1+
# Hiperparametros del modelo (se modifican via sweep)
22
semilla: 42
33
max_epochs: 50
44
patience: 5
55
batch_size: 64
66
learning_rate: 0.001
7-
8-
# Parametros del modelo
97
model_type: "cnn_batchnorm"
10-
output_size: 2
118
dropout_rate: 0.3
129

13-
# Parametros de datos
14-
image_size: 28
15-
data_fraction: 1.0
16-
val_split: 0.15
17-
test_split: 0.15
18-
num_workers: 4
19-
data_dir: "./data"
20-
2110
# W&B
2211
wandb_project: "pistachio-mlops"

notebook/pistachio.ipynb

Lines changed: 106 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,27 @@
9393
"metadata": {},
9494
"outputs": [],
9595
"source": [
96-
"import logging\n",
97-
"import os\n",
98-
"import shutil\n",
99-
"from pathlib import Path\n",
100-
"\n",
101-
"import kagglehub\n",
102-
"import lightning as pl\n",
103-
"import matplotlib.pyplot as plt\n",
104-
"import numpy as np\n",
105-
"import torch\n",
106-
"import torch.nn as nn\n",
107-
"import torch.optim as optim\n",
108-
"import torchvision.transforms as transforms\n",
109-
"from lightning import Trainer\n",
110-
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
111-
"from sklearn.metrics import classification_report, confusion_matrix, f1_score\n",
112-
"from torch.utils.data import DataLoader, random_split\n",
113-
"from torchvision.datasets import ImageFolder"
96+
"import logging\n",
97+
"import os\n",
98+
"import shutil\n",
99+
"from pathlib import Path\n",
100+
"\n",
101+
"import kagglehub\n",
102+
"import lightning as pl\n",
103+
"import matplotlib.pyplot as plt\n",
104+
"import numpy as np\n",
105+
"import torch\n",
106+
"import torch.nn as nn\n",
107+
"import torch.optim as optim\n",
108+
"import torchvision.transforms as transforms\n",
109+
"from lightning import Trainer\n",
110+
"from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n",
111+
"from sklearn.metrics import classification_report, confusion_matrix, f1_score\n",
112+
"from torch.utils.data import DataLoader, random_split\n",
113+
"from torchvision.datasets import ImageFolder\n",
114+
"\n",
115+
"logging.basicConfig(level=logging.INFO)\n",
116+
"logger = logging.getLogger(__name__)"
114117
]
115118
},
116119
{
@@ -127,17 +130,18 @@
127130
"id": "a0005",
128131
"metadata": {},
129132
"outputs": [],
130-
"source": [
133+
"source": [
131134
"class PistachioDataModule(pl.LightningDataModule):\n",
132135
" def __init__(\n",
133136
" self,\n",
134137
" batch_size=64,\n",
135-
" num_workers=8,\n",
136-
" data_dir=\"./pistachio_data\",\n",
138+
" num_workers=4,\n",
139+
" data_dir=\"./data\",\n",
137140
" data_fraction=1.0,\n",
138141
" val_split=0.15,\n",
139142
" test_split=0.15,\n",
140143
" image_size=28,\n",
144+
" seed=42,\n",
141145
" ):\n",
142146
" super().__init__()\n",
143147
" self.batch_size = batch_size\n",
@@ -147,18 +151,20 @@
147151
" self.val_split = val_split\n",
148152
" self.test_split = test_split\n",
149153
" self.image_size = image_size\n",
154+
" self.seed = seed\n",
150155
" self.dataset_path = None\n",
156+
" self.class_names = None\n",
151157
"\n",
152158
" def prepare_data(self):\n",
153159
" if not os.path.exists(self.data_dir) or not any(Path(self.data_dir).iterdir()):\n",
154-
" print(\"Downloading pistachio dataset from Kaggle...\")\n",
160+
" logger.info(\"Descargando dataset de pistachos desde Kaggle...\")\n",
155161
" raw_path = kagglehub.dataset_download(\n",
156162
" \"muratkokludataset/pistachio-image-dataset\"\n",
157163
" )\n",
158164
" self.dataset_path = os.path.join(\n",
159165
" raw_path, \"Pistachio_Image_Dataset\", \"Pistachio_Image_Dataset\"\n",
160166
" )\n",
161-
" print(f\"Dataset downloaded to: {raw_path}\")\n",
167+
" logger.info(\"Dataset descargado en: %s\", raw_path)\n",
162168
"\n",
163169
" os.makedirs(self.data_dir, exist_ok=True)\n",
164170
" src = Path(self.dataset_path)\n",
@@ -168,7 +174,7 @@
168174
" if item.is_dir() and not dest.exists():\n",
169175
" shutil.copytree(item, dest)\n",
170176
"\n",
171-
" print(f\"Dataset prepared in: {self.data_dir}\")\n",
177+
" logger.info(\"Dataset copiado a: %s\", self.data_dir)\n",
172178
"\n",
173179
" def setup(self, stage=None):\n",
174180
" transform = transforms.Compose(\n",
@@ -180,22 +186,31 @@
180186
" )\n",
181187
"\n",
182188
" full_dataset = ImageFolder(self.data_dir, transform=transform)\n",
189+
" self.class_names = full_dataset.classes\n",
183190
" dataset_size = len(full_dataset)\n",
184191
"\n",
192+
" if self.data_fraction < 1.0:\n",
193+
" subset_size = int(dataset_size * self.data_fraction)\n",
194+
" full_dataset, _ = random_split(\n",
195+
" full_dataset,\n",
196+
" [subset_size, dataset_size - subset_size],\n",
197+
" generator=torch.Generator().manual_seed(self.seed),\n",
198+
" )\n",
199+
" dataset_size = subset_size\n",
200+
"\n",
185201
" test_size = int(dataset_size * self.test_split)\n",
186202
" val_size = int(dataset_size * self.val_split)\n",
187203
" train_size = dataset_size - val_size - test_size\n",
188204
"\n",
189205
" self.train_dataset, self.val_dataset, self.test_dataset = random_split(\n",
190206
" full_dataset,\n",
191207
" [train_size, val_size, test_size],\n",
192-
" generator=torch.Generator().manual_seed(42),\n",
208+
" generator=torch.Generator().manual_seed(self.seed),\n",
193209
" )\n",
194210
"\n",
195-
" print(f\"Classes: {full_dataset.classes}\")\n",
196-
" print(f\"Train size: {len(self.train_dataset)}\")\n",
197-
" print(f\"Val size: {len(self.val_dataset)}\")\n",
198-
" print(f\"Test size: {len(self.test_dataset)}\")\n",
211+
" logger.info(\"Clases: %s\", self.class_names)\n",
212+
" logger.info(\"Train: %d, Val: %d, Test: %d\",\n",
213+
" train_size, val_size, test_size)\n",
199214
"\n",
200215
" def train_dataloader(self):\n",
201216
" return DataLoader(\n",
@@ -265,7 +280,7 @@
265280
"dm.prepare_data()\n",
266281
"dm.setup()\n",
267282
"\n",
268-
"class_names = dm.train_dataset.dataset.classes\n",
283+
"class_names = dm.class_names\n",
269284
"\n",
270285
"fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n",
271286
"fig.suptitle(\"Samples\", fontsize=14)\n",
@@ -382,7 +397,7 @@
382397
" self.log(\"test_acc\", acc, on_step=False, on_epoch=True, prog_bar=True)\n",
383398
" self.test_predictions.extend(predicted.cpu().numpy())\n",
384399
" self.test_labels.extend(y_batch.cpu().numpy())\n",
385-
" return acc\n",
400+
" return {\"loss\": loss, \"acc\": acc}\n",
386401
"\n",
387402
" def on_test_epoch_start(self):\n",
388403
" self.test_predictions = []\n",
@@ -415,10 +430,10 @@
415430
"metadata": {},
416431
"outputs": [],
417432
"source": [
418-
"class CNNBatchNormModel(BaselineModel):\n",
419-
" def __init__(self, output_size=2, lr=0.001):\n",
420-
" super().__init__(output_size=output_size, lr=lr)\n",
421-
" self.save_hyperparameters()\n",
433+
"class CNNBatchNormModel(BaselineModel):\n",
434+
" def __init__(self, output_size=2, lr=0.001, dropout_rate=0.3):\n",
435+
" super().__init__(output_size=output_size, lr=lr)\n",
436+
" self.save_hyperparameters()\n",
422437
" self.features = nn.Sequential(\n",
423438
" nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),\n",
424439
" nn.BatchNorm2d(32),\n",
@@ -440,16 +455,53 @@
440455
" nn.Linear(64 * 7 * 7, 128),\n",
441456
" nn.BatchNorm1d(128),\n",
442457
" nn.ReLU(),\n",
443-
" nn.Dropout(0.3),\n",
444-
" nn.Linear(128, output_size),\n",
445-
" )\n",
446-
"\n",
447-
" def forward(self, x):\n",
448-
" x = self.features(x)\n",
449-
" return self.classifier(x)\n",
450-
"\n",
451-
" def configure_optimizers(self):\n",
452-
" return optim.Adam(self.parameters(), lr=self.hparams.lr)"
458+
" nn.Dropout(dropout_rate),\n",
459+
" nn.Linear(128, output_size),\n",
460+
" )\n",
461+
"\n",
462+
" def forward(self, x):\n",
463+
" x = self.features(x)\n",
464+
" return self.classifier(x)\n",
465+
"\n",
466+
" def configure_optimizers(self):\n",
467+
" return optim.Adam(self.parameters(), lr=self.hparams.lr)\n",
468+
"\n",
469+
"\n",
470+
"class CNNDropoutModel(BaselineModel):\n",
471+
" \"\"\"CNN con Dropout en lugar de BatchNorm para comparacion.\"\"\"\n",
472+
" def __init__(self, output_size=2, lr=0.001, dropout_rate=0.3):\n",
473+
" super().__init__(output_size=output_size, lr=lr)\n",
474+
" self.save_hyperparameters()\n",
475+
" self.features = nn.Sequential(\n",
476+
" nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),\n",
477+
" nn.ReLU(),\n",
478+
" nn.Dropout2d(dropout_rate),\n",
479+
" nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),\n",
480+
" nn.ReLU(),\n",
481+
" nn.Dropout2d(dropout_rate),\n",
482+
" nn.MaxPool2d(2),\n",
483+
" nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),\n",
484+
" nn.ReLU(),\n",
485+
" nn.Dropout2d(dropout_rate),\n",
486+
" nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),\n",
487+
" nn.ReLU(),\n",
488+
" nn.Dropout2d(dropout_rate),\n",
489+
" nn.MaxPool2d(2),\n",
490+
" )\n",
491+
" self.classifier = nn.Sequential(\n",
492+
" nn.Flatten(),\n",
493+
" nn.Linear(64 * 7 * 7, 128),\n",
494+
" nn.ReLU(),\n",
495+
" nn.Dropout(dropout_rate),\n",
496+
" nn.Linear(128, output_size),\n",
497+
" )\n",
498+
"\n",
499+
" def forward(self, x):\n",
500+
" x = self.features(x)\n",
501+
" return self.classifier(x)\n",
502+
"\n",
503+
" def configure_optimizers(self):\n",
504+
" return optim.Adam(self.parameters(), lr=self.hparams.lr)"
453505
]
454506
},
455507
{
@@ -707,11 +759,17 @@
707759
]
708760
}
709761
],
710-
"source": [
762+
"source": [
711763
"data_module = PistachioDataModule(data_fraction=1.0)\n",
712764
"lrs = [0.01]\n",
713-
"results = train_models(CNNBatchNormModel, lrs, data_module)\n",
714-
"print_model_results(results)"
765+
"\n",
766+
"logging.info(\"Entrenando CNNBatchNormModel...\")\n",
767+
"results_bn = train_models(CNNBatchNormModel, lrs, data_module)\n",
768+
"print_model_results(results_bn)\n",
769+
"\n",
770+
"logging.info(\"Entrenando CNNDropoutModel...\")\n",
771+
"results_drop = train_models(CNNDropoutModel, lrs, data_module)\n",
772+
"print_model_results(results_drop)"
715773
]
716774
}
717775
],

src/api_inferencia.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from src.model import get_model
1010
from src.utils import get_project_root, load_config
1111

12+
OUTPUT_SIZE = 2
13+
IMAGE_SIZE = 28
14+
CLASS_NAMES = ["Kirmizi_Pistachio", "Siirt_Pistachio"]
15+
1216

1317
@asynccontextmanager
1418
async def lifespan(app: FastAPI):
@@ -18,8 +22,8 @@ async def lifespan(app: FastAPI):
1822
model_type = parametros.get("model_type", "cnn_batchnorm")
1923
model = get_model(
2024
model_type=model_type,
21-
output_size=int(parametros["output_size"]),
22-
lr=float(parametros.get("learning_rate", 0.001)),
25+
output_size=OUTPUT_SIZE,
26+
lr=0.001,
2327
dropout_rate=float(parametros.get("dropout_rate", 0.3)),
2428
)
2529

@@ -29,8 +33,8 @@ async def lifespan(app: FastAPI):
2933
model.eval()
3034

3135
app.state.model = model
32-
app.state.class_names = ["Kirmizi_Pistachio", "Siirt_Pistachio"]
33-
app.state.image_size = int(parametros.get("image_size", 28))
36+
app.state.class_names = CLASS_NAMES
37+
app.state.image_size = IMAGE_SIZE
3438

3539
yield
3640

src/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def main():
3737
action="store_true",
3838
help="Desactivar logging a W&B",
3939
)
40+
parser.add_argument(
41+
"--data-fraction",
42+
type=float,
43+
default=1.0,
44+
help="Fraccion de datos a usar (0-1, util para smoke tests)",
45+
)
4046
args = parser.parse_args()
4147

4248
setup_logging("info")
@@ -52,7 +58,7 @@ def main():
5258
count=args.count,
5359
)
5460
else:
55-
train_model(config, wandb_log=not args.no_wandb)
61+
train_model(config, wandb_log=not args.no_wandb, data_fraction=args.data_fraction)
5662

5763

5864
if __name__ == "__main__":

src/sweep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"learning_rate": {"values": [0.001, 0.01]},
1717
"batch_size": {"values": [32, 64]},
1818
"model_type": {"values": ["cnn_batchnorm", "cnn_dropout"]},
19-
"image_size": {"values": [28]},
19+
"dropout_rate": {"values": [0.3]},
2020
},
2121
}
2222

0 commit comments

Comments
 (0)