Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit 1c1a7a8

Browse files
PreetGitsnehagaur01
authored andcommitted
Adding python callback during training
1 parent 0622aad commit 1c1a7a8

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

python/fasttext_module/fasttext/FastText.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def train_supervised(*kargs, **kwargs):
516516
'model': "supervised"
517517
})
518518

519+
callback = kwargs.pop("callback", None)
519520
arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount',
520521
'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
521522
'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
@@ -525,7 +526,10 @@ def train_supervised(*kargs, **kwargs):
525526
supervised_default)
526527
a = _build_args(args, manually_set_args)
527528
ft = _FastText(args=a)
528-
fasttext.train(ft.f, a)
529+
if callback:
530+
fasttext.train_with_callback(ft.f, a, callback)
531+
else:
532+
fasttext.train(ft.f, a)
529533
ft.set_args(ft.f.getArgs())
530534
return ft
531535

@@ -544,14 +548,18 @@ def train_unsupervised(*kargs, **kwargs):
544548
dataset pulled by the example script word-vector-example.sh, which is
545549
part of the fastText repository.
546550
"""
551+
callback = kwargs.pop("callback", None)
547552
arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount',
548553
'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
549554
'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
550555
args, manually_set_args = read_args(kargs, kwargs, arg_names,
551556
unsupervised_default)
552557
a = _build_args(args, manually_set_args)
553558
ft = _FastText(args=a)
554-
fasttext.train(ft.f, a)
559+
if callback:
560+
fasttext.train_with_callback(ft.f, a, callback)
561+
else:
562+
fasttext.train(ft.f, a)
555563
ft.set_args(ft.f.getArgs())
556564
return ft
557565

python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <fasttext.h>
1313
#include <pybind11/numpy.h>
1414
#include <pybind11/pybind11.h>
15+
#include <pybind11/functional.h>
1516
#include <pybind11/stl.h>
1617
#include <real.h>
1718
#include <vector.h>
@@ -166,6 +167,13 @@ PYBIND11_MODULE(fasttext_pybind, m) {
166167
}
167168
},
168169
py::call_guard<py::gil_scoped_release>());
170+
171+
m.def(
172+
"train_with_callback",
173+
[](fasttext::FastText& ft, fasttext::Args& a, fasttext::FastText::TrainCallback& c) {
174+
ft.train(a, c);
175+
},
176+
py::call_guard<py::gil_scoped_release>());
169177

170178
py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
171179
.def(py::init<ssize_t>())

0 commit comments

Comments
 (0)