@@ -142,7 +142,7 @@ def test_similarity_unseen_docs(self):
142
142
model .build_vocab (corpus )
143
143
self .assertTrue (model .docvecs .similarity_unseen_docs (model , rome_str , rome_str ) > model .docvecs .similarity_unseen_docs (model , rome_str , car_str ))
144
144
145
- def model_sanity (self , model ):
145
+ def model_sanity (self , model , keep_training = True ):
146
146
"""Any non-trivial model on DocsLeeCorpus can pass these sanity checks"""
147
147
fire1 = 0 # doc 0 sydney fires
148
148
fire2 = 8 # doc 8 sydney fires
@@ -179,6 +179,12 @@ def model_sanity(self, model):
179
179
# fire docs should be closer than fire-tennis
180
180
self .assertTrue (model .docvecs .similarity (fire1 , fire2 ) > model .docvecs .similarity (fire1 , tennis1 ))
181
181
182
+ # keep training after save
183
+ if keep_training :
184
+ model .save (testfile ())
185
+ loaded = doc2vec .Doc2Vec .load (testfile ())
186
+ loaded .train (sentences )
187
+
182
188
def test_training (self ):
183
189
"""Test doc2vec training."""
184
190
corpus = DocsLeeCorpus ()
@@ -316,10 +322,10 @@ def test_delete_temporary_training_data(self):
316
322
model .delete_temporary_training_data (keep_doctags_vectors = True , keep_inference = True )
317
323
self .assertTrue (model .docvecs and hasattr (model .docvecs , 'doctag_syn0' ))
318
324
self .assertTrue (hasattr (model , 'syn1' ))
319
- self .model_sanity (model )
325
+ self .model_sanity (model , keep_training = False )
320
326
model = doc2vec .Doc2Vec (list_corpus , dm = 1 , dm_mean = 1 , size = 24 , window = 4 , hs = 0 , negative = 1 , alpha = 0.05 , min_count = 2 , iter = 20 )
321
327
model .delete_temporary_training_data (keep_doctags_vectors = True , keep_inference = True )
322
- self .model_sanity (model )
328
+ self .model_sanity (model , keep_training = False )
323
329
self .assertTrue (hasattr (model , 'syn1neg' ))
324
330
325
331
@log_capture ()
0 commit comments