-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathvits.cpp
More file actions
120 lines (77 loc) · 2.79 KB
/
vits.cpp
File metadata and controls
120 lines (77 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include "vits.h"
std::vector<int64_t> VITS::ZeroPadVec(const std::vector<int32_t> &InIDs)
{
std::vector<int64_t> NewIDs;
NewIDs.reserve(InIDs.size() * 2);
NewIDs.push_back(0);
for (auto CharID : InIDs)
{
NewIDs.push_back((int64_t)CharID);
NewIDs.push_back(0);
}
// Add final 0
// NewIDs.push_back(0);
return NewIDs;
}
VITS::VITS()
{
}
bool VITS::Initialize(const std::string &SavedModelFolder, ETTSRepo::Enum InTTSRepo)
{
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
Model = torch::jit::load(SavedModelFolder);
}
catch (const c10::Error& e) {
return false;
}
CurrentRepo = InTTSRepo;
return true;
}
TFTensor<float> VITS::DoInference(const std::vector<int32_t> &InputIDs, const std::vector<float> &ArgsFloat, const std::vector<int32_t> ArgsInt, int32_t SpeakerID, int32_t EmotionID)
{
// without this memory consumption is 4x
torch::NoGradGuard no_grad;
// TorchMoji hidden states are added to ArgsFloat
const bool UsesTorchMoji = ArgsFloat.size() > 1;
std::vector<int64_t> PaddedIDs;
// Our current TM-enabled models don't use zero interspersion
if (UsesTorchMoji)
PaddedIDs.assign(InputIDs.begin(),InputIDs.end());
else
PaddedIDs = ZeroPadVec(InputIDs);
std::vector<int64_t> inLen = { (int64_t)PaddedIDs.size() };
// ZDisket: Is this really necessary?
torch::TensorOptions Opts = torch::TensorOptions().requires_grad(false);
auto InIDS = torch::tensor(PaddedIDs, Opts).unsqueeze(0);
auto InLens = torch::tensor(inLen, Opts);
auto InLenScale = torch::tensor({ ArgsFloat[0]}, Opts);
std::vector<torch::jit::IValue> inputs{ InIDS,InLens,InLenScale };
if (SpeakerID != -1){
auto InSpkid = torch::tensor({SpeakerID},Opts);
inputs.push_back(InSpkid);
}
if (EmotionID != -1){
auto InEmid = torch::tensor({EmotionID},Opts);
inputs.push_back(InEmid);
}
// Handle TorchMoji Emb
if (UsesTorchMoji){
// Make a copy stripping first elem
std::vector<float> TMHidden(ArgsFloat.begin() + 1, ArgsFloat.end());
auto InMoji = torch::tensor(TMHidden,Opts).unsqueeze(0);
inputs.push_back(InMoji);
}
// Infer
c10::IValue Output = Model.get_method("infer_ts")(inputs);
// Output = tuple (audio,att)
auto OutputT = Output.toTuple();
// Grab audio
// [1, frames] -> [frames]
auto AuTens = OutputT.get()->elements()[0].toTensor().squeeze();
// Grab Attention
// [1, 1, x, y] -> [x, y] -> [y,x] -> [1, y, x]
auto AttTens = OutputT.get()->elements()[1].toTensor().squeeze().transpose(0,1).unsqueeze(0);
Attention = VoxUtil::CopyTensor<float>(AttTens);
return VoxUtil::CopyTensor<float>(AuTens);
}