11import sys
22from operator import eq , ge , gt , le , lt , ne
33
4- import mudata as md
54import pandas as pd
6-
7- ################################################################################
8- # VIASH
9- ################################################################################
5+ from mudata import read_h5ad
106
117## VIASH START
128par = {
1814 "n_genes:gt:500:rna" ,
1915 ],
2016 "prefix" : "cell" ,
17+ "output_compression" : None ,
2118}
19+ meta = {"resources_dir" : "src/utils/" }
2220## VIASH END
2321
24- ################################################################################
25- # FUNCTIONS
26- ################################################################################
22+ sys .path .append (meta ["resources_dir" ])
23+ from setup_logger import setup_logger # noqa: E402
24+ from compress_h5mu import write_h5ad_to_h5mu_with_compression # noqa: E402
25+
26+ logger = setup_logger ()
27+
28+ OPERATORS = {
29+ "lt" : {"function" : lt , "string" : "lt" , "symbol" : "<" },
30+ "gt" : {"function" : gt , "string" : "gt" , "symbol" : ">" },
31+ "le" : {"function" : le , "string" : "le" , "symbol" : "<=" },
32+ "ge" : {"function" : ge , "string" : "ge" , "symbol" : ">=" },
33+ "eq" : {"function" : eq , "string" : "eq" , "symbol" : "==" },
34+ "ne" : {"function" : ne , "string" : "ne" , "symbol" : "!=" },
35+ }
2736
2837
2938def parse_value (raw_value ):
@@ -40,47 +49,34 @@ def parse_value(raw_value):
4049 return raw_value
4150
4251
43- def parse_operator (operator_string ):
44- operators = {
45- "lt" : {"function" : lt , "string" : "lt" , "symbol" : "<" },
46- "gt" : {"function" : gt , "string" : "gt" , "symbol" : ">" },
47- "le" : {"function" : le , "string" : "le" , "symbol" : "<=" },
48- "ge" : {"function" : ge , "string" : "ge" , "symbol" : ">=" },
49- "eq" : {"function" : eq , "string" : "eq" , "symbol" : "==" },
50- "ne" : {"function" : ne , "string" : "ne" , "symbol" : "!=" },
51- }
52- if operator_string not in operators :
52+ def parse_operator (operator_string , filter_string ):
53+ if operator_string not in OPERATORS :
5354 raise ValueError (
54- "Operator must be one of 'lt', 'gt', 'le', 'ge', 'eq', or 'ne '. "
55- f"Got : { operator_string } ."
55+ f"Unknown operator ' { operator_string } ' in filter ' { filter_string } '. "
56+ f"Must be one of : { ', ' . join ( OPERATORS ) } ."
5657 )
57- return operators [operator_string ]
58+ return OPERATORS [operator_string ]
5859
5960
6061def parse_filters (raw_filters ):
61- if isinstance (raw_filters , str ):
62- raw_filters = [f for f in raw_filters .split ("," ) if f ]
63-
6462 filters = []
6563 for filter_string in raw_filters :
6664 parts = filter_string .split (":" )
6765 if len (parts ) not in {3 , 4 }:
6866 raise ValueError (
69- "Each filter must be formatted as"
70- "'<column>:<operator>:<value>:<group>' (<group> is optional)."
67+ f "Each filter must be formatted as "
68+ f "'<column>:<operator>:<value>:<group>' (<group> is optional). "
7169 f"Got: '{ filter_string } '."
7270 )
7371
74- column , operator , value = parts [0 ], parts [1 ], parts [2 ]
72+ column , operator_str , value = parts [0 ], parts [1 ], parts [2 ]
7573 group = parts [3 ] if len (parts ) == 4 else None
74+ operator = parse_operator (operator_str , filter_string )
7675
77- operator = parse_operator (operator )
78-
76+ name_parts = [p for p in (group , column , operator ["string" ], value ) if p ]
7977 filters .append (
8078 {
81- "name" : f"{ group } _{ column } _{ operator ['string' ]} _{ value } "
82- if group
83- else f"{ column } _{ operator ['string' ]} _{ value } " ,
79+ "name" : "_" .join (name_parts ),
8480 "description" : f"{ column } { operator ['symbol' ]} { value } "
8581 + (f" ({ group } )" if group else "" ),
8682 "column" : column ,
@@ -94,25 +90,22 @@ def parse_filters(raw_filters):
9490
9591
9692def create_masks (adata , filters ):
93+ missing = sorted ({f ["column" ] for f in filters } - set (adata .obs .columns ))
94+ if missing :
95+ raise KeyError (
96+ f"The following columns referenced by filters are not in .obs: { missing } "
97+ )
98+
9799 masks = {}
98100 group_masks = {}
99101 overall_mask = pd .Series (True , index = adata .obs .index )
100102
101- for filter in filters :
102- column = filter ["column" ]
103-
104- if column not in adata .obs .columns :
105- raise KeyError (f"Column '{ column } ' not found in adata.obs." )
106-
107- name = filter ["name" ]
108- operator = filter ["operator" ]
109- value = filter ["value" ]
110- group = filter ["group" ]
111-
112- mask = operator (adata .obs [column ], value )
113- masks [name ] = mask
103+ for filt in filters :
104+ mask = filt ["operator" ](adata .obs [filt ["column" ]], filt ["value" ])
105+ masks [filt ["name" ]] = mask
114106 overall_mask &= mask
115107
108+ group = filt ["group" ]
116109 if group :
117110 if group not in group_masks :
118111 group_masks [group ] = pd .Series (True , index = adata .obs .index )
@@ -122,85 +115,76 @@ def create_masks(adata, filters):
122115 group_masks = pd .DataFrame (group_masks , index = adata .obs .index )
123116 group_masks ["overall" ] = overall_mask
124117
125- return (masks , group_masks )
126-
127-
128- ################################################################################
129- # MAIN
130- ################################################################################
118+ return masks , group_masks
131119
132120
133121def main (par ):
134- print (f"====== Create cell masks (mudata v{ md .__version__ } ) ======" , flush = True )
122+ prefix = par ["prefix" ] or ""
123+ prefix_part = f"{ prefix } _" if prefix else ""
135124
136- print (f"\n >>> Reading MuData from '{ par ['input' ]} '..." , flush = True )
137- mdata = md .read_h5mu (par ["input" ])
138- print (mdata , flush = True )
139-
140- print (f"\n >>> Extracting modality '{ par ['modality' ]} '..." , flush = True )
141- if par ["modality" ] not in mdata .mod :
142- raise KeyError (
143- f"Modality '{ par ['modality' ]} ' not found in MuData. "
144- f"Available modalities: { list (mdata .mod .keys ())} "
145- )
146- adata = mdata [par ["modality" ]]
147- print (adata , flush = True )
125+ logger .info ("Reading modality '%s' from '%s'" , par ["modality" ], par ["input" ])
126+ try :
127+ adata = read_h5ad (par ["input" ], mod = par ["modality" ])
128+ except KeyError :
129+ raise ValueError (f"Modality '{ par ['modality' ]} ' not found in '{ par ['input' ]} '." )
148130
149- print ( " \n >>> Parsing filters... " , flush = True )
131+ logger . info ( " Parsing %d filter(s) " , len ( par [ "filters" ]) )
150132 filters = parse_filters (par ["filters" ])
151- print (f"Parsed { len (filters )} filters:" , flush = True )
152- for filter in filters :
153- print (f" - { filter ['name' ]} : { filter ['description' ]} " , flush = True )
133+ for filt in filters :
134+ logger .info (" - %s: %s" , filt ["name" ], filt ["description" ])
154135
155- print ( " \n >>> Creating masks..." , flush = True )
136+ logger . info ( " Creating masks" )
156137 masks , group_masks = create_masks (adata , filters )
157- print (f"Created { len (masks .columns )} individual masks" , flush = True )
158- print (masks , flush = True )
159- print (f"\n Created { len (group_masks .columns )} group masks" , flush = True )
160- print (group_masks , flush = True )
138+ logger .info (
139+ "Created %d individual mask(s) and %d group mask(s)" ,
140+ len (masks .columns ),
141+ len (group_masks .columns ),
142+ )
161143
162- print ("\n >>> Adding masks to AnnData..." , flush = True )
163- obsm_name = f"{ par ['prefix' ]} _masks"
144+ obsm_name = f"{ prefix_part } masks"
164145 adata .obsm [obsm_name ] = masks
165- print (f"Individual masks stored in obsm['{ obsm_name } ']" , flush = True )
166- print (adata .obsm [obsm_name ], flush = True )
146+ logger .info ("Stored individual masks in .obsm['%s']" , obsm_name )
167147
168148 group_mask_names = []
169149 for group in group_masks .columns :
170- if group == "overall" :
171- mask_name = f"{ par ['prefix' ]} _mask"
172- else :
173- mask_name = f"{ par ['prefix' ]} _mask_{ group } "
174-
150+ mask_suffix = "" if group == "overall" else f"_{ group } "
151+ mask_name = f"{ prefix_part } mask{ mask_suffix } "
175152 adata .obs [mask_name ] = group_masks [group ]
176153 adata .obsm [obsm_name ][group ] = group_masks [group ]
177154 group_mask_names .append (mask_name )
155+ logger .info ("Stored group masks in .obs: %s" , group_mask_names )
178156
179- print (f"\n Group masks stored in obs with prefix '{ par ['prefix' ]} _mask'" , flush = True )
180- print (adata .obs [group_mask_names ], flush = True )
181-
182- print ("\n >>> Adding filters to AnnData..." , flush = True )
183- filters_name = f"{ par ['prefix' ]} _filters"
157+ filters_name = f"{ prefix_part } filters"
184158 filters_records = [
185159 {
186- "name" : filter ["name" ],
187- "description" : filter ["description" ],
188- "column" : filter ["column" ],
189- "operator" : filter ["operator" ].__name__ ,
190- "value" : filter ["value" ],
191- "group" : filter ["group" ],
160+ "name" : filt ["name" ],
161+ "description" : filt ["description" ],
162+ "column" : filt ["column" ],
163+ "operator" : filt ["operator" ].__name__ ,
164+ "value" : filt ["value" ],
165+ "group" : filt ["group" ],
192166 }
193- for filter in filters
167+ for filt in filters
194168 ]
195- adata .uns [filters_name ] = pd .DataFrame (filters_records )
196- print (f"Filters stored in uns['{ filters_name } ']" , flush = True )
197- print (adata .uns [filters_name ], flush = True )
198-
199- print (f"\n >>> Writing output to '{ par ['output' ]} '..." , flush = True )
200- print (mdata , flush = True )
201- mdata .write_h5mu (par ["output" ])
202-
203- print ("\n >>> Done!\n " )
169+ filters_df = pd .DataFrame (filters_records )
170+ # Empty string for ungrouped filters: h5 cannot write Python None as a
171+ # string and anndata does not opt into nullable string writing by default.
172+ filters_df ["group" ] = filters_df ["group" ].fillna ("" ).astype (str )
173+ adata .uns [filters_name ] = filters_df
174+ logger .info ("Stored filter definitions in .uns['%s']" , filters_name )
175+
176+ logger .info (
177+ "Writing output to '%s' with compression '%s'" ,
178+ par ["output" ],
179+ par ["output_compression" ],
180+ )
181+ write_h5ad_to_h5mu_with_compression (
182+ output_file = par ["output" ],
183+ h5mu = par ["input" ],
184+ modality_name = par ["modality" ],
185+ modality_data = adata ,
186+ output_compression = par ["output_compression" ],
187+ )
204188
205189
206190if __name__ == "__main__" :
0 commit comments