Skip to content

Commit 3b60094

Browse files
committed
transform warning
1 parent db40b9c commit 3b60094

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

gpjax/parameters.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def unpack(self):
3737
def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState:
3838
"""Initialise the stateful parameters of any GPJax object. This function also returns the trainability status of each parameter and set of bijectors that allow parameters to be constrained and unconstrained."""
3939
if key is None:
40-
warn("No PRNGKey specified. Defaulting to seed 123.")
40+
warn("No PRNGKey specified. Defaulting to seed 123.", UserWarning, stacklevel=2)
4141
key = jr.PRNGKey(123)
4242
params = model._initialise_params(key)
4343
if kwargs:
@@ -183,6 +183,12 @@ def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict:
183183
Returns:
184184
tp.Dict: A transformed parameter set.s The dictionary is equal in structure to the input params dictionary.
185185
"""
186+
warn(
187+
"`transform` will be deprecated in a future release. As of v0.5.0, please use `constrain`"
188+
" or `unconstrain` instead.",
189+
DeprecationWarning,
190+
stacklevel=2,
191+
)
186192
return jax.tree_util.tree_map(
187193
lambda param, trans: trans(param), params, transform_map
188194
)

0 commit comments

Comments
 (0)