diff --git a/multi_listener.go b/multi_listener.go index ceab170..894adc7 100644 --- a/multi_listener.go +++ b/multi_listener.go @@ -22,16 +22,29 @@ import ( ) type MultiChannelFactory func(underlay Underlay, closeCallback func()) (MultiChannel, error) +type UngroupedChannelFallback func(underlay Underlay) error type MultiListener struct { - channels map[string]MultiChannel - lock sync.Mutex - channelFactory MultiChannelFactory + channels map[string]MultiChannel + lock sync.Mutex + multiChannelFactory MultiChannelFactory + ungroupedChannelFallback UngroupedChannelFallback } func (self *MultiListener) AcceptUnderlay(underlay Underlay) { - log := pfxlog.Logger().WithField("underlayId", underlay.ConnectionId()). - WithField("underlayType", GetUnderlayType(underlay)) + isGrouped, _ := Headers(underlay.Headers()).GetBoolHeader(IsGroupedHeader) + + log := pfxlog.Logger(). + WithField("underlayId", underlay.ConnectionId()). + WithField("underlayType", GetUnderlayType(underlay)). + WithField("isGrouped", isGrouped) + + if !isGrouped { + if err := self.ungroupedChannelFallback(underlay); err != nil { + log.WithError(err).Errorf("failed to create channel") + } + return + } chId := underlay.ConnectionId() @@ -47,13 +60,13 @@ func (self *MultiListener) AcceptUnderlay(underlay Underlay) { } else { log.Info("no existing channel found for underlay") var err error - mc, err = self.channelFactory(underlay, func() { + mc, err = self.multiChannelFactory(underlay, func() { self.CloseChannel(chId) }) if mc != nil { if err != nil { - pfxlog.Logger().WithError(err).Errorf("failed to create multi-underlay channel") + log.WithError(err).Errorf("failed to create multi-underlay channel") } else { self.lock.Lock() self.channels[chId] = mc @@ -69,10 +82,11 @@ func (self *MultiListener) CloseChannel(chId string) { self.lock.Unlock() } -func NewMultiListener(channelF MultiChannelFactory) *MultiListener { +func NewMultiListener(channelF MultiChannelFactory, fallback UngroupedChannelFallback) *MultiListener { result := &MultiListener{ - channels: make(map[string]MultiChannel), - channelFactory: channelF, + channels: make(map[string]MultiChannel), + multiChannelFactory: channelF, + ungroupedChannelFallback: fallback, } return result } diff --git a/multi_test.go b/multi_test.go index 3a2c06a..788fbc2 100644 --- a/multi_test.go +++ b/multi_test.go @@ -17,6 +17,7 @@ package channel import ( + "errors" "fmt" "github.com/michaelquigley/pfxlog" "github.com/openziti/foundation/v2/goroutines" @@ -112,6 +113,8 @@ func Test_MultiUnderlayChannels(t *testing.T) { } underlayHandler := NewListenerPriorityChannel(wrapper) return newMultiChannel("listener", underlayHandler, wrapper, closeCallback) + }, func(underlay Underlay) error { + return errors.New("this implementation only accepts grouped channel") }) listenerConfig := ListenerConfig{