diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..6d9e4c7d --- /dev/null +++ b/.editorconfig @@ -0,0 +1,5 @@ +root = true + +[*.js] +indent_style = space +indent_size = 2 \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/---bug-report.md b/.github_deactivated/ISSUE_TEMPLATE/---bug-report.md similarity index 100% rename from .github/ISSUE_TEMPLATE/---bug-report.md rename to .github_deactivated/ISSUE_TEMPLATE/---bug-report.md diff --git a/.github/ISSUE_TEMPLATE/---feature-request.md b/.github_deactivated/ISSUE_TEMPLATE/---feature-request.md similarity index 100% rename from .github/ISSUE_TEMPLATE/---feature-request.md rename to .github_deactivated/ISSUE_TEMPLATE/---feature-request.md diff --git a/.github/ISSUE_TEMPLATE/---question-clarification.md b/.github_deactivated/ISSUE_TEMPLATE/---question-clarification.md similarity index 100% rename from .github/ISSUE_TEMPLATE/---question-clarification.md rename to .github_deactivated/ISSUE_TEMPLATE/---question-clarification.md diff --git a/.github/ISSUE_TEMPLATE/tech-issue.md b/.github_deactivated/ISSUE_TEMPLATE/tech-issue.md similarity index 100% rename from .github/ISSUE_TEMPLATE/tech-issue.md rename to .github_deactivated/ISSUE_TEMPLATE/tech-issue.md diff --git a/.github/workflows/close-stale-prs.yml b/.github_deactivated/workflows/close-stale-prs.yml similarity index 100% rename from .github/workflows/close-stale-prs.yml rename to .github_deactivated/workflows/close-stale-prs.yml diff --git a/.github/workflows/codeql-analysis.yml b/.github_deactivated/workflows/codeql-analysis.yml similarity index 100% rename from .github/workflows/codeql-analysis.yml rename to .github_deactivated/workflows/codeql-analysis.yml diff --git a/.github/workflows/compatibility_tests.yml b/.github_deactivated/workflows/compatibility_tests.yml similarity index 100% rename from .github/workflows/compatibility_tests.yml rename to .github_deactivated/workflows/compatibility_tests.yml diff --git a/.github/workflows/lint-pr-commit-message.yml b/.github_deactivated/workflows/lint-pr-commit-message.yml similarity index 100% rename from .github/workflows/lint-pr-commit-message.yml rename to .github_deactivated/workflows/lint-pr-commit-message.yml diff --git a/.github/workflows/push_tests.yml b/.github_deactivated/workflows/push_tests.yml similarity index 100% rename from .github/workflows/push_tests.yml rename to .github_deactivated/workflows/push_tests.yml diff --git a/Makefile b/Makefile index c2b04eb6..80b52210 100644 --- a/Makefile +++ b/Makefile @@ -79,6 +79,16 @@ smoke-test: smoke-test-annotations: cd client && $(MAKE) smoke-test-annotations +# STARTING SERVER AND FRONTEND + +.PHONY: start +start: start-frontend-noblock start-server + +.PHONY: start-frontend-noblock +start-frontend-noblock: + @echo "Starting frontend..." + @cd client && nohup make start-frontend & + # FORMATTING CODE .PHONY: fmt diff --git a/README.md b/README.md index 76c9f864..67247203 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,23 @@ +# Moritz notes + +Read this to get started (install& get startend) +https://github.com/chanzuckerberg/cellxgene/blob/main/dev_docs/developer_guidelines.md + +## Installation + +### Workaround \[webpack-cli] HookWebpackError: error:0308010C:digital envelope routines::unsupported + +Run `export NODE_OPTIONS=--openssl-legacy-provider` before `make build-for-server-dev` + +# General + + _an interactive explorer for single-cell transcriptomics data_ [![DOI](https://zenodo.org/badge/105615409.svg)](https://zenodo.org/badge/latestdoi/105615409) [![PyPI](https://img.shields.io/pypi/v/cellxgene)](https://pypi.org/project/cellxgene/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/cellxgene)](https://pypistats.org/packages/cellxgene) [![GitHub last commit](https://img.shields.io/github/last-commit/chanzuckerberg/cellxgene)](https://github.com/chanzuckerberg/cellxgene/pulse) + [![Push Tests](https://github.com/chanzuckerberg/cellxgene/workflows/Push%20Tests/badge.svg)](https://github.com/chanzuckerberg/cellxgene/actions?query=workflow%3A%22Push+Tests%22) [![Compatibility Tests](https://github.com/chanzuckerberg/cellxgene/workflows/Compatibility%20Tests/badge.svg)](https://github.com/chanzuckerberg/cellxgene/actions?query=workflow%3A%22Compatibility+Tests%22) ![Code Coverage](https://codecov.io/gh/chanzuckerberg/cellxgene/branch/main/graph/badge.svg) diff --git a/client/.husky/pre-commit b/client/.husky/pre-commit index 5a724e7e..9ff0bca3 100755 --- a/client/.husky/pre-commit +++ b/client/.husky/pre-commit @@ -1,5 +1,6 @@ #!/bin/sh +exit 0 . "$(dirname "$0")/_/husky.sh" cd client -npx --no-install lint-staged --config "./configuration/lint-staged/lint-staged.config.js" \ No newline at end of file +npx --no-install lint-staged --config "./configuration/lint-staged/lint-staged.config.js" diff --git a/client/configuration/webpack/webpack.config.dev.js b/client/configuration/webpack/webpack.config.dev.js index 03e7246d..ead287fb 100644 --- a/client/configuration/webpack/webpack.config.dev.js +++ b/client/configuration/webpack/webpack.config.dev.js @@ -33,7 +33,7 @@ const devConfig = { options: { name: "static/assets/[name].[ext]", // (thuang): This is needed to make sure @font url path is '/static/assets/' - publicPath: "..", + publicPath: "", }, }, ], diff --git a/client/configuration/webpack/webpack.config.prod.js b/client/configuration/webpack/webpack.config.prod.js index ee34773b..6b1da2ea 100644 --- a/client/configuration/webpack/webpack.config.prod.js +++ b/client/configuration/webpack/webpack.config.prod.js @@ -47,8 +47,8 @@ const prodConfig = { include: [nodeModules, fonts, images], options: { name: "static/assets/[name]-[contenthash].[ext]", - // (thuang): This is needed to make sure @font url path is '../static/assets/' - publicPath: "..", + // (thuang): This is needed to make sure @font url path is '../static/assets/' <- not for me + publicPath: "", }, }, ], diff --git a/client/favicon.png b/client/favicon.png deleted file mode 100644 index 58f43344..00000000 Binary files a/client/favicon.png and /dev/null differ diff --git a/client/favicon.png b/client/favicon.png new file mode 120000 index 00000000..085f119d --- /dev/null +++ b/client/favicon.png @@ -0,0 +1 @@ +src/images/icon_cellwhisperer.png \ No newline at end of file diff --git a/client/src/actions/annotation.js b/client/src/actions/annotation.js index ad2f0db5..4f5d6bf1 100644 --- a/client/src/actions/annotation.js +++ b/client/src/actions/annotation.js @@ -5,9 +5,66 @@ import difference from "lodash.difference"; import pako from "pako"; import * as globals from "../globals"; import { MatrixFBS, AnnotationsHelpers } from "../util/stateManager"; +import { isTypedArray } from "../util/typeHelpers"; const { isUserAnnotation } = AnnotationsHelpers; +export const annotationCreateContinuousAction = + (newContinuousName, values) => async (dispatch, getState) => { + /* + Add a new user-created continuous to the obs annotations. + + Arguments: + newContinuousName - string name for the continuous. + continuousToDuplicate - obs continuous to use for initial values, or null. + */ + const { annoMatrix: prevAnnoMatrix, obsCrossfilter: prevObsCrossfilter } = + getState(); + if (!prevAnnoMatrix || !prevObsCrossfilter) return; + const { schema } = prevAnnoMatrix; + + /* name must be a string, non-zero length */ + if (typeof newContinuousName !== "string" || newContinuousName.length === 0) + throw new Error("user annotations require string name"); + + if (!isTypedArray(values) || values.length === 0) + // TODO check for correct length + throw new Error( + `Provided values are of wrong format or length ${typeof values}, ${ + values.length + }` + ); + + /* ensure the name isn't already in use! */ + if (schema.annotations.obsByName[newContinuousName]) + throw new Error("name collision on annotation continuous create"); + + const newSchema = { + name: newContinuousName, + type: "float32", + writable: false, + }; + + const obsCrossfilter = prevObsCrossfilter.addObsColumn( + newSchema, + values.constructor, + values + ); + + // TODO this is probably a noop (and should be removed) + dispatch({ + type: "annotation: create continuous", + data: newContinuousName, + annoMatrix: obsCrossfilter.annoMatrix, + obsCrossfilter, + }); + + dispatch({ + type: "color by continuous metadata", + colorAccessor: newContinuousName, + }); + }; + export const annotationCreateCategoryAction = (newCategoryName, categoryToDuplicate) => async (dispatch, getState) => { /* diff --git a/client/src/actions/index.js b/client/src/actions/index.js index 5c5f495d..eeebd985 100644 --- a/client/src/actions/index.js +++ b/client/src/actions/index.js @@ -11,6 +11,7 @@ import * as annoActions from "./annotation"; import * as viewActions from "./viewStack"; import * as embActions from "./embedding"; import * as genesetActions from "./geneset"; +import * as llmEmbeddingsActions from "./llmEmbeddings"; function setGlobalConfig(config) { /** @@ -236,6 +237,7 @@ function fetchJson(pathAndQuery) { } export default { + fetchJson, doInitialDataLoad, requestDifferentialExpression, requestSingleGeneExpressionCountsForColoringPOST, @@ -256,6 +258,8 @@ export default { clipAction: viewActions.clipAction, subsetAction: viewActions.subsetAction, resetSubsetAction: viewActions.resetSubsetAction, + annotationCreateContinuousAction: + annoActions.annotationCreateContinuousAction, annotationCreateCategoryAction: annoActions.annotationCreateCategoryAction, annotationRenameCategoryAction: annoActions.annotationRenameCategoryAction, annotationDeleteCategoryAction: annoActions.annotationDeleteCategoryAction, @@ -272,4 +276,8 @@ export default { genesetDelete: genesetActions.genesetDelete, genesetAddGenes: genesetActions.genesetAddGenes, genesetDeleteGenes: genesetActions.genesetDeleteGenes, + requestEmbeddingLLMWithText: llmEmbeddingsActions.requestEmbeddingLLMWithText, + requestEmbeddingLLMWithCells: + llmEmbeddingsActions.requestEmbeddingLLMWithCells, + startChatRequest: llmEmbeddingsActions.startChatRequest, }; diff --git a/client/src/actions/llmEmbeddings.js b/client/src/actions/llmEmbeddings.js new file mode 100644 index 00000000..76597a22 --- /dev/null +++ b/client/src/actions/llmEmbeddings.js @@ -0,0 +1,206 @@ +import * as globals from "../globals"; +import { annotationCreateContinuousAction } from "./annotation"; +import { matrixFBSToDataframe } from "../util/stateManager/matrix"; + +/* + LLM embedding querying +*/ +export const requestEmbeddingLLMWithCells = + /* + Send a request to the LLM embedding model with text + */ + (cellSelection) => async (dispatch) => { + dispatch({ + type: "request to embedding model started", + }); + try { + // Legal values are null, Array or TypedArray. Null is initial state. + if (!cellSelection) cellSelection = []; + + // These lines ensure that we convert any TypedArray to an Array. + // This is necessary because JSON.stringify() does some very strange + // things with TypedArrays (they are marshalled to JSON objects, rather + // than being marshalled as a JSON array). + cellSelection = Array.isArray(cellSelection) + ? cellSelection + : Array.from(cellSelection); + + const res = await fetch( + `${globals.API.prefix}${globals.API.version}llmembs/obs`, + { + method: "POST", + headers: new Headers({ + Accept: "application/json", + "Content-Type": "application/json", + }), + body: JSON.stringify({ + cellSelection: { filter: { obs: { index: cellSelection } } }, + }), + credentials: "include", + } + ); + + if (!res.ok || res.headers.get("Content-Type") !== "application/json") { + return dispatch({ + type: "request llm embeddings error", + error: new Error( + `Unexpected response ${res.status} ${ + res.statusText + } ${res.headers.get("Content-Type")}}` + ), + }); + } + + const response = await res.json(); + return dispatch({ + type: "embedding model text response from cells", + data: response, + }); + } catch (error) { + return dispatch({ + type: "request llm embeddings error", + error, + }); + } + }; + +export const requestEmbeddingLLMWithText = + /* + Send a request to the LLM embedding model with text + */ + (text) => async (dispatch) => { + dispatch({ + type: "request to embedding model started", + }); + try { + const res = await fetch( + `${globals.API.prefix}${globals.API.version}llmembs/text`, + { + method: "POST", + headers: new Headers({ + Accept: "application/octet-stream", + "Content-Type": "application/json", + }), + body: JSON.stringify({ + text, + }), + credentials: "include", + } + ); + + if ( + !res.ok || + res.headers.get("Content-Type") !== "application/octet-stream" + ) { + return dispatch({ + type: "request llm embeddings error", + error: new Error( + `Unexpected response ${res.status} ${ + res.statusText + } ${res.headers.get("Content-Type")}}` + ), + }); + } + + const buffer = await res.arrayBuffer(); + const dataframe = matrixFBSToDataframe(buffer); + const col = dataframe.__columns[0]; + + const annotationName = dataframe.colIndex.getLabel(0); + + dispatch({ + type: "embedding model annotation response from text", + }); + + return dispatch(annotationCreateContinuousAction(annotationName, col)); + } catch (error) { + return dispatch({ + type: "request llm embeddings error", + error, + }); + } + }; + + +/* + Action creator to interact with the http_bot endpoint +*/ +export const startChatRequest = (messages, prompt, cellSelection) => async (dispatch) => { + let newMessages = messages.concat({from: "human", value: prompt}); + dispatch({ type: "chat request start", newMessages }); + + try { + if (!cellSelection) cellSelection = []; + + // These lines ensure that we convert any TypedArray to an Array. + // This is necessary because JSON.stringify() does some very strange + // things with TypedArrays (they are marshalled to JSON objects, rather + // than being marshalled as a JSON array). + cellSelection = Array.isArray(cellSelection) + ? cellSelection + : Array.from(cellSelection); + + const pload = { + messages: newMessages, // TODO might need to add to first message + cellSelection: { filter: { obs: { index: cellSelection } } }, + }; + + const response = await fetch(`${globals.API.prefix}${globals.API.version}llmembs/chat`, { + method: 'POST', + headers: new Headers({ + // Accept: "application/json", + 'Content-Type': 'application/json', + }), + body: JSON.stringify(pload), + }); + + if (!response.ok) { + throw new Error('Failed to get response from the model'); + } + + // NOTE: The canonical way to solve this would probably be to use EventStreams. But it should also be possible with fetch as below + // Stream the response (assuming the API sends back chunked responses) + const reader = response.body.getReader(); + let chunksAll = new Uint8Array(0); + let receivedLength = 0; // length at the moment + while(true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + let temp = new Uint8Array(receivedLength + value.length); + temp.set(chunksAll, 0); // copy the old data + temp.set(value, receivedLength); // append the new chunk + chunksAll = temp; // reassign the extended array + receivedLength += value.length; + + // get the last chunk + + // Assuming chunksAll is the Uint8Array containing the data + let lastZeroIndex = chunksAll.lastIndexOf(0); + + if (lastZeroIndex == -1) { + continue; + } + let secondLastZeroIndex = chunksAll.lastIndexOf(0, lastZeroIndex - 1); + // if secondLastZeroIndex is -1 (only 1 zero), go from the start + let lastChunk = chunksAll.slice(secondLastZeroIndex+1, lastZeroIndex); + + // Decode into a string + let result = new TextDecoder("utf-8").decode(lastChunk); + + // Parse the JSON (assuming the final string is a JSON object) + const data = JSON.parse(result); + + // trim away the '' string: + data.text = data.text.replace("", ""); + + dispatch({ type: "chat request success", payload: data.text }); + } + + } catch (error) { + dispatch({ type: "chat request failure", payload: error.message }); + } +}; diff --git a/client/src/annoMatrix/loader.js b/client/src/annoMatrix/loader.js index 7eb89065..ffa05443 100644 --- a/client/src/annoMatrix/loader.js +++ b/client/src/annoMatrix/loader.js @@ -105,7 +105,8 @@ export default class AnnoMatrixLoader extends AnnoMatrix { * a primitive type, including null or undefined. If an array, it must be of same size as nObs and same type as Ctor */ - colSchema.writable = true; + + if (colSchema.writable === undefined) colSchema.writable = true; const colName = colSchema.name; if ( _getColumnSchema(this.schema, "obs", colName) || @@ -126,10 +127,11 @@ export default class AnnoMatrixLoader extends AnnoMatrix { data = new Ctor(this.nObs).fill(value); } newAnnoMatrix._cache.obs = this._cache.obs.withCol(colName, data); - normalizeWritableCategoricalSchema( - colSchema, - newAnnoMatrix._cache.obs.col(colName) - ); + if (colSchema.type === "categorical" || colSchema.type === undefined) + normalizeWritableCategoricalSchema( + colSchema, + newAnnoMatrix._cache.obs.col(colName) + ); newAnnoMatrix.schema = addObsAnnoColumn(this.schema, colName, colSchema); return newAnnoMatrix; } diff --git a/client/src/components/chatSidebar/index.js b/client/src/components/chatSidebar/index.js new file mode 100644 index 00000000..6682c615 --- /dev/null +++ b/client/src/components/chatSidebar/index.js @@ -0,0 +1,181 @@ +import React from "react"; +import { connect } from "react-redux"; +import { Button, InputGroup } from "@blueprintjs/core"; +import actions from "../../actions"; + +function renderList(items) { + return ( + + ); +} + +@connect((state) => ({ + ...state.llmEmbeddings, + obsCrossfilter: state.obsCrossfilter, +})) +class ChatSideBar extends React.Component { + constructor(props) { + super(props); + this.state = { + inputText: "", + conversationSample: null, + }; + this.messagesEndRef = React.createRef(); // Create a ref for the messages container + } + + handleInputChange = (e) => { + this.setState({ inputText: e.target.value }); + }; + + findCellsClick = () => { + const { dispatch } = this.props; + const { inputText } = this.state; + dispatch(actions.requestEmbeddingLLMWithText(inputText)); + }; + + chatSelectedClick = () => { + const { dispatch, obsCrossfilter, messages } = this.props; + const { inputText, conversationSample } = this.state; + // Dispatch the action to send the message + let submitMessages = messages; + + // Test if conversationSample changed + if (JSON.stringify(conversationSample) !== JSON.stringify(obsCrossfilter.allSelectedLabels())) { + submitMessages = []; + this.setState({ conversationSample: obsCrossfilter.allSelectedLabels() }); + } + dispatch(actions.startChatRequest(submitMessages, inputText, obsCrossfilter.allSelectedLabels())); + this.setState({ inputText: "" }); // Clear the input after sending + }; + + componentDidUpdate(prevProps) { + if (prevProps.messages !== this.props.messages) { + this.scrollToBottom(); + } + } + + scrollToBottom = () => { + if (this.messagesEndRef.current) { + this.messagesEndRef.current.scrollTop = this.messagesEndRef.current.scrollHeight; + } + }; + + render() { + const { messages, loading, obsCrossfilter } = this.props; + const { inputText, conversationSample } = this.state; + + return ( +
+
+ {typeof messages === "string" ? ( + messages + ) : ( +
+ {messages.map((message) => ( +
+ {message.value} +
+ ))} +
+ )} +
+
+ { + if (e.key === "Enter" && obsCrossfilter.countSelected() > 0 && inputText) { + this.chatSelectedClick(); + } + }} + /> +
+
+ + +
+
+ ); + } +} + +export default ChatSideBar; diff --git a/client/src/components/framework/logo.js b/client/src/components/framework/logo.js index 3b7d3881..4fb0fac2 100644 --- a/client/src/components/framework/logo.js +++ b/client/src/components/framework/logo.js @@ -1,16 +1,8 @@ import React from "react"; -import icon from "../../images/icon.png"; const Logo = (props) => { - const { size } = props; - return ( - CELLxGENE Annotate Logo - ); + const { size, alt, src } = props; + return {alt}; }; export default Logo; diff --git a/client/src/components/leftSidebar/topLeftLogoAndTitle.js b/client/src/components/leftSidebar/topLeftLogoAndTitle.js index 5de741a2..9a9fda87 100644 --- a/client/src/components/leftSidebar/topLeftLogoAndTitle.js +++ b/client/src/components/leftSidebar/topLeftLogoAndTitle.js @@ -6,6 +6,9 @@ import Logo from "../framework/logo"; import Truncate from "../util/truncate"; import InformationMenu from "./infoMenu"; +import cxgIcon from "../../images/icon_cxg.png"; +import cellwhispererIcon from "../../images/icon_cellwhisperer.png"; + const DATASET_TITLE_FONT_SIZE = 14; @connect((state) => { @@ -48,31 +51,49 @@ class LeftSideBar extends React.Component { }} >
- - - cell +
+ + cell + + × + + gene + +
+
+ + - × + CellWhisperer - gene - +
({ @@ -8,6 +9,7 @@ import * as globals from "../../globals"; scatterplotYYaccessor: state.controls.scatterplotYYaccessor, })) class RightSidebar extends React.Component { + // Bar should capitalized... render() { return (
+
); } diff --git a/client/src/globals.js b/client/src/globals.js index 05845540..25540312 100644 --- a/client/src/globals.js +++ b/client/src/globals.js @@ -74,8 +74,8 @@ export const graphWidth = 700; export const graphHeight = 700; export const scatterplotMarginLeft = 11; -export const rightSidebarWidth = 365; -export const leftSidebarWidth = 365; +export const rightSidebarWidth = 515; +export const leftSidebarWidth = 325; export const leftSidebarSectionHeading = { fontSize: 18, textTransform: "uppercase", @@ -115,6 +115,7 @@ if (typeof window !== "undefined" && window.CELLXGENE && window.CELLXGENE.API) { // prefix: "http://api.clustering.czi.technology/api/", // prefix: "http://tabulamuris.cxg.czi.technology/api/", // prefix: "http://api-staging.clustering.czi.technology/api/", + // prefix: `http://s0-n11.hpc.meduniwien.ac.at:${CXG_SERVER_PORT}/api/`, prefix: `http://localhost:${CXG_SERVER_PORT}/api/`, version: "v0.2/", }; diff --git a/client/src/images/icon.png b/client/src/images/icon.png deleted file mode 100644 index dbc81eb8..00000000 Binary files a/client/src/images/icon.png and /dev/null differ diff --git a/client/src/images/icon_cellwhisperer.png b/client/src/images/icon_cellwhisperer.png new file mode 100644 index 00000000..92699f91 Binary files /dev/null and b/client/src/images/icon_cellwhisperer.png differ diff --git a/client/src/reducers/index.js b/client/src/reducers/index.js index 29ace171..b72bb699 100644 --- a/client/src/reducers/index.js +++ b/client/src/reducers/index.js @@ -19,6 +19,7 @@ import genesetsUI from "./genesetsUI"; import autosave from "./autosave"; import centroidLabels from "./centroidLabels"; import pointDialation from "./pointDilation"; +import llmEmbeddings from "./llmEmbeddings"; import { gcMiddleware as annoMatrixGC } from "../annoMatrix"; import undoableConfig from "./undoableConfig"; @@ -41,6 +42,7 @@ const Reducer = undoable( ["centroidLabels", centroidLabels], ["pointDilation", pointDialation], ["autosave", autosave], + ["llmEmbeddings", llmEmbeddings], // TODO might need to go before continouosselection or annotations ]), [ "annoMatrix", @@ -55,6 +57,7 @@ const Reducer = undoable( "centroidLabels", "genesets", "annotations", + "llmEmbeddings", ], undoableConfig ); diff --git a/client/src/reducers/llmEmbeddings.js b/client/src/reducers/llmEmbeddings.js new file mode 100644 index 00000000..4a462d29 --- /dev/null +++ b/client/src/reducers/llmEmbeddings.js @@ -0,0 +1,81 @@ +/* + Reducers for LLMEmbedding. +*/ +const LLMEmbedding = ( + state = { + messages: "The LLM still primarily hallucinates. Please use its results with a good portion of scepticism (you will see yourself).", + loading: false, + }, + action +) => { + switch (action.type) { + case "request to embedding model started": { + return { + ...state, + loading: true, + }; + } + case "embedding model annotation response from text": { + return { + ...state, + loading: false, + }; + } + + case "embedding model text response from cells": { + return { + ...state, + messages: action.data, + loading: false, + }; + } + case "request llm embeddings error": { + return { + ...state, + messages: `ERROR: ${action.error}`, + loading: false, + }; + } + + case "chat reset": { + return { + ...state, + // Add an empty message to the end of the list of messages + messages: [], + // error: null, + }; + } + + case "chat request start": { + return { + ...state, + loading: true, + // Add an empty message to the end of the list of messages + messages: action.newMessages.concat({ value: "", from: "gpt" }), + // error: null, + }; + } + case "chat request success": { + return { + ...state, + // Replace the last entry of messages + messages: state.messages.slice(0, -1).concat({ value: action.payload, from: "gpt" }), + loading: false, + // error: null, + }; + } + case "chat request failure": { + return { + ...state, + messages: state.messages.slice(0, -1).concat({ value: `ERROR: ${action.payload}`, from: "gpt" }), + loading: false, + // error: action.payload, // Error message + }; + } + + default: + return state; + } +}; + +export default LLMEmbedding; diff --git a/client/src/reducers/undoableConfig.js b/client/src/reducers/undoableConfig.js index d18270ad..dfe59838 100644 --- a/client/src/reducers/undoableConfig.js +++ b/client/src/reducers/undoableConfig.js @@ -105,6 +105,8 @@ const saveOnActions = new Set([ "annotation: delete label", "annotation: category edited", + "annotation: create continuous", + /* geneset component action */ "geneset: create", "geneset: delete", diff --git a/common.mk b/common.mk index e331c579..340f729f 100644 --- a/common.mk +++ b/common.mk @@ -20,11 +20,12 @@ export CXG_SERVER_PORT := $(call env_or_else_default,CXG_SERVER_PORT) export CXG_CLIENT_PORT := $(call env_or_else_default,CXG_CLIENT_PORT) export CXG_OPTIONS := $(call env_or_else_default,CXG_OPTIONS) export DATASET := $(call full_path,$(call env_or_else_default,DATASET)) +export MODEL := $(call env_or_else_default,MODEL) export JEST_ENV := $(call env_or_else_default,JEST_ENV) .PHONY: start-server start-server: - cellxgene launch -p $(CXG_SERVER_PORT) $(CXG_OPTIONS) $(DATASET) + cellxgene launch -p $(CXG_SERVER_PORT) $(CXG_OPTIONS) $(DATASET) $(MODEL) # copy the client assets to a location known to the server # $(1) is the source of the client assets diff --git a/docs/cellxgene-favicon.png b/docs/cellxgene-favicon.png deleted file mode 100644 index 58f43344..00000000 Binary files a/docs/cellxgene-favicon.png and /dev/null differ diff --git a/docs/cellxgene-favicon.png b/docs/cellxgene-favicon.png new file mode 120000 index 00000000..f19d66a0 --- /dev/null +++ b/docs/cellxgene-favicon.png @@ -0,0 +1 @@ +/home/moritz/Projects/cellwhisperer/modules/cellxgene/client/src/images/icon_cellwhisperer.png \ No newline at end of file diff --git a/environment.default.json b/environment.default.json index c5b32e0d..81d303d7 100644 --- a/environment.default.json +++ b/environment.default.json @@ -1,8 +1,9 @@ { "CXG_CLIENT_PORT": 3000, - "CXG_OPTIONS": "--debug", + "CXG_OPTIONS": "--debug --host 0.0.0.0 --max-category-items 500", "CXG_SERVER_PORT": 5005, - "DATASET": "example-dataset/pbmc3k.h5ad", + "DATASET": "~/cellwhisperer/results/daniel/03jujd8s/cellxgene.h5ad", + "MODEL": "~/cellwhisperer/results/models/jointemb/03jujd8s.ckpt", "DEBUG": "debug", "DEV": "dev", "JEST_ENV": "prod", diff --git a/server/app/app.py b/server/app/app.py index 61f9bc84..d088f851 100644 --- a/server/app/app.py +++ b/server/app/app.py @@ -167,6 +167,27 @@ def get(self, data_adaptor): return common_rest.layout_obs_get(request, data_adaptor) +class LLMEmbeddingsObsAPI(Resource): + @cache_control(no_store=True) + @rest_get_data_adaptor + def post(self, data_adaptor): + return common_rest.llm_embeddings_obs_post(request, data_adaptor) + + +class LLMEmbeddingsTextAPI(Resource): + @cache_control(no_store=True) + @rest_get_data_adaptor + def post(self, data_adaptor): + return common_rest.llm_embeddings_text_post(request, data_adaptor) + + +class LLMEmbeddingsChatAPI(Resource): + @cache_control(no_store=True) + @rest_get_data_adaptor + def post(self, data_adaptor): + return common_rest.llm_embeddings_chat_post(request, data_adaptor) + + class GenesetsAPI(Resource): @cache_control(public=True, max_age=ONE_WEEK) @rest_get_data_adaptor @@ -222,6 +243,9 @@ def add_resource(resource, url): # Computation routes add_resource(DiffExpObsAPI, "/diffexp/obs") add_resource(LayoutObsAPI, "/layout/obs") + add_resource(LLMEmbeddingsObsAPI, "/llmembs/obs") + add_resource(LLMEmbeddingsTextAPI, "/llmembs/text") + add_resource(LLMEmbeddingsChatAPI, "/llmembs/chat") return api diff --git a/server/cli/launch.py b/server/cli/launch.py index 7fed726a..09e17428 100644 --- a/server/cli/launch.py +++ b/server/cli/launch.py @@ -98,6 +98,13 @@ def config_args(func): show_default=False, help="Disable on-demand differential expression.", ) + @click.option( + "--disable-llmembs", + is_flag=True, + default=not DEFAULT_CONFIG.dataset_config.llmembs__enable, + show_default=False, + help="Disable on-demand LLM-Embeddings services.", + ) @click.option( "--embedding", "-e", @@ -222,6 +229,7 @@ def launch_args(func): @dataset_args @server_args @click.argument("datapath", required=False, metavar="") + @click.argument("modelpath", required=False, metavar="") @click.option( "--open", "-o", @@ -303,6 +311,7 @@ def _before_adding_routes(app, app_config): @launch_args def launch( datapath, + modelpath, verbose, debug, open_browser, @@ -324,6 +333,7 @@ def launch( disable_gene_sets_save, backed, disable_diffexp, + disable_llmembs, config_file, dump_default_config, x_approximate_distribution, @@ -385,6 +395,8 @@ def launch( embeddings__names=embedding, diffexp__enable=not disable_diffexp, diffexp__lfc_cutoff=diffexp_lfc_cutoff, + llmembs__enable=not disable_llmembs, + llmembs__model_checkpoint=modelpath, X_approximate_distribution=x_approximate_distribution, ) diff --git a/server/common/compute/cellwhisperer_wrapper.py b/server/common/compute/cellwhisperer_wrapper.py new file mode 100644 index 00000000..02e9a1e1 --- /dev/null +++ b/server/common/compute/cellwhisperer_wrapper.py @@ -0,0 +1,284 @@ +import logging + +import os +import json +import pandas as pd +import numpy as np +from typing import List + +import requests +import pickle +import torch + +from cellwhisperer.utils.inference import score_transcriptomes_vs_texts, rank_terms_by_score, prepare_terms +import torch +from cellwhisperer.utils.model_io import load_cellwhisperer_model +from . import llava_utils, llava_conversation + +default_conversation = llava_conversation.conv_mistral_instruct + +logger = logging.getLogger(__name__) + + +class CellWhispererWrapper: + def __init__(self, model_path_or_url: str): + """ + Load the model from the given path or use it via the given URL + """ + if os.path.exists(model_path_or_url): + logging.info("Loading LLM embedding model...") + self.pl_model, self.tokenizer, self.transcriptome_processor = load_cellwhisperer_model( + model_path_or_url, cache=True + ) + logging.info("Loading done") + self.logit_scale = self.pl_model.model.discriminator.temperature.exp() + else: + self.pl_model = None + self.api_url = model_path_or_url + # load logit_scale via API + response = requests.get(self.api_url + "/logit_scale") + self.logit_scale = float(response.content) + + def preprocess_data(self, adaptor): + """ + Preprocess data for LLM embeddings, making sure that subsequent API requests run fast. + If things are cached already (through frozenmodel and/or the adaptor) this will be fast + + adaptor: Access to the adata object + """ + logging.info("Preprocessing data for LLM embeddings, making sure it's fast") + return # just for testing + + # Make sure that all the zero-shot class terms are embedded + mask = np.zeros(adaptor.data.shape[0], dtype=bool) + mask[0] = True # Generate mask with single element + self.llm_obs_to_text(adaptor, mask=mask) + + # Embed all cells + self.llm_text_to_annotations(adaptor, text="test") + + response = requests.post(self.api_url + "/store_cache") + + def llm_obs_to_text(self, adaptor, mask): + """ + Embed the given cells into the LLM space and return their average similarity to different keywords as formatted text. + Keyword types used for comparison are: (i) selected enrichR terms (see cellwhisperer.validation.zero_shot.functions.write_enrichr_terms_to_json) \ + and (ii) cell type annotations (currently all values in adata.obs.columns). For more info, see cellwhisperer.validation.zero_shot.functions. + :param adaptor: DataAdaptor instance + :param mask: + :return: dictionary {text: } + """ + var_index_col_name = adaptor.get_schema()["annotations"]["var"]["index"] + obs_index_col_name = adaptor.get_schema()["annotations"]["obs"]["index"] + + if "transcriptome_embeds" in adaptor.data.obsm: + transcriptome_embeds = torch.from_numpy(adaptor.data.obsm["transcriptome_embeds"][mask]) + # transcriptomes = transcriptomes.to(self.pl_model.model.device) + else: + # Provide raw read counts, which will be processed by the model + try: + transcriptomes = adaptor.data[mask].to_memory(copy=True) + except MemoryError: + raise + + transcriptomes.var.index = adaptor.data.var[var_index_col_name].astype(str) + transcriptomes.obs.index = adaptor.data.obs.loc[mask, obs_index_col_name].astype(str) + transcriptome_embeds = self.pl_model.embed_transcriptomes(transcriptomes) + + # Get all categorical columns (too extensive and doesn't make sense) + # obs_cols = [c for c, t in adaptor.data.obs.dtypes.items() if isinstance(t, CategoricalDtype)] + # additional_text_dict = { + # obs_col: adaptor.data.obs[obs_col].astype(str).unique().tolist() for obs_col in obs_cols + # } + terms = adaptor.data.uns["terms"] + + terms_df = prepare_terms(terms) # additional_text_dict + text_embeds = self._embed_texts(terms_df["term"].to_list()) + + scores, _ = score_transcriptomes_vs_texts( + transcriptome_input=transcriptome_embeds, + text_list_or_text_embeds=text_embeds, + logit_scale=self.logit_scale, + average_mode="embeddings", + score_norm_method="zscore", + ) # n_text * 1 + + similarity_scores_df = rank_terms_by_score(scores, terms_df) + + top_5_entries = ( + similarity_scores_df.query("logits > 0.0") # drop negatives + .groupby("library") + .apply(lambda x: x.nlargest(5, "logits")) + .reset_index(drop=True) + ) + + # Combine the term names with the scores (logits) + top_5_entries["labels"] = top_5_entries["term"] + " (" + top_5_entries["logits"].astype(str) + ")" + + # Combine the term names with the scores (logits) + top_5_entries["labels"] = top_5_entries["term"] + " (" + top_5_entries["logits"].astype(str) + ")" + + # Group by 'library' and create a list of 'labels' + grouped = top_5_entries.groupby("library") + + # Find the maximum logits value for each group + max_logits_per_group = grouped["logits"].max() + + # Sort the groups by the maximum logits value in descending order + sorted_groups = max_logits_per_group.sort_values(ascending=False) + + # Generate the final object to return, sorted by strongest hits on the library-level + structured_text = [ + { + "library": library, + "keywords": grouped.get_group(library)["labels"].tolist(), + } + for library in sorted_groups.index + ] + return structured_text + + def llm_text_to_annotations(self, adaptor, text) -> pd.Series: + """ + Embed the given text into the LLM space and return the similarity of each cell to the text. The similarity will be used as new cell-level annotation + """ + # Converts an obs index of "0", "1", ... to "TTTGCATGAGAGGC-1", ... + obs_index_col_name = adaptor.get_schema()["annotations"]["obs"]["index"] + var_index_col_name = adaptor.get_schema()["annotations"]["var"]["index"] + + if "transcriptome_embeds" in adaptor.data.obsm: + transcriptome_embeds = torch.from_numpy(adaptor.data.obsm["transcriptome_embeds"]) + # transcriptomes = transcriptomes.to(self.pl_model.model.device) + else: + assert self.pl_model is not None, "Model is not loaded, so embeddings need to be preprocessed in advance" + # Provide raw read counts, which will be processed by the model + transcriptomes = adaptor.data.to_memory(copy=True) # NOTE copy is slow! + transcriptomes.var.index = adaptor.data.var[var_index_col_name] + transcriptomes.obs.index = adaptor.data.obs[obs_index_col_name].astype(str) + transcriptome_embeds = self.pl_model.embed_transcriptomes(transcriptomes) + + texts = text.split("MINUS") + assert len(texts) in [1, 2], "At max. one MINUS sign allowed" + text_embeds = self._embed_texts(texts) + + scores, _ = score_transcriptomes_vs_texts( + transcriptome_input=transcriptome_embeds, + text_list_or_text_embeds=text_embeds, + logit_scale=self.logit_scale, + average_mode=None, + batch_size=64, + score_norm_method=None, + grouping_keys=adaptor.data.obs[obs_index_col_name].astype(str).values, + ) + + if len(text_embeds) == 2: + scores = scores[0] - scores[1] + else: + scores = scores[0] + + return pd.Series(scores.cpu().detach()) + + def _embed_texts(self, texts: List[str]): + if self.pl_model is None: + # Serialize your input data + # Send the POST request with the json-list of texts + response = requests.post(self.api_url + "/text_embedding", json=texts) + + # Check if the request was successful + if response.status_code == 200: + # Deserialize the response data + text_embeds = torch.from_numpy(pickle.loads(response.content)) + else: + logging.warning(f"Request failed with status code {response.status_code}, {response.content}") + raise RuntimeError(f"Request to model API failed: {response.status_code}") + else: + assert self.pl_model is not None, "Model is not loaded, but querying API for text embedding failed as well" + text_embeds = self.pl_model.embed_texts(texts) + + return text_embeds + + def llm_chat(self, adaptor, messages, mask): + # Extract necessary information from the request + transcriptome_embeds = adaptor.data.obsm["transcriptome_embeds"][mask].mean(axis=0).tolist() + + transcriptomes = adaptor.data.X[mask] + if transcriptomes.shape[0] > 10000: + logging.warning("Too many cells to process, sampling 10k cells") + transcriptomes = transcriptomes[np.random.choice(transcriptomes.shape[0], 10000, replace=False)] + mean_transcriptome = transcriptomes.mean(axis=0).A1 + + # Compute top genes + try: + mean_normalized_transcriptome = ( + mean_transcriptome - adaptor.data.var["log1p_normalizer"].to_numpy() + ) # normalize in logspace via difference + except KeyError: + logging.warning("No log1p_normalizer found in var. Using unnormalized log1ps to compute top genes") + mean_normalized_transcriptome = mean_transcriptome + + n_top_genes = 20 # TODO needs to become config + top_genes = ( + pd.Series(data=mean_normalized_transcriptome, index=adaptor.data.var.gene_name) + .nlargest(n_top_genes) + .index.tolist() + ) + + # Initialize the conversation + state = default_conversation.copy() + + # TODO consider including both normalized and unnormalized genes. Why? A reviewer might check whether the genes are amongst the top expressed ones + + INIT_MESSAGES = [ + { + "from": "human", + "value": f"Help me analyzing this sample of cells. Always respond in proper english sentences and in a tone of uncertainty. Start by listing the top {n_top_genes} genes.", + }, + { + "from": "gpt", + "value": f"Sure, It looks like the top normalized genes are {', '.join(top_genes)}.", + }, + ] + + for i, message in enumerate(INIT_MESSAGES + messages): + if i == 0: + assert message["from"] == "human" + llava_utils.add_text(state, message["value"], transcriptome_embeds, "Transcriptome") + else: + role = {"human": state.roles[0], "gpt": state.roles[1]}[message["from"]] + state.append_message(role, message["value"]) + state.append_message(state.roles[1], None) + + state.offset = 2 + + # TODO need to make CONTROLLER_URL flexible in there + for chunk in llava_utils.http_bot( + state, "Mistral-7B-Instruct-v0.2__03jujd8s", temperature=0.2, top_p=0.7, max_new_tokens=512 + ): + yield json.dumps({"text": chunk}).encode() + b"\x00" + + def manual_request(): + """ + unused in favor of function borrowed from llava + + """ + # Construct the payload for the worker + pload = { + "model": model, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + "images": [transcriptome_embeds], + } + + # Get the worker address from the controller + worker_addr_response = requests.post(f"{controller_url}/get_worker_address", json={"model": model}) + worker_addr = worker_addr_response.json()["address"] + print(worker_addr) + + # Stream the response + with requests.post( + f"{worker_addr}/worker_generate_stream", headers={"User-Agent": "LLaVA Client"}, json=pload, stream=True + ) as r: + for chunk in r.iter_lines(delimiter=b"\x00"): + if chunk: + yield chunk + b"\x00" diff --git a/server/common/compute/llava_conversation.py b/server/common/compute/llava_conversation.py new file mode 100644 index 00000000..548cfed4 --- /dev/null +++ b/server/common/compute/llava_conversation.py @@ -0,0 +1,400 @@ +# copied from LLaVA/llava/conversation.py +import dataclasses +from enum import auto, Enum +from typing import List, Tuple +import base64 +from io import BytesIO +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + if "mmtag" in self.version: + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + else: + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG", max_len=1344, min_len=672): + if image_process_mode == "Transcriptome": + return image + if image_process_mode == "Pad": + + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + if max(image.size) > max_len: + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + return image + else: + buffered = BytesIO() + image.save(buffered, format=image_format) + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + return img_b64_str + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + image = self.process_image(image, image_process_mode, return_pil=return_pil) + images.append(image) + return images + + # def to_gradio_chatbot(self): + # ret = [] + # for i, (role, msg) in enumerate(self.messages[self.offset :]): + # if i % 2 == 0: + # if type(msg) is tuple: + # msg, image, image_process_mode = msg + # img_b64_str = self.process_image(image, "Default", return_pil=False, image_format="JPEG") + # img_str = f'user upload image' + # msg = img_str + msg.replace("", "").strip() + # ret.append([msg, None]) + # else: + # ret.append([msg, None]) + # else: + # ret[-1][-1] = msg + # return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ( + "Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", + ), + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llava_llama_2 = Conversation( + system="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +conv_llava_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v0_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + version="v0_mmtag", +) + +conv_llava_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llava_v1_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", + version="v1_mmtag", +) + +conv_mistral_instruct = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_chatml_direct = Conversation( + system="""<|im_start|>system +Answer the questions.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +default_conversation = conv_mistral_instruct +conv_templates = { + "default": conv_vicuna_v0, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "llama_2": conv_llama_2, + "mistral_instruct": conv_mistral_instruct, + "chatml_direct": conv_chatml_direct, + "mistral_direct": conv_chatml_direct, + "plain": conv_llava_plain, + "v0_plain": conv_llava_plain, + "llava_v0": conv_llava_v0, + "v0_mmtag": conv_llava_v0_mmtag, + "llava_v1": conv_llava_v1, + "v1_mmtag": conv_llava_v1_mmtag, + "llava_llama_2": conv_llava_llama_2, + "mpt": conv_mpt, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/server/common/compute/llava_utils.py b/server/common/compute/llava_utils.py new file mode 100644 index 00000000..cbee45ac --- /dev/null +++ b/server/common/compute/llava_utils.py @@ -0,0 +1,222 @@ +# Derived from LLaVA/llava/serve/gradio_web_server.py and + +from .llava_conversation import default_conversation, conv_templates, SeparatorStyle + +import datetime +import json +import os +import time +import hashlib + +import requests +import logging +from pathlib import Path + + +# logger = build_logger("gradio_web_server", "gradio_web_server.log") +logger = logging.getLogger("llava_utils") + +CONTROLLER_URL = "http://cellwhisperer_llava_controller:10000" +# CONTROLLER_URL = "http://localhost:10000" +LOGDIR = Path("logs/") +LOGDIR.mkdir(exist_ok=True) + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" + +headers = {"User-Agent": "LLaVA Client"} + + +def get_conv_log_filename(): + t = datetime.datetime.now() + return LOGDIR / f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json" + + +# def get_model_list(): +# ret = requests.post(CONTROLLER_URL + "/refresh_all_workers") +# assert ret.status_code == 200 +# ret = requests.post(CONTROLLER_URL + "/list_models") +# models = ret.json()["models"] +# models.sort(key=lambda x: priority.get(x, x)) +# logger.info(f"Models: {models}") +# return models + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + + +def vote_last_response(state, vote_type, model_selector, request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + + +def downvote_last_response(state, model_selector, request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + + +def flag_last_response(state, model_selector, request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + + +def regenerate(state, image_process_mode, request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + + +def clear_history(request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return state + + +def add_text(state, text, image=None, image_process_mode="Transcriptome"): + if len(text) <= 0 and image is None: + raise ValueError("No input") + + text = text[:1536] # Hard cut-off + if image is not None: + text = text[:1200] # Hard cut-off for images + if "" not in text: + # text = '' + text + text = text + "\n" + text = (text, image, image_process_mode) + state.append_message(state.roles[0], text) + + +def http_bot(state, model_selector, temperature, top_p, max_new_tokens): + start_tstamp = time.time() + model_name = model_selector + + # For the first user-provided message, cut away the preamble + # if len(state.messages) == state.offset + 2: + # # First round of conversation + # if "llava" in model_name.lower() or "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + # if "llama-2" in model_name.lower(): + # template_name = "llava_llama_2" + # elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + # if "orca" in model_name.lower(): + # template_name = "mistral_orca" + # elif "hermes" in model_name.lower(): + # template_name = "chatml_direct" + # else: + # template_name = "mistral_instruct" + # elif "llava-v1.6-34b" in model_name.lower(): + # template_name = "chatml_direct" + # elif "v1" in model_name.lower(): + # if "mmtag" in model_name.lower(): + # template_name = "v1_mmtag" + # elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + # template_name = "v1_mmtag" + # else: + # template_name = "llava_v1" + # elif "mpt" in model_name.lower(): + # template_name = "mpt" + # else: + # if "mmtag" in model_name.lower(): + # template_name = "v0_mmtag" + # elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + # template_name = "v0_mmtag" + # else: + # template_name = "llava_v0" + # elif "mpt" in model_name: + # template_name = "mpt_text" + # elif "llama-2" in model_name: + # template_name = "llama_2" + # else: + # template_name = "vicuna_v1" + # new_state = conv_templates[template_name].copy() + # new_state.append_message(new_state.roles[0], state.messages[-2][1]) + # new_state.append_message(new_state.roles[1], None) + # state = new_state + + # Query worker address + ret = requests.post(CONTROLLER_URL + "/get_worker_address", json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + yield server_error_msg + return + + # Construct prompt + prompt = state.get_prompt() + + all_images = state.get_images(return_pil=True) + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1536), + "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, + "images": f"List of {len(state.get_images())} images (transcriptomes)", + } + logger.info(f"==== request ====\n{pload}") + + pload["images"] = state.get_images() + + try: + # Stream output + response = requests.post( + worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10 + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt) :].strip() + yield output + else: + yield data["error_code"] + return + time.sleep(0.03) + else: + yield output + except requests.exceptions.RequestException as e: + yield server_error_msg + f" ({e})" + return + + finish_tstamp = time.time() + logger.info(f"{output}") + + state.messages[-1][-1] = output + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "messages": state.messages, + "state": state.dict(), + # "images": all_image_hash, + # "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") diff --git a/server/common/config/client_config.py b/server/common/config/client_config.py index bd741ec3..f46f5f85 100644 --- a/server/common/config/client_config.py +++ b/server/common/config/client_config.py @@ -39,6 +39,7 @@ def get_client_config(app_config, data_adaptor): "diffexp_lfc_cutoff": dataset_config.diffexp__lfc_cutoff, "backed": server_config.adaptor__anndata_adaptor__backed, "disable-diffexp": not dataset_config.diffexp__enable, + "disable-llmembs": not dataset_config.llmembs__enable, "annotations": False, "annotations_file": None, "annotations_dir": None, diff --git a/server/common/config/dataset_config.py b/server/common/config/dataset_config.py index 59307aa3..c0f1acd4 100644 --- a/server/common/config/dataset_config.py +++ b/server/common/config/dataset_config.py @@ -38,6 +38,9 @@ def __init__(self, tag, app_config, default_config): self.diffexp__lfc_cutoff = default_config["diffexp"]["lfc_cutoff"] self.diffexp__top_n = default_config["diffexp"]["top_n"] + self.llmembs__enable = default_config["llmembs"]["enable"] + self.llmembs__model_checkpoint = default_config["llmembs"]["model_checkpoint"] + self.X_approximate_distribution = default_config["X_approximate_distribution"] except KeyError as e: @@ -52,6 +55,7 @@ def complete_config(self, context): self.handle_user_annotations(context) self.handle_embeddings() self.handle_diffexp(context) + self.handle_llmembs(context) self.handle_X_approximate_distribution() def get_data_adaptor(self): @@ -181,6 +185,10 @@ def handle_diffexp(self, context): "CAUTION: due to the size of your dataset, " "running differential expression may take longer or fail." ) + def handle_llmembs(self, context): + self.validate_correct_type_of_configuration_attribute("llmembs__enable", bool) + self.validate_correct_type_of_configuration_attribute("llmembs__model_checkpoint", str) + def handle_X_approximate_distribution(self): self.validate_correct_type_of_configuration_attribute("X_approximate_distribution", str) if self.X_approximate_distribution not in ["auto", "normal", "count"]: diff --git a/server/common/rest.py b/server/common/rest.py index 78fe467c..8935408b 100644 --- a/server/common/rest.py +++ b/server/common/rest.py @@ -5,7 +5,7 @@ import zlib import json -from flask import make_response, jsonify, current_app, abort +from flask import make_response, jsonify, current_app, abort, stream_with_context from urllib.parse import unquote from server.common.config.client_config import get_client_config @@ -397,3 +397,118 @@ def summarize_var_post(request, data_adaptor): key = request.args.get("key", default=None) return summarize_var_helper(request, data_adaptor, key, request.get_data()) + + +def llm_embeddings_text_post(request, data_adaptor): + """ + Given a text description, return a cell annotation + """ + try: + llm_embeddings_text_post.counter + except AttributeError: + llm_embeddings_text_post.counter = 0 + + if not data_adaptor.dataset_config.llmembs__enable: + return abort(HTTPStatus.NOT_IMPLEMENTED) + + args = request.get_json() + try: + text = args.get("text") + + if text is None: + return abort_and_log(HTTPStatus.BAD_REQUEST, "missing required parameter text") + + except (KeyError, TypeError) as e: + return abort_and_log(HTTPStatus.BAD_REQUEST, str(e), include_exc_info=True) + + try: + labels = data_adaptor.compute_llmembs_text_to_annotations(text) + + # compute a string-like hash and take the first 5 characters + + annotation_name = f"{llm_embeddings_text_post.counter}_{text.replace(' ', '_')[:20]}" + llm_embeddings_text_post.counter += 1 + + labels.name = annotation_name + index_name = data_adaptor.parameters.get("obs_names") + labels.index = data_adaptor.data.obs[index_name] + + if request.accept_mimetypes.accept_json: + return make_response(jsonify(labels.to_dict()), HTTPStatus.OK) + + fbs = data_adaptor.annotation_to_fbs_matrix( + Axis.OBS, [annotation_name], labels.to_frame() + ) # same as calling encode_matrix_fbs directly + + return make_response(fbs, HTTPStatus.OK, {"Content-Type": "application/octet-stream"}) + + except (ValueError, DisabledFeatureError, FilterError, ExceedsLimitError) as e: + return abort_and_log(HTTPStatus.BAD_REQUEST, str(e), include_exc_info=True) + except JSONEncodingValueError: + # JSON encoding failure, usually due to bad data. Just let it ripple up + # to default exception handler. + current_app.logger.warning(JSON_NaN_to_num_warning_msg) + raise + + +def llm_embeddings_obs_post(request, data_adaptor): + """ + Given a set of cells, return a text description for them + """ + if not data_adaptor.dataset_config.llmembs__enable: + return abort(HTTPStatus.NOT_IMPLEMENTED) + + args = request.get_json() + try: + selection_filter = args.get("cellSelection", {"filter": {}})["filter"] + + if selection_filter is None: + return abort_and_log(HTTPStatus.BAD_REQUEST, "missing required parameter set1") + if Axis.VAR in selection_filter: + return abort_and_log(HTTPStatus.BAD_REQUEST, "var axis filter not enabled") + + except (KeyError, TypeError) as e: + return abort_and_log(HTTPStatus.BAD_REQUEST, str(e), include_exc_info=True) + + try: + model_result = data_adaptor.llmembs_obs_to_text(selection_filter) + return make_response(model_result, HTTPStatus.OK, {"Content-Type": "application/json"}) + except (ValueError, DisabledFeatureError, FilterError, ExceedsLimitError) as e: + return abort_and_log(HTTPStatus.BAD_REQUEST, str(e), include_exc_info=True) + except JSONEncodingValueError: + # JSON encoding failure, usually due to bad data. Just let it ripple up + # to default exception handler. + current_app.logger.warning(JSON_NaN_to_num_warning_msg) + raise + + +def llm_embeddings_chat_post(request, data_adaptor): + if not data_adaptor.dataset_config.llmembs__enable: + return abort(HTTPStatus.NOT_IMPLEMENTED) + + args = request.get_json() + try: + selection_filter = args.get("cellSelection", {"filter": {}})["filter"] + + # TODO more filters may be appropriate + if selection_filter is None: + return abort_and_log(HTTPStatus.BAD_REQUEST, "missing required parameter cellSelection") + + except (KeyError, TypeError) as e: + return abort_and_log(HTTPStatus.BAD_REQUEST, str(e), include_exc_info=True) + + try: + chat_generator = data_adaptor.establish_llmembs_chat(args, selection_filter) + + response = make_response( + stream_with_context(chat_generator), HTTPStatus.OK, {"Content-Type": "application/json"} + ) + response.headers["Content-Encoding"] = "identity" # Explicitly set to 'identity' to indicate no compression + return response + except (ValueError, DisabledFeatureError, FilterError, ExceedsLimitError) as e: + return abort_and_log(HTTPStatus.BAD_REQUEST, str(e), include_exc_info=True) + except JSONEncodingValueError: + # JSON encoding failure, usually due to bad data. Just let it ripple up + # to default exception handler. + current_app.logger.warning(JSON_NaN_to_num_warning_msg) + raise diff --git a/server/data_anndata/anndata_adaptor.py b/server/data_anndata/anndata_adaptor.py index 31e5398b..0df5b238 100644 --- a/server/data_anndata/anndata_adaptor.py +++ b/server/data_anndata/anndata_adaptor.py @@ -1,3 +1,4 @@ +import logging import warnings import anndata @@ -7,15 +8,17 @@ from scipy import sparse import server.common.compute.diffexp_generic as diffexp_generic +from server.common.compute.cellwhisperer_wrapper import CellWhispererWrapper import server.common.compute.estimate_distribution as estimate_distribution from server.common.colors import convert_anndata_category_colors_to_cxg_category_colors from server.common.constants import Axis, MAX_LAYOUTS, XApproximateDistribution from server.common.corpora import corpora_get_props_from_anndata -from server.common.errors import PrepareError, DatasetAccessError +from server.common.errors import PrepareError, DatasetAccessError, FilterError from server.common.utils.type_conversion_utils import get_schema_type_hint_of_array from server.data_common.data_adaptor import DataAdaptor from server.common.fbs.matrix import encode_matrix_fbs + anndata_version = version.parse(str(anndata.__version__)).release @@ -32,6 +35,9 @@ def __init__(self, data_locator, app_config=None, dataset_config=None): self.X_approximate_distribution = None self._load_data(data_locator) self._validate_and_initialize() + self.cellwhisperer = CellWhispererWrapper(self.dataset_config.llmembs__model_checkpoint) + + self.cellwhisperer.preprocess_data(self) # required to cache all the keywords def cleanup(self): pass @@ -333,6 +339,35 @@ def compute_diffexp_ttest(self, maskA, maskB, top_n=None, lfc_cutoff=None): lfc_cutoff = self.dataset_config.diffexp__lfc_cutoff return diffexp_generic.diffexp_ttest(self, maskA, maskB, top_n, lfc_cutoff) + def compute_llmembs_obs_to_text(self, mask): + return self.cellwhisperer.llm_obs_to_text(self, mask) + + def compute_llmembs_text_to_annotations(self, text): + """ + Computes an LLM embedding for each cell and compares it to the embedding of the text and returns the distance + + :param text: the text to embed + :return: pandas Series of cell embeddings + """ + return self.cellwhisperer.llm_text_to_annotations(self, text=text) + + def establish_llmembs_chat(self, data, obs_filter): + """ + Computes the mean expression of each gene in the dataset for the specified observations and runs the + embedding LLM to generate a text + :param obs_filter: filter: dictionary with filter params for set of observations (cells) + :return: generator on text + """ + if Axis.VAR in obs_filter: + raise FilterError("Observation filters may not contain variable conditions") + try: + shape = self.get_shape() + obs_mask = self._axis_filter_to_mask(Axis.OBS, obs_filter["obs"], shape[0]) + except (KeyError, IndexError): + raise FilterError("Error parsing filter") + + return self.cellwhisperer.llm_chat(self, data["messages"], obs_mask) + def get_colors(self): return convert_anndata_category_colors_to_cxg_category_colors(self.data) diff --git a/server/data_common/data_adaptor.py b/server/data_common/data_adaptor.py index f750c550..6873119b 100644 --- a/server/data_common/data_adaptor.py +++ b/server/data_common/data_adaptor.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd from scipy import sparse +from server.common.rest import llm_embeddings_obs_post from server_timing import Timing as ServerTiming from server.common.config.app_config import AppConfig @@ -339,6 +340,18 @@ def diffexp_topN(self, obsFilterA, obsFilterB, top_n=None): def compute_diffexp_ttest(self, maskA, maskB, top_n, lfc_cutoff): pass + @abstractmethod + def compute_llmembs_obs_to_text(self, mask): + pass + + @abstractmethod + def compute_llmembs_text_to_annotations(self, text): + pass + + @abstractmethod + def establish_llmembs_chat(self, data, obs_filter): + pass + @staticmethod def normalize_embedding(embedding): """Normalize embedding layout to meet client assumptions. @@ -421,3 +434,28 @@ def summarize_var(self, method, filter, query_hash): col_idx = pd.Index([query_hash]) return encode_matrix_fbs(mean, col_idx=col_idx, row_idx=None) + + def llmembs_obs_to_text(self, obsFilter): + """ + Computes the mean expression of each gene in the dataset for the specified observations and runs the + embedding LLM to generate a text + TODO this function might be a bit redundant (i.e. why not directly call compute_llmbembs_obs_to_text from rest.py?) + + + :param obsFilter: filter: dictionary with filter params for set of observations (cells) + :return: top N genes and corresponding stats + """ + if Axis.VAR in obsFilter: + raise FilterError("Observation filters may not contain variable conditions") + try: + shape = self.get_shape() + obs_mask = self._axis_filter_to_mask(Axis.OBS, obsFilter["obs"], shape[0]) + except (KeyError, IndexError): + raise FilterError("Error parsing filter") + + result = self.compute_llmembs_obs_to_text(mask=obs_mask) + + try: + return jsonify_strict(result) + except ValueError: + raise JSONEncodingValueError("Error encoding LLM Embeddings text result to JSON") diff --git a/server/default_config.py b/server/default_config.py index 6e8527f0..147fece4 100644 --- a/server/default_config.py +++ b/server/default_config.py @@ -67,6 +67,9 @@ lfc_cutoff: 0.01 top_n: 10 + llmembs: + enable: true + model_checkpoint: ~/cellwhisperer/results/models/jointemb/03jujd8s.ckpt X_approximate_distribution: auto external: diff --git a/server/requirements.txt b/server/requirements.txt index b383c9aa..9bae79dc 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -19,4 +19,4 @@ pandas<2.0.0 PyYAML>=5.4 # CVE-2020-14343 requests>=2.22.0 s3fs==0.4.2 -scipy>=1.4 \ No newline at end of file +scipy>=1.4 diff --git a/test/unit/compute/test_llmembs.py b/test/unit/compute/test_llmembs.py new file mode 100644 index 00000000..3db3c4c3 --- /dev/null +++ b/test/unit/compute/test_llmembs.py @@ -0,0 +1,68 @@ +import unittest + +import numpy as np + +from server.common.compute import llm_embeddings +from server.data_common.matrix_loader import MatrixDataLoader +from test.unit import app_config +from test import PROJECT_ROOT + + +class LLMEmbsTest(unittest.TestCase): + """Tests the llmembs returns the expected results for one test case, using the h5ad + adaptor types and different algorithms.""" + + def load_dataset(self, path, extra_server_config={}, extra_dataset_config={}): + config = app_config(path, extra_server_config=extra_server_config, extra_dataset_config=extra_dataset_config) + loader = MatrixDataLoader(path) + adaptor = loader.open(config) + return adaptor + + def get_mask(self, adaptor, start, stride): + """Simple function to return a mask or rows""" + rows = adaptor.get_shape()[0] + sel = list(range(start, rows, stride)) + mask = np.zeros(rows, dtype=bool) + mask[sel] = True + return mask + + def compare_llmembs_results(self, results, expects): + self.assertEqual(len(results), len(expects)) + for result, expect in zip(results, expects): + self.assertEqual(result[0], expect[0]) + self.assertTrue(np.isclose(result[1], expect[1], 1e-6, 1e-4)) + self.assertTrue(np.isclose(result[2], expect[2], 1e-6, 1e-4)) + self.assertTrue(np.isclose(result[3], expect[3], 1e-6, 1e-4)) + + def check_results(self, results): + """Checks the results for a specific set of rows selections""" + + self.assertIn("text", results) + self.assertIsInstance(results["text"], str) + # expects = [] + + # self.compare_llmembs_results(results, expects) + + def test_anndata_default(self): + """Test an anndata adaptor with its default llmembs algorithm (llmembs_generic)""" + adaptor = self.load_dataset(f"{PROJECT_ROOT}/example-dataset/pbmc3k.h5ad") + mask = self.get_mask(adaptor, 1, 10) + results = adaptor.compute_llmembs_obs_to_text(mask) + self.check_results(results) + + +def test_h5ad_default(self): + """Test a h5ad adaptor with its default llmembs algorithm (llmembs_cxg)""" + adaptor = self.load_dataset(f"{PROJECT_ROOT}/example-dataset/pbmc3k.h5ad") + mask = self.get_mask(adaptor, 1, 10) + + # run it through the adaptor + results = adaptor.compute_llmembs_obs_to_text(mask) + self.check_results(results) + + # run it directly + results = llm_embeddings.llm_obs_to_text(adaptor, mask) + self.check_results(results) + + +# TODO test also compute_llmembs_text_to_annotations