diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 596cfcd8a..8c6c894de 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -19,6 +19,7 @@ import ( "regexp" "strconv" "strings" + "sync" "syscall" "time" @@ -42,7 +43,10 @@ import ( type netTun struct { ep *channel.Endpoint stack *stack.Stack + notifyHandle *channel.NotificationHandle events chan tun.Event + pktMu sync.RWMutex + pktClosed bool incomingPacket chan *bufferv2.View mtu int dnsServers []netip.Addr @@ -70,7 +74,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, if tcpipErr != nil { return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } - dev.ep.AddNotify(dev) + dev.notifyHandle = dev.ep.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) @@ -162,11 +166,16 @@ func (tun *netTun) WriteNotify() { view := pkt.ToView() pkt.DecRef() - tun.incomingPacket <- view + tun.pktMu.RLock() + if !tun.pktClosed { + tun.incomingPacket <- view + } + tun.pktMu.RUnlock() } func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) + tun.ep.RemoveNotify(tun.notifyHandle) if tun.events != nil { close(tun.events) @@ -174,9 +183,10 @@ func (tun *netTun) Close() error { tun.ep.Close() - if tun.incomingPacket != nil { - close(tun.incomingPacket) - } + tun.pktMu.Lock() + tun.pktClosed = true + close(tun.incomingPacket) + tun.pktMu.Unlock() return nil }