Skip to content

Commit eff8dc1

Browse files
authored
fix(audio): fix audioTextResponse decode (#638)
* fix(audio): fix audioTextResponse decode * test(audio): add audioTextResponse decode test * test(audio): simplify code
1 parent 4ce03a9 commit eff8dc1

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

client.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,14 @@ func decodeResponse(body io.Reader, v any) error {
193193
return nil
194194
}
195195

196-
if result, ok := v.(*string); ok {
197-
return decodeString(body, result)
196+
switch o := v.(type) {
197+
case *string:
198+
return decodeString(body, o)
199+
case *audioTextResponse:
200+
return decodeString(body, &o.Text)
201+
default:
202+
return json.NewDecoder(body).Decode(v)
198203
}
199-
return json.NewDecoder(body).Decode(v)
200204
}
201205

202206
func decodeString(body io.Reader, output *string) error {

client_test.go

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"reflect"
1011
"testing"
1112

1213
"github.com/sashabaranov/go-openai/internal/test"
14+
"github.com/sashabaranov/go-openai/internal/test/checks"
1315
)
1416

1517
var errTestRequestBuilderFailed = errors.New("test request builder failed")
@@ -43,38 +45,68 @@ func TestDecodeResponse(t *testing.T) {
4345
testCases := []struct {
4446
name string
4547
value interface{}
48+
expected interface{}
4649
body io.Reader
4750
hasError bool
4851
}{
4952
{
50-
name: "nil input",
51-
value: nil,
52-
body: bytes.NewReader([]byte("")),
53+
name: "nil input",
54+
value: nil,
55+
body: bytes.NewReader([]byte("")),
56+
expected: nil,
5357
},
5458
{
55-
name: "string input",
56-
value: &stringInput,
57-
body: bytes.NewReader([]byte("test")),
59+
name: "string input",
60+
value: &stringInput,
61+
body: bytes.NewReader([]byte("test")),
62+
expected: "test",
5863
},
5964
{
6065
name: "map input",
6166
value: &map[string]interface{}{},
6267
body: bytes.NewReader([]byte(`{"test": "test"}`)),
68+
expected: map[string]interface{}{
69+
"test": "test",
70+
},
6371
},
6472
{
6573
name: "reader return error",
6674
value: &stringInput,
6775
body: &errorReader{err: errors.New("dummy")},
6876
hasError: true,
6977
},
78+
{
79+
name: "audio text input",
80+
value: &audioTextResponse{},
81+
body: bytes.NewReader([]byte("test")),
82+
expected: audioTextResponse{
83+
Text: "test",
84+
},
85+
},
86+
}
87+
88+
assertEqual := func(t *testing.T, expected, actual interface{}) {
89+
t.Helper()
90+
if expected == actual {
91+
return
92+
}
93+
v := reflect.ValueOf(actual).Elem().Interface()
94+
if !reflect.DeepEqual(v, expected) {
95+
t.Fatalf("Unexpected value: %v, expected: %v", v, expected)
96+
}
7097
}
7198

7299
for _, tc := range testCases {
73100
t.Run(tc.name, func(t *testing.T) {
74101
err := decodeResponse(tc.body, tc.value)
75-
if (err != nil) != tc.hasError {
76-
t.Errorf("Unexpected error: %v", err)
102+
if tc.hasError {
103+
checks.HasError(t, err, "Unexpected nil error")
104+
return
105+
}
106+
if err != nil {
107+
t.Fatalf("Unexpected error: %v", err)
77108
}
109+
assertEqual(t, tc.expected, tc.value)
78110
})
79111
}
80112
}

0 commit comments

Comments
 (0)