|
5 | 5 | package ssh
|
6 | 6 |
|
7 | 7 | import (
|
| 8 | + "errors" |
8 | 9 | "io"
|
9 | 10 | "net"
|
| 11 | + "strings" |
10 | 12 | "sync/atomic"
|
11 | 13 | "testing"
|
12 | 14 | "time"
|
@@ -62,6 +64,133 @@ func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
|
62 | 64 | }
|
63 | 65 | }
|
64 | 66 |
|
| 67 | +func TestMaxAuthTriesNoneMethod(t *testing.T) { |
| 68 | + username := "testuser" |
| 69 | + serverConfig := &ServerConfig{ |
| 70 | + MaxAuthTries: 2, |
| 71 | + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { |
| 72 | + if conn.User() == username && string(password) == clientPassword { |
| 73 | + return nil, nil |
| 74 | + } |
| 75 | + return nil, errors.New("invalid credentials") |
| 76 | + }, |
| 77 | + } |
| 78 | + c1, c2, err := netPipe() |
| 79 | + if err != nil { |
| 80 | + t.Fatalf("netPipe: %v", err) |
| 81 | + } |
| 82 | + defer c1.Close() |
| 83 | + defer c2.Close() |
| 84 | + |
| 85 | + var serverAuthErrors []error |
| 86 | + |
| 87 | + serverConfig.AddHostKey(testSigners["rsa"]) |
| 88 | + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { |
| 89 | + serverAuthErrors = append(serverAuthErrors, err) |
| 90 | + } |
| 91 | + go newServer(c1, serverConfig) |
| 92 | + |
| 93 | + clientConfig := ClientConfig{ |
| 94 | + User: username, |
| 95 | + HostKeyCallback: InsecureIgnoreHostKey(), |
| 96 | + } |
| 97 | + clientConfig.SetDefaults() |
| 98 | + // Our client will send 'none' auth only once, so we need to send the |
| 99 | + // requests manually. |
| 100 | + c := &connection{ |
| 101 | + sshConn: sshConn{ |
| 102 | + conn: c2, |
| 103 | + user: username, |
| 104 | + clientVersion: []byte(packageVersion), |
| 105 | + }, |
| 106 | + } |
| 107 | + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) |
| 108 | + if err != nil { |
| 109 | + t.Fatalf("unable to exchange version: %v", err) |
| 110 | + } |
| 111 | + c.transport = newClientTransport( |
| 112 | + newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */), |
| 113 | + c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr()) |
| 114 | + if err := c.transport.waitSession(); err != nil { |
| 115 | + t.Fatalf("unable to wait session: %v", err) |
| 116 | + } |
| 117 | + c.sessionID = c.transport.getSessionID() |
| 118 | + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { |
| 119 | + t.Fatalf("unable to send ssh-userauth message: %v", err) |
| 120 | + } |
| 121 | + packet, err := c.transport.readPacket() |
| 122 | + if err != nil { |
| 123 | + t.Fatal(err) |
| 124 | + } |
| 125 | + if len(packet) > 0 && packet[0] == msgExtInfo { |
| 126 | + packet, err = c.transport.readPacket() |
| 127 | + if err != nil { |
| 128 | + t.Fatal(err) |
| 129 | + } |
| 130 | + } |
| 131 | + var serviceAccept serviceAcceptMsg |
| 132 | + if err := Unmarshal(packet, &serviceAccept); err != nil { |
| 133 | + t.Fatal(err) |
| 134 | + } |
| 135 | + for i := 0; i <= serverConfig.MaxAuthTries; i++ { |
| 136 | + auth := new(noneAuth) |
| 137 | + _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil) |
| 138 | + if i < serverConfig.MaxAuthTries { |
| 139 | + if err != nil { |
| 140 | + t.Fatal(err) |
| 141 | + } |
| 142 | + continue |
| 143 | + } |
| 144 | + if err == nil { |
| 145 | + t.Fatal("client: got no error") |
| 146 | + } else if !strings.Contains(err.Error(), "too many authentication failures") { |
| 147 | + t.Fatalf("client: got unexpected error: %v", err) |
| 148 | + } |
| 149 | + } |
| 150 | + if len(serverAuthErrors) != 3 { |
| 151 | + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) |
| 152 | + } |
| 153 | + for _, err := range serverAuthErrors { |
| 154 | + if !errors.Is(err, ErrNoAuth) { |
| 155 | + t.Errorf("go error: %v; want: %v", err, ErrNoAuth) |
| 156 | + } |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) { |
| 161 | + username := "testuser" |
| 162 | + serverConfig := &ServerConfig{ |
| 163 | + MaxAuthTries: 1, |
| 164 | + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { |
| 165 | + if conn.User() == username && string(password) == clientPassword { |
| 166 | + return nil, nil |
| 167 | + } |
| 168 | + return nil, errors.New("invalid credentials") |
| 169 | + }, |
| 170 | + } |
| 171 | + clientConfig := &ClientConfig{ |
| 172 | + User: username, |
| 173 | + Auth: []AuthMethod{ |
| 174 | + Password(clientPassword), |
| 175 | + }, |
| 176 | + HostKeyCallback: InsecureIgnoreHostKey(), |
| 177 | + } |
| 178 | + |
| 179 | + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) |
| 180 | + if err != nil { |
| 181 | + t.Fatalf("client login error: %s", err) |
| 182 | + } |
| 183 | + if len(serverAuthErrors) != 2 { |
| 184 | + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) |
| 185 | + } |
| 186 | + if !errors.Is(serverAuthErrors[0], ErrNoAuth) { |
| 187 | + t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth) |
| 188 | + } |
| 189 | + if serverAuthErrors[1] != nil { |
| 190 | + t.Errorf("unexpected error: %v", serverAuthErrors[1]) |
| 191 | + } |
| 192 | +} |
| 193 | + |
65 | 194 | func TestNewServerConnValidationErrors(t *testing.T) {
|
66 | 195 | serverConf := &ServerConfig{
|
67 | 196 | PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
|
|
0 commit comments