Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Sources/SpeakerKit/DiarizationResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@ public struct DiarizationResult: Sendable {
public let frameRate: Float
public private(set) var segments: [SpeakerSegment]
public var timings: (any DiarizationTimings)?
public var speakerCentroidEmbeddings: [Int: [Float]]
Comment thread
leecrossley marked this conversation as resolved.
Outdated

/// Pyannote init: builds segments from binary speaker activity matrix
init(binaryMatrix: [[Int]], diarizationFrameRate: Float) {
init(binaryMatrix: [[Int]], diarizationFrameRate: Float, speakerCentroidEmbeddings: [Int: [Float]] = [:]) {
self.binaryMatrix = binaryMatrix
self.frameRate = diarizationFrameRate
self.speakerCount = binaryMatrix.count
self.totalFrames = speakerCount > 0 ? binaryMatrix[0].count : 0
self.segments = []
self.timings = nil
self.speakerCentroidEmbeddings = speakerCentroidEmbeddings

self.updateSegments(minActiveOffset: 0.0)
}
Expand All @@ -51,6 +53,7 @@ public struct DiarizationResult: Sendable {
self.frameRate = frameRate
self.segments = segments
self.timings = timings
self.speakerCentroidEmbeddings = [:]
}

public mutating func updateSegments(minActiveOffset: Float) {
Expand Down
23 changes: 22 additions & 1 deletion Sources/SpeakerKit/Pyannote/PyannoteDiarizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,27 @@ actor PyannoteDiarizerActor {
return DiarizationResult(binaryMatrix: [], diarizationFrameRate: diarizationFrameRate)
}

var centroidSums: [Int: [Float]] = [:]
var centroidCounts: [Int: Int] = [:]
for emb in speakerEmbeddings {
Comment thread
leecrossley marked this conversation as resolved.
Outdated
guard emb.clusterId >= 0, !emb.embedding.isEmpty else { continue }
if var existing = centroidSums[emb.clusterId] {
for i in 0..<emb.embedding.count {
existing[i] += emb.embedding[i]
}
centroidSums[emb.clusterId] = existing
centroidCounts[emb.clusterId] = centroidCounts[emb.clusterId]! + 1
} else {
centroidSums[emb.clusterId] = emb.embedding
centroidCounts[emb.clusterId] = 1
}
}
var centroidEmbeddings: [Int: [Float]] = [:]
for (clusterId, sum) in centroidSums {
let count = Float(centroidCounts[clusterId]!)
centroidEmbeddings[clusterId] = sum.map { $0 / count }
}

let speakerCount = (speakerEmbeddings.map { $0.clusterId }.max() ?? 0) + 1
let chunkLength = SpeakerSegmenterModel.chunkLengthInSeconds
let maxChunks = config.segmenterModel.maxChunks(for: originalLength)
Expand Down Expand Up @@ -360,7 +381,7 @@ actor PyannoteDiarizerActor {
}
}

return DiarizationResult(binaryMatrix: binaryDiarization, diarizationFrameRate: diarizationFrameRate)
return DiarizationResult(binaryMatrix: binaryDiarization, diarizationFrameRate: diarizationFrameRate, speakerCentroidEmbeddings: centroidEmbeddings)
}

func diarize(audioArray: [Float], options: (any DiarizationOptions)?, progressCallback: (@Sendable (Progress) -> Void)?) async throws -> DiarizationResult {
Expand Down
4 changes: 2 additions & 2 deletions Sources/SpeakerKit/Pyannote/SpeakerEmbedderModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import CoreML
import ArgmaxCore

struct SpeakerEmbedding {
let embedding: [Float]
public struct SpeakerEmbedding {
Comment thread
leecrossley marked this conversation as resolved.
Outdated
public let embedding: [Float]
let pldaEmbedding: [Float]?
let activeFrames: [Float]
let windowIndex: Int
Expand Down