Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions cmd/cdpgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,6 @@ func (d *domainClient) %s(ctx context.Context) (%s, error) {
g.Printf(`
type %[4]s struct { rpcc.Stream }

// GetStream returns the original Stream for use with cdp.Sync.
func (c *%[4]s) GetStream() rpcc.Stream { return c.Stream }

func (c *%[4]s) Recv() (*%[3]s, error) {
event := new(%[3]s)
if err := c.RecvMsg(event); err != nil {
Expand Down
4 changes: 3 additions & 1 deletion rpcc/call.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package rpcc

import "context"
import (
"context"
)

type rpcCall struct {
Method string
Expand Down
3 changes: 3 additions & 0 deletions rpcc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ type Stream interface {
// RecvMsg will return ErrStreamClosing once all pending messages
// have been received.
Close() error
// Sync is used internally to synchronize Streams. It is made
// public here so that Stream synchronization can be mocked.
Sync(store interface{}) (activate func(ok bool) (done func()), err error)
}

// NewStream creates a new stream that listens to notifications from the
Expand Down
84 changes: 49 additions & 35 deletions rpcc/stream_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// to be called before loading the next.
type syncMessageStore struct {
mu sync.Mutex
conn *Conn // Used as validation.
writers map[string]streamWriter
backlog []*message
pending bool
Expand All @@ -27,6 +28,12 @@ func (s *syncMessageStore) subscribe(method string, w streamWriter, conn *Conn)
s.mu.Lock()
defer s.mu.Unlock()

if s.conn == nil {
s.conn = conn
} else if conn != s.conn {
return nil, fmt.Errorf("rpcc: same Conn must be used")
}

if _, ok := s.writers[method]; ok {
return nil, fmt.Errorf("%s already subscribed", method)
}
Expand Down Expand Up @@ -126,57 +133,64 @@ func Sync(s ...Stream) (err error) {
}

store := newSyncMessageStore()
var swap []func(bool) func()

defer func() {
// Perform swap, mutex lock (streamClient.mu) is still active.
for _, s := range swap {
defer s(err == nil)()
}
if err != nil {
store.close()
}
}()

var conn *Conn
var swap []func()

for _, ss := range s {
sc, ok := ss.(*streamClient)
if !ok {
return fmt.Errorf("rpcc: Sync: bad Stream type: %T", ss)
}
if conn == nil {
conn = sc.conn
}
if sc.conn != conn {
return errors.New("rpcc: Sync: all Streams must share same Conn")
swapFn, err := ss.Sync(store)
if err != nil {
return err
}
swap = append(swap, swapFn)
}

// The Stream lock must be held until the
// swap has been done for all streams.
sc.mu.Lock()
defer sc.mu.Unlock()

if sc.remove == nil {
return errors.New("rpcc: Sync: Stream is closed")
}
return nil
}

// Allow store to manage messages to streamClient.
unsub, err := store.subscribe(sc.method, sc, sc.conn)
func (s *streamClient) Sync(storer interface{}) (activate func(bool) func(), err error) {
// The Stream lock must be held until the
// swap has been done for all streams.
s.mu.Lock()
defer func() {
if err != nil {
return errors.New("rpcc: Sync: " + err.Error())
s.mu.Unlock()
}
}()

// Delay listener swap until all Streams have been
// processed so that we can abort on error.
swap = append(swap, func() {
sc.remove() // Prevent direct events from Conn.
sc.remove = unsub // Remove from store on Close.
store, ok := storer.(*syncMessageStore)
if !ok {
return nil, fmt.Errorf("streamClient: Sync: bad store %T must be of type *syncMessageStore", storer)
}

// Clear stream messages to prevent sync issues.
sc.mbuf.clear()
})
if s.remove == nil {
return nil, errors.New("rpcc: Sync: Stream is closed")
}

// Perform swap, mutex lock (streamClient.mu) is still active.
for _, s := range swap {
s()
// Allow store to manage messages to streamClient.
unsub, err := store.subscribe(s.method, s, s.conn)
if err != nil {
return nil, errors.New("rpcc: Sync: " + err.Error())
}

return nil
// Delay listener swap until all Streams have been
// processed so that we can abort on error.
return func(ok bool) func() {
if ok {
s.remove() // Prevent direct events from Conn.
s.remove = unsub // Remove from store on Close.

// Clear stream messages to prevent sync issues.
s.mbuf.clear()
}
return func() { s.mu.Unlock() }
}, nil
}
4 changes: 4 additions & 0 deletions rpcc/stream_sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpcc

import (
"context"
"errors"
"strconv"
"testing"

Expand Down Expand Up @@ -78,6 +79,9 @@ type fakeStream struct{}
func (s *fakeStream) Ready() <-chan struct{} { return nil }
func (s *fakeStream) RecvMsg(m interface{}) error { return nil }
func (s *fakeStream) Close() error { return nil }
func (s *fakeStream) Sync(store interface{}) (func(bool) func(), error) {
return nil, errors.New("fake stream")
}

var (
_ Stream = (*fakeStream)(nil)
Expand Down
22 changes: 2 additions & 20 deletions sync.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
package cdp

import (
"fmt"

"github.com/mafredri/cdp/rpcc"
)

type eventClient interface {
rpcc.Stream
}

type getStreamer interface {
GetStream() rpcc.Stream
}

// Sync takes two or more event clients and sets them into synchronous operation,
// relative to each other. This operation cannot be undone. If an error is
// returned this function is no-op and the event clients will continue in
Expand All @@ -30,14 +20,6 @@ type getStreamer interface {
// order of arrival. If an event for both A and B is triggered, in that order,
// it will not be possible to receive the event from B before the event from A
// has been received.
func Sync(c ...eventClient) error {
var s []rpcc.Stream
for _, cc := range c {
cs, ok := cc.(getStreamer)
if !ok {
return fmt.Errorf("cdp: Sync: bad eventClient type: %T", cc)
}
s = append(s, cs.GetStream())
}
return rpcc.Sync(s...)
func Sync(eventClients ...rpcc.Stream) error {
return rpcc.Sync(eventClients...)
}