Skip to content

Commit a5e6ef6

Browse files
committed
add mixed precision support to deepxde
1 parent e104988 commit a5e6ef6

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

deepxde/config.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def default_float():
7171
def set_default_float(value):
7272
"""Sets the default float type.
7373
74-
The default floating point type is 'float32'.
74+
The default floating point type is 'float32'. Mixed precision uses the method in the paper: `J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision. Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.
7575
7676
Args:
77-
value (String): 'float16', 'float32', or 'float64'.
77+
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).
7878
"""
7979
if value == "float16":
8080
print("Set the default float type to float16")
@@ -85,6 +85,20 @@ def set_default_float(value):
8585
elif value == "float64":
8686
print("Set the default float type to float64")
8787
real.set_float64()
88+
elif value == "mixed":
89+
print("Set the float type to mixed precision of float16 and float32")
90+
real.set_mixed()
91+
if backend_name == "tensorflow":
92+
real.set_float16()
93+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
94+
return # don't try to set it again below
95+
if backend_name == "pytorch":
96+
# Use float16 during the forward and backward passes, but store in float32
97+
real.set_float32()
98+
else:
99+
raise ValueError(
100+
f"{backend_name} backend does not currently support mixed precision."
101+
)
88102
else:
89103
raise ValueError(f"{value} not supported in deepXDE")
90104
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:

deepxde/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,12 @@ def closure():
372372
total_loss = torch.sum(losses)
373373
self.opt.zero_grad()
374374
total_loss.backward()
375-
return total_loss
376375

377-
self.opt.step(closure)
376+
def closure_mixed():
377+
with torch.autocast(device_type="cuda", dtype=torch.float16):
378+
closure()
379+
380+
self.opt.step(closure if not config.real.mixed else closure_mixed)
378381
if self.lr_scheduler is not None:
379382
self.lr_scheduler.step()
380383

deepxde/real.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class Real:
77
def __init__(self, precision):
88
self.precision = None
99
self.reals = None
10+
self.mixed = False
1011
if precision == 16:
1112
self.set_float16()
1213
elif precision == 32:
@@ -28,3 +29,6 @@ def set_float32(self):
2829
def set_float64(self):
2930
self.precision = 64
3031
self.reals = {np: np.float64, bkd.lib: bkd.float64}
32+
33+
def set_mixed(self):
34+
self.mixed = True

0 commit comments

Comments
 (0)