File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -935,6 +935,22 @@ def op(inputs):
935935 # Clear cached Jacobians and Hessians.
936936 grad .clear ()
937937 y = utils .to_numpy (y )
938+ elif backend_name == "jax" :
939+ if utils .get_num_args (operator ) == 2 :
940+
941+ @jax .jit
942+ def op (inputs ):
943+ y_fn = lambda _x : self .net .apply (self .net .params , _x )
944+ return operator (inputs , (y_fn (inputs ), y_fn ))
945+
946+ elif utils .get_num_args (operator ) == 3 :
947+ # TODO: JAX backend Implementation of Auxiliary variables.
948+ raise NotImplementedError (
949+ "Model.predict() with auxiliary variable hasn't been implemented "
950+ "for backend jax."
951+ )
952+ y = op (x )
953+ y = utils .to_numpy (y )
938954 elif backend_name == "paddle" :
939955 self .net .eval ()
940956 inputs = paddle .to_tensor (x , stop_gradient = False )
You can’t perform that action at this time.
0 commit comments