@@ -68,125 +68,87 @@ def compare_variables(self):
6868
6969 success = True
7070 if self ._compare == None or self ._compare == []:
71+ self ._compare = []
7172 # If compare is an empty list, compare all variables
7273 print (f"Specific comparison variables not provided,\n "
7374 f"will compare ALL variables in \n "
7475 f"{ self ._src_file } \n "
7576 f"with\n "
7677 f"{ self ._tgt_file } \n " )
7778 for var in ds_src .variables :
78- lvar = ds_src [var ]
7979 if var not in ds_tgt .variables :
8080 print (f" Comparison failed! Variable not found.\n "
8181 f" - var name: { var } \n "
8282 f" - file name: { self ._tgt_file } " )
8383 success = False
8484 continue
85- if self ._allow_transpose :
86- rvar = ds_tgt [var ].transpose (* lvar .dims )
87- else :
88- rvar = ds_tgt [var ]
89- expect (lvar .dims == rvar .dims ,
90- f" Error!, variables names match, but dimensions do not.\n "
91- f" - var name: { var } \n "
92- f" - src dimensions: { lvar .dims } \n "
93- f" - tgt dimensions: { rvar .dims } \n " )
94- success = self .are_equal (lvar .data ,rvar .data )
95-
96- else :
97- for expr in self ._compare :
98- # Split the expression, to get the output var name
99- tokens = expr .split ('=' )
100- expect (len (tokens )== 2 ,"Error! Compare variables with 'lhs=rhs' syntax." )
101-
102- lhs = tokens [0 ]
103- rhs = tokens [1 ]
104-
105- lname , ldims = self .get_name_and_dims (lhs )
106- rname , rdims = self .get_name_and_dims (rhs )
107-
108- if lname not in ds_src .variables :
109- print (f" Comparison failed! Variable not found.\n "
110- f" - var name: { lname } \n "
111- f" - file name: { self ._src_file } " )
112- success = False
113- continue
114- if rname not in ds_tgt .variables :
115- print (f" Comparison failed! Variable not found.\n "
116- f" - var name: { rname } \n "
117- f" - file name: { self ._tgt_file } " )
118- success = False
119- continue
120- lvar = ds_src .variables [lname ];
121- rvar = ds_tgt .variables [rname ];
122-
123- lvar_rank = len (lvar .dims )
124- rvar_rank = len (rvar .dims )
125-
126- expect (len (ldims )== 0 or len (ldims )== lvar_rank ,
127- f"Invalid slice specification for { lname } .\n "
128- f" input request: ({ ',' .join (ldims )} )\n "
129- f" variable rank: { lvar_rank } " )
130- expect (len (rdims )== 0 or len (rdims )== rvar_rank ,
131- f"Invalid slice specification for { rname } .\n "
132- f" input request: ({ ',' .join (rdims )} )\n "
133- f" variable rank: { rvar_rank } " )
134-
135-
136- lslices = [[idim ,slice ] for idim ,slice in enumerate (ldims ) if slice != ":" ]
137- rslices = [[idim ,slice ] for idim ,slice in enumerate (rdims ) if slice != ":" ]
138-
139- lrank = lvar_rank - len (lslices )
140- rrank = rvar_rank - len (rslices )
141-
142- if lrank != rrank :
143- print (f" Comparison failed. Rank mismatch.\n "
144- f" - input comparison: { expr } \n "
145- f" - upon slicing, rank({ lname } ) = { lrank } \n "
146- f" - upon slicing, rank({ rname } ) = { rrank } " )
147- success = False
148- continue
149-
150- lvals = self .slice_variable (lvar ,lvar .data [:],lslices )
151- rvals = self .slice_variable (rvar ,rvar .data [:],rslices )
152-
153- success = self .are_equal (lvals ,rvals )
154-
85+ self ._compare .append (var + "=" + var )
86+
87+ for expr in self ._compare :
88+ # Split the expression, to get the output var name
89+ tokens = expr .split ('=' )
90+ expect (len (tokens )== 2 ,"Error! Compare variables with 'lhs=rhs' syntax." )
91+
92+ lhs = tokens [0 ]
93+ rhs = tokens [1 ]
94+
95+ lname , ldims = self .get_name_and_dims (lhs )
96+ rname , rdims = self .get_name_and_dims (rhs )
97+
98+ if lname not in ds_src .variables :
99+ print (f" Comparison failed! Variable not found.\n "
100+ f" - var name: { lname } \n "
101+ f" - file name: { self ._src_file } " )
102+ success = False
103+ continue
104+ if rname not in ds_tgt .variables :
105+ print (f" Comparison failed! Variable not found.\n "
106+ f" - var name: { rname } \n "
107+ f" - file name: { self ._tgt_file } " )
108+ success = False
109+ continue
110+ lvar = ds_src [lname ];
111+ rvar = ds_tgt [rname ];
112+
113+ lvar_rank = len (lvar .dims )
114+ rvar_rank = len (rvar .dims )
115+
116+ expect (len (ldims )== 0 or len (ldims )== lvar_rank ,
117+ f"Invalid slice specification for { lname } .\n "
118+ f" input request: ({ ',' .join (ldims )} )\n "
119+ f" variable rank: { lvar_rank } " )
120+ expect (len (rdims )== 0 or len (rdims )== rvar_rank ,
121+ f"Invalid slice specification for { rname } .\n "
122+ f" input request: ({ ',' .join (rdims )} )\n "
123+ f" variable rank: { rvar_rank } " )
124+
125+ lslices = {lvar .dims [idim ]:int (slice )- 1 for idim ,slice in enumerate (ldims ) if slice != ":" }
126+ rslices = {rvar .dims [idim ]:int (slice )- 1 for idim ,slice in enumerate (rdims ) if slice != ":" }
127+ lvar_sliced = lvar .sel (lslices )
128+ rvar_sliced = rvar .sel (rslices )
129+ expect (set (lvar_sliced .dims ) == set (rvar_sliced .dims ),
130+ f"Error, even when sliced these two elements do not share the same dimensionsn\n "
131+ f" - left var name : { lname } \n "
132+ f" - right var name : { rname } \n "
133+ f" - left dimensions : { lvar_sliced .dims } \n "
134+ f" - right dimensions: { rvar_sliced .dims } \n " )
135+
136+ if self ._allow_transpose :
137+ rvar_sliced = rvar_sliced .transpose (* lvar_sliced .dims )
138+
139+ equal = (lvar_sliced .data == rvar_sliced .data ).all ()
140+ if not equal :
141+ rse = np .sqrt ((lvar_sliced .data - rvar_sliced .data )** 2 )
142+ nonmatch_count = np .count_nonzero (rse )
143+ print (f" Comparison failed. Values differ at { nonmatch_count } out of { rse .size } locations.\n "
144+ f" - input comparison: { expr } \n "
145+ f' - max L2 error, { rse .max ()} \n '
146+ f' - max L2 location, [{ "," .join (map (str ,(np .array (np .unravel_index (rse .argmax (),rse .shape ))+ 1 ).tolist ()))} ]\n '
147+ f' - dimensions, { lvar_sliced .dims } ' )
148+ success = False
155149
156150 return success
157151
158- ###########################################################################
159- def are_equal (self ,lvals ,rvals ):
160- ###########################################################################
161-
162- if not np .array_equal (lvals ,rvals ):
163- # print (f"lvals: {lvals}")
164- # print (f"rvals: {rvals}")
165- item = np .argwhere (lvals != rvals )[0 ]
166- rval = self .slice_variable (rvar ,rvals ,
167- [[idim ,slice ] for idim ,slice in enumerate (item )])
168- lval = self .slice_variable (lvar ,lvals ,
169- [[idim ,slice ] for idim ,slice in enumerate (item )])
170- loc = "," .join ([str (i + 1 ) for i in item ])
171- print (f" Comparison failed. Values differ.\n "
172- f" - input comparison: { expr } \n "
173- f' - upon slicing, { lname } ({ loc } ) = { lval } \n '
174- f' - upon slicing, { rname } ({ loc } ) = { rval } ' )
175- return False
176- return True
177-
178- ###########################################################################
179- def slice_variable (self ,var ,vals ,slices ):
180- ###########################################################################
181-
182- if len (slices )== 0 :
183- return vals
184-
185- idim , slice_idx = slices .pop (- 1 )
186- vals = vals .take (int (slice_idx )- 1 ,axis = int (idim ))
187-
188- return self .slice_variable (var ,vals ,slices )
189-
190152 ###########################################################################
191153 def run (self ):
192154 ###########################################################################
0 commit comments