Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 48 additions & 16 deletions providers/os/connection/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package ssh
import (
"bytes"
"context"
"errors"
"io"
"net"
"os"
Expand All @@ -16,6 +15,7 @@ import (
"time"

awsconf "github.com/aws/aws-sdk-go-v2/config"
"github.com/cockroachdb/errors"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should just use "errors"

"github.com/kevinburke/ssh_config"
"github.com/mitchellh/go-homedir"
rawsftp "github.com/pkg/sftp"
Expand All @@ -33,7 +33,6 @@ import (
"go.mondoo.com/cnquery/v11/providers/os/connection/ssh/signers"
"go.mondoo.com/cnquery/v11/utils/multierr"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)

Expand Down Expand Up @@ -493,27 +492,60 @@ func establishClientConnection(pCfg *inventory.Config, hostKeyCallback ssh.HostK
}
}

log.Debug().Int("methods", len(authMethods)).Str("user", user).Msg("connect to remote ssh")
conn, err := ssh.Dial("tcp", pCfg.Host+":"+strconv.Itoa(int(pCfg.Port)), &ssh.ClientConfig{
addr := pCfg.Host + ":" + strconv.Itoa(int(pCfg.Port))
sshClientConfig := &ssh.ClientConfig{
User: user,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
})
}

supportsHybrid, err := serverSupportsHybridKEX(addr)
if err == nil && supportsHybrid {
// force the Key Exchange Algorithm to a compatible one
sshClientConfig.Config = ssh.Config{
KeyExchanges: []string{
"curve25519-sha256",
"curve25519-sha256@libssh.org",
"ecdh-sha2-nistp256",
"diffie-hellman-group14-sha1",
},
}
}

log.Debug().
Int("methods", len(authMethods)).
Str("user", user).
Bool("hybrid_key_exchange", supportsHybrid).
Msg("connect to remote ssh")
conn, err := ssh.Dial("tcp", addr, sshClientConfig)
return conn, closer, err
}

// hasAgentLoadedKey returns if the ssh agent has loaded the key file
// This may not be 100% accurate. The key can be stored in multiple locations with the
// same fingerprint. We cannot determine the fingerprint without decoding the encrypted
// key, `ssh-keygen -lf /Users/chartmann/.ssh/id_rsa` seems to use the ssh agent to
// determine the fingerprint without prompting for the password
func hasAgentLoadedKey(list []*agent.Key, filename string) bool {
for i := range list {
if list[i].Comment == filename {
return true
}
// Detects if the remote server offers hybrid PQ KEX algorithms
func serverSupportsHybridKEX(addr string) (bool, error) {
conn, err := net.DialTimeout("tcp", addr, 5*time.Second)
if err != nil {
log.Debug().Err(err).Msg("fail to verify KEX algorithms")
return false, err
}
return false
defer conn.Close()

// Read the server's version string
var buf [256]byte
n, err := conn.Read(buf[:])
if err != nil {
return false, errors.Wrap(err, "failed to read banner")
}
banner := string(buf[:n])

// We'll stop here. Full KEXINIT parsing requires building a custom packet reader.
// For now, assume OpenSSH 9.9+ includes PQ KEX unless we detect otherwise.
if strings.Contains(banner, "OpenSSH_9.9") {
// Naively assume 9.9+ offers hybrid KEX
return true, nil
}

return false, nil
}

// prepareConnection determines the auth methods required for a ssh connection and also prepares any other
Expand Down
65 changes: 65 additions & 0 deletions providers/os/connection/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
package ssh

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/providers-sdk/v1/inventory"
"go.mondoo.com/cnquery/v11/providers/os/connection/shared"
)
Expand Down Expand Up @@ -37,3 +39,66 @@ func TestSSHAuthError(t *testing.T) {
// local testing without ssh agent
err.Error() == "no authentication method defined")
}

// helper to start a fake SSH server with a custom banner
func startMockSSHServer(t *testing.T, banner string) (addr string, closeFn func()) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.Nil(t, err)

go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
// simulate SSH banner
_, _ = conn.Write([]byte(banner + "\r\n"))
}()

return ln.Addr().String(), func() { ln.Close() }
}

func TestServerSupportsHybridKEX(t *testing.T) {
tests := []struct {
name string
banner string
expectHybrid bool
}{
{
name: "OpenSSH 9.9 detected",
banner: "SSH-2.0-OpenSSH_9.9",
expectHybrid: true,
},
{
name: "OpenSSH 9.7 (no hybrid)",
banner: "SSH-2.0-OpenSSH_9.7",
expectHybrid: false,
},
{
name: "Non-OpenSSH server",
banner: "SSH-2.0-CustomSSH_1.0",
expectHybrid: false,
},
{
name: "Malformed banner",
banner: "garbage",
expectHybrid: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, shutdown := startMockSSHServer(t, tt.banner)
defer shutdown()

got, err := serverSupportsHybridKEX(addr)
require.Nil(t, err)
assert.Equal(t, tt.expectHybrid, got)
})
}
}

func TestServerSupportsHybridKEX_ServerUnreachable(t *testing.T) {
_, err := serverSupportsHybridKEX("127.0.0.1:9")
require.NotNil(t, err)
}
Loading