21
21
from collections import namedtuple
22
22
from contextlib import contextmanager
23
23
from copy import deepcopy
24
+ from enum import Enum
24
25
from functools import reduce
25
26
from typing import TYPE_CHECKING , Any , Callable , Tuple , Union
26
27
49
50
from ...symbolic_shape .operators import SYMBOLIC_BINARY_OPS , SYMBOLIC_UNARY_OPS
50
51
from ...utils import (
51
52
ENV_SOT_ALLOW_DYNAMIC_SHAPE ,
53
+ NUMPY_API_SUPPORTED_DICT ,
52
54
NameGenerator ,
53
55
SIRToCodeMap ,
54
56
SotUndefinedVar ,
86
88
GlobalVariable ,
87
89
ListVariable ,
88
90
NullVariable ,
91
+ NumpyArrayVariable ,
89
92
PaddleLayerVariable ,
90
93
ParameterVariable ,
91
94
SymbolicVariable ,
99
102
if TYPE_CHECKING :
100
103
import types
101
104
105
+ GraphNodeVariableType : TypeAlias = Union [
106
+ TensorVariable , SymbolicVariable , NumpyArrayVariable
107
+ ]
108
+
102
109
103
110
CompileGraphResult : TypeAlias = Tuple [
104
111
Callable [..., Any ],
108
115
OrderedSet [Union [TensorVariable , SymbolicVariable ]],
109
116
],
110
117
]
118
+ GraphNodeVariableClasses = (
119
+ TensorVariable ,
120
+ SymbolicVariable ,
121
+ NumpyArrayVariable ,
122
+ )
111
123
112
124
113
125
def convert_to_meta (inputs : Any ):
@@ -116,7 +128,7 @@ def convert_to_meta(inputs: Any):
116
128
"""
117
129
118
130
def func (x ):
119
- if isinstance (x , ( TensorVariable , SymbolicVariable ) ):
131
+ if isinstance (x , GraphNodeVariableClasses ):
120
132
return x .meta
121
133
if isinstance (x , VariableBase ):
122
134
return x .get_py_value ()
@@ -131,7 +143,7 @@ def convert_to_symbol(inputs: Any):
131
143
"""
132
144
133
145
def func (x ):
134
- if isinstance (x , ( TensorVariable , SymbolicVariable ) ):
146
+ if isinstance (x , GraphNodeVariableClasses ):
135
147
return x .get_symbol ()
136
148
if isinstance (x , VariableBase ):
137
149
return x .get_py_value ()
@@ -155,7 +167,7 @@ def record_symbols(SIR, *args, **kwargs):
155
167
non_params = set ()
156
168
157
169
def fn (value ):
158
- if isinstance (value , ( TensorVariable , SymbolicVariable ) ):
170
+ if isinstance (value , GraphNodeVariableClasses ):
159
171
symbol_meta_map [value .get_symbol ()] = value .meta
160
172
if isinstance (value , ParameterVariable ):
161
173
params .add (value .get_symbol ())
@@ -190,6 +202,12 @@ def func(x):
190
202
return map_variables (func , inputs , restore_variable = True )
191
203
192
204
205
+ class APIType (Enum ):
206
+ PADDLE = 0
207
+ SYMBOLIC = 1
208
+ NUMPY = 2
209
+
210
+
193
211
class VariableLoader :
194
212
def __init__ (self , store_var_info , pycode_gen ):
195
213
self ._store_var_info = store_var_info
@@ -541,7 +559,34 @@ def message_handler(*args, **kwargs):
541
559
InferMetaCache (),
542
560
self .sir_builder .call_API ,
543
561
func ,
544
- False ,
562
+ APIType .PADDLE ,
563
+ * args ,
564
+ ** kwargs ,
565
+ )
566
+
567
+ def call_numpy_api (
568
+ self ,
569
+ func : Callable [..., Any ],
570
+ * args : VariableBase ,
571
+ ** kwargs : VariableBase ,
572
+ ):
573
+ """
574
+ Record Numpy API to SIR
575
+
576
+ Args:
577
+ func: numpy api
578
+ """
579
+ assert func in NUMPY_API_SUPPORTED_DICT .values ()
580
+ log (3 , f"call numpy.api : { func .__name__ } " , "\n " )
581
+
582
+ def message_handler (* args , ** kwargs ):
583
+ return f"Call numpy api error: { func .__name__ } , may be not a operator api?"
584
+
585
+ return inner_error_default_handler (self .symbolic_call , message_handler )(
586
+ InferMetaCache (),
587
+ self .sir_builder .call_API ,
588
+ func ,
589
+ APIType .NUMPY ,
545
590
* args ,
546
591
** kwargs ,
547
592
)
@@ -562,7 +607,7 @@ def message_handler(*args, **kwargs):
562
607
InferMetaCache (),
563
608
self .sir_builder .call_API ,
564
609
op ,
565
- True ,
610
+ APIType . SYMBOLIC ,
566
611
* args ,
567
612
** kwargs ,
568
613
)
@@ -584,7 +629,7 @@ def message_handler(*args, **kwargs):
584
629
InferMetaCache (),
585
630
self .sir_builder .call_METHOD ,
586
631
method_name ,
587
- False ,
632
+ APIType . PADDLE ,
588
633
* args ,
589
634
** kwargs ,
590
635
)
@@ -619,7 +664,7 @@ def message_handler(*args, **kwargs):
619
664
return f"Call paddle layer error: { layer } , may be not a valid paddle layer?"
620
665
621
666
return inner_error_default_handler (self .symbolic_call , message_handler )(
622
- infer_meta_fn , compute_fn , layer , False , * args , ** kwargs
667
+ infer_meta_fn , compute_fn , layer , APIType . PADDLE , * args , ** kwargs
623
668
)
624
669
625
670
def call_ast (
@@ -653,7 +698,7 @@ def message_handler(*args, **kwargs):
653
698
ast_infer_meta ,
654
699
compute_fn ,
655
700
static_function ,
656
- False ,
701
+ APIType . PADDLE ,
657
702
* args ,
658
703
** kwargs ,
659
704
)
@@ -662,7 +707,7 @@ def message_handler(*args, **kwargs):
662
707
return None
663
708
664
709
def symbolic_call (
665
- self , infer_meta_fn , compute_fn , func , is_symbolic_var , * args , ** kwargs
710
+ self , infer_meta_fn , compute_fn , func , api_type , * args , ** kwargs
666
711
):
667
712
"""
668
713
Using infer_meta_fn and compute_fn convert func to symbolic function.
@@ -763,11 +808,14 @@ def try_infer_meta_fn(args, kwargs) -> Any:
763
808
764
809
log (3 , f" inputs : { inputs_symbols } " , "\n " )
765
810
766
- if is_symbolic_var :
811
+ if api_type == APIType . SYMBOLIC :
767
812
var_cls = SymbolicVariable
768
813
tracker = SymbolicOperationTracker (
769
814
list (args ) + list (kwargs .values ()), func
770
815
)
816
+ elif api_type == APIType .NUMPY :
817
+ var_cls = NumpyArrayVariable
818
+ tracker = DummyTracker (list (args ) + list (kwargs .values ()))
771
819
else :
772
820
var_cls = TensorVariable
773
821
tracker = DummyTracker (list (args ) + list (kwargs .values ()))
@@ -807,7 +855,7 @@ def try_infer_meta_fn(args, kwargs) -> Any:
807
855
stmt_stacks ,
808
856
) # symbolic only contain symbols.
809
857
self ._put_inner (outputs )
810
- if is_symbolic_var :
858
+ if api_type == APIType . SYMBOLIC :
811
859
# compute_fn should be call_method
812
860
tracker = SymbolicOperationTracker (
813
861
list (args ) + list (kwargs .values ()), func
@@ -892,13 +940,13 @@ def remove_global_guarded_variable(self, variable: VariableBase):
892
940
893
941
def _find_tensor_inputs (
894
942
self , input_names : list [str ]
895
- ) -> OrderedSet [TensorVariable | SymbolicVariable ]:
896
- inputs : OrderedSet [TensorVariable | SymbolicVariable ] = OrderedSet ()
943
+ ) -> OrderedSet [GraphNodeVariableType ]:
944
+ inputs : OrderedSet [GraphNodeVariableType ] = OrderedSet ()
897
945
for name in input_names :
898
946
found = False
899
947
for variable in self .input_variables :
900
948
if (
901
- isinstance (variable , ( TensorVariable , SymbolicVariable ) )
949
+ isinstance (variable , GraphNodeVariableClasses )
902
950
and variable .get_symbol ().name == name
903
951
):
904
952
inputs .add (variable )
@@ -908,30 +956,37 @@ def _find_tensor_inputs(
908
956
assert len (inputs ) == len (input_names ), "Number of inputs not match."
909
957
return inputs
910
958
911
- def gen_load_inputs (
912
- self , inputs : OrderedSet [TensorVariable | SymbolicVariable ]
913
- ):
959
+ def gen_load_inputs (self , inputs : OrderedSet [GraphNodeVariableType ]):
914
960
for input_var in inputs :
915
- # For SymbolicVariable, we use paddle.full([], value, "int64")
916
- # to convert it to a Tensor
917
961
if isinstance (input_var , SymbolicVariable ):
962
+ # For SymbolicVariable, we use paddle.full([], value, "int64")
963
+ # to convert it to a Tensor
918
964
self .pycode_gen .gen_load_object (
919
965
paddle .full ,
920
966
"___paddle_full" ,
921
967
)
922
968
self .pycode_gen .gen_build_list (0 )
923
- input_var .tracker .gen_instructions (self .pycode_gen )
924
- if isinstance (input_var , SymbolicVariable ):
969
+ input_var .tracker .gen_instructions (self .pycode_gen )
925
970
self .pycode_gen .gen_load_const ("int64" )
926
971
self .pycode_gen .gen_call_function (3 )
972
+ elif isinstance (input_var , NumpyArrayVariable ):
973
+ # For NumpyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
974
+ self .pycode_gen .gen_load_object (
975
+ paddle .to_tensor ,
976
+ "___paddle_to_tensor" ,
977
+ )
978
+ input_var .tracker .gen_instructions (self .pycode_gen )
979
+ self .pycode_gen .gen_call_function (1 )
980
+ else :
981
+ input_var .tracker .gen_instructions (self .pycode_gen )
927
982
928
983
@staticmethod
929
984
def _is_graph_output (
930
985
var ,
931
- ) -> TypeGuard [TensorVariable | SymbolicVariable ]:
986
+ ) -> TypeGuard [GraphNodeVariableType ]:
932
987
return isinstance (
933
988
var .tracker , (DummyTracker , SymbolicOperationTracker )
934
- ) and isinstance (var , ( TensorVariable , SymbolicVariable ) )
989
+ ) and isinstance (var , GraphNodeVariableClasses )
935
990
936
991
@staticmethod
937
992
def _collect_related_dummy_tensor (var ):
@@ -949,17 +1004,15 @@ def _collect_related_dummy_tensor(var):
949
1004
950
1005
def _find_tensor_outputs (
951
1006
self , outputs : list [VariableBase ]
952
- ) -> OrderedSet [TensorVariable | SymbolicVariable ]:
1007
+ ) -> OrderedSet [GraphNodeVariableType ]:
953
1008
"""
954
1009
Return all TensorVariable. find TensorVariables participating in networking from the output Variables
955
1010
956
1011
Args:
957
1012
outputs: output variables
958
1013
"""
959
1014
960
- output_tensors : OrderedSet [TensorVariable | SymbolicVariable ] = (
961
- OrderedSet ()
962
- )
1015
+ output_tensors : OrderedSet [GraphNodeVariableType ] = OrderedSet ()
963
1016
# Find Tensor Variables from outputs.
964
1017
for output in outputs :
965
1018
if isinstance (
0 commit comments