Skip to content

Commit c0c4674

Browse files
committed
Align parameter tagging with Flax conventions
1 parent d1ebeae commit c0c4674

File tree

4 files changed

+11
-3
lines changed

4 files changed

+11
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,4 @@ node_modules/
153153

154154
docs/api
155155
docs/_examples
156+
local_libs/

examples/regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@
217217
# this, we use our defined `posterior` and `likelihood` at our test inputs to obtain
218218
# the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean`
219219
# and `stddev` can be used to extract the predictive mean and standard deviatation.
220-
#
220+
#
221221
# We are only concerned here about the variance between the test points and themselves, so
222222
# we can just copute the diagonal version of the covariance. We enforce this by using
223223
# `return_covariance_type = "diagonal"` in the `predict` call.

gpjax/parameters.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,14 @@ def __init__(self, value: T, tag: ParameterTag, **kwargs):
7777
_check_is_arraylike(value)
7878

7979
super().__init__(value=jnp.asarray(value), **kwargs)
80-
self.tag = tag
80+
81+
# nnx.Variable metadata must be set via set_metadata (direct setattr is disallowed).
82+
self.set_metadata(tag=tag)
83+
84+
@property
85+
def tag(self) -> ParameterTag:
86+
"""Return the parameter's constraint tag."""
87+
return self.metadata.get("tag", "real")
8188

8289

8390
class NonNegativeReal(Parameter[T]):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ help = "Check code formatting and style"
140140

141141
# Testing tasks
142142
[tool.poe.tasks.test]
143-
cmd = "pytest . -v -n 8 --beartype-packages='gpjax'"
143+
cmd = "pytest tests -v -n 8 --beartype-packages='gpjax'"
144144
help = "Run tests with pytest"
145145

146146
[tool.poe.tasks.coverage]

0 commit comments

Comments
 (0)