Skip to content

Commit 0d0be6a

Browse files
committed
Add set_backend() utility.
1 parent b97338e commit 0d0be6a

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

keras/utils/backend_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import copy
2+
import importlib
3+
import os
14
import sys
25

36
from keras import backend as backend_module
7+
from keras.api_export import keras_export
48
from keras.backend.common import global_state
59

610

@@ -82,3 +86,41 @@ def __getattr__(self, name):
8286
from keras import backend as numpy_backend
8387

8488
return getattr(numpy_backend, name)
89+
90+
91+
@keras_export("keras.config.set_backend")
92+
def set_backend(backend):
93+
"""Reload the backend (and the Keras package).
94+
95+
Example:
96+
97+
```python
98+
keras.config.set_backend("jax")
99+
```
100+
101+
Note that this will **NOT** convert the type of any already
102+
instantiated objects, except for the `keras` module itself.
103+
Thus, any layers / tensors / etc. already created will no
104+
longer be usable without errors. It is strongly recommended **not**
105+
to keep around **any** Keras-originated objects instances created
106+
before calling `set_backend()`.
107+
"""
108+
os.environ["KERAS_BACKEND"] = backend
109+
# Clear module cache.
110+
loaded_modules = [
111+
key for key in sys.modules.keys() if key.startswith("keras")
112+
]
113+
for key in loaded_modules:
114+
del sys.modules[key]
115+
# Reimport Keras with the new backend (set via KERAS_BACKEND).
116+
import keras
117+
118+
# Finally: refresh all imported Keras submodules.
119+
globs = copy.copy(globals())
120+
for key, value in globs.items():
121+
if value.__class__ == keras.__class__:
122+
if str(value).startswith("<module 'keras."):
123+
module_name = str(value)
124+
module_name = module_name[module_name.find("'") + 1 :]
125+
module_name = module_name[: module_name.find("'")]
126+
globals()[key] = importlib.import_module(module_name)

0 commit comments

Comments
 (0)