Skip to content

Commit 7dc5ff1

Browse files
committed
Refactor dataset handling and enhance evaluation: update NatureFromFolder class for test dataset size, add evaluate_and_plot function in test.py, and improve config settings.
1 parent 588c5d2 commit 7dc5ff1

File tree

5 files changed

+109
-24
lines changed

5 files changed

+109
-24
lines changed

data/results.csv

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ experiment,time,psnr,ssim,loss_critic,loss_generator
22
7_first_experiment,00:16:50,11.103890439714435,0.34202152,1.099349385527171,0.1359923942584425
33
8_model_compile,00:16:49,19.3652867520302,0.4844975471496582,-1.2957792903326415,0.0911528160242729
44
9_transform,00:17:01,1.7978309320343853,-0.0438352711498737,-7.376533726493926,0.7896081656897993
5-
10_normal,00:16:51,8.219669060046096,0.31684741377830505,-31.28607681788266,1.967616661362452
6-
14_multistep,04:07:37,23.004348053166062,0.5242958068847656,-2.5751562796301553,-0.035972834620931246
7-
15_cosine,04:02:15,21.85281842397048,0.5211549997329712,-1.9651767669255815,0.23522654886764408
8-
15_cosine_warm_restarts,04:06:25,22.68296964537476,0.5351459980010986,-1.6454665754569295,-0.03318920495223519
9-
17_step,04:06:31,14.040152634682375,0.2926912009716034,,
10-
18_gan_custom,04:05:33,23.52549501689562,0.5291714072227478,-2.1092974134947307,-0.043675970456390975
11-
19_exponential,04:13:36,22.58134529896099,0.5127521753311157,-1.2576447154311854,0.07773792746407023
12-
20_lambda,04:07:58,22.653632716710888,0.5084766745567322,-1.1908818194046753,0.12637249500213202
5+
10_normal,00:16:51,8.219669060046096,0.316847413778305,-31.28607681788266,1.967616661362452
6+
14_multistep,04:07:37,23.004348053166066,0.5242958068847656,-2.5751562796301557,-0.0359728346209312
7+
15_cosine,04:02:15,21.85281842397048,0.5211549997329712,-1.9651767669255813,0.235226548867644
8+
15_cosine_warm_restarts,04:06:25,22.68296964537476,0.5351459980010986,-1.6454665754569295,-0.0331892049522351
9+
17_step,04:06:31,14.040152634682377,0.2926912009716034,,
10+
18_gan_custom,04:05:33,23.52549501689562,0.5291714072227478,-2.1092974134947307,-0.0436759704563909
11+
19_exponential,04:13:36,22.58134529896099,0.5127521753311157,-1.2576447154311854,0.0777379274640702
12+
20_lambda,04:07:58,22.653632716710888,0.5084766745567322,-1.1908818194046753,0.126372495002132
13+
21_Custom_new_dataset,12:40:16,30.987861233820492,0.6934245824813843,-0.8939516761237561,0.09604286165589182

src/test.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,94 @@
11
import torch
22
from torch import optim
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
6+
from tqdm import tqdm
7+
38

49
from src.esrgan.model import Generator
510
from src.utils import config
611
from src.utils.utils import load_checkpoint, plot_examples
712
from src.utils.utils import seed_torch
13+
from src.utils.data_loaders import get_loaders
14+
15+
import matplotlib.pyplot as plt
16+
import numpy as np
17+
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
18+
19+
20+
def evaluate_and_plot(test_loader, gen, num_samples=3):
21+
psnr_list = []
22+
ssim_list = []
23+
selected_samples = []
24+
25+
loop = tqdm(test_loader, desc="Evaluating", leave=False)
26+
with torch.no_grad():
27+
for idx, (low_res, high_res) in enumerate(loop):
28+
low_res = low_res.to(config.DEVICE)
29+
high_res = high_res.to(config.DEVICE)
30+
31+
# Generator ile yüksek çözünürlüklü görüntü üret
32+
fake_high_res = gen(low_res)
33+
34+
# PSNR ve SSIM hesapla
35+
for i in range(low_res.shape[0]):
36+
sr_img = fake_high_res[i].cpu().numpy().transpose(1, 2, 0)
37+
hr_img = high_res[i].cpu().numpy().transpose(1, 2, 0)
38+
lr_img = low_res[i].cpu().numpy().transpose(1, 2, 0)
39+
40+
# Normalize görüntüler (0-1 aralığına)
41+
sr_img = np.clip(sr_img, 0, 1)
42+
hr_img = np.clip(hr_img, 0, 1)
43+
lr_img = np.clip(lr_img, 0, 1)
44+
45+
# PSNR ve SSIM hesapla
46+
psnr = peak_signal_noise_ratio(hr_img, sr_img, data_range=1.0)
47+
ssim = structural_similarity(
48+
hr_img, sr_img, channel_axis=2, data_range=1.0
49+
)
50+
psnr_list.append(psnr)
51+
ssim_list.append(ssim)
52+
53+
# İlk num_samples kadar örneği seç
54+
if len(selected_samples) < num_samples:
55+
selected_samples.append((lr_img, hr_img, sr_img, psnr, ssim))
56+
57+
# Ortalama PSNR ve SSIM değerlerini yazdır
58+
print(f"Average PSNR: {np.mean(psnr_list):.4f}")
59+
print(f"Average SSIM: {np.mean(ssim_list):.4f}")
60+
61+
# Seçilen örnekleri çiz
62+
fig, axes = plt.subplots(num_samples, 3, figsize=(10, 3 * num_samples))
63+
fig.suptitle(
64+
f"Super-Resolution Evaluation Samples\nPSNR: {np.mean(psnr_list):.4f} SSIM{np.mean(ssim_list):.4f}",
65+
fontsize=16,
66+
)
67+
for i, (lr, hr, sr, psnr, ssim) in enumerate(selected_samples):
68+
axes[i, 0].imshow(lr)
69+
axes[i, 0].set_title(f"LR")
70+
axes[i, 0].axis("off")
71+
72+
axes[i, 1].imshow(sr)
73+
axes[i, 1].set_title(f"SR PSNR: {psnr:.2f} SSIM: {ssim:.4f}")
74+
axes[i, 1].axis("off")
75+
76+
axes[i, 2].imshow(hr)
77+
axes[i, 2].set_title(f"HR")
78+
axes[i, 2].axis("off")
79+
80+
plt.tight_layout()
81+
plt.savefig(
82+
f"{config.SAVE_PATH}/evaluation_samples.png", dpi=300, bbox_inches="tight"
83+
)
84+
plt.show()
85+
886

987
def test():
1088
seed_torch(config.SEED)
89+
_, _, test_loader = get_loaders()
1190
gen = Generator(in_channels=3).to(config.DEVICE)
12-
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
13-
load_checkpoint(
14-
config.CHECKPOINT_GEN,
15-
gen,
16-
opt_gen,
17-
config.LEARNING_RATE,
18-
)
19-
plot_examples(config.TEST_IMAGE_DIR, gen)
91+
gen.eval()
92+
load_checkpoint(config.CHECKPOINT_GEN, gen)
93+
print("Testing the model...")
94+
evaluate_and_plot(test_loader, gen, num_samples=12)

src/utils/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LAMBDA_GP = 10
1414
NUM_WORKERS = 12
1515
NUM_EPOCHS = 25
16+
TEST_SIZE = 1000
1617

1718
BATCH_SIZE = 32
1819
HIGH_RES = 128
@@ -30,8 +31,12 @@
3031
LOG_DIR = "logs"
3132
DATA_SET_NAME_DIR = "esrgan_dataset"
3233
EXPERIMENT = VERSION + "_" + NAME
34+
35+
3336
def extension(mode_name: str) -> str:
3437
return mode_name + EXPERIMENT + VERSION + ".pth"
38+
39+
3540
SAVE_PATH = f"{LOG_DIR}/{DATA_SET_NAME_DIR}/{EXPERIMENT}"
3641
DATA = "data"
3742
CACHE_DIR = DATA + "/cache"

src/utils/dataset_from_folder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, type_="train", train_ratio=0.7, val_ratio=0.1):
1515
self.class_names = os.listdir(self.root_dir)
1616
print(self.class_names[:10])
1717

18-
for index, name in enumerate(self.class_names):
18+
for name in self.class_names:
1919
files = os.path.join(self.root_dir, name)
2020
self.data.append(files)
2121
print(self.data[:10])
@@ -32,7 +32,10 @@ def __init__(self, type_="train", train_ratio=0.7, val_ratio=0.1):
3232
elif type_ == "val":
3333
self.data = self.data[train_end:val_end]
3434
elif type_ == "test":
35-
self.data = self.data[val_end:]
35+
test_size = total - val_end
36+
if test_size >= cfg.TEST_SIZE:
37+
test_size = cfg.TEST_SIZE
38+
self.data = self.data[val_end : val_end + test_size]
3639
else:
3740
raise ValueError("type_ must be 'train', 'val', or 'test'")
3841

@@ -67,4 +70,4 @@ def test():
6770

6871

6972
if __name__ == "__main__":
70-
test()
73+
test()

src/utils/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def gradient_penalty(discriminator, real, fake):
2727
retain_graph=True,
2828
)[0]
2929
# Her görüntüyü vektörleştirir
30-
gradient = gradient.flatten(start_dim=1)
30+
gradient = gradient.flatten(start_dim=1)
3131
# Her örneğin gradyan uzunluğunu hesaplar
3232
gradient_norm = gradient.norm(2, dim=1)
3333
# Normun 1’den sapmasına ceza verir
@@ -44,17 +44,18 @@ def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
4444
torch.save(checkpoint, filename)
4545

4646

47-
def load_checkpoint(checkpoint_file, model, optimizer, lr):
47+
def load_checkpoint(checkpoint_file, model, optimizer=None, lr=None):
4848
print("=> Loading checkpoint")
4949
checkpoint = torch.load(checkpoint_file, map_location=cfg.DEVICE)
5050
# model.load_state_dict(checkpoint)
5151
model.load_state_dict(checkpoint["state_dict"])
52-
optimizer.load_state_dict(checkpoint["optimizer"])
5352

5453
# If we don't do this then it will just have learning rate of old checkpoint
5554
# and it will lead to many hours of debugging \:
56-
for param_group in optimizer.param_groups:
57-
param_group["lr"] = lr
55+
if lr is not None:
56+
optimizer.load_state_dict(checkpoint["optimizer"])
57+
for param_group in optimizer.param_groups:
58+
param_group["lr"] = lr
5859

5960

6061
def plot_examples(low_res_folder, gen, ex):

0 commit comments

Comments
 (0)