Skip to content

Commit ab48824

Browse files
committed
Some updates in code after workshop
1 parent c60d940 commit ab48824

3 files changed

Lines changed: 111 additions & 54 deletions

File tree

src/2022_06_10_analyze_rock_classification.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Script for analysis of rock type classification model"""
1+
"""Script for analysis of rock type classification model
2+
3+
Run tensorboard calling this in your terminal:
4+
tensorboard --logdir Reports/tensorboard_logs
5+
6+
"""
27

38
import torch
49
from pathlib import Path
@@ -21,8 +26,9 @@
2126
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
2227
fig, ax = plt.subplots(figsize=(15,15))
2328
disp.plot(cmap="viridis", ax=ax, values_format=".2f")
29+
# ax.set_xticklabels(perf["class_names"])
2430
ax.set_xlabel("Predicted rocktype")
2531
ax.set_ylabel("True rocktype")
2632
plt.tight_layout()
27-
plt.savefig("Figures/confusion_matrix_rocktypes.png", dpi=600)
28-
# plt.show()
33+
plt.savefig(Path("Figures/confusion_matrix_rocktypes.png"), dpi=600)
34+
plt.show()

src/2022_06_10_image_classification.py

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88
https://www.kaggle.com/datasets/neelgajare/rocks-dataset
99
1010
TODO: Try to classify better with:
11-
- better tuning of parameters
12-
- weighting in loss function
11+
- Tuning of hyperparameters
1312
- different learning rate in backbone
1413
- other backbones
15-
- Pytorch Lightning
16-
- Better augmentation techniques. Cropping?
14+
- more complicated head-network
15+
- Pytorch Lightning or Keras implementations
16+
- Better augmentation techniques. Cropping? Filters?
1717
- Changes in dataset. Remove obvious crazy images.
18+
- Cross validation and other splits
19+
20+
Dataset need to be structured like:
21+
ROOT > Classname > filename.jpg
1822
1923
@author: Tom F. Hansen, Georg H. Erharter
2024
"""
@@ -29,35 +33,41 @@
2933
from torch.optim import lr_scheduler
3034
from torch.optim.optimizer import Optimizer
3135
from torchvision import datasets, models, transforms
36+
import torchvision
3237

3338
import numpy as np
39+
import numpy.typing as npt
3440
from pathlib import Path
3541
from sklearn.model_selection import train_test_split
3642
from sklearn.metrics import balanced_accuracy_score
43+
from sklearn.utils import class_weight
3744
import pickle
3845
from rich.traceback import install
3946
from rich.progress import track
4047
from typing import Tuple
48+
from utility import imshow
4149

4250
# SETUPS
4351
######################################################################################
4452

4553
# presenting better error messages using rich
4654
install()
4755

48-
ROOT = Path.cwd()
4956
DATA_DIR = Path(
5057
"/mnt/c/Users/TFH/NGI/TG Machine Learning - General/2022 ML workshop series/datasets/Rocks")
5158
TEST_SIZE = 0.3
5259
NUM_WORKERS = 12
60+
# remember to place both model and data on the same device
5361
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
62+
SAVE_PERFORMANCE = True
63+
SHOW_BATCH = True
5464

5565
# Hyperparameters
5666
BATCH_SIZE = 64
5767
LR = 0.01
5868
MOMENTUM = 0.9
59-
STEP_SIZE = 3 #decay lr every xxx epoch
60-
GAMMA = 0.1 #decay factor for multiplication
69+
STEP_SIZE = 4 #decay lr every xxx epoch
70+
GAMMA = 0.3 #decay factor for multiplication
6171
NUM_EPOCHS = 10
6272

6373
# stop randomness for model comparison
@@ -73,7 +83,7 @@
7383
transforms.RandomResizedCrop(224),
7484
transforms.RandomHorizontalFlip(),
7585
transforms.ToTensor(),
76-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
86+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # values for transfer learning model
7787
]),
7888
'test': transforms.Compose([
7989
transforms.Resize(256),
@@ -82,11 +92,12 @@
8292
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
8393
]),
8494
}
95+
8596
# due to lazy_loading I don't allocate more memory of using the same dataset here
8697
train_dataset = datasets.ImageFolder(root=DATA_DIR, transform=data_transforms["train"])
8798
test_dataset = datasets.ImageFolder(root=DATA_DIR, transform=data_transforms["test"])
8899

89-
# create splitting indices for samples - could also use SubsetRandomSampler
100+
# create splitting indices for samples
90101
num_classes = len(train_dataset.classes)
91102
indices = np.arange(len(train_dataset))
92103
labels = train_dataset.targets
@@ -96,15 +107,32 @@
96107
train_set = Subset(train_dataset,indices=train_ind)
97108
test_set = Subset(test_dataset,indices=test_ind)
98109

110+
111+
# Testing for comparison with the ants and bees dataset. You should get accuracies over 95% on that dataset.
112+
# Note that this dataset has just 245 images in training set and still the transfer learning model works well.
113+
# You can just uncomment this code and it should run, after you have updated with your path to the dataset
114+
# DATA_DIR_TRAIN = Path("/home/tfha/datasets/hymenoptera_data/train")
115+
# DATA_DIR_VAL = Path("/home/tfha/datasets/hymenoptera_data/val")
116+
117+
# train_dataset = datasets.ImageFolder(root=DATA_DIR_TRAIN, transform=data_transforms["train"])
118+
# test_dataset = datasets.ImageFolder(root=DATA_DIR_VAL, transform=data_transforms["test"])
119+
# num_classes = len(train_dataset.classes)
120+
# train_set = train_dataset
121+
# test_set = test_dataset
122+
123+
99124
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,num_workers=NUM_WORKERS)
100125
test_dataloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKERS)
101126

102-
# sample = iter(train_dataloader).next()
103-
# sample[0] # batch of images
104-
# sample[1] # batch of labels
105-
# image1 = sample[0][0]
106-
# image1.shape
127+
#VISUALIZE A BATCH OF DATA
128+
######################################################################################
107129

130+
if SHOW_BATCH:
131+
# Get a batch of training data
132+
inputs, classes = next(iter(train_dataloader))
133+
# Make a grid from batch
134+
out = torchvision.utils.make_grid(inputs)
135+
imshow(out, title=[train_dataset.classes[x] for x in classes])
108136

109137
# METHODS FOR TRAINING AND EVALUATION
110138
######################################################################################
@@ -114,17 +142,17 @@ def train_epoch(
114142
optimizer: Optimizer,
115143
loss_function: nn.CrossEntropyLoss,
116144
dataloader: DataLoader,
117-
):
145+
)->Tuple[float, npt.NDArray, npt.NDArray]:
118146
"""Train model for all samples in one epoch.
119147
Returning loss, labels, predictions"""
120148

121149
epoch_loss = []
122-
epoch_labels = np.array(())
123-
epoch_preds = np.array(())
150+
epoch_labels: npt.NDArray = np.array(())
151+
epoch_preds: npt.NDArray = np.array(())
124152

125153
# looping over all batches of samples
126154
for images, labels in track(dataloader,description="Training batches: "):
127-
images = images.to(device)
155+
images = images.to(device) #sending data to gpu or cpu
128156
labels = labels.to(device)
129157

130158
logits = model(images) # forward pass
@@ -148,13 +176,13 @@ def test_epoch(
148176
model: nn.Module,
149177
loss_function: nn.CrossEntropyLoss,
150178
dataloader: DataLoader,
151-
)->Tuple[float, list, list]:
179+
)->Tuple[float, npt.NDArray, npt.NDArray]:
152180
"""Test model for all samples in one epoch.
153181
Returning loss, labels, predictions"""
154182

155183
epoch_loss = []
156-
epoch_labels = np.array(())
157-
epoch_preds = np.array(())
184+
epoch_labels: npt.NDArray = np.array(())
185+
epoch_preds: npt.NDArray = np.array(())
158186

159187
# looping over all batches of samples
160188
for images, labels in track(dataloader,description="Testing batches: "):
@@ -177,8 +205,9 @@ def test_epoch(
177205
# DEFINE NETWORK, LOSSFUNCTION, OPTIMIZER, LR-SCHEDULER
178206
######################################################################################
179207
print(f"cuda is available: {torch.cuda.is_available()}. Device is {DEVICE}")
180-
model = models.resnet18(pretrained=True) #pretrained on 1000-class Imagenet database
181-
#turn of gradient update (learning) in backbone model
208+
model = models.resnet50(pretrained=True) #pretrained on 1000-class Imagenet database
209+
210+
#turn off gradient update (learning) in backbone model
182211
for param in model.parameters():
183212
param.requires_grad = False
184213

@@ -188,7 +217,12 @@ def test_epoch(
188217

189218
model.to(DEVICE)
190219

191-
loss_function = nn.CrossEntropyLoss()
220+
weights = class_weight.compute_class_weight(class_weight="balanced",
221+
classes=np.unique(train_dataset.targets),
222+
y=train_dataset.targets)
223+
weights = torch.tensor(weights).to(DEVICE)
224+
225+
loss_function = nn.CrossEntropyLoss(weight=weights.float())
192226
optimizer = optim.SGD(model.fc.parameters(), lr=LR, momentum=MOMENTUM)
193227
LR_scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
194228

@@ -199,46 +233,51 @@ def test_epoch(
199233
performance = []
200234
performance_path = Path("Reports/rock_classification_performance.pkl")
201235
tensorboard_path = Path("Reports/tensorboard_logs")
202-
writer = SummaryWriter(log_dir=tensorboard_path)
236+
writer = SummaryWriter(log_dir=tensorboard_path) # defines the Tensorboard writer
203237

204238
for epoch in range(NUM_EPOCHS):
205239
# train
206240
model.train() # sets model in training mode. Turns on gradient update
207241
loss_train, train_labels, train_predictions = train_epoch(DEVICE,model, optimizer, loss_function,train_dataloader)
208-
# acc_train = accuracy(train_predictions, train_labels, average="macro") # macro is balanced accuracy
209242
acc_train = balanced_accuracy_score(train_labels, train_predictions)
210243

211244
# test
212245
model.eval() # Freeze model weights. No model update
213246
loss_test, labels, predictions = test_epoch(DEVICE, model, loss_function, test_dataloader)
214247
acc_test = balanced_accuracy_score(labels, predictions)
215248

249+
LR_scheduler.step()
250+
216251
# report metrics
217-
print(f"Train-loss: {loss_train:.3f}. Train-acc: {acc_train:.2f}. \
218-
Test-loss: {loss_test:.3f}. Test-acc: {acc_test:.3f}")
219-
220-
# add data to Tensorboard for live reporting
221-
writer.add_scalars("Loss development",{
222-
"Loss train": loss_train,
223-
"Loss test": loss_test
224-
}, global_step=epoch)
225-
writer.add_scalars("Accuracy development",{
226-
"Accuracy train": acc_train,
227-
"Accuracy test": acc_test
228-
}, global_step=epoch)
229-
230-
# save data to pickle file every epoch. Load data for later analysis
231-
performance.append({
232-
'epoch': epoch + 1, #epoch counts from 0
233-
'train_loss': loss_train,
234-
'train_acc': acc_train,
235-
'test_loss': loss_test,
236-
'test_acc': acc_test,
237-
'test_labels':labels,
238-
'test_predictions':predictions
239-
})
240-
pickle.dump(performance, open(performance_path, 'wb'))
252+
current_lr = LR_scheduler.get_last_lr()
253+
print(f"Epoch: {epoch}. Train-loss: {loss_train:.3f}. Train-acc: {acc_train:.2f}. Test-loss: {loss_test:.3f}. Test-acc: {acc_test:.3f}. LR: {current_lr}")
254+
255+
256+
if SAVE_PERFORMANCE:
257+
# add data to Tensorboard for inspection of training development
258+
writer.add_scalars("Loss development",{
259+
"Loss train": loss_train,
260+
"Loss test": loss_test
261+
}, global_step=epoch)
262+
writer.add_scalars("Accuracy development",{
263+
"Accuracy train": acc_train,
264+
"Accuracy test": acc_test
265+
}, global_step=epoch)
266+
267+
# save data to pickle file every epoch. Load data for later analysis of results
268+
performance.append({
269+
'epoch': epoch + 1, #epoch counts from 0
270+
'train_loss': loss_train,
271+
'train_acc': acc_train,
272+
'test_loss': loss_test,
273+
'test_acc': acc_test,
274+
'test_labels':labels,
275+
'test_predictions':predictions,
276+
'class_names':train_dataset.classes
277+
})
278+
pickle.dump(performance, open(performance_path, 'wb'))
241279

242280
writer.close()
243281
# save trained model for later predictions and analysis
244-
torch.save(model, Path("Reports/rock_model.pth"))
282+
if SAVE_PERFORMANCE:
283+
torch.save(model, Path("Reports/rock_model.pth"))

src/utility.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ def plot_features(self, X: npt.NDArray)->None:
5757
plt.show()
5858

5959

60+
def imshow(inp, title=None):
61+
"""Imshow for Tensor."""
62+
inp = inp.numpy().transpose((1, 2, 0))
63+
mean = np.array([0.485, 0.456, 0.406])
64+
std = np.array([0.229, 0.224, 0.225])
65+
inp = std * inp + mean
66+
inp = np.clip(inp, 0, 1)
67+
plt.imshow(inp)
68+
if title is not None:
69+
plt.title(title)
70+
plt.pause(0.001) # pause a bit so that plots are updated
71+
6072

6173
def examplify():
6274
"""Examplifies functionality"""

0 commit comments

Comments
 (0)