Skip to content

Commit 8854071

Browse files
authored
Adds a better command line interface to gauge comparisons (#746)
* Adds a better command line interface to gauge comparison Also adds some minor changes to imports and a todo regarding switching to pathlib * Add short flag versions
1 parent de24cc8 commit 8854071

File tree

1 file changed

+68
-66
lines changed

1 file changed

+68
-66
lines changed

src/pyclaw/gauges.py

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
"""Module contains class definitions related to dealing with gauge data"""
44

5+
# TODO: switch to pathlib for all file handling
6+
from pathlib import Path
57
import os
6-
import sys
7-
import numpy
8+
import argparse
9+
import warnings
10+
11+
import numpy as np
812

913
# See if pandas is available
1014
try:
@@ -75,7 +79,7 @@ def __init__(self, gauge_id=None, path=None, use_pandas=False):
7579
def read(self, gauge_id, path=None, use_pandas=False):
7680
r"""Read the gauge file at path into this object
7781
78-
Read in the gauge with gauge id `gauge_id` located at `path`.
82+
Read in the gauge with gauge id `gauge_id` located at `path`.
7983
8084
If `use_pandas` is `True` then `q` will be a Pandas frame.
8185
(not yet implemented)
@@ -137,10 +141,10 @@ def read(self, gauge_id, path=None, use_pandas=False):
137141

138142
#data = line.split()
139143
#nvals = 1 + len(data)-6 # should agree with num_eqn
140-
144+
141145
# Check to see if there is also data in the .txt file,
142146
# otherwise perhaps it's in a binary .bin file.
143-
147+
144148
line = gauge_file.readline()
145149
if len(line) > 0:
146150
if 'binary32' in line:
@@ -172,7 +176,7 @@ def read(self, gauge_id, path=None, use_pandas=False):
172176

173177
if file_format == 'ascii':
174178
# data follows header in .txt file:
175-
data = numpy.loadtxt(gauge_path, comments="#")
179+
data = np.loadtxt(gauge_path, comments="#")
176180
if data.ndim == 1:
177181
# only one line in gauge file, expand to 2d array
178182
data = data.reshape((1,len(data)))
@@ -184,28 +188,27 @@ def read(self, gauge_id, path=None, use_pandas=False):
184188
if not os.path.isfile(gauge_path):
185189
msg = 'No data in .txt file and did not find ' \
186190
+ '\n binary file %s' % gauge_path
187-
import warnings
188191
warnings.warn(msg)
189192
return
190193

191194
if file_format in ['binary','binary64']:
192-
data = numpy.fromfile(gauge_path, dtype=numpy.float64)
195+
data = np.fromfile(gauge_path, dtype=np.float64)
193196
elif file_format == 'binary32':
194-
data = numpy.fromfile(gauge_path, dtype=numpy.float32)
197+
data = np.fromfile(gauge_path, dtype=np.float32)
195198

196199
# assume rows are: level, t, q[0:num_eqn]
197200
nrows = 2 + num_eqn
198-
assert numpy.mod(len(data),nrows) == 0, \
201+
assert np.mod(len(data), nrows) == 0, \
199202
'*** unexpected number of values in gauge file' \
200203
+ '\n*** expected nrows = %i rows' % nrows
201204
ncols = int(len(data)/nrows)
202205
data = data.reshape((nrows,ncols), order='F').T
203206

204-
self.level = data[:, 0].astype(numpy.int64)
207+
self.level = data[:, 0].astype(np.int64)
205208
self.t = data[:, 1]
206209
self.q = data[:, 2:].transpose()
207210

208-
211+
209212
if num_eqn != self.q.shape[0]:
210213
raise ValueError("Number of fields in gauge file does not match",
211214
"recorded number in header.")
@@ -215,11 +218,11 @@ def read(self, gauge_id, path=None, use_pandas=False):
215218
if self.gtype == 'lagrangian':
216219
# for lagrangian gauges x,y paths are stored in q in place of u,v
217220
# note: only implemented in 2d geoclaw for now!
218-
self.particle_path = numpy.vstack((self.t, self.q[1,:],
221+
self.particle_path = np.vstack((self.t, self.q[1,:],
219222
self.q[2,:])).T
220223
else:
221224
# set the lagrangian path to a single fixed location at t=t0:
222-
self.particle_path = numpy.array([[self.t[0], self.location[0],
225+
self.particle_path = np.array([[self.t[0], self.location[0],
223226
self.location[1]]])
224227

225228

@@ -333,11 +336,15 @@ def compare_gauges(paths, gauge_id, fields='all'):
333336
axes.set_ylabel("q[%s, :]" % n)
334337
axes.legend()
335338

336-
axes = fig.add_subplot(len(fields), 2, 2 * i + 2)
337-
axes.plot(gauges[0].t,
338-
numpy.abs(gauges[0].q[n, :] - gauges[1].q[n, :]), 'r')
339-
axes.set_xlabel("t")
340-
axes.set_ylabel("$|q_{old}[%s, :] - q_{new}[%s, :]|$" % (n, n))
339+
if gauges[0].t.shape == gauges[1].t.shape:
340+
axes = fig.add_subplot(len(fields), 2, 2 * i + 2)
341+
axes.plot(gauges[0].t,
342+
np.abs(gauges[0].q[n, :] - gauges[1].q[n, :]), 'r')
343+
axes.set_xlabel("t")
344+
axes.set_ylabel("$|q_{old}[%s, :] - q_{new}[%s, :]|$" % (n, n))
345+
else:
346+
warnings.warn("Gauge time series do not match, skipping direct " +
347+
"comparison.")
341348

342349
return fig
343350

@@ -363,17 +370,17 @@ def convert_gauges(path, output_path=None):
363370
output_path = os.getcwd()
364371

365372
# Load old data
366-
data = numpy.loadtxt(path)
367-
old_ids = numpy.asarray(data[:, 0], dtype=int)
368-
unique_ids = numpy.asarray(list(set(old_ids)))
373+
data = np.loadtxt(path)
374+
old_ids = np.asarray(data[:, 0], dtype=int)
375+
unique_ids = np.asarray(list(set(old_ids)))
369376

370377
# Create new gauges and compare
371378
for gauge_id in unique_ids:
372-
gauge_indices = numpy.nonzero(old_ids == gauge_id)[0]
379+
gauge_indices = np.nonzero(old_ids == gauge_id)[0]
373380
new_gauge = GaugeSolution()
374381
new_gauge.id = gauge_id
375-
new_gauge.location = (numpy.nan, numpy.nan)
376-
new_gauge.level = numpy.asarray(data[gauge_indices, 1], dtype=int)
382+
new_gauge.location = (np.nan, np.nan)
383+
new_gauge.level = np.asarray(data[gauge_indices, 1], dtype=int)
377384
new_gauge.t = data[gauge_indices, 2]
378385
new_gauge.q = data[gauge_indices, 3:].transpose()
379386
new_gauge.write(output_path)
@@ -386,7 +393,7 @@ def convert_gauges(path, output_path=None):
386393

387394

388395
def compare_old_gauges(old_path, new_path, gauge_id, plot=False, abs_tol=1e-14,
389-
rel_tol=0.0,
396+
rel_tol=0.0,
390397
verbose=False):
391398
r"""Compare old gauge data at `path` to new gauge data at same path
392399
@@ -413,9 +420,9 @@ def compare_old_gauges(old_path, new_path, gauge_id, plot=False, abs_tol=1e-14,
413420
"""
414421

415422
# Load old gauge data
416-
data = numpy.loadtxt(old_path)
417-
old_ids = numpy.asarray(data[:, 0], dtype=int)
418-
gauge_indices = numpy.nonzero(old_ids == gauge_id)[0]
423+
data = np.loadtxt(old_path)
424+
old_ids = np.asarray(data[:, 0], dtype=int)
425+
gauge_indices = np.nonzero(old_ids == gauge_id)[0]
419426
q = data[gauge_indices, 3:]
420427

421428
# Load new data
@@ -425,9 +432,9 @@ def compare_old_gauges(old_path, new_path, gauge_id, plot=False, abs_tol=1e-14,
425432
if verbose:
426433
print("Comparison of gauge %s:" % gauge_id)
427434
print(r" ||\Delta q||_2 = ",
428-
numpy.linalg.norm(q - gauge.q.transpose(), ord=2))
435+
np.linalg.norm(q - gauge.q.transpose(), ord=2))
429436
print(r" arg(||\Delta q||_\infty = ",
430-
numpy.argmax(q - gauge.q.transpose()))
437+
np.argmax(q - gauge.q.transpose()))
431438

432439
if plot:
433440
import matplotlib.pyplot as plt
@@ -439,7 +446,7 @@ def compare_old_gauges(old_path, new_path, gauge_id, plot=False, abs_tol=1e-14,
439446
axes.set_xlabel("t (s)")
440447
axes.set_ylabel("q(%s, :)" % i)
441448

442-
return numpy.allclose(q, gauge.q.transpose(), rtol=rel_tol, atol=abs_tol)
449+
return np.allclose(q, gauge.q.transpose(), rtol=rel_tol, atol=abs_tol)
443450

444451

445452
def check_old_gauge_data(path, gauge_id, new_gauge_path="./regression_data"):
@@ -451,27 +458,27 @@ def check_old_gauge_data(path, gauge_id, new_gauge_path="./regression_data"):
451458
:Input:
452459
- *path* (string) - Path to old gauge data file
453460
- *gauge_id* (int) - Gauge id to compare
454-
- *new_gauge_path* (path) - Path to directory containing new gauge files,
461+
- *new_gauge_path* (path) - Path to directory containing new gauge files,
455462
defaults to './regression_data'.
456463
457464
:Output:
458-
- (figure) Returns a matplotlib figure object plotting the differences in
465+
- (figure) Returns a matplotlib figure object plotting the differences in
459466
time.
460467
"""
461468

462469
import matplotlib.pyplot as plt
463470

464471
# Load old gauge data
465-
data = numpy.loadtxt(path)
466-
old_ids = numpy.asarray(data[:, 0], dtype=int)
467-
gauge_indices = numpy.nonzero(old_ids == gauge_id)[0]
472+
data = np.loadtxt(path)
473+
old_ids = np.asarray(data[:, 0], dtype=int)
474+
gauge_indices = np.nonzero(old_ids == gauge_id)[0]
468475
q = data[gauge_indices, 3:]
469476

470477
# Load new data
471478
gauge = GaugeSolution(gauge_id, new_gauge_path)
472479

473-
print(numpy.linalg.norm(q - gauge.q.transpose(), ord=2))
474-
print(numpy.argmax(q - gauge.q.transpose()))
480+
print(np.linalg.norm(q - gauge.q.transpose(), ord=2))
481+
print(np.argmax(q - gauge.q.transpose()))
475482

476483
fig = plt.figure()
477484
for i in range(gauge.q.shape[0]):
@@ -481,34 +488,29 @@ def check_old_gauge_data(path, gauge_id, new_gauge_path="./regression_data"):
481488

482489
return fig
483490

484-
485491
if __name__ == "__main__":
486-
487492
import matplotlib.pyplot as plt
493+
parser = argparse.ArgumentParser()
494+
parser.add_argument('path', type=Path, nargs=2,
495+
help="Paths to gauge data directories")
496+
parser.add_argument('--id', '-n', type=int, nargs='+', default=None,
497+
help="Gauge IDs to plot, " +
498+
"defaults to all found in first path")
499+
parser.add_argument('--fields', '-f', type=int, nargs='+', default=None,
500+
help="Fields to plot, defaults to all")
501+
args = parser.parse_args()
502+
503+
if args.id:
504+
gauge_ids = args.id
505+
else:
506+
gauge_ids = [int(str(path.name[5:10]))
507+
for path in args.path[0].glob("gauge*.txt")]
508+
if args.fields:
509+
fields = args.fields
510+
else:
511+
fields = 'all'
512+
513+
for gauge_id in gauge_ids:
514+
fig = compare_gauges(args.path, gauge_id, fields)
488515

489-
help_msg = \
490-
"""gauges.py path1 path2 [gauge_id] [fields...]
491-
492-
Plots a comparison between the gauges at path1 and path2 with gauge_id and the
493-
fields specified. Only one gauge_id can be specified at a time but a number of
494-
fields can be specified including 'all'.
495-
"""
496-
497-
fields = 'all'
498-
gauge_id = 1
499-
if len(sys.argv) < 3:
500-
print(help_msg)
501-
sys.exit(0)
502-
elif len(sys.argv) >= 3:
503-
paths = [str(sys.argv[1]), str(sys.argv[2])]
504-
if len(sys.argv) > 3:
505-
gauge_id = int(sys.argv[3])
506-
if len(sys.argv) > 4:
507-
if sys.argv[4].lower() == 'all':
508-
fields = 'all'
509-
else:
510-
fields = [int(field for field in sys.argv[4:])]
511-
512-
513-
fig = compare_gauges(paths, gauge_id, fields)
514516
plt.show()

0 commit comments

Comments
 (0)