Skip to content

Commit e005c84

Browse files
authored
EVoC algorithm (#57)
1 parent 8ffdb33 commit e005c84

4 files changed

Lines changed: 419 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ uvx marimo edit --sandbox <notebook>
3131
| FastAPI + GliNER | Zero-shot entity extraction as webapp, API, or CLI. | [![Open in molab](https://marimo.io/molab-shield.svg)](https://molab.marimo.io/github/marimo-team/gallery-examples/blob/main/notebooks/library/fastapi-gliner.py) |
3232
| Chemical Space Explorer | Explore chemical space with RDKit fingerprints, t-SNE, and HDBSCAN clustering. | [![Open in molab](https://marimo.io/molab-shield.svg)](https://molab.marimo.io/github/marimo-team/gallery-examples/blob/main/notebooks/library/chemical-space-explorer.py) |
3333
| Bayesian Regression | Interactive sequential Bayesian linear regression demo. | [![Open in molab](https://marimo.io/molab-shield.svg)](https://molab.marimo.io/github/marimo-team/gallery-examples/blob/main/notebooks/algorithms/bayesian-regression-demo.py) |
34+
| Nested Clusters with EVoC | Explore Fashion MNIST with EVoC hierarchical clusters, parallel coordinates, and a treemap. | [![Open in molab](https://marimo.io/molab-shield.svg)](https://molab.marimo.io/github/marimo-team/gallery-examples/blob/main/notebooks/algorithms/evoc-fashion.py) |
3435

3536
## Research papers
3637

notebooks/algorithms/__marimo__/session/evoc-fashion.py.json

Lines changed: 191 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# /// script
2+
# requires-python = ">=3.12"
3+
# dependencies = [
4+
# "marimo",
5+
# "polars==1.39.3",
6+
# "numpy==2.4.4",
7+
# "scikit-learn==1.8.0",
8+
# "wigglystuff==0.3.5",
9+
# "matplotlib==3.10.8",
10+
# "pandas==3.0.1",
11+
# "umap-learn==0.5.11",
12+
# "evoc==0.3.1",
13+
# ]
14+
# ///
15+
16+
import marimo
17+
18+
__generated_with = "0.23.2"
19+
app = marimo.App(width="medium")
20+
21+
22+
@app.cell
23+
def _():
24+
import marimo as mo
25+
import numpy as np
26+
import polars as pl
27+
from sklearn.datasets import fetch_openml
28+
from sklearn.decomposition import PCA
29+
from umap import UMAP
30+
from evoc import EVoC
31+
from wigglystuff import ParallelCoordinates
32+
import matplotlib.pyplot as plt
33+
34+
return EVoC, PCA, ParallelCoordinates, UMAP, fetch_openml, mo, np, pl, plt
35+
36+
37+
@app.cell(hide_code=True)
38+
def _(mo):
39+
mo.md(r"""
40+
# Nested Clusters with EVoC
41+
42+
This notebook loads the Fashion MNIST dataset, reduces the 784 pixel features
43+
down to a handful of components, and visualizes them with an interactive
44+
parallel coordinates plot. Use the brushes on each axis to filter and explore
45+
how different clothing categories separate in PCA/UMAP space.
46+
47+
But why stop there? You can also explore clustering methods like [EVoC](https://github.com/TutteInstitute/evoc) that give you a view into nested clusters. These make the parallel coordinates more interesting, but you can also explore them with other widgets as well.
48+
""")
49+
return
50+
51+
52+
@app.cell
53+
def _(fetch_openml, np):
54+
mnist = fetch_openml("Fashion-MNIST", version=1, as_frame=False, parser="auto")
55+
images = mnist.data.astype(np.float32)
56+
labels = mnist.target.astype(int)
57+
58+
label_names = {
59+
0: "T-shirt/top",
60+
1: "Trouser",
61+
2: "Pullover",
62+
3: "Dress",
63+
4: "Coat",
64+
5: "Sandal",
65+
6: "Shirt",
66+
7: "Sneaker",
67+
8: "Bag",
68+
9: "Ankle boot",
69+
}
70+
return images, label_names, labels
71+
72+
73+
@app.cell
74+
def _(
75+
PCA,
76+
UMAP,
77+
checkbox,
78+
images,
79+
label_names,
80+
labels,
81+
n_components_slider,
82+
n_samples_slider,
83+
np,
84+
pl,
85+
):
86+
rng = np.random.default_rng(42)
87+
idx = rng.choice(len(images), size=n_samples_slider.value, replace=False)
88+
89+
if checkbox.value:
90+
pca = UMAP(n_components=n_components_slider.value)
91+
else:
92+
pca = PCA(n_components=n_components_slider.value)
93+
94+
components = pca.fit_transform(images[idx])
95+
96+
df = pl.DataFrame(
97+
{f"PC{i + 1}": components[:, i] for i in range(n_components_slider.value)}
98+
).with_columns(pl.Series("label", [label_names[labels[i]] for i in idx]))
99+
return df, idx
100+
101+
102+
@app.cell(hide_code=True)
103+
def _(mo):
104+
n_samples_slider = mo.ui.slider(
105+
start=2500, stop=5000, step=500, value=2500, label="Number of samples"
106+
)
107+
n_components_slider = mo.ui.slider(start=3, stop=15, step=1, value=8, label="Components")
108+
checkbox = mo.ui.checkbox(label="UMAP")
109+
[n_samples_slider, n_components_slider, checkbox]
110+
return checkbox, n_components_slider, n_samples_slider
111+
112+
113+
@app.cell(hide_code=True)
114+
def _(ParallelCoordinates, df, mo):
115+
widget = mo.ui.anywidget(ParallelCoordinates(df, height=500, color_by="label"))
116+
widget
117+
return
118+
119+
120+
@app.cell(hide_code=True)
121+
def _(mo):
122+
mo.md(r"""
123+
## Now to EVoCe a new trick!
124+
125+
Let's now add the cluster layers to the chart. That already gives you an interesting idea on where you might be able to find clusters.
126+
""")
127+
return
128+
129+
130+
@app.cell
131+
def _(est):
132+
est.cluster_layers_
133+
return
134+
135+
136+
@app.cell(hide_code=True)
137+
def _(EVoC, ParallelCoordinates, df, idx, images, mo, np):
138+
est = EVoC(random_state=42)
139+
est.fit_predict(images[idx])
140+
141+
pltr = df.with_columns(
142+
c0=est.cluster_layers_[0] + np.random.random(est.cluster_layers_[0].shape[0]) / 1.2,
143+
c1=est.cluster_layers_[1] + np.random.random(est.cluster_layers_[0].shape[0]) / 1.2,
144+
c2=est.cluster_layers_[2] + np.random.random(est.cluster_layers_[0].shape[0]) / 1.2,
145+
)
146+
147+
evoc_widget = mo.ui.anywidget(ParallelCoordinates(pltr, height=500, color_by="label"))
148+
evoc_widget
149+
return est, evoc_widget
150+
151+
152+
@app.cell(hide_code=True)
153+
def _(evoc_widget, idx, images, label_names, labels, mo, np, plt):
154+
_filtered = evoc_widget.selected_indices
155+
_sample_idx = np.array(_filtered[:10]) if len(_filtered) >= 10 else np.array(_filtered)
156+
157+
if len(_sample_idx) == 0:
158+
mo.md("_Brush an axis above to preview up to 10 images from the selection._")
159+
else:
160+
_fig, _axes = plt.subplots(1, len(_sample_idx), figsize=(2 * len(_sample_idx), 2))
161+
if len(_sample_idx) == 1:
162+
_axes = [_axes]
163+
for _ax, _si in zip(_axes, _sample_idx):
164+
_ax.imshow(images[idx[_si]].reshape(28, 28), cmap="gray")
165+
_ax.set_title(label_names[labels[idx[_si]]], fontsize=9)
166+
_ax.axis("off")
167+
plt.tight_layout()
168+
_fig
169+
return
170+
171+
172+
@app.cell(hide_code=True)
173+
def _(mo):
174+
mo.md(r"""
175+
## Treemap
176+
177+
You can also explore this data using a treemap. That's what we do below.
178+
""")
179+
return
180+
181+
182+
@app.cell
183+
def _(df, est, mo, pl):
184+
from wigglystuff import Treemap, NestedTable
185+
186+
treemapped = df.select(c0=est.cluster_layers_[2], c1=est.cluster_layers_[1], c2=est.cluster_layers_[0], n=pl.lit(1), r=pl.row_index())
187+
188+
_agg = treemapped.group_by("c0", "c1", "c2").len().sort("len", descending=True)
189+
190+
treemap = mo.ui.anywidget(Treemap.from_dataframe(_agg, path_cols=["c0", "c1", "c2"], width="100%", height=500))
191+
treemap
192+
return treemap, treemapped
193+
194+
195+
@app.cell(hide_code=True)
196+
def _(idx, images, label_names, labels, mo, np, plt, subset):
197+
_filtered = subset["r"].to_list()
198+
_sample_idx = np.array(_filtered[:10]) if len(_filtered) >= 10 else np.array(_filtered)
199+
200+
if len(_sample_idx) == 0:
201+
mo.md("_Hover a treemap tile to preview up to 10 images from that cluster._")
202+
else:
203+
_fig, _axes = plt.subplots(1, len(_sample_idx), figsize=(2 * len(_sample_idx), 2))
204+
if len(_sample_idx) == 1:
205+
_axes = [_axes]
206+
for _ax, _si in zip(_axes, _sample_idx):
207+
_ax.imshow(images[idx[_si]].reshape(28, 28), cmap="gray")
208+
_ax.set_title(label_names[labels[idx[_si]]], fontsize=9)
209+
_ax.axis("off")
210+
plt.tight_layout()
211+
_fig
212+
return
213+
214+
215+
@app.cell
216+
def _(pl, treemap, treemapped):
217+
subset = treemapped
218+
for col, val in enumerate(treemap.hovered_path[1:]):
219+
subset = subset.filter(pl.col(f"c{col}") == int(val))
220+
return (subset,)
221+
222+
223+
if __name__ == "__main__":
224+
app.run()

scripts/create-sessions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def run_single(notebook_path: str, cli_args: dict, argv: list[str] | None) -> No
6565
print(" Warning: notebook had errors during execution")
6666

6767
cell_ids = list(file_manager.app.cell_manager.cell_ids())
68-
session_data = serialize_session_view(session_view, cell_ids)
68+
session_data = serialize_session_view(
69+
session_view, cell_ids, drop_virtual_file_outputs=True
70+
)
6971

7072
# Treat ModuleNotFoundError as a hard failure — the session cache
7173
# would be useless if a dependency is missing.

0 commit comments

Comments
 (0)