Skip to content

Commit f1de758

Browse files
authored
Merge pull request #17 from moriyoshi/prepared-certificates
Support for prepared certificates
2 parents 51a9eee + fee58d1 commit f1de758

File tree

9 files changed

+498
-66
lines changed

9 files changed

+498
-66
lines changed

certcache.go

+150-40
Original file line numberDiff line numberDiff line change
@@ -56,32 +56,80 @@ type CertCache struct {
5656

5757
const certificateFileName = "cert.pem"
5858
const certificateBlockName = "CERTIFICATE"
59+
const privateKeyFileName = "key.pem"
60+
const privateKeyBlockName = "PRIVATE KEY"
5961

6062
func buildKeyString(hosts []string) string {
6163
key := strings.Join(hosts, ";")
6264
return key
6365
}
6466

6567
func (c *CertCache) writeCertificate(key string, cert *tls.Certificate) (err error) {
66-
leadingDirs, ok := c.buildPathToCachedCert(key)
68+
leadingDir, ok := c.buildPathToCachedCert(key)
6769
if !ok {
6870
return
6971
}
70-
err = os.MkdirAll(leadingDirs, os.FileMode(0777))
72+
73+
leadingDirTmp := leadingDir + "$tmp$"
74+
err = os.MkdirAll(leadingDirTmp, os.FileMode(0700))
7175
if err != nil {
7276
return
7377
}
74-
path := filepath.Join(leadingDirs, certificateFileName)
75-
w, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0666))
78+
defer func() {
79+
if err != nil {
80+
os.RemoveAll(leadingDirTmp)
81+
}
82+
}()
83+
84+
certFilePath := filepath.Join(leadingDirTmp, certificateFileName)
85+
privKeyFilePath := filepath.Join(leadingDirTmp, privateKeyFileName)
86+
87+
err = func(certFilePath string) (err error) {
88+
w, err := os.OpenFile(certFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0600))
89+
if err != nil {
90+
return
91+
}
92+
defer w.Close()
93+
94+
for _, x509Cert := range cert.Certificate {
95+
err = pem.Encode(w, &pem.Block{Type: certificateBlockName, Bytes: x509Cert})
96+
if err != nil {
97+
return
98+
}
99+
_, err = w.Write([]byte{'\n'})
100+
if err != nil {
101+
return
102+
}
103+
}
104+
return
105+
}(certFilePath)
76106
if err != nil {
77107
return
78108
}
79-
defer w.Close()
80-
err = pem.Encode(w, &pem.Block{Type: certificateBlockName, Bytes: cert.Certificate[0]})
109+
110+
err = func(privKeyFilePath string) (err error) {
111+
privKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
112+
if err != nil {
113+
return
114+
}
115+
116+
w, err := os.OpenFile(privKeyFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0666))
117+
if err != nil {
118+
return
119+
}
120+
defer w.Close()
121+
err = pem.Encode(w, &pem.Block{Type: privateKeyBlockName, Bytes: privKeyBytes})
122+
if err != nil {
123+
return
124+
}
125+
return
126+
}(privKeyFilePath)
81127
if err != nil {
82128
return
83129
}
84-
return nil
130+
131+
err = os.Rename(leadingDirTmp, leadingDir)
132+
return
85133
}
86134

87135
func (c *CertCache) buildPathToCachedCert(key string) (string, bool) {
@@ -92,66 +140,121 @@ func (c *CertCache) buildPathToCachedCert(key string) (string, bool) {
92140
}
93141

94142
func (c *CertCache) readAndValidateCertificate(key string, hosts []string, now time.Time) (*tls.Certificate, error) {
95-
leadingDirs, ok := c.buildPathToCachedCert(key)
143+
leadingDir, ok := c.buildPathToCachedCert(key)
96144
if !ok {
97145
return nil, nil
98146
}
99-
path := filepath.Join(leadingDirs, certificateFileName)
100-
pemBytes, err := ioutil.ReadFile(path)
101-
if err != nil {
102-
return nil, err
147+
148+
if _, err := os.Stat(leadingDir); os.IsNotExist(err) {
149+
return nil, nil
103150
}
104-
certDerBytes := []byte(nil)
105-
for {
106-
var pemBlock *pem.Block
107-
pemBlock, pemBytes = pem.Decode(pemBytes)
108-
if pemBlock == nil {
109-
break
151+
152+
var x509Cert *x509.Certificate
153+
var certDerBytes [][]byte
154+
{
155+
certFilePath := filepath.Join(leadingDir, certificateFileName)
156+
pemBytes, err := ioutil.ReadFile(certFilePath)
157+
if err != nil {
158+
return nil, err
110159
}
111-
if pemBlock.Type == certificateBlockName {
112-
certDerBytes = pemBlock.Bytes
113-
break
160+
161+
for {
162+
var pemBlock *pem.Block
163+
pemBlock, pemBytes = pem.Decode(pemBytes)
164+
if pemBlock == nil {
165+
break
166+
}
167+
if pemBlock.Type == certificateBlockName {
168+
certDerBytes = append(certDerBytes, pemBlock.Bytes)
169+
}
170+
}
171+
if len(certDerBytes) == 0 {
172+
return nil, errors.Errorf("no valid certificate contained in %s", certFilePath)
173+
}
174+
175+
x509Cert, err = x509.ParseCertificate(certDerBytes[0])
176+
if err != nil {
177+
return nil, errors.Wrapf(err, "invalid certificate found in %s", certFilePath)
178+
}
179+
if len(certDerBytes) == 1 && c.issuerCert != nil {
180+
err = x509Cert.CheckSignatureFrom(c.issuerCert)
181+
if err != nil {
182+
return nil, errors.Wrapf(err, "invalid certificate found in %s", certFilePath)
183+
}
184+
}
185+
186+
if !now.Before(x509Cert.NotAfter) {
187+
return nil, errors.Errorf("ceritificate no longer valid (not after: %s, now: %s)", x509Cert.NotAfter.Local().Format(time.RFC1123), now.Local().Format(time.RFC1123))
114188
}
115189
}
116-
if certDerBytes == nil {
117-
return nil, errors.Errorf("no valid certificate contained in %s", path)
118-
}
119-
x509Cert, err := x509.ParseCertificate(certDerBytes)
120-
if err != nil {
121-
return nil, errors.Wrapf(err, "invalid certificate found in %s", path)
122-
}
123-
x509Cert.RawIssuer = c.issuerCert.Raw
124-
err = x509Cert.CheckSignatureFrom(c.issuerCert)
125-
if err != nil {
126-
return nil, errors.Wrapf(err, "invalid certificate found in %s", path)
190+
191+
var privKey crypto.PrivateKey
192+
{
193+
privKeyFilePath := filepath.Join(leadingDir, privateKeyFileName)
194+
pemBytes, err := ioutil.ReadFile(privKeyFilePath)
195+
if err != nil {
196+
if !os.IsNotExist(err) {
197+
return nil, err
198+
}
199+
}
200+
if err == nil {
201+
b, _ := pem.Decode(pemBytes)
202+
privKey, err = x509.ParsePKCS8PrivateKey(b.Bytes)
203+
if err != nil {
204+
return nil, errors.Wrapf(err, "failed to parse private key %s", privKeyFilePath)
205+
}
206+
} else {
207+
if c != nil {
208+
privKey = c.privateKey
209+
}
210+
err = nil
211+
}
127212
}
128-
if !now.Before(x509Cert.NotAfter) {
129-
return nil, errors.Errorf("ceritificate no longer valid (not after: %s, now: %s)", x509Cert.NotAfter.Local().Format(time.RFC1123), now.Local().Format(time.RFC1123))
213+
214+
if privKey == nil {
215+
return nil, errors.Errorf("no private key is available (cache is broken)")
130216
}
131217

132218
outer:
133219
for _, a := range hosts {
220+
dnsNameMatched := false
134221
for _, b := range x509Cert.DNSNames {
135-
if a == b {
222+
if wildMatch(b, a) {
223+
dnsNameMatched = true
136224
break outer
137225
}
138226
}
139-
return nil, errors.Errorf("certificate does not cover the host name %s", a)
227+
if !dnsNameMatched {
228+
dnsNameMatched = wildMatch(x509Cert.Subject.CommonName, a)
229+
}
230+
if !dnsNameMatched {
231+
return nil, errors.Errorf("certificate does not cover the host name %s", a)
232+
}
140233
}
141234

142235
return &tls.Certificate{
143-
Certificate: [][]byte{certDerBytes, c.issuerCert.Raw},
144-
PrivateKey: c.privateKey,
236+
Certificate: certDerBytes,
237+
PrivateKey: privKey,
145238
}, nil
146239
}
147240

241+
func (c *CertCache) evict(key string) error {
242+
c.Logger.Debugf("evicting cache foe %s", key)
243+
leadingDir, ok := c.buildPathToCachedCert(key)
244+
if !ok {
245+
return nil
246+
}
247+
return os.RemoveAll(leadingDir)
248+
}
249+
148250
func (c *CertCache) readCertificate(key string, hosts []string, now time.Time) (cert *tls.Certificate, err error) {
149251
cert, err = c.readAndValidateCertificate(
150252
key,
151253
hosts,
152254
now,
153255
)
154256
if err != nil {
257+
c.evict(key)
155258
c.Logger.Warn(err.Error())
156259
err = nil
157260
}
@@ -161,20 +264,27 @@ func (c *CertCache) readCertificate(key string, hosts []string, now time.Time) (
161264
func (c *CertCache) Put(hosts []string, cert *tls.Certificate) error {
162265
key := buildKeyString(hosts)
163266
c.certs[key] = cert
164-
return c.writeCertificate(key, cert)
267+
err := c.writeCertificate(key, cert)
268+
if err != nil {
269+
c.Logger.Warn(err.Error())
270+
err = nil
271+
}
272+
return err
165273
}
166274

167275
func (c *CertCache) Get(hosts []string, now time.Time) (cert *tls.Certificate, err error) {
168276
key := buildKeyString(hosts)
169277
cert, ok := c.certs[key]
170278
if !ok {
171-
c.Logger.Debug("Certificate not found in in-process cache")
279+
c.Logger.Debug("certificate not found in in-process cache")
172280
cert, err = c.readCertificate(key, hosts, now)
173281
if err != nil {
174282
return
175283
}
176284
if cert != nil {
177285
c.certs[key] = cert
286+
} else {
287+
c.Logger.Debug("certificate not found in cache directory")
178288
}
179289
}
180290
return

config.go

+33-1
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,20 @@ type PerHostConfig struct {
7070
Patterns []Pattern
7171
}
7272

73+
type PreparedCertificate struct {
74+
Pattern *regexp.Regexp
75+
TLSCertificate *tls.Certificate
76+
Certificate *x509.Certificate
77+
}
78+
7379
type MITMConfig struct {
7480
ServerTLSConfigTemplate *tls.Config
7581
ClientTLSConfigTemplate *tls.Config
7682
SigningCertificateKeyPair struct {
7783
Certificate *x509.Certificate
7884
PrivateKey crypto.PrivateKey
7985
}
86+
Prepared []PreparedCertificate
8087
CacheDirectory string
8188
DisableCache bool
8289
}
@@ -140,7 +147,7 @@ func (ctx *ConfigReaderContext) extractPerHostConfigs(deref dereference) (perHos
140147
"hosts", func(urlStr string, hostMap dereference) error {
141148
url, err := url.Parse(urlStr)
142149
if err != nil {
143-
return errors.Wrapf(err, "invalid value for URL (%s) under %s", urlStr)
150+
return errors.Wrapf(err, "invalid value for URL (%s)", urlStr)
144151
}
145152
if url.Path != "" {
146153
return errors.Errorf("path may not be present: %s", urlStr)
@@ -735,6 +742,7 @@ func (ctx *ConfigReaderContext) extractTLSConfig(deref dereference, client bool)
735742
}
736743

737744
func (ctx *ConfigReaderContext) extractMITMConfig(deref dereference) (retval MITMConfig, err error) {
745+
retval.ServerTLSConfigTemplate = new(tls.Config)
738746
retval.ClientTLSConfigTemplate = new(tls.Config)
739747
err = deref.multi(
740748
"tls", func(deref dereference) error {
@@ -758,6 +766,29 @@ func (ctx *ConfigReaderContext) extractMITMConfig(deref dereference) (retval MIT
758766
retval.SigningCertificateKeyPair.PrivateKey = tlsCert.PrivateKey
759767
return nil
760768
},
769+
"prepared", func(_ int, deref dereference) error {
770+
visited := false
771+
return deref.iterateHomogeneousValuedMap(yamlMapType, func(hostPattern string, deref dereference) error {
772+
if visited {
773+
return errors.Errorf("extra item exists")
774+
}
775+
visited = true
776+
hostPatternRegexp, err := regexp.Compile(hostPattern)
777+
if err != nil {
778+
return errors.Errorf("invalid regexp %s", hostPattern)
779+
}
780+
tlsCert, cert, err := ctx.extractCertPrivateKeyPairs(deref)
781+
if err != nil {
782+
return err
783+
}
784+
retval.Prepared = append(retval.Prepared, PreparedCertificate{
785+
Pattern: hostPatternRegexp,
786+
TLSCertificate: &tlsCert,
787+
Certificate: cert,
788+
})
789+
return nil
790+
})
791+
},
761792
"cache_directory", func(cacheDirectory string) error {
762793
retval.CacheDirectory = cacheDirectory
763794
return nil
@@ -862,6 +893,7 @@ func loadConfig(yamlFile string, progname string) (*Config, error) {
862893
return nil, errors.Wrapf(err, "failed to load %s", yamlFile)
863894
}
864895
ctx := &ConfigReaderContext{
896+
filename: yamlFile,
865897
warn: func(msg string) {
866898
fmt.Fprintf(os.Stderr, "%s: %s\n", progname, msg)
867899
},

deref.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ func (deref dereference) homogeneousMapValue(typ reflect.Type) (interface{}, err
778778
}
779779
mapVal, ok := deref.value.(map[interface{}]interface{})
780780
if !ok {
781-
return nil, deref.errorf("mapping of %s expected, got %T", deref.value)
781+
return nil, deref.errorf("mapping of %s expected, got %T", typ.String(), deref.value)
782782
}
783783
hMapVal := reflect.MakeMap(reflect.MapOf(emptyInterfaceType, typ))
784784
for k, v := range mapVal {

example.yml

+8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ tls:
1515
ca:
1616
cert: testca.rsa.crt.pem
1717
key: testca.rsa.key.pem
18+
# MITM with prepared certificates
19+
prepared:
20+
- ^local\\.my-domain\\.example\\.com$:
21+
cert: certs/my-domain-cert.pem
22+
key: certs/my-domain-key.pem
23+
- .*:
24+
cert: real-certs/fallback-cert.pem
25+
key: real-certs/fallback-key.pem
1826

1927
# response filters
2028
response_filters:

go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/pkg/errors v0.9.1
77
github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0
88
github.com/sirupsen/logrus v1.3.0
9+
github.com/stretchr/testify v1.2.2
910
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3
1011
golang.org/x/text v0.3.0 // indirect
1112
gopkg.in/yaml.v2 v2.3.0

httpx/transport.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1720,7 +1720,7 @@ type connectMethodKey struct {
17201720

17211721
func (k connectMethodKey) String() string {
17221722
// Only used by tests.
1723-
return fmt.Sprintf("%s|%s|%s|%p", k.proxy, k.scheme, k.addr, k.tlsConfigAddr)
1723+
return fmt.Sprintf("%s|%s|%s|%d", k.proxy, k.scheme, k.addr, k.tlsConfigAddr)
17241724
}
17251725

17261726
// persistConn wraps a connection, usually a persistent one

0 commit comments

Comments
 (0)