Skip to content

Add models for kinect-wsj dataset #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.formatting.provider": "yapf"
}
238 changes: 238 additions & 0 deletions asteroid/engine/system_kinect_wsj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import torch
import pytorch_lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau
from asteroid_filterbanks.transforms import mag

from ..utils import flatten_dict


class System(pl.LightningModule):
"""Base class for deep learning systems.
Contains a model, an optimizer, a loss function, training and validation
dataloaders and learning rate scheduler.

Note that by default, any PyTorch-Lightning hooks are *not* passed to the model.
If you want to use Lightning hooks, add the hooks to a subclass::

class MySystem(System):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
return self.model.on_train_batch_start(batch, batch_idx, dataloader_idx)

Args:
model (torch.nn.Module): Instance of model.
optimizer (torch.optim.Optimizer): Instance or list of optimizers.
loss_func (callable): Loss function with signature
(est_targets, targets).
train_loader (torch.utils.data.DataLoader): Training dataloader.
val_loader (torch.utils.data.DataLoader): Validation dataloader.
scheduler (torch.optim.lr_scheduler._LRScheduler): Instance, or list
of learning rate schedulers. Also supports dict or list of dict as
``{"interval": "step", "scheduler": sched}`` where ``interval=="step"``
for step-wise schedulers and ``interval=="epoch"`` for classical ones.
config: Anything to be saved with the checkpoints during training.
The config dictionary to re-instantiate the run for example.

.. note:: By default, ``training_step`` (used by ``pytorch-lightning`` in the
training loop) and ``validation_step`` (used for the validation loop)
share ``common_step``. If you want different behavior for the training
loop and the validation loop, overwrite both ``training_step`` and
``validation_step`` instead.

For more info on its methods, properties and hooks, have a look at lightning's docs:
https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#lightningmodule-api
"""

default_monitor: str = "val_loss"

def __init__(
self,
model,
optimizer,
loss_func,
train_loader,
val_loader=None,
scheduler=None,
config=None,
mask_mixture=True,
):
super().__init__()
self.model = model
self.optimizer = optimizer
self.loss_func = loss_func
self.mask_mixture = mask_mixture
self.train_loader = train_loader
self.val_loader = val_loader
self.scheduler = scheduler
self.config = {} if config is None else config
# Save lightning's AttributeDict under self.hparams
self.save_hyperparameters(self.config_to_hparams(self.config))

def forward(self, *args, **kwargs):
"""Applies forward pass of the model.

Returns:
:class:`torch.Tensor`
"""
return self.model(*args, **kwargs)

def common_step(self, batch, batch_nb, train=True):
"""Common forward step between training and validation.

The function of this method is to unpack the data given by the loader,
forward the batch through the model and compute the loss.
Pytorch-lightning handles all the rest.

Args:
batch: the object returned by the loader (a list of torch.Tensor
in most cases) but can be something else.
batch_nb (int): The number of the batch in the epoch.
train (bool): Whether in training mode. Needed only if the training
and validation steps are fundamentally different, otherwise,
pytorch-lightning handles the usual differences.

Returns:
:class:`torch.Tensor` : The loss value on this batch.

.. note::
This is typically the method to overwrite when subclassing
``System``. If the training and validation steps are somehow
different (except for ``loss.backward()`` and ``optimzer.step()``),
the argument ``train`` can be used to switch behavior.
Otherwise, ``training_step`` and ``validation_step`` can be overwriten.
"""

inputs, targets, masks = batch
# Take the first channels
inputs = inputs[..., 0]
targets = targets[..., 0]
est_targets = self(inputs)
loss = self.loss_func(est_targets, targets)
return loss
'''
inputs, targets, masks = self.unpack_data(batch)
embeddings, est_masks = self(inputs)
spec = mag(self.model.encoder(inputs.unsqueeze(1)))
if self.mask_mixture:
est_masks = est_masks * spec.unsqueeze(1)
masks = masks * spec.unsqueeze(1)
loss, loss_dic = self.loss_func(
embeddings, targets, est_src=est_masks, target_src=masks, mix_spec=spec
)
return loss
'''

def training_step(self, batch, batch_nb):
"""Pass data through the model and compute the loss.

Backprop is **not** performed (meaning PL will do it for you).

Args:
batch: the object returned by the loader (a list of torch.Tensor
in most cases) but can be something else.
batch_nb (int): The number of the batch in the epoch.

Returns:
torch.Tensor, the value of the loss.
"""

loss = self.common_step(batch, batch_nb, train=True)
self.log("loss", loss, logger=True)
return loss

def validation_step(self, batch, batch_nb):
"""Need to overwrite PL validation_step to do validation.

Args:
batch: the object returned by the loader (a list of torch.Tensor
in most cases) but can be something else.
batch_nb (int): The number of the batch in the epoch.
"""
loss = self.common_step(batch, batch_nb, train=False)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)

def on_validation_epoch_end(self):
"""Log hp_metric to tensorboard for hparams selection."""
hp_metric = self.trainer.callback_metrics.get("val_loss", None)
if hp_metric is not None:
self.trainer.logger.log_metrics({"hp_metric": hp_metric},
step=self.trainer.global_step)

def configure_optimizers(self):
"""Initialize optimizers, batch-wise and epoch-wise schedulers."""
if self.scheduler is None:
return self.optimizer

if not isinstance(self.scheduler, (list, tuple)):
self.scheduler = [self.scheduler] # support multiple schedulers

epoch_schedulers = []
for sched in self.scheduler:
if not isinstance(sched, dict):
if isinstance(sched, ReduceLROnPlateau):
sched = {
"scheduler": sched,
"monitor": self.default_monitor
}
epoch_schedulers.append(sched)
else:
sched.setdefault("monitor", self.default_monitor)
sched.setdefault("frequency", 1)
# Backward compat
if sched["interval"] == "batch":
sched["interval"] = "step"
assert sched["interval"] in [
"epoch",
"step",
], "Scheduler interval should be either step or epoch"
epoch_schedulers.append(sched)
return [self.optimizer], epoch_schedulers

def train_dataloader(self):
"""Training dataloader"""
return self.train_loader

def val_dataloader(self):
"""Validation dataloader"""
return self.val_loader

def on_save_checkpoint(self, checkpoint):
"""Overwrite if you want to save more things in the checkpoint."""
checkpoint["training_config"] = self.config
return checkpoint

def unpack_data(self, batch, EPS=1e-8):
mix, sources, noise = batch
# Take only the first channel
mix = mix[..., 0]
sources = sources[..., 0]
noise = noise[..., 0]
noise = noise.unsqueeze(1)
# Compute magnitude spectrograms and IRM
src_mag_spec = mag(self.model.encoder(sources))
noise_mag_spec = mag(self.model.encoder(noise))
noise_mag_spec = noise_mag_spec.unsqueeze(1)
real_mask = src_mag_spec / (noise_mag_spec +
src_mag_spec.sum(1, keepdim=True) + EPS)
# Get the src idx having the maximum energy
binary_mask = real_mask.argmax(1)
return mix, binary_mask, real_mask

@staticmethod
def config_to_hparams(dic):
"""Sanitizes the config dict to be handled correctly by torch
SummaryWriter. It flatten the config dict, converts ``None`` to
``"None"`` and any list and tuple into torch.Tensors.

Args:
dic (dict): Dictionary to be transformed.

Returns:
dict: Transformed dictionary.
"""
dic = flatten_dict(dic)
for k, v in dic.items():
if v is None:
dic[k] = str(v)
elif isinstance(v, (list, tuple)):
dic[k] = torch.tensor(v)
return dic
Binary file added egs/.DS_Store
Binary file not shown.
Binary file added egs/kinect-wsj/.DS_Store
Binary file not shown.
Binary file added egs/kinect-wsj/ConvTasNet/.DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions egs/kinect-wsj/ConvTasNet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Results
Coming soon
143 changes: 143 additions & 0 deletions egs/kinect-wsj/ConvTasNet/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import os
import random
import soundfile as sf
import torch
import yaml
import json
import argparse
import pandas as pd
from tqdm import tqdm
from pprint import pprint

from asteroid.metrics import get_metrics
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from asteroid import ConvTasNet
from asteroid.utils import tensors_to_device
from asteroid.dsp.normalization import normalize_estimates
from asteroid.data import KinectWsjMixDataset

parser = argparse.ArgumentParser()
parser.add_argument("--test_dir",
type=str,
required=True,
help="Test directory including the csv files")
parser.add_argument("--n_src", type=int, default=2)
parser.add_argument(
"--out_dir",
type=str,
required=True,
help="Directory in exp_dir where the eval results"
" will be stored",
)
parser.add_argument("--use_gpu",
type=int,
default=0,
help="Whether to use the GPU for model execution")
parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root")
parser.add_argument("--n_save_ex",
type=int,
default=10,
help="Number of audio examples to save, -1 means all")

compute_metrics = ["si_sdr", "sdr", "sir", "sar", "stoi"]


def main(conf):
model_path = os.path.join(conf["exp_dir"], "best_model.pth")
model = ConvTasNet.from_pretrained(model_path)
# Handle device placement
if conf["use_gpu"]:
model.cuda()
model_device = next(model.parameters()).device
test_set = KinectWsjMixDataset(conf["test_dir"],
n_src=conf["n_src"],
segment=None) # Uses all segment length
# Used to reorder sources only
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

# Randomly choose the indexes of sentences to save.
eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"])
ex_save_dir = os.path.join(eval_save_dir, "examples/")
if conf["n_save_ex"] == -1:
conf["n_save_ex"] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, sources, noises = tensors_to_device(test_set[idx],
device=model_device)
mix = mix[..., 0]
sources = sources[..., 0]
#est_sources = model(mix.unsqueeze(0))
est_sources = model.separate(mix[None, None])
loss, reordered_sources = loss_func(est_sources,
sources[None],
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
# For each utterance, we get a dictionary with the mixture path,
# the input and output metrics
utt_metrics = get_metrics(
mix_np,
sources_np,
est_sources_np,
sample_rate=conf["sample_rate"],
metrics_list=compute_metrics,
)
utt_metrics["mix_path"] = test_set.mix[idx][0]
est_sources_np_normalized = normalize_estimates(est_sources_np, mix_np)

series_list.append(pd.Series(utt_metrics))

# Save some examples in a folder. Wav files and metrics as text.
if idx in save_idx:
local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
os.makedirs(local_save_dir, exist_ok=True)
sf.write(local_save_dir + "mixture.wav", mix_np[0],
conf["sample_rate"])
# Loop over the sources and estimates
for src_idx, src in enumerate(sources_np):
sf.write(local_save_dir + "s{}.wav".format(src_idx), src,
conf["sample_rate"])
for src_idx, est_src in enumerate(est_sources_np_normalized):
sf.write(
local_save_dir + "s{}_estimate.wav".format(src_idx),
est_src,
conf["sample_rate"],
)
# Write local metrics to the example folder.
with open(local_save_dir + "metrics.json", "w") as f:
json.dump(utt_metrics, f, indent=0)

# Save all metrics to the experiment folder.
all_metrics_df = pd.DataFrame(series_list)
all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv"))

# Print and save summary metrics
final_results = {}
for metric_name in compute_metrics:
input_metric_name = "input_" + metric_name
ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
final_results[metric_name] = all_metrics_df[metric_name].mean()
final_results[metric_name + "_imp"] = ldf.mean()

print("Overall metrics :")
pprint(final_results)

with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f:
json.dump(final_results, f, indent=0)


if __name__ == "__main__":
args = parser.parse_args()
arg_dic = dict(vars(args))
# Load training config
conf_path = os.path.join(args.exp_dir, "conf.yml")
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
arg_dic["sample_rate"] = train_conf["data"]["sample_rate"]
arg_dic["train_conf"] = train_conf

main(arg_dic)
Loading