|
1 | | -// Copyright © 2022 Kaleido, Inc. |
| 1 | +// Copyright © 2023 Kaleido, Inc. |
2 | 2 | // |
3 | 3 | // SPDX-License-Identifier: Apache-2.0 |
4 | 4 | // |
|
17 | 17 | package wsclient |
18 | 18 |
|
19 | 19 | import ( |
| 20 | + "crypto/rand" |
| 21 | + "crypto/rsa" |
| 22 | + "crypto/tls" |
| 23 | + "crypto/x509" |
| 24 | + "crypto/x509/pkix" |
| 25 | + "encoding/pem" |
20 | 26 | "fmt" |
| 27 | + "math/big" |
| 28 | + "net" |
21 | 29 | "net/http" |
22 | 30 | "net/http/httptest" |
| 31 | + "os" |
| 32 | + "testing" |
| 33 | + "time" |
23 | 34 |
|
24 | 35 | "github.com/gorilla/websocket" |
| 36 | + "github.com/stretchr/testify/assert" |
25 | 37 | ) |
26 | 38 |
|
| 39 | +// GenerateTLSCertificates creates a key pair for server and client auth |
| 40 | +func GenerateTLSCertficates(t *testing.T) (publicKeyFile *os.File, privateKeyFile *os.File) { |
| 41 | + // Create an X509 certificate pair |
| 42 | + privatekey, _ := rsa.GenerateKey(rand.Reader, 2048) |
| 43 | + publickey := &privatekey.PublicKey |
| 44 | + var privateKeyBytes = x509.MarshalPKCS1PrivateKey(privatekey) |
| 45 | + privateKeyFile, _ = os.CreateTemp("", "key.pem") |
| 46 | + privateKeyBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes} |
| 47 | + err := pem.Encode(privateKeyFile, privateKeyBlock) |
| 48 | + assert.NoError(t, err) |
| 49 | + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) |
| 50 | + x509Template := &x509.Certificate{ |
| 51 | + SerialNumber: serialNumber, |
| 52 | + Subject: pkix.Name{ |
| 53 | + Organization: []string{"Unit Tests"}, |
| 54 | + }, |
| 55 | + NotBefore: time.Now(), |
| 56 | + NotAfter: time.Now().Add(1000 * time.Second), |
| 57 | + KeyUsage: x509.KeyUsageDigitalSignature, |
| 58 | + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, |
| 59 | + BasicConstraintsValid: true, |
| 60 | + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, |
| 61 | + } |
| 62 | + derBytes, err := x509.CreateCertificate(rand.Reader, x509Template, x509Template, publickey, privatekey) |
| 63 | + assert.NoError(t, err) |
| 64 | + publicKeyFile, _ = os.CreateTemp("", "cert.pem") |
| 65 | + err = pem.Encode(publicKeyFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) |
| 66 | + assert.NoError(t, err) |
| 67 | + |
| 68 | + return publicKeyFile, privateKeyFile |
| 69 | +} |
| 70 | + |
| 71 | +// NewTestTLSWSServer creates a little test server for packages (including wsclient itself) to use in unit tests |
| 72 | +// and secured with mTLS by passing in a key pair |
| 73 | +func NewTestTLSWSServer(testReq func(req *http.Request), publicKeyFile *os.File, privateKeyFile *os.File) (toServer, fromServer chan string, url string, done func(), err error) { |
| 74 | + upgrader := &websocket.Upgrader{WriteBufferSize: 1024, ReadBufferSize: 1024} |
| 75 | + toServer = make(chan string, 1) |
| 76 | + fromServer = make(chan string, 1) |
| 77 | + sendDone := make(chan struct{}) |
| 78 | + receiveDone := make(chan struct{}) |
| 79 | + connected := false |
| 80 | + |
| 81 | + handlerFunc := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { |
| 82 | + if testReq != nil { |
| 83 | + testReq(req) |
| 84 | + } |
| 85 | + if connected { |
| 86 | + // test server only handles one open connection, as it only has one set of channels |
| 87 | + res.WriteHeader(409) |
| 88 | + return |
| 89 | + } |
| 90 | + ws, _ := upgrader.Upgrade(res, req, http.Header{}) |
| 91 | + go func() { |
| 92 | + defer close(receiveDone) |
| 93 | + for { |
| 94 | + _, data, err := ws.ReadMessage() |
| 95 | + if err != nil { |
| 96 | + return |
| 97 | + } |
| 98 | + toServer <- string(data) |
| 99 | + } |
| 100 | + }() |
| 101 | + go func() { |
| 102 | + defer close(sendDone) |
| 103 | + defer ws.Close() |
| 104 | + for data := range fromServer { |
| 105 | + _ = ws.WriteMessage(websocket.TextMessage, []byte(data)) |
| 106 | + } |
| 107 | + }() |
| 108 | + connected = true |
| 109 | + }) |
| 110 | + |
| 111 | + svr := httptest.NewUnstartedServer(handlerFunc) |
| 112 | + |
| 113 | + cert, err := tls.LoadX509KeyPair(publicKeyFile.Name(), privateKeyFile.Name()) |
| 114 | + if err != nil { |
| 115 | + return toServer, fromServer, "", nil, err |
| 116 | + } |
| 117 | + rootCAs := x509.NewCertPool() |
| 118 | + caPEM, _ := os.ReadFile(publicKeyFile.Name()) |
| 119 | + rootCAs.AppendCertsFromPEM(caPEM) |
| 120 | + svr.TLS = &tls.Config{ |
| 121 | + MinVersion: tls.VersionTLS12, |
| 122 | + ClientAuth: tls.RequireAndVerifyClientCert, |
| 123 | + Certificates: []tls.Certificate{cert}, |
| 124 | + ClientCAs: rootCAs, |
| 125 | + } |
| 126 | + svr.StartTLS() |
| 127 | + addr := svr.Listener.Addr() |
| 128 | + |
| 129 | + return toServer, fromServer, fmt.Sprintf("wss://%s", addr), func() { |
| 130 | + close(fromServer) |
| 131 | + svr.Close() |
| 132 | + if connected { |
| 133 | + <-sendDone |
| 134 | + <-receiveDone |
| 135 | + } |
| 136 | + }, nil |
| 137 | +} |
| 138 | + |
27 | 139 | // NewTestWSServer creates a little test server for packages (including wsclient itself) to use in unit tests |
28 | 140 | func NewTestWSServer(testReq func(req *http.Request)) (toServer, fromServer chan string, url string, done func()) { |
29 | 141 | upgrader := &websocket.Upgrader{WriteBufferSize: 1024, ReadBufferSize: 1024} |
|
0 commit comments