Skip to content

Commit ebd9fa4

Browse files
Copybara Servicecopybara-github
authored andcommitted
Add connections and buffers flush support in polling and streaming clients.
PiperOrigin-RevId: 913782444
1 parent 3c6cac3 commit ebd9fa4

8 files changed

Lines changed: 487 additions & 20 deletions

File tree

fleetspeak/src/client/client.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,23 @@ func (c *Client) Stop() {
310310
done <- struct{}{}
311311
log.Info("Messages have been drained.")
312312
}
313+
314+
// FlushServices triggers all services that support it to flush their data.
315+
func (c *Client) FlushServices(ctx context.Context) {
316+
log.InfoContextf(ctx, "Client: FlushServices called")
317+
c.sc.FlushServices(ctx)
318+
}
319+
320+
// ForceResetAndFlushCommunicator programmatically and forcefully tears down the existing
321+
// network connection, establishes a fresh one, and blocks until the outbox is
322+
// successfully drained or the context expires.
323+
func (c *Client) ForceResetAndFlushCommunicator(ctx context.Context) error {
324+
log.InfoContextf(ctx, "ForceResetAndFlushCommunicator triggered")
325+
if c.com == nil {
326+
return fmt.Errorf("communicator not set")
327+
}
328+
log.InfoContextf(ctx, "Calling Reset")
329+
c.com.Reset()
330+
log.InfoContextf(ctx, "Reset finished, calling Flush")
331+
return c.com.Flush(ctx)
332+
}

fleetspeak/src/client/comms/comms.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ type Communicator interface {
3939
Start() error // Tells the communicator to start sending and receiving messages.
4040
Stop() // Tells the communicator to stop sending and receiving messages.
4141

42+
// Reset forcefully tears down any underlying network connections and
43+
// immediately initiates establishing a fresh connection (or polling,
44+
// depending on the implementation) to resume communication.
45+
Reset()
46+
47+
// Flush blocks until the communicator has successfully established a
48+
// fresh connection and flushed all pending messages, or until the provided
49+
// context expires.
50+
Flush(ctx context.Context) error
51+
4252
// GetFileIfModified attempts to retrieve a file from a server, if it
4353
// has been modified since modSince. If it has not been modified, it
4454
// returns nil. Otherwise, it returns a ReadCloser for the file's data

fleetspeak/src/client/https/polling.go

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ type Communicator struct {
5858
clientCertificateHeader string
5959

6060
certBytes []byte
61+
62+
wakeUp chan struct{}
63+
pollComplete chan error
64+
mu sync.Mutex
65+
pollCancel context.CancelFunc
6166
}
6267

6368
// Setup implements comms.Communicator.
@@ -108,6 +113,8 @@ func (c *Communicator) configure() error {
108113
}
109114
c.ctx, c.done = context.WithCancel(context.Background())
110115
c.clientCertificateHeader = si.ClientCertificateHeader
116+
c.wakeUp = make(chan struct{}, 1)
117+
c.pollComplete = make(chan error, 1)
111118
c.certBytes = certBytes
112119
return nil
113120
}
@@ -127,6 +134,39 @@ func (c *Communicator) Stop() {
127134
c.wd.Stop()
128135
}
129136

137+
// Reset implements comms.Communicator.
138+
func (c *Communicator) Reset() {
139+
log.Infof("Reset called")
140+
c.mu.Lock()
141+
if c.pollCancel != nil {
142+
c.pollCancel()
143+
}
144+
c.mu.Unlock()
145+
c.hc.Transport.(*http.Transport).CloseIdleConnections()
146+
// Drain pollComplete to ensure Flush waits for a new poll.
147+
select {
148+
case <-c.pollComplete:
149+
default:
150+
}
151+
select {
152+
case c.wakeUp <- struct{}{}:
153+
default:
154+
}
155+
}
156+
157+
// Flush implements comms.Communicator.
158+
func (c *Communicator) Flush(ctx context.Context) error {
159+
log.InfoContextf(ctx, "Flush called")
160+
for {
161+
select {
162+
case <-ctx.Done():
163+
return ctx.Err()
164+
case err := <-c.pollComplete:
165+
return err
166+
}
167+
}
168+
}
169+
130170
// processingLoop polls the server according to the configured policies while
131171
// the communicator is active.
132172
//
@@ -152,11 +192,19 @@ func (c *Communicator) processingLoop() {
152192
// and updates the variables defined above. In case of failure it also sleeps
153193
// for the MinFailureDelay.
154194
poll := func() {
195+
var err error
196+
defer func() {
197+
select {
198+
case c.pollComplete <- err:
199+
default:
200+
}
201+
}()
155202
c.wd.Reset()
156203
if c.cctx.CurrentID() != c.id {
157204
c.configure()
158205
}
159-
active, err := c.poll(toSend)
206+
var active bool
207+
active, err = c.poll(toSend)
160208
if err != nil {
161209
log.Warningf("Failure during polling: %v", err)
162210
for _, m := range toSend {
@@ -175,6 +223,8 @@ func (c *Communicator) processingLoop() {
175223
case <-t.C:
176224
case <-c.ctx.Done():
177225
t.Stop()
226+
case <-c.wakeUp:
227+
t.Stop()
178228
}
179229
return
180230
}
@@ -255,6 +305,9 @@ func (c *Communicator) processingLoop() {
255305
return
256306
case <-t.C:
257307
poll()
308+
case <-c.wakeUp:
309+
t.Stop()
310+
poll()
258311
case m := <-c.cctx.Outbox():
259312
t.Stop()
260313
toSend = append(toSend, m)
@@ -324,12 +377,16 @@ func (c *Communicator) pollHost(host string, data []byte) (*fspb.ContactData, er
324377

325378
u := url.URL{Scheme: "https", Host: host, Path: "/message"}
326379

380+
c.mu.Lock()
381+
compression := c.conf.GetCompression()
382+
c.mu.Unlock()
383+
327384
body := &bytes.Buffer{}
328-
if c.conf.GetCompression() == fspb.CompressionAlgorithm_COMPRESSION_NONE {
385+
if compression == fspb.CompressionAlgorithm_COMPRESSION_NONE {
329386
// Shortcut to prevent unnecessary copying of data
330387
body = bytes.NewBuffer(data)
331388
} else {
332-
bw := CompressingWriter(body, c.conf.GetCompression())
389+
bw := CompressingWriter(body, compression)
333390
bw.Write(data)
334391
bw.Close()
335392
}
@@ -339,6 +396,17 @@ func (c *Communicator) pollHost(host string, data []byte) (*fspb.ContactData, er
339396
if sendErr != nil {
340397
return nil, sendErr
341398
}
399+
var reqCtx context.Context
400+
c.mu.Lock()
401+
reqCtx, c.pollCancel = context.WithCancel(c.ctx)
402+
c.mu.Unlock()
403+
defer func() {
404+
c.mu.Lock()
405+
c.pollCancel()
406+
c.pollCancel = nil
407+
c.mu.Unlock()
408+
}()
409+
req = req.WithContext(reqCtx)
342410
SetContentEncoding(req.Header, c.conf.GetCompression())
343411
if c.clientCertificateHeader != "" {
344412
bc := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.certBytes})

fleetspeak/src/client/https/polling_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,146 @@ func TestCertificateRevoked(t *testing.T) {
733733
tl.Close()
734734
cl.Stop()
735735
}
736+
737+
func TestReset(t *testing.T) {
738+
var c Communicator
739+
conf := config.Configuration{
740+
Servers: []string{"localhost:1234"},
741+
CommunicatorConfig: &clpb.CommunicatorConfig{
742+
MaxPollDelaySeconds: 10,
743+
MinFailureDelaySeconds: 10,
744+
},
745+
}
746+
dialCalls := make(chan struct{}, 10)
747+
c.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
748+
dialCalls <- struct{}{}
749+
<-ctx.Done()
750+
return nil, ctx.Err()
751+
}
752+
753+
cl, err := client.New(
754+
conf,
755+
client.Components{
756+
Communicator: &c,
757+
})
758+
if err != nil {
759+
t.Fatalf("unable to create client: %v", err)
760+
}
761+
defer cl.Stop()
762+
763+
// Wait for first dial attempt
764+
select {
765+
case <-dialCalls:
766+
case <-time.After(5 * time.Second):
767+
t.Fatal("Timed out waiting for first dial attempt")
768+
}
769+
770+
// Call Reset
771+
time.Sleep(100 * time.Millisecond)
772+
c.Reset()
773+
774+
// Wait for second dial attempt (should be immediate, not waiting 10s)
775+
select {
776+
case <-dialCalls:
777+
case <-time.After(2 * time.Second):
778+
t.Fatal("Timed out waiting for second dial attempt after Reset")
779+
}
780+
}
781+
782+
func TestFlush(t *testing.T) {
783+
// Create a local https server for the client to talk to.
784+
pemCert, pemKey, err := common_util.ServerCert()
785+
if err != nil {
786+
t.Fatal(err)
787+
}
788+
cp, err := tls.X509KeyPair(pemCert, pemKey)
789+
if err != nil {
790+
t.Fatal(err)
791+
}
792+
ad, err := net.ResolveTCPAddr("tcp", "localhost:0")
793+
if err != nil {
794+
t.Fatal(err)
795+
}
796+
tl, err := net.ListenTCP("tcp", ad)
797+
if err != nil {
798+
t.Fatal(err)
799+
}
800+
addr := tl.Addr().String()
801+
802+
pollCount := int32(0)
803+
mux := http.NewServeMux()
804+
mux.HandleFunc("/message", func(res http.ResponseWriter, req *http.Request) {
805+
atomic.AddInt32(&pollCount, 1)
806+
cd := fspb.ContactData{
807+
SequencingNonce: 42,
808+
}
809+
buf, err := proto.Marshal(&cd)
810+
if err != nil {
811+
t.Fatalf("unable to marshal ContactData: %v", err)
812+
}
813+
res.Header().Set("Content-Type", "application/octet-stream")
814+
res.WriteHeader(http.StatusOK)
815+
res.Write(buf)
816+
})
817+
818+
server := http.Server{
819+
Addr: addr,
820+
Handler: mux,
821+
TLSConfig: &tls.Config{
822+
ClientAuth: tls.RequireAnyClientCert,
823+
Certificates: []tls.Certificate{cp},
824+
NextProtos: []string{"h2"},
825+
},
826+
}
827+
l := tls.NewListener(tl, server.TLSConfig)
828+
go server.Serve(l)
829+
defer tl.Close()
830+
831+
var c Communicator
832+
conf := config.Configuration{
833+
TrustedCerts: x509.NewCertPool(),
834+
Servers: []string{addr},
835+
CommunicatorConfig: &clpb.CommunicatorConfig{
836+
MaxPollDelaySeconds: 10,
837+
MinFailureDelaySeconds: 10,
838+
},
839+
}
840+
if !conf.TrustedCerts.AppendCertsFromPEM(pemCert) {
841+
t.Fatal("unable to add server cert to pool")
842+
}
843+
844+
cl, err := client.New(
845+
conf,
846+
client.Components{
847+
Communicator: &c,
848+
})
849+
if err != nil {
850+
t.Fatalf("unable to create client: %v", err)
851+
}
852+
defer cl.Stop()
853+
854+
// Wait for first poll
855+
for atomic.LoadInt32(&pollCount) == 0 {
856+
time.Sleep(100 * time.Millisecond)
857+
}
858+
859+
// Call Flush in a goroutine
860+
flushDone := make(chan error, 1)
861+
go func() {
862+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
863+
defer cancel()
864+
flushDone <- c.Flush(ctx)
865+
}()
866+
867+
// Trigger a poll via wakeUp (indirectly via Reset)
868+
c.Reset()
869+
870+
select {
871+
case err := <-flushDone:
872+
if err != nil {
873+
t.Errorf("Flush failed: %v", err)
874+
}
875+
case <-time.After(6 * time.Second):
876+
t.Fatal("Timed out waiting for Flush")
877+
}
878+
}

0 commit comments

Comments
 (0)