Skip to content

Commit af7347f

Browse files
Merge pull request #610 from Steinbeck-Lab/development
fix: decimer to work on low res images
2 parents b97c584 + d02472d commit af7347f

File tree

9 files changed

+196
-42
lines changed

9 files changed

+196
-42
lines changed

.github/workflows/test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ jobs:
4646
pip3 install --no-cache-dir -r requirements.txt
4747
pip install git+https://github.com/Kohulan/DECIMER-Image-Segmentation.git@bbox --no-deps
4848
pip3 install --no-deps decimer
49-
pip3 install --no-deps STOUT-pypi==2.0.5
5049
pip install flake8 pytest
5150
pip install pytest-cov
5251
wget -O surge "https://github.com/StructureGenerator/surge/releases/download/v1.0/surge-linux-v1.0"

Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ RUN conda install -c conda-forge python=${PYTHON_VERSION} sqlite --force-reinsta
3131
pip3 install --no-cache-dir -r requirements.txt && \
3232
# Install specific packages without dependencies
3333
pip3 install --no-cache-dir --no-deps \
34-
decimer-segmentation==1.1.3 \
35-
decimer==2.3.0 \
34+
git+https://github.com/Kohulan/DECIMER-Image-Segmentation.git@bbox \
35+
decimer==2.7.1 \
3636
chembl_structure_pipeline
3737

3838

3939
COPY ./app ./app
4040

41-
CMD uvicorn app.main:app --host 0.0.0.0 --port 80 --workers ${WORKERS}
41+
CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port 80 --workers ${WORKERS}"]

Dockerfile.lite

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ RUN conda install -c conda-forge python=${PYTHON_VERSION} sqlite --force-reinsta
3535

3636
COPY ./app /code/app
3737

38-
CMD uvicorn app.main:app --host 0.0.0.0 --port 80 --workers ${WORKERS}
38+
CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port 80 --workers ${WORKERS}"]

app/modules/decimer.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,38 @@ def get_predicted_segments(path: str) -> str:
106106
return ".".join(smiles_predicted)
107107

108108

109-
def get_predicted_segments_from_file(content: any, filename: str) -> tuple:
110-
"""Takes an image file path and returns a set of paths and image names of.
109+
def get_predicted_segments_from_file(content: any, filename: str) -> str:
110+
"""Takes an image file content and filename, saves it temporarily, and returns SMILES prediction.
111111
112-
segmented images.
112+
If the image dimensions are below 500 pixels, uses predict_SMILES directly.
113+
Otherwise, uses segmentation approach.
113114
114115
Args:
115-
input_path (str): the path of an image.
116+
content (any): The image file content.
117+
filename (str): The filename to save the content to.
116118
117119
Returns:
118-
image_name (str): image file name.
119-
segments (list): a set of segmented images.
120+
str: Predicted SMILES string.
120121
"""
121122

123+
# Write the content to file and ensure it's closed
122124
with open(filename, "wb") as f:
123125
f.write(content)
124-
smiles = get_predicted_segments(filename)
125-
os.remove(filename)
126+
127+
try:
128+
# Check image dimensions
129+
img = Image.open(filename)
130+
width, height = img.size
131+
img.close() # Close the image to free resources
132+
133+
# If image is small (below 500 pixels in either dimension), use direct prediction
134+
if width < 500 or height < 500:
135+
smiles = predict_SMILES(filename)
136+
else:
137+
smiles = get_predicted_segments(filename)
138+
126139
return smiles
140+
finally:
141+
# Ensure the temporary file is always removed
142+
if os.path.exists(filename):
143+
os.remove(filename)

requirements_lite.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ prometheus-fastapi-instrumentator
1414
pystow>=0.4.9
1515
python-multipart
1616
selfies>=2.1.1
17-
tensorflow==2.12.0
17+
tensorflow==2.15.1
18+
Keras-Preprocessing==1.1.2
1819
unicodedata2==15.0.0
1920
websockets==10.4
2021
mapchiral

tests/small_molecule.png

2.84 KB
Loading

tests/test_decimer.py

Lines changed: 165 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

33
import os
4+
import tempfile
5+
from unittest.mock import patch
46

57
import pytest
8+
from PIL import Image
69

710
from app.modules.decimer import convert_image
811
from app.modules.decimer import get_predicted_segments
@@ -29,34 +32,184 @@ def sample_image_path():
2932
return os.path.join(TEST_FILES_DIR, "segment_sample.png")
3033

3134

35+
@pytest.fixture(scope="module")
36+
def small_image_path():
37+
"""Small image (400x300) - should trigger direct prediction"""
38+
return os.path.join(TEST_FILES_DIR, "small_molecule.png")
39+
40+
41+
@pytest.fixture(scope="module")
42+
def tiny_image_path():
43+
"""Tiny image (200x150) - should trigger direct prediction"""
44+
return os.path.join(TEST_FILES_DIR, "tiny_molecule.png")
45+
46+
47+
@pytest.fixture(scope="module")
48+
def caffeine_image_path():
49+
"""Caffeine image for testing"""
50+
return os.path.join(TEST_FILES_DIR, "caffeine.png")
51+
52+
3253
# Test the convert_image function
3354
def test_convert_image(sample_gif_path, sample_png_path):
3455
converted_path = convert_image(sample_gif_path)
3556
assert os.path.isfile(converted_path)
3657
assert converted_path == sample_png_path
58+
# Clean up the converted file
59+
if os.path.exists(converted_path):
60+
os.remove(converted_path)
3761

3862

39-
# Test the get_segments function
40-
def test_get_segments(sample_gif_path):
63+
# Test the get_segments function with GIF
64+
def test_get_segments_gif(sample_gif_path):
4165
image_name, segments = get_segments(sample_gif_path)
4266
assert image_name == "segment_sample.gif"
43-
assert len(segments) > 0
67+
assert isinstance(segments, list)
68+
69+
70+
# Test the get_segments function with PNG
71+
def test_get_segments_png(sample_png_path):
72+
image_name, segments = get_segments(sample_png_path)
73+
assert image_name == "segment_sample.png"
74+
assert isinstance(segments, list)
4475

4576

4677
# Test the get_predicted_segments function
47-
def test_get_predicted_segments(sample_gif_path):
48-
predicted_smiles = get_predicted_segments(sample_gif_path)
78+
@patch("app.modules.decimer.predict_SMILES")
79+
def test_get_predicted_segments(mock_predict_smiles, sample_png_path):
80+
mock_predict_smiles.return_value = "CCO"
81+
predicted_smiles = get_predicted_segments(sample_png_path)
82+
assert isinstance(predicted_smiles, str)
83+
assert len(predicted_smiles) > 0
84+
85+
86+
# Test get_predicted_segments_from_file with large image (should use segmentation)
87+
@patch("app.modules.decimer.get_predicted_segments")
88+
def test_get_predicted_segments_from_file_large_image(
89+
mock_get_predicted_segments, caffeine_image_path
90+
):
91+
"""Test that large images (>=500 pixels) use segmentation approach"""
92+
mock_get_predicted_segments.return_value = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
93+
94+
with open(caffeine_image_path, "rb") as f:
95+
content = f.read()
96+
97+
predicted_smiles = get_predicted_segments_from_file(content, "test_large.png")
98+
4999
assert isinstance(predicted_smiles, str)
50100
assert len(predicted_smiles) > 0
101+
mock_get_predicted_segments.assert_called_once()
51102

52103

53-
# Test the get_predicted_segments_from_file function
54-
def test_get_predicted_segments_from_file(sample_image_path):
55-
with open(sample_image_path, "rb") as f:
104+
# Test get_predicted_segments_from_file with small image (should use direct prediction)
105+
@patch("app.modules.decimer.predict_SMILES")
106+
def test_get_predicted_segments_from_file_small_image(
107+
mock_predict_smiles, small_image_path
108+
):
109+
"""Test that small images (<500 pixels) use direct prediction"""
110+
mock_predict_smiles.return_value = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
111+
112+
with open(small_image_path, "rb") as f:
56113
content = f.read()
57-
predicted_smiles = get_predicted_segments_from_file(
58-
content,
59-
"caffeine.png",
60-
)
114+
115+
predicted_smiles = get_predicted_segments_from_file(content, "test_small.png")
116+
117+
assert isinstance(predicted_smiles, str)
118+
assert len(predicted_smiles) > 0
119+
mock_predict_smiles.assert_called_once()
120+
121+
122+
# Test get_predicted_segments_from_file with tiny image (should use direct prediction)
123+
@patch("app.modules.decimer.predict_SMILES")
124+
def test_get_predicted_segments_from_file_tiny_image(
125+
mock_predict_smiles, tiny_image_path
126+
):
127+
"""Test that tiny images (<500 pixels) use direct prediction"""
128+
mock_predict_smiles.return_value = "C1CCC1"
129+
130+
with open(tiny_image_path, "rb") as f:
131+
content = f.read()
132+
133+
predicted_smiles = get_predicted_segments_from_file(content, "test_tiny.png")
134+
61135
assert isinstance(predicted_smiles, str)
62136
assert len(predicted_smiles) > 0
137+
mock_predict_smiles.assert_called_once()
138+
139+
140+
# Test error handling in get_predicted_segments_from_file
141+
def test_get_predicted_segments_from_file_cleanup():
142+
"""Test that temporary files are always cleaned up, even on errors"""
143+
test_content = b"invalid image content"
144+
test_filename = "test_cleanup.png"
145+
146+
# This should fail but still clean up the file
147+
try:
148+
get_predicted_segments_from_file(test_content, test_filename)
149+
except Exception:
150+
pass # Expected to fail with invalid image content
151+
152+
# File should not exist after function completes
153+
assert not os.path.exists(test_filename)
154+
155+
156+
# Test image size detection logic
157+
def test_image_size_detection():
158+
"""Test that the image size detection works correctly"""
159+
# Create temporary images with known sizes
160+
with tempfile.NamedTemporaryFile(
161+
suffix=".png", delete=False
162+
) as tmp_large, tempfile.NamedTemporaryFile(
163+
suffix=".png", delete=False
164+
) as tmp_small:
165+
166+
try:
167+
# Create large image (600x600)
168+
large_img = Image.new("RGB", (600, 600), "white")
169+
large_img.save(tmp_large.name)
170+
171+
# Create small image (300x300)
172+
small_img = Image.new("RGB", (300, 300), "white")
173+
small_img.save(tmp_small.name)
174+
175+
# Test with large image content
176+
with open(tmp_large.name, "rb") as f:
177+
large_content = f.read()
178+
179+
# Test with small image content
180+
with open(tmp_small.name, "rb") as f:
181+
small_content = f.read()
182+
183+
# Mock the prediction functions to verify which path is taken
184+
with patch("app.modules.decimer.predict_SMILES") as mock_direct, patch(
185+
"app.modules.decimer.get_predicted_segments"
186+
) as mock_segment:
187+
188+
mock_direct.return_value = "direct_prediction"
189+
mock_segment.return_value = "segmented_prediction"
190+
191+
# Test large image uses segmentation
192+
result_large = get_predicted_segments_from_file(
193+
large_content, "test_large_600x600.png"
194+
)
195+
assert result_large == "segmented_prediction"
196+
mock_segment.assert_called()
197+
mock_direct.assert_not_called()
198+
199+
# Reset mocks
200+
mock_direct.reset_mock()
201+
mock_segment.reset_mock()
202+
203+
# Test small image uses direct prediction
204+
result_small = get_predicted_segments_from_file(
205+
small_content, "test_small_300x300.png"
206+
)
207+
assert result_small == "direct_prediction"
208+
mock_direct.assert_called()
209+
mock_segment.assert_not_called()
210+
211+
finally:
212+
# Clean up temporary files
213+
for tmp_file in [tmp_large.name, tmp_small.name]:
214+
if os.path.exists(tmp_file):
215+
os.remove(tmp_file)

tests/test_deeplearningtools.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,13 @@
33
import pytest
44
from DECIMER import predict_SMILES
55
from rdkit import Chem
6-
from STOUT import translate_forward
7-
from STOUT import translate_reverse
86

97

108
@pytest.fixture
119
def test_smiles():
1210
return "CN1C(=O)C2=C(N=CN2C)N(C)C1=O"
1311

1412

15-
def test_smilestoiupac(test_smiles):
16-
smiles = test_smiles
17-
expected_result = "1,3,7-trimethylpurine-2,6-dione"
18-
actual_result = translate_forward(smiles)
19-
assert expected_result == actual_result
20-
21-
22-
def test_iupactosmiles(test_smiles):
23-
iupac_name = "1,3,7-trimethylpurine-2,6-dione"
24-
expected_result = "CN1C=NC2=C1C(=O)N(C)C(=O)N2C"
25-
actual_result = translate_reverse(iupac_name)
26-
assert expected_result == actual_result
27-
28-
2913
def test_imagetosmiles(test_smiles):
3014
img_path = "tests/caffeine.png"
3115
expected_result = test_smiles

tests/tiny_molecule.png

44 KB
Loading

0 commit comments

Comments
 (0)