Skip to content

Commit d1ebeae

Browse files
authored
Limit dtype promotion in Constmant mean (#573)
* test: Add failing test for Constant mean function dtype preservation (#523) * fix: Preserve dtype in Constant mean function (#523) * format: run formatter on mean tests
1 parent c5e1fd7 commit d1ebeae

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

gpjax/mean_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
147147
Float[Array, "1"]: The evaluated mean function.
148148
"""
149149
if isinstance(self.constant, Parameter):
150-
return jnp.ones((x.shape[0], 1)) * self.constant.value
150+
return jnp.ones((x.shape[0], 1), dtype=x.dtype) * self.constant.value
151151
else:
152-
return jnp.ones((x.shape[0], 1)) * self.constant
152+
return jnp.ones((x.shape[0], 1), dtype=x.dtype) * self.constant
153153

154154

155155
class Zero(Constant):

tests/test_mean_functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
Constant,
3939
Zero,
4040
)
41-
from gpjax.parameters import Parameter
41+
from gpjax.parameters import (
42+
Parameter,
43+
Real,
44+
)
4245

4346

4447
def test_abstract() -> None:
@@ -323,3 +326,15 @@ def test_zero_mean_function_uses_raw_value():
323326
result = meanf(x)
324327
expected = jnp.array([[0.0], [0.0], [0.0]])
325328
assert jnp.allclose(result, expected)
329+
330+
331+
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
332+
@pytest.mark.parametrize("partype", [Real, jnp.array])
333+
def test_constant_dtype_preservation(dtype, partype):
334+
"""Test that Constant mean function preserves dtype of the constant."""
335+
x = jnp.arange(5, dtype=dtype).reshape(-1, 1)
336+
constant = partype(jnp.array(3.0, dtype=dtype))
337+
mean_fn = Constant(constant)
338+
mean = mean_fn(x)
339+
340+
assert mean.dtype == dtype

0 commit comments

Comments
 (0)