Skip to content

Commit 832fcea

Browse files
committed
Support DataClassInstanceVariable
1 parent 433be34 commit 832fcea

File tree

6 files changed

+273
-3
lines changed

6 files changed

+273
-3
lines changed

python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ def step(self, instr: Instruction):
671671
assert opname != "CALL", "CALL should fused with PRECALL"
672672
with EventGuard(f"{opname}", event_level=2):
673673
try:
674+
# breakpoint()
674675
return getattr(self, opname)(instr) # run single step.
675676
except SotCapturedException as e:
676677
self.handle_exception(e)

python/paddle/jit/sot/opcode_translator/executor/side_effects.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
1818

1919
if TYPE_CHECKING:
20+
from collections.abc import Callable
21+
22+
from typing_extensions import TypeAlias
23+
2024
from .mutable_data import DataGetter, MutableData
2125
from .pycode_generator import PyCodeGen
2226
from .variables import VariableBase
2327

28+
IdGetter: TypeAlias = Callable[[Any], int]
2429
MutableDataT = TypeVar("MutableDataT", bound=MutableData)
2530

2631

@@ -51,8 +56,9 @@ def get_proxy(
5156
proxy_type: type[MutableDataT],
5257
data: Any,
5358
getter: DataGetter,
59+
id_getter: IdGetter = id,
5460
) -> MutableDataT:
55-
data_id = id(data)
61+
data_id = id_getter(data)
5662
if data_id not in self.data_id_to_proxy:
5763
self.data_id_to_proxy[data_id] = proxy_type(data, getter)
5864
return self.data_id_to_proxy[data_id] # type: ignore

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import inspect
1919
import math
2020
import operator
21+
from dataclasses import fields
2122
from functools import partial, reduce
2223
from typing import TYPE_CHECKING
2324

@@ -67,6 +68,7 @@
6768
CallableVariable,
6869
ConstantVariable,
6970
ContainerVariable,
71+
DataClassInstanceVariable,
7072
DictVariable,
7173
EnumerateVariable,
7274
EnumVariable,
@@ -1556,6 +1558,31 @@ def exception_variable_equal(left: ExceptionVariable, right: ExceptionVariable):
15561558
lambda left, right: exception_variable_equal(left, right),
15571559
)
15581560

1561+
1562+
@Dispatcher.register_decorator(operator.eq)
1563+
def dataclass_instance_eq(
1564+
lhs: DataClassInstanceVariable, rhs: DataClassInstanceVariable
1565+
):
1566+
if lhs.get_py_type() != rhs.get_py_type():
1567+
return ConstantVariable(False, lhs.graph, DummyTracker([lhs, rhs]))
1568+
1569+
return ConstantVariable(
1570+
all(
1571+
Dispatcher.call(
1572+
operator.eq, lhs.getattr(field.name), rhs.getattr(field.name)
1573+
)
1574+
for field in fields(lhs.get_py_type())
1575+
),
1576+
lhs.graph,
1577+
DummyTracker([lhs, rhs]),
1578+
)
1579+
1580+
1581+
@Dispatcher.register_decorator(operator.ne)
1582+
def dataclass_instance_ne(lhs: TupleVariable, rhs: TupleVariable):
1583+
return Dispatcher.call(operator.eq, lhs, rhs).bool_not()
1584+
1585+
15591586
Dispatcher.register(
15601587
operator.eq,
15611588
("EnumVariable", "EnumVariable"),

python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .basic import ( # noqa: F401
2222
CellVariable,
2323
ConstantVariable,
24+
DataClassInstanceVariable,
2425
DataVariable,
2526
DygraphTracerVariable,
2627
EnumVariable,

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import operator
1818
import sys
1919
import types
20+
from dataclasses import dataclass, is_dataclass
2021
from enum import Enum
2122
from functools import cached_property, reduce
2223
from typing import TYPE_CHECKING, Any
@@ -2396,3 +2397,133 @@ def get_enum_class_id(cls, enum_class: type[Enum]):
23962397
id = len(same_name_enums)
23972398
same_name_enums.append(enum_class)
23982399
return id
2400+
2401+
2402+
class DataClassInstanceVariable(VariableBase):
2403+
known_dataclasses = {}
2404+
2405+
def __init__(
2406+
self,
2407+
value,
2408+
graph: FunctionGraph,
2409+
tracker: Tracker,
2410+
):
2411+
super().__init__(graph=graph, tracker=tracker)
2412+
self.dataclass_type = type(value)
2413+
var_dict = value.__dict__
2414+
var_dict.update(
2415+
{'__post_init__': getattr(value, '__post_init__', None)}
2416+
)
2417+
self.proxy = self.graph.side_effects.get_proxy(
2418+
MutableDictLikeData,
2419+
var_dict,
2420+
self.proxy_getter,
2421+
id_getter=lambda _: id(value),
2422+
)
2423+
2424+
@cached_property
2425+
def attr_proxy(self):
2426+
return self.proxy
2427+
2428+
def getattr(self, name: str, default=None):
2429+
return self.proxy.get(name)
2430+
2431+
def setattr(self, key: str, value):
2432+
self.proxy.set(key, value)
2433+
self.graph.side_effects.record_proxy_variable(self)
2434+
return ConstantVariable.wrap_literal(None, self.graph)
2435+
2436+
def proxy_getter(self, proxy: MutableDictLikeData, key: Any):
2437+
if key not in proxy.original_data:
2438+
return MutableDictLikeData.Empty()
2439+
return VariableFactory.from_value(
2440+
proxy.original_data[key],
2441+
self.graph,
2442+
tracker=GetAttrTracker(self, key, changed=proxy.has_changed),
2443+
)
2444+
2445+
def get_py_type(self):
2446+
return self.dataclass_type
2447+
2448+
def get_py_value(self, allow_tensor=False):
2449+
return self.dataclass_type(
2450+
**{
2451+
key: value.get_py_value(allow_tensor)
2452+
for key, value in self.proxy.get_all().items()
2453+
}
2454+
)
2455+
2456+
@check_faster_guard
2457+
def make_faster_guard(self) -> list[paddle.framework.core.GuardNodeBase]:
2458+
expr_node = self.tracker.guard_tree_expr_node()
2459+
type_guard = paddle.framework.core.GuardNode(
2460+
paddle.framework.core.TypeMatchGuard(self.get_py_type()),
2461+
[expr_node],
2462+
)
2463+
guard_variables = filter(
2464+
lambda var: not isinstance(var, MutableDictLikeData.Empty),
2465+
self.proxy.reproduce(0).values(),
2466+
)
2467+
return reduce(
2468+
operator.add,
2469+
[[type_guard]]
2470+
+ [
2471+
item.make_faster_guard()
2472+
for item in guard_variables
2473+
if item.tracker.need_guard()
2474+
],
2475+
)
2476+
2477+
@check_guard
2478+
def make_stringified_guard(self) -> list[StringifiedExpression]:
2479+
data_class = self.get_py_type()
2480+
class_name = data_class.__name__
2481+
data_class_id = DataClassInstanceVariable.get_class_id(data_class)
2482+
extern_var_name = f"__{class_name}_{data_class_id}"
2483+
frame_value_tracer = self.tracker.trace_value_from_frame()
2484+
type_guard = FasterStringifiedExpression(
2485+
f"isinstance({{}}, {extern_var_name})",
2486+
paddle.framework.core.InstanceCheckGuard(self.get_py_type()),
2487+
[frame_value_tracer],
2488+
union_free_vars(
2489+
frame_value_tracer.free_vars,
2490+
{f"{extern_var_name}": self.get_py_type()},
2491+
),
2492+
)
2493+
guard_variables = filter(
2494+
lambda var: not isinstance(var, MutableDictLikeData.Empty),
2495+
self.proxy.reproduce(0).values(),
2496+
)
2497+
return reduce(
2498+
operator.add,
2499+
[[type_guard]]
2500+
+ [
2501+
item.make_stringified_guard()
2502+
for item in guard_variables
2503+
if item.tracker.need_guard()
2504+
],
2505+
)
2506+
2507+
@VariableFactory.register_from_value()
2508+
def from_value(value: dataclass, graph: FunctionGraph, tracker: Tracker):
2509+
if is_dataclass(value) and not isinstance(value, type):
2510+
var = DataClassInstanceVariable(value, graph=graph, tracker=tracker)
2511+
return var
2512+
return None
2513+
2514+
@classmethod
2515+
def get_class_id(cls, data_class: type[dataclass]):
2516+
class_name = data_class.__name__
2517+
DataClassInstanceVariable.known_dataclasses.setdefault(class_name, [])
2518+
same_name_dataclasses = DataClassInstanceVariable.known_dataclasses[
2519+
class_name
2520+
]
2521+
id = 0
2522+
for i, cls in enumerate(same_name_dataclasses):
2523+
if data_class == cls:
2524+
id = i
2525+
break
2526+
else:
2527+
id = len(same_name_dataclasses)
2528+
same_name_dataclasses.append(data_class)
2529+
return id

test/sot/test_dataclass.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import unittest
16-
from dataclasses import dataclass
18+
from dataclasses import dataclass, field
19+
from enum import IntEnum
20+
from typing import Callable
1721

18-
from test_case_base import TestCaseBase
22+
from test_case_base import (
23+
TestCaseBase,
24+
test_instruction_translator_cache_context,
25+
)
1926

2027
import paddle
28+
from paddle.jit.sot.psdb import check_no_breakgraph
2129
from paddle.jit.sot.utils import strict_mode_guard
2230

2331

@@ -54,5 +62,101 @@ def test_dtype_reconstruct_with_post_init(self):
5462
self.assert_results(return_dataclass_with_post_init, x)
5563

5664

65+
class DataType(IntEnum):
66+
FLOAT32 = 1
67+
FLOAT64 = 2
68+
INT32 = 3
69+
INT64 = 4
70+
71+
72+
@dataclass
73+
class DataMeta:
74+
x: paddle.Tensor
75+
y: paddle.Tensor = None
76+
z: DataType = DataType.FLOAT32
77+
m: list[list[paddle.Tensor]] = field(default_factory=list)
78+
n: int = 0
79+
f: Callable[[DataMeta], list] = None
80+
81+
def __post_init__(self):
82+
self.x += 1
83+
84+
85+
@check_no_breakgraph
86+
def is_eq(data: DataMeta, data2: DataMeta):
87+
return data == data2
88+
89+
90+
@check_no_breakgraph
91+
def get_attr(data: DataMeta):
92+
return data.x + data.y
93+
94+
95+
@check_no_breakgraph
96+
def set_attr(data: DataMeta):
97+
ori_x = data.x
98+
data.x = data.x + data.n
99+
res = data.x
100+
data.x = ori_x
101+
return res
102+
103+
104+
@check_no_breakgraph
105+
def callable_attr(data: DataMeta):
106+
return data.f(data)
107+
108+
109+
class TestDataClassInstance(TestCaseBase):
110+
def test_guard(self):
111+
d1 = Data(x=paddle.randn([1]))
112+
dm1 = DataMeta(x=paddle.randn([1]))
113+
dm2 = DataMeta(x=paddle.randn([1]))
114+
dm3 = DataMeta(x=paddle.zeros([1]))
115+
dm4 = DataMeta(x=paddle.randn([1]), z=DataType.INT32)
116+
dm5 = DataMeta(x=paddle.randn([1]), n=1)
117+
with test_instruction_translator_cache_context() as ctx:
118+
self.assertEqual(ctx.translate_count, 0)
119+
self.assert_results(is_eq, dm1, dm2)
120+
self.assertEqual(ctx.translate_count, 1)
121+
self.assert_results(is_eq, dm1, dm2)
122+
self.assertEqual(ctx.translate_count, 1)
123+
self.assert_results(is_eq, dm1, d1)
124+
self.assertEqual(ctx.translate_count, 2)
125+
self.assert_results(is_eq, dm1, dm3)
126+
self.assertEqual(ctx.translate_count, 2)
127+
self.assert_results(is_eq, dm1, dm4)
128+
self.assertEqual(ctx.translate_count, 3)
129+
self.assert_results(is_eq, dm1, dm5)
130+
self.assertEqual(ctx.translate_count, 4)
131+
132+
def test_get_attr(self):
133+
dm = DataMeta(x=paddle.randn([1, 2]), y=paddle.randn([1]))
134+
self.assert_results(get_attr, dm)
135+
136+
def test_set_attr(self):
137+
dm = DataMeta(x=paddle.ones([1, 2]), n=2)
138+
self.assert_results(set_attr, dm)
139+
140+
def test_callable_attr(self):
141+
142+
def process_func(data: DataMeta):
143+
return data.x.shape
144+
145+
dm = DataMeta(x=paddle.randn([1, 2]), f=process_func)
146+
self.assert_results(callable_attr, dm)
147+
148+
def test_eq(self):
149+
dm1 = DataMeta(x=paddle.randn([1]))
150+
dm2 = DataMeta(x=paddle.randn([1]))
151+
dm3 = DataMeta(x=paddle.zeros([1]))
152+
dm4 = DataMeta(x=paddle.randn([1]), z=DataType.INT32)
153+
self.assert_results(is_eq, dm1, dm2)
154+
self.assert_results(is_eq, dm1, dm3)
155+
self.assert_results(is_eq, dm1, dm4)
156+
# TODO(wangmingkai): operator.eq with args UserDefinedFunctionVariable
157+
# dm5 = DataMeta(x= paddle.randn([1]), f=lambda _: [])
158+
# self.assert_results(is_eq, dm1, dm5)
159+
160+
57161
if __name__ == "__main__":
58162
unittest.main()

0 commit comments

Comments
 (0)