Skip to content

Commit cbc913b

Browse files
authored
Update dall-e openai support (#198)
1 parent 4bdb995 commit cbc913b

4 files changed

Lines changed: 111 additions & 68 deletions

File tree

examples/llm/openai/thread/main.go

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,72 +3,69 @@ package main
33
import (
44
"context"
55
"fmt"
6+
"strings"
67

78
"github.com/henomis/lingoose/llm/openai"
89
"github.com/henomis/lingoose/thread"
10+
"github.com/henomis/lingoose/transformer"
911
)
1012

11-
type Answer struct {
12-
Answer string `json:"answer" jsonschema:"description=the pirate answer"`
13+
type Image struct {
14+
Description string `json:"description" jsonschema:"description=the description of the image that should be created"`
1315
}
1416

15-
func getAnswer(a Answer) string {
16-
return "🦜 ☠️ " + a.Answer
17+
func crateImage(i Image) string {
18+
d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize512x512)
19+
imageURL, err := d.Transform(context.Background(), i.Description)
20+
if err != nil {
21+
return fmt.Errorf("error creating image: %w", err).Error()
22+
}
23+
24+
fmt.Println("Image created with url:", imageURL)
25+
26+
return imageURL.(string)
1727
}
1828

1929
func newStr(str string) *string {
2030
return &str
2131
}
2232

2333
func main() {
24-
openaillm := openai.New()
25-
openaillm.WithToolChoice(newStr("getPirateAnswer"))
34+
openaillm := openai.New().WithModel(openai.GPT4o)
35+
openaillm.WithToolChoice(newStr("auto"))
2636
err := openaillm.BindFunction(
27-
getAnswer,
28-
"getPirateAnswer",
29-
"use this function to get the pirate answer",
37+
crateImage,
38+
"createImage",
39+
"use this function to create an image from a description",
3040
)
3141
if err != nil {
3242
panic(err)
3343
}
3444

3545
t := thread.New().AddMessage(
3646
thread.NewUserMessage().AddContent(
37-
thread.NewTextContent("Hello, I'm a user"),
38-
).AddContent(
39-
thread.NewTextContent("Can you greet me?"),
40-
),
41-
).AddMessage(
42-
thread.NewUserMessage().AddContent(
43-
thread.NewTextContent("please greet me as a pirate."),
47+
thread.NewTextContent("Please, create an image that inspires you"),
4448
),
4549
)
4650

47-
fmt.Println(t)
48-
4951
err = openaillm.Generate(context.Background(), t)
5052
if err != nil {
5153
panic(err)
5254
}
5355

54-
t.AddMessage(thread.NewUserMessage().AddContent(
55-
thread.NewTextContent("now translate to italian as a poem"),
56-
))
56+
if t.LastMessage().Role == thread.RoleTool {
57+
t.AddMessage(thread.NewUserMessage().AddContent(
58+
thread.NewImageContentFromURL(
59+
strings.ReplaceAll(t.LastMessage().Contents[0].AsToolResponseData().Result, `"`, ""),
60+
),
61+
).AddContent(
62+
thread.NewTextContent("can you describe the image?"),
63+
))
5764

58-
fmt.Println(t)
59-
// disable functions
60-
openaillm.WithToolChoice(nil)
61-
openaillm.WithStream(true, func(a string) {
62-
if a == openai.EOS {
63-
fmt.Printf("\n")
64-
return
65+
err = openaillm.Generate(context.Background(), t)
66+
if err != nil {
67+
panic(err)
6568
}
66-
fmt.Printf("%s", a)
67-
})
68-
69-
err = openaillm.Generate(context.Background(), t)
70-
if err != nil {
71-
panic(err)
7269
}
7370

7471
fmt.Println(t)

examples/transformer/dalle/main.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@ package main
22

33
import (
44
"context"
5+
"fmt"
56

67
"github.com/henomis/lingoose/transformer"
78
)
89

910
func main() {
1011

11-
d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize1024).AsFile("test.png")
12+
d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize1024x1024)
1213

13-
_, err := d.Transform(context.Background(), "a goose working with pipelines")
14+
imageURL, err := d.Transform(context.Background(), "a goose working with pipelines")
1415
if err != nil {
1516
panic(err)
1617
}
18+
19+
fmt.Println("Image created:", imageURL)
1720
}

llm/openai/function.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package openai
22

33
import (
4+
"bytes"
45
"encoding/json"
56
"errors"
67
"fmt"
78
"reflect"
9+
"strings"
810

911
"github.com/invopop/jsonschema"
1012
"github.com/sashabaranov/go-openai"
@@ -187,11 +189,14 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er
187189

188190
// Marshal the function result to JSON
189191
if len(result) > 0 {
190-
jsonResultData, errMarshal := json.Marshal(result[0].Interface())
191-
if errMarshal != nil {
192-
return "", fmt.Errorf("error marshaling result: %w", errMarshal)
192+
var resultBytes bytes.Buffer
193+
enc := json.NewEncoder(&resultBytes)
194+
enc.SetEscapeHTML(false)
195+
err = enc.Encode(result[0].Interface())
196+
if err != nil {
197+
return "", fmt.Errorf("error marshaling result: %w", err)
193198
}
194-
return string(jsonResultData), nil
199+
return strings.TrimSpace(resultBytes.String()), nil
195200
}
196201

197202
return "", nil

transformer/dall-e.go

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ type DallEImageOutput any
1717
type DallEImageSize string
1818

1919
const (
20-
DallEImageSize256 DallEImageSize = openai.CreateImageSize256x256
21-
DallEImageSize512 DallEImageSize = openai.CreateImageSize512x512
22-
DallEImageSize1024 DallEImageSize = openai.CreateImageSize1024x1024
20+
DallEImageSize256x256 DallEImageSize = openai.CreateImageSize256x256
21+
DallEImageSize512x512 DallEImageSize = openai.CreateImageSize512x512
22+
DallEImageSize1024x1024 DallEImageSize = openai.CreateImageSize1024x1024
23+
DallEImageSize1792x104 DallEImageSize = openai.CreateImageSize1792x1024
24+
DallEImageSize1024x1792 DallEImageSize = openai.CreateImageSize1024x1792
2325
)
2426

2527
type DallEImageFormat string
@@ -30,19 +32,45 @@ const (
3032
DallEImageFormatImage DallEImageFormat = "image"
3133
)
3234

35+
type DallEModel string
36+
37+
const (
38+
DallEModel2 DallEModel = openai.CreateImageModelDallE2
39+
DallEModel3 DallEModel = openai.CreateImageModelDallE3
40+
)
41+
42+
type DallEImageQuality string
43+
44+
const (
45+
DallEImageQualityHD DallEImageQuality = openai.CreateImageQualityHD
46+
DallEImageQualityStandard DallEImageQuality = openai.CreateImageQualityStandard
47+
)
48+
49+
type DallEImageStyle string
50+
51+
const (
52+
DallEImageStyleVivid DallEImageStyle = openai.CreateImageStyleVivid
53+
DallEImageStyleNatural DallEImageStyle = openai.CreateImageStyleNatural
54+
)
55+
3356
type DallE struct {
3457
openAIClient *openai.Client
58+
model DallEModel
3559
imageSize DallEImageSize
3660
imageFormat DallEImageFormat
37-
imageFile string
61+
imageStyle DallEImageStyle
62+
imageQuality DallEImageQuality
3863
}
3964

4065
func NewDallE() *DallE {
4166
openAIKey := os.Getenv("OPENAI_API_KEY")
4267
return &DallE{
4368
openAIClient: openai.NewClient(openAIKey),
44-
imageSize: DallEImageSize256,
69+
model: DallEModel2,
70+
imageSize: DallEImageSize256x256,
4571
imageFormat: DallEImageFormatURL,
72+
imageStyle: DallEImageStyleNatural,
73+
imageQuality: DallEImageQualityStandard,
4674
}
4775
}
4876

@@ -56,73 +84,83 @@ func (d *DallE) WithImageSize(imageSize DallEImageSize) *DallE {
5684
return d
5785
}
5886

59-
func (d *DallE) AsURL() *DallE {
60-
d.imageFormat = DallEImageFormatURL
87+
func (d *DallE) WithImageStyle(imageStyle DallEImageStyle) *DallE {
88+
d.imageStyle = imageStyle
6189
return d
6290
}
6391

64-
func (d *DallE) AsFile(path string) *DallE {
65-
d.imageFormat = DallEImageFormatFile
66-
d.imageFile = path
92+
func (d *DallE) WithImageQuality(imageQuality DallEImageQuality) *DallE {
93+
d.imageQuality = imageQuality
6794
return d
6895
}
6996

70-
func (d *DallE) AsImage() *DallE {
71-
d.imageFormat = DallEImageFormatImage
97+
func (d *DallE) WithModel(model DallEModel) *DallE {
98+
d.model = model
99+
return d
100+
}
101+
102+
func (d *DallE) WithImageFormat(imageFormat DallEImageFormat) *DallE {
103+
d.imageFormat = imageFormat
72104
return d
73105
}
74106

75107
func (d *DallE) Transform(ctx context.Context, input string) (any, error) {
76108
switch d.imageFormat {
77109
case DallEImageFormatURL:
78-
return d.transformToURL(ctx, input)
110+
return d.TransformAsURL(ctx, input)
79111
case DallEImageFormatFile:
80-
return d.transformToFile(ctx, input)
112+
return d.TransformAsFile(ctx, input, nil)
81113
case DallEImageFormatImage:
82-
return d.transformToImage(ctx, input)
114+
return d.TransformToImage(ctx, input)
83115
default:
84116
return "", fmt.Errorf("unknown image format: %s", d.imageFormat)
85117
}
86118
}
87119

88-
func (d *DallE) transformToURL(ctx context.Context, input string) (any, error) {
120+
func (d *DallE) TransformAsURL(ctx context.Context, input string) (string, error) {
89121
reqURL := openai.ImageRequest{
90122
Prompt: input,
123+
Model: string(d.model),
91124
Size: string(d.imageSize),
125+
Quality: string(d.imageQuality),
126+
Style: string(d.imageStyle),
92127
ResponseFormat: openai.CreateImageResponseFormatURL,
93128
N: 1,
94129
}
95130

96131
respURL, err := d.openAIClient.CreateImage(ctx, reqURL)
97132
if err != nil {
98-
return nil, err
133+
return "", err
99134
}
100135

101136
return respURL.Data[0].URL, nil
102137
}
103138

104-
func (d *DallE) transformToFile(ctx context.Context, input string) (any, error) {
105-
imgData, err := d.transformToImage(ctx, input)
139+
func (d *DallE) TransformAsFile(ctx context.Context, input string, file *os.File) (string, error) {
140+
imgData, err := d.TransformToImage(ctx, input)
106141
if err != nil {
107-
return nil, err
142+
return "", err
108143
}
109144

110-
file, err := os.Create(d.imageFile)
111-
if err != nil {
112-
return nil, err
145+
if file == nil {
146+
// create a temporary file
147+
file, err = os.CreateTemp("", "dall-e-*.png")
148+
if err != nil {
149+
return "", err
150+
}
113151
}
152+
114153
defer file.Close()
115154

116-
err = png.Encode(file, imgData.(image.Image))
155+
err = png.Encode(file, imgData)
117156
if err != nil {
118-
return nil, err
157+
return "", err
119158
}
120159

121-
var output interface{}
122-
return output, nil
160+
return file.Name(), nil
123161
}
124162

125-
func (d *DallE) transformToImage(ctx context.Context, input string) (any, error) {
163+
func (d *DallE) TransformToImage(ctx context.Context, input string) (image.Image, error) {
126164
reqBase64 := openai.ImageRequest{
127165
Prompt: input,
128166
Size: string(d.imageSize),

0 commit comments

Comments
 (0)