-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.go
More file actions
133 lines (115 loc) · 2.97 KB
/
Copy pathmain.go
File metadata and controls
133 lines (115 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Text-generation demonstrates text generation with provider selection.
// It can run with OpenAI, Gemini, or both providers in parallel to compare results.
//
// Usage:
//
// go run examples/text-generation/main.go
// go run examples/text-generation/main.go -openai
// go run examples/text-generation/main.go -gemini
// go run examples/text-generation/main.go -openai -gemini -debug
package main
import (
"context"
"flag"
"fmt"
"log"
"os"
"sync"
"log/slog"
"github.com/montanaflynn/grail"
"github.com/montanaflynn/grail/providers/gemini"
"github.com/montanaflynn/grail/providers/openai"
)
// Demonstrates text generation with provider selection.
func main() {
ctx := context.Background()
openaiFlag := flag.Bool("openai", false, "use OpenAI provider")
geminiFlag := flag.Bool("gemini", false, "use Gemini provider")
debugFlag := flag.Bool("debug", false, "enable debug logging")
flag.Parse()
level := slog.LevelInfo
if *debugFlag {
level = slog.LevelDebug
}
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
runOpenAI := *openaiFlag
runGemini := *geminiFlag || (!*openaiFlag && !*geminiFlag)
type result struct {
provider string
text string
err error
}
var wg sync.WaitGroup
resultsCh := make(chan result, 2)
if runGemini {
wg.Add(1)
go func() {
defer wg.Done()
text, err := generateWithProvider(ctx, logger, "gemini", "GEMINI_API_KEY")
resultsCh <- result{provider: "gemini", text: text, err: err}
}()
}
if runOpenAI {
wg.Add(1)
go func() {
defer wg.Done()
text, err := generateWithProvider(ctx, logger, "openai", "OPENAI_API_KEY")
resultsCh <- result{provider: "openai", text: text, err: err}
}()
}
go func() {
wg.Wait()
close(resultsCh)
}()
for res := range resultsCh {
if res.err != nil {
log.Printf("%s: generate text error: %v", res.provider, res.err)
continue
}
if res.text == "" {
log.Printf("%s: empty text response", res.provider)
continue
}
fmt.Printf("[%s] %s\n", res.provider, res.text)
}
}
func generateWithProvider(ctx context.Context, logger *slog.Logger, providerName, envKey string) (string, error) {
key := os.Getenv(envKey)
var (
provider grail.Provider
err error
)
switch providerName {
case "gemini":
provider, err = gemini.New(
ctx,
gemini.WithAPIKey(key),
)
case "openai":
provider, err = openai.New(
openai.WithAPIKey(key),
)
default:
return "", fmt.Errorf("unknown provider %q", providerName)
}
if err != nil {
return "", fmt.Errorf("new %s provider: %w", providerName, err)
}
client := grail.NewClient(provider, grail.WithLogger(logger))
return generateText(ctx, client)
}
func generateText(ctx context.Context, client grail.Client) (string, error) {
res, err := client.Generate(ctx, grail.Request{
Inputs: []grail.Input{
grail.InputText("Explain how AI works in a few words"),
},
Output: grail.OutputText(),
})
if err != nil {
return "", err
}
text, _ := res.Text()
return text, nil
}