Skip to content
This repository was archived by the owner on Dec 3, 2019. It is now read-only.

Add functions to httparchive.go to restrict certificate SANs #4651

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 196 additions & 3 deletions web_page_replay_go/src/httparchive.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,61 @@ package main
import (
"bufio"
"bytes"
"crypto"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"time"

"github.com/catapult-project/catapult/web_page_replay_go/src/webpagereplay"
"./webpagereplay"
"github.com/urfave/cli"
)

const usage = "%s [ls|cat|edit|merge|add|addAll] [options] archive_file [output_file] [url]"

type CertConfig struct {
// Flags common to all commands.
certFile, keyFile string
}

func (certCfg *CertConfig) Flags() []cli.Flag {
return []cli.Flag{
cli.StringFlag{
Name: "https_cert_file",
Value: "wpr_cert.pem",
Usage: "File containing a PEM-encoded X509 certificate to use with SSL.",
Destination: &certCfg.certFile,
},
cli.StringFlag{
Name: "https_key_file",
Value: "wpr_key.pem",
Usage: "File containing a PEM-encoded private key to use with SSL.",
Destination: &certCfg.keyFile,
},
}
}

type Config struct {
method, host, fullPath string
decodeResponseBody, skipExisting, overwriteExisting bool
certConfig CertConfig
root_cert tls.Certificate
}

func (cfg *Config) DefaultFlags() []cli.Flag {
return []cli.Flag{
return append(cfg.certConfig.Flags(),

cli.StringFlag{
Name: "command",
Value: "",
Expand All @@ -53,7 +86,8 @@ func (cfg *Config) DefaultFlags() []cli.Flag {
Usage: "Decode/encode response body according to Content-Encoding header.",
Destination: &cfg.decodeResponseBody,
},
}
)

}

func (cfg *Config) AddFlags() []cli.Flag {
Expand Down Expand Up @@ -272,6 +306,146 @@ func addAll(cfg *Config, archive *webpagereplay.Archive, outfile string, inputFi
return writeArchive(archive, outfile)
}

func restrictSSLCertSANs(cfg *Config, archive *webpagereplay.Archive, outfile string) error {

var ipMap = make(map[string][]string)

//Find hosts present in the requests collection & assigns certificates and host ips for these
requestHostsDict := make(map[string]string)
for requestHost := range archive.Requests {
if _, ok := requestHostsDict[requestHost]; !ok {
requestHostsDict[requestHost] = requestHost
}
}
for k := range requestHostsDict {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}
conn, err := tls.DialWithDialer(dialer, "tcp", fmt.Sprintf("%s:443", requestHostsDict[k]), &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
})
if err == nil {
_, ok := archive.RemoteAddresses[requestHostsDict[k]]
if !ok {
fakecert, err := x509.ParseCertificate(cfg.root_cert.Certificate[0])
if err == nil {
currCert, er := CreateDomainRestrictedCert([]string{requestHostsDict[k]}, conn.ConnectionState().PeerCertificates[0], fakecert, cfg.root_cert.PrivateKey)
if er == nil {
if archive.RemoteAddresses == nil {
archive.RemoteAddresses = make(map[string]string)
}
archive.RemoteAddresses[requestHostsDict[k]] = conn.RemoteAddr().String()

if _, ok := archive.NegotiatedProtocol[requestHostsDict[k]]; !ok {
archive.NegotiatedProtocol[requestHostsDict[k]] = conn.ConnectionState().NegotiatedProtocol
}
if _, ok := archive.Certs[requestHostsDict[k]]; !ok {
archive.Certs[requestHostsDict[k]] = currCert
}
}
}
}
}
}

for host, ip := range archive.RemoteAddresses {
if h, ok := ipMap[ip]; ok {
h = append(h, host)
ipMap[ip] = h
} else {
ipMap[ip] = []string{host}
}
}

for ip := range ipMap {
for i := range ipMap[ip] {
currentSANList := []string{ipMap[ip][i]}
currentCert, err := x509.ParseCertificate(archive.Certs[ipMap[ip][i]])
//derBytes, negotiatedProtocol, ip, err := archive.FindHostTLSConfig(i)
if err != nil {
return err
}
for j := range ipMap[ip] {
if j != i {
certValidationErr := currentCert.VerifyHostname(ipMap[ip][j])
if certValidationErr == nil {
currentSANList = append(currentSANList, ipMap[ip][j])
}
}
}

fakecert, e := x509.ParseCertificate(cfg.root_cert.Certificate[0])

if e != nil {
println(fmt.Sprintf("New Cert DNS : %v", e))
}

newCert, err := CreateDomainRestrictedCert(currentSANList, currentCert, fakecert, cfg.root_cert.PrivateKey)
newCertParsed, er := x509.ParseCertificate(newCert)
if er == nil {
println(fmt.Sprintf("New Cert CN: %s", newCertParsed.Subject.CommonName))
println(fmt.Sprintf("Old Cert CN: %s", currentCert.Subject.CommonName))

for str := range newCertParsed.DNSNames {
println(fmt.Sprintf("IP: %s New Cert DNS : %s", ip, newCertParsed.DNSNames[str]))
}
for str := range currentCert.DNSNames {
println(fmt.Sprintf("IP: %s Old Cert DNS : %s", ip, currentCert.DNSNames[str]))
}
}

archive.Certs[ipMap[ip][i]] = newCert

}
}

return writeArchive(archive, outfile)
}

//Mints a restricted certificate that is only valid for the SANs in the certificateSAN parameter
func CreateDomainRestrictedCert(certificateSANs []string, rootCert *x509.Certificate, fakecert *x509.Certificate, rootKey crypto.PrivateKey) ([]byte, error) {
template := x509.Certificate{

SerialNumber: rootCert.SerialNumber,

Subject: pkix.Name{

CommonName: certificateSANs[0],
},

Issuer: fakecert.Subject,

NotBefore: time.Now(),

NotAfter: time.Now().Add(time.Hour * 24 * 180),

KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign,

ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},

BasicConstraintsValid: true,

IsCA: true,

AuthorityKeyId: rootCert.AuthorityKeyId,

CRLDistributionPoints: rootCert.CRLDistributionPoints,

IssuingCertificateURL: rootCert.IssuingCertificateURL,

DNSNames: certificateSANs,

PublicKey: rootCert.PublicKey,
}

derBytes, err := x509.CreateCertificate(rand.Reader, &template, fakecert, fakecert.PublicKey, rootKey)

return derBytes, err

}

// compressResponse compresses resp.Body in place according to resp's Content-Encoding header.
func compressResponse(resp *http.Response) error {
ce := strings.ToLower(resp.Header.Get("Content-Encoding"))
Expand Down Expand Up @@ -310,6 +484,15 @@ func main() {
if len(c.Args()) != wantArgs {
return fmt.Errorf("Expected %d arguments but got %d", wantArgs, len(c.Args()))
}
cfg.certConfig.certFile = "wpr_cert.pem"
cfg.certConfig.keyFile = "wpr_key.pem"
log.Printf("Loading cert from %v\n", cfg.certConfig.certFile)
log.Printf("Loading key from %v\n", cfg.certConfig.keyFile)
var err error
cfg.root_cert, err = tls.LoadX509KeyPair(cfg.certConfig.certFile, cfg.certConfig.keyFile)
if err != nil {
return fmt.Errorf("error opening cert or key files: %v", err)
}
return nil
}
}
Expand Down Expand Up @@ -387,6 +570,16 @@ func main() {
return addAll(cfg, loadArchiveOrDie(c, 0), c.Args().Get(1), c.Args().Get(2))
},
},
cli.Command{
Name: "restrictSSLCertSANs",
Usage: "Transforms the certificates in the archives to only those SANs that were served from the IP address",
ArgsUsage: "input_archive output_archive urls_file",
Flags: cfg.AddFlags(),
Before: checkArgs("restrictSSLCertSANs", 2),
Action: func(c *cli.Context) error {
return restrictSSLCertSANs(cfg, loadArchiveOrDie(c, 0), c.Args().Get(1))
},
},
}
app.Usage = "HTTP Archive Utils"
app.UsageText = fmt.Sprintf(usage, progName)
Expand Down
17 changes: 10 additions & 7 deletions web_page_replay_go/src/webpagereplay/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ type RequestMatch struct {
}

func (requestMatch *RequestMatch) SetMatch(
match *ArchivedRequest,
request *http.Request,
response *http.Response,
ratio float64) {
match *ArchivedRequest,
request *http.Request,
response *http.Response,
ratio float64) {
requestMatch.Match = match
requestMatch.Request = request
requestMatch.Response = response
Expand Down Expand Up @@ -103,6 +103,8 @@ type Archive struct {
// Maps host string to the negotiated protocol. eg. "http/1.1" or "h2"
// If absent, will default to "http/1.1".
NegotiatedProtocol map[string]string
// Maps the remote IPs for the hosts, will be used with transforming the certificate SANS
RemoteAddresses map[string]string
// The time seed that was used to initialize deterministic.js.
DeterministicTimeSeedMs int64
// When an incoming request matches multiple recorded responses, whether to
Expand Down Expand Up @@ -254,9 +256,9 @@ func (a *Archive) FindRequest(req *http.Request) (*http.Request, *http.Response,
// Given an incoming request and a set of matches in the archive, identify the best match,
// based on request headers.
func (a *Archive) findBestMatchInArchivedRequestSet(
incomingReq *http.Request,
archivedReqs []*ArchivedRequest) (
*http.Request, *http.Response, error) {
incomingReq *http.Request,
archivedReqs []*ArchivedRequest) (
*http.Request, *http.Response, error) {
scheme := incomingReq.URL.Scheme

if len(archivedReqs) == 0 {
Expand Down Expand Up @@ -496,6 +498,7 @@ func (a *WritableArchive) RecordTlsConfig(host string, der_bytes []byte, negotia
a.NegotiatedProtocol = make(map[string]string)
}
a.NegotiatedProtocol[host] = negotiatedProtocol

}

// Close flushes the the archive and closes the output file.
Expand Down