@@ -212,3 +212,92 @@ def test_main_no_specific_sdg(
212212
213213 # There is only one state by doc because the rest of steps were mocked
214214 self .assertEqual (state_in_db [0 ].title , Step .DOCUMENT_CLASSIFIED_NON_SDG .value )
215+
216+
217+ @patch (
218+ "welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.n_classify_slices"
219+ )
220+ @patch (
221+ "welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.bi_classify_slices"
222+ )
223+ @patch (
224+ "welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.retrieve_models"
225+ )
226+ @patch (
227+ "welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.create_db_session"
228+ )
229+ @patch (
230+ "welearn_datastack.nodes_workflow.DocumentClassifier.document_classifier.retrieve_ids_from_csv"
231+ )
232+ def test_main_externally_classified (
233+ self ,
234+ mock_retrieve_ids ,
235+ mock_create_session ,
236+ mock_retrieve_models ,
237+ mock_bi_classify ,
238+ mock_n_classify ,
239+ ):
240+ mock_bi_classify .return_value = True
241+ mock_n_classify .return_value = []
242+
243+ doc_test_id = uuid .uuid4 ()
244+
245+ local_engine = create_engine ("sqlite://" )
246+ s_maker = sessionmaker (local_engine )
247+ handle_schema_with_sqlite (local_engine )
248+
249+ test_session = s_maker ()
250+ Base .metadata .create_all (test_session .get_bind ())
251+
252+
253+ mock_retrieve_ids .return_value = [doc_test_id ]
254+ session = test_session
255+ mock_create_session .return_value = session
256+ mock_retrieve_models .return_value = [Mock (lang = "en" , title = "model_name" )]
257+
258+
259+ corpus_source_name = "test_corpus"
260+
261+ corpus_test = Corpus (
262+ id = uuid .uuid4 (),
263+ source_name = corpus_source_name ,
264+ is_fix = True ,
265+ is_active = True ,
266+ )
267+ doc_test = WeLearnDocument (
268+ id = doc_test_id ,
269+ url = "https://example.org" ,
270+ corpus_id = corpus_test .id ,
271+ title = "test" ,
272+ lang = "en" ,
273+ full_content = "test" ,
274+ description = "test" ,
275+ details = {"test" : "test" , "external_sdg" : [10 ]},
276+ trace = 1 ,
277+ )
278+
279+ slice_test_id = uuid .uuid4 ()
280+ slice_test = DocumentSlice (
281+ id = slice_test_id ,
282+ document_id = doc_test .id ,
283+ embedding = numpy .array ([1 , 2 , 3 ]),
284+ body = "test" ,
285+ order_sequence = 0 ,
286+ embedding_model_name = "test" ,
287+ embedding_model_id = uuid .uuid4 (),
288+ )
289+
290+ test_session .add (corpus_test )
291+ test_session .add (doc_test )
292+ test_session .add (slice_test )
293+ test_session .commit ()
294+
295+ document_classifier .main ()
296+
297+ state_in_db = session .query (ProcessState ).all ()
298+
299+ # There is only one state by doc because the rest of steps were mocked
300+ self .assertEqual (state_in_db [0 ].title , Step .DOCUMENT_CLASSIFIED_SDG .value )
301+
302+ sdg_in_db = session .query (Sdg ).all ()
303+ self .assertEqual (sdg_in_db [0 ].sdg_number , 10 )
0 commit comments