Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions client/grpc/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"runtime"
"time"

"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"

"github.com/netbirdio/netbird/util/embeddedroots"
)

// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")

// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
Expand All @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}

// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()

state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}

if state == connectivity.Shutdown {
return ErrConnectionShutdown
}

return nil
}

// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
Expand All @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}))
}

connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

conn, err := grpc.DialContext(
connCtx,
conn, err := grpc.NewClient(
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, fmt.Errorf("new client: %w", err)
}

ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err
}

Expand Down
3 changes: 1 addition & 2 deletions client/grpc/dialer_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)

func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
Expand All @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {

conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
Expand Down
12 changes: 7 additions & 5 deletions client/net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID
}

// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
Expand All @@ -29,21 +28,24 @@ type TCPConn struct {
ID hooks.ConnectionID
}

// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}

// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
cleanupConnID(id)
return err
}

// cleanupConnID executes close hooks for a connection ID.
func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}

return err
}
1 change: 0 additions & 1 deletion client/net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}

if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion client/net/dialer_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.

conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}

Expand Down Expand Up @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str

ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
return fmt.Errorf("resolve address %s: %w", address, err)
}

log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
Expand Down
4 changes: 2 additions & 2 deletions client/net/listener_listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr)
}

// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
Expand All @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr)
}

// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)
Expand Down
3 changes: 1 addition & 2 deletions shared/management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}
Expand Down
3 changes: 1 addition & 2 deletions shared/signal/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
return fmt.Errorf("create connection: %w", err)
}
return nil
}
Expand Down
Loading