Skip to content

Commit b8db5be

Browse files
committed
add mixed precision support to deepxde
1 parent 0643941 commit b8db5be

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

deepxde/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,19 @@ def set_default_float(value):
7979
if value == "float16":
8080
print("Set the default float type to float16")
8181
real.set_float16()
82+
elif value == "mixed":
83+
print("Set training policy to mixed")
84+
real.set_mixed()
85+
if backend_name == "tensorflow":
86+
real.set_float16()
87+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
88+
elif backend_name == "pytorch":
89+
# we cast to float16 during the passes in the training loop, but store in float32
90+
real.set_float32()
91+
else:
92+
raise ValueError(
93+
f"{backend_name} backend does not currently support mixed precision in deepXDE"
94+
)
8295
elif value == "float32":
8396
print("Set the default float type to float32")
8497
real.set_float32()

deepxde/model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,20 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
353353

354354
def train_step(inputs, targets, auxiliary_vars):
355355
def closure():
356-
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
357-
total_loss = torch.sum(losses)
358-
self.opt.zero_grad()
359-
total_loss.backward()
356+
if config.real.mixed:
357+
with torch.autocast(device_type="cuda", dtype=torch.float16):
358+
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[
359+
1
360+
]
361+
total_loss = torch.sum(losses)
362+
# we do the backprop in float16
363+
self.opt.zero_grad()
364+
total_loss.backward()
365+
else:
366+
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
367+
total_loss = torch.sum(losses)
368+
self.opt.zero_grad()
369+
total_loss.backward()
360370
return total_loss
361371

362372
self.opt.step(closure)

deepxde/real.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class Real:
77
def __init__(self, precision):
88
self.precision = None
99
self.reals = None
10+
self.mixed = False
1011
if precision == 16:
1112
self.set_float16()
1213
elif precision == 32:
@@ -17,6 +18,9 @@ def __init__(self, precision):
1718
def __call__(self, package):
1819
return self.reals[package]
1920

21+
def set_mixed(self):
22+
self.mixed = True
23+
2024
def set_float16(self):
2125
self.precision = 16
2226
self.reals = {np: np.float16, bkd.lib: bkd.float16}

0 commit comments

Comments
 (0)