From 73bbf5cbdf1ac113403d354a5ab692c794c5a042 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 24 Sep 2024 15:05:23 +0100 Subject: [PATCH] Fix outer result dtype (#582) --- cubed/array_api/linalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cubed/array_api/linalg.py b/cubed/array_api/linalg.py index 91b25394..8ca57877 100644 --- a/cubed/array_api/linalg.py +++ b/cubed/array_api/linalg.py @@ -3,6 +3,7 @@ from cubed.array_api.array_object import Array # These functions are in both the main and linalg namespaces +from cubed.array_api.data_type_functions import result_type from cubed.array_api.linear_algebra_functions import ( # noqa: F401 matmul, matrix_transpose, @@ -15,7 +16,9 @@ def outer(x1, x2, /): - return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype) + return blockwise( + nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=result_type(x1, x2) + ) class QRResult(NamedTuple):