Skip to content

Commit a22e423

Browse files
committed
feat(clients): add sse subscribe and turn follow
1 parent 6da58dc commit a22e423

17 files changed

Lines changed: 3621 additions & 3 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ bin/
88
*.dll
99
*.so
1010
*.dylib
11+
!clients/rust/src/bin/
12+
!clients/rust/src/bin/**
1113

1214
# Rust
1315
target/
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Copyright 2025 StrongDM Inc
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package main
5+
6+
import (
7+
"context"
8+
"encoding/json"
9+
"flag"
10+
"fmt"
11+
"os"
12+
"os/signal"
13+
"syscall"
14+
15+
cxdb "github.com/strongdm/ai-cxdb/clients/go"
16+
"github.com/strongdm/ai-cxdb/clients/go/types"
17+
)
18+
19+
type eventOutput struct {
20+
Kind string `json:"kind"`
21+
Type string `json:"type"`
22+
Data json.RawMessage `json:"data"`
23+
}
24+
25+
type turnOutput struct {
26+
Kind string `json:"kind"`
27+
ContextID uint64 `json:"context_id"`
28+
TurnID uint64 `json:"turn_id"`
29+
Depth uint32 `json:"depth"`
30+
DeclaredTypeID string `json:"declared_type_id,omitempty"`
31+
DeclaredTypeVer uint32 `json:"declared_type_version,omitempty"`
32+
Item *types.ConversationItem `json:"item,omitempty"`
33+
DecodeError string `json:"decode_error,omitempty"`
34+
}
35+
36+
func main() {
37+
var (
38+
eventsURL string
39+
binAddr string
40+
follow bool
41+
useTLS bool
42+
clientTag string
43+
maxEvents int
44+
maxTurns int
45+
maxErrors int
46+
)
47+
48+
flag.StringVar(&eventsURL, "cxdb-events-url", "", "CXDB SSE events URL (required)")
49+
flag.StringVar(&binAddr, "cxdb-bin-addr", "", "CXDB binary address (required for --follow-turns)")
50+
flag.BoolVar(&follow, "follow-turns", false, "Follow turns via binary protocol")
51+
flag.BoolVar(&useTLS, "tls", false, "Use TLS for binary protocol")
52+
flag.StringVar(&clientTag, "client-tag", "", "Optional client tag for binary protocol")
53+
flag.IntVar(&maxEvents, "max-events", 0, "Stop after N SSE events (0 = no limit)")
54+
flag.IntVar(&maxTurns, "max-turns", 0, "Stop after N decoded turns (0 = no limit)")
55+
flag.IntVar(&maxErrors, "max-errors", 0, "Stop after N errors (0 = no limit)")
56+
flag.Parse()
57+
58+
if eventsURL == "" {
59+
fmt.Fprintln(os.Stderr, "--cxdb-events-url is required")
60+
os.Exit(2)
61+
}
62+
if follow && binAddr == "" {
63+
fmt.Fprintln(os.Stderr, "--cxdb-bin-addr is required when --follow-turns is set")
64+
os.Exit(2)
65+
}
66+
67+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
68+
defer cancel()
69+
70+
events, errs := cxdb.SubscribeEvents(ctx, eventsURL)
71+
72+
var client *cxdb.Client
73+
if follow {
74+
var err error
75+
if useTLS {
76+
client, err = cxdb.DialTLS(binAddr, cxdb.WithClientTag(clientTag))
77+
} else {
78+
client, err = cxdb.Dial(binAddr, cxdb.WithClientTag(clientTag))
79+
}
80+
if err != nil {
81+
fmt.Fprintf(os.Stderr, "dial cxdb: %v\n", err)
82+
os.Exit(1)
83+
}
84+
defer client.Close()
85+
}
86+
87+
var eventOut <-chan cxdb.Event = events
88+
var followEvents <-chan cxdb.Event
89+
90+
if follow {
91+
teeOut := make(chan cxdb.Event, 128)
92+
teeFollow := make(chan cxdb.Event, 128)
93+
followEvents = teeFollow
94+
eventOut = teeOut
95+
96+
go func() {
97+
defer close(teeOut)
98+
defer close(teeFollow)
99+
for ev := range events {
100+
select {
101+
case <-ctx.Done():
102+
return
103+
case teeOut <- ev:
104+
}
105+
select {
106+
case <-ctx.Done():
107+
return
108+
case teeFollow <- ev:
109+
}
110+
}
111+
}()
112+
113+
turns, turnErrs := cxdb.FollowTurns(ctx, followEvents, client)
114+
errorCount := consume(ctx, cancel, eventOut, errs, turnErrs, turns, maxEvents, maxTurns, maxErrors)
115+
if maxErrors > 0 && errorCount >= maxErrors {
116+
os.Exit(1)
117+
}
118+
return
119+
}
120+
121+
errorCount := consume(ctx, cancel, eventOut, errs, nil, nil, maxEvents, maxTurns, maxErrors)
122+
if maxErrors > 0 && errorCount >= maxErrors {
123+
os.Exit(1)
124+
}
125+
}
126+
127+
func consume(
128+
ctx context.Context,
129+
cancel context.CancelFunc,
130+
events <-chan cxdb.Event,
131+
errs <-chan error,
132+
turnErrs <-chan error,
133+
turns <-chan cxdb.FollowTurn,
134+
maxEvents int,
135+
maxTurns int,
136+
maxErrors int,
137+
) int {
138+
eventCount := 0
139+
turnCount := 0
140+
errorCount := 0
141+
142+
stopIfDone := func() {
143+
stopOnEvents := maxEvents > 0
144+
stopOnTurns := maxTurns > 0
145+
stopOnErrors := maxErrors > 0
146+
if stopOnErrors && errorCount >= maxErrors {
147+
cancel()
148+
return
149+
}
150+
if (stopOnEvents && eventCount >= maxEvents) || (stopOnTurns && turnCount >= maxTurns) {
151+
if !stopOnEvents || eventCount >= maxEvents {
152+
if !stopOnTurns || turnCount >= maxTurns {
153+
cancel()
154+
}
155+
}
156+
}
157+
}
158+
159+
for {
160+
select {
161+
case <-ctx.Done():
162+
return errorCount
163+
case ev, ok := <-events:
164+
if !ok {
165+
events = nil
166+
break
167+
}
168+
printEvent(ev)
169+
eventCount++
170+
stopIfDone()
171+
case err, ok := <-errs:
172+
if !ok {
173+
errs = nil
174+
break
175+
}
176+
if err != nil {
177+
fmt.Fprintf(os.Stderr, "subscribe error: %v\n", err)
178+
errorCount++
179+
stopIfDone()
180+
}
181+
case err, ok := <-turnErrs:
182+
if !ok {
183+
turnErrs = nil
184+
break
185+
}
186+
if err != nil {
187+
fmt.Fprintf(os.Stderr, "follow error: %v\n", err)
188+
errorCount++
189+
stopIfDone()
190+
}
191+
case turn, ok := <-turns:
192+
if !ok {
193+
turns = nil
194+
break
195+
}
196+
printTurn(turn)
197+
turnCount++
198+
stopIfDone()
199+
}
200+
201+
if events == nil && errs == nil && turns == nil && turnErrs == nil {
202+
return errorCount
203+
}
204+
}
205+
}
206+
207+
func printEvent(ev cxdb.Event) {
208+
out := eventOutput{Kind: "event", Type: ev.Type, Data: ev.Data}
209+
data, err := json.Marshal(out)
210+
if err != nil {
211+
fmt.Fprintf(os.Stderr, "encode event: %v\n", err)
212+
return
213+
}
214+
fmt.Fprintln(os.Stdout, string(data))
215+
}
216+
217+
func printTurn(turn cxdb.FollowTurn) {
218+
result := turnOutput{
219+
Kind: "turn",
220+
ContextID: turn.ContextID,
221+
TurnID: turn.Turn.TurnID,
222+
Depth: turn.Turn.Depth,
223+
DeclaredTypeID: turn.Turn.TypeID,
224+
DeclaredTypeVer: turn.Turn.TypeVersion,
225+
}
226+
227+
if turn.Turn.Encoding != cxdb.EncodingMsgpack {
228+
result.DecodeError = "unsupported encoding"
229+
} else if turn.Turn.Compression != cxdb.CompressionNone {
230+
result.DecodeError = "unsupported compression"
231+
} else {
232+
var item types.ConversationItem
233+
if err := cxdb.DecodeMsgpackInto(turn.Turn.Payload, &item); err != nil {
234+
result.DecodeError = err.Error()
235+
} else {
236+
result.Item = &item
237+
}
238+
}
239+
240+
data, err := json.Marshal(result)
241+
if err != nil {
242+
fmt.Fprintf(os.Stderr, "encode turn: %v\n", err)
243+
return
244+
}
245+
fmt.Fprintln(os.Stdout, string(data))
246+
}

0 commit comments

Comments
 (0)