99from itertools import groupby
1010from collections import defaultdict
1111from dataclasses import dataclass
12- from typing import Dict
12+ from typing import Dict , List
1313
1414from athena .generators .blocks_generator import BlocksGenerator
1515from athena .generators .block_name_generator import BlockNameGenerator
5656 True ,
5757 "Whether extend split_positions to include the head and tail of the statement sequence." ,
5858)
59+ flags .DEFINE_boolean (
60+ "use_all_inputs" ,
61+ False ,
62+ "Whether use all inputs of the ir program." ,
63+ )
5964flags .DEFINE_boolean (
6065 "eval_mode" ,
6166 False ,
6267 "Generate graphnet sample for eval, which only keep output tensors with maximum depth (longest chain)." ,
6368)
64- flags .DEFINE_string ("tmp_dir" , tempfile . gettempdir () , "tmp directory." )
69+ flags .DEFINE_string ("tmp_dir" , None , "tmp directory." )
6570
6671
6772@dataclass
@@ -73,10 +78,11 @@ class GraphnetSample:
7378 input_meta : str
7479 weight_meta : str
7580 model : str
81+ subgraph_range : List [int ] = None
7682
7783
7884def ConvertOutputStringToSample (
79- model_name , unique_name , subgraph_idx , program_id , sample_str
85+ model_name , unique_name , subgraph_idx , program_id , sample_str , subgraph_range = None
8086):
8187 metadata = {
8288 "framework" : "paddle" ,
@@ -94,6 +100,7 @@ def ConvertOutputStringToSample(
94100 input_meta = input_meta .strip ("\n \n \n " ) + "\n " ,
95101 weight_meta = weight_meta .rstrip ("\n \n \n " ) + "\n " ,
96102 model = model ,
103+ subgraph_range = subgraph_range ,
97104 )
98105 # PrintToTerminal(unique_name, sample_str)
99106 return sample
@@ -154,7 +161,7 @@ def __init__(
154161 example_inputs_file ,
155162 op_example_inputs_file ,
156163 eval_mode ,
157- tmp_dir ,
164+ tmp_dir = None ,
158165 ):
159166 self .model_name = model_name
160167 self .programs_file = programs_file
@@ -243,50 +250,63 @@ def ExtendHeadAndTail(self, seq_stmts, split_positions, group_head_and_tail):
243250 print (f"split_positions_for_seq_stmts: { split_positions_for_seq_stmts } " )
244251 return split_positions_for_seq_stmts
245252
246- def GetOutputSampleStrings (self , split_positions , group_head_and_tail = True ):
253+ def GetOutputSampleStrings (
254+ self , split_positions , group_head_and_tail = True , use_all_inputs = False
255+ ):
247256 def MakeSequenceSampleGenerator (
248- program_id , seq_stmts , op_example_inputs_meta_getter
257+ program_id , program_seq_stmts , op_example_inputs_meta_getter
249258 ):
250- generator = GraphnetSequenceSampleGenerator (
251- program_id , op_example_inputs_meta_getter
259+ return GraphnetSequenceSampleGenerator (
260+ program_id , program_seq_stmts , op_example_inputs_meta_getter
252261 )
253- return generator .Generate (seq_stmts )
254262
255263 print (f"origin split_positions: { split_positions } " )
256264 generated_sample_strs = set ()
257- for subgraph_idx , (program_id , seq_stmts ) in enumerate (
265+ for subgraph_idx , (program_id , program_seq_stmts ) in enumerate (
258266 self .program_seq_stmts_list
259267 ):
268+ generator = MakeSequenceSampleGenerator (
269+ program_id , program_seq_stmts , self .op_example_inputs_meta_getter
270+ )
260271 split_positions_for_seq_stmts = self .ExtendHeadAndTail (
261- seq_stmts , split_positions , group_head_and_tail
272+ program_seq_stmts , split_positions , group_head_and_tail
262273 )
263274 for i in range (len (split_positions_for_seq_stmts ) - 1 ):
264- seq_stmts_slice = seq_stmts [
265- split_positions_for_seq_stmts [i ] : split_positions_for_seq_stmts [
266- i + 1
267- ]
268- ]
269- sample_str = MakeSequenceSampleGenerator (
270- program_id , seq_stmts_slice , self .op_example_inputs_meta_getter
271- )
275+ subgraph_range = split_positions_for_seq_stmts [i : i + 2 ]
276+ sample_str = generator .Generate (subgraph_range , use_all_inputs )
272277 if sample_str not in generated_sample_strs :
273278 generated_sample_strs .add (sample_str )
274- stmt_hash = GetSeqStmtsHash (seq_stmts_slice )
275- yield (subgraph_idx , program_id , stmt_hash , sample_str )
276-
277- def __call__ (self , split_positions , group_head_and_tail = True ):
279+ stmt_hash = GetSeqStmtsHash (
280+ program_seq_stmts [subgraph_range [0 ] : subgraph_range [1 ]]
281+ )
282+ yield (
283+ subgraph_idx ,
284+ program_id ,
285+ stmt_hash ,
286+ subgraph_range ,
287+ sample_str ,
288+ )
289+
290+ def __call__ (self , split_positions , group_head_and_tail = True , use_all_inputs = False ):
278291 graphnet_sample_results = []
279292 seg_counter = defaultdict (lambda : itertools .count ())
280- for _ , (subgraph_idx , program_id , uid , sample_str ) in enumerate (
281- self .GetOutputSampleStrings (split_positions , group_head_and_tail )
293+ for _ , (subgraph_idx , program_id , uid , subgraph_range , sample_str ) in enumerate (
294+ self .GetOutputSampleStrings (
295+ split_positions , group_head_and_tail , use_all_inputs
296+ )
282297 ):
283298 unique_name = f"{ uid } _{ next (seg_counter [uid ])} "
284299 sample = ConvertOutputStringToSample (
285- self .model_name , unique_name , subgraph_idx , program_id , sample_str
300+ self .model_name ,
301+ unique_name ,
302+ subgraph_idx ,
303+ program_id ,
304+ sample_str ,
305+ subgraph_range ,
286306 )
287307 graphnet_sample_results .append (sample )
288308 print (
289- f"[SubgraphGenerator] Generate { len (graphnet_sample_results )} graphnet subgraph samples ({ split_positions = } , { group_head_and_tail = } )."
309+ f"[SubgraphGenerator] Generate { len (graphnet_sample_results )} graphnet subgraph samples ({ split_positions = } , { group_head_and_tail = } , { use_all_inputs = } )."
290310 )
291311 return graphnet_sample_results
292312
@@ -298,6 +318,7 @@ def RunGeneration(
298318 op_example_inputs ,
299319 split_positions ,
300320 group_head_and_tail ,
321+ use_all_inputs ,
301322 eval_mode ,
302323 tmp_dir = None ,
303324):
@@ -313,7 +334,9 @@ def RunGeneration(
313334 eval_mode ,
314335 tmp_dir ,
315336 )
316- graphnet_sample_results = generator (split_positions , group_head_and_tail )
337+ graphnet_sample_results = generator (
338+ split_positions , group_head_and_tail , use_all_inputs
339+ )
317340 return graphnet_sample_results
318341
319342
@@ -327,6 +350,7 @@ def main(argv):
327350 op_example_inputs = FLAGS .op_example_inputs ,
328351 split_positions = split_positions ,
329352 group_head_and_tail = FLAGS .group_head_and_tail ,
353+ use_all_inputs = FLAGS .use_all_inputs ,
330354 eval_mode = FLAGS .eval_mode ,
331355 )
332356
0 commit comments