Skip to content

Commit 2e53c78

Browse files
committed
refactor: move each provider into its own package
Reasoning for this is: 1. Users can import only those they want, and the Go compiler won't compile the external library the user don't need. 2. This simplify the API and makes it follow the Go conventions better: * `ai.NewOpenAiProvider` -> `openai.New` * `ai.WithOpenAiAPIKey` -> `openai.WithAPIKey` * etc.
1 parent 66fbbb0 commit 2e53c78

8 files changed

Lines changed: 259 additions & 259 deletions

File tree

examples/agent/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import (
66
"os"
77

88
"github.com/charmbracelet/ai"
9-
"github.com/charmbracelet/ai/providers"
9+
"github.com/charmbracelet/ai/providers/openai"
1010
)
1111

1212
func main() {
13-
provider := providers.NewOpenAiProvider(
14-
providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY")),
13+
provider := openai.New(
14+
openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")),
1515
)
1616
model, err := provider.LanguageModel("gpt-4o")
1717
if err != nil {

examples/simple/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ import (
66
"os"
77

88
"github.com/charmbracelet/ai"
9-
"github.com/charmbracelet/ai/providers"
9+
"github.com/charmbracelet/ai/providers/anthropic"
1010
)
1111

1212
func main() {
13-
provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
13+
provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
1414
model, err := provider.LanguageModel("claude-sonnet-4-20250514")
1515
if err != nil {
1616
fmt.Println(err)

examples/stream/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ import (
77
"os"
88

99
"github.com/charmbracelet/ai"
10-
"github.com/charmbracelet/ai/providers"
10+
"github.com/charmbracelet/ai/providers/openai"
1111
)
1212

1313
func main() {
14-
provider := providers.NewOpenAiProvider(providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY")))
14+
provider := openai.New(openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")))
1515
model, err := provider.LanguageModel("gpt-4o")
1616
if err != nil {
1717
fmt.Println(err)

examples/streaming-agent-simple/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"os"
77

88
"github.com/charmbracelet/ai"
9-
"github.com/charmbracelet/ai/providers"
9+
"github.com/charmbracelet/ai/providers/openai"
1010
)
1111

1212
func main() {
@@ -18,8 +18,8 @@ func main() {
1818
}
1919

2020
// Create provider and model
21-
provider := providers.NewOpenAiProvider(
22-
providers.WithOpenAiAPIKey(apiKey),
21+
provider := openai.New(
22+
openai.WithAPIKey(apiKey),
2323
)
2424
model, err := provider.LanguageModel("gpt-4o-mini")
2525
if err != nil {

examples/streaming-agent/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"strings"
88

99
"github.com/charmbracelet/ai"
10-
"github.com/charmbracelet/ai/providers"
10+
"github.com/charmbracelet/ai/providers/anthropic"
1111
)
1212

1313
func main() {
@@ -24,7 +24,7 @@ func main() {
2424
fmt.Println()
2525

2626
// Create OpenAI provider and model
27-
provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
27+
provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
2828
model, err := provider.LanguageModel("claude-sonnet-4-20250514")
2929
if err != nil {
3030
fmt.Println(err)
Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package providers
1+
package anthropic
22

33
import (
44
"context"
@@ -16,46 +16,46 @@ import (
1616
"github.com/charmbracelet/ai"
1717
)
1818

19-
type AnthropicProviderOptions struct {
20-
SendReasoning *bool `json:"send_reasoning,omitempty"`
21-
Thinking *AnthropicThinkingProviderOption `json:"thinking,omitempty"`
22-
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
19+
type ProviderOptions struct {
20+
SendReasoning *bool `json:"send_reasoning,omitempty"`
21+
Thinking *ThinkingProviderOption `json:"thinking,omitempty"`
22+
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
2323
}
2424

25-
type AnthropicThinkingProviderOption struct {
25+
type ThinkingProviderOption struct {
2626
BudgetTokens int64 `json:"budget_tokens"`
2727
}
2828

29-
type AnthropicReasoningMetadata struct {
29+
type ReasoningMetadata struct {
3030
Signature string `json:"signature"`
3131
RedactedData string `json:"redacted_data"`
3232
}
3333

34-
type AnthropicCacheControlProviderOptions struct {
34+
type CacheControlProviderOptions struct {
3535
Type string `json:"type"`
3636
}
37-
type AnthropicFilePartProviderOptions struct {
37+
type FilePartProviderOptions struct {
3838
EnableCitations bool `json:"enable_citations"`
3939
Title string `json:"title"`
4040
Context string `json:"context"`
4141
}
4242

43-
type anthropicProviderOptions struct {
43+
type options struct {
4444
baseURL string
4545
apiKey string
4646
name string
4747
headers map[string]string
4848
client option.HTTPClient
4949
}
5050

51-
type anthropicProvider struct {
52-
options anthropicProviderOptions
51+
type provider struct {
52+
options options
5353
}
5454

55-
type AnthropicOption = func(*anthropicProviderOptions)
55+
type Option = func(*options)
5656

57-
func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider {
58-
options := anthropicProviderOptions{
57+
func New(opts ...Option) ai.Provider {
58+
options := options{
5959
headers: map[string]string{},
6060
}
6161
for _, o := range opts {
@@ -69,42 +69,42 @@ func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider {
6969
options.name = "anthropic"
7070
}
7171

72-
return &anthropicProvider{
72+
return &provider{
7373
options: options,
7474
}
7575
}
7676

77-
func WithAnthropicBaseURL(baseURL string) AnthropicOption {
78-
return func(o *anthropicProviderOptions) {
77+
func WithBaseURL(baseURL string) Option {
78+
return func(o *options) {
7979
o.baseURL = baseURL
8080
}
8181
}
8282

83-
func WithAnthropicAPIKey(apiKey string) AnthropicOption {
84-
return func(o *anthropicProviderOptions) {
83+
func WithAPIKey(apiKey string) Option {
84+
return func(o *options) {
8585
o.apiKey = apiKey
8686
}
8787
}
8888

89-
func WithAnthropicName(name string) AnthropicOption {
90-
return func(o *anthropicProviderOptions) {
89+
func WithName(name string) Option {
90+
return func(o *options) {
9191
o.name = name
9292
}
9393
}
9494

95-
func WithAnthropicHeaders(headers map[string]string) AnthropicOption {
96-
return func(o *anthropicProviderOptions) {
95+
func WithHeaders(headers map[string]string) Option {
96+
return func(o *options) {
9797
maps.Copy(o.headers, headers)
9898
}
9999
}
100100

101-
func WithAnthropicHTTPClient(client option.HTTPClient) AnthropicOption {
102-
return func(o *anthropicProviderOptions) {
101+
func WithHTTPClient(client option.HTTPClient) Option {
102+
return func(o *options) {
103103
o.client = client
104104
}
105105
}
106106

107-
func (a *anthropicProvider) LanguageModel(modelID string) (ai.LanguageModel, error) {
107+
func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
108108
anthropicClientOptions := []option.RequestOption{}
109109
if a.options.apiKey != "" {
110110
anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(a.options.apiKey))
@@ -120,34 +120,34 @@ func (a *anthropicProvider) LanguageModel(modelID string) (ai.LanguageModel, err
120120
if a.options.client != nil {
121121
anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(a.options.client))
122122
}
123-
return anthropicLanguageModel{
124-
modelID: modelID,
125-
provider: fmt.Sprintf("%s.messages", a.options.name),
126-
providerOptions: a.options,
127-
client: anthropic.NewClient(anthropicClientOptions...),
123+
return languageModel{
124+
modelID: modelID,
125+
provider: fmt.Sprintf("%s.messages", a.options.name),
126+
options: a.options,
127+
client: anthropic.NewClient(anthropicClientOptions...),
128128
}, nil
129129
}
130130

131-
type anthropicLanguageModel struct {
132-
provider string
133-
modelID string
134-
client anthropic.Client
135-
providerOptions anthropicProviderOptions
131+
type languageModel struct {
132+
provider string
133+
modelID string
134+
client anthropic.Client
135+
options options
136136
}
137137

138138
// Model implements ai.LanguageModel.
139-
func (a anthropicLanguageModel) Model() string {
139+
func (a languageModel) Model() string {
140140
return a.modelID
141141
}
142142

143143
// Provider implements ai.LanguageModel.
144-
func (a anthropicLanguageModel) Provider() string {
144+
func (a languageModel) Provider() string {
145145
return a.provider
146146
}
147147

148-
func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
148+
func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
149149
params := &anthropic.MessageNewParams{}
150-
providerOptions := &AnthropicProviderOptions{}
150+
providerOptions := &ProviderOptions{}
151151
if v, ok := call.ProviderOptions["anthropic"]; ok {
152152
err := ai.ParseOptions(v, providerOptions)
153153
if err != nil {
@@ -158,7 +158,7 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN
158158
if providerOptions.SendReasoning != nil {
159159
sendReasoning = *providerOptions.SendReasoning
160160
}
161-
systemBlocks, messages, warnings := toAnthropicPrompt(call.Prompt, sendReasoning)
161+
systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)
162162

163163
if call.FrequencyPenalty != nil {
164164
warnings = append(warnings, ai.CallWarning{
@@ -235,7 +235,7 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN
235235
if providerOptions.DisableParallelToolUse != nil {
236236
disableParallelToolUse = *providerOptions.DisableParallelToolUse
237237
}
238-
tools, toolChoice, toolWarnings := toAnthropicTools(call.Tools, call.ToolChoice, disableParallelToolUse)
238+
tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse)
239239
params.Tools = tools
240240
if toolChoice != nil {
241241
params.ToolChoice = *toolChoice
@@ -246,19 +246,19 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN
246246
return params, warnings, nil
247247
}
248248

249-
func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlProviderOptions {
249+
func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
250250
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
251251
if cacheControl, ok := anthropicOptions["cache_control"]; ok {
252252
if cc, ok := cacheControl.(map[string]any); ok {
253-
cacheControlOption := &AnthropicCacheControlProviderOptions{}
253+
cacheControlOption := &CacheControlProviderOptions{}
254254
err := ai.ParseOptions(cc, cacheControlOption)
255255
if err != nil {
256256
return cacheControlOption
257257
}
258258
}
259259
} else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
260260
if cc, ok := cacheControl.(map[string]any); ok {
261-
cacheControlOption := &AnthropicCacheControlProviderOptions{}
261+
cacheControlOption := &CacheControlProviderOptions{}
262262
err := ai.ParseOptions(cc, cacheControlOption)
263263
if err != nil {
264264
return cacheControlOption
@@ -269,9 +269,9 @@ func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlP
269269
return nil
270270
}
271271

272-
func getReasoningMetadata(providerOptions ai.ProviderOptions) *AnthropicReasoningMetadata {
272+
func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
273273
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
274-
reasoningMetadata := &AnthropicReasoningMetadata{}
274+
reasoningMetadata := &ReasoningMetadata{}
275275
err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
276276
if err != nil {
277277
return reasoningMetadata
@@ -333,7 +333,7 @@ func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
333333
return blocks
334334
}
335335

336-
func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
336+
func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
337337
for _, tool := range tools {
338338
if tool.GetType() == ai.ToolTypeFunction {
339339
ft, ok := tool.(ai.FunctionTool)
@@ -414,7 +414,7 @@ func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParalle
414414
return anthropicTools, anthropicToolChoice, warnings
415415
}
416416

417-
func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
417+
func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
418418
var systemBlocks []anthropic.TextBlockParam
419419
var messages []anthropic.MessageParam
420420
var warnings []ai.CallWarning
@@ -638,7 +638,7 @@ func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.Te
638638
return systemBlocks, messages, warnings
639639
}
640640

641-
func (o anthropicLanguageModel) handleError(err error) error {
641+
func (o languageModel) handleError(err error) error {
642642
var apiErr *anthropic.Error
643643
if errors.As(err, &apiErr) {
644644
requestDump := apiErr.DumpRequest(true)
@@ -662,7 +662,7 @@ func (o anthropicLanguageModel) handleError(err error) error {
662662
return err
663663
}
664664

665-
func mapAnthropicFinishReason(finishReason string) ai.FinishReason {
665+
func mapFinishReason(finishReason string) ai.FinishReason {
666666
switch finishReason {
667667
case "end", "stop_sequence":
668668
return ai.FinishReasonStop
@@ -676,7 +676,7 @@ func mapAnthropicFinishReason(finishReason string) ai.FinishReason {
676676
}
677677

678678
// Generate implements ai.LanguageModel.
679-
func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
679+
func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
680680
params, warnings, err := a.prepareParams(call)
681681
if err != nil {
682682
return nil, err
@@ -746,7 +746,7 @@ func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai
746746
CacheCreationTokens: response.Usage.CacheCreationInputTokens,
747747
CacheReadTokens: response.Usage.CacheReadInputTokens,
748748
},
749-
FinishReason: mapAnthropicFinishReason(string(response.StopReason)),
749+
FinishReason: mapFinishReason(string(response.StopReason)),
750750
ProviderMetadata: ai.ProviderMetadata{
751751
"anthropic": make(map[string]any),
752752
},
@@ -755,7 +755,7 @@ func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai
755755
}
756756

757757
// Stream implements ai.LanguageModel.
758-
func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
758+
func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
759759
params, warnings, err := a.prepareParams(call)
760760
if err != nil {
761761
return nil, err
@@ -904,7 +904,7 @@ func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St
904904
yield(ai.StreamPart{
905905
Type: ai.StreamPartTypeFinish,
906906
ID: acc.ID,
907-
FinishReason: mapAnthropicFinishReason(string(acc.StopReason)),
907+
FinishReason: mapFinishReason(string(acc.StopReason)),
908908
Usage: ai.Usage{
909909
InputTokens: acc.Usage.InputTokens,
910910
OutputTokens: acc.Usage.OutputTokens,

0 commit comments

Comments
 (0)