11from __future__ import annotations
22
33import os
4+ import tempfile
5+ from unittest .mock import patch
46
57import pytest
8+ from PIL import Image
69
710from app .modules .decimer import convert_image
811from app .modules .decimer import get_predicted_segments
@@ -29,34 +32,184 @@ def sample_image_path():
2932 return os .path .join (TEST_FILES_DIR , "segment_sample.png" )
3033
3134
35+ @pytest .fixture (scope = "module" )
36+ def small_image_path ():
37+ """Small image (400x300) - should trigger direct prediction"""
38+ return os .path .join (TEST_FILES_DIR , "small_molecule.png" )
39+
40+
41+ @pytest .fixture (scope = "module" )
42+ def tiny_image_path ():
43+ """Tiny image (200x150) - should trigger direct prediction"""
44+ return os .path .join (TEST_FILES_DIR , "tiny_molecule.png" )
45+
46+
47+ @pytest .fixture (scope = "module" )
48+ def caffeine_image_path ():
49+ """Caffeine image for testing"""
50+ return os .path .join (TEST_FILES_DIR , "caffeine.png" )
51+
52+
3253# Test the convert_image function
3354def test_convert_image (sample_gif_path , sample_png_path ):
3455 converted_path = convert_image (sample_gif_path )
3556 assert os .path .isfile (converted_path )
3657 assert converted_path == sample_png_path
58+ # Clean up the converted file
59+ if os .path .exists (converted_path ):
60+ os .remove (converted_path )
3761
3862
39- # Test the get_segments function
40- def test_get_segments (sample_gif_path ):
63+ # Test the get_segments function with GIF
64+ def test_get_segments_gif (sample_gif_path ):
4165 image_name , segments = get_segments (sample_gif_path )
4266 assert image_name == "segment_sample.gif"
43- assert len (segments ) > 0
67+ assert isinstance (segments , list )
68+
69+
70+ # Test the get_segments function with PNG
71+ def test_get_segments_png (sample_png_path ):
72+ image_name , segments = get_segments (sample_png_path )
73+ assert image_name == "segment_sample.png"
74+ assert isinstance (segments , list )
4475
4576
4677# Test the get_predicted_segments function
47- def test_get_predicted_segments (sample_gif_path ):
48- predicted_smiles = get_predicted_segments (sample_gif_path )
78+ @patch ("app.modules.decimer.predict_SMILES" )
79+ def test_get_predicted_segments (mock_predict_smiles , sample_png_path ):
80+ mock_predict_smiles .return_value = "CCO"
81+ predicted_smiles = get_predicted_segments (sample_png_path )
82+ assert isinstance (predicted_smiles , str )
83+ assert len (predicted_smiles ) > 0
84+
85+
86+ # Test get_predicted_segments_from_file with large image (should use segmentation)
87+ @patch ("app.modules.decimer.get_predicted_segments" )
88+ def test_get_predicted_segments_from_file_large_image (
89+ mock_get_predicted_segments , caffeine_image_path
90+ ):
91+ """Test that large images (>=500 pixels) use segmentation approach"""
92+ mock_get_predicted_segments .return_value = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
93+
94+ with open (caffeine_image_path , "rb" ) as f :
95+ content = f .read ()
96+
97+ predicted_smiles = get_predicted_segments_from_file (content , "test_large.png" )
98+
4999 assert isinstance (predicted_smiles , str )
50100 assert len (predicted_smiles ) > 0
101+ mock_get_predicted_segments .assert_called_once ()
51102
52103
53- # Test the get_predicted_segments_from_file function
54- def test_get_predicted_segments_from_file (sample_image_path ):
55- with open (sample_image_path , "rb" ) as f :
104+ # Test get_predicted_segments_from_file with small image (should use direct prediction)
105+ @patch ("app.modules.decimer.predict_SMILES" )
106+ def test_get_predicted_segments_from_file_small_image (
107+ mock_predict_smiles , small_image_path
108+ ):
109+ """Test that small images (<500 pixels) use direct prediction"""
110+ mock_predict_smiles .return_value = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C"
111+
112+ with open (small_image_path , "rb" ) as f :
56113 content = f .read ()
57- predicted_smiles = get_predicted_segments_from_file (
58- content ,
59- "caffeine.png" ,
60- )
114+
115+ predicted_smiles = get_predicted_segments_from_file (content , "test_small.png" )
116+
117+ assert isinstance (predicted_smiles , str )
118+ assert len (predicted_smiles ) > 0
119+ mock_predict_smiles .assert_called_once ()
120+
121+
122+ # Test get_predicted_segments_from_file with tiny image (should use direct prediction)
123+ @patch ("app.modules.decimer.predict_SMILES" )
124+ def test_get_predicted_segments_from_file_tiny_image (
125+ mock_predict_smiles , tiny_image_path
126+ ):
127+ """Test that tiny images (<500 pixels) use direct prediction"""
128+ mock_predict_smiles .return_value = "C1CCC1"
129+
130+ with open (tiny_image_path , "rb" ) as f :
131+ content = f .read ()
132+
133+ predicted_smiles = get_predicted_segments_from_file (content , "test_tiny.png" )
134+
61135 assert isinstance (predicted_smiles , str )
62136 assert len (predicted_smiles ) > 0
137+ mock_predict_smiles .assert_called_once ()
138+
139+
140+ # Test error handling in get_predicted_segments_from_file
141+ def test_get_predicted_segments_from_file_cleanup ():
142+ """Test that temporary files are always cleaned up, even on errors"""
143+ test_content = b"invalid image content"
144+ test_filename = "test_cleanup.png"
145+
146+ # This should fail but still clean up the file
147+ try :
148+ get_predicted_segments_from_file (test_content , test_filename )
149+ except Exception :
150+ pass # Expected to fail with invalid image content
151+
152+ # File should not exist after function completes
153+ assert not os .path .exists (test_filename )
154+
155+
156+ # Test image size detection logic
157+ def test_image_size_detection ():
158+ """Test that the image size detection works correctly"""
159+ # Create temporary images with known sizes
160+ with tempfile .NamedTemporaryFile (
161+ suffix = ".png" , delete = False
162+ ) as tmp_large , tempfile .NamedTemporaryFile (
163+ suffix = ".png" , delete = False
164+ ) as tmp_small :
165+
166+ try :
167+ # Create large image (600x600)
168+ large_img = Image .new ("RGB" , (600 , 600 ), "white" )
169+ large_img .save (tmp_large .name )
170+
171+ # Create small image (300x300)
172+ small_img = Image .new ("RGB" , (300 , 300 ), "white" )
173+ small_img .save (tmp_small .name )
174+
175+ # Test with large image content
176+ with open (tmp_large .name , "rb" ) as f :
177+ large_content = f .read ()
178+
179+ # Test with small image content
180+ with open (tmp_small .name , "rb" ) as f :
181+ small_content = f .read ()
182+
183+ # Mock the prediction functions to verify which path is taken
184+ with patch ("app.modules.decimer.predict_SMILES" ) as mock_direct , patch (
185+ "app.modules.decimer.get_predicted_segments"
186+ ) as mock_segment :
187+
188+ mock_direct .return_value = "direct_prediction"
189+ mock_segment .return_value = "segmented_prediction"
190+
191+ # Test large image uses segmentation
192+ result_large = get_predicted_segments_from_file (
193+ large_content , "test_large_600x600.png"
194+ )
195+ assert result_large == "segmented_prediction"
196+ mock_segment .assert_called ()
197+ mock_direct .assert_not_called ()
198+
199+ # Reset mocks
200+ mock_direct .reset_mock ()
201+ mock_segment .reset_mock ()
202+
203+ # Test small image uses direct prediction
204+ result_small = get_predicted_segments_from_file (
205+ small_content , "test_small_300x300.png"
206+ )
207+ assert result_small == "direct_prediction"
208+ mock_direct .assert_called ()
209+ mock_segment .assert_not_called ()
210+
211+ finally :
212+ # Clean up temporary files
213+ for tmp_file in [tmp_large .name , tmp_small .name ]:
214+ if os .path .exists (tmp_file ):
215+ os .remove (tmp_file )
0 commit comments