2727MODEL_NAME = "Mistral-7B-Instruct-v0.2__cellwhisperer_clip_v1"
2828
2929
30- def gene_score_contributions (
31- transcriptome_input : torch .Tensor ,
32- text_list_or_text_embeds : Union [List [str ], torch .Tensor ],
33- logit_scale : float ,
34- score_norm_method : str = None ,
35- ) -> pd .Series :
36- """
37- Just a dummy for testing
38- """
39- return pd .Series (
40- {
41- "Gene 1" : 0.1 ,
42- "Gene 2" : - 0.1 ,
43- }
44- )
45-
46-
4730class CellWhispererWrapper :
4831 def __init__ (self , model_path_or_url : str ):
4932 """
@@ -85,6 +68,8 @@ def preprocess_data(self, adaptor):
8568
8669 def llm_obs_to_text (self , adaptor , mask ):
8770 """
71+ Currently unused, in favor of the more advanced chat functionality, but still functional
72+
8873 Embed the given cells into the LLM space and return their average similarity to different keywords as formatted text.
8974 Keyword types used for comparison are: (i) selected enrichR terms (see cellwhisperer.validation.zero_shot.functions.write_enrichr_terms_to_json) \
9075 and (ii) cell type annotations (currently all values in adata.obs.columns). For more info, see cellwhisperer.validation.zero_shot.functions.
@@ -231,7 +216,7 @@ def _prepare_messages(self, adaptor, messages, mask):
231216 codes = np .concatenate ([top_genes_df [col ].cat .codes .values for col in top_genes_df .columns ])
232217 counts = np .bincount (codes , minlength = len (top_genes_df ["Top_1" ].cat .categories ))
233218 category_counts = pd .Series (counts , index = top_genes_df [top_genes_df .columns [0 ]].cat .categories )
234- n_top_genes = 50 # TODO number of top genes to list needs to become configurable
219+ n_top_genes = 50 # NOTE number of top genes to list should be configurable
235220 top_genes = category_counts .sort_values (ascending = False ).index [:n_top_genes ].to_list ()
236221
237222 # Initialize the conversation
@@ -250,7 +235,7 @@ def _prepare_messages(self, adaptor, messages, mask):
250235 ]
251236 state .offset = 2
252237
253- # TODO the transcriptome is added too late. consider changing
238+ # NOTE: the transcriptome is added too late. consider changing
254239
255240 for i , message in enumerate (messages ):
256241 if i == 0 :
@@ -275,14 +260,14 @@ def llm_chat(self, adaptor, messages, mask, temperature):
275260
276261 state .append_message (state .roles [1 ], None )
277262
278- # TODO need to make CONTROLLER_URL flexible in there
279263 for chunk in llava_utils .http_bot (state , MODEL_NAME , temperature , top_p = 0.7 , max_new_tokens = 512 , log = True ):
280264 yield json .dumps ({"text" : chunk }).encode () + b"\x00 "
281265
282266 def gene_score_contributions (self , adaptor , prompt , mask ) -> pd .Series :
283267 """
284268 Which genes increase or decrease the prompt-similiarity in the selected cells?
285269 """
270+ raise NotImplementedError ("Analysis showed that this is not working as expected" )
286271
287272 var_index_col_name = adaptor .get_schema ()["annotations" ]["var" ]["index" ]
288273 obs_index_col_name = adaptor .get_schema ()["annotations" ]["obs" ]["index" ]
@@ -296,7 +281,7 @@ def gene_score_contributions(self, adaptor, prompt, mask) -> pd.Series:
296281
297282 text_embeds = self ._embed_texts ([prompt ])
298283
299- gene_contribs : pd .Series = gene_score_contributions (
284+ gene_contribs : pd .Series = gene_score_contributions ( # NOTE: note implemented
300285 transcriptome_input = transcriptomes ,
301286 text_list_or_text_embeds = text_embeds ,
302287 logit_scale = self .logit_scale ,
0 commit comments