Skip to content

Commit 75219a5

Browse files
committed
Revamp compare_nc_files to fully use xarray
This commit changes the current method for slicing variables in favor of xarray build in functions. It also addresses a reviewer comment about spacing and in building helper fields only for transposed output.
1 parent 753f368 commit 75219a5

File tree

2 files changed

+83
-117
lines changed

2 files changed

+83
-117
lines changed

components/eamxx/scripts/compare_nc_files.py

Lines changed: 65 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -68,125 +68,87 @@ def compare_variables(self):
6868

6969
success = True
7070
if self._compare == None or self._compare == []:
71+
self._compare = []
7172
# If compare is an empty list, compare all variables
7273
print(f"Specific comparison variables not provided,\n"
7374
f"will compare ALL variables in \n"
7475
f"{self._src_file}\n"
7576
f"with\n"
7677
f"{self._tgt_file}\n")
7778
for var in ds_src.variables:
78-
lvar = ds_src[var]
7979
if var not in ds_tgt.variables:
8080
print (f" Comparison failed! Variable not found.\n"
8181
f" - var name: {var}\n"
8282
f" - file name: {self._tgt_file}")
8383
success = False
8484
continue
85-
if self._allow_transpose:
86-
rvar = ds_tgt[var].transpose(*lvar.dims)
87-
else:
88-
rvar = ds_tgt[var]
89-
expect (lvar.dims == rvar.dims,
90-
f" Error!, variables names match, but dimensions do not.\n"
91-
f" - var name: {var}\n"
92-
f" - src dimensions: {lvar.dims}\n"
93-
f" - tgt dimensions: {rvar.dims}\n")
94-
success = self.are_equal(lvar.data,rvar.data)
95-
96-
else:
97-
for expr in self._compare:
98-
# Split the expression, to get the output var name
99-
tokens = expr.split('=')
100-
expect(len(tokens)==2,"Error! Compare variables with 'lhs=rhs' syntax.")
101-
102-
lhs = tokens[0]
103-
rhs = tokens[1]
104-
105-
lname, ldims = self.get_name_and_dims(lhs)
106-
rname, rdims = self.get_name_and_dims(rhs)
107-
108-
if lname not in ds_src.variables:
109-
print (f" Comparison failed! Variable not found.\n"
110-
f" - var name: {lname}\n"
111-
f" - file name: {self._src_file}")
112-
success = False
113-
continue
114-
if rname not in ds_tgt.variables:
115-
print (f" Comparison failed! Variable not found.\n"
116-
f" - var name: {rname}\n"
117-
f" - file name: {self._tgt_file}")
118-
success = False
119-
continue
120-
lvar = ds_src.variables[lname];
121-
rvar = ds_tgt.variables[rname];
122-
123-
lvar_rank = len(lvar.dims)
124-
rvar_rank = len(rvar.dims)
125-
126-
expect (len(ldims)==0 or len(ldims)==lvar_rank,
127-
f"Invalid slice specification for {lname}.\n"
128-
f" input request: ({','.join(ldims)})\n"
129-
f" variable rank: {lvar_rank}")
130-
expect (len(rdims)==0 or len(rdims)==rvar_rank,
131-
f"Invalid slice specification for {rname}.\n"
132-
f" input request: ({','.join(rdims)})\n"
133-
f" variable rank: {rvar_rank}")
134-
135-
136-
lslices = [[idim,slice] for idim,slice in enumerate(ldims) if slice!=":"]
137-
rslices = [[idim,slice] for idim,slice in enumerate(rdims) if slice!=":"]
138-
139-
lrank = lvar_rank - len(lslices)
140-
rrank = rvar_rank - len(rslices)
141-
142-
if lrank!=rrank:
143-
print (f" Comparison failed. Rank mismatch.\n"
144-
f" - input comparison: {expr}\n"
145-
f" - upon slicing, rank({lname}) = {lrank}\n"
146-
f" - upon slicing, rank({rname}) = {rrank}")
147-
success = False
148-
continue
149-
150-
lvals = self.slice_variable(lvar,lvar.data[:],lslices)
151-
rvals = self.slice_variable(rvar,rvar.data[:],rslices)
152-
153-
success = self.are_equal(lvals,rvals)
154-
85+
self._compare.append(var+"="+var)
86+
87+
for expr in self._compare:
88+
# Split the expression, to get the output var name
89+
tokens = expr.split('=')
90+
expect(len(tokens)==2,"Error! Compare variables with 'lhs=rhs' syntax.")
91+
92+
lhs = tokens[0]
93+
rhs = tokens[1]
94+
95+
lname, ldims = self.get_name_and_dims(lhs)
96+
rname, rdims = self.get_name_and_dims(rhs)
97+
98+
if lname not in ds_src.variables:
99+
print (f" Comparison failed! Variable not found.\n"
100+
f" - var name: {lname}\n"
101+
f" - file name: {self._src_file}")
102+
success = False
103+
continue
104+
if rname not in ds_tgt.variables:
105+
print (f" Comparison failed! Variable not found.\n"
106+
f" - var name: {rname}\n"
107+
f" - file name: {self._tgt_file}")
108+
success = False
109+
continue
110+
lvar = ds_src[lname];
111+
rvar = ds_tgt[rname];
112+
113+
lvar_rank = len(lvar.dims)
114+
rvar_rank = len(rvar.dims)
115+
116+
expect (len(ldims)==0 or len(ldims)==lvar_rank,
117+
f"Invalid slice specification for {lname}.\n"
118+
f" input request: ({','.join(ldims)})\n"
119+
f" variable rank: {lvar_rank}")
120+
expect (len(rdims)==0 or len(rdims)==rvar_rank,
121+
f"Invalid slice specification for {rname}.\n"
122+
f" input request: ({','.join(rdims)})\n"
123+
f" variable rank: {rvar_rank}")
124+
125+
lslices = {lvar.dims[idim]:int(slice)-1 for idim,slice in enumerate(ldims) if slice!=":"}
126+
rslices = {rvar.dims[idim]:int(slice)-1 for idim,slice in enumerate(rdims) if slice!=":"}
127+
lvar_sliced = lvar.sel(lslices)
128+
rvar_sliced = rvar.sel(rslices)
129+
expect (set(lvar_sliced.dims) == set(rvar_sliced.dims),
130+
f"Error, even when sliced these two elements do not share the same dimensionsn\n"
131+
f" - left var name : {lname}\n"
132+
f" - right var name : {rname}\n"
133+
f" - left dimensions : {lvar_sliced.dims}\n"
134+
f" - right dimensions: {rvar_sliced.dims}\n")
135+
136+
if self._allow_transpose:
137+
rvar_sliced = rvar_sliced.transpose(*lvar_sliced.dims)
138+
139+
equal = (lvar_sliced.data==rvar_sliced.data).all()
140+
if not equal:
141+
rse = np.sqrt((lvar_sliced.data-rvar_sliced.data)**2)
142+
nonmatch_count = np.count_nonzero(rse)
143+
print (f" Comparison failed. Values differ at {nonmatch_count} out of {rse.size} locations.\n"
144+
f" - input comparison: {expr}\n"
145+
f' - max L2 error, {rse.max()}\n'
146+
f' - max L2 location, [{",".join(map(str,(np.array(np.unravel_index(rse.argmax(),rse.shape))+1).tolist()))}]\n'
147+
f' - dimensions, {lvar_sliced.dims}')
148+
success = False
155149

156150
return success
157151

158-
###########################################################################
159-
def are_equal(self,lvals,rvals):
160-
###########################################################################
161-
162-
if not np.array_equal(lvals,rvals):
163-
# print (f"lvals: {lvals}")
164-
# print (f"rvals: {rvals}")
165-
item = np.argwhere(lvals!=rvals)[0]
166-
rval = self.slice_variable(rvar,rvals,
167-
[[idim,slice] for idim,slice in enumerate(item)])
168-
lval = self.slice_variable(lvar,lvals,
169-
[[idim,slice] for idim,slice in enumerate(item)])
170-
loc = ",".join([str(i+1) for i in item])
171-
print (f" Comparison failed. Values differ.\n"
172-
f" - input comparison: {expr}\n"
173-
f' - upon slicing, {lname}({loc}) = {lval}\n'
174-
f' - upon slicing, {rname}({loc}) = {rval}')
175-
return False
176-
return True
177-
178-
###########################################################################
179-
def slice_variable(self,var,vals,slices):
180-
###########################################################################
181-
182-
if len(slices)==0:
183-
return vals
184-
185-
idim, slice_idx = slices.pop(-1)
186-
vals = vals.take(int(slice_idx)-1,axis=int(idim))
187-
188-
return self.slice_variable(var,vals,slices)
189-
190152
###########################################################################
191153
def run(self):
192154
###########################################################################

components/eamxx/src/share/io/scorpio_output.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,18 @@ void AtmosphereOutput::init()
346346

347347
// Initialize a helper_field for each unique layout. This can be used for operations
348348
// such as writing transposed output.
349-
if (m_helper_fields.find(layout.to_string()) == m_helper_fields.end()) {
350-
// We can add a new helper field for this layout
351-
const auto helper_layout = m_transpose ? layout.transpose() : layout;
352-
const std::string helper_name = "helper_"+helper_layout.to_string();
353-
using namespace ekat::units;
354-
FieldIdentifier fid_helper(helper_name,helper_layout,Units::invalid(),fid.get_grid_name());
355-
Field helper(fid_helper);
356-
helper.get_header().get_alloc_properties().request_allocation();
357-
helper.allocate_view();
358-
m_helper_fields[layout.to_string()] = helper;
349+
if (m_transpose) {
350+
const std::string helper_name = "transposed_"+helper_layout.to_string();
351+
if (m_helper_fields.find(helper_name == m_helper_fields.end()) {
352+
// We can add a new helper field for this layout
353+
const auto helper_layout = m_transpose ? layout.transpose() : layout;
354+
using namespace ekat::units;
355+
FieldIdentifier fid_helper(helper_name,helper_layout,Units::invalid(),fid.get_grid_name());
356+
Field helper(fid_helper);
357+
helper.get_header().get_alloc_properties().request_allocation();
358+
helper.allocate_view();
359+
m_helper_fields[helper_name] = helper;
360+
}
359361
}
360362

361363
// Now check that all the dims of this field are already set to be registered.
@@ -516,12 +518,13 @@ run (const std::string& filename,
516518
auto func_start = std::chrono::steady_clock::now();
517519
if (m_transpose) {
518520
const auto& fl = count.get_header().get_identifier().get_layout().to_string();
519-
auto& temp = m_helper_fields.at(fl);
521+
const std::string helper_name = "transposed_"+fl;
522+
auto& temp = m_helper_fields.at(helper_name);
520523
transpose(count,temp);
521524
scorpio::write_var(filename,count.name(),temp.get_internal_view_data<int,Host>());
522-
} else {
525+
} else {
523526
scorpio::write_var(filename,count.name(),count.get_internal_view_data<int,Host>());
524-
}
527+
}
525528
auto func_finish = std::chrono::steady_clock::now();
526529
auto duration_loc = std::chrono::duration_cast<std::chrono::milliseconds>(func_finish - func_start);
527530
duration_write += duration_loc.count();
@@ -601,7 +604,8 @@ run (const std::string& filename,
601604
auto func_start = std::chrono::steady_clock::now();
602605
if (m_transpose) {
603606
const auto& fl = f_out.get_header().get_identifier().get_layout().to_string();
604-
auto& temp = m_helper_fields.at(fl);
607+
const std::string helper_name = "transposed_"+fl;
608+
auto& temp = m_helper_fields.at(helper_name);
605609
transpose(f_out,temp);
606610
scorpio::write_var(filename,field_name,temp.get_internal_view_data<Real,Host>());
607611
} else {

0 commit comments

Comments
 (0)