Skip to content

Commit 61cda6b

Browse files
authored
support loading kaldi model in python (#3976)
* support load kaldi model in python * add some component * split one file to multi component wrap files * fix some bugs and add test mdl * add testmode func in batchnorm pybind * change StatsSum StatsSumsq to Mean Var * make const
1 parent a634a5c commit 61cda6b

21 files changed

+574
-0
lines changed

src/nnet3/nnet-convolutional-component.h

+2
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ class TdnnComponent: public UpdatableComponent {
553553

554554
CuMatrixBase<BaseFloat> &LinearParams() { return linear_params_; }
555555

556+
const CuMatrix<BaseFloat> &Linearparams() const { return linear_params_; }
557+
556558
// This allows you to resize the vector in order to add a bias where
557559
// there previously was none-- obviously this should be done carefully.
558560
CuVector<BaseFloat> &BiasParams() { return bias_params_; }

src/nnet3/nnet-normalize-component.cc

+18
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,24 @@ void BatchNormComponent::Write(std::ostream &os, bool binary) const {
641641
WriteToken(os, binary, "</BatchNormComponent>");
642642
}
643643

644+
CuVector<BaseFloat> BatchNormComponent::Mean() const {
645+
CuVector<BaseFloat> mean(stats_sum_);
646+
if (count_ != 0) {
647+
mean.Scale(1.0 / count_);
648+
}
649+
return mean;
650+
}
651+
652+
CuVector<BaseFloat> BatchNormComponent::Var() const {
653+
CuVector<BaseFloat> mean(stats_sum_), var(stats_sumsq_);
654+
if (count_ != 0) {
655+
mean.Scale(1.0 / count_);
656+
var.Scale(1.0 / count_);
657+
var.AddVecVec(-1.0, mean, mean, 1.0);
658+
}
659+
return var;
660+
}
661+
644662
void BatchNormComponent::Scale(BaseFloat scale) {
645663
if (scale == 0) {
646664
count_ = 0.0;

src/nnet3/nnet-normalize-component.h

+5
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ class BatchNormComponent: public Component {
224224
const CuVector<BaseFloat> &Offset() const { return offset_; }
225225
const CuVector<BaseFloat> &Scale() const { return scale_; }
226226

227+
CuVector<BaseFloat> Mean() const;
228+
CuVector<BaseFloat> Var() const;
229+
double Count() const { return count_; }
230+
BaseFloat Eps() const { return epsilon_; }
231+
227232
private:
228233

229234
struct Memo {

src/nnet3/nnet-simple-component.h

+1
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ class LinearComponent: public UpdatableComponent {
971971
BaseFloat OrthonormalConstraint() const { return orthonormal_constraint_; }
972972
CuMatrixBase<BaseFloat> &Params() { return params_; }
973973
const CuMatrixBase<BaseFloat> &Params() const { return params_; }
974+
const CuMatrix<BaseFloat> &Params2() const { return params_; }
974975
private:
975976

976977
// disallow assignment operator.

src/pybind/Makefile

100644100755
+5
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ matrix/sparse_matrix_pybind.cc \
9090
nnet3/nnet3_pybind.cc \
9191
nnet3/nnet_chain_example_pybind.cc \
9292
nnet3/nnet_common_pybind.cc \
93+
nnet3/nnet_component_itf_pybind.cc \
94+
nnet3/nnet_convolutional_component_pybind.cc \
9395
nnet3/nnet_example_pybind.cc \
96+
nnet3/nnet_nnet_pybind.cc \
97+
nnet3/nnet_normalize_component_pybind.cc \
98+
nnet3/nnet_simple_component_pybind.cc \
9499
tests/test_dlpack_subvector.cc \
95100
util/kaldi_holder_pybind.cc \
96101
util/kaldi_io_pybind.cc \

src/pybind/kaldi/io_util.py

100644100755
+16
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,19 @@ def read_transition_model(rxfilename):
8282
ki.Close()
8383

8484
return trans_model
85+
86+
87+
def read_nnet3_model(rxfilename):
88+
'''Read nnet model from an rxfilename.
89+
'''
90+
ki = kaldi_pybind.Input()
91+
is_opened, is_binary = ki.Open(rxfilename, read_header=True)
92+
if not is_opened:
93+
raise FileNotOpenException('Failed to open {}'.format(rxfilename))
94+
95+
nnet = kaldi_pybind.nnet3.Nnet()
96+
nnet.Read(ki.Stream(), is_binary)
97+
98+
ki.Close()
99+
100+
return nnet

src/pybind/nnet3/Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

22
test:
33
python3 ./nnet_chain_example_pybind_test.py
4+
python3 ./nnet_nnet_pybind_test.py
45

src/pybind/nnet3/final.mdl

322 KB
Binary file not shown.

src/pybind/nnet3/nnet3_pybind.cc

+11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
// Copyright 2019 Mobvoi AI Lab, Beijing, China
44
// (author: Fangjun Kuang, Yaguang Hu, Jian Wang)
5+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
56

67
// See ../../../COPYING for clarification regarding multiple authors
78
//
@@ -22,12 +23,22 @@
2223

2324
#include "nnet3/nnet_chain_example_pybind.h"
2425
#include "nnet3/nnet_common_pybind.h"
26+
#include "nnet3/nnet_component_itf_pybind.h"
27+
#include "nnet3/nnet_convolutional_component_pybind.h"
2528
#include "nnet3/nnet_example_pybind.h"
29+
#include "nnet3/nnet_nnet_pybind.h"
30+
#include "nnet3/nnet_normalize_component_pybind.h"
31+
#include "nnet3/nnet_simple_component_pybind.h"
2632

2733
void pybind_nnet3(py::module& _m) {
2834
py::module m = _m.def_submodule("nnet3", "nnet3 pybind for Kaldi");
2935

3036
pybind_nnet_common(m);
37+
pybind_nnet_component_itf(m);
38+
pybind_nnet_convolutional_component(m);
3139
pybind_nnet_example(m);
3240
pybind_nnet_chain_example(m);
41+
pybind_nnet_nnet(m);
42+
pybind_nnet_normalize_component(m);
43+
pybind_nnet_simple_component(m);
3344
}

src/pybind/nnet3/nnet_chain_example_pybind_test.py

100755100644
File mode changed.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// pybind/nnet3/nnet_component_itf_pybind.cc
2+
3+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
#include "nnet3/nnet_component_itf_pybind.h"
21+
22+
#include "nnet3/nnet-component-itf.h"
23+
24+
using namespace kaldi::nnet3;
25+
26+
void pybind_nnet_component_itf(py::module& m) {
27+
using PyClass = Component;
28+
py::class_<PyClass>(m, "Component",
29+
"Abstract base-class for neural-net components.")
30+
.def("Type", &PyClass::Type,
31+
"Returns a string such as \"SigmoidComponent\", describing the "
32+
"type of the object.")
33+
.def("Info", &PyClass::Info,
34+
"Returns some text-form information about this component, for "
35+
"diagnostics. Starts with the type of the component. E.g. "
36+
"\"SigmoidComponent dim=900\", although most components will have "
37+
"much more info.")
38+
.def_static("NewComponentOfType", &PyClass::NewComponentOfType,
39+
py::return_value_policy::take_ownership);
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// pybind/nnet3/nnet_component_itf_pybind.h
2+
3+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
12+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
13+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
14+
// MERCHANTABLITY OR NON-INFRINGEMENT.
15+
// See the Apache 2 License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#ifndef KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
19+
#define KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
20+
21+
#include "pybind/kaldi_pybind.h"
22+
23+
void pybind_nnet_component_itf(py::module& m);
24+
25+
#endif // KALDI_PYBIND_NNET3_NNET_COMPONENT_ITF_PYBIND_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// pybind/nnet3/nnet_convolutional_component_pybind.cc
2+
3+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
#include "nnet3/nnet_convolutional_component_pybind.h"
21+
22+
#include "nnet3/nnet-convolutional-component.h"
23+
24+
using namespace kaldi::nnet3;
25+
26+
void pybind_nnet_convolutional_component(py::module& m) {
27+
using TC = kaldi::nnet3::TdnnComponent;
28+
py::class_<TC, Component>(m, "TdnnComponent")
29+
.def("LinearParams", &TC::Linearparams,
30+
py::return_value_policy::reference)
31+
.def("BiasParams", &TC::BiasParams,
32+
py::return_value_policy::reference);
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// pybind/nnet3/nnet_convolutional_component_pybind.h
2+
3+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
12+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
13+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
14+
// MERCHANTABLITY OR NON-INFRINGEMENT.
15+
// See the Apache 2 License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#ifndef KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
19+
#define KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_
20+
21+
#include "pybind/kaldi_pybind.h"
22+
23+
void pybind_nnet_convolutional_component(py::module& m);
24+
25+
#endif // KALDI_PYBIND_NNET3_NNET_CONVOLUTIONAL_COMPONENT_PYBIND_H_

src/pybind/nnet3/nnet_nnet_pybind.cc

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// pybind/nnet3/nnet_nnet_pybind.cc
2+
3+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
#include "nnet3/nnet_nnet_pybind.h"
21+
22+
#include "nnet3/nnet-nnet.h"
23+
24+
using namespace kaldi;
25+
using namespace kaldi::nnet3;
26+
27+
void pybind_nnet_nnet(py::module& m) {
28+
using PyClass = kaldi::nnet3::Nnet;
29+
auto nnet = py::class_<PyClass>(
30+
m, "Nnet",
31+
"This function can be used either to initialize a new Nnet from a "
32+
"config file, or to add to an existing Nnet, possibly replacing "
33+
"certain parts of it. It will die with error if something went wrong. "
34+
"Also see the function ReadEditConfig() in nnet-utils.h (it's made a "
35+
"non-member because it doesn't need special access).");
36+
nnet.def(py::init<>())
37+
.def("Read", &PyClass::Read, py::arg("is"), py::arg("binary"))
38+
.def("GetComponentNames", &PyClass::GetComponentNames,
39+
"returns vector of component names (needed by some parsing code, "
40+
"for instance).",
41+
py::return_value_policy::reference)
42+
.def("GetComponentName", &PyClass::GetComponentName,
43+
py::arg("component_index"))
44+
.def("Info", &PyClass::Info,
45+
"returns some human-readable information about the network, "
46+
"mostly for debugging purposes. Also see function NnetInfo() in "
47+
"nnet-utils.h, which prints out more extensive infoformation.")
48+
.def("NumComponents", &PyClass::NumComponents)
49+
.def("NumNodes", &PyClass::NumNodes)
50+
.def("GetComponent", (Component * (PyClass::*)(int32)) & PyClass::GetComponent,
51+
py::arg("c"), py::return_value_policy::reference);
52+
}

src/pybind/nnet3/nnet_nnet_pybind.h

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// pybind/nnet3/nnet_nnet_pybind.h
2+
3+
// Copyright 2020 JD AI, Beijing, China (author: Lu Fan)
4+
5+
// See ../../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
12+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
13+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
14+
// MERCHANTABLITY OR NON-INFRINGEMENT.
15+
// See the Apache 2 License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#ifndef KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
19+
#define KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_
20+
21+
#include "pybind/kaldi_pybind.h"
22+
23+
void pybind_nnet_nnet(py::module& m);
24+
25+
#endif // KALDI_PYBIND_NNET3_NNET_NNET_PYBIND_H_

0 commit comments

Comments
 (0)