Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new lockDuplicateRequest #186

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/routedns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ func instantiateGroup(id string, g group, resolvers map[string]rdns.Resolver) er
gr = append(gr, resolver)
}
switch g.Type {
case "lock-duplicate-request":
if len(gr) != 1 {
return fmt.Errorf("type lock-duplicate-request only supports one resolver in '%s'", id)
}
resolvers[id] = rdns.NewlockDuplicateRequest(id, gr[0])
case "round-robin":
resolvers[id] = rdns.NewRoundRobin(id, gr...)
case "fail-rotate":
Expand Down
174 changes: 174 additions & 0 deletions lockduplicaterequest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package rdns

import (
"encoding/binary"
"github.com/miekg/dns"
"sync"
"sync/atomic"
"time"
)

type key struct {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the global namespace, probably best to use a name like lockDuplicateKey or similar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

name string
qtype uint16
ecs_ipv4 uint32
ecs_ipv6_hi uint64
ecs_ipv6_lo uint64
ecs_mask uint8
}

type value struct {
//If expiredTimeStamp is negative number , it means it will return SERVFAIL.
expiredTimeStamp int64
mu *newMutex
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why re-implement the mutex here? Can't you use the regular mutex instead and pass a *value below?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If timeout, this plugin should unlock-all same request and then return SERVFAIL to downstream.
Tradition Mutex can not support it.

}

type lockDuplicateRequest struct {
id string
resolver Resolver
m sync.Map
}

var _ Resolver = &lockDuplicateRequest{}

func NewlockDuplicateRequest(id string, resolver Resolver) *lockDuplicateRequest {
ret := &lockDuplicateRequest{id: id, resolver: resolver}
return ret
}

func (r *lockDuplicateRequest) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {

var ecsipv4 uint32 = 0
var ecsipv6hi uint64 = 0
var ecsipv6lo uint64 = 0
var ecsmask uint8 = 0

if len(q.Question) != 1 {
return r.resolver.Resolve(q, ci)
}

edns0 := q.IsEdns0()
if edns0 != nil {
// Find the ECS option
for _, opt := range edns0.Option {
ecs, ok := opt.(*dns.EDNS0_SUBNET)
if !ok {
continue
}
switch ecs.Family {
case 1: // ip4
ecsipv4 = ldrByteToUint32(ecs.Address.To4())
ecsmask = ecs.SourceNetmask
break
case 2: // ip6
ecsipv6hi, ecsipv6lo = ldrByteToUint128(ecs.Address.To16())
ecsmask = ecs.SourceNetmask
break
}
}
}

k := key{
name: q.Question[0].Name,
qtype: q.Question[0].Qtype,
ecs_ipv4: ecsipv4,
ecs_ipv6_hi: ecsipv6hi,
ecs_ipv6_lo: ecsipv6lo,
ecs_mask: ecsmask,
}

var newValue value
parse, _ := time.ParseDuration("+5s")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be much more efficient to use a constant that is set to 5 * time.Second here rather than parsing a string in every call.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

expired := time.Now().Add(parse)
newValue.expiredTimeStamp = expired.Unix()
newValue.mu = &newMutex{}
loaded, _ := r.m.LoadOrStore(k, newValue)
v := loaded.(value)


var returnAnswer *dns.Msg
var returnError error

v.mu.Lock()
if v.expiredTimeStamp >= 0 {
a, err := r.resolver.Resolve(q, ci)
returnAnswer = a
returnError = err
} else {
returnAnswer = nil
returnError = QueryTimeoutError{q}
}
v.mu.UnLock()

currentTimeStamp := time.Now().Unix()
expiredTimeStamp := v.expiredTimeStamp
if v.mu.Count() == 0 {
r.m.Delete(k)
} else if expiredTimeStamp >= 0 && expiredTimeStamp < currentTimeStamp {
r.m.Delete(k)
v.expiredTimeStamp = ^expiredTimeStamp
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since v is not locked here, I don't think this is thread-safe.

Regardless, instead of doing these calculations, using a channel that is closed to signal "timeout" would be better. Could even use a context with timeout in this case. Or simply add a (fixed) timeout into the value, then check it with https://pkg.go.dev/time#Time.After against the current time.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It is not thread-safe.

I prefer timeTicker.

v.mu.UnLockAll()
}

return returnAnswer, returnError
}

func (r *lockDuplicateRequest) String() string {
return r.id
}

func ldrByteToUint128(b []byte) (uint64, uint64) {
hi := binary.BigEndian.Uint64(b[0:8])
lo := binary.BigEndian.Uint64(b[8:16])
return hi, lo
}

func ldrByteToUint32(b []byte) uint32 {
return binary.BigEndian.Uint32(b[0:4])
}

type newMutex struct {
count uint32
sync sync.Mutex
}

func (r *newMutex) Count() uint32 {
return r.count
}

func (r *newMutex) Lock() {
r.sync.Lock()
atomic.AddUint32(&(r.count), +1)
}

func (r *newMutex) UnLock() {
for {
count := atomic.LoadUint32(&(r.count))
if count > 0 {
ok := atomic.CompareAndSwapUint32(&(r.count), count, count-1)
if ok {
break
}
} else {
return
}
}
r.sync.Unlock()
}

func (r *newMutex) UnLockAll() {
for {
count := atomic.LoadUint32(&(r.count))
if count > 0 {
ok := atomic.CompareAndSwapUint32(&(r.count), count, 0)
if ok {
for count > 0 {
count--
r.sync.Unlock()
}
}
} else {
break
}
}
}