-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_experiment.py
More file actions
245 lines (203 loc) · 9.43 KB
/
run_experiment.py
File metadata and controls
245 lines (203 loc) · 9.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import argparse
import yaml
import mlflow
import pandas as pd
import torch
import deepxde as dde
from importlib import import_module
import random
import numpy as np
from src.utils.analysis import fft1D
torch.set_default_dtype(torch.float64)
dde.config.real.set_float64()
def setup_device(config):
"""Sets up the device for torch and deepxde."""
device_str = config.get('execution', {}).get('device', 'gpu')
if device_str.lower() == 'cpu' or not torch.cuda.is_available():
device = torch.device("cpu")
print("--- Running on CPU ---")
else:
device = torch.device("cuda:0")
print(f"--- Running on GPU: {torch.cuda.get_device_name(0)} ---")
return device
# --- Helper to dynamically load classes ---
def load_class(module_path, class_name):
module = import_module(module_path)
return getattr(module, class_name)
def set_random_seed(seed):
"""Sets the random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# The following two lines are often recommended for full reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudunn.benchmark = False
print(f"--- Set all random seeds to {seed} ---")
def main(config_path, seed):
# 1. Load Configuration
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# if seed is not None:
# set_random_seed(seed)
device = setup_device(config)
# OVERFITTING EXPERIMENT
# Set the sine wave frequency to a value depending on the seed
# if seed is not None:
# sine_freq = 4 * np.pi * (seed % 10 + 1)
# config['problem']['ic_params']['sum_sine_components'] = [(5.0, sine_freq)]
# # alpha is decreasing with seed
# config['problem']['alpha'] = 0.1 / ((seed % 10 + 1)*20)
# 2. Load Problem, Model, and Trainer classes dynamically
ProblemClass = load_class('src.problems', config['problem']['name'])
ModelClass = load_class('src.models.all_models', config['model']['name'])
TrainerClass = load_class('src.trainers.standard_trainer', config['trainer']['name'])
# 3. Instantiate components
problem = ProblemClass(config)
if config['model']['output_transform']:
net = ModelClass(config, output_transform=problem.get_output_transform()).to(device)
else:
net = ModelClass(config).to(device)
trainer = TrainerClass(config)
# --- NEW: Pre-training step for models that need IC data ---
if hasattr(net, 'load_ic_data'):
print("Model has 'load_ic_data' method. Performing pre-training IC analysis.")
# Generate a fine grid of points for the IC
ic_grid_x = np.linspace(problem.x_min, problem.x_max, 2048).reshape(-1, 1)
ic_func = problem.get_initial_condition_func()
ic_grid_u = ic_func(ic_grid_x).flatten()
# Get number of fixed frequencies from config
model_cfg = config.get('model', {})
n_freq_fixed = model_cfg.get('n_freq_fixed', 50)
# Perform FFT
k_vals, a_vals, b_vals = fft1D(ic_grid_u, ic_grid_x.flatten(), n_freq_fixed)
# Load data into the network
net.load_ic_data(k_vals, a_vals, b_vals)
# 4. Setup DeepXDE Data object
if hasattr(problem, 'get_test_data'):
data = dde.data.TimePDE(
problem.geomtime,
problem.pde,
problem.get_ics_bcs(),
num_domain=config['data']['num_domain'],
num_boundary=config['data']['num_boundary'],
num_initial=config['data']['num_initial']
)
else:
print("Using analytical solution for test data.")
data = dde.data.TimePDE(
problem.geomtime,
problem.pde,
problem.get_ics_bcs(),
num_domain=config['data']['num_domain'],
num_boundary=config['data']['num_boundary'],
num_initial=config['data']['num_initial'],
solution=problem.analytical_solution,
num_test=config['data']['num_test'],
)
dde_model = dde.Model(data, net)
# 5. Setup MLFlow
mlflow.set_tracking_uri(config['mlflow_tracking_uri'])
experiment = mlflow.get_experiment_by_name(config['experiment_name'])
if not experiment:
mlflow.create_experiment(name=config['experiment_name'])
experiment = mlflow.get_experiment_by_name(config['experiment_name'])
base_run_name = config['run_name']
run_name_with_seed = f"{base_run_name}_seed_{seed}" if seed is not None else base_run_name
with mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name_with_seed) as run:
mlflow.log_params(config)
if seed is not None:
mlflow.log_param("random_seed", seed)
mlflow.log_param("num_model_parameters", sum(p.numel() for p in net.parameters()))
# Setup Callbacks
from src.utils.callbacks import MLFlowMetricsLogger, PredictionLogger, ModelParameterLogger#, GradientImbalanceLogger
log_points_np = problem.geomtime.uniform_points(200*200)
# This callback logs losses and L2 error
metrics_logger = MLFlowMetricsLogger(log_every=300)
# This callback logs prediction history for creating animations later
prediction_logger = PredictionLogger(
log_points=log_points_np,
log_every=300,
run_name=config['run_name']
)
# This callback logs learned parameters like k_n and w_n
param_logger = ModelParameterLogger(
net=net,
log_every=300,
run_name=config['run_name']
)
# grad_imbalance_logger = GradientImbalanceLogger(
# log_every=300
# )
# Resampler
pde_resampler = dde.callbacks.PDEPointResampler(period=300)
# --- Combine all callbacks into a list ---
callbacks = [pde_resampler, metrics_logger, prediction_logger, param_logger]#, grad_imbalance_logger]
# 6. Train the model
if device.type == 'cuda':
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
trainer.train(dde_model, callbacks=callbacks)
if device.type == 'cuda':
end_time.record()
torch.cuda.synchronize()
elapsed_time = start_time.elapsed_time(end_time) / 1000.0
mlflow.log_metric("training_time_seconds", elapsed_time)
# 7. Final Logging and Visualization
print("--- Final Logging and Visualization ---")
# Log learned parameters specific to the model
net.log_specific_params(mlflow)
X_test, y_test = None, None
if hasattr(problem, 'get_test_data'):
X_test, y_test = problem.get_test_data()
if X_test is not None and y_test is not None:
final_preds = dde_model.predict(X_test)
# Calculate final L2 relative error and log it
final_l2_error = dde.metrics.l2_relative_error(y_test, final_preds)
print(f"Final L2 Relative Error: {final_l2_error:.6f}")
mlflow.log_metric("final_l2_relative_error", final_l2_error)
output_dim = y_test.shape[1]
if output_dim > 1:
model_amp = np.sqrt(final_preds[:, 0]**2 + final_preds[:, 1]**2)
truth_amp = np.sqrt(y_test[:, 0]**2 + y_test[:, 1]**2)
df = pd.DataFrame({
'x': X_test[:, 0],
'time': X_test[:, 1],
'model': model_amp.flatten(),
'ground_truth': truth_amp.flatten(),
'difference': (model_amp - truth_amp).flatten()
})
else:
df = pd.DataFrame({
'x': X_test[:, 0],
'time': X_test[:, 1],
'model': final_preds.flatten(),
'ground_truth': y_test.flatten(),
'difference': (final_preds - y_test).flatten()
})
from src.utils.plotting import create_final_solution_plots
create_final_solution_plots(df, config['run_name'], problem.get_plot_amplitude())
else:
log_points_tensor = torch.from_numpy(log_points_np).to(device)
final_preds_tensor = torch.from_numpy(dde_model.predict(log_points_np)).to(device)
true_vals_tensor = problem.analytical_solution(log_points_tensor)
final_preds = final_preds_tensor.cpu().numpy()
true_vals = true_vals_tensor.cpu().numpy()
df = pd.DataFrame({
'x': log_points_np[:, 0],
'time': log_points_np[:, 1],
'model': final_preds.flatten(),
'ground_truth': true_vals.flatten(),
'difference': (final_preds - true_vals).flatten()
})
from src.utils.plotting import create_final_solution_plots
create_final_solution_plots(df, config['run_name'], problem.plot_amplitude)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to the experiment config YAML file.")
parser.add_argument("--seed", type=int, default=None, help="Random seed for the experiment.")
args = parser.parse_args()
main(args.config, args.seed)