-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidation.py
More file actions
156 lines (121 loc) · 4.16 KB
/
validation.py
File metadata and controls
156 lines (121 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Validation Script for Far-side Magnetogram Generation.
Kim, Park et al. (2019), Nature Astronomy, 3, 397
https://doi.org/10.1038/s41550-019-0711-5
This script loads a trained checkpoint and evaluates on the validation set.
Usage:
python validation.py --checkpoint ./checkpoints/checkpoint_best.pth
python validation.py --checkpoint ./checkpoints/checkpoint_best.pth --data_dir ./data
"""
from pathlib import Path
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from config import TrainConfig
from dataset import ValidationDataset
from networks import Generator
def evaluate(
model: nn.Module,
val_loader: DataLoader,
device: torch.device,
) -> Dict[str, float]:
"""
Run evaluation on validation set.
Args:
model: Generator model.
val_loader: Validation data loader.
device: Device to run evaluation on.
Returns:
Dictionary of evaluation metrics.
"""
model.eval()
criterion_l1 = nn.L1Loss()
criterion_mse = nn.MSELoss()
total_l1 = 0.0
total_mse = 0.0
total_samples = 0
with torch.no_grad():
for euv, magnetogram in val_loader:
euv = euv.to(device)
magnetogram = magnetogram.to(device)
output = model(euv)
# L1 loss
loss_l1 = criterion_l1(output, magnetogram)
total_l1 += loss_l1.item() * euv.size(0)
# MSE loss
loss_mse = criterion_mse(output, magnetogram)
total_mse += loss_mse.item() * euv.size(0)
total_samples += euv.size(0)
avg_l1 = total_l1 / total_samples
avg_mse = total_mse / total_samples
avg_rmse = np.sqrt(avg_mse)
return {
"l1_loss": avg_l1,
"mse_loss": avg_mse,
"rmse": avg_rmse,
"num_samples": total_samples,
}
def main() -> None:
"""Main validation function."""
# Load config
config = TrainConfig.from_args()
# Check checkpoint path
checkpoint_path = Path(config.checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
print(f"Loading checkpoint: {checkpoint_path}")
# Load checkpoint
device = torch.device(config.device)
checkpoint = torch.load(checkpoint_path, map_location=device)
# Create model
generator = Generator(
in_channels=config.in_channels,
out_channels=config.out_channels,
base_features=config.ngf,
)
generator.load_state_dict(checkpoint["generator_state_dict"])
generator = generator.to(device)
print(f"Loaded checkpoint from epoch {checkpoint['epoch'] + 1}")
# Create validation dataset
val_dataset = ValidationDataset(
data_dir=f"{config.data_dir}/valid",
data_range=config.data_range,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
)
print(f"Validation samples: {len(val_dataset)}")
print(f"Device: {device}")
print("-" * 50)
# Run evaluation
metrics = evaluate(generator, val_loader, device)
# Print results
print("Validation Results:")
print(f" L1 Loss: {metrics['l1_loss']:.6f}")
print(f" MSE Loss: {metrics['mse_loss']:.6f}")
print(f" RMSE: {metrics['rmse']:.6f}")
print(f" Samples: {metrics['num_samples']}")
print("-" * 50)
# Save results to file
output_dir = Path(config.save_dir)
output_dir.mkdir(parents=True, exist_ok=True)
result_path = output_dir / "validation_results.txt"
with open(result_path, "w") as f:
f.write("Validation Results\n")
f.write("=" * 40 + "\n")
f.write(f"Checkpoint: {checkpoint_path}\n")
f.write(f"Epoch: {checkpoint['epoch'] + 1}\n")
f.write(f"Samples: {metrics['num_samples']}\n")
f.write("-" * 40 + "\n")
f.write(f"L1 Loss: {metrics['l1_loss']:.6f}\n")
f.write(f"MSE Loss: {metrics['mse_loss']:.6f}\n")
f.write(f"RMSE: {metrics['rmse']:.6f}\n")
print(f"Results saved to: {result_path}")
if __name__ == "__main__":
main()