all: sync with master

This commit is contained in:
Ainar Garipov
2024-01-30 18:43:51 +03:00
parent f6ad64bf69
commit b01c10b73e
196 changed files with 3190 additions and 1790 deletions

View File

@@ -89,12 +89,8 @@ type Config struct {
// servers are not responding.
FallbackDNS []string `yaml:"fallback_dns"`
// AllServers, if true, parallel queries to all configured upstream servers
// are enabled.
AllServers bool `yaml:"all_servers"`
// FastestAddr, if true, use Fastest Address algorithm.
FastestAddr bool `yaml:"fastest_addr"`
// UpstreamMode determines the logic through which upstreams will be used.
UpstreamMode UpstreamMode `yaml:"upstream_mode"`
// FastestTimeout replaces the default timeout for dialing IP addresses
// when FastestAddr is true.
@@ -114,11 +110,10 @@ type Config struct {
// BlockedHosts is the list of hosts that should be blocked.
BlockedHosts []string `yaml:"blocked_hosts"`
// TrustedProxies is the list of IP addresses and CIDR networks to detect
// proxy servers addresses the DoH requests from which should be handled.
// The value of nil or an empty slice for this field makes Proxy not trust
// any address.
TrustedProxies []string `yaml:"trusted_proxies"`
// TrustedProxies is the list of CIDR networks with proxy servers addresses
// from which the DoH requests should be handled. The value of nil or an
// empty slice for this field makes Proxy not trust any address.
TrustedProxies []netutil.Prefix `yaml:"trusted_proxies"`
// DNS cache settings
@@ -154,7 +149,7 @@ type Config struct {
// MaxGoroutines is the max number of parallel goroutines for processing
// incoming requests.
MaxGoroutines uint32 `yaml:"max_goroutines"`
MaxGoroutines uint `yaml:"max_goroutines"`
// HandleDDR, if true, handle DDR requests
HandleDDR bool `yaml:"handle_ddr"`
@@ -294,9 +289,21 @@ type ServerConfig struct {
ServePlainDNS bool
}
// UpstreamMode is a enumeration of upstream mode representations. See
// [proxy.UpstreamModeType].
type UpstreamMode string
const (
UpstreamModeLoadBalance UpstreamMode = "load_balance"
UpstreamModeParallel UpstreamMode = "parallel"
UpstreamModeFastestAddr UpstreamMode = "fastest_addr"
)
// newProxyConfig creates and validates configuration for the main proxy.
func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
srvConf := s.conf
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit),
@@ -304,7 +311,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny,
TrustedProxies: srvConf.TrustedProxies,
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic,
@@ -313,7 +320,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: int(srvConf.MaxGoroutines),
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
}
@@ -323,18 +330,11 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())
}
if srvConf.CacheSize != 0 {
conf.CacheEnabled = true
conf.CacheSizeBytes = int(srvConf.CacheSize)
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
if err != nil {
return nil, fmt.Errorf("upstream mode: %w", err)
}
setProxyUpstreamMode(
conf,
srvConf.AllServers,
srvConf.FastestAddr,
srvConf.FastestTimeout.Duration,
)
conf.BogusNXDomain, err = parseBogusNXDOMAIN(srvConf.BogusNXDomain)
if err != nil {
return nil, fmt.Errorf("bogus_nxdomain: %w", err)
@@ -361,6 +361,37 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
return nil, errors.Error("no default upstream servers configured")
}
conf, err = prepareCacheConfig(conf,
srvConf.CacheSize,
srvConf.CacheMinTTL,
srvConf.CacheMaxTTL,
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
return conf, nil
}
// prepareCacheConfig prepares the cache configuration and returns an error if
// there is one.
func prepareCacheConfig(
conf *proxy.Config,
size uint32,
minTTL uint32,
maxTTL uint32,
) (prepared *proxy.Config, err error) {
if size != 0 {
conf.CacheEnabled = true
conf.CacheSizeBytes = int(size)
}
err = validateCacheTTL(minTTL, maxTTL)
if err != nil {
return nil, fmt.Errorf("validating cache ttl: %w", err)
}
return conf, nil
}
@@ -739,3 +770,19 @@ func (s *Server) enableProtectionAfterPause() {
log.Info("dns: protection is restarted after pause")
}
// validateCacheTTL returns an error if the configuration of the cache TTL
// invalid.
//
// TODO(s.chzhen): Move to dnsproxy.
func validateCacheTTL(minTTL, maxTTL uint32) (err error) {
if minTTL == 0 && maxTTL == 0 {
return nil
}
if maxTTL > 0 && minTTL > maxTTL {
return errors.Error("cache_ttl_min must be less than or equal to cache_ttl_max")
}
return nil
}

View File

@@ -290,6 +290,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,

View File

@@ -81,6 +81,7 @@ type DHCP interface {
Enabled() (ok bool)
}
// SystemResolvers is an interface for accessing the OS-provided resolvers.
type SystemResolvers interface {
// Addrs returns the list of system resolvers' addresses.
Addrs() (addrs []netip.AddrPort)
@@ -142,7 +143,7 @@ type Server struct {
// PTR resolving.
sysResolvers SystemResolvers
// etcHosts contains the data from the system's hosts files.
// etcHosts contains the current data from the system's hosts files.
etcHosts upstream.Resolver
// bootstrap is the resolver for upstreams' hostnames.
@@ -239,6 +240,11 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
p.Anonymizer = aghnet.NewIPMut(nil)
}
var etcHosts upstream.Resolver
if p.EtcHosts != nil {
etcHosts = upstream.NewHostsResolver(p.EtcHosts)
}
s = &Server{
dnsFilter: p.DNSFilter,
dhcpServer: p.DHCPServer,
@@ -247,6 +253,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
privateNets: p.PrivateNets,
// TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{
EnableLRU: true,
@@ -257,9 +264,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
ServePlainDNS: true,
},
}
if p.EtcHosts != nil {
s.etcHosts = p.EtcHosts
}
s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
if err != nil {
@@ -307,7 +311,7 @@ func (s *Server) WriteDiskConfig(c *Config) {
c.AllowedClients = stringutil.CloneSlice(sc.AllowedClients)
c.DisallowedClients = stringutil.CloneSlice(sc.DisallowedClients)
c.BlockedHosts = stringutil.CloneSlice(sc.BlockedHosts)
c.TrustedProxies = stringutil.CloneSlice(sc.TrustedProxies)
c.TrustedProxies = slices.Clone(sc.TrustedProxies)
c.UpstreamDNS = stringutil.CloneSlice(sc.UpstreamDNS)
}
@@ -386,7 +390,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
var resolver *proxy.Proxy
var errMsg string
if s.privateNets.Contains(ip.AsSlice()) {
if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS {
return "", 0, nil
}
@@ -466,13 +470,15 @@ func (s *Server) startLocked() error {
return err
}
// setupLocalResolvers initializes the resolvers for local addresses. It
// assumes s.serverLock is locked or the Server not running.
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
// prepareLocalResolvers initializes the local upstreams configuration using
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
func (s *Server) prepareLocalResolvers(
boot upstream.Resolver,
) (uc *proxy.UpstreamConfig, err error) {
set, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return err
return nil, err
}
resolvers := s.conf.LocalPTRResolvers
@@ -489,29 +495,46 @@ func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: boot,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return fmt.Errorf("preparing private upstreams: %w", err)
return nil, fmt.Errorf("preparing private upstreams: %w", err)
}
if confNeedsFiltering {
err = filterOutAddrs(uc, set)
if err != nil {
return fmt.Errorf("filtering private upstreams: %w", err)
return nil, fmt.Errorf("filtering private upstreams: %w", err)
}
}
return uc, nil
}
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running.
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
uc, err := s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: uc,
},
}
err = s.localResolvers.Init()
if err != nil {
return fmt.Errorf("initializing proxy: %w", err)
}
// TODO(e.burkov): Should we also consider the DNS64 usage?
if s.conf.UsePrivateRDNS &&
// Only set the upstream config if there are any upstreams. It's safe
@@ -697,15 +720,13 @@ func (s *Server) prepareInternalProxy() (err error) {
CacheEnabled: true,
CacheSizeBytes: 4096,
UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: int(s.conf.MaxGoroutines),
MaxGoroutines: s.conf.MaxGoroutines,
}
setProxyUpstreamMode(
conf,
srvConf.AllServers,
srvConf.FastestAddr,
srvConf.FastestTimeout.Duration,
)
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
if err != nil {
return fmt.Errorf("invalid upstream mode: %w", err)
}
// TODO(a.garipov): Make a proper constructor for proxy.Proxy.
p := &proxy.Proxy{

View File

@@ -177,6 +177,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -305,6 +306,7 @@ func TestServer(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -344,6 +346,7 @@ func TestServer_timeout(t *testing.T) {
srvConf := &ServerConfig{
UpstreamTimeout: testTimeout,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -362,6 +365,7 @@ func TestServer_timeout(t *testing.T) {
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
require.NoError(t, err)
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
Enabled: false,
}
@@ -379,6 +383,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
"#tls://1.1.1.1",
"8.8.8.8",
},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -401,6 +406,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -478,7 +484,8 @@ func TestServerRace(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
UpstreamMode: UpstreamModeLoadBalance,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
ServePlainDNS: true,
@@ -531,6 +538,7 @@ func TestSafeSearch(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -547,32 +555,41 @@ func TestSafeSearch(t *testing.T) {
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct {
host string
want netip.Addr
host string
want netip.Addr
wantCNAME string
}{{
host: "yandex.com.",
want: yandexIP,
host: "yandex.com.",
want: yandexIP,
wantCNAME: "",
}, {
host: "yandex.by.",
want: yandexIP,
host: "yandex.by.",
want: yandexIP,
wantCNAME: "",
}, {
host: "yandex.kz.",
want: yandexIP,
host: "yandex.kz.",
want: yandexIP,
wantCNAME: "",
}, {
host: "yandex.ru.",
want: yandexIP,
host: "yandex.ru.",
want: yandexIP,
wantCNAME: "",
}, {
host: "www.google.com.",
want: googleIP,
host: "www.google.com.",
want: googleIP,
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.com.af.",
want: googleIP,
host: "www.google.com.af.",
want: googleIP,
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.be.",
want: googleIP,
host: "www.google.be.",
want: googleIP,
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.by.",
want: googleIP,
host: "www.google.by.",
want: googleIP,
wantCNAME: "forcesafesearch.google.com.",
}}
for _, tc := range testCases {
@@ -582,7 +599,18 @@ func TestSafeSearch(t *testing.T) {
var reply *dns.Msg
reply, _, err = client.Exchange(req, addr)
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
assertResponse(t, reply, tc.want)
if tc.wantCNAME != "" {
require.Len(t, reply.Answer, 2)
cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
assert.Equal(t, tc.wantCNAME, cname.Target)
} else {
require.Len(t, reply.Answer, 1)
}
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
})
}
}
@@ -594,6 +622,7 @@ func TestInvalidRequest(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -623,6 +652,7 @@ func TestBlockedRequest(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -658,7 +688,8 @@ func TestServerCustomClientUpstream(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
CacheSize: defaultCacheSize,
CacheSize: defaultCacheSize,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -736,6 +767,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -769,6 +801,7 @@ func TestBlockCNAME(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -844,6 +877,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
FilterHandler: func(_ netip.Addr, _ string, settings *filtering.Settings) {
settings.FilteringEnabled = false
},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -889,6 +923,7 @@ func TestNullBlockedRequest(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -954,7 +989,8 @@ func TestBlockedCustomIP(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -1007,6 +1043,7 @@ func TestBlockedByHosts(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -1058,6 +1095,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -1116,7 +1154,8 @@ func TestRewrite(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamDNS: []string{"8.8.8.8:53"},
UpstreamDNS: []string{"8.8.8.8:53"},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},
@@ -1245,6 +1284,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
require.NoError(t, err)
@@ -1327,6 +1367,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
require.NoError(t, err)
@@ -1506,9 +1547,9 @@ func TestServer_Exchange(t *testing.T) {
},
},
}
srv.conf.UsePrivateRDNS = true
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
require.NoError(t, srv.internalProxy.Init())
testCases := []struct {
req netip.Addr
@@ -1584,6 +1625,7 @@ func TestServer_Exchange(t *testing.T) {
srv.localResolvers = &proxy.Proxy{
Config: pcfg,
}
require.NoError(t, srv.localResolvers.Init())
t.Run(tc.name, func(t *testing.T) {
host, ttl, eerr := srv.Exchange(tc.req)

View File

@@ -38,6 +38,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,

View File

@@ -95,7 +95,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.Reason == filtering.Rewritten:
pctx.Res = s.filterRewritten(req, host, res, q.Qtype)
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, pctx); err != nil {
return nil, err
@@ -105,37 +105,6 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
return res, err
}
// filterRewritten handles DNS rewrite filters. It returns a DNS response with
// the data from the filtering result. All parameters must not be nil.
func (s *Server) filterRewritten(
req *dns.Msg,
host string,
res *filtering.Result,
qt uint16,
) (resp *dns.Msg) {
resp = s.makeResponse(req)
name := host
if len(res.CanonName) != 0 {
resp.Answer = append(resp.Answer, s.genAnswerCNAME(req, res.CanonName))
name = res.CanonName
}
for _, ip := range res.IPList {
switch qt {
case dns.TypeA:
a := s.genAnswerA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
case dns.TypeAAAA:
a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
}
}
return resp
}
// checkHostRules checks the host against filters. It is safe for concurrent
// use.
func (s *Server) checkHostRules(

View File

@@ -31,6 +31,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{
Enabled: false,
},

View File

@@ -70,7 +70,7 @@ type jsonDNSConfig struct {
DisableIPv6 *bool `json:"disable_ipv6"`
// UpstreamMode defines the way DNS requests are constructed.
UpstreamMode *string `json:"upstream_mode"`
UpstreamMode *jsonUpstreamMode `json:"upstream_mode"`
// BlockedResponseTTL is the TTL for blocked responses.
BlockedResponseTTL *uint32 `json:"blocked_response_ttl"`
@@ -114,6 +114,21 @@ type jsonDNSConfig struct {
DefaultLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"`
}
// jsonUpstreamMode is a enumeration of upstream modes.
type jsonUpstreamMode string
const (
// jsonUpstreamModeEmpty is the default value on frontend, it is used as
// jsonUpstreamModeLoadBalance mode.
//
// Deprecated: Use jsonUpstreamModeLoadBalance instead.
jsonUpstreamModeEmpty jsonUpstreamMode = ""
jsonUpstreamModeLoadBalance jsonUpstreamMode = "load_balance"
jsonUpstreamModeParallel jsonUpstreamMode = "parallel"
jsonUpstreamModeFastestAddr jsonUpstreamMode = "fastest_addr"
)
func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus()
@@ -145,11 +160,16 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
usePrivateRDNS := s.conf.UsePrivateRDNS
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
var upstreamMode string
if s.conf.FastestAddr {
upstreamMode = "fastest_addr"
} else if s.conf.AllServers {
upstreamMode = "parallel"
var upstreamMode jsonUpstreamMode
switch s.conf.UpstreamMode {
case UpstreamModeLoadBalance:
// TODO(d.kolyshev): Support jsonUpstreamModeLoadBalance on frontend instead
// of jsonUpstreamModeEmpty.
upstreamMode = jsonUpstreamModeEmpty
case UpstreamModeParallel:
upstreamMode = jsonUpstreamModeParallel
case UpstreamModeFastestAddr:
upstreamMode = jsonUpstreamModeFastestAddr
}
defPTRUps, err := s.defaultLocalPTRUpstreams()
@@ -222,18 +242,22 @@ func (req *jsonDNSConfig) checkBlockingMode() (err error) {
return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6)
}
// checkUpstreamsMode returns an error if the upstream mode is invalid.
func (req *jsonDNSConfig) checkUpstreamsMode() (err error) {
// checkUpstreamMode returns an error if the upstream mode is invalid.
func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
if req.UpstreamMode == nil {
return nil
}
mode := *req.UpstreamMode
if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok {
return fmt.Errorf("upstream_mode: incorrect value %q", mode)
switch um := *req.UpstreamMode; um {
case
jsonUpstreamModeEmpty,
jsonUpstreamModeLoadBalance,
jsonUpstreamModeParallel,
jsonUpstreamModeFastestAddr:
return nil
default:
return fmt.Errorf("upstream_mode: incorrect value %q", um)
}
return nil
}
// checkBootstrap returns an error if any bootstrap address is invalid.
@@ -243,17 +267,22 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }()
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
@@ -297,7 +326,7 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return err
}
err = req.checkUpstreamsMode()
err = req.checkUpstreamMode()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
@@ -354,15 +383,12 @@ func (req *jsonDNSConfig) checkCacheTTL() (err error) {
if req.CacheMinTTL != nil {
minTTL = *req.CacheMinTTL
}
if req.CacheMaxTTL != nil {
maxTTL = *req.CacheMaxTTL
}
if minTTL <= maxTTL {
return nil
}
return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max")
return validateCacheTTL(minTTL, maxTTL)
}
// checkRatelimitSubnetMaskLen returns an error if the length of the subnet mask
@@ -446,8 +472,9 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
}
if dc.UpstreamMode != nil {
s.conf.AllServers = *dc.UpstreamMode == "parallel"
s.conf.FastestAddr = *dc.UpstreamMode == "fastest_addr"
s.conf.UpstreamMode = mustParseUpstreamMode(*dc.UpstreamMode)
} else {
s.conf.UpstreamMode = UpstreamModeLoadBalance
}
if dc.EDNSCSUseCustom != nil && *dc.EDNSCSUseCustom {
@@ -460,6 +487,22 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
return s.setConfigRestartable(dc)
}
// mustParseUpstreamMode returns an upstream mode parsed from jsonUpstreamMode.
// Panics in case of invalid value.
func mustParseUpstreamMode(mode jsonUpstreamMode) (um UpstreamMode) {
switch mode {
case jsonUpstreamModeEmpty, jsonUpstreamModeLoadBalance:
return UpstreamModeLoadBalance
case jsonUpstreamModeParallel:
return UpstreamModeParallel
case jsonUpstreamModeFastestAddr:
return UpstreamModeFastestAddr
default:
// Should never happen, since the value should be validated.
panic(fmt.Errorf("unexpected upstream mode: %q", mode))
}
}
// setIfNotNil sets the value pointed at by currentPtr to the value pointed at
// by newPtr if newPtr is not nil. currentPtr must not be nil.
func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {

View File

@@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
@@ -76,6 +77,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
FallbackDNS: []string{"9.9.9.10"},
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 56,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ConfigModified: func() {},
@@ -102,7 +104,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
}, {
conf: func() ServerConfig {
conf := defaultConf
conf.FastestAddr = true
conf.UpstreamMode = UpstreamModeFastestAddr
return conf
},
@@ -110,7 +112,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
}, {
conf: func() ServerConfig {
conf := defaultConf
conf.AllServers = true
conf.UpstreamMode = UpstreamModeParallel
return conf
},
@@ -156,6 +158,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 56,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ConfigModified: func() {},
@@ -224,11 +227,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
`upstream servers: validating upstream "!!!": not an ip:port`,
}, {
name: "bootstraps_bad",
wantSet: `validating dns config: checking bootstrap a: invalid address: not a bootstrap: ` +
`ParseAddr("a"): unable to parse IP`,
wantSet: `validating dns config: checking bootstrap a: not a bootstrap: ParseAddr("a"): ` +
`unable to parse IP`,
}, {
name: "cache_bad_ttl",
wantSet: `validating dns config: cache_ttl_min must be less or equal than cache_ttl_max`,
wantSet: `validating dns config: cache_ttl_min must be less than or equal to cache_ttl_max`,
}, {
name: "upstream_mode_bad",
wantSet: `validating dns config: upstream_mode: incorrect value "somethingelse"`,
@@ -522,11 +525,12 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
UpstreamTimeout: upsTimeout,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
srv.etcHosts = hc
srv.etcHosts = upstream.NewHostsResolver(hc)
startDeferStop(t, srv)
testCases := []struct {

View File

@@ -66,12 +66,46 @@ func (s *Server) genDNSFilterMessage(
// If Safe Search generated the necessary IP addresses, use them.
// Otherwise, if there were no errors, there are no addresses for the
// requested IP version, so produce a NODATA response.
return s.genResponseWithIPs(req, ipsFromRules(res.Rules))
return s.getCNAMEWithIPs(req, ipsFromRules(res.Rules), res.CanonName)
default:
return s.genForBlockingMode(req, ipsFromRules(res.Rules))
}
}
// getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
resp = s.makeResponse(req)
originalName := req.Question[0].Name
var ans []dns.RR
if cname != "" {
ans = append(ans, s.genAnswerCNAME(req, cname))
// The given IPs actually are resolved for this cname.
req.Question[0].Name = dns.Fqdn(cname)
defer func() { req.Question[0].Name = originalName }()
}
switch req.Question[0].Qtype {
case dns.TypeA:
ans = append(ans, s.genAnswersWithIPv4s(req, ips)...)
case dns.TypeAAAA:
for _, ip := range ips {
if ip.Is6() {
ans = append(ans, s.genAnswerAAAA(req, ip))
}
}
default:
// Go on and return an empty response.
}
resp.Answer = ans
return resp
}
// genForBlockingMode generates a filtered response to req based on the server's
// blocking mode.
func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {

View File

@@ -36,11 +36,8 @@ type dnsContext struct {
// unreversedReqIP stores an IP address obtained from a PTR request if it
// was parsed successfully and belongs to one of the locally served IP
// ranges. It is also filled with unmapped version of the address if it's
// within DNS64 prefixes.
//
// TODO(e.burkov): Use netip.Addr when we switch to netip more fully.
unreversedReqIP net.IP
// ranges.
unreversedReqIP netip.Addr
// err is the error returned from a processing function.
err error
@@ -350,7 +347,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr().AsSlice())
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
return rc
}
@@ -491,14 +488,7 @@ func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
}
}
var subnet *net.IPNet
subnet, err = netutil.SubnetFromReversedAddr(domain[idx:])
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
return netutil.IPNetToPrefixNoMapped(subnet)
return netutil.PrefixFromReversedAddr(domain[idx:])
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
@@ -532,8 +522,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
subnetAddr := subnet.Addr()
addrData := subnetAddr.AsSlice()
if !s.privateNets.Contains(addrData) {
if !s.privateNets.Contains(subnetAddr) {
return resultCodeSuccess
}
@@ -548,7 +537,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = addrData
dctx.unreversedReqIP = subnetAddr
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA
@@ -573,16 +562,8 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
ip := dctx.unreversedReqIP
if ip == nil {
return resultCodeSuccess
}
// TODO(a.garipov): Remove once we switch to [netip.Addr] more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Debug("dnsforward: bad reverse ip %v from dhcp: %s", ip, err)
ipAddr := dctx.unreversedReqIP
if ipAddr == (netip.Addr{}) {
return resultCodeSuccess
}
@@ -591,7 +572,7 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
log.Debug("dnsforward: dhcp client %s is %q", ip, host)
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host)
req := pctx.Req
resp := s.makeResponse(req)
@@ -624,7 +605,7 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
}
ip := dctx.unreversedReqIP
if ip == nil {
if ip == (netip.Addr{}) {
return resultCodeSuccess
}
@@ -639,8 +620,7 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
// Generate the server failure if the private upstream configuration
// is empty.
//
// TODO(e.burkov): Get rid of this crutch once the local resolvers
// logic is moved to the dnsproxy completely.
// This is a crutch, see TODO at [Server.localResolvers].
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)

View File

@@ -79,6 +79,7 @@ func TestServer_ProcessInitial(t *testing.T) {
c := ServerConfig{
Config: Config{
AAAADisabled: tc.aaaaDisabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -179,6 +180,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
c := ServerConfig{
Config: Config{
AAAADisabled: tc.aaaaDisabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -694,6 +696,7 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
// Improve Config declaration for tests.
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -770,6 +773,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
@@ -791,7 +795,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
}
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,
unreversedReqIP: net.IP{192, 168, 1, 1},
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
}
s.conf.UsePrivateRDNS = use
}

View File

@@ -25,10 +25,10 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
ip := pctx.Addr.Addr().AsSlice()
s.anonymizer.Load()(ip)
ipStr := net.IP(ip).String()
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr)
ipStr := pctx.Addr.Addr().String()
ids := []string{ipStr, dctx.clientID}
qt, cl := q.Qtype, q.Qclass

View File

@@ -17,6 +17,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,

View File

@@ -136,18 +136,22 @@ func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
// based on provided parameters.
func setProxyUpstreamMode(
conf *proxy.Config,
allServers bool,
fastestAddr bool,
upstreamMode UpstreamMode,
fastestTimeout time.Duration,
) {
if allServers {
) (err error) {
switch upstreamMode {
case UpstreamModeParallel:
conf.UpstreamMode = proxy.UModeParallel
} else if fastestAddr {
case UpstreamModeFastestAddr:
conf.UpstreamMode = proxy.UModeFastestAddr
conf.FastestPingTimeout = fastestTimeout
} else {
case UpstreamModeLoadBalance:
conf.UpstreamMode = proxy.UModeLoadBalance
default:
return fmt.Errorf("unexpected value %q", upstreamMode)
}
return nil
}
// createBootstrap returns a bootstrap resolver based on the configuration of s.
@@ -176,7 +180,7 @@ func (s *Server) createBootstrap(
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, b)
parallel = append(parallel, upstream.NewCachingResolver(b))
}
if s.etcHosts != nil {
@@ -294,7 +298,7 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
continue
}
if !privateNets.Contains(subnet.Addr().AsSlice()) {
if !privateNets.Contains(subnet.Addr()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),