Skip to content

Commit aac198b

Browse files
authored
Add better parameter filtering (#530)
* Add better parameter filtering * Clean up conditional * Bump version
1 parent 35dabd7 commit aac198b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4429
-784
lines changed

docs/scripts/gen_examples.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616

1717
def process_file(file: Path, out_file: Path | None = None, execute: bool = False):
18-
"""Converts a python file to markdown using jupytext and nbconvert."""
18+
"""Converts a python file to markdown using jupytext and nbconvert.
19+
20+
Raises:
21+
subprocess.CalledProcessError: If the conversion fails.
22+
"""
1923

2024
out_dir = out_file.parent
2125
command = f"cd {out_dir.as_posix()} && "
@@ -30,7 +34,11 @@ def process_file(file: Path, out_file: Path | None = None, execute: bool = False
3034
else:
3135
command += f"jupytext --to markdown {file} --output {out_file}"
3236

33-
subprocess.run(command, shell=True, check=False)
37+
result = subprocess.run(command, shell=True, check=False, capture_output=True, text=True)
38+
if result.returncode != 0:
39+
error_msg = f"Failed to process {file.name}: {result.stderr}"
40+
print(error_msg)
41+
raise subprocess.CalledProcessError(result.returncode, command, output=result.stdout, stderr=result.stderr)
3442

3543

3644
def is_modified(file: Path, out_file: Path):
@@ -63,30 +71,51 @@ def main(args):
6371

6472
print(files)
6573

74+
# Track failures
75+
failures = []
76+
6677
# process files in parallel
6778
if args.parallel:
6879
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
69-
futures = []
80+
futures = {}
7081
for file in files:
7182
out_file = out_dir / f"{file.stem}.md"
72-
futures.append(
73-
executor.submit(
74-
process_file, file, out_file=out_file, execute=args.execute
75-
)
83+
future = executor.submit(
84+
process_file, file, out_file=out_file, execute=args.execute
7685
)
86+
futures[future] = file
7787

7888
for future in as_completed(futures):
89+
file = futures[future]
7990
try:
8091
future.result()
92+
print(f"Successfully processed: {file.name}")
8193
except Exception as e:
82-
print(f"Error processing file: {e}")
94+
print(f"Error processing {file.name}: {e}")
95+
failures.append((file, e))
8396
else:
8497
for file in files:
8598
out_file = out_dir / f"{file.stem}.md"
86-
process_file(file, out_file=out_file, execute=args.execute)
99+
try:
100+
process_file(file, out_file=out_file, execute=args.execute)
101+
print(f"Successfully processed: {file.name}")
102+
except Exception as e:
103+
print(f"Error processing {file.name}: {e}")
104+
failures.append((file, e))
105+
106+
# Report failures and exit with error code if any failed
107+
if failures:
108+
print(f"\n{len(failures)} file(s) failed to process:")
109+
for file, error in failures:
110+
print(f" - {file.name}")
111+
return 1 # Return non-zero exit code
112+
else:
113+
print(f"\nAll {len(files)} file(s) processed successfully!")
114+
return 0
87115

88116

89117
if __name__ == "__main__":
118+
import sys
90119
project_root = Path(__file__).parents[2]
91120

92121
parser = ArgumentParser()
@@ -99,4 +128,5 @@ def main(args):
99128
parser.add_argument("--parallel", type=bool, default=False)
100129
args = parser.parse_args()
101130

102-
main(args)
131+
exit_code = main(args)
132+
sys.exit(exit_code)

examples/backend.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# cell_metadata_filter: -all
5+
# custom_cell_magics: kql
6+
# text_representation:
7+
# extension: .py
8+
# format_name: percent
9+
# format_version: '1.3'
10+
# jupytext_version: 1.11.2
11+
# kernelspec:
12+
# display_name: .venv
13+
# language: python
14+
# name: python3
15+
# ---
16+
117
# %% [markdown]
218
# # Backend Module Design
319
#
@@ -116,7 +132,7 @@
116132
# the parameter's value using a tree map operation.
117133

118134
# %%
119-
print(constant_param._tag)
135+
print(constant_param.tag)
120136

121137
# %% [markdown]
122138
# For most users, you will not need to worry about this as we provide a set of default
@@ -126,7 +142,7 @@
126142
# see how you can define your own bijectors and parameter types.
127143

128144
# %%
129-
print(DEFAULT_BIJECTION[constant_param._tag])
145+
print(DEFAULT_BIJECTION[constant_param.tag])
130146

131147
# %% [markdown]
132148
# We see here that the Softplus bijector is specified as the default for strictly
@@ -229,7 +245,6 @@
229245
# altering the way in which we invoke `nnx.split`.
230246

231247
# %%
232-
233248
graphdef, positive_reals, other_params = nnx.split(posterior, PositiveReal, ...)
234249
print(positive_reals)
235250

examples/barycentres.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
with install_import_hook("gpjax", "beartype.beartype"):
5252
import gpjax as gpx
53+
from gpjax.parameters import Parameter
5354

5455

5556
key = jr.key(123)
@@ -179,6 +180,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> npd.MultivariateNormal:
179180
model=posterior,
180181
objective=nmll,
181182
train_data=D,
183+
trainable=Parameter,
182184
)
183185
latent_dist = opt_posterior.predict(xtest, train_data=D)
184186
return opt_posterior.likelihood(latent_dist)

examples/classification.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,18 @@
4040
import optax as ox
4141

4242
from examples.utils import use_mpl_style
43-
from gpjax.linalg import lower_cholesky, PSD, solve
43+
from gpjax.linalg import (
44+
PSD,
45+
lower_cholesky,
46+
solve,
47+
)
4448

4549
config.update("jax_enable_x64", True)
4650

4751

4852
with install_import_hook("gpjax", "beartype.beartype"):
4953
import gpjax as gpx
54+
from gpjax.parameters import Parameter
5055

5156

5257
identity_matrix = jnp.eye
@@ -119,7 +124,6 @@
119124

120125
# %%
121126
optimiser = ox.adam(learning_rate=0.01)
122-
123127
opt_posterior, history = gpx.fit(
124128
model=posterior,
125129
# we use the negative lpd as we are minimising
@@ -128,6 +132,7 @@
128132
optim=ox.adamw(learning_rate=0.01),
129133
num_iters=1000,
130134
key=key,
135+
trainable=Parameter, # train all parameters (default behavior)
131136
)
132137

133138
# %% [markdown]
@@ -224,7 +229,7 @@
224229

225230
# Negative Hessian, H = -∇²p_tilde(y|f):
226231
graphdef, params, *static_state = nnx.split(
227-
opt_posterior, gpx.parameters.Parameter, ...
232+
opt_posterior, Parameter, ...
228233
)
229234

230235

examples/collapsed_vi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# extension: .py
88
# format_name: percent
99
# format_version: '1.3'
10-
# jupytext_version: 1.11.2
10+
# jupytext_version: 1.17.3
1111
# kernelspec:
1212
# display_name: .venv
1313
# language: python
@@ -45,6 +45,7 @@
4545

4646
with install_import_hook("gpjax", "beartype.beartype"):
4747
import gpjax as gpx
48+
from gpjax.parameters import Parameter
4849

4950

5051
# set the default style for plotting
@@ -137,6 +138,7 @@
137138
# _optimise_ their location such that the evidence lower bound is maximised.
138139

139140
# %%
141+
# Use the enhanced fit API with trainable parameter filtering
140142
opt_posterior, history = gpx.fit(
141143
model=q,
142144
# we want want to minimize the *negative* ELBO
@@ -145,6 +147,7 @@
145147
optim=ox.adamw(learning_rate=1e-2),
146148
num_iters=500,
147149
key=key,
150+
trainable=Parameter,
148151
)
149152

150153
# %%

examples/constructing_new_kernels.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
from gpjax.parameters import (
4242
DEFAULT_BIJECTION,
4343
PositiveReal,
44-
Static,
4544
)
4645

4746
config.update("jax_enable_x64", True)
4847

4948

5049
with install_import_hook("gpjax", "beartype.beartype"):
5150
import gpjax as gpx
51+
from gpjax.parameters import Parameter
5252

5353

5454
# set the default style for plotting
@@ -249,7 +249,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
249249

250250

251251
class Polar(gpx.kernels.AbstractKernel):
252-
period: Static
252+
period: float
253253
tau: PositiveReal
254254

255255
def __init__(
@@ -260,13 +260,13 @@ def __init__(
260260
n_dims: int | None = None,
261261
):
262262
super().__init__(active_dims, n_dims, DenseKernelComputation())
263-
self.period = Static(jnp.array(period))
263+
self.period = jnp.array(period)
264264
self.tau = PositiveReal(jnp.array(tau), tag="polar")
265265

266266
def __call__(
267267
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
268268
) -> Float[Array, "1"]:
269-
c = self.period.value / 2.0
269+
c = self.period / 2.0
270270
t = angular_distance(x, y, c)
271271
K = (1 + self.tau.value * t / c) * jnp.clip(
272272
1 - t / c, 0, jnp.inf
@@ -315,6 +315,7 @@ def __call__(
315315
model=circular_posterior,
316316
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
317317
train_data=D,
318+
trainable=Parameter,
318319
)
319320

320321
# %% [markdown]

examples/deep_kernels.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
with install_import_hook("gpjax", "beartype.beartype"):
6161
import gpjax as gpx
6262
from gpjax.kernels.base import AbstractKernel
63+
from gpjax.parameters import (
64+
Parameter,
65+
)
6366

6467

6568
# set the default style for plotting
@@ -219,13 +222,18 @@ def __call__(self, x: jax.Array) -> jax.Array:
219222
ox.adamw(learning_rate=schedule),
220223
)
221224

225+
# Train all parameters (default behavior with trainable=Parameter)
226+
# Alternative options for selective training:
227+
# - trainable=PositiveReal # only train positive parameters
228+
# - trainable=lambda module, path, value: 'kernel' in path # only kernel params
222229
opt_posterior, history = gpx.fit(
223230
model=posterior,
224231
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
225232
train_data=D,
226233
optim=optimiser,
227234
num_iters=800,
228235
key=key,
236+
trainable=Parameter, # explicitly specify trainable filter (default)
229237
)
230238

231239
# %% [markdown]

examples/graph_kernels.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
with install_import_hook("gpjax", "beartype.beartype"):
4545
import gpjax as gpx
46+
from gpjax.parameters import Parameter
4647

4748

4849
# set the default style for plotting
@@ -179,6 +180,7 @@
179180
model=posterior,
180181
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
181182
train_data=D,
183+
trainable=Parameter,
182184
)
183185

184186
# %% [markdown]

examples/intro_to_kernels.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@
3737
from sklearn.preprocessing import StandardScaler
3838

3939
from examples.utils import use_mpl_style
40-
from gpjax.parameters import Static
4140
from gpjax.typing import Array
4241

4342
config.update("jax_enable_x64", True)
4443

4544

4645
with install_import_hook("gpjax", "beartype.beartype"):
4746
import gpjax as gpx
47+
from gpjax.parameters import Parameter
4848

4949

5050
key = jr.key(42)
@@ -264,7 +264,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821
264264
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
265265

266266
likelihood = gpx.likelihoods.Gaussian(
267-
num_datapoints=D.n, obs_stddev=Static(jnp.array(1e-3))
267+
num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
268268
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
269269

270270
no_opt_posterior = prior * likelihood
@@ -281,6 +281,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821
281281
model=no_opt_posterior,
282282
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
283283
train_data=D,
284+
trainable=Parameter,
284285
)
285286

286287

@@ -546,13 +547,17 @@ def loss(posterior, data):
546547
return -gpx.objectives.conjugate_mll(posterior, data)
547548

548549

550+
# Optimize all parameters. Alternative filtering strategies available:
551+
# - trainable=gpx.PositiveReal: train only positive parameters
552+
# - custom filters for specific parameter subsets
549553
opt_posterior, history = gpx.fit(
550554
model=posterior,
551555
objective=loss,
552556
train_data=D,
553557
optim=ox.adamw(learning_rate=1e-2),
554558
num_iters=500,
555559
key=key,
560+
trainable=Parameter, # train all parameters (default)
556561
)
557562

558563

0 commit comments

Comments
 (0)