|
30 | 30 | from numpyro.infer import ( |
31 | 31 | MCMC, |
32 | 32 | NUTS, |
| 33 | + Predictive, |
33 | 34 | ) |
34 | 35 |
|
35 | 36 | import gpjax as gpx |
|
67 | 68 | plt.scatter(x, y, label="Data", alpha=0.6) |
68 | 69 | plt.plot(x, y_clean, "k--", label="True Signal") |
69 | 70 | plt.legend() |
70 | | -plt.show() |
| 71 | +# plt.show() |
71 | 72 |
|
72 | 73 | # %% [markdown] |
73 | 74 | # ## Model Definition |
74 | 75 | # |
75 | 76 | # We define a GP model with a generic mean function (zero for now, as we will handle the linear trend explicitly in the Numpyro model) and a kernel that is the product of a periodic kernel and an RBF kernel. This choice reflects our prior knowledge that the signal is locally periodic. |
76 | 77 |
|
77 | 78 | # %% |
78 | | -kernel = gpx.kernels.RBF() * gpx.kernels.Periodic() |
| 79 | +# Define priors |
| 80 | +lengthscale_prior = dist.LogNormal(0.0, 1.0) |
| 81 | +variance_prior = dist.LogNormal(0.0, 1.0) |
| 82 | +period_prior = dist.LogNormal(0.0, 0.5) |
| 83 | +noise_prior = dist.LogNormal(0.0, 1.0) |
| 84 | + |
| 85 | +# Define Kernel with priors |
| 86 | +# We can explicitly attach priors to the parameters |
| 87 | +kernel = gpx.kernels.RBF( |
| 88 | + lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior), |
| 89 | + variance=gpx.parameters.PositiveReal(1.0, prior=variance_prior), |
| 90 | +) * gpx.kernels.Periodic( |
| 91 | + lengthscale=gpx.parameters.PositiveReal(1.0, prior=lengthscale_prior), |
| 92 | + period=gpx.parameters.PositiveReal(1.0, prior=period_prior), |
| 93 | +) |
| 94 | + |
79 | 95 | meanf = gpx.mean_functions.Zero() |
80 | 96 | prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) |
81 | 97 |
|
82 | 98 | # We will use a ConjugatePosterior since we assume Gaussian noise |
83 | | -likelihood = gpx.likelihoods.Gaussian(num_datapoints=N) |
| 99 | +likelihood = gpx.likelihoods.Gaussian( |
| 100 | + num_datapoints=N, |
| 101 | + obs_stddev=gpx.parameters.NonNegativeReal(1.0, prior=noise_prior), |
| 102 | +) |
84 | 103 | posterior = prior * likelihood |
85 | 104 |
|
86 | 105 | # We initialise the model parameters. |
@@ -111,7 +130,8 @@ def model(X, Y): |
111 | 130 | # 2. Register GP parameters |
112 | 131 | # This automatically samples parameters from the GPJax model |
113 | 132 | # and returns a model with updated values. |
114 | | - # We can specify custom priors if needed, but we'll rely on defaults here. |
| 133 | + # We attached priors to the parameters during model definition, |
| 134 | + # so register_parameters will use those. |
115 | 135 | # register_parameters modifies the model in-place (and returns it). |
116 | 136 | # Since Numpyro re-runs this function, we are overwriting the parameters |
117 | 137 | # of the same object repeatedly, which is fine as they are completely determined |
@@ -150,67 +170,59 @@ def model(X, Y): |
150 | 170 | samples = mcmc.get_samples() |
151 | 171 |
|
152 | 172 |
|
153 | | -# Helper to get predictions |
154 | | -def predict(rng_key, sample_idx): |
155 | | - # Reconstruct model with sampled values |
156 | | - |
157 | | - # Linear part |
158 | | - slope = samples["slope"][sample_idx] |
159 | | - intercept = samples["intercept"][sample_idx] |
160 | | - trend = slope * x + intercept |
| 173 | +def predict_fn(X_new, Y_train): |
| 174 | + # 1. Sample linear model parameters |
| 175 | + slope = numpyro.sample("slope", dist.Normal(0.0, 2.0)) |
| 176 | + intercept = numpyro.sample("intercept", dist.Normal(0.0, 2.0)) |
161 | 177 |
|
162 | | - # GP part |
163 | | - # We use numpyro.handlers.substitute to inject the sampled values into register_parameters |
164 | | - # to reconstruct the GP model state for this sample. |
165 | | - sample_dict = {k: v[sample_idx] for k, v in samples.items()} |
| 178 | + # Calculate residuals |
| 179 | + trend_train = slope * x + intercept |
| 180 | + residuals = Y_train - trend_train |
166 | 181 |
|
167 | | - with numpyro.handlers.substitute(data=sample_dict): |
168 | | - # We call register_parameters again to update the posterior object with this sample's values |
169 | | - p_posterior = register_parameters(posterior) |
| 182 | + # 2. Register GP parameters |
| 183 | + p_posterior = register_parameters(posterior) |
170 | 184 |
|
171 | | - # Now predict on residuals |
172 | | - residuals = y - trend |
| 185 | + # Create dataset for residuals |
173 | 186 | D_resid = gpx.Dataset(X=x, y=residuals) |
174 | 187 |
|
175 | | - latent_dist = p_posterior.predict(x, train_data=D_resid) |
176 | | - predictive_mean = latent_dist.mean |
177 | | - predictive_std = latent_dist.stddev() |
| 188 | + # 3. Compute latent GP distribution |
| 189 | + latent_dist = p_posterior.predict(X_new, train_data=D_resid) |
178 | 190 |
|
179 | | - return trend + predictive_mean, predictive_std |
| 191 | + # 4. Sample latent function values |
| 192 | + f = numpyro.sample("f", latent_dist) |
| 193 | + f = f.reshape((-1, 1)) |
180 | 194 |
|
| 195 | + # 5. Compute and return total prediction |
| 196 | + total_prediction = slope * X_new + intercept + f |
| 197 | + numpyro.deterministic("y_pred", total_prediction) |
| 198 | + return total_prediction |
181 | 199 |
|
182 | | -# Plot |
183 | | -plt.figure(figsize=(12, 6)) |
184 | | -plt.scatter(x, y, alpha=0.5, label="Data", color="gray") |
185 | | -plt.plot(x, y_clean, "k--", label="True Signal") |
186 | 200 |
|
187 | | -# Compute mean prediction (using mean of samples for efficiency) |
188 | | -mean_slope = jnp.mean(samples["slope"]) |
189 | | -mean_intercept = jnp.mean(samples["intercept"]) |
190 | | -mean_trend = mean_slope * x + mean_intercept |
| 201 | +# Create predictive utility |
| 202 | +predictive = Predictive(predict_fn, posterior_samples=samples) |
191 | 203 |
|
192 | | -mean_samples = {k: jnp.mean(v, axis=0) for k, v in samples.items()} |
193 | | -with numpyro.handlers.substitute(data=mean_samples): |
194 | | - p_posterior_mean = register_parameters(posterior) |
| 204 | +# Generate predictions |
| 205 | +predictions = predictive(jr.key(1), X_new=x, Y_train=y) |
| 206 | +y_pred = predictions["y_pred"] |
195 | 207 |
|
196 | | -residuals_mean = y - mean_trend |
197 | | -D_resid_mean = gpx.Dataset(X=x, y=residuals_mean) |
198 | | -latent_dist = p_posterior_mean.predict(x, train_data=D_resid_mean) |
199 | | -pred_mean = latent_dist.mean |
200 | | -pred_std = latent_dist.stddev() |
| 208 | +# Compute statistics |
| 209 | +mean_prediction = jnp.mean(y_pred, axis=0) |
| 210 | +std_prediction = jnp.std(y_pred, axis=0) |
201 | 211 |
|
202 | | -total_mean = mean_trend.flatten() + pred_mean.flatten() |
203 | | -std_flat = pred_std.flatten() |
| 212 | +# Plot |
| 213 | +plt.figure(figsize=(12, 6)) |
| 214 | +plt.scatter(x, y, alpha=0.5, label="Data", color="gray") |
| 215 | +plt.plot(x, y_clean, "k--", label="True Signal") |
204 | 216 |
|
205 | | -plt.plot(x, total_mean, "b-", label="Posterior Mean") |
| 217 | +plt.plot(x, mean_prediction, "b-", label="Posterior Mean") |
206 | 218 | plt.fill_between( |
207 | 219 | x.flatten(), |
208 | | - total_mean - 2 * std_flat, |
209 | | - total_mean + 2 * std_flat, |
| 220 | + mean_prediction.flatten() - 2 * std_prediction.flatten(), |
| 221 | + mean_prediction.flatten() + 2 * std_prediction.flatten(), |
210 | 222 | color="b", |
211 | 223 | alpha=0.2, |
212 | 224 | label="95% CI (GP Uncertainty)", |
213 | 225 | ) |
214 | 226 |
|
215 | 227 | plt.legend() |
216 | | -plt.show() |
| 228 | +# plt.show() |
0 commit comments