Skip to content

Commit 9a79988

Browse files
committed
add norm for jax backend
1 parent 18400e5 commit 9a79988

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

deepxde/backend/jax/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def reduce_max(input_tensor):
165165
return jnp.max(input_tensor)
166166

167167

168+
def norm(tensor, ord=None, axis=None, keepdims=False):
169+
return jnp.linalg.norm(tensor, ord=ord, dim=axis, keepdim=keepdims)
170+
171+
168172
def zeros(shape, dtype):
169173
return jnp.zeros(shape, dtype=dtype)
170174

0 commit comments

Comments
 (0)