Skip to content

Commit 81591aa

Browse files
committed
Backend paddle: allow enable prim to accelerate running
1 parent 18400e5 commit 81591aa

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

deepxde/backend/paddle/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,16 @@
1-
from .tensor import * # pylint: disable=redefined-builtin
1+
from .tensor import * # pylint: disable=redefined-builtin
2+
import os
3+
4+
# enable prim if specified
5+
enable_prim_value = os.getenv("PRIM")
6+
enable_prim = enable_prim_value.lower() in ['1', 'true', 'yes', 'on'] if enable_prim_value else False
7+
if enable_prim:
8+
# Mostly for compiler running with dy2st.
9+
from paddle.framework import core
10+
11+
core.set_prim_eager_enabled(True)
12+
# The following protected member access is required.
13+
# There is no alternative public API available now.
14+
# pylint: disable=protected-access
15+
core._set_prim_all_enabled(True)
16+
print("Prim mode is enabled.")

0 commit comments

Comments
 (0)