Skip to content

Commit 189f430

Browse files
committed
Add a quota handler callback
This PR adds a quota handler callback function which, if specified, is called by the server just before making an allocation for a user. The handler should return a single bool: if true then the allocation request can proceed, otherwise the request is rejected with a 486 (Allocation Quota Reached) error.
1 parent af7dfe3 commit 189f430

File tree

5 files changed

+74
-0
lines changed

5 files changed

+74
-0
lines changed

internal/server/server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ type Request struct {
2929
// User Configuration
3030
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)
3131

32+
// Quota Handler
33+
QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool)
34+
3235
Log logging.LeveledLogger
3336
Realm string
3437
ChannelBindTimeout time.Duration

internal/server/turn.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint:
157157
// server is free to define this allocation quota any way it wishes,
158158
// but SHOULD define it based on the username used to authenticate
159159
// the request, and not on the client's transport address.
160+
if req.QuotaHandler != nil && !req.QuotaHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) {
161+
quotaReachedMsg := buildMsg(stunMsg.TransactionID,
162+
stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse),
163+
&stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached})
164+
165+
return buildAndSend(req.Conn, req.SrcAddr, quotaReachedMsg...)
166+
}
160167

161168
// 8. Also at any point, the server MAY choose to reject the request
162169
// with a 300 (Try Alternate) error if it wishes to redirect the

server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const (
2424
type Server struct {
2525
log logging.LeveledLogger
2626
authHandler AuthHandler
27+
quotaHandler QuotaHandler
2728
realm string
2829
channelBindTimeout time.Duration
2930
nonceHash *server.NonceHash
@@ -59,6 +60,7 @@ func NewServer(config ServerConfig) (*Server, error) { //nolint:gocognit,cyclop
5960
server := &Server{
6061
log: loggerFactory.NewLogger("turn"),
6162
authHandler: config.AuthHandler,
63+
quotaHandler: config.QuotaHandler,
6264
realm: config.Realm,
6365
channelBindTimeout: config.ChannelBindTimeout,
6466
packetConnConfigs: config.PacketConnConfigs,
@@ -231,6 +233,7 @@ func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Man
231233
Buff: buf[:n],
232234
Log: s.log,
233235
AuthHandler: s.authHandler,
236+
QuotaHandler: s.quotaHandler,
234237
Realm: s.realm,
235238
AllocationManager: allocationManager,
236239
ChannelBindTimeout: s.channelBindTimeout,

server_config.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ func GenerateAuthKey(username, realm, password string) []byte {
113113
// allocation's lifecycle.
114114
type EventHandler = allocation.EventHandler
115115

116+
// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is
117+
// exceeded. If the callback returns true the allocation request is accepted, otherwise it is
118+
// rejected and a 486 (Allocation Quota Reached) error is returned to the user.
119+
type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool)
120+
116121
// ServerConfig configures the Pion TURN Server.
117122
type ServerConfig struct {
118123
// PacketConnConfigs and ListenerConfigs are a list of all the turn listeners
@@ -130,6 +135,10 @@ type ServerConfig struct {
130135
// allowing users to customize Pion TURN with custom behavior
131136
AuthHandler AuthHandler
132137

138+
// QuotaHandler is a callback used to reject new allocations when a
139+
// per-user quota is exceeded.
140+
QuotaHandler QuotaHandler
141+
133142
// EventHandlers is a set of callbacks for tracking allocation lifecycle.
134143
EventHandler EventHandler
135144

server_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,58 @@ func TestSTUNOnly(t *testing.T) {
10111011
assert.NoError(t, conn.Close())
10121012
}
10131013

1014+
func TestQuotaReached(t *testing.T) {
1015+
serverAddr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:3478")
1016+
assert.NoError(t, err)
1017+
1018+
serverConn, err := net.ListenPacket(serverAddr.Network(), serverAddr.String())
1019+
assert.NoError(t, err)
1020+
1021+
defer serverConn.Close() //nolint:errcheck
1022+
1023+
credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")}
1024+
server, err := NewServer(ServerConfig{
1025+
AuthHandler: func(username, _ string, _ net.Addr) (key []byte, ok bool) {
1026+
if pw, ok := credMap[username]; ok {
1027+
return pw, true
1028+
}
1029+
return nil, false //nolint:nlreturn
1030+
},
1031+
QuotaHandler: func(_, _ string, _ net.Addr) (ok bool) { return false },
1032+
Realm: "pion.ly",
1033+
PacketConnConfigs: []PacketConnConfig{{
1034+
PacketConn: serverConn,
1035+
RelayAddressGenerator: &RelayAddressGeneratorStatic{
1036+
RelayAddress: net.ParseIP("127.0.0.1"),
1037+
Address: "0.0.0.0",
1038+
},
1039+
}},
1040+
LoggerFactory: logging.NewDefaultLoggerFactory(),
1041+
})
1042+
assert.NoError(t, err)
1043+
1044+
defer server.Close() //nolint:errcheck
1045+
1046+
conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
1047+
assert.NoError(t, err)
1048+
1049+
client, err := NewClient(&ClientConfig{
1050+
Conn: conn,
1051+
STUNServerAddr: "127.0.0.1:3478",
1052+
TURNServerAddr: "127.0.0.1:3478",
1053+
Username: "user",
1054+
Password: "pass",
1055+
Realm: "pion.ly",
1056+
LoggerFactory: logging.NewDefaultLoggerFactory(),
1057+
})
1058+
assert.NoError(t, err)
1059+
assert.NoError(t, client.Listen())
1060+
defer client.Close()
1061+
1062+
_, err = client.Allocate()
1063+
assert.Equal(t, err.Error(), "Allocate error response (error 486: )")
1064+
}
1065+
10141066
func RunBenchmarkServer(b *testing.B, clientNum int) { //nolint:cyclop
10151067
b.Helper()
10161068

0 commit comments

Comments
 (0)