@@ -59,6 +59,27 @@ type GSSAPIWithMICConfig struct {
59
59
Server GSSAPIServer
60
60
}
61
61
62
+ // SendAuthBanner implements [ServerPreAuthConn].
63
+ func (s * connection ) SendAuthBanner (msg string ) error {
64
+ return s .transport .writePacket (Marshal (& userAuthBannerMsg {
65
+ Message : msg ,
66
+ }))
67
+ }
68
+
69
+ func (* connection ) unexportedMethodForFutureProofing () {}
70
+
71
+ // ServerPreAuthConn is the interface available on an incoming server
72
+ // connection before authentication has completed.
73
+ type ServerPreAuthConn interface {
74
+ unexportedMethodForFutureProofing () // permits growing ServerPreAuthConn safely later, ala testing.TB
75
+
76
+ ConnMetadata
77
+
78
+ // SendAuthBanner sends a banner message to the client.
79
+ // It returns an error once the authentication phase has ended.
80
+ SendAuthBanner (string ) error
81
+ }
82
+
62
83
// ServerConfig holds server specific configuration data.
63
84
type ServerConfig struct {
64
85
// Config contains configuration shared between client and server.
@@ -118,6 +139,12 @@ type ServerConfig struct {
118
139
// attempts.
119
140
AuthLogCallback func (conn ConnMetadata , method string , err error )
120
141
142
+ // PreAuthConnCallback, if non-nil, is called upon receiving a new connection
143
+ // before any authentication has started. The provided ServerPreAuthConn
144
+ // can be used at any time before authentication is complete, including
145
+ // after this callback has returned.
146
+ PreAuthConnCallback func (ServerPreAuthConn )
147
+
121
148
// ServerVersion is the version identification string to announce in
122
149
// the public handshake.
123
150
// If empty, a reasonable default is used.
@@ -488,14 +515,18 @@ func (b *BannerError) Error() string {
488
515
}
489
516
490
517
func (s * connection ) serverAuthenticate (config * ServerConfig ) (* Permissions , error ) {
518
+ if config .PreAuthConnCallback != nil {
519
+ config .PreAuthConnCallback (s )
520
+ }
521
+
491
522
sessionID := s .transport .getSessionID ()
492
523
var cache pubKeyCache
493
524
var perms * Permissions
494
525
495
526
authFailures := 0
496
527
noneAuthCount := 0
497
528
var authErrs []error
498
- var displayedBanner bool
529
+ var calledBannerCallback bool
499
530
partialSuccessReturned := false
500
531
// Set the initial authentication callbacks from the config. They can be
501
532
// changed if a PartialSuccessError is returned.
@@ -542,14 +573,10 @@ userAuthLoop:
542
573
543
574
s .user = userAuthReq .User
544
575
545
- if ! displayedBanner && config .BannerCallback != nil {
546
- displayedBanner = true
547
- msg := config .BannerCallback (s )
548
- if msg != "" {
549
- bannerMsg := & userAuthBannerMsg {
550
- Message : msg ,
551
- }
552
- if err := s .transport .writePacket (Marshal (bannerMsg )); err != nil {
576
+ if ! calledBannerCallback && config .BannerCallback != nil {
577
+ calledBannerCallback = true
578
+ if msg := config .BannerCallback (s ); msg != "" {
579
+ if err := s .SendAuthBanner (msg ); err != nil {
553
580
return nil , err
554
581
}
555
582
}
@@ -762,10 +789,7 @@ userAuthLoop:
762
789
var bannerErr * BannerError
763
790
if errors .As (authErr , & bannerErr ) {
764
791
if bannerErr .Message != "" {
765
- bannerMsg := & userAuthBannerMsg {
766
- Message : bannerErr .Message ,
767
- }
768
- if err := s .transport .writePacket (Marshal (bannerMsg )); err != nil {
792
+ if err := s .SendAuthBanner (bannerErr .Message ); err != nil {
769
793
return nil , err
770
794
}
771
795
}
0 commit comments