Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions internal/aghtest/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package aghtest

import (
"context"
"crypto/tls"
"crypto/x509"
"net/http"
"net/netip"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
Expand Down Expand Up @@ -198,3 +201,26 @@ var _ aghhttp.Registrar = (*Registrar)(nil)
func (m *Registrar) Register(method, path string, h http.HandlerFunc) {
m.OnRegister(method, path, h)
}

// TLSConfigProvider is a fake [aghtls.TLSConfigProvider] implementation for
// tests.
// TODO(m.kazantsev): Use in tests.
type TLSConfigProvider struct {
OnTLSConfig func() (conf *tls.Config)
OnRootCAs func() (cert *x509.CertPool)
}

// type check
var _ aghtls.TLSConfigProvider = (*TLSConfigProvider)(nil)

// TLSConfig implements the [aghtls.TLSConfigProvider] interface for
// *TLSConfigProvider.
func (t *TLSConfigProvider) TLSConfig() (conf *tls.Config) {
return t.OnTLSConfig()
}

// RootCAs implements the [aghtls.TLSConfigProvider] interface for
// *TLSConfigProvider.
func (t *TLSConfigProvider) RootCAs() (pool *x509.CertPool) {
return t.OnRootCAs()
}
39 changes: 39 additions & 0 deletions internal/aghtls/configprovider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package aghtls

import (
"crypto/tls"
"crypto/x509"
)

// TLSConfigProvider provides TLS configuration to consumers. Implementations
// must be safe for concurrent use.
//
// TODO(m.kazantsev): Merge with the Manager interface.
// TODO(m.kazantsev): Add at least one real implementation.
type TLSConfigProvider interface {
// TLSConfig returns a clone of the current TLS configuration. conf
// provides its certificates via GetConfigForClient method.
TLSConfig() (conf *tls.Config)

// RootCAs returns the current root CA pool.
RootCAs() (root *x509.CertPool)
}

// type check
var _ TLSConfigProvider = EmptyTLSConfigProvider{}

// EmptyTLSConfigProvider is the implementation of the [TLSConfigProvider]
// interface that does nothing.
type EmptyTLSConfigProvider struct{}

// TLSConfig implements the [TLSConfigProvider] interface for
// EmptyTLSConfigProvider. It always returns nil.
func (EmptyTLSConfigProvider) TLSConfig() (conf *tls.Config) {
return nil
}

// RootCAs implements the [TLSConfigProvider] interface for
// EmptyTLSConfigProvider. It always returns nil.
func (EmptyTLSConfigProvider) RootCAs() (root *x509.CertPool) {
return nil
}
21 changes: 15 additions & 6 deletions internal/dnsforward/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,12 @@ func (conf *ServerConfig) loadUpstreams(

upstreams = stringutil.SplitTrimmed(string(data), "\n")

l.DebugContext(ctx, "got upstreams", "number", len(upstreams), "filename", conf.UpstreamDNSFileName)
l.DebugContext(
ctx,
"got upstreams",
"number", len(upstreams),
"filename", conf.UpstreamDNSFileName,
)

return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
}
Expand Down Expand Up @@ -652,7 +657,10 @@ func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error)

// ourAddrsSet returns an addrPortSet that contains all the configured listening
// addresses. l must not be nil.
func (conf *ServerConfig) ourAddrsSet(ctx context.Context, l *slog.Logger) (m addrPortSet, err error) {
func (conf *ServerConfig) ourAddrsSet(
ctx context.Context,
l *slog.Logger,
) (m addrPortSet, err error) {
addrs, unspecPorts := conf.collectDNSAddrs()
switch {
case addrs.Len() == 0:
Expand Down Expand Up @@ -781,8 +789,9 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) {
return false
}

// Called by 'tls' package when Client Hello is received
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
// onGetCertificate is called by [tls] package when Client Hello is received. If
// the server name (from SNI) supplied by client is incorrect - we terminate the
// ongoing TLS handshake.
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) {
// TODO(s.chzhen): Pass context.
Expand All @@ -798,8 +807,8 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
return s.conf.TLSConf.Cert, nil
}

// preparePlain prepares the plain-DNS configuration for the DNS proxy.
// preparePlain assumes that prepareTLS has already been called.
// preparePlain prepares the plain-DNS configuration for the DNS proxy. The
// method assumes that prepareTLS has already been called.
func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) {
if s.conf.ServePlainDNS {
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
Expand Down
9 changes: 7 additions & 2 deletions internal/dnsforward/configvalidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ func newUpstreamConfigValidator(
// collectErrResults parses err and returns parsing results containing the
// original upstream configuration line and the corresponding error. err can be
// nil. l must not be nil.
func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err error) (results []*parseResult) {
func collectErrResults(
ctx context.Context,
l *slog.Logger,
lines []string,
err error,
) (results []*parseResult) {
if err == nil {
return nil
}
Expand Down Expand Up @@ -132,7 +137,7 @@ func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err
}

// insertConfResults parses conf and inserts the upstream result into results.
// It can insert multiple results as well as none.
// It can insert multiple results as well as none. conf must not be nil.
func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) {
insertListResults(conf.Upstreams, results, false)

Expand Down
4 changes: 2 additions & 2 deletions internal/home/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,8 +884,8 @@ func (c *configuration) write(
}

if tlsMgr != nil {
tlsConf := tlsMgr.config()
config.TLS = *tlsConf
extTLSConf := tlsMgr.extendedTLSConfig()
config.TLS = *extTLSConf
}

if globalContext.stats != nil {
Expand Down
16 changes: 11 additions & 5 deletions internal/home/controlupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ type versionResponse struct {
Disabled bool `json:"disabled"`
}

// maxPrivilegedPort is the maximum port number. This only applies to Unix, as
// on Windows, [aghnet.CanBindPrivilegedPorts] always returns `true`, `nil`.
const maxPrivilegedPort = 1024

// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS. l and tlsMgr must not be
// nil.
Expand All @@ -191,9 +195,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate(
}

canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsMgr.config()) ||
config.HTTPConfig.Address.Port() < 1024 ||
config.DNS.Port < 1024 {
if tlsConfUsesPrivilegedPorts(tlsMgr.extendedTLSConfig()) ||
config.HTTPConfig.Address.Port() < maxPrivilegedPort ||
config.DNS.Port < maxPrivilegedPort {
canUpdate, err = aghnet.CanBindPrivilegedPorts(ctx, l)
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
Expand All @@ -206,9 +210,11 @@ func (vr *versionResponse) setAllowedToAutoUpdate(
}

// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used.
// indicates that privileged ports are used. c must be valid
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
return c.Enabled && (c.PortHTTPS < maxPrivilegedPort ||
c.PortDNSOverTLS < maxPrivilegedPort ||
c.PortDNSOverQUIC < maxPrivilegedPort)
}

// finishUpdate completes an update procedure. It is intended to be used as a
Expand Down
69 changes: 36 additions & 33 deletions internal/home/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func initDNSServer(
dnsConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsMgr.config(),
tlsMgr.extendedTLSConfig(),
config.HTTPConfig.DoH,
tlsMgr,
httpReg,
Expand Down Expand Up @@ -212,7 +212,8 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
}
}

func isRunning() bool {
// isRunning checks whether the DNS server is running.
func isRunning() (ok bool) {
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
}

Expand Down Expand Up @@ -262,7 +263,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
func newServerConfig(
dnsConf *dnsConfig,
clientSrcConf *clientSourcesConfig,
tlsConf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
dohConf *doHConfig,
tlsMgr *tlsManager,
httpReg aghhttp.Registrar,
Expand All @@ -274,7 +275,7 @@ func newServerConfig(
fwdConf := dnsConf.Config
fwdConf.ClientsContainer = clientsContainer

intTLSConf, err := newDNSTLSConfig(tlsConf, hosts, dohConf.InsecureEnabled)
intTLSConf, err := newDNSTLSConfig(extTLSConf, hosts, dohConf.InsecureEnabled)
if err != nil {
return nil, fmt.Errorf("constructing tls config: %w", err)
}
Expand Down Expand Up @@ -322,43 +323,43 @@ func newServerConfig(
}

// newDNSTLSConfig converts values from the configuration file into the internal
// TLS settings for the DNS server. conf must not be nil.
// TLS settings for the DNS server. extTLSConf must not be nil.
func newDNSTLSConfig(
conf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
addrs []netip.Addr,
allowUnencryptedDoH bool,
) (dnsConf *dnsforward.TLSConfig, err error) {
if !conf.Enabled {
if !extTLSConf.Enabled {
return &dnsforward.TLSConfig{}, nil
}

// TODO(e.burkov): Add tracking for DNSCrypt configuration file changes to
// the [aghtls.Manager].
dnsCryptConf, err := newDNSCryptConfig(conf, addrs)
dnsCryptConf, err := newDNSCryptConfig(extTLSConf, addrs)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}

dnsConf = &dnsforward.TLSConfig{
DNSCryptConf: dnsCryptConf,
ServerName: conf.ServerName,
StrictSNICheck: conf.StrictSNICheck,
ServerName: extTLSConf.ServerName,
StrictSNICheck: extTLSConf.StrictSNICheck,
}

if conf.PortHTTPS != 0 {
dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, conf.PortHTTPS)
if extTLSConf.PortHTTPS != 0 {
dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, extTLSConf.PortHTTPS)
}

if conf.PortDNSOverTLS != 0 {
dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, conf.PortDNSOverTLS)
if extTLSConf.PortDNSOverTLS != 0 {
dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, extTLSConf.PortDNSOverTLS)
}

if conf.PortDNSOverQUIC != 0 {
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, conf.PortDNSOverQUIC)
if extTLSConf.PortDNSOverQUIC != 0 {
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, extTLSConf.PortDNSOverQUIC)
}

cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
cert, err := tls.X509KeyPair(extTLSConf.CertificateChainData, extTLSConf.PrivateKeyData)
if err != nil {
err = fmt.Errorf("parsing tls key pair: %w", err)
if allowUnencryptedDoH || dnsCryptConf != nil {
Expand All @@ -378,20 +379,20 @@ func newDNSTLSConfig(
}

// newDNSCryptConfig converts values from the configuration file into the
// internal DNSCrypt settings for the DNS server. conf must not be nil.
// internal DNSCrypt settings for the DNS server. extTLSConf must not be nil.
func newDNSCryptConfig(
conf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
addrs []netip.Addr,
) (dnsCryptConf *dnsforward.DNSCryptConfig, err error) {
if conf.PortDNSCrypt == 0 {
if extTLSConf.PortDNSCrypt == 0 {
return nil, nil
}

if conf.DNSCryptConfigFile == "" {
if extTLSConf.DNSCryptConfigFile == "" {
return nil, fmt.Errorf("dnscrypt_config_file: %w", errors.ErrEmptyValue)
}

f, err := os.Open(conf.DNSCryptConfigFile)
f, err := os.Open(extTLSConf.DNSCryptConfigFile)
if err != nil {
return nil, fmt.Errorf("opening dnscrypt config: %w", err)
}
Expand All @@ -410,8 +411,8 @@ func newDNSCryptConfig(

return &dnsforward.DNSCryptConfig{
ResolverCert: cert,
UDPListenAddrs: ipsToUDPAddrs(addrs, conf.PortDNSCrypt),
TCPListenAddrs: ipsToTCPAddrs(addrs, conf.PortDNSCrypt),
UDPListenAddrs: ipsToUDPAddrs(addrs, extTLSConf.PortDNSCrypt),
TCPListenAddrs: ipsToTCPAddrs(addrs, extTLSConf.PortDNSCrypt),
ProviderName: rc.ProviderName,
}, nil
}
Expand All @@ -426,16 +427,16 @@ type dnsEncryption struct {
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
// listens on. tlsMgr must not be nil.
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
tlsConf := tlsMgr.config()
extTLSConf := tlsMgr.extendedTLSConfig()

if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
if !extTLSConf.Enabled || extTLSConf.ServerName == "" {
return dnsEncryption{}
}

hostname := tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
hostname := extTLSConf.ServerName
if extTLSConf.PortHTTPS != 0 {
addr := hostname
if p := tlsConf.PortHTTPS; p != defaultPortHTTPS {
if p := extTLSConf.PortHTTPS; p != defaultPortHTTPS {
addr = netutil.JoinHostPort(addr, p)
}

Expand All @@ -446,14 +447,14 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
}).String()
}

if p := tlsConf.PortDNSOverTLS; p != 0 {
if p := extTLSConf.PortDNSOverTLS; p != 0 {
de.tls = (&url.URL{
Scheme: "tls",
Host: netutil.JoinHostPort(hostname, p),
}).String()
}

if p := tlsConf.PortDNSOverQUIC; p != 0 {
if p := extTLSConf.PortDNSOverQUIC; p != 0 {
de.quic = (&url.URL{
Scheme: "quic",
Host: netutil.JoinHostPort(hostname, p),
Expand All @@ -463,7 +464,9 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
return de
}

func startDNSServer() error {
// startDNSServer starts the DNS server, clients container, filters, stats and
// the query log.
func startDNSServer() (err error) {
config.RLock()
defer config.RUnlock()

Expand All @@ -475,7 +478,7 @@ func startDNSServer() error {

// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := globalContext.clients.Start(ctx)
err = globalContext.clients.Start(ctx)
if err != nil {
return fmt.Errorf("starting clients container: %w", err)
}
Expand Down
Loading
Loading