Skip to content

Commit 28d33d7

Browse files
authored
Merge pull request #34 from CUQI-DTU/setup_ST_experiment
Setup st experiment
2 parents bfb05c1 + 3c2bfe3 commit 28d33d7

File tree

5 files changed

+1575
-267
lines changed

5 files changed

+1575
-267
lines changed

code/ear_aqueducts/advection_diffusion_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137

138138
#%% STEP 3: Create output directory
139139
#----------------------------------
140-
parent_dir = 'results_jan6/'+args.version
140+
parent_dir = 'results_feb7/'+args.version
141141
dir_name = parent_dir +'/output'+tag
142142
if not os.path.exists(dir_name):
143143
os.makedirs(dir_name)
@@ -290,7 +290,7 @@
290290
callback_obj = Callback(
291291
dir_name=dir_name,
292292
exact_x=exact_x,
293-
exact_data=exact_data.reshape(G_cont2D.fun_shape),
293+
exact_data=exact_data.reshape(G_cont2D.fun_shape) if exact_data is not None else None,
294294
data=data.reshape(G_cont2D.fun_shape),
295295
args=args,
296296
locations=diff_locations if args.data_grad else locations,

code/ear_aqueducts/advection_diffusion_inference_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def read_data_files(args):
329329
real_std_data = std_file[CA_ST_std_list].values.T
330330
if args.data_grad:
331331
print('real_data shape: ', real_data.shape)
332-
print('real data:\n' , real_data)
333332
real_data_diff = np.zeros((real_data.shape[0]-1, real_data.shape[1]))
334333

335334
for i in range(real_data.shape[0]-1):
@@ -775,7 +774,7 @@ def sample_the_posterior(sampler, posterior, G_c, args, callback=None):
775774

776775
return posterior_samples_burnthin, my_sampler
777776

778-
def plot_time_series(times, locations, data, plot_legend=True, plot_type='over_time', d3_alpha=0, marker=None, linestyle='-', colormap=None):
777+
def plot_time_series(times, locations, data, plot_legend=True, plot_type='over_time', d3_alpha=0, marker=None, linestyle='-', colormap=None, y_log=False, plot_against=None):
779778
# Plot data
780779
# plot type can be 'over_time' or 'over_location' or 'surface'
781780
if colormap is None:
@@ -819,10 +818,31 @@ def plot_time_series(times, locations, data, plot_legend=True, plot_type='over_t
819818
lines = None
820819
legends = None
821820
# rotate the plot
821+
822+
elif plot_type == 'against_data':
823+
if plot_against is None:
824+
raise Exception('plot_against must be provided when plot_type is "against_data"')
825+
if len(plot_against) != len(data):
826+
raise Exception('plot_against must have the same length as data')
827+
legends = ['loc = '+"{:.2f}".format(obs) for obs in locations]
828+
lines = []
829+
for i in range(len(locations)):
830+
lines.append(plt.scatter(plot_against[i,:], data[i,:], color=color[i%len(color)],marker=marker))
831+
832+
if plot_legend:
833+
plt.legend(lines, legends)
834+
plt.xlabel('data')
835+
plt.ylabel('reconstruction')
836+
if y_log:
837+
plt.xscale('log')
838+
822839

823840

824841
else:
825842
raise Exception('Unsupported plot type')
843+
# set y scale to log
844+
if y_log:
845+
plt.yscale('log')
826846

827847
return lines, legends
828848

0 commit comments

Comments
 (0)