Skip to content

Commit 15da9fa

Browse files
committed
Refactor packet tracking.
Signed-off-by: SuperQ <[email protected]>
1 parent 0caa487 commit 15da9fa

File tree

2 files changed

+121
-40
lines changed

2 files changed

+121
-40
lines changed

packet_tracking.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package probing
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"github.com/google/uuid"
8+
)
9+
10+
type PacketTracker struct {
11+
currentUUID uuid.UUID
12+
packets map[uuid.UUID]PacketSequence
13+
sequence int
14+
nextSequence int
15+
timeout time.Duration
16+
timeoutCh chan *inFlightPacket
17+
18+
mutex sync.RWMutex
19+
}
20+
21+
type PacketSequence struct {
22+
packets map[uint]inFlightPacket
23+
}
24+
25+
type inFlightPacket struct {
26+
timeoutTimer *time.Timer
27+
}
28+
29+
func newPacketTracker(t time.Duration) *PacketTracker {
30+
firstUUID := uuid.New()
31+
var firstSequence = map[uuid.UUID]map[int]struct{}{}
32+
firstSequence[firstUUID] = make(map[int]struct{})
33+
34+
return &PacketTracker{
35+
packets: map[uuid.UUID]PacketSequence{},
36+
sequence: 0,
37+
timeout: t,
38+
}
39+
}
40+
41+
func (t *PacketTracker) AddPacket() int {
42+
t.mutex.Lock()
43+
defer t.mutex.Unlock()
44+
45+
if t.nextSequence > 65535 {
46+
newUUID := uuid.New()
47+
t.packets[newUUID] = PacketSequence{}
48+
t.currentUUID = newUUID
49+
t.nextSequence = 0
50+
}
51+
52+
t.sequence = t.nextSequence
53+
t.packets[t.currentUUID][t.sequence] = inFlightPacket{}
54+
// if t.timeout > 0 {
55+
// t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout)
56+
// }
57+
t.nextSequence++
58+
return t.sequence
59+
}
60+
61+
// DeletePacket removes a packet from the tracker.
62+
func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) {
63+
t.mutex.Lock()
64+
defer t.mutex.Unlock()
65+
66+
if t.hasPacket(u, seq) {
67+
if t.packets[u][seq] != nil {
68+
t.packets[u][seq].timeoutTimer.Stop()
69+
}
70+
delete(t.packets[u], seq)
71+
}
72+
}
73+
74+
func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool {
75+
_, inflight := t.packets[u][seq]
76+
return inflight
77+
}
78+
79+
// HasPacket checks the tracker to see if it's currently tracking a packet.
80+
func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool {
81+
t.mutex.RLock()
82+
defer t.mutex.Unlock()
83+
84+
return t.hasPacket(u, seq)
85+
}
86+
87+
func (t *PacketTracker) HasUUID(u uuid.UUID) bool {
88+
_, hasUUID := t.packets[u]
89+
return hasUUID
90+
}
91+
92+
func (t *PacketTracker) CurrentUUID() uuid.UUID {
93+
t.mutex.RLock()
94+
defer t.mutex.Unlock()
95+
96+
return t.currentUUID
97+
}

ping.go

+24-40
Original file line numberDiff line numberDiff line change
@@ -87,27 +87,22 @@ var (
8787
// New returns a new Pinger struct pointer.
8888
func New(addr string) *Pinger {
8989
r := rand.New(rand.NewSource(getSeed()))
90-
firstUUID := uuid.New()
91-
var firstSequence = map[uuid.UUID]map[int]struct{}{}
92-
firstSequence[firstUUID] = make(map[int]struct{})
9390
return &Pinger{
9491
Count: -1,
9592
Interval: time.Second,
9693
RecordRtts: true,
9794
Size: timeSliceLength + trackerLength,
9895
Timeout: time.Duration(math.MaxInt64),
9996

100-
addr: addr,
101-
done: make(chan interface{}),
102-
id: r.Intn(math.MaxUint16),
103-
trackerUUIDs: []uuid.UUID{firstUUID},
104-
ipaddr: nil,
105-
ipv4: false,
106-
network: "ip",
107-
protocol: "udp",
108-
awaitingSequences: firstSequence,
109-
TTL: 64,
110-
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
97+
addr: addr,
98+
done: make(chan interface{}),
99+
id: r.Intn(math.MaxUint16),
100+
ipaddr: nil,
101+
ipv4: false,
102+
network: "ip",
103+
protocol: "udp",
104+
TTL: 64,
105+
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
111106
}
112107
}
113108

@@ -143,6 +138,9 @@ type Pinger struct {
143138
// Number of duplicate packets received
144139
PacketsRecvDuplicates int
145140

141+
// Per-packet timeout
142+
PacketTimeout time.Duration
143+
146144
// Round trip time statistics
147145
minRtt time.Duration
148146
maxRtt time.Duration
@@ -189,14 +187,11 @@ type Pinger struct {
189187
ipaddr *net.IPAddr
190188
addr string
191189

192-
// trackerUUIDs is the list of UUIDs being used for sending packets.
193-
trackerUUIDs []uuid.UUID
194-
195190
ipv4 bool
196191
id int
197192
sequence int
198-
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
199-
awaitingSequences map[uuid.UUID]map[int]struct{}
193+
// tracker is a PacketTrackrer of UUIDs and sequence numbers.
194+
tracker *PacketTracker
200195
// network is one of "ip", "ip4", or "ip6".
201196
network string
202197
// protocol is "icmp" or "udp".
@@ -413,6 +408,9 @@ func (p *Pinger) Run() error {
413408
if err != nil {
414409
return err
415410
}
411+
412+
p.tracker = newPacketTracker(p.PacketTimeout)
413+
416414
if conn, err = p.listen(); err != nil {
417415
return err
418416
}
@@ -615,19 +613,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
615613
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
616614
}
617615

618-
for _, item := range p.trackerUUIDs {
619-
if item == packetUUID {
620-
return &packetUUID, nil
621-
}
616+
if p.tracker.HasUUID(packetUUID) {
617+
return &packetUUID, nil
622618
}
623619
return nil, nil
624620
}
625621

626-
// getCurrentTrackerUUID grabs the latest tracker UUID.
627-
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
628-
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
629-
}
630-
631622
func (p *Pinger) processPacket(recv *packet) error {
632623
receivedAt := time.Now()
633624
var proto int
@@ -676,15 +667,15 @@ func (p *Pinger) processPacket(recv *packet) error {
676667
inPkt.Rtt = receivedAt.Sub(timestamp)
677668
inPkt.Seq = pkt.Seq
678669
// If we've already received this sequence, ignore it.
679-
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
670+
if !p.tracker.HasPacket(*pktUUID, pkt.Seq) {
680671
p.PacketsRecvDuplicates++
681672
if p.OnDuplicateRecv != nil {
682673
p.OnDuplicateRecv(inPkt)
683674
}
684675
return nil
685676
}
686-
// remove it from the list of sequences we're waiting for so we don't get duplicates.
687-
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
677+
// Remove it from the list of sequences we're waiting for so we don't get duplicates.
678+
p.tracker.DeletePacket(*pktUUID, pkt.Seq)
688679
p.updateStatistics(inPkt)
689680
default:
690681
// Very bad, not sure how this can happen
@@ -705,7 +696,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
705696
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
706697
}
707698

708-
currentUUID := p.getCurrentTrackerUUID()
699+
currentUUID := p.tracker.CurrentUUID()
709700
uuidEncoded, err := currentUUID.MarshalBinary()
710701
if err != nil {
711702
return fmt.Errorf("unable to marshal UUID binary: %w", err)
@@ -753,15 +744,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
753744
handler(outPkt)
754745
}
755746
// mark this sequence as in-flight
756-
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
747+
p.sequence = p.tracker.AddPacket()
757748
p.PacketsSent++
758-
p.sequence++
759-
if p.sequence > 65535 {
760-
newUUID := uuid.New()
761-
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
762-
p.awaitingSequences[newUUID] = make(map[int]struct{})
763-
p.sequence = 0
764-
}
765749
break
766750
}
767751

0 commit comments

Comments
 (0)