-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathlockduplicaterequest.go
174 lines (149 loc) · 3.27 KB
/
lockduplicaterequest.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package rdns
import (
"encoding/binary"
"github.com/miekg/dns"
"sync"
"sync/atomic"
"time"
)
type key struct {
name string
qtype uint16
ecs_ipv4 uint32
ecs_ipv6_hi uint64
ecs_ipv6_lo uint64
ecs_mask uint8
}
type value struct {
//If expiredTimeStamp is negative number , it means it will return SERVFAIL.
expiredTimeStamp int64
mu *newMutex
}
type lockDuplicateRequest struct {
id string
resolver Resolver
m sync.Map
}
var _ Resolver = &lockDuplicateRequest{}
func NewlockDuplicateRequest(id string, resolver Resolver) *lockDuplicateRequest {
ret := &lockDuplicateRequest{id: id, resolver: resolver}
return ret
}
func (r *lockDuplicateRequest) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
var ecsipv4 uint32 = 0
var ecsipv6hi uint64 = 0
var ecsipv6lo uint64 = 0
var ecsmask uint8 = 0
if len(q.Question) != 1 {
return r.resolver.Resolve(q, ci)
}
edns0 := q.IsEdns0()
if edns0 != nil {
// Find the ECS option
for _, opt := range edns0.Option {
ecs, ok := opt.(*dns.EDNS0_SUBNET)
if !ok {
continue
}
switch ecs.Family {
case 1: // ip4
ecsipv4 = ldrByteToUint32(ecs.Address.To4())
ecsmask = ecs.SourceNetmask
break
case 2: // ip6
ecsipv6hi, ecsipv6lo = ldrByteToUint128(ecs.Address.To16())
ecsmask = ecs.SourceNetmask
break
}
}
}
k := key{
name: q.Question[0].Name,
qtype: q.Question[0].Qtype,
ecs_ipv4: ecsipv4,
ecs_ipv6_hi: ecsipv6hi,
ecs_ipv6_lo: ecsipv6lo,
ecs_mask: ecsmask,
}
var newValue value
parse, _ := time.ParseDuration("+5s")
expired := time.Now().Add(parse)
newValue.expiredTimeStamp = expired.Unix()
newValue.mu = &newMutex{}
loaded, _ := r.m.LoadOrStore(k, newValue)
v := loaded.(value)
var returnAnswer *dns.Msg
var returnError error
v.mu.Lock()
if v.expiredTimeStamp >= 0 {
a, err := r.resolver.Resolve(q, ci)
returnAnswer = a
returnError = err
} else {
returnAnswer = nil
returnError = QueryTimeoutError{q}
}
v.mu.UnLock()
currentTimeStamp := time.Now().Unix()
expiredTimeStamp := v.expiredTimeStamp
if v.mu.Count() == 0 {
r.m.Delete(k)
} else if expiredTimeStamp >= 0 && expiredTimeStamp < currentTimeStamp {
r.m.Delete(k)
v.expiredTimeStamp = ^expiredTimeStamp
v.mu.UnLockAll()
}
return returnAnswer, returnError
}
func (r *lockDuplicateRequest) String() string {
return r.id
}
func ldrByteToUint128(b []byte) (uint64, uint64) {
hi := binary.BigEndian.Uint64(b[0:8])
lo := binary.BigEndian.Uint64(b[8:16])
return hi, lo
}
func ldrByteToUint32(b []byte) uint32 {
return binary.BigEndian.Uint32(b[0:4])
}
type newMutex struct {
count uint32
sync sync.Mutex
}
func (r *newMutex) Count() uint32 {
return r.count
}
func (r *newMutex) Lock() {
r.sync.Lock()
atomic.AddUint32(&(r.count), +1)
}
func (r *newMutex) UnLock() {
for {
count := atomic.LoadUint32(&(r.count))
if count > 0 {
ok := atomic.CompareAndSwapUint32(&(r.count), count, count-1)
if ok {
break
}
} else {
return
}
}
r.sync.Unlock()
}
func (r *newMutex) UnLockAll() {
for {
count := atomic.LoadUint32(&(r.count))
if count > 0 {
ok := atomic.CompareAndSwapUint32(&(r.count), count, 0)
if ok {
for count > 0 {
count--
r.sync.Unlock()
}
}
} else {
break
}
}
}