Skip to content

Commit 03b33b8

Browse files
authored
Add interactive installation of Paddle (#1212)
1 parent 5bd9fa4 commit 03b33b8

File tree

2 files changed

+222
-16
lines changed

2 files changed

+222
-16
lines changed

deepxde/backend/__init__.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import backend
99
from .set_default_backend import set_default_backend
10-
from .utils import verify_backend
10+
from .utils import interactive_install_paddle, verify_backend
1111

1212
_enabled_apis = set()
1313

@@ -47,15 +47,6 @@ def backend_message(backend_name):
4747

4848

4949
def load_backend(mod_name):
50-
if mod_name not in [
51-
"tensorflow.compat.v1",
52-
"tensorflow",
53-
"pytorch",
54-
"jax",
55-
"paddle",
56-
]:
57-
raise NotImplementedError("Unsupported backend: %s" % mod_name)
58-
5950
backend_message(mod_name)
6051
mod = importlib.import_module(".%s" % mod_name.replace(".", "_"), __name__)
6152
thismod = sys.modules[__name__]
@@ -110,12 +101,10 @@ def get_preferred_backend():
110101
return backend_name
111102

112103
# No backend selected
113-
print(
114-
"DeepXDE backend not selected. Use tensorflow.compat.v1.",
115-
file=sys.stderr,
116-
)
117-
set_default_backend("tensorflow.compat.v1")
118-
return "tensorflow.compat.v1"
104+
print("No backend selected.")
105+
interactive_install_paddle()
106+
set_default_backend("paddle")
107+
return "paddle"
119108

120109

121110
load_backend(get_preferred_backend())

deepxde/backend/utils.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import os
2+
import sys
3+
4+
5+
# Verify if the backend is available/importable
16
def import_tensorflow_compat_v1():
27
# pylint: disable=import-outside-toplevel
38
try:
@@ -75,3 +80,215 @@ def verify_backend(backend_name):
7580
raise RuntimeError(
7681
f"Backend is set as {backend_name}, but '{backend_name}' failed to import."
7782
)
83+
84+
85+
# Ask user if install paddle and install it
86+
def run_install(command):
87+
"""Send command to terminal and print it.
88+
89+
Args:
90+
command (str): command to be sent to terminal.
91+
"""
92+
print("Install command:", command)
93+
installed = os.system(command)
94+
if installed == 0:
95+
print("Paddle installed successfully!\n", file=sys.stderr, flush=True)
96+
else:
97+
sys.exit(
98+
"Paddle installed failed!\n"
99+
"Please visit https://www.paddlepaddle.org.cn/en for help and install it manually, "
100+
"or use another backend."
101+
)
102+
103+
104+
def get_platform():
105+
"""Get user's platform.
106+
107+
Returns:
108+
platform (str): "windows", "linux" or "darwin"
109+
"""
110+
if sys.platform in ["win32", "cygwin"]:
111+
return "windows"
112+
if sys.platform in ["linux", "linux2"]:
113+
return "linux"
114+
if sys.platform == "darwin":
115+
return "darwin"
116+
sys.exit(
117+
f"Your system {sys.platform} is not supported by Paddle. Paddle installation stopped.\n"
118+
"Please use another backend."
119+
)
120+
121+
122+
def get_cuda(platform):
123+
"""Check whether cuda is avaliable and get its version.
124+
125+
Returns:
126+
cuda_verion (str) or None
127+
"""
128+
if platform == "linux":
129+
cuda_list = [101, 102, 110, 111, 112, 116, 117, 118]
130+
elif platform == "windows":
131+
cuda_list = [101, 102, 110, 111, 112, 113, 114, 115, 116, 117, 118]
132+
nvcc_text = os.popen("nvcc -V").read()
133+
if nvcc_text != "":
134+
cuda_version = nvcc_text.split("Cuda compilation tools, release ")[-1].split(
135+
","
136+
)[0]
137+
version = int(float(cuda_version) * 10)
138+
if version not in cuda_list:
139+
cuda_list_str = [str(i / 10) for i in cuda_list]
140+
msg_cl = "/".join(cuda_list_str)
141+
print(
142+
f"Your CUDA version is {cuda_version},",
143+
f"but Paddle only supports CUDA {msg_cl} for {platform} now.",
144+
file=sys.stderr,
145+
flush=True,
146+
)
147+
else:
148+
return cuda_version
149+
150+
return None
151+
152+
153+
def get_rocm():
154+
"""Check whether ROCm4.0 is avaliable.
155+
156+
Returns:
157+
bool
158+
"""
159+
roc_text1 = os.popen("/opt/rocm/bin/rocminfo").read()
160+
roc_text2 = os.popen("/opt/rocm/opencl/bin/clinfo").read()
161+
if roc_text1 != "" and roc_text2 != "":
162+
return True
163+
164+
print("There is no avaliable ROCm4.0.", file=sys.stderr, flush=True)
165+
return False
166+
167+
168+
def check_avx(platform):
169+
"""Check whether avx is supported."""
170+
avx_text1 = avx_text2 = ""
171+
if platform == "darwin":
172+
avx_text1 = os.popen("sysctl machdep.cpu.features | grep -i avx").read()
173+
avx_text2 = os.popen("sysctl machdep.cpu.leaf7_features | grep -i avx").read()
174+
elif platform == "linux":
175+
avx_text1 = os.popen("cat /proc/cpuinfo | grep -i avx").read()
176+
elif platform == "windows":
177+
return
178+
179+
if avx_text1 == "" and avx_text2 == "":
180+
sys.exit(
181+
"Your machine doesn't support AVX, which is required by PaddlePaddle (develop version). "
182+
"Paddle installation stopped.\n"
183+
"Please use another backend."
184+
)
185+
186+
187+
def get_python_executable():
188+
"""Get user's python executable.
189+
190+
Returns:
191+
str: python exection path
192+
"""
193+
return sys.executable
194+
195+
196+
def generate_cmd(py_exec, platform, cuda_version=None, has_rocm=False):
197+
"""Generate command.
198+
199+
Args:
200+
py_exec (str): python executable path.
201+
platform (str): User's platform.
202+
cuda_version (str): Whether cuda is avaliable and its version if it is.
203+
has_rocm (bool): Whether ROCm4.0 has been installed.
204+
"""
205+
if platform == "darwin":
206+
print(
207+
"Paddle can only be installed in macOS with CPU version now. ",
208+
"Installing CPU version...",
209+
file=sys.stderr,
210+
flush=True,
211+
)
212+
cmd = "{}{}{}".format(
213+
py_exec,
214+
" -m pip install paddlepaddle==0.0.0 -f ",
215+
"https://www.paddlepaddle.org.cn/whl/mac/cpu/develop.html",
216+
)
217+
return cmd
218+
219+
if cuda_version is not None:
220+
print(f"Installing CUDA {cuda_version} version...", file=sys.stderr, flush=True)
221+
cmd = "{}{}{}{}{}{}".format(
222+
py_exec,
223+
" -m pip install paddlepaddle-gpu==0.0.0.post",
224+
int(float(cuda_version) * 10),
225+
" -f https://www.paddlepaddle.org.cn/whl/",
226+
platform,
227+
"/gpu/develop.html",
228+
)
229+
return cmd
230+
231+
if platform == "linux" and has_rocm:
232+
print("Installing ROCm4.0 version...", file=sys.stderr, flush=True)
233+
cmd = "{}{}{}".format(
234+
py_exec,
235+
" -m pip install --pre paddlepaddle-rocm -f ",
236+
"https://www.paddlepaddle.org.cn/whl/rocm/develop.html",
237+
)
238+
return cmd
239+
240+
print("Installing CPU version...", file=sys.stderr, flush=True)
241+
cmd = "{}{}".format(
242+
py_exec,
243+
" -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/",
244+
)
245+
if platform == "windows":
246+
cmd += "windows/cpu-mkl-avx/develop.html"
247+
elif platform == "linux":
248+
cmd += "linux/cpu-mkl/develop.html"
249+
return cmd
250+
251+
252+
def install_paddle():
253+
"""Generate command and install paddle."""
254+
# get user's platform
255+
platform = get_platform()
256+
# check avx
257+
check_avx(platform)
258+
# check python version
259+
py_exec = get_python_executable()
260+
261+
# get user's device and generate cmd
262+
if platform == "darwin":
263+
cmd = generate_cmd(py_exec, platform)
264+
else:
265+
cuda_version = get_cuda(platform)
266+
has_rocm = get_rocm() if platform == "linux" and cuda_version is None else False
267+
cmd = generate_cmd(py_exec, platform, cuda_version, has_rocm)
268+
269+
# run command
270+
run_install(cmd)
271+
272+
273+
def interactive_install_paddle():
274+
"""Ask the user for installing paddle."""
275+
try:
276+
notice = "Do you want to install the recommended backend Paddle (y/n): "
277+
msg = input(notice)
278+
except EOFError:
279+
msg = "n"
280+
281+
cnt = 0
282+
while cnt < 3:
283+
if msg == "y":
284+
install_paddle()
285+
return
286+
if msg == "n":
287+
break
288+
cnt += 1
289+
msg = input("Please enter correctly (y/n): ")
290+
291+
sys.exit(
292+
"No available backend found.\n"
293+
"Please manually install a backend, and run DeepXDE again."
294+
)

0 commit comments

Comments
 (0)