@@ -516,6 +516,7 @@ def train_supervised(*kargs, **kwargs):
516
516
'model' : "supervised"
517
517
})
518
518
519
+ callback = kwargs .pop ("callback" , None )
519
520
arg_names = ['input' , 'lr' , 'dim' , 'ws' , 'epoch' , 'minCount' ,
520
521
'minCountLabel' , 'minn' , 'maxn' , 'neg' , 'wordNgrams' , 'loss' , 'bucket' ,
521
522
'thread' , 'lrUpdateRate' , 't' , 'label' , 'verbose' , 'pretrainedVectors' ,
@@ -525,7 +526,10 @@ def train_supervised(*kargs, **kwargs):
525
526
supervised_default )
526
527
a = _build_args (args , manually_set_args )
527
528
ft = _FastText (args = a )
528
- fasttext .train (ft .f , a )
529
+ if callback :
530
+ fasttext .train_with_callback (ft .f , a , callback )
531
+ else :
532
+ fasttext .train (ft .f , a )
529
533
ft .set_args (ft .f .getArgs ())
530
534
return ft
531
535
@@ -544,13 +548,18 @@ def train_unsupervised(*kargs, **kwargs):
544
548
dataset pulled by the example script word-vector-example.sh, which is
545
549
part of the fastText repository.
546
550
"""
551
+ callback = kwargs .pop ("callback" , None )
547
552
arg_names = ['input' , 'model' , 'lr' , 'dim' , 'ws' , 'epoch' , 'minCount' ,
548
553
'minCountLabel' , 'minn' , 'maxn' , 'neg' , 'wordNgrams' , 'loss' , 'bucket' ,
549
554
'thread' , 'lrUpdateRate' , 't' , 'label' , 'verbose' , 'pretrainedVectors' ]
550
555
args , manually_set_args = read_args (kargs , kwargs , arg_names ,
551
556
unsupervised_default )
552
557
a = _build_args (args , manually_set_args )
553
558
ft = _FastText (args = a )
559
+ if callback :
560
+ fasttext .train_with_callback (ft .f , a , callback )
561
+ else :
562
+ fasttext .train (ft .f , a )
554
563
fasttext .train (ft .f , a )
555
564
ft .set_args (ft .f .getArgs ())
556
565
return ft
0 commit comments