Skip to content

Commit 3f21624

Browse files
committed
add optimization of ACM certificate creation. Closes #452
It will now attempt to utilize any existing cert that can satisfy the domain, and will otherwise create one with a wildcard so that future domains will not require an additional cert
1 parent c82d529 commit 3f21624

File tree

5 files changed

+197
-19
lines changed

5 files changed

+197
-19
lines changed

internal/util/util.go

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,7 @@ func Md5(s string) string {
351351
return hex.EncodeToString(h.Sum(nil))
352352
}
353353

354-
// Domain returns the effective domain. For example
355-
// the string "api.example.com" becomes "example.com",
356-
// while "api.example.co.uk" becomes "example.co.uk".
354+
// Domain returns the effective domain (TLD plus one).
357355
func Domain(s string) string {
358356
d, err := publicsuffix.EffectiveTLDPlusOne(s)
359357
if err != nil {
@@ -363,6 +361,41 @@ func Domain(s string) string {
363361
return d
364362
}
365363

364+
// CertDomainNames returns the certificate domain name
365+
// and alternative names for a requested domain.
366+
func CertDomainNames(s string) []string {
367+
// effective domain
368+
if Domain(s) == s {
369+
return []string{s, "*." + s}
370+
}
371+
372+
// subdomain
373+
return []string{RemoveSubdomains(s, 1), "*." + RemoveSubdomains(s, 1)}
374+
}
375+
376+
// IsWildcardDomain returns true if the domain is a wildcard.
377+
func IsWildcardDomain(s string) bool {
378+
return strings.HasPrefix(s, "*.")
379+
}
380+
381+
// WildcardMatches returns true if wildcard is a wildcard domain
382+
// and it satisfies the given domain.
383+
func WildcardMatches(wildcard, domain string) bool {
384+
if !IsWildcardDomain(wildcard) {
385+
return false
386+
}
387+
388+
w := RemoveSubdomains(wildcard, 1)
389+
d := RemoveSubdomains(domain, 1)
390+
return w == d
391+
}
392+
393+
// RemoveSubdomains returns the domain without the n left-most subdomain(s).
394+
func RemoveSubdomains(s string, n int) string {
395+
domains := strings.Split(s, ".")
396+
return strings.Join(domains[n:], ".")
397+
}
398+
366399
// ParseSections returns INI style sections from r.
367400
func ParseSections(r io.Reader) (sections []string, err error) {
368401
s := bufio.NewScanner(r)
@@ -378,3 +411,16 @@ func ParseSections(r io.Reader) (sections []string, err error) {
378411

379412
return
380413
}
414+
415+
// UniqueStrings returns a string slice of unique values.
416+
func UniqueStrings(s []string) (v []string) {
417+
m := make(map[string]struct{})
418+
for _, val := range s {
419+
_, ok := m[val]
420+
if !ok {
421+
v = append(v, val)
422+
m[val] = struct{}{}
423+
}
424+
}
425+
return
426+
}

internal/util/util_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ func TestDomain(t *testing.T) {
8585
assert.Equal(t, "example.co.uk", Domain("v1.api.example.co.uk"))
8686
}
8787

88+
func TestCertDomainNames(t *testing.T) {
89+
assert.Equal(t, []string{"example.com", "*.example.com"}, CertDomainNames("example.com"))
90+
assert.Equal(t, []string{"example.com", "*.example.com"}, CertDomainNames("api.example.com"))
91+
assert.Equal(t, []string{"api.example.com", "*.api.example.com"}, CertDomainNames("v1.api.example.com"))
92+
}
93+
94+
func TestWildcardMatches(t *testing.T) {
95+
assert.True(t, WildcardMatches("*.api.example.com", "v1.api.example.com"))
96+
assert.True(t, WildcardMatches("*.example.com", "api.example.com"))
97+
assert.False(t, WildcardMatches("example.com", "api.example.com"))
98+
assert.False(t, WildcardMatches("*.api.example.com", "api.example.com"))
99+
}
100+
88101
func TestParseSections(t *testing.T) {
89102
r := strings.NewReader(`[personal]
90103
aws_access_key_id = personal_key

platform/lambda/lambda.go

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io/ioutil"
99
"os"
1010
"strings"
11+
"sync"
1112
"time"
1213

1314
"github.com/apex/log"
@@ -318,35 +319,34 @@ func (p *Platform) getHostedZone() (zones []*route53.HostedZone, err error) {
318319
func (p *Platform) createCerts() error {
319320
s := session.New(aws.NewConfig().WithRegion("us-east-1"))
320321
a := acm.New(s)
322+
var domains []string
321323

322324
// existing certs
323-
res, err := a.ListCertificates(&acm.ListCertificatesInput{
324-
MaxItems: aws.Int64(1000),
325-
})
326-
325+
log.Debug("fetching existing certs")
326+
certs, err := getCerts(a)
327327
if err != nil {
328-
return errors.Wrap(err, "listing")
328+
return errors.Wrap(err, "fetching certs")
329329
}
330330

331-
var domains []string
332-
333331
// request certs
334332
for _, s := range p.config.Stages.List() {
335333
if s == nil {
336334
continue
337335
}
338336

337+
certDomains := util.CertDomainNames(s.Domain)
338+
339339
// see if the cert exists
340340
log.Debugf("looking up cert for %s", s.Domain)
341-
arn := getCert(res.CertificateSummaryList, s.Domain)
341+
arn := getCert(certs, s.Domain)
342342
if arn != "" {
343343
log.Debugf("found cert for %s: %s", s.Domain, arn)
344344
s.Cert = arn
345345
continue
346346
}
347347

348348
option := acm.DomainValidationOption{
349-
DomainName: &s.Domain,
349+
DomainName: aws.String(certDomains[0]),
350350
ValidationDomain: aws.String(util.Domain(s.Domain)),
351351
}
352352

@@ -356,15 +356,16 @@ func (p *Platform) createCerts() error {
356356

357357
// request the cert
358358
res, err := a.RequestCertificate(&acm.RequestCertificateInput{
359-
DomainName: &s.Domain,
359+
DomainName: aws.String(certDomains[0]),
360360
DomainValidationOptions: options,
361+
SubjectAlternativeNames: aws.StringSlice(certDomains[1:]),
361362
})
362363

363364
if err != nil {
364-
return errors.Wrapf(err, "requesting cert for %s", s.Domain)
365+
return errors.Wrapf(err, "requesting cert for %v", certDomains)
365366
}
366367

367-
domains = append(domains, s.Domain)
368+
domains = append(domains, certDomains[0])
368369
s.Cert = *res.CertificateArn
369370
}
370371

@@ -379,7 +380,7 @@ func (p *Platform) createCerts() error {
379380

380381
// wait for approval
381382
for range time.Tick(4 * time.Second) {
382-
res, err = a.ListCertificates(&acm.ListCertificatesInput{
383+
res, err := a.ListCertificates(&acm.ListCertificatesInput{
383384
MaxItems: aws.Int64(1000),
384385
CertificateStatuses: aws.StringSlice([]string{acm.CertificateStatusPendingValidation}),
385386
})
@@ -787,12 +788,72 @@ func toEnv(env config.Environment, stage string) *lambda.Environment {
787788
}
788789
}
789790

790-
// getCert returns the ARN if the cert is present.
791-
func getCert(certs []*acm.CertificateSummary, domain string) string {
791+
// getCerts returns the certificates available.
792+
func getCerts(a *acm.ACM) (certs []*acm.CertificateDetail, err error) {
793+
var g errgroup.Group
794+
var mu sync.Mutex
795+
796+
res, err := a.ListCertificates(&acm.ListCertificatesInput{
797+
MaxItems: aws.Int64(1000),
798+
})
799+
800+
if err != nil {
801+
return nil, errors.Wrap(err, "listing")
802+
}
803+
804+
for _, c := range res.CertificateSummaryList {
805+
c := c
806+
g.Go(func() error {
807+
res, err := a.DescribeCertificate(&acm.DescribeCertificateInput{
808+
CertificateArn: c.CertificateArn,
809+
})
810+
811+
if err != nil {
812+
return errors.Wrap(err, "describing")
813+
}
814+
815+
mu.Lock()
816+
certs = append(certs, res.Certificate)
817+
mu.Unlock()
818+
return nil
819+
})
820+
}
821+
822+
err = g.Wait()
823+
return
824+
}
825+
826+
// getCert returns the ARN of a certificate with can satisfy domain,
827+
// favoring more specific certificates, then falling back on wildcards.
828+
func getCert(certs []*acm.CertificateDetail, domain string) string {
829+
// exact domain
792830
for _, c := range certs {
793831
if *c.DomainName == domain {
794832
return *c.CertificateArn
795833
}
796834
}
835+
836+
// exact alt
837+
for _, c := range certs {
838+
for _, a := range c.SubjectAlternativeNames {
839+
if *a == domain {
840+
return *c.CertificateArn
841+
}
842+
}
843+
}
844+
845+
// wildcards
846+
for _, c := range certs {
847+
if util.WildcardMatches(*c.DomainName, domain) {
848+
return *c.CertificateArn
849+
}
850+
851+
for _, a := range c.SubjectAlternativeNames {
852+
if util.WildcardMatches(*a, domain) {
853+
return *c.CertificateArn
854+
}
855+
}
856+
}
857+
797858
return ""
798859
}

platform/lambda/lambda_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package lambda
2+
3+
import (
4+
"testing"
5+
6+
"github.com/aws/aws-sdk-go/aws"
7+
"github.com/aws/aws-sdk-go/service/acm"
8+
"github.com/tj/assert"
9+
)
10+
11+
func TestGetCert(t *testing.T) {
12+
certs := []*acm.CertificateDetail{
13+
{
14+
DomainName: aws.String("example.com"),
15+
CertificateArn: aws.String("arn:example.com"),
16+
SubjectAlternativeNames: aws.StringSlice([]string{
17+
"*.example.com",
18+
}),
19+
},
20+
{
21+
DomainName: aws.String("*.apex.sh"),
22+
CertificateArn: aws.String("arn:*.apex.sh"),
23+
},
24+
{
25+
DomainName: aws.String("api.example.com"),
26+
CertificateArn: aws.String("arn:api.example.com"),
27+
SubjectAlternativeNames: aws.StringSlice([]string{
28+
"*.api.example.com",
29+
"something.example.com",
30+
}),
31+
},
32+
}
33+
34+
arn := getCert(certs, "example.com")
35+
assert.Equal(t, "arn:example.com", arn)
36+
37+
arn = getCert(certs, "www.example.com")
38+
assert.Equal(t, "arn:example.com", arn)
39+
40+
arn = getCert(certs, "api.example.com")
41+
assert.Equal(t, "arn:api.example.com", arn)
42+
43+
arn = getCert(certs, "apex.sh")
44+
assert.Empty(t, arn)
45+
46+
arn = getCert(certs, "api.apex.sh")
47+
assert.Equal(t, "arn:*.apex.sh", arn)
48+
49+
arn = getCert(certs, "v1.api.example.com")
50+
assert.Equal(t, "arn:api.example.com", arn)
51+
52+
arn = getCert(certs, "something.example.com")
53+
assert.Equal(t, "arn:api.example.com", arn)
54+
55+
arn = getCert(certs, "staging.v1.api.example.com")
56+
assert.Empty(t, arn)
57+
}

reporter/text/text.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,12 @@ func (r *reporter) Start() {
203203
}
204204
fmt.Printf("\n")
205205
case "platform.certs.create":
206-
domains := e.Fields["domains"].([]string)
206+
domains := util.UniqueStrings(e.Fields["domains"].([]string))
207207
r.log("domains", "Check your email to approve the certificate")
208208
r.pending("confirm", strings.Join(domains, ", "))
209209
case "platform.certs.create.complete":
210210
r.complete("confirm", "complete", e.Duration("duration"))
211+
fmt.Printf("\n")
211212
case "metrics", "metrics.complete":
212213
fmt.Printf("\n")
213214
case "metrics.value":

0 commit comments

Comments
 (0)