Skip to content

Commit caf38c1

Browse files
grvsahilGaurav Sahilmaha-hajja
authored
feat: configurable cohere processors destination (#2235)
* feat: configurable cohere processors destination * resolve PR comments * fix: embed test and generated files --------- Co-authored-by: Gaurav Sahil <[email protected]> Co-authored-by: Maha M <[email protected]>
1 parent 9a41a17 commit caf38c1

File tree

11 files changed

+171
-73
lines changed

11 files changed

+171
-73
lines changed

pkg/plugin/processor/builtin/impl/cohere/command.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ func (p *CommandProcessor) Process(ctx context.Context, records []opencdc.Record
146146
return append(out, sdk.ErrorRecord{Error: fmt.Errorf("failed to resolve reference %v: %w", p.config.RequestBodyRef, err)})
147147
}
148148

149-
content := fmt.Sprintf(p.config.Prompt, p.getInput(requestRef.Get()))
149+
input, err := p.getInput(requestRef.Get())
150+
if err != nil {
151+
return append(out, sdk.ErrorRecord{Error: fmt.Errorf("failed to get input: %w", err)})
152+
}
153+
content := fmt.Sprintf(p.config.Prompt, input)
150154
for {
151155
resp, err := p.client.command(ctx, content)
152156
attempt := p.backoffCfg.Attempt()
@@ -248,14 +252,16 @@ func unmarshalChatResponse(res []byte) (*ChatResponse, error) {
248252
return response, nil
249253
}
250254

251-
func (p *CommandProcessor) getInput(val any) string {
255+
func (p *CommandProcessor) getInput(val any) (string, error) {
252256
switch v := val.(type) {
253-
case opencdc.RawData:
254-
return string(v)
255-
case opencdc.StructuredData:
256-
return string(v.Bytes())
257+
case opencdc.Position:
258+
return string(v), nil
259+
case opencdc.Data:
260+
return string(v.Bytes()), nil
261+
case string:
262+
return v, nil
257263
default:
258-
return fmt.Sprintf("%v", v)
264+
return "", fmt.Errorf("unsupported type %T", v)
259265
}
260266
}
261267

@@ -266,12 +272,12 @@ func (p *CommandProcessor) setField(r *opencdc.Record, refRes *sdk.ReferenceReso
266272

267273
ref, err := refRes.Resolve(r)
268274
if err != nil {
269-
return fmt.Errorf("error reference resolver: %w", err)
275+
return fmt.Errorf("error resolving reference: %w", err)
270276
}
271277

272278
err = ref.Set(data)
273279
if err != nil {
274-
return fmt.Errorf("error reference set: %w", err)
280+
return fmt.Errorf("error setting reference: %w", err)
275281
}
276282

277283
return nil

pkg/plugin/processor/builtin/impl/cohere/command_examples_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func ExampleCommandProcessor() {
3333
Summary: `Generate responses using Cohere's command model`,
3434
Description: `
3535
This example demonstrates how to use the Cohere command processor to generate responses for a record's ` + "`.Payload.After`" + ` field.
36-
The processor sends the input text to the Cohere API and replaces it with the model's response.`,
36+
The processor sends the input text from the configured "request.body" to the Cohere API and stores the model's response into the configured "response.body"`,
3737
Config: config.Config{
3838
commandProcessorConfigApiKey: "apikey",
3939
commandProcessorConfigPrompt: "hello",

pkg/plugin/processor/builtin/impl/cohere/embed.go

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ type embedProcConfig struct {
5353
// The maximum waiting time before retrying.
5454
BackoffRetryMax time.Duration `json:"backoffRetry.max" default:"5s"`
5555
// Specifies the field from which the request body should be created.
56-
InputField string `json:"inputField" validate:"regex=^\\.(Payload|Key).*" default:".Payload.After"`
56+
InputField string `json:"inputField" default:".Payload.After"`
57+
// OutputField specifies which field will the response body be saved at.
58+
OutputField string `json:"outputField" default:".Payload.After"`
5759
// MaxTextsPerRequest controls the number of texts sent in each Cohere embedding API call (max 96)
5860
MaxTextsPerRequest int `json:"maxTextsPerRequest" default:"96"`
5961
}
@@ -75,8 +77,9 @@ type embedModel interface {
7577
type EmbedProcessor struct {
7678
sdk.UnimplementedProcessor
7779

78-
inputFieldRefResolver *sdk.ReferenceResolver
79-
logger log.CtxLogger
80+
inputFieldRefResolver *sdk.ReferenceResolver
81+
outputFieldRefResolver *sdk.ReferenceResolver
82+
logger log.CtxLogger
8083

8184
config embedProcConfig
8285
backoffCfg *backoff.Backoff
@@ -109,10 +112,16 @@ func (p *EmbedProcessor) Configure(ctx context.Context, cfg config.Config) error
109112
func (p *EmbedProcessor) Open(ctx context.Context) error {
110113
inputResolver, err := sdk.NewReferenceResolver(p.config.InputField)
111114
if err != nil {
112-
return cerrors.Errorf(`failed to create a field resolver for %v parameter: %w`, p.config.InputField, err)
115+
return cerrors.Errorf("failed to create a field resolver for %v parameter: %w", p.config.InputField, err)
113116
}
114117
p.inputFieldRefResolver = &inputResolver
115118

119+
outputResolver, err := sdk.NewReferenceResolver(p.config.OutputField)
120+
if err != nil {
121+
return cerrors.Errorf("failed to create a field resolver for %v parameter: %w", p.config.OutputField, err)
122+
}
123+
p.outputFieldRefResolver = &outputResolver
124+
116125
// Initialize the client only if it hasn't been injected
117126
if p.client == nil {
118127
p.client = &embedClient{
@@ -136,12 +145,14 @@ func (p *EmbedProcessor) Specification() (sdk.Specification, error) {
136145
// parameters it expects.
137146

138147
return sdk.Specification{
139-
Name: "cohere.embed",
140-
Summary: "Conduit processor for Cohere's embed model.",
141-
Description: "Conduit processor for Cohere's embed model.",
142-
Version: "v0.1.0",
143-
Author: "Meroxa, Inc.",
144-
Parameters: embedProcConfig{}.Parameters(),
148+
Name: "cohere.embed",
149+
Summary: "Conduit processor for Cohere's embed model.",
150+
Description: "The Cohere embed processor extracts text from the configured inputField, generates embeddings " +
151+
"using Cohere's embedding model, and stores the result in the configured outputField. " +
152+
"The embeddings are compressed using the zstd algorithm for efficient storage and transmission.",
153+
Version: "v0.1.0",
154+
Author: "Meroxa, Inc.",
155+
Parameters: embedProcConfig{}.Parameters(),
145156
}, nil
146157
}
147158

@@ -174,7 +185,12 @@ func (p *EmbedProcessor) processBatch(ctx context.Context, records []opencdc.Rec
174185
if err != nil {
175186
return out, cerrors.Errorf("failed to resolve reference %v: %w", p.config.InputField, err)
176187
}
177-
embeddingInputs = append(embeddingInputs, p.getEmbeddingInput(inRef.Get()))
188+
189+
input, err := p.getInput(inRef.Get())
190+
if err != nil {
191+
return out, cerrors.Errorf("failed to get input: %w", err)
192+
}
193+
embeddingInputs = append(embeddingInputs, input)
178194
}
179195

180196
var embeddings [][]float64
@@ -231,12 +247,8 @@ func (p *EmbedProcessor) processBatch(ctx context.Context, records []opencdc.Rec
231247
return out, cerrors.Errorf("failed to compress embeddings: %w", err)
232248
}
233249

234-
// Store the embedding in .Payload.After
235-
switch record.Payload.After.(type) {
236-
case opencdc.RawData:
237-
record.Payload.After = opencdc.RawData(compressedEmbedding)
238-
case opencdc.StructuredData:
239-
record.Payload.After = opencdc.StructuredData{"embedding": compressedEmbedding}
250+
if err := p.setField(&record, p.outputFieldRefResolver, compressedEmbedding); err != nil {
251+
return out, cerrors.Errorf("failed to set output: %w", err)
240252
}
241253

242254
out = append(out, sdk.SingleRecord(record))
@@ -245,14 +257,34 @@ func (p *EmbedProcessor) processBatch(ctx context.Context, records []opencdc.Rec
245257
return out, nil
246258
}
247259

248-
func (p *EmbedProcessor) getEmbeddingInput(val any) string {
260+
func (p *EmbedProcessor) setField(r *opencdc.Record, refRes *sdk.ReferenceResolver, data any) error {
261+
if refRes == nil {
262+
return nil
263+
}
264+
265+
ref, err := refRes.Resolve(r)
266+
if err != nil {
267+
return cerrors.Errorf("error resolving reference: %w", err)
268+
}
269+
270+
err = ref.Set(data)
271+
if err != nil {
272+
return cerrors.Errorf("error setting reference: %w", err)
273+
}
274+
275+
return nil
276+
}
277+
278+
func (p *EmbedProcessor) getInput(val any) (string, error) {
249279
switch v := val.(type) {
250-
case opencdc.RawData:
251-
return string(v)
252-
case opencdc.StructuredData:
253-
return string(v.Bytes())
280+
case opencdc.Position:
281+
return string(v), nil
282+
case opencdc.Data:
283+
return string(v.Bytes()), nil
284+
case string:
285+
return v, nil
254286
default:
255-
return fmt.Sprintf("%v", v)
287+
return "", fmt.Errorf("unsupported type %T", v)
256288
}
257289
}
258290

pkg/plugin/processor/builtin/impl/cohere/embed_examples_test.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,69 @@
1515
package cohere
1616

1717
import (
18+
"context"
19+
"fmt"
20+
1821
"github.com/conduitio/conduit-commons/config"
1922
"github.com/conduitio/conduit-commons/opencdc"
2023
sdk "github.com/conduitio/conduit-processor-sdk"
2124
"github.com/conduitio/conduit/pkg/foundation/log"
2225
"github.com/conduitio/conduit/pkg/plugin/processor/builtin/internal/exampleutil"
26+
"github.com/goccy/go-json"
2327
)
2428

2529
func ExampleEmbedProcessor() {
2630
p := NewEmbedProcessor(log.Nop())
2731
p.client = mockEmbedClient{}
2832

33+
embedding, err := p.client.embed(context.Background(), []string{"test input"})
34+
if err != nil {
35+
panic(fmt.Sprintf("failed to get embedding: %v", err))
36+
}
37+
if len(embedding) == 0 {
38+
panic("no embeddings found")
39+
}
40+
41+
embeddingJSON, err := json.Marshal(embedding[0])
42+
if err != nil {
43+
panic(fmt.Sprintf("failed to marshal embeddings: %v", err))
44+
}
45+
46+
// Compress the embedding using zstd
47+
compressedEmbedding, err := compressData(embeddingJSON)
48+
if err != nil {
49+
panic(fmt.Sprintf("failed to compress embeddings: %v", err))
50+
}
51+
2952
exampleutil.RunExample(p, exampleutil.Example{
3053
Summary: `Generate embeddings using Cohere's embedding model`,
3154
Description: `
3255
This example demonstrates how to use the Cohere embedding processor to generate embeddings for a record.
33-
The processor extracts text from the specified input field (default: ".Payload.After"), sends it to the Cohere API,
34-
and stores the resulting embeddings in the record's ".Payload.After" field as compressed data using the zstd algorithm.
56+
The processor extracts text from the configured "inputField" (default: ".Payload.After"), sends it to the Cohere API,
57+
and stores the resulting embeddings in the configured "outputField" as compressed data using the zstd algorithm.
3558
3659
In this example, the processor is configured with a mock client and an API key. The input record's metadata is updated
37-
to include the embedding model used ("embed-english-v2.0"). Note that the compressed embeddings cannot be directly compared
38-
in this test, so the focus is on verifying the metadata update.`,
60+
to include the embedding model used ("embed-english-v2.0").`,
3961
Config: config.Config{
40-
"apiKey": "fake-api-key",
62+
"apiKey": "fake-api-key",
63+
"inputField": ".Payload.After",
64+
"outputField": ".Payload.After",
4165
},
4266
Have: opencdc.Record{
4367
Operation: opencdc.OperationCreate,
4468
Position: opencdc.Position("pos-1"),
4569
Metadata: map[string]string{},
70+
Payload: opencdc.Change{
71+
After: opencdc.RawData("test input"),
72+
},
4673
},
4774
Want: sdk.SingleRecord{
4875
Operation: opencdc.OperationCreate,
4976
Position: opencdc.Position("pos-1"),
5077
Metadata: opencdc.Metadata{"cohere.embed.model": "embed-english-v2.0"},
78+
Payload: opencdc.Change{
79+
After: opencdc.RawData(compressedEmbedding),
80+
},
5181
},
5282
})
5383

@@ -66,7 +96,8 @@ in this test, so the focus is on verifying the metadata update.`,
6696
// "key": null,
6797
// "payload": {
6898
// "before": null,
69-
// "after": null
99+
// - "after": "test input"
100+
// + "after": "(\ufffd/\ufffd\u0004\u0000i\u0000\u0000[0.1,0.2,0.3]\ufffd^xH"
70101
// }
71102
// }
72103
}

pkg/plugin/processor/builtin/impl/cohere/embed_test.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,10 @@ func TestEmbedProcessor_Process(t *testing.T) {
162162
wantErr string
163163
}{
164164
{
165-
name: "successful process with single record",
165+
name: "successful process single record to replace .Payload.After with result of the request",
166166
config: config.Config{
167-
embedProcConfigApiKey: "api-key",
167+
embedProcConfigApiKey: "api-key",
168+
embedProcConfigOutputField: ".Payload.After",
168169
},
169170
records: []opencdc.Record{
170171
{
@@ -201,9 +202,26 @@ func TestEmbedProcessor_Process(t *testing.T) {
201202
wantErr: "",
202203
},
203204
{
204-
name: "successful process with single record having structured data in payload",
205+
name: "failed to process single record to set new field 'response' in .Payload.After having raw data",
205206
config: config.Config{
206-
embedProcConfigApiKey: "api-key",
207+
embedProcConfigApiKey: "api-key",
208+
embedProcConfigOutputField: ".Payload.After.response",
209+
},
210+
records: []opencdc.Record{
211+
{
212+
Payload: opencdc.Change{
213+
After: opencdc.RawData("test input"),
214+
},
215+
Metadata: map[string]string{},
216+
},
217+
},
218+
wantErr: `failed to set output: error resolving reference: could not resolve field "response": .Payload.After does not contain structured data: cannot resolve reference`,
219+
},
220+
{
221+
name: "successful process single record to set new field 'response' in .Payload.After having structured data",
222+
config: config.Config{
223+
embedProcConfigApiKey: "api-key",
224+
embedProcConfigOutputField: ".Payload.After.response",
207225
},
208226
records: []opencdc.Record{
209227
{
@@ -228,7 +246,7 @@ func TestEmbedProcessor_Process(t *testing.T) {
228246
result := []sdk.ProcessedRecord{
229247
sdk.SingleRecord(opencdc.Record{
230248
Payload: opencdc.Change{
231-
After: opencdc.StructuredData{"embedding": compressedEmbedding},
249+
After: opencdc.StructuredData{"test": "testInput", "response": compressedEmbedding},
232250
},
233251
Metadata: map[string]string{
234252
EmbedModelMetadata: "embed-english-v2.0",

pkg/plugin/processor/builtin/impl/cohere/paramgen_embed_proc.go

Lines changed: 8 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)