Skip to content

Commit e81258e

Browse files
authored
Backend JAX: Add norm (#1950)
1 parent 179c1b9 commit e81258e

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

deepxde/backend/jax/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def reduce_max(input_tensor):
165165
return jnp.max(input_tensor)
166166

167167

168+
def norm(tensor, ord=None, axis=None, keepdims=False):
169+
return jnp.linalg.norm(tensor, ord=ord, axis=axis, keepdims=keepdims)
170+
171+
168172
def zeros(shape, dtype):
169173
return jnp.zeros(shape, dtype=dtype)
170174

0 commit comments

Comments
 (0)