Skip to content

Commit ab76ce4

Browse files
committed
address review feedback on filter/create_cell_masks
1 parent aabc5a5 commit ab76ce4

3 files changed

Lines changed: 325 additions & 154 deletions

File tree

src/filter/create_cell_masks/config.vsh.yaml

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ argument_groups:
3434
Individual cell masks are stored as a `DataFrame` in `.obsm` and
3535
summary masks for each filter group are stored as boolean columns in
3636
`.obs`.
37+
__merge__: [., /src/base/h5_compression_argument.yaml]
3738

3839
- name: Parameters
3940
arguments:
@@ -51,10 +52,10 @@ argument_groups:
5152
type: string
5253
required: true
5354
multiple: true
54-
multiple_sep: ","
55-
example: "total_counts:gt:100:rna,cell_volume:lt:10"
55+
multiple_sep: ";"
56+
example: "total_counts:gt:100:rna;cell_volume:lt:10"
5657
summary: |
57-
A comma-separated set of filters to create cell masks in format
58+
A set of filters to create cell masks in format
5859
`<column>:<operator>:<value>:<group>`
5960
description: |
6061
A set of filters to create cell masks. Each filter should be specified
@@ -72,31 +73,36 @@ argument_groups:
7273
selecting cells with `total_counts` greater than 100. The created mask
7374
will be named `rna_total_counts`.
7475
75-
Individual filters are separated by commas.
76+
Individual filters are separated by semicolons.
7677
- name: --prefix
7778
type: string
78-
default: "cell"
79+
required: false
80+
example: "cell"
7981
summary: |
80-
A prefix to use for naming the cell masks `DataFrame` and `.obs`
81-
columns
82+
An optional prefix to use for naming the cell masks `DataFrame` and
83+
`.obs` columns
8284
description: |
83-
A prefix to use for naming the cell masks `DataFrame` and `.obs`
84-
columns.The `DataFrame` containing individual cell masks in `.obsm`
85-
will be named `<prefix>_masks` (e.g., `cell_masks`) and the summary
86-
columns in `.obs` will be named `<prefix>_mask_<group>` (e.g.,
87-
`cell_mask_rna`).
85+
An optional prefix to use for naming the cell masks `DataFrame` and
86+
`.obs` columns. When set, the `DataFrame` containing individual cell
87+
masks in `.obsm` will be named `<prefix>_masks` (e.g., `cell_masks`)
88+
and the summary columns in `.obs` will be named
89+
`<prefix>_mask_<group>` (e.g., `cell_mask_rna`). When omitted, the
90+
names are `masks` and `mask_<group>` (the overall summary column is
91+
simply `mask`).
8892
8993
resources:
9094
- type: python_script
9195
path: script.py
96+
- path: /src/utils/setup_logger.py
97+
- path: /src/utils/compress_h5mu.py
9298

9399
test_resources:
94100
- type: python_script
95101
path: test.py
96102

97103
engines:
98104
- type: docker
99-
image: python:3.12-slim
105+
image: python:3.13-slim
100106
setup:
101107
- type: apt
102108
packages:

src/filter/create_cell_masks/script.py

Lines changed: 87 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import sys
22
from operator import eq, ge, gt, le, lt, ne
33

4-
import mudata as md
54
import pandas as pd
6-
7-
################################################################################
8-
# VIASH
9-
################################################################################
5+
from mudata import read_h5ad
106

117
## VIASH START
128
par = {
@@ -18,12 +14,25 @@
1814
"n_genes:gt:500:rna",
1915
],
2016
"prefix": "cell",
17+
"output_compression": None,
2118
}
19+
meta = {"resources_dir": "src/utils/"}
2220
## VIASH END
2321

24-
################################################################################
25-
# FUNCTIONS
26-
################################################################################
22+
sys.path.append(meta["resources_dir"])
23+
from setup_logger import setup_logger # noqa: E402
24+
from compress_h5mu import write_h5ad_to_h5mu_with_compression # noqa: E402
25+
26+
logger = setup_logger()
27+
28+
OPERATORS = {
29+
"lt": {"function": lt, "string": "lt", "symbol": "<"},
30+
"gt": {"function": gt, "string": "gt", "symbol": ">"},
31+
"le": {"function": le, "string": "le", "symbol": "<="},
32+
"ge": {"function": ge, "string": "ge", "symbol": ">="},
33+
"eq": {"function": eq, "string": "eq", "symbol": "=="},
34+
"ne": {"function": ne, "string": "ne", "symbol": "!="},
35+
}
2736

2837

2938
def parse_value(raw_value):
@@ -40,47 +49,34 @@ def parse_value(raw_value):
4049
return raw_value
4150

4251

43-
def parse_operator(operator_string):
44-
operators = {
45-
"lt": {"function": lt, "string": "lt", "symbol": "<"},
46-
"gt": {"function": gt, "string": "gt", "symbol": ">"},
47-
"le": {"function": le, "string": "le", "symbol": "<="},
48-
"ge": {"function": ge, "string": "ge", "symbol": ">="},
49-
"eq": {"function": eq, "string": "eq", "symbol": "=="},
50-
"ne": {"function": ne, "string": "ne", "symbol": "!="},
51-
}
52-
if operator_string not in operators:
52+
def parse_operator(operator_string, filter_string):
53+
if operator_string not in OPERATORS:
5354
raise ValueError(
54-
"Operator must be one of 'lt', 'gt', 'le', 'ge', 'eq', or 'ne'. "
55-
f"Got: {operator_string}."
55+
f"Unknown operator '{operator_string}' in filter '{filter_string}'. "
56+
f"Must be one of: {', '.join(OPERATORS)}."
5657
)
57-
return operators[operator_string]
58+
return OPERATORS[operator_string]
5859

5960

6061
def parse_filters(raw_filters):
61-
if isinstance(raw_filters, str):
62-
raw_filters = [f for f in raw_filters.split(",") if f]
63-
6462
filters = []
6563
for filter_string in raw_filters:
6664
parts = filter_string.split(":")
6765
if len(parts) not in {3, 4}:
6866
raise ValueError(
69-
"Each filter must be formatted as"
70-
"'<column>:<operator>:<value>:<group>' (<group> is optional)."
67+
f"Each filter must be formatted as "
68+
f"'<column>:<operator>:<value>:<group>' (<group> is optional). "
7169
f"Got: '{filter_string}'."
7270
)
7371

74-
column, operator, value = parts[0], parts[1], parts[2]
72+
column, operator_str, value = parts[0], parts[1], parts[2]
7573
group = parts[3] if len(parts) == 4 else None
74+
operator = parse_operator(operator_str, filter_string)
7675

77-
operator = parse_operator(operator)
78-
76+
name_parts = [p for p in (group, column, operator["string"], value) if p]
7977
filters.append(
8078
{
81-
"name": f"{group}_{column}_{operator['string']}_{value}"
82-
if group
83-
else f"{column}_{operator['string']}_{value}",
79+
"name": "_".join(name_parts),
8480
"description": f"{column} {operator['symbol']} {value}"
8581
+ (f" ({group})" if group else ""),
8682
"column": column,
@@ -94,25 +90,22 @@ def parse_filters(raw_filters):
9490

9591

9692
def create_masks(adata, filters):
93+
missing = sorted({f["column"] for f in filters} - set(adata.obs.columns))
94+
if missing:
95+
raise KeyError(
96+
f"The following columns referenced by filters are not in .obs: {missing}"
97+
)
98+
9799
masks = {}
98100
group_masks = {}
99101
overall_mask = pd.Series(True, index=adata.obs.index)
100102

101-
for filter in filters:
102-
column = filter["column"]
103-
104-
if column not in adata.obs.columns:
105-
raise KeyError(f"Column '{column}' not found in adata.obs.")
106-
107-
name = filter["name"]
108-
operator = filter["operator"]
109-
value = filter["value"]
110-
group = filter["group"]
111-
112-
mask = operator(adata.obs[column], value)
113-
masks[name] = mask
103+
for filt in filters:
104+
mask = filt["operator"](adata.obs[filt["column"]], filt["value"])
105+
masks[filt["name"]] = mask
114106
overall_mask &= mask
115107

108+
group = filt["group"]
116109
if group:
117110
if group not in group_masks:
118111
group_masks[group] = pd.Series(True, index=adata.obs.index)
@@ -122,85 +115,76 @@ def create_masks(adata, filters):
122115
group_masks = pd.DataFrame(group_masks, index=adata.obs.index)
123116
group_masks["overall"] = overall_mask
124117

125-
return (masks, group_masks)
126-
127-
128-
################################################################################
129-
# MAIN
130-
################################################################################
118+
return masks, group_masks
131119

132120

133121
def main(par):
134-
print(f"====== Create cell masks (mudata v{md.__version__}) ======", flush=True)
122+
prefix = par["prefix"] or ""
123+
prefix_part = f"{prefix}_" if prefix else ""
135124

136-
print(f"\n>>> Reading MuData from '{par['input']}'...", flush=True)
137-
mdata = md.read_h5mu(par["input"])
138-
print(mdata, flush=True)
139-
140-
print(f"\n>>> Extracting modality '{par['modality']}'...", flush=True)
141-
if par["modality"] not in mdata.mod:
142-
raise KeyError(
143-
f"Modality '{par['modality']}' not found in MuData. "
144-
f"Available modalities: {list(mdata.mod.keys())}"
145-
)
146-
adata = mdata[par["modality"]]
147-
print(adata, flush=True)
125+
logger.info("Reading modality '%s' from '%s'", par["modality"], par["input"])
126+
try:
127+
adata = read_h5ad(par["input"], mod=par["modality"])
128+
except KeyError:
129+
raise ValueError(f"Modality '{par['modality']}' not found in '{par['input']}'.")
148130

149-
print("\n>>> Parsing filters...", flush=True)
131+
logger.info("Parsing %d filter(s)", len(par["filters"]))
150132
filters = parse_filters(par["filters"])
151-
print(f"Parsed {len(filters)} filters:", flush=True)
152-
for filter in filters:
153-
print(f" - {filter['name']}: {filter['description']}", flush=True)
133+
for filt in filters:
134+
logger.info(" - %s: %s", filt["name"], filt["description"])
154135

155-
print("\n>>> Creating masks...", flush=True)
136+
logger.info("Creating masks")
156137
masks, group_masks = create_masks(adata, filters)
157-
print(f"Created {len(masks.columns)} individual masks", flush=True)
158-
print(masks, flush=True)
159-
print(f"\nCreated {len(group_masks.columns)} group masks", flush=True)
160-
print(group_masks, flush=True)
138+
logger.info(
139+
"Created %d individual mask(s) and %d group mask(s)",
140+
len(masks.columns),
141+
len(group_masks.columns),
142+
)
161143

162-
print("\n>>> Adding masks to AnnData...", flush=True)
163-
obsm_name = f"{par['prefix']}_masks"
144+
obsm_name = f"{prefix_part}masks"
164145
adata.obsm[obsm_name] = masks
165-
print(f"Individual masks stored in obsm['{obsm_name}']", flush=True)
166-
print(adata.obsm[obsm_name], flush=True)
146+
logger.info("Stored individual masks in .obsm['%s']", obsm_name)
167147

168148
group_mask_names = []
169149
for group in group_masks.columns:
170-
if group == "overall":
171-
mask_name = f"{par['prefix']}_mask"
172-
else:
173-
mask_name = f"{par['prefix']}_mask_{group}"
174-
150+
mask_suffix = "" if group == "overall" else f"_{group}"
151+
mask_name = f"{prefix_part}mask{mask_suffix}"
175152
adata.obs[mask_name] = group_masks[group]
176153
adata.obsm[obsm_name][group] = group_masks[group]
177154
group_mask_names.append(mask_name)
155+
logger.info("Stored group masks in .obs: %s", group_mask_names)
178156

179-
print(f"\nGroup masks stored in obs with prefix '{par['prefix']}_mask'", flush=True)
180-
print(adata.obs[group_mask_names], flush=True)
181-
182-
print("\n>>> Adding filters to AnnData...", flush=True)
183-
filters_name = f"{par['prefix']}_filters"
157+
filters_name = f"{prefix_part}filters"
184158
filters_records = [
185159
{
186-
"name": filter["name"],
187-
"description": filter["description"],
188-
"column": filter["column"],
189-
"operator": filter["operator"].__name__,
190-
"value": filter["value"],
191-
"group": filter["group"],
160+
"name": filt["name"],
161+
"description": filt["description"],
162+
"column": filt["column"],
163+
"operator": filt["operator"].__name__,
164+
"value": filt["value"],
165+
"group": filt["group"],
192166
}
193-
for filter in filters
167+
for filt in filters
194168
]
195-
adata.uns[filters_name] = pd.DataFrame(filters_records)
196-
print(f"Filters stored in uns['{filters_name}']", flush=True)
197-
print(adata.uns[filters_name], flush=True)
198-
199-
print(f"\n>>> Writing output to '{par['output']}'...", flush=True)
200-
print(mdata, flush=True)
201-
mdata.write_h5mu(par["output"])
202-
203-
print("\n>>> Done!\n")
169+
filters_df = pd.DataFrame(filters_records)
170+
# Empty string for ungrouped filters: h5 cannot write Python None as a
171+
# string and anndata does not opt into nullable string writing by default.
172+
filters_df["group"] = filters_df["group"].fillna("").astype(str)
173+
adata.uns[filters_name] = filters_df
174+
logger.info("Stored filter definitions in .uns['%s']", filters_name)
175+
176+
logger.info(
177+
"Writing output to '%s' with compression '%s'",
178+
par["output"],
179+
par["output_compression"],
180+
)
181+
write_h5ad_to_h5mu_with_compression(
182+
output_file=par["output"],
183+
h5mu=par["input"],
184+
modality_name=par["modality"],
185+
modality_data=adata,
186+
output_compression=par["output_compression"],
187+
)
204188

205189

206190
if __name__ == "__main__":

0 commit comments

Comments
 (0)