Skip to content

Commit 48e103f

Browse files
committed
feature: support Unbind requests
1 parent 27a3508 commit 48e103f

11 files changed

+200
-10
lines changed

README.md

+4-7
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,11 @@ func searchHandler(w *gldap.ResponseWriter, r *gldap.Request) {
143143
* Modify Requests
144144
* Add Requests
145145
* Delete Requests
146+
* Unbind Requests
146147
147-
### Near-term features
148-
149-
* Unbind Requests
150-
151-
### Long-term features
152-
153-
* ???
148+
### Future features
149+
At this point, we may wait until issues are opened before planning new features
150+
given that all the basic LDAP operations are supported.
154151
155152
<hr>
156153

conn.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,13 @@ func (c *conn) serveRequests() error {
111111
// BusyResponse when the limit is reached. This limit per conn
112112
// should be configurable
113113

114-
// TODO: stop serving requests when an UnbindRequest is received
114+
case r.routeOp == unbindRouteOperation:
115+
// support an optional unbind route
116+
if c.router.unbindRoute != nil {
117+
c.router.unbindRoute.handler()(w, r)
118+
}
119+
// stop serving requests when UnbindRequest is received
120+
return nil
115121

116122
// If it's a StartTLS request, then we can't dispatch it concurrently,
117123
// since the conn needs to complete it's TLS negotiation before handling

message.go

+12
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const (
4444
modifyRequestType requestType = "modify"
4545
addRequestType requestType = "add"
4646
deleteRequestType requestType = "delete"
47+
unbindRequestType requestType = "unbind"
4748
)
4849

4950
// Message defines a common interface for all messages
@@ -115,6 +116,11 @@ type DeleteMessage struct {
115116
Controls []Control
116117
}
117118

119+
// UnbindMessage is an unbind request message
120+
type UnbindMessage struct {
121+
baseMessage
122+
}
123+
118124
// newMessage will create a new message from the packet.
119125
func newMessage(p *packet) (Message, error) {
120126
const op = "gldap.NewMessage"
@@ -130,6 +136,12 @@ func newMessage(p *packet) (Message, error) {
130136
}
131137

132138
switch reqType {
139+
case unbindRequestType:
140+
return &UnbindMessage{
141+
baseMessage: baseMessage{
142+
id: msgID,
143+
},
144+
}, nil
133145
case bindRequestType:
134146
u, pass, controls, err := p.simpleBindParameters()
135147
if err != nil {

mux.go

+26
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type Mux struct {
1212
mu sync.Mutex
1313
routes []route
1414
defaultRoute route
15+
unbindRoute route
1516
}
1617

1718
// NewMux creates a new multiplexer.
@@ -44,6 +45,31 @@ func (m *Mux) Bind(bindFn HandlerFunc, opt ...Option) error {
4445
return nil
4546
}
4647

48+
// Unbind will register a handler for unbind requests and override the default
49+
// unbind handler. Registering an unbind handler is optional and regardless of
50+
// whether or not an unbind route is defined the server will stop serving
51+
// requests for a connection after an unbind request is received. Options
52+
// supported: WithLabel
53+
func (m *Mux) Unbind(bindFn HandlerFunc, opt ...Option) error {
54+
const op = "gldap.(Mux).Unbind"
55+
if bindFn == nil {
56+
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
57+
}
58+
opts := getRouteOpts(opt...)
59+
60+
r := &unbindRoute{
61+
baseRoute: &baseRoute{
62+
h: bindFn,
63+
routeOp: bindRouteOperation,
64+
label: opts.withLabel,
65+
},
66+
}
67+
m.mu.Lock()
68+
defer m.mu.Unlock()
69+
m.unbindRoute = r
70+
return nil
71+
}
72+
4773
// Search will register a handler for search requests.
4874
// Options supported: WithLabel, WithBaseDN, WithScope
4975
func (m *Mux) Search(searchFn HandlerFunc, opt ...Option) error {

mux_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,44 @@ func TestMux_Delete(t *testing.T) {
146146
})
147147
}
148148
}
149+
150+
func TestMux_Unbind(t *testing.T) {
151+
tests := []struct {
152+
name string
153+
mux *Mux
154+
fn HandlerFunc
155+
wantErr bool
156+
wantErrIs error
157+
wantErrContains string
158+
}{
159+
{
160+
name: "missing-fn",
161+
mux: func() *Mux { m, err := NewMux(); require.NoError(t, err); return m }(),
162+
wantErr: true,
163+
wantErrIs: ErrInvalidParameter,
164+
wantErrContains: "missing HandlerFunc",
165+
},
166+
{
167+
name: "valid",
168+
mux: func() *Mux { m, err := NewMux(); require.NoError(t, err); return m }(),
169+
fn: func(*ResponseWriter, *Request) {},
170+
},
171+
}
172+
for _, tc := range tests {
173+
t.Run(tc.name, func(t *testing.T) {
174+
assert, require := assert.New(t), require.New(t)
175+
err := tc.mux.Unbind(tc.fn)
176+
if tc.wantErr {
177+
require.Error(err)
178+
if tc.wantErrIs != nil {
179+
assert.ErrorIs(err, tc.wantErrIs)
180+
}
181+
if tc.wantErrContains != "" {
182+
assert.Contains(err.Error(), tc.wantErrContains)
183+
}
184+
return
185+
}
186+
require.NoError(err)
187+
})
188+
}
189+
}

packet.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ func (p *packet) requestType() (requestType, error) {
129129
return addRequestType, nil
130130
case ApplicationDelRequest:
131131
return deleteRequestType, nil
132+
case ApplicationUnbindRequest:
133+
return unbindRequestType, nil
132134
default:
133135
return unknownRequestType, fmt.Errorf("%s: unhandled request type %d: %w", op, requestPacket.Tag, ErrInternal)
134136
}
@@ -537,8 +539,8 @@ func (p *packet) assertApplicationRequest() error {
537539
}
538540
switch chkPacket.TagType {
539541
case ber.TypePrimitive:
540-
if chkPacket.Tag != ApplicationDelRequest {
541-
return fmt.Errorf("%s: incorrect type, primitive %q must be a delete request %q, but got %q", op, ber.TypePrimitive, ApplicationDelRequest, chkPacket.Tag)
542+
if chkPacket.Tag != ApplicationDelRequest && chkPacket.Tag != ApplicationUnbindRequest {
543+
return fmt.Errorf("%s: incorrect type, primitive %q must be a delete request %q or an unbind request %q, but got %q", op, ber.TypePrimitive, ApplicationDelRequest, ApplicationUnbindRequest, chkPacket.Tag)
542544
}
543545
case ber.TypeConstructed:
544546
default:

request.go

+13
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ func newRequest(id int, c *conn, p *packet) (*Request, error) {
6161
routeOp = addRouteOperation
6262
case *DeleteMessage:
6363
routeOp = deleteRouteOperation
64+
case *UnbindMessage:
65+
routeOp = unbindRouteOperation
6466
default:
6567
// this should be unreachable, since newMessage defaults to returning an
6668
// *ExtendedOperationMessage
@@ -256,6 +258,17 @@ func (r *Request) GetDeleteMessage() (*DeleteMessage, error) {
256258
return m, nil
257259
}
258260

261+
// GetUnbindMessage retrieves the UnbindMessage from the request, which
262+
// allows you handle the request based on the message attributes.
263+
func (r *Request) GetUnbindMessage() (*UnbindMessage, error) {
264+
const op = "gldap.(Request).GetUnbindMessage"
265+
m, ok := r.message.(*UnbindMessage)
266+
if !ok {
267+
return nil, fmt.Errorf("%s: %T not an unbind request: %w", op, r.message, ErrInvalidParameter)
268+
}
269+
return m, nil
270+
}
271+
259272
func intPtr(i int) *int {
260273
return &i
261274
}

request_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ func Test_newRequest(t *testing.T) {
7373
},
7474
},
7575
},
76+
{
77+
name: "valid-unbind",
78+
requestID: 1,
79+
conn: &conn{},
80+
packet: testUnbindRequestPacket(t,
81+
UnbindMessage{
82+
baseMessage: baseMessage{id: 1},
83+
},
84+
),
85+
wantMsg: &UnbindMessage{
86+
baseMessage: baseMessage{id: 1},
87+
},
88+
},
7689
{
7790
name: "valid-search",
7891
requestID: 1,

route.go

+7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ const (
2929
// deleteRouteOperation is a route supporting the delete operation
3030
deleteRouteOperation routeOperation = "delete"
3131

32+
// unbindRouteOperation is a route supporting the unbind operation
33+
unbindRouteOperation routeOperation = "unbind"
34+
3235
// defaultRouteOperation is a default route which is used when there are no routes
3336
// defined for a particular operation
3437
defaultRouteOperation routeOperation = "noRoute"
@@ -73,6 +76,10 @@ type simpleBindRoute struct {
7376
authChoice AuthChoice
7477
}
7578

79+
type unbindRoute struct {
80+
*baseRoute
81+
}
82+
7683
type extendedRoute struct {
7784
*baseRoute
7885
extendedName ExtendedOperationName

testing.go

+12
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ func testSimpleBindRequestPacket(t *testing.T, m SimpleBindMessage) *packet {
8888
}
8989
}
9090

91+
func testUnbindRequestPacket(t *testing.T, m UnbindMessage) *packet {
92+
t.Helper()
93+
94+
envelope := testRequestEnvelope(t, int(m.GetID()))
95+
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationUnbindRequest, nil, "Unbind Request")
96+
envelope.AppendChild(pkt)
97+
98+
return &packet{
99+
Packet: envelope,
100+
}
101+
}
102+
91103
func testModifyRequestPacket(t *testing.T, m ModifyMessage) *packet {
92104
t.Helper()
93105
envelope := testRequestEnvelope(t, int(m.GetID()))

testing_e2e_test.go

+61
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/tls"
55
"crypto/x509"
66
"fmt"
7+
"sync"
78
"testing"
89

910
"github.com/go-ldap/ldap/v3"
@@ -616,3 +617,63 @@ func TestDirectory_DeleteResponse(t *testing.T) {
616617
})
617618
}
618619
}
620+
621+
func Test_Start_Unbind(t *testing.T) {
622+
t.Parallel()
623+
t.Run("unbind", func(t *testing.T) {
624+
assert, require := assert.New(t), require.New(t)
625+
port := testdirectory.FreePort(t)
626+
627+
l := hclog.New(&hclog.LoggerOptions{
628+
Name: "simple-bind-logger",
629+
Level: hclog.Error,
630+
})
631+
632+
// create a new server
633+
s, err := gldap.NewServer(gldap.WithLogger(l), gldap.WithDisablePanicRecovery())
634+
require.NoError(err)
635+
636+
// create a router and add a bind handler
637+
r, err := gldap.NewMux()
638+
require.NoError(err)
639+
640+
var got string
641+
var wg sync.WaitGroup
642+
wg.Add(1)
643+
r.Unbind(func(w *gldap.ResponseWriter, req *gldap.Request) {
644+
m, err := req.GetUnbindMessage()
645+
if err != nil {
646+
t.Fatalf("unable to get unbind msg: %s", err.Error())
647+
}
648+
if m == nil {
649+
t.Fatal("unbind msg is nil")
650+
}
651+
got = "unbind-success"
652+
wg.Done()
653+
})
654+
655+
r.Bind(func(w *gldap.ResponseWriter, r *gldap.Request) {
656+
resp := r.NewBindResponse(gldap.WithResponseCode(gldap.ResultSuccess))
657+
defer func() {
658+
_ = w.Write(resp)
659+
}()
660+
})
661+
662+
s.Router(r)
663+
go s.Run(fmt.Sprintf(":%d", port))
664+
defer s.Stop()
665+
666+
conn, err := ldap.DialURL(fmt.Sprintf("ldap://localhost:%d", port))
667+
require.NoError(err)
668+
defer conn.Close()
669+
670+
err = conn.Bind("does not", "matter")
671+
require.NoError(err)
672+
673+
err = conn.Unbind()
674+
require.NoError(err)
675+
676+
wg.Wait()
677+
assert.Equal("unbind-success", got)
678+
})
679+
}

0 commit comments

Comments
 (0)