Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ type Channel interface {

type MultiChannel interface {
Channel
UnderlayAcceptor
DialUnderlay(factory GroupedUnderlayFactory, underlayType string)
AcceptUnderlay(underlay Underlay) bool
GetUnderlayCountsByType() map[string]int
GetUnderlayHandler() UnderlayHandler
}
Expand Down Expand Up @@ -174,7 +174,7 @@ type DialUnderlayFactory interface {
}

type GroupedUnderlayFactory interface {
CreateGroupedUnderlay(groupId string, underlayType string, timeout time.Duration) (Underlay, error)
CreateGroupedUnderlay(groupId string, groupedSecret []byte, underlayType string, timeout time.Duration) (Underlay, error)
DialFailed(channel MultiChannel, underlayType string, attempt int)
}

Expand Down
10 changes: 10 additions & 0 deletions classic_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ func (self *classicListener) acceptConnection(peer transport.Conn) {
}
}

if isGrouped {
if secret := hello.Headers[GroupSecretHeader]; len(secret) == 0 {
newSecret := uuid.New()
hello.Headers[GroupSecretHeader] = newSecret[:]
}
}

impl.init(hello.IdToken, connectionId, hello.Headers)

if err = self.ackHello(impl, request, true, ""); err == nil {
Expand Down Expand Up @@ -285,6 +292,9 @@ func (self *classicListener) ackHello(impl classicUnderlay, request *Message, su
if underlayType, _ := request.GetStringHeader(TypeHeader); underlayType != "" {
response.PutStringHeader(TypeHeader, underlayType)
}
if groupSecret := request.Headers[GroupSecretHeader]; len(groupSecret) > 0 {
response.Headers[GroupSecretHeader] = groupSecret
}

response.sequence = HelloSequence

Expand Down
1 change: 1 addition & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (
TypeHeader = 7
IdHeader = 8
IsGroupedHeader = 9
GroupSecretHeader = 10

// Headers in the range 128-255 inclusive will be reflected when creating replies
ReflectedHeaderBitMask = 1 << 7
Expand Down
29 changes: 24 additions & 5 deletions multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package channel

import (
"bytes"
"crypto/x509"
"errors"
"fmt"
Expand Down Expand Up @@ -83,6 +84,7 @@ type multiChannelImpl struct {
underlayHandler UnderlayHandler
userData interface{}
replyCounter uint32
groupSecret []byte

lock sync.Mutex
underlays concurrenz.CopyOnWriteSlice[Underlay]
Expand Down Expand Up @@ -111,6 +113,12 @@ func NewMultiChannel(config *MultiChannelConfig) (MultiChannel, error) {
impl.headers.Store(config.Underlay.Headers())
impl.underlays.Append(config.Underlay)

if groupSecret := config.Underlay.Headers()[GroupSecretHeader]; len(groupSecret) == 0 {
return nil, errors.New("no group secret header found for multi channel")
} else {
impl.groupSecret = groupSecret
}

if err := bind(config.BindHandler, impl); err != nil {
for _, u := range impl.underlays.Value() {
if closeErr := u.Close(); closeErr != nil {
Expand All @@ -130,15 +138,23 @@ func NewMultiChannel(config *MultiChannelConfig) (MultiChannel, error) {
return impl, nil
}

func (self *multiChannelImpl) AcceptUnderlay(underlay Underlay) bool {
func (self *multiChannelImpl) AcceptUnderlay(underlay Underlay) error {
self.lock.Lock()
defer self.lock.Unlock()

groupSecret := underlay.Headers()[GroupSecretHeader]
if !bytes.Equal(groupSecret, self.groupSecret) {
if err := underlay.Close(); err != nil {
pfxlog.ContextLogger(self.Label()).WithError(err).Error("error closing underlay")
}
return fmt.Errorf("new underlay for '%s' not accepted: incorrect group secret", self.ConnectionId())
}

if self.IsClosed() {
if err := underlay.Close(); err != nil {
pfxlog.ContextLogger(self.Label()).WithError(err).Error("error closing underlay")
}
return false
return fmt.Errorf("new underlay for '%s' not accepted: multi-channel is closed", self.ConnectionId())
}

self.certs.Store(underlay.Certificates())
Expand All @@ -149,7 +165,7 @@ func (self *multiChannelImpl) AcceptUnderlay(underlay Underlay) bool {

self.underlayHandler.HandleUnderlayAccepted(self, underlay)

return true
return nil
}

func (self *multiChannelImpl) startMultiplex(underlay Underlay) {
Expand Down Expand Up @@ -499,9 +515,12 @@ func (self *multiChannelImpl) DialUnderlay(factory GroupedUnderlayFactory, under
dialTimeout = DefaultConnectTimeout
}

underlay, err := factory.CreateGroupedUnderlay(self.ConnectionId(), underlayType, dialTimeout)
underlay, err := factory.CreateGroupedUnderlay(self.ConnectionId(), self.groupSecret, underlayType, dialTimeout)
if err == nil {
self.AcceptUnderlay(underlay)
if err = self.AcceptUnderlay(underlay); err != nil {
log.WithError(err).Error("dial of new underlay failed")
factory.DialFailed(self, underlayType, attempt)
}
return
} else {
factory.DialFailed(self, underlayType, attempt)
Expand Down
4 changes: 3 additions & 1 deletion multi_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ func (self *MultiListener) AcceptUnderlay(underlay Underlay) {

if ok {
log.Info("found existing channel for underlay")
mc.AcceptUnderlay(underlay)
if err := mc.AcceptUnderlay(underlay); err != nil {
log.WithError(err).Error("error accepting underlay")
}
} else {
log.Info("no existing channel found for underlay")
var err error
Expand Down
3 changes: 2 additions & 1 deletion multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ func (self *dialPriorityChannel) DialFailed(_ MultiChannel, _ string, attempt in
time.Sleep(delay)
}

func (self *dialPriorityChannel) CreateGroupedUnderlay(groupId string, underlayType string, timeout time.Duration) (Underlay, error) {
func (self *dialPriorityChannel) CreateGroupedUnderlay(groupId string, groupSecret []byte, underlayType string, timeout time.Duration) (Underlay, error) {
underlay, err := self.dialer.CreateWithHeaders(timeout, map[int32][]byte{
TypeHeader: []byte(underlayType),
ConnectionIdHeader: []byte(groupId),
GroupSecretHeader: groupSecret,
IsGroupedHeader: {1},
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4.0
4.1
Loading