Skip to content

Commit 586edce

Browse files
committed
GPU_ALLOW_GROWTH enabled in TF and TF v1
1 parent d1a896d commit 586edce

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import absolute_import
22

3+
import os
34
from distutils.version import LooseVersion
45

56
import tensorflow as tf
67

78

9+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
10+
811
if LooseVersion(tf.__version__) < LooseVersion("2.2.0"):
912
raise RuntimeError("DeepXDE requires tensorflow>=2.2.0.")

deepxde/backend/tensorflow_compat_v1/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import absolute_import
22

3+
import os
34
from distutils.version import LooseVersion
45

56
import tensorflow as tf
67

78

9+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
10+
811
if LooseVersion(tf.__version__) < LooseVersion("2.2.0"):
912
raise RuntimeError("DeepXDE requires tensorflow>=2.2.0.")
1013

deepxde/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,7 @@ def predict(self, x, operator=None, callbacks=None):
200200
def _open_tfsession(self):
201201
if self.sess is not None:
202202
return
203-
tfconfig = tf.ConfigProto()
204-
tfconfig.gpu_options.allow_growth = True
205-
self.sess = tf.Session(config=tfconfig)
203+
self.sess = tf.Session()
206204
self.saver = tf.train.Saver(max_to_keep=None)
207205
self.train_state.set_tfsession(self.sess)
208206

0 commit comments

Comments
 (0)