Skip to content

Commit da4ac69

Browse files
committed
Add numeric check
1 parent e9fd3fc commit da4ac69

File tree

1 file changed

+34
-12
lines changed
  • benchmarks/numeric_check

1 file changed

+34
-12
lines changed

benchmarks/numeric_check/run.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,66 @@
11
import argparse
2-
import pickle
32
import os
4-
import torch
3+
import pickle
54

65
from typing import Any
7-
from pathlib import Path
86

9-
def find_pkl_files(path: Path):
10-
abs_path = path.absolute()
11-
return [f for f in os.listdir(abs_path) if os.path.isfile(os.path.join(abs_path, f)) and f.endswith(".pkl")]
7+
import torch
8+
9+
10+
def find_pkl_files(path: str):
11+
abs_path = os.path.abspath(path)
12+
return [
13+
f
14+
for f in os.listdir(abs_path)
15+
if os.path.isfile(os.path.join(abs_path, f)) and f.endswith(".pkl")
16+
]
17+
1218

1319
def common_pkl_files(pkl_files_a, pkl_files_b):
1420
set_a = set(pkl_files_a)
1521
set_b = set(pkl_files_b)
1622
return list(set_a.intersection(set_b)), list(set_a - set_b), list(set_b - set_a)
1723

24+
1825
def check_tensor_numeric(a, b):
19-
assert isinstance(a, tuple), "Out A must be a tuple."
20-
assert isinstance(b, tuple), "Out B must be a tuple."
21-
assert len(a) == len(b), f"A and B must be equal length, but len_a={len(a)}, len_b={len(b)}"
26+
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
27+
torch.testing.assert_close(a, b)
28+
return
29+
assert isinstance(a, list) or isinstance(
30+
a, tuple
31+
), f"Out A must be a tuple or list, get type {type(a)}."
32+
assert isinstance(b, list) or isinstance(
33+
b, tuple
34+
), f"Out B must be a tuple or list, get type {type(b)}."
35+
assert len(a) == len(
36+
b
37+
), f"A and B must be equal length, but len_a={len(a)}, len_b={len(b)}"
2238
for i in range(len(a)):
2339
tensor_a = a[i]
2440
tensor_b = b[i]
2541
torch.testing.assert_close(tensor_a, tensor_b)
2642

43+
2744
def load_data_from_pickle_file(pickle_file_path) -> Any:
2845
with open(pickle_file_path, "rb") as pfp:
2946
data = pickle.load(pfp)
3047
return data
3148

49+
3250
if __name__ == "__main__":
3351
parser = argparse.ArgumentParser()
3452
parser.add_argument("--a", help="Side A of the output.")
3553
parser.add_argument("--b", help="Side B of the output.")
3654
args = parser.parse_args()
3755
pkl_files_a = find_pkl_files(args.a)
3856
pkl_files_b = find_pkl_files(args.b)
39-
common_files, a_only_files, b_only_files = common_pkl_files(pkl_files_a, pkl_files_b)
40-
for common_file in common_files:
57+
common_files, a_only_files, b_only_files = common_pkl_files(
58+
pkl_files_a, pkl_files_b
59+
)
60+
for common_file in sorted(common_files):
61+
print(f"checking {common_file} ...", end="")
4162
data_a = load_data_from_pickle_file(os.path.join(args.a, common_file))
4263
data_b = load_data_from_pickle_file(os.path.join(args.b, common_file))
4364
check_tensor_numeric(data_a, data_b)
44-
print("A and B sides numeric match.")
65+
print("OK")
66+
print("A and B numerically match.")

0 commit comments

Comments
 (0)