Skip to content

Commit a315365

Browse files
committed
add reshape functions to backends
1 parent 0a64f06 commit a315365

File tree

5 files changed

+16
-0
lines changed

5 files changed

+16
-0
lines changed

scoringrules/backend/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,7 @@ def cov(self, x: "Array", rowvar: bool, bias: bool) -> "Array":
310310
@abc.abstractmethod
311311
def det(self, x: "Array") -> "Array":
312312
"""Return the determinant of a matrix."""
313+
314+
@abc.abstractmethod
315+
def reshape(self, x: "Array", shape: int | tuple[int, ...]) -> "Array":
316+
"""Reshape an array to a new ``shape``."""

scoringrules/backend/jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def cov(self, x: "Array", rowvar: bool = True, bias: bool = False) -> "Array":
280280
def det(self, x: "Array") -> "Array":
281281
return jnp.linalg.det(x)
282282

283+
def reshape(self, x: "Array", shape: int | tuple[int, ...]) -> "Array":
284+
return jnp.reshape(x, shape)
285+
283286

284287
if __name__ == "__main__":
285288
B = JaxBackend()

scoringrules/backend/numpy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def cov(self, x: "NDArray", rowvar: bool = True, bias: bool = False) -> "NDArray
276276
def det(self, x: "NDArray") -> "NDArray":
277277
return np.linalg.det(x)
278278

279+
def reshape(self, x: "NDArray", shape: int | tuple[int, ...]) -> "NDArray":
280+
return np.reshape(x, shape=shape)
281+
279282

280283
class NumbaBackend(NumpyBackend):
281284
"""Numba backend."""

scoringrules/backend/tensorflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def cov(self, x: "Tensor", rowvar: bool = True, bias: bool = False) -> "Tensor":
321321
def det(self, x: "Tensor") -> "Tensor":
322322
return tf.linalg.det(x)
323323

324+
def reshape(self, x: "Tensor", shape: int | tuple[int, ...]) -> "Tensor":
325+
return tf.reshape(x, shape)
326+
324327

325328
if __name__ == "__main__":
326329
B = TensorflowBackend()

scoringrules/backend/torch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,6 @@ def cov(self, x: "Tensor", rowvar: bool = True, bias: bool = False) -> "Tensor":
300300

301301
def det(self, x: "Tensor") -> "Tensor":
302302
return torch.linalg.det(x)
303+
304+
def reshape(self, x: "Tensor", shape: int | tuple[int, ...]) -> "Tensor":
305+
return torch.reshape(x, shape)

0 commit comments

Comments
 (0)