Skip to content
This repository was archived by the owner on Mar 25, 2024. It is now read-only.

add NME_SC method for clustering #1

Open
wants to merge 12 commits into
base: max_speaker
Choose a base branch
from
77 changes: 77 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os,sys,time
import argparse
from simple_diarizer.diarizer import Diarizer
import pprint

parser = argparse.ArgumentParser(
description="Speaker diarization",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,

)
parser.add_argument(dest='audio_name', type=str, help="Input audio file")
parser.add_argument(dest='outputfile', nargs="?", default=None, help="Optional output file")
parser.add_argument("--number_of_speakers", dest='number_of_speaker', default=None, type=int, help="Number of speakers (if known)")
parser.add_argument("--max_speakers", dest='max_speakers', default=25, type=int, help="Maximum number of speakers (if number of speaker is unknown)")
parser.add_argument("--embed_model", dest='embed_model', default="ecapa", type=str, help="Name of embedding")
parser.add_argument("--cluster_method", dest='cluster_method', default="nme-sc", type=str, help="Clustering method")
args = parser.parse_args()

diar = Diarizer(
embed_model=args.embed_model, # 'xvec' and 'ecapa' supported
cluster_method=args.cluster_method # 'ahc' 'sc' and 'nme-sc' supported
)

WAV_FILE=args.audio_name
num_speakers=args.number_of_speaker if args.number_of_speaker != "None" else None
max_spk= args.max_speakers
output_file=args.outputfile

t0 = time.time()

segments = diar.diarize(WAV_FILE, num_speakers=num_speakers,max_speakers=max_spk,outfile=output_file)

print("Time used for processing:", time.time() - t0)

if not output_file:

json = {}
_segments = []
_speakers = {}
seg_id = 1
spk_i = 1
spk_i_dict = {}

for seg in segments:

segment = {}
segment["seg_id"] = seg_id

if seg['label'] not in spk_i_dict.keys():
spk_i_dict[seg['label']] = spk_i
spk_i += 1

spk_id = "spk" + str(spk_i_dict[seg['label']])
segment["spk_id"] = spk_id
segment["seg_begin"] = round(seg['start'])
segment["seg_end"] = round(seg['end'])

if spk_id not in _speakers:
_speakers[spk_id] = {}
_speakers[spk_id]["spk_id"] = spk_id
_speakers[spk_id]["duration"] = seg['end']-seg['start']
_speakers[spk_id]["nbr_seg"] = 1
else:
_speakers[spk_id]["duration"] += seg['end']-seg['start']
_speakers[spk_id]["nbr_seg"] += 1

_segments.append(segment)
seg_id += 1

for spkstat in _speakers.values():
spkstat["duration"] = round(spkstat["duration"])

json["speakers"] = list(_speakers.values())
json["segments"] = _segments

pprint.pprint(json)

6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
ipython>=7.9.0
matplotlib>=3.5.1
# ipython>=7.9.0
# matplotlib>=3.5.1
pandas>=1.3.5
scikit-learn>=1.0.2
speechbrain>=0.5.11
torchaudio>=0.10.1
onnxruntime>=1.14.0
scipy<=1.8.1 # newer version can provoke segmentation faults
2 changes: 1 addition & 1 deletion simple_diarizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import os

__version__ = os.getenv("GITHUB_REF_NAME", "latest")
__version__ = os.getenv("GITHUB_REF_NAME", "1.0.2")
54 changes: 48 additions & 6 deletions simple_diarizer/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy.ndimage import gaussian_filter
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
from sklearn.metrics import pairwise_distances

from .spectral_clustering import NME_SpectralClustering

def similarity_matrix(embeds, metric="cosine"):
return pairwise_distances(embeds, metric=metric)
Expand Down Expand Up @@ -43,9 +43,7 @@ def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwar
# A lot of these methods are lifted from
# https://github.com/wq2012/SpectralCluster
##########################################


def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs):
def cluster_SC(embeds, n_clusters=None, max_speakers= None, threshold=None, enhance_sim=True, **kwargs):
"""
Cluster embeds using Spectral Clustering
"""
Expand All @@ -59,7 +57,7 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwar
if n_clusters is None:
(eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S)
# Get number of clusters.
k = compute_number_of_clusters(eigenvalues, 100, threshold)
k = compute_number_of_clusters(eigenvalues, max_speakers, threshold)

# Get spectral embeddings.
spectral_embeddings = eigenvectors[:, :k]
Expand All @@ -82,6 +80,25 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwar
return cluster_model.fit_predict(S)


def cluster_NME_SC(embeds, n_clusters=None, max_speakers= None, threshold=None, enhance_sim=True, **kwargs):
"""
Cluster embeds using NME-Spectral Clustering

if n_clusters is None:
assert threshold, "If num_clusters is not defined, threshold must be defined"
"""

S = cos_similarity(embeds)

labels = NME_SpectralClustering(
S,
num_clusters=n_clusters,
max_num_clusters=max_speakers
)

return labels


def diagonal_fill(A):
"""
Sets the diagonal elemnts of the matrix to the max of each row
Expand Down Expand Up @@ -134,7 +151,7 @@ def row_max_norm(A):
def sim_enhancement(A):
func_order = [
diagonal_fill,
gaussian_blur,

row_threshold_mult,
symmetrization,
diffusion,
Expand All @@ -144,6 +161,31 @@ def sim_enhancement(A):
A = f(A)
return A

def cos_similarity(x):
"""Compute cosine similarity matrix in CPU & memory sensitive way

Args:
x (np.ndarray): embeddings, 2D array, embeddings are in rows

Returns:
np.ndarray: cosine similarity matrix

"""
assert x.ndim == 2, f"x has {x.ndim} dimensions, it must be matrix"
x = x / (np.sqrt(np.sum(np.square(x), axis=1, keepdims=True)) + 1.0e-32)
assert np.allclose(np.ones_like(x[:, 0]), np.sum(np.square(x), axis=1))
max_n_elm = 200000000
step = max(max_n_elm // (x.shape[0] * x.shape[0]), 1)
retval = np.zeros(shape=(x.shape[0], x.shape[0]), dtype=np.float64)
x0 = np.expand_dims(x, 0)
x1 = np.expand_dims(x, 1)
for i in range(0, x.shape[1], step):
product = x0[:, :, i : i + step] * x1[:, :, i : i + step]
retval += np.sum(product, axis=2, keepdims=False)
assert np.all(retval >= -1.0001), retval
assert np.all(retval <= 1.0001), retval
return retval


def compute_affinity_matrix(X):
"""Compute the affinity matrix from data.
Expand Down
69 changes: 42 additions & 27 deletions simple_diarizer/diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import pandas as pd
import torch
import torchaudio
from speechbrain.pretrained import EncoderClassifier
from speechbrain.inference.speaker import EncoderClassifier
from tqdm.autonotebook import tqdm

from .cluster import cluster_AHC, cluster_SC
from .cluster import cluster_AHC, cluster_SC, cluster_NME_SC
from .utils import check_wav_16khz_mono, convert_wavfile


Expand All @@ -25,12 +25,16 @@ def __init__(
assert cluster_method in [
"ahc",
"sc",
], "Only ahc and sc in the supported clustering options"
"nme-sc",
], "Only ahc,sc and nme-sc in the supported clustering options"

if cluster_method == "ahc":
self.cluster = cluster_AHC
if cluster_method == "sc":
self.cluster = cluster_SC
if cluster_method == "nme-sc":
self.cluster = cluster_NME_SC


self.vad_model, self.get_speech_ts = self.setup_VAD()

Expand All @@ -56,7 +60,7 @@ def __init__(

def setup_VAD(self):
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad"
repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=True
)
# force_reload=True)

Expand Down Expand Up @@ -182,6 +186,7 @@ def diarize(
self,
wav_file,
num_speakers=2,
max_speakers=None,
threshold=None,
silence_tolerance=0.2,
enhance_sim=True,
Expand All @@ -194,6 +199,7 @@ def diarize(
Inputs:
wav_file (path): Path to input audio file
num_speakers (int) or NoneType: Number of speakers to cluster to
max_speakers (int)
threshold (float) or NoneType: Threshold to cluster to if
num_speakers is not defined
silence_tolerance (float): Same speaker segments which are close enough together
Expand Down Expand Up @@ -229,10 +235,10 @@ def diarize(
'cluster_labels': cluster_labels (list): cluster label for each embed in embeds
}

Uses AHC/SC to cluster
Uses AHC/SC/NME-SC to cluster
"""
recname = os.path.splitext(os.path.basename(wav_file))[0]

if check_wav_16khz_mono(wav_file):
signal, fs = torchaudio.load(wav_file)
else:
Expand All @@ -249,25 +255,34 @@ def diarize(
print("Running VAD...")
speech_ts = self.vad(signal[0])
print("Splitting by silence found {} utterances".format(len(speech_ts)))
assert len(speech_ts) >= 1, "Couldn't find any speech during VAD"

print("Extracting embeddings...")
embeds, segments = self.recording_embeds(signal, fs, speech_ts)

print("Clustering to {} speakers...".format(num_speakers))
cluster_labels = self.cluster(
embeds,
n_clusters=num_speakers,
threshold=threshold,
enhance_sim=enhance_sim,
)

print("Cleaning up output...")
cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
cleaned_segments = self.join_samespeaker_segments(
cleaned_segments, silence_tolerance=silence_tolerance
)
#assert len(speech_ts) >= 1, "Couldn't find any speech during VAD"

if len(speech_ts) >= 1:
print("Extracting embeddings...")
embeds, segments = self.recording_embeds(signal, fs, speech_ts)

[w,k]=embeds.shape
if w >= 2:
print('Clustering to {} speakers...'.format(num_speakers))
cluster_labels = self.cluster(embeds, n_clusters=num_speakers,max_speakers=max_speakers,
threshold=threshold, enhance_sim=enhance_sim)



cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
cleaned_segments = self.join_samespeaker_segments(cleaned_segments,
silence_tolerance=silence_tolerance)


else:
cluster_labels =[ 1]
cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)

else:
cleaned_segments = []

print("Done!")
if outfile:
self.rttm_output(cleaned_segments, recname, outfile=outfile)
Expand All @@ -281,9 +296,9 @@ def diarize(
"cluster_labels": cluster_labels}

@staticmethod
def rttm_output(segments, recname, outfile=None):
def rttm_output(segments, recname, outfile=None, channel=0):
assert outfile, "Please specify an outfile"
rttm_line = "SPEAKER {} 0 {} {} <NA> <NA> {} <NA> <NA>\n"
rttm_line = "SPEAKER {} "+str(channel)+" {} {} <NA> <NA> {} <NA> <NA>\n"
with open(outfile, "w") as fp:
for seg in segments:
start = seg["start"]
Expand Down
Loading