1313
1414from mlperf_logging .rcp_checker .rcp_checker import RCP_Checker
1515
16+ def print_rcp_record (record ):
17+ print (f"{ record ['BS' ]} ,{ record ['RCP Mean' ]} ,{ record ['Min Epochs' ]} " )
18+
19+ # this should be a method of rcp_checker.RCP_Checker, but it's missing.
20+ # Instead we derived it from _find_min_rcp()
21+ def find_max_rcp (checker , rcp_pass_arg = 'pruned_rcps' ):
22+ '''Find RCP with the smallest batch size for a benchmark'''
23+ max_bs = - 1
24+ max_record = None
25+ rcp_data = checker ._get_rcp_data (rcp_pass_arg )
26+ for _ , record_contents in rcp_data .items ():
27+ if record_contents ['BS' ] > max_bs :
28+ max_record = record_contents
29+ max_bs = record_contents ['BS' ]
30+ return max_record
31+
32+ # this should be a method of rcp_checker.RCP_Checker, but it's missing.
33+ # Instead we derived it by extracting parts of rcp_checker.check_directory()
34+ def get_rcp_record_for_bs (bs , checker , rcp_pass_arg = 'pruned_rcps' ):
35+ rcp_record = checker ._find_rcp (bs , rcp_pass_arg )
36+ if rcp_record is None :
37+ # bs is not one of the generated sizes, so need to interpolate:
38+ rcp_max = checker ._find_bottom_max_rcp (bs , rcp_pass_arg )
39+ if rcp_max is None :
40+ raise RuntimeError ("Error: no sufficiently large RCP bs found" )
41+ rcp_min = checker ._find_top_min_rcp (bs , rcp_pass_arg )
42+ if rcp_min is None :
43+ # bs is smaller than the smallest rcp, so just use smallest rcp
44+ rcp_record = checker ._find_min_rcp (rcp_pass_arg )
45+ else :
46+ # interpolate
47+ interp_record_name , interp_record = checker ._create_interp_rcp (bs , rcp_min , rcp_max )
48+ rcp_record = interp_record
49+ return rcp_record
50+
1651def main ():
1752 parser = argparse .ArgumentParser (
1853 description = 'Parse rcps_.json file, prune, and print out rcp means and mins'
@@ -27,18 +62,32 @@ def main():
2762 parser .add_argument ('--verbose' , action = 'store_true' )
2863 parser .add_argument ('--unpruned' , action = 'store_true' ,
2964 help = 'print the unpruned rcps instead of the pruned' )
65+ parser .add_argument ('--no-header' , action = 'store_true' ,
66+ help = 'do not print the header line' )
3067 parser .add_argument ('--custom_rcps' , type = argparse .FileType ('r' ),
3168 help = 'specify an RCP json file to use' )
69+ parser .add_argument ('--interpolate' , action = 'store_true' ,
70+ help = 'generate interpolated rcp min/mean for all batch sizes' )
71+
3272
3373 args = parser .parse_args ()
34- checker = RCP_Checker (args .usage , args .version , args .benchmark , args .verbose , args .custom_rcps )
35- data = checker .pruned_rcp_data
74+ rcp_pass_arg = 'pruned_rcps'
3675 if (args .unpruned ):
37- data = checker .rcp_data
76+ rcp_pass_arg = 'full_rcps'
77+
78+ checker = RCP_Checker (args .usage , args .version , args .benchmark , args .verbose , args .custom_rcps )
79+
80+ if not args .no_header :
81+ print ("BS,Mean,Min" )
3882
39- print ("BS,Mean,Min" )
40- for key , record in data .items ():
41- print (f"{ record ['BS' ]} ,{ record ['RCP Mean' ]} ,{ record ['Min Epochs' ]} " )
83+ if not args .interpolate :
84+ data = checker ._get_rcp_data (rcp_pass_arg )
85+ for key , record in data .items ():
86+ print_rcp_record (record )
87+ else :
88+ for bs in range (checker ._find_min_rcp (rcp_pass_arg )['BS' ], find_max_rcp (checker , rcp_pass_arg )['BS' ]+ 1 ):
89+ record = get_rcp_record_for_bs (bs , checker , rcp_pass_arg )
90+ print_rcp_record (record )
4291
4392if __name__ == '__main__' :
4493 main ()
0 commit comments