Skip to content

Commit 862737b

Browse files
committed
add pow support for jax
1 parent b185a24 commit 862737b

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

pinnicle/physics/physics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ..parameter import PhysicsParameter
44
from . import EquationBase
55
import itertools
6-
from ..utils import slice_column, jacobian
6+
from ..utils import slice_column, jacobian, ppow
77

88
class Physics:
99
""" All the physics in used as constraint in the PINN
@@ -80,7 +80,7 @@ def vel_mag(self, nn_input_var, nn_output_var, X):
8080
vid = self.output_var.index('v')
8181
u = slice_column(nn_output_var, uid)
8282
v = slice_column(nn_output_var, vid)
83-
vel = bkd.pow((bkd.square(u) + bkd.square(v) + 1.0e-30), 0.5)
83+
vel = ppow((bkd.square(u) + bkd.square(v) + 1.0e-30), 0.5)
8484
return vel
8585

8686
def surf_x(self, nn_input_var, nn_output_var, X):

pinnicle/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .helper import *
2-
from .backends_specified import slice_column, jacobian, slice_function_jax, matmul
2+
from .backends_specified import slice_column, jacobian, slice_function_jax, matmul, ppow
33
from .history import History
44
from .data_misfit import get
55
from .plotting import plot_solutions, plot_dict_data, plot_data, plot_nn, tripcolor_similarity, tripcolor_residuals, diffplot, resplot, plot_tracks
66
from .plotmodel import *
7-
from .data_interpolation import *
7+
from .data_interpolation import *

pinnicle/utils/backends_specified.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def matmul_jax(A, B):
4545
slice_column = slice_column_tf
4646
jacobian = jacobian_tf
4747
matmul = matmul_tf
48+
ppow = bkd.pow
4849
elif backend_name == "jax":
4950
slice_column = slice_column_jax
5051
jacobian = jacobian_jax
5152
matmul = matmul_jax
53+
ppow = jax.numpy.pow
5254
else:
5355
raise ValueError(f"{backend_name} is not supported by PINNICLE")

tests/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def test_slice_column():
8383
assert c.shape == (1,)
8484
assert c[0] == 2
8585

86+
def test_ppow():
87+
a = backend.as_tensor([2.0])
88+
c = pinnicle.utils.backends_specified.ppow(a, 2.0)
89+
assert c == 4.0
90+
91+
8692
def test_interpfrombedmachine():
8793
x = np.array([300025,301025,302025])
8894
y = np.array([-2579975, -2578975, -2577975])

0 commit comments

Comments
 (0)