Skip to content

Commit c0809b4

Browse files
authored
Merge pull request #178 from JaxGaussianProcesses/init_params
` init_params` revamp, remove test from `./gpjax`
2 parents 3bbc8cb + 6f69f22 commit c0809b4

19 files changed

+218
-476
lines changed

docs/_api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ process objects.
3030
.. autoclass:: AbstractPrior
3131
:members:
3232
:special-members: __call__
33-
:private-members: _initialise_params
33+
:private-members: init_params
3434
:exclude-members: from_tuple, replace, to_tuple
3535

3636
.. autoclass:: AbstractPosterior

examples/graph_kernels.pct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
kernel = jk.GraphKernel(laplacian=L)
8686
prior = gpx.Prior(kernel=kernel)
8787

88-
true_params = prior._initialise_params(key)
88+
true_params = prior.init_params(key)
8989
true_params["kernel"] = {
9090
"lengthscale": jnp.array(2.3),
9191
"variance": jnp.array(3.2),
@@ -101,7 +101,7 @@
101101
kernel.compute_engine.gram
102102

103103
# %%
104-
kernel.gram(params=kernel._initialise_params(key), inputs=x)
104+
kernel.gram(params=kernel.init_params(key), inputs=x)
105105

106106
# %% [markdown]
107107
#

examples/haiku.pct.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,16 @@ def __call__(
107107

108108
def initialise(self, dummy_x: Float[Array, "1 D"], key: jr.KeyArray) -> None:
109109
nn_params = self.network.init(rng=key, x=dummy_x)
110-
base_kernel_params = self.base_kernel._initialise_params(key)
110+
base_kernel_params = self.base_kernel.init_params(key)
111111
self._params = {**nn_params, **base_kernel_params}
112112

113-
def _initialise_params(self, key: jr.KeyArray) -> Dict:
113+
def init_params(self, key: jr.KeyArray) -> Dict:
114114
return self._params
115115

116+
# This is depreciated. Can be removed once JaxKern is updated.
117+
def _initialise_params(self, key: jr.KeyArray) -> Dict:
118+
return self.init_params(key)
119+
116120

117121
# %% [markdown]
118122
# ### Defining a network

examples/kernels.pct.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# format_version: '1.3'
1010
# jupytext_version: 1.11.2
1111
# kernelspec:
12-
# display_name: Python 3.9.7 ('gpjax')
12+
# display_name: base
1313
# language: python
1414
# name: python3
1515
# ---
@@ -97,7 +97,7 @@
9797

9898
# %%
9999
print(f"ARD: {slice_kernel.ard}")
100-
print(f"Lengthscales: {slice_kernel._initialise_params(key)['lengthscale']}")
100+
print(f"Lengthscales: {slice_kernel.init_params(key)['lengthscale']}")
101101

102102
# %% [markdown]
103103
# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions.
@@ -107,7 +107,7 @@
107107
x_matrix = jr.normal(key, shape=(50, 5))
108108

109109
# Default parameter dictionary
110-
params = slice_kernel._initialise_params(key)
110+
params = slice_kernel.init_params(key)
111111

112112
# Compute the Gram matrix
113113
K = slice_kernel.gram(params, x_matrix)
@@ -127,9 +127,9 @@
127127
sum_k = k1 + k2
128128

129129
fig, ax = plt.subplots(ncols=3, figsize=(20, 5))
130-
im0 = ax[0].matshow(k1.gram(k1._initialise_params(key), x).to_dense())
131-
im1 = ax[1].matshow(k2.gram(k2._initialise_params(key), x).to_dense())
132-
im2 = ax[2].matshow(sum_k.gram(sum_k._initialise_params(key), x).to_dense())
130+
im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense())
131+
im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense())
132+
im2 = ax[2].matshow(sum_k.gram(sum_k.init_params(key), x).to_dense())
133133

134134
fig.colorbar(im0, ax=ax[0])
135135
fig.colorbar(im1, ax=ax[1])
@@ -144,10 +144,10 @@
144144
prod_k = k1 * k2 * k3
145145

146146
fig, ax = plt.subplots(ncols=4, figsize=(20, 5))
147-
im0 = ax[0].matshow(k1.gram(k1._initialise_params(key), x).to_dense())
148-
im1 = ax[1].matshow(k2.gram(k2._initialise_params(key), x).to_dense())
149-
im2 = ax[2].matshow(k3.gram(k3._initialise_params(key), x).to_dense())
150-
im3 = ax[3].matshow(prod_k.gram(prod_k._initialise_params(key), x).to_dense())
147+
im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense())
148+
im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense())
149+
im2 = ax[2].matshow(k3.gram(k3.init_params(key), x).to_dense())
150+
im3 = ax[3].matshow(prod_k.gram(prod_k.init_params(key), x).to_dense())
151151

152152
fig.colorbar(im0, ax=ax[0])
153153
fig.colorbar(im1, ax=ax[1])
@@ -218,9 +218,13 @@ def __call__(
218218
K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau
219219
return K.squeeze()
220220

221-
def _initialise_params(self, key: jr.PRNGKey) -> dict:
221+
def init_params(self, key: jr.KeyArray) -> dict:
222222
return {"tau": jnp.array([4.0])}
223223

224+
# This is depreciated. Can be removed once JaxKern is updated.
225+
def _initialise_params(self, key: jr.KeyArray) -> Dict:
226+
return self.init_params(key)
227+
224228

225229
# %% [markdown]
226230
# We unpack this now to make better sense of it. In the kernel's `__init__`

gpjax/config.py

Lines changed: 26 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -13,138 +13,33 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
import jax
17-
import distrax as dx
18-
import jax.numpy as jnp
19-
import jax.random as jr
20-
import tensorflow_probability.substrates.jax.bijectors as tfb
21-
from ml_collections import ConfigDict
2216

23-
__config = None
17+
import deprecation
2418

25-
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
26-
Softplus = dx.Lambda(
27-
forward=lambda x: jnp.log(1 + jnp.exp(x)),
28-
inverse=lambda x: jnp.log(jnp.exp(x) - 1.0),
19+
depreciate = deprecation.deprecated(
20+
deprecated_in="0.5.6",
21+
removed_in="0.6.0",
22+
details="Use method from jaxutils.config instead.",
2923
)
3024

31-
32-
def reset_global_config() -> None:
33-
global __config
34-
__config = get_default_config()
35-
36-
37-
def get_global_config() -> ConfigDict:
38-
"""Get the global config file used within GPJax.
39-
40-
Returns:
41-
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
42-
"""
43-
global __config
44-
45-
if __config is None:
46-
__config = get_default_config()
47-
return __config
48-
49-
# If the global config is available, check if the x64 state has changed
50-
x64_state = jax.config.x64_enabled
51-
52-
# If the x64 state has not changed, return the existing global config
53-
if x64_state is __config.x64_state:
54-
return __config
55-
56-
# If the x64 state has changed, return the updated global config
57-
update_x64_sensitive_settings()
58-
return __config
59-
60-
61-
def update_x64_sensitive_settings() -> None:
62-
"""Update the global config if x64 state changes."""
63-
global __config
64-
65-
# Update the x64 state
66-
x64_state = jax.config.x64_enabled
67-
__config.x64_state = x64_state
68-
69-
# Update the x64 sensitive bijectors
70-
FillScaleTriL = dx.Chain(
71-
[
72-
tfb.FillScaleTriL(diag_shift=jnp.array(__config.jitter)),
73-
]
74-
)
75-
76-
transformations = __config.transformations
77-
transformations.triangular_transform = FillScaleTriL
78-
79-
80-
def get_default_config() -> ConfigDict:
81-
"""Construct and return the default config file.
82-
83-
Returns:
84-
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
85-
"""
86-
87-
config = ConfigDict(type_safe=False)
88-
config.key = jr.PRNGKey(123)
89-
90-
# Set the x64 state
91-
config.x64_state = jax.config.x64_enabled
92-
93-
# Covariance matrix stabilising jitter
94-
config.jitter = 1e-6
95-
96-
FillScaleTriL = dx.Chain(
97-
[
98-
tfb.FillScaleTriL(diag_shift=jnp.array(config.jitter)),
99-
]
100-
)
101-
102-
# Default bijections
103-
config.transformations = transformations = ConfigDict()
104-
transformations.positive_transform = Softplus
105-
transformations.identity_transform = Identity
106-
transformations.triangular_transform = FillScaleTriL
107-
108-
# Default parameter transforms
109-
transformations.alpha = "positive_transform"
110-
transformations.lengthscale = "positive_transform"
111-
transformations.variance = "positive_transform"
112-
transformations.smoothness = "positive_transform"
113-
transformations.shift = "positive_transform"
114-
transformations.obs_noise = "positive_transform"
115-
transformations.latent = "identity_transform"
116-
transformations.basis_fns = "identity_transform"
117-
transformations.offset = "identity_transform"
118-
transformations.inducing_inputs = "identity_transform"
119-
transformations.variational_mean = "identity_transform"
120-
transformations.variational_root_covariance = "triangular_transform"
121-
transformations.natural_vector = "identity_transform"
122-
transformations.natural_matrix = "identity_transform"
123-
transformations.expectation_vector = "identity_transform"
124-
transformations.expectation_matrix = "identity_transform"
125-
126-
return config
127-
128-
129-
# This function is created for testing purposes only
130-
def get_global_config_if_exists() -> ConfigDict:
131-
"""Get the global config file used within GPJax if it is available.
132-
133-
Returns:
134-
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
135-
"""
136-
global __config
137-
return __config
138-
139-
140-
def add_parameter(param_name: str, bijection: dx.Bijector) -> None:
141-
"""Add a parameter and its corresponding transform to GPJax's config file.
142-
143-
Args:
144-
param_name (str): The name of the parameter that is to be added.
145-
bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value.
146-
"""
147-
lookup_name = f"{param_name}_transform"
148-
get_global_config()
149-
__config.transformations[lookup_name] = bijection
150-
__config.transformations[param_name] = lookup_name
25+
from jaxutils import config
26+
27+
Identity = config.Identity
28+
Softplus = config.Softplus
29+
reset_global_config = depreciate(config.reset_global_config)
30+
get_global_config = depreciate(config.get_global_config)
31+
get_default_config = depreciate(config.get_default_config)
32+
update_x64_sensitive_settings = depreciate(config.update_x64_sensitive_settings)
33+
get_global_config_if_exists = depreciate(config.get_global_config_if_exists)
34+
add_parameter = depreciate(config.add_parameter)
35+
36+
__all__ = [
37+
"Identity",
38+
"Softplus",
39+
"reset_global_config",
40+
"get_global_config",
41+
"get_default_config",
42+
"update_x64_sensitive_settings",
43+
"get_global_config_if_exists",
44+
"set_global_config",
45+
]

0 commit comments

Comments
 (0)