11import pytest
2+ from unittest .mock import patch , MagicMock
23from app .modules .classyfire import classify , result
3- import asyncio
44
55
66@pytest .fixture
@@ -14,33 +14,69 @@ def invalid_smiles():
1414
1515
1616@pytest .mark .asyncio
17- async def test_valid_classyfire (valid_smiles ):
17+ @patch ("app.modules.classyfire.requests.post" )
18+ @patch ("app.modules.classyfire.requests.get" )
19+ async def test_valid_classyfire (mock_get , mock_post , valid_smiles ):
20+ # Mock the initial classification request
21+ mock_post_response = MagicMock ()
22+ mock_post_response .json .return_value = {
23+ "id" : "12345" ,
24+ "query_type" : "STRUCTURE" ,
25+ "query_input" : valid_smiles ,
26+ }
27+ mock_post_response .raise_for_status .return_value = None
28+ mock_post .return_value = mock_post_response
29+
30+ # Mock the result retrieval request
31+ mock_get_response = MagicMock ()
32+ mock_get_response .json .return_value = {
33+ "id" : "12345" ,
34+ "classification_status" : "Done" ,
35+ "entities" : [{"class" : {"name" : "Imidazopyrimidines" }}],
36+ }
37+ mock_get_response .raise_for_status .return_value = None
38+ mock_get .return_value = mock_get_response
39+
1840 result_ = await classify (valid_smiles )
1941 assert result_ ["query_type" ] == "STRUCTURE"
2042 id_ = result_ ["id" ]
2143
22- while True :
23- classified = await result (id_ )
24- if classified ["classification_status" ] == "Done" :
25- break
26- await asyncio .sleep (2 )
27-
44+ classified = await result (id_ )
2845 assert classified ["classification_status" ] == "Done"
2946 assert classified ["entities" ][0 ]["class" ]["name" ] == "Imidazopyrimidines"
3047
3148
3249@pytest .mark .asyncio
33- async def test_invalid_classyfire (invalid_smiles ):
50+ @patch ("app.modules.classyfire.requests.post" )
51+ @patch ("app.modules.classyfire.requests.get" )
52+ async def test_invalid_classyfire (mock_get , mock_post , invalid_smiles ):
53+ # Mock the initial classification request
54+ mock_post_response = MagicMock ()
55+ mock_post_response .json .return_value = {
56+ "id" : "12346" ,
57+ "query_type" : "STRUCTURE" ,
58+ "query_input" : invalid_smiles ,
59+ }
60+ mock_post_response .raise_for_status .return_value = None
61+ mock_post .return_value = mock_post_response
62+
63+ # Mock the result retrieval request
64+ mock_get_response = MagicMock ()
65+ mock_get_response .json .return_value = {
66+ "id" : "12346" ,
67+ "classification_status" : "Done" ,
68+ "invalid_entities" : [
69+ {"report" : ["Cannot process the input SMILES string, please check again" ]}
70+ ],
71+ }
72+ mock_get_response .raise_for_status .return_value = None
73+ mock_get .return_value = mock_get_response
74+
3475 result_ = await classify (invalid_smiles )
3576 assert result_ ["query_input" ] == "invalid_smiles"
3677 id_ = result_ ["id" ]
3778
38- while True :
39- classified = await result (id_ )
40- if classified ["classification_status" ] == "Done" :
41- break
42- await asyncio .sleep (2 )
43-
79+ classified = await result (id_ )
4480 assert classified ["classification_status" ] == "Done"
4581 assert (
4682 classified ["invalid_entities" ][0 ]["report" ][0 ]
0 commit comments