Skip to content

Commit e822453

Browse files
authored
Use context from ClientHello during GetCertificate (#249)
* Use context from ClientHello during GetCertificate (see #247) * Avoid recursive ops during on-demand issuance
1 parent 5bca6d1 commit e822453

File tree

1 file changed

+32
-33
lines changed

1 file changed

+32
-33
lines changed

handshake.go

+32-33
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ import (
4646
// GetCertificate will run in a new context, use GetCertificateWithContext to provide
4747
// a context.
4848
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
49-
ctx := context.TODO() // TODO: get a proper context? from somewhere...
50-
return cfg.GetCertificateWithContext(ctx, clientHello)
49+
return cfg.GetCertificateWithContext(clientHello.Context(), clientHello)
5150
}
5251

5352
func (cfg *Config) GetCertificateWithContext(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -276,15 +275,15 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
276275
name := cfg.getNameFromClientHello(hello)
277276

278277
// By this point, we need to load or obtain a certificate. If a swarm of requests comes in for the same
279-
// domain, avoid pounding manager or storage thousands of times simultaneously. We do a similar sync
278+
// domain, avoid pounding manager or storage thousands of times simultaneously. We use a similar sync
280279
// strategy for obtaining certificate during handshake.
281280
certLoadWaitChansMu.Lock()
282281
wait, ok := certLoadWaitChans[name]
283282
if ok {
284283
// another goroutine is already loading the cert; just wait and we'll get it from the in-memory cache
285284
certLoadWaitChansMu.Unlock()
286285

287-
timeout := time.NewTimer(2 * time.Minute) // TODO: have Caddy use the context param to establish a timeout
286+
timeout := time.NewTimer(2 * time.Minute)
288287
select {
289288
case <-timeout.C:
290289
return Certificate{}, fmt.Errorf("timed out waiting to load certificate for %s", name)
@@ -480,6 +479,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
480479
// wait for it to finish obtaining the cert and then we'll use it.
481480
obtainCertWaitChansMu.Unlock()
482481

482+
log.Debug("new certificate is needed, but is already being obtained; waiting for that issuance to complete",
483+
zap.String("subject", name))
484+
483485
// TODO: see if we can get a proper context in here, for true cancellation
484486
timeout := time.NewTimer(2 * time.Minute)
485487
select {
@@ -489,7 +491,9 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
489491
timeout.Stop()
490492
}
491493

492-
return cfg.loadCertFromStorage(ctx, log, hello)
494+
// it should now be loaded in the cache, ready to go; if not,
495+
// the goroutine in charge of that probably had an error
496+
return cfg.getCertDuringHandshake(ctx, hello, false)
493497
}
494498

495499
// looks like it's up to us to do all the work and obtain the cert.
@@ -507,28 +511,28 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
507511

508512
log.Info("obtaining new certificate", zap.String("server_name", name))
509513

510-
// TODO: we are only adding a timeout because we don't know if the context passed in is actually cancelable...
514+
// set a timeout so we don't inadvertently hold a client handshake open too long
511515
// (timeout duration is based on https://caddy.community/t/zerossl-dns-challenge-failing-often-route53-plugin/13822/24?u=matt)
512516
var cancel context.CancelFunc
513517
ctx, cancel = context.WithTimeout(ctx, 180*time.Second)
514518
defer cancel()
515519

516-
// Obtain the certificate
520+
// obtain the certificate (this puts it in storage) and if successful,
521+
// load it from storage so we and any other waiting goroutine can use it
522+
var cert Certificate
517523
err := cfg.ObtainCertAsync(ctx, name)
524+
if err == nil {
525+
// load from storage while others wait to make the op as atomic as possible
526+
cert, err = cfg.loadCertFromStorage(ctx, log, hello)
527+
if err != nil {
528+
log.Error("loading newly-obtained certificate from storage", zap.String("server_name", name), zap.Error(err))
529+
}
530+
}
518531

519-
// immediately unblock anyone waiting for it; doing this in
520-
// a defer would risk deadlock because of the recursive call
521-
// to getCertDuringHandshake below when we return!
532+
// immediately unblock anyone waiting for it
522533
unblockWaiters()
523534

524-
if err != nil {
525-
// shucks; failed to solve challenge on-demand
526-
return Certificate{}, err
527-
}
528-
529-
// success; certificate was just placed on disk, so
530-
// we need only restart serving the certificate
531-
return cfg.loadCertFromStorage(ctx, log, hello)
535+
return cert, err
532536
}
533537

534538
// handshakeMaintenance performs a check on cert for expiration and OCSP validity.
@@ -611,7 +615,7 @@ func (cfg *Config) handshakeMaintenance(ctx context.Context, hello *tls.ClientHe
611615
//
612616
// This function is safe for use by multiple concurrent goroutines.
613617
func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.ClientHelloInfo, currentCert Certificate) (Certificate, error) {
614-
log := cfg.Logger.Named("on_demand")
618+
log := logWithRemote(cfg.Logger.Named("on_demand"), hello)
615619

616620
name := cfg.getNameFromClientHello(hello)
617621
timeLeft := time.Until(expiresAt(currentCert.Leaf))
@@ -651,7 +655,9 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
651655
timeout.Stop()
652656
}
653657

654-
return cfg.loadCertFromStorage(ctx, log, hello)
658+
// it should now be loaded in the cache, ready to go; if not,
659+
// the goroutine in charge of that probably had an error
660+
return cfg.getCertDuringHandshake(ctx, hello, false)
655661
}
656662

657663
// looks like it's up to us to do all the work and renew the cert
@@ -703,16 +709,8 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
703709
} else {
704710
err = cfg.RenewCertAsync(ctx, name, false)
705711
if err == nil {
706-
// even though the recursive nature of the dynamic cert loading
707-
// would just call this function anyway, we do it here to
708-
// make the replacement as atomic as possible.
709-
newCert, err = cfg.CacheManagedCertificate(ctx, name)
710-
if err != nil {
711-
log.Error("loading renewed certificate", zap.String("server_name", name), zap.Error(err))
712-
} else {
713-
// replace the old certificate with the new one
714-
cfg.certCache.replaceCertificate(currentCert, newCert)
715-
}
712+
// load from storage while in lock to make the replacement as atomic as possible
713+
newCert, err = cfg.reloadManagedCertificate(ctx, currentCert)
716714
}
717715
}
718716

@@ -722,11 +720,10 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
722720
unblockWaiters()
723721

724722
if err != nil {
725-
log.Error("renewing and reloading certificate", zap.Error(err))
726-
return newCert, err
723+
log.Error("renewing and reloading certificate", zap.String("server_name", name), zap.Error(err))
727724
}
728725

729-
return cfg.loadCertFromStorage(ctx, log, hello)
726+
return newCert, err
730727
}
731728

732729
// if the certificate hasn't expired, we can serve what we have and renew in the background
@@ -872,6 +869,8 @@ var (
872869
obtainCertWaitChans = make(map[string]chan struct{})
873870
obtainCertWaitChansMu sync.Mutex
874871
)
872+
873+
// TODO: this lockset should probably be per-cache
875874
var (
876875
certLoadWaitChans = make(map[string]chan struct{})
877876
certLoadWaitChansMu sync.Mutex

0 commit comments

Comments
 (0)