all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-07-03 14:10:40 +03:00
parent cadb765b7d
commit b22b16d98c
140 changed files with 6739 additions and 2521 deletions

View File

@@ -56,15 +56,20 @@ func (rm *requestMatcher) MatchRequest(
) (res *urlfilter.DNSResult, ok bool) {
switch req.DNSType {
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
log.Debug("%s: handling the request for %s", hostsContainerPrefix, req.Hostname)
log.Debug(
"%s: handling %s request for %s",
hostsContainerPrefix,
dns.Type(req.DNSType),
req.Hostname,
)
rm.stateLock.RLock()
defer rm.stateLock.RUnlock()
return rm.engine.MatchRequest(req)
default:
return nil, false
}
rm.stateLock.RLock()
defer rm.stateLock.RUnlock()
return rm.engine.MatchRequest(req)
}
// Translate returns the source hosts-syntax rule for the generated dnsrewrite
@@ -96,6 +101,8 @@ const hostsContainerPrefix = "hosts container"
// HostsContainer stores the relevant hosts database provided by the OS and
// processes both A/AAAA and PTR DNS requests for those.
//
// TODO(e.burkov): Improve API and move to golibs.
type HostsContainer struct {
// requestMatcher matches the requests and translates the rules. It's
// embedded to implement MatchRequest and Translate for *HostsContainer.

View File

@@ -25,11 +25,8 @@ func (s *bitSet) isSet(n uint64) (ok bool) {
var word uint64
word, ok = s.words[wordIdx]
if !ok {
return false
}
return word&(1<<bitIdx) != 0
return ok && word&(1<<bitIdx) != 0
}
// set sets or unsets a bit.

View File

@@ -249,31 +249,30 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DHCPv4) {
switch giaddr, ciaddr, mtype := req.GatewayIPAddr, req.ClientIPAddr, resp.MessageType(); {
case giaddr != nil && !giaddr.IsUnspecified():
// Send any return messages to the server port on the BOOTP
// relay agent whose address appears in giaddr.
// Send any return messages to the server port on the BOOTP relay agent
// whose address appears in giaddr.
peer = &net.UDPAddr{
IP: giaddr,
Port: dhcpv4.ServerPort,
}
if mtype == dhcpv4.MessageTypeNak {
// Set the broadcast bit in the DHCPNAK, so that the relay agent
// broadcasts it to the client, because the client may not have
// a correct network address or subnet mask, and the client may not
// be answering ARP requests.
// broadcasts it to the client, because the client may not have a
// correct network address or subnet mask, and the client may not be
// answering ARP requests.
resp.SetBroadcast()
}
case mtype == dhcpv4.MessageTypeNak:
// Broadcast any DHCPNAK messages to 0xffffffff.
case ciaddr != nil && !ciaddr.IsUnspecified():
// Unicast DHCPOFFER and DHCPACK messages to the address in
// ciaddr.
// Unicast DHCPOFFER and DHCPACK messages to the address in ciaddr.
peer = &net.UDPAddr{
IP: ciaddr,
Port: dhcpv4.ClientPort,
}
case !req.IsBroadcast() && req.ClientHWAddr != nil:
// Unicast DHCPOFFER and DHCPACK messages to the client's
// hardware address and yiaddr.
// Unicast DHCPOFFER and DHCPACK messages to the client's hardware
// address and yiaddr.
peer = &dhcpUnicastAddr{
Addr: raw.Addr{HardwareAddr: req.ClientHWAddr},
yiaddr: resp.YourIPAddr,

View File

@@ -247,31 +247,30 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DHCPv4) {
switch giaddr, ciaddr, mtype := req.GatewayIPAddr, req.ClientIPAddr, resp.MessageType(); {
case giaddr != nil && !giaddr.IsUnspecified():
// Send any return messages to the server port on the BOOTP
// relay agent whose address appears in giaddr.
// Send any return messages to the server port on the BOOTP relay agent
// whose address appears in giaddr.
peer = &net.UDPAddr{
IP: giaddr,
Port: dhcpv4.ServerPort,
}
if mtype == dhcpv4.MessageTypeNak {
// Set the broadcast bit in the DHCPNAK, so that the relay agent
// broadcasts it to the client, because the client may not have
// a correct network address or subnet mask, and the client may not
// be answering ARP requests.
// broadcasts it to the client, because the client may not have a
// correct network address or subnet mask, and the client may not be
// answering ARP requests.
resp.SetBroadcast()
}
case mtype == dhcpv4.MessageTypeNak:
// Broadcast any DHCPNAK messages to 0xffffffff.
case ciaddr != nil && !ciaddr.IsUnspecified():
// Unicast DHCPOFFER and DHCPACK messages to the address in
// ciaddr.
// Unicast DHCPOFFER and DHCPACK messages to the address in ciaddr.
peer = &net.UDPAddr{
IP: ciaddr,
Port: dhcpv4.ClientPort,
}
case !req.IsBroadcast() && req.ClientHWAddr != nil:
// Unicast DHCPOFFER and DHCPACK messages to the client's
// hardware address and yiaddr.
// Unicast DHCPOFFER and DHCPACK messages to the client's hardware
// address and yiaddr.
peer = &dhcpUnicastAddr{
Addr: packet.Addr{HardwareAddr: req.ClientHWAddr},
yiaddr: resp.YourIPAddr,

View File

@@ -28,8 +28,9 @@ const (
defaultBackoff time.Duration = 500 * time.Millisecond
)
// Lease contains the necessary information about a DHCP lease. It's used in
// various places. So don't change it without good reason.
// Lease contains the necessary information about a DHCP lease. It's used as is
// in the database, so don't change it until it's absolutely necessary, see
// [dataVersion].
type Lease struct {
// Expiry is the expiration time of the lease.
Expiry time.Time `json:"expires"`
@@ -41,8 +42,6 @@ type Lease struct {
HWAddr net.HardwareAddr `json:"mac"`
// IP is the IP address leased to the client.
//
// TODO(a.garipov): Migrate leases.db.
IP netip.Addr `json:"ip"`
// IsStatic defines if the lease is static.

View File

@@ -51,6 +51,9 @@ func migrateDB(conf *ServerConfig) (err error) {
oldLeasesPath := filepath.Join(conf.WorkDir, dbFilename)
dataDirPath := filepath.Join(conf.DataDir, dataFilename)
// #nosec G304 -- Trust this path, since it's taken from the old file name
// relative to the working directory and should generally be considered
// safe.
file, err := os.Open(oldLeasesPath)
if errors.Is(err, os.ErrNotExist) {
// Nothing to migrate.

View File

@@ -200,7 +200,7 @@ func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
func (ra *raCtx) Init() (err error) {
ra.stop.Store(0)
ra.conn = nil
if !(ra.raAllowSLAAC || ra.raSLAACOnly) {
if !ra.raAllowSLAAC && !ra.raSLAACOnly {
return nil
}

View File

@@ -0,0 +1,86 @@
package dhcpsvc
import (
"net/netip"
"time"
"github.com/google/gopacket/layers"
)
// Config is the configuration for the DHCP service.
type Config struct {
// Interfaces stores configurations of DHCP server specific for the network
// interface identified by its name.
Interfaces map[string]*InterfaceConfig
// LocalDomainName is the top-level domain name to use for resolving DHCP
// clients' hostnames.
LocalDomainName string
// ICMPTimeout is the timeout for checking another DHCP server's presence.
ICMPTimeout time.Duration
// Enabled is the state of the service, whether it is enabled or not.
Enabled bool
}
// InterfaceConfig is the configuration of a single DHCP interface.
type InterfaceConfig struct {
// IPv4 is the configuration of DHCP protocol for IPv4.
IPv4 *IPv4Config
// IPv6 is the configuration of DHCP protocol for IPv6.
IPv6 *IPv6Config
}
// IPv4Config is the interface-specific configuration for DHCPv4.
type IPv4Config struct {
// GatewayIP is the IPv4 address of the network's gateway. It is used as
// the default gateway for DHCP clients and also used in calculating the
// network-specific broadcast address.
GatewayIP netip.Addr
// SubnetMask is the IPv4 subnet mask of the network. It should be a valid
// IPv4 subnet mask (i.e. all 1s followed by all 0s).
SubnetMask netip.Addr
// RangeStart is the first address in the range to assign to DHCP clients.
RangeStart netip.Addr
// RangeEnd is the last address in the range to assign to DHCP clients.
RangeEnd netip.Addr
// Options is the list of DHCP options to send to DHCP clients.
Options layers.DHCPOptions
// LeaseDuration is the TTL of a DHCP lease.
LeaseDuration time.Duration
// Enabled is the state of the DHCPv4 service, whether it is enabled or not
// on the specific interface.
Enabled bool
}
// IPv6Config is the interface-specific configuration for DHCPv6.
type IPv6Config struct {
// RangeStart is the first address in the range to assign to DHCP clients.
RangeStart netip.Addr
// Options is the list of DHCP options to send to DHCP clients.
Options layers.DHCPOptions
// LeaseDuration is the TTL of a DHCP lease.
LeaseDuration time.Duration
// RASlaacOnly defines whether the DHCP clients should only use SLAAC for
// address assignment.
RASLAACOnly bool
// RAAllowSlaac defines whether the DHCP clients may use SLAAC for address
// assignment.
RAAllowSLAAC bool
// Enabled is the state of the DHCPv6 service, whether it is enabled or not
// on the specific interface.
Enabled bool
}

120
internal/dhcpsvc/dhcpsvc.go Normal file
View File

@@ -0,0 +1,120 @@
// Package dhcpsvc contains the AdGuard Home DHCP service.
//
// TODO(e.burkov): Add tests.
package dhcpsvc
import (
"context"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
)
// Lease is a DHCP lease.
//
// TODO(e.burkov): Consider it to [agh], since it also may be needed in
// [websvc]. Also think of implementing iterating methods with appropriate
// signatures.
type Lease struct {
// IP is the IP address leased to the client.
IP netip.Addr
// Expiry is the expiration time of the lease.
Expiry time.Time
// Hostname of the client.
Hostname string
// HWAddr is the physical hardware address (MAC address).
HWAddr net.HardwareAddr
// IsStatic defines if the lease is static.
IsStatic bool
}
type Interface interface {
agh.ServiceWithConfig[*Config]
// Enabled returns true if DHCP provides information about clients.
Enabled() (ok bool)
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have an IP address.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
// returns nil if there is no such client, due to an assumption that a DHCP
// client must always have a MAC address.
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
// IPByHost returns the IP address of the DHCP client with the given
// hostname. The hostname will be an empty string if there is no such
// client, due to an assumption that a DHCP client must always have a
// hostname, either set by the client or assigned automatically.
IPByHost(host string) (ip netip.Addr)
// Leases returns all the DHCP leases.
Leases() (leases []*Lease)
// AddLease adds a new DHCP lease. It returns an error if the lease is
// invalid or already exists.
AddLease(l *Lease) (err error)
// EditLease changes an existing DHCP lease. It returns an error if there
// is no lease equal to old or if new is invalid or already exists.
EditLease(old, new *Lease) (err error)
// RemoveLease removes an existing DHCP lease. It returns an error if there
// is no lease equal to l.
RemoveLease(l *Lease) (err error)
// Reset removes all the DHCP leases.
Reset() (err error)
}
// Empty is an [Interface] implementation that does nothing.
type Empty struct{}
// type check
var _ Interface = Empty{}
// Start implements the [Service] interface for Empty.
func (Empty) Start() (err error) { return nil }
// Shutdown implements the [Service] interface for Empty.
func (Empty) Shutdown(_ context.Context) (err error) { return nil }
var _ agh.ServiceWithConfig[*Config] = Empty{}
// Config implements the [ServiceWithConfig] interface for Empty.
func (Empty) Config() (conf *Config) { return nil }
// Enabled implements the [Interface] interface for Empty.
func (Empty) Enabled() (ok bool) { return false }
// HostByIP implements the [Interface] interface for Empty.
func (Empty) HostByIP(_ netip.Addr) (host string) { return "" }
// MACByIP implements the [Interface] interface for Empty.
func (Empty) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
// IPByHost implements the [Interface] interface for Empty.
func (Empty) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} }
// Leases implements the [Interface] interface for Empty.
func (Empty) Leases() (leases []*Lease) { return nil }
// AddLease implements the [Interface] interface for Empty.
func (Empty) AddLease(_ *Lease) (err error) { return nil }
// EditLease implements the [Interface] interface for Empty.
func (Empty) EditLease(_, _ *Lease) (err error) { return nil }
// RemoveLease implements the [Interface] interface for Empty.
func (Empty) RemoveLease(_ *Lease) (err error) { return nil }
// Reset implements the [Interface] interface for Empty.
func (Empty) Reset() (err error) { return nil }

View File

@@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -436,102 +435,6 @@ func (s *Server) initDefaultSettings() {
}
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// prepareUpstreamSettings - prepares upstream DNS server settings
func (s *Server) prepareUpstreamSettings() error {
// We're setting a customized set of RootCAs. The reason is that Go default
// mechanism of loading TLS roots does not always work properly on some
// routers so we're loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
upstream.RootCAs = s.conf.TLSv12Roots
upstream.CipherSuites = s.conf.TLSCiphers
// Load upstreams either from the file, or from the settings
var upstreams []string
if s.conf.UpstreamDNSFileName != "" {
data, err := os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil {
return fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dns: using %d upstream servers from file %s", len(upstreams), s.conf.UpstreamDNSFileName)
} else {
upstreams = s.conf.UpstreamDNS
}
httpVersions := UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams)
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
upstreamConfig, err := proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: httpVersions,
PreferIPv6: s.conf.BootstrapPreferIPv6,
},
)
if err != nil {
return fmt.Errorf("parsing upstream config: %w", err)
}
if len(upstreamConfig.Upstreams) == 0 {
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
var uc *proxy.UpstreamConfig
uc, err = proxy.ParseUpstreamsConfig(
defaultDNS,
&upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: httpVersions,
PreferIPv6: s.conf.BootstrapPreferIPv6,
},
)
if err != nil {
return fmt.Errorf("parsing default upstreams: %w", err)
}
upstreamConfig.Upstreams = uc.Upstreams
}
s.conf.UpstreamConfig = upstreamConfig
return nil
}
// setProxyUpstreamMode sets the upstream mode and related settings in conf
// based on provided parameters.
func setProxyUpstreamMode(
conf *proxy.Config,
allServers bool,
fastestAddr bool,
fastestTimeout time.Duration,
) {
if allServers {
conf.UpstreamMode = proxy.UModeParallel
} else if fastestAddr {
conf.UpstreamMode = proxy.UModeFastestAddr
conf.FastestPingTimeout = fastestTimeout
} else {
conf.UpstreamMode = proxy.UModeLoadBalance
}
}
// prepareIpsetListSettings reads and prepares the ipset configuration either
// from a file or from the data in the configuration file.
func (s *Server) prepareIpsetListSettings() (err error) {
@@ -540,6 +443,7 @@ func (s *Server) prepareIpsetListSettings() (err error) {
return s.ipset.init(s.conf.IpsetList)
}
// #nosec G304 -- Trust the path explicitly given by the user.
data, err := os.ReadFile(fn)
if err != nil {
return err

View File

@@ -145,10 +145,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
// processRecursion checks the incoming request and halts its handling by
// answering NXDOMAIN if s has tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing recursion")
defer log.Debug("dnsforward: finished processing recursion")
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("recursion detected resolving %q", msg.Question[0].Name)
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
@@ -158,10 +161,13 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
}
// processInitial terminates the following processing for some requests if
// needed and enriches the ctx with some client-specific information.
// needed and enriches dctx with some client-specific information.
//
// TODO(e.burkov): Decompose into less general processors.
func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing initial")
defer log.Debug("dnsforward: finished processing initial")
pctx := dctx.proxyCtx
q := pctx.Req.Question[0]
qt := q.Qtype
@@ -282,6 +288,9 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
//
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing ddr")
defer log.Debug("dnsforward: finished processing ddr")
if !s.conf.HandleDDR {
return resultCodeSuccess
}
@@ -375,6 +384,9 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
// processDetermineLocal determines if the client's IP address is from locally
// served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local detection")
defer log.Debug("dnsforward: finished processing local detection")
rc = resultCodeSuccess
var ip net.IP
@@ -405,6 +417,9 @@ func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
//
// TODO(a.garipov): Adapt to AAAA as well.
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing dhcp hosts")
defer log.Debug("dnsforward: finished processing dhcp hosts")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
@@ -544,6 +559,9 @@ func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local restriction")
defer log.Debug("dnsforward: finished processing local restriction")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
@@ -613,6 +631,9 @@ func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing dhcp addrs")
defer log.Debug("dnsforward: finished processing dhcp addrs")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
@@ -658,6 +679,9 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
// processLocalPTR responds to PTR requests if the target IP is detected to be
// inside the local network and the query was not answered from DHCP.
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local ptr")
defer log.Debug("dnsforward: finished processing local ptr")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
@@ -692,6 +716,9 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
// Apply filtering logic
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req")
if ctx.proxyCtx.Res != nil {
// Go on since the response is already set.
return resultCodeSuccess
@@ -725,6 +752,9 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
// processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing upstream")
defer log.Debug("dnsforward: finished processing upstream")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
@@ -871,6 +901,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
// Apply filtering logic after we have received response from upstream servers
func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering after resp")
defer log.Debug("dnsforward: finished processing filtering after resp")
pctx := dctx.proxyCtx
switch res := dctx.result; res.Reason {
case filtering.NotFilteredAllowList:

View File

@@ -48,12 +48,33 @@ var webRegistered bool
// hostToIPTable is a convenient type alias for tables of host names to an IP
// address.
//
// TODO(e.burkov): Use the [DHCP] interface instead.
type hostToIPTable = map[string]netip.Addr
// ipToHostTable is a convenient type alias for tables of IP addresses to their
// host names. For example, for use with PTR queries.
//
// TODO(e.burkov): Use the [DHCP] interface instead.
type ipToHostTable = map[netip.Addr]string
// DHCP is an interface for accessing DHCP lease data needed in this package.
type DHCP interface {
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have an IP address.
HostByIP(ip netip.Addr) (host string)
// IPByHost returns the IP address of the DHCP client with the given
// hostname. The hostname will be an empty string if there is no such
// client, due to an assumption that a DHCP client must always have a
// hostname, either set by the client or assigned automatically.
IPByHost(host string) (ip netip.Addr)
// Enabled returns true if DHCP provides information about clients.
Enabled() (ok bool)
}
// Server is the main way to start a DNS server.
//
// Example:
@@ -215,7 +236,7 @@ func (s *Server) Close() {
s.dnsProxy = nil
if err := s.ipset.close(); err != nil {
log.Error("closing ipset: %s", err)
log.Error("dnsforward: closing ipset: %s", err)
}
}
@@ -443,21 +464,17 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
return err
}
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", localAddrs)
var upsConfig *proxy.UpstreamConfig
upsConfig, err = proxy.ParseUpstreamsConfig(
localAddrs,
&upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
upsConfig, err := s.prepareUpstreamConfig(localAddrs, nil, &upstream.Options{
Bootstrap: bootstraps,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
},
)
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return fmt.Errorf("parsing upstreams: %w", err)
return fmt.Errorf("parsing private upstreams: %w", err)
}
s.localResolvers = &proxy.Proxy{
@@ -489,7 +506,8 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
err = s.prepareUpstreamSettings()
if err != nil {
return fmt.Errorf("preparing upstream settings: %w", err)
// Don't wrap the error, because it's informative enough as is.
return err
}
var proxyConfig proxy.Config
@@ -656,7 +674,9 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
s.serverLock.Lock()
defer s.serverLock.Unlock()
log.Print("Start reconfiguring the server")
log.Info("dnsforward: starting reconfiguring server")
defer log.Info("dnsforward: finished reconfiguring server")
err := s.stopLocked()
if err != nil {
return fmt.Errorf("could not reconfigure the server: %w", err)
@@ -708,13 +728,13 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
// Allow if at least one of the checks allows in allowlist mode, but block
// if at least one of the checks blocks in blocklist mode.
if allowlistMode && blockedByIP && blockedByClientID {
log.Debug("client %v (id %q) is not in access allowlist", ip, clientID)
log.Debug("dnsforward: client %v (id %q) is not in access allowlist", ip, clientID)
// Return now without substituting the empty rule for the
// clientID because the rule can't be empty here.
return true, rule
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
log.Debug("client %v (id %q) is in access blocklist", ip, clientID)
log.Debug("dnsforward: client %v (id %q) is in access blocklist", ip, clientID)
blocked = true
}

View File

@@ -53,14 +53,14 @@ func (s *Server) beforeRequestHandler(
// getClientRequestFilteringSettings looks up client filtering settings using
// the client's IP address and ID, if any, from dctx.
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
setts := s.dnsFilter.Settings()
setts.ProtectionEnabled = dctx.protectionEnabled
if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, dctx.clientID, &setts)
s.conf.FilterHandler(ip, dctx.clientID, setts)
}
return &setts
return setts
}
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the

View File

@@ -633,61 +633,70 @@ func (err domainSpecificTestError) Error() (msg string) {
return fmt.Sprintf("WARNING: %s", err.error)
}
// checkDNS checks the upstream server defined by upstreamConfigStr using
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
// upstream's address.
func checkDNS(
upstreamConfigStr string,
bootstrap []string,
bootstrapPrefIPv6 bool,
timeout time.Duration,
healthCheck healthCheckFunc,
) (err error) {
if IsCommentOrEmpty(upstreamConfigStr) {
return nil
// parseUpstreamLine parses line and creates the [upstream.Upstream] using opts
// and information from [s.dnsFilter.EtcHosts]. It returns an error if the line
// is not a valid upstream line, see [upstream.AddressToUpstream]. It's a
// caller's responsibility to close u.
func (s *Server) parseUpstreamLine(
line string,
opts *upstream.Options,
) (u upstream.Upstream, specific bool, err error) {
// Separate upstream from domains list.
upstreamAddr, domains, err := separateUpstream(line)
if err != nil {
return nil, false, fmt.Errorf("wrong upstream format: %w", err)
}
// Separate upstream from domains list.
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}
specific = len(domains) > 0
useDefault, err := validateUpstream(upstreamAddr, domains)
if err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
return nil, specific, fmt.Errorf("wrong upstream format: %w", err)
} else if useDefault {
return nil
}
if len(bootstrap) == 0 {
bootstrap = defaultBootstrap
return nil, specific, nil
}
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
PreferIPv6: bootstrapPrefIPv6,
})
opts = &upstream.Options{
Bootstrap: opts.Bootstrap,
Timeout: opts.Timeout,
PreferIPv6: opts.PreferIPv6,
}
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
resolved := s.resolveUpstreamHost(extractUpstreamHost(upstreamAddr))
sortNetIPAddrs(resolved, opts.PreferIPv6)
opts.ServerIPAddrs = resolved
}
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
if err != nil {
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
return nil, specific, fmt.Errorf("creating upstream for %q: %w", upstreamAddr, err)
}
return u, specific, nil
}
func (s *Server) checkDNS(line string, opts *upstream.Options, check healthCheckFunc) (err error) {
if IsCommentOrEmpty(line) {
return nil
}
var u upstream.Upstream
var specific bool
defer func() {
if err != nil && specific {
err = domainSpecificTestError{error: err}
}
}()
u, specific, err = s.parseUpstreamLine(line, opts)
if err != nil || u == nil {
return err
}
defer func() { err = errors.WithDeferred(err, u.Close()) }()
if err = healthCheck(u); err != nil {
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
if domains != nil {
return domainSpecificTestError{error: err}
}
return err
}
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
return nil
return check(u)
}
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
@@ -699,47 +708,54 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
return
}
result := map[string]string{}
bootstraps := req.BootstrapDNS
bootstrapPrefIPv6 := s.conf.BootstrapPreferIPv6
timeout := s.conf.UpstreamTimeout
opts := &upstream.Options{
Bootstrap: req.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
PreferIPv6: s.conf.BootstrapPreferIPv6,
}
if len(opts.Bootstrap) == 0 {
opts.Bootstrap = defaultBootstrap
}
type upsCheckResult = struct {
res string
err error
host string
}
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
upsNum := len(req.Upstreams) + len(req.PrivateUpstreams)
result := make(map[string]string, upsNum)
resCh := make(chan upsCheckResult, upsNum)
checkUps := func(ups string, healthCheck healthCheckFunc) {
res := upsCheckResult{
host: ups,
}
defer func() { resCh <- res }()
checkErr := checkDNS(ups, bootstraps, bootstrapPrefIPv6, timeout, healthCheck)
if checkErr != nil {
res.res = checkErr.Error()
} else {
res.res = "OK"
}
}
for _, ups := range req.Upstreams {
go checkUps(ups, checkDNSUpstreamExc)
go func(ups string) {
resCh <- upsCheckResult{
host: ups,
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
}
}(ups)
}
for _, ups := range req.PrivateUpstreams {
go checkUps(ups, checkPrivateUpstreamExc)
go func(ups string) {
resCh <- upsCheckResult{
host: ups,
err: s.checkDNS(ups, opts, checkPrivateUpstreamExc),
}
}(ups)
}
for i := 0; i < upsNum; i++ {
pair := <-resCh
// TODO(e.burkov): The upstreams used for both common and private
// resolving should be reported separately.
result[pair.host] = pair.res
pair := <-resCh
if pair.err != nil {
result[pair.host] = pair.err.Error()
} else {
result[pair.host] = "OK"
}
}
close(resCh)
_ = aghhttp.WriteJSONResponse(w, r, result)
}

View File

@@ -13,10 +13,12 @@ import (
"path/filepath"
"strings"
"testing"
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/netutil"
@@ -280,6 +282,10 @@ func TestIsCommentOrEmpty(t *testing.T) {
}
func TestValidateUpstreams(t *testing.T) {
const sdnsStamp = `sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_J` +
`S3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczE` +
`uYWRndWFyZC5jb20`
testCases := []struct {
name string
wantErr string
@@ -300,7 +306,7 @@ func TestValidateUpstreams(t *testing.T) {
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"[/host/]" + sdnsStamp,
},
}, {
name: "with_default",
@@ -310,7 +316,7 @@ func TestValidateUpstreams(t *testing.T) {
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"[/host/]" + sdnsStamp,
"8.8.8.8",
},
}, {
@@ -326,9 +332,10 @@ func TestValidateUpstreams(t *testing.T) {
wantErr: `validating upstream "123.3.7m": not an ip:port`,
set: []string{"123.3.7m"},
}, {
name: "invalid",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
set: []string{"[/host.com]tls://dns.adguard.com"},
name: "invalid",
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` +
`missing separator`,
set: []string{"[/host.com]tls://dns.adguard.com"},
}, {
name: "invalid",
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
@@ -340,14 +347,14 @@ func TestValidateUpstreams(t *testing.T) {
"1.1.1.1",
"tls://1.1.1.1",
"https://dns.adguard.com/dns-query",
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
sdnsStamp,
"udp://dns.google",
"udp://8.8.8.8",
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
"[/host/]" + sdnsStamp,
"[/пример.рф/]8.8.8.8",
},
}, {
@@ -418,27 +425,28 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
}
}
func newLocalUpstreamListener(t *testing.T, port int, handler dns.Handler) (real net.Addr) {
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper()
startCh := make(chan struct{})
upsSrv := &dns.Server{
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(port)).String(),
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), port).String(),
Net: "tcp",
Handler: handler,
NotifyStartedFunc: func() { close(startCh) },
}
go func() {
t := testutil.PanicT{}
err := upsSrv.ListenAndServe()
require.NoError(t, err)
require.NoError(testutil.PanicT{}, err)
}()
<-startCh
testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown)
return upsSrv.Listener.Addr()
return testutil.RequireTypeAssert[*net.TCPAddr](t, upsSrv.Listener.Addr()).AddrPort()
}
func TestServer_handleTestUpstreaDNS(t *testing.T) {
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
err := w.WriteMsg(new(dns.Msg).SetReply(m))
require.NoError(testutil.PanicT{}, err)
@@ -457,9 +465,38 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
}).String()
const upsTimeout = 100 * time.Millisecond
const (
upsTimeout = 100 * time.Millisecond
srv := createTestServer(t, &filtering.Config{}, ServerConfig{
hostsFileName = "hosts"
upstreamHost = "custom.localhost"
)
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
hostsUps := (&url.URL{
Scheme: "tcp",
Host: netutil.JoinHostPort(upstreamHost, int(hostsListener.Port())),
}).String()
hc, err := aghnet.NewHostsContainer(
filtering.SysHostsListID,
fstest.MapFS{
hostsFileName: &fstest.MapFile{
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
},
},
&aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(_ string) (err error) { return nil },
OnClose: func() (err error) { return nil },
},
hostsFileName,
)
require.NoError(t, err)
srv := createTestServer(t, &filtering.Config{
EtcHosts: hc,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UpstreamTimeout: upsTimeout,
@@ -486,8 +523,7 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
"upstream_dns": []string{badUps},
},
wantResp: map[string]any{
badUps: `upstream "` + badUps + `" fails to exchange: ` +
`couldn't communicate with upstream: exchanging with ` +
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
name: "broken",
@@ -497,20 +533,40 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
},
wantResp: map[string]any{
goodUps: "OK",
badUps: `upstream "` + badUps + `" fails to exchange: ` +
`couldn't communicate with upstream: exchanging with ` +
badUps: `couldn't communicate with upstream: exchanging with ` +
badUps + ` over tcp: dns: id mismatch`,
},
name: "both",
}, {
body: map[string]any{
"upstream_dns": []string{"[/domain.example/]" + badUps},
},
wantResp: map[string]any{
"[/domain.example/]" + badUps: `WARNING: couldn't communicate ` +
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
`dns: id mismatch`,
},
name: "domain_specific_error",
}, {
body: map[string]any{
"upstream_dns": []string{hostsUps},
},
wantResp: map[string]any{
hostsUps: "OK",
},
name: "etc_hosts",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reqBody, err := json.Marshal(tc.body)
var reqBody []byte
reqBody, err = json.Marshal(tc.body)
require.NoError(t, err)
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
require.NoError(t, err)
srv.handleTestUpstreamDNS(w, r)
@@ -538,11 +594,15 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
req := map[string]any{
"upstream_dns": []string{sleepyUps},
}
reqBody, err := json.Marshal(req)
var reqBody []byte
reqBody, err = json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
require.NoError(t, err)
srv.handleTestUpstreamDNS(w, r)

View File

@@ -110,6 +110,9 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
// process adds the resolved IP addresses to the domain's ipsets, if any.
func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: ipset: started processing")
defer log.Debug("dnsforward: ipset: finished processing")
if c.skipIpsetProcessing(dctx) {
return resultCodeSuccess
}
@@ -125,12 +128,12 @@ func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
n, err := c.ipsetMgr.Add(host, ip4s, ip6s)
if err != nil {
// Consider ipset errors non-critical to the request.
log.Error("ipset: adding host ips: %s", err)
log.Error("dnsforward: ipset: adding host ips: %s", err)
return resultCodeSuccess
}
log.Debug("ipset: added %d new ipset entries", n)
log.Debug("dnsforward: ipset: added %d new ipset entries", n)
return resultCodeSuccess
}

View File

@@ -57,16 +57,13 @@ func (s *Server) genDNSFilterMessage(
return s.genBlockedHost(req, s.conf.SafeBrowsingBlockHost, dctx)
case filtering.FilteredParental:
return s.genBlockedHost(req, s.conf.ParentalBlockHost, dctx)
case filtering.FilteredSafeSearch:
// If Safe Search generated the necessary IP addresses, use them.
// Otherwise, if there were no errors, there are no addresses for the
// requested IP version, so produce a NODATA response.
return s.genResponseWithIPs(req, ipsFromRules(res.Rules))
default:
// If the query was filtered by Safe Search, filtering also must return
// the IP addresses that must be used in response. Return them
// regardless of the filtering method.
ips := ipsFromRules(res.Rules)
if res.Reason == filtering.FilteredSafeSearch && len(ips) > 0 {
return s.genResponseWithIPs(req, ips)
}
return s.genForBlockingMode(req, ips)
return s.genForBlockingMode(req, ipsFromRules(res.Rules))
}
}

View File

@@ -17,60 +17,78 @@ import (
// Write Stats data and logs
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing querylog and stats")
defer log.Debug("dnsforward: finished processing querylog and stats")
elapsed := time.Since(dctx.startTime)
pctx := dctx.proxyCtx
shouldLog := true
msg := pctx.Req
q := msg.Question[0]
q := pctx.Req.Question[0]
host := strings.ToLower(strings.TrimSuffix(q.Name, "."))
// don't log ANY request if refuseAny is enabled
if q.Qtype == dns.TypeANY && s.conf.RefuseAny {
shouldLog = false
}
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
ip = slices.Clone(ip)
s.serverLock.RLock()
defer s.serverLock.RUnlock()
s.anonymizer.Load()(ip)
log.Debug("client ip: %s", ip)
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
ipStr := ip.String()
ids := []string{ipStr, dctx.clientID}
qt, cl := q.Qtype, q.Qclass
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
// uninitialized while in use. This can happen after proxy server has been
// stopped, but its workers haven't yet exited.
if shouldLog &&
s.queryLog != nil &&
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
// containing persistent client.
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass, ids) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.shouldLog(host, qt, cl, ids) {
s.logQuery(dctx, pctx, elapsed, ip)
} else {
log.Debug(
"dnsforward: request %s %s from %s ignored; not logging",
dns.Type(q.Qtype),
"dnsforward: request %s %s %q from %s ignored; not adding to querylog",
dns.Class(cl),
dns.Type(qt),
host,
ip,
)
}
if s.stats != nil &&
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
// containing persistent client.
s.stats.ShouldCount(host, q.Qtype, q.Qclass, ids) {
if s.shouldCountStat(host, qt, cl, ids) {
s.updateStats(dctx, elapsed, *dctx.result, ipStr)
} else {
log.Debug(
"dnsforward: request %s %s %q from %s ignored; not counting in stats",
dns.Class(cl),
dns.Type(qt),
host,
ip,
)
}
return resultCodeSuccess
}
// shouldLog returns true if the query with the given data should be logged in
// the query log. s.serverLock is expected to be locked.
func (s *Server) shouldLog(host string, qt, cl uint16, ids []string) (ok bool) {
if qt == dns.TypeANY && s.conf.RefuseAny {
return false
}
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start containing
// persistent client.
return s.queryLog != nil && s.queryLog.ShouldLog(host, qt, cl, ids)
}
// shouldCountStat returns true if the query with the given data should be
// counted in the statistics. s.serverLock is expected to be locked.
func (s *Server) shouldCountStat(host string, qt, cl uint16, ids []string) (ok bool) {
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start containing
// persistent client.
return s.stats != nil && s.stats.ShouldCount(host, qt, cl, ids)
}
// logQuery pushes the request details into the query log.
func (s *Server) logQuery(
dctx *dnsContext,
@@ -123,7 +141,10 @@ func (s *Server) updateStats(
pctx := ctx.proxyCtx
e := stats.Entry{}
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
if e.Domain != "." {
// Remove last ".", but save the domain as is for "." queries.
e.Domain = e.Domain[:len(e.Domain)-1]
}
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID

View File

@@ -46,6 +46,10 @@ type testStats struct {
// Update implements the [stats.Interface] interface for *testStats.
func (l *testStats) Update(e stats.Entry) {
if e.Domain == "" {
return
}
l.lastEntry = e
}
@@ -54,9 +58,12 @@ func (l *testStats) ShouldCount(string, uint16, uint16, []string) bool {
return true
}
func TestProcessQueryLogsAndStats(t *testing.T) {
func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
const domain = "example.com."
testCases := []struct {
name string
domain string
proto proxy.Proto
addr net.Addr
clientID string
@@ -67,6 +74,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult stats.Result
}{{
name: "success_udp",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -77,6 +85,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls_clientid",
domain: domain,
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "cli42",
@@ -87,6 +96,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls",
domain: domain,
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -97,6 +107,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RNotFiltered,
}, {
name: "success_quic",
domain: domain,
proto: proxy.ProtoQUIC,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -107,6 +118,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RNotFiltered,
}, {
name: "success_https",
domain: domain,
proto: proxy.ProtoHTTPS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -117,6 +129,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RNotFiltered,
}, {
name: "success_dnscrypt",
domain: domain,
proto: proxy.ProtoDNSCrypt,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -127,6 +140,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RNotFiltered,
}, {
name: "success_udp_filtered",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -137,6 +151,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RFiltered,
}, {
name: "success_udp_sb",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -147,6 +162,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RSafeBrowsing,
}, {
name: "success_udp_ss",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -157,6 +173,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantStatResult: stats.RSafeSearch,
}, {
name: "success_udp_pc",
domain: domain,
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
@@ -165,6 +182,17 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
wantCode: resultCodeSuccess,
reason: filtering.FilteredParental,
wantStatResult: stats.RParental,
}, {
name: "success_udp_pc_empty_fqdn",
domain: ".",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 5}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.5",
wantCode: resultCodeSuccess,
reason: filtering.FilteredParental,
wantStatResult: stats.RParental,
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
@@ -181,7 +209,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := &dns.Msg{
Question: []dns.Question{{
Name: "example.com.",
Name: tc.domain,
}},
}
pctx := &proxy.DNSContext{

View File

@@ -0,0 +1,311 @@
package dnsforward
import (
"bytes"
"fmt"
"net"
"net/url"
"os"
"strings"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (s *Server) loadUpstreams() (upstreams []string, err error) {
if s.conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
}
var data []byte
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}
// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings() (err error) {
// We're setting a customized set of RootCAs. The reason is that Go default
// mechanism of loading TLS roots does not always work properly on some
// routers so we're loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
upstream.RootCAs = s.conf.TLSv12Roots
upstream.CipherSuites = s.conf.TLSCiphers
// Load upstreams either from the file, or from the settings
var upstreams []string
upstreams, err = s.loadUpstreams()
if err != nil {
return fmt.Errorf("loading upstreams: %w", err)
}
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
return nil
}
// prepareUpstreamConfig sets upstream configuration based on upstreams and
// configuration of s.
func (s *Server) prepareUpstreamConfig(
upstreams []string,
defaultUpstreams []string,
opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) {
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
if err != nil {
return nil, fmt.Errorf("parsing upstream config: %w", err)
}
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
var defaultUpstreamConfig *proxy.UpstreamConfig
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
if err != nil {
return nil, fmt.Errorf("parsing default upstreams: %w", err)
}
uc.Upstreams = defaultUpstreamConfig.Upstreams
}
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
err = s.replaceUpstreamsWithHosts(uc, opts)
if err != nil {
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
}
}
return uc, nil
}
// replaceUpstreamsWithHosts replaces unique upstreams with their resolved
// versions based on the system hosts file.
//
// TODO(e.burkov): This should be performed inside dnsproxy, which should
// actually consider /etc/hosts. See TODO on [aghnet.HostsContainer].
func (s *Server) replaceUpstreamsWithHosts(
upsConf *proxy.UpstreamConfig,
opts *upstream.Options,
) (err error) {
resolved := map[string]*upstream.Options{}
err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts)
if err != nil {
return fmt.Errorf("resolving upstreams: %w", err)
}
hosts := maps.Keys(upsConf.DomainReservedUpstreams)
// TODO(e.burkov): Think of extracting sorted range into an util function.
slices.Sort(hosts)
for _, host := range hosts {
err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts)
if err != nil {
return fmt.Errorf("resolving upstreams reserved for %s: %w", host, err)
}
}
hosts = maps.Keys(upsConf.SpecifiedDomainUpstreams)
slices.Sort(hosts)
for _, host := range hosts {
err = s.resolveUpstreamsWithHosts(resolved, upsConf.SpecifiedDomainUpstreams[host], opts)
if err != nil {
return fmt.Errorf("resolving upstreams specific for %s: %w", host, err)
}
}
return nil
}
// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams
// and replaces those both in upstreams and resolved. Upstreams that failed to
// resolve are placed to resolved as-is. This function only returns error of
// upstreams closing.
func (s *Server) resolveUpstreamsWithHosts(
resolved map[string]*upstream.Options,
upstreams []upstream.Upstream,
opts *upstream.Options,
) (err error) {
for i := range upstreams {
u := upstreams[i]
addr := u.Address()
host := extractUpstreamHost(addr)
withIPs, ok := resolved[host]
if !ok {
ips := s.resolveUpstreamHost(host)
if len(ips) == 0 {
resolved[host] = nil
return nil
}
sortNetIPAddrs(ips, opts.PreferIPv6)
withIPs = opts.Clone()
withIPs.ServerIPAddrs = ips
resolved[host] = withIPs
} else if withIPs == nil {
continue
}
if err = u.Close(); err != nil {
return fmt.Errorf("closing upstream %s: %w", addr, err)
}
upstreams[i], err = upstream.AddressToUpstream(addr, withIPs)
if err != nil {
return fmt.Errorf("replacing upstream %s with resolved %s: %w", addr, host, err)
}
log.Debug("dnsforward: using %s for %s", withIPs.ServerIPAddrs, upstreams[i].Address())
}
return nil
}
// extractUpstreamHost returns the hostname of addr without port with an
// assumption that any address passed here has already been successfully parsed
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
func extractUpstreamHost(addr string) (host string) {
var err error
if strings.Contains(addr, "://") {
var u *url.URL
u, err = url.Parse(addr)
if err != nil {
log.Debug("dnsforward: parsing upstream %s: %s", addr, err)
return addr
}
return u.Hostname()
}
// Probably, plain UDP upstream defined by address or address:port.
host, err = netutil.SplitHost(addr)
if err != nil {
return addr
}
return host
}
// resolveUpstreamHost returns the version of ups with IP addresses from the
// system hosts file placed into its options.
func (s *Server) resolveUpstreamHost(host string) (addrs []net.IP) {
req := &urlfilter.DNSRequest{
Hostname: host,
DNSType: dns.TypeA,
}
aRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
req.DNSType = dns.TypeAAAA
aaaaRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
var ips []net.IP
for _, rw := range append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...) {
dr := rw.DNSRewrite
if dr == nil || dr.Value == nil {
continue
}
if ip, ok := dr.Value.(net.IP); ok {
ips = append(ips, ip)
}
}
return ips
}
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end.
//
// TODO(e.burkov): This function taken from dnsproxy, which also already
// contains a few similar functions. Think of moving to golibs.
func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
l := len(addrs)
if l <= 1 {
return
}
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (sortsBefore bool) {
switch len(addrA) {
case net.IPv4len, net.IPv6len:
switch len(addrB) {
case net.IPv4len, net.IPv6len:
// Go on.
default:
return true
}
default:
return false
}
if aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil; aIs4 != bIs4 {
if aIs4 {
return !preferIPv6
}
return preferIPv6
}
return bytes.Compare(addrA, addrB) < 0
})
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
if !http3 {
return upstream.DefaultHTTPVersions
}
return []upstream.HTTPVersion{
upstream.HTTPVersion3,
upstream.HTTPVersion2,
upstream.HTTPVersion11,
}
}
// setProxyUpstreamMode sets the upstream mode and related settings in conf
// based on provided parameters.
func setProxyUpstreamMode(
conf *proxy.Config,
allServers bool,
fastestAddr bool,
fastestTimeout time.Duration,
) {
if allServers {
conf.UpstreamMode = proxy.UModeParallel
} else if fastestAddr {
conf.UpstreamMode = proxy.UModeFastestAddr
conf.FastestPingTimeout = fastestTimeout
} else {
conf.UpstreamMode = proxy.UModeLoadBalance
}
}

View File

@@ -2,9 +2,12 @@ package filtering
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"golang.org/x/exp/slices"
@@ -44,23 +47,57 @@ func initBlockedServices() {
log.Debug("filtering: initialized %d services", l)
}
// BlockedSvcKnown returns true if a blocked service ID is known.
func BlockedSvcKnown(s string) (ok bool) {
_, ok = serviceRules[s]
// BlockedServices is the configuration of blocked services.
type BlockedServices struct {
// Schedule is blocked services schedule for every day of the week.
Schedule *schedule.Weekly `yaml:"schedule"`
return ok
// IDs is the names of blocked services.
IDs []string `yaml:"ids"`
}
// Clone returns a deep copy of blocked services.
func (s *BlockedServices) Clone() (c *BlockedServices) {
if s == nil {
return nil
}
return &BlockedServices{
Schedule: s.Schedule.Clone(),
IDs: slices.Clone(s.IDs),
}
}
// Validate returns an error if blocked services contain unknown service ID. s
// must not be nil.
func (s *BlockedServices) Validate() (err error) {
for _, id := range s.IDs {
_, ok := serviceRules[id]
if !ok {
return fmt.Errorf("unknown blocked-service %q", id)
}
}
return nil
}
// ApplyBlockedServices - set blocked services settings for this DNS request
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
func (d *DNSFilter) ApplyBlockedServices(setts *Settings) {
d.confLock.RLock()
defer d.confLock.RUnlock()
setts.ServicesRules = []ServiceEntry{}
if list == nil {
d.confLock.RLock()
defer d.confLock.RUnlock()
list = d.Config.BlockedServices
bsvc := d.BlockedServices
// TODO(s.chzhen): Use startTime from [dnsforward.dnsContext].
if !bsvc.Schedule.Contains(time.Now()) {
d.ApplyBlockedServicesList(setts, bsvc.IDs)
}
}
// ApplyBlockedServicesList appends filtering rules to the settings.
func (d *DNSFilter) ApplyBlockedServicesList(setts *Settings, list []string) {
for _, name := range list {
rules, ok := serviceRules[name]
if !ok {
@@ -90,7 +127,7 @@ func (d *DNSFilter) handleBlockedServicesAll(w http.ResponseWriter, r *http.Requ
func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
d.confLock.RLock()
list := d.Config.BlockedServices
list := d.Config.BlockedServices.IDs
d.confLock.RUnlock()
_ = aghhttp.WriteJSONResponse(w, r, list)
@@ -106,7 +143,7 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
}
d.confLock.Lock()
d.Config.BlockedServices = list
d.Config.BlockedServices.IDs = list
d.confLock.Unlock()
log.Debug("Updated blocked services list: %d", len(list))

View File

@@ -103,9 +103,9 @@ type Config struct {
Rewrites []*LegacyRewrite `yaml:"rewrites"`
// Names of services to block (globally).
// BlockedServices is the configuration of blocked services.
// Per-client settings can override this configuration.
BlockedServices []string `yaml:"blocked_services"`
BlockedServices *BlockedServices `yaml:"blocked_services"`
// EtcHosts is a container of IP-hostname pairs taken from the operating
// system configuration files (e.g. /etc/hosts).
@@ -298,12 +298,12 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
atomic.StoreUint32(&d.enabled, mathutil.BoolToNumber[uint32](enabled))
}
// GetConfig - get configuration
func (d *DNSFilter) GetConfig() (s Settings) {
// Settings returns filtering settings.
func (d *DNSFilter) Settings() (s *Settings) {
d.confLock.RLock()
defer d.confLock.RUnlock()
return Settings{
return &Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchConf.Enabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
@@ -519,7 +519,7 @@ func (d *DNSFilter) matchSysHosts(
dnsres, _ := d.EtcHosts.MatchRequest(&urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
// TODO(e.burkov): Wait for urlfilter update to pass netip.Addr.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
@@ -987,16 +987,13 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
return nil, fmt.Errorf("rewrites: preparing: %s", err)
}
bsvcs := []string{}
for _, s := range d.BlockedServices {
if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service %q", s)
if d.BlockedServices != nil {
err = d.BlockedServices.Validate()
continue
if err != nil {
return nil, fmt.Errorf("filtering: %w", err)
}
bsvcs = append(bsvcs, s)
}
d.BlockedServices = bsvcs
if blockFilters != nil {
err = d.initFiltering(nil, blockFilters)

View File

@@ -169,7 +169,7 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
deleted = (*filters)[delIdx]
p := deleted.Path(d.DataDir)
err = os.Rename(p, p+".old")
if err != nil {
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("deleting filter %d: renaming file %q: %s", deleted.ID, p, err)
return
@@ -416,12 +416,12 @@ type checkHostResp struct {
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
host := r.URL.Query().Get("name")
setts := d.GetConfig()
setts := d.Settings()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
d.ApplyBlockedServices(&setts, nil)
result, err := d.CheckHost(host, dns.TypeA, &setts)
d.ApplyBlockedServices(setts)
result, err := d.CheckHost(host, dns.TypeA, setts)
if err != nil {
aghhttp.Error(
r,
@@ -555,6 +555,7 @@ func (d *DNSFilter) RegisterFilteringHandlers() {
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
registerHTTP(http.MethodPut, "/control/rewrite/update", d.handleRewriteUpdate)
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesIDs)

View File

@@ -84,7 +84,7 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
return nil
}
// TODO(a.garipov): Check cnames for cycles on initialisation.
// TODO(a.garipov): Check cnames for cycles on initialization.
cnames := stringutil.NewSet()
host := dReq.Hostname
for len(rrules) > 0 && rrules[0].DNSRewrite != nil && rrules[0].DNSRewrite.NewCNAME != "" {

View File

@@ -6,6 +6,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices"
)
// TODO(d.kolyshev): Use [rewrite.Item] instead.
@@ -91,3 +92,62 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
d.Config.ConfigModified()
}
// rewriteUpdateJSON is a struct for JSON object with rewrite rule update info.
type rewriteUpdateJSON struct {
Target rewriteEntryJSON `json:"target"`
Update rewriteEntryJSON `json:"update"`
}
// handleRewriteUpdate is the handler for the PUT /control/rewrite/update HTTP
// API.
func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) {
updateJSON := rewriteUpdateJSON{}
err := json.NewDecoder(r.Body).Decode(&updateJSON)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
rwDel := &LegacyRewrite{
Domain: updateJSON.Target.Domain,
Answer: updateJSON.Target.Answer,
}
rwAdd := &LegacyRewrite{
Domain: updateJSON.Update.Domain,
Answer: updateJSON.Update.Answer,
}
err = rwAdd.normalize()
if err != nil {
// Shouldn't happen currently, since normalize only returns a non-nil
// error when a rewrite is nil, but be change-proof.
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
return
}
index := -1
defer func() {
if index >= 0 {
d.Config.ConfigModified()
}
}()
d.confLock.Lock()
defer d.confLock.Unlock()
index = slices.IndexFunc(d.Config.Rewrites, rwDel.equal)
if index == -1 {
aghhttp.Error(r, w, http.StatusBadRequest, "target rule not found")
return
}
d.Config.Rewrites = slices.Replace(d.Config.Rewrites, index, index+1, rwAdd)
log.Debug("rewrite: removed element: %s -> %s", rwDel.Domain, rwDel.Answer)
log.Debug("rewrite: added element: %s -> %s", rwAdd.Domain, rwAdd.Answer)
}

View File

@@ -0,0 +1,237 @@
package filtering_test
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(d.kolyshev): Use [rewrite.Item] instead.
type rewriteJSON struct {
Domain string `json:"domain"`
Answer string `json:"answer"`
}
type rewriteUpdateJSON struct {
Target rewriteJSON `json:"target"`
Update rewriteJSON `json:"update"`
}
const (
// testTimeout is the common timeout for tests.
testTimeout = 100 * time.Millisecond
listURL = "/control/rewrite/list"
addURL = "/control/rewrite/add"
deleteURL = "/control/rewrite/delete"
updateURL = "/control/rewrite/update"
decodeErrorMsg = "json.Decode: json: cannot unmarshal string into Go value of type" +
" filtering.rewriteEntryJSON\n"
)
func TestDNSFilter_handleRewriteHTTP(t *testing.T) {
confModCh := make(chan struct{})
reqCh := make(chan struct{})
testRewrites := []*rewriteJSON{
{Domain: "example.local", Answer: "example.rewrite"},
{Domain: "one.local", Answer: "one.rewrite"},
}
testRewritesJSON, mErr := json.Marshal(testRewrites)
require.NoError(t, mErr)
testCases := []struct {
reqData any
name string
url string
method string
wantList []*rewriteJSON
wantBody string
wantConfMod bool
wantStatus int
}{{
name: "list",
url: listURL,
method: http.MethodGet,
reqData: nil,
wantConfMod: false,
wantStatus: http.StatusOK,
wantBody: string(testRewritesJSON) + "\n",
wantList: testRewrites,
}, {
name: "add",
url: addURL,
method: http.MethodPost,
reqData: rewriteJSON{Domain: "add.local", Answer: "add.rewrite"},
wantConfMod: true,
wantStatus: http.StatusOK,
wantBody: "",
wantList: append(
testRewrites,
&rewriteJSON{Domain: "add.local", Answer: "add.rewrite"},
),
}, {
name: "add_error",
url: addURL,
method: http.MethodPost,
reqData: "invalid_json",
wantConfMod: false,
wantStatus: http.StatusBadRequest,
wantBody: decodeErrorMsg,
wantList: testRewrites,
}, {
name: "delete",
url: deleteURL,
method: http.MethodPost,
reqData: rewriteJSON{Domain: "one.local", Answer: "one.rewrite"},
wantConfMod: true,
wantStatus: http.StatusOK,
wantBody: "",
wantList: []*rewriteJSON{{Domain: "example.local", Answer: "example.rewrite"}},
}, {
name: "delete_error",
url: deleteURL,
method: http.MethodPost,
reqData: "invalid_json",
wantConfMod: false,
wantStatus: http.StatusBadRequest,
wantBody: decodeErrorMsg,
wantList: testRewrites,
}, {
name: "update",
url: updateURL,
method: http.MethodPut,
reqData: rewriteUpdateJSON{
Target: rewriteJSON{Domain: "one.local", Answer: "one.rewrite"},
Update: rewriteJSON{Domain: "upd.local", Answer: "upd.rewrite"},
},
wantConfMod: true,
wantStatus: http.StatusOK,
wantBody: "",
wantList: []*rewriteJSON{
{Domain: "example.local", Answer: "example.rewrite"},
{Domain: "upd.local", Answer: "upd.rewrite"},
},
}, {
name: "update_error",
url: updateURL,
method: http.MethodPut,
reqData: "invalid_json",
wantConfMod: false,
wantStatus: http.StatusBadRequest,
wantBody: "json.Decode: json: cannot unmarshal string into Go value of type" +
" filtering.rewriteUpdateJSON\n",
wantList: testRewrites,
}, {
name: "update_error_target",
url: updateURL,
method: http.MethodPut,
reqData: rewriteUpdateJSON{
Target: rewriteJSON{Domain: "inv.local", Answer: "inv.rewrite"},
Update: rewriteJSON{Domain: "upd.local", Answer: "upd.rewrite"},
},
wantConfMod: false,
wantStatus: http.StatusBadRequest,
wantBody: "target rule not found\n",
wantList: testRewrites,
}}
for _, tc := range testCases {
onConfModified := func() {
if !tc.wantConfMod {
panic("config modified has been fired")
}
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
}
t.Run(tc.name, func(t *testing.T) {
handlers := make(map[string]http.Handler)
d, err := filtering.New(&filtering.Config{
ConfigModified: onConfModified,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
Rewrites: rewriteEntriesToLegacyRewrites(testRewrites),
}, nil)
require.NoError(t, err)
t.Cleanup(d.Close)
d.RegisterFilteringHandlers()
require.NotEmpty(t, handlers)
require.Contains(t, handlers, listURL)
require.Contains(t, handlers, tc.url)
var body io.Reader
if tc.reqData != nil {
data, rErr := json.Marshal(tc.reqData)
require.NoError(t, rErr)
body = bytes.NewReader(data)
}
r := httptest.NewRequest(tc.method, tc.url, body)
w := httptest.NewRecorder()
go func() {
handlers[tc.url].ServeHTTP(w, r)
testutil.RequireSend(testutil.PanicT{}, reqCh, struct{}{}, testTimeout)
}()
if tc.wantConfMod {
testutil.RequireReceive(t, confModCh, testTimeout)
}
testutil.RequireReceive(t, reqCh, testTimeout)
assert.Equal(t, tc.wantStatus, w.Code)
respBody, err := io.ReadAll(w.Body)
require.NoError(t, err)
assert.Equal(t, []byte(tc.wantBody), respBody)
assertRewritesList(t, handlers[listURL], tc.wantList)
})
}
}
// assertRewritesList checks if rewrites list equals the list received from the
// handler by listURL.
func assertRewritesList(t *testing.T, handler http.Handler, wantList []*rewriteJSON) {
t.Helper()
r := httptest.NewRequest(http.MethodGet, listURL, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
var actual []*rewriteJSON
err := json.NewDecoder(w.Body).Decode(&actual)
require.NoError(t, err)
assert.Equal(t, wantList, actual)
}
// rewriteEntriesToLegacyRewrites gets legacy rewrites from json entries.
func rewriteEntriesToLegacyRewrites(entries []*rewriteJSON) (rw []*filtering.LegacyRewrite) {
for _, entry := range entries {
rw = append(rw, &filtering.LegacyRewrite{
Domain: entry.Domain,
Answer: entry.Answer,
})
}
return rw
}

View File

@@ -161,12 +161,8 @@ func (ss *Default) resetEngine(
// type check
var _ filtering.SafeSearch = (*Default)(nil)
// CheckHost implements the [filtering.SafeSearch] interface for
// *DefaultSafeSearch.
func (ss *Default) CheckHost(
host string,
qtype rules.RRType,
) (res filtering.Result, err error) {
// CheckHost implements the [filtering.SafeSearch] interface for *Default.
func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Result, err error) {
start := time.Now()
defer func() {
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
@@ -196,14 +192,10 @@ func (ss *Default) CheckHost(
return filtering.Result{}, err
}
if fltRes != nil {
res = *fltRes
ss.setCacheResult(host, qtype, res)
res = *fltRes
ss.setCacheResult(host, qtype, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host)
return res, nil
}
// searchHost looks up DNS rewrites in the internal DNS filtering engine.
@@ -229,7 +221,11 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
}
// newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA].
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the
// empty result is converted into a NODATA response.
//
// TODO(a.garipov): Use the main rewrite result mechanism used in
// [dnsforward.Server.filterDNSRequest].
func (ss *Default) newResult(
rewrite *rules.DNSRewrite,
qtype rules.RRType,
@@ -243,9 +239,10 @@ func (ss *Default) newResult(
}
if rewrite.RRType == qtype {
ip, ok := rewrite.Value.(net.IP)
v := rewrite.Value
ip, ok := v.(net.IP)
if !ok || ip == nil {
return nil, nil
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", v)
}
res.Rules[0].IP = ip
@@ -255,14 +252,14 @@ func (ss *Default) newResult(
host := rewrite.NewCNAME
if host == "" {
return nil, nil
return res, nil
}
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, err
return nil, fmt.Errorf("resolving cname: %w", err)
}
ss.log(log.DEBUG, "resolved %s", ips)
@@ -276,11 +273,9 @@ func (ss *Default) newResult(
}
res.Rules[0].IP = ip
return res, nil
}
return nil, nil
return res, nil
}
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].

View File

@@ -1,6 +1,7 @@
package safesearch_test
import (
"context"
"net"
"testing"
"time"
@@ -71,6 +72,25 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
}
}
func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Nil(t, res.Rules[0].IP)
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
}
func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.TestResolver{}
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
@@ -105,6 +125,56 @@ func TestDefault_CheckHost_google(t *testing.T) {
}
}
// testResolver is a [filtering.Resolver] for tests.
//
// TODO(a.garipov): Move to aghtest and use everywhere.
type testResolver struct {
OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error)
}
// type check
var _ filtering.Resolver = (*testResolver)(nil)
// LookupIP implements the [filtering.Resolver] interface for *testResolver.
func (r *testResolver) LookupIP(
ctx context.Context,
network string,
host string,
) (ips []net.IP, err error) {
return r.OnLookupIP(ctx, network, host)
}
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
conf := testConf
conf.CustomResolver = &testResolver{
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
assert.Equal(t, "ip6", network)
assert.Equal(t, "safe.duckduckgo.com", host)
return nil, nil
},
}
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
// DuckDuckGo doesn't have a safe-search IPv6 address. The result should be
// the same as the one for Yandex IPv6. That is, a NODATA response.
res, err := ss.CheckHost("www.duckduckgo.com", dns.TypeAAAA)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Nil(t, res.Rules[0].IP)
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
}
func TestDefault_Update(t *testing.T) {
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)

View File

@@ -27,6 +27,25 @@ var blockedServices = []blockedService{{
"||9cache.com^",
"||9gag.com^",
},
}, {
ID: "activision_blizzard",
Name: "Activision Blizzard",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"-237 0 1572 1572\"><path d=\"m549.1.2 548.4 1571.4H798l-74.2-200H374.5l-74.3 200H.7zM626 1085.1l-83-274.3-82.9 274.3z\"/></svg>"),
Rules: []string{
"||activision.com^",
"||activisionblizzard.com^",
"||demonware.net^",
},
}, {
ID: "aliexpress",
Name: "AliExpress",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M9 4C6.25 4 4 6.25 4 9v32c0 2.75 2.25 5 5 5h32c2.75 0 5-2.25 5-5V9c0-2.75-2.25-5-5-5H9zm0 2h32c1.668 0 3 1.332 3 3v3.38A3.973 3.973 0 0 0 41 11H9a3.973 3.973 0 0 0-3 1.38V9c0-1.668 1.332-3 3-3zm6 11a1 1 0 0 1 1 1c0 4.962 4.037 9 9 9s9-4.038 9-9a1 1 0 1 1 2 0c0 6.065-4.935 11-11 11s-11-4.935-11-11a1 1 0 0 1 1-1z\"/></svg>"),
Rules: []string{
"||ae-rus.net^",
"||ae-rus.ru^",
"||aliexpress.com^",
"||aliexpress.ru^",
},
}, {
ID: "amazon",
Name: "Amazon",
@@ -234,6 +253,16 @@ var blockedServices = []blockedService{{
"||z.cn^",
"||zappos^",
},
}, {
ID: "battle_net",
Name: "Battle.net",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M43.11 22.15s3.95.2 3.95-2.12c0-3.03-5.26-5.77-5.26-5.77s.83-1.74 1.34-2.72a37.3 37.3 0 0 0 2.09-5.65c.16-1.1-.09-1.44-.09-1.44-.35 2.34-4.17 9.09-4.47 9.32-3.72-1.75-8.83-2.23-8.83-2.23S26.84 1 22.13 1c-4.67 0-4.65 9.02-4.65 9.02s-1.32-2.56-2.97-2.56c-2.42 0-3.22 3.67-3.22 7.64a37.8 37.8 0 0 0-9.16 1.17c-.36.1-1.49.92-.97.82 1.04-.34 5.95-1.1 10.25-.72.24 3.77 2.44 8.68 2.44 8.68S9.13 31.9 9.13 36.78c0 1.29.56 3.64 3.95 3.64 2.84 0 6.03-1.7 6.63-2.06a6.33 6.33 0 0 0-.91 2.83c0 .54.31 2.06 2.5 2.06 2.82 0 5.96-2.16 5.96-2.16s2.96 4.93 5.5 7.2c.69.6 1.34.71 1.34.71s-2.52-2.43-5.84-8.68c3.08-1.9 6.3-6.4 6.3-6.4l3.3.01c4.6 0 11.11-.96 11.11-4.61 0-3.77-5.86-7.17-5.86-7.17Zm.52-2.26c0 1.33-1.27 1.3-1.27 1.3l-.97.08s-1.82-.97-2.93-1.41c0 0 1.72-2.65 2.12-3.4.3.18 3.05 1.9 3.05 3.43ZM24.43 6.3c2.15 0 5.23 5.1 5.23 5.1s-4.8-.44-8.76 1.89c.1-3.67 1.34-7 3.52-7Zm-8.56 4.13c.69 0 1.36.83 1.64 1.54 0 .47.24 3.2.24 3.2l-3.96-.16c0-3.57 1.4-4.58 2.08-4.58Zm-.4 24.8c-2.17 0-2.62-1.2-2.62-2.29 0-2.45 1.96-5.9 1.96-5.9s2.2 4.63 6.04 6.59a10.02 10.02 0 0 1-5.39 1.6Zm7.02 4.85c-1.52 0-1.7-.98-1.7-1.21 0-.7.55-1.54.55-1.54s2.55-1.73 2.71-1.91l1.89 3.52s-1.93 1.14-3.45 1.14Zm4.74-1.92c-.93-1.62-1.6-3.3-1.6-3.3s3.78.24 5.82-1.86a11.2 11.2 0 0 1-5.65 1.07c4.93-4.34 7.8-7.48 10.23-10.74a9.46 9.46 0 0 0-1.6-1.15c-1.46 1.76-7.16 7.86-12.45 10.88-6.69-3.64-8.09-14.38-8.23-16.6l3.65.34s-1.37 2.44-1.37 4.23c0 1.79.21 1.89.21 1.89s-.04-3.13 1.89-5.54c1.46 7.82 3 11.83 4.19 14.22.6-.25 1.74-.76 1.74-.76s-3.38-9.73-3.19-16.31a13.8 13.8 0 0 1 6.36-1.66c6.73 0 12.14 2.9 12.14 2.9l-2.12 2.95s-1.89-3.42-4.55-4.03c1.4 1.05 2.98 2.44 3.8 4.43a68.4 68.4 0 0 0-14.47-3.59c-.19.8-.17 1.94-.17 1.94s9.03 1.66 15.6 5.43c-.05 8.21-9 14.53-10.23 15.26Zm8.55-6.14s2.8-3.68 2.76-8.55c0 0 4.52 2.8 4.52 5.54 0 3.05-7.28 3-7.28 3Z\"/></svg>"),
Rules: []string{
"||battle.net^",
"||battlenet.com.cn^",
"||bnet.163.com^",
"||bnet.cn^",
},
}, {
ID: "bilibili",
Name: "Bilibili",
@@ -283,6 +312,21 @@ var blockedServices = []blockedService{{
"||mincdn.com^",
"||yo9.com^",
},
}, {
ID: "blizzard_entertainment",
Name: "Blizzard Entertainment",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 -32 128 128\"><path fill-rule=\"evenodd\" d=\"M105 2h3v1h2l2 1 1 1h3l1 1h4l1 1 2 2v1l1 3v4l1 2v6l-1 2v2l-1 3v2l-1 2v14l-1 2v1l-1 3-1 1h-3l-1 1h-6a5 5 0 0 0 1-6l2-1h-1l-1-3v-3a350 350 0 0 1 0-8l-1-3v-1l-1-1V9h1V6l-1-1-4-3Zm9 13v10h1v25a8 8 0 0 0 2-4l1-1 1-3V30l1-1v-2l1-1v-5l-1-2-2-3-1-1h-3Z\" clip-rule=\"evenodd\"/><path fill-rule=\"evenodd\" d=\"M101 24v1l2 1h1v2h1l1 2v5l1 2s0-1 0 0l1 7 1 2v7l-1 5h-2l-2-2-4-1 1-3 1-2a22 22 0 0 0-1-10l-1-4h-1l1-4-1-1-2-3v2l-1 1v3l1 6v4l1 1-1 3v4l1 1-1 2v4l-1-1a13 13 0 0 0-4-5l-2-2 2-5V27l-1-1v-4l-1-1v-5h-1a33 33 0 0 1 0-4l4-4h-2l-4-4h-1V3h10l2 1 2 1h1c2 0 2 1 3 2l2 3 1 3v1l-1 2v1a11 11 0 0 1-1 4l-4 3ZM96 9v13l1 1a3 3 0 0 0 1-1c1 0 2-1 2-3v-1l1-1v-3l-2-3-2-2h-1ZM26 3l1 1h1l2 3v5l1 1v2l-1 1v9l1 1 1 1-1 7v9l-1 1 1 1-1 1v8h3l1-1h7v-1h16v6l1 2h-6l-1-1h-2l-1-1H31a4 4 0 0 0-3-1l-1 1h-1l-1 1h-5l1-1a10 10 0 0 0 3-2v-9l1-1-1-1V35l1-1V21l-1-1v-4l1-1v-3l1-2-1-3h-1l-2-2-1-1 1-1h4Z\" clip-rule=\"evenodd\"/><path fill-rule=\"evenodd\" d=\"M84 60v-3l-1-2v-4l-3-2v-1l1-2a11 11 0 0 0 2-6l-1-1-3-2h-2v3l1 1h1l-1 2h-4l-2 1-2 1-1-2v-1l1-1 1-1 1-2v-5l1-1v-6l1-1v-3l1-1v-3l1-2 1-1-1-1 1-1 1-3 1-1V7l1-1c1-1 0-4 2-3l1 3 1 1 1 2v1l1 5 1 3v2l1 1v2l1 1v8l1 3v9l-1 1-2 5v3l-1 2v4l-1 1h-1Zm-4-36-1 1v2l-1 2v4l4 1h2v-7l-1-3-2-1-1-1v2Z\" clip-rule=\"evenodd\"/><path fill-rule=\"evenodd\" d=\"M77 4v1l-2 3v2l-1 2v1l-1 1-1 4v7h-1v2a5 5 0 0 1-1 2v2l-2 2v7l-1 2v2l-2 4v3-1h3v-1l3-1 1-1 3-2h3l1 1-1 1a3 3 0 0 0 0 1l1 1v5l-1 1h-7v-1h-2l-2 1h-4l-2 1-1-2v-2l1-1v-1l1-1-1-1 1-1v-2l1-2-1-2v-8l1-2 2-5-1-1 1-2v-1l1-1v-4l1-1 2-4v-2l1-1h-3V8h-1l-1 1-2 3-1 4h-1l-1-1v-2l1-1V4h16ZM32 4h9l1 2-3 2 1 2-1 1v13l1 2-1 2v6l-1 1v2l1 1v5l-1 1 1 2 1 1 2 1v2h-7l-2 1-1-1 3-2v-8a4 4 0 0 1 0-2l1-1v-3l-1-14v-2l1-1h-1V7l-2-1h-1l-1-1 1-1Zm12 0h14v15c-2 1-2 4-3 6v2c-1 0-3 1-2 4h-1l-2 3-1 2v3l-1 2-2 5h2l1-1h2l1-1c1-1 1-3 3-3l1-2 2-2h1l1 3h-1v1l-1 1v7h-8l-1 1-2-1h-3l-1-1 1-1v-3l-1-2 1-1-1-1 1-3v-2l1-2 1-3a7 7 0 0 1 2-4l1-4 2-2 2-3v-3h1l2-3V8l-3-1h-2l-1 1a3 3 0 0 0-2 3l-1 1v4l-1 1-1 1v-1l-1-1V4ZM17 22l1 1h1v3s0-1 0 0l2 1v5l1 2-1 8v3a6 6 0 0 1 0 2l-1 2-1 2-1 3-3 2-2 2-3 1-1-1-1 1H1l-1-1 2-1 1-4V26l1-1-1-3V11l1-1-1-1H2V8L1 7 0 6V5l1-1h15l1 1c2 0 3 1 3 2l1 3v6l-4 6Zm-6-11v9h1l1-2 2-1v-6h-1l-1-1h-2v1Zm0 19-1 1 1 2-1 3v9a2 2 0 0 0 0 1v6l-1 1 3-1 1-2h1v-4l1-4v-5l-1-1 1-3-1-1v-2s0 1 0 0v-1l-1-1a20 20 0 0 1-2-2v4Z\" clip-rule=\"evenodd\"/></svg>"),
Rules: []string{
"||battle.net^",
"||battlenet.com.cn^",
"||blizzard.cn^",
"||blizzardgames.cn^",
"||blz-contentstack.com^",
"||blzstatic.cn^",
"||bnet.163.com^",
"||bnet.cn^",
"||lizzard.com^",
},
}, {
ID: "cloudflare",
Name: "CloudFlare",
@@ -319,6 +363,14 @@ var blockedServices = []blockedService{{
"||warp.plus^",
"||workers.dev^",
},
}, {
ID: "clubhouse",
Name: "Clubhouse",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M29.8 4a1 1 0 0 0-.92.7 1 1 0 0 0 .36 1.1 31.2 31.2 0 0 1 6 6.02 1 1 0 1 0 1.6-1.2 33.2 33.2 0 0 0-6.4-6.4A1 1 0 0 0 29.8 4Zm-7.16 1.06c-.46 0-.87.3-.99.74a1 1 0 0 0 .5 1.15 31.13 31.13 0 0 1 11.13 10.6 1 1 0 1 0 1.7-1.07A33.12 33.12 0 0 0 23.11 5.2a.96.96 0 0 0-.48-.14ZM14.5 7.01a3.42 3.42 0 0 0-3.27 2.28l-.26-.27A3.49 3.49 0 0 0 8.5 8.01c-.9 0-1.8.34-2.48 1.01a3.51 3.51 0 0 0-.57 4.17c-.52.15-1.01.42-1.43.84a3.52 3.52 0 0 0 0 4.94l.27.27c-.46.16-.9.41-1.27.79a3.52 3.52 0 0 0 0 4.94l.88.88 16.47 16.47a9.01 9.01 0 0 0 12.72 0l4.23-4.22a9.94 9.94 0 0 0 2.3-3.59l2.63-7.08a8.03 8.03 0 0 1 1.84-2.87l1.74-1.73 1-1a4.02 4.02 0 0 0 0-5.66 4.02 4.02 0 0 0-5.66 0l-1 1-.7.71-4.2 4.2a2.98 2.98 0 0 1-4.24 0L17.9 8.96l-.94-.94a3.49 3.49 0 0 0-2.47-1.01Zm0 1.98c.38 0 .76.15 1.06.45l.94.94 13.1 13.1a5.02 5.02 0 0 0 7.08 0l4.2-4.18.7-.71 1-1c.8-.8 2.05-.8 2.83 0 .8.79.8 2.04 0 2.83l-2.73 2.73a10.03 10.03 0 0 0-2.3 3.58l-2.63 7.08a8.02 8.02 0 0 1-1.84 2.87l-4.23 4.23a6.99 6.99 0 0 1-9.9 0L4.44 23.56a1.5 1.5 0 0 1 0-2.12c.59-.59 1.45-.55 2.08.08l.1.09 8.2 8.37a1 1 0 0 0 .97.29 1 1 0 0 0 .46-1.68l-9.52-9.73-.01-.01-1.28-1.29a1.5 1.5 0 0 1 0-2.12c.6-.6 1.47-.58 2.08.03l9.18 9.17a1 1 0 0 0 1.69-.43 1 1 0 0 0-.28-.98L9 14.13l-.06-.07-1.5-1.5c-.6-.6-.6-1.53 0-2.12a1.5 1.5 0 0 1 2.12 0L20.8 21.67a1 1 0 0 0 1.68-.44 1 1 0 0 0-.27-.97l-8.7-8.7-.06-.06a1.4 1.4 0 0 1-.01-2.06c.3-.3.68-.45 1.06-.45ZM4.23 32a1 1 0 0 0-.82 1.51c3 5.18 7.36 9.46 12.59 12.37a1 1 0 0 0 1.51-.89 1 1 0 0 0-.54-.86A31.16 31.16 0 0 1 5.15 32.5a1.01 1.01 0 0 0-.92-.51Z\"/></svg>"),
Rules: []string{
"||clubhouse.com^",
"||clubhouseapi.com^",
},
}, {
ID: "crunchyroll",
Name: "Crunchyroll",
@@ -726,6 +778,18 @@ var blockedServices = []blockedService{{
"||xxbay.com^",
"||yibei.org^",
},
}, {
ID: "electronic_arts",
Name: "Electronic Arts",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 1000 1000\"><path d=\"M500 1000C224.3 1000 0 775.7 0 500S224.3 0 500 0s500 224.3 500 500-224.3 500-500 500zm84.63-693.4H302.05l-42.87 68.9h282.25zm57.75.66L469.63 582.33H278.02l44.2-68.96h114.85l43.87-68.93h-265.5l-43.86 68.93h62.9L147.2 651.05h364.2L645.9 438.9l49.05 74.46h-44.23l-41.88 68.96H739.8l45.48 68.72h83.54z\"/></svg>"),
Rules: []string{
"||ea.com^",
"||eamobile.com^",
"||easports.com^",
"||nearpolar.com^",
"||swtor.com^",
"||tnt-ea.com^",
},
}, {
ID: "epic_games",
Name: "Epic Games",
@@ -1390,11 +1454,39 @@ var blockedServices = []blockedService{{
"||line-apps.com^",
"||line-cdn.net^",
"||line-scdn.net^",
"||line.biz^",
"||line.me^",
"||line.naver.jp^",
"||linecorp.com^",
"||linefriends.com.tw^",
"||linefriends.com^",
"||linegame.jp^",
"||linemobile.com^",
"||linemyshop.com^",
"||lineshoppingseller.com^",
"||linetv.tw^",
},
}, {
ID: "linkedin",
Name: "LinkedIn",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M41,4H9C6.24,4,4,6.24,4,9v32c0,2.76,2.24,5,5,5h32c2.76,0,5-2.24,5-5V9C46,6.24,43.76,4,41,4z M17,20v19h-6V20H17z M11,14.47c0-1.4,1.2-2.47,3-2.47s2.93,1.07,3,2.47c0,1.4-1.12,2.53-3,2.53C12.2,17,11,15.87,11,14.47z M39,39h-6c0,0,0-9.26,0-10 c0-2-1-4-3.5-4.04h-0.08C27,24.96,26,27.02,26,29c0,0.91,0,10,0,10h-6V20h6v2.56c0,0,1.93-2.56,5.81-2.56 c3.97,0,7.19,2.73,7.19,8.26V39z\"/></svg>"),
Rules: []string{
"||bizographics.com^",
"||cs1404.wpc.epsiloncdn.net^",
"||cs767.wpc.epsiloncdn.net^",
"||l-0005.dc-msedge.net^",
"||l-0005.l-dc-msedge.net^",
"||l-0005.l-msedge.net^",
"||l-0015.l-msedge.net^",
"||licdn.cn^",
"||licdn.com^",
"||linkedin.at^",
"||linkedin.be^",
"||linkedin.cn^",
"||linkedin.com^",
"||linkedin.nl^",
"||linkedin.qtlcdn.com^",
"||lnkd.in^",
},
}, {
ID: "mail_ru",
@@ -1438,7 +1530,6 @@ var blockedServices = []blockedService{{
"||masto.pt^",
"||mastodon.au^",
"||mastodon.bida.im^",
"||mastodon.com.tr^",
"||mastodon.eus^",
"||mastodon.green^",
"||mastodon.ie^",
@@ -1454,7 +1545,7 @@ var blockedServices = []blockedService{{
"||mastodon.social^",
"||mastodon.uno^",
"||mastodon.world^",
"||mastodon.xyz^",
"||mastodon.zaclys.com^",
"||mastodonapp.uk^",
"||mastodonners.nl^",
"||mastodont.cat^",
@@ -1465,12 +1556,12 @@ var blockedServices = []blockedService{{
"||metalhead.club^",
"||mindly.social^",
"||mstdn.ca^",
"||mstdn.jp^",
"||mstdn.party^",
"||mstdn.plus^",
"||mstdn.social^",
"||muenchen.social^",
"||newsie.social^",
"||muenster.im^",
"||nerdculture.de^",
"||noc.social^",
"||norden.social^",
"||nrw.social^",
@@ -1498,16 +1589,17 @@ var blockedServices = []blockedService{{
"||techhub.social^",
"||theblower.au^",
"||tkz.one^",
"||todon.eu^",
"||toot.aquilenet.fr^",
"||toot.community^",
"||toot.funami.tech^",
"||toot.io^",
"||toot.wales^",
"||troet.cafe^",
"||twingyeo.kr^",
"||union.place^",
"||universeodon.com^",
"||urbanists.social^",
"||wien.rocks^",
"||wxw.moe^",
},
}, {
@@ -1550,6 +1642,44 @@ var blockedServices = []blockedService{{
"||nflxso.net^",
"||nflxvideo.net^",
},
}, {
ID: "nintendo",
Name: "Nintendo",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M6 7v36h12.6V21.75l13 20.78.27.47H44V7H31.4v1l.04 20.22L18.5 7.47 18.22 7Zm2 2h9.1l14.5 23.22 1.84 3v-3.5L33.4 9H42v32h-9L18.44 17.75l-1.85-2.94V41H8Z\"/></svg>"),
Rules: []string{
"||nintendo-europe.com^",
"||nintendo.be^",
"||nintendo.co.jp^",
"||nintendo.co.uk^",
"||nintendo.com.au^",
"||nintendo.com^",
"||nintendo.de^",
"||nintendo.es^",
"||nintendo.eu^",
"||nintendo.fr^",
"||nintendo.it^",
"||nintendo.jp^",
"||nintendo.net^",
"||nintendo.nl^",
"||nintendoswitch.cn^",
"||nintendowifi.net^",
},
}, {
ID: "nvidia",
Name: "Nvidia",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 48 48\"><path d=\"M20 8a2 2 0 0 0-2 2v2.55l.84-.05c10.76-.37 17.78 8.82 17.78 8.82s-8.05 9.8-16.44 9.8c-.73 0-1.47-.07-2.18-.19v-2.2c.73.23 1.52.35 2.3.35 5.88 0 11.35-7.6 11.35-7.6s-5.07-6.91-12.81-6.66l-.82.03v-2.3c-9.49.77-17.68 8.8-17.68 8.8S4.97 34.76 18 35.98v-2.44c.59.07 1.22.12 1.81.12 7.82 0 13.47-3.99 18.94-8.7.91.73 4.62 2.49 5.4 3.26-5.2 4.36-17.33 7.86-24.2 7.86-.66 0-1.32-.03-1.95-.1V38c0 1.1.9 2 2 2h25a2 2 0 0 0 2-2V10a2 2 0 0 0-2-2H20zm-2 6.86v2.82a11.8 11.8 0 0 1 1.57-.07c4.95 0 7.9 3.85 7.9 3.85l-4.03 3.39c-1.8-3.02-2.43-4.35-5.44-4.7v8.57c-4.06-1.38-5.4-6.14-5.4-6.14s2.37-2.83 5.38-2.46H18v-2.44a15.66 15.66 0 0 0-9.22 4.46s2 7.52 9.22 8.8v2.6c-9.56-1.17-12.82-11.7-12.82-11.7s4.27-6.3 12.82-6.97z\"/></svg>"),
Rules: []string{
"||geforce.com^",
"||geforcenow.com^",
"||nvidia.cn^",
"||nvidia.com.global.ogslb.com^",
"||nvidia.com^",
"||nvidia.eu^",
"||nvidia.partners^",
"||nvidiagrid.net^",
"||nvidianews.com^",
"||tegrazone.com^",
},
}, {
ID: "ok",
Name: "OK.ru",
@@ -1704,6 +1834,14 @@ var blockedServices = []blockedService{{
"||robloxcdn.com^",
"||robloxdev.cn^",
},
}, {
ID: "rockstar_games",
Name: "Rockstar games",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M12 3c-4.96 0-9 4.04-9 9v26c0 4.96 4.04 9 9 9h26c4.96 0 9-4.04 9-9V12c0-4.96-4.04-9-9-9H12zm0 2h26c3.88 0 7 3.12 7 7v26c0 3.88-3.12 7-7 7H12c-3.88 0-7-3.12-7-7V12c0-3.88 3.12-7 7-7zm3.72 5a1 1 0 0 0-.97.79l-3.87 18a1 1 0 0 0 .98 1.21h4.27a1 1 0 0 0 .97-.79L18.47 23h2.07c.94 0 1.12.15 1.36.73.24.57.3 1.76.1 3.4-.08.68-.05 1.22.02 1.6v.03a1 1 0 0 0 .3.97l3.37 3.12-2.6 5.74a1 1 0 0 0 1.43 1.26l5.58-3.39 4.29 3.33a1 1 0 0 0 1.6-.98l-1.09-5.56 4.7-3.47a1 1 0 0 0-.6-1.8h-4.86l-.82-5.14a1 1 0 0 0-.98-.84 1 1 0 0 0-.88.51l-2.77 5a14.3 14.3 0 0 1 .06-2.83c.15-1.48.01-2.64-.18-3.45-.06-.28-.08-.25-.15-.45.3-.17.4-.13.77-.5.8-.8 1.6-2.18 1.75-4.26.17-2.26-.55-3.98-1.92-4.9C27.65 10.17 25.91 10 24 10h-8.28zm.81 2H24c1.75 0 3.13.25 3.9.77.76.52 1.18 1.27 1.05 3.1-.13 1.67-.69 2.51-1.17 3a2 2 0 0 1-.82.56 1 1 0 0 0-.6 1.44s.12.21.27.82c.14.6.26 1.53.13 2.79a14.24 14.24 0 0 0-.01 3.52h-2.76c-.01-.19-.04-.32 0-.62.22-1.78.25-3.21-.24-4.42A3.38 3.38 0 0 0 20.54 21h-2.87a1 1 0 0 0-.98.78L15.32 28H13.1l3.44-16zm2.76 1.03a1 1 0 0 0-.98.8l-.98 4.94a1 1 0 0 0 .98 1.2h4.47c.79 0 1.65-.12 2.44-.58a3.6 3.6 0 0 0 1.68-2.41 3.3 3.3 0 0 0-.72-2.92 3.35 3.35 0 0 0-2.47-1.03h-4.42zm.82 2h3.6c.41 0 .79.16 1 .4.22.22.36.52.23 1.15-.13.62-.36.88-.72 1.08a3 3 0 0 1-1.44.3h-3.25l.58-2.93zm11.7 10.99.49 3.11a1 1 0 0 0 .98.84h2.69l-2.76 2.05a1 1 0 0 0-.4 1l.7 3.56-2.73-2.12a1 1 0 0 0-1.13-.07l-3.4 2.07 1.56-3.44a1 1 0 0 0-.23-1.15L25.55 30H29a1 1 0 0 0 .88-.51l1.92-3.47z\"/></svg>"),
Rules: []string{
"||rockstargames.com^",
"||rsg.sc^",
},
}, {
ID: "shopee",
Name: "Shopee",
@@ -1940,6 +2078,16 @@ var blockedServices = []blockedService{{
"||twvid.com^",
"||vine.co^",
},
}, {
ID: "ubisoft",
Name: "Ubisoft",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 32 32\"><path d=\"M15.22 3C7.14 3 3.66 10.18 3.66 10.18l1.03.74s-1.3 2.45-1.26 5.6A12.5 12.5 0 0 0 16.08 29a12.5 12.5 0 0 0 12.49-12.46c0-9-6.98-13.54-13.35-13.54zm.07 2.2c6.3 0 11.2 5.07 11.2 10.98 0 6.27-4.71 10.62-10.2 10.62-4.04 0-7.69-3.08-7.69-7.3a5.8 5.8 0 0 1 2.75-5.03l.21.23a6.37 6.37 0 0 0-1.53 3.91c0 3.32 2.6 5.62 5.88 5.62 4.18 0 6.97-3.56 6.97-7.7 0-4.81-4.25-8.9-9.36-8.9a11.1 11.1 0 0 0-6.61 2.3l-.21-.2a10.07 10.07 0 0 1 8.59-4.54zM13.4 9.8c3.26 0 6.44 2.15 7.24 5.22l-.3.1a8.35 8.35 0 0 0-6.52-3.44c-5.08 0-7.75 4.62-7.36 8.47l-.3.12s-.56-1.24-.56-2.71a7.8 7.8 0 0 1 7.8-7.76zm2.15 5.33a2.77 2.77 0 0 1 2.78 2.74c0 1.23-.79 1.96-.79 1.96l.94.65s-.93 1.46-2.82 1.46a3.4 3.4 0 0 1-.1-6.8z\"/></svg>"),
Rules: []string{
"||ubi.com^",
"||ubisoft.com^",
"||ubisoft.org^",
"||ubisoftconnect.com^",
},
}, {
ID: "valorant",
Name: "Valorant",

View File

@@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/stringutil"
)
@@ -22,12 +23,14 @@ type Client struct {
safeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
Name string
IDs []string
Tags []string
BlockedServices []string
Upstreams []string
IDs []string
Tags []string
Upstreams []string
UseOwnSettings bool
FilteringEnabled bool
@@ -43,9 +46,9 @@ type Client struct {
func (c *Client) ShallowClone() (sh *Client) {
clone := *c
clone.BlockedServices = c.BlockedServices.Clone()
clone.IDs = stringutil.CloneSlice(c.IDs)
clone.Tags = stringutil.CloneSlice(c.Tags)
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
return &clone
@@ -127,14 +130,13 @@ func (cs clientSource) MarshalText() (text []byte, err error) {
// RuntimeClient is a client information about which has been obtained using the
// source described in the Source field.
type RuntimeClient struct {
WHOISInfo *RuntimeClientWHOISInfo
Host string
Source clientSource
}
// WHOIS is the filtered WHOIS data of a client.
WHOIS *whois.Info
// RuntimeClientWHOISInfo is the filtered WHOIS data for a runtime client.
type RuntimeClientWHOISInfo struct {
City string `json:"city,omitempty"`
Country string `json:"country,omitempty"`
Orgname string `json:"orgname,omitempty"`
// Host is the host name of a client.
Host string
// Source is the source from which the information about the client has
// been obtained.
Source clientSource
}

View File

@@ -11,9 +11,11 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -23,6 +25,23 @@ import (
"golang.org/x/exp/slices"
)
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
// needs.
type DHCP interface {
// Leases returns all the DHCP leases.
Leases() (leases []*dhcpsvc.Lease)
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have an IP address.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
// returns nil if there is no such client, due to an assumption that a DHCP
// client must always have a MAC address.
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
}
// clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different
@@ -77,7 +96,7 @@ func (clients *clientsContainer) Init(
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
filteringConf *filtering.Config,
) {
) (err error) {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
@@ -91,23 +110,29 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.addFromConfig(objects, filteringConf)
err = clients.addFromConfig(objects, filteringConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing {
return
return nil
}
clients.updateFromDHCP(true)
if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
}
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}
return nil
}
func (clients *clientsContainer) handleHostsUpdates() {
@@ -147,12 +172,14 @@ func (clients *clientsContainer) reloadARP() {
type clientObject struct {
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices `yaml:"blocked_services"`
Name string `yaml:"name"`
Tags []string `yaml:"tags"`
IDs []string `yaml:"ids"`
BlockedServices []string `yaml:"blocked_services"`
Upstreams []string `yaml:"upstreams"`
IDs []string `yaml:"ids"`
Tags []string `yaml:"tags"`
Upstreams []string `yaml:"upstreams"`
UseGlobalSettings bool `yaml:"use_global_settings"`
FilteringEnabled bool `yaml:"filtering_enabled"`
@@ -166,7 +193,10 @@ type clientObject struct {
// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
func (clients *clientsContainer) addFromConfig(
objects []*clientObject,
filteringConf *filtering.Config,
) (err error) {
for _, o := range objects {
cli := &Client{
Name: o.Name,
@@ -187,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err := cli.setSafeSearch(
err = cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
@@ -199,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
}
}
for _, s := range o.BlockedServices {
if filtering.BlockedSvcKnown(s) {
cli.BlockedServices = append(cli.BlockedServices, s)
} else {
log.Info("clients: skipping unknown blocked service %q", s)
}
err = o.BlockedServices.Validate()
if err != nil {
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
}
cli.BlockedServices = o.BlockedServices.Clone()
for _, t := range o.Tags {
if clients.allTags.Has(t) {
cli.Tags = append(cli.Tags, t)
@@ -217,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
slices.Sort(cli.Tags)
_, err := clients.Add(cli)
_, err = clients.Add(cli)
if err != nil {
log.Error("clients: adding clients %s: %s", cli.Name, err)
}
}
return nil
}
// forConfig returns all currently known persistent clients as objects for the
@@ -235,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
o := &clientObject{
Name: cli.Name,
Tags: stringutil.CloneSlice(cli.Tags),
IDs: stringutil.CloneSlice(cli.IDs),
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
BlockedServices: cli.BlockedServices.Clone(),
IDs: stringutil.CloneSlice(cli.IDs),
Tags: stringutil.CloneSlice(cli.Tags),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
@@ -276,15 +308,38 @@ func (clients *clientsContainer) periodicUpdate() {
}
}
// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of
// runtime clients using the DHCP server's leases.
//
// TODO(e.burkov): Remove when switched to dhcpsvc.
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
switch flags {
case dhcpd.LeaseChangedAdded,
dhcpd.LeaseChangedAddedStatic,
dhcpd.LeaseChangedRemovedStatic:
clients.updateFromDHCP(true)
case dhcpd.LeaseChangedRemovedAll:
clients.updateFromDHCP(false)
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceDHCP)
if flags == dhcpd.LeaseChangedRemovedAll {
return
}
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if l.Hostname == "" {
continue
}
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("clients: added %d client aliases from dhcp", n)
}
// clientSource checks if client with this IP address already exists and returns
@@ -300,23 +355,11 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource)
}
rc, ok := clients.ipToRC[ip]
if !ok {
return ClientSourceNone
if ok {
return rc.Source
}
return rc.Source
}
func toQueryLogWHOIS(wi *RuntimeClientWHOISInfo) (cw *querylog.ClientWHOIS) {
if wi == nil {
return &querylog.ClientWHOIS{}
}
return &querylog.ClientWHOIS{
City: wi.City,
Country: wi.Country,
Orgname: wi.Orgname,
}
return ClientSourceNone
}
// findMultiple is a wrapper around Find to make it a valid client finder for
@@ -352,7 +395,7 @@ func (clients *clientsContainer) clientOrArtificial(
defer func() {
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
if c.WHOIS == nil {
c.WHOIS = &querylog.ClientWHOIS{}
c.WHOIS = &whois.Info{}
}
}()
@@ -369,7 +412,7 @@ func (clients *clientsContainer) clientOrArtificial(
if ok {
return &querylog.Client{
Name: rc.Host,
WHOIS: toQueryLogWHOIS(rc.WHOISInfo),
WHOIS: rc.WHOIS,
}, false
}
@@ -477,11 +520,11 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
}
}
if clients.dhcpServer == nil {
return nil, false
if clients.dhcpServer != nil {
return clients.findDHCP(ip)
}
return clients.findDHCP(ip)
return nil, false
}
// findDHCP searches for a client by its MAC, if the DHCP server is active and
@@ -701,35 +744,34 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
}
// setWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findLocked(ip.String())
if ok {
log.Debug("clients: client for %s is already created, ignore whois info", ip)
return
}
// TODO(e.burkov): Consider storing WHOIS information separately and
// potentially get rid of [RuntimeClient].
rc, ok := clients.ipToRC[ip]
if ok {
rc.WHOISInfo = wi
if !ok {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &RuntimeClient{
Source: ClientSourceWHOIS,
}
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
return
}
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &RuntimeClient{
Source: ClientSourceWHOIS,
}
rc.WHOISInfo = wi
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
rc.WHOIS = wi
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
@@ -753,23 +795,19 @@ func (clients *clientsContainer) addHostLocked(
src clientSource,
) (ok bool) {
rc, ok := clients.ipToRC[ip]
if ok {
if rc.Source > src {
return false
}
rc.Host = host
rc.Source = src
} else {
if !ok {
rc = &RuntimeClient{
Host: host,
Source: src,
WHOISInfo: &RuntimeClientWHOISInfo{},
WHOIS: &whois.Info{},
}
clients.ipToRC[ip] = rc
} else if src < rc.Source {
return false
}
rc.Host = host
rc.Source = src
log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC))
return true
@@ -838,38 +876,6 @@ func (clients *clientsContainer) addFromSystemARP() {
log.Debug("clients: added %d client aliases from arp neighborhood", added)
}
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) updateFromDHCP(add bool) {
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceDHCP)
if !add {
return
}
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if l.Hostname == "" {
continue
}
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("clients: added %d client aliases from dhcp", n)
}
// close gracefully closes all the client-specific upstream configurations of
// the persistent clients.
func (clients *clientsContainer) close() (err error) {

View File

@@ -9,25 +9,26 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer() (c *clientsContainer) {
func newClientsContainer(t *testing.T) (c *clientsContainer) {
c = &clientsContainer{
testing: true,
}
c.Init(nil, nil, nil, nil, &filtering.Config{})
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
require.NoError(t, err)
return c
}
func TestClients(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
t.Run("add_success", func(t *testing.T) {
var (
@@ -198,8 +199,8 @@ func TestClients(t *testing.T) {
}
func TestClientsWHOIS(t *testing.T) {
clients := newClientsContainer()
whois := &RuntimeClientWHOISInfo{
clients := newClientsContainer(t)
whois := &whois.Info{
Country: "AU",
Orgname: "Example Org",
}
@@ -210,7 +211,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
assert.Equal(t, rc.WHOISInfo, whois)
assert.Equal(t, rc.WHOIS, whois)
})
t.Run("existing_auto-client", func(t *testing.T) {
@@ -222,7 +223,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
assert.Equal(t, rc.WHOISInfo, whois)
assert.Equal(t, rc.WHOIS, whois)
})
t.Run("can't_set_manually-added", func(t *testing.T) {
@@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
}
func TestClientsAddExisting(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
@@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
// Add client with upstreams.
ok, err := clients.Add(&Client{

View File

@@ -9,6 +9,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
)
// clientJSON is a common structure used by several handlers to deal with
@@ -28,7 +30,8 @@ type clientJSON struct {
// the allowlist.
DisallowedRule *string `json:"disallowed_rule,omitempty"`
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
// WHOIS is the filtered WHOIS data of a client.
WHOIS *whois.Info `json:"whois_info,omitempty"`
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
Name string `json:"name"`
@@ -51,7 +54,7 @@ type clientJSON struct {
}
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
WHOIS *whois.Info `json:"whois_info"`
IP netip.Addr `json:"ip"`
Name string `json:"name"`
@@ -78,7 +81,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
for ip, rc := range clients.ipToRC {
cj := runtimeClientJSON{
WHOISInfo: rc.WHOISInfo,
WHOIS: rc.WHOIS,
Name: rc.Host,
Source: rc.Source,
@@ -116,15 +119,24 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
}
}
weekly := schedule.EmptyWeekly()
if prev != nil {
weekly = prev.BlockedServices.Schedule.Clone()
}
c = &Client{
safeSearchConf: safeSearchConf,
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
BlockedServices: &filtering.BlockedServices{
Schedule: weekly,
IDs: cj.BlockedServices,
},
IDs: cj.IDs,
Tags: cj.Tags,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
@@ -178,7 +190,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
BlockedServices: c.BlockedServices,
BlockedServices: c.BlockedServices.IDs,
Upstreams: c.Upstreams,
@@ -344,16 +357,16 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
IDs: []string{idStr},
Disallowed: &disallowed,
DisallowedRule: &rule,
WHOISInfo: &RuntimeClientWHOISInfo{},
WHOIS: &whois.Info{},
}
return cj
}
cj = &clientJSON{
Name: rc.Host,
IDs: []string{idStr},
WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IDs: []string{idStr},
WHOIS: rc.WHOIS,
}
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/golibs/errors"
@@ -90,18 +91,17 @@ type clientSourcesConfig struct {
HostsFile bool `yaml:"hosts"`
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
// configuration is loaded from YAML.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type configuration struct {
// Raw file data to avoid re-reading of configuration file
// It's reset after config is parsed
fileData []byte
// BindHost is the address for the web interface server to listen on.
BindHost netip.Addr `yaml:"bind_host"`
// BindPort is the port for the web interface server to listen on.
BindPort int `yaml:"bind_port"`
// HTTPConfig is the block with http conf.
HTTPConfig httpConfig `yaml:"http"`
// Users are the clients capable for accessing the web interface.
Users []webUser `yaml:"users"`
// AuthAttempts is the maximum number of failed login attempts a user
@@ -119,10 +119,6 @@ type configuration struct {
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
DebugPProf bool `yaml:"debug_pprof"`
// TTL for a web session (in hours)
// An active session is automatically refreshed once a day.
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
QueryLog queryLogConfig `yaml:"querylog"`
@@ -155,7 +151,23 @@ type configuration struct {
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
}
// field ordering is important -- yaml fields will mirror ordering from here
// httpConfig is a block with HTTP configuration params.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type httpConfig struct {
// Address is the address to serve the web UI on.
Address netip.AddrPort
// SessionTTL for a web session.
// An active session is automatically refreshed once a day.
SessionTTL timeutil.Duration `yaml:"session_ttl"`
}
// dnsConfig is a block with DNS configuration params.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type dnsConfig struct {
BindHosts []netip.Addr `yaml:"bind_hosts"`
Port int `yaml:"port"`
@@ -260,11 +272,12 @@ type statsConfig struct {
//
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
var config = &configuration{
BindPort: 3000,
BindHost: netip.IPv4Unspecified(),
AuthAttempts: 5,
AuthBlockMin: 15,
WebSessionTTLHours: 30 * 24,
AuthAttempts: 5,
AuthBlockMin: 15,
HTTPConfig: httpConfig{
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
},
DNS: dnsConfig{
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
Port: defaultPortDNS,
@@ -316,6 +329,11 @@ var config = &configuration{
Yandex: true,
YouTube: true,
},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: []string{},
},
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
@@ -421,8 +439,8 @@ func readLogSettings() (ls *logSettings) {
// validateBindHosts returns error if any of binding hosts from configuration is
// not a valid IP address.
func validateBindHosts(conf *configuration) (err error) {
if !conf.BindHost.IsValid() {
return errors.Error("bind_host is not a valid ip address")
if !conf.HTTPConfig.Address.IsValid() {
return errors.Error("http.address is not a valid ip address")
}
for i, addr := range conf.DNS.BindHosts {
@@ -456,7 +474,7 @@ func parseConfig() (err error) {
}
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort))
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))

View File

@@ -103,7 +103,7 @@ type statusResponse struct {
Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"`
HTTPPort uint16 `json:"http_port"`
// ProtectionDisabledDuration is the duration of the protection pause in
// milliseconds.
@@ -158,7 +158,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
Language: config.Language,
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
HTTPPort: config.HTTPConfig.Address.Port(),
ProtectionDisabledDuration: protectionDisabledDuration,
ProtectionEnabled: protectionEnabled,
IsRunning: isRunning(),

View File

@@ -96,8 +96,9 @@ type checkConfResp struct {
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
portInt := req.Web.Port
port := tcpPort(portInt)
// TODO(a.garipov): Declare all port variables anywhere as uint16.
reqPort := uint16(req.Web.Port)
port := tcpPort(reqPort)
addPorts(tcpPorts, port)
if err = tcpPorts.Validate(); err != nil {
// Reset the value for the port to 1 to make sure that validateDNS
@@ -108,15 +109,15 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
return err
}
switch portInt {
case 0, config.BindPort:
switch reqPort {
case 0, config.HTTPConfig.Address.Port():
return nil
default:
// Go on and check the port binding only if it's not zero or won't be
// unbound after install.
}
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(portInt)))
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, reqPort))
}
// validateDNS returns error if the DNS part of the initial configuration can't
@@ -127,11 +128,11 @@ func (req *checkConfReq) validateDNS(
) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.DNS.Port
port := uint16(req.DNS.Port)
switch port {
case 0:
return false, nil
case config.BindPort:
case config.HTTPConfig.Address.Port():
// Go on and only check the UDP port since the TCP one is already bound
// by AdGuard Home for web interface.
default:
@@ -318,8 +319,7 @@ type applyConfigReq struct {
// copyInstallSettings copies the installation parameters between two
// configuration structures.
func copyInstallSettings(dst, src *configuration) {
dst.BindHost = src.BindHost
dst.BindPort = src.BindPort
dst.HTTPConfig = src.HTTPConfig
dst.DNS.BindHosts = src.DNS.BindHosts
dst.DNS.Port = src.DNS.Port
}
@@ -413,8 +413,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
copyInstallSettings(curConfig, config)
Context.firstRun = false
config.BindHost = req.Web.IP
config.BindPort = req.Web.Port
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port
@@ -487,7 +486,8 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return nil, false, errors.Error("ports cannot be 0")
}
restartHTTP = config.BindHost != req.Web.IP || config.BindPort != req.Web.Port
addrPort := config.HTTPConfig.Address
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
if restartHTTP {
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
if err != nil {

View File

@@ -157,7 +157,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
if tlsConfUsesPrivilegedPorts(tlsConf) ||
config.HTTPConfig.Address.Port() < 1024 ||
config.DNS.Port < 1024 {
canUpdate, err = aghnet.CanBindPrivilegedPorts()
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"os"
"path/filepath"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -17,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -25,7 +27,7 @@ import (
yaml "gopkg.in/yaml.v3"
)
// Default ports.
// Default listening ports.
const (
defaultPortDNS = 53
defaultPortHTTP = 80
@@ -169,13 +171,72 @@ func initDNSServer(
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
}
if config.Clients.Sources.WHOIS {
Context.whois = initWHOIS(&Context.clients)
}
initWHOIS()
return nil
}
// initWHOIS initializes the WHOIS.
//
// TODO(s.chzhen): Consider making configurable.
func initWHOIS() {
const (
// defaultQueueSize is the size of queue of IPs for WHOIS processing.
defaultQueueSize = 255
// defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second
// defaultCacheSize is the maximum size of the cache. If it's zero,
// cache size is unlimited.
defaultCacheSize = 10_000
// defaultMaxConnReadSize is an upper limit in bytes for reading from
// net.Conn.
defaultMaxConnReadSize = 64 * 1024
// defaultMaxRedirects is the maximum redirects count.
defaultMaxRedirects = 5
// defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250
// defaultIPTTL is the Time to Live duration for cached IP addresses.
defaultIPTTL = 1 * time.Hour
)
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
var w whois.Interface
if config.Clients.Sources.WHOIS {
w = whois.New(&whois.Config{
DialContext: customDialContext,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Timeout: defaultTimeout,
CacheSize: defaultCacheSize,
MaxConnReadSize: defaultMaxConnReadSize,
MaxRedirects: defaultMaxRedirects,
MaxInfoLen: defaultMaxInfoLen,
CacheTTL: defaultIPTTL,
})
} else {
w = whois.Empty{}
}
go func() {
defer log.OnPanic("whois")
for ip := range Context.whoisCh {
info, changed := w.Process(context.Background(), ip)
if info != nil && changed {
Context.clients.setWHOISInfo(ip, info)
}
}
}()
}
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
// a subnet set that matches all locally served networks, see
// [netutil.IsLocallyServed].
@@ -218,9 +279,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
Context.whoisCh <- ip
}
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
@@ -390,7 +449,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
Context.filters.ApplyBlockedServices(setts, nil)
Context.filters.ApplyBlockedServices(setts)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
@@ -414,12 +473,12 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
if c.UseOwnBlockedServices {
// TODO(e.burkov): Get rid of this crutch.
svcs := c.BlockedServices
if svcs == nil {
svcs = []string{}
setts.ServicesRules = nil
svcs := c.BlockedServices.IDs
if !c.BlockedServices.Schedule.Contains(time.Now()) {
Context.filters.ApplyBlockedServicesList(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
Context.filters.ApplyBlockedServices(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
setts.ClientName = c.Name
@@ -463,9 +522,7 @@ func startDNSServer() error {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
Context.whoisCh <- ip
}
return nil

View File

@@ -0,0 +1,176 @@
package home
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
Context.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
"default": {
UseOwnSettings: false,
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
FilteringEnabled: false,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
"custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
},
"partial_custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
}
testCases := []struct {
name string
id string
FilteringEnabled assert.BoolAssertionFunc
SafeSearchEnabled assert.BoolAssertionFunc
SafeBrowsingEnabled assert.BoolAssertionFunc
ParentalEnabled assert.BoolAssertionFunc
}{{
name: "global_settings",
id: "default",
FilteringEnabled: assert.False,
SafeSearchEnabled: assert.False,
SafeBrowsingEnabled: assert.False,
ParentalEnabled: assert.False,
}, {
name: "custom_settings",
id: "custom_filtering",
FilteringEnabled: assert.True,
SafeSearchEnabled: assert.True,
SafeBrowsingEnabled: assert.True,
ParentalEnabled: assert.True,
}, {
name: "partial",
id: "partial_custom_filtering",
FilteringEnabled: assert.True,
SafeSearchEnabled: assert.True,
SafeBrowsingEnabled: assert.False,
ParentalEnabled: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
tc.FilteringEnabled(t, setts.FilteringEnabled)
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
tc.ParentalEnabled(t, setts.ParentalEnabled)
})
}
}
func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
filtering.InitModule()
var (
globalBlockedServices = []string{"ok"}
clientBlockedServices = []string{"ok", "mail_ru", "vk"}
invalidBlockedServices = []string{"invalid"}
err error
)
Context.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: globalBlockedServices,
},
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
"default": {
UseOwnBlockedServices: false,
},
"no_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
UseOwnBlockedServices: true,
},
"services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
},
"invalid_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: invalidBlockedServices,
},
UseOwnBlockedServices: true,
},
"allow_all": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.FullWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
},
}
testCases := []struct {
name string
id string
wantLen int
}{{
name: "global_settings",
id: "default",
wantLen: len(globalBlockedServices),
}, {
name: "custom_settings",
id: "no_services",
wantLen: 0,
}, {
name: "custom_settings_block",
id: "services",
wantLen: len(clientBlockedServices),
}, {
name: "custom_settings_invalid",
id: "invalid_services",
wantLen: 0,
}, {
name: "custom_settings_inactive_schedule",
id: "allow_all",
wantLen: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
require.Len(t, setts.ServicesRules, tc.wantLen)
})
}
}

View File

@@ -57,7 +57,6 @@ type homeContext struct {
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module
dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module
@@ -84,6 +83,9 @@ type homeContext struct {
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// whoisCh is the channel for receiving IPs for WHOIS processing.
whoisCh chan netip.Addr
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
tlsCipherIDs []uint16
@@ -353,21 +355,43 @@ func initContextClients() (err error) {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(
err = Context.clients.Init(
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
arpdb,
config.DNS.DnsfilterConf,
)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
return nil
}
// setupBindOpts overrides bind host/port from the opts.
func setupBindOpts(opts options) (err error) {
bindAddr := opts.bindAddr
if bindAddr != (netip.AddrPort{}) {
config.HTTPConfig.Address = bindAddr
if config.HTTPConfig.Address.Port() != 0 {
err = checkPorts()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
return nil
}
if opts.bindPort != 0 {
config.BindPort = opts.bindPort
config.HTTPConfig.Address = netip.AddrPortFrom(
config.HTTPConfig.Address.Addr(),
uint16(opts.bindPort),
)
err = checkPorts()
if err != nil {
@@ -377,7 +401,10 @@ func setupBindOpts(opts options) (err error) {
}
if opts.bindHost.IsValid() {
config.BindHost = opts.bindHost
config.HTTPConfig.Address = netip.AddrPortFrom(
opts.bindHost,
config.HTTPConfig.Address.Port(),
)
}
return nil
@@ -461,7 +488,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
// checkPorts is a helper for ports validation in config.
func checkPorts() (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort))
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
@@ -501,8 +528,8 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
webConf := webConfig{
firstRun: Context.firstRun,
BindHost: config.BindHost,
BindPort: config.BindPort,
BindHost: config.HTTPConfig.Address.Addr(),
BindPort: int(config.HTTPConfig.Address.Port()),
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHdrTimeout,
@@ -638,8 +665,8 @@ func initUsers() (auth *Auth, err error) {
log.Info("authratelimiter is disabled")
}
sessionTTL := config.WebSessionTTLHours * 60 * 60
auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter)
sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter)
if auth == nil {
return nil, errors.Error("initializing auth module failed")
}
@@ -917,7 +944,7 @@ func printHTTPAddresses(proto string) {
Context.tls.WriteDiskConfig(&tlsConf)
}
port := config.BindPort
port := int(config.HTTPConfig.Address.Port())
if proto == aghhttp.SchemeHTTPS {
port = tlsConf.PortHTTPS
}
@@ -929,9 +956,9 @@ func printHTTPAddresses(proto string) {
return
}
bindhost := config.BindHost
if !bindhost.IsUnspecified() {
printWebAddrs(proto, bindhost.String(), port)
bindHost := config.HTTPConfig.Address.Addr()
if !bindHost.IsUnspecified() {
printWebAddrs(proto, bindHost.String(), port)
return
}
@@ -942,14 +969,14 @@ func printHTTPAddresses(proto string) {
// That's weird, but we'll ignore it.
//
// TODO(e.burkov): Find out when it happens.
printWebAddrs(proto, bindhost.String(), port)
printWebAddrs(proto, bindHost.String(), port)
return
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
printWebAddrs(proto, addr.String(), config.BindPort)
printWebAddrs(proto, addr.String(), port)
}
}
}

View File

@@ -35,11 +35,18 @@ type options struct {
serviceControlAction string
// bindHost is the address on which to serve the HTTP UI.
//
// Deprecated: Use bindAddr.
bindHost netip.Addr
// bindPort is the port on which to serve the HTTP UI.
//
// Deprecated: Use bindAddr.
bindPort int
// bindAddr is the address to serve the web UI on.
bindAddr netip.AddrPort
// checkConfig is true if the current invocation is only required to check
// the configuration file and exit.
checkConfig bool
@@ -147,9 +154,10 @@ var cmdLineOpts = []cmdLineOpt{{
return o.bindHost.String(), true
},
description: "Host address to bind HTTP server on.",
longName: "host",
shortName: "h",
description: "Deprecated. Host address to bind HTTP server on. Use --web-addr. " +
"The short -h will work as --help in the future.",
longName: "host",
shortName: "h",
}, {
updateWithValue: func(o options, v string) (options, error) {
var err error
@@ -174,9 +182,23 @@ var cmdLineOpts = []cmdLineOpt{{
return strconv.Itoa(o.bindPort), true
},
description: "Port to serve HTTP pages on.",
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
longName: "port",
shortName: "p",
}, {
updateWithValue: func(o options, v string) (oo options, err error) {
o.bindAddr, err = netip.ParseAddrPort(v)
return o, err
},
updateNoValue: nil,
effect: nil,
serialize: func(o options) (val string, ok bool) {
return o.bindAddr.String(), o.bindAddr.IsValid()
},
description: "Address to serve the web UI on, in the host:port format.",
longName: "web-addr",
shortName: "",
}, {
updateWithValue: func(o options, v string) (options, error) {
o.serviceControlAction = v

View File

@@ -82,6 +82,23 @@ func TestParseBindPort(t *testing.T) {
testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
}
func TestParseBindAddr(t *testing.T) {
wantAddrPort := netip.MustParseAddrPort("1.2.3.4:8089")
assert.Zero(t, testParseOK(t).bindAddr, "empty is not web-addr")
assert.Equal(t, wantAddrPort, testParseOK(t, "--web-addr", "1.2.3.4:8089").bindAddr)
assert.Equal(t, netip.MustParseAddrPort("1.2.3.4:0"), testParseOK(t, "--web-addr", "1.2.3.4:0").bindAddr)
testParseParamMissing(t, "-web-addr")
testParseErr(t, "not an int", "--web-addr", "1.2.3.4:x")
testParseErr(t, "hex not supported", "--web-addr", "1.2.3.4:0x100")
testParseErr(t, "port negative", "--web-addr", "1.2.3.4:-1")
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:65536")
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:4294967297") // 2^32 + 1
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:18446744073709551617") // 2^64 + 1
}
func TestParseLogfile(t *testing.T) {
assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
@@ -162,6 +179,10 @@ func TestOptsToArgs(t *testing.T) {
name: "bind_port",
args: []string{"-p", "666"},
opts: options{bindPort: 666},
}, {
name: "web-addr",
args: []string{"--web-addr", "1.2.3.4:8080"},
opts: options{bindAddr: netip.MustParseAddrPort("1.2.3.4:8080")},
}, {
name: "log_file",
args: []string{"-l", "path"},

View File

@@ -228,12 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
for _, tc := range testCases {
w.Reset()
cc := &clientsContainer{
list: map[string]*Client{},
idIndex: map[string]*Client{},
ipToRC: map[netip.Addr]*RuntimeClient{},
allTags: stringutil.NewSet(),
}
cc := newClientsContainer(t)
ch := make(chan netip.Addr)
rdns := &RDNS{
exchanger: &rDNSExchanger{

View File

@@ -320,7 +320,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
if setts.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
@@ -407,7 +407,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),

View File

@@ -3,6 +3,7 @@ package home
import (
"bytes"
"fmt"
"net/netip"
"net/url"
"os"
"path"
@@ -22,7 +23,7 @@ import (
)
// currentSchemaVersion is the current schema version.
const currentSchemaVersion = 20
const currentSchemaVersion = 23
// These aliases are provided for convenience.
type (
@@ -94,6 +95,9 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
upgradeSchema17to18,
upgradeSchema18to19,
upgradeSchema19to20,
upgradeSchema20to21,
upgradeSchema21to22,
upgradeSchema22to23,
}
n := 0
@@ -1128,6 +1132,199 @@ func upgradeSchema19to20(diskConf yobj) (err error) {
return nil
}
// upgradeSchema20to21 performs the following changes:
//
// # BEFORE:
// 'dns':
// 'blocked_services':
// - 'svc_name'
//
// # AFTER:
// 'dns':
// 'blocked_services':
// 'ids':
// - 'svc_name'
// 'schedule':
// 'time_zone': 'Local'
func upgradeSchema20to21(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 20 to 21")
diskConf["schema_version"] = 21
const field = "blocked_services"
dnsVal, ok := diskConf["dns"]
if !ok {
return nil
}
dns, ok := dnsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
}
blockedVal, ok := dns[field]
if !ok {
return nil
}
services, ok := blockedVal.(yarr)
if !ok {
return fmt.Errorf("unexpected type of blocked: %T", blockedVal)
}
dns[field] = yobj{
"ids": services,
"schedule": yobj{
"time_zone": "Local",
},
}
return nil
}
// upgradeSchema21to22 performs the following changes:
//
// # BEFORE:
// 'persistent':
// - 'name': 'client_name'
// 'blocked_services':
// - 'svc_name'
//
// # AFTER:
// 'persistent':
// - 'name': 'client_name'
// 'blocked_services':
// 'ids':
// - 'svc_name'
// 'schedule':
// 'time_zone': 'Local'
func upgradeSchema21to22(diskConf yobj) (err error) {
log.Println("Upgrade yaml: 21 to 22")
diskConf["schema_version"] = 22
const field = "blocked_services"
clientsVal, ok := diskConf["clients"]
if !ok {
return nil
}
clients, ok := clientsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
}
persistentVal, ok := clients["persistent"]
if !ok {
return nil
}
persistent, ok := persistentVal.([]any)
if !ok {
return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal)
}
for i, val := range persistent {
var c yobj
c, ok = val.(yobj)
if !ok {
return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val)
}
var blockedVal any
blockedVal, ok = c[field]
if !ok {
continue
}
var services yarr
services, ok = blockedVal.(yarr)
if !ok {
return fmt.Errorf(
"persistent client at index %d: unexpected type of blocked services: %T",
i,
blockedVal,
)
}
c[field] = yobj{
"ids": services,
"schedule": yobj{
"time_zone": "Local",
},
}
}
return nil
}
// upgradeSchema22to23 performs the following changes:
//
// # BEFORE:
// 'bind_host': '1.2.3.4'
// 'bind_port': 8080
// 'web_session_ttl': 720
//
// # AFTER:
// 'http':
// 'address': '1.2.3.4:8080'
// 'session_ttl': '720h'
func upgradeSchema22to23(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 22 to 23")
diskConf["schema_version"] = 23
bindHostVal, ok := diskConf["bind_host"]
if !ok {
return nil
}
bindHost, ok := bindHostVal.(string)
if !ok {
return fmt.Errorf("unexpected type of bind_host: %T", bindHostVal)
}
bindHostAddr, err := netip.ParseAddr(bindHost)
if err != nil {
return fmt.Errorf("invalid bind_host value: %s", bindHost)
}
bindPortVal, ok := diskConf["bind_port"]
if !ok {
return nil
}
bindPort, ok := bindPortVal.(int)
if !ok {
return fmt.Errorf("unexpected type of bind_port: %T", bindPortVal)
}
sessionTTLVal, ok := diskConf["web_session_ttl"]
if !ok {
return nil
}
sessionTTL, ok := sessionTTLVal.(int)
if !ok {
return fmt.Errorf("unexpected type of web_session_ttl: %T", sessionTTLVal)
}
addr := netip.AddrPortFrom(bindHostAddr, uint16(bindPort))
if !addr.IsValid() {
return fmt.Errorf("invalid address: %s", addr)
}
diskConf["http"] = yobj{
"address": addr.String(),
"session_ttl": timeutil.Duration{Duration: time.Duration(sessionTTL) * time.Hour}.String(),
}
delete(diskConf, "bind_host")
delete(diskConf, "bind_port")
delete(diskConf, "web_session_ttl")
return nil
}
// TODO(a.garipov): Replace with log.Output when we port it to our logging
// package.
func funcName() string {

View File

@@ -1140,3 +1140,169 @@ func TestUpgradeSchema19to20(t *testing.T) {
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
})
}
func TestUpgradeSchema20to21(t *testing.T) {
const newSchemaVer = 21
testCases := []struct {
in yobj
want yobj
name string
}{{
name: "nothing",
in: yobj{},
want: yobj{
"schema_version": newSchemaVer,
},
}, {
name: "no_clients",
in: yobj{
"dns": yobj{
"blocked_services": yarr{"ok"},
},
},
want: yobj{
"dns": yobj{
"blocked_services": yobj{
"ids": yarr{"ok"},
"schedule": yobj{
"time_zone": "Local",
},
},
},
"schema_version": newSchemaVer,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema20to21(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema21to22(t *testing.T) {
const newSchemaVer = 22
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{
"clients": yobj{},
},
want: yobj{
"clients": yobj{},
"schema_version": newSchemaVer,
},
name: "nothing",
}, {
in: yobj{
"clients": yobj{
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}},
},
},
want: yobj{
"clients": yobj{
"persistent": []any{yobj{
"name": "localhost",
"blocked_services": yobj{
"ids": yarr{},
"schedule": yobj{
"time_zone": "Local",
},
},
}},
},
"schema_version": newSchemaVer,
},
name: "no_services",
}, {
in: yobj{
"clients": yobj{
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}},
},
},
want: yobj{
"clients": yobj{
"persistent": []any{yobj{
"name": "localhost",
"blocked_services": yobj{
"ids": yarr{"ok"},
"schedule": yobj{
"time_zone": "Local",
},
},
}},
},
"schema_version": newSchemaVer,
},
name: "services",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema21to22(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema22to23(t *testing.T) {
const newSchemaVer = 23
testCases := []struct {
in yobj
want yobj
name string
}{{
name: "empty",
in: yobj{},
want: yobj{
"schema_version": newSchemaVer,
},
}, {
name: "ok",
in: yobj{
"bind_host": "1.2.3.4",
"bind_port": 8081,
"web_session_ttl": 720,
},
want: yobj{
"http": yobj{
"address": "1.2.3.4:8081",
"session_ttl": "720h",
},
"schema_version": newSchemaVer,
},
}, {
name: "v6_address",
in: yobj{
"bind_host": "2001:db8::1",
"bind_port": 8081,
"web_session_ttl": 720,
},
want: yobj{
"http": yobj{
"address": "[2001:db8::1]:8081",
"session_ttl": "720h",
},
"schema_version": newSchemaVer,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema22to23(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}

View File

@@ -119,7 +119,9 @@ func webCheckPortAvailable(port int) (ok bool) {
return true
}
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
return aghnet.CheckPort("tcp", addrPort) == nil
}
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server

View File

@@ -1,259 +0,0 @@
package home
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
)
const (
defaultServer = "whois.arin.net"
defaultPort = "43"
maxValueLength = 250
whoisTTL = 1 * 60 * 60 // 1 hour
)
// WHOIS - module context
type WHOIS struct {
clients *clientsContainer
ipChan chan netip.Addr
// dialContext specifies the dial function for creating unencrypted TCP
// connections.
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// Contains IP addresses of clients
// An active IP address is resolved once again after it expires.
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
ipAddrs cache.Cache
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
timeoutMsec uint
}
// initWHOIS creates the WHOIS module context.
func initWHOIS(clients *clientsContainer) *WHOIS {
w := WHOIS{
timeoutMsec: 5000,
clients: clients,
ipAddrs: cache.New(cache.Config{
EnableLRU: true,
MaxCount: 10000,
}),
dialContext: customDialContext,
ipChan: make(chan netip.Addr, 255),
}
go w.workerLoop()
return &w
}
// If the value is too large - cut it and append "..."
func trimValue(s string) string {
if len(s) <= maxValueLength {
return s
}
return s[:maxValueLength-3] + "..."
}
// isWHOISComment returns true if the string is empty or is a WHOIS comment.
func isWHOISComment(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' || s[0] == '%'
}
// strmap is an alias for convenience.
type strmap = map[string]string
// whoisParse parses a subset of plain-text data from the WHOIS response into
// a string map.
func whoisParse(data string) (m strmap) {
m = strmap{}
var orgname string
lines := strings.Split(data, "\n")
for _, l := range lines {
if isWHOISComment(l) {
continue
}
kv := strings.SplitN(l, ":", 2)
if len(kv) != 2 {
continue
}
k := strings.ToLower(strings.TrimSpace(kv[0]))
v := strings.TrimSpace(kv[1])
if v == "" {
continue
}
switch k {
case "orgname", "org-name":
k = "orgname"
v = trimValue(v)
orgname = v
case "city", "country":
v = trimValue(v)
case "descr", "netname":
k = "orgname"
v = stringutil.Coalesce(orgname, v)
orgname = v
case "whois":
k = "whois"
case "referralserver":
k = "whois"
v = strings.TrimPrefix(v, "whois://")
default:
continue
}
m[k] = v
}
return m
}
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response
func (w *WHOIS) query(ctx context.Context, target, serverAddr string) (data string, err error) {
addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" {
target = "n + " + target
}
conn, err := w.dialContext(ctx, "tcp", serverAddr)
if err != nil {
return "", err
}
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
r, err := aghio.LimitReader(conn, MaxConnReadSize)
if err != nil {
return "", err
}
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n"))
if err != nil {
return "", err
}
// This use of ReadAll is now safe, because we limited the conn Reader.
var whoisData []byte
whoisData, err = io.ReadAll(r)
if err != nil {
return "", err
}
return string(whoisData), nil
}
// Query WHOIS servers (handle redirects)
func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
server := net.JoinHostPort(defaultServer, defaultPort)
const maxRedirects = 5
for i := 0; i != maxRedirects; i++ {
resp, err := w.query(ctx, target, server)
if err != nil {
return "", err
}
log.Debug("whois: received response (%d bytes) from %s IP:%s", len(resp), server, target)
m := whoisParse(resp)
redir, ok := m["whois"]
if !ok {
return resp, nil
}
redir = strings.ToLower(redir)
_, _, err = net.SplitHostPort(redir)
if err != nil {
server = net.JoinHostPort(redir, defaultPort)
} else {
server = redir
}
log.Debug("whois: redirected to %s IP:%s", redir, target)
}
return "", fmt.Errorf("whois: redirect loop")
}
// Request WHOIS information
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
resp, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("whois: error: %s IP:%s", err, ip)
return nil
}
log.Debug("whois: IP:%s response: %d bytes", ip, len(resp))
m := whoisParse(resp)
wi = &RuntimeClientWHOISInfo{
City: m["city"],
Country: m["country"],
Orgname: m["orgname"],
}
// Don't return an empty struct so that the frontend doesn't get
// confused.
if *wi == (RuntimeClientWHOISInfo{}) {
return nil
}
return wi
}
// Begin - begin requesting WHOIS info
func (w *WHOIS) Begin(ip netip.Addr) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get(ipBytes)
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
return
}
}
expire = make([]byte, 8)
binary.BigEndian.PutUint64(expire, now+whoisTTL)
_ = w.ipAddrs.Set(ipBytes, expire)
log.Debug("whois: adding %s", ip)
select {
case w.ipChan <- ip:
default:
log.Debug("whois: queue is full")
}
}
// workerLoop processes the IP addresses it got from the channel and associates
// the retrieving WHOIS info with a client.
func (w *WHOIS) workerLoop() {
for ip := range w.ipChan {
info := w.process(context.Background(), ip)
if info == nil {
continue
}
w.clients.setWHOISInfo(ip, info)
}
}

View File

@@ -1,152 +0,0 @@
package home
import (
"context"
"io"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeConn is a mock implementation of net.Conn to simplify testing.
//
// TODO(e.burkov): Search for other places in code where it may be used. Move
// into aghtest then.
type fakeConn struct {
// Conn is embedded here simply to make *fakeConn a net.Conn without
// actually implementing all methods.
net.Conn
data []byte
}
// Write implements net.Conn interface for *fakeConn. It always returns 0 and a
// nil error without mutating the slice.
func (c *fakeConn) Write(_ []byte) (n int, err error) {
return 0, nil
}
// Read implements net.Conn interface for *fakeConn. It puts the content of
// c.data field into b up to the b's capacity.
func (c *fakeConn) Read(b []byte) (n int, err error) {
return copy(b, c.data), io.EOF
}
// Close implements net.Conn interface for *fakeConn. It always returns nil.
func (c *fakeConn) Close() (err error) {
return nil
}
// SetReadDeadline implements net.Conn interface for *fakeConn. It always
// returns nil.
func (c *fakeConn) SetReadDeadline(_ time.Time) (err error) {
return nil
}
// fakeDial is a mock implementation of customDialContext to simplify testing.
func (c *fakeConn) fakeDial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return c, nil
}
func TestWHOIS(t *testing.T) {
const (
nl = "\n"
data = `OrgName: FakeOrg LLC` + nl +
`City: Nonreal` + nl +
`Country: Imagiland` + nl
)
fc := &fakeConn{
data: []byte(data),
}
w := WHOIS{
timeoutMsec: 5000,
dialContext: fc.fakeDial,
}
resp, err := w.queryAll(context.Background(), "1.2.3.4")
assert.NoError(t, err)
m := whoisParse(resp)
require.NotEmpty(t, m)
assert.Equal(t, "FakeOrg LLC", m["orgname"])
assert.Equal(t, "Imagiland", m["country"])
assert.Equal(t, "Nonreal", m["city"])
}
func TestWHOISParse(t *testing.T) {
const (
city = "Nonreal"
country = "Imagiland"
orgname = "FakeOrgLLC"
whois = "whois.example.net"
)
testCases := []struct {
want strmap
name string
in string
}{{
want: strmap{},
name: "empty",
in: ``,
}, {
want: strmap{},
name: "comments",
in: "%\n#",
}, {
want: strmap{},
name: "no_colon",
in: "city",
}, {
want: strmap{},
name: "no_value",
in: "city:",
}, {
want: strmap{"city": city},
name: "city",
in: `city: ` + city,
}, {
want: strmap{"country": country},
name: "country",
in: `country: ` + country,
}, {
want: strmap{"orgname": orgname},
name: "orgname",
in: `orgname: ` + orgname,
}, {
want: strmap{"orgname": orgname},
name: "orgname_hyphen",
in: `org-name: ` + orgname,
}, {
want: strmap{"orgname": orgname},
name: "orgname_descr",
in: `descr: ` + orgname,
}, {
want: strmap{"orgname": orgname},
name: "orgname_netname",
in: `netname: ` + orgname,
}, {
want: strmap{"whois": whois},
name: "whois",
in: `whois: ` + whois,
}, {
want: strmap{"whois": whois},
name: "referralserver",
in: `referralserver: whois://` + whois,
}, {
want: strmap{},
name: "other",
in: `other: value`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := whoisParse(tc.in)
assert.Equal(t, tc.want, got)
})
}
}

63
internal/next/agh/agh.go Normal file
View File

@@ -0,0 +1,63 @@
// Package agh contains common entities and interfaces of AdGuard Home.
package agh
import "context"
// Service is the interface for API servers.
//
// TODO(a.garipov): Consider adding a context to Start.
//
// TODO(a.garipov): Consider adding a Wait method or making an extension
// interface for that.
type Service interface {
// Start starts the service. It does not block.
Start() (err error)
// Shutdown gracefully stops the service. ctx is used to determine
// a timeout before trying to stop the service less gracefully.
Shutdown(ctx context.Context) (err error)
}
// type check
var _ Service = EmptyService{}
// EmptyService is a [Service] that does nothing.
//
// TODO(a.garipov): Remove if unnecessary.
type EmptyService struct{}
// Start implements the [Service] interface for EmptyService.
func (EmptyService) Start() (err error) { return nil }
// Shutdown implements the [Service] interface for EmptyService.
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
// ServiceWithConfig is an extension of the [Service] interface for services
// that can return their configuration.
//
// TODO(a.garipov): Consider removing this generic interface if we figure out
// how to make it testable in a better way.
type ServiceWithConfig[ConfigType any] interface {
Service
Config() (c ConfigType)
}
// type check
var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil)
// EmptyServiceWithConfig is a ServiceWithConfig that does nothing. Its Config
// method returns Conf.
//
// TODO(a.garipov): Remove if unnecessary.
type EmptyServiceWithConfig[ConfigType any] struct {
EmptyService
Conf ConfigType
}
// Config implements the [ServiceWithConfig] interface for
// *EmptyServiceWithConfig.
func (s *EmptyServiceWithConfig[ConfigType]) Config() (conf ConfigType) {
return s.Conf
}

View File

@@ -1,23 +1,15 @@
package querylog
import "github.com/AdguardTeam/AdGuardHome/internal/whois"
// Client is the information required by the query log to match against clients
// during searches.
type Client struct {
WHOIS *ClientWHOIS `json:"whois,omitempty"`
Name string `json:"name"`
DisallowedRule string `json:"disallowed_rule"`
Disallowed bool `json:"disallowed"`
IgnoreQueryLog bool `json:"-"`
}
// ClientWHOIS is the filtered WHOIS data for the client.
//
// TODO(a.garipov): Merge with home.RuntimeClientWHOISInfo after the
// refactoring is done.
type ClientWHOIS struct {
City string `json:"city,omitempty"`
Country string `json:"country,omitempty"`
Orgname string `json:"orgname,omitempty"`
WHOIS *whois.Info `json:"whois,omitempty"`
Name string `json:"name"`
DisallowedRule string `json:"disallowed_rule"`
Disallowed bool `json:"disallowed"`
IgnoreQueryLog bool `json:"-"`
}
// clientCacheKey is the key by which a cached client information is found.

View File

@@ -161,12 +161,15 @@ func (l *queryLog) clear() {
// newLogEntry creates an instance of logEntry from parameters.
func newLogEntry(params *AddParams) (entry *logEntry) {
q := params.Question.Question[0]
qHost := q.Name
if qHost != "." {
qHost = strings.ToLower(q.Name[:len(q.Name)-1])
}
entry = &logEntry{
// TODO(d.kolyshev): Export this timestamp to func params.
Time: time.Now(),
QHost: strings.ToLower(q.Name[:len(q.Name)-1]),
Time: time.Now(),
QHost: qHost,
QType: dns.Type(q.Qtype).String(),
QClass: dns.Class(q.Qclass).String(),

View File

@@ -43,6 +43,7 @@ func TestQueryLog(t *testing.T) {
// Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
addEntry(l, "", net.IPv4(1, 1, 1, 5), net.IPv4(2, 2, 2, 5))
type tcAssertion struct {
host string
@@ -59,10 +60,11 @@ func TestQueryLog(t *testing.T) {
name: "all",
sCr: []searchCriterion{},
want: []tcAssertion{
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
{num: 0, host: ".", answer: net.IPv4(1, 1, 1, 5), client: net.IPv4(2, 2, 2, 5)},
{num: 1, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
{num: 2, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 4, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
},
}, {
name: "by_domain_strict",
@@ -104,10 +106,11 @@ func TestQueryLog(t *testing.T) {
value: "2.2.2",
}},
want: []tcAssertion{
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
{num: 0, host: ".", answer: net.IPv4(1, 1, 1, 5), client: net.IPv4(2, 2, 2, 5)},
{num: 1, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
{num: 2, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 4, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
},
}}

View File

@@ -0,0 +1,250 @@
// Package schedule provides types for scheduling.
package schedule
import (
"fmt"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil"
"gopkg.in/yaml.v3"
)
// Weekly is a schedule for one week. Each day of the week has one range with
// a beginning and an end.
type Weekly struct {
// location is used to calculate the offsets of the day ranges.
location *time.Location
// days are the day ranges of this schedule. The indexes of this array are
// the [time.Weekday] values.
days [7]dayRange
}
// EmptyWeekly creates empty weekly schedule with local time zone.
func EmptyWeekly() (w *Weekly) {
return &Weekly{
location: time.Local,
}
}
// FullWeekly creates full weekly schedule with local time zone.
//
// TODO(s.chzhen): Consider moving into tests.
func FullWeekly() (w *Weekly) {
fullDay := dayRange{start: 0, end: maxDayRange}
return &Weekly{
location: time.Local,
days: [7]dayRange{
time.Sunday: fullDay,
time.Monday: fullDay,
time.Tuesday: fullDay,
time.Wednesday: fullDay,
time.Thursday: fullDay,
time.Friday: fullDay,
time.Saturday: fullDay,
},
}
}
// Clone returns a deep copy of a weekly.
func (w *Weekly) Clone() (c *Weekly) {
// NOTE: Do not use time.LoadLocation, because the results will be
// different on time zone database update.
return &Weekly{
location: w.location,
days: w.days,
}
}
// Contains returns true if t is within the corresponding day range of the
// schedule in the schedule's time zone.
func (w *Weekly) Contains(t time.Time) (ok bool) {
t = t.In(w.location)
wd := t.Weekday()
dr := w.days[wd]
// Calculate the offset of the day range.
//
// NOTE: Do not use [time.Truncate] since it requires UTC time zone.
y, m, d := t.Date()
day := time.Date(y, m, d, 0, 0, 0, 0, w.location)
offset := t.Sub(day)
return dr.contains(offset)
}
// type check
var _ yaml.Unmarshaler = (*Weekly)(nil)
// UnmarshalYAML implements the [yaml.Unmarshaler] interface for *Weekly.
func (w *Weekly) UnmarshalYAML(value *yaml.Node) (err error) {
conf := &weeklyConfig{}
err = value.Decode(conf)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
weekly := Weekly{}
weekly.location, err = time.LoadLocation(conf.TimeZone)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
days := []dayConfig{
time.Sunday: conf.Sunday,
time.Monday: conf.Monday,
time.Tuesday: conf.Tuesday,
time.Wednesday: conf.Wednesday,
time.Thursday: conf.Thursday,
time.Friday: conf.Friday,
time.Saturday: conf.Saturday,
}
for i, d := range days {
r := dayRange{
start: d.Start.Duration,
end: d.End.Duration,
}
err = w.validate(r)
if err != nil {
return fmt.Errorf("weekday %s: %w", time.Weekday(i), err)
}
weekly.days[i] = r
}
*w = weekly
return nil
}
// weeklyConfig is the YAML configuration structure of Weekly.
type weeklyConfig struct {
// TimeZone is the local time zone.
TimeZone string `yaml:"time_zone"`
// Days of the week.
Sunday dayConfig `yaml:"sun,omitempty"`
Monday dayConfig `yaml:"mon,omitempty"`
Tuesday dayConfig `yaml:"tue,omitempty"`
Wednesday dayConfig `yaml:"wed,omitempty"`
Thursday dayConfig `yaml:"thu,omitempty"`
Friday dayConfig `yaml:"fri,omitempty"`
Saturday dayConfig `yaml:"sat,omitempty"`
}
// dayConfig is the YAML configuration structure of dayRange.
type dayConfig struct {
Start timeutil.Duration `yaml:"start"`
End timeutil.Duration `yaml:"end"`
}
// maxDayRange is the maximum value for day range end.
const maxDayRange = 24 * time.Hour
// validate returns the day range rounding errors, if any.
func (w *Weekly) validate(r dayRange) (err error) {
defer func() { err = errors.Annotate(err, "bad day range: %w") }()
err = r.validate()
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
start := r.start.Truncate(time.Minute)
end := r.end.Truncate(time.Minute)
switch {
case start != r.start:
return fmt.Errorf("start %s isn't rounded to minutes", r.start)
case end != r.end:
return fmt.Errorf("end %s isn't rounded to minutes", r.end)
default:
return nil
}
}
// type check
var _ yaml.Marshaler = (*Weekly)(nil)
// MarshalYAML implements the [yaml.Marshaler] interface for *Weekly.
func (w *Weekly) MarshalYAML() (v any, err error) {
return weeklyConfig{
TimeZone: w.location.String(),
Sunday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Sunday].start},
End: timeutil.Duration{Duration: w.days[time.Sunday].end},
},
Monday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Monday].start},
End: timeutil.Duration{Duration: w.days[time.Monday].end},
},
Tuesday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Tuesday].start},
End: timeutil.Duration{Duration: w.days[time.Tuesday].end},
},
Wednesday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Wednesday].start},
End: timeutil.Duration{Duration: w.days[time.Wednesday].end},
},
Thursday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Thursday].start},
End: timeutil.Duration{Duration: w.days[time.Thursday].end},
},
Friday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Friday].start},
End: timeutil.Duration{Duration: w.days[time.Friday].end},
},
Saturday: dayConfig{
Start: timeutil.Duration{Duration: w.days[time.Saturday].start},
End: timeutil.Duration{Duration: w.days[time.Saturday].end},
},
}, nil
}
// dayRange represents a single interval within a day. The interval begins at
// start and ends before end. That is, it contains a time point T if start <=
// T < end.
type dayRange struct {
// start is an offset from the beginning of the day. It must be greater
// than or equal to zero and less than 24h.
start time.Duration
// end is an offset from the beginning of the day. It must be greater than
// or equal to zero and less than or equal to 24h.
end time.Duration
}
// validate returns the day range validation errors, if any.
func (r dayRange) validate() (err error) {
switch {
case r == dayRange{}:
return nil
case r.start < 0:
return fmt.Errorf("start %s is negative", r.start)
case r.end < 0:
return fmt.Errorf("end %s is negative", r.end)
case r.start >= r.end:
return fmt.Errorf("start %s is greater or equal to end %s", r.start, r.end)
case r.start >= maxDayRange:
return fmt.Errorf("start %s is greater or equal to %s", r.start, maxDayRange)
case r.end > maxDayRange:
return fmt.Errorf("end %s is greater than %s", r.end, maxDayRange)
default:
return nil
}
}
// contains returns true if start <= offset < end, where offset is the time
// duration from the beginning of the day.
func (r *dayRange) contains(offset time.Duration) (ok bool) {
return r.start <= offset && offset < r.end
}

View File

@@ -0,0 +1,371 @@
package schedule
import (
"testing"
"time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestWeekly_Contains(t *testing.T) {
baseTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
otherTime := baseTime.Add(1 * timeutil.Day)
// NOTE: In the Etc area the sign of the offsets is flipped. So, Etc/GMT-3
// is actually UTC+03:00.
otherTZ := time.FixedZone("Etc/GMT-3", 3*60*60)
// baseSchedule, 12:00 to 14:00.
baseSchedule := &Weekly{
days: [7]dayRange{
time.Friday: {start: 12 * time.Hour, end: 14 * time.Hour},
},
location: time.UTC,
}
// allDaySchedule, 00:00 to 24:00.
allDaySchedule := &Weekly{
days: [7]dayRange{
time.Friday: {start: 0, end: 24 * time.Hour},
},
location: time.UTC,
}
// oneMinSchedule, 00:00 to 00:01.
oneMinSchedule := &Weekly{
days: [7]dayRange{
time.Friday: {start: 0, end: 1 * time.Minute},
},
location: time.UTC,
}
testCases := []struct {
schedule *Weekly
assert assert.BoolAssertionFunc
t time.Time
name string
}{{
schedule: EmptyWeekly(),
assert: assert.False,
t: baseTime,
name: "empty",
}, {
schedule: allDaySchedule,
assert: assert.True,
t: baseTime,
name: "same_day_all_day",
}, {
schedule: baseSchedule,
assert: assert.True,
t: baseTime.Add(13 * time.Hour),
name: "same_day_inside",
}, {
schedule: baseSchedule,
assert: assert.False,
t: baseTime.Add(11 * time.Hour),
name: "same_day_outside",
}, {
schedule: allDaySchedule,
assert: assert.True,
t: baseTime.Add(24*time.Hour - time.Second),
name: "same_day_last_second",
}, {
schedule: allDaySchedule,
assert: assert.False,
t: otherTime,
name: "other_day_all_day",
}, {
schedule: baseSchedule,
assert: assert.False,
t: otherTime.Add(13 * time.Hour),
name: "other_day_inside",
}, {
schedule: baseSchedule,
assert: assert.False,
t: otherTime.Add(11 * time.Hour),
name: "other_day_outside",
}, {
schedule: baseSchedule,
assert: assert.True,
t: baseTime.Add(13 * time.Hour).In(otherTZ),
name: "same_day_inside_other_tz",
}, {
schedule: baseSchedule,
assert: assert.False,
t: baseTime.Add(11 * time.Hour).In(otherTZ),
name: "same_day_outside_other_tz",
}, {
schedule: oneMinSchedule,
assert: assert.True,
t: baseTime,
name: "one_minute_beginning",
}, {
schedule: oneMinSchedule,
assert: assert.True,
t: baseTime.Add(1*time.Minute - 1),
name: "one_minute_end",
}, {
schedule: oneMinSchedule,
assert: assert.False,
t: baseTime.Add(1 * time.Minute),
name: "one_minute_past_end",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.assert(t, tc.schedule.Contains(tc.t))
})
}
}
const brusselsSunday = `
sun:
start: 12h
end: 14h
time_zone: Europe/Brussels
`
func TestWeekly_UnmarshalYAML(t *testing.T) {
const (
sameTime = `
sun:
start: 9h
end: 9h
`
negativeStart = `
sun:
start: -1h
end: 1h
`
badTZ = `
time_zone: "bad_timezone"
`
badYAML = `
yaml: "bad"
yaml: "bad"
`
)
brusseltsTZ, err := time.LoadLocation("Europe/Brussels")
require.NoError(t, err)
brusselsWeekly := &Weekly{
days: [7]dayRange{{
start: time.Hour * 12,
end: time.Hour * 14,
}},
location: brusseltsTZ,
}
testCases := []struct {
name string
wantErrMsg string
data []byte
want *Weekly
}{{
name: "empty",
wantErrMsg: "",
data: []byte(""),
want: &Weekly{},
}, {
name: "null",
wantErrMsg: "",
data: []byte("null"),
want: &Weekly{},
}, {
name: "brussels_sunday",
wantErrMsg: "",
data: []byte(brusselsSunday),
want: brusselsWeekly,
}, {
name: "start_equal_end",
wantErrMsg: "weekday Sunday: bad day range: start 9h0m0s is greater or equal to end 9h0m0s",
data: []byte(sameTime),
want: &Weekly{},
}, {
name: "start_negative",
wantErrMsg: "weekday Sunday: bad day range: start -1h0m0s is negative",
data: []byte(negativeStart),
want: &Weekly{},
}, {
name: "bad_time_zone",
wantErrMsg: "unknown time zone bad_timezone",
data: []byte(badTZ),
want: &Weekly{},
}, {
name: "bad_yaml",
wantErrMsg: "yaml: unmarshal errors:\n line 3: mapping key \"yaml\" already defined at line 2",
data: []byte(badYAML),
want: &Weekly{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
w := &Weekly{}
err = yaml.Unmarshal(tc.data, w)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, w)
})
}
}
func TestWeekly_MarshalYAML(t *testing.T) {
brusselsTZ, err := time.LoadLocation("Europe/Brussels")
require.NoError(t, err)
brusselsWeekly := &Weekly{
days: [7]dayRange{time.Sunday: {
start: time.Hour * 12,
end: time.Hour * 14,
}},
location: brusselsTZ,
}
testCases := []struct {
name string
data []byte
want *Weekly
}{{
name: "empty",
data: []byte(""),
want: &Weekly{},
}, {
name: "null",
data: []byte("null"),
want: &Weekly{},
}, {
name: "brussels_sunday",
data: []byte(brusselsSunday),
want: brusselsWeekly,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var data []byte
data, err = yaml.Marshal(brusselsWeekly)
require.NoError(t, err)
w := &Weekly{}
err = yaml.Unmarshal(data, w)
require.NoError(t, err)
assert.Equal(t, brusselsWeekly, w)
})
}
}
func TestWeekly_Validate(t *testing.T) {
testCases := []struct {
name string
in dayRange
wantErrMsg string
}{{
name: "empty",
wantErrMsg: "",
in: dayRange{},
}, {
name: "start_seconds",
wantErrMsg: "bad day range: start 1s isn't rounded to minutes",
in: dayRange{
start: time.Second,
end: time.Hour,
},
}, {
name: "end_seconds",
wantErrMsg: "bad day range: end 1s isn't rounded to minutes",
in: dayRange{
start: 0,
end: time.Second,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
w := &Weekly{}
err := w.validate(tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
func TestDayRange_Validate(t *testing.T) {
testCases := []struct {
name string
in dayRange
wantErrMsg string
}{{
name: "empty",
wantErrMsg: "",
in: dayRange{},
}, {
name: "valid",
wantErrMsg: "",
in: dayRange{
start: time.Hour,
end: time.Hour * 2,
},
}, {
name: "valid_end_max",
wantErrMsg: "",
in: dayRange{
start: 0,
end: time.Hour * 24,
},
}, {
name: "start_negative",
wantErrMsg: "start -1h0m0s is negative",
in: dayRange{
start: time.Hour * -1,
end: time.Hour * 2,
},
}, {
name: "end_negative",
wantErrMsg: "end -1h0m0s is negative",
in: dayRange{
start: 0,
end: time.Hour * -1,
},
}, {
name: "start_equal_end",
wantErrMsg: "start 1h0m0s is greater or equal to end 1h0m0s",
in: dayRange{
start: time.Hour,
end: time.Hour,
},
}, {
name: "start_greater_end",
wantErrMsg: "start 2h0m0s is greater or equal to end 1h0m0s",
in: dayRange{
start: time.Hour * 2,
end: time.Hour,
},
}, {
name: "start_equal_max",
wantErrMsg: "start 24h0m0s is greater or equal to 24h0m0s",
in: dayRange{
start: time.Hour * 24,
end: time.Hour * 48,
},
}, {
name: "end_greater_max",
wantErrMsg: "end 48h0m0s is greater than 24h0m0s",
in: dayRange{
start: 0,
end: time.Hour * 48,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.in.validate()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View File

@@ -4,20 +4,21 @@ go 1.19
require (
github.com/fzipp/gocyclo v0.6.0
github.com/golangci/misspell v0.4.0
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28
github.com/golangci/misspell v0.4.1
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601
github.com/kisielk/errcheck v1.6.3
github.com/kyoh86/looppointer v0.2.1
github.com/securego/gosec/v2 v2.16.0
golang.org/x/tools v0.9.3
golang.org/x/vuln v0.1.0
github.com/uudashr/gocognit v1.0.6
golang.org/x/tools v0.10.0
golang.org/x/vuln v0.2.0
honnef.co/go/tools v0.4.3
mvdan.cc/gofumpt v0.5.0
mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8
mvdan.cc/unparam v0.0.0-20230610194454-9ea02bef9868
)
require (
github.com/BurntSushi/toml v1.3.1 // indirect
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gookit/color v1.5.3 // indirect
@@ -25,9 +26,9 @@ require (
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/sync v0.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20230626212559-97b1e661b5df // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.9.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -1,6 +1,5 @@
github.com/BurntSushi/toml v1.3.1 h1:rHnDkSK+/g6DlREUK73PkmIs60pqrnuduK+JmP++JmU=
github.com/BurntSushi/toml v1.3.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8=
github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
@@ -8,8 +7,8 @@ github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/golangci/misspell v0.4.0 h1:KtVB/hTK4bbL/S6bs64rYyk8adjmh1BygbBiaAiX+a0=
github.com/golangci/misspell v0.4.0/go.mod h1:W6O/bwV6lGDxUCChm2ykw9NQdd5bYd1Xkjo88UcWyJc=
github.com/golangci/misspell v0.4.1 h1:+y73iSicVy2PqyX7kmUefHusENlrP9YwuHZHPLGQj/g=
github.com/golangci/misspell v0.4.1/go.mod h1:9mAN1quEo3DlpbaIKKyEvRxK1pwqR9s/Sea1bJCtlNI=
github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786 h1:rcv+Ippz6RAtvaGgKxc+8FQIpxHgsF+HBzPyYL2cyVU=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
@@ -19,8 +18,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE=
github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE=
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28 h1:9alfqbrhuD+9fLZ4iaAVwhlp5PEhmnBt7yvK2Oy5C1U=
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601 h1:mrEEilTAUmaAORhssPPkxj84TsHrPMLBGW2Z4SoTxm8=
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8=
github.com/kisielk/errcheck v1.6.3/go.mod h1:nXw/i/MfnvRHqXa7XXmQMUB0oNFGuBrNI8d8NLy0LPw=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -40,6 +39,8 @@ github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyO
github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/uudashr/gocognit v1.0.6 h1:2Cgi6MweCsdB6kpcVQp7EW4U23iBFQWfTXiWlyp842Y=
github.com/uudashr/gocognit v1.0.6/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -51,25 +52,26 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 h1:pnP8r+W8Fm7XJ8CWtXi4S9oJmPBTrkfYN/dNbaPj6Y4=
golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/exp/typeparams v0.0.0-20230626212559-97b1e661b5df h1:jfUqBujZx2dktJVEmZpCkyngz7MWrVv1y9kLOqFNsqw=
golang.org/x/exp/typeparams v0.0.0-20230626212559-97b1e661b5df/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -79,8 +81,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -92,10 +95,11 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/vuln v0.1.0 h1:9GRdj6wAIkDrsMevuolY+SXERPjQPp2P1ysYA0jpZe0=
golang.org/x/vuln v0.1.0/go.mod h1:/YuzZYjGbwB8y19CisAppfyw3uTZnuCz3r+qgx/QRzU=
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
golang.org/x/vuln v0.2.0 h1:Dlz47lW0pvPHU7tnb10S8vbMn9GnV2B6eyT7Tem5XBI=
golang.org/x/vuln v0.2.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -107,5 +111,5 @@ honnef.co/go/tools v0.4.3 h1:o/n5/K5gXqk8Gozvs2cnL0F2S1/g1vcGCAx2vETjITw=
honnef.co/go/tools v0.4.3/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA=
mvdan.cc/gofumpt v0.5.0 h1:0EQ+Z56k8tXjj/6TQD25BFNKQXpCvT0rnansIc7Ug5E=
mvdan.cc/gofumpt v0.5.0/go.mod h1:HBeVDtMKRZpXyxFciAirzdKklDlGu8aAy1wEbH5Y9js=
mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8 h1:VuJo4Mt0EVPychre4fNlDWDuE5AjXtPJpRUWqZDQhaI=
mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8/go.mod h1:Oh/d7dEtzsNHGOq1Cdv8aMm3KdKhVvPbRQcM8WFpBR8=
mvdan.cc/unparam v0.0.0-20230610194454-9ea02bef9868 h1:F4Q7pXcrU9UiU1fq0ZWqSOxKjNAteRuDr7JDk7uVLRQ=
mvdan.cc/unparam v0.0.0-20230610194454-9ea02bef9868/go.mod h1:6ZaiQyI7Tiq0HQ56g6N8TlkSd80/LyagZeaw8mb7jYE=

View File

@@ -9,6 +9,7 @@ import (
_ "github.com/kisielk/errcheck"
_ "github.com/kyoh86/looppointer"
_ "github.com/securego/gosec/v2/cmd/gosec"
_ "github.com/uudashr/gocognit/cmd/gocognit"
_ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness"
_ "golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow"
_ "golang.org/x/vuln/cmd/govulncheck"

View File

@@ -143,14 +143,7 @@ func Verbose() (v string) {
runtime.Version(),
)
if committime != "" {
commitTimeUnix, err := strconv.ParseInt(committime, 10, 64)
if err != nil {
stringutil.WriteToBuilder(b, nl, vFmtTimeHdr, fmt.Sprintf("parse error: %s", err))
} else {
stringutil.WriteToBuilder(b, nl, vFmtTimeHdr, time.Unix(commitTimeUnix, 0).String())
}
}
writeCommitTime(b)
stringutil.WriteToBuilder(b, nl, vFmtGOOSHdr, nl, vFmtGOARCHHdr)
if goarm != "" {
@@ -179,3 +172,16 @@ func Verbose() (v string) {
return b.String()
}
func writeCommitTime(b *strings.Builder) {
if committime == "" {
return
}
commitTimeUnix, err := strconv.ParseInt(committime, 10, 64)
if err != nil {
stringutil.WriteToBuilder(b, "\n", vFmtTimeHdr, fmt.Sprintf("parse error: %s", err))
} else {
stringutil.WriteToBuilder(b, "\n", vFmtTimeHdr, time.Unix(commitTimeUnix, 0).String())
}
}

388
internal/whois/whois.go Normal file
View File

@@ -0,0 +1,388 @@
// Package whois provides WHOIS functionality.
package whois
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/bluele/gcache"
)
const (
// DefaultServer is the default WHOIS server.
DefaultServer = "whois.arin.net"
// DefaultPort is the default port for WHOIS requests.
DefaultPort = 43
)
// Interface provides WHOIS functionality.
type Interface interface {
// Process makes WHOIS request and returns WHOIS information or nil.
// changed indicates that Info was updated since last request.
Process(ctx context.Context, ip netip.Addr) (info *Info, changed bool)
}
// Empty is an empty [Interface] implementation which does nothing.
type Empty struct{}
// type check
var _ Interface = (*Empty)(nil)
// Process implements the [Interface] interface for Empty.
func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool) {
return nil, false
}
// Config is the configuration structure for Default.
type Config struct {
// DialContext specifies the dial function for creating unencrypted TCP
// connections.
DialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// ServerAddr is the address of the WHOIS server.
ServerAddr string
// Timeout is the timeout for WHOIS requests.
Timeout time.Duration
// CacheTTL is the Time to Live duration for cached IP addresses.
CacheTTL time.Duration
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
MaxConnReadSize int64
// MaxRedirects is the maximum redirects count.
MaxRedirects int
// MaxInfoLen is the maximum length of Info fields returned by Process.
MaxInfoLen int
// CacheSize is the maximum size of the cache. It must be greater than
// zero.
CacheSize int
// Port is the port for WHOIS requests.
Port uint16
}
// Default is the default WHOIS information processor.
type Default struct {
// cache is the cache containing IP addresses of clients. An active IP
// address is resolved once again after it expires. If IP address couldn't
// be resolved, it stays here for some time to prevent further attempts to
// resolve the same IP.
cache gcache.Cache
// dialContext connects to a remote server resolving hostname using our own
// DNS server and unecrypted TCP connection.
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// serverAddr is the address of the WHOIS server.
serverAddr string
// portStr is the port for WHOIS requests.
portStr string
// timeout is the timeout for WHOIS requests.
timeout time.Duration
// cacheTTL is the Time to Live duration for cached IP addresses.
cacheTTL time.Duration
// maxConnReadSize is an upper limit in bytes for reading from net.Conn.
maxConnReadSize int64
// maxRedirects is the maximum redirects count.
maxRedirects int
// maxInfoLen is the maximum length of Info fields returned by Process.
maxInfoLen int
}
// New returns a new default WHOIS information processor. conf must not be
// nil.
func New(conf *Config) (w *Default) {
return &Default{
serverAddr: conf.ServerAddr,
dialContext: conf.DialContext,
timeout: conf.Timeout,
cache: gcache.New(conf.CacheSize).LRU().Build(),
maxConnReadSize: conf.MaxConnReadSize,
maxRedirects: conf.MaxRedirects,
portStr: strconv.Itoa(int(conf.Port)),
maxInfoLen: conf.MaxInfoLen,
cacheTTL: conf.CacheTTL,
}
}
// trimValue trims s and replaces the last 3 characters of the cut with "..."
// to fit into max. max must be greater than 3.
func trimValue(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max-3] + "..."
}
// isWHOISComment returns true if the data is empty or is a WHOIS comment.
func isWHOISComment(data []byte) (ok bool) {
return len(data) == 0 || data[0] == '#' || data[0] == '%'
}
// whoisParse parses a subset of plain-text data from the WHOIS response into a
// string map. It trims values of the returned map to maxLen.
func whoisParse(data []byte, maxLen int) (info map[string]string) {
info = map[string]string{}
var orgname string
lines := bytes.Split(data, []byte("\n"))
for _, l := range lines {
if isWHOISComment(l) {
continue
}
before, after, found := bytes.Cut(l, []byte(":"))
if !found {
continue
}
key := strings.ToLower(string(before))
val := strings.TrimSpace(string(after))
if val == "" {
continue
}
switch key {
case "orgname", "org-name":
key = "orgname"
val = trimValue(val, maxLen)
orgname = val
case "city", "country":
val = trimValue(val, maxLen)
case "descr", "netname":
key = "orgname"
val = stringutil.Coalesce(orgname, val)
orgname = val
case "whois":
key = "whois"
case "referralserver":
key = "whois"
val = strings.TrimPrefix(val, "whois://")
default:
continue
}
info[key] = val
}
return info
}
// query sends request to a server and returns the response or error.
func (w *Default) query(ctx context.Context, target, serverAddr string) (data []byte, err error) {
addr, _, _ := net.SplitHostPort(serverAddr)
if addr == DefaultServer {
// Display type flags for query.
//
// See https://www.arin.net/resources/registry/whois/rws/api/#nicname-whois-queries.
target = "n + " + target
}
conn, err := w.dialContext(ctx, "tcp", serverAddr)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
r, err := aghio.LimitReader(conn, w.maxConnReadSize)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
_ = conn.SetReadDeadline(time.Now().Add(w.timeout))
_, err = io.WriteString(conn, target+"\r\n")
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
// This use of ReadAll is now safe, because we limited the conn Reader.
data, err = io.ReadAll(r)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
return data, nil
}
// queryAll queries WHOIS server and handles redirects.
func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) {
server := net.JoinHostPort(w.serverAddr, w.portStr)
var data []byte
for i := 0; i < w.maxRedirects; i++ {
data, err = w.query(ctx, target, server)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
log.Debug("whois: received response (%d bytes) from %q about %q", len(data), server, target)
info = whoisParse(data, w.maxInfoLen)
redir, ok := info["whois"]
if !ok {
return info, nil
}
redir = strings.ToLower(redir)
_, _, err = net.SplitHostPort(redir)
if err != nil {
server = net.JoinHostPort(redir, w.portStr)
} else {
server = redir
}
log.Debug("whois: redirected to %q about %q", redir, target)
}
return nil, fmt.Errorf("whois: redirect loop")
}
// type check
var _ Interface = (*Default)(nil)
// Process makes WHOIS request and returns WHOIS information or nil. changed
// indicates that Info was updated since last request.
func (w *Default) Process(ctx context.Context, ip netip.Addr) (wi *Info, changed bool) {
if netutil.IsSpecialPurposeAddr(ip) {
return nil, false
}
wi, expired := w.findInCache(ip)
if wi != nil && !expired {
// Don't return an empty struct so that the frontend doesn't get
// confused.
if (*wi == Info{}) {
return nil, false
}
return wi, false
}
return w.requestInfo(ctx, ip, wi)
}
// requestInfo makes WHOIS request and returns WHOIS info. changed is false if
// received information is equal to cached.
func (w *Default) requestInfo(
ctx context.Context,
ip netip.Addr,
cached *Info,
) (wi *Info, changed bool) {
var info Info
defer func() {
item := toCacheItem(info, w.cacheTTL)
err := w.cache.Set(ip, item)
if err != nil {
log.Debug("whois: cache: adding item %q: %s", ip, err)
}
}()
kv, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("whois: quering about %q: %s", ip, err)
return nil, true
}
info = Info{
City: kv["city"],
Country: kv["country"],
Orgname: kv["orgname"],
}
changed = cached == nil || info != *cached
// Don't return an empty struct so that the frontend doesn't get confused.
if (info == Info{}) {
return nil, changed
}
return &info, changed
}
// findInCache finds Info in the cache. expired indicates that Info is valid.
func (w *Default) findInCache(ip netip.Addr) (wi *Info, expired bool) {
val, err := w.cache.Get(ip)
if err != nil {
if !errors.Is(err, gcache.KeyNotFoundError) {
log.Debug("whois: cache: retrieving info about %q: %s", ip, err)
}
return nil, false
}
item, ok := val.(*cacheItem)
if !ok {
log.Debug("whois: cache: %q bad type %T", ip, val)
return nil, false
}
return fromCacheItem(item)
}
// Info is the filtered WHOIS data for a runtime client.
type Info struct {
City string `json:"city,omitempty"`
Country string `json:"country,omitempty"`
Orgname string `json:"orgname,omitempty"`
}
// cacheItem represents an item that we will store in the cache.
type cacheItem struct {
// expiry is the time when cacheItem will expire.
expiry time.Time
// info is the WHOIS data for a runtime client.
info *Info
}
// toCacheItem creates a cached item from a WHOIS info and Time to Live
// duration.
func toCacheItem(info Info, ttl time.Duration) (item *cacheItem) {
return &cacheItem{
expiry: time.Now().Add(ttl),
info: &info,
}
}
// fromCacheItem creates a WHOIS info from the cached item. expired indicates
// that WHOIS info is valid. item must not be nil.
func fromCacheItem(item *cacheItem) (info *Info, expired bool) {
if time.Now().After(item.expiry) {
return item.info, true
}
return item.info, false
}

View File

@@ -0,0 +1,155 @@
package whois_test
import (
"context"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/testutil/fakenet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDefault_Process(t *testing.T) {
const (
nl = "\n"
city = "Nonreal"
country = "Imagiland"
orgname = "FakeOrgLLC"
referralserver = "whois.example.net"
)
ip := netip.MustParseAddr("1.2.3.4")
testCases := []struct {
want *whois.Info
name string
data string
}{{
want: nil,
name: "empty",
data: "",
}, {
want: nil,
name: "comments",
data: "%\n#",
}, {
want: nil,
name: "no_colon",
data: "city",
}, {
want: nil,
name: "no_value",
data: "city:",
}, {
want: &whois.Info{
City: city,
},
name: "city",
data: "city: " + city,
}, {
want: &whois.Info{
Country: country,
},
name: "country",
data: "country: " + country,
}, {
want: &whois.Info{
Orgname: orgname,
},
name: "orgname",
data: "orgname: " + orgname,
}, {
want: &whois.Info{
Orgname: orgname,
},
name: "orgname_hyphen",
data: "org-name: " + orgname,
}, {
want: &whois.Info{
Orgname: orgname,
},
name: "orgname_descr",
data: "descr: " + orgname,
}, {
want: &whois.Info{
Orgname: orgname,
},
name: "orgname_netname",
data: "netname: " + orgname,
}, {
want: &whois.Info{
City: city,
Country: country,
Orgname: orgname,
},
name: "full",
data: "OrgName: " + orgname + nl + "City: " + city + nl + "Country: " + country,
}, {
want: nil,
name: "whois",
data: "whois: " + referralserver,
}, {
want: nil,
name: "referralserver",
data: "referralserver: whois://" + referralserver,
}, {
want: nil,
name: "other",
data: "other: value",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hit := 0
fakeConn := &fakenet.Conn{
OnRead: func(b []byte) (n int, err error) {
hit++
return copy(b, tc.data), io.EOF
},
OnWrite: func(b []byte) (n int, err error) {
return len(b), nil
},
OnClose: func() (err error) {
return nil
},
OnSetReadDeadline: func(t time.Time) (err error) {
return nil
},
}
w := whois.New(&whois.Config{
Timeout: 5 * time.Second,
DialContext: func(_ context.Context, _, addr string) (_ net.Conn, _ error) {
hit = 0
return fakeConn, nil
},
MaxConnReadSize: 1024,
MaxRedirects: 3,
MaxInfoLen: 250,
CacheSize: 100,
CacheTTL: time.Hour,
})
got, changed := w.Process(context.Background(), ip)
require.True(t, changed)
assert.Equal(t, tc.want, got)
assert.Equal(t, 1, hit)
// From cache.
got, changed = w.Process(context.Background(), ip)
require.False(t, changed)
assert.Equal(t, tc.want, got)
assert.Equal(t, 1, hit)
})
}
}