Skip to content

Commit 8fc2362

Browse files
authored
[#genai] bug fixes
1 parent 19fd24b commit 8fc2362

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

genai/message.go

+38-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ type Message struct {
1818
done bool
1919
}
2020

21+
func NewTextMessage(text string, mimeType string) *Message {
22+
return &Message{
23+
rwer: bytes.NewBufferString(text),
24+
mimeType: mimeType,
25+
}
26+
}
27+
28+
func NewBinMessage(data []byte, mimeType string) *Message {
29+
return &Message{
30+
rwer: bytes.NewBuffer(data),
31+
mimeType: mimeType,
32+
}
33+
}
34+
func NewFileMessage(u string, mimeType string) *Message {
35+
return &Message{
36+
u: &url.URL{Path: u},
37+
mimeType: mimeType,
38+
}
39+
}
40+
2141
// Mime returns the MIME type of the message
2242

2343
func (m *Message) Mime() string {
@@ -34,6 +54,11 @@ func (m *Message) Read(p []byte) (n int, err error) {
3454
return m.rwer.Read(p)
3555
}
3656

57+
// ReadFrom implements the io.ReaderFrom interface
58+
func (m *Message) ReadFrom(r io.Reader) (n int64, err error) {
59+
return io.Copy(m.rwer, r)
60+
}
61+
3762
// SetActor sets the actor that sent the message
3863
func (m *Message) SetActor(actor Actor) {
3964
m.msgActor = actor
@@ -49,6 +74,10 @@ func (m *Message) Write(p []byte) (n int, err error) {
4974
return m.rwer.Write(p)
5075
}
5176

77+
func (m *Message) WriteTo(w io.Writer) (n int64, err error) {
78+
return io.Copy(w, m.rwer)
79+
}
80+
5281
// URL returns the URL of the message message
5382
func (b *Message) URL() *url.URL {
5483
return b.u
@@ -68,7 +97,15 @@ func (m *Message) Done() {
6897
func (m *Message) String() string {
6998

7099
switch m.mimeType {
71-
case ioutils.MimeTextPlain, ioutils.MimeTextHTML, ioutils.MimeMarkDown, ioutils.MimeTextYAML:
100+
case ioutils.MimeTextPlain,
101+
ioutils.MimeTextHTML,
102+
ioutils.MimeMarkDown,
103+
ioutils.MimeTextYAML,
104+
ioutils.MimeApplicationJSON,
105+
ioutils.MimeApplicationXML,
106+
ioutils.MimeTextXML,
107+
ioutils.MimeTextCSS,
108+
ioutils.MimeTextCSV:
72109
return m.rwer.(*bytes.Buffer).String()
73110
default:
74111
return fmt.Sprintf("{mimeType: %s, actor: %s}", m.mimeType, m.msgActor)

genai/provider.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ var GetUnsupportedProviderErr = errutils.NewCustomError("unsupported provider fo
3232

3333
var Providers managers.ItemManager[Provider] = managers.NewItemManager[Provider]()
3434

35-
type StreamingHandller func(message Message, last bool)
35+
type StreamingHandller func(last bool, messages ...*Message)
3636

3737
type Model struct {
3838
Id string `json:"id" yaml:"id" bson:"id"`
@@ -68,8 +68,8 @@ type Provider interface {
6868
// This will be a non-blocking call.
6969
// The handler function will be called for each message generated by the model.
7070
// The last parameter will be true for the last message.
71-
// The exchange will also be updated with the generated messages.
72-
GenerateStream(model string, exchange Exchange, handler StreamingHandller, options Options) error
71+
// The exchange will also be updated with the generated messages once the stream is completed successfully.
72+
GenerateStream(model string, exchange Exchange, handler StreamingHandller, options *Options) error
7373
}
7474

7575
// Options represents the options for the Service
@@ -277,9 +277,9 @@ func (o *Options) GetTemperature(defaultValue float32) float32 {
277277

278278
// GetTopP retrieves the "top_p" option from the Options.
279279
// Returns the value as a float64, or the provided default value if the option does not exist.
280-
func (o *Options) GetTopP(defaultValue float64) float64 {
280+
func (o *Options) GetTopP(defaultValue float32) float32 {
281281
if o.Has(OptionTopP) {
282-
return o.GetFloat64(OptionTopP)
282+
return o.GetFloat32(OptionTopP)
283283
}
284284
return defaultValue
285285
}

0 commit comments

Comments
 (0)