Skip to content

Commit 09ec8f8

Browse files
committed
Address PR comments
Fix `map_location` Replace `annotated_only` with `factor_filter`
1 parent 4dbd7cf commit 09ec8f8

13 files changed

Lines changed: 66 additions & 50 deletions

docs/notebooks/tcga_brca_bulk_multiomics.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@
543543
"id": "e6314dd0",
544544
"metadata": {},
545545
"source": [
546-
"We confirm the optimiser converged, then look at how variance is partitioned across the informed pathway factors. We pass `annotated_only=True` to drop the single dense factor, which absorbs the dominant axis of variation; each informed factor explains a smaller, pathway-specific slice."
546+
"We confirm the optimiser converged, then look at how variance is partitioned across the informed pathway factors. We pass a factor-name predicate that keeps factors present in the Hallmark annotation mask, dropping the single dense factor that absorbs the dominant axis of variation; each informed factor explains a smaller, pathway-specific slice."
547547
]
548548
},
549549
{
@@ -611,7 +611,11 @@
611611
}
612612
],
613613
"source": [
614-
"mfl.pl.variance_explained(model, annotated_only=True, figsize=(7, 12))"
614+
"mfl.pl.variance_explained(\n",
615+
" model,\n",
616+
" factor_filter=lambda factor: factor in mdata[\"rna\"].varm[\"gene_set_mask\"].columns,\n",
617+
" figsize=(7, 12),\n",
618+
")"
615619
]
616620
},
617621
{

src/mofaflex/_core/mofaflex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def load(cls, path: str | Path, map_location=None) -> "MOFAFLEX":
537537

538538
if map_location is not None:
539539
state["train_opts"]["device"] = map_location
540+
map_location = state["train_opts"]["device"]
540541

541542
model = cls.__new__(cls)
542543
model._train_loss_elbo = state["train_loss_elbo"]

src/mofaflex/_core/terms/mofaflex.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(
115115

116116
def _results_to_df(
117117
self,
118-
results: Mapping[str, np.ndarray],
118+
results: Mapping[str, np.ndarray | pd.DataFrame],
119119
axis: Literal[0, 1],
120120
ordered: bool = False,
121121
factors_subset: slice = slice(None),
@@ -124,14 +124,18 @@ def _results_to_df(
124124
ret = {}
125125
for name, res in results.items():
126126
fnames = factor_names
127+
index = (
128+
res.index
129+
if isinstance(res, pd.DataFrame)
130+
else (self._sample_names[name] if axis == 0 else self._feature_names[name])
131+
)
132+
values = res.to_numpy() if isinstance(res, pd.DataFrame) else res
127133
if ordered:
128134
factor_order = self.factor_order[factors_subset].copy()
129135
factor_order[np.argsort(factor_order)] = np.arange(len(factor_order))
130-
res = res[:, factor_order]
136+
values = values[:, factor_order]
131137
fnames = fnames[factor_order]
132-
ret[name] = pd.DataFrame(
133-
res, index=self._sample_names[name] if axis == 0 else self._feature_names[name], columns=fnames
134-
)
138+
ret[name] = pd.DataFrame(values, index=index, columns=fnames)
135139
return ret
136140

137141
def _wrap_api_method(self, axis: Literal[0, 1], prior: Prior, api: PriorDynamicAPI):
@@ -785,9 +789,7 @@ def _load(
785789
)
786790

787791
self._prior_api_properties = {}
788-
# map_location may be None (e.g. MOFAFLEX.load without an explicit device); fall back to the
789-
# default device so methods that need it (e.g. PCGSE in get_significant_annotations) still work.
790-
self._device = default_torch_device(map_location)
792+
self._device = map_location
791793
self._init_api()
792794

793795
def _get_postprocessed_factors(

src/mofaflex/pl/_plotting.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping, Sequence
3+
from collections.abc import Callable, Mapping, Sequence
44
from contextlib import suppress
55
from functools import partial
66
from typing import TYPE_CHECKING, Literal
@@ -464,7 +464,7 @@ def variance_explained(
464464
model: MOFAFLEX,
465465
group_by: Literal["group", "view"] = "group",
466466
term: str | None = None,
467-
annotated_only: bool = False,
467+
factor_filter: Callable[[str], bool] | None = None,
468468
figsize: tuple[float, float] | None = None,
469469
) -> p9.ggplot:
470470
"""Plot the fraction of variance explained per factor in each group and view.
@@ -475,9 +475,7 @@ def variance_explained(
475475
term: The name of the additive term to plot the variance explained for. If `None` and the model has only one additive term,
476476
will plot the fraction of variance explained per factor for this term. If `None` and the model has multiple terms, will
477477
plot the fraction of variance explained per term.
478-
annotated_only: If `True`, only plot the factors that are informed by prior annotations (e.g. the gene-set factors of an
479-
:class:`~mofaflex.priors.InformedHorseshoe` prior), dropping the uninformed dense factors. Requires the model to
480-
have an informed prior.
478+
factor_filter: Predicate applied to factor names. Only factors for which the predicate returns `True` are plotted.
481479
figsize: Figure size in inches.
482480
"""
483481
if group_by == "group":
@@ -491,19 +489,12 @@ def variance_explained(
491489
figsize = (len(model.group_names) * 3, 5)
492490

493491
byterm = term is None and model.n_terms > 1
494-
if annotated_only and byterm:
495-
raise ValueError("`annotated_only` is only supported at the factor level. Specify a single `term`.")
492+
if factor_filter is not None and byterm:
493+
raise ValueError("`factor_filter` is only supported at the factor level. Specify a single `term`.")
496494
col = "term" if byterm else "component"
497495
df_r2 = model.get_r2("byterm" if byterm else "term", ordered=True, term=term)
498-
if annotated_only:
499-
try:
500-
annotations = model.get_annotations()
501-
except AttributeError as e:
502-
raise ValueError(
503-
"`annotated_only=True` requires a model with an informed prior providing annotations."
504-
) from e
505-
informed_factors = set().union(*(annot.columns for annot in annotations.values()))
506-
df_r2 = df_r2[df_r2[col].isin(informed_factors)]
496+
if factor_filter is not None:
497+
df_r2 = df_r2[df_r2[col].map(factor_filter)]
507498
df_r2 = df_r2.assign(factor=lambda x: pd.Categorical(x[col], categories=x[col].unique()))
508499
heatmap = (
509500
p9.ggplot(df_r2, p9.aes(x=x, y="factor", fill="R2"))
@@ -1144,15 +1135,7 @@ def top_weights(
11441135
and (df.groupby("factor", observed=True)["feature"].aggregate(lambda x: x.duplicated().sum()) > 0).any()
11451136
):
11461137
df = df.assign(feature=lambda x: x.feature.str + "_" + x.view.str)
1147-
# A feature can be among the top features of several factors. With a single global feature category, those
1148-
# shared features share one y position, scrambling the per-facet ordering. Make the category unique per facet
1149-
# (factor) so each facet is ordered by its own |weight|, and strip the suffix again for the axis labels.
1150-
_sep = "\x1f"
1151-
df = df.assign(
1152-
feature=lambda x: pd.Categorical(
1153-
keys := x.feature.astype(str) + _sep + x.factor.astype(str), categories=keys.unique()
1154-
)
1155-
)
1138+
df = df.assign(feature=lambda x: pd.Categorical(x.feature, categories=x.feature.unique()))
11561139

11571140
aes_kwargs = {}
11581141
if have_annot:
@@ -1163,7 +1146,6 @@ def top_weights(
11631146
+ p9.geom_segment()
11641147
+ p9.geom_point(size=5, stroke=0)
11651148
+ p9.scale_shape_manual(values=("$\\oplus$", "$\\ominus$"), breaks=(True, False), guide=None)
1166-
+ p9.scale_y_discrete(labels=lambda breaks: [b.rsplit(_sep, 1)[0] for b in breaks])
11671149
+ _weights_inferred_color_scale
11681150
+ p9.scale_x_continuous(expand=(0, 0, 0.05, 0))
11691151
+ p9.labs(x="| Weight |", y="", color="")
-493 Bytes
Loading
-15 Bytes
Loading
249 Bytes
Loading
-209 Bytes
Loading
157 Bytes
Loading
-578 Bytes
Loading

0 commit comments

Comments
 (0)