Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions deepxde/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

# Default float type
real = Real(32)
# Using mixed precision
mixed = False
# Random seed
random_seed = None
if backend_name == "jax":
Expand Down Expand Up @@ -71,11 +73,14 @@ def default_float():
def set_default_float(value):
"""Sets the default float type.

The default floating point type is 'float32'.
The default floating point type is 'float32'. Mixed precision uses the method in the paper:
`J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision.
Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.

Args:
value (String): 'float16', 'float32', or 'float64'.
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).
"""
global mixed
if value == "float16":
print("Set the default float type to float16")
real.set_float16()
Expand All @@ -85,6 +90,20 @@ def set_default_float(value):
elif value == "float64":
print("Set the default float type to float64")
real.set_float64()
elif value == "mixed":
print("Set the float type to mixed precision of float16 and float32")
mixed = True
if backend_name == "tensorflow":
real.set_float16()
tf.keras.mixed_precision.set_global_policy("mixed_float16")
return # don't try to set it again below
if backend_name == "pytorch":
# Use float16 during the forward and backward passes, but store in float32
real.set_float32()
else:
raise ValueError(
f"{backend_name} backend does not currently support mixed precision."
)
else:
raise ValueError(f"{value} not supported in deepXDE")
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
Expand Down
6 changes: 5 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,11 @@ def closure():
total_loss.backward()
return total_loss

self.opt.step(closure)
def closure_mixed():
with torch.autocast(device_type=torch.get_default_device().type, dtype=torch.float16):
return closure()

self.opt.step(closure if not config.mixed else closure_mixed)
if self.lr_scheduler is not None:
self.lr_scheduler.step()

Expand Down
2 changes: 2 additions & 0 deletions docs/user/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ General usage
| **A**: `#5`_
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
| **A**: `#28`_
- | **Q**: How can I use mixed precision training?
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `this paper <https://doi.org/10.1016/j.cma.2024.117093>`_ for more information.
- | **Q**: I want to set the global random seeds.
| **A**: `#353`_
- | **Q**: GPU.
Expand Down