all: sync with master; upd chlog
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
@@ -22,12 +23,14 @@ type Client struct {
|
||||
safeSearchConf filtering.SafeSearchConfig
|
||||
SafeSearch filtering.SafeSearch
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices
|
||||
|
||||
Name string
|
||||
|
||||
IDs []string
|
||||
Tags []string
|
||||
BlockedServices []string
|
||||
Upstreams []string
|
||||
IDs []string
|
||||
Tags []string
|
||||
Upstreams []string
|
||||
|
||||
UseOwnSettings bool
|
||||
FilteringEnabled bool
|
||||
@@ -43,9 +46,9 @@ type Client struct {
|
||||
func (c *Client) ShallowClone() (sh *Client) {
|
||||
clone := *c
|
||||
|
||||
clone.BlockedServices = c.BlockedServices.Clone()
|
||||
clone.IDs = stringutil.CloneSlice(c.IDs)
|
||||
clone.Tags = stringutil.CloneSlice(c.Tags)
|
||||
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
|
||||
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
|
||||
return &clone
|
||||
@@ -127,14 +130,13 @@ func (cs clientSource) MarshalText() (text []byte, err error) {
|
||||
// RuntimeClient is a client information about which has been obtained using the
|
||||
// source described in the Source field.
|
||||
type RuntimeClient struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo
|
||||
Host string
|
||||
Source clientSource
|
||||
}
|
||||
// WHOIS is the filtered WHOIS data of a client.
|
||||
WHOIS *whois.Info
|
||||
|
||||
// RuntimeClientWHOISInfo is the filtered WHOIS data for a runtime client.
|
||||
type RuntimeClientWHOISInfo struct {
|
||||
City string `json:"city,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Orgname string `json:"orgname,omitempty"`
|
||||
// Host is the host name of a client.
|
||||
Host string
|
||||
|
||||
// Source is the source from which the information about the client has
|
||||
// been obtained.
|
||||
Source clientSource
|
||||
}
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -23,6 +25,23 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||
// needs.
|
||||
type DHCP interface {
|
||||
// Leases returns all the DHCP leases.
|
||||
Leases() (leases []*dhcpsvc.Lease)
|
||||
|
||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||
// address. The address will be netip.Addr{} if there is no such client,
|
||||
// due to an assumption that a DHCP client must always have an IP address.
|
||||
HostByIP(ip netip.Addr) (host string)
|
||||
|
||||
// MACByIP returns the MAC address for the given IP address leased. It
|
||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||
// client must always have a MAC address.
|
||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||
}
|
||||
|
||||
// clientsContainer is the storage of all runtime and persistent clients.
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
@@ -77,7 +96,7 @@ func (clients *clientsContainer) Init(
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
filteringConf *filtering.Config,
|
||||
) {
|
||||
) (err error) {
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
@@ -91,23 +110,29 @@ func (clients *clientsContainer) Init(
|
||||
clients.dhcpServer = dhcpServer
|
||||
clients.etcHosts = etcHosts
|
||||
clients.arpdb = arpdb
|
||||
clients.addFromConfig(objects, filteringConf)
|
||||
err = clients.addFromConfig(objects, filteringConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
|
||||
if clients.testing {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.updateFromDHCP(true)
|
||||
if clients.dhcpServer != nil {
|
||||
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
|
||||
clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
|
||||
}
|
||||
|
||||
if clients.etcHosts != nil {
|
||||
go clients.handleHostsUpdates()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
@@ -147,12 +172,14 @@ func (clients *clientsContainer) reloadARP() {
|
||||
type clientObject struct {
|
||||
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
|
||||
|
||||
// BlockedServices is the configuration of blocked services of a client.
|
||||
BlockedServices *filtering.BlockedServices `yaml:"blocked_services"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
|
||||
Tags []string `yaml:"tags"`
|
||||
IDs []string `yaml:"ids"`
|
||||
BlockedServices []string `yaml:"blocked_services"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
IDs []string `yaml:"ids"`
|
||||
Tags []string `yaml:"tags"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
|
||||
UseGlobalSettings bool `yaml:"use_global_settings"`
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"`
|
||||
@@ -166,7 +193,10 @@ type clientObject struct {
|
||||
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
|
||||
func (clients *clientsContainer) addFromConfig(
|
||||
objects []*clientObject,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
for _, o := range objects {
|
||||
cli := &Client{
|
||||
Name: o.Name,
|
||||
@@ -187,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err := cli.setSafeSearch(
|
||||
err = cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
@@ -199,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range o.BlockedServices {
|
||||
if filtering.BlockedSvcKnown(s) {
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
} else {
|
||||
log.Info("clients: skipping unknown blocked service %q", s)
|
||||
}
|
||||
err = o.BlockedServices.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
|
||||
}
|
||||
|
||||
cli.BlockedServices = o.BlockedServices.Clone()
|
||||
|
||||
for _, t := range o.Tags {
|
||||
if clients.allTags.Has(t) {
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
@@ -217,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
|
||||
slices.Sort(cli.Tags)
|
||||
|
||||
_, err := clients.Add(cli)
|
||||
_, err = clients.Add(cli)
|
||||
if err != nil {
|
||||
log.Error("clients: adding clients %s: %s", cli.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forConfig returns all currently known persistent clients as objects for the
|
||||
@@ -235,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
o := &clientObject{
|
||||
Name: cli.Name,
|
||||
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
@@ -276,15 +308,38 @@ func (clients *clientsContainer) periodicUpdate() {
|
||||
}
|
||||
}
|
||||
|
||||
// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of
|
||||
// runtime clients using the DHCP server's leases.
|
||||
//
|
||||
// TODO(e.burkov): Remove when switched to dhcpsvc.
|
||||
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
switch flags {
|
||||
case dhcpd.LeaseChangedAdded,
|
||||
dhcpd.LeaseChangedAddedStatic,
|
||||
dhcpd.LeaseChangedRemovedStatic:
|
||||
clients.updateFromDHCP(true)
|
||||
case dhcpd.LeaseChangedRemovedAll:
|
||||
clients.updateFromDHCP(false)
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceDHCP)
|
||||
|
||||
if flags == dhcpd.LeaseChangedRemovedAll {
|
||||
return
|
||||
}
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if l.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from dhcp", n)
|
||||
}
|
||||
|
||||
// clientSource checks if client with this IP address already exists and returns
|
||||
@@ -300,23 +355,11 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource)
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
return ClientSourceNone
|
||||
if ok {
|
||||
return rc.Source
|
||||
}
|
||||
|
||||
return rc.Source
|
||||
}
|
||||
|
||||
func toQueryLogWHOIS(wi *RuntimeClientWHOISInfo) (cw *querylog.ClientWHOIS) {
|
||||
if wi == nil {
|
||||
return &querylog.ClientWHOIS{}
|
||||
}
|
||||
|
||||
return &querylog.ClientWHOIS{
|
||||
City: wi.City,
|
||||
Country: wi.Country,
|
||||
Orgname: wi.Orgname,
|
||||
}
|
||||
return ClientSourceNone
|
||||
}
|
||||
|
||||
// findMultiple is a wrapper around Find to make it a valid client finder for
|
||||
@@ -352,7 +395,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
defer func() {
|
||||
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
|
||||
if c.WHOIS == nil {
|
||||
c.WHOIS = &querylog.ClientWHOIS{}
|
||||
c.WHOIS = &whois.Info{}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -369,7 +412,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
if ok {
|
||||
return &querylog.Client{
|
||||
Name: rc.Host,
|
||||
WHOIS: toQueryLogWHOIS(rc.WHOISInfo),
|
||||
WHOIS: rc.WHOIS,
|
||||
}, false
|
||||
}
|
||||
|
||||
@@ -477,11 +520,11 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if clients.dhcpServer == nil {
|
||||
return nil, false
|
||||
if clients.dhcpServer != nil {
|
||||
return clients.findDHCP(ip)
|
||||
}
|
||||
|
||||
return clients.findDHCP(ip)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
||||
@@ -701,35 +744,34 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
|
||||
}
|
||||
|
||||
// setWHOISInfo sets the WHOIS information for a client.
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Consider storing WHOIS information separately and
|
||||
// potentially get rid of [RuntimeClient].
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
rc.WHOISInfo = wi
|
||||
if !ok {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &RuntimeClient{
|
||||
Source: ClientSourceWHOIS,
|
||||
}
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &RuntimeClient{
|
||||
Source: ClientSourceWHOIS,
|
||||
}
|
||||
|
||||
rc.WHOISInfo = wi
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
rc.WHOIS = wi
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
@@ -753,23 +795,19 @@ func (clients *clientsContainer) addHostLocked(
|
||||
src clientSource,
|
||||
) (ok bool) {
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
} else {
|
||||
if !ok {
|
||||
rc = &RuntimeClient{
|
||||
Host: host,
|
||||
Source: src,
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
} else if src < rc.Source {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
|
||||
log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC))
|
||||
|
||||
return true
|
||||
@@ -838,38 +876,6 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
||||
}
|
||||
|
||||
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
|
||||
// server.
|
||||
func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceDHCP)
|
||||
|
||||
if !add {
|
||||
return
|
||||
}
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if l.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("clients: added %d client aliases from dhcp", n)
|
||||
}
|
||||
|
||||
// close gracefully closes all the client-specific upstream configurations of
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
|
||||
@@ -9,25 +9,26 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer() (c *clientsContainer) {
|
||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
|
||||
c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
@@ -198,8 +199,8 @@ func TestClients(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsWHOIS(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
whois := &RuntimeClientWHOISInfo{
|
||||
clients := newClientsContainer(t)
|
||||
whois := &whois.Info{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
}
|
||||
@@ -210,7 +211,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOIS, whois)
|
||||
})
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
@@ -222,7 +223,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOIS, whois)
|
||||
})
|
||||
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
@@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
@@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
|
||||
// clientJSON is a common structure used by several handlers to deal with
|
||||
@@ -28,7 +30,8 @@ type clientJSON struct {
|
||||
// the allowlist.
|
||||
DisallowedRule *string `json:"disallowed_rule,omitempty"`
|
||||
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
|
||||
// WHOIS is the filtered WHOIS data of a client.
|
||||
WHOIS *whois.Info `json:"whois_info,omitempty"`
|
||||
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
|
||||
|
||||
Name string `json:"name"`
|
||||
@@ -51,7 +54,7 @@ type clientJSON struct {
|
||||
}
|
||||
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
WHOIS *whois.Info `json:"whois_info"`
|
||||
|
||||
IP netip.Addr `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
@@ -78,7 +81,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
|
||||
for ip, rc := range clients.ipToRC {
|
||||
cj := runtimeClientJSON{
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
WHOIS: rc.WHOIS,
|
||||
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
@@ -116,15 +119,24 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
||||
}
|
||||
}
|
||||
|
||||
weekly := schedule.EmptyWeekly()
|
||||
if prev != nil {
|
||||
weekly = prev.BlockedServices.Schedule.Clone()
|
||||
}
|
||||
|
||||
c = &Client{
|
||||
safeSearchConf: safeSearchConf,
|
||||
|
||||
Name: cj.Name,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
Upstreams: cj.Upstreams,
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: weekly,
|
||||
IDs: cj.BlockedServices,
|
||||
},
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
Upstreams: cj.Upstreams,
|
||||
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
@@ -178,7 +190,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
|
||||
|
||||
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
|
||||
BlockedServices: c.BlockedServices,
|
||||
|
||||
BlockedServices: c.BlockedServices.IDs,
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
|
||||
@@ -344,16 +357,16 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
DisallowedRule: &rule,
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
cj = &clientJSON{
|
||||
Name: rc.Host,
|
||||
IDs: []string{idStr},
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
Name: rc.Host,
|
||||
IDs: []string{idStr},
|
||||
WHOIS: rc.WHOIS,
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -90,18 +91,17 @@ type clientSourcesConfig struct {
|
||||
HostsFile bool `yaml:"hosts"`
|
||||
}
|
||||
|
||||
// configuration is loaded from YAML
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
// configuration is loaded from YAML.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type configuration struct {
|
||||
// Raw file data to avoid re-reading of configuration file
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
// BindHost is the address for the web interface server to listen on.
|
||||
BindHost netip.Addr `yaml:"bind_host"`
|
||||
// BindPort is the port for the web interface server to listen on.
|
||||
BindPort int `yaml:"bind_port"`
|
||||
|
||||
// HTTPConfig is the block with http conf.
|
||||
HTTPConfig httpConfig `yaml:"http"`
|
||||
// Users are the clients capable for accessing the web interface.
|
||||
Users []webUser `yaml:"users"`
|
||||
// AuthAttempts is the maximum number of failed login attempts a user
|
||||
@@ -119,10 +119,6 @@ type configuration struct {
|
||||
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
|
||||
DebugPProf bool `yaml:"debug_pprof"`
|
||||
|
||||
// TTL for a web session (in hours)
|
||||
// An active session is automatically refreshed once a day.
|
||||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
QueryLog queryLogConfig `yaml:"querylog"`
|
||||
@@ -155,7 +151,23 @@ type configuration struct {
|
||||
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
// httpConfig is a block with HTTP configuration params.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type httpConfig struct {
|
||||
// Address is the address to serve the web UI on.
|
||||
Address netip.AddrPort
|
||||
|
||||
// SessionTTL for a web session.
|
||||
// An active session is automatically refreshed once a day.
|
||||
SessionTTL timeutil.Duration `yaml:"session_ttl"`
|
||||
}
|
||||
|
||||
// dnsConfig is a block with DNS configuration params.
|
||||
//
|
||||
// Field ordering is important, YAML fields better not to be reordered, if it's
|
||||
// not absolutely necessary.
|
||||
type dnsConfig struct {
|
||||
BindHosts []netip.Addr `yaml:"bind_hosts"`
|
||||
Port int `yaml:"port"`
|
||||
@@ -260,11 +272,12 @@ type statsConfig struct {
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
||||
var config = &configuration{
|
||||
BindPort: 3000,
|
||||
BindHost: netip.IPv4Unspecified(),
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
WebSessionTTLHours: 30 * 24,
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
HTTPConfig: httpConfig{
|
||||
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
|
||||
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
|
||||
},
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
|
||||
Port: defaultPortDNS,
|
||||
@@ -316,6 +329,11 @@ var config = &configuration{
|
||||
Yandex: true,
|
||||
YouTube: true,
|
||||
},
|
||||
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: []string{},
|
||||
},
|
||||
},
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
@@ -421,8 +439,8 @@ func readLogSettings() (ls *logSettings) {
|
||||
// validateBindHosts returns error if any of binding hosts from configuration is
|
||||
// not a valid IP address.
|
||||
func validateBindHosts(conf *configuration) (err error) {
|
||||
if !conf.BindHost.IsValid() {
|
||||
return errors.Error("bind_host is not a valid ip address")
|
||||
if !conf.HTTPConfig.Address.IsValid() {
|
||||
return errors.Error("http.address is not a valid ip address")
|
||||
}
|
||||
|
||||
for i, addr := range conf.DNS.BindHosts {
|
||||
@@ -456,7 +474,7 @@ func parseConfig() (err error) {
|
||||
}
|
||||
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort))
|
||||
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
@@ -103,7 +103,7 @@ type statusResponse struct {
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
HTTPPort uint16 `json:"http_port"`
|
||||
|
||||
// ProtectionDisabledDuration is the duration of the protection pause in
|
||||
// milliseconds.
|
||||
@@ -158,7 +158,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
Language: config.Language,
|
||||
DNSAddrs: dnsAddrs,
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.BindPort,
|
||||
HTTPPort: config.HTTPConfig.Address.Port(),
|
||||
ProtectionDisabledDuration: protectionDisabledDuration,
|
||||
ProtectionEnabled: protectionEnabled,
|
||||
IsRunning: isRunning(),
|
||||
|
||||
@@ -96,8 +96,9 @@ type checkConfResp struct {
|
||||
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
portInt := req.Web.Port
|
||||
port := tcpPort(portInt)
|
||||
// TODO(a.garipov): Declare all port variables anywhere as uint16.
|
||||
reqPort := uint16(req.Web.Port)
|
||||
port := tcpPort(reqPort)
|
||||
addPorts(tcpPorts, port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
// Reset the value for the port to 1 to make sure that validateDNS
|
||||
@@ -108,15 +109,15 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
|
||||
return err
|
||||
}
|
||||
|
||||
switch portInt {
|
||||
case 0, config.BindPort:
|
||||
switch reqPort {
|
||||
case 0, config.HTTPConfig.Address.Port():
|
||||
return nil
|
||||
default:
|
||||
// Go on and check the port binding only if it's not zero or won't be
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(portInt)))
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, reqPort))
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
@@ -127,11 +128,11 @@ func (req *checkConfReq) validateDNS(
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.DNS.Port
|
||||
port := uint16(req.DNS.Port)
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
case config.BindPort:
|
||||
case config.HTTPConfig.Address.Port():
|
||||
// Go on and only check the UDP port since the TCP one is already bound
|
||||
// by AdGuard Home for web interface.
|
||||
default:
|
||||
@@ -318,8 +319,7 @@ type applyConfigReq struct {
|
||||
// copyInstallSettings copies the installation parameters between two
|
||||
// configuration structures.
|
||||
func copyInstallSettings(dst, src *configuration) {
|
||||
dst.BindHost = src.BindHost
|
||||
dst.BindPort = src.BindPort
|
||||
dst.HTTPConfig = src.HTTPConfig
|
||||
dst.DNS.BindHosts = src.DNS.BindHosts
|
||||
dst.DNS.Port = src.DNS.Port
|
||||
}
|
||||
@@ -413,8 +413,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
copyInstallSettings(curConfig, config)
|
||||
|
||||
Context.firstRun = false
|
||||
config.BindHost = req.Web.IP
|
||||
config.BindPort = req.Web.Port
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
|
||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
|
||||
@@ -487,7 +486,8 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
return nil, false, errors.Error("ports cannot be 0")
|
||||
}
|
||||
|
||||
restartHTTP = config.BindHost != req.Web.IP || config.BindPort != req.Web.Port
|
||||
addrPort := config.HTTPConfig.Address
|
||||
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
|
||||
if restartHTTP {
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
|
||||
if err != nil {
|
||||
|
||||
@@ -157,7 +157,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) ||
|
||||
config.HTTPConfig.Address.Port() < 1024 ||
|
||||
config.DNS.Port < 1024 {
|
||||
canUpdate, err = aghnet.CanBindPrivilegedPorts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -25,7 +27,7 @@ import (
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Default ports.
|
||||
// Default listening ports.
|
||||
const (
|
||||
defaultPortDNS = 53
|
||||
defaultPortHTTP = 80
|
||||
@@ -169,13 +171,72 @@ func initDNSServer(
|
||||
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
|
||||
}
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
}
|
||||
initWHOIS()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initWHOIS initializes the WHOIS.
|
||||
//
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
func initWHOIS() {
|
||||
const (
|
||||
// defaultQueueSize is the size of queue of IPs for WHOIS processing.
|
||||
defaultQueueSize = 255
|
||||
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
defaultTimeout = 5 * time.Second
|
||||
|
||||
// defaultCacheSize is the maximum size of the cache. If it's zero,
|
||||
// cache size is unlimited.
|
||||
defaultCacheSize = 10_000
|
||||
|
||||
// defaultMaxConnReadSize is an upper limit in bytes for reading from
|
||||
// net.Conn.
|
||||
defaultMaxConnReadSize = 64 * 1024
|
||||
|
||||
// defaultMaxRedirects is the maximum redirects count.
|
||||
defaultMaxRedirects = 5
|
||||
|
||||
// defaultMaxInfoLen is the maximum length of whois.Info fields.
|
||||
defaultMaxInfoLen = 250
|
||||
|
||||
// defaultIPTTL is the Time to Live duration for cached IP addresses.
|
||||
defaultIPTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
|
||||
|
||||
var w whois.Interface
|
||||
|
||||
if config.Clients.Sources.WHOIS {
|
||||
w = whois.New(&whois.Config{
|
||||
DialContext: customDialContext,
|
||||
ServerAddr: whois.DefaultServer,
|
||||
Port: whois.DefaultPort,
|
||||
Timeout: defaultTimeout,
|
||||
CacheSize: defaultCacheSize,
|
||||
MaxConnReadSize: defaultMaxConnReadSize,
|
||||
MaxRedirects: defaultMaxRedirects,
|
||||
MaxInfoLen: defaultMaxInfoLen,
|
||||
CacheTTL: defaultIPTTL,
|
||||
})
|
||||
} else {
|
||||
w = whois.Empty{}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("whois")
|
||||
|
||||
for ip := range Context.whoisCh {
|
||||
info, changed := w.Process(context.Background(), ip)
|
||||
if info != nil && changed {
|
||||
Context.clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||
// a subnet set that matches all locally served networks, see
|
||||
// [netutil.IsLocallyServed].
|
||||
@@ -218,9 +279,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
@@ -390,7 +449,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
// pref is a prefix for logging messages around the scope.
|
||||
const pref = "applying filters"
|
||||
|
||||
Context.filters.ApplyBlockedServices(setts, nil)
|
||||
Context.filters.ApplyBlockedServices(setts)
|
||||
|
||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
@@ -414,12 +473,12 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
// TODO(e.burkov): Get rid of this crutch.
|
||||
svcs := c.BlockedServices
|
||||
if svcs == nil {
|
||||
svcs = []string{}
|
||||
setts.ServicesRules = nil
|
||||
svcs := c.BlockedServices.IDs
|
||||
if !c.BlockedServices.Schedule.Contains(time.Now()) {
|
||||
Context.filters.ApplyBlockedServicesList(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
}
|
||||
Context.filters.ApplyBlockedServices(setts, svcs)
|
||||
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
@@ -463,9 +522,7 @@ func startDNSServer() error {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
Context.whoisCh <- ip
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
176
internal/home/dns_internal_test.go
Normal file
176
internal/home/dns_internal_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
"default": {
|
||||
UseOwnSettings: false,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||
FilteringEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
"custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
},
|
||||
"partial_custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
FilteringEnabled assert.BoolAssertionFunc
|
||||
SafeSearchEnabled assert.BoolAssertionFunc
|
||||
SafeBrowsingEnabled assert.BoolAssertionFunc
|
||||
ParentalEnabled assert.BoolAssertionFunc
|
||||
}{{
|
||||
name: "global_settings",
|
||||
id: "default",
|
||||
FilteringEnabled: assert.False,
|
||||
SafeSearchEnabled: assert.False,
|
||||
SafeBrowsingEnabled: assert.False,
|
||||
ParentalEnabled: assert.False,
|
||||
}, {
|
||||
name: "custom_settings",
|
||||
id: "custom_filtering",
|
||||
FilteringEnabled: assert.True,
|
||||
SafeSearchEnabled: assert.True,
|
||||
SafeBrowsingEnabled: assert.True,
|
||||
ParentalEnabled: assert.True,
|
||||
}, {
|
||||
name: "partial",
|
||||
id: "partial_custom_filtering",
|
||||
FilteringEnabled: assert.True,
|
||||
SafeSearchEnabled: assert.True,
|
||||
SafeBrowsingEnabled: assert.False,
|
||||
ParentalEnabled: assert.False,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
tc.FilteringEnabled(t, setts.FilteringEnabled)
|
||||
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
|
||||
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
|
||||
tc.ParentalEnabled(t, setts.ParentalEnabled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
filtering.InitModule()
|
||||
|
||||
var (
|
||||
globalBlockedServices = []string{"ok"}
|
||||
clientBlockedServices = []string{"ok", "mail_ru", "vk"}
|
||||
invalidBlockedServices = []string{"invalid"}
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
Context.filters, err = filtering.New(&filtering.Config{
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: globalBlockedServices,
|
||||
},
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
"default": {
|
||||
UseOwnBlockedServices: false,
|
||||
},
|
||||
"no_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"invalid_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: invalidBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
"allow_all": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.FullWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
id string
|
||||
wantLen int
|
||||
}{{
|
||||
name: "global_settings",
|
||||
id: "default",
|
||||
wantLen: len(globalBlockedServices),
|
||||
}, {
|
||||
name: "custom_settings",
|
||||
id: "no_services",
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "custom_settings_block",
|
||||
id: "services",
|
||||
wantLen: len(clientBlockedServices),
|
||||
}, {
|
||||
name: "custom_settings_invalid",
|
||||
id: "invalid_services",
|
||||
wantLen: 0,
|
||||
}, {
|
||||
name: "custom_settings_inactive_schedule",
|
||||
id: "allow_all",
|
||||
wantLen: 0,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setts := &filtering.Settings{}
|
||||
|
||||
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
|
||||
require.Len(t, setts.ServicesRules, tc.wantLen)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,6 @@ type homeContext struct {
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dhcpServer dhcpd.Interface // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
@@ -84,6 +83,9 @@ type homeContext struct {
|
||||
client *http.Client
|
||||
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
|
||||
|
||||
// whoisCh is the channel for receiving IPs for WHOIS processing.
|
||||
whoisCh chan netip.Addr
|
||||
|
||||
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
tlsCipherIDs []uint16
|
||||
|
||||
@@ -353,21 +355,43 @@ func initContextClients() (err error) {
|
||||
arpdb = aghnet.NewARPDB()
|
||||
}
|
||||
|
||||
Context.clients.Init(
|
||||
err = Context.clients.Init(
|
||||
config.Clients.Persistent,
|
||||
Context.dhcpServer,
|
||||
Context.etcHosts,
|
||||
arpdb,
|
||||
config.DNS.DnsfilterConf,
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupBindOpts overrides bind host/port from the opts.
|
||||
func setupBindOpts(opts options) (err error) {
|
||||
bindAddr := opts.bindAddr
|
||||
if bindAddr != (netip.AddrPort{}) {
|
||||
config.HTTPConfig.Address = bindAddr
|
||||
|
||||
if config.HTTPConfig.Address.Port() != 0 {
|
||||
err = checkPorts()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.bindPort != 0 {
|
||||
config.BindPort = opts.bindPort
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(
|
||||
config.HTTPConfig.Address.Addr(),
|
||||
uint16(opts.bindPort),
|
||||
)
|
||||
|
||||
err = checkPorts()
|
||||
if err != nil {
|
||||
@@ -377,7 +401,10 @@ func setupBindOpts(opts options) (err error) {
|
||||
}
|
||||
|
||||
if opts.bindHost.IsValid() {
|
||||
config.BindHost = opts.bindHost
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(
|
||||
opts.bindHost,
|
||||
config.HTTPConfig.Address.Port(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -461,7 +488,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
// checkPorts is a helper for ports validation in config.
|
||||
func checkPorts() (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort))
|
||||
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
@@ -501,8 +528,8 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
|
||||
|
||||
webConf := webConfig{
|
||||
firstRun: Context.firstRun,
|
||||
BindHost: config.BindHost,
|
||||
BindPort: config.BindPort,
|
||||
BindHost: config.HTTPConfig.Address.Addr(),
|
||||
BindPort: int(config.HTTPConfig.Address.Port()),
|
||||
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
@@ -638,8 +665,8 @@ func initUsers() (auth *Auth, err error) {
|
||||
log.Info("authratelimiter is disabled")
|
||||
}
|
||||
|
||||
sessionTTL := config.WebSessionTTLHours * 60 * 60
|
||||
auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter)
|
||||
sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
|
||||
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter)
|
||||
if auth == nil {
|
||||
return nil, errors.Error("initializing auth module failed")
|
||||
}
|
||||
@@ -917,7 +944,7 @@ func printHTTPAddresses(proto string) {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := config.BindPort
|
||||
port := int(config.HTTPConfig.Address.Port())
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
@@ -929,9 +956,9 @@ func printHTTPAddresses(proto string) {
|
||||
return
|
||||
}
|
||||
|
||||
bindhost := config.BindHost
|
||||
if !bindhost.IsUnspecified() {
|
||||
printWebAddrs(proto, bindhost.String(), port)
|
||||
bindHost := config.HTTPConfig.Address.Addr()
|
||||
if !bindHost.IsUnspecified() {
|
||||
printWebAddrs(proto, bindHost.String(), port)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -942,14 +969,14 @@ func printHTTPAddresses(proto string) {
|
||||
// That's weird, but we'll ignore it.
|
||||
//
|
||||
// TODO(e.burkov): Find out when it happens.
|
||||
printWebAddrs(proto, bindhost.String(), port)
|
||||
printWebAddrs(proto, bindHost.String(), port)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
printWebAddrs(proto, addr.String(), config.BindPort)
|
||||
printWebAddrs(proto, addr.String(), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,11 +35,18 @@ type options struct {
|
||||
serviceControlAction string
|
||||
|
||||
// bindHost is the address on which to serve the HTTP UI.
|
||||
//
|
||||
// Deprecated: Use bindAddr.
|
||||
bindHost netip.Addr
|
||||
|
||||
// bindPort is the port on which to serve the HTTP UI.
|
||||
//
|
||||
// Deprecated: Use bindAddr.
|
||||
bindPort int
|
||||
|
||||
// bindAddr is the address to serve the web UI on.
|
||||
bindAddr netip.AddrPort
|
||||
|
||||
// checkConfig is true if the current invocation is only required to check
|
||||
// the configuration file and exit.
|
||||
checkConfig bool
|
||||
@@ -147,9 +154,10 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
|
||||
return o.bindHost.String(), true
|
||||
},
|
||||
description: "Host address to bind HTTP server on.",
|
||||
longName: "host",
|
||||
shortName: "h",
|
||||
description: "Deprecated. Host address to bind HTTP server on. Use --web-addr. " +
|
||||
"The short -h will work as --help in the future.",
|
||||
longName: "host",
|
||||
shortName: "h",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
var err error
|
||||
@@ -174,9 +182,23 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
|
||||
return strconv.Itoa(o.bindPort), true
|
||||
},
|
||||
description: "Port to serve HTTP pages on.",
|
||||
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
|
||||
longName: "port",
|
||||
shortName: "p",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (oo options, err error) {
|
||||
o.bindAddr, err = netip.ParseAddrPort(v)
|
||||
|
||||
return o, err
|
||||
},
|
||||
updateNoValue: nil,
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) {
|
||||
return o.bindAddr.String(), o.bindAddr.IsValid()
|
||||
},
|
||||
description: "Address to serve the web UI on, in the host:port format.",
|
||||
longName: "web-addr",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
o.serviceControlAction = v
|
||||
|
||||
@@ -82,6 +82,23 @@ func TestParseBindPort(t *testing.T) {
|
||||
testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
|
||||
}
|
||||
|
||||
func TestParseBindAddr(t *testing.T) {
|
||||
wantAddrPort := netip.MustParseAddrPort("1.2.3.4:8089")
|
||||
|
||||
assert.Zero(t, testParseOK(t).bindAddr, "empty is not web-addr")
|
||||
|
||||
assert.Equal(t, wantAddrPort, testParseOK(t, "--web-addr", "1.2.3.4:8089").bindAddr)
|
||||
assert.Equal(t, netip.MustParseAddrPort("1.2.3.4:0"), testParseOK(t, "--web-addr", "1.2.3.4:0").bindAddr)
|
||||
testParseParamMissing(t, "-web-addr")
|
||||
|
||||
testParseErr(t, "not an int", "--web-addr", "1.2.3.4:x")
|
||||
testParseErr(t, "hex not supported", "--web-addr", "1.2.3.4:0x100")
|
||||
testParseErr(t, "port negative", "--web-addr", "1.2.3.4:-1")
|
||||
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:65536")
|
||||
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:4294967297") // 2^32 + 1
|
||||
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:18446744073709551617") // 2^64 + 1
|
||||
}
|
||||
|
||||
func TestParseLogfile(t *testing.T) {
|
||||
assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
|
||||
assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
|
||||
@@ -162,6 +179,10 @@ func TestOptsToArgs(t *testing.T) {
|
||||
name: "bind_port",
|
||||
args: []string{"-p", "666"},
|
||||
opts: options{bindPort: 666},
|
||||
}, {
|
||||
name: "web-addr",
|
||||
args: []string{"--web-addr", "1.2.3.4:8080"},
|
||||
opts: options{bindAddr: netip.MustParseAddrPort("1.2.3.4:8080")},
|
||||
}, {
|
||||
name: "log_file",
|
||||
args: []string{"-l", "path"},
|
||||
|
||||
@@ -228,12 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
w.Reset()
|
||||
|
||||
cc := &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: map[string]*Client{},
|
||||
ipToRC: map[netip.Addr]*RuntimeClient{},
|
||||
allTags: stringutil.NewSet(),
|
||||
}
|
||||
cc := newClientsContainer(t)
|
||||
ch := make(chan netip.Addr)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
|
||||
@@ -320,7 +320,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
@@ -407,7 +407,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
|
||||
@@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 20
|
||||
const currentSchemaVersion = 23
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
type (
|
||||
@@ -94,6 +95,9 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
upgradeSchema17to18,
|
||||
upgradeSchema18to19,
|
||||
upgradeSchema19to20,
|
||||
upgradeSchema20to21,
|
||||
upgradeSchema21to22,
|
||||
upgradeSchema22to23,
|
||||
}
|
||||
|
||||
n := 0
|
||||
@@ -1128,6 +1132,199 @@ func upgradeSchema19to20(diskConf yobj) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema20to21 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'dns':
|
||||
// 'blocked_services':
|
||||
// - 'svc_name'
|
||||
//
|
||||
// # AFTER:
|
||||
// 'dns':
|
||||
// 'blocked_services':
|
||||
// 'ids':
|
||||
// - 'svc_name'
|
||||
// 'schedule':
|
||||
// 'time_zone': 'Local'
|
||||
func upgradeSchema20to21(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 20 to 21")
|
||||
diskConf["schema_version"] = 21
|
||||
|
||||
const field = "blocked_services"
|
||||
|
||||
dnsVal, ok := diskConf["dns"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
dns, ok := dnsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
|
||||
}
|
||||
|
||||
blockedVal, ok := dns[field]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
services, ok := blockedVal.(yarr)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of blocked: %T", blockedVal)
|
||||
}
|
||||
|
||||
dns[field] = yobj{
|
||||
"ids": services,
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema21to22 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'persistent':
|
||||
// - 'name': 'client_name'
|
||||
// 'blocked_services':
|
||||
// - 'svc_name'
|
||||
//
|
||||
// # AFTER:
|
||||
// 'persistent':
|
||||
// - 'name': 'client_name'
|
||||
// 'blocked_services':
|
||||
// 'ids':
|
||||
// - 'svc_name'
|
||||
// 'schedule':
|
||||
// 'time_zone': 'Local'
|
||||
func upgradeSchema21to22(diskConf yobj) (err error) {
|
||||
log.Println("Upgrade yaml: 21 to 22")
|
||||
diskConf["schema_version"] = 22
|
||||
|
||||
const field = "blocked_services"
|
||||
|
||||
clientsVal, ok := diskConf["clients"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients, ok := clientsVal.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
|
||||
}
|
||||
|
||||
persistentVal, ok := clients["persistent"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
persistent, ok := persistentVal.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal)
|
||||
}
|
||||
|
||||
for i, val := range persistent {
|
||||
var c yobj
|
||||
c, ok = val.(yobj)
|
||||
if !ok {
|
||||
return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val)
|
||||
}
|
||||
|
||||
var blockedVal any
|
||||
blockedVal, ok = c[field]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var services yarr
|
||||
services, ok = blockedVal.(yarr)
|
||||
if !ok {
|
||||
return fmt.Errorf(
|
||||
"persistent client at index %d: unexpected type of blocked services: %T",
|
||||
i,
|
||||
blockedVal,
|
||||
)
|
||||
}
|
||||
|
||||
c[field] = yobj{
|
||||
"ids": services,
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchema22to23 performs the following changes:
|
||||
//
|
||||
// # BEFORE:
|
||||
// 'bind_host': '1.2.3.4'
|
||||
// 'bind_port': 8080
|
||||
// 'web_session_ttl': 720
|
||||
//
|
||||
// # AFTER:
|
||||
// 'http':
|
||||
// 'address': '1.2.3.4:8080'
|
||||
// 'session_ttl': '720h'
|
||||
func upgradeSchema22to23(diskConf yobj) (err error) {
|
||||
log.Printf("Upgrade yaml: 22 to 23")
|
||||
diskConf["schema_version"] = 23
|
||||
|
||||
bindHostVal, ok := diskConf["bind_host"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
bindHost, ok := bindHostVal.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of bind_host: %T", bindHostVal)
|
||||
}
|
||||
|
||||
bindHostAddr, err := netip.ParseAddr(bindHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid bind_host value: %s", bindHost)
|
||||
}
|
||||
|
||||
bindPortVal, ok := diskConf["bind_port"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
bindPort, ok := bindPortVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of bind_port: %T", bindPortVal)
|
||||
}
|
||||
|
||||
sessionTTLVal, ok := diskConf["web_session_ttl"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionTTL, ok := sessionTTLVal.(int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type of web_session_ttl: %T", sessionTTLVal)
|
||||
}
|
||||
|
||||
addr := netip.AddrPortFrom(bindHostAddr, uint16(bindPort))
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("invalid address: %s", addr)
|
||||
}
|
||||
|
||||
diskConf["http"] = yobj{
|
||||
"address": addr.String(),
|
||||
"session_ttl": timeutil.Duration{Duration: time.Duration(sessionTTL) * time.Hour}.String(),
|
||||
}
|
||||
|
||||
delete(diskConf, "bind_host")
|
||||
delete(diskConf, "bind_port")
|
||||
delete(diskConf, "web_session_ttl")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Replace with log.Output when we port it to our logging
|
||||
// package.
|
||||
func funcName() string {
|
||||
|
||||
@@ -1140,3 +1140,169 @@ func TestUpgradeSchema19to20(t *testing.T) {
|
||||
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpgradeSchema20to21(t *testing.T) {
|
||||
const newSchemaVer = 21
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
name: "nothing",
|
||||
in: yobj{},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}, {
|
||||
name: "no_clients",
|
||||
in: yobj{
|
||||
"dns": yobj{
|
||||
"blocked_services": yarr{"ok"},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"dns": yobj{
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{"ok"},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema20to21(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema21to22(t *testing.T) {
|
||||
const newSchemaVer = 22
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
in: yobj{
|
||||
"clients": yobj{},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "nothing",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{
|
||||
"name": "localhost",
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "no_services",
|
||||
}, {
|
||||
in: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}},
|
||||
},
|
||||
},
|
||||
want: yobj{
|
||||
"clients": yobj{
|
||||
"persistent": []any{yobj{
|
||||
"name": "localhost",
|
||||
"blocked_services": yobj{
|
||||
"ids": yarr{"ok"},
|
||||
"schedule": yobj{
|
||||
"time_zone": "Local",
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
name: "services",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema21to22(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpgradeSchema22to23(t *testing.T) {
|
||||
const newSchemaVer = 23
|
||||
|
||||
testCases := []struct {
|
||||
in yobj
|
||||
want yobj
|
||||
name string
|
||||
}{{
|
||||
name: "empty",
|
||||
in: yobj{},
|
||||
want: yobj{
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}, {
|
||||
name: "ok",
|
||||
in: yobj{
|
||||
"bind_host": "1.2.3.4",
|
||||
"bind_port": 8081,
|
||||
"web_session_ttl": 720,
|
||||
},
|
||||
want: yobj{
|
||||
"http": yobj{
|
||||
"address": "1.2.3.4:8081",
|
||||
"session_ttl": "720h",
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}, {
|
||||
name: "v6_address",
|
||||
in: yobj{
|
||||
"bind_host": "2001:db8::1",
|
||||
"bind_port": 8081,
|
||||
"web_session_ttl": 720,
|
||||
},
|
||||
want: yobj{
|
||||
"http": yobj{
|
||||
"address": "[2001:db8::1]:8081",
|
||||
"session_ttl": "720h",
|
||||
},
|
||||
"schema_version": newSchemaVer,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := upgradeSchema22to23(tc.in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.want, tc.in)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +119,9 @@ func webCheckPortAvailable(port int) (ok bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
|
||||
|
||||
return aghnet.CheckPort("tcp", addrPort) == nil
|
||||
}
|
||||
|
||||
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
|
||||
@@ -1,259 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultServer = "whois.arin.net"
|
||||
defaultPort = "43"
|
||||
maxValueLength = 250
|
||||
whoisTTL = 1 * 60 * 60 // 1 hour
|
||||
)
|
||||
|
||||
// WHOIS - module context
|
||||
type WHOIS struct {
|
||||
clients *clientsContainer
|
||||
ipChan chan netip.Addr
|
||||
|
||||
// dialContext specifies the dial function for creating unencrypted TCP
|
||||
// connections.
|
||||
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||
|
||||
// Contains IP addresses of clients
|
||||
// An active IP address is resolved once again after it expires.
|
||||
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
|
||||
ipAddrs cache.Cache
|
||||
|
||||
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
|
||||
timeoutMsec uint
|
||||
}
|
||||
|
||||
// initWHOIS creates the WHOIS module context.
|
||||
func initWHOIS(clients *clientsContainer) *WHOIS {
|
||||
w := WHOIS{
|
||||
timeoutMsec: 5000,
|
||||
clients: clients,
|
||||
ipAddrs: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: 10000,
|
||||
}),
|
||||
dialContext: customDialContext,
|
||||
ipChan: make(chan netip.Addr, 255),
|
||||
}
|
||||
|
||||
go w.workerLoop()
|
||||
|
||||
return &w
|
||||
}
|
||||
|
||||
// If the value is too large - cut it and append "..."
|
||||
func trimValue(s string) string {
|
||||
if len(s) <= maxValueLength {
|
||||
return s
|
||||
}
|
||||
return s[:maxValueLength-3] + "..."
|
||||
}
|
||||
|
||||
// isWHOISComment returns true if the string is empty or is a WHOIS comment.
|
||||
func isWHOISComment(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#' || s[0] == '%'
|
||||
}
|
||||
|
||||
// strmap is an alias for convenience.
|
||||
type strmap = map[string]string
|
||||
|
||||
// whoisParse parses a subset of plain-text data from the WHOIS response into
|
||||
// a string map.
|
||||
func whoisParse(data string) (m strmap) {
|
||||
m = strmap{}
|
||||
|
||||
var orgname string
|
||||
lines := strings.Split(data, "\n")
|
||||
for _, l := range lines {
|
||||
if isWHOISComment(l) {
|
||||
continue
|
||||
}
|
||||
|
||||
kv := strings.SplitN(l, ":", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
k := strings.ToLower(strings.TrimSpace(kv[0]))
|
||||
v := strings.TrimSpace(kv[1])
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
switch k {
|
||||
case "orgname", "org-name":
|
||||
k = "orgname"
|
||||
v = trimValue(v)
|
||||
orgname = v
|
||||
case "city", "country":
|
||||
v = trimValue(v)
|
||||
case "descr", "netname":
|
||||
k = "orgname"
|
||||
v = stringutil.Coalesce(orgname, v)
|
||||
orgname = v
|
||||
case "whois":
|
||||
k = "whois"
|
||||
case "referralserver":
|
||||
k = "whois"
|
||||
v = strings.TrimPrefix(v, "whois://")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
m[k] = v
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
|
||||
const MaxConnReadSize = 64 * 1024
|
||||
|
||||
// Send request to a server and receive the response
|
||||
func (w *WHOIS) query(ctx context.Context, target, serverAddr string) (data string, err error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
if addr == "whois.arin.net" {
|
||||
target = "n + " + target
|
||||
}
|
||||
|
||||
conn, err := w.dialContext(ctx, "tcp", serverAddr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
|
||||
|
||||
r, err := aghio.LimitReader(conn, MaxConnReadSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
|
||||
_, err = conn.Write([]byte(target + "\r\n"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// This use of ReadAll is now safe, because we limited the conn Reader.
|
||||
var whoisData []byte
|
||||
whoisData, err = io.ReadAll(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(whoisData), nil
|
||||
}
|
||||
|
||||
// Query WHOIS servers (handle redirects)
|
||||
func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
|
||||
server := net.JoinHostPort(defaultServer, defaultPort)
|
||||
const maxRedirects = 5
|
||||
for i := 0; i != maxRedirects; i++ {
|
||||
resp, err := w.query(ctx, target, server)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
log.Debug("whois: received response (%d bytes) from %s IP:%s", len(resp), server, target)
|
||||
|
||||
m := whoisParse(resp)
|
||||
redir, ok := m["whois"]
|
||||
if !ok {
|
||||
return resp, nil
|
||||
}
|
||||
redir = strings.ToLower(redir)
|
||||
|
||||
_, _, err = net.SplitHostPort(redir)
|
||||
if err != nil {
|
||||
server = net.JoinHostPort(redir, defaultPort)
|
||||
} else {
|
||||
server = redir
|
||||
}
|
||||
|
||||
log.Debug("whois: redirected to %s IP:%s", redir, target)
|
||||
}
|
||||
return "", fmt.Errorf("whois: redirect loop")
|
||||
}
|
||||
|
||||
// Request WHOIS information
|
||||
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
|
||||
resp, err := w.queryAll(ctx, ip.String())
|
||||
if err != nil {
|
||||
log.Debug("whois: error: %s IP:%s", err, ip)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("whois: IP:%s response: %d bytes", ip, len(resp))
|
||||
|
||||
m := whoisParse(resp)
|
||||
|
||||
wi = &RuntimeClientWHOISInfo{
|
||||
City: m["city"],
|
||||
Country: m["country"],
|
||||
Orgname: m["orgname"],
|
||||
}
|
||||
|
||||
// Don't return an empty struct so that the frontend doesn't get
|
||||
// confused.
|
||||
if *wi == (RuntimeClientWHOISInfo{}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return wi
|
||||
}
|
||||
|
||||
// Begin - begin requesting WHOIS info
|
||||
func (w *WHOIS) Begin(ip netip.Addr) {
|
||||
ipBytes := ip.AsSlice()
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := w.ipAddrs.Get(ipBytes)
|
||||
if len(expire) != 0 {
|
||||
exp := binary.BigEndian.Uint64(expire)
|
||||
if exp > now {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
expire = make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expire, now+whoisTTL)
|
||||
_ = w.ipAddrs.Set(ipBytes, expire)
|
||||
|
||||
log.Debug("whois: adding %s", ip)
|
||||
|
||||
select {
|
||||
case w.ipChan <- ip:
|
||||
default:
|
||||
log.Debug("whois: queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// workerLoop processes the IP addresses it got from the channel and associates
|
||||
// the retrieving WHOIS info with a client.
|
||||
func (w *WHOIS) workerLoop() {
|
||||
for ip := range w.ipChan {
|
||||
info := w.process(context.Background(), ip)
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
w.clients.setWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeConn is a mock implementation of net.Conn to simplify testing.
|
||||
//
|
||||
// TODO(e.burkov): Search for other places in code where it may be used. Move
|
||||
// into aghtest then.
|
||||
type fakeConn struct {
|
||||
// Conn is embedded here simply to make *fakeConn a net.Conn without
|
||||
// actually implementing all methods.
|
||||
net.Conn
|
||||
data []byte
|
||||
}
|
||||
|
||||
// Write implements net.Conn interface for *fakeConn. It always returns 0 and a
|
||||
// nil error without mutating the slice.
|
||||
func (c *fakeConn) Write(_ []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read implements net.Conn interface for *fakeConn. It puts the content of
|
||||
// c.data field into b up to the b's capacity.
|
||||
func (c *fakeConn) Read(b []byte) (n int, err error) {
|
||||
return copy(b, c.data), io.EOF
|
||||
}
|
||||
|
||||
// Close implements net.Conn interface for *fakeConn. It always returns nil.
|
||||
func (c *fakeConn) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn interface for *fakeConn. It always
|
||||
// returns nil.
|
||||
func (c *fakeConn) SetReadDeadline(_ time.Time) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// fakeDial is a mock implementation of customDialContext to simplify testing.
|
||||
func (c *fakeConn) fakeDial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func TestWHOIS(t *testing.T) {
|
||||
const (
|
||||
nl = "\n"
|
||||
data = `OrgName: FakeOrg LLC` + nl +
|
||||
`City: Nonreal` + nl +
|
||||
`Country: Imagiland` + nl
|
||||
)
|
||||
|
||||
fc := &fakeConn{
|
||||
data: []byte(data),
|
||||
}
|
||||
|
||||
w := WHOIS{
|
||||
timeoutMsec: 5000,
|
||||
dialContext: fc.fakeDial,
|
||||
}
|
||||
resp, err := w.queryAll(context.Background(), "1.2.3.4")
|
||||
assert.NoError(t, err)
|
||||
|
||||
m := whoisParse(resp)
|
||||
require.NotEmpty(t, m)
|
||||
|
||||
assert.Equal(t, "FakeOrg LLC", m["orgname"])
|
||||
assert.Equal(t, "Imagiland", m["country"])
|
||||
assert.Equal(t, "Nonreal", m["city"])
|
||||
}
|
||||
|
||||
func TestWHOISParse(t *testing.T) {
|
||||
const (
|
||||
city = "Nonreal"
|
||||
country = "Imagiland"
|
||||
orgname = "FakeOrgLLC"
|
||||
whois = "whois.example.net"
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
want strmap
|
||||
name string
|
||||
in string
|
||||
}{{
|
||||
want: strmap{},
|
||||
name: "empty",
|
||||
in: ``,
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "comments",
|
||||
in: "%\n#",
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "no_colon",
|
||||
in: "city",
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "no_value",
|
||||
in: "city:",
|
||||
}, {
|
||||
want: strmap{"city": city},
|
||||
name: "city",
|
||||
in: `city: ` + city,
|
||||
}, {
|
||||
want: strmap{"country": country},
|
||||
name: "country",
|
||||
in: `country: ` + country,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname",
|
||||
in: `orgname: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname_hyphen",
|
||||
in: `org-name: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname_descr",
|
||||
in: `descr: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"orgname": orgname},
|
||||
name: "orgname_netname",
|
||||
in: `netname: ` + orgname,
|
||||
}, {
|
||||
want: strmap{"whois": whois},
|
||||
name: "whois",
|
||||
in: `whois: ` + whois,
|
||||
}, {
|
||||
want: strmap{"whois": whois},
|
||||
name: "referralserver",
|
||||
in: `referralserver: whois://` + whois,
|
||||
}, {
|
||||
want: strmap{},
|
||||
name: "other",
|
||||
in: `other: value`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := whoisParse(tc.in)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user