Skip to content

Commit 3334a9c

Browse files
authored
Add support for word-level audio transcription timestamp granularity (#733)
* Add support for audio transcription timestamp_granularities word * Fixup multiple timestamp granularities
1 parent c9953a7 commit 3334a9c

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

audio.go

+26-5
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@ const (
2727
AudioResponseFormatVTT AudioResponseFormat = "vtt"
2828
)
2929

30+
type TranscriptionTimestampGranularity string
31+
32+
const (
33+
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word"
34+
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment"
35+
)
36+
3037
// AudioRequest represents a request structure for audio API.
31-
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
3238
type AudioRequest struct {
3339
Model string
3440

@@ -38,10 +44,11 @@ type AudioRequest struct {
3844
// Reader is an optional io.Reader when you do not want to use an existing file.
3945
Reader io.Reader
4046

41-
Prompt string // For translation, it should be in English
42-
Temperature float32
43-
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
44-
Format AudioResponseFormat
47+
Prompt string
48+
Temperature float32
49+
Language string // Only for transcription.
50+
Format AudioResponseFormat
51+
TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription.
4552
}
4653

4754
// AudioResponse represents a response structure for audio API.
@@ -62,6 +69,11 @@ type AudioResponse struct {
6269
NoSpeechProb float64 `json:"no_speech_prob"`
6370
Transient bool `json:"transient"`
6471
} `json:"segments"`
72+
Words []struct {
73+
Word string `json:"word"`
74+
Start float64 `json:"start"`
75+
End float64 `json:"end"`
76+
} `json:"words"`
6577
Text string `json:"text"`
6678

6779
httpHeader
@@ -179,6 +191,15 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
179191
}
180192
}
181193

194+
if len(request.TimestampGranularities) > 0 {
195+
for _, tg := range request.TimestampGranularities {
196+
err = b.WriteField("timestamp_granularities[]", string(tg))
197+
if err != nil {
198+
return fmt.Errorf("writing timestamp_granularities[]: %w", err)
199+
}
200+
}
201+
}
202+
182203
// Close the multipart writer
183204
return b.Close()
184205
}

audio_api_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ func TestAudioWithOptionalArgs(t *testing.T) {
105105
Temperature: 0.5,
106106
Language: "zh",
107107
Format: openai.AudioResponseFormatSRT,
108+
TimestampGranularities: []openai.TranscriptionTimestampGranularity{
109+
openai.TranscriptionTimestampGranularitySegment,
110+
openai.TranscriptionTimestampGranularityWord,
111+
},
108112
}
109113
_, err := tc.createFn(ctx, req)
110114
checks.NoError(t, err, "audio API error")

audio_test.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
2424
Temperature: 0.5,
2525
Language: "en",
2626
Format: AudioResponseFormatSRT,
27+
TimestampGranularities: []TranscriptionTimestampGranularity{
28+
TranscriptionTimestampGranularitySegment,
29+
TranscriptionTimestampGranularityWord,
30+
},
2731
}
2832

2933
mockFailedErr := fmt.Errorf("mock form builder fail")
@@ -47,7 +51,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
4751
return nil
4852
}
4953

50-
failOn := []string{"model", "prompt", "temperature", "language", "response_format"}
54+
failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"}
5155
for _, failingField := range failOn {
5256
failForField = failingField
5357
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)

0 commit comments

Comments
 (0)