Skip to content

Commit a8ea4be

Browse files
bradfitzgopherbot
authored andcommitted
ssh: add ServerConfig.PreAuthConnCallback, ServerPreAuthConn (banner) interface
Fixes golang/go#68688 Change-Id: Id5f72b32c61c9383a26ec182339486a432c7cdf5 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/613856 LUCI-TryBot-Result: Go LUCI <[email protected]> Auto-Submit: Nicola Murino <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]> Reviewed-by: Nicola Murino <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]>
1 parent 71d3a4c commit a8ea4be

File tree

3 files changed

+135
-15
lines changed

3 files changed

+135
-15
lines changed

ssh/handshake.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ type handshakeTransport struct {
8080
pendingPackets [][]byte // Used when a key exchange is in progress.
8181
writePacketsLeft uint32
8282
writeBytesLeft int64
83+
userAuthComplete bool // whether the user authentication phase is complete
8384

8485
// If the read loop wants to schedule a kex, it pings this
8586
// channel, and the write loop will send out a kex
@@ -552,16 +553,25 @@ func (t *handshakeTransport) sendKexInit() error {
552553
return nil
553554
}
554555

556+
var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase")
557+
555558
func (t *handshakeTransport) writePacket(p []byte) error {
559+
t.mu.Lock()
560+
defer t.mu.Unlock()
561+
556562
switch p[0] {
557563
case msgKexInit:
558564
return errors.New("ssh: only handshakeTransport can send kexInit")
559565
case msgNewKeys:
560566
return errors.New("ssh: only handshakeTransport can send newKeys")
567+
case msgUserAuthBanner:
568+
if t.userAuthComplete {
569+
return errSendBannerPhase
570+
}
571+
case msgUserAuthSuccess:
572+
t.userAuthComplete = true
561573
}
562574

563-
t.mu.Lock()
564-
defer t.mu.Unlock()
565575
if t.writeError != nil {
566576
return t.writeError
567577
}

ssh/server.go

+37-13
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ type GSSAPIWithMICConfig struct {
5959
Server GSSAPIServer
6060
}
6161

62+
// SendAuthBanner implements [ServerPreAuthConn].
63+
func (s *connection) SendAuthBanner(msg string) error {
64+
return s.transport.writePacket(Marshal(&userAuthBannerMsg{
65+
Message: msg,
66+
}))
67+
}
68+
69+
func (*connection) unexportedMethodForFutureProofing() {}
70+
71+
// ServerPreAuthConn is the interface available on an incoming server
72+
// connection before authentication has completed.
73+
type ServerPreAuthConn interface {
74+
unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB
75+
76+
ConnMetadata
77+
78+
// SendAuthBanner sends a banner message to the client.
79+
// It returns an error once the authentication phase has ended.
80+
SendAuthBanner(string) error
81+
}
82+
6283
// ServerConfig holds server specific configuration data.
6384
type ServerConfig struct {
6485
// Config contains configuration shared between client and server.
@@ -118,6 +139,12 @@ type ServerConfig struct {
118139
// attempts.
119140
AuthLogCallback func(conn ConnMetadata, method string, err error)
120141

142+
// PreAuthConnCallback, if non-nil, is called upon receiving a new connection
143+
// before any authentication has started. The provided ServerPreAuthConn
144+
// can be used at any time before authentication is complete, including
145+
// after this callback has returned.
146+
PreAuthConnCallback func(ServerPreAuthConn)
147+
121148
// ServerVersion is the version identification string to announce in
122149
// the public handshake.
123150
// If empty, a reasonable default is used.
@@ -488,14 +515,18 @@ func (b *BannerError) Error() string {
488515
}
489516

490517
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
518+
if config.PreAuthConnCallback != nil {
519+
config.PreAuthConnCallback(s)
520+
}
521+
491522
sessionID := s.transport.getSessionID()
492523
var cache pubKeyCache
493524
var perms *Permissions
494525

495526
authFailures := 0
496527
noneAuthCount := 0
497528
var authErrs []error
498-
var displayedBanner bool
529+
var calledBannerCallback bool
499530
partialSuccessReturned := false
500531
// Set the initial authentication callbacks from the config. They can be
501532
// changed if a PartialSuccessError is returned.
@@ -542,14 +573,10 @@ userAuthLoop:
542573

543574
s.user = userAuthReq.User
544575

545-
if !displayedBanner && config.BannerCallback != nil {
546-
displayedBanner = true
547-
msg := config.BannerCallback(s)
548-
if msg != "" {
549-
bannerMsg := &userAuthBannerMsg{
550-
Message: msg,
551-
}
552-
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
576+
if !calledBannerCallback && config.BannerCallback != nil {
577+
calledBannerCallback = true
578+
if msg := config.BannerCallback(s); msg != "" {
579+
if err := s.SendAuthBanner(msg); err != nil {
553580
return nil, err
554581
}
555582
}
@@ -762,10 +789,7 @@ userAuthLoop:
762789
var bannerErr *BannerError
763790
if errors.As(authErr, &bannerErr) {
764791
if bannerErr.Message != "" {
765-
bannerMsg := &userAuthBannerMsg{
766-
Message: bannerErr.Message,
767-
}
768-
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
792+
if err := s.SendAuthBanner(bannerErr.Message); err != nil {
769793
return nil, err
770794
}
771795
}

ssh/server_test.go

+86
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,92 @@ func TestPublicKeyCallbackLastSeen(t *testing.T) {
348348
}
349349
}
350350

351+
func TestPreAuthConnAndBanners(t *testing.T) {
352+
testDone := make(chan struct{})
353+
defer close(testDone)
354+
355+
authConnc := make(chan ServerPreAuthConn, 1)
356+
serverConfig := &ServerConfig{
357+
PreAuthConnCallback: func(c ServerPreAuthConn) {
358+
t.Logf("got ServerPreAuthConn: %v", c)
359+
authConnc <- c // for use later in the test
360+
for _, s := range []string{"hello1", "hello2"} {
361+
if err := c.SendAuthBanner(s); err != nil {
362+
t.Errorf("failed to send banner %q: %v", s, err)
363+
}
364+
}
365+
// Now start a goroutine to spam SendAuthBanner in hopes
366+
// of hitting a race.
367+
go func() {
368+
for {
369+
select {
370+
case <-testDone:
371+
return
372+
default:
373+
if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase {
374+
t.Errorf("unexpected error from SendAuthBanner: %v", err)
375+
}
376+
time.Sleep(5 * time.Millisecond)
377+
}
378+
}
379+
}()
380+
},
381+
NoClientAuth: true,
382+
NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
383+
t.Logf("got NoClientAuthCallback")
384+
return &Permissions{}, nil
385+
},
386+
}
387+
serverConfig.AddHostKey(testSigners["rsa"])
388+
389+
var banners []string
390+
clientConfig := &ClientConfig{
391+
User: "test",
392+
HostKeyCallback: InsecureIgnoreHostKey(),
393+
BannerCallback: func(msg string) error {
394+
if msg != "attempted-race" {
395+
banners = append(banners, msg)
396+
}
397+
return nil
398+
},
399+
}
400+
401+
c1, c2, err := netPipe()
402+
if err != nil {
403+
t.Fatalf("netPipe: %v", err)
404+
}
405+
defer c1.Close()
406+
defer c2.Close()
407+
go newServer(c1, serverConfig)
408+
c, _, _, err := NewClientConn(c2, "", clientConfig)
409+
if err != nil {
410+
t.Fatalf("client connection failed: %v", err)
411+
}
412+
defer c.Close()
413+
414+
wantBanners := []string{
415+
"hello1",
416+
"hello2",
417+
}
418+
if !reflect.DeepEqual(banners, wantBanners) {
419+
t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
420+
}
421+
422+
// Now that we're authenticated, verify that use of SendBanner
423+
// is an error.
424+
var bc ServerPreAuthConn
425+
select {
426+
case bc = <-authConnc:
427+
default:
428+
t.Fatal("expected ServerPreAuthConn")
429+
}
430+
if err := bc.SendAuthBanner("wrong-phase"); err == nil {
431+
t.Error("unexpected success of SendAuthBanner after authentication")
432+
} else if err != errSendBannerPhase {
433+
t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase)
434+
}
435+
}
436+
351437
type markerConn struct {
352438
closed uint32
353439
used uint32

0 commit comments

Comments
 (0)