Skip to content

Commit 7b70b97

Browse files
committed
use torch.as_tensor to preserve gradients
1 parent 0c1738a commit 7b70b97

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

scoringrules/backend/registry.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ def __getitem__(self, __key: str) -> ArrayBackend:
5959
"""Get a backend from the registry."""
6060
try:
6161
return super().__getitem__(__key)
62-
except KeyError as err:
63-
raise BackendNotRegistered(
64-
f"The backend '{__key}' is not registered. "
65-
f"You can register it with scoringrules.register_backend('{__key}')"
66-
) from err
62+
except KeyError:
63+
self.register_backend(__key)
64+
return super().__getitem__(__key)
6765

6866
def set_active(self, backend: str):
6967
self._active = backend

scoringrules/backend/torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def asarray(
3232
*,
3333
dtype: Dtype | None = None,
3434
) -> "Tensor":
35-
return torch.asarray(obj, dtype=dtype)
35+
# torch.asarray(obj) would cancel gradients!
36+
return torch.as_tensor(obj, dtype=dtype)
3637

3738
def broadcast_arrays(self, *arrays: "Tensor") -> tuple["Tensor", ...]:
3839
return torch.broadcast_tensors(*arrays)

0 commit comments

Comments
 (0)