-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_resnet18_xray.py
More file actions
119 lines (98 loc) · 3.56 KB
/
train_resnet18_xray.py
File metadata and controls
119 lines (98 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""ResNet-18 fine-tuning script for X-ray classification."""
import os
from dataclasses import dataclass
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm
@dataclass
class Config:
data_dir: str = "./xray_data"
batch_size: int = 32
num_workers: int = 4
num_epochs: int = 5
lr: float = 1e-4
weight_decay: float = 1e-4
num_classes: int = 2
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def get_dataloaders(cfg: Config):
transform_train = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
)
transform_val = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
)
train_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "train"), transform=transform_train)
val_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "val"), transform=transform_val)
train_loader = DataLoader(
train_ds,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_ds,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
)
return train_loader, val_loader
def build_model(num_classes: int) -> nn.Module:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
return model
def train(cfg: Config) -> None:
device = torch.device(cfg.device)
train_loader, val_loader = get_dataloaders(cfg)
model = build_model(cfg.num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.num_epochs)
best_acc = 0.0
for epoch in range(1, cfg.num_epochs + 1):
model.train()
pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
acc = correct / total
print(f"Validation accuracy: {acc:.4f}")
if acc > best_acc:
best_acc = acc
os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), "checkpoints/resnet18_xray_best.pt")
print("Saved new best model.")
if __name__ == "__main__":
cfg = Config()
train(cfg)