Skip to content

Commit 8ac67fb

Browse files
authored
feat: transcribe and tts model discovery (#8)
1 parent dbedfe3 commit 8ac67fb

53 files changed

Lines changed: 2786 additions & 93 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/providers.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ model := provider.ChatModel("gpt-4o-mini")
9696
| Option | Default | Description |
9797
|--------|---------|-------------|
9898
| `WithAPIKey(key)` | `""` | API key sent as `Authorization: Bearer <key>` |
99+
| `WithBedrockRegion(region)` | disabled | Use AWS SigV4 with the default AWS credential chain for Bedrock's OpenAI-compatible endpoint |
100+
| `WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken)` | disabled | Use AWS SigV4 with static AWS credentials for Bedrock |
99101
| `WithBaseURL(url)` | `https://api.openai.com/v1` | Base URL for API requests |
100102
| `WithHTTPClient(client)` | `&http.Client{}` | Custom HTTP client (for proxies, timeouts, etc.) |
101103

@@ -130,6 +132,12 @@ provider := completions.New(
130132
completions.WithBaseURL("https://your-resource.openai.azure.com/openai/deployments/gpt-4o"),
131133
)
132134

135+
// Amazon Bedrock OpenAI-compatible endpoint (AWS credentials)
136+
provider := completions.New(
137+
completions.WithBedrockRegion("us-east-1"),
138+
completions.WithBaseURL("https://bedrock-mantle.us-east-1.api.aws/v1"),
139+
)
140+
133141
// Local (Ollama, vLLM, etc.)
134142
provider := completions.New(
135143
completions.WithBaseURL("http://localhost:11434/v1"),
@@ -198,6 +206,8 @@ model := provider.ChatModel("gpt-4o-mini")
198206
| Option | Default | Description |
199207
|--------|---------|-------------|
200208
| `WithAPIKey(key)` | `""` | API key sent as `Authorization: Bearer <key>` |
209+
| `WithBedrockRegion(region)` | disabled | Use AWS SigV4 with the default AWS credential chain for Bedrock's OpenAI-compatible endpoint |
210+
| `WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken)` | disabled | Use AWS SigV4 with static AWS credentials for Bedrock |
201211
| `WithBaseURL(url)` | `https://api.openai.com/v1` | Base URL for API requests |
202212
| `WithHTTPClient(client)` | `&http.Client{}` | Custom HTTP client |
203213

@@ -218,6 +228,8 @@ provider := responses.New(
218228
responses.WithAPIKey("sk-or-v1-..."),
219229
responses.WithBaseURL("https://openrouter.ai/api/v1"),
220230
)
231+
232+
Amazon Bedrock's OpenAI-compatible Responses endpoint also works with the same provider when you configure a Bedrock Mantle base URL and either `WithBedrockRegion(...)` or `WithBedrockCredentials(...)`.
221233
model := provider.ChatModel("openai/o4-mini")
222234
```
223235
@@ -777,6 +789,8 @@ vec, err := sdk.Embed(ctx, "Hello world",
777789
| Option | Default | Description |
778790
|--------|---------|-------------|
779791
| `WithAPIKey(key)` | `""` | API key sent as `Authorization: Bearer <key>` |
792+
| `WithBedrockRegion(region)` | disabled | Use AWS SigV4 with the default AWS credential chain for Bedrock's OpenAI-compatible endpoint |
793+
| `WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken)` | disabled | Use AWS SigV4 with static AWS credentials for Bedrock |
780794
| `WithBaseURL(url)` | `https://api.openai.com/v1` | Base URL for API requests |
781795
| `WithHTTPClient(client)` | `&http.Client{}` | Custom HTTP client |
782796

go.mod

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,27 @@ module github.com/memohai/twilight-ai
33
go 1.25.7
44

55
require (
6+
github.com/aws/aws-sdk-go-v2 v1.41.5
7+
github.com/aws/aws-sdk-go-v2/config v1.32.14
8+
github.com/aws/aws-sdk-go-v2/credentials v1.19.14
69
github.com/google/jsonschema-go v0.4.2
710
github.com/google/uuid v1.6.0
811
github.com/gorilla/websocket v1.5.3
912
github.com/modelcontextprotocol/go-sdk v1.5.0
1013
)
1114

1215
require (
16+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
17+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
18+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
19+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
20+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
21+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
22+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
23+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
24+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
25+
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
26+
github.com/aws/smithy-go v1.24.2 // indirect
1327
github.com/segmentio/asm v1.2.1 // indirect
1428
github.com/segmentio/encoding v0.5.4 // indirect
1529
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect

go.sum

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,31 @@
1+
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
2+
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
3+
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
4+
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
5+
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
6+
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
7+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
8+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
9+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
10+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
11+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
12+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
13+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
14+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
15+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
16+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
17+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
18+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
19+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
20+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
21+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
22+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
23+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
24+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
25+
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
26+
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
27+
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
28+
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
129
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
230
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
331
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=

internal/utils/aws.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package utils
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"crypto/sha256"
7+
"encoding/hex"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"sync"
12+
"time"
13+
14+
"github.com/aws/aws-sdk-go-v2/aws"
15+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
16+
"github.com/aws/aws-sdk-go-v2/config"
17+
"github.com/aws/aws-sdk-go-v2/credentials"
18+
)
19+
20+
const (
21+
emptySHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
22+
23+
bedrockService = "bedrock"
24+
)
25+
26+
type lazyCredentialsProvider struct {
27+
region string
28+
29+
once sync.Once
30+
cp aws.CredentialsProvider
31+
err error
32+
}
33+
34+
func NewBedrockDefaultCredentialsPreparer(region string) func(*http.Request) error {
35+
return NewSigV4Preparer(bedrockService, region, &lazyCredentialsProvider{region: region})
36+
}
37+
38+
func NewBedrockStaticCredentialsPreparer(region, accessKeyID, secretAccessKey, sessionToken string) func(*http.Request) error {
39+
cp := aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, sessionToken))
40+
return NewSigV4Preparer(bedrockService, region, cp)
41+
}
42+
43+
func NewSigV4Preparer(service, region string, cp aws.CredentialsProvider) func(*http.Request) error {
44+
signer := v4.NewSigner()
45+
46+
return func(req *http.Request) error {
47+
payloadHash, err := requestPayloadHash(req)
48+
if err != nil {
49+
return err
50+
}
51+
52+
creds, err := cp.Retrieve(req.Context())
53+
if err != nil {
54+
return fmt.Errorf("retrieve AWS credentials: %w", err)
55+
}
56+
57+
if err := signer.SignHTTP(req.Context(), creds, req, payloadHash, service, region, time.Now().UTC()); err != nil {
58+
return fmt.Errorf("sign AWS request: %w", err)
59+
}
60+
61+
return nil
62+
}
63+
}
64+
65+
func (p *lazyCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
66+
p.once.Do(func() {
67+
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(p.region))
68+
if err != nil {
69+
p.err = err
70+
return
71+
}
72+
p.cp = cfg.Credentials
73+
})
74+
if p.err != nil {
75+
return aws.Credentials{}, p.err
76+
}
77+
return p.cp.Retrieve(ctx)
78+
}
79+
80+
func requestPayloadHash(req *http.Request) (string, error) {
81+
if req.Body == nil {
82+
return emptySHA256, nil
83+
}
84+
85+
if req.GetBody != nil {
86+
body, err := req.GetBody()
87+
if err != nil {
88+
return "", fmt.Errorf("clone request body: %w", err)
89+
}
90+
defer body.Close()
91+
return hashReader(body)
92+
}
93+
94+
data, err := io.ReadAll(req.Body)
95+
if err != nil {
96+
return "", fmt.Errorf("read request body: %w", err)
97+
}
98+
99+
req.Body = io.NopCloser(bytes.NewReader(data))
100+
req.GetBody = func() (io.ReadCloser, error) {
101+
return io.NopCloser(bytes.NewReader(data)), nil
102+
}
103+
104+
return hashBytes(data), nil
105+
}
106+
107+
func hashReader(r io.Reader) (string, error) {
108+
h := sha256.New()
109+
if _, err := io.Copy(h, r); err != nil {
110+
return "", fmt.Errorf("hash request body: %w", err)
111+
}
112+
return hex.EncodeToString(h.Sum(nil)), nil
113+
}
114+
115+
func hashBytes(data []byte) string {
116+
sum := sha256.Sum256(data)
117+
return hex.EncodeToString(sum[:])
118+
}

internal/utils/fetch.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type RequestOptions struct {
1717
Headers map[string]string
1818
Query map[string]string
1919
Body any
20+
Prepare func(*http.Request) error
2021
}
2122

2223
type APIError struct {
@@ -86,6 +87,12 @@ func BuildRequest(ctx context.Context, opts *RequestOptions) (*http.Request, err
8687
req.Header.Set(k, v)
8788
}
8889

90+
if opts.Prepare != nil {
91+
if err := opts.Prepare(req); err != nil {
92+
return nil, fmt.Errorf("prepare request: %w", err)
93+
}
94+
}
95+
8996
return req, nil
9097
}
9198

provider/alibabacloud/speech/speech.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ package speech
55

66
import (
77
"context"
8+
"fmt"
89
"strings"
910

1011
sdk "github.com/memohai/twilight-ai/sdk"
1112
)
1213

1314
const (
14-
defaultModelID = "cosyvoice-tts"
15+
defaultModelID = "cosyvoice-v1"
1516
defaultBaseURL = "wss://dashscope.aliyuncs.com/api-ws/v1/inference/"
1617
defaultModel = "cosyvoice-v1"
1718
defaultFormat = "mp3"
@@ -51,14 +52,22 @@ func New(opts ...Option) *Provider {
5152
// SpeechModel creates a SpeechModel bound to this provider.
5253
func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
5354
if id == "" {
54-
id = defaultModelID
55+
id = defaultModel
5556
}
5657
return &sdk.SpeechModel{ID: id, Provider: p}
5758
}
5859

60+
// ListModels returns the speech models exposed by this provider.
61+
func (p *Provider) ListModels(context.Context) ([]*sdk.SpeechModel, error) {
62+
return nil, fmt.Errorf("alibabacloud speech: provider does not expose a remote models discovery API")
63+
}
64+
5965
// DoSynthesize synthesizes speech and returns the complete audio bytes.
6066
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {
6167
cfg := parseConfig(params.Config)
68+
if params.Model != nil && params.Model.ID != "" {
69+
cfg.Model = params.Model.ID
70+
}
6271

6372
audio, err := p.client.synthesize(ctx, params.Text, &cfg)
6473
if err != nil {
@@ -73,6 +82,9 @@ func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*
7382
// DoStream synthesizes speech and returns a streaming result.
7483
func (p *Provider) DoStream(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechStreamResult, error) {
7584
cfg := parseConfig(params.Config)
85+
if params.Model != nil && params.Model.ID != "" {
86+
cfg.Model = params.Model.ID
87+
}
7688

7789
ch, errCh := p.client.stream(ctx, params.Text, &cfg)
7890
return sdk.NewSpeechStreamResult(ch, contentTypeForFormat(cfg.Format), errCh), nil

provider/alibabacloud/speech/speech_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ func TestProvider_SpeechModel(t *testing.T) {
163163
}
164164
}
165165

166+
func TestProvider_ListModels(t *testing.T) {
167+
t.Parallel()
168+
p := New()
169+
170+
models, err := p.ListModels(context.Background())
171+
if err == nil {
172+
t.Fatal("expected unsupported error")
173+
}
174+
if len(models) != 0 {
175+
t.Fatalf("len(models) = %d, want 0", len(models))
176+
}
177+
}
178+
166179
func TestParseConfig(t *testing.T) {
167180
t.Parallel()
168181
cfg := parseConfig(map[string]any{

provider/anthropic/messages/messages.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,9 @@ type streamingBlock struct {
763763

764764
func generateID() string {
765765
b := make([]byte, 12)
766-
rand.Read(b)
766+
if _, err := rand.Read(b); err != nil {
767+
panic("anthropic: generateID entropy failure: " + err.Error())
768+
}
767769
return fmt.Sprintf("toolu_%x", b)
768770
}
769771

provider/deepgram/speech/speech.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,22 @@ func New(opts ...Option) *Provider {
6666
// SpeechModel creates a SpeechModel bound to this provider.
6767
func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
6868
if id == "" {
69-
id = defaultModelID
69+
id = defaultVoiceModel
7070
}
7171
return &sdk.SpeechModel{ID: id, Provider: p}
7272
}
7373

74+
// ListModels returns the speech models exposed by this provider.
75+
func (p *Provider) ListModels(context.Context) ([]*sdk.SpeechModel, error) {
76+
return nil, fmt.Errorf("deepgram speech: provider does not expose a remote models discovery API in this SDK")
77+
}
78+
7479
// DoSynthesize synthesizes speech and returns the complete audio bytes.
7580
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {
7681
cfg := parseConfig(params.Config)
82+
if params.Model != nil && params.Model.ID != "" {
83+
cfg.Model = params.Model.ID
84+
}
7785

7886
body, err := p.doRequest(ctx, params.Text, cfg)
7987
if err != nil {
@@ -94,6 +102,9 @@ func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*
94102
// DoStream synthesizes speech and returns a streaming result backed by chunked HTTP body.
95103
func (p *Provider) DoStream(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechStreamResult, error) {
96104
cfg := parseConfig(params.Config)
105+
if params.Model != nil && params.Model.ID != "" {
106+
cfg.Model = params.Model.ID
107+
}
97108

98109
body, err := p.doRequest(ctx, params.Text, cfg)
99110
if err != nil {

provider/deepgram/speech/speech_test.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,26 @@ func TestProvider_DoSynthesize_ConnectionFailure(t *testing.T) {
108108
func TestProvider_SpeechModel(t *testing.T) {
109109
t.Parallel()
110110
p := New()
111-
m := p.SpeechModel("deepgram-tts")
112-
if m.ID != "deepgram-tts" {
111+
m := p.SpeechModel("aura-2-orpheus-en")
112+
if m.ID != "aura-2-orpheus-en" {
113113
t.Errorf("ID = %q", m.ID)
114114
}
115115
m2 := p.SpeechModel("")
116-
if m2.ID != defaultModelID {
117-
t.Errorf("default ID = %q, want %q", m2.ID, defaultModelID)
116+
if m2.ID != defaultVoiceModel {
117+
t.Errorf("default ID = %q, want %q", m2.ID, defaultVoiceModel)
118+
}
119+
}
120+
121+
func TestProvider_ListModels(t *testing.T) {
122+
t.Parallel()
123+
p := New()
124+
125+
models, err := p.ListModels(context.Background())
126+
if err == nil {
127+
t.Fatal("expected unsupported error")
128+
}
129+
if len(models) != 0 {
130+
t.Fatalf("len(models) = %d, want 0", len(models))
118131
}
119132
}
120133

0 commit comments

Comments
 (0)