Skip to content

Commit 4a75c0b

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add numeric check tools (#170)
Summary: Add a small tool to check the numeric of two exported runs Pull Request resolved: #170 Reviewed By: adamomainz Differential Revision: D70496932 Pulled By: xuzhao9 fbshipit-source-id: 2013d90d0866005756f5d0022906aca767728fe5
1 parent ebec32e commit 4a75c0b

File tree

1 file changed

+66
-0
lines changed
  • benchmarks/numeric_check

1 file changed

+66
-0
lines changed

benchmarks/numeric_check/run.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import argparse
2+
import os
3+
import pickle
4+
5+
from typing import Any
6+
7+
import torch
8+
9+
10+
def find_pkl_files(path: str):
11+
abs_path = os.path.abspath(path)
12+
return [
13+
f
14+
for f in os.listdir(abs_path)
15+
if os.path.isfile(os.path.join(abs_path, f)) and f.endswith(".pkl")
16+
]
17+
18+
19+
def common_pkl_files(pkl_files_a, pkl_files_b):
20+
set_a = set(pkl_files_a)
21+
set_b = set(pkl_files_b)
22+
return list(set_a.intersection(set_b)), list(set_a - set_b), list(set_b - set_a)
23+
24+
25+
def check_tensor_numeric(a, b):
26+
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
27+
torch.testing.assert_close(a, b)
28+
return
29+
assert isinstance(a, list) or isinstance(
30+
a, tuple
31+
), f"Out A must be a tuple or list, get type {type(a)}."
32+
assert isinstance(b, list) or isinstance(
33+
b, tuple
34+
), f"Out B must be a tuple or list, get type {type(b)}."
35+
assert len(a) == len(
36+
b
37+
), f"A and B must be equal length, but len_a={len(a)}, len_b={len(b)}"
38+
for i in range(len(a)):
39+
tensor_a = a[i]
40+
tensor_b = b[i]
41+
torch.testing.assert_close(tensor_a, tensor_b)
42+
43+
44+
def load_data_from_pickle_file(pickle_file_path) -> Any:
45+
with open(pickle_file_path, "rb") as pfp:
46+
data = pickle.load(pfp)
47+
return data
48+
49+
50+
if __name__ == "__main__":
51+
parser = argparse.ArgumentParser()
52+
parser.add_argument("--a", help="Side A of the output.")
53+
parser.add_argument("--b", help="Side B of the output.")
54+
args = parser.parse_args()
55+
pkl_files_a = find_pkl_files(args.a)
56+
pkl_files_b = find_pkl_files(args.b)
57+
common_files, a_only_files, b_only_files = common_pkl_files(
58+
pkl_files_a, pkl_files_b
59+
)
60+
for common_file in sorted(common_files):
61+
print(f"checking {common_file} ...", end="")
62+
data_a = load_data_from_pickle_file(os.path.join(args.a, common_file))
63+
data_b = load_data_from_pickle_file(os.path.join(args.b, common_file))
64+
check_tensor_numeric(data_a, data_b)
65+
print("OK")
66+
print("A and B numerically match.")

0 commit comments

Comments
 (0)