diff --git a/channel.go b/channel.go index 9c88136..d691159 100644 --- a/channel.go +++ b/channel.go @@ -41,8 +41,8 @@ type Channel interface { type MultiChannel interface { Channel + UnderlayAcceptor DialUnderlay(factory GroupedUnderlayFactory, underlayType string) - AcceptUnderlay(underlay Underlay) bool GetUnderlayCountsByType() map[string]int GetUnderlayHandler() UnderlayHandler } @@ -174,7 +174,7 @@ type DialUnderlayFactory interface { } type GroupedUnderlayFactory interface { - CreateGroupedUnderlay(groupId string, underlayType string, timeout time.Duration) (Underlay, error) + CreateGroupedUnderlay(groupId string, groupedSecret []byte, underlayType string, timeout time.Duration) (Underlay, error) DialFailed(channel MultiChannel, underlayType string, attempt int) } diff --git a/classic_listener.go b/classic_listener.go index 65d79e6..6919535 100644 --- a/classic_listener.go +++ b/classic_listener.go @@ -225,6 +225,13 @@ func (self *classicListener) acceptConnection(peer transport.Conn) { } } + if isGrouped { + if secret := hello.Headers[GroupSecretHeader]; len(secret) == 0 { + newSecret := uuid.New() + hello.Headers[GroupSecretHeader] = newSecret[:] + } + } + impl.init(hello.IdToken, connectionId, hello.Headers) if err = self.ackHello(impl, request, true, ""); err == nil { @@ -285,6 +292,9 @@ func (self *classicListener) ackHello(impl classicUnderlay, request *Message, su if underlayType, _ := request.GetStringHeader(TypeHeader); underlayType != "" { response.PutStringHeader(TypeHeader, underlayType) } + if groupSecret := request.Headers[GroupSecretHeader]; len(groupSecret) > 0 { + response.Headers[GroupSecretHeader] = groupSecret + } response.sequence = HelloSequence diff --git a/message.go b/message.go index 4cb3461..15f0ac8 100644 --- a/message.go +++ b/message.go @@ -48,6 +48,7 @@ const ( TypeHeader = 7 IdHeader = 8 IsGroupedHeader = 9 + GroupSecretHeader = 10 // Headers in the range 128-255 inclusive will be reflected when creating replies ReflectedHeaderBitMask = 1 << 7 diff --git a/multi.go b/multi.go index f12d741..8bb79d4 100644 --- a/multi.go +++ b/multi.go @@ -17,6 +17,7 @@ package channel import ( + "bytes" "crypto/x509" "errors" "fmt" @@ -83,6 +84,7 @@ type multiChannelImpl struct { underlayHandler UnderlayHandler userData interface{} replyCounter uint32 + groupSecret []byte lock sync.Mutex underlays concurrenz.CopyOnWriteSlice[Underlay] @@ -111,6 +113,12 @@ func NewMultiChannel(config *MultiChannelConfig) (MultiChannel, error) { impl.headers.Store(config.Underlay.Headers()) impl.underlays.Append(config.Underlay) + if groupSecret := config.Underlay.Headers()[GroupSecretHeader]; len(groupSecret) == 0 { + return nil, errors.New("no group secret header found for multi channel") + } else { + impl.groupSecret = groupSecret + } + if err := bind(config.BindHandler, impl); err != nil { for _, u := range impl.underlays.Value() { if closeErr := u.Close(); closeErr != nil { @@ -130,15 +138,23 @@ func NewMultiChannel(config *MultiChannelConfig) (MultiChannel, error) { return impl, nil } -func (self *multiChannelImpl) AcceptUnderlay(underlay Underlay) bool { +func (self *multiChannelImpl) AcceptUnderlay(underlay Underlay) error { self.lock.Lock() defer self.lock.Unlock() + groupSecret := underlay.Headers()[GroupSecretHeader] + if !bytes.Equal(groupSecret, self.groupSecret) { + if err := underlay.Close(); err != nil { + pfxlog.ContextLogger(self.Label()).WithError(err).Error("error closing underlay") + } + return fmt.Errorf("new underlay for '%s' not accepted: incorrect group secret", self.ConnectionId()) + } + if self.IsClosed() { if err := underlay.Close(); err != nil { pfxlog.ContextLogger(self.Label()).WithError(err).Error("error closing underlay") } - return false + return fmt.Errorf("new underlay for '%s' not accepted: multi-channel is closed", self.ConnectionId()) } self.certs.Store(underlay.Certificates()) @@ -149,7 +165,7 @@ func (self *multiChannelImpl) AcceptUnderlay(underlay Underlay) bool { self.underlayHandler.HandleUnderlayAccepted(self, underlay) - return true + return nil } func (self *multiChannelImpl) startMultiplex(underlay Underlay) { @@ -499,9 +515,12 @@ func (self *multiChannelImpl) DialUnderlay(factory GroupedUnderlayFactory, under dialTimeout = DefaultConnectTimeout } - underlay, err := factory.CreateGroupedUnderlay(self.ConnectionId(), underlayType, dialTimeout) + underlay, err := factory.CreateGroupedUnderlay(self.ConnectionId(), self.groupSecret, underlayType, dialTimeout) if err == nil { - self.AcceptUnderlay(underlay) + if err = self.AcceptUnderlay(underlay); err != nil { + log.WithError(err).Error("dial of new underlay failed") + factory.DialFailed(self, underlayType, attempt) + } return } else { factory.DialFailed(self, underlayType, attempt) diff --git a/multi_listener.go b/multi_listener.go index 92e33a3..ceab170 100644 --- a/multi_listener.go +++ b/multi_listener.go @@ -41,7 +41,9 @@ func (self *MultiListener) AcceptUnderlay(underlay Underlay) { if ok { log.Info("found existing channel for underlay") - mc.AcceptUnderlay(underlay) + if err := mc.AcceptUnderlay(underlay); err != nil { + log.WithError(err).Error("error accepting underlay") + } } else { log.Info("no existing channel found for underlay") var err error diff --git a/multi_test.go b/multi_test.go index 43184b8..3a2c06a 100644 --- a/multi_test.go +++ b/multi_test.go @@ -329,10 +329,11 @@ func (self *dialPriorityChannel) DialFailed(_ MultiChannel, _ string, attempt in time.Sleep(delay) } -func (self *dialPriorityChannel) CreateGroupedUnderlay(groupId string, underlayType string, timeout time.Duration) (Underlay, error) { +func (self *dialPriorityChannel) CreateGroupedUnderlay(groupId string, groupSecret []byte, underlayType string, timeout time.Duration) (Underlay, error) { underlay, err := self.dialer.CreateWithHeaders(timeout, map[int32][]byte{ TypeHeader: []byte(underlayType), ConnectionIdHeader: []byte(groupId), + GroupSecretHeader: groupSecret, IsGroupedHeader: {1}, }) if err != nil { diff --git a/version b/version index 5186d07..7d5c902 100644 --- a/version +++ b/version @@ -1 +1 @@ -4.0 +4.1