4040
4141# Default float type
4242real = Real (32 )
43+ # Using mixed precision
44+ mixed = False
4345# Random seed
4446random_seed = None
4547if backend_name == "jax" :
@@ -71,11 +73,14 @@ def default_float():
7173def set_default_float (value ):
7274 """Sets the default float type.
7375
74- The default floating point type is 'float32'.
76+ The default floating point type is 'float32'. Mixed precision uses the method in the paper:
77+ `J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision.
78+ Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.
7579
7680 Args:
77- value (String): 'float16', 'float32', or 'float64' .
81+ value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision) .
7882 """
83+ global mixed
7984 if value == "float16" :
8085 print ("Set the default float type to float16" )
8186 real .set_float16 ()
@@ -85,6 +90,20 @@ def set_default_float(value):
8590 elif value == "float64" :
8691 print ("Set the default float type to float64" )
8792 real .set_float64 ()
93+ elif value == "mixed" :
94+ print ("Set the float type to mixed precision of float16 and float32" )
95+ mixed = True
96+ if backend_name == "tensorflow" :
97+ real .set_float16 ()
98+ tf .keras .mixed_precision .set_global_policy ("mixed_float16" )
99+ return # don't try to set it again below
100+ if backend_name == "pytorch" :
101+ # Use float16 during the forward and backward passes, but store in float32
102+ real .set_float32 ()
103+ else :
104+ raise ValueError (
105+ f"{ backend_name } backend does not currently support mixed precision."
106+ )
88107 else :
89108 raise ValueError (f"{ value } not supported in deepXDE" )
90109 if backend_name in ["tensorflow.compat.v1" , "tensorflow" ]:
0 commit comments