Skip to content

Commit caf547b

Browse files
authored
[SOT][NumPy] Complete the basic procedure (#72154)
1 parent 881b7ba commit caf547b

File tree

17 files changed

+651
-93
lines changed

17 files changed

+651
-93
lines changed

paddle/fluid/pybind/jit.cc

+7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ void BindGuard(pybind11::module *m) {
122122
"NumPyArrayValueMatchGuard",
123123
R"DOC(NumPyArrayValueMatchGuard Class.)DOC")
124124
.def(py::init<const py::object &>(), py::arg("array"));
125+
py::class_<NumPyArrayShapeMatchGuard,
126+
GuardBase,
127+
std::shared_ptr<NumPyArrayShapeMatchGuard>>(
128+
*m,
129+
"NumPyArrayShapeMatchGuard",
130+
R"DOC(NumPyArrayShapeMatchGuard Class.)DOC")
131+
.def(py::init<const std::vector<py::object> &>(), py::arg("shape"));
125132
py::class_<WeakRefMatchGuard, GuardBase, std::shared_ptr<WeakRefMatchGuard>>(
126133
*m, "WeakRefMatchGuard", R"DOC(WeakRefMatchGuard Class.)DOC")
127134
.def(py::init<const py::object &>(), py::arg("func"));

paddle/fluid/pybind/sot/guards.cc

+18
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,24 @@ bool NumPyArrayValueMatchGuard::check(PyObject* value) {
193193
.cast<bool>();
194194
}
195195

196+
bool NumPyArrayShapeMatchGuard::check(PyObject* value) {
197+
py::array array = py::reinterpret_borrow<py::array>(value);
198+
if (!array) {
199+
return false;
200+
}
201+
int ndim = array.ndim();
202+
auto shape = array.shape();
203+
if (ndim != static_cast<int>(expected_.size())) {
204+
return false;
205+
}
206+
for (int i = 0; i < ndim; ++i) {
207+
if (expected_[i].has_value() && shape[i] != expected_[i].value()) {
208+
return false;
209+
}
210+
}
211+
return true;
212+
}
213+
196214
bool WeakRefMatchGuard::check(PyObject* value) {
197215
if (value == nullptr || expected_ == nullptr || Py_IsNone(expected_)) {
198216
return false;

paddle/fluid/pybind/sot/guards.h

+24
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,30 @@ class NumPyArrayValueMatchGuard : public GuardBase {
253253
PyObject* expected_;
254254
};
255255

256+
class NumPyArrayShapeMatchGuard : public GuardBase {
257+
public:
258+
explicit NumPyArrayShapeMatchGuard(
259+
const std::vector<std::optional<int64_t>>& shape)
260+
: expected_(shape) {}
261+
262+
explicit NumPyArrayShapeMatchGuard(const std::vector<py::object>& shape) {
263+
expected_.resize(shape.size());
264+
for (size_t i = 0; i < shape.size(); ++i) {
265+
if (py::isinstance<py::int_>(shape[i]) && shape[i].cast<int64_t>() > 0) {
266+
expected_[i] = std::make_optional(shape[i].cast<int64_t>());
267+
}
268+
}
269+
}
270+
271+
bool check(PyObject* value) override;
272+
std::string get_guard_name() const override {
273+
return "NumPyArrayShapeMatchGuard";
274+
}
275+
276+
private:
277+
std::vector<std::optional<int64_t>> expected_;
278+
};
279+
256280
class WeakRefMatchGuard : public GuardBase {
257281
public:
258282
explicit WeakRefMatchGuard(const py::object& obj) {

python/paddle/jit/sot/infer_meta.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
import copy
1717
from functools import cached_property
18-
from typing import TypeVar
18+
from typing import TYPE_CHECKING, Any, TypeVar
1919

2020
import paddle
2121
from paddle.amp.auto_cast import amp_state
2222
from paddle.base.data_feeder import convert_dtype
23+
from paddle.base.framework import convert_np_dtype_to_dtype_
2324
from paddle.base.unique_name import (
2425
UniqueNameGenerator,
2526
guard as UniqueNameGuard,
@@ -46,6 +47,9 @@
4647
meta_str,
4748
)
4849

50+
if TYPE_CHECKING:
51+
import numpy.typing as npt
52+
4953
DynamicSymbolT = TypeVar("DynamicSymbolT")
5054
SOT_INFER_META_INNER_VAR = "___SOT_INFER_META_INNER_VAR"
5155

@@ -226,6 +230,28 @@ def from_value(value) -> MetaInfo:
226230
dist_info=dist_info,
227231
)
228232

233+
@staticmethod
234+
def from_numpy(
235+
nparray: npt.NDArray[Any], *, dynamic_axes: list[int] | None = None
236+
):
237+
dtype = convert_np_dtype_to_dtype_(nparray.dtype)
238+
dynamic_axes = dynamic_axes or []
239+
shape = [
240+
SymbolicInt() if i in dynamic_axes else dim
241+
for i, dim in enumerate(nparray.shape)
242+
]
243+
return MetaInfo(
244+
shape,
245+
dtype,
246+
True, # stop_gradient
247+
None,
248+
None, # persistable
249+
None,
250+
None,
251+
None,
252+
dist_info=None,
253+
)
254+
229255
def is_inner_var(self):
230256
return self.name == SOT_INFER_META_INNER_VAR
231257

python/paddle/jit/sot/opcode_translator/executor/function_graph.py

+80-27
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections import namedtuple
2222
from contextlib import contextmanager
2323
from copy import deepcopy
24+
from enum import Enum
2425
from functools import reduce
2526
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union
2627

@@ -49,6 +50,7 @@
4950
from ...symbolic_shape.operators import SYMBOLIC_BINARY_OPS, SYMBOLIC_UNARY_OPS
5051
from ...utils import (
5152
ENV_SOT_ALLOW_DYNAMIC_SHAPE,
53+
NUMPY_API_SUPPORTED_DICT,
5254
NameGenerator,
5355
SIRToCodeMap,
5456
SotUndefinedVar,
@@ -86,6 +88,7 @@
8688
GlobalVariable,
8789
ListVariable,
8890
NullVariable,
91+
NumpyArrayVariable,
8992
PaddleLayerVariable,
9093
ParameterVariable,
9194
SymbolicVariable,
@@ -99,6 +102,10 @@
99102
if TYPE_CHECKING:
100103
import types
101104

105+
GraphNodeVariableType: TypeAlias = Union[
106+
TensorVariable, SymbolicVariable, NumpyArrayVariable
107+
]
108+
102109

103110
CompileGraphResult: TypeAlias = Tuple[
104111
Callable[..., Any],
@@ -108,6 +115,11 @@
108115
OrderedSet[Union[TensorVariable, SymbolicVariable]],
109116
],
110117
]
118+
GraphNodeVariableClasses = (
119+
TensorVariable,
120+
SymbolicVariable,
121+
NumpyArrayVariable,
122+
)
111123

112124

113125
def convert_to_meta(inputs: Any):
@@ -116,7 +128,7 @@ def convert_to_meta(inputs: Any):
116128
"""
117129

118130
def func(x):
119-
if isinstance(x, (TensorVariable, SymbolicVariable)):
131+
if isinstance(x, GraphNodeVariableClasses):
120132
return x.meta
121133
if isinstance(x, VariableBase):
122134
return x.get_py_value()
@@ -131,7 +143,7 @@ def convert_to_symbol(inputs: Any):
131143
"""
132144

133145
def func(x):
134-
if isinstance(x, (TensorVariable, SymbolicVariable)):
146+
if isinstance(x, GraphNodeVariableClasses):
135147
return x.get_symbol()
136148
if isinstance(x, VariableBase):
137149
return x.get_py_value()
@@ -155,7 +167,7 @@ def record_symbols(SIR, *args, **kwargs):
155167
non_params = set()
156168

157169
def fn(value):
158-
if isinstance(value, (TensorVariable, SymbolicVariable)):
170+
if isinstance(value, GraphNodeVariableClasses):
159171
symbol_meta_map[value.get_symbol()] = value.meta
160172
if isinstance(value, ParameterVariable):
161173
params.add(value.get_symbol())
@@ -190,6 +202,12 @@ def func(x):
190202
return map_variables(func, inputs, restore_variable=True)
191203

192204

205+
class APIType(Enum):
206+
PADDLE = 0
207+
SYMBOLIC = 1
208+
NUMPY = 2
209+
210+
193211
class VariableLoader:
194212
def __init__(self, store_var_info, pycode_gen):
195213
self._store_var_info = store_var_info
@@ -541,7 +559,34 @@ def message_handler(*args, **kwargs):
541559
InferMetaCache(),
542560
self.sir_builder.call_API,
543561
func,
544-
False,
562+
APIType.PADDLE,
563+
*args,
564+
**kwargs,
565+
)
566+
567+
def call_numpy_api(
568+
self,
569+
func: Callable[..., Any],
570+
*args: VariableBase,
571+
**kwargs: VariableBase,
572+
):
573+
"""
574+
Record Numpy API to SIR
575+
576+
Args:
577+
func: numpy api
578+
"""
579+
assert func in NUMPY_API_SUPPORTED_DICT.values()
580+
log(3, f"call numpy.api : {func.__name__}", "\n")
581+
582+
def message_handler(*args, **kwargs):
583+
return f"Call numpy api error: {func.__name__}, may be not a operator api?"
584+
585+
return inner_error_default_handler(self.symbolic_call, message_handler)(
586+
InferMetaCache(),
587+
self.sir_builder.call_API,
588+
func,
589+
APIType.NUMPY,
545590
*args,
546591
**kwargs,
547592
)
@@ -562,7 +607,7 @@ def message_handler(*args, **kwargs):
562607
InferMetaCache(),
563608
self.sir_builder.call_API,
564609
op,
565-
True,
610+
APIType.SYMBOLIC,
566611
*args,
567612
**kwargs,
568613
)
@@ -584,7 +629,7 @@ def message_handler(*args, **kwargs):
584629
InferMetaCache(),
585630
self.sir_builder.call_METHOD,
586631
method_name,
587-
False,
632+
APIType.PADDLE,
588633
*args,
589634
**kwargs,
590635
)
@@ -619,7 +664,7 @@ def message_handler(*args, **kwargs):
619664
return f"Call paddle layer error: {layer}, may be not a valid paddle layer?"
620665

621666
return inner_error_default_handler(self.symbolic_call, message_handler)(
622-
infer_meta_fn, compute_fn, layer, False, *args, **kwargs
667+
infer_meta_fn, compute_fn, layer, APIType.PADDLE, *args, **kwargs
623668
)
624669

625670
def call_ast(
@@ -653,7 +698,7 @@ def message_handler(*args, **kwargs):
653698
ast_infer_meta,
654699
compute_fn,
655700
static_function,
656-
False,
701+
APIType.PADDLE,
657702
*args,
658703
**kwargs,
659704
)
@@ -662,7 +707,7 @@ def message_handler(*args, **kwargs):
662707
return None
663708

664709
def symbolic_call(
665-
self, infer_meta_fn, compute_fn, func, is_symbolic_var, *args, **kwargs
710+
self, infer_meta_fn, compute_fn, func, api_type, *args, **kwargs
666711
):
667712
"""
668713
Using infer_meta_fn and compute_fn convert func to symbolic function.
@@ -763,11 +808,14 @@ def try_infer_meta_fn(args, kwargs) -> Any:
763808

764809
log(3, f" inputs : {inputs_symbols}", "\n")
765810

766-
if is_symbolic_var:
811+
if api_type == APIType.SYMBOLIC:
767812
var_cls = SymbolicVariable
768813
tracker = SymbolicOperationTracker(
769814
list(args) + list(kwargs.values()), func
770815
)
816+
elif api_type == APIType.NUMPY:
817+
var_cls = NumpyArrayVariable
818+
tracker = DummyTracker(list(args) + list(kwargs.values()))
771819
else:
772820
var_cls = TensorVariable
773821
tracker = DummyTracker(list(args) + list(kwargs.values()))
@@ -807,7 +855,7 @@ def try_infer_meta_fn(args, kwargs) -> Any:
807855
stmt_stacks,
808856
) # symbolic only contain symbols.
809857
self._put_inner(outputs)
810-
if is_symbolic_var:
858+
if api_type == APIType.SYMBOLIC:
811859
# compute_fn should be call_method
812860
tracker = SymbolicOperationTracker(
813861
list(args) + list(kwargs.values()), func
@@ -892,13 +940,13 @@ def remove_global_guarded_variable(self, variable: VariableBase):
892940

893941
def _find_tensor_inputs(
894942
self, input_names: list[str]
895-
) -> OrderedSet[TensorVariable | SymbolicVariable]:
896-
inputs: OrderedSet[TensorVariable | SymbolicVariable] = OrderedSet()
943+
) -> OrderedSet[GraphNodeVariableType]:
944+
inputs: OrderedSet[GraphNodeVariableType] = OrderedSet()
897945
for name in input_names:
898946
found = False
899947
for variable in self.input_variables:
900948
if (
901-
isinstance(variable, (TensorVariable, SymbolicVariable))
949+
isinstance(variable, GraphNodeVariableClasses)
902950
and variable.get_symbol().name == name
903951
):
904952
inputs.add(variable)
@@ -908,30 +956,37 @@ def _find_tensor_inputs(
908956
assert len(inputs) == len(input_names), "Number of inputs not match."
909957
return inputs
910958

911-
def gen_load_inputs(
912-
self, inputs: OrderedSet[TensorVariable | SymbolicVariable]
913-
):
959+
def gen_load_inputs(self, inputs: OrderedSet[GraphNodeVariableType]):
914960
for input_var in inputs:
915-
# For SymbolicVariable, we use paddle.full([], value, "int64")
916-
# to convert it to a Tensor
917961
if isinstance(input_var, SymbolicVariable):
962+
# For SymbolicVariable, we use paddle.full([], value, "int64")
963+
# to convert it to a Tensor
918964
self.pycode_gen.gen_load_object(
919965
paddle.full,
920966
"___paddle_full",
921967
)
922968
self.pycode_gen.gen_build_list(0)
923-
input_var.tracker.gen_instructions(self.pycode_gen)
924-
if isinstance(input_var, SymbolicVariable):
969+
input_var.tracker.gen_instructions(self.pycode_gen)
925970
self.pycode_gen.gen_load_const("int64")
926971
self.pycode_gen.gen_call_function(3)
972+
elif isinstance(input_var, NumpyArrayVariable):
973+
# For NumpyArrayVariable, we use paddle.to_tensor(value) to convert it to a Tensor
974+
self.pycode_gen.gen_load_object(
975+
paddle.to_tensor,
976+
"___paddle_to_tensor",
977+
)
978+
input_var.tracker.gen_instructions(self.pycode_gen)
979+
self.pycode_gen.gen_call_function(1)
980+
else:
981+
input_var.tracker.gen_instructions(self.pycode_gen)
927982

928983
@staticmethod
929984
def _is_graph_output(
930985
var,
931-
) -> TypeGuard[TensorVariable | SymbolicVariable]:
986+
) -> TypeGuard[GraphNodeVariableType]:
932987
return isinstance(
933988
var.tracker, (DummyTracker, SymbolicOperationTracker)
934-
) and isinstance(var, (TensorVariable, SymbolicVariable))
989+
) and isinstance(var, GraphNodeVariableClasses)
935990

936991
@staticmethod
937992
def _collect_related_dummy_tensor(var):
@@ -949,17 +1004,15 @@ def _collect_related_dummy_tensor(var):
9491004

9501005
def _find_tensor_outputs(
9511006
self, outputs: list[VariableBase]
952-
) -> OrderedSet[TensorVariable | SymbolicVariable]:
1007+
) -> OrderedSet[GraphNodeVariableType]:
9531008
"""
9541009
Return all TensorVariable. find TensorVariables participating in networking from the output Variables
9551010
9561011
Args:
9571012
outputs: output variables
9581013
"""
9591014

960-
output_tensors: OrderedSet[TensorVariable | SymbolicVariable] = (
961-
OrderedSet()
962-
)
1015+
output_tensors: OrderedSet[GraphNodeVariableType] = OrderedSet()
9631016
# Find Tensor Variables from outputs.
9641017
for output in outputs:
9651018
if isinstance(

0 commit comments

Comments
 (0)