Skip to content

Commit 2236965

Browse files
authored
Backend TensorFlow: Fix POD basis dtype (#1620)
1 parent f403119 commit 2236965

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepxde/nn/tensorflow/deeponet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def __init__(
486486
regularization=None,
487487
):
488488
super().__init__()
489-
self.pod_basis = tf.convert_to_tensor(pod_basis, dtype=tf.float32)
489+
self.pod_basis = tf.convert_to_tensor(pod_basis, dtype=config.real(tf))
490490
if isinstance(activation, dict):
491491
activation_branch = activation["branch"]
492492
self.activation_trunk = activations.get(activation["trunk"])

0 commit comments

Comments
 (0)