Skip to content

Commit 83067fc

Browse files
authored
Merge pull request #20 from WqyJh/feat-transcribe
feat: add tests & examples for transcription
2 parents 3000a4f + e5d3abb commit 83067fc

File tree

10 files changed

+434
-27
lines changed

10 files changed

+434
-27
lines changed

.github/workflows/go-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
runs-on: ubuntu-latest
1111
strategy:
1212
matrix:
13-
go-version: ["1.19", "1.20", "1.22", "1.23"]
13+
go-version: ["1.20", "1.22", "1.23", "1.24"]
1414

1515
steps:
1616
- uses: actions/checkout@v4

api.go

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"io"
89
"net/http"
910
)
1011

@@ -29,6 +30,63 @@ type CreateSessionResponse struct {
2930
ClientSecret ClientSecret `json:"client_secret"`
3031
}
3132

33+
// CreateTranscriptionSessionRequest is the request for creating a transcription session.
34+
type CreateTranscriptionSessionRequest struct {
35+
// The set of items to include in the transcription.
36+
Include []string `json:"include,omitempty"`
37+
// The format of input audio. Options are "pcm16", "g711_ulaw", or "g711_alaw".
38+
InputAudioFormat AudioFormat `json:"input_audio_format,omitempty"`
39+
// Configuration for input audio noise reduction.
40+
InputAudioNoiseReduction *InputAudioNoiseReduction `json:"input_audio_noise_reduction,omitempty"`
41+
// Configuration for input audio transcription.
42+
InputAudioTranscription *InputAudioTranscription `json:"input_audio_transcription,omitempty"`
43+
44+
// Attention: Keep this field empty! It's shocking that this field is documented but not supported.
45+
// You may get error of "Unknown parameter: 'modalities'." if this field is not empty.
46+
// Issue reported: https://community.openai.com/t/unknown-parameter-modalities-when-creating-transcriptionsessions/1150141/6
47+
// Docs: https://platform.openai.com/docs/api-reference/realtime-sessions/create-transcription#realtime-sessions-create-transcription-modalities
48+
// The set of modalities the model can respond with. To disable audio, set this to ["text"].
49+
Modalities []Modality `json:"modalities,omitempty"`
50+
51+
// Configuration for turn detection.
52+
TurnDetection *ClientTurnDetection `json:"turn_detection,omitempty"`
53+
}
54+
55+
// CreateTranscriptionSessionResponse is the response from creating a transcription session.
56+
type CreateTranscriptionSessionResponse struct {
57+
// The unique ID of the session.
58+
ID string `json:"id"`
59+
// The object type, must be "realtime.transcription_session".
60+
Object string `json:"object"`
61+
// The format of input audio.
62+
InputAudioFormat AudioFormat `json:"input_audio_format,omitempty"`
63+
// Configuration of the transcription model.
64+
InputAudioTranscription *InputAudioTranscription `json:"input_audio_transcription,omitempty"`
65+
// The set of modalities.
66+
Modalities []Modality `json:"modalities,omitempty"`
67+
// Configuration for turn detection.
68+
TurnDetection *ServerTurnDetection `json:"turn_detection,omitempty"`
69+
// Ephemeral key returned by the API.
70+
ClientSecret ClientSecret `json:"client_secret"`
71+
}
72+
73+
type OpenAIError struct {
74+
StatusCode int `json:"-"`
75+
Message string `json:"message"`
76+
Type string `json:"type"`
77+
Param string `json:"param"`
78+
Code any `json:"code"`
79+
}
80+
81+
type ErrorResponse struct { //nolint:errname // this is a http error response
82+
StatusCode int `json:"-"`
83+
OpenAIError `json:"error"`
84+
}
85+
86+
func (e *ErrorResponse) Error() string {
87+
return e.OpenAIError.Message
88+
}
89+
3290
type httpOption struct {
3391
client *http.Client
3492
headers http.Header
@@ -83,12 +141,23 @@ func HTTPDo[Q any, R any](ctx context.Context, url string, req *Q, opts ...HTTPO
83141
}
84142
defer response.Body.Close()
85143

144+
data, err = io.ReadAll(response.Body)
145+
if err != nil {
146+
return nil, fmt.Errorf("failed to read response body: %w", err)
147+
}
148+
86149
if response.StatusCode != http.StatusOK {
87-
return nil, fmt.Errorf("http status code: %d", response.StatusCode)
150+
var errResp ErrorResponse
151+
err = json.Unmarshal(data, &errResp)
152+
if err != nil {
153+
return nil, fmt.Errorf("http status code: %d, error: %s", response.StatusCode, string(data))
154+
}
155+
errResp.StatusCode = response.StatusCode
156+
return nil, &errResp
88157
}
89158

90159
var resp R
91-
err = json.NewDecoder(response.Body).Decode(&resp)
160+
err = json.Unmarshal(data, &resp)
92161
if err != nil {
93162
return nil, fmt.Errorf("failed to decode response: %w", err)
94163
}

api_integration_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,51 @@ func TestCreateSession(t *testing.T) {
2828
require.NoError(t, err)
2929
require.NotEmpty(t, session.ClientSecret.Value)
3030
require.NotZero(t, session.ClientSecret.ExpiresAt)
31+
require.Equal(t, openairt.GPT4oRealtimePreview20241217, session.Model)
32+
require.Equal(t, "You are a friendly assistant.", session.Instructions)
3133
t.Logf("session: %+v", session)
3234
}
35+
36+
func TestCreateTranscriptionSession(t *testing.T) {
37+
key := os.Getenv("OPENAI_API_KEY")
38+
if key == "" {
39+
t.Skip("OPENAI_API_KEY is not set")
40+
}
41+
client := openairt.NewClient(key)
42+
session, err := client.CreateTranscriptionSession(context.Background(), &openairt.CreateTranscriptionSessionRequest{
43+
InputAudioFormat: openairt.AudioFormatPcm16,
44+
InputAudioTranscription: &openairt.InputAudioTranscription{
45+
Model: openairt.GPT4oTranscribe,
46+
Language: "en",
47+
},
48+
InputAudioNoiseReduction: &openairt.InputAudioNoiseReduction{
49+
Type: openairt.NearFieldNoiseReduction,
50+
},
51+
// Attention: Keep this field empty! It's shocking that this field is documented but not supported.
52+
// Modalities: []openairt.Modality{
53+
// openairt.ModalityText,
54+
// },
55+
TurnDetection: &openairt.ClientTurnDetection{
56+
Type: openairt.ClientTurnDetectionTypeServerVad,
57+
TurnDetectionParams: openairt.TurnDetectionParams{
58+
Threshold: 0.6,
59+
PrefixPaddingMs: 300,
60+
SilenceDurationMs: 500,
61+
},
62+
},
63+
Include: []string{},
64+
})
65+
require.NoError(t, err)
66+
require.NotEmpty(t, session.ClientSecret.Value)
67+
require.NotZero(t, session.ClientSecret.ExpiresAt)
68+
require.Equal(t, "realtime.transcription_session", session.Object)
69+
require.Equal(t, openairt.AudioFormatPcm16, session.InputAudioFormat)
70+
require.Equal(t, openairt.GPT4oTranscribe, session.InputAudioTranscription.Model)
71+
require.Equal(t, "en", session.InputAudioTranscription.Language)
72+
require.Equal(t, openairt.ServerTurnDetectionTypeServerVad, session.TurnDetection.Type)
73+
require.InEpsilon(t, 0.6, session.TurnDetection.Threshold, 0.0001)
74+
require.Equal(t, 300, session.TurnDetection.PrefixPaddingMs)
75+
require.Equal(t, 500, session.TurnDetection.SilenceDurationMs)
76+
require.Empty(t, session.Modalities)
77+
t.Logf("transcription session: %+v", session)
78+
}

api_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,109 @@ func TestCreateSessionResponse(t *testing.T) {
9595
require.NoError(t, err)
9696
require.Equal(t, expected, actual)
9797
}
98+
99+
func TestCreateTranscriptionSessionRequest(t *testing.T) {
100+
data := `{
101+
"input_audio_format": "pcm16",
102+
"input_audio_transcription": {
103+
"model": "gpt-4o-transcribe",
104+
"language": "en",
105+
"prompt": "This is a test transcription"
106+
},
107+
"input_audio_noise_reduction": {
108+
"type": "near_field"
109+
},
110+
"modalities": ["text"],
111+
"turn_detection": {
112+
"type": "server_vad",
113+
"threshold": 0.6,
114+
"prefix_padding_ms": 300,
115+
"silence_duration_ms": 500
116+
},
117+
"include": ["item.input_audio_transcription.logprobs"]
118+
}`
119+
expected := openairt.CreateTranscriptionSessionRequest{
120+
InputAudioFormat: openairt.AudioFormatPcm16,
121+
InputAudioTranscription: &openairt.InputAudioTranscription{
122+
Model: "gpt-4o-transcribe",
123+
Language: "en",
124+
Prompt: "This is a test transcription",
125+
},
126+
InputAudioNoiseReduction: &openairt.InputAudioNoiseReduction{
127+
Type: "near_field",
128+
},
129+
Modalities: []openairt.Modality{
130+
openairt.ModalityText,
131+
},
132+
TurnDetection: &openairt.ClientTurnDetection{
133+
Type: openairt.ClientTurnDetectionTypeServerVad,
134+
TurnDetectionParams: openairt.TurnDetectionParams{
135+
Threshold: 0.6,
136+
PrefixPaddingMs: 300,
137+
SilenceDurationMs: 500,
138+
},
139+
},
140+
Include: []string{"item.input_audio_transcription.logprobs"},
141+
}
142+
143+
var actual openairt.CreateTranscriptionSessionRequest
144+
err := json.Unmarshal([]byte(data), &actual)
145+
require.NoError(t, err)
146+
require.Equal(t, expected, actual)
147+
148+
actualBytes, err := json.Marshal(actual)
149+
require.NoError(t, err)
150+
jsontools.RequireJSONEq(t, data, string(actualBytes))
151+
}
152+
153+
func TestCreateTranscriptionSessionResponse(t *testing.T) {
154+
data := `{
155+
"id": "trans_123456",
156+
"object": "realtime.transcription_session",
157+
"input_audio_format": "pcm16",
158+
"input_audio_transcription": {
159+
"model": "gpt-4o-transcribe",
160+
"language": "en"
161+
},
162+
"modalities": ["text"],
163+
"turn_detection": {
164+
"type": "server_vad",
165+
"threshold": 0.6,
166+
"prefix_padding_ms": 300,
167+
"silence_duration_ms": 500
168+
},
169+
"client_secret": {
170+
"value": "ek_trans_abc123",
171+
"expires_at": 1234567890
172+
}
173+
}`
174+
expected := openairt.CreateTranscriptionSessionResponse{
175+
ID: "trans_123456",
176+
Object: "realtime.transcription_session",
177+
InputAudioFormat: openairt.AudioFormatPcm16,
178+
InputAudioTranscription: &openairt.InputAudioTranscription{
179+
Model: "gpt-4o-transcribe",
180+
Language: "en",
181+
},
182+
Modalities: []openairt.Modality{
183+
openairt.ModalityText,
184+
},
185+
TurnDetection: &openairt.ServerTurnDetection{
186+
Type: openairt.ServerTurnDetectionTypeServerVad,
187+
TurnDetectionParams: openairt.TurnDetectionParams{
188+
Threshold: 0.6,
189+
PrefixPaddingMs: 300,
190+
SilenceDurationMs: 500,
191+
},
192+
},
193+
ClientSecret: openairt.ClientSecret{
194+
Value: "ek_trans_abc123",
195+
ExpiresAt: 1234567890,
196+
},
197+
}
198+
199+
var actual openairt.CreateTranscriptionSessionResponse
200+
err := json.Unmarshal([]byte(data), &actual)
201+
require.NoError(t, err)
202+
require.Equal(t, expected, actual)
203+
}

client.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package openairt
22

33
import (
44
"context"
5-
"fmt"
5+
"errors"
66
"net/http"
77
"net/url"
88
)
@@ -15,6 +15,13 @@ const (
1515
GPT4oMiniRealtimePreview20241217 = "gpt-4o-mini-realtime-preview-2024-12-17"
1616
)
1717

18+
// Transcription models.
19+
const (
20+
GPT4oTranscribe = "gpt-4o-transcribe"
21+
GPT4oMiniTranscribe = "gpt-4o-mini-transcribe"
22+
Whisper1 = "whisper-1"
23+
)
24+
1825
// Client is OpenAI Realtime API client.
1926
type Client struct {
2027
config ClientConfig
@@ -74,7 +81,7 @@ func WithModel(model string) ConnectOption {
7481
}
7582
}
7683

77-
// Set transcription intent instead of model
84+
// Set transcription intent instead of model.
7885
func WithIntent() ConnectOption {
7986
return func(opts *connectOption) {
8087
opts.intent = "transcription"
@@ -113,10 +120,10 @@ func (c *Client) Connect(ctx context.Context, opts ...ConnectOption) (*Conn, err
113120

114121
// get url by model
115122
var url string
116-
if connectOpts.intent == "" {
123+
if connectOpts.intent == "" { //nolint:gocritic // if conditions would be determined in order
117124
url = c.getURL(connectOpts.model)
118125
} else if c.config.APIType != APITypeOpenAI {
119-
return nil, fmt.Errorf("Azure API type with intent set not implemented");
126+
return nil, errors.New("intent not supported for Azure API type")
120127
} else {
121128
url = c.config.BaseURL + "?" + "intent=" + connectOpts.intent
122129
}
@@ -152,3 +159,15 @@ func (c *Client) CreateSession(ctx context.Context, req *CreateSessionRequest) (
152159
WithHeaders(c.getAPIHeaders()),
153160
)
154161
}
162+
163+
// CreateTranscriptionSession creates a new transcription session.
164+
func (c *Client) CreateTranscriptionSession(ctx context.Context, req *CreateTranscriptionSessionRequest) (*CreateTranscriptionSessionResponse, error) {
165+
return HTTPDo[CreateTranscriptionSessionRequest, CreateTranscriptionSessionResponse](
166+
ctx,
167+
c.config.APIBaseURL+"/realtime/transcription_sessions",
168+
req,
169+
WithClient(c.config.HTTPClient),
170+
WithMethod(http.MethodPost),
171+
WithHeaders(c.getAPIHeaders()),
172+
)
173+
}

client_event.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ type ClientSession struct {
5252
Temperature *float32 `json:"temperature,omitempty"`
5353
// Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or "inf" for the maximum available tokens for a given model. Defaults to "inf".
5454
MaxOutputTokens IntOrInf `json:"max_response_output_tokens,omitempty"`
55+
// Configuration for input audio noise reduction. This can be set to null to turn off. Noise reduction filters audio added to the input audio buffer before it is sent to VAD and the model. Filtering the audio can improve VAD and turn detection accuracy (reducing false positives) and model performance by improving perception of the input audio.
56+
InputAudioNoiseReduction *InputAudioNoiseReduction `json:"input_audio_noise_reduction,omitempty"`
5557
}
5658

5759
// SessionUpdateEvent is the event for session update.

examples/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ require (
1919
golang.org/x/sys v0.25.0 // indirect
2020
)
2121

22-
// replace github.com/WqyJh/go-openai-realtime => ../
22+
replace github.com/WqyJh/go-openai-realtime => ../

0 commit comments

Comments
 (0)