Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Sources/TTSKit/Utilities/Sampling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ public class GreedyTokenSampler: TokenSampling, @unchecked Sendable {
let probsArray = await topKProbs.toFloatArray()
let idxArray = await topKIndices.toIntArray()
let probSum = probsArray.reduce(0, +)
// Numerical underflow at low temperature and small topK (e.g. 0.10
// + 15 over long-form generation) can round every top-k probability
// to zero. Float.random(in: 0..<0) crashes; fall back to greedy
// (the highest-probability token, which topK returns first).
// ref: https://github.com/argmaxinc/argmax-oss-swift/issues/450
guard probSum > 0 else {
return idxArray.first.map(Int32.init) ?? Int32(vocabSize - 1)
}
let randomValue = Float.random(in: 0..<probSum, using: &rng)
var cumulativeSum: Float = 0
for (i, probability) in probsArray.enumerated() {
Expand All @@ -245,6 +253,10 @@ public class GreedyTokenSampler: TokenSampling, @unchecked Sendable {
return idxArray.last.map(Int32.init) ?? Int32(vocabSize - 1)
} else {
let probsArray = await probs.toFloatArray()
let probSum = probsArray.reduce(0, +)
guard probSum > 0 else {
return Int32(vocabSize - 1)
}
let randomValue = Float.random(in: 0..<1, using: &rng)
var cumulativeSum: Float = 0
for (i, probability) in probsArray.enumerated() {
Expand Down