Skip to content

Commit 0f98a1f

Browse files
committed
Add transcription providers and speech model discovery
1 parent ddf8b1c commit 0f98a1f

42 files changed

Lines changed: 1832 additions & 36 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

provider/alibabacloud/speech/speech.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ package speech
55

66
import (
77
"context"
8+
"fmt"
89
"strings"
910

1011
sdk "github.com/memohai/twilight-ai/sdk"
1112
)
1213

1314
const (
14-
defaultModelID = "cosyvoice-tts"
15+
defaultModelID = "cosyvoice-v1"
1516
defaultBaseURL = "wss://dashscope.aliyuncs.com/api-ws/v1/inference/"
1617
defaultModel = "cosyvoice-v1"
1718
defaultFormat = "mp3"
@@ -51,14 +52,22 @@ func New(opts ...Option) *Provider {
5152
// SpeechModel creates a SpeechModel bound to this provider.
5253
func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
5354
if id == "" {
54-
id = defaultModelID
55+
id = defaultModel
5556
}
5657
return &sdk.SpeechModel{ID: id, Provider: p}
5758
}
5859

60+
// ListModels returns the speech models exposed by this provider.
61+
func (p *Provider) ListModels(context.Context) ([]*sdk.SpeechModel, error) {
62+
return nil, fmt.Errorf("alibabacloud speech: provider does not expose a remote models discovery API")
63+
}
64+
5965
// DoSynthesize synthesizes speech and returns the complete audio bytes.
6066
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {
6167
cfg := parseConfig(params.Config)
68+
if params.Model != nil && params.Model.ID != "" {
69+
cfg.Model = params.Model.ID
70+
}
6271

6372
audio, err := p.client.synthesize(ctx, params.Text, &cfg)
6473
if err != nil {
@@ -73,6 +82,9 @@ func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*
7382
// DoStream synthesizes speech and returns a streaming result.
7483
func (p *Provider) DoStream(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechStreamResult, error) {
7584
cfg := parseConfig(params.Config)
85+
if params.Model != nil && params.Model.ID != "" {
86+
cfg.Model = params.Model.ID
87+
}
7688

7789
ch, errCh := p.client.stream(ctx, params.Text, &cfg)
7890
return sdk.NewSpeechStreamResult(ch, contentTypeForFormat(cfg.Format), errCh), nil

provider/alibabacloud/speech/speech_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ func TestProvider_SpeechModel(t *testing.T) {
163163
}
164164
}
165165

166+
func TestProvider_ListModels(t *testing.T) {
167+
t.Parallel()
168+
p := New()
169+
170+
models, err := p.ListModels(context.Background())
171+
if err == nil {
172+
t.Fatal("expected unsupported error")
173+
}
174+
if len(models) != 0 {
175+
t.Fatalf("len(models) = %d, want 0", len(models))
176+
}
177+
}
178+
166179
func TestParseConfig(t *testing.T) {
167180
t.Parallel()
168181
cfg := parseConfig(map[string]any{

provider/anthropic/messages/messages.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,9 @@ type streamingBlock struct {
763763

764764
func generateID() string {
765765
b := make([]byte, 12)
766-
rand.Read(b)
766+
if _, err := rand.Read(b); err != nil {
767+
panic("anthropic: generateID entropy failure: " + err.Error())
768+
}
767769
return fmt.Sprintf("toolu_%x", b)
768770
}
769771

provider/deepgram/speech/speech.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,22 @@ func New(opts ...Option) *Provider {
6666
// SpeechModel creates a SpeechModel bound to this provider.
6767
func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
6868
if id == "" {
69-
id = defaultModelID
69+
id = defaultVoiceModel
7070
}
7171
return &sdk.SpeechModel{ID: id, Provider: p}
7272
}
7373

74+
// ListModels returns the speech models exposed by this provider.
75+
func (p *Provider) ListModels(context.Context) ([]*sdk.SpeechModel, error) {
76+
return nil, fmt.Errorf("deepgram speech: provider does not expose a remote models discovery API in this SDK")
77+
}
78+
7479
// DoSynthesize synthesizes speech and returns the complete audio bytes.
7580
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {
7681
cfg := parseConfig(params.Config)
82+
if params.Model != nil && params.Model.ID != "" {
83+
cfg.Model = params.Model.ID
84+
}
7785

7886
body, err := p.doRequest(ctx, params.Text, cfg)
7987
if err != nil {
@@ -94,6 +102,9 @@ func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*
94102
// DoStream synthesizes speech and returns a streaming result backed by chunked HTTP body.
95103
func (p *Provider) DoStream(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechStreamResult, error) {
96104
cfg := parseConfig(params.Config)
105+
if params.Model != nil && params.Model.ID != "" {
106+
cfg.Model = params.Model.ID
107+
}
97108

98109
body, err := p.doRequest(ctx, params.Text, cfg)
99110
if err != nil {

provider/deepgram/speech/speech_test.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,26 @@ func TestProvider_DoSynthesize_ConnectionFailure(t *testing.T) {
108108
func TestProvider_SpeechModel(t *testing.T) {
109109
t.Parallel()
110110
p := New()
111-
m := p.SpeechModel("deepgram-tts")
112-
if m.ID != "deepgram-tts" {
111+
m := p.SpeechModel("aura-2-orpheus-en")
112+
if m.ID != "aura-2-orpheus-en" {
113113
t.Errorf("ID = %q", m.ID)
114114
}
115115
m2 := p.SpeechModel("")
116-
if m2.ID != defaultModelID {
117-
t.Errorf("default ID = %q, want %q", m2.ID, defaultModelID)
116+
if m2.ID != defaultVoiceModel {
117+
t.Errorf("default ID = %q, want %q", m2.ID, defaultVoiceModel)
118+
}
119+
}
120+
121+
func TestProvider_ListModels(t *testing.T) {
122+
t.Parallel()
123+
p := New()
124+
125+
models, err := p.ListModels(context.Background())
126+
if err == nil {
127+
t.Fatal("expected unsupported error")
128+
}
129+
if len(models) != 0 {
130+
t.Fatalf("len(models) = %d, want 0", len(models))
118131
}
119132
}
120133

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
package transcription
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"net/url"
11+
"strings"
12+
13+
sdk "github.com/memohai/twilight-ai/sdk"
14+
)
15+
16+
const (
17+
defaultModelID = "nova-3"
18+
defaultBaseURL = "https://api.deepgram.com"
19+
)
20+
21+
type Option func(*Provider)
22+
23+
func WithAPIKey(key string) Option { return func(p *Provider) { p.apiKey = key } }
24+
func WithBaseURL(baseURL string) Option {
25+
return func(p *Provider) { p.baseURL = strings.TrimRight(baseURL, "/") }
26+
}
27+
func WithHTTPClient(hc *http.Client) Option { return func(p *Provider) { p.httpClient = hc } }
28+
29+
type Provider struct {
30+
apiKey string
31+
baseURL string
32+
httpClient *http.Client
33+
}
34+
35+
func New(opts ...Option) *Provider {
36+
p := &Provider{baseURL: defaultBaseURL, httpClient: &http.Client{}}
37+
for _, opt := range opts {
38+
opt(p)
39+
}
40+
return p
41+
}
42+
43+
func (p *Provider) TranscriptionModel(id string) *sdk.TranscriptionModel {
44+
if id == "" {
45+
id = defaultModelID
46+
}
47+
return &sdk.TranscriptionModel{ID: id, Provider: p}
48+
}
49+
50+
func (p *Provider) ListModels(context.Context) ([]*sdk.TranscriptionModel, error) {
51+
return nil, fmt.Errorf("deepgram transcription: provider does not expose a remote models discovery API in this SDK")
52+
}
53+
54+
type audioConfig struct {
55+
Language string
56+
SmartFormat bool
57+
DetectLang bool
58+
Diarize bool
59+
Punctuate bool
60+
}
61+
62+
func parseConfig(cfg map[string]any) audioConfig {
63+
ac := audioConfig{SmartFormat: true, Punctuate: true}
64+
if cfg == nil {
65+
return ac
66+
}
67+
if v, ok := cfg["language"].(string); ok && v != "" {
68+
ac.Language = v
69+
}
70+
if v, ok := cfg["smart_format"].(bool); ok {
71+
ac.SmartFormat = v
72+
}
73+
if v, ok := cfg["detect_language"].(bool); ok {
74+
ac.DetectLang = v
75+
}
76+
if v, ok := cfg["diarize"].(bool); ok {
77+
ac.Diarize = v
78+
}
79+
if v, ok := cfg["punctuate"].(bool); ok {
80+
ac.Punctuate = v
81+
}
82+
return ac
83+
}
84+
85+
func (p *Provider) DoTranscribe(ctx context.Context, params sdk.TranscriptionParams) (*sdk.TranscriptionResult, error) {
86+
cfg := parseConfig(params.Config)
87+
modelID := defaultModelID
88+
if params.Model != nil && params.Model.ID != "" {
89+
modelID = params.Model.ID
90+
}
91+
92+
u, err := url.Parse(p.baseURL + "/v1/listen")
93+
if err != nil {
94+
return nil, fmt.Errorf("deepgram transcription: parse URL: %w", err)
95+
}
96+
q := u.Query()
97+
q.Set("model", modelID)
98+
if cfg.Language != "" {
99+
q.Set("language", cfg.Language)
100+
}
101+
if cfg.SmartFormat {
102+
q.Set("smart_format", "true")
103+
}
104+
if cfg.DetectLang {
105+
q.Set("detect_language", "true")
106+
}
107+
if cfg.Diarize {
108+
q.Set("diarize", "true")
109+
}
110+
if cfg.Punctuate {
111+
q.Set("punctuate", "true")
112+
}
113+
u.RawQuery = q.Encode()
114+
115+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(params.Audio))
116+
if err != nil {
117+
return nil, fmt.Errorf("deepgram transcription: build request: %w", err)
118+
}
119+
if params.ContentType != "" {
120+
req.Header.Set("Content-Type", params.ContentType)
121+
} else {
122+
req.Header.Set("Content-Type", "audio/wav")
123+
}
124+
req.Header.Set("Authorization", "Token "+p.apiKey)
125+
126+
resp, err := p.httpClient.Do(req)
127+
if err != nil {
128+
return nil, fmt.Errorf("deepgram transcription: request failed: %w", err)
129+
}
130+
defer resp.Body.Close()
131+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
132+
body, _ := io.ReadAll(resp.Body)
133+
return nil, fmt.Errorf("deepgram transcription: unexpected status %d: %s", resp.StatusCode, string(body))
134+
}
135+
136+
var payload struct {
137+
Results struct {
138+
Channels []struct {
139+
DetectedLanguage string `json:"detected_language"`
140+
Alternatives []struct {
141+
Transcript string `json:"transcript"`
142+
Words []struct {
143+
Word string `json:"word"`
144+
Start float64 `json:"start"`
145+
End float64 `json:"end"`
146+
Speaker int `json:"speaker"`
147+
} `json:"words"`
148+
} `json:"alternatives"`
149+
} `json:"channels"`
150+
} `json:"results"`
151+
Metadata struct {
152+
Duration float64 `json:"duration"`
153+
} `json:"metadata"`
154+
}
155+
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
156+
return nil, fmt.Errorf("deepgram transcription: decode response: %w", err)
157+
}
158+
if len(payload.Results.Channels) == 0 || len(payload.Results.Channels[0].Alternatives) == 0 {
159+
return nil, fmt.Errorf("deepgram transcription: empty transcript in response")
160+
}
161+
alt := payload.Results.Channels[0].Alternatives[0]
162+
out := &sdk.TranscriptionResult{
163+
Text: alt.Transcript,
164+
Language: payload.Results.Channels[0].DetectedLanguage,
165+
DurationSeconds: payload.Metadata.Duration,
166+
}
167+
if len(alt.Words) > 0 {
168+
out.Words = make([]sdk.TranscriptionWord, 0, len(alt.Words))
169+
for _, w := range alt.Words {
170+
out.Words = append(out.Words, sdk.TranscriptionWord{
171+
Text: w.Word,
172+
Start: w.Start,
173+
End: w.End,
174+
SpeakerID: fmt.Sprintf("speaker_%d", w.Speaker),
175+
})
176+
}
177+
}
178+
return out, nil
179+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package transcription
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
sdk "github.com/memohai/twilight-ai/sdk"
10+
)
11+
12+
func TestProvider_ListModels(t *testing.T) {
13+
t.Parallel()
14+
p := New()
15+
models, err := p.ListModels(context.Background())
16+
if err == nil {
17+
t.Fatal("expected unsupported error")
18+
}
19+
if len(models) != 0 {
20+
t.Fatalf("len(models) = %d, want 0", len(models))
21+
}
22+
}
23+
24+
func TestProvider_DoTranscribe(t *testing.T) {
25+
t.Parallel()
26+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27+
if r.URL.Path != "/v1/listen" {
28+
t.Fatalf("path = %s", r.URL.Path)
29+
}
30+
_, _ = w.Write([]byte(`{"metadata":{"duration":1.5},"results":{"channels":[{"detected_language":"en","alternatives":[{"transcript":"hello from deepgram","words":[{"word":"hello","start":0,"end":0.3,"speaker":0}]}]}]}}`))
31+
}))
32+
defer srv.Close()
33+
34+
p := New(WithAPIKey("key"), WithBaseURL(srv.URL))
35+
result, err := p.DoTranscribe(context.Background(), sdk.TranscriptionParams{
36+
Model: p.TranscriptionModel("nova-3"),
37+
Audio: []byte("audio"),
38+
Filename: "test.wav",
39+
ContentType: "audio/wav",
40+
})
41+
if err != nil {
42+
t.Fatalf("DoTranscribe: %v", err)
43+
}
44+
if result.Text != "hello from deepgram" {
45+
t.Fatalf("text = %q", result.Text)
46+
}
47+
}

provider/edge/speech/speech.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package speech
22

33
import (
44
"context"
5+
"fmt"
56
"strings"
67

78
sdk "github.com/memohai/twilight-ai/sdk"
@@ -47,6 +48,11 @@ func (p *Provider) SpeechModel(id string) *sdk.SpeechModel {
4748
return &sdk.SpeechModel{ID: id, Provider: p}
4849
}
4950

51+
// ListModels returns the speech models exposed by this provider.
52+
func (p *Provider) ListModels(context.Context) ([]*sdk.SpeechModel, error) {
53+
return nil, fmt.Errorf("edge speech: provider does not expose a remote models discovery API")
54+
}
55+
5056
// DoSynthesize synthesizes speech and returns the complete audio.
5157
func (p *Provider) DoSynthesize(ctx context.Context, params sdk.SpeechParams) (*sdk.SpeechResult, error) {
5258
cfg := parseConfig(params.Config)

0 commit comments

Comments
 (0)