Skip to content

Commit d7a7302

Browse files
authored
Merge pull request #21 from Xreki/use_all_inputs
Support use_all_inputs when subgraph_range is start at 0.
2 parents d76ca9e + a741abe commit d7a7302

File tree

4 files changed

+84
-33
lines changed

4 files changed

+84
-33
lines changed

athena/generators/graphnet_sequence_sample_generator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,27 @@ class SequenceFuncDesc:
3333

3434

3535
class GraphnetSequenceSampleGenerator:
36-
def __init__(self, program_id, op_example_inputs_meta_getter):
36+
def __init__(self, program_id, program_seq_stmts, op_example_inputs_meta_getter):
3737
self.program_id = program_id
38+
self.program_seq_stmts = program_seq_stmts
3839
self.op_example_inputs_meta_getter = op_example_inputs_meta_getter
3940
self.input_spec_mode = "original"
4041

41-
def Generate(self, seq_stmts):
42-
seq_func_desc = self.MakeSequenceFuncDesc(seq_stmts)
42+
def Generate(self, subgraph_range, use_all_inputs):
43+
assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2
44+
seq_stmts = self.program_seq_stmts[subgraph_range[0] : subgraph_range[1]]
45+
seq_func_desc = self.MakeSequenceFuncDesc(
46+
seq_stmts, use_all_inputs and subgraph_range[0] == 0
47+
)
4348
return self._RenderTemplate(seq_func_desc)
4449

45-
def MakeSequenceFuncDesc(self, seq_stmts):
46-
op_id2seq_stmt = OrderedDict((stmt.op_id, stmt) for stmt in seq_stmts)
50+
def MakeSequenceFuncDesc(self, seq_stmts, use_all_inputs):
51+
if use_all_inputs:
52+
op_id2seq_stmt = OrderedDict(
53+
(stmt.op_id, stmt) for stmt in self.program_seq_stmts
54+
)
55+
else:
56+
op_id2seq_stmt = OrderedDict((stmt.op_id, stmt) for stmt in seq_stmts)
4757
ops_func_signature = OpsFuncSignature(
4858
tensor_ids=self.GetTensorIds(op_id2seq_stmt),
4959
operand_ids=self.GetOperandIds(op_id2seq_stmt),

athena/generators/template_graphnet_sequence_sample.jinja

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ class GraphModule(paddle.nn.Layer):
120120
{%- endfor -%}
121121
{%- for tensor_id in sig.tensor_ids -%}
122122
{%- if 'parameter' not in sig.tensor_name4tensor_id(tensor_id) %}
123+
{%- set data, dtype = sig.immediate_value4int_array_member_id(tensor_id) -%}
124+
{%- if data is none %}
123125
{{"\t\t"}}{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}{{","}}
124126
{%- endif %}
127+
{%- endif %}
125128
{%- endfor -%}
126129
):
127130
{%- for tensor_id in sig.tensor_ids -%}

athena/generators/template_op_example_input_meta_script.jinja

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def CalculateTensorMeta(tensor, meta_name):
7777
raise NotImplementedError(f"meta_name: {meta_name}")
7878

7979

80+
def IsInitialized(tensor):
81+
try:
82+
is_initialized = isinstance(tensor, paddle.Tensor) and tensor.numel() > 0
83+
_ = x.shape
84+
_ = x.dtype
85+
return is_initialized
86+
except Exception:
87+
return False
88+
89+
8090
def InitTensorMeta(tensor, meta_name, tensor_meta):
8191
if tensor_meta:
8292
return getattr(tensor_meta, meta_name)
@@ -88,6 +98,8 @@ def InitTensorMeta(tensor, meta_name, tensor_meta):
8898
return [InitTensorMeta(t, meta_name, tensor_meta) for t in tensor]
8999
if not hasattr(tensor, meta_name):
90100
raise NotImplementedError(f"type(tensor): {type(tensor)}, meta_name: {meta_name}")
101+
if not IsInitialized(tensor):
102+
return None
91103
kLimit = 64
92104
if tensor.numel().item() < kLimit:
93105
return None
@@ -109,6 +121,8 @@ def InitTensorData(tensor):
109121
return [InitTensorData(t) for t in tensor]
110122
if not hasattr(tensor, 'numel'):
111123
raise NotImplementedError(f"type(tensor): {type(tensor)}")
124+
if not IsInitialized(tensor):
125+
return None
112126
kLimit = 64
113127
if tensor.numel().item() >= kLimit:
114128
return None

athena/graphnet_samples.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from itertools import groupby
1010
from collections import defaultdict
1111
from dataclasses import dataclass
12-
from typing import Dict
12+
from typing import Dict, List
1313

1414
from athena.generators.blocks_generator import BlocksGenerator
1515
from athena.generators.block_name_generator import BlockNameGenerator
@@ -56,12 +56,17 @@
5656
True,
5757
"Whether extend split_positions to include the head and tail of the statement sequence.",
5858
)
59+
flags.DEFINE_boolean(
60+
"use_all_inputs",
61+
False,
62+
"Whether use all inputs of the ir program.",
63+
)
5964
flags.DEFINE_boolean(
6065
"eval_mode",
6166
False,
6267
"Generate graphnet sample for eval, which only keep output tensors with maximum depth (longest chain).",
6368
)
64-
flags.DEFINE_string("tmp_dir", tempfile.gettempdir(), "tmp directory.")
69+
flags.DEFINE_string("tmp_dir", None, "tmp directory.")
6570

6671

6772
@dataclass
@@ -73,10 +78,11 @@ class GraphnetSample:
7378
input_meta: str
7479
weight_meta: str
7580
model: str
81+
subgraph_range: List[int] = None
7682

7783

7884
def ConvertOutputStringToSample(
79-
model_name, unique_name, subgraph_idx, program_id, sample_str
85+
model_name, unique_name, subgraph_idx, program_id, sample_str, subgraph_range=None
8086
):
8187
metadata = {
8288
"framework": "paddle",
@@ -94,6 +100,7 @@ def ConvertOutputStringToSample(
94100
input_meta=input_meta.strip("\n\n\n") + "\n",
95101
weight_meta=weight_meta.rstrip("\n\n\n") + "\n",
96102
model=model,
103+
subgraph_range=subgraph_range,
97104
)
98105
# PrintToTerminal(unique_name, sample_str)
99106
return sample
@@ -154,7 +161,7 @@ def __init__(
154161
example_inputs_file,
155162
op_example_inputs_file,
156163
eval_mode,
157-
tmp_dir,
164+
tmp_dir=None,
158165
):
159166
self.model_name = model_name
160167
self.programs_file = programs_file
@@ -243,50 +250,63 @@ def ExtendHeadAndTail(self, seq_stmts, split_positions, group_head_and_tail):
243250
print(f"split_positions_for_seq_stmts: {split_positions_for_seq_stmts}")
244251
return split_positions_for_seq_stmts
245252

246-
def GetOutputSampleStrings(self, split_positions, group_head_and_tail=True):
253+
def GetOutputSampleStrings(
254+
self, split_positions, group_head_and_tail=True, use_all_inputs=False
255+
):
247256
def MakeSequenceSampleGenerator(
248-
program_id, seq_stmts, op_example_inputs_meta_getter
257+
program_id, program_seq_stmts, op_example_inputs_meta_getter
249258
):
250-
generator = GraphnetSequenceSampleGenerator(
251-
program_id, op_example_inputs_meta_getter
259+
return GraphnetSequenceSampleGenerator(
260+
program_id, program_seq_stmts, op_example_inputs_meta_getter
252261
)
253-
return generator.Generate(seq_stmts)
254262

255263
print(f"origin split_positions: {split_positions}")
256264
generated_sample_strs = set()
257-
for subgraph_idx, (program_id, seq_stmts) in enumerate(
265+
for subgraph_idx, (program_id, program_seq_stmts) in enumerate(
258266
self.program_seq_stmts_list
259267
):
268+
generator = MakeSequenceSampleGenerator(
269+
program_id, program_seq_stmts, self.op_example_inputs_meta_getter
270+
)
260271
split_positions_for_seq_stmts = self.ExtendHeadAndTail(
261-
seq_stmts, split_positions, group_head_and_tail
272+
program_seq_stmts, split_positions, group_head_and_tail
262273
)
263274
for i in range(len(split_positions_for_seq_stmts) - 1):
264-
seq_stmts_slice = seq_stmts[
265-
split_positions_for_seq_stmts[i] : split_positions_for_seq_stmts[
266-
i + 1
267-
]
268-
]
269-
sample_str = MakeSequenceSampleGenerator(
270-
program_id, seq_stmts_slice, self.op_example_inputs_meta_getter
271-
)
275+
subgraph_range = split_positions_for_seq_stmts[i : i + 2]
276+
sample_str = generator.Generate(subgraph_range, use_all_inputs)
272277
if sample_str not in generated_sample_strs:
273278
generated_sample_strs.add(sample_str)
274-
stmt_hash = GetSeqStmtsHash(seq_stmts_slice)
275-
yield (subgraph_idx, program_id, stmt_hash, sample_str)
276-
277-
def __call__(self, split_positions, group_head_and_tail=True):
279+
stmt_hash = GetSeqStmtsHash(
280+
program_seq_stmts[subgraph_range[0] : subgraph_range[1]]
281+
)
282+
yield (
283+
subgraph_idx,
284+
program_id,
285+
stmt_hash,
286+
subgraph_range,
287+
sample_str,
288+
)
289+
290+
def __call__(self, split_positions, group_head_and_tail=True, use_all_inputs=False):
278291
graphnet_sample_results = []
279292
seg_counter = defaultdict(lambda: itertools.count())
280-
for _, (subgraph_idx, program_id, uid, sample_str) in enumerate(
281-
self.GetOutputSampleStrings(split_positions, group_head_and_tail)
293+
for _, (subgraph_idx, program_id, uid, subgraph_range, sample_str) in enumerate(
294+
self.GetOutputSampleStrings(
295+
split_positions, group_head_and_tail, use_all_inputs
296+
)
282297
):
283298
unique_name = f"{uid}_{next(seg_counter[uid])}"
284299
sample = ConvertOutputStringToSample(
285-
self.model_name, unique_name, subgraph_idx, program_id, sample_str
300+
self.model_name,
301+
unique_name,
302+
subgraph_idx,
303+
program_id,
304+
sample_str,
305+
subgraph_range,
286306
)
287307
graphnet_sample_results.append(sample)
288308
print(
289-
f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=})."
309+
f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=}, {use_all_inputs=})."
290310
)
291311
return graphnet_sample_results
292312

@@ -298,6 +318,7 @@ def RunGeneration(
298318
op_example_inputs,
299319
split_positions,
300320
group_head_and_tail,
321+
use_all_inputs,
301322
eval_mode,
302323
tmp_dir=None,
303324
):
@@ -313,7 +334,9 @@ def RunGeneration(
313334
eval_mode,
314335
tmp_dir,
315336
)
316-
graphnet_sample_results = generator(split_positions, group_head_and_tail)
337+
graphnet_sample_results = generator(
338+
split_positions, group_head_and_tail, use_all_inputs
339+
)
317340
return graphnet_sample_results
318341

319342

@@ -327,6 +350,7 @@ def main(argv):
327350
op_example_inputs=FLAGS.op_example_inputs,
328351
split_positions=split_positions,
329352
group_head_and_tail=FLAGS.group_head_and_tail,
353+
use_all_inputs=FLAGS.use_all_inputs,
330354
eval_mode=FLAGS.eval_mode,
331355
)
332356

0 commit comments

Comments
 (0)