Skip to content

Commit 18dc7a6

Browse files
committed
Improve 3D UMAP GIF: dark theme, elevation oscillation, higher quality
- Dark background (#0D1117) matches GitHub dark mode - Gentle elevation oscillation (20-35°) adds depth to the rotation - Larger point size, depth shading, rare populations rendered on top - Legend repositioned to lower right, dark-themed - Random seed fixed for reproducibility - Optimised with gifsicle: 1.8MB (was 0.7MB at low quality)
1 parent d27a46b commit 18dc7a6

2 files changed

Lines changed: 38 additions & 29 deletions

File tree

docs/umap_3d_rotation.gif

1.09 MB
Loading

scripts/make_3d_umap_gif.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import scanpy as sc
44
import matplotlib.pyplot as plt
5+
import numpy as np
56
import imageio.v3 as iio
67
from pathlib import Path
78
from io import BytesIO
@@ -10,57 +11,64 @@
1011
RESULTS_DIR = Path("results")
1112
DOCS_DIR = Path("docs")
1213

13-
1414
N_FRAMES = 120
1515
FPS = 24
1616

1717

1818
def compute_3d_umap(adata):
19-
"""Compute 3D UMAP embedding."""
20-
sc.tl.umap(adata, n_components=3)
19+
"""Compute 3D UMAP embedding with fixed seed."""
20+
sc.tl.umap(adata, n_components=3, random_state=42)
2121
return adata.obsm["X_umap"]
2222

2323

24-
def render_frame(coords, cell_types, azim, elev=25):
25-
"""Render a single frame of the 3D UMAP at a given azimuth angle."""
26-
fig = plt.figure(figsize=(8, 6), facecolor="white")
27-
ax = fig.add_subplot(111, projection="3d", facecolor="white")
24+
def render_frame(coords, cell_types, azim, elev):
25+
"""Render a single frame of the 3D UMAP."""
26+
fig = plt.figure(figsize=(7, 7), facecolor="#0D1117")
27+
ax = fig.add_subplot(111, projection="3d", facecolor="#0D1117")
28+
29+
# Sort cell types so smaller populations render on top
30+
type_order = cell_types.value_counts().index[::-1]
2831

29-
for ct in cell_types.cat.categories:
32+
for ct in type_order:
3033
mask = cell_types == ct
34+
colour = PALETTE.get(ct, "#AAAAAA")
3135
ax.scatter(
3236
coords[mask, 0], coords[mask, 1], coords[mask, 2],
33-
c=PALETTE.get(ct, "#AAAAAA"),
34-
s=6, alpha=0.8, label=ct, edgecolors="none",
37+
c=colour, s=10, alpha=0.85, label=ct,
38+
edgecolors="none", rasterized=True, depthshade=True,
3539
)
3640

3741
ax.view_init(elev=elev, azim=azim)
42+
43+
# Clean axes — no ticks, no panes, no grid
3844
ax.set_xticks([])
3945
ax.set_yticks([])
4046
ax.set_zticks([])
4147
ax.xaxis.pane.fill = False
4248
ax.yaxis.pane.fill = False
4349
ax.zaxis.pane.fill = False
44-
ax.xaxis.pane.set_edgecolor("#EEEEEE")
45-
ax.yaxis.pane.set_edgecolor("#EEEEEE")
46-
ax.zaxis.pane.set_edgecolor("#EEEEEE")
47-
ax.xaxis.line.set_color("#CCCCCC")
48-
ax.yaxis.line.set_color("#CCCCCC")
49-
ax.zaxis.line.set_color("#CCCCCC")
50-
ax.grid(True, alpha=0.15)
51-
52-
ax.legend(
53-
loc="upper left", fontsize=7, framealpha=0.7,
54-
facecolor="white", edgecolor="#DDDDDD",
55-
markerscale=3,
50+
ax.xaxis.pane.set_edgecolor("#0D1117")
51+
ax.yaxis.pane.set_edgecolor("#0D1117")
52+
ax.zaxis.pane.set_edgecolor("#0D1117")
53+
ax.xaxis.line.set_color("#0D1117")
54+
ax.yaxis.line.set_color("#0D1117")
55+
ax.zaxis.line.set_color("#0D1117")
56+
ax.grid(False)
57+
58+
# Legend — bottom right, out of the way
59+
legend = ax.legend(
60+
loc="lower right", fontsize=8, framealpha=0.85,
61+
facecolor="#161B22", edgecolor="#30363D",
62+
labelcolor="white", markerscale=3,
5663
)
64+
legend.get_frame().set_linewidth(0.5)
5765

58-
ax.set_title("3D UMAP — PBMC Immune Cell Profiling", color="#222222",
59-
fontsize=14, fontweight="bold", pad=10)
66+
ax.set_title("PBMC Immune Cell Profiling — 3D UMAP",
67+
color="white", fontsize=13, fontweight="bold", pad=15)
6068

6169
buf = BytesIO()
62-
fig.savefig(buf, format="png", dpi=100, bbox_inches="tight",
63-
facecolor="white", edgecolor="none")
70+
fig.savefig(buf, format="png", dpi=120, bbox_inches="tight",
71+
facecolor="#0D1117", edgecolor="none", pad_inches=0.3)
6472
plt.close(fig)
6573
buf.seek(0)
6674
return iio.imread(buf)
@@ -71,7 +79,6 @@ def main():
7179
adata = sc.read_h5ad(in_path)
7280
print(f"Loaded {in_path}")
7381

74-
# Need to recompute neighbor graph since preprocessed data is subset
7582
if "neighbors" not in adata.uns:
7683
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=40)
7784

@@ -83,9 +90,11 @@ def main():
8390
frames = []
8491
for i in range(N_FRAMES):
8592
azim = (i / N_FRAMES) * 360
86-
frame = render_frame(coords, cell_types, azim)
93+
# Gentle elevation oscillation: 20° to 35° and back
94+
elev = 27.5 + 7.5 * np.sin(2 * np.pi * i / N_FRAMES)
95+
frame = render_frame(coords, cell_types, azim, elev)
8796
frames.append(frame)
88-
if (i + 1) % 30 == 0:
97+
if (i + 1) % 45 == 0:
8998
print(f" {i + 1}/{N_FRAMES} frames")
9099

91100
DOCS_DIR.mkdir(exist_ok=True)

0 commit comments

Comments
 (0)