Skip to content

Commit da5c7e4

Browse files
committed
write test
1 parent 78eb987 commit da5c7e4

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

tests/document_classifier/test_document_classifier.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,92 @@ def test_main_no_specific_sdg(
212212

213213
# There is only one state by doc because the rest of steps were mocked
214214
self.assertEqual(state_in_db[0].title, Step.DOCUMENT_CLASSIFIED_NON_SDG.value)
215+
216+
217+
@patch(
218+
"welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.n_classify_slices"
219+
)
220+
@patch(
221+
"welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.bi_classify_slices"
222+
)
223+
@patch(
224+
"welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.retrieve_models"
225+
)
226+
@patch(
227+
"welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.create_db_session"
228+
)
229+
@patch(
230+
"welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.retrieve_ids_from_csv"
231+
)
232+
def test_main_externally_classified(
233+
self,
234+
mock_retrieve_ids,
235+
mock_create_session,
236+
mock_retrieve_models,
237+
mock_bi_classify,
238+
mock_n_classify,
239+
):
240+
mock_bi_classify.return_value = True
241+
mock_n_classify.return_value = []
242+
243+
doc_test_id = uuid.uuid4()
244+
245+
local_engine = create_engine("sqlite://")
246+
s_maker = sessionmaker(local_engine)
247+
handle_schema_with_sqlite(local_engine)
248+
249+
test_session = s_maker()
250+
Base.metadata.create_all(test_session.get_bind())
251+
252+
253+
mock_retrieve_ids.return_value = [doc_test_id]
254+
session = test_session
255+
mock_create_session.return_value = session
256+
mock_retrieve_models.return_value = [Mock(lang="en", title="model_name")]
257+
258+
259+
corpus_source_name = "test_corpus"
260+
261+
corpus_test = Corpus(
262+
id=uuid.uuid4(),
263+
source_name=corpus_source_name,
264+
is_fix=True,
265+
is_active=True,
266+
)
267+
doc_test = WeLearnDocument(
268+
id=doc_test_id,
269+
url="https://example.org",
270+
corpus_id=corpus_test.id,
271+
title="test",
272+
lang="en",
273+
full_content="test",
274+
description="test",
275+
details={"test": "test", "external_sdg": [10]},
276+
trace=1,
277+
)
278+
279+
slice_test_id = uuid.uuid4()
280+
slice_test = DocumentSlice(
281+
id=slice_test_id,
282+
document_id=doc_test.id,
283+
embedding=numpy.array([1, 2, 3]),
284+
body="test",
285+
order_sequence=0,
286+
embedding_model_name="test",
287+
embedding_model_id=uuid.uuid4(),
288+
)
289+
290+
test_session.add(corpus_test)
291+
test_session.add(doc_test)
292+
test_session.add(slice_test)
293+
test_session.commit()
294+
295+
document_classifier.main()
296+
297+
state_in_db = session.query(ProcessState).all()
298+
299+
# There is only one state by doc because the rest of steps were mocked
300+
self.assertEqual(state_in_db[0].title, Step.DOCUMENT_CLASSIFIED_SDG.value)
301+
302+
sdg_in_db = session.query(Sdg).all()
303+
self.assertEqual(sdg_in_db[0].sdg_number, 10)

welearn_datastack/nodes_workflow/DocumentClassifier/document_classifier.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import uuid
44
from itertools import groupby
55
from typing import List, Set
6-
from uuid import UUID
6+
from uuid import UUID, uuid4
77

88
from sqlalchemy.orm import Session
99

@@ -72,9 +72,10 @@ def main() -> None:
7272
sdg_docs_ids: List[UUID] = []
7373
specific_sdgs: List[Sdg] = []
7474
logger.info("Starting bi-classification")
75+
key_external_sdg = "external_sdg"
7576
slices_per_docs = sorted(slices, key=lambda x: x.document_id) # type: ignore
7677
for k, g in groupby(slices_per_docs, lambda x: x.document_id):
77-
doc_slices = list(g)
78+
doc_slices: List[DocumentSlice] = list(g) # type: ignore
7879
lang = doc_slices[0].document.lang
7980
bi_model = bi_model_by_lang.get(lang)
8081
if not bi_model:
@@ -85,8 +86,22 @@ def main() -> None:
8586
# No SDG found, process it later
8687
non_sdg_docs_ids.add(k)
8788
continue
88-
89-
doc_sdgs = n_classify_slices(doc_slices, n_model_by_lang.get(lang)) # type: ignore
89+
if key_external_sdg in doc_slices[0].document.details:
90+
logger.info(f"Document {doc_slices[0].document_id} was externally classified")
91+
doc_sdgs: List[Sdg] = []
92+
for sdg_number in doc_slices[0].document.details[key_external_sdg]:
93+
for local_slice in doc_slices:
94+
doc_sdgs.append(
95+
Sdg(
96+
slice_id=local_slice.id,
97+
sdg_number=sdg_number,
98+
id=uuid4(),
99+
bi_classifier_model_id=uuid4(),
100+
n_classifier_model_id=uuid4()
101+
)
102+
)
103+
else:
104+
doc_sdgs = n_classify_slices(doc_slices, n_model_by_lang.get(lang)) # type: ignore
90105
if not doc_sdgs:
91106
# No SDG found, process it later
92107
non_sdg_docs_ids.add(k)

0 commit comments

Comments
 (0)