Skip to content

Commit d484bf5

Browse files
refactoring cellwhisperer_wrapper
1 parent f937305 commit d484bf5

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

server/common/compute/cellwhisperer_wrapper.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,6 @@
2727
MODEL_NAME = "Mistral-7B-Instruct-v0.2__cellwhisperer_clip_v1"
2828

2929

30-
def gene_score_contributions(
31-
transcriptome_input: torch.Tensor,
32-
text_list_or_text_embeds: Union[List[str], torch.Tensor],
33-
logit_scale: float,
34-
score_norm_method: str = None,
35-
) -> pd.Series:
36-
"""
37-
Just a dummy for testing
38-
"""
39-
return pd.Series(
40-
{
41-
"Gene 1": 0.1,
42-
"Gene 2": -0.1,
43-
}
44-
)
45-
46-
4730
class CellWhispererWrapper:
4831
def __init__(self, model_path_or_url: str):
4932
"""
@@ -85,6 +68,8 @@ def preprocess_data(self, adaptor):
8568

8669
def llm_obs_to_text(self, adaptor, mask):
8770
"""
71+
Currently unused, in favor of the more advanced chat functionality, but still functional
72+
8873
Embed the given cells into the LLM space and return their average similarity to different keywords as formatted text.
8974
Keyword types used for comparison are: (i) selected enrichR terms (see cellwhisperer.validation.zero_shot.functions.write_enrichr_terms_to_json) \
9075
and (ii) cell type annotations (currently all values in adata.obs.columns). For more info, see cellwhisperer.validation.zero_shot.functions.
@@ -231,7 +216,7 @@ def _prepare_messages(self, adaptor, messages, mask):
231216
codes = np.concatenate([top_genes_df[col].cat.codes.values for col in top_genes_df.columns])
232217
counts = np.bincount(codes, minlength=len(top_genes_df["Top_1"].cat.categories))
233218
category_counts = pd.Series(counts, index=top_genes_df[top_genes_df.columns[0]].cat.categories)
234-
n_top_genes = 50 # TODO number of top genes to list needs to become configurable
219+
n_top_genes = 50 # NOTE number of top genes to list should be configurable
235220
top_genes = category_counts.sort_values(ascending=False).index[:n_top_genes].to_list()
236221

237222
# Initialize the conversation
@@ -250,7 +235,7 @@ def _prepare_messages(self, adaptor, messages, mask):
250235
]
251236
state.offset = 2
252237

253-
# TODO the transcriptome is added too late. consider changing
238+
# NOTE: the transcriptome is added too late. consider changing
254239

255240
for i, message in enumerate(messages):
256241
if i == 0:
@@ -275,14 +260,14 @@ def llm_chat(self, adaptor, messages, mask, temperature):
275260

276261
state.append_message(state.roles[1], None)
277262

278-
# TODO need to make CONTROLLER_URL flexible in there
279263
for chunk in llava_utils.http_bot(state, MODEL_NAME, temperature, top_p=0.7, max_new_tokens=512, log=True):
280264
yield json.dumps({"text": chunk}).encode() + b"\x00"
281265

282266
def gene_score_contributions(self, adaptor, prompt, mask) -> pd.Series:
283267
"""
284268
Which genes increase or decrease the prompt-similiarity in the selected cells?
285269
"""
270+
raise NotImplementedError("Analysis showed that this is not working as expected")
286271

287272
var_index_col_name = adaptor.get_schema()["annotations"]["var"]["index"]
288273
obs_index_col_name = adaptor.get_schema()["annotations"]["obs"]["index"]
@@ -296,7 +281,7 @@ def gene_score_contributions(self, adaptor, prompt, mask) -> pd.Series:
296281

297282
text_embeds = self._embed_texts([prompt])
298283

299-
gene_contribs: pd.Series = gene_score_contributions(
284+
gene_contribs: pd.Series = gene_score_contributions( # NOTE: note implemented
300285
transcriptome_input=transcriptomes,
301286
text_list_or_text_embeds=text_embeds,
302287
logit_scale=self.logit_scale,

0 commit comments

Comments
 (0)