Skip to content

Commit 8126469

Browse files
authored
Merge pull request #416 from matthew-frank/matthew-frank/view-rcp-interpolation-also
Add abililty to view interpolation to rcp_viewer.py
2 parents 782506c + 65e69a4 commit 8126469

1 file changed

Lines changed: 55 additions & 6 deletions

File tree

mlperf_logging/rcp_checker/visualization_scripts/rcp_viewer.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,41 @@
1313

1414
from mlperf_logging.rcp_checker.rcp_checker import RCP_Checker
1515

16+
def print_rcp_record(record):
17+
print(f"{record['BS']},{record['RCP Mean']},{record['Min Epochs']}")
18+
19+
# this should be a method of rcp_checker.RCP_Checker, but it's missing.
20+
# Instead we derived it from _find_min_rcp()
21+
def find_max_rcp(checker, rcp_pass_arg='pruned_rcps'):
22+
'''Find RCP with the smallest batch size for a benchmark'''
23+
max_bs = -1
24+
max_record = None
25+
rcp_data = checker._get_rcp_data(rcp_pass_arg)
26+
for _, record_contents in rcp_data.items():
27+
if record_contents['BS'] > max_bs:
28+
max_record = record_contents
29+
max_bs = record_contents['BS']
30+
return max_record
31+
32+
# this should be a method of rcp_checker.RCP_Checker, but it's missing.
33+
# Instead we derived it by extracting parts of rcp_checker.check_directory()
34+
def get_rcp_record_for_bs(bs, checker, rcp_pass_arg='pruned_rcps'):
35+
rcp_record = checker._find_rcp(bs, rcp_pass_arg)
36+
if rcp_record is None:
37+
# bs is not one of the generated sizes, so need to interpolate:
38+
rcp_max = checker._find_bottom_max_rcp(bs, rcp_pass_arg)
39+
if rcp_max is None:
40+
raise RuntimeError("Error: no sufficiently large RCP bs found")
41+
rcp_min = checker._find_top_min_rcp(bs, rcp_pass_arg)
42+
if rcp_min is None:
43+
# bs is smaller than the smallest rcp, so just use smallest rcp
44+
rcp_record = checker._find_min_rcp(rcp_pass_arg)
45+
else:
46+
# interpolate
47+
interp_record_name, interp_record = checker._create_interp_rcp(bs, rcp_min, rcp_max)
48+
rcp_record = interp_record
49+
return rcp_record
50+
1651
def main():
1752
parser = argparse.ArgumentParser(
1853
description='Parse rcps_.json file, prune, and print out rcp means and mins'
@@ -27,18 +62,32 @@ def main():
2762
parser.add_argument('--verbose', action='store_true')
2863
parser.add_argument('--unpruned', action='store_true',
2964
help='print the unpruned rcps instead of the pruned')
65+
parser.add_argument('--no-header', action='store_true',
66+
help='do not print the header line')
3067
parser.add_argument('--custom_rcps', type=argparse.FileType('r'),
3168
help='specify an RCP json file to use')
69+
parser.add_argument('--interpolate', action='store_true',
70+
help='generate interpolated rcp min/mean for all batch sizes')
71+
3272

3373
args = parser.parse_args()
34-
checker=RCP_Checker(args.usage, args.version, args.benchmark, args.verbose, args.custom_rcps)
35-
data=checker.pruned_rcp_data
74+
rcp_pass_arg='pruned_rcps'
3675
if (args.unpruned):
37-
data=checker.rcp_data
76+
rcp_pass_arg='full_rcps'
77+
78+
checker=RCP_Checker(args.usage, args.version, args.benchmark, args.verbose, args.custom_rcps)
79+
80+
if not args.no_header:
81+
print("BS,Mean,Min")
3882

39-
print("BS,Mean,Min")
40-
for key, record in data.items():
41-
print(f"{record['BS']},{record['RCP Mean']},{record['Min Epochs']}")
83+
if not args.interpolate:
84+
data=checker._get_rcp_data(rcp_pass_arg)
85+
for key, record in data.items():
86+
print_rcp_record(record)
87+
else:
88+
for bs in range(checker._find_min_rcp(rcp_pass_arg)['BS'], find_max_rcp(checker, rcp_pass_arg)['BS']+1):
89+
record = get_rcp_record_for_bs(bs, checker, rcp_pass_arg)
90+
print_rcp_record(record)
4291

4392
if __name__ == '__main__':
4493
main()

0 commit comments

Comments
 (0)