all: sync with master; upd chlog
This commit is contained in:
@@ -56,15 +56,20 @@ func (rm *requestMatcher) MatchRequest(
|
||||
) (res *urlfilter.DNSResult, ok bool) {
|
||||
switch req.DNSType {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
log.Debug("%s: handling the request for %s", hostsContainerPrefix, req.Hostname)
|
||||
log.Debug(
|
||||
"%s: handling %s request for %s",
|
||||
hostsContainerPrefix,
|
||||
dns.Type(req.DNSType),
|
||||
req.Hostname,
|
||||
)
|
||||
|
||||
rm.stateLock.RLock()
|
||||
defer rm.stateLock.RUnlock()
|
||||
|
||||
return rm.engine.MatchRequest(req)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
rm.stateLock.RLock()
|
||||
defer rm.stateLock.RUnlock()
|
||||
|
||||
return rm.engine.MatchRequest(req)
|
||||
}
|
||||
|
||||
// Translate returns the source hosts-syntax rule for the generated dnsrewrite
|
||||
@@ -96,6 +101,8 @@ const hostsContainerPrefix = "hosts container"
|
||||
|
||||
// HostsContainer stores the relevant hosts database provided by the OS and
|
||||
// processes both A/AAAA and PTR DNS requests for those.
|
||||
//
|
||||
// TODO(e.burkov): Improve API and move to golibs.
|
||||
type HostsContainer struct {
|
||||
// requestMatcher matches the requests and translates the rules. It's
|
||||
// embedded to implement MatchRequest and Translate for *HostsContainer.
|
||||
|
||||
@@ -25,11 +25,8 @@ func (s *bitSet) isSet(n uint64) (ok bool) {
|
||||
|
||||
var word uint64
|
||||
word, ok = s.words[wordIdx]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return word&(1<<bitIdx) != 0
|
||||
return ok && word&(1<<bitIdx) != 0
|
||||
}
|
||||
|
||||
// set sets or unsets a bit.
|
||||
|
||||
@@ -249,31 +249,30 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
|
||||
func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DHCPv4) {
|
||||
switch giaddr, ciaddr, mtype := req.GatewayIPAddr, req.ClientIPAddr, resp.MessageType(); {
|
||||
case giaddr != nil && !giaddr.IsUnspecified():
|
||||
// Send any return messages to the server port on the BOOTP
|
||||
// relay agent whose address appears in giaddr.
|
||||
// Send any return messages to the server port on the BOOTP relay agent
|
||||
// whose address appears in giaddr.
|
||||
peer = &net.UDPAddr{
|
||||
IP: giaddr,
|
||||
Port: dhcpv4.ServerPort,
|
||||
}
|
||||
if mtype == dhcpv4.MessageTypeNak {
|
||||
// Set the broadcast bit in the DHCPNAK, so that the relay agent
|
||||
// broadcasts it to the client, because the client may not have
|
||||
// a correct network address or subnet mask, and the client may not
|
||||
// be answering ARP requests.
|
||||
// broadcasts it to the client, because the client may not have a
|
||||
// correct network address or subnet mask, and the client may not be
|
||||
// answering ARP requests.
|
||||
resp.SetBroadcast()
|
||||
}
|
||||
case mtype == dhcpv4.MessageTypeNak:
|
||||
// Broadcast any DHCPNAK messages to 0xffffffff.
|
||||
case ciaddr != nil && !ciaddr.IsUnspecified():
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the address in
|
||||
// ciaddr.
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the address in ciaddr.
|
||||
peer = &net.UDPAddr{
|
||||
IP: ciaddr,
|
||||
Port: dhcpv4.ClientPort,
|
||||
}
|
||||
case !req.IsBroadcast() && req.ClientHWAddr != nil:
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the client's
|
||||
// hardware address and yiaddr.
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the client's hardware
|
||||
// address and yiaddr.
|
||||
peer = &dhcpUnicastAddr{
|
||||
Addr: raw.Addr{HardwareAddr: req.ClientHWAddr},
|
||||
yiaddr: resp.YourIPAddr,
|
||||
|
||||
@@ -247,31 +247,30 @@ func (c *dhcpConn) buildEtherPkt(payload []byte, peer *dhcpUnicastAddr) (pkt []b
|
||||
func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DHCPv4) {
|
||||
switch giaddr, ciaddr, mtype := req.GatewayIPAddr, req.ClientIPAddr, resp.MessageType(); {
|
||||
case giaddr != nil && !giaddr.IsUnspecified():
|
||||
// Send any return messages to the server port on the BOOTP
|
||||
// relay agent whose address appears in giaddr.
|
||||
// Send any return messages to the server port on the BOOTP relay agent
|
||||
// whose address appears in giaddr.
|
||||
peer = &net.UDPAddr{
|
||||
IP: giaddr,
|
||||
Port: dhcpv4.ServerPort,
|
||||
}
|
||||
if mtype == dhcpv4.MessageTypeNak {
|
||||
// Set the broadcast bit in the DHCPNAK, so that the relay agent
|
||||
// broadcasts it to the client, because the client may not have
|
||||
// a correct network address or subnet mask, and the client may not
|
||||
// be answering ARP requests.
|
||||
// broadcasts it to the client, because the client may not have a
|
||||
// correct network address or subnet mask, and the client may not be
|
||||
// answering ARP requests.
|
||||
resp.SetBroadcast()
|
||||
}
|
||||
case mtype == dhcpv4.MessageTypeNak:
|
||||
// Broadcast any DHCPNAK messages to 0xffffffff.
|
||||
case ciaddr != nil && !ciaddr.IsUnspecified():
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the address in
|
||||
// ciaddr.
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the address in ciaddr.
|
||||
peer = &net.UDPAddr{
|
||||
IP: ciaddr,
|
||||
Port: dhcpv4.ClientPort,
|
||||
}
|
||||
case !req.IsBroadcast() && req.ClientHWAddr != nil:
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the client's
|
||||
// hardware address and yiaddr.
|
||||
// Unicast DHCPOFFER and DHCPACK messages to the client's hardware
|
||||
// address and yiaddr.
|
||||
peer = &dhcpUnicastAddr{
|
||||
Addr: packet.Addr{HardwareAddr: req.ClientHWAddr},
|
||||
yiaddr: resp.YourIPAddr,
|
||||
|
||||
@@ -28,8 +28,9 @@ const (
|
||||
defaultBackoff time.Duration = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Lease contains the necessary information about a DHCP lease. It's used in
|
||||
// various places. So don't change it without good reason.
|
||||
// Lease contains the necessary information about a DHCP lease. It's used as is
|
||||
// in the database, so don't change it until it's absolutely necessary, see
|
||||
// [dataVersion].
|
||||
type Lease struct {
|
||||
// Expiry is the expiration time of the lease.
|
||||
Expiry time.Time `json:"expires"`
|
||||
@@ -41,8 +42,6 @@ type Lease struct {
|
||||
HWAddr net.HardwareAddr `json:"mac"`
|
||||
|
||||
// IP is the IP address leased to the client.
|
||||
//
|
||||
// TODO(a.garipov): Migrate leases.db.
|
||||
IP netip.Addr `json:"ip"`
|
||||
|
||||
// IsStatic defines if the lease is static.
|
||||
|
||||
@@ -51,6 +51,9 @@ func migrateDB(conf *ServerConfig) (err error) {
|
||||
oldLeasesPath := filepath.Join(conf.WorkDir, dbFilename)
|
||||
dataDirPath := filepath.Join(conf.DataDir, dataFilename)
|
||||
|
||||
// #nosec G304 -- Trust this path, since it's taken from the old file name
|
||||
// relative to the working directory and should generally be considered
|
||||
// safe.
|
||||
file, err := os.Open(oldLeasesPath)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Nothing to migrate.
|
||||
|
||||
@@ -200,7 +200,7 @@ func createICMPv6RAPacket(params icmpv6RA) (data []byte, err error) {
|
||||
func (ra *raCtx) Init() (err error) {
|
||||
ra.stop.Store(0)
|
||||
ra.conn = nil
|
||||
if !(ra.raAllowSLAAC || ra.raSLAACOnly) {
|
||||
if !ra.raAllowSLAAC && !ra.raSLAACOnly {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
86
internal/dhcpsvc/config.go
Normal file
86
internal/dhcpsvc/config.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// Config is the configuration for the DHCP service.
|
||||
type Config struct {
|
||||
// Interfaces stores configurations of DHCP server specific for the network
|
||||
// interface identified by its name.
|
||||
Interfaces map[string]*InterfaceConfig
|
||||
|
||||
// LocalDomainName is the top-level domain name to use for resolving DHCP
|
||||
// clients' hostnames.
|
||||
LocalDomainName string
|
||||
|
||||
// ICMPTimeout is the timeout for checking another DHCP server's presence.
|
||||
ICMPTimeout time.Duration
|
||||
|
||||
// Enabled is the state of the service, whether it is enabled or not.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// InterfaceConfig is the configuration of a single DHCP interface.
|
||||
type InterfaceConfig struct {
|
||||
// IPv4 is the configuration of DHCP protocol for IPv4.
|
||||
IPv4 *IPv4Config
|
||||
|
||||
// IPv6 is the configuration of DHCP protocol for IPv6.
|
||||
IPv6 *IPv6Config
|
||||
}
|
||||
|
||||
// IPv4Config is the interface-specific configuration for DHCPv4.
|
||||
type IPv4Config struct {
|
||||
// GatewayIP is the IPv4 address of the network's gateway. It is used as
|
||||
// the default gateway for DHCP clients and also used in calculating the
|
||||
// network-specific broadcast address.
|
||||
GatewayIP netip.Addr
|
||||
|
||||
// SubnetMask is the IPv4 subnet mask of the network. It should be a valid
|
||||
// IPv4 subnet mask (i.e. all 1s followed by all 0s).
|
||||
SubnetMask netip.Addr
|
||||
|
||||
// RangeStart is the first address in the range to assign to DHCP clients.
|
||||
RangeStart netip.Addr
|
||||
|
||||
// RangeEnd is the last address in the range to assign to DHCP clients.
|
||||
RangeEnd netip.Addr
|
||||
|
||||
// Options is the list of DHCP options to send to DHCP clients.
|
||||
Options layers.DHCPOptions
|
||||
|
||||
// LeaseDuration is the TTL of a DHCP lease.
|
||||
LeaseDuration time.Duration
|
||||
|
||||
// Enabled is the state of the DHCPv4 service, whether it is enabled or not
|
||||
// on the specific interface.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// IPv6Config is the interface-specific configuration for DHCPv6.
|
||||
type IPv6Config struct {
|
||||
// RangeStart is the first address in the range to assign to DHCP clients.
|
||||
RangeStart netip.Addr
|
||||
|
||||
// Options is the list of DHCP options to send to DHCP clients.
|
||||
Options layers.DHCPOptions
|
||||
|
||||
// LeaseDuration is the TTL of a DHCP lease.
|
||||
LeaseDuration time.Duration
|
||||
|
||||
// RASlaacOnly defines whether the DHCP clients should only use SLAAC for
|
||||
// address assignment.
|
||||
RASLAACOnly bool
|
||||
|
||||
// RAAllowSlaac defines whether the DHCP clients may use SLAAC for address
|
||||
// assignment.
|
||||
RAAllowSLAAC bool
|
||||
|
||||
// Enabled is the state of the DHCPv6 service, whether it is enabled or not
|
||||
// on the specific interface.
|
||||
Enabled bool
|
||||
}
|
||||
120
internal/dhcpsvc/dhcpsvc.go
Normal file
120
internal/dhcpsvc/dhcpsvc.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Package dhcpsvc contains the AdGuard Home DHCP service.
|
||||
//
|
||||
// TODO(e.burkov): Add tests.
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
)
|
||||
|
||||
// Lease is a DHCP lease.
|
||||
//
|
||||
// TODO(e.burkov): Consider it to [agh], since it also may be needed in
|
||||
// [websvc]. Also think of implementing iterating methods with appropriate
|
||||
// signatures.
|
||||
type Lease struct {
|
||||
// IP is the IP address leased to the client.
|
||||
IP netip.Addr
|
||||
|
||||
// Expiry is the expiration time of the lease.
|
||||
Expiry time.Time
|
||||
|
||||
// Hostname of the client.
|
||||
Hostname string
|
||||
|
||||
// HWAddr is the physical hardware address (MAC address).
|
||||
HWAddr net.HardwareAddr
|
||||
|
||||
// IsStatic defines if the lease is static.
|
||||
IsStatic bool
|
||||
}
|
||||
|
||||
type Interface interface {
|
||||
agh.ServiceWithConfig[*Config]
|
||||
|
||||
// Enabled returns true if DHCP provides information about clients.
|
||||
Enabled() (ok bool)
|
||||
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have an IP address.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// MACByIP returns the MAC address for the given IP address leased. It
|
||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||
// client must always have a MAC address.
|
||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||
|
||||
// IPByHost returns the IP address of the DHCP client with the given
|
||||
// hostname. The hostname will be an empty string if there is no such
|
||||
// client, due to an assumption that a DHCP client must always have a
|
||||
// hostname, either set by the client or assigned automatically.
|
||||
IPByHost(host string) (ip netip.Addr)
|
||||
|
||||
// Leases returns all the DHCP leases.
|
||||
Leases() (leases []*Lease)
|
||||
|
||||
// AddLease adds a new DHCP lease. It returns an error if the lease is
|
||||
// invalid or already exists.
|
||||
AddLease(l *Lease) (err error)
|
||||
|
||||
// EditLease changes an existing DHCP lease. It returns an error if there
|
||||
// is no lease equal to old or if new is invalid or already exists.
|
||||
EditLease(old, new *Lease) (err error)
|
||||
|
||||
// RemoveLease removes an existing DHCP lease. It returns an error if there
|
||||
// is no lease equal to l.
|
||||
RemoveLease(l *Lease) (err error)
|
||||
|
||||
// Reset removes all the DHCP leases.
|
||||
Reset() (err error)
|
||||
}
|
||||
|
||||
// Empty is an [Interface] implementation that does nothing.
|
||||
type Empty struct{}
|
||||
|
||||
// type check
|
||||
var _ Interface = Empty{}
|
||||
|
||||
// Start implements the [Service] interface for Empty.
|
||||
func (Empty) Start() (err error) { return nil }
|
||||
|
||||
// Shutdown implements the [Service] interface for Empty.
|
||||
func (Empty) Shutdown(_ context.Context) (err error) { return nil }
|
||||
|
||||
var _ agh.ServiceWithConfig[*Config] = Empty{}
|
||||
|
||||
// Config implements the [ServiceWithConfig] interface for Empty.
|
||||
func (Empty) Config() (conf *Config) { return nil }
|
||||
|
||||
// Enabled implements the [Interface] interface for Empty.
|
||||
func (Empty) Enabled() (ok bool) { return false }
|
||||
|
||||
// HostByIP implements the [Interface] interface for Empty.
|
||||
func (Empty) HostByIP(_ netip.Addr) (host string) { return "" }
|
||||
|
||||
// MACByIP implements the [Interface] interface for Empty.
|
||||
func (Empty) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
|
||||
|
||||
// IPByHost implements the [Interface] interface for Empty.
|
||||
func (Empty) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} }
|
||||
|
||||
// Leases implements the [Interface] interface for Empty.
|
||||
func (Empty) Leases() (leases []*Lease) { return nil }
|
||||
|
||||
// AddLease implements the [Interface] interface for Empty.
|
||||
func (Empty) AddLease(_ *Lease) (err error) { return nil }
|
||||
|
||||
// EditLease implements the [Interface] interface for Empty.
|
||||
func (Empty) EditLease(_, _ *Lease) (err error) { return nil }
|
||||
|
||||
// RemoveLease implements the [Interface] interface for Empty.
|
||||
func (Empty) RemoveLease(_ *Lease) (err error) { return nil }
|
||||
|
||||
// Reset implements the [Interface] interface for Empty.
|
||||
func (Empty) Reset() (err error) { return nil }
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -436,102 +435,6 @@ func (s *Server) initDefaultSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings - prepares upstream DNS server settings
|
||||
func (s *Server) prepareUpstreamSettings() error {
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
if s.conf.UpstreamDNSFileName != "" {
|
||||
data, err := os.ReadFile(s.conf.UpstreamDNSFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading upstream from file: %w", err)
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dns: using %d upstream servers from file %s", len(upstreams), s.conf.UpstreamDNSFileName)
|
||||
} else {
|
||||
upstreams = s.conf.UpstreamDNS
|
||||
}
|
||||
|
||||
httpVersions := UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams)
|
||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: httpVersions,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstream config: %w", err)
|
||||
}
|
||||
|
||||
if len(upstreamConfig.Upstreams) == 0 {
|
||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
var uc *proxy.UpstreamConfig
|
||||
uc, err = proxy.ParseUpstreamsConfig(
|
||||
defaultDNS,
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: httpVersions,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing default upstreams: %w", err)
|
||||
}
|
||||
|
||||
upstreamConfig.Upstreams = uc.Upstreams
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = upstreamConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
||||
// based on provided parameters.
|
||||
func setProxyUpstreamMode(
|
||||
conf *proxy.Config,
|
||||
allServers bool,
|
||||
fastestAddr bool,
|
||||
fastestTimeout time.Duration,
|
||||
) {
|
||||
if allServers {
|
||||
conf.UpstreamMode = proxy.UModeParallel
|
||||
} else if fastestAddr {
|
||||
conf.UpstreamMode = proxy.UModeFastestAddr
|
||||
conf.FastestPingTimeout = fastestTimeout
|
||||
} else {
|
||||
conf.UpstreamMode = proxy.UModeLoadBalance
|
||||
}
|
||||
}
|
||||
|
||||
// prepareIpsetListSettings reads and prepares the ipset configuration either
|
||||
// from a file or from the data in the configuration file.
|
||||
func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
@@ -540,6 +443,7 @@ func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
return s.ipset.init(s.conf.IpsetList)
|
||||
}
|
||||
|
||||
// #nosec G304 -- Trust the path explicitly given by the user.
|
||||
data, err := os.ReadFile(fn)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -145,10 +145,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
// processRecursion checks the incoming request and halts its handling by
|
||||
// answering NXDOMAIN if s has tried to resolve it recently.
|
||||
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing recursion")
|
||||
defer log.Debug("dnsforward: finished processing recursion")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
|
||||
log.Debug("recursion detected resolving %q", msg.Question[0].Name)
|
||||
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
|
||||
pctx.Res = s.genNXDomain(pctx.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -158,10 +161,13 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// processInitial terminates the following processing for some requests if
|
||||
// needed and enriches the ctx with some client-specific information.
|
||||
// needed and enriches dctx with some client-specific information.
|
||||
//
|
||||
// TODO(e.burkov): Decompose into less general processors.
|
||||
func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing initial")
|
||||
defer log.Debug("dnsforward: finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
@@ -282,6 +288,9 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing ddr")
|
||||
defer log.Debug("dnsforward: finished processing ddr")
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -375,6 +384,9 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
// processDetermineLocal determines if the client's IP address is from locally
|
||||
// served network and saves the result into the context.
|
||||
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local detection")
|
||||
defer log.Debug("dnsforward: finished processing local detection")
|
||||
|
||||
rc = resultCodeSuccess
|
||||
|
||||
var ip net.IP
|
||||
@@ -405,6 +417,9 @@ func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing dhcp hosts")
|
||||
defer log.Debug("dnsforward: finished processing dhcp hosts")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@@ -544,6 +559,9 @@ func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
|
||||
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
|
||||
// in locally served network from external clients.
|
||||
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local restriction")
|
||||
defer log.Debug("dnsforward: finished processing local restriction")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@@ -613,6 +631,9 @@ func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing dhcp addrs")
|
||||
defer log.Debug("dnsforward: finished processing dhcp addrs")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
return resultCodeSuccess
|
||||
@@ -658,6 +679,9 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
// processLocalPTR responds to PTR requests if the target IP is detected to be
|
||||
// inside the local network and the query was not answered from DHCP.
|
||||
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing local ptr")
|
||||
defer log.Debug("dnsforward: finished processing local ptr")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
return resultCodeSuccess
|
||||
@@ -692,6 +716,9 @@ func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Apply filtering logic
|
||||
func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering before req")
|
||||
defer log.Debug("dnsforward: finished processing filtering before req")
|
||||
|
||||
if ctx.proxyCtx.Res != nil {
|
||||
// Go on since the response is already set.
|
||||
return resultCodeSuccess
|
||||
@@ -725,6 +752,9 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing upstream")
|
||||
defer log.Debug("dnsforward: finished processing upstream")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@@ -871,6 +901,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering after resp")
|
||||
defer log.Debug("dnsforward: finished processing filtering after resp")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
switch res := dctx.result; res.Reason {
|
||||
case filtering.NotFilteredAllowList:
|
||||
|
||||
@@ -48,12 +48,33 @@ var webRegistered bool
|
||||
|
||||
// hostToIPTable is a convenient type alias for tables of host names to an IP
|
||||
// address.
|
||||
//
|
||||
// TODO(e.burkov): Use the [DHCP] interface instead.
|
||||
type hostToIPTable = map[string]netip.Addr
|
||||
|
||||
// ipToHostTable is a convenient type alias for tables of IP addresses to their
|
||||
// host names. For example, for use with PTR queries.
|
||||
//
|
||||
// TODO(e.burkov): Use the [DHCP] interface instead.
|
||||
type ipToHostTable = map[netip.Addr]string
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data needed in this package.
|
||||
type DHCP interface {
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have an IP address.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// IPByHost returns the IP address of the DHCP client with the given
|
||||
// hostname. The hostname will be an empty string if there is no such
|
||||
// client, due to an assumption that a DHCP client must always have a
|
||||
// hostname, either set by the client or assigned automatically.
|
||||
IPByHost(host string) (ip netip.Addr)
|
||||
|
||||
// Enabled returns true if DHCP provides information about clients.
|
||||
Enabled() (ok bool)
|
||||
}
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@@ -215,7 +236,7 @@ func (s *Server) Close() {
|
||||
s.dnsProxy = nil
|
||||
|
||||
if err := s.ipset.close(); err != nil {
|
||||
log.Error("closing ipset: %s", err)
|
||||
log.Error("dnsforward: closing ipset: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -443,21 +464,17 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", localAddrs)
|
||||
|
||||
var upsConfig *proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(
|
||||
localAddrs,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
upsConfig, err := s.prepareUpstreamConfig(localAddrs, nil, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
},
|
||||
)
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstreams: %w", err)
|
||||
return fmt.Errorf("parsing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
@@ -489,7 +506,8 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
err = s.prepareUpstreamSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream settings: %w", err)
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var proxyConfig proxy.Config
|
||||
@@ -656,7 +674,9 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
log.Print("Start reconfiguring the server")
|
||||
log.Info("dnsforward: starting reconfiguring server")
|
||||
defer log.Info("dnsforward: finished reconfiguring server")
|
||||
|
||||
err := s.stopLocked()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
@@ -708,13 +728,13 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
|
||||
// Allow if at least one of the checks allows in allowlist mode, but block
|
||||
// if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
log.Debug("dnsforward: client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %v (id %q) is in access blocklist", ip, clientID)
|
||||
log.Debug("dnsforward: client %v (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
@@ -53,14 +53,14 @@ func (s *Server) beforeRequestHandler(
|
||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
||||
// the client's IP address and ID, if any, from dctx.
|
||||
func (s *Server) getClientRequestFilteringSettings(dctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
setts := s.dnsFilter.Settings()
|
||||
setts.ProtectionEnabled = dctx.protectionEnabled
|
||||
if s.conf.FilterHandler != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(ip, dctx.clientID, &setts)
|
||||
s.conf.FilterHandler(ip, dctx.clientID, setts)
|
||||
}
|
||||
|
||||
return &setts
|
||||
return setts
|
||||
}
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the
|
||||
|
||||
@@ -633,61 +633,70 @@ func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkDNS checks the upstream server defined by upstreamConfigStr using
|
||||
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
|
||||
// upstream's address.
|
||||
func checkDNS(
|
||||
upstreamConfigStr string,
|
||||
bootstrap []string,
|
||||
bootstrapPrefIPv6 bool,
|
||||
timeout time.Duration,
|
||||
healthCheck healthCheckFunc,
|
||||
) (err error) {
|
||||
if IsCommentOrEmpty(upstreamConfigStr) {
|
||||
return nil
|
||||
// parseUpstreamLine parses line and creates the [upstream.Upstream] using opts
|
||||
// and information from [s.dnsFilter.EtcHosts]. It returns an error if the line
|
||||
// is not a valid upstream line, see [upstream.AddressToUpstream]. It's a
|
||||
// caller's responsibility to close u.
|
||||
func (s *Server) parseUpstreamLine(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
) (u upstream.Upstream, specific bool, err error) {
|
||||
// Separate upstream from domains list.
|
||||
upstreamAddr, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
// Separate upstream from domains list.
|
||||
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
specific = len(domains) > 0
|
||||
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
return nil, specific, fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(bootstrap) == 0 {
|
||||
bootstrap = defaultBootstrap
|
||||
return nil, specific, nil
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
PreferIPv6: bootstrapPrefIPv6,
|
||||
})
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
|
||||
resolved := s.resolveUpstreamHost(extractUpstreamHost(upstreamAddr))
|
||||
sortNetIPAddrs(resolved, opts.PreferIPv6)
|
||||
opts.ServerIPAddrs = resolved
|
||||
}
|
||||
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
return nil, specific, fmt.Errorf("creating upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
|
||||
return u, specific, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkDNS(line string, opts *upstream.Options, check healthCheckFunc) (err error) {
|
||||
if IsCommentOrEmpty(line) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var u upstream.Upstream
|
||||
var specific bool
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
u, specific, err = s.parseUpstreamLine(line, opts)
|
||||
if err != nil || u == nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
if err = healthCheck(u); err != nil {
|
||||
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
|
||||
if domains != nil {
|
||||
return domainSpecificTestError{error: err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
|
||||
|
||||
return nil
|
||||
return check(u)
|
||||
}
|
||||
|
||||
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -699,47 +708,54 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]string{}
|
||||
bootstraps := req.BootstrapDNS
|
||||
bootstrapPrefIPv6 := s.conf.BootstrapPreferIPv6
|
||||
timeout := s.conf.UpstreamTimeout
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
if len(opts.Bootstrap) == 0 {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
type upsCheckResult = struct {
|
||||
res string
|
||||
err error
|
||||
host string
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
upsNum := len(req.Upstreams) + len(req.PrivateUpstreams)
|
||||
result := make(map[string]string, upsNum)
|
||||
resCh := make(chan upsCheckResult, upsNum)
|
||||
|
||||
checkUps := func(ups string, healthCheck healthCheckFunc) {
|
||||
res := upsCheckResult{
|
||||
host: ups,
|
||||
}
|
||||
defer func() { resCh <- res }()
|
||||
|
||||
checkErr := checkDNS(ups, bootstraps, bootstrapPrefIPv6, timeout, healthCheck)
|
||||
if checkErr != nil {
|
||||
res.res = checkErr.Error()
|
||||
} else {
|
||||
res.res = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go checkUps(ups, checkDNSUpstreamExc)
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go checkUps(ups, checkPrivateUpstreamExc)
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkPrivateUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
|
||||
for i := 0; i < upsNum; i++ {
|
||||
pair := <-resCh
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
result[pair.host] = pair.res
|
||||
pair := <-resCh
|
||||
if pair.err != nil {
|
||||
result[pair.host] = pair.err.Error()
|
||||
} else {
|
||||
result[pair.host] = "OK"
|
||||
}
|
||||
}
|
||||
close(resCh)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, result)
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -280,6 +282,10 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
const sdnsStamp = `sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_J` +
|
||||
`S3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczE` +
|
||||
`uYWRndWFyZC5jb20`
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErr string
|
||||
@@ -300,7 +306,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/host/]" + sdnsStamp,
|
||||
},
|
||||
}, {
|
||||
name: "with_default",
|
||||
@@ -310,7 +316,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/host/]" + sdnsStamp,
|
||||
"8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
@@ -326,9 +332,10 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
wantErr: `validating upstream "123.3.7m": not an ip:port`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": ` +
|
||||
`missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
|
||||
@@ -340,14 +347,14 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
sdnsStamp,
|
||||
"udp://dns.google",
|
||||
"udp://8.8.8.8",
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/host/]" + sdnsStamp,
|
||||
"[/пример.рф/]8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
@@ -418,27 +425,28 @@ func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newLocalUpstreamListener(t *testing.T, port int, handler dns.Handler) (real net.Addr) {
|
||||
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
startCh := make(chan struct{})
|
||||
upsSrv := &dns.Server{
|
||||
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(port)).String(),
|
||||
Addr: netip.AddrPortFrom(netutil.IPv4Localhost(), port).String(),
|
||||
Net: "tcp",
|
||||
Handler: handler,
|
||||
NotifyStartedFunc: func() { close(startCh) },
|
||||
}
|
||||
go func() {
|
||||
t := testutil.PanicT{}
|
||||
|
||||
err := upsSrv.ListenAndServe()
|
||||
require.NoError(t, err)
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
}()
|
||||
|
||||
<-startCh
|
||||
testutil.CleanupAndRequireSuccess(t, upsSrv.Shutdown)
|
||||
|
||||
return upsSrv.Listener.Addr()
|
||||
return testutil.RequireTypeAssert[*net.TCPAddr](t, upsSrv.Listener.Addr()).AddrPort()
|
||||
}
|
||||
|
||||
func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
goodHandler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
err := w.WriteMsg(new(dns.Msg).SetReply(m))
|
||||
require.NoError(testutil.PanicT{}, err)
|
||||
@@ -457,9 +465,38 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||
}).String()
|
||||
|
||||
const upsTimeout = 100 * time.Millisecond
|
||||
const (
|
||||
upsTimeout = 100 * time.Millisecond
|
||||
|
||||
srv := createTestServer(t, &filtering.Config{}, ServerConfig{
|
||||
hostsFileName = "hosts"
|
||||
upstreamHost = "custom.localhost"
|
||||
)
|
||||
|
||||
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
|
||||
hostsUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: netutil.JoinHostPort(upstreamHost, int(hostsListener.Port())),
|
||||
}).String()
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(
|
||||
filtering.SysHostsListID,
|
||||
fstest.MapFS{
|
||||
hostsFileName: &fstest.MapFile{
|
||||
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
|
||||
},
|
||||
},
|
||||
&aghtest.FSWatcher{
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(_ string) (err error) { return nil },
|
||||
OnClose: func() (err error) { return nil },
|
||||
},
|
||||
hostsFileName,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UpstreamTimeout: upsTimeout,
|
||||
@@ -486,8 +523,7 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
"upstream_dns": []string{badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `upstream "` + badUps + `" fails to exchange: ` +
|
||||
`couldn't communicate with upstream: exchanging with ` +
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "broken",
|
||||
@@ -497,20 +533,40 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `upstream "` + badUps + `" fails to exchange: ` +
|
||||
`couldn't communicate with upstream: exchanging with ` +
|
||||
badUps: `couldn't communicate with upstream: exchanging with ` +
|
||||
badUps + ` over tcp: dns: id mismatch`,
|
||||
},
|
||||
name: "both",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
"[/domain.example/]" + badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "domain_specific_error",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{hostsUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
hostsUps: "OK",
|
||||
},
|
||||
name: "etc_hosts",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqBody, err := json.Marshal(tc.body)
|
||||
var reqBody []byte
|
||||
reqBody, err = json.Marshal(tc.body)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.handleTestUpstreamDNS(w, r)
|
||||
@@ -538,11 +594,15 @@ func TestServer_handleTestUpstreaDNS(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"upstream_dns": []string{sleepyUps},
|
||||
}
|
||||
reqBody, err := json.Marshal(req)
|
||||
|
||||
var reqBody []byte
|
||||
reqBody, err = json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(reqBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.handleTestUpstreamDNS(w, r)
|
||||
|
||||
@@ -110,6 +110,9 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: ipset: started processing")
|
||||
defer log.Debug("dnsforward: ipset: finished processing")
|
||||
|
||||
if c.skipIpsetProcessing(dctx) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -125,12 +128,12 @@ func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
|
||||
n, err := c.ipsetMgr.Add(host, ip4s, ip6s)
|
||||
if err != nil {
|
||||
// Consider ipset errors non-critical to the request.
|
||||
log.Error("ipset: adding host ips: %s", err)
|
||||
log.Error("dnsforward: ipset: adding host ips: %s", err)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("ipset: added %d new ipset entries", n)
|
||||
log.Debug("dnsforward: ipset: added %d new ipset entries", n)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -57,16 +57,13 @@ func (s *Server) genDNSFilterMessage(
|
||||
return s.genBlockedHost(req, s.conf.SafeBrowsingBlockHost, dctx)
|
||||
case filtering.FilteredParental:
|
||||
return s.genBlockedHost(req, s.conf.ParentalBlockHost, dctx)
|
||||
case filtering.FilteredSafeSearch:
|
||||
// If Safe Search generated the necessary IP addresses, use them.
|
||||
// Otherwise, if there were no errors, there are no addresses for the
|
||||
// requested IP version, so produce a NODATA response.
|
||||
return s.genResponseWithIPs(req, ipsFromRules(res.Rules))
|
||||
default:
|
||||
// If the query was filtered by Safe Search, filtering also must return
|
||||
// the IP addresses that must be used in response. Return them
|
||||
// regardless of the filtering method.
|
||||
ips := ipsFromRules(res.Rules)
|
||||
if res.Reason == filtering.FilteredSafeSearch && len(ips) > 0 {
|
||||
return s.genResponseWithIPs(req, ips)
|
||||
}
|
||||
|
||||
return s.genForBlockingMode(req, ips)
|
||||
return s.genForBlockingMode(req, ipsFromRules(res.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,60 +17,78 @@ import (
|
||||
|
||||
// Write Stats data and logs
|
||||
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing querylog and stats")
|
||||
defer log.Debug("dnsforward: finished processing querylog and stats")
|
||||
|
||||
elapsed := time.Since(dctx.startTime)
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
shouldLog := true
|
||||
msg := pctx.Req
|
||||
q := msg.Question[0]
|
||||
q := pctx.Req.Question[0]
|
||||
host := strings.ToLower(strings.TrimSuffix(q.Name, "."))
|
||||
|
||||
// don't log ANY request if refuseAny is enabled
|
||||
if q.Qtype == dns.TypeANY && s.conf.RefuseAny {
|
||||
shouldLog = false
|
||||
}
|
||||
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
ip = slices.Clone(ip)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
s.anonymizer.Load()(ip)
|
||||
|
||||
log.Debug("client ip: %s", ip)
|
||||
log.Debug("dnsforward: client ip for stats and querylog: %s", ip)
|
||||
|
||||
ipStr := ip.String()
|
||||
ids := []string{ipStr, dctx.clientID}
|
||||
qt, cl := q.Qtype, q.Qclass
|
||||
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly
|
||||
// uninitialized while in use. This can happen after proxy server has been
|
||||
// stopped, but its workers haven't yet exited.
|
||||
if shouldLog &&
|
||||
s.queryLog != nil &&
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
|
||||
// containing persistent client.
|
||||
s.queryLog.ShouldLog(host, q.Qtype, q.Qclass, ids) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
if s.shouldLog(host, qt, cl, ids) {
|
||||
s.logQuery(dctx, pctx, elapsed, ip)
|
||||
} else {
|
||||
log.Debug(
|
||||
"dnsforward: request %s %s from %s ignored; not logging",
|
||||
dns.Type(q.Qtype),
|
||||
"dnsforward: request %s %s %q from %s ignored; not adding to querylog",
|
||||
dns.Class(cl),
|
||||
dns.Type(qt),
|
||||
host,
|
||||
ip,
|
||||
)
|
||||
}
|
||||
|
||||
if s.stats != nil &&
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start
|
||||
// containing persistent client.
|
||||
s.stats.ShouldCount(host, q.Qtype, q.Qclass, ids) {
|
||||
if s.shouldCountStat(host, qt, cl, ids) {
|
||||
s.updateStats(dctx, elapsed, *dctx.result, ipStr)
|
||||
} else {
|
||||
log.Debug(
|
||||
"dnsforward: request %s %s %q from %s ignored; not counting in stats",
|
||||
dns.Class(cl),
|
||||
dns.Type(qt),
|
||||
host,
|
||||
ip,
|
||||
)
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// shouldLog returns true if the query with the given data should be logged in
|
||||
// the query log. s.serverLock is expected to be locked.
|
||||
func (s *Server) shouldLog(host string, qt, cl uint16, ids []string) (ok bool) {
|
||||
if qt == dns.TypeANY && s.conf.RefuseAny {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start containing
|
||||
// persistent client.
|
||||
return s.queryLog != nil && s.queryLog.ShouldLog(host, qt, cl, ids)
|
||||
}
|
||||
|
||||
// shouldCountStat returns true if the query with the given data should be
|
||||
// counted in the statistics. s.serverLock is expected to be locked.
|
||||
func (s *Server) shouldCountStat(host string, qt, cl uint16, ids []string) (ok bool) {
|
||||
// TODO(s.chzhen): Use dnsforward.dnsContext when it will start containing
|
||||
// persistent client.
|
||||
return s.stats != nil && s.stats.ShouldCount(host, qt, cl, ids)
|
||||
}
|
||||
|
||||
// logQuery pushes the request details into the query log.
|
||||
func (s *Server) logQuery(
|
||||
dctx *dnsContext,
|
||||
@@ -123,7 +141,10 @@ func (s *Server) updateStats(
|
||||
pctx := ctx.proxyCtx
|
||||
e := stats.Entry{}
|
||||
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
|
||||
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
|
||||
if e.Domain != "." {
|
||||
// Remove last ".", but save the domain as is for "." queries.
|
||||
e.Domain = e.Domain[:len(e.Domain)-1]
|
||||
}
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
|
||||
@@ -46,6 +46,10 @@ type testStats struct {
|
||||
|
||||
// Update implements the [stats.Interface] interface for *testStats.
|
||||
func (l *testStats) Update(e stats.Entry) {
|
||||
if e.Domain == "" {
|
||||
return
|
||||
}
|
||||
|
||||
l.lastEntry = e
|
||||
}
|
||||
|
||||
@@ -54,9 +58,12 @@ func (l *testStats) ShouldCount(string, uint16, uint16, []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
const domain = "example.com."
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
proto proxy.Proto
|
||||
addr net.Addr
|
||||
clientID string
|
||||
@@ -67,6 +74,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult stats.Result
|
||||
}{{
|
||||
name: "success_udp",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -77,6 +85,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls_clientid",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "cli42",
|
||||
@@ -87,6 +96,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -97,6 +107,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_quic",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoQUIC,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -107,6 +118,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_https",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoHTTPS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -117,6 +129,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_dnscrypt",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoDNSCrypt,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -127,6 +140,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_udp_filtered",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -137,6 +151,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RFiltered,
|
||||
}, {
|
||||
name: "success_udp_sb",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -147,6 +162,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RSafeBrowsing,
|
||||
}, {
|
||||
name: "success_udp_ss",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -157,6 +173,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RSafeSearch,
|
||||
}, {
|
||||
name: "success_udp_pc",
|
||||
domain: domain,
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
@@ -165,6 +182,17 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: filtering.FilteredParental,
|
||||
wantStatResult: stats.RParental,
|
||||
}, {
|
||||
name: "success_udp_pc_empty_fqdn",
|
||||
domain: ".",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 5}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.5",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: filtering.FilteredParental,
|
||||
wantStatResult: stats.RParental,
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||
@@ -181,7 +209,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "example.com.",
|
||||
Name: tc.domain,
|
||||
}},
|
||||
}
|
||||
pctx := &proxy.DNSContext{
|
||||
|
||||
311
internal/dnsforward/upstreams.go
Normal file
311
internal/dnsforward/upstreams.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
// the configuration itself.
|
||||
func (s *Server) loadUpstreams() (upstreams []string, err error) {
|
||||
if s.conf.UpstreamDNSFileName == "" {
|
||||
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading upstream from file: %w", err)
|
||||
}
|
||||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
|
||||
|
||||
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
|
||||
}
|
||||
|
||||
// prepareUpstreamSettings sets upstream DNS server settings.
|
||||
func (s *Server) prepareUpstreamSettings() (err error) {
|
||||
// We're setting a customized set of RootCAs. The reason is that Go default
|
||||
// mechanism of loading TLS roots does not always work properly on some
|
||||
// routers so we're loading roots manually and pass it here.
|
||||
//
|
||||
// See [aghtls.SystemRootCAs].
|
||||
upstream.RootCAs = s.conf.TLSv12Roots
|
||||
upstream.CipherSuites = s.conf.TLSCiphers
|
||||
|
||||
// Load upstreams either from the file, or from the settings
|
||||
var upstreams []string
|
||||
upstreams, err = s.loadUpstreams()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing upstream config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareUpstreamConfig sets upstream configuration based on upstreams and
|
||||
// configuration of s.
|
||||
func (s *Server) prepareUpstreamConfig(
|
||||
upstreams []string,
|
||||
defaultUpstreams []string,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing upstream config: %w", err)
|
||||
}
|
||||
|
||||
if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
|
||||
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
|
||||
var defaultUpstreamConfig *proxy.UpstreamConfig
|
||||
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing default upstreams: %w", err)
|
||||
}
|
||||
|
||||
uc.Upstreams = defaultUpstreamConfig.Upstreams
|
||||
}
|
||||
|
||||
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
|
||||
err = s.replaceUpstreamsWithHosts(uc, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving upstreams with hosts: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// replaceUpstreamsWithHosts replaces unique upstreams with their resolved
|
||||
// versions based on the system hosts file.
|
||||
//
|
||||
// TODO(e.burkov): This should be performed inside dnsproxy, which should
|
||||
// actually consider /etc/hosts. See TODO on [aghnet.HostsContainer].
|
||||
func (s *Server) replaceUpstreamsWithHosts(
|
||||
upsConf *proxy.UpstreamConfig,
|
||||
opts *upstream.Options,
|
||||
) (err error) {
|
||||
resolved := map[string]*upstream.Options{}
|
||||
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams: %w", err)
|
||||
}
|
||||
|
||||
hosts := maps.Keys(upsConf.DomainReservedUpstreams)
|
||||
// TODO(e.burkov): Think of extracting sorted range into an util function.
|
||||
slices.Sort(hosts)
|
||||
for _, host := range hosts {
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams reserved for %s: %w", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
hosts = maps.Keys(upsConf.SpecifiedDomainUpstreams)
|
||||
slices.Sort(hosts)
|
||||
for _, host := range hosts {
|
||||
err = s.resolveUpstreamsWithHosts(resolved, upsConf.SpecifiedDomainUpstreams[host], opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving upstreams specific for %s: %w", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams
|
||||
// and replaces those both in upstreams and resolved. Upstreams that failed to
|
||||
// resolve are placed to resolved as-is. This function only returns error of
|
||||
// upstreams closing.
|
||||
func (s *Server) resolveUpstreamsWithHosts(
|
||||
resolved map[string]*upstream.Options,
|
||||
upstreams []upstream.Upstream,
|
||||
opts *upstream.Options,
|
||||
) (err error) {
|
||||
for i := range upstreams {
|
||||
u := upstreams[i]
|
||||
addr := u.Address()
|
||||
host := extractUpstreamHost(addr)
|
||||
|
||||
withIPs, ok := resolved[host]
|
||||
if !ok {
|
||||
ips := s.resolveUpstreamHost(host)
|
||||
if len(ips) == 0 {
|
||||
resolved[host] = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
sortNetIPAddrs(ips, opts.PreferIPv6)
|
||||
|
||||
withIPs = opts.Clone()
|
||||
withIPs.ServerIPAddrs = ips
|
||||
resolved[host] = withIPs
|
||||
} else if withIPs == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = u.Close(); err != nil {
|
||||
return fmt.Errorf("closing upstream %s: %w", addr, err)
|
||||
}
|
||||
|
||||
upstreams[i], err = upstream.AddressToUpstream(addr, withIPs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("replacing upstream %s with resolved %s: %w", addr, host, err)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: using %s for %s", withIPs.ServerIPAddrs, upstreams[i].Address())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUpstreamHost returns the hostname of addr without port with an
|
||||
// assumption that any address passed here has already been successfully parsed
|
||||
// by [upstream.AddressToUpstream]. This function eesentially mirrors the logic
|
||||
// of [upstream.AddressToUpstream], see TODO on [replaceUpstreamsWithHosts].
|
||||
func extractUpstreamHost(addr string) (host string) {
|
||||
var err error
|
||||
if strings.Contains(addr, "://") {
|
||||
var u *url.URL
|
||||
u, err = url.Parse(addr)
|
||||
if err != nil {
|
||||
log.Debug("dnsforward: parsing upstream %s: %s", addr, err)
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
return u.Hostname()
|
||||
}
|
||||
|
||||
// Probably, plain UDP upstream defined by address or address:port.
|
||||
host, err = netutil.SplitHost(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// resolveUpstreamHost returns the version of ups with IP addresses from the
|
||||
// system hosts file placed into its options.
|
||||
func (s *Server) resolveUpstreamHost(host string) (addrs []net.IP) {
|
||||
req := &urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: dns.TypeA,
|
||||
}
|
||||
aRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
|
||||
|
||||
req.DNSType = dns.TypeAAAA
|
||||
aaaaRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)
|
||||
|
||||
var ips []net.IP
|
||||
for _, rw := range append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...) {
|
||||
dr := rw.DNSRewrite
|
||||
if dr == nil || dr.Value == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip, ok := dr.Value.(net.IP); ok {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
||||
// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
|
||||
// Invalid addresses are sorted near the end.
|
||||
//
|
||||
// TODO(e.burkov): This function taken from dnsproxy, which also already
|
||||
// contains a few similar functions. Think of moving to golibs.
|
||||
func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
|
||||
l := len(addrs)
|
||||
if l <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
slices.SortStableFunc(addrs, func(addrA, addrB net.IP) (sortsBefore bool) {
|
||||
switch len(addrA) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
switch len(addrB) {
|
||||
case net.IPv4len, net.IPv6len:
|
||||
// Go on.
|
||||
default:
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if aIs4, bIs4 := addrA.To4() != nil, addrB.To4() != nil; aIs4 != bIs4 {
|
||||
if aIs4 {
|
||||
return !preferIPv6
|
||||
}
|
||||
|
||||
return preferIPv6
|
||||
}
|
||||
|
||||
return bytes.Compare(addrA, addrB) < 0
|
||||
})
|
||||
}
|
||||
|
||||
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration
|
||||
// depending on configuration.
|
||||
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
|
||||
if !http3 {
|
||||
return upstream.DefaultHTTPVersions
|
||||
}
|
||||
|
||||
return []upstream.HTTPVersion{
|
||||
upstream.HTTPVersion3,
|
||||
upstream.HTTPVersion2,
|
||||
upstream.HTTPVersion11,
|
||||
}
|
||||
}
|
||||
|
||||
// setProxyUpstreamMode sets the upstream mode and related settings in conf
|
||||
// based on provided parameters.
|
||||
func setProxyUpstreamMode(
|
||||
conf *proxy.Config,
|
||||
allServers bool,
|
||||
fastestAddr bool,
|
||||
fastestTimeout time.Duration,
|
||||
) {
|
||||
if allServers {
|
||||
conf.UpstreamMode = proxy.UModeParallel
|
||||
} else if fastestAddr {
|
||||
conf.UpstreamMode = proxy.UModeFastestAddr
|
||||
conf.FastestPingTimeout = fastestTimeout
|
||||
} else {
|
||||
conf.UpstreamMode = proxy.UModeLoadBalance
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,12 @@ package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"golang.org/x/exp/slices"
|
||||
@@ -44,23 +47,57 @@ func initBlockedServices() {
|
||||
log.Debug("filtering: initialized %d services", l)
|
||||
}
|
||||
|
||||
// BlockedSvcKnown returns true if a blocked service ID is known.
|
||||
func BlockedSvcKnown(s string) (ok bool) {
|
||||
_, ok = serviceRules[s]
|
||||
// BlockedServices is the configuration of blocked services.
|
||||
type BlockedServices struct {
|
||||
// Schedule is blocked services schedule for every day of the week.
|
||||
Schedule *schedule.Weekly `yaml:"schedule"`
|
||||
|
||||
return ok
|
||||
// IDs is the names of blocked services.
|
||||
IDs []string `yaml:"ids"`
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of blocked services.
|
||||
func (s *BlockedServices) Clone() (c *BlockedServices) {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BlockedServices{
|
||||
Schedule: s.Schedule.Clone(),
|
||||
IDs: slices.Clone(s.IDs),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate returns an error if blocked services contain unknown service ID. s
|
||||
// must not be nil.
|
||||
func (s *BlockedServices) Validate() (err error) {
|
||||
for _, id := range s.IDs {
|
||||
_, ok := serviceRules[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown blocked-service %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
|
||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings) {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
setts.ServicesRules = []ServiceEntry{}
|
||||
if list == nil {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
list = d.Config.BlockedServices
|
||||
bsvc := d.BlockedServices
|
||||
|
||||
// TODO(s.chzhen): Use startTime from [dnsforward.dnsContext].
|
||||
if !bsvc.Schedule.Contains(time.Now()) {
|
||||
d.ApplyBlockedServicesList(setts, bsvc.IDs)
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyBlockedServicesList appends filtering rules to the settings.
|
||||
func (d *DNSFilter) ApplyBlockedServicesList(setts *Settings, list []string) {
|
||||
for _, name := range list {
|
||||
rules, ok := serviceRules[name]
|
||||
if !ok {
|
||||
@@ -90,7 +127,7 @@ func (d *DNSFilter) handleBlockedServicesAll(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
|
||||
d.confLock.RLock()
|
||||
list := d.Config.BlockedServices
|
||||
list := d.Config.BlockedServices.IDs
|
||||
d.confLock.RUnlock()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, list)
|
||||
@@ -106,7 +143,7 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
d.confLock.Lock()
|
||||
d.Config.BlockedServices = list
|
||||
d.Config.BlockedServices.IDs = list
|
||||
d.confLock.Unlock()
|
||||
|
||||
log.Debug("Updated blocked services list: %d", len(list))
|
||||
|
||||
@@ -103,9 +103,9 @@ type Config struct {
|
||||
|
||||
Rewrites []*LegacyRewrite `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// BlockedServices is the configuration of blocked services.
|
||||
// Per-client settings can override this configuration.
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
BlockedServices *BlockedServices `yaml:"blocked_services"`
|
||||
|
||||
// EtcHosts is a container of IP-hostname pairs taken from the operating
|
||||
// system configuration files (e.g. /etc/hosts).
|
||||
@@ -298,12 +298,12 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
|
||||
atomic.StoreUint32(&d.enabled, mathutil.BoolToNumber[uint32](enabled))
|
||||
}
|
||||
|
||||
// GetConfig - get configuration
|
||||
func (d *DNSFilter) GetConfig() (s Settings) {
|
||||
// Settings returns filtering settings.
|
||||
func (d *DNSFilter) Settings() (s *Settings) {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
return Settings{
|
||||
return &Settings{
|
||||
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
|
||||
SafeSearchEnabled: d.Config.SafeSearchConf.Enabled,
|
||||
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
|
||||
@@ -519,7 +519,7 @@ func (d *DNSFilter) matchSysHosts(
|
||||
dnsres, _ := d.EtcHosts.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass netip.Addr.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
@@ -987,16 +987,13 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
return nil, fmt.Errorf("rewrites: preparing: %s", err)
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
if !BlockedSvcKnown(s) {
|
||||
log.Debug("skipping unknown blocked-service %q", s)
|
||||
if d.BlockedServices != nil {
|
||||
err = d.BlockedServices.Validate()
|
||||
|
||||
continue
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filtering: %w", err)
|
||||
}
|
||||
bsvcs = append(bsvcs, s)
|
||||
}
|
||||
d.BlockedServices = bsvcs
|
||||
|
||||
if blockFilters != nil {
|
||||
err = d.initFiltering(nil, blockFilters)
|
||||
|
||||
@@ -169,7 +169,7 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
deleted = (*filters)[delIdx]
|
||||
p := deleted.Path(d.DataDir)
|
||||
err = os.Rename(p, p+".old")
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("deleting filter %d: renaming file %q: %s", deleted.ID, p, err)
|
||||
|
||||
return
|
||||
@@ -416,12 +416,12 @@ type checkHostResp struct {
|
||||
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.URL.Query().Get("name")
|
||||
|
||||
setts := d.GetConfig()
|
||||
setts := d.Settings()
|
||||
setts.FilteringEnabled = true
|
||||
setts.ProtectionEnabled = true
|
||||
|
||||
d.ApplyBlockedServices(&setts, nil)
|
||||
result, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
d.ApplyBlockedServices(setts)
|
||||
result, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
@@ -555,6 +555,7 @@ func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
registerHTTP(http.MethodPut, "/control/rewrite/update", d.handleRewriteUpdate)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesIDs)
|
||||
|
||||
@@ -84,7 +84,7 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Check cnames for cycles on initialisation.
|
||||
// TODO(a.garipov): Check cnames for cycles on initialization.
|
||||
cnames := stringutil.NewSet()
|
||||
host := dReq.Hostname
|
||||
for len(rrules) > 0 && rrules[0].DNSRewrite != nil && rrules[0].DNSRewrite.NewCNAME != "" {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
||||
@@ -91,3 +92,62 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// rewriteUpdateJSON is a struct for JSON object with rewrite rule update info.
|
||||
type rewriteUpdateJSON struct {
|
||||
Target rewriteEntryJSON `json:"target"`
|
||||
Update rewriteEntryJSON `json:"update"`
|
||||
}
|
||||
|
||||
// handleRewriteUpdate is the handler for the PUT /control/rewrite/update HTTP
|
||||
// API.
|
||||
func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
updateJSON := rewriteUpdateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&updateJSON)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rwDel := &LegacyRewrite{
|
||||
Domain: updateJSON.Target.Domain,
|
||||
Answer: updateJSON.Target.Answer,
|
||||
}
|
||||
|
||||
rwAdd := &LegacyRewrite{
|
||||
Domain: updateJSON.Update.Domain,
|
||||
Answer: updateJSON.Update.Answer,
|
||||
}
|
||||
|
||||
err = rwAdd.normalize()
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
index := -1
|
||||
defer func() {
|
||||
if index >= 0 {
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
}()
|
||||
|
||||
d.confLock.Lock()
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
index = slices.IndexFunc(d.Config.Rewrites, rwDel.equal)
|
||||
if index == -1 {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "target rule not found")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d.Config.Rewrites = slices.Replace(d.Config.Rewrites, index, index+1, rwAdd)
|
||||
|
||||
log.Debug("rewrite: removed element: %s -> %s", rwDel.Domain, rwDel.Answer)
|
||||
log.Debug("rewrite: added element: %s -> %s", rwAdd.Domain, rwAdd.Answer)
|
||||
}
|
||||
|
||||
237
internal/filtering/rewritehttp_test.go
Normal file
237
internal/filtering/rewritehttp_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package filtering_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
||||
type rewriteJSON struct {
|
||||
Domain string `json:"domain"`
|
||||
Answer string `json:"answer"`
|
||||
}
|
||||
|
||||
type rewriteUpdateJSON struct {
|
||||
Target rewriteJSON `json:"target"`
|
||||
Update rewriteJSON `json:"update"`
|
||||
}
|
||||
|
||||
const (
|
||||
// testTimeout is the common timeout for tests.
|
||||
testTimeout = 100 * time.Millisecond
|
||||
|
||||
listURL = "/control/rewrite/list"
|
||||
addURL = "/control/rewrite/add"
|
||||
deleteURL = "/control/rewrite/delete"
|
||||
updateURL = "/control/rewrite/update"
|
||||
|
||||
decodeErrorMsg = "json.Decode: json: cannot unmarshal string into Go value of type" +
|
||||
" filtering.rewriteEntryJSON\n"
|
||||
)
|
||||
|
||||
func TestDNSFilter_handleRewriteHTTP(t *testing.T) {
|
||||
confModCh := make(chan struct{})
|
||||
reqCh := make(chan struct{})
|
||||
testRewrites := []*rewriteJSON{
|
||||
{Domain: "example.local", Answer: "example.rewrite"},
|
||||
{Domain: "one.local", Answer: "one.rewrite"},
|
||||
}
|
||||
|
||||
testRewritesJSON, mErr := json.Marshal(testRewrites)
|
||||
require.NoError(t, mErr)
|
||||
|
||||
testCases := []struct {
|
||||
reqData any
|
||||
name string
|
||||
url string
|
||||
method string
|
||||
wantList []*rewriteJSON
|
||||
wantBody string
|
||||
wantConfMod bool
|
||||
wantStatus int
|
||||
}{{
|
||||
name: "list",
|
||||
url: listURL,
|
||||
method: http.MethodGet,
|
||||
reqData: nil,
|
||||
wantConfMod: false,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: string(testRewritesJSON) + "\n",
|
||||
wantList: testRewrites,
|
||||
}, {
|
||||
name: "add",
|
||||
url: addURL,
|
||||
method: http.MethodPost,
|
||||
reqData: rewriteJSON{Domain: "add.local", Answer: "add.rewrite"},
|
||||
wantConfMod: true,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
wantList: append(
|
||||
testRewrites,
|
||||
&rewriteJSON{Domain: "add.local", Answer: "add.rewrite"},
|
||||
),
|
||||
}, {
|
||||
name: "add_error",
|
||||
url: addURL,
|
||||
method: http.MethodPost,
|
||||
reqData: "invalid_json",
|
||||
wantConfMod: false,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: decodeErrorMsg,
|
||||
wantList: testRewrites,
|
||||
}, {
|
||||
name: "delete",
|
||||
url: deleteURL,
|
||||
method: http.MethodPost,
|
||||
reqData: rewriteJSON{Domain: "one.local", Answer: "one.rewrite"},
|
||||
wantConfMod: true,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
wantList: []*rewriteJSON{{Domain: "example.local", Answer: "example.rewrite"}},
|
||||
}, {
|
||||
name: "delete_error",
|
||||
url: deleteURL,
|
||||
method: http.MethodPost,
|
||||
reqData: "invalid_json",
|
||||
wantConfMod: false,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: decodeErrorMsg,
|
||||
wantList: testRewrites,
|
||||
}, {
|
||||
name: "update",
|
||||
url: updateURL,
|
||||
method: http.MethodPut,
|
||||
reqData: rewriteUpdateJSON{
|
||||
Target: rewriteJSON{Domain: "one.local", Answer: "one.rewrite"},
|
||||
Update: rewriteJSON{Domain: "upd.local", Answer: "upd.rewrite"},
|
||||
},
|
||||
wantConfMod: true,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
wantList: []*rewriteJSON{
|
||||
{Domain: "example.local", Answer: "example.rewrite"},
|
||||
{Domain: "upd.local", Answer: "upd.rewrite"},
|
||||
},
|
||||
}, {
|
||||
name: "update_error",
|
||||
url: updateURL,
|
||||
method: http.MethodPut,
|
||||
reqData: "invalid_json",
|
||||
wantConfMod: false,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "json.Decode: json: cannot unmarshal string into Go value of type" +
|
||||
" filtering.rewriteUpdateJSON\n",
|
||||
wantList: testRewrites,
|
||||
}, {
|
||||
name: "update_error_target",
|
||||
url: updateURL,
|
||||
method: http.MethodPut,
|
||||
reqData: rewriteUpdateJSON{
|
||||
Target: rewriteJSON{Domain: "inv.local", Answer: "inv.rewrite"},
|
||||
Update: rewriteJSON{Domain: "upd.local", Answer: "upd.rewrite"},
|
||||
},
|
||||
wantConfMod: false,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "target rule not found\n",
|
||||
wantList: testRewrites,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
onConfModified := func() {
|
||||
if !tc.wantConfMod {
|
||||
panic("config modified has been fired")
|
||||
}
|
||||
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlers := make(map[string]http.Handler)
|
||||
|
||||
d, err := filtering.New(&filtering.Config{
|
||||
ConfigModified: onConfModified,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
Rewrites: rewriteEntriesToLegacyRewrites(testRewrites),
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.RegisterFilteringHandlers()
|
||||
require.NotEmpty(t, handlers)
|
||||
require.Contains(t, handlers, listURL)
|
||||
require.Contains(t, handlers, tc.url)
|
||||
|
||||
var body io.Reader
|
||||
if tc.reqData != nil {
|
||||
data, rErr := json.Marshal(tc.reqData)
|
||||
require.NoError(t, rErr)
|
||||
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(tc.method, tc.url, body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
go func() {
|
||||
handlers[tc.url].ServeHTTP(w, r)
|
||||
|
||||
testutil.RequireSend(testutil.PanicT{}, reqCh, struct{}{}, testTimeout)
|
||||
}()
|
||||
|
||||
if tc.wantConfMod {
|
||||
testutil.RequireReceive(t, confModCh, testTimeout)
|
||||
}
|
||||
|
||||
testutil.RequireReceive(t, reqCh, testTimeout)
|
||||
assert.Equal(t, tc.wantStatus, w.Code)
|
||||
|
||||
respBody, err := io.ReadAll(w.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte(tc.wantBody), respBody)
|
||||
|
||||
assertRewritesList(t, handlers[listURL], tc.wantList)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// assertRewritesList checks if rewrites list equals the list received from the
|
||||
// handler by listURL.
|
||||
func assertRewritesList(t *testing.T, handler http.Handler, wantList []*rewriteJSON) {
|
||||
t.Helper()
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, listURL, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var actual []*rewriteJSON
|
||||
err := json.NewDecoder(w.Body).Decode(&actual)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantList, actual)
|
||||
}
|
||||
|
||||
// rewriteEntriesToLegacyRewrites gets legacy rewrites from json entries.
|
||||
func rewriteEntriesToLegacyRewrites(entries []*rewriteJSON) (rw []*filtering.LegacyRewrite) {
|
||||
for _, entry := range entries {
|
||||
rw = append(rw, &filtering.LegacyRewrite{
|
||||
Domain: entry.Domain,
|
||||
Answer: entry.Answer,
|
||||
})
|
||||
}
|
||||
|
||||
return rw
|
||||
}
|
||||
@@ -161,12 +161,8 @@ func (ss *Default) resetEngine(
|
||||
// type check
|
||||
var _ filtering.SafeSearch = (*Default)(nil)
|
||||
|
||||
// CheckHost implements the [filtering.SafeSearch] interface for
|
||||
// *DefaultSafeSearch.
|
||||
func (ss *Default) CheckHost(
|
||||
host string,
|
||||
qtype rules.RRType,
|
||||
) (res filtering.Result, err error) {
|
||||
// CheckHost implements the [filtering.SafeSearch] interface for *Default.
|
||||
func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Result, err error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
|
||||
@@ -196,14 +192,10 @@ func (ss *Default) CheckHost(
|
||||
return filtering.Result{}, err
|
||||
}
|
||||
|
||||
if fltRes != nil {
|
||||
res = *fltRes
|
||||
ss.setCacheResult(host, qtype, res)
|
||||
res = *fltRes
|
||||
ss.setCacheResult(host, qtype, res)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// searchHost looks up DNS rewrites in the internal DNS filtering engine.
|
||||
@@ -229,7 +221,11 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
|
||||
}
|
||||
|
||||
// newResult creates Result object from rewrite rule. qtype must be either
|
||||
// [dns.TypeA] or [dns.TypeAAAA].
|
||||
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the
|
||||
// empty result is converted into a NODATA response.
|
||||
//
|
||||
// TODO(a.garipov): Use the main rewrite result mechanism used in
|
||||
// [dnsforward.Server.filterDNSRequest].
|
||||
func (ss *Default) newResult(
|
||||
rewrite *rules.DNSRewrite,
|
||||
qtype rules.RRType,
|
||||
@@ -243,9 +239,10 @@ func (ss *Default) newResult(
|
||||
}
|
||||
|
||||
if rewrite.RRType == qtype {
|
||||
ip, ok := rewrite.Value.(net.IP)
|
||||
v := rewrite.Value
|
||||
ip, ok := v.(net.IP)
|
||||
if !ok || ip == nil {
|
||||
return nil, nil
|
||||
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", v)
|
||||
}
|
||||
|
||||
res.Rules[0].IP = ip
|
||||
@@ -255,14 +252,14 @@ func (ss *Default) newResult(
|
||||
|
||||
host := rewrite.NewCNAME
|
||||
if host == "" {
|
||||
return nil, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
ss.log(log.DEBUG, "resolving %q", host)
|
||||
|
||||
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("resolving cname: %w", err)
|
||||
}
|
||||
|
||||
ss.log(log.DEBUG, "resolved %s", ips)
|
||||
@@ -276,11 +273,9 @@ func (ss *Default) newResult(
|
||||
}
|
||||
|
||||
res.Rules[0].IP = ip
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package safesearch_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -71,6 +72,25 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
|
||||
conf := testConf
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
|
||||
// with a nil IP address. This isn't really necessary and should be changed
|
||||
// once the TODO in [safesearch.Default.newResult] is resolved.
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Nil(t, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_google(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
@@ -105,6 +125,56 @@ func TestDefault_CheckHost_google(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// testResolver is a [filtering.Resolver] for tests.
|
||||
//
|
||||
// TODO(a.garipov): Move to aghtest and use everywhere.
|
||||
type testResolver struct {
|
||||
OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ filtering.Resolver = (*testResolver)(nil)
|
||||
|
||||
// LookupIP implements the [filtering.Resolver] interface for *testResolver.
|
||||
func (r *testResolver) LookupIP(
|
||||
ctx context.Context,
|
||||
network string,
|
||||
host string,
|
||||
) (ips []net.IP, err error) {
|
||||
return r.OnLookupIP(ctx, network, host)
|
||||
}
|
||||
|
||||
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
|
||||
conf := testConf
|
||||
conf.CustomResolver = &testResolver{
|
||||
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
|
||||
assert.Equal(t, "ip6", network)
|
||||
assert.Equal(t, "safe.duckduckgo.com", host)
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
|
||||
// DuckDuckGo doesn't have a safe-search IPv6 address. The result should be
|
||||
// the same as the one for Yandex IPv6. That is, a NODATA response.
|
||||
res, err := ss.CheckHost("www.duckduckgo.com", dns.TypeAAAA)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
|
||||
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
|
||||
// with a nil IP address. This isn't really necessary and should be changed
|
||||
// once the TODO in [safesearch.Default.newResult] is resolved.
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
assert.Nil(t, res.Rules[0].IP)
|
||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||
}
|
||||
|
||||
func TestDefault_Update(t *testing.T) {
|
||||
conf := testConf
|
||||
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
|
||||
|
||||
@@ -27,6 +27,25 @@ var blockedServices = []blockedService{{
|
||||
"||9cache.com^",
|
||||
"||9gag.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "activision_blizzard",
|
||||
Name: "Activision Blizzard",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"-237 0 1572 1572\"><path d=\"m549.1.2 548.4 1571.4H798l-74.2-200H374.5l-74.3 200H.7zM626 1085.1l-83-274.3-82.9 274.3z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||activision.com^",
|
||||
"||activisionblizzard.com^",
|
||||
"||demonware.net^",
|
||||
},
|
||||
}, {
|
||||
ID: "aliexpress",
|
||||
Name: "AliExpress",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M9 4C6.25 4 4 6.25 4 9v32c0 2.75 2.25 5 5 5h32c2.75 0 5-2.25 5-5V9c0-2.75-2.25-5-5-5H9zm0 2h32c1.668 0 3 1.332 3 3v3.38A3.973 3.973 0 0 0 41 11H9a3.973 3.973 0 0 0-3 1.38V9c0-1.668 1.332-3 3-3zm6 11a1 1 0 0 1 1 1c0 4.962 4.037 9 9 9s9-4.038 9-9a1 1 0 1 1 2 0c0 6.065-4.935 11-11 11s-11-4.935-11-11a1 1 0 0 1 1-1z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||ae-rus.net^",
|
||||
"||ae-rus.ru^",
|
||||
"||aliexpress.com^",
|
||||
"||aliexpress.ru^",
|
||||
},
|
||||
}, {
|
||||
ID: "amazon",
|
||||
Name: "Amazon",
|
||||
@@ -234,6 +253,16 @@ var blockedServices = []blockedService{{
|
||||
"||z.cn^",
|
||||
"||zappos^",
|
||||
},
|
||||
}, {
|
||||
ID: "battle_net",
|
||||
Name: "Battle.net",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M43.11 22.15s3.95.2 3.95-2.12c0-3.03-5.26-5.77-5.26-5.77s.83-1.74 1.34-2.72a37.3 37.3 0 0 0 2.09-5.65c.16-1.1-.09-1.44-.09-1.44-.35 2.34-4.17 9.09-4.47 9.32-3.72-1.75-8.83-2.23-8.83-2.23S26.84 1 22.13 1c-4.67 0-4.65 9.02-4.65 9.02s-1.32-2.56-2.97-2.56c-2.42 0-3.22 3.67-3.22 7.64a37.8 37.8 0 0 0-9.16 1.17c-.36.1-1.49.92-.97.82 1.04-.34 5.95-1.1 10.25-.72.24 3.77 2.44 8.68 2.44 8.68S9.13 31.9 9.13 36.78c0 1.29.56 3.64 3.95 3.64 2.84 0 6.03-1.7 6.63-2.06a6.33 6.33 0 0 0-.91 2.83c0 .54.31 2.06 2.5 2.06 2.82 0 5.96-2.16 5.96-2.16s2.96 4.93 5.5 7.2c.69.6 1.34.71 1.34.71s-2.52-2.43-5.84-8.68c3.08-1.9 6.3-6.4 6.3-6.4l3.3.01c4.6 0 11.11-.96 11.11-4.61 0-3.77-5.86-7.17-5.86-7.17Zm.52-2.26c0 1.33-1.27 1.3-1.27 1.3l-.97.08s-1.82-.97-2.93-1.41c0 0 1.72-2.65 2.12-3.4.3.18 3.05 1.9 3.05 3.43ZM24.43 6.3c2.15 0 5.23 5.1 5.23 5.1s-4.8-.44-8.76 1.89c.1-3.67 1.34-7 3.52-7Zm-8.56 4.13c.69 0 1.36.83 1.64 1.54 0 .47.24 3.2.24 3.2l-3.96-.16c0-3.57 1.4-4.58 2.08-4.58Zm-.4 24.8c-2.17 0-2.62-1.2-2.62-2.29 0-2.45 1.96-5.9 1.96-5.9s2.2 4.63 6.04 6.59a10.02 10.02 0 0 1-5.39 1.6Zm7.02 4.85c-1.52 0-1.7-.98-1.7-1.21 0-.7.55-1.54.55-1.54s2.55-1.73 2.71-1.91l1.89 3.52s-1.93 1.14-3.45 1.14Zm4.74-1.92c-.93-1.62-1.6-3.3-1.6-3.3s3.78.24 5.82-1.86a11.2 11.2 0 0 1-5.65 1.07c4.93-4.34 7.8-7.48 10.23-10.74a9.46 9.46 0 0 0-1.6-1.15c-1.46 1.76-7.16 7.86-12.45 10.88-6.69-3.64-8.09-14.38-8.23-16.6l3.65.34s-1.37 2.44-1.37 4.23c0 1.79.21 1.89.21 1.89s-.04-3.13 1.89-5.54c1.46 7.82 3 11.83 4.19 14.22.6-.25 1.74-.76 1.74-.76s-3.38-9.73-3.19-16.31a13.8 13.8 0 0 1 6.36-1.66c6.73 0 12.14 2.9 12.14 2.9l-2.12 2.95s-1.89-3.42-4.55-4.03c1.4 1.05 2.98 2.44 3.8 4.43a68.4 68.4 0 0 0-14.47-3.59c-.19.8-.17 1.94-.17 1.94s9.03 1.66 15.6 5.43c-.05 8.21-9 14.53-10.23 15.26Zm8.55-6.14s2.8-3.68 2.76-8.55c0 0 4.52 2.8 4.52 5.54 0 3.05-7.28 3-7.28 3Z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||battle.net^",
|
||||
"||battlenet.com.cn^",
|
||||
"||bnet.163.com^",
|
||||
"||bnet.cn^",
|
||||
},
|
||||
}, {
|
||||
ID: "bilibili",
|
||||
Name: "Bilibili",
|
||||
@@ -283,6 +312,21 @@ var blockedServices = []blockedService{{
|
||||
"||mincdn.com^",
|
||||
"||yo9.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "blizzard_entertainment",
|
||||
Name: "Blizzard Entertainment",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 -32 128 128\"><path fill-rule=\"evenodd\" d=\"M105 2h3v1h2l2 1 1 1h3l1 1h4l1 1 2 2v1l1 3v4l1 2v6l-1 2v2l-1 3v2l-1 2v14l-1 2v1l-1 3-1 1h-3l-1 1h-6a5 5 0 0 0 1-6l2-1h-1l-1-3v-3a350 350 0 0 1 0-8l-1-3v-1l-1-1V9h1V6l-1-1-4-3Zm9 13v10h1v25a8 8 0 0 0 2-4l1-1 1-3V30l1-1v-2l1-1v-5l-1-2-2-3-1-1h-3Z\" clip-rule=\"evenodd\"/><path fill-rule=\"evenodd\" d=\"M101 24v1l2 1h1v2h1l1 2v5l1 2s0-1 0 0l1 7 1 2v7l-1 5h-2l-2-2-4-1 1-3 1-2a22 22 0 0 0-1-10l-1-4h-1l1-4-1-1-2-3v2l-1 1v3l1 6v4l1 1-1 3v4l1 1-1 2v4l-1-1a13 13 0 0 0-4-5l-2-2 2-5V27l-1-1v-4l-1-1v-5h-1a33 33 0 0 1 0-4l4-4h-2l-4-4h-1V3h10l2 1 2 1h1c2 0 2 1 3 2l2 3 1 3v1l-1 2v1a11 11 0 0 1-1 4l-4 3ZM96 9v13l1 1a3 3 0 0 0 1-1c1 0 2-1 2-3v-1l1-1v-3l-2-3-2-2h-1ZM26 3l1 1h1l2 3v5l1 1v2l-1 1v9l1 1 1 1-1 7v9l-1 1 1 1-1 1v8h3l1-1h7v-1h16v6l1 2h-6l-1-1h-2l-1-1H31a4 4 0 0 0-3-1l-1 1h-1l-1 1h-5l1-1a10 10 0 0 0 3-2v-9l1-1-1-1V35l1-1V21l-1-1v-4l1-1v-3l1-2-1-3h-1l-2-2-1-1 1-1h4Z\" clip-rule=\"evenodd\"/><path fill-rule=\"evenodd\" d=\"M84 60v-3l-1-2v-4l-3-2v-1l1-2a11 11 0 0 0 2-6l-1-1-3-2h-2v3l1 1h1l-1 2h-4l-2 1-2 1-1-2v-1l1-1 1-1 1-2v-5l1-1v-6l1-1v-3l1-1v-3l1-2 1-1-1-1 1-1 1-3 1-1V7l1-1c1-1 0-4 2-3l1 3 1 1 1 2v1l1 5 1 3v2l1 1v2l1 1v8l1 3v9l-1 1-2 5v3l-1 2v4l-1 1h-1Zm-4-36-1 1v2l-1 2v4l4 1h2v-7l-1-3-2-1-1-1v2Z\" clip-rule=\"evenodd\"/><path fill-rule=\"evenodd\" d=\"M77 4v1l-2 3v2l-1 2v1l-1 1-1 4v7h-1v2a5 5 0 0 1-1 2v2l-2 2v7l-1 2v2l-2 4v3-1h3v-1l3-1 1-1 3-2h3l1 1-1 1a3 3 0 0 0 0 1l1 1v5l-1 1h-7v-1h-2l-2 1h-4l-2 1-1-2v-2l1-1v-1l1-1-1-1 1-1v-2l1-2-1-2v-8l1-2 2-5-1-1 1-2v-1l1-1v-4l1-1 2-4v-2l1-1h-3V8h-1l-1 1-2 3-1 4h-1l-1-1v-2l1-1V4h16ZM32 4h9l1 2-3 2 1 2-1 1v13l1 2-1 2v6l-1 1v2l1 1v5l-1 1 1 2 1 1 2 1v2h-7l-2 1-1-1 3-2v-8a4 4 0 0 1 0-2l1-1v-3l-1-14v-2l1-1h-1V7l-2-1h-1l-1-1 1-1Zm12 0h14v15c-2 1-2 4-3 6v2c-1 0-3 1-2 4h-1l-2 3-1 2v3l-1 2-2 5h2l1-1h2l1-1c1-1 1-3 3-3l1-2 2-2h1l1 3h-1v1l-1 1v7h-8l-1 1-2-1h-3l-1-1 1-1v-3l-1-2 1-1-1-1 1-3v-2l1-2 1-3a7 7 0 0 1 2-4l1-4 2-2 2-3v-3h1l2-3V8l-3-1h-2l-1 1a3 3 0 0 0-2 3l-1 1v4l-1 1-1 1v-1l-1-1V4ZM17 22l1 1h1v3s0-1 0 0l2 1v5l1 2-1 8v3a6 6 0 0 1 0 2l-1 2-1 2-1 3-3 2-2 2-3 1-1-1-1 1H1l-1-1 2-1 1-4V26l1-1-1-3V11l1-1-1-1H2V8L1 7 0 6V5l1-1h15l1 1c2 0 3 1 3 2l1 3v6l-4 6Zm-6-11v9h1l1-2 2-1v-6h-1l-1-1h-2v1Zm0 19-1 1 1 2-1 3v9a2 2 0 0 0 0 1v6l-1 1 3-1 1-2h1v-4l1-4v-5l-1-1 1-3-1-1v-2s0 1 0 0v-1l-1-1a20 20 0 0 1-2-2v4Z\" clip-rule=\"evenodd\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||battle.net^",
|
||||
"||battlenet.com.cn^",
|
||||
"||blizzard.cn^",
|
||||
"||blizzardgames.cn^",
|
||||
"||blz-contentstack.com^",
|
||||
"||blzstatic.cn^",
|
||||
"||bnet.163.com^",
|
||||
"||bnet.cn^",
|
||||
"||lizzard.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "cloudflare",
|
||||
Name: "CloudFlare",
|
||||
@@ -319,6 +363,14 @@ var blockedServices = []blockedService{{
|
||||
"||warp.plus^",
|
||||
"||workers.dev^",
|
||||
},
|
||||
}, {
|
||||
ID: "clubhouse",
|
||||
Name: "Clubhouse",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M29.8 4a1 1 0 0 0-.92.7 1 1 0 0 0 .36 1.1 31.2 31.2 0 0 1 6 6.02 1 1 0 1 0 1.6-1.2 33.2 33.2 0 0 0-6.4-6.4A1 1 0 0 0 29.8 4Zm-7.16 1.06c-.46 0-.87.3-.99.74a1 1 0 0 0 .5 1.15 31.13 31.13 0 0 1 11.13 10.6 1 1 0 1 0 1.7-1.07A33.12 33.12 0 0 0 23.11 5.2a.96.96 0 0 0-.48-.14ZM14.5 7.01a3.42 3.42 0 0 0-3.27 2.28l-.26-.27A3.49 3.49 0 0 0 8.5 8.01c-.9 0-1.8.34-2.48 1.01a3.51 3.51 0 0 0-.57 4.17c-.52.15-1.01.42-1.43.84a3.52 3.52 0 0 0 0 4.94l.27.27c-.46.16-.9.41-1.27.79a3.52 3.52 0 0 0 0 4.94l.88.88 16.47 16.47a9.01 9.01 0 0 0 12.72 0l4.23-4.22a9.94 9.94 0 0 0 2.3-3.59l2.63-7.08a8.03 8.03 0 0 1 1.84-2.87l1.74-1.73 1-1a4.02 4.02 0 0 0 0-5.66 4.02 4.02 0 0 0-5.66 0l-1 1-.7.71-4.2 4.2a2.98 2.98 0 0 1-4.24 0L17.9 8.96l-.94-.94a3.49 3.49 0 0 0-2.47-1.01Zm0 1.98c.38 0 .76.15 1.06.45l.94.94 13.1 13.1a5.02 5.02 0 0 0 7.08 0l4.2-4.18.7-.71 1-1c.8-.8 2.05-.8 2.83 0 .8.79.8 2.04 0 2.83l-2.73 2.73a10.03 10.03 0 0 0-2.3 3.58l-2.63 7.08a8.02 8.02 0 0 1-1.84 2.87l-4.23 4.23a6.99 6.99 0 0 1-9.9 0L4.44 23.56a1.5 1.5 0 0 1 0-2.12c.59-.59 1.45-.55 2.08.08l.1.09 8.2 8.37a1 1 0 0 0 .97.29 1 1 0 0 0 .46-1.68l-9.52-9.73-.01-.01-1.28-1.29a1.5 1.5 0 0 1 0-2.12c.6-.6 1.47-.58 2.08.03l9.18 9.17a1 1 0 0 0 1.69-.43 1 1 0 0 0-.28-.98L9 14.13l-.06-.07-1.5-1.5c-.6-.6-.6-1.53 0-2.12a1.5 1.5 0 0 1 2.12 0L20.8 21.67a1 1 0 0 0 1.68-.44 1 1 0 0 0-.27-.97l-8.7-8.7-.06-.06a1.4 1.4 0 0 1-.01-2.06c.3-.3.68-.45 1.06-.45ZM4.23 32a1 1 0 0 0-.82 1.51c3 5.18 7.36 9.46 12.59 12.37a1 1 0 0 0 1.51-.89 1 1 0 0 0-.54-.86A31.16 31.16 0 0 1 5.15 32.5a1.01 1.01 0 0 0-.92-.51Z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||clubhouse.com^",
|
||||
"||clubhouseapi.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "crunchyroll",
|
||||
Name: "Crunchyroll",
|
||||
@@ -726,6 +778,18 @@ var blockedServices = []blockedService{{
|
||||
"||xxbay.com^",
|
||||
"||yibei.org^",
|
||||
},
|
||||
}, {
|
||||
ID: "electronic_arts",
|
||||
Name: "Electronic Arts",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 1000 1000\"><path d=\"M500 1000C224.3 1000 0 775.7 0 500S224.3 0 500 0s500 224.3 500 500-224.3 500-500 500zm84.63-693.4H302.05l-42.87 68.9h282.25zm57.75.66L469.63 582.33H278.02l44.2-68.96h114.85l43.87-68.93h-265.5l-43.86 68.93h62.9L147.2 651.05h364.2L645.9 438.9l49.05 74.46h-44.23l-41.88 68.96H739.8l45.48 68.72h83.54z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||ea.com^",
|
||||
"||eamobile.com^",
|
||||
"||easports.com^",
|
||||
"||nearpolar.com^",
|
||||
"||swtor.com^",
|
||||
"||tnt-ea.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "epic_games",
|
||||
Name: "Epic Games",
|
||||
@@ -1390,11 +1454,39 @@ var blockedServices = []blockedService{{
|
||||
"||line-apps.com^",
|
||||
"||line-cdn.net^",
|
||||
"||line-scdn.net^",
|
||||
"||line.biz^",
|
||||
"||line.me^",
|
||||
"||line.naver.jp^",
|
||||
"||linecorp.com^",
|
||||
"||linefriends.com.tw^",
|
||||
"||linefriends.com^",
|
||||
"||linegame.jp^",
|
||||
"||linemobile.com^",
|
||||
"||linemyshop.com^",
|
||||
"||lineshoppingseller.com^",
|
||||
"||linetv.tw^",
|
||||
},
|
||||
}, {
|
||||
ID: "linkedin",
|
||||
Name: "LinkedIn",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M41,4H9C6.24,4,4,6.24,4,9v32c0,2.76,2.24,5,5,5h32c2.76,0,5-2.24,5-5V9C46,6.24,43.76,4,41,4z M17,20v19h-6V20H17z M11,14.47c0-1.4,1.2-2.47,3-2.47s2.93,1.07,3,2.47c0,1.4-1.12,2.53-3,2.53C12.2,17,11,15.87,11,14.47z M39,39h-6c0,0,0-9.26,0-10 c0-2-1-4-3.5-4.04h-0.08C27,24.96,26,27.02,26,29c0,0.91,0,10,0,10h-6V20h6v2.56c0,0,1.93-2.56,5.81-2.56 c3.97,0,7.19,2.73,7.19,8.26V39z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||bizographics.com^",
|
||||
"||cs1404.wpc.epsiloncdn.net^",
|
||||
"||cs767.wpc.epsiloncdn.net^",
|
||||
"||l-0005.dc-msedge.net^",
|
||||
"||l-0005.l-dc-msedge.net^",
|
||||
"||l-0005.l-msedge.net^",
|
||||
"||l-0015.l-msedge.net^",
|
||||
"||licdn.cn^",
|
||||
"||licdn.com^",
|
||||
"||linkedin.at^",
|
||||
"||linkedin.be^",
|
||||
"||linkedin.cn^",
|
||||
"||linkedin.com^",
|
||||
"||linkedin.nl^",
|
||||
"||linkedin.qtlcdn.com^",
|
||||
"||lnkd.in^",
|
||||
},
|
||||
}, {
|
||||
ID: "mail_ru",
|
||||
@@ -1438,7 +1530,6 @@ var blockedServices = []blockedService{{
|
||||
"||masto.pt^",
|
||||
"||mastodon.au^",
|
||||
"||mastodon.bida.im^",
|
||||
"||mastodon.com.tr^",
|
||||
"||mastodon.eus^",
|
||||
"||mastodon.green^",
|
||||
"||mastodon.ie^",
|
||||
@@ -1454,7 +1545,7 @@ var blockedServices = []blockedService{{
|
||||
"||mastodon.social^",
|
||||
"||mastodon.uno^",
|
||||
"||mastodon.world^",
|
||||
"||mastodon.xyz^",
|
||||
"||mastodon.zaclys.com^",
|
||||
"||mastodonapp.uk^",
|
||||
"||mastodonners.nl^",
|
||||
"||mastodont.cat^",
|
||||
@@ -1465,12 +1556,12 @@ var blockedServices = []blockedService{{
|
||||
"||metalhead.club^",
|
||||
"||mindly.social^",
|
||||
"||mstdn.ca^",
|
||||
"||mstdn.jp^",
|
||||
"||mstdn.party^",
|
||||
"||mstdn.plus^",
|
||||
"||mstdn.social^",
|
||||
"||muenchen.social^",
|
||||
"||newsie.social^",
|
||||
"||muenster.im^",
|
||||
"||nerdculture.de^",
|
||||
"||noc.social^",
|
||||
"||norden.social^",
|
||||
"||nrw.social^",
|
||||
@@ -1498,16 +1589,17 @@ var blockedServices = []blockedService{{
|
||||
"||techhub.social^",
|
||||
"||theblower.au^",
|
||||
"||tkz.one^",
|
||||
"||todon.eu^",
|
||||
"||toot.aquilenet.fr^",
|
||||
"||toot.community^",
|
||||
"||toot.funami.tech^",
|
||||
"||toot.io^",
|
||||
"||toot.wales^",
|
||||
"||troet.cafe^",
|
||||
"||twingyeo.kr^",
|
||||
"||union.place^",
|
||||
"||universeodon.com^",
|
||||
"||urbanists.social^",
|
||||
"||wien.rocks^",
|
||||
"||wxw.moe^",
|
||||
},
|
||||
}, {
|
||||
@@ -1550,6 +1642,44 @@ var blockedServices = []blockedService{{
|
||||
"||nflxso.net^",
|
||||
"||nflxvideo.net^",
|
||||
},
|
||||
}, {
|
||||
ID: "nintendo",
|
||||
Name: "Nintendo",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M6 7v36h12.6V21.75l13 20.78.27.47H44V7H31.4v1l.04 20.22L18.5 7.47 18.22 7Zm2 2h9.1l14.5 23.22 1.84 3v-3.5L33.4 9H42v32h-9L18.44 17.75l-1.85-2.94V41H8Z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||nintendo-europe.com^",
|
||||
"||nintendo.be^",
|
||||
"||nintendo.co.jp^",
|
||||
"||nintendo.co.uk^",
|
||||
"||nintendo.com.au^",
|
||||
"||nintendo.com^",
|
||||
"||nintendo.de^",
|
||||
"||nintendo.es^",
|
||||
"||nintendo.eu^",
|
||||
"||nintendo.fr^",
|
||||
"||nintendo.it^",
|
||||
"||nintendo.jp^",
|
||||
"||nintendo.net^",
|
||||
"||nintendo.nl^",
|
||||
"||nintendoswitch.cn^",
|
||||
"||nintendowifi.net^",
|
||||
},
|
||||
}, {
|
||||
ID: "nvidia",
|
||||
Name: "Nvidia",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 48 48\"><path d=\"M20 8a2 2 0 0 0-2 2v2.55l.84-.05c10.76-.37 17.78 8.82 17.78 8.82s-8.05 9.8-16.44 9.8c-.73 0-1.47-.07-2.18-.19v-2.2c.73.23 1.52.35 2.3.35 5.88 0 11.35-7.6 11.35-7.6s-5.07-6.91-12.81-6.66l-.82.03v-2.3c-9.49.77-17.68 8.8-17.68 8.8S4.97 34.76 18 35.98v-2.44c.59.07 1.22.12 1.81.12 7.82 0 13.47-3.99 18.94-8.7.91.73 4.62 2.49 5.4 3.26-5.2 4.36-17.33 7.86-24.2 7.86-.66 0-1.32-.03-1.95-.1V38c0 1.1.9 2 2 2h25a2 2 0 0 0 2-2V10a2 2 0 0 0-2-2H20zm-2 6.86v2.82a11.8 11.8 0 0 1 1.57-.07c4.95 0 7.9 3.85 7.9 3.85l-4.03 3.39c-1.8-3.02-2.43-4.35-5.44-4.7v8.57c-4.06-1.38-5.4-6.14-5.4-6.14s2.37-2.83 5.38-2.46H18v-2.44a15.66 15.66 0 0 0-9.22 4.46s2 7.52 9.22 8.8v2.6c-9.56-1.17-12.82-11.7-12.82-11.7s4.27-6.3 12.82-6.97z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||geforce.com^",
|
||||
"||geforcenow.com^",
|
||||
"||nvidia.cn^",
|
||||
"||nvidia.com.global.ogslb.com^",
|
||||
"||nvidia.com^",
|
||||
"||nvidia.eu^",
|
||||
"||nvidia.partners^",
|
||||
"||nvidiagrid.net^",
|
||||
"||nvidianews.com^",
|
||||
"||tegrazone.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "ok",
|
||||
Name: "OK.ru",
|
||||
@@ -1704,6 +1834,14 @@ var blockedServices = []blockedService{{
|
||||
"||robloxcdn.com^",
|
||||
"||robloxdev.cn^",
|
||||
},
|
||||
}, {
|
||||
ID: "rockstar_games",
|
||||
Name: "Rockstar games",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 50 50\"><path d=\"M12 3c-4.96 0-9 4.04-9 9v26c0 4.96 4.04 9 9 9h26c4.96 0 9-4.04 9-9V12c0-4.96-4.04-9-9-9H12zm0 2h26c3.88 0 7 3.12 7 7v26c0 3.88-3.12 7-7 7H12c-3.88 0-7-3.12-7-7V12c0-3.88 3.12-7 7-7zm3.72 5a1 1 0 0 0-.97.79l-3.87 18a1 1 0 0 0 .98 1.21h4.27a1 1 0 0 0 .97-.79L18.47 23h2.07c.94 0 1.12.15 1.36.73.24.57.3 1.76.1 3.4-.08.68-.05 1.22.02 1.6v.03a1 1 0 0 0 .3.97l3.37 3.12-2.6 5.74a1 1 0 0 0 1.43 1.26l5.58-3.39 4.29 3.33a1 1 0 0 0 1.6-.98l-1.09-5.56 4.7-3.47a1 1 0 0 0-.6-1.8h-4.86l-.82-5.14a1 1 0 0 0-.98-.84 1 1 0 0 0-.88.51l-2.77 5a14.3 14.3 0 0 1 .06-2.83c.15-1.48.01-2.64-.18-3.45-.06-.28-.08-.25-.15-.45.3-.17.4-.13.77-.5.8-.8 1.6-2.18 1.75-4.26.17-2.26-.55-3.98-1.92-4.9C27.65 10.17 25.91 10 24 10h-8.28zm.81 2H24c1.75 0 3.13.25 3.9.77.76.52 1.18 1.27 1.05 3.1-.13 1.67-.69 2.51-1.17 3a2 2 0 0 1-.82.56 1 1 0 0 0-.6 1.44s.12.21.27.82c.14.6.26 1.53.13 2.79a14.24 14.24 0 0 0-.01 3.52h-2.76c-.01-.19-.04-.32 0-.62.22-1.78.25-3.21-.24-4.42A3.38 3.38 0 0 0 20.54 21h-2.87a1 1 0 0 0-.98.78L15.32 28H13.1l3.44-16zm2.76 1.03a1 1 0 0 0-.98.8l-.98 4.94a1 1 0 0 0 .98 1.2h4.47c.79 0 1.65-.12 2.44-.58a3.6 3.6 0 0 0 1.68-2.41 3.3 3.3 0 0 0-.72-2.92 3.35 3.35 0 0 0-2.47-1.03h-4.42zm.82 2h3.6c.41 0 .79.16 1 .4.22.22.36.52.23 1.15-.13.62-.36.88-.72 1.08a3 3 0 0 1-1.44.3h-3.25l.58-2.93zm11.7 10.99.49 3.11a1 1 0 0 0 .98.84h2.69l-2.76 2.05a1 1 0 0 0-.4 1l.7 3.56-2.73-2.12a1 1 0 0 0-1.13-.07l-3.4 2.07 1.56-3.44a1 1 0 0 0-.23-1.15L25.55 30H29a1 1 0 0 0 .88-.51l1.92-3.47z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||rockstargames.com^",
|
||||
"||rsg.sc^",
|
||||
},
|
||||
}, {
|
||||
ID: "shopee",
|
||||
Name: "Shopee",
|
||||
@@ -1940,6 +2078,16 @@ var blockedServices = []blockedService{{
|
||||
"||twvid.com^",
|
||||
"||vine.co^",
|
||||
},
|
||||
}, {
|
||||
ID: "ubisoft",
|
||||
Name: "Ubisoft",
|
||||
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"0 0 32 32\"><path d=\"M15.22 3C7.14 3 3.66 10.18 3.66 10.18l1.03.74s-1.3 2.45-1.26 5.6A12.5 12.5 0 0 0 16.08 29a12.5 12.5 0 0 0 12.49-12.46c0-9-6.98-13.54-13.35-13.54zm.07 2.2c6.3 0 11.2 5.07 11.2 10.98 0 6.27-4.71 10.62-10.2 10.62-4.04 0-7.69-3.08-7.69-7.3a5.8 5.8 0 0 1 2.75-5.03l.21.23a6.37 6.37 0 0 0-1.53 3.91c0 3.32 2.6 5.62 5.88 5.62 4.18 0 6.97-3.56 6.97-7.7 0-4.81-4.25-8.9-9.36-8.9a11.1 11.1 0 0 0-6.61 2.3l-.21-.2a10.07 10.07 0 0 1 8.59-4.54zM13.4 9.8c3.26 0 6.44 2.15 7.24 5.22l-.3.1a8.35 8.35 0 0 0-6.52-3.44c-5.08 0-7.75 4.62-7.36 8.47l-.3.12s-.56-1.24-.56-2.71a7.8 7.8 0 0 1 7.8-7.76zm2.15 5.33a2.77 2.77 0 0 1 2.78 2.74c0 1.23-.79 1.96-.79 1.96l.94.65s-.93 1.46-2.82 1.46a3.4 3.4 0 0 1-.1-6.8z\"/></svg>"),
|
||||
Rules: []string{
|
||||
"||ubi.com^",
|
||||
"||ubisoft.com^",
|
||||
"||ubisoft.org^",
|
||||
"||ubisoftconnect.com^",
|
||||
},
|
||||
}, {
|
||||
ID: "valorant",
|
||||
Name: "Valorant",
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
@@ -22,12 +23,14 @@ type Client struct {
|
||||
safeSearchConf filtering.SafeSearchConfig
|
||||
SafeSearch filtering.SafeSearch
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices
|
||||
|
||||
Name string
|
||||
|
||||
IDs []string
|
||||
Tags []string
|
||||
BlockedServices []string
|
||||
Upstreams []string
|
||||
IDs []string
|
||||
Tags []string
|
||||
Upstreams []string
|
||||
|
||||
UseOwnSettings bool
|
||||
FilteringEnabled bool
|
||||
@@ -43,9 +46,9 @@ type Client struct {
|
||||
func (c *Client) ShallowClone() (sh *Client) {
|
||||
clone := *c
|
||||
|
||||
clone.BlockedServices = c.BlockedServices.Clone()
|
||||
clone.IDs = stringutil.CloneSlice(c.IDs)
|
||||
clone.Tags = stringutil.CloneSlice(c.Tags)
|
||||
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
|
||||
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
|
||||
return &clone
|
||||
@@ -127,14 +130,13 @@ func (cs clientSource) MarshalText() (text []byte, err error) {
|
||||
// RuntimeClient is a client information about which has been obtained using the
|
||||
// source described in the Source field.
|
||||
type RuntimeClient struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo
|
||||
Host string
|
||||
Source clientSource
|
||||
}
|
||||
// WHOIS is the filtered WHOIS data of a client.
|
||||
WHOIS *whois.Info
|
||||
|
||||
// RuntimeClientWHOISInfo is the filtered WHOIS data for a runtime client.
|
||||
type RuntimeClientWHOISInfo struct {
|
||||
City string `json:"city,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Orgname string `json:"orgname,omitempty"`
|
||||
// Host is the host name of a client.
|
||||
Host string
|
||||
|
||||
// Source is the source from which the information about the client has
|
||||
// been obtained.
|
||||
Source clientSource
|
||||
}
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -23,6 +25,23 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||
// needs.
|
||||
type DHCP interface {
|
||||
// Leases returns all the DHCP leases.
|
||||
Leases() (leases []*dhcpsvc.Lease)
|
||||
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have an IP address.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// MACByIP returns the MAC address for the given IP address leased. It
|
||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||
// client must always have a MAC address.
|
||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
@@ -77,7 +96,7 @@ func (clients *clientsContainer) Init(
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
filteringConf *filtering.Config,
|
||||
) {
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
@@ -91,23 +110,29 @@ func (clients *clientsContainer) Init(
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpdb = arpdb
|
||||
clients.addFromConfig(objects, filteringConf)
|
||||
err = clients.addFromConfig(objects, filteringConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
|
||||
if clients.testing {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.updateFromDHCP(true)
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
|
||||
}
|
||||
|
||||
if clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
@@ -147,12 +172,14 @@ func (clients *clientsContainer) reloadARP() {
|
||||
type clientObject struct {
|
||||
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices `yaml:"blocked_services"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
|
||||
Tags []string `yaml:"tags"`
|
||||
IDs []string `yaml:"ids"`
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
IDs []string `yaml:"ids"`
|
||||
Tags []string `yaml:"tags"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
|
||||
UseGlobalSettings bool `yaml:"use_global_settings"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
@@ -166,7 +193,10 @@ type clientObject struct {
|
||||
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
|
||||
func (clients *clientsContainer) addFromConfig(
|
||||
objects []*clientObject,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
for _, o := range objects {
|
||||
cli := &Client{
|
||||
Name: o.Name,
|
||||
@@ -187,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err := cli.setSafeSearch(
|
||||
err = cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
@@ -199,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range o.BlockedServices {
|
||||
if filtering.BlockedSvcKnown(s) {
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
} else {
|
||||
log.Info("clients: skipping unknown blocked service %q", s)
|
||||
}
|
||||
err = o.BlockedServices.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
|
||||
}
|
||||
|
||||
cli.BlockedServices = o.BlockedServices.Clone()
|
||||
|
||||
for _, t := range o.Tags {
|
||||
if clients.allTags.Has(t) {
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
@@ -217,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
|
||||
slices.Sort(cli.Tags)
|
||||
|
||||
_, err := clients.Add(cli)
|
||||
_, err = clients.Add(cli)
|
||||
if err != nil {
|
||||
log.Error("clients: adding clients %s: %s", cli.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forConfig returns all currently known persistent clients as objects for the
|
||||
@@ -235,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
o := &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
@@ -276,15 +308,38 @@ func (clients *clientsContainer) periodicUpdate() {
|
||||
}
|
||||
}
|
||||
|
||||
// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of
|
||||
// runtime clients using the DHCP server's leases.
|
||||
//
|
||||
// TODO(e.burkov): Remove when switched to dhcpsvc.
|
||||
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
switch flags {
|
||||
case dhcpd.LeaseChangedAdded,
|
||||
dhcpd.LeaseChangedAddedStatic,
|
||||
dhcpd.LeaseChangedRemovedStatic:
|
||||
clients.updateFromDHCP(true)
|
||||
case dhcpd.LeaseChangedRemovedAll:
|
||||
clients.updateFromDHCP(false)
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceDHCP)
|
||||
|
||||
if flags == dhcpd.LeaseChangedRemovedAll {
|
||||
return
|
||||
}
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if l.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from dhcp", n)
|
||||
}
|
||||
|
||||
// clientSource checks if client with this IP address already exists and returns
|
||||
@@ -300,23 +355,11 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource)
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
return ClientSourceNone
|
||||
if ok {
|
||||
return rc.Source
|
||||
}
|
||||
|
||||
return rc.Source
|
||||
}
|
||||
|
||||
func toQueryLogWHOIS(wi *RuntimeClientWHOISInfo) (cw *querylog.ClientWHOIS) {
|
||||
if wi == nil {
|
||||
return &querylog.ClientWHOIS{}
|
||||
}
|
||||
|
||||
return &querylog.ClientWHOIS{
|
||||
City: wi.City,
|
||||
Country: wi.Country,
|
||||
Orgname: wi.Orgname,
|
||||
}
|
||||
return ClientSourceNone
|
||||
}
|
||||
|
||||
// findMultiple is a wrapper around Find to make it a valid client finder for
|
||||
@@ -352,7 +395,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
defer func() {
|
||||
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
|
||||
if c.WHOIS == nil {
|
||||
c.WHOIS = &querylog.ClientWHOIS{}
|
||||
c.WHOIS = &whois.Info{}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -369,7 +412,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
if ok {
|
||||
return &querylog.Client{
|
||||
Name: rc.Host,
|
||||
WHOIS: toQueryLogWHOIS(rc.WHOISInfo),
|
||||
WHOIS: rc.WHOIS,
|
||||
}, false
|
||||
}
|
||||
|
||||
@@ -477,11 +520,11 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if clients.dhcpServer == nil {
|
||||
return nil, false
|
||||
if clients.dhcpServer != nil {
|
||||
return clients.findDHCP(ip)
|
||||
}
|
||||
|
||||
return clients.findDHCP(ip)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
||||
@@ -701,35 +744,34 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
|
||||
}
|
||||
|
||||
// setWHOISInfo sets the WHOIS information for a client.
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Consider storing WHOIS information separately and
|
||||
// potentially get rid of [RuntimeClient].
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
rc.WHOISInfo = wi
|
||||
if !ok {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &RuntimeClient{
|
||||
Source: ClientSourceWHOIS,
|
||||
}
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &RuntimeClient{
|
||||
Source: ClientSourceWHOIS,
|
||||
}
|
||||
|
||||
rc.WHOISInfo = wi
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
rc.WHOIS = wi
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
@@ -753,23 +795,19 @@ func (clients *clientsContainer) addHostLocked(
|
||||
src clientSource,
|
||||
) (ok bool) {
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
} else {
|
||||
if !ok {
|
||||
rc = &RuntimeClient{
|
||||
Host: host,
|
||||
Source: src,
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
} else if src < rc.Source {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
|
||||
log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC))
|
||||
|
||||
return true
|
||||
@@ -838,38 +876,6 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
||||
}
|
||||
|
||||
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
|
||||
// server.
|
||||
func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceDHCP)
|
||||
|
||||
if !add {
|
||||
return
|
||||
}
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if l.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from dhcp", n)
|
||||
}
|
||||
|
||||
// close gracefully closes all the client-specific upstream configurations of
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
|
||||
@@ -9,25 +9,26 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer() (c *clientsContainer) {
|
||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
|
||||
c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
@@ -198,8 +199,8 @@ func TestClients(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsWHOIS(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
whois := &RuntimeClientWHOISInfo{
|
||||
clients := newClientsContainer(t)
|
||||
whois := &whois.Info{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
}
|
||||
@@ -210,7 +211,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOIS, whois)
|
||||
})
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
@@ -222,7 +223,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOIS, whois)
|
||||
})
|
||||
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
@@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
@@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
|
||||
// clientJSON is a common structure used by several handlers to deal with
|
||||
@@ -28,7 +30,8 @@ type clientJSON struct {
|
||||
// the allowlist.
|
||||
DisallowedRule *string `json:"disallowed_rule,omitempty"`
|
||||
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
|
||||
// WHOIS is the filtered WHOIS data of a client.
|
||||
WHOIS *whois.Info `json:"whois_info,omitempty"`
|
||||
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
|
||||
|
||||
Name string `json:"name"`
|
||||
@@ -51,7 +54,7 @@ type clientJSON struct {
|
||||
}
|
||||
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
WHOIS *whois.Info `json:"whois_info"`
|
||||
|
||||
IP netip.Addr `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
@@ -78,7 +81,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
|
||||
for ip, rc := range clients.ipToRC {
|
||||
cj := runtimeClientJSON{
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
WHOIS: rc.WHOIS,
|
||||
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
@@ -116,15 +119,24 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
||||
}
|
||||
}
|
||||
|
||||
weekly := schedule.EmptyWeekly()
|
||||
if prev != nil {
|
||||
weekly = prev.BlockedServices.Schedule.Clone()
|
||||
}
|
||||
|
||||
c = &Client{
|
||||
safeSearchConf: safeSearchConf,
|
||||
|
||||
Name: cj.Name,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
Upstreams: cj.Upstreams,
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: weekly,
|
||||
IDs: cj.BlockedServices,
|
||||
},
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
Upstreams: cj.Upstreams,
|
||||
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
@@ -178,7 +190,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
BlockedServices: c.BlockedServices,
|
||||
|
||||
BlockedServices: c.BlockedServices.IDs,
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
|
||||
@@ -344,16 +357,16 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
DisallowedRule: &rule,
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
cj = &clientJSON{
|
||||
Name: rc.Host,
|
||||
IDs: []string{idStr},
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
Name: rc.Host,
|
||||
IDs: []string{idStr},
|
||||
WHOIS: rc.WHOIS,
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -90,18 +91,17 @@ type clientSourcesConfig struct {
|
||||
HostsFile bool `yaml:"hosts"`
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
// configuration is loaded from YAML.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type configuration struct {
|
||||
// Raw file data to avoid re-reading of configuration file
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
// BindHost is the address for the web interface server to listen on.
|
||||
BindHost netip.Addr `yaml:"bind_host"`
|
||||
// BindPort is the port for the web interface server to listen on.
|
||||
BindPort int `yaml:"bind_port"`
|
||||
|
||||
// HTTPConfig is the block with http conf.
|
||||
HTTPConfig httpConfig `yaml:"http"`
|
||||
// Users are the clients capable for accessing the web interface.
|
||||
Users []webUser `yaml:"users"`
|
||||
// AuthAttempts is the maximum number of failed login attempts a user
|
||||
@@ -119,10 +119,6 @@ type configuration struct {
|
||||
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
|
||||
DebugPProf bool `yaml:"debug_pprof"`
|
||||
|
||||
// TTL for a web session (in hours)
|
||||
// An active session is automatically refreshed once a day.
|
||||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
QueryLog queryLogConfig `yaml:"querylog"`
|
||||
@@ -155,7 +151,23 @@ type configuration struct {
|
||||
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
// httpConfig is a block with HTTP configuration params.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type httpConfig struct {
|
||||
// Address is the address to serve the web UI on.
|
||||
Address netip.AddrPort
|
||||
|
||||
// SessionTTL for a web session.
|
||||
// An active session is automatically refreshed once a day.
|
||||
SessionTTL timeutil.Duration `yaml:"session_ttl"`
|
||||
}
|
||||
|
||||
// dnsConfig is a block with DNS configuration params.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type dnsConfig struct {
|
||||
BindHosts []netip.Addr `yaml:"bind_hosts"`
|
||||
Port int `yaml:"port"`
|
||||
@@ -260,11 +272,12 @@ type statsConfig struct {
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
||||
var config = &configuration{
|
||||
BindPort: 3000,
|
||||
BindHost: netip.IPv4Unspecified(),
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
WebSessionTTLHours: 30 * 24,
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
HTTPConfig: httpConfig{
|
||||
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
|
||||
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
|
||||
},
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
|
||||
Port: defaultPortDNS,
|
||||
@@ -316,6 +329,11 @@ var config = &configuration{
|
||||
Yandex: true,
|
||||
YouTube: true,
|
||||
},
|
||||
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: []string{},
|
||||
},
|
||||
},
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
@@ -421,8 +439,8 @@ func readLogSettings() (ls *logSettings) {
|
||||
// validateBindHosts returns error if any of binding hosts from configuration is
|
||||
// not a valid IP address.
|
||||
func validateBindHosts(conf *configuration) (err error) {
|
||||
if !conf.BindHost.IsValid() {
|
||||
return errors.Error("bind_host is not a valid ip address")
|
||||
if !conf.HTTPConfig.Address.IsValid() {
|
||||
return errors.Error("http.address is not a valid ip address")
|
||||
}
|
||||
|
||||
for i, addr := range conf.DNS.BindHosts {
|
||||
@@ -456,7 +474,7 @@ func parseConfig() (err error) {
|
||||
}
|
||||
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort))
|
||||
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
@@ -103,7 +103,7 @@ type statusResponse struct {
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
HTTPPort uint16 `json:"http_port"`
|
||||
|
||||
// ProtectionDisabledDuration is the duration of the protection pause in
|
||||
// milliseconds.
|
||||
@@ -158,7 +158,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
Language: config.Language,
|
||||
DNSAddrs: dnsAddrs,
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.BindPort,
|
||||
HTTPPort: config.HTTPConfig.Address.Port(),
|
||||
ProtectionDisabledDuration: protectionDisabledDuration,
|
||||
ProtectionEnabled: protectionEnabled,
|
||||
IsRunning: isRunning(),
|
||||
|
||||
@@ -96,8 +96,9 @@ type checkConfResp struct {
|
||||
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
portInt := req.Web.Port
|
||||
port := tcpPort(portInt)
|
||||
// TODO(a.garipov): Declare all port variables anywhere as uint16.
|
||||
reqPort := uint16(req.Web.Port)
|
||||
port := tcpPort(reqPort)
|
||||
addPorts(tcpPorts, port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
// Reset the value for the port to 1 to make sure that validateDNS
|
||||
@@ -108,15 +109,15 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
|
||||
return err
|
||||
}
|
||||
|
||||
switch portInt {
|
||||
case 0, config.BindPort:
|
||||
switch reqPort {
|
||||
case 0, config.HTTPConfig.Address.Port():
|
||||
return nil
|
||||
default:
|
||||
// Go on and check the port binding only if it's not zero or won't be
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(portInt)))
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, reqPort))
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
@@ -127,11 +128,11 @@ func (req *checkConfReq) validateDNS(
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.DNS.Port
|
||||
port := uint16(req.DNS.Port)
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
case config.BindPort:
|
||||
case config.HTTPConfig.Address.Port():
|
||||
// Go on and only check the UDP port since the TCP one is already bound
|
||||
// by AdGuard Home for web interface.
|
||||
default:
|
||||
@@ -318,8 +319,7 @@ type applyConfigReq struct {
|
||||
// copyInstallSettings copies the installation parameters between two
|
||||
// configuration structures.
|
||||
func copyInstallSettings(dst, src *configuration) {
|
||||
dst.BindHost = src.BindHost
|
||||
dst.BindPort = src.BindPort
|
||||
dst.HTTPConfig = src.HTTPConfig
|
||||
dst.DNS.BindHosts = src.DNS.BindHosts
|
||||
dst.DNS.Port = src.DNS.Port
|
||||
}
|
||||
@@ -413,8 +413,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
copyInstallSettings(curConfig, config)
|
||||
|
||||
Context.firstRun = false
|
||||
config.BindHost = req.Web.IP
|
||||
config.BindPort = req.Web.Port
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
|
||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
|
||||
@@ -487,7 +486,8 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
return nil, false, errors.Error("ports cannot be 0")
|
||||
}
|
||||
|
||||
restartHTTP = config.BindHost != req.Web.IP || config.BindPort != req.Web.Port
|
||||
addrPort := config.HTTPConfig.Address
|
||||
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
|
||||
if restartHTTP {
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
|
||||
if err != nil {
|
||||
|
||||
@@ -157,7 +157,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) ||
|
||||
config.HTTPConfig.Address.Port() < 1024 ||
|
||||
config.DNS.Port < 1024 {
|
||||
canUpdate, err = aghnet.CanBindPrivilegedPorts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -25,7 +27,7 @@ import (
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Default ports.
|
||||
// Default listening ports.
|
||||
const (
|
||||
defaultPortDNS = 53
|
||||
defaultPortHTTP = 80
|
||||
@@ -169,13 +171,72 @@ func initDNSServer(
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
}
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
}
|
||||
initWHOIS()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initWHOIS initializes the WHOIS.
|
||||
//
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
func initWHOIS() {
|
||||
const (
|
||||
// defaultQueueSize is the size of queue of IPs for WHOIS processing.
|
||||
defaultQueueSize = 255
|
||||
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
defaultTimeout = 5 * time.Second
|
||||
|
||||
// defaultCacheSize is the maximum size of the cache. If it's zero,
|
||||
// cache size is unlimited.
|
||||
defaultCacheSize = 10_000
|
||||
|
||||
// defaultMaxConnReadSize is an upper limit in bytes for reading from
|
||||
// net.Conn.
|
||||
defaultMaxConnReadSize = 64 * 1024
|
||||
|
||||
// defaultMaxRedirects is the maximum redirects count.
|
||||
defaultMaxRedirects = 5
|
||||
|
||||
// defaultMaxInfoLen is the maximum length of whois.Info fields.
|
||||
defaultMaxInfoLen = 250
|
||||
|
||||
// defaultIPTTL is the Time to Live duration for cached IP addresses.
|
||||
defaultIPTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
|
||||
|
||||
var w whois.Interface
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
w = whois.New(&whois.Config{
|
||||
DialContext: customDialContext,
|
||||
ServerAddr: whois.DefaultServer,
|
||||
Port: whois.DefaultPort,
|
||||
Timeout: defaultTimeout,
|
||||
CacheSize: defaultCacheSize,
|
||||
MaxConnReadSize: defaultMaxConnReadSize,
|
||||
MaxRedirects: defaultMaxRedirects,
|
||||
MaxInfoLen: defaultMaxInfoLen,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
} else {
|
||||
w = whois.Empty{}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("whois")
|
||||
|
||||
for ip := range Context.whoisCh {
|
||||
info, changed := w.Process(context.Background(), ip)
|
||||
if info != nil && changed {
|
||||
Context.clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||
// a subnet set that matches all locally served networks, see
|
||||
// [netutil.IsLocallyServed].
|
||||
@@ -218,9 +279,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
@@ -390,7 +449,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
// pref is a prefix for logging messages around the scope.
|
||||
const pref = "applying filters"
|
||||
|
||||
Context.filters.ApplyBlockedServices(setts, nil)
|
||||
Context.filters.ApplyBlockedServices(setts)
|
||||
|
||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
@@ -414,12 +473,12 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
// TODO(e.burkov): Get rid of this crutch.
|
||||
svcs := c.BlockedServices
|
||||
if svcs == nil {
|
||||
svcs = []string{}
|
||||
setts.ServicesRules = nil
|
||||
svcs := c.BlockedServices.IDs
|
||||
if !c.BlockedServices.Schedule.Contains(time.Now()) {
|
||||
Context.filters.ApplyBlockedServicesList(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
}
|
||||
Context.filters.ApplyBlockedServices(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
@@ -463,9 +522,7 @@ func startDNSServer() error {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
176
internal/home/dns_internal_test.go
Normal file
176
internal/home/dns_internal_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
"default": {
|
||||
UseOwnSettings: false,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||
FilteringEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
"custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
},
|
||||
"partial_custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
FilteringEnabled assert.BoolAssertionFunc
|
||||
SafeSearchEnabled assert.BoolAssertionFunc
|
||||
SafeBrowsingEnabled assert.BoolAssertionFunc
|
||||
ParentalEnabled assert.BoolAssertionFunc
|
||||
}{{
|
||||
name: "global_settings",
|
||||
id: "default",
|
||||
FilteringEnabled: assert.False,
|
||||
SafeSearchEnabled: assert.False,
|
||||
SafeBrowsingEnabled: assert.False,
|
||||
ParentalEnabled: assert.False,
|
||||
}, {
|
||||
name: "custom_settings",
|
||||
id: "custom_filtering",
|
||||
FilteringEnabled: assert.True,
|
||||
SafeSearchEnabled: assert.True,
|
||||
SafeBrowsingEnabled: assert.True,
|
||||
ParentalEnabled: assert.True,
|
||||
}, {
|
||||
name: "partial",
|
||||
id: "partial_custom_filtering",
|
||||
FilteringEnabled: assert.True,
|
||||
SafeSearchEnabled: assert.True,
|
||||
SafeBrowsingEnabled: assert.False,
|
||||
ParentalEnabled: assert.False,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
tc.FilteringEnabled(t, setts.FilteringEnabled)
|
||||
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
|
||||
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
|
||||
tc.ParentalEnabled(t, setts.ParentalEnabled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
filtering.InitModule()
|
||||
|
||||
var (
|
||||
globalBlockedServices = []string{"ok"}
|
||||
clientBlockedServices = []string{"ok", "mail_ru", "vk"}
|
||||
invalidBlockedServices = []string{"invalid"}
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: globalBlockedServices,
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
"default": {
|
||||
UseOwnBlockedServices: false,
|
||||
},
|
||||
"no_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"invalid_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: invalidBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"allow_all": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.FullWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
wantLen int
|
||||
}{{
|
||||
name: "global_settings",
|
||||
id: "default",
|
||||
wantLen: len(globalBlockedServices),
|
||||
}, {
|
||||
name: "custom_settings",
|
||||
id: "no_services",
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "custom_settings_block",
|
||||
id: "services",
|
||||
wantLen: len(clientBlockedServices),
|
||||
}, {
|
||||
name: "custom_settings_invalid",
|
||||
id: "invalid_services",
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "custom_settings_inactive_schedule",
|
||||
id: "allow_all",
|
||||
wantLen: 0,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
require.Len(t, setts.ServicesRules, tc.wantLen)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,6 @@ type homeContext struct {
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dhcpServer dhcpd.Interface // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
@@ -84,6 +83,9 @@ type homeContext struct {
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
|
||||
// whoisCh is the channel for receiving IPs for WHOIS processing.
|
||||
whoisCh chan netip.Addr
|
||||
|
||||
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
tlsCipherIDs []uint16
|
||||
|
||||
@@ -353,21 +355,43 @@ func initContextClients() (err error) {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
}
|
||||
|
||||
Context.clients.Init(
|
||||
err = Context.clients.Init(
|
||||
config.Clients.Persistent,
|
||||
Context.dhcpServer,
|
||||
Context.etcHosts,
|
||||
arpdb,
|
||||
config.DNS.DnsfilterConf,
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupBindOpts overrides bind host/port from the opts.
|
||||
func setupBindOpts(opts options) (err error) {
|
||||
bindAddr := opts.bindAddr
|
||||
if bindAddr != (netip.AddrPort{}) {
|
||||
config.HTTPConfig.Address = bindAddr
|
||||
|
||||
if config.HTTPConfig.Address.Port() != 0 {
|
||||
err = checkPorts()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.bindPort != 0 {
|
||||
config.BindPort = opts.bindPort
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(
|
||||
config.HTTPConfig.Address.Addr(),
|
||||
uint16(opts.bindPort),
|
||||
)
|
||||
|
||||
err = checkPorts()
|
||||
if err != nil {
|
||||
@@ -377,7 +401,10 @@ func setupBindOpts(opts options) (err error) {
|
||||
}
|
||||
|
||||
if opts.bindHost.IsValid() {
|
||||
config.BindHost = opts.bindHost
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(
|
||||
opts.bindHost,
|
||||
config.HTTPConfig.Address.Port(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -461,7 +488,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
// checkPorts is a helper for ports validation in config.
|
||||
func checkPorts() (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort))
|
||||
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
@@ -501,8 +528,8 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
|
||||
|
||||
webConf := webConfig{
|
||||
firstRun: Context.firstRun,
|
||||
BindHost: config.BindHost,
|
||||
BindPort: config.BindPort,
|
||||
BindHost: config.HTTPConfig.Address.Addr(),
|
||||
BindPort: int(config.HTTPConfig.Address.Port()),
|
||||
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
@@ -638,8 +665,8 @@ func initUsers() (auth *Auth, err error) {
|
||||
log.Info("authratelimiter is disabled")
|
||||
}
|
||||
|
||||
sessionTTL := config.WebSessionTTLHours * 60 * 60
|
||||
auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter)
|
||||
sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
|
||||
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter)
|
||||
if auth == nil {
|
||||
return nil, errors.Error("initializing auth module failed")
|
||||
}
|
||||
@@ -917,7 +944,7 @@ func printHTTPAddresses(proto string) {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := config.BindPort
|
||||
port := int(config.HTTPConfig.Address.Port())
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
@@ -929,9 +956,9 @@ func printHTTPAddresses(proto string) {
|
||||
return
|
||||
}
|
||||
|
||||
bindhost := config.BindHost
|
||||
if !bindhost.IsUnspecified() {
|
||||
printWebAddrs(proto, bindhost.String(), port)
|
||||
bindHost := config.HTTPConfig.Address.Addr()
|
||||
if !bindHost.IsUnspecified() {
|
||||
printWebAddrs(proto, bindHost.String(), port)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -942,14 +969,14 @@ func printHTTPAddresses(proto string) {
|
||||
// That's weird, but we'll ignore it.
|
||||
//
|
||||
// TODO(e.burkov): Find out when it happens.
|
||||
printWebAddrs(proto, bindhost.String(), port)
|
||||
printWebAddrs(proto, bindHost.String(), port)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
printWebAddrs(proto, addr.String(), config.BindPort)
|
||||
printWebAddrs(proto, addr.String(), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,11 +35,18 @@ type options struct {
|
||||
serviceControlAction string
|
||||
|
||||
// bindHost is the address on which to serve the HTTP UI.
|
||||
//
|
||||
// Deprecated: Use bindAddr.
|
||||
bindHost netip.Addr
|
||||
|
||||
// bindPort is the port on which to serve the HTTP UI.
|
||||
//
|
||||
// Deprecated: Use bindAddr.
|
||||
bindPort int
|
||||
|
||||
// bindAddr is the address to serve the web UI on.
|
||||
bindAddr netip.AddrPort
|
||||
|
||||
// checkConfig is true if the current invocation is only required to check
|
||||
// the configuration file and exit.
|
||||
checkConfig bool
|
||||
@@ -147,9 +154,10 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
|
||||
return o.bindHost.String(), true
|
||||
},
|
||||
description: "Host address to bind HTTP server on.",
|
||||
longName: "host",
|
||||
shortName: "h",
|
||||
description: "Deprecated. Host address to bind HTTP server on. Use --web-addr. " +
|
||||
"The short -h will work as --help in the future.",
|
||||
longName: "host",
|
||||
shortName: "h",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
var err error
|
||||
@@ -174,9 +182,23 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
|
||||
return strconv.Itoa(o.bindPort), true
|
||||
},
|
||||
description: "Port to serve HTTP pages on.",
|
||||
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
|
||||
longName: "port",
|
||||
shortName: "p",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (oo options, err error) {
|
||||
o.bindAddr, err = netip.ParseAddrPort(v)
|
||||
|
||||
return o, err
|
||||
},
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) {
|
||||
return o.bindAddr.String(), o.bindAddr.IsValid()
|
||||
},
|
||||
description: "Address to serve the web UI on, in the host:port format.",
|
||||
longName: "web-addr",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
o.serviceControlAction = v
|
||||
|
||||
@@ -82,6 +82,23 @@ func TestParseBindPort(t *testing.T) {
|
||||
testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
|
||||
}
|
||||
|
||||
func TestParseBindAddr(t *testing.T) {
|
||||
wantAddrPort := netip.MustParseAddrPort("1.2.3.4:8089")
|
||||
|
||||
assert.Zero(t, testParseOK(t).bindAddr, "empty is not web-addr")
|
||||
|
||||
assert.Equal(t, wantAddrPort, testParseOK(t, "--web-addr", "1.2.3.4:8089").bindAddr)
|
||||
assert.Equal(t, netip.MustParseAddrPort("1.2.3.4:0"), testParseOK(t, "--web-addr", "1.2.3.4:0").bindAddr)
|
||||
testParseParamMissing(t, "-web-addr")
|
||||
|
||||
testParseErr(t, "not an int", "--web-addr", "1.2.3.4:x")
|
||||
testParseErr(t, "hex not supported", "--web-addr", "1.2.3.4:0x100")
|
||||
testParseErr(t, "port negative", "--web-addr", "1.2.3.4:-1")
|
||||
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:65536")
|
||||
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:4294967297") // 2^32 + 1
|
||||
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:18446744073709551617") // 2^64 + 1
|
||||
}
|
||||
|
||||
func TestParseLogfile(t *testing.T) {
|
||||
assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
|
||||
assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
|
||||
@@ -162,6 +179,10 @@ func TestOptsToArgs(t *testing.T) {
|
||||
name: "bind_port",
|
||||
args: []string{"-p", "666"},
|
||||
opts: options{bindPort: 666},
|
||||
}, {
|
||||
name: "web-addr",
|
||||
args: []string{"--web-addr", "1.2.3.4:8080"},
|
||||
opts: options{bindAddr: netip.MustParseAddrPort("1.2.3.4:8080")},
|
||||
}, {
|
||||
name: "log_file",
|
||||
args: []string{"-l", "path"},
|
||||
|
||||
@@ -228,12 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
w.Reset()
|
||||
|
||||
cc := &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: map[string]*Client{},
|
||||
ipToRC: map[netip.Addr]*RuntimeClient{},
|
||||
allTags: stringutil.NewSet(),
|
||||
}
|
||||
cc := newClientsContainer(t)
|
||||
ch := make(chan netip.Addr)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
|
||||
@@ -320,7 +320,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
@@ -407,7 +407,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
|
||||
@@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 20
|
||||
const currentSchemaVersion = 23
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
@@ -94,6 +95,9 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
upgradeSchema17to18,
|
||||
upgradeSchema18to19,
|
||||
upgradeSchema19to20,
|
||||
upgradeSchema20to21,
|
||||
upgradeSchema21to22,
|
||||
upgradeSchema22to23,
|
||||
}
|
||||
|
||||
n := 0
|
||||
@@ -1128,6 +1132,199 @@ func upgradeSchema19to20(diskConf yobj) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema20to21 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'dns':
|
||||
// 'blocked_services':
|
||||
// - 'svc_name'
|
||||
//
|
||||
// # AFTER:
|
||||
// 'dns':
|
||||
// 'blocked_services':
|
||||
// 'ids':
|
||||
// - 'svc_name'
|
||||
// 'schedule':
|
||||
// 'time_zone': 'Local'
|
||||
func upgradeSchema20to21(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 20 to 21")
|
||||
diskConf["schema_version"] = 21
|
||||
|
||||
const field = "blocked_services"
|
||||
|
||||
dnsVal, ok := diskConf["dns"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
dns, ok := dnsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
|
||||
}
|
||||
|
||||
blockedVal, ok := dns[field]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
services, ok := blockedVal.(yarr)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of blocked: %T", blockedVal)
|
||||
}
|
||||
|
||||
dns[field] = yobj{
|
||||
"ids": services,
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema21to22 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'persistent':
|
||||
// - 'name': 'client_name'
|
||||
// 'blocked_services':
|
||||
// - 'svc_name'
|
||||
//
|
||||
// # AFTER:
|
||||
// 'persistent':
|
||||
// - 'name': 'client_name'
|
||||
// 'blocked_services':
|
||||
// 'ids':
|
||||
// - 'svc_name'
|
||||
// 'schedule':
|
||||
// 'time_zone': 'Local'
|
||||
func upgradeSchema21to22(diskConf yobj) (err error) {
|
||||
log.Println("Upgrade yaml: 21 to 22")
|
||||
diskConf["schema_version"] = 22
|
||||
|
||||
const field = "blocked_services"
|
||||
|
||||
clientsVal, ok := diskConf["clients"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients, ok := clientsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
|
||||
}
|
||||
|
||||
persistentVal, ok := clients["persistent"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
persistent, ok := persistentVal.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal)
|
||||
}
|
||||
|
||||
for i, val := range persistent {
|
||||
var c yobj
|
||||
c, ok = val.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val)
|
||||
}
|
||||
|
||||
var blockedVal any
|
||||
blockedVal, ok = c[field]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var services yarr
|
||||
services, ok = blockedVal.(yarr)
|
||||
if !ok {
|
||||
return fmt.Errorf(
|
||||
"persistent client at index %d: unexpected type of blocked services: %T",
|
||||
i,
|
||||
blockedVal,
|
||||
)
|
||||
}
|
||||
|
||||
c[field] = yobj{
|
||||
"ids": services,
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema22to23 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'bind_host': '1.2.3.4'
|
||||
// 'bind_port': 8080
|
||||
// 'web_session_ttl': 720
|
||||
//
|
||||
// # AFTER:
|
||||
// 'http':
|
||||
// 'address': '1.2.3.4:8080'
|
||||
// 'session_ttl': '720h'
|
||||
func upgradeSchema22to23(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 22 to 23")
|
||||
diskConf["schema_version"] = 23
|
||||
|
||||
bindHostVal, ok := diskConf["bind_host"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
bindHost, ok := bindHostVal.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of bind_host: %T", bindHostVal)
|
||||
}
|
||||
|
||||
bindHostAddr, err := netip.ParseAddr(bindHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid bind_host value: %s", bindHost)
|
||||
}
|
||||
|
||||
bindPortVal, ok := diskConf["bind_port"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
bindPort, ok := bindPortVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of bind_port: %T", bindPortVal)
|
||||
}
|
||||
|
||||
sessionTTLVal, ok := diskConf["web_session_ttl"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionTTL, ok := sessionTTLVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of web_session_ttl: %T", sessionTTLVal)
|
||||
}
|
||||
|
||||
addr := netip.AddrPortFrom(bindHostAddr, uint16(bindPort))
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("invalid address: %s", addr)
|
||||
}
|
||||
|
||||
diskConf["http"] = yobj{
|
||||
"address": addr.String(),
|
||||
"session_ttl": timeutil.Duration{Duration: time.Duration(sessionTTL) * time.Hour}.String(),
|
||||
}
|
||||
|
||||
delete(diskConf, "bind_host")
|
||||
delete(diskConf, "bind_port")
|
||||
delete(diskConf, "web_session_ttl")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
||||
@@ -1140,3 +1140,169 @@ func TestUpgradeSchema19to20(t *testing.T) {
|
||||
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpgradeSchema20to21(t *testing.T) {
|
||||
const newSchemaVer = 21
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
name: "nothing",
|
||||
in: yobj{},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}, {
|
||||
name: "no_clients",
|
||||
in: yobj{
|
||||
"dns": yobj{
|
||||
"blocked_services": yarr{"ok"},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"dns": yobj{
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{"ok"},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema20to21(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema21to22(t *testing.T) {
|
||||
const newSchemaVer = 22
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{
|
||||
"clients": yobj{},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "nothing",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{
|
||||
"name": "localhost",
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "no_services",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{
|
||||
"name": "localhost",
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{"ok"},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "services",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema21to22(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema22to23(t *testing.T) {
|
||||
const newSchemaVer = 23
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
name: "empty",
|
||||
in: yobj{},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}, {
|
||||
name: "ok",
|
||||
in: yobj{
|
||||
"bind_host": "1.2.3.4",
|
||||
"bind_port": 8081,
|
||||
"web_session_ttl": 720,
|
||||
},
|
||||
want: yobj{
|
||||
"http": yobj{
|
||||
"address": "1.2.3.4:8081",
|
||||
"session_ttl": "720h",
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}, {
|
||||
name: "v6_address",
|
||||
in: yobj{
|
||||
"bind_host": "2001:db8::1",
|
||||
"bind_port": 8081,
|
||||
"web_session_ttl": 720,
|
||||
},
|
||||
want: yobj{
|
||||
"http": yobj{
|
||||
"address": "[2001:db8::1]:8081",
|
||||
"session_ttl": "720h",
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema22to23(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +119,9 @@ func webCheckPortAvailable(port int) (ok bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
|
||||
|
||||
return aghnet.CheckPort("tcp", addrPort) == nil
|
||||
}
|
||||
|
||||
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
|
||||
@@ -1,259 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultServer = "whois.arin.net"
|
||||
defaultPort = "43"
|
||||
maxValueLength = 250
|
||||
whoisTTL = 1 * 60 * 60 // 1 hour
|
||||
)
|
||||
|
||||
// WHOIS - module context
|
||||
type WHOIS struct {
|
||||
clients *clientsContainer
|
||||
ipChan chan netip.Addr
|
||||
|
||||
// dialContext specifies the dial function for creating unencrypted TCP
|
||||
// connections.
|
||||
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
|
||||
// Contains IP addresses of clients
|
||||
// An active IP address is resolved once again after it expires.
|
||||
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
|
||||
ipAddrs cache.Cache
|
||||
|
||||
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
|
||||
timeoutMsec uint
|
||||
}
|
||||
|
||||
// initWHOIS creates the WHOIS module context.
|
||||
func initWHOIS(clients *clientsContainer) *WHOIS {
|
||||
w := WHOIS{
|
||||
timeoutMsec: 5000,
|
||||
clients: clients,
|
||||
ipAddrs: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: 10000,
|
||||
}),
|
||||
dialContext: customDialContext,
|
||||
ipChan: make(chan netip.Addr, 255),
|
||||
}
|
||||
|
||||
go w.workerLoop()
|
||||
|
||||
return &w
|
||||
}
|
||||
|
||||
// If the value is too large - cut it and append "..."
|
||||
func trimValue(s string) string {
|
||||
if len(s) <= maxValueLength {
|
||||
return s
|
||||
}
|
||||
return s[:maxValueLength-3] + "..."
|
||||
}
|
||||
|
||||
// isWHOISComment returns true if the string is empty or is a WHOIS comment.
|
||||
func isWHOISComment(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#' || s[0] == '%'
|
||||
}
|
||||
|
||||
// strmap is an alias for convenience.
|
||||
type strmap = map[string]string
|
||||
|
||||
// whoisParse parses a subset of plain-text data from the WHOIS response into
|
||||
// a string map.
|
||||
func whoisParse(data string) (m strmap) {
|
||||
m = strmap{}
|
||||
|
||||
var orgname string
|
||||
lines := strings.Split(data, "\n")
|
||||
for _, l := range lines {
|
||||
if isWHOISComment(l) {
|
||||
continue
|
||||
}
|
||||
|
||||
kv := strings.SplitN(l, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
k := strings.ToLower(strings.TrimSpace(kv[0]))
|
||||
v := strings.TrimSpace(kv[1])
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch k {
|
||||
case "orgname", "org-name":
|
||||
k = "orgname"
|
||||
v = trimValue(v)
|
||||
orgname = v
|
||||
case "city", "country":
|
||||
v = trimValue(v)
|
||||
case "descr", "netname":
|
||||
k = "orgname"
|
||||
v = stringutil.Coalesce(orgname, v)
|
||||
orgname = v
|
||||
case "whois":
|
||||
k = "whois"
|
||||
case "referralserver":
|
||||
k = "whois"
|
||||
v = strings.TrimPrefix(v, "whois://")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
m[k] = v
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
const MaxConnReadSize = 64 * 1024
|
||||
|
||||
// Send request to a server and receive the response
|
||||
func (w *WHOIS) query(ctx context.Context, target, serverAddr string) (data string, err error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
if addr == "whois.arin.net" {
|
||||
target = "n + " + target
|
||||
}
|
||||
|
||||
conn, err := w.dialContext(ctx, "tcp", serverAddr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
|
||||
|
||||
r, err := aghio.LimitReader(conn, MaxConnReadSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
|
||||
_, err = conn.Write([]byte(target + "\r\n"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// This use of ReadAll is now safe, because we limited the conn Reader.
|
||||
var whoisData []byte
|
||||
whoisData, err = io.ReadAll(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(whoisData), nil
|
||||
}
|
||||
|
||||
// Query WHOIS servers (handle redirects)
|
||||
func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
|
||||
server := net.JoinHostPort(defaultServer, defaultPort)
|
||||
const maxRedirects = 5
|
||||
for i := 0; i != maxRedirects; i++ {
|
||||
resp, err := w.query(ctx, target, server)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
log.Debug("whois: received response (%d bytes) from %s IP:%s", len(resp), server, target)
|
||||
|
||||
m := whoisParse(resp)
|
||||
redir, ok := m["whois"]
|
||||
if !ok {
|
||||
return resp, nil
|
||||
}
|
||||
redir = strings.ToLower(redir)
|
||||
|
||||
_, _, err = net.SplitHostPort(redir)
|
||||
if err != nil {
|
||||
server = net.JoinHostPort(redir, defaultPort)
|
||||
} else {
|
||||
server = redir
|
||||
}
|
||||
|
||||
log.Debug("whois: redirected to %s IP:%s", redir, target)
|
||||
}
|
||||
return "", fmt.Errorf("whois: redirect loop")
|
||||
}
|
||||
|
||||
// Request WHOIS information
|
||||
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
|
||||
resp, err := w.queryAll(ctx, ip.String())
|
||||
if err != nil {
|
||||
log.Debug("whois: error: %s IP:%s", err, ip)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("whois: IP:%s response: %d bytes", ip, len(resp))
|
||||
|
||||
m := whoisParse(resp)
|
||||
|
||||
wi = &RuntimeClientWHOISInfo{
|
||||
City: m["city"],
|
||||
Country: m["country"],
|
||||
Orgname: m["orgname"],
|
||||
}
|
||||
|
||||
// Don't return an empty struct so that the frontend doesn't get
|
||||
// confused.
|
||||
if *wi == (RuntimeClientWHOISInfo{}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return wi
|
||||
}
|
||||
|
||||
// Begin - begin requesting WHOIS info
|
||||
func (w *WHOIS) Begin(ip netip.Addr) {
|
||||
ipBytes := ip.AsSlice()
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := w.ipAddrs.Get(ipBytes)
|
||||
if len(expire) != 0 {
|
||||
exp := binary.BigEndian.Uint64(expire)
|
||||
if exp > now {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
expire = make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expire, now+whoisTTL)
|
||||
_ = w.ipAddrs.Set(ipBytes, expire)
|
||||
|
||||
log.Debug("whois: adding %s", ip)
|
||||
|
||||
select {
|
||||
case w.ipChan <- ip:
|
||||
default:
|
||||
log.Debug("whois: queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// workerLoop processes the IP addresses it got from the channel and associates
|
||||
// the retrieving WHOIS info with a client.
|
||||
func (w *WHOIS) workerLoop() {
|
||||
for ip := range w.ipChan {
|
||||
info := w.process(context.Background(), ip)
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
w.clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeConn is a mock implementation of net.Conn to simplify testing.
|
||||
//
|
||||
// TODO(e.burkov): Search for other places in code where it may be used. Move
|
||||
// into aghtest then.
|
||||
type fakeConn struct {
|
||||
// Conn is embedded here simply to make *fakeConn a net.Conn without
|
||||
// actually implementing all methods.
|
||||
net.Conn
|
||||
data []byte
|
||||
}
|
||||
|
||||
// Write implements net.Conn interface for *fakeConn. It always returns 0 and a
|
||||
// nil error without mutating the slice.
|
||||
func (c *fakeConn) Write(_ []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read implements net.Conn interface for *fakeConn. It puts the content of
|
||||
// c.data field into b up to the b's capacity.
|
||||
func (c *fakeConn) Read(b []byte) (n int, err error) {
|
||||
return copy(b, c.data), io.EOF
|
||||
}
|
||||
|
||||
// Close implements net.Conn interface for *fakeConn. It always returns nil.
|
||||
func (c *fakeConn) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn interface for *fakeConn. It always
|
||||
// returns nil.
|
||||
func (c *fakeConn) SetReadDeadline(_ time.Time) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// fakeDial is a mock implementation of customDialContext to simplify testing.
|
||||
func (c *fakeConn) fakeDial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func TestWHOIS(t *testing.T) {
|
||||
const (
|
||||
nl = "\n"
|
||||
data = `OrgName: FakeOrg LLC` + nl +
|
||||
`City: Nonreal` + nl +
|
||||
`Country: Imagiland` + nl
|
||||
)
|
||||
|
||||
fc := &fakeConn{
|
||||
data: []byte(data),
|
||||
}
|
||||
|
||||
w := WHOIS{
|
||||
timeoutMsec: 5000,
|
||||
dialContext: fc.fakeDial,
|
||||
}
|
||||
resp, err := w.queryAll(context.Background(), "1.2.3.4")
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := whoisParse(resp)
|
||||
require.NotEmpty(t, m)
|
||||
|
||||
assert.Equal(t, "FakeOrg LLC", m["orgname"])
|
||||
assert.Equal(t, "Imagiland", m["country"])
|
||||
assert.Equal(t, "Nonreal", m["city"])
|
||||
}
|
||||
|
||||
func TestWHOISParse(t *testing.T) {
|
||||
const (
|
||||
city = "Nonreal"
|
||||
country = "Imagiland"
|
||||
orgname = "FakeOrgLLC"
|
||||
whois = "whois.example.net"
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
want strmap
|
||||
name string
|
||||
in string
|
||||
}{{
|
||||
want: strmap{},
|
||||
name: "empty",
|
||||
in: ``,
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "comments",
|
||||
in: "%\n#",
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "no_colon",
|
||||
in: "city",
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "no_value",
|
||||
in: "city:",
|
||||
}, {
|
||||
want: strmap{"city": city},
|
||||
name: "city",
|
||||
in: `city: ` + city,
|
||||
}, {
|
||||
want: strmap{"country": country},
|
||||
name: "country",
|
||||
in: `country: ` + country,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname",
|
||||
in: `orgname: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname_hyphen",
|
||||
in: `org-name: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname_descr",
|
||||
in: `descr: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname_netname",
|
||||
in: `netname: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"whois": whois},
|
||||
name: "whois",
|
||||
in: `whois: ` + whois,
|
||||
}, {
|
||||
want: strmap{"whois": whois},
|
||||
name: "referralserver",
|
||||
in: `referralserver: whois://` + whois,
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "other",
|
||||
in: `other: value`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := whoisParse(tc.in)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
63
internal/next/agh/agh.go
Normal file
63
internal/next/agh/agh.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Package agh contains common entities and interfaces of AdGuard Home.
|
||||
package agh
|
||||
|
||||
import "context"
|
||||
|
||||
// Service is the interface for API servers.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding a context to Start.
|
||||
//
|
||||
// TODO(a.garipov): Consider adding a Wait method or making an extension
|
||||
// interface for that.
|
||||
type Service interface {
|
||||
// Start starts the service. It does not block.
|
||||
Start() (err error)
|
||||
|
||||
// Shutdown gracefully stops the service. ctx is used to determine
|
||||
// a timeout before trying to stop the service less gracefully.
|
||||
Shutdown(ctx context.Context) (err error)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Service = EmptyService{}
|
||||
|
||||
// EmptyService is a [Service] that does nothing.
|
||||
//
|
||||
// TODO(a.garipov): Remove if unnecessary.
|
||||
type EmptyService struct{}
|
||||
|
||||
// Start implements the [Service] interface for EmptyService.
|
||||
func (EmptyService) Start() (err error) { return nil }
|
||||
|
||||
// Shutdown implements the [Service] interface for EmptyService.
|
||||
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
|
||||
|
||||
// ServiceWithConfig is an extension of the [Service] interface for services
|
||||
// that can return their configuration.
|
||||
//
|
||||
// TODO(a.garipov): Consider removing this generic interface if we figure out
|
||||
// how to make it testable in a better way.
|
||||
type ServiceWithConfig[ConfigType any] interface {
|
||||
Service
|
||||
|
||||
Config() (c ConfigType)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil)
|
||||
|
||||
// EmptyServiceWithConfig is a ServiceWithConfig that does nothing. Its Config
|
||||
// method returns Conf.
|
||||
//
|
||||
// TODO(a.garipov): Remove if unnecessary.
|
||||
type EmptyServiceWithConfig[ConfigType any] struct {
|
||||
EmptyService
|
||||
|
||||
Conf ConfigType
|
||||
}
|
||||
|
||||
// Config implements the [ServiceWithConfig] interface for
|
||||
// *EmptyServiceWithConfig.
|
||||
func (s *EmptyServiceWithConfig[ConfigType]) Config() (conf ConfigType) {
|
||||
return s.Conf
|
||||
}
|
||||
@@ -1,23 +1,15 @@
|
||||
package querylog
|
||||
|
||||
import "github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
|
||||
// Client is the information required by the query log to match against clients
|
||||
// during searches.
|
||||
type Client struct {
|
||||
WHOIS *ClientWHOIS `json:"whois,omitempty"`
|
||||
Name string `json:"name"`
|
||||
DisallowedRule string `json:"disallowed_rule"`
|
||||
Disallowed bool `json:"disallowed"`
|
||||
IgnoreQueryLog bool `json:"-"`
|
||||
}
|
||||
|
||||
// ClientWHOIS is the filtered WHOIS data for the client.
|
||||
//
|
||||
// TODO(a.garipov): Merge with home.RuntimeClientWHOISInfo after the
|
||||
// refactoring is done.
|
||||
type ClientWHOIS struct {
|
||||
City string `json:"city,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Orgname string `json:"orgname,omitempty"`
|
||||
WHOIS *whois.Info `json:"whois,omitempty"`
|
||||
Name string `json:"name"`
|
||||
DisallowedRule string `json:"disallowed_rule"`
|
||||
Disallowed bool `json:"disallowed"`
|
||||
IgnoreQueryLog bool `json:"-"`
|
||||
}
|
||||
|
||||
// clientCacheKey is the key by which a cached client information is found.
|
||||
|
||||
@@ -161,12 +161,15 @@ func (l *queryLog) clear() {
|
||||
// newLogEntry creates an instance of logEntry from parameters.
|
||||
func newLogEntry(params *AddParams) (entry *logEntry) {
|
||||
q := params.Question.Question[0]
|
||||
qHost := q.Name
|
||||
if qHost != "." {
|
||||
qHost = strings.ToLower(q.Name[:len(q.Name)-1])
|
||||
}
|
||||
|
||||
entry = &logEntry{
|
||||
// TODO(d.kolyshev): Export this timestamp to func params.
|
||||
Time: time.Now(),
|
||||
|
||||
QHost: strings.ToLower(q.Name[:len(q.Name)-1]),
|
||||
Time: time.Now(),
|
||||
QHost: qHost,
|
||||
QType: dns.Type(q.Qtype).String(),
|
||||
QClass: dns.Class(q.Qclass).String(),
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ func TestQueryLog(t *testing.T) {
|
||||
// Add memory entries.
|
||||
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
|
||||
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
|
||||
addEntry(l, "", net.IPv4(1, 1, 1, 5), net.IPv4(2, 2, 2, 5))
|
||||
|
||||
type tcAssertion struct {
|
||||
host string
|
||||
@@ -59,10 +60,11 @@ func TestQueryLog(t *testing.T) {
|
||||
name: "all",
|
||||
sCr: []searchCriterion{},
|
||||
want: []tcAssertion{
|
||||
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
|
||||
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
{num: 0, host: ".", answer: net.IPv4(1, 1, 1, 5), client: net.IPv4(2, 2, 2, 5)},
|
||||
{num: 1, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
|
||||
{num: 2, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 4, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
},
|
||||
}, {
|
||||
name: "by_domain_strict",
|
||||
@@ -104,10 +106,11 @@ func TestQueryLog(t *testing.T) {
|
||||
value: "2.2.2",
|
||||
}},
|
||||
want: []tcAssertion{
|
||||
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
|
||||
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
{num: 0, host: ".", answer: net.IPv4(1, 1, 1, 5), client: net.IPv4(2, 2, 2, 5)},
|
||||
{num: 1, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
|
||||
{num: 2, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 4, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
},
|
||||
}}
|
||||
|
||||
|
||||
250
internal/schedule/schedule.go
Normal file
250
internal/schedule/schedule.go
Normal file
@@ -0,0 +1,250 @@
|
||||
// Package schedule provides types for scheduling.
|
||||
package schedule
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Weekly is a schedule for one week. Each day of the week has one range with
|
||||
// a beginning and an end.
|
||||
type Weekly struct {
|
||||
// location is used to calculate the offsets of the day ranges.
|
||||
location *time.Location
|
||||
|
||||
// days are the day ranges of this schedule. The indexes of this array are
|
||||
// the [time.Weekday] values.
|
||||
days [7]dayRange
|
||||
}
|
||||
|
||||
// EmptyWeekly creates empty weekly schedule with local time zone.
|
||||
func EmptyWeekly() (w *Weekly) {
|
||||
return &Weekly{
|
||||
location: time.Local,
|
||||
}
|
||||
}
|
||||
|
||||
// FullWeekly creates full weekly schedule with local time zone.
|
||||
//
|
||||
// TODO(s.chzhen): Consider moving into tests.
|
||||
func FullWeekly() (w *Weekly) {
|
||||
fullDay := dayRange{start: 0, end: maxDayRange}
|
||||
|
||||
return &Weekly{
|
||||
location: time.Local,
|
||||
days: [7]dayRange{
|
||||
time.Sunday: fullDay,
|
||||
time.Monday: fullDay,
|
||||
time.Tuesday: fullDay,
|
||||
time.Wednesday: fullDay,
|
||||
time.Thursday: fullDay,
|
||||
time.Friday: fullDay,
|
||||
time.Saturday: fullDay,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of a weekly.
|
||||
func (w *Weekly) Clone() (c *Weekly) {
|
||||
// NOTE: Do not use time.LoadLocation, because the results will be
|
||||
// different on time zone database update.
|
||||
return &Weekly{
|
||||
location: w.location,
|
||||
days: w.days,
|
||||
}
|
||||
}
|
||||
|
||||
// Contains returns true if t is within the corresponding day range of the
|
||||
// schedule in the schedule's time zone.
|
||||
func (w *Weekly) Contains(t time.Time) (ok bool) {
|
||||
t = t.In(w.location)
|
||||
wd := t.Weekday()
|
||||
dr := w.days[wd]
|
||||
|
||||
// Calculate the offset of the day range.
|
||||
//
|
||||
// NOTE: Do not use [time.Truncate] since it requires UTC time zone.
|
||||
y, m, d := t.Date()
|
||||
day := time.Date(y, m, d, 0, 0, 0, 0, w.location)
|
||||
offset := t.Sub(day)
|
||||
|
||||
return dr.contains(offset)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ yaml.Unmarshaler = (*Weekly)(nil)
|
||||
|
||||
// UnmarshalYAML implements the [yaml.Unmarshaler] interface for *Weekly.
|
||||
func (w *Weekly) UnmarshalYAML(value *yaml.Node) (err error) {
|
||||
conf := &weeklyConfig{}
|
||||
|
||||
err = value.Decode(conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
weekly := Weekly{}
|
||||
|
||||
weekly.location, err = time.LoadLocation(conf.TimeZone)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
days := []dayConfig{
|
||||
time.Sunday: conf.Sunday,
|
||||
time.Monday: conf.Monday,
|
||||
time.Tuesday: conf.Tuesday,
|
||||
time.Wednesday: conf.Wednesday,
|
||||
time.Thursday: conf.Thursday,
|
||||
time.Friday: conf.Friday,
|
||||
time.Saturday: conf.Saturday,
|
||||
}
|
||||
for i, d := range days {
|
||||
r := dayRange{
|
||||
start: d.Start.Duration,
|
||||
end: d.End.Duration,
|
||||
}
|
||||
|
||||
err = w.validate(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("weekday %s: %w", time.Weekday(i), err)
|
||||
}
|
||||
|
||||
weekly.days[i] = r
|
||||
}
|
||||
|
||||
*w = weekly
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// weeklyConfig is the YAML configuration structure of Weekly.
|
||||
type weeklyConfig struct {
|
||||
// TimeZone is the local time zone.
|
||||
TimeZone string `yaml:"time_zone"`
|
||||
|
||||
// Days of the week.
|
||||
|
||||
Sunday dayConfig `yaml:"sun,omitempty"`
|
||||
Monday dayConfig `yaml:"mon,omitempty"`
|
||||
Tuesday dayConfig `yaml:"tue,omitempty"`
|
||||
Wednesday dayConfig `yaml:"wed,omitempty"`
|
||||
Thursday dayConfig `yaml:"thu,omitempty"`
|
||||
Friday dayConfig `yaml:"fri,omitempty"`
|
||||
Saturday dayConfig `yaml:"sat,omitempty"`
|
||||
}
|
||||
|
||||
// dayConfig is the YAML configuration structure of dayRange.
|
||||
type dayConfig struct {
|
||||
Start timeutil.Duration `yaml:"start"`
|
||||
End timeutil.Duration `yaml:"end"`
|
||||
}
|
||||
|
||||
// maxDayRange is the maximum value for day range end.
|
||||
const maxDayRange = 24 * time.Hour
|
||||
|
||||
// validate returns the day range rounding errors, if any.
|
||||
func (w *Weekly) validate(r dayRange) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "bad day range: %w") }()
|
||||
|
||||
err = r.validate()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
start := r.start.Truncate(time.Minute)
|
||||
end := r.end.Truncate(time.Minute)
|
||||
|
||||
switch {
|
||||
case start != r.start:
|
||||
return fmt.Errorf("start %s isn't rounded to minutes", r.start)
|
||||
case end != r.end:
|
||||
return fmt.Errorf("end %s isn't rounded to minutes", r.end)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ yaml.Marshaler = (*Weekly)(nil)
|
||||
|
||||
// MarshalYAML implements the [yaml.Marshaler] interface for *Weekly.
|
||||
func (w *Weekly) MarshalYAML() (v any, err error) {
|
||||
return weeklyConfig{
|
||||
TimeZone: w.location.String(),
|
||||
Sunday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Sunday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Sunday].end},
|
||||
},
|
||||
Monday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Monday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Monday].end},
|
||||
},
|
||||
Tuesday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Tuesday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Tuesday].end},
|
||||
},
|
||||
Wednesday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Wednesday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Wednesday].end},
|
||||
},
|
||||
Thursday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Thursday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Thursday].end},
|
||||
},
|
||||
Friday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Friday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Friday].end},
|
||||
},
|
||||
Saturday: dayConfig{
|
||||
Start: timeutil.Duration{Duration: w.days[time.Saturday].start},
|
||||
End: timeutil.Duration{Duration: w.days[time.Saturday].end},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dayRange represents a single interval within a day. The interval begins at
|
||||
// start and ends before end. That is, it contains a time point T if start <=
|
||||
// T < end.
|
||||
type dayRange struct {
|
||||
// start is an offset from the beginning of the day. It must be greater
|
||||
// than or equal to zero and less than 24h.
|
||||
start time.Duration
|
||||
|
||||
// end is an offset from the beginning of the day. It must be greater than
|
||||
// or equal to zero and less than or equal to 24h.
|
||||
end time.Duration
|
||||
}
|
||||
|
||||
// validate returns the day range validation errors, if any.
|
||||
func (r dayRange) validate() (err error) {
|
||||
switch {
|
||||
case r == dayRange{}:
|
||||
return nil
|
||||
case r.start < 0:
|
||||
return fmt.Errorf("start %s is negative", r.start)
|
||||
case r.end < 0:
|
||||
return fmt.Errorf("end %s is negative", r.end)
|
||||
case r.start >= r.end:
|
||||
return fmt.Errorf("start %s is greater or equal to end %s", r.start, r.end)
|
||||
case r.start >= maxDayRange:
|
||||
return fmt.Errorf("start %s is greater or equal to %s", r.start, maxDayRange)
|
||||
case r.end > maxDayRange:
|
||||
return fmt.Errorf("end %s is greater than %s", r.end, maxDayRange)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// contains returns true if start <= offset < end, where offset is the time
|
||||
// duration from the beginning of the day.
|
||||
func (r *dayRange) contains(offset time.Duration) (ok bool) {
|
||||
return r.start <= offset && offset < r.end
|
||||
}
|
||||
371
internal/schedule/schedule_internal_test.go
Normal file
371
internal/schedule/schedule_internal_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package schedule
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestWeekly_Contains(t *testing.T) {
|
||||
baseTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
otherTime := baseTime.Add(1 * timeutil.Day)
|
||||
|
||||
// NOTE: In the Etc area the sign of the offsets is flipped. So, Etc/GMT-3
|
||||
// is actually UTC+03:00.
|
||||
otherTZ := time.FixedZone("Etc/GMT-3", 3*60*60)
|
||||
|
||||
// baseSchedule, 12:00 to 14:00.
|
||||
baseSchedule := &Weekly{
|
||||
days: [7]dayRange{
|
||||
time.Friday: {start: 12 * time.Hour, end: 14 * time.Hour},
|
||||
},
|
||||
location: time.UTC,
|
||||
}
|
||||
|
||||
// allDaySchedule, 00:00 to 24:00.
|
||||
allDaySchedule := &Weekly{
|
||||
days: [7]dayRange{
|
||||
time.Friday: {start: 0, end: 24 * time.Hour},
|
||||
},
|
||||
location: time.UTC,
|
||||
}
|
||||
|
||||
// oneMinSchedule, 00:00 to 00:01.
|
||||
oneMinSchedule := &Weekly{
|
||||
days: [7]dayRange{
|
||||
time.Friday: {start: 0, end: 1 * time.Minute},
|
||||
},
|
||||
location: time.UTC,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
schedule *Weekly
|
||||
assert assert.BoolAssertionFunc
|
||||
t time.Time
|
||||
name string
|
||||
}{{
|
||||
schedule: EmptyWeekly(),
|
||||
assert: assert.False,
|
||||
t: baseTime,
|
||||
name: "empty",
|
||||
}, {
|
||||
schedule: allDaySchedule,
|
||||
assert: assert.True,
|
||||
t: baseTime,
|
||||
name: "same_day_all_day",
|
||||
}, {
|
||||
schedule: baseSchedule,
|
||||
assert: assert.True,
|
||||
t: baseTime.Add(13 * time.Hour),
|
||||
name: "same_day_inside",
|
||||
}, {
|
||||
schedule: baseSchedule,
|
||||
assert: assert.False,
|
||||
t: baseTime.Add(11 * time.Hour),
|
||||
name: "same_day_outside",
|
||||
}, {
|
||||
schedule: allDaySchedule,
|
||||
assert: assert.True,
|
||||
t: baseTime.Add(24*time.Hour - time.Second),
|
||||
name: "same_day_last_second",
|
||||
}, {
|
||||
schedule: allDaySchedule,
|
||||
assert: assert.False,
|
||||
t: otherTime,
|
||||
name: "other_day_all_day",
|
||||
}, {
|
||||
schedule: baseSchedule,
|
||||
assert: assert.False,
|
||||
t: otherTime.Add(13 * time.Hour),
|
||||
name: "other_day_inside",
|
||||
}, {
|
||||
schedule: baseSchedule,
|
||||
assert: assert.False,
|
||||
t: otherTime.Add(11 * time.Hour),
|
||||
name: "other_day_outside",
|
||||
}, {
|
||||
schedule: baseSchedule,
|
||||
assert: assert.True,
|
||||
t: baseTime.Add(13 * time.Hour).In(otherTZ),
|
||||
name: "same_day_inside_other_tz",
|
||||
}, {
|
||||
schedule: baseSchedule,
|
||||
assert: assert.False,
|
||||
t: baseTime.Add(11 * time.Hour).In(otherTZ),
|
||||
name: "same_day_outside_other_tz",
|
||||
}, {
|
||||
schedule: oneMinSchedule,
|
||||
assert: assert.True,
|
||||
t: baseTime,
|
||||
name: "one_minute_beginning",
|
||||
}, {
|
||||
schedule: oneMinSchedule,
|
||||
assert: assert.True,
|
||||
t: baseTime.Add(1*time.Minute - 1),
|
||||
name: "one_minute_end",
|
||||
}, {
|
||||
schedule: oneMinSchedule,
|
||||
assert: assert.False,
|
||||
t: baseTime.Add(1 * time.Minute),
|
||||
name: "one_minute_past_end",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.assert(t, tc.schedule.Contains(tc.t))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const brusselsSunday = `
|
||||
sun:
|
||||
start: 12h
|
||||
end: 14h
|
||||
time_zone: Europe/Brussels
|
||||
`
|
||||
|
||||
func TestWeekly_UnmarshalYAML(t *testing.T) {
|
||||
const (
|
||||
sameTime = `
|
||||
sun:
|
||||
start: 9h
|
||||
end: 9h
|
||||
`
|
||||
negativeStart = `
|
||||
sun:
|
||||
start: -1h
|
||||
end: 1h
|
||||
`
|
||||
badTZ = `
|
||||
time_zone: "bad_timezone"
|
||||
`
|
||||
badYAML = `
|
||||
yaml: "bad"
|
||||
yaml: "bad"
|
||||
`
|
||||
)
|
||||
|
||||
brusseltsTZ, err := time.LoadLocation("Europe/Brussels")
|
||||
require.NoError(t, err)
|
||||
|
||||
brusselsWeekly := &Weekly{
|
||||
days: [7]dayRange{{
|
||||
start: time.Hour * 12,
|
||||
end: time.Hour * 14,
|
||||
}},
|
||||
location: brusseltsTZ,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
data []byte
|
||||
want *Weekly
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
data: []byte(""),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
data: []byte("null"),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "brussels_sunday",
|
||||
wantErrMsg: "",
|
||||
data: []byte(brusselsSunday),
|
||||
want: brusselsWeekly,
|
||||
}, {
|
||||
name: "start_equal_end",
|
||||
wantErrMsg: "weekday Sunday: bad day range: start 9h0m0s is greater or equal to end 9h0m0s",
|
||||
data: []byte(sameTime),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "start_negative",
|
||||
wantErrMsg: "weekday Sunday: bad day range: start -1h0m0s is negative",
|
||||
data: []byte(negativeStart),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "bad_time_zone",
|
||||
wantErrMsg: "unknown time zone bad_timezone",
|
||||
data: []byte(badTZ),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "bad_yaml",
|
||||
wantErrMsg: "yaml: unmarshal errors:\n line 3: mapping key \"yaml\" already defined at line 2",
|
||||
data: []byte(badYAML),
|
||||
want: &Weekly{},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
w := &Weekly{}
|
||||
err = yaml.Unmarshal(tc.data, w)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, w)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeekly_MarshalYAML(t *testing.T) {
|
||||
brusselsTZ, err := time.LoadLocation("Europe/Brussels")
|
||||
require.NoError(t, err)
|
||||
|
||||
brusselsWeekly := &Weekly{
|
||||
days: [7]dayRange{time.Sunday: {
|
||||
start: time.Hour * 12,
|
||||
end: time.Hour * 14,
|
||||
}},
|
||||
location: brusselsTZ,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
want *Weekly
|
||||
}{{
|
||||
name: "empty",
|
||||
data: []byte(""),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "null",
|
||||
data: []byte("null"),
|
||||
want: &Weekly{},
|
||||
}, {
|
||||
name: "brussels_sunday",
|
||||
data: []byte(brusselsSunday),
|
||||
want: brusselsWeekly,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var data []byte
|
||||
data, err = yaml.Marshal(brusselsWeekly)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := &Weekly{}
|
||||
err = yaml.Unmarshal(data, w)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, brusselsWeekly, w)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeekly_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
in dayRange
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
in: dayRange{},
|
||||
}, {
|
||||
name: "start_seconds",
|
||||
wantErrMsg: "bad day range: start 1s isn't rounded to minutes",
|
||||
in: dayRange{
|
||||
start: time.Second,
|
||||
end: time.Hour,
|
||||
},
|
||||
}, {
|
||||
name: "end_seconds",
|
||||
wantErrMsg: "bad day range: end 1s isn't rounded to minutes",
|
||||
in: dayRange{
|
||||
start: 0,
|
||||
end: time.Second,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
w := &Weekly{}
|
||||
err := w.validate(tc.in)
|
||||
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDayRange_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
in dayRange
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
in: dayRange{},
|
||||
}, {
|
||||
name: "valid",
|
||||
wantErrMsg: "",
|
||||
in: dayRange{
|
||||
start: time.Hour,
|
||||
end: time.Hour * 2,
|
||||
},
|
||||
}, {
|
||||
name: "valid_end_max",
|
||||
wantErrMsg: "",
|
||||
in: dayRange{
|
||||
start: 0,
|
||||
end: time.Hour * 24,
|
||||
},
|
||||
}, {
|
||||
name: "start_negative",
|
||||
wantErrMsg: "start -1h0m0s is negative",
|
||||
in: dayRange{
|
||||
start: time.Hour * -1,
|
||||
end: time.Hour * 2,
|
||||
},
|
||||
}, {
|
||||
name: "end_negative",
|
||||
wantErrMsg: "end -1h0m0s is negative",
|
||||
in: dayRange{
|
||||
start: 0,
|
||||
end: time.Hour * -1,
|
||||
},
|
||||
}, {
|
||||
name: "start_equal_end",
|
||||
wantErrMsg: "start 1h0m0s is greater or equal to end 1h0m0s",
|
||||
in: dayRange{
|
||||
start: time.Hour,
|
||||
end: time.Hour,
|
||||
},
|
||||
}, {
|
||||
name: "start_greater_end",
|
||||
wantErrMsg: "start 2h0m0s is greater or equal to end 1h0m0s",
|
||||
in: dayRange{
|
||||
start: time.Hour * 2,
|
||||
end: time.Hour,
|
||||
},
|
||||
}, {
|
||||
name: "start_equal_max",
|
||||
wantErrMsg: "start 24h0m0s is greater or equal to 24h0m0s",
|
||||
in: dayRange{
|
||||
start: time.Hour * 24,
|
||||
end: time.Hour * 48,
|
||||
},
|
||||
}, {
|
||||
name: "end_greater_max",
|
||||
wantErrMsg: "end 48h0m0s is greater than 24h0m0s",
|
||||
in: dayRange{
|
||||
start: 0,
|
||||
end: time.Hour * 48,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.in.validate()
|
||||
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,20 +4,21 @@ go 1.19
|
||||
|
||||
require (
|
||||
github.com/fzipp/gocyclo v0.6.0
|
||||
github.com/golangci/misspell v0.4.0
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28
|
||||
github.com/golangci/misspell v0.4.1
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601
|
||||
github.com/kisielk/errcheck v1.6.3
|
||||
github.com/kyoh86/looppointer v0.2.1
|
||||
github.com/securego/gosec/v2 v2.16.0
|
||||
golang.org/x/tools v0.9.3
|
||||
golang.org/x/vuln v0.1.0
|
||||
github.com/uudashr/gocognit v1.0.6
|
||||
golang.org/x/tools v0.10.0
|
||||
golang.org/x/vuln v0.2.0
|
||||
honnef.co/go/tools v0.4.3
|
||||
mvdan.cc/gofumpt v0.5.0
|
||||
mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8
|
||||
mvdan.cc/unparam v0.0.0-20230610194454-9ea02bef9868
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.3.1 // indirect
|
||||
github.com/BurntSushi/toml v1.3.2 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/gookit/color v1.5.3 // indirect
|
||||
@@ -25,9 +26,9 @@ require (
|
||||
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 // indirect
|
||||
golang.org/x/mod v0.10.0 // indirect
|
||||
golang.org/x/sync v0.2.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20230626212559-97b1e661b5df // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/sync v0.3.0 // indirect
|
||||
golang.org/x/sys v0.9.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
github.com/BurntSushi/toml v1.3.1 h1:rHnDkSK+/g6DlREUK73PkmIs60pqrnuduK+JmP++JmU=
|
||||
github.com/BurntSushi/toml v1.3.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
|
||||
github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8=
|
||||
github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY=
|
||||
@@ -8,8 +7,8 @@ github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
|
||||
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/golangci/misspell v0.4.0 h1:KtVB/hTK4bbL/S6bs64rYyk8adjmh1BygbBiaAiX+a0=
|
||||
github.com/golangci/misspell v0.4.0/go.mod h1:W6O/bwV6lGDxUCChm2ykw9NQdd5bYd1Xkjo88UcWyJc=
|
||||
github.com/golangci/misspell v0.4.1 h1:+y73iSicVy2PqyX7kmUefHusENlrP9YwuHZHPLGQj/g=
|
||||
github.com/golangci/misspell v0.4.1/go.mod h1:9mAN1quEo3DlpbaIKKyEvRxK1pwqR9s/Sea1bJCtlNI=
|
||||
github.com/google/go-cmdtest v0.4.1-0.20220921163831-55ab3332a786 h1:rcv+Ippz6RAtvaGgKxc+8FQIpxHgsF+HBzPyYL2cyVU=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
@@ -19,8 +18,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE=
|
||||
github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28 h1:9alfqbrhuD+9fLZ4iaAVwhlp5PEhmnBt7yvK2Oy5C1U=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230107090616-13ace0543b28/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601 h1:mrEEilTAUmaAORhssPPkxj84TsHrPMLBGW2Z4SoTxm8=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20230610083614-0e73809eb601/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
|
||||
github.com/kisielk/errcheck v1.6.3 h1:dEKh+GLHcWm2oN34nMvDzn1sqI0i0WxPvrgiJA5JuM8=
|
||||
github.com/kisielk/errcheck v1.6.3/go.mod h1:nXw/i/MfnvRHqXa7XXmQMUB0oNFGuBrNI8d8NLy0LPw=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
@@ -40,6 +39,8 @@ github.com/securego/gosec/v2 v2.16.0 h1:Pi0JKoasQQ3NnoRao/ww/N/XdynIB9NRYYZT5CyO
|
||||
github.com/securego/gosec/v2 v2.16.0/go.mod h1:xvLcVZqUfo4aAQu56TNv7/Ltz6emAOQAEsrZrt7uGlI=
|
||||
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/uudashr/gocognit v1.0.6 h1:2Cgi6MweCsdB6kpcVQp7EW4U23iBFQWfTXiWlyp842Y=
|
||||
github.com/uudashr/gocognit v1.0.6/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -51,25 +52,26 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
|
||||
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1 h1:pnP8r+W8Fm7XJ8CWtXi4S9oJmPBTrkfYN/dNbaPj6Y4=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230626212559-97b1e661b5df h1:jfUqBujZx2dktJVEmZpCkyngz7MWrVv1y9kLOqFNsqw=
|
||||
golang.org/x/exp/typeparams v0.0.0-20230626212559-97b1e661b5df/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
|
||||
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
|
||||
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
||||
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -79,8 +81,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -92,10 +95,11 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
||||
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
|
||||
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/vuln v0.1.0 h1:9GRdj6wAIkDrsMevuolY+SXERPjQPp2P1ysYA0jpZe0=
|
||||
golang.org/x/vuln v0.1.0/go.mod h1:/YuzZYjGbwB8y19CisAppfyw3uTZnuCz3r+qgx/QRzU=
|
||||
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
|
||||
golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg=
|
||||
golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM=
|
||||
golang.org/x/vuln v0.2.0 h1:Dlz47lW0pvPHU7tnb10S8vbMn9GnV2B6eyT7Tem5XBI=
|
||||
golang.org/x/vuln v0.2.0/go.mod h1:V0eyhHwaAaHrt42J9bgrN6rd12f6GU4T0Lu0ex2wDg4=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -107,5 +111,5 @@ honnef.co/go/tools v0.4.3 h1:o/n5/K5gXqk8Gozvs2cnL0F2S1/g1vcGCAx2vETjITw=
|
||||
honnef.co/go/tools v0.4.3/go.mod h1:36ZgoUOrqOk1GxwHhyryEkq8FQWkUO2xGuSMhUCcdvA=
|
||||
mvdan.cc/gofumpt v0.5.0 h1:0EQ+Z56k8tXjj/6TQD25BFNKQXpCvT0rnansIc7Ug5E=
|
||||
mvdan.cc/gofumpt v0.5.0/go.mod h1:HBeVDtMKRZpXyxFciAirzdKklDlGu8aAy1wEbH5Y9js=
|
||||
mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8 h1:VuJo4Mt0EVPychre4fNlDWDuE5AjXtPJpRUWqZDQhaI=
|
||||
mvdan.cc/unparam v0.0.0-20230312165513-e84e2d14e3b8/go.mod h1:Oh/d7dEtzsNHGOq1Cdv8aMm3KdKhVvPbRQcM8WFpBR8=
|
||||
mvdan.cc/unparam v0.0.0-20230610194454-9ea02bef9868 h1:F4Q7pXcrU9UiU1fq0ZWqSOxKjNAteRuDr7JDk7uVLRQ=
|
||||
mvdan.cc/unparam v0.0.0-20230610194454-9ea02bef9868/go.mod h1:6ZaiQyI7Tiq0HQ56g6N8TlkSd80/LyagZeaw8mb7jYE=
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
_ "github.com/kisielk/errcheck"
|
||||
_ "github.com/kyoh86/looppointer"
|
||||
_ "github.com/securego/gosec/v2/cmd/gosec"
|
||||
_ "github.com/uudashr/gocognit/cmd/gocognit"
|
||||
_ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness"
|
||||
_ "golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow"
|
||||
_ "golang.org/x/vuln/cmd/govulncheck"
|
||||
|
||||
@@ -143,14 +143,7 @@ func Verbose() (v string) {
|
||||
runtime.Version(),
|
||||
)
|
||||
|
||||
if committime != "" {
|
||||
commitTimeUnix, err := strconv.ParseInt(committime, 10, 64)
|
||||
if err != nil {
|
||||
stringutil.WriteToBuilder(b, nl, vFmtTimeHdr, fmt.Sprintf("parse error: %s", err))
|
||||
} else {
|
||||
stringutil.WriteToBuilder(b, nl, vFmtTimeHdr, time.Unix(commitTimeUnix, 0).String())
|
||||
}
|
||||
}
|
||||
writeCommitTime(b)
|
||||
|
||||
stringutil.WriteToBuilder(b, nl, vFmtGOOSHdr, nl, vFmtGOARCHHdr)
|
||||
if goarm != "" {
|
||||
@@ -179,3 +172,16 @@ func Verbose() (v string) {
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func writeCommitTime(b *strings.Builder) {
|
||||
if committime == "" {
|
||||
return
|
||||
}
|
||||
|
||||
commitTimeUnix, err := strconv.ParseInt(committime, 10, 64)
|
||||
if err != nil {
|
||||
stringutil.WriteToBuilder(b, "\n", vFmtTimeHdr, fmt.Sprintf("parse error: %s", err))
|
||||
} else {
|
||||
stringutil.WriteToBuilder(b, "\n", vFmtTimeHdr, time.Unix(commitTimeUnix, 0).String())
|
||||
}
|
||||
}
|
||||
|
||||
388
internal/whois/whois.go
Normal file
388
internal/whois/whois.go
Normal file
@@ -0,0 +1,388 @@
|
||||
// Package whois provides WHOIS functionality.
|
||||
package whois
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/bluele/gcache"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultServer is the default WHOIS server.
|
||||
DefaultServer = "whois.arin.net"
|
||||
|
||||
// DefaultPort is the default port for WHOIS requests.
|
||||
DefaultPort = 43
|
||||
)
|
||||
|
||||
// Interface provides WHOIS functionality.
|
||||
type Interface interface {
|
||||
// Process makes WHOIS request and returns WHOIS information or nil.
|
||||
// changed indicates that Info was updated since last request.
|
||||
Process(ctx context.Context, ip netip.Addr) (info *Info, changed bool)
|
||||
}
|
||||
|
||||
// Empty is an empty [Interface] implementation which does nothing.
|
||||
type Empty struct{}
|
||||
|
||||
// type check
|
||||
var _ Interface = (*Empty)(nil)
|
||||
|
||||
// Process implements the [Interface] interface for Empty.
|
||||
func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Config is the configuration structure for Default.
|
||||
type Config struct {
|
||||
// DialContext specifies the dial function for creating unencrypted TCP
|
||||
// connections.
|
||||
DialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
|
||||
// ServerAddr is the address of the WHOIS server.
|
||||
ServerAddr string
|
||||
|
||||
// Timeout is the timeout for WHOIS requests.
|
||||
Timeout time.Duration
|
||||
|
||||
// CacheTTL is the Time to Live duration for cached IP addresses.
|
||||
CacheTTL time.Duration
|
||||
|
||||
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
MaxConnReadSize int64
|
||||
|
||||
// MaxRedirects is the maximum redirects count.
|
||||
MaxRedirects int
|
||||
|
||||
// MaxInfoLen is the maximum length of Info fields returned by Process.
|
||||
MaxInfoLen int
|
||||
|
||||
// CacheSize is the maximum size of the cache. It must be greater than
|
||||
// zero.
|
||||
CacheSize int
|
||||
|
||||
// Port is the port for WHOIS requests.
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// Default is the default WHOIS information processor.
|
||||
type Default struct {
|
||||
// cache is the cache containing IP addresses of clients. An active IP
|
||||
// address is resolved once again after it expires. If IP address couldn't
|
||||
// be resolved, it stays here for some time to prevent further attempts to
|
||||
// resolve the same IP.
|
||||
cache gcache.Cache
|
||||
|
||||
// dialContext connects to a remote server resolving hostname using our own
|
||||
// DNS server and unecrypted TCP connection.
|
||||
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
|
||||
// serverAddr is the address of the WHOIS server.
|
||||
serverAddr string
|
||||
|
||||
// portStr is the port for WHOIS requests.
|
||||
portStr string
|
||||
|
||||
// timeout is the timeout for WHOIS requests.
|
||||
timeout time.Duration
|
||||
|
||||
// cacheTTL is the Time to Live duration for cached IP addresses.
|
||||
cacheTTL time.Duration
|
||||
|
||||
// maxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
maxConnReadSize int64
|
||||
|
||||
// maxRedirects is the maximum redirects count.
|
||||
maxRedirects int
|
||||
|
||||
// maxInfoLen is the maximum length of Info fields returned by Process.
|
||||
maxInfoLen int
|
||||
}
|
||||
|
||||
// New returns a new default WHOIS information processor. conf must not be
|
||||
// nil.
|
||||
func New(conf *Config) (w *Default) {
|
||||
return &Default{
|
||||
serverAddr: conf.ServerAddr,
|
||||
dialContext: conf.DialContext,
|
||||
timeout: conf.Timeout,
|
||||
cache: gcache.New(conf.CacheSize).LRU().Build(),
|
||||
maxConnReadSize: conf.MaxConnReadSize,
|
||||
maxRedirects: conf.MaxRedirects,
|
||||
portStr: strconv.Itoa(int(conf.Port)),
|
||||
maxInfoLen: conf.MaxInfoLen,
|
||||
cacheTTL: conf.CacheTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// trimValue trims s and replaces the last 3 characters of the cut with "..."
|
||||
// to fit into max. max must be greater than 3.
|
||||
func trimValue(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
|
||||
// isWHOISComment returns true if the data is empty or is a WHOIS comment.
|
||||
func isWHOISComment(data []byte) (ok bool) {
|
||||
return len(data) == 0 || data[0] == '#' || data[0] == '%'
|
||||
}
|
||||
|
||||
// whoisParse parses a subset of plain-text data from the WHOIS response into a
|
||||
// string map. It trims values of the returned map to maxLen.
|
||||
func whoisParse(data []byte, maxLen int) (info map[string]string) {
|
||||
info = map[string]string{}
|
||||
|
||||
var orgname string
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
for _, l := range lines {
|
||||
if isWHOISComment(l) {
|
||||
continue
|
||||
}
|
||||
|
||||
before, after, found := bytes.Cut(l, []byte(":"))
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
key := strings.ToLower(string(before))
|
||||
val := strings.TrimSpace(string(after))
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "orgname", "org-name":
|
||||
key = "orgname"
|
||||
val = trimValue(val, maxLen)
|
||||
orgname = val
|
||||
case "city", "country":
|
||||
val = trimValue(val, maxLen)
|
||||
case "descr", "netname":
|
||||
key = "orgname"
|
||||
val = stringutil.Coalesce(orgname, val)
|
||||
orgname = val
|
||||
case "whois":
|
||||
key = "whois"
|
||||
case "referralserver":
|
||||
key = "whois"
|
||||
val = strings.TrimPrefix(val, "whois://")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
info[key] = val
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// query sends request to a server and returns the response or error.
|
||||
func (w *Default) query(ctx context.Context, target, serverAddr string) (data []byte, err error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
if addr == DefaultServer {
|
||||
// Display type flags for query.
|
||||
//
|
||||
// See https://www.arin.net/resources/registry/whois/rws/api/#nicname-whois-queries.
|
||||
target = "n + " + target
|
||||
}
|
||||
|
||||
conn, err := w.dialContext(ctx, "tcp", serverAddr)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
|
||||
|
||||
r, err := aghio.LimitReader(conn, w.maxConnReadSize)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(w.timeout))
|
||||
_, err = io.WriteString(conn, target+"\r\n")
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This use of ReadAll is now safe, because we limited the conn Reader.
|
||||
data, err = io.ReadAll(r)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// queryAll queries WHOIS server and handles redirects.
|
||||
func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) {
|
||||
server := net.JoinHostPort(w.serverAddr, w.portStr)
|
||||
var data []byte
|
||||
|
||||
for i := 0; i < w.maxRedirects; i++ {
|
||||
data, err = w.query(ctx, target, server)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("whois: received response (%d bytes) from %q about %q", len(data), server, target)
|
||||
|
||||
info = whoisParse(data, w.maxInfoLen)
|
||||
redir, ok := info["whois"]
|
||||
if !ok {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
redir = strings.ToLower(redir)
|
||||
|
||||
_, _, err = net.SplitHostPort(redir)
|
||||
if err != nil {
|
||||
server = net.JoinHostPort(redir, w.portStr)
|
||||
} else {
|
||||
server = redir
|
||||
}
|
||||
|
||||
log.Debug("whois: redirected to %q about %q", redir, target)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("whois: redirect loop")
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Interface = (*Default)(nil)
|
||||
|
||||
// Process makes WHOIS request and returns WHOIS information or nil. changed
|
||||
// indicates that Info was updated since last request.
|
||||
func (w *Default) Process(ctx context.Context, ip netip.Addr) (wi *Info, changed bool) {
|
||||
if netutil.IsSpecialPurposeAddr(ip) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
wi, expired := w.findInCache(ip)
|
||||
if wi != nil && !expired {
|
||||
// Don't return an empty struct so that the frontend doesn't get
|
||||
// confused.
|
||||
if (*wi == Info{}) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return wi, false
|
||||
}
|
||||
|
||||
return w.requestInfo(ctx, ip, wi)
|
||||
}
|
||||
|
||||
// requestInfo makes WHOIS request and returns WHOIS info. changed is false if
|
||||
// received information is equal to cached.
|
||||
func (w *Default) requestInfo(
|
||||
ctx context.Context,
|
||||
ip netip.Addr,
|
||||
cached *Info,
|
||||
) (wi *Info, changed bool) {
|
||||
var info Info
|
||||
|
||||
defer func() {
|
||||
item := toCacheItem(info, w.cacheTTL)
|
||||
err := w.cache.Set(ip, item)
|
||||
if err != nil {
|
||||
log.Debug("whois: cache: adding item %q: %s", ip, err)
|
||||
}
|
||||
}()
|
||||
|
||||
kv, err := w.queryAll(ctx, ip.String())
|
||||
if err != nil {
|
||||
log.Debug("whois: quering about %q: %s", ip, err)
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
info = Info{
|
||||
City: kv["city"],
|
||||
Country: kv["country"],
|
||||
Orgname: kv["orgname"],
|
||||
}
|
||||
|
||||
changed = cached == nil || info != *cached
|
||||
|
||||
// Don't return an empty struct so that the frontend doesn't get confused.
|
||||
if (info == Info{}) {
|
||||
return nil, changed
|
||||
}
|
||||
|
||||
return &info, changed
|
||||
}
|
||||
|
||||
// findInCache finds Info in the cache. expired indicates that Info is valid.
|
||||
func (w *Default) findInCache(ip netip.Addr) (wi *Info, expired bool) {
|
||||
val, err := w.cache.Get(ip)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gcache.KeyNotFoundError) {
|
||||
log.Debug("whois: cache: retrieving info about %q: %s", ip, err)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
item, ok := val.(*cacheItem)
|
||||
if !ok {
|
||||
log.Debug("whois: cache: %q bad type %T", ip, val)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return fromCacheItem(item)
|
||||
}
|
||||
|
||||
// Info is the filtered WHOIS data for a runtime client.
|
||||
type Info struct {
|
||||
City string `json:"city,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Orgname string `json:"orgname,omitempty"`
|
||||
}
|
||||
|
||||
// cacheItem represents an item that we will store in the cache.
|
||||
type cacheItem struct {
|
||||
// expiry is the time when cacheItem will expire.
|
||||
expiry time.Time
|
||||
|
||||
// info is the WHOIS data for a runtime client.
|
||||
info *Info
|
||||
}
|
||||
|
||||
// toCacheItem creates a cached item from a WHOIS info and Time to Live
|
||||
// duration.
|
||||
func toCacheItem(info Info, ttl time.Duration) (item *cacheItem) {
|
||||
return &cacheItem{
|
||||
expiry: time.Now().Add(ttl),
|
||||
info: &info,
|
||||
}
|
||||
}
|
||||
|
||||
// fromCacheItem creates a WHOIS info from the cached item. expired indicates
|
||||
// that WHOIS info is valid. item must not be nil.
|
||||
func fromCacheItem(item *cacheItem) (info *Info, expired bool) {
|
||||
if time.Now().After(item.expiry) {
|
||||
return item.info, true
|
||||
}
|
||||
|
||||
return item.info, false
|
||||
}
|
||||
155
internal/whois/whois_test.go
Normal file
155
internal/whois/whois_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package whois_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakenet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefault_Process(t *testing.T) {
|
||||
const (
|
||||
nl = "\n"
|
||||
city = "Nonreal"
|
||||
country = "Imagiland"
|
||||
orgname = "FakeOrgLLC"
|
||||
referralserver = "whois.example.net"
|
||||
)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
testCases := []struct {
|
||||
want *whois.Info
|
||||
name string
|
||||
data string
|
||||
}{{
|
||||
want: nil,
|
||||
name: "empty",
|
||||
data: "",
|
||||
}, {
|
||||
want: nil,
|
||||
name: "comments",
|
||||
data: "%\n#",
|
||||
}, {
|
||||
want: nil,
|
||||
name: "no_colon",
|
||||
data: "city",
|
||||
}, {
|
||||
want: nil,
|
||||
name: "no_value",
|
||||
data: "city:",
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
City: city,
|
||||
},
|
||||
name: "city",
|
||||
data: "city: " + city,
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
Country: country,
|
||||
},
|
||||
name: "country",
|
||||
data: "country: " + country,
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
Orgname: orgname,
|
||||
},
|
||||
name: "orgname",
|
||||
data: "orgname: " + orgname,
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
Orgname: orgname,
|
||||
},
|
||||
name: "orgname_hyphen",
|
||||
data: "org-name: " + orgname,
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
Orgname: orgname,
|
||||
},
|
||||
name: "orgname_descr",
|
||||
data: "descr: " + orgname,
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
Orgname: orgname,
|
||||
},
|
||||
name: "orgname_netname",
|
||||
data: "netname: " + orgname,
|
||||
}, {
|
||||
want: &whois.Info{
|
||||
City: city,
|
||||
Country: country,
|
||||
Orgname: orgname,
|
||||
},
|
||||
name: "full",
|
||||
data: "OrgName: " + orgname + nl + "City: " + city + nl + "Country: " + country,
|
||||
}, {
|
||||
want: nil,
|
||||
name: "whois",
|
||||
data: "whois: " + referralserver,
|
||||
}, {
|
||||
want: nil,
|
||||
name: "referralserver",
|
||||
data: "referralserver: whois://" + referralserver,
|
||||
}, {
|
||||
want: nil,
|
||||
name: "other",
|
||||
data: "other: value",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
hit := 0
|
||||
|
||||
fakeConn := &fakenet.Conn{
|
||||
OnRead: func(b []byte) (n int, err error) {
|
||||
hit++
|
||||
|
||||
return copy(b, tc.data), io.EOF
|
||||
},
|
||||
OnWrite: func(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
},
|
||||
OnClose: func() (err error) {
|
||||
return nil
|
||||
},
|
||||
OnSetReadDeadline: func(t time.Time) (err error) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
w := whois.New(&whois.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
DialContext: func(_ context.Context, _, addr string) (_ net.Conn, _ error) {
|
||||
hit = 0
|
||||
|
||||
return fakeConn, nil
|
||||
},
|
||||
MaxConnReadSize: 1024,
|
||||
MaxRedirects: 3,
|
||||
MaxInfoLen: 250,
|
||||
CacheSize: 100,
|
||||
CacheTTL: time.Hour,
|
||||
})
|
||||
|
||||
got, changed := w.Process(context.Background(), ip)
|
||||
require.True(t, changed)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
assert.Equal(t, 1, hit)
|
||||
|
||||
// From cache.
|
||||
got, changed = w.Process(context.Background(), ip)
|
||||
require.False(t, changed)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
assert.Equal(t, 1, hit)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user