-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplots.py
More file actions
225 lines (190 loc) · 8.36 KB
/
Copy pathplots.py
File metadata and controls
225 lines (190 loc) · 8.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
Visualization module for Bayesian Real Estate Intelligence.
Generates publication-quality plots using ArviZ and matplotlib.
All plot functions return (fig, ax) for composability.
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import arviz as az
# Style
plt.rcParams.update({
"figure.facecolor": "white",
"axes.facecolor": "#f8f9fa",
"axes.grid": True,
"grid.alpha": 0.3,
"font.size": 11,
"axes.titlesize": 13,
"axes.labelsize": 11,
})
PORTAL_COLORS = {
"habitaclia": "#2196F3",
"fotocasa": "#FF9800",
"milanuncios": "#4CAF50",
"idealista": "#E91E63",
}
def plot_shrinkage(shrinkage_df: pd.DataFrame, save_path: str = None):
"""Visualize partial pooling: portal estimates shrunk toward group mean."""
fig, ax = plt.subplots(figsize=(10, 5))
group_mean = shrinkage_df["group_mu_alpha"].iloc[0]
for i, row in shrinkage_df.iterrows():
color = PORTAL_COLORS.get(row["portal"], "#666")
ax.errorbar(
row["alpha_mean"], i,
xerr=[[row["alpha_mean"] - row["alpha_hdi_low"]],
[row["alpha_hdi_high"] - row["alpha_mean"]]],
fmt="o", color=color, capsize=5, markersize=8, linewidth=2,
label=row["portal"],
)
ax.axvline(group_mean, color="red", linestyle="--", alpha=0.7, label="Group mean")
ax.set_yticks(range(len(shrinkage_df)))
ax.set_yticklabels(shrinkage_df["portal"])
ax.set_xlabel("Intercept (log-price, standardized)")
ax.set_title("Hierarchical Shrinkage: Portal Intercepts toward Group Mean")
ax.legend(loc="lower right")
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, ax
def plot_posterior_comparison(trace, var_names, title="Posterior Distributions",
save_path=None):
"""Side-by-side posterior density plots."""
axes = az.plot_posterior(trace, var_names=var_names, figsize=(14, 4))
fig = axes.flatten()[0].get_figure()
fig.suptitle(title, fontsize=14, y=1.02)
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig
def plot_spatial_surface(grid_data: dict, df: pd.DataFrame, save_path: str = None):
"""Plot GP-predicted price surface with uncertainty."""
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Mean surface
im1 = axes[0].contourf(
grid_data["lon_grid"], grid_data["lat_grid"],
grid_data["f_mean"], levels=20, cmap="RdYlGn_r",
)
axes[0].scatter(df["lon"], df["lat"], c="black", s=5, alpha=0.3, zorder=5)
axes[0].set_title("Predicted Price Surface (mean)")
axes[0].set_xlabel("Longitude")
axes[0].set_ylabel("Latitude")
plt.colorbar(im1, ax=axes[0], label="Log-price (standardized)")
# Uncertainty surface
im2 = axes[1].contourf(
grid_data["lon_grid"], grid_data["lat_grid"],
grid_data["f_std"], levels=20, cmap="Oranges",
)
axes[1].scatter(df["lon"], df["lat"], c="black", s=5, alpha=0.3, zorder=5)
axes[1].set_title("Prediction Uncertainty (std)")
axes[1].set_xlabel("Longitude")
axes[1].set_ylabel("Latitude")
plt.colorbar(im2, ax=axes[1], label="Std deviation")
fig.suptitle("Gaussian Process Spatial Price Model", fontsize=14)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, axes
def plot_anomaly_scores(scores_df: pd.DataFrame, save_path: str = None):
"""Visualize anomaly detection results."""
fig = plt.figure(figsize=(14, 8))
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.35, wspace=0.3)
# 1. Anomaly score distribution
ax1 = fig.add_subplot(gs[0, 0])
ax1.hist(scores_df["p_anomaly"], bins=50, color="#2196F3", edgecolor="white", alpha=0.8)
ax1.axvline(0.5, color="red", linestyle="--", alpha=0.8, label="Decision boundary")
ax1.set_xlabel("P(anomaly)")
ax1.set_ylabel("Count")
ax1.set_title("Anomaly Score Distribution")
ax1.legend()
# 2. Price vs size colored by anomaly score
ax2 = fig.add_subplot(gs[0, 1])
scatter = ax2.scatter(
scores_df["size_m2"], scores_df["price"],
c=scores_df["p_anomaly"], cmap="RdYlGn_r", s=20, alpha=0.7,
)
plt.colorbar(scatter, ax=ax2, label="P(anomaly)")
ax2.set_xlabel("Size (m²)")
ax2.set_ylabel("Price (EUR)")
ax2.set_title("Listings by Anomaly Probability")
# 3. Top anomalies by portal
ax3 = fig.add_subplot(gs[1, 0])
flagged = scores_df[scores_df["is_flagged"]]
if len(flagged) > 0:
portal_counts = flagged["portal"].value_counts()
colors = [PORTAL_COLORS.get(p, "#666") for p in portal_counts.index]
ax3.barh(portal_counts.index, portal_counts.values, color=colors)
ax3.set_xlabel("Flagged listings")
ax3.set_title("Anomalies by Portal")
else:
ax3.text(0.5, 0.5, "No anomalies detected", ha="center", va="center")
# 4. Residual distribution with mixture
ax4 = fig.add_subplot(gs[1, 1])
normal_mask = ~scores_df["is_flagged"]
ax4.hist(scores_df.loc[normal_mask, "residual"], bins=40, alpha=0.6,
color="#4CAF50", label="Normal", density=True)
if len(flagged) > 0:
ax4.hist(scores_df.loc[~normal_mask, "residual"], bins=20, alpha=0.6,
color="#E91E63", label="Anomaly", density=True)
ax4.set_xlabel("Price residual")
ax4.set_title("Residual Distribution by Component")
ax4.legend()
fig.suptitle("Bayesian Anomaly Detection Results", fontsize=14)
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig
def plot_ppc(trace: az.InferenceData, model_name: str = "", save_path: str = None):
"""Posterior predictive check: overlay observed data with model-generated data."""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 1. KDE overlay (observed vs predicted)
az.plot_ppc(trace, kind="kde", num_pp_samples=200, ax=axes[0],
colors=["#2196F3", "#E91E63", "#333333"])
axes[0].set_title(f"PPC Density Overlay — {model_name}")
axes[0].set_xlabel("Observed value (standardized)")
# 2. Cumulative distribution comparison
az.plot_ppc(trace, kind="cumulative", num_pp_samples=200, ax=axes[1],
colors=["#2196F3", "#E91E63", "#333333"])
axes[1].set_title(f"PPC Cumulative — {model_name}")
axes[1].set_xlabel("Observed value (standardized)")
fig.suptitle("Posterior Predictive Check: Can the model reproduce the data?",
fontsize=13)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, axes
def plot_model_comparison_summary(models: dict, save_path: str = None):
"""Visual summary of model comparison and tractability."""
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
names = list(models.keys())
nuts_times = [getattr(m, "nuts_time", 0) for m in models.values()]
advi_times = [getattr(m, "advi_time", 0) for m in models.values()]
x = np.arange(len(names))
width = 0.35
axes[0].bar(x - width/2, nuts_times, width, label="NUTS (MCMC)", color="#2196F3")
if any(t > 0 for t in advi_times):
axes[0].bar(x + width/2, advi_times, width, label="ADVI (VI)", color="#FF9800")
axes[0].set_xticks(x)
axes[0].set_xticklabels(names, rotation=15)
axes[0].set_ylabel("Time (seconds)")
axes[0].set_title("Inference Time: MCMC vs VI")
axes[0].legend()
# ESS/s comparison
ess_rates = []
for m in models.values():
if hasattr(m, "trace") and m.trace is not None and hasattr(m, "nuts_time"):
summary = az.summary(m.trace, kind="diagnostics")
ess_rates.append(float(summary["ess_bulk"].median() / m.nuts_time))
else:
ess_rates.append(0)
colors = ["#4CAF50" if e > 50 else "#FF9800" if e > 20 else "#E91E63" for e in ess_rates]
axes[1].bar(names, ess_rates, color=colors)
axes[1].set_ylabel("ESS / second")
axes[1].set_title("Sampling Efficiency")
axes[1].axhline(50, color="green", linestyle="--", alpha=0.5, label="Good (>50)")
axes[1].axhline(20, color="orange", linestyle="--", alpha=0.5, label="Marginal (>20)")
axes[1].legend()
fig.suptitle("Tractability Analysis", fontsize=14)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig, axes