-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
184 lines (147 loc) · 5.17 KB
/
utils.py
File metadata and controls
184 lines (147 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# utils.py
import os
import csv
import torch
import random
import numpy as np
import yaml
import logging
import matplotlib.pyplot as plt
def load_config(config_path):
"""
Loads the YAML configuration file.
Args:
config_path (str): Path to the YAML config file.
Returns:
dict: Parsed configuration.
"""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def set_seed(seed):
"""
Sets the random seed for reproducibility.
Args:
seed (int): Seed value.
"""
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def save_model(model, path, epoch, loss):
"""
Save the model state and additional metadata.
Creates the directory (if not present) and saves the model parameters.
Args:
model (torch.nn.Module): The model to save.
path (str): Path to save the model.
epoch (int): Epoch at which the model is saved.
loss (float): Validation loss at the time of saving.
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'model_state_dict': model.state_dict(),
'epoch': epoch,
'loss': loss
}, path)
def load_model(model, optimizer, path, device):
"""
Loads the model state from the specified path.
Args:
model (torch.nn.Module): Model to load the state into.
optimizer (torch.optim.Optimizer): Optimizer to load the state into.
path (str): Path to the saved model.
device (torch.device): Device to map the model to.
Returns:
int: Epoch number.
float: Loss value.
"""
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
return epoch, loss
def get_device():
"""
Returns the available device (GPU if available, else CPU).
Returns:
torch.device: The device to use.
"""
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_logger(log_directory, job_name):
"""
Sets up and returns a logger.
Args:
log_directory (str): Directory to save log files.
job_name (str): Name of the job for log file naming.
Returns:
logging.Logger: Configured logger.
"""
os.makedirs(log_directory, exist_ok=True)
logger = logging.getLogger(job_name)
logger.setLevel(logging.INFO)
if not logger.handlers:
fh = logging.FileHandler(os.path.join(
log_directory, f"{job_name}_log.txt"))
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s', datefmt='%H:%M:%S')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
return logger
def plot_metric_over_epochs(history, metric_name, output_path):
"""
Plots a single metric over epochs.
Args:
history (dict): Dictionary containing metric history.
metric_name (str): Name of the metric to plot.
output_path (str): Directory to save the plot.
"""
train_key = f"train_{metric_name}"
val_key = f"val_{metric_name}"
plt.figure()
plt.plot(history[train_key], label=f"Train {metric_name.capitalize()}")
plt.plot(history[val_key], label=f"Val {metric_name.capitalize()}")
plt.xlabel("Epoch")
plt.ylabel(metric_name.capitalize())
plt.title(f"{metric_name.capitalize()} Over Epochs")
plt.legend()
plt.grid(True)
os.makedirs(output_path, exist_ok=True)
plt.savefig(os.path.join(output_path, f"{metric_name}_over_epochs.png"))
plt.close()
def plot_all_metrics(history, output_path):
"""
Plots all relevant metrics over epochs.
Args:
history (dict): Dictionary containing metric history.
output_path (str): Directory to save the plots.
"""
for metric in ['loss']: # , 'accuracy', 'coverage'
plot_metric_over_epochs(history, metric, output_path)
def save_all_jobs_results(all_jobs_metrics, output_path):
"""
Saves the results of all jobs (hyperparameters and final averaged metrics) into a CSV.
"""
if not all_jobs_metrics:
print("No job metrics to save.")
return
fieldnames = [
'job_name', 'learning_rate', 'epochs', 'batch_size',
'hidden_size', 'num_layers', 'bidirectional',
'avg_train_loss', 'avg_val_loss',
'avg_mask2_coverage', 'avg_mask2_accuracy'
]
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for job_metric in all_jobs_metrics:
writer.writerow(job_metric)