Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Sources/WhisperKit/Core/Configurations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ open class WhisperKitConfig {
public var modelRepo: String?
/// Token for downloading models from repo (if required)
public var modelToken: String?
/// Endpoint for downloading models and tokenizers
public var endpoint: String

/// Folder to store models
public var modelFolder: String?
Expand Down Expand Up @@ -50,6 +52,7 @@ open class WhisperKitConfig {
downloadBase: URL? = nil,
modelRepo: String? = nil,
modelToken: String? = nil,
endpoint: String = Constants.defaultRemoteEndpoint,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
Expand All @@ -72,6 +75,7 @@ open class WhisperKitConfig {
self.downloadBase = downloadBase
self.modelRepo = modelRepo
self.modelToken = modelToken
self.endpoint = endpoint
self.modelFolder = modelFolder
self.tokenizerFolder = tokenizerFolder
self.computeOptions = computeOptions
Expand Down
26 changes: 22 additions & 4 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ open class WhisperKit {
public var modelFolder: URL?
public var tokenizerFolder: URL?
public private(set) var useBackgroundDownloadSession: Bool
public private(set) var endpoint: String

/// Callbacks
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
Expand All @@ -68,6 +69,7 @@ open class WhisperKit {
voiceActivityDetector = config.voiceActivityDetector
tokenizerFolder = config.tokenizerFolder ?? config.downloadBase
useBackgroundDownloadSession = config.useBackgroundDownloadSession
endpoint = config.endpoint
currentTimings = TranscriptionTimings()
Logging.shared.logLevel = config.verbose ? config.logLevel : .none

Expand All @@ -77,7 +79,8 @@ open class WhisperKit {
modelRepo: config.modelRepo,
modelToken: config.modelToken,
modelFolder: config.modelFolder,
download: config.download
download: config.download,
endpoint: config.endpoint
)

if let prewarm = config.prewarm, prewarm {
Expand Down Expand Up @@ -110,12 +113,14 @@ open class WhisperKit {
prewarm: Bool? = nil,
load: Bool? = nil,
download: Bool = true,
useBackgroundDownloadSession: Bool = false
useBackgroundDownloadSession: Bool = false,
endpoint: String = Constants.defaultRemoteEndpoint
) async throws {
let config = WhisperKitConfig(
model: model,
downloadBase: downloadBase,
modelRepo: modelRepo,
endpoint: endpoint,
modelFolder: modelFolder,
tokenizerFolder: tokenizerFolder,
computeOptions: computeOptions,
Expand Down Expand Up @@ -296,6 +301,18 @@ open class WhisperKit {
}
}

/// Downloads tokenizer for the specified model variant
public static func downloadTokenizer(
for variant: ModelVariant,
downloadBase: URL? = nil,
endpoint: String = Constants.defaultRemoteEndpoint
) async throws -> URL {
let tokenizerName = ModelUtilities.tokenizerNameForVariant(variant)
let hubApi = HubApi(downloadBase: downloadBase, endpoint: endpoint)
let tokenizerFolder = try await hubApi.snapshot(from: Hub.Repo(id: tokenizerName, type: .models), matching: ["*"])
return tokenizerFolder
}

/// Sets up the model folder either from a local path or by downloading from a repository.
open func setupModels(
model: String?,
Expand All @@ -305,7 +322,7 @@ open class WhisperKit {
modelFolder: String?,
download: Bool,
remoteConfigName: String = Constants.defaultRemoteConfigName,
endpoint: String = Constants.defaultRemoteEndpoint
endpoint: String
) async throws {
// If a local model folder is provided, use it; otherwise, download the model
if let folder = modelFolder {
Expand Down Expand Up @@ -485,7 +502,8 @@ open class WhisperKit {
for: modelVariant,
tokenizerFolder: tokenizerFolder,
additionalSearchPaths: additionalSearchPaths,
useBackgroundSession: useBackgroundDownloadSession
useBackgroundSession: useBackgroundDownloadSession,
endpoint: self.endpoint
)
currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - tokenizerLoadStart

Expand Down
8 changes: 5 additions & 3 deletions Sources/WhisperKit/Utilities/ModelUtilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ public struct ModelUtilities {
for pretrained: ModelVariant,
tokenizerFolder: URL? = nil,
additionalSearchPaths: [URL] = [],
useBackgroundSession: Bool = false
useBackgroundSession: Bool = false,
endpoint: String
Comment thread
lemo366 marked this conversation as resolved.
Outdated
) async throws -> WhisperTokenizer {
let tokenizerName = tokenizerNameForVariant(pretrained)
let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession)
Expand Down Expand Up @@ -267,9 +268,10 @@ public struct ModelUtilities {
public func loadTokenizer(
for pretrained: ModelVariant,
tokenizerFolder: URL? = nil,
useBackgroundSession: Bool = false
useBackgroundSession: Bool = false,
endpoint: String
) async throws -> WhisperTokenizer {
return try await ModelUtilities.loadTokenizer(for: pretrained, tokenizerFolder: tokenizerFolder, useBackgroundSession: useBackgroundSession)
return try await ModelUtilities.loadTokenizer(for: pretrained, tokenizerFolder: tokenizerFolder, useBackgroundSession: useBackgroundSession, endpoint: endpoint)
}

@available(*, deprecated, message: "Subject to removal in a future version. Use ModelUtilities.modelSupport(for:from:) -> ModelSupport instead.")
Expand Down
54 changes: 31 additions & 23 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import Tokenizers
import XCTest

final class UnitTests: XCTestCase {
let endpoint = Constants.defaultRemoteEndpoint

override func setUp() async throws {
Logging.shared.logLevel = .debug
}
Expand Down Expand Up @@ -718,7 +720,7 @@ final class UnitTests: XCTestCase {
"Failed to load the model"
)
textDecoder.tokenizer = try await XCTUnwrapAsync(
await ModelUtilities.loadTokenizer(for: .tiny),
await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint),
"Failed to load the tokenizer"
)

Expand Down Expand Up @@ -752,7 +754,7 @@ final class UnitTests: XCTestCase {
let textDecoder = TextDecoder()
let modelPath = try await URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc")
try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions)

Expand All @@ -776,7 +778,7 @@ final class UnitTests: XCTestCase {
let textDecoder = TextDecoder()
let modelPath = try await URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc")
try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions)

Expand Down Expand Up @@ -950,15 +952,17 @@ final class UnitTests: XCTestCase {
let repoTokenizer = try await ModelUtilities.loadTokenizer(
for: .tiny,
tokenizerFolder: tempDir,
useBackgroundSession: false
useBackgroundSession: false,
endpoint: endpoint
) as! WhisperTokenizerWrapper
XCTAssertEqual(repoTokenizer.tokenizerFolder?.path, repoStyleDir.path, "tokenizerFolder should exactly match repoStyleDir path")

// Check that direct loading from the top level of the directory is prioritized when no repo style exists (tokenizer already exists here)
let tokenizer = try await ModelUtilities.loadTokenizer(
for: .tiny,
tokenizerFolder: repoStyleDir,
useBackgroundSession: false
useBackgroundSession: false,
endpoint: endpoint
) as! WhisperTokenizerWrapper
XCTAssertEqual(tokenizer.tokenizerFolder?.path, repoStyleDir.path, "tokenizerFolder should exactly match repoStyleDir path")

Expand All @@ -974,7 +978,8 @@ final class UnitTests: XCTestCase {
let tokenizerAtTopLevel = try await ModelUtilities.loadTokenizer(
for: .tiny,
tokenizerFolder: tempDir,
useBackgroundSession: false
useBackgroundSession: false,
endpoint: endpoint
) as! WhisperTokenizerWrapper
XCTAssertEqual(tokenizerAtTopLevel.tokenizerFolder?.path, repoStyleDir.path, "tokenizerFolder should exactly match repoStyleDir path")
}
Expand All @@ -986,7 +991,8 @@ final class UnitTests: XCTestCase {
let tokenizer = try await ModelUtilities.loadTokenizer(
for: variant,
tokenizerFolder: nil,
useBackgroundSession: false
useBackgroundSession: false,
endpoint: endpoint
) as! WhisperTokenizerWrapper
XCTAssertNotNil(tokenizer, "Should load tokenizer for variant \(variant)")
XCTAssertTrue(tokenizer.tokenizerFolder!.path.contains(expectedName), "Tokenizer folder should contain \(expectedName)")
Expand Down Expand Up @@ -1015,7 +1021,8 @@ final class UnitTests: XCTestCase {

// Load a tokenizer that should match
let tokenizer = try await ModelUtilities.loadTokenizer(
for: .tiny
for: .tiny,
endpoint: endpoint
) as! WhisperTokenizerWrapper

// Verify tokenizer location
Expand Down Expand Up @@ -1066,7 +1073,8 @@ final class UnitTests: XCTestCase {
let tokenizer = try await ModelUtilities.loadTokenizer(
for: .tiny,
tokenizerFolder: corruptedTokenizerDir,
useBackgroundSession: false
useBackgroundSession: false,
endpoint: endpoint
)
XCTAssertNotNil(tokenizer, "Should successfully fall back to Hub when local loading fails")
} catch {
Expand Down Expand Up @@ -1227,11 +1235,11 @@ final class UnitTests: XCTestCase {
let tokenText = "<|startoftranscript|>"

let textDecoder = TextDecoder()
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)
let encodedToken = try XCTUnwrap(textDecoder.tokenizer?.convertTokenToId(tokenText))
let decodedToken = try XCTUnwrap(textDecoder.tokenizer?.decode(tokens: [encodedToken]))

textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .largev3)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .largev3, endpoint: endpoint)
let encodedTokenLarge = try XCTUnwrap(textDecoder.tokenizer?.convertTokenToId(tokenText))
let decodedTokenLarge = try XCTUnwrap(textDecoder.tokenizer?.decode(tokens: [encodedTokenLarge]))

Expand All @@ -1246,11 +1254,11 @@ final class UnitTests: XCTestCase {
// This token index changes with v3
let tokenTextShifted = "<|0.00|>"

textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)
let encodedTokenShifted = try XCTUnwrap(textDecoder.tokenizer?.convertTokenToId(tokenTextShifted))
let decodedTokenShifted = try XCTUnwrap(textDecoder.tokenizer?.decode(tokens: [encodedTokenShifted]))

textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .largev3)
textDecoder.tokenizer = try await ModelUtilities.loadTokenizer(for: .largev3, endpoint: endpoint)
let encodedTokenLargeShifted = try XCTUnwrap(textDecoder.tokenizer?.convertTokenToId(tokenTextShifted))
let decodedTokenLargeShifted = try XCTUnwrap(textDecoder.tokenizer?.decode(tokens: [encodedTokenLargeShifted]))

Expand All @@ -1265,7 +1273,7 @@ final class UnitTests: XCTestCase {
func testTokenizerOutput() async throws {
let tokenInputs = [50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50889]

let tokenizer = try await ModelUtilities.loadTokenizer(for: .largev3)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .largev3, endpoint: endpoint)
let decodedText = tokenizer.decode(tokens: tokenInputs)

XCTAssertNotNil(decodedText)
Expand Down Expand Up @@ -1300,7 +1308,7 @@ final class UnitTests: XCTestCase {
}

func testSplitToWordTokens() async throws {
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

// Hello, world! This is a test, isn't it?
let tokenIds = [50364, 2425, 11, 1002, 0, 50414, 50414, 639, 307, 257, 220, 31636, 11, 1943, 380, 309, 30, 50257]
Expand All @@ -1317,7 +1325,7 @@ final class UnitTests: XCTestCase {
}

func testSplitToWordTokensSpanish() async throws {
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

// ¡Hola Mundo! Esta es una prueba, ¿no?
let tokenIds = [50363, 24364, 48529, 376, 6043, 0, 20547, 785, 2002, 48241, 11, 3841, 1771, 30, 50257]
Expand All @@ -1334,7 +1342,7 @@ final class UnitTests: XCTestCase {
}

func testSplitToWordTokensJapanese() async throws {
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

// こんにちは、世界!これはテストですよね?
let tokenIds = [50364, 38088, 1231, 24486, 171, 120, 223, 25212, 22985, 40498, 4767, 30346, 171, 120, 253, 50257]
Expand Down Expand Up @@ -2416,7 +2424,7 @@ final class UnitTests: XCTestCase {
}
}

let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

let wordTokenIds = [400, 370, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 1029, 437, 291, 393, 360, 337, 428, 1941, 13]
let result = try SegmentSeeker().findAlignment(
Expand Down Expand Up @@ -2780,7 +2788,7 @@ final class UnitTests: XCTestCase {
lastSpeechTimestamp: 0,
constrainedMedianDuration: constrainedMedianDuration,
maxDuration: maxDuration,
tokenizer: try! await ModelUtilities.loadTokenizer(for: .tiny)
tokenizer: try! await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)
)

let updatedWords = updatedSegments.compactMap { $0.words }.flatMap { $0 }
Expand Down Expand Up @@ -2877,7 +2885,7 @@ final class UnitTests: XCTestCase {
lastSpeechTimestamp: 0,
constrainedMedianDuration: constrainedMedianDuration,
maxDuration: maxDuration,
tokenizer: try! await ModelUtilities.loadTokenizer(for: .tiny)
tokenizer: try! await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)
)

let updatedWords = updatedSegments.first!.words!
Expand Down Expand Up @@ -3102,7 +3110,7 @@ final class UnitTests: XCTestCase {
let customFilter = PlusOneFilter()
let decoder = TextDecoder()
decoder.logitsFilters = [customFilter]
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

let options = DecodingOptions(
withoutTimestamps: true,
Expand All @@ -3124,7 +3132,7 @@ final class UnitTests: XCTestCase {

func testCreateLogitsFiltersWithSuppressBlank() async throws {
let decoder = TextDecoder()
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

let options = DecodingOptions(
withoutTimestamps: true,
Expand Down Expand Up @@ -3171,7 +3179,7 @@ final class UnitTests: XCTestCase {

func testCreateLogitsFiltersWithTimestamps() async throws {
let decoder = TextDecoder()
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny)
let tokenizer = try await ModelUtilities.loadTokenizer(for: .tiny, endpoint: endpoint)

let options = DecodingOptions(
withoutTimestamps: false,
Expand Down