Skip to content

Commit 68497e5

Browse files
Merge pull request #15 from matchms/add_tfidf_to_umap
Add tfidf option to umap
2 parents 5fc9e6c + dd1f1fb commit 68497e5

File tree

2 files changed

+96
-30
lines changed

2 files changed

+96
-30
lines changed

chemap/plotting/chem_space_umap.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import numpy as np
44
import pandas as pd
55
from chemap import FingerprintConfig, compute_fingerprints
6-
from chemap.fingerprint_conversions import fingerprints_to_csr
6+
from chemap.fingerprint_conversions import (
7+
fingerprints_to_csr,
8+
fingerprints_to_tfidf,
9+
idf_normalized,
10+
)
711
from chemap.metrics import (
812
tanimoto_distance_dense,
913
tanimoto_distance_sparse,
@@ -53,7 +57,7 @@ def create_chem_space_umap(
5357
fpgen: Optional[Any] = None,
5458
fingerprint_config: Optional[FingerprintConfig] = None,
5559
show_progress: bool = True,
56-
log_count: bool = False,
60+
scaling: str = None,
5761
# UMAP (CPU / umap-learn)
5862
n_neighbors: int = 100,
5963
min_dist: float = 0.25,
@@ -80,9 +84,9 @@ def create_chem_space_umap(
8084
FingerprintConfig(count=True, folded=False, invalid_policy="raise")
8185
show_progress:
8286
Forwarded to compute_fingerprints.
83-
log_count:
84-
If True, apply np.log1p to counts (works for sparse CSR and dense arrays).
85-
(For binary fingerprints this is harmless)
87+
scaling:
88+
Define scaling for count fingerprints. Default is None, which means no scaling.
89+
Can be set to "log" for log1p scaling, or to "tfidf" for TF-IDF scaling of bits.
8690
n_neighbors, min_dist, umap_random_state:
8791
Standard UMAP parameters.
8892
n_jobs:
@@ -137,14 +141,20 @@ def create_chem_space_umap(
137141

138142
if not fingerprint_config.folded:
139143
# Convert to CSR matrix
140-
fps_csr = fingerprints_to_csr(fingerprints).X
144+
if scaling == "tfidf":
145+
fps_csr = fingerprints_to_tfidf(fingerprints).X
146+
else:
147+
fps_csr = fingerprints_to_csr(fingerprints).X
141148

142-
if log_count:
143-
# Works well for count fingerprints ( for binary it's essentially unchanged).
144-
fps_csr = _log1p_csr_inplace(fps_csr)
149+
if scaling == "log":
150+
fps_csr = _log1p_csr_inplace(fps_csr)
145151

146152
coords = reducer.fit_transform(fps_csr)
147153
else:
154+
if scaling == "log":
155+
fingerprints = np.log1p(fingerprints)
156+
elif scaling == "tfidf":
157+
fingerprints *= idf_normalized((fingerprints > 0).sum(axis=0), fingerprints.shape[0])
148158
coords = reducer.fit_transform(fingerprints)
149159

150160
df[x_col] = coords[:, 0]
@@ -163,13 +173,39 @@ def create_chem_space_umap_gpu(
163173
fpgen: Optional[Any] = None,
164174
fingerprint_config: Optional[FingerprintConfig] = None,
165175
show_progress: bool = True,
166-
log_count: bool = False,
176+
scaling: str = None,
167177
# UMAP (GPU / cuML)
168178
n_neighbors: int = 100,
169179
min_dist: float = 0.25,
170180
) -> pd.DataFrame:
171181
"""Compute fingerprints and create 2D UMAP coordinates using cuML (GPU).
172182
183+
Parameters
184+
----------
185+
data:
186+
Input dataframe containing a SMILES column.
187+
col_smiles:
188+
Name of the SMILES column.
189+
inplace:
190+
If True, write x/y columns into `data` and return it. Else returns a copy.
191+
x_col, y_col:
192+
Output coordinate column names.
193+
fpgen:
194+
RDKit fingerprint generator. Defaults to Morgan radius=9, fpSize=4096.
195+
fingerprint_config:
196+
FingerprintConfig for chemap.compute_fingerprints. Defaults to:
197+
FingerprintConfig(count=True, folded=False, invalid_policy="raise")
198+
show_progress:
199+
Forwarded to compute_fingerprints.
200+
scaling:
201+
Define scaling for count fingerprints. Default is None, which means no scaling.
202+
Can be set to "log" for log1p scaling, or to "tfidf" for TF-IDF scaling of bits.
203+
n_neighbors, min_dist, umap_random_state:
204+
Standard UMAP parameters.
205+
n_jobs:
206+
Passed to umap-learn UMAP for parallelism. Ignores random_state when n_jobs != 1.
207+
Default -1 uses all CPUs.
208+
173209
Notes
174210
-----
175211
- cuML UMAP here is fixed to metric="cosine"
@@ -222,12 +258,12 @@ def create_chem_space_umap_gpu(
222258
)
223259

224260
# Reduce memory footprint (works well for count fingerprints)
225-
if not log_count:
226-
# stays integer-like
227-
fps = fingerprints.astype(np.int8, copy=False)
261+
if scaling == "log":
262+
fingerprints = np.log1p(fingerprints).astype(np.float32, copy=False)
263+
elif scaling == "tfidf":
264+
fingerprints *= idf_normalized((fingerprints > 0).sum(axis=0), fingerprints.shape[0])
228265
else:
229-
# log1p returns float
230-
fps = np.log1p(fingerprints).astype(np.float32, copy=False)
266+
fingerprints = fingerprints.astype(np.int8, copy=False)
231267

232268
umap_model = cuUMAP(
233269
n_neighbors=int(n_neighbors),
@@ -238,7 +274,7 @@ def create_chem_space_umap_gpu(
238274
n_components=2,
239275
)
240276

241-
coords = umap_model.fit_transform(fps)
277+
coords = umap_model.fit_transform(fingerprints)
242278

243279
# cuML may return cupy/cudf-backed arrays; np.asarray makes it safe for pandas columns.
244280
coords_np = np.asarray(coords)

chemap/plotting/scatter_plots.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class ScatterStyle:
3636
alpha: float = 0.25
3737
linewidths: float = 0.0
3838

39+
display_legend: bool = True
40+
legend_outside: bool = False
41+
3942
legend_title: Optional[str] = None
4043
legend_loc: str = "lower left"
4144
legend_frameon: bool = False
@@ -132,21 +135,40 @@ def scatter_plot_base(
132135
ax.set_xlabel("")
133136
ax.set_ylabel("")
134137

135-
legend_title = style.legend_title if style.legend_title is not None else label_col
136-
handles = _build_legend_handles(
137-
legend_labels,
138-
palette,
139-
markersize=style.legend_markersize,
140-
alpha=style.legend_alpha,
141-
)
138+
# ---- legend (optional + outside option) ----
139+
if style.display_legend:
140+
legend_title = style.legend_title if style.legend_title is not None else label_col
141+
handles = _build_legend_handles(
142+
legend_labels,
143+
palette,
144+
markersize=style.legend_markersize,
145+
alpha=style.legend_alpha,
146+
)
142147

143-
ax.legend(
144-
handles=handles,
145-
title=legend_title,
146-
loc=style.legend_loc,
147-
frameon=style.legend_frameon,
148-
ncol=style.legend_ncol,
149-
)
148+
if style.legend_outside:
149+
# Put legend outside right; loc controls anchor point of legend box itself.
150+
ax.legend(
151+
handles=handles,
152+
title=legend_title,
153+
loc="center left",
154+
bbox_to_anchor=(1.02, 0.5),
155+
frameon=style.legend_frameon,
156+
ncol=style.legend_ncol,
157+
borderaxespad=0.0,
158+
)
159+
# Leave room on the right so legend isn't clipped
160+
fig.tight_layout(rect=(0, 0, 0.85, 1))
161+
else:
162+
ax.legend(
163+
handles=handles,
164+
title=legend_title,
165+
loc=style.legend_loc,
166+
frameon=style.legend_frameon,
167+
ncol=style.legend_ncol,
168+
)
169+
fig.tight_layout()
170+
else:
171+
fig.tight_layout()
150172

151173
fig.tight_layout()
152174
return fig, ax
@@ -174,6 +196,8 @@ def scatter_plot_all_classes(
174196
s: float = 5.0,
175197
alpha: float = 0.25,
176198
linewidths: float = 0.0,
199+
display_legend: bool = True,
200+
legend_outside: bool = False,
177201
legend_title: Optional[str] = None,
178202
legend_loc: str = "lower left",
179203
legend_frameon: bool = False,
@@ -243,6 +267,8 @@ def scatter_plot_all_classes(
243267
s=s,
244268
alpha=alpha,
245269
linewidths=linewidths,
270+
display_legend=display_legend,
271+
legend_outside=legend_outside,
246272
legend_title=legend_title if legend_title is not None else subclass_col,
247273
legend_loc=legend_loc,
248274
legend_frameon=legend_frameon,
@@ -300,6 +326,8 @@ def scatter_plot_hierarchical_labels(
300326
s: float = 2.0,
301327
alpha: float = 0.2,
302328
linewidths: float = 0.0,
329+
display_legend: bool = True,
330+
legend_outside: bool = False,
303331
legend_title: str = "Class / Superclass",
304332
legend_loc: str = "lower left",
305333
legend_frameon: bool = False,
@@ -398,6 +426,8 @@ def scatter_plot_hierarchical_labels(
398426
s=s,
399427
alpha=alpha,
400428
linewidths=linewidths,
429+
display_legend=display_legend,
430+
legend_outside=legend_outside,
401431
legend_title=legend_title,
402432
legend_loc=legend_loc,
403433
legend_frameon=legend_frameon,

0 commit comments

Comments
 (0)