Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/ssl_exporter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ modules:
prober: tcp
tls_config:
server_name: example.com
tcp_mysql_starttls:
prober: tcp
tcp:
starttls: mysql
tls_config:
insecure_skip_verify: true
tcp_client_auth:
prober: tcp
tls_config:
Expand Down
97 changes: 92 additions & 5 deletions prober/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ type queryResponse struct {
send string
sendBytes []byte
expectBytes []byte
expectFn func(buffer []byte, bytes int) error
}

var (
// BUFFSIZE is default size in bytes for generic buffer
BUFFSIZE = 8 * 1024
// These are the protocols for which I had servers readily available to test
// against. There are plenty of other protocols that should be added here in
// the future.
Expand Down Expand Up @@ -111,6 +114,76 @@ var (
expect: "OK",
},
},
"mysql": []queryResponse{
queryResponse{
expectFn: func(buffer []byte, bytes int) error {
if bytes == 0 {
return fmt.Errorf("read 0 bytes from MySQL server")
} else if bytes < 21 {
// Packet length[3], Packet number[1] + minimum payload[17]
return fmt.Errorf("MySQL packet too short. Expected length > 21, got %d", bytes)
} else if bytes != (4 + int(buffer[0]) + (int(buffer[1]) << 8) + (int(buffer[2]) << 16)) {
// Packet length[3], Packet number[1] + minimum payload[17]
return fmt.Errorf(
"MySQL packet length does not match. Got %d, expected %d",
bytes,
4+int(buffer[0])+(int(buffer[1])<<8)+(int(buffer[2])<<16),
)
} else if buffer[4] != 0xA {
// protocol version[1]
return fmt.Errorf("Only MySQL protocol version 10 (0xA) is supported. Got %x", buffer[4])
}
position := 5
// server version[string+NULL]
for ; ; position++ {
if position >= bytes {
return fmt.Errorf("Cannot confirm MySQL version")
} else if buffer[position] == 0 {
break
}
}
position++
// make sure we have at least 15 bytes left in the packet
if position+15 > bytes {
return fmt.Errorf("MySQL server handshake packet is broken")
}

position += 12 // skip over conn id[4] + SALT[8]
if buffer[position] != 0 { // verify filler
return fmt.Errorf(
"MySQL packet is broken. Expected null at %d position, got %x",
position,
buffer[position],
)
}
position++

// capability flags[2]
// !((packet[pos] + (packet[pos + 1] << 8)) & ssl_flg)
if (int(buffer[position]) + (int(buffer[position+1])<<8)&0x800) == 0 {
return fmt.Errorf("MySQL server does not support SSL")
}
return nil
},
},
queryResponse{
sendBytes: []byte{
/* payload_length, sequence_id */
0x20, 0x00, 0x00, 0x01,
/* payload */
/* capability flags, CLIENT_SSL always set */
0x85, 0xae, 0x7f, 0x00,
/* max-packet size */
0x00, 0x00, 0x00, 0x01,
/* character set */
0x21,
/* string[23] reserved (all [0]) */
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
},
},
"postgres": []queryResponse{
queryResponse{
sendBytes: []byte{0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f},
Expand All @@ -137,13 +210,13 @@ var (
func startTLS(logger log.Logger, conn net.Conn, proto string) error {
var err error

qr, ok := startTLSqueryResponses[proto]
actions, ok := startTLSqueryResponses[proto]
if !ok {
return fmt.Errorf("STARTTLS is not supported for %s", proto)
}

scanner := bufio.NewScanner(conn)
for _, qr := range qr {
for _, qr := range actions {
if qr.expect != "" {
var match bool
for scanner.Scan() {
Expand Down Expand Up @@ -171,12 +244,26 @@ func startTLS(logger log.Logger, conn net.Conn, proto string) error {
return nil
}
level.Debug(logger).Log("msg", fmt.Sprintf("read bytes: %x", buffer))
if bytes.Compare(buffer, qr.expectBytes) != 0 {
return fmt.Errorf("read bytes %x didn't match with expected bytes %x", buffer, qr.expectBytes)
} else {
if len(qr.expectBytes) > 0 {
if bytes.Compare(buffer, qr.expectBytes) != 0 {
return fmt.Errorf("read bytes %x didn't match with expected bytes %x", buffer, qr.expectBytes)
}
level.Debug(logger).Log("msg", fmt.Sprintf("expected bytes %x matched with read bytes %x", qr.expectBytes, buffer))
}
}
if qr.expectFn != nil {
buffer := make([]byte, BUFFSIZE)
bytes, err := conn.Read(buffer)
if err != nil {
return nil
}
level.Debug(logger).Log("msg", fmt.Sprintf("read bytes: %x", buffer))

if err := qr.expectFn(buffer, bytes); err != nil {
return err
}
level.Debug(logger).Log("msg", fmt.Sprintf("expected function for %s matched with read bytes %x", proto, buffer))
}
if qr.send != "" {
level.Debug(logger).Log("msg", fmt.Sprintf("sending line: %s", qr.send))
if _, err := fmt.Fprintf(conn, "%s\r\n", qr.send); err != nil {
Expand Down
39 changes: 39 additions & 0 deletions prober/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,45 @@ func TestProbeTCPStartTLSPostgreSQL(t *testing.T) {
checkTLSVersionMetrics("TLS 1.3", registry, t)
}

// TestProbeTCPStartTLSMySQL tests STARTTLS against a mock MySQL server
func TestProbeTCPStartTLSMySQL(t *testing.T) {
server, certPEM, _, caFile, teardown, err := test.SetupTCPServer()
if err != nil {
t.Fatalf(err.Error())
}
defer teardown()

server.StartMySQL()
defer server.Close()

module := config.Module{
TCP: config.TCPProbe{
StartTLS: "mysql",
},
TLSConfig: config.TLSConfig{
CAFile: caFile,
InsecureSkipVerify: false,
},
}

registry := prometheus.NewRegistry()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if err := ProbeTCP(ctx, newTestLogger(), server.Listener.Addr().String(), module, registry); err != nil {
t.Fatalf("error: %s", err)
}

cert, err := newCertificate(certPEM)
if err != nil {
t.Fatal(err)
}
checkCertificateMetrics(cert, registry, t)
checkOCSPMetrics([]byte{}, registry, t)
checkTLSVersionMetrics("TLS 1.3", registry, t)
}

// TestProbeTCPTimeout tests that the TCP probe respects the timeout in the
// context
func TestProbeTCPTimeout(t *testing.T) {
Expand Down
61 changes: 61 additions & 0 deletions test/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,67 @@ func (t *TCPServer) StartPostgreSQL() {
}()
}

// StartMySQL starts a listener that negotiates a TLS connection with a MySQL
// client using STARTTLS
func (t *TCPServer) StartMySQL() {
go func() {
conn, err := t.Listener.Accept()
if err != nil {
panic(fmt.Sprintf("Error accepting on socket: %s", err))
}
defer conn.Close()

// Packet extracted using tcpdump from a real MySQL server
sslResponseMessage := []byte{
0x54, 0x00, 0x00, 0x00, 0x0a, 0x35, 0x2e, 0x37, 0x2e, 0x33, 0x31, 0x2d, 0x33, 0x34,
0x2d, 0x35, 0x37, 0x2d, 0x6c, 0x6f, 0x67, 0x00, 0x92, 0x4d, 0x00, 0x00, 0x64, 0x41, 0x72, 0x79,
0x10, 0x07, 0x50, 0x18, 0x00, 0xff, 0xff, 0x2d, 0x02, 0x00, 0xff, 0xc1, 0x15, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x76, 0x2c, 0x54, 0x1e, 0x51, 0x5d, 0x06, 0x6c, 0x56,
0x44, 0x49, 0x7b, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65,
0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
}

if _, err := conn.Write(sslResponseMessage); err != nil {
panic("Error writing initial response to client")
}

sslRequestMessage := []byte{
/* payload_length, sequence_id */
0x20, 0x00, 0x00, 0x01,
/* payload */
/* capability flags, CLIENT_SSL always set */
0x85, 0xae, 0x7f, 0x00,
/* max-packet size */
0x00, 0x00, 0x00, 0x01,
/* character set */
0x21,
/* string[23] reserved (all [0]) */
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}

buffer := make([]byte, len(sslRequestMessage))

_, err = io.ReadFull(conn, buffer)
if err != nil {
panic("Error reading input from client")
}

if bytes.Compare(buffer, sslRequestMessage) != 0 {
panic(fmt.Sprintf("Error in dialog. No %x received", buffer))
}

tlsConn := tls.Server(conn, t.TLS)
if err := tlsConn.Handshake(); err != nil {
level.Error(t.logger).Log("msg", err)
}
defer tlsConn.Close()

t.stopCh <- struct{}{}
}()
}

// Close stops the server and closes the listener
func (t *TCPServer) Close() {
<-t.stopCh
Expand Down