|
1 | 1 | import argparse
|
2 |
| -import pickle |
3 | 2 | import os
|
4 |
| -import torch |
| 3 | +import pickle |
5 | 4 |
|
6 | 5 | from typing import Any
|
7 |
| -from pathlib import Path |
8 | 6 |
|
9 |
| -def find_pkl_files(path: Path): |
10 |
| - abs_path = path.absolute() |
11 |
| - return [f for f in os.listdir(abs_path) if os.path.isfile(os.path.join(abs_path, f)) and f.endswith(".pkl")] |
| 7 | +import torch |
| 8 | + |
| 9 | + |
| 10 | +def find_pkl_files(path: str): |
| 11 | + abs_path = os.path.abspath(path) |
| 12 | + return [ |
| 13 | + f |
| 14 | + for f in os.listdir(abs_path) |
| 15 | + if os.path.isfile(os.path.join(abs_path, f)) and f.endswith(".pkl") |
| 16 | + ] |
| 17 | + |
12 | 18 |
|
13 | 19 | def common_pkl_files(pkl_files_a, pkl_files_b):
|
14 | 20 | set_a = set(pkl_files_a)
|
15 | 21 | set_b = set(pkl_files_b)
|
16 | 22 | return list(set_a.intersection(set_b)), list(set_a - set_b), list(set_b - set_a)
|
17 | 23 |
|
| 24 | + |
18 | 25 | def check_tensor_numeric(a, b):
|
19 |
| - assert isinstance(a, tuple), "Out A must be a tuple." |
20 |
| - assert isinstance(b, tuple), "Out B must be a tuple." |
21 |
| - assert len(a) == len(b), f"A and B must be equal length, but len_a={len(a)}, len_b={len(b)}" |
| 26 | + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): |
| 27 | + torch.testing.assert_close(a, b) |
| 28 | + return |
| 29 | + assert isinstance(a, list) or isinstance( |
| 30 | + a, tuple |
| 31 | + ), f"Out A must be a tuple or list, get type {type(a)}." |
| 32 | + assert isinstance(b, list) or isinstance( |
| 33 | + b, tuple |
| 34 | + ), f"Out B must be a tuple or list, get type {type(b)}." |
| 35 | + assert len(a) == len( |
| 36 | + b |
| 37 | + ), f"A and B must be equal length, but len_a={len(a)}, len_b={len(b)}" |
22 | 38 | for i in range(len(a)):
|
23 | 39 | tensor_a = a[i]
|
24 | 40 | tensor_b = b[i]
|
25 | 41 | torch.testing.assert_close(tensor_a, tensor_b)
|
26 | 42 |
|
| 43 | + |
27 | 44 | def load_data_from_pickle_file(pickle_file_path) -> Any:
|
28 | 45 | with open(pickle_file_path, "rb") as pfp:
|
29 | 46 | data = pickle.load(pfp)
|
30 | 47 | return data
|
31 | 48 |
|
| 49 | + |
32 | 50 | if __name__ == "__main__":
|
33 | 51 | parser = argparse.ArgumentParser()
|
34 | 52 | parser.add_argument("--a", help="Side A of the output.")
|
35 | 53 | parser.add_argument("--b", help="Side B of the output.")
|
36 | 54 | args = parser.parse_args()
|
37 | 55 | pkl_files_a = find_pkl_files(args.a)
|
38 | 56 | pkl_files_b = find_pkl_files(args.b)
|
39 |
| - common_files, a_only_files, b_only_files = common_pkl_files(pkl_files_a, pkl_files_b) |
40 |
| - for common_file in common_files: |
| 57 | + common_files, a_only_files, b_only_files = common_pkl_files( |
| 58 | + pkl_files_a, pkl_files_b |
| 59 | + ) |
| 60 | + for common_file in sorted(common_files): |
| 61 | + print(f"checking {common_file} ...", end="") |
41 | 62 | data_a = load_data_from_pickle_file(os.path.join(args.a, common_file))
|
42 | 63 | data_b = load_data_from_pickle_file(os.path.join(args.b, common_file))
|
43 | 64 | check_tensor_numeric(data_a, data_b)
|
44 |
| - print("A and B sides numeric match.") |
| 65 | + print("OK") |
| 66 | + print("A and B numerically match.") |
0 commit comments