Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions Sources/SpeakerKit/DiarizationResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,53 @@ public struct DiarizationResult: Sendable {
public private(set) var segments: [SpeakerSegment]
public var timings: (any DiarizationTimings)?

/// Per-speaker centroid embeddings keyed by `speakerId`, in the raw speaker-embedder output
/// space (unnormalised, pre-PLDA). Useful for linking the same speaker across independent
/// `diarize(...)` calls without re-running the embedder.
///
/// Each centroid is the arithmetic mean of the final per-window embeddings assigned to that
/// `speakerId` after clustering and cluster reassignment, so the centroid reflects the
/// speaker's actual membership in this result.
///
/// Compare centroids with cosine distance via `centroidCosineDistance(between:_:)` or
/// `nearestSpeakerCentroid(to:)`, matching the convention used by
/// `MathOps.cosineDistanceMatrix` elsewhere in SpeakerKit. SpeakerKit does not define a
/// universal "same speaker" threshold for comparing centroids across independent runs;
/// callers should calibrate that policy for their model, audio, and application.
///
/// This field is populated by the Pyannote backend (`PyannoteDiarizer`). Other backends
/// conforming to `Diarizer` may leave it as `[:]` if they do not expose per-cluster centroids.
public private(set) var speakerCentroidEmbeddings: [Int: [Float]]
Comment thread
leecrossley marked this conversation as resolved.

/// Pyannote init: builds segments from binary speaker activity matrix
init(binaryMatrix: [[Int]], diarizationFrameRate: Float) {
init(binaryMatrix: [[Int]], diarizationFrameRate: Float, speakerCentroidEmbeddings: [Int: [Float]] = [:]) {
self.binaryMatrix = binaryMatrix
self.frameRate = diarizationFrameRate
self.speakerCount = binaryMatrix.count
self.totalFrames = speakerCount > 0 ? binaryMatrix[0].count : 0
self.segments = []
self.timings = nil
self.speakerCentroidEmbeddings = speakerCentroidEmbeddings

self.updateSegments(minActiveOffset: 0.0)
}

/// Generic init: for engines that produce segments directly
public init(speakerCount: Int, totalFrames: Int, frameRate: Float, segments: [SpeakerSegment], timings: (any DiarizationTimings)? = nil) {
public init(
speakerCount: Int,
totalFrames: Int,
frameRate: Float,
segments: [SpeakerSegment],
timings: (any DiarizationTimings)? = nil,
speakerCentroidEmbeddings: [Int: [Float]] = [:]
) {
self.binaryMatrix = []
self.speakerCount = speakerCount
self.totalFrames = totalFrames
self.frameRate = frameRate
self.segments = segments
self.timings = timings
self.speakerCentroidEmbeddings = speakerCentroidEmbeddings
}

public mutating func updateSegments(minActiveOffset: Float) {
Expand Down Expand Up @@ -101,6 +128,49 @@ public struct DiarizationResult: Sendable {
self.segments = segments.sorted { $0.startFrame < $1.startFrame }
}

// MARK: - Speaker Centroid Comparison

/// Cosine distance in `[0.0, 2.0]` between two speaker centroids from this result.
Comment thread
leecrossley marked this conversation as resolved.
///
/// Delegates to `MathOps.cosineDistance(_:_:)`, matching the convention used by
/// `MathOps.cosineDistanceMatrix` elsewhere in SpeakerKit. The result is clamped to
/// `[0, 2]` to absorb floating-point error near the extremes. A distance of `0` means
/// identical direction, `1` means orthogonal vectors (no directional similarity), and
/// `2` means opposite direction.
///
/// - Returns: `nil` if either `speakerId` is absent from
/// ``speakerCentroidEmbeddings``, the centroids have different dimensions, or either
/// vector is empty. Zero-magnitude centroids (unreachable in real diarization runs)
/// yield `MathOps.cosineDistance`'s sentinel of `1.0`.
public func centroidCosineDistance(between a: Int, _ b: Int) -> Float? {
Comment thread
leecrossley marked this conversation as resolved.
Outdated
Comment thread
leecrossley marked this conversation as resolved.
Outdated
guard let lhs = speakerCentroidEmbeddings[a],
let rhs = speakerCentroidEmbeddings[b],
lhs.count == rhs.count, !lhs.isEmpty else { return nil }
return MathOps.cosineDistance(lhs, rhs)
}

/// Nearest centroid in this result to an external speaker embedding.
///
/// This is a pure nearest-neighbour lookup over ``speakerCentroidEmbeddings``. It does not
/// apply a same-speaker threshold; callers should interpret the returned distance according
/// to their own calibration.
///
/// - Returns: The nearest compatible centroid, or `nil` when `embedding` is empty, no
/// centroid exists, or all stored centroids have different dimensions.
public func nearestSpeakerCentroid(to embedding: [Float]) -> (speakerId: Int, distance: Float)? {
Comment thread
leecrossley marked this conversation as resolved.
guard !embedding.isEmpty else { return nil }

var nearest: (speakerId: Int, distance: Float)?
for (speakerId, centroid) in speakerCentroidEmbeddings where centroid.count == embedding.count {
let distance = MathOps.cosineDistance(embedding, centroid)
if nearest == nil || distance < (nearest?.distance ?? .infinity) {
nearest = (speakerId, distance)
}
}

return nearest
}

// MARK: - Speaker Info Matching

public func addSpeakerInfo(to transcription: [TranscriptionResult], strategy: SpeakerInfoStrategy = SpeakerInfoStrategy.subsegment) -> [[SpeakerSegment]] {
Expand Down
8 changes: 6 additions & 2 deletions Sources/SpeakerKit/Pyannote/PyannoteDiarizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ actor PyannoteDiarizerActor {
progressObj.completedUnitCount = 80
progressCallback?(progressObj)
var diarizationResult = postProcess(speakerEmbeddings: clusteringResult.speakerEmbeddings,
speakerCentroids: clusteringResult.speakerCentroids,
originalLength: audioLength,
useExclusiveReconciliation: resolvedOptions.useExclusiveReconciliation)
timings.numberOfSpeakers = diarizationResult.speakerCount
Expand All @@ -268,7 +269,10 @@ actor PyannoteDiarizerActor {
return diarizationResult
}

private func postProcess(speakerEmbeddings: [SpeakerEmbedding], originalLength: Int, useExclusiveReconciliation: Bool) -> DiarizationResult {
private func postProcess(speakerEmbeddings: [SpeakerEmbedding],
speakerCentroids: [Int: [Float]],
originalLength: Int,
useExclusiveReconciliation: Bool) -> DiarizationResult {
let startTime = CFAbsoluteTimeGetCurrent()
defer {
let totalTime = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000
Expand Down Expand Up @@ -360,7 +364,7 @@ actor PyannoteDiarizerActor {
}
}

return DiarizationResult(binaryMatrix: binaryDiarization, diarizationFrameRate: diarizationFrameRate)
return DiarizationResult(binaryMatrix: binaryDiarization, diarizationFrameRate: diarizationFrameRate, speakerCentroidEmbeddings: speakerCentroids)
}

func diarize(audioArray: [Float], options: (any DiarizationOptions)?, progressCallback: (@Sendable (Progress) -> Void)?) async throws -> DiarizationResult {
Expand Down
5 changes: 4 additions & 1 deletion Sources/SpeakerKit/Pyannote/SpeakerClustering.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ struct VBxClusteringConfig: Sendable {
struct ClusteringResult {
let clusterIndices: [Int]
let speakerEmbeddings: [SpeakerEmbedding]
let speakerCentroids: [Int: [Float]]

init(clusterIndices: [Int],
speakerEmbeddings: [SpeakerEmbedding]) {
speakerEmbeddings: [SpeakerEmbedding],
speakerCentroids: [Int: [Float]] = [:]) {
self.clusterIndices = clusterIndices
self.speakerEmbeddings = speakerEmbeddings
self.speakerCentroids = speakerCentroids
}
}

Expand Down
54 changes: 41 additions & 13 deletions Sources/SpeakerKit/Pyannote/VBxClustering.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,24 @@ actor VBxClustering: Clusterer {

_speakerEmbeddings.sort { ($0.windowIndex, $0.speakerIndex) < ($1.windowIndex, $1.speakerIndex) }

let (clusters, _) = cluster(embeddings: _speakerEmbeddings, config: config)
let (clusters, _, centroids) = cluster(embeddings: _speakerEmbeddings, config: config)

for (clusterIndex, clusterId) in clusters.enumerated() {
_speakerEmbeddings[clusterIndex].clusterId = clusterId
}

// Key centroids by the final clusterId so each entry matches its speakerId membership.
let distinctClusterIds = Set(clusters.filter { $0 >= 0 })
var centroidMap: [Int: [Float]] = [:]
centroidMap.reserveCapacity(distinctClusterIds.count)
for clusterId in distinctClusterIds where clusterId < centroids.count {
centroidMap[clusterId] = centroids[clusterId]
}

return ClusteringResult(
clusterIndices: clusters,
speakerEmbeddings: _speakerEmbeddings
speakerEmbeddings: _speakerEmbeddings,
speakerCentroids: centroidMap
)
}

Expand All @@ -45,9 +54,10 @@ actor VBxClustering: Clusterer {
func cluster(
embeddings: [SpeakerEmbedding],
config: VBxClusteringConfig
) -> (clusters: [Int], linkageMatrix: [[Float]]) {
) -> (clusters: [Int], linkageMatrix: [[Float]], centroids: [[Float]]) {
let trainableEmbeddings = embeddings.filter { $0.nonOverlappedFrameRatio > config.minActiveRatio }
let embeddingsFloats = trainableEmbeddings.map { $0.embedding }
let allEmbeddingsFloats = embeddings.map { $0.embedding }

let pldaEmbeddingsFloats = trainableEmbeddings.map { $0.pldaEmbedding ?? [] }

Expand Down Expand Up @@ -88,6 +98,7 @@ actor VBxClustering: Clusterer {

let clusterAssignments = speakerWeights.isEmpty ? clusters : MathOps.argmax(speakerWeights, axis: 0)
var centroids = calculateCentroids(speakerWeights: speakerWeights, embeddings: embeddingsFloats)
Comment thread
leecrossley marked this conversation as resolved.
// These centroids seed cluster reassignment; returned centroids are recomputed below.

let autoSpeakerCount = centroids.count
Logging.debug("VBx clustering completed with \(autoSpeakerCount) speakers")
Expand All @@ -97,23 +108,29 @@ actor VBxClustering: Clusterer {
if let requestedSpeakers = config.numSpeakers, autoSpeakerCount != requestedSpeakers {
Logging.debug("K-Means correction: VBx gave \(autoSpeakerCount) speakers, requested \(requestedSpeakers)")
let kAssignments = ClusterAlgorithms.kMeans(embeddings: embeddingsNormalized, clusterCount: requestedSpeakers)
centroids = centroidsFromAssignments(assignments: kAssignments, embeddings: embeddingsFloats, k: requestedSpeakers)
centroids = centroidsFromAssignments(
assignments: kAssignments,
embeddings: embeddingsFloats,
clusterCount: requestedSpeakers
)
}

if !centroids.isEmpty {
let allEmbeddingsFloats = embeddings.map { $0.embedding }
clusters = clusterReassignment(embeddings: allEmbeddingsFloats, centroids: centroids)
Logging.debug("Cluster reassignment completed")
} else {
// clusterAssignments covers only trainableEmbeddings (T ≤ N). Derive centroids
// from those AHC assignments so clusterReassignment can cover all N embeddings to match path above.
let numClusters = (clusterAssignments.max() ?? -1) + 1
let fallbackCentroids = numClusters > 0
? centroidsFromAssignments(assignments: clusterAssignments, embeddings: embeddingsFloats, k: numClusters)
? centroidsFromAssignments(
assignments: clusterAssignments,
embeddings: embeddingsFloats,
clusterCount: numClusters
)
: []

if !fallbackCentroids.isEmpty {
let allEmbeddingsFloats = embeddings.map { $0.embedding }
clusters = clusterReassignment(embeddings: allEmbeddingsFloats, centroids: fallbackCentroids)
Logging.debug("Cluster reassignment from AHC fallback completed")
} else {
Expand All @@ -122,7 +139,18 @@ actor VBxClustering: Clusterer {
}
}

return (clusters, linkageMatrix)
// Returned centroids are the arithmetic mean of embeddings under the final assignment,
// uniform across all paths (VBx weighted, kMeans correction, AHC fallback).
let numFinalClusters = (clusters.max() ?? -1) + 1
let finalCentroids: [[Float]] = numFinalClusters > 0
? centroidsFromAssignments(
assignments: clusters,
embeddings: allEmbeddingsFloats,
clusterCount: numFinalClusters
)
: []

return (clusters, linkageMatrix, finalCentroids)
}

// MARK: - Internal Methods
Expand Down Expand Up @@ -176,24 +204,24 @@ actor VBxClustering: Clusterer {
return clusterIndices
}

private func centroidsFromAssignments(assignments: [Int], embeddings: [[Float]], k: Int) -> [[Float]] {
func centroidsFromAssignments(assignments: [Int], embeddings: [[Float]], clusterCount: Int) -> [[Float]] {
guard !embeddings.isEmpty, !embeddings[0].isEmpty else { return [] }
let dim = embeddings[0].count
var sums = Array(repeating: Array(repeating: Float(0), count: dim), count: k)
var counts = Array(repeating: 0, count: k)
var sums = Array(repeating: Array(repeating: Float(0), count: dim), count: clusterCount)
var counts = Array(repeating: 0, count: clusterCount)
for (i, assignment) in assignments.enumerated() {
guard i < embeddings.count else { continue }
counts[assignment] += 1
for d in 0..<dim { sums[assignment][d] += embeddings[i][d] }
}
return (0..<k).map { ki in
return (0..<clusterCount).map { ki in
let count = counts[ki]
guard count > 0 else { return sums[ki] }
return sums[ki].map { $0 / Float(count) }
}
}

private func calculateCentroids(speakerWeights: [[Float]], embeddings: [[Float]]) -> [[Float]] {
func calculateCentroids(speakerWeights: [[Float]], embeddings: [[Float]]) -> [[Float]] {
guard !speakerWeights.isEmpty, !embeddings.isEmpty, !embeddings[0].isEmpty else {
return []
}
Expand Down
Loading