Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod training;
pub fn setup_logging() {
let filter = match env::var("RUST_LOG") {
Ok(_) => EnvFilter::from_env("RUST_LOG"),
_ => EnvFilter::new("xd_tts=info,app=info,trainer=info"),
_ => EnvFilter::new("xd_tts=debug,app=info,trainer=info"),
};

let fmt = tracing_subscriber::fmt::Layer::default();
Expand Down
212 changes: 94 additions & 118 deletions src/tacotron2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use griffin_lim::mel::create_mel_filter_bank;
use griffin_lim::GriffinLim;
use ndarray::Array2;
use ndarray::{concatenate, prelude::*};
use ort::{inputs, CPUExecutionProvider, GraphOptimizationLevel, Session, Tensor};
use tract_onnx::prelude::*;
use tract_onnx::tract_hir::infer::InferenceOp;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::rc::Rc;
use tracing::{debug, info};

// Mel parameters:
Expand Down Expand Up @@ -65,51 +68,53 @@ fn sigmoid(x: f32) -> f32 {
}
}

type Model = SimplePlan<InferenceFact, Box<dyn InferenceOp>, Graph<InferenceFact, Box<dyn InferenceOp>>>;

// Downloaded from `https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306` and used
// `export_tacotron2_onnx.py` in https://github.com/NVIDIA/DeepLearningExamples
pub struct Tacotron2 {
encoder: Session,
decoder: Session,
postnet: Session,
encoder: Model,
decoder: Model,
postnet: Model,
phoneme_ids: Vec<Unit>,
}

struct DecoderState {
decoder_input: Array2<f32>,
attention_hidden: Array2<f32>,
attention_cell: Array2<f32>,
decoder_hidden: Array2<f32>,
decoder_cell: Array2<f32>,
attention_weights: Array2<f32>,
attention_weights_cum: Array2<f32>,
attention_context: Array2<f32>,
// memory: CowArray<f32, Ix3>,
// processed_memory: CowArray<f32, Ix3>,
mask: Array2<bool>,
decoder_input: TValue,
attention_hidden: TValue,
attention_cell: TValue,
decoder_hidden: TValue,
decoder_cell: TValue,
attention_weights: TValue,
attention_weights_cum: TValue,
attention_context: TValue,
mask: TValue,
}

impl DecoderState {
fn new(memory: &ArrayViewD<f32>, unpadded_len: usize) -> Self {
fn new(memory: &Tensor, unpadded_len: usize) -> anyhow::Result<Self> {
let bs = memory.shape()[0];
let seq_len = memory.shape()[1];
let attention_rnn_dim = 1024;
let decoder_rnn_dim = 1024;
let encoder_embedding_dim = 512;
let n_mel_channels = 80;

let attention_hidden = Array2::zeros((bs, attention_rnn_dim));
let attention_cell = Array2::zeros((bs, attention_rnn_dim));
let decoder_hidden = Array2::zeros((bs, decoder_rnn_dim));
let decoder_cell = Array2::zeros((bs, decoder_rnn_dim));
let attention_weights = Array2::zeros((bs, seq_len));
let attention_weights_cum = Array2::zeros((bs, seq_len));
let attention_context = Array2::zeros((bs, encoder_embedding_dim));
let decoder_input = Array2::zeros((bs, n_mel_channels));
let attention_hidden = TValue::from_const(Arc::new(Tensor::zero::<f32>(&[bs, attention_rnn_dim])?));
let attention_cell = attention_hidden.clone();
let decoder_hidden = TValue::from_const(Arc::new(Tensor::zero::<f32>(&[bs, decoder_rnn_dim])?));
let decoder_cell = decoder_hidden.clone();
let attention_weights = TValue::from_const(Arc::new(Tensor::zero::<f32>(&[bs, seq_len])?));
let attention_weights_cum = attention_weights.clone();
let attention_context = TValue::from_const(Arc::new(Tensor::zero::<f32>(&[bs, encoder_embedding_dim])?));
let decoder_input = TValue::from_const(Arc::new(Tensor::zero::<f32>(&[bs, n_mel_channels])?));
// This is only really needed for batched inputs
let mut mask = Array2::from_elem((1, seq_len), false);
mask.slice_mut(s![.., unpadded_len..]).fill(true);

Self {
let mask = TValue::from_const(Arc::new(mask.into()));

Ok(Self {
attention_hidden,
attention_cell,
decoder_hidden,
Expand All @@ -119,37 +124,30 @@ impl DecoderState {
attention_context,
decoder_input,
mask,
}
})
}
}

impl Tacotron2 {
pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
// ort calls into a C++ library which has it's own global initialisation that needs to be
// ran. Fortunately, this can be called multiple times so we don't have to fiddle around to
// make it safer.
ort::init()
.with_name("xd_tts")
.with_execution_providers(&[CPUExecutionProvider::default().build()])
.commit()?;

// Load all the networks. Context is added to the error so we can tell easily which network
// messes things up

let encoder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_model_from_file(path.as_ref().join("encoder.onnx"))
.context("converting encoder to runnable model")?;

let decoder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_model_from_file(path.as_ref().join("decoder_iter.onnx"))
.context("converting decoder_iter to runnable model")?;

let postnet = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level1)?
.with_model_from_file(path.as_ref().join("postnet.onnx"))
.context("converting postnet to runnable model")?;

let encoder = tract_onnx::onnx()
.model_for_path(path.as_ref().join("encoder.onnx"))
.context("loading encoder onnx")?
.into_runnable()
.context("creating runnable encoder")?;

let decoder = tract_onnx::onnx()
.model_for_path(path.as_ref().join("decoder_iter.onnx"))
.context("loading decoder onnx")?
.into_runnable()
.context("creating runnable decoder")?;

let postnet = tract_onnx::onnx()
.model_for_path(path.as_ref().join("postnet.onnx"))
.context("loading postnet onnx")?
.into_runnable()
.context("creating runnable postnet")?;

Ok(Self {
encoder,
Expand All @@ -161,40 +159,44 @@ impl Tacotron2 {

fn run_decoder(
&self,
memory: &Array<f32, IxDyn>,
processed_memory: &Array<f32, IxDyn>,
state: &mut DecoderState,
memory: TValue,
processed_memory: TValue,
state: DecoderState,
) -> anyhow::Result<Array2<f32>> {
let gate_threshold = 0.6;
let max_decoder_steps = 1000;

let mut inputs = inputs![
"decoder_input" => state.decoder_input.view(),
"attention_hidden" => state.attention_hidden.view(),
"attention_cell" => state.attention_cell.view(),
"decoder_hidden" => state.decoder_hidden.view(),
"decoder_cell" => state.decoder_cell.view(),
"attention_weights" => state.attention_weights.view(),
"attention_weights_cum" => state.attention_weights_cum.view(),
"attention_context" => state.attention_context.view(),
"memory" => memory.view(),
"processed_memory" => processed_memory.view(),
"mask" => state.mask.view()
]?;
let mut inputs = tvec![
state.decoder_input,
state.attention_hidden,
state.attention_cell,
state.decoder_hidden,
state.decoder_cell,
state.attention_weights,
state.attention_weights_cum,
state.attention_context,
memory.clone(),
processed_memory.clone(),
state.mask.clone(),
];
// Concat the spectrogram etc

let mut mel_spec = Array2::zeros((0, 0));

// Because we always break out of this we could use `loop`.
for i in 0..max_decoder_steps {
debug!("Decoder iter: {}", i);
// init decoder inputs
let mut infer = self.decoder.run(inputs)?;

let gate_prediction = &infer["gate_prediction"].extract_tensor::<f32>()?;
let mel_output = &infer["decoder_output"].extract_tensor::<f32>()?;
let mel_output = mel_output.view().clone().into_dimensionality()?;
let gate_prediction = infer.remove(1);
let gate_prediction = *gate_prediction.to_scalar::<f32>()?;
let mel_output = &infer[0];
let mel_output = mel_output.to_array_view::<f32>()?
.clone()
.into_dimensionality()?;

debug!("Gate: {}", gate_prediction.view()[[0, 0]]);
debug!("Gate: {}", gate_prediction);

if i == 0 {
mel_spec = mel_output.to_owned();
Expand All @@ -203,7 +205,7 @@ impl Tacotron2 {
.context("Joining decoder iter output")?;
}

if sigmoid(gate_prediction.view()[[0, 0]]) > gate_threshold
if sigmoid(gate_prediction) > gate_threshold
|| i + 1 == max_decoder_steps
{
debug!("Stopping after {} steps", i);
Expand All @@ -212,47 +214,20 @@ impl Tacotron2 {
// Prepare the inputs for the next run. We could put this in a condition, but as it's
// moved on inference it's hard to do this and keep the borrow checker happy. So I
// moved the condition up to above with the break.
inputs = inputs![
"memory" => memory.view(),
"processed_memory" => processed_memory.view(),
"mask" => state.mask.view(),
]?;
inputs.insert("decoder_input", infer.remove("decoder_output").unwrap());
inputs.insert(
"attention_hidden",
infer.remove("out_attention_hidden").unwrap(),
);
inputs.insert(
"attention_cell",
infer.remove("out_attention_cell").unwrap(),
);
inputs.insert(
"decoder_hidden",
infer.remove("out_decoder_hidden").unwrap(),
);
inputs.insert("decoder_cell", infer.remove("out_decoder_cell").unwrap());
inputs.insert(
"attention_weights",
infer.remove("out_attention_weights").unwrap(),
);
inputs.insert(
"attention_weights_cum",
infer.remove("out_attention_weights_cum").unwrap(),
);
inputs.insert(
"attention_context",
infer.remove("out_attention_context").unwrap(),
);
inputs = infer;
inputs.push(memory.clone());
inputs.push(processed_memory.clone());
inputs.push(state.mask.clone());
}

// We have to transpose it and add in a batch dimension for it to be the right shape.
let mel_spec = mel_spec.t().insert_axis(Axis(0));
let mel_spec = mel_spec.t().insert_axis(Axis(0)).into_owned();

let post = self.postnet.run(inputs![mel_spec.view()]?)?;
let mel_spec = TValue::Var(Rc::new(mel_spec.into()));
let post = self.postnet.run(tvec![mel_spec])?;

let post = post["mel_outputs_postnet"]
.extract_tensor::<f32>()?
.view()
let post = post[0]
.to_array_view::<f32>()?
.clone()
.remove_axis(Axis(0))
.into_dimensionality()?
Expand All @@ -262,6 +237,7 @@ impl Tacotron2 {
}

fn infer_chunk(&self, mut phonemes: Vec<i64>) -> anyhow::Result<Array2<f32>> {
debug!("Running {} phonemes pre padding", phonemes.len());
let units_len = phonemes.len();
assert!(units_len <= 100);

Expand All @@ -274,25 +250,25 @@ impl Tacotron2 {
}

// Run encoder
debug!("{:?}", phonemes.len());
let plen = arr1(&[phonemes.len() as i64]);
let plen = Tensor::from_shape(&[1], &[phonemes.len() as i64])?;
let plen = TValue::from_const(Arc::new(plen));
let phonemes =
Array2::from_shape_vec((1, phonemes.len()), phonemes).context("invalid dimensions")?;
let phonemes = TValue::from_const(Arc::new(phonemes.into()));

let encoder_outputs = self.encoder.run(inputs![phonemes, plen]?)?;
debug!("Starting encoder inference");
let mut encoder_outputs = self.encoder.run(tvec![phonemes, plen])?;
debug!("Finished encoder inference");
assert_eq!(encoder_outputs.len(), 3);

// The outputs in order are: memory, processed_memory, lens. Despite the name
// OrtOwnedTensor
let memory: Tensor<f32> = encoder_outputs[0].extract_tensor()?;
let processed_memory: Tensor<f32> = encoder_outputs[1].extract_tensor()?;

let mut decoder_state = DecoderState::new(&memory.view(), units_len);
let memory = encoder_outputs.remove(0);
let processed_memory = encoder_outputs.remove(0);

let memory = memory.view().to_owned();
let processed_memory = processed_memory.view().to_owned();
let decoder_state = DecoderState::new(&memory, units_len)?;

self.run_decoder(&memory, &processed_memory, &mut decoder_state)
self.run_decoder(memory, processed_memory, decoder_state)
}

/// Runs inference on the units returning a mel-spectrogram
Expand Down