Skip to content

Commit 36c1e6b

Browse files
committed
fix: harden provider model list response parsing
1 parent 0f98a1f commit 36c1e6b

12 files changed

Lines changed: 366 additions & 72 deletions

File tree

provider/elevenlabs/speech/speech.go

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,29 @@ func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
7474

7575
// ListModels returns the speech models exposed by this provider.
7676
func (p *Provider) ListModels(ctx context.Context) ([]*sdk.SpeechModel, error) {
77-
type modelsListResponse struct {
78-
Models []struct {
79-
ModelID string `json:"model_id"`
80-
CanDoTTS bool `json:"can_do_text_to_speech"`
81-
} `json:"models"`
82-
}
83-
84-
resp, err := utils.FetchJSON[modelsListResponse](ctx, p.httpClient, &utils.RequestOptions{
85-
Method: http.MethodGet,
86-
BaseURL: p.baseURL,
87-
Path: "/v1/models",
88-
Headers: map[string]string{"xi-api-key": p.apiKey},
89-
})
77+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/v1/models", http.NoBody)
78+
if err != nil {
79+
return nil, fmt.Errorf("elevenlabs speech: build list models request: %w", err)
80+
}
81+
req.Header.Set("xi-api-key", p.apiKey)
82+
83+
resp, err := p.httpClient.Do(req)
9084
if err != nil {
9185
return nil, fmt.Errorf("elevenlabs speech: list models request failed: %w", err)
9286
}
87+
defer resp.Body.Close()
88+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
89+
body, _ := io.ReadAll(resp.Body)
90+
return nil, fmt.Errorf("elevenlabs speech: unexpected status %d: %s", resp.StatusCode, string(body))
91+
}
9392

94-
models := make([]*sdk.SpeechModel, 0, len(resp.Models))
95-
for _, m := range resp.Models {
93+
rawModels, err := decodeModelsResponse(resp.Body)
94+
if err != nil {
95+
return nil, fmt.Errorf("elevenlabs speech: decode response: %w", err)
96+
}
97+
98+
models := make([]*sdk.SpeechModel, 0, len(rawModels))
99+
for _, m := range rawModels {
96100
if m.CanDoTTS && m.ModelID != "" {
97101
models = append(models, p.SpeechModel(m.ModelID))
98102
}
@@ -103,6 +107,31 @@ func (p *Provider) ListModels(ctx context.Context) ([]*sdk.SpeechModel, error) {
103107
return models, nil
104108
}
105109

110+
type elevenlabsModel struct {
111+
ModelID string `json:"model_id"`
112+
CanDoTTS bool `json:"can_do_text_to_speech"`
113+
}
114+
115+
func decodeModelsResponse(r io.Reader) ([]elevenlabsModel, error) {
116+
body, err := io.ReadAll(r)
117+
if err != nil {
118+
return nil, err
119+
}
120+
121+
var wrapped struct {
122+
Models []elevenlabsModel `json:"models"`
123+
}
124+
if err := json.Unmarshal(body, &wrapped); err == nil && len(wrapped.Models) > 0 {
125+
return wrapped.Models, nil
126+
}
127+
128+
var direct []elevenlabsModel
129+
if err := json.Unmarshal(body, &direct); err != nil {
130+
return nil, err
131+
}
132+
return direct, nil
133+
}
134+
106135
// DoSynthesize synthesizes speech and returns the complete audio bytes.
107136
// Uses the non-streaming /v1/text-to-speech/{voice_id} endpoint.
108137
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {

provider/elevenlabs/speech/speech_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,24 @@ func TestProvider_ListModels(t *testing.T) {
162162
}
163163
}
164164

165+
func TestProvider_ListModels_ArrayResponse(t *testing.T) {
166+
t.Parallel()
167+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
168+
w.Header().Set("Content-Type", "application/json")
169+
_, _ = w.Write([]byte(`[{"model_id":"eleven_multilingual_v2","can_do_text_to_speech":true},{"model_id":"scribe_v1","can_do_text_to_speech":false}]`))
170+
}))
171+
defer srv.Close()
172+
173+
p := New(WithAPIKey("key"), WithBaseURL(srv.URL))
174+
models, err := p.ListModels(context.Background())
175+
if err != nil {
176+
t.Fatalf("ListModels: %v", err)
177+
}
178+
if len(models) != 1 || models[0].ID != "eleven_multilingual_v2" {
179+
t.Fatalf("unexpected models: %+v", models)
180+
}
181+
}
182+
165183
func TestParseConfig(t *testing.T) {
166184
t.Parallel()
167185
cfg := parseConfig(map[string]any{

provider/elevenlabs/transcription/transcription.go

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,29 @@ func (p *Provider) TranscriptionModel(id string) *sdk.TranscriptionModel {
5151
}
5252

5353
func (p *Provider) ListModels(ctx context.Context) ([]*sdk.TranscriptionModel, error) {
54-
type modelsListResponse struct {
55-
Models []map[string]any `json:"models"`
54+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/v1/models", http.NoBody)
55+
if err != nil {
56+
return nil, fmt.Errorf("elevenlabs transcription: build list models request: %w", err)
5657
}
58+
req.Header.Set("xi-api-key", p.apiKey)
5759

58-
resp, err := utils.FetchJSON[modelsListResponse](ctx, p.httpClient, &utils.RequestOptions{
59-
Method: http.MethodGet,
60-
BaseURL: p.baseURL,
61-
Path: "/v1/models",
62-
Headers: map[string]string{"xi-api-key": p.apiKey},
63-
})
60+
resp, err := p.httpClient.Do(req)
6461
if err != nil {
6562
return nil, fmt.Errorf("elevenlabs transcription: list models request failed: %w", err)
6663
}
64+
defer resp.Body.Close()
65+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
66+
body, _ := io.ReadAll(resp.Body)
67+
return nil, fmt.Errorf("elevenlabs transcription: unexpected status %d: %s", resp.StatusCode, string(body))
68+
}
69+
70+
rawModels, err := decodeTranscriptionModelsResponse(resp.Body)
71+
if err != nil {
72+
return nil, fmt.Errorf("elevenlabs transcription: decode response: %w", err)
73+
}
6774

68-
models := make([]*sdk.TranscriptionModel, 0, len(resp.Models))
69-
for _, raw := range resp.Models {
75+
models := make([]*sdk.TranscriptionModel, 0, len(rawModels))
76+
for _, raw := range rawModels {
7077
id, _ := raw["model_id"].(string)
7178
if id == "" {
7279
continue
@@ -87,6 +94,26 @@ func (p *Provider) ListModels(ctx context.Context) ([]*sdk.TranscriptionModel, e
8794
return models, nil
8895
}
8996

97+
func decodeTranscriptionModelsResponse(r io.Reader) ([]map[string]any, error) {
98+
body, err := io.ReadAll(r)
99+
if err != nil {
100+
return nil, err
101+
}
102+
103+
var wrapped struct {
104+
Models []map[string]any `json:"models"`
105+
}
106+
if err := json.Unmarshal(body, &wrapped); err == nil && len(wrapped.Models) > 0 {
107+
return wrapped.Models, nil
108+
}
109+
110+
var direct []map[string]any
111+
if err := json.Unmarshal(body, &direct); err != nil {
112+
return nil, err
113+
}
114+
return direct, nil
115+
}
116+
90117
type audioConfig struct {
91118
LanguageCode string
92119
TagAudioEvents *bool

provider/elevenlabs/transcription/transcription_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ func TestProvider_ListModels(t *testing.T) {
2626
}
2727
}
2828

29+
func TestProvider_ListModels_ArrayResponse(t *testing.T) {
30+
t.Parallel()
31+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32+
_, _ = w.Write([]byte(`[{"model_id":"scribe_v2","can_do_speech_to_text":true},{"model_id":"eleven_v3","can_do_speech_to_text":false}]`))
33+
}))
34+
defer srv.Close()
35+
36+
p := New(WithAPIKey("key"), WithBaseURL(srv.URL))
37+
models, err := p.ListModels(context.Background())
38+
if err != nil {
39+
t.Fatalf("ListModels: %v", err)
40+
}
41+
if len(models) != 1 || models[0].ID != "scribe_v2" {
42+
t.Fatalf("unexpected models: %+v", models)
43+
}
44+
}
45+
2946
func TestProvider_DoTranscribe(t *testing.T) {
3047
t.Parallel()
3148
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

provider/openai/speech/speech.go

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,30 @@ func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
7575

7676
// ListModels returns the speech models exposed by this provider.
7777
func (p *Provider) ListModels(ctx context.Context) ([]*sdk.SpeechModel, error) {
78-
type modelsListResponse struct {
79-
Data []struct {
80-
ID string `json:"id"`
81-
} `json:"data"`
78+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/models", http.NoBody)
79+
if err != nil {
80+
return nil, fmt.Errorf("openai speech: build list models request: %w", err)
8281
}
82+
req.Header.Set("Authorization", "Bearer "+p.apiKey)
8383

84-
resp, err := utils.FetchJSON[modelsListResponse](ctx, p.httpClient, &utils.RequestOptions{
85-
Method: http.MethodGet,
86-
BaseURL: p.baseURL,
87-
Path: "/models",
88-
Headers: map[string]string{"Authorization": "Bearer " + p.apiKey},
89-
})
84+
resp, err := p.httpClient.Do(req)
9085
if err != nil {
9186
return nil, fmt.Errorf("openai speech: list models request failed: %w", err)
9287
}
88+
defer resp.Body.Close()
89+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
90+
body, _ := io.ReadAll(resp.Body)
91+
return nil, fmt.Errorf("openai speech: unexpected status %d: %s", resp.StatusCode, string(body))
92+
}
93+
94+
rawModels, err := decodeModelIDs(resp.Body)
95+
if err != nil {
96+
return nil, fmt.Errorf("openai speech: decode response: %w", err)
97+
}
9398

94-
models := make([]*sdk.SpeechModel, 0, len(resp.Data))
95-
for _, m := range resp.Data {
99+
models := make([]*sdk.SpeechModel, 0, len(rawModels))
100+
for _, id := range rawModels {
101+
m := struct{ ID string }{ID: id}
96102
if isOpenAITTSModel(m.ID) {
97103
models = append(models, p.SpeechModel(m.ID))
98104
}
@@ -108,6 +114,42 @@ func isOpenAITTSModel(id string) bool {
108114
return strings.Contains(id, "tts") || strings.Contains(id, "audio")
109115
}
110116

117+
func decodeModelIDs(r io.Reader) ([]string, error) {
118+
body, err := io.ReadAll(r)
119+
if err != nil {
120+
return nil, err
121+
}
122+
123+
var wrapped struct {
124+
Data []struct {
125+
ID string `json:"id"`
126+
} `json:"data"`
127+
}
128+
if err := json.Unmarshal(body, &wrapped); err == nil && len(wrapped.Data) > 0 {
129+
out := make([]string, 0, len(wrapped.Data))
130+
for _, m := range wrapped.Data {
131+
if m.ID != "" {
132+
out = append(out, m.ID)
133+
}
134+
}
135+
return out, nil
136+
}
137+
138+
var direct []struct {
139+
ID string `json:"id"`
140+
}
141+
if err := json.Unmarshal(body, &direct); err != nil {
142+
return nil, err
143+
}
144+
out := make([]string, 0, len(direct))
145+
for _, m := range direct {
146+
if m.ID != "" {
147+
out = append(out, m.ID)
148+
}
149+
}
150+
return out, nil
151+
}
152+
111153
// DoSynthesize synthesizes speech and returns the complete audio bytes.
112154
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {
113155
cfg := parseConfig(params.Config)

provider/openai/speech/speech_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,24 @@ func TestProvider_ListModels(t *testing.T) {
157157
}
158158
}
159159

160+
func TestProvider_ListModels_ArrayResponse(t *testing.T) {
161+
t.Parallel()
162+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163+
w.Header().Set("Content-Type", "application/json")
164+
_, _ = w.Write([]byte(`[{"id":"gpt-4o-mini-tts"},{"id":"tts-1"},{"id":"gpt-4.1"}]`))
165+
}))
166+
defer srv.Close()
167+
168+
p := New(WithAPIKey("key"), WithBaseURL(srv.URL))
169+
models, err := p.ListModels(context.Background())
170+
if err != nil {
171+
t.Fatalf("ListModels: %v", err)
172+
}
173+
if len(models) != 2 {
174+
t.Fatalf("len(models) = %d, want 2", len(models))
175+
}
176+
}
177+
160178
func TestParseConfig(t *testing.T) {
161179
t.Parallel()
162180
cfg := parseConfig(map[string]any{

provider/openai/transcription/transcription.go

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,30 @@ func (p *Provider) TranscriptionModel(id string) *sdk.TranscriptionModel {
5151
}
5252

5353
func (p *Provider) ListModels(ctx context.Context) ([]*sdk.TranscriptionModel, error) {
54-
type modelsListResponse struct {
55-
Data []struct {
56-
ID string `json:"id"`
57-
} `json:"data"`
54+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/models", http.NoBody)
55+
if err != nil {
56+
return nil, fmt.Errorf("openai transcription: build list models request: %w", err)
5857
}
58+
req.Header.Set("Authorization", utils.BearerToken(p.apiKey))
5959

60-
resp, err := utils.FetchJSON[modelsListResponse](ctx, p.httpClient, &utils.RequestOptions{
61-
Method: http.MethodGet,
62-
BaseURL: p.baseURL,
63-
Path: "/models",
64-
Headers: map[string]string{"Authorization": utils.BearerToken(p.apiKey)},
65-
})
60+
resp, err := p.httpClient.Do(req)
6661
if err != nil {
6762
return nil, fmt.Errorf("openai transcription: list models request failed: %w", err)
6863
}
64+
defer resp.Body.Close()
65+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
66+
body, _ := io.ReadAll(resp.Body)
67+
return nil, fmt.Errorf("openai transcription: unexpected status %d: %s", resp.StatusCode, string(body))
68+
}
69+
70+
rawModels, err := decodeModelIDs(resp.Body)
71+
if err != nil {
72+
return nil, fmt.Errorf("openai transcription: decode response: %w", err)
73+
}
6974

70-
models := make([]*sdk.TranscriptionModel, 0, len(resp.Data))
71-
for _, m := range resp.Data {
75+
models := make([]*sdk.TranscriptionModel, 0, len(rawModels))
76+
for _, id := range rawModels {
77+
m := struct{ ID string }{ID: id}
7278
if isTranscriptionModel(m.ID) {
7379
models = append(models, p.TranscriptionModel(m.ID))
7480
}
@@ -84,6 +90,42 @@ func isTranscriptionModel(id string) bool {
8490
return id == "whisper-1" || strings.Contains(id, "transcribe")
8591
}
8692

93+
func decodeModelIDs(r io.Reader) ([]string, error) {
94+
body, err := io.ReadAll(r)
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
var wrapped struct {
100+
Data []struct {
101+
ID string `json:"id"`
102+
} `json:"data"`
103+
}
104+
if err := json.Unmarshal(body, &wrapped); err == nil && len(wrapped.Data) > 0 {
105+
out := make([]string, 0, len(wrapped.Data))
106+
for _, m := range wrapped.Data {
107+
if m.ID != "" {
108+
out = append(out, m.ID)
109+
}
110+
}
111+
return out, nil
112+
}
113+
114+
var direct []struct {
115+
ID string `json:"id"`
116+
}
117+
if err := json.Unmarshal(body, &direct); err != nil {
118+
return nil, err
119+
}
120+
out := make([]string, 0, len(direct))
121+
for _, m := range direct {
122+
if m.ID != "" {
123+
out = append(out, m.ID)
124+
}
125+
}
126+
return out, nil
127+
}
128+
87129
func (p *Provider) DoTranscribe(ctx context.Context, params sdk.TranscriptionParams) (*sdk.TranscriptionResult, error) {
88130
cfg := parseConfig(params.Config)
89131
modelID := defaultModelID

0 commit comments

Comments
 (0)