Skip to content

Commit 4f8fb4f

Browse files
authored
Backend JAX supports model.predict (#1594)
1 parent b08a84e commit 4f8fb4f

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

deepxde/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,22 @@ def op(inputs):
935935
# Clear cached Jacobians and Hessians.
936936
grad.clear()
937937
y = utils.to_numpy(y)
938+
elif backend_name == "jax":
939+
if utils.get_num_args(operator) == 2:
940+
941+
@jax.jit
942+
def op(inputs):
943+
y_fn = lambda _x: self.net.apply(self.net.params, _x)
944+
return operator(inputs, (y_fn(inputs), y_fn))
945+
946+
elif utils.get_num_args(operator) == 3:
947+
# TODO: JAX backend Implementation of Auxiliary variables.
948+
raise NotImplementedError(
949+
"Model.predict() with auxiliary variable hasn't been implemented "
950+
"for backend jax."
951+
)
952+
y = op(x)
953+
y = utils.to_numpy(y)
938954
elif backend_name == "paddle":
939955
self.net.eval()
940956
inputs = paddle.to_tensor(x, stop_gradient=False)

0 commit comments

Comments
 (0)