-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
385 lines (333 loc) · 15 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
import configparser
import torch
from transformers import AutoImageProcessor, AutoModel
from qdrant_client import QdrantClient
from qdrant_client import models
from collections import Counter
from PIL import Image
from ollama import generate
import streamlit as st
import requests
from io import BytesIO
##############################################
# Config Setup #
##############################################
st.set_page_config(
page_title="Medical Assistant for Dermatoscopic Images",
layout="wide",
initial_sidebar_state="expanded",
page_icon="https://storage.googleapis.com/demo-skin-cancer/qdrant-logo.png"
)
config = configparser.ConfigParser()
config.read("config.ini")
# Qdrant
qdrant_cloud_url = config["qdrant"]["cloud_url"]
qdrant_api = config["secrets"]["api_key"]
client = QdrantClient(url=qdrant_cloud_url, api_key=qdrant_api)
collection_name = "melanoma_main_collection"
# Ollama
OLLAMA_MODEL = "deepseek-llm"
##############################################
# Model and Processor #
##############################################
@st.cache_resource
def load_model_and_processor():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large").to(device)
return device, processor, model
device, processor, model = load_model_and_processor()
##############################################
# Embedding Extraction (ViT) #
##############################################
def get_embeddings_query(image):
"""Extract cls_output, average-pooled patches, and raw patches from the image."""
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs).last_hidden_state
outputs = outputs.cpu().numpy()
patches = outputs[0, 1:, :] # all but the CLS token
cls_output = outputs[0, 0, :] # the CLS token
pooled_patches = patches.mean(axis=0)
del outputs, inputs
torch.cuda.empty_cache()
return cls_output, pooled_patches, patches
##############################################
# Qdrant Search Helpers #
##############################################
def search_knn(image_embedding, k=10):
"""Search the Qdrant collection by a single 'CLS' vector."""
try:
result = client.query_points(
collection_name="melanoma_main_collection",
query=image_embedding,
using="cls",
limit=k,
with_payload=True
).points
return result
except Exception as e:
st.error("Error: Unable to retrieve data from the knowledge base. Please check your connection and try again.")
return []
def search_rerank_knn(image_embedding, patch_embeddings, prefetch_k=20, k=10):
"""
Example of a 2-stage retrieval:
1) Prefetch with 'pooled_patches'
2) Rerank with patch-level embeddings
"""
try:
result = client.query_points(
collection_name="melanoma_main_collection",
prefetch=models.Prefetch(
query=image_embedding,
using="pooled_patches",
limit=prefetch_k
),
query=patch_embeddings,
using="patches",
limit=k,
with_payload=True
).points
return result
except Exception as e:
st.error("Error: Unable to retrieve data from the knowledge base. Please check your connection and try again.")
return []
##############################################
# KNN Classification #
##############################################
def KNN_classifier_dx_with_rerank(image_embedding, patch_embeddings, k=10):
"""
2-stage KNN classification using rerank approach.
"""
points = search_rerank_knn(image_embedding, patch_embeddings, k*2, k)
if not points:
return None, 0
diagnoses = [p.payload.get("dx", "") for p in points]
counter = Counter(diagnoses)
two_most_common = counter.most_common(2)
# If there's a tie, increase k until a maximum
if len(two_most_common) > 1 and two_most_common[0][1] == two_most_common[1][1] and k <= 50:
return KNN_classifier_dx_with_rerank(image_embedding, patch_embeddings, k + 5)
dx_class = two_most_common[0][0]
dx_confidence = two_most_common[0][1] / k
return dx_class, dx_confidence
##############################################
# Simple RAG via DeepSeek #
##############################################
def rag_deepseek(dx_class, sureness):
"""
Sends a simple prompt to the deepseek-llm
"""
prompt = f"""
You are an AI medical assistant designed to help medical professionals with skin lesion categorization.
You're providing information to assist them in the diagnosis of skin lesions.
## Context:
- The user is a doctor who has uploaded an image of a skin lesion to classify.
- The app uses vector search to find similar images from a large collection of multi-source dermatoscopic images of common pigmented skin lesions.
## Task Context
The classifier result for the uploaded image is:
- Diagnosis: {dx_class}
- Confidence Score: {sureness:.2%}
## Task:
Generate a structured and informative piece of text that includes:
1. A disclaimer emphasizing that this AI tool is only suitable for an assistence to a medical expert.
2. A clear presentation of the diagnosis and confidence score.
3. A brief but precise explanation of the diagnosed condition, ensuring it is medically relevant.
## Output Formatting:
- Don't use personal tone, greetings or chatty manner.
- The disclaimer must appear at the beginning of the response.
- The diagnosis and confidence score must be clearly stated.
- The response should be formatted in markdown for readability.
- Use neutral, professional tone.
## Information about the uploaded image classification to a doctor:
"""
try:
chunks = generate(model=OLLAMA_MODEL, prompt=prompt, options={"seed": 42, "temperature":0}, stream=True)
for chunk in chunks:
yield chunk["response"]
except Exception as e:
yield f"AI assitant error"
##############################################
# Streamlit Frontend #
##############################################
st.title("Medical Assistant for Dermatoscopic Images")
st.markdown("### Upload an image to get an AI medical assistant diagnosis or retrieve similar cases from the database.")
# Initialize session state variables
if 'image' not in st.session_state:
st.session_state.image = None
if 'diagnosis_text' not in st.session_state:
st.session_state.diagnosis_text = None
if 'similar_images' not in st.session_state:
st.session_state.similar_images = None
# Upload Section
# Create tabs for image selection methods
image_source = st.radio(
"Select image source:",
["Upload Image", "Enter Image Public URL", "Choose Example Images"],
horizontal=True
)
# OPTION 1: Upload image
if image_source == "Upload Image":
image_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
if image_file:
# Clear diagnosis when new image is uploaded
st.session_state.diagnosis_text = None
st.session_state.similar_images = None
st.session_state.image = Image.open(image_file)
# OPTION 2: Image from URL
elif image_source == "Enter Image Public URL":
image_url = st.text_input("Enter the public URL of an image:", "")
if image_url:
try:
response = requests.get(image_url)
# Clear diagnosis when new image is selected from URL
st.session_state.diagnosis_text = None
st.session_state.similar_images = None
st.session_state.image = Image.open(BytesIO(response.content))
except Exception as e:
st.error(f"Error loading image from URL: {e}")
# OPTION 3: Example images
elif image_source == "Choose Example Images":
st.markdown("### Select an example:")
# Example images with their corresponding diagnoses
example_images = [
{"url": "https://storage.googleapis.com/demo-skin-cancer/test_ISIC_0024431_HAM_0002134.jpeg",
"diagnosis": "Basal cell carcinoma"},
{"url": "https://storage.googleapis.com/demo-skin-cancer/test_ISIC_0031918_HAM_0003141.jpeg",
"diagnosis": "Actinic keratosis"},
{"url": "https://storage.googleapis.com/demo-skin-cancer/test_ISIC_0025901_HAM_0005562.jpeg",
"diagnosis": "Melanocytic nevus"},
]
# Cache example images to load them only once
@st.cache_data
def load_example_image(url):
try:
response = requests.get(url)
return Image.open(BytesIO(response.content))
except Exception as e:
st.error(f"Error loading example image: {e}")
return None
# Display example images in a row
cols = st.columns(len(example_images))
for i, (col, img_data) in enumerate(zip(cols, example_images)):
with col:
example_img = load_example_image(img_data["url"])
if example_img:
st.image(example_img, width=400)
st.write(f"**Diagnosis:** {img_data['diagnosis']}")
if st.button(f"Select Image {i+1}", key=f"example_{i}"):
# Clear diagnosis when new example image is selected
st.session_state.similar_images = None
st.session_state.diagnosis_text = None
st.session_state.image = example_img
# If an image is selected (from any source), proceed with analysis
if st.session_state.image:
# Get embeddings from the image
with st.spinner("Extracting image features..."):
cls_output, pooled_patches, patches = get_embeddings_query(st.session_state.image)
st.markdown("---")
# Show the last selected image
st.subheader("Selected Image")
st.image(st.session_state.image, width=400)
# ---------------- Diagnosis Section ----------------
st.subheader("AI Assistant's Diagnosis:")
st.markdown(
"Click **Run Diagnosis** to get a skin lesion diagnosis from the AI assistant:"
)
# Two-column layout: left for button, right for streaming output
col_diag_button, col_diag_output = st.columns([1, 2])
with col_diag_button:
run_diag = st.button("Run Diagnosis", key="diag_button")
st.image("https://storage.googleapis.com/demo-skin-cancer/DeepSeek_RAG_logo.png", width=300)
with col_diag_output:
diag_output = st.empty() # Reserved placeholder for diagnosis result
# Display existing diagnosis if available
if st.session_state.diagnosis_text:
diag_output.markdown(st.session_state.diagnosis_text)
if run_diag:
diag_output.empty()
with st.spinner("Processing image for diagnosis..."):
melanoma_class, sureness = KNN_classifier_dx_with_rerank(pooled_patches, patches)
if melanoma_class:
diag_text = ""
# Stream diagnosis text to the fixed output container
with st.spinner("Generating diagnosis..."):
for chunk_text in rag_deepseek(melanoma_class, sureness):
diag_text += chunk_text
diag_output.markdown(diag_text)
# Store the generated diagnosis text in session state
st.session_state.diagnosis_text = diag_text
else:
st.error("Classification error. Please try again.")
if st.button("Retry"):
st.rerun()
st.markdown("---")
# ---------------- Similar Cases Section ----------------
st.subheader("Similar Cases")
st.markdown(
"Click **Show Top-5 Similar Images** to view cases from the database that are similar to the selected image."
)
def display_images():
main_col, legend_col = st.columns([2, 1])
with main_col:
st.subheader("Retrieved Information:")
for point in st.session_state.similar_images:
with st.container():
image_col, info_col = st.columns([1, 1])
with image_col:
st.image(point.payload["url"], width=400)
with info_col:
st.markdown(f"""
| **Attribute** | **Value** |
|----------------|-----------|
| **Age:** | {int(point.payload.get('age', 'N/A'))} |
| **Sex:** | {point.payload.get('sex', 'N/A')} |
| **Diagnosis:** | {point.payload.get('dx', 'N/A')} |
| **Type:** | {point.payload.get('dx_type', 'N/A')} |
| **Localization:** | {point.payload.get('localization', 'N/A')} |
""")
# Legend column on the right
with legend_col:
st.subheader("Legend")
st.markdown("""
#####
##### Diagnosis Confirmation Types:
- **histo**: Histopathology confirmed
- **follow_up**: Follow-up confirmed
- **consensus**: Expert consensus
- **confocal**: Confocal microscopy confirmed
#####
##### Localizations:
- **face**: Face area
- **trunk**: Torso area
- **scalp**: Head/scalp area
- **acral**: Hands/feet
- **back**: Back area
- **abdomen**: Stomach area
- **chest**: Chest area
- **upper extremity**: Arms area
- **lower extremity**: Legs area
- **neck**: Neck area
- **genital**: Genital area
- **ear**: Ear area
- **foot**: Foot area
- **hand**: Hand area
- **unknown**: Unknown area
""")
run_top_similar = st.button("Show Top-5 Similar Images", key="similar_button")
if st.session_state.similar_images and not run_top_similar:
display_images()
if run_top_similar:
# Clear previously displayed images when button is pressed
st.session_state.similar_images = None
with st.spinner("Searching for similar cases..."):
similar_images = search_knn(cls_output, k=5)
if similar_images:
st.session_state.similar_images = similar_images
display_images()
else:
st.error("Searching error. Please try again.")
if st.button("Retry"):
st.rerun()
else:
st.info("Please select an image using one of the methods above to begin.")