|
| 1 | +package api |
| 2 | + |
| 3 | +import ( |
| 4 | + "crypto/tls" |
| 5 | + "os" |
| 6 | + "path/filepath" |
| 7 | + "sync" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/fsnotify/fsnotify" |
| 11 | + |
| 12 | + "zotregistry.dev/zot/v2/pkg/log" |
| 13 | +) |
| 14 | + |
| 15 | +const ( |
| 16 | + // certCheckCacheDuration is the minimum time between file stat checks when fsnotify is unavailable. |
| 17 | + // This prevents excessive file system calls during high TLS handshake rates. |
| 18 | + certCheckCacheDuration = 1 * time.Second |
| 19 | +) |
| 20 | + |
| 21 | +// CertReloader handles automatic reloading of TLS certificates without downtime. |
| 22 | +// It monitors certificate and key files for changes and reloads them dynamically |
| 23 | +// using a GetCertificate callback in tls.Config. |
| 24 | +type CertReloader struct { |
| 25 | + certMu sync.RWMutex |
| 26 | + cert *tls.Certificate |
| 27 | + certPath string |
| 28 | + keyPath string |
| 29 | + certMod time.Time |
| 30 | + keyMod time.Time |
| 31 | + log log.Logger |
| 32 | + watcher *fsnotify.Watcher |
| 33 | + reloadMu sync.Mutex // Prevents concurrent reload operations |
| 34 | + lastCheck time.Time |
| 35 | + checkCache time.Duration // Minimum time between file stat checks |
| 36 | + stopWatcher chan struct{} |
| 37 | +} |
| 38 | + |
| 39 | +// NewCertReloader creates a new certificate reloader and loads the initial certificate. |
| 40 | +// It starts an fsnotify watcher to monitor certificate file changes. |
| 41 | +func NewCertReloader(certPath, keyPath string, logger log.Logger) (*CertReloader, error) { |
| 42 | + reloader := &CertReloader{ |
| 43 | + certPath: certPath, |
| 44 | + keyPath: keyPath, |
| 45 | + log: logger, |
| 46 | + checkCache: certCheckCacheDuration, |
| 47 | + stopWatcher: make(chan struct{}), |
| 48 | + } |
| 49 | + |
| 50 | + if err := reloader.reload(); err != nil { |
| 51 | + return nil, err |
| 52 | + } |
| 53 | + |
| 54 | + // Start fsnotify watcher in background |
| 55 | + if err := reloader.startWatcher(); err != nil { |
| 56 | + // Log warning but don't fail - we'll fall back to periodic checking |
| 57 | + logger.Warn().Err(err).Msg("failed to start fsnotify watcher, falling back to periodic checking") |
| 58 | + } |
| 59 | + |
| 60 | + return reloader, nil |
| 61 | +} |
| 62 | + |
| 63 | +// Close stops the file watcher and releases resources. |
| 64 | +func (cr *CertReloader) Close() error { |
| 65 | + if cr.stopWatcher != nil { |
| 66 | + close(cr.stopWatcher) |
| 67 | + } |
| 68 | + |
| 69 | + if cr.watcher != nil { |
| 70 | + return cr.watcher.Close() |
| 71 | + } |
| 72 | + |
| 73 | + return nil |
| 74 | +} |
| 75 | + |
| 76 | +// startWatcher initializes the fsnotify watcher for certificate files. |
| 77 | +func (cr *CertReloader) startWatcher() error { |
| 78 | + watcher, err := fsnotify.NewWatcher() |
| 79 | + if err != nil { |
| 80 | + return err |
| 81 | + } |
| 82 | + |
| 83 | + cr.watcher = watcher |
| 84 | + |
| 85 | + // Watch the directory containing the certificate files |
| 86 | + // This is more reliable than watching files directly, especially for atomic file updates |
| 87 | + certDir := filepath.Dir(cr.certPath) |
| 88 | + keyDir := filepath.Dir(cr.keyPath) |
| 89 | + |
| 90 | + if err := watcher.Add(certDir); err != nil { |
| 91 | + return err |
| 92 | + } |
| 93 | + |
| 94 | + // If cert and key are in different directories, watch both |
| 95 | + if certDir != keyDir { |
| 96 | + if err := watcher.Add(keyDir); err != nil { |
| 97 | + return err |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + // Start goroutine to handle file system events |
| 102 | + go cr.watchLoop() |
| 103 | + |
| 104 | + return nil |
| 105 | +} |
| 106 | + |
| 107 | +// watchLoop handles file system events from fsnotify. |
| 108 | +func (cr *CertReloader) watchLoop() { |
| 109 | + for { |
| 110 | + select { |
| 111 | + case <-cr.stopWatcher: |
| 112 | + return |
| 113 | + case event, ok := <-cr.watcher.Events: |
| 114 | + if !ok { |
| 115 | + return |
| 116 | + } |
| 117 | + |
| 118 | + // Check if the event is for our certificate or key files |
| 119 | + if event.Name == cr.certPath || event.Name == cr.keyPath { |
| 120 | + // Only process write and create events |
| 121 | + if event.Op&(fsnotify.Write|fsnotify.Create) != 0 { |
| 122 | + cr.log.Debug().Str("file", event.Name).Str("op", event.Op.String()). |
| 123 | + Msg("certificate file change detected") |
| 124 | + |
| 125 | + // Try to reload the certificate |
| 126 | + cr.tryReload() |
| 127 | + } |
| 128 | + } |
| 129 | + case err, ok := <-cr.watcher.Errors: |
| 130 | + if !ok { |
| 131 | + return |
| 132 | + } |
| 133 | + cr.log.Warn().Err(err).Msg("fsnotify watcher error") |
| 134 | + } |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +// tryReload attempts to reload certificates with proper concurrency control. |
| 139 | +func (cr *CertReloader) tryReload() { |
| 140 | + // Use mutex to ensure only one reload happens at a time |
| 141 | + // This prevents race condition where multiple goroutines detect changes simultaneously |
| 142 | + cr.reloadMu.Lock() |
| 143 | + defer cr.reloadMu.Unlock() |
| 144 | + |
| 145 | + if err := cr.reload(); err != nil { |
| 146 | + cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath). |
| 147 | + Msg("failed to reload TLS certificates") |
| 148 | + } else { |
| 149 | + cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath). |
| 150 | + Msg("TLS certificates reloaded successfully") |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +// reload loads the certificate and key from disk and updates the internal certificate. |
| 155 | +func (cr *CertReloader) reload() error { |
| 156 | + // Get file modification times |
| 157 | + certInfo, err := os.Stat(cr.certPath) |
| 158 | + if err != nil { |
| 159 | + return err |
| 160 | + } |
| 161 | + |
| 162 | + keyInfo, err := os.Stat(cr.keyPath) |
| 163 | + if err != nil { |
| 164 | + return err |
| 165 | + } |
| 166 | + |
| 167 | + certMod := certInfo.ModTime() |
| 168 | + keyMod := keyInfo.ModTime() |
| 169 | + |
| 170 | + // Load the certificate |
| 171 | + newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath) |
| 172 | + if err != nil { |
| 173 | + return err |
| 174 | + } |
| 175 | + |
| 176 | + // Update the certificate and modification times |
| 177 | + cr.certMu.Lock() |
| 178 | + defer cr.certMu.Unlock() |
| 179 | + |
| 180 | + cr.cert = &newCert |
| 181 | + cr.certMod = certMod |
| 182 | + cr.keyMod = keyMod |
| 183 | + |
| 184 | + return nil |
| 185 | +} |
| 186 | + |
| 187 | +// maybeReload checks if the certificate files have been modified and reloads them if necessary. |
| 188 | +// This is used as a fallback when fsnotify is not available or fails. |
| 189 | +// Uses time-based caching to avoid excessive file system calls. |
| 190 | +func (cr *CertReloader) maybeReload() error { |
| 191 | + // Use time-based cache to reduce frequency of stat calls |
| 192 | + cr.certMu.RLock() |
| 193 | + if time.Since(cr.lastCheck) < cr.checkCache { |
| 194 | + // Recently checked, skip stat calls |
| 195 | + cr.certMu.RUnlock() |
| 196 | + |
| 197 | + return nil |
| 198 | + } |
| 199 | + cr.certMu.RUnlock() |
| 200 | + |
| 201 | + // Update last check time |
| 202 | + cr.certMu.Lock() |
| 203 | + cr.lastCheck = time.Now() |
| 204 | + cr.certMu.Unlock() |
| 205 | + |
| 206 | + // Check cert file modification time |
| 207 | + certInfo, err := os.Stat(cr.certPath) |
| 208 | + if err != nil { |
| 209 | + return err |
| 210 | + } |
| 211 | + |
| 212 | + keyInfo, err := os.Stat(cr.keyPath) |
| 213 | + if err != nil { |
| 214 | + return err |
| 215 | + } |
| 216 | + |
| 217 | + certMod := certInfo.ModTime() |
| 218 | + keyMod := keyInfo.ModTime() |
| 219 | + |
| 220 | + // Check if files have been modified |
| 221 | + cr.certMu.RLock() |
| 222 | + needsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod) |
| 223 | + cr.certMu.RUnlock() |
| 224 | + |
| 225 | + if needsReload { |
| 226 | + // Use reloadMu to prevent concurrent reload operations |
| 227 | + cr.reloadMu.Lock() |
| 228 | + defer cr.reloadMu.Unlock() |
| 229 | + |
| 230 | + // Double-check after acquiring lock - another goroutine might have already reloaded |
| 231 | + cr.certMu.RLock() |
| 232 | + stillNeedsReload := certMod.After(cr.certMod) || keyMod.After(cr.keyMod) |
| 233 | + cr.certMu.RUnlock() |
| 234 | + |
| 235 | + if stillNeedsReload { |
| 236 | + if err := cr.reload(); err != nil { |
| 237 | + cr.log.Warn().Err(err).Str("cert", cr.certPath).Str("key", cr.keyPath). |
| 238 | + Msg("failed to reload TLS certificates") |
| 239 | + |
| 240 | + return err |
| 241 | + } |
| 242 | + |
| 243 | + cr.log.Info().Str("cert", cr.certPath).Str("key", cr.keyPath). |
| 244 | + Msg("TLS certificates reloaded successfully") |
| 245 | + } |
| 246 | + } |
| 247 | + |
| 248 | + return nil |
| 249 | +} |
| 250 | + |
| 251 | +// GetCertificateFunc returns a function that can be used as tls.Config.GetCertificate. |
| 252 | +// This function checks for certificate updates on each TLS handshake and reloads if necessary. |
| 253 | +// If fsnotify watcher is active, this only performs time-cached checks as a fallback. |
| 254 | +func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { |
| 255 | + return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| 256 | + // Try to reload the certificate if it has changed |
| 257 | + // This is a fallback mechanism when fsnotify is not available |
| 258 | + // Errors are logged but ignored to maintain availability with existing certificate |
| 259 | + _ = cr.maybeReload() |
| 260 | + |
| 261 | + cr.certMu.RLock() |
| 262 | + defer cr.certMu.RUnlock() |
| 263 | + |
| 264 | + return cr.cert, nil |
| 265 | + } |
| 266 | +} |
0 commit comments