all: sync with master; upd chlog

This commit is contained in:
Ainar Garipov
2023-07-03 14:10:40 +03:00
parent cadb765b7d
commit b22b16d98c
140 changed files with 6739 additions and 2521 deletions

View File

@@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/stringutil"
)
@@ -22,12 +23,14 @@ type Client struct {
safeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
Name string
IDs []string
Tags []string
BlockedServices []string
Upstreams []string
IDs []string
Tags []string
Upstreams []string
UseOwnSettings bool
FilteringEnabled bool
@@ -43,9 +46,9 @@ type Client struct {
func (c *Client) ShallowClone() (sh *Client) {
clone := *c
clone.BlockedServices = c.BlockedServices.Clone()
clone.IDs = stringutil.CloneSlice(c.IDs)
clone.Tags = stringutil.CloneSlice(c.Tags)
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
return &clone
@@ -127,14 +130,13 @@ func (cs clientSource) MarshalText() (text []byte, err error) {
// RuntimeClient is a client information about which has been obtained using the
// source described in the Source field.
type RuntimeClient struct {
WHOISInfo *RuntimeClientWHOISInfo
Host string
Source clientSource
}
// WHOIS is the filtered WHOIS data of a client.
WHOIS *whois.Info
// RuntimeClientWHOISInfo is the filtered WHOIS data for a runtime client.
type RuntimeClientWHOISInfo struct {
City string `json:"city,omitempty"`
Country string `json:"country,omitempty"`
Orgname string `json:"orgname,omitempty"`
// Host is the host name of a client.
Host string
// Source is the source from which the information about the client has
// been obtained.
Source clientSource
}

View File

@@ -11,9 +11,11 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@@ -23,6 +25,23 @@ import (
"golang.org/x/exp/slices"
)
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
// needs.
type DHCP interface {
// Leases returns all the DHCP leases.
Leases() (leases []*dhcpsvc.Lease)
// HostByIP returns the hostname of the DHCP client with the given IP
// address. The address will be netip.Addr{} if there is no such client,
// due to an assumption that a DHCP client must always have an IP address.
HostByIP(ip netip.Addr) (host string)
// MACByIP returns the MAC address for the given IP address leased. It
// returns nil if there is no such client, due to an assumption that a DHCP
// client must always have a MAC address.
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
}
// clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different
@@ -77,7 +96,7 @@ func (clients *clientsContainer) Init(
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
filteringConf *filtering.Config,
) {
) (err error) {
if clients.list != nil {
log.Fatal("clients.list != nil")
}
@@ -91,23 +110,29 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts
clients.arpdb = arpdb
clients.addFromConfig(objects, filteringConf)
err = clients.addFromConfig(objects, filteringConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing {
return
return nil
}
clients.updateFromDHCP(true)
if clients.dhcpServer != nil {
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
clients.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded)
}
if clients.etcHosts != nil {
go clients.handleHostsUpdates()
}
return nil
}
func (clients *clientsContainer) handleHostsUpdates() {
@@ -147,12 +172,14 @@ func (clients *clientsContainer) reloadARP() {
type clientObject struct {
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices `yaml:"blocked_services"`
Name string `yaml:"name"`
Tags []string `yaml:"tags"`
IDs []string `yaml:"ids"`
BlockedServices []string `yaml:"blocked_services"`
Upstreams []string `yaml:"upstreams"`
IDs []string `yaml:"ids"`
Tags []string `yaml:"tags"`
Upstreams []string `yaml:"upstreams"`
UseGlobalSettings bool `yaml:"use_global_settings"`
FilteringEnabled bool `yaml:"filtering_enabled"`
@@ -166,7 +193,10 @@ type clientObject struct {
// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) {
func (clients *clientsContainer) addFromConfig(
objects []*clientObject,
filteringConf *filtering.Config,
) (err error) {
for _, o := range objects {
cli := &Client{
Name: o.Name,
@@ -187,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err := cli.setSafeSearch(
err = cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
@@ -199,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
}
}
for _, s := range o.BlockedServices {
if filtering.BlockedSvcKnown(s) {
cli.BlockedServices = append(cli.BlockedServices, s)
} else {
log.Info("clients: skipping unknown blocked service %q", s)
}
err = o.BlockedServices.Validate()
if err != nil {
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
}
cli.BlockedServices = o.BlockedServices.Clone()
for _, t := range o.Tags {
if clients.allTags.Has(t) {
cli.Tags = append(cli.Tags, t)
@@ -217,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
slices.Sort(cli.Tags)
_, err := clients.Add(cli)
_, err = clients.Add(cli)
if err != nil {
log.Error("clients: adding clients %s: %s", cli.Name, err)
}
}
return nil
}
// forConfig returns all currently known persistent clients as objects for the
@@ -235,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
o := &clientObject{
Name: cli.Name,
Tags: stringutil.CloneSlice(cli.Tags),
IDs: stringutil.CloneSlice(cli.IDs),
BlockedServices: stringutil.CloneSlice(cli.BlockedServices),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
BlockedServices: cli.BlockedServices.Clone(),
IDs: stringutil.CloneSlice(cli.IDs),
Tags: stringutil.CloneSlice(cli.Tags),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
@@ -276,15 +308,38 @@ func (clients *clientsContainer) periodicUpdate() {
}
}
// onDHCPLeaseChanged is a callback for the DHCP server. It updates the list of
// runtime clients using the DHCP server's leases.
//
// TODO(e.burkov): Remove when switched to dhcpsvc.
func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
switch flags {
case dhcpd.LeaseChangedAdded,
dhcpd.LeaseChangedAddedStatic,
dhcpd.LeaseChangedRemovedStatic:
clients.updateFromDHCP(true)
case dhcpd.LeaseChangedRemovedAll:
clients.updateFromDHCP(false)
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceDHCP)
if flags == dhcpd.LeaseChangedRemovedAll {
return
}
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if l.Hostname == "" {
continue
}
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("clients: added %d client aliases from dhcp", n)
}
// clientSource checks if client with this IP address already exists and returns
@@ -300,23 +355,11 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src clientSource)
}
rc, ok := clients.ipToRC[ip]
if !ok {
return ClientSourceNone
if ok {
return rc.Source
}
return rc.Source
}
func toQueryLogWHOIS(wi *RuntimeClientWHOISInfo) (cw *querylog.ClientWHOIS) {
if wi == nil {
return &querylog.ClientWHOIS{}
}
return &querylog.ClientWHOIS{
City: wi.City,
Country: wi.Country,
Orgname: wi.Orgname,
}
return ClientSourceNone
}
// findMultiple is a wrapper around Find to make it a valid client finder for
@@ -352,7 +395,7 @@ func (clients *clientsContainer) clientOrArtificial(
defer func() {
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id)
if c.WHOIS == nil {
c.WHOIS = &querylog.ClientWHOIS{}
c.WHOIS = &whois.Info{}
}
}()
@@ -369,7 +412,7 @@ func (clients *clientsContainer) clientOrArtificial(
if ok {
return &querylog.Client{
Name: rc.Host,
WHOIS: toQueryLogWHOIS(rc.WHOISInfo),
WHOIS: rc.WHOIS,
}, false
}
@@ -477,11 +520,11 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
}
}
if clients.dhcpServer == nil {
return nil, false
if clients.dhcpServer != nil {
return clients.findDHCP(ip)
}
return clients.findDHCP(ip)
return nil, false
}
// findDHCP searches for a client by its MAC, if the DHCP server is active and
@@ -701,35 +744,34 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
}
// setWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findLocked(ip.String())
if ok {
log.Debug("clients: client for %s is already created, ignore whois info", ip)
return
}
// TODO(e.burkov): Consider storing WHOIS information separately and
// potentially get rid of [RuntimeClient].
rc, ok := clients.ipToRC[ip]
if ok {
rc.WHOISInfo = wi
if !ok {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &RuntimeClient{
Source: ClientSourceWHOIS,
}
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
return
}
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &RuntimeClient{
Source: ClientSourceWHOIS,
}
rc.WHOISInfo = wi
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
rc.WHOIS = wi
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
@@ -753,23 +795,19 @@ func (clients *clientsContainer) addHostLocked(
src clientSource,
) (ok bool) {
rc, ok := clients.ipToRC[ip]
if ok {
if rc.Source > src {
return false
}
rc.Host = host
rc.Source = src
} else {
if !ok {
rc = &RuntimeClient{
Host: host,
Source: src,
WHOISInfo: &RuntimeClientWHOISInfo{},
WHOIS: &whois.Info{},
}
clients.ipToRC[ip] = rc
} else if src < rc.Source {
return false
}
rc.Host = host
rc.Source = src
log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC))
return true
@@ -838,38 +876,6 @@ func (clients *clientsContainer) addFromSystemARP() {
log.Debug("clients: added %d client aliases from arp neighborhood", added)
}
// updateFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) updateFromDHCP(add bool) {
if clients.dhcpServer == nil || !config.Clients.Sources.DHCP {
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(ClientSourceDHCP)
if !add {
return
}
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if l.Hostname == "" {
continue
}
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("clients: added %d client aliases from dhcp", n)
}
// close gracefully closes all the client-specific upstream configurations of
// the persistent clients.
func (clients *clientsContainer) close() (err error) {

View File

@@ -9,25 +9,26 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newClientsContainer is a helper that creates a new clients container for
// tests.
func newClientsContainer() (c *clientsContainer) {
func newClientsContainer(t *testing.T) (c *clientsContainer) {
c = &clientsContainer{
testing: true,
}
c.Init(nil, nil, nil, nil, &filtering.Config{})
err := c.Init(nil, nil, nil, nil, &filtering.Config{})
require.NoError(t, err)
return c
}
func TestClients(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
t.Run("add_success", func(t *testing.T) {
var (
@@ -198,8 +199,8 @@ func TestClients(t *testing.T) {
}
func TestClientsWHOIS(t *testing.T) {
clients := newClientsContainer()
whois := &RuntimeClientWHOISInfo{
clients := newClientsContainer(t)
whois := &whois.Info{
Country: "AU",
Orgname: "Example Org",
}
@@ -210,7 +211,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
assert.Equal(t, rc.WHOISInfo, whois)
assert.Equal(t, rc.WHOIS, whois)
})
t.Run("existing_auto-client", func(t *testing.T) {
@@ -222,7 +223,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
assert.Equal(t, rc.WHOISInfo, whois)
assert.Equal(t, rc.WHOIS, whois)
})
t.Run("can't_set_manually-added", func(t *testing.T) {
@@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
}
func TestClientsAddExisting(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
@@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer()
clients := newClientsContainer(t)
// Add client with upstreams.
ok, err := clients.Add(&Client{

View File

@@ -9,6 +9,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
)
// clientJSON is a common structure used by several handlers to deal with
@@ -28,7 +30,8 @@ type clientJSON struct {
// the allowlist.
DisallowedRule *string `json:"disallowed_rule,omitempty"`
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info,omitempty"`
// WHOIS is the filtered WHOIS data of a client.
WHOIS *whois.Info `json:"whois_info,omitempty"`
SafeSearchConf *filtering.SafeSearchConfig `json:"safe_search"`
Name string `json:"name"`
@@ -51,7 +54,7 @@ type clientJSON struct {
}
type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
WHOIS *whois.Info `json:"whois_info"`
IP netip.Addr `json:"ip"`
Name string `json:"name"`
@@ -78,7 +81,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
for ip, rc := range clients.ipToRC {
cj := runtimeClientJSON{
WHOISInfo: rc.WHOISInfo,
WHOIS: rc.WHOIS,
Name: rc.Host,
Source: rc.Source,
@@ -116,15 +119,24 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
}
}
weekly := schedule.EmptyWeekly()
if prev != nil {
weekly = prev.BlockedServices.Schedule.Clone()
}
c = &Client{
safeSearchConf: safeSearchConf,
Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
BlockedServices: &filtering.BlockedServices{
Schedule: weekly,
IDs: cj.BlockedServices,
},
IDs: cj.IDs,
Tags: cj.Tags,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
@@ -178,7 +190,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
SafeBrowsingEnabled: c.SafeBrowsingEnabled,
UseGlobalBlockedServices: !c.UseOwnBlockedServices,
BlockedServices: c.BlockedServices,
BlockedServices: c.BlockedServices.IDs,
Upstreams: c.Upstreams,
@@ -344,16 +357,16 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
IDs: []string{idStr},
Disallowed: &disallowed,
DisallowedRule: &rule,
WHOISInfo: &RuntimeClientWHOISInfo{},
WHOIS: &whois.Info{},
}
return cj
}
cj = &clientJSON{
Name: rc.Host,
IDs: []string{idStr},
WHOISInfo: rc.WHOISInfo,
Name: rc.Host,
IDs: []string{idStr},
WHOIS: rc.WHOIS,
}
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View File

@@ -14,6 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/golibs/errors"
@@ -90,18 +91,17 @@ type clientSourcesConfig struct {
HostsFile bool `yaml:"hosts"`
}
// configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here
// configuration is loaded from YAML.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type configuration struct {
// Raw file data to avoid re-reading of configuration file
// It's reset after config is parsed
fileData []byte
// BindHost is the address for the web interface server to listen on.
BindHost netip.Addr `yaml:"bind_host"`
// BindPort is the port for the web interface server to listen on.
BindPort int `yaml:"bind_port"`
// HTTPConfig is the block with http conf.
HTTPConfig httpConfig `yaml:"http"`
// Users are the clients capable for accessing the web interface.
Users []webUser `yaml:"users"`
// AuthAttempts is the maximum number of failed login attempts a user
@@ -119,10 +119,6 @@ type configuration struct {
// DebugPProf defines if the profiling HTTP handler will listen on :6060.
DebugPProf bool `yaml:"debug_pprof"`
// TTL for a web session (in hours)
// An active session is automatically refreshed once a day.
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
QueryLog queryLogConfig `yaml:"querylog"`
@@ -155,7 +151,23 @@ type configuration struct {
SchemaVersion int `yaml:"schema_version"` // keeping last so that users will be less tempted to change it -- used when upgrading between versions
}
// field ordering is important -- yaml fields will mirror ordering from here
// httpConfig is a block with HTTP configuration params.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type httpConfig struct {
// Address is the address to serve the web UI on.
Address netip.AddrPort
// SessionTTL for a web session.
// An active session is automatically refreshed once a day.
SessionTTL timeutil.Duration `yaml:"session_ttl"`
}
// dnsConfig is a block with DNS configuration params.
//
// Field ordering is important, YAML fields better not to be reordered, if it's
// not absolutely necessary.
type dnsConfig struct {
BindHosts []netip.Addr `yaml:"bind_hosts"`
Port int `yaml:"port"`
@@ -260,11 +272,12 @@ type statsConfig struct {
//
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
var config = &configuration{
BindPort: 3000,
BindHost: netip.IPv4Unspecified(),
AuthAttempts: 5,
AuthBlockMin: 15,
WebSessionTTLHours: 30 * 24,
AuthAttempts: 5,
AuthBlockMin: 15,
HTTPConfig: httpConfig{
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
},
DNS: dnsConfig{
BindHosts: []netip.Addr{netip.IPv4Unspecified()},
Port: defaultPortDNS,
@@ -316,6 +329,11 @@ var config = &configuration{
Yandex: true,
YouTube: true,
},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: []string{},
},
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
@@ -421,8 +439,8 @@ func readLogSettings() (ls *logSettings) {
// validateBindHosts returns error if any of binding hosts from configuration is
// not a valid IP address.
func validateBindHosts(conf *configuration) (err error) {
if !conf.BindHost.IsValid() {
return errors.Error("bind_host is not a valid ip address")
if !conf.HTTPConfig.Address.IsValid() {
return errors.Error("http.address is not a valid ip address")
}
for i, addr := range conf.DNS.BindHosts {
@@ -456,7 +474,7 @@ func parseConfig() (err error) {
}
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort))
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))

View File

@@ -103,7 +103,7 @@ type statusResponse struct {
Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"`
HTTPPort uint16 `json:"http_port"`
// ProtectionDisabledDuration is the duration of the protection pause in
// milliseconds.
@@ -158,7 +158,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
Language: config.Language,
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
HTTPPort: config.HTTPConfig.Address.Port(),
ProtectionDisabledDuration: protectionDisabledDuration,
ProtectionEnabled: protectionEnabled,
IsRunning: isRunning(),

View File

@@ -96,8 +96,9 @@ type checkConfResp struct {
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
portInt := req.Web.Port
port := tcpPort(portInt)
// TODO(a.garipov): Declare all port variables anywhere as uint16.
reqPort := uint16(req.Web.Port)
port := tcpPort(reqPort)
addPorts(tcpPorts, port)
if err = tcpPorts.Validate(); err != nil {
// Reset the value for the port to 1 to make sure that validateDNS
@@ -108,15 +109,15 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
return err
}
switch portInt {
case 0, config.BindPort:
switch reqPort {
case 0, config.HTTPConfig.Address.Port():
return nil
default:
// Go on and check the port binding only if it's not zero or won't be
// unbound after install.
}
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(portInt)))
return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, reqPort))
}
// validateDNS returns error if the DNS part of the initial configuration can't
@@ -127,11 +128,11 @@ func (req *checkConfReq) validateDNS(
) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
port := req.DNS.Port
port := uint16(req.DNS.Port)
switch port {
case 0:
return false, nil
case config.BindPort:
case config.HTTPConfig.Address.Port():
// Go on and only check the UDP port since the TCP one is already bound
// by AdGuard Home for web interface.
default:
@@ -318,8 +319,7 @@ type applyConfigReq struct {
// copyInstallSettings copies the installation parameters between two
// configuration structures.
func copyInstallSettings(dst, src *configuration) {
dst.BindHost = src.BindHost
dst.BindPort = src.BindPort
dst.HTTPConfig = src.HTTPConfig
dst.DNS.BindHosts = src.DNS.BindHosts
dst.DNS.Port = src.DNS.Port
}
@@ -413,8 +413,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
copyInstallSettings(curConfig, config)
Context.firstRun = false
config.BindHost = req.Web.IP
config.BindPort = req.Web.Port
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port
@@ -487,7 +486,8 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return nil, false, errors.Error("ports cannot be 0")
}
restartHTTP = config.BindHost != req.Web.IP || config.BindPort != req.Web.Port
addrPort := config.HTTPConfig.Address
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
if restartHTTP {
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
if err != nil {

View File

@@ -157,7 +157,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) || config.BindPort < 1024 || config.DNS.Port < 1024 {
if tlsConfUsesPrivilegedPorts(tlsConf) ||
config.HTTPConfig.Address.Port() < 1024 ||
config.DNS.Port < 1024 {
canUpdate, err = aghnet.CanBindPrivilegedPorts()
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"os"
"path/filepath"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -17,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -25,7 +27,7 @@ import (
yaml "gopkg.in/yaml.v3"
)
// Default ports.
// Default listening ports.
const (
defaultPortDNS = 53
defaultPortHTTP = 80
@@ -169,13 +171,72 @@ func initDNSServer(
Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS)
}
if config.Clients.Sources.WHOIS {
Context.whois = initWHOIS(&Context.clients)
}
initWHOIS()
return nil
}
// initWHOIS initializes the WHOIS.
//
// TODO(s.chzhen): Consider making configurable.
func initWHOIS() {
const (
// defaultQueueSize is the size of queue of IPs for WHOIS processing.
defaultQueueSize = 255
// defaultTimeout is the timeout for WHOIS requests.
defaultTimeout = 5 * time.Second
// defaultCacheSize is the maximum size of the cache. If it's zero,
// cache size is unlimited.
defaultCacheSize = 10_000
// defaultMaxConnReadSize is an upper limit in bytes for reading from
// net.Conn.
defaultMaxConnReadSize = 64 * 1024
// defaultMaxRedirects is the maximum redirects count.
defaultMaxRedirects = 5
// defaultMaxInfoLen is the maximum length of whois.Info fields.
defaultMaxInfoLen = 250
// defaultIPTTL is the Time to Live duration for cached IP addresses.
defaultIPTTL = 1 * time.Hour
)
Context.whoisCh = make(chan netip.Addr, defaultQueueSize)
var w whois.Interface
if config.Clients.Sources.WHOIS {
w = whois.New(&whois.Config{
DialContext: customDialContext,
ServerAddr: whois.DefaultServer,
Port: whois.DefaultPort,
Timeout: defaultTimeout,
CacheSize: defaultCacheSize,
MaxConnReadSize: defaultMaxConnReadSize,
MaxRedirects: defaultMaxRedirects,
MaxInfoLen: defaultMaxInfoLen,
CacheTTL: defaultIPTTL,
})
} else {
w = whois.Empty{}
}
go func() {
defer log.OnPanic("whois")
for ip := range Context.whoisCh {
info, changed := w.Process(context.Background(), ip)
if info != nil && changed {
Context.clients.setWHOISInfo(ip, info)
}
}
}()
}
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
// a subnet set that matches all locally served networks, see
// [netutil.IsLocallyServed].
@@ -218,9 +279,7 @@ func onDNSRequest(pctx *proxy.DNSContext) {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
Context.whoisCh <- ip
}
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
@@ -390,7 +449,7 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
Context.filters.ApplyBlockedServices(setts, nil)
Context.filters.ApplyBlockedServices(setts)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
@@ -414,12 +473,12 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
if c.UseOwnBlockedServices {
// TODO(e.burkov): Get rid of this crutch.
svcs := c.BlockedServices
if svcs == nil {
svcs = []string{}
setts.ServicesRules = nil
svcs := c.BlockedServices.IDs
if !c.BlockedServices.Schedule.Contains(time.Now()) {
Context.filters.ApplyBlockedServicesList(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
Context.filters.ApplyBlockedServices(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
setts.ClientName = c.Name
@@ -463,9 +522,7 @@ func startDNSServer() error {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
Context.whoisCh <- ip
}
return nil

View File

@@ -0,0 +1,176 @@
package home
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
Context.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
"default": {
UseOwnSettings: false,
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
FilteringEnabled: false,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
"custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
},
"partial_custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
}
testCases := []struct {
name string
id string
FilteringEnabled assert.BoolAssertionFunc
SafeSearchEnabled assert.BoolAssertionFunc
SafeBrowsingEnabled assert.BoolAssertionFunc
ParentalEnabled assert.BoolAssertionFunc
}{{
name: "global_settings",
id: "default",
FilteringEnabled: assert.False,
SafeSearchEnabled: assert.False,
SafeBrowsingEnabled: assert.False,
ParentalEnabled: assert.False,
}, {
name: "custom_settings",
id: "custom_filtering",
FilteringEnabled: assert.True,
SafeSearchEnabled: assert.True,
SafeBrowsingEnabled: assert.True,
ParentalEnabled: assert.True,
}, {
name: "partial",
id: "partial_custom_filtering",
FilteringEnabled: assert.True,
SafeSearchEnabled: assert.True,
SafeBrowsingEnabled: assert.False,
ParentalEnabled: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
tc.FilteringEnabled(t, setts.FilteringEnabled)
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
tc.ParentalEnabled(t, setts.ParentalEnabled)
})
}
}
func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
filtering.InitModule()
var (
globalBlockedServices = []string{"ok"}
clientBlockedServices = []string{"ok", "mail_ru", "vk"}
invalidBlockedServices = []string{"invalid"}
err error
)
Context.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: globalBlockedServices,
},
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
"default": {
UseOwnBlockedServices: false,
},
"no_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
UseOwnBlockedServices: true,
},
"services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
},
"invalid_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: invalidBlockedServices,
},
UseOwnBlockedServices: true,
},
"allow_all": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.FullWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
},
}
testCases := []struct {
name string
id string
wantLen int
}{{
name: "global_settings",
id: "default",
wantLen: len(globalBlockedServices),
}, {
name: "custom_settings",
id: "no_services",
wantLen: 0,
}, {
name: "custom_settings_block",
id: "services",
wantLen: len(clientBlockedServices),
}, {
name: "custom_settings_invalid",
id: "invalid_services",
wantLen: 0,
}, {
name: "custom_settings_inactive_schedule",
id: "allow_all",
wantLen: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
require.Len(t, setts.ServicesRules, tc.wantLen)
})
}
}

View File

@@ -57,7 +57,6 @@ type homeContext struct {
queryLog querylog.QueryLog // query log module
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module
dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module
@@ -84,6 +83,9 @@ type homeContext struct {
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
// whoisCh is the channel for receiving IPs for WHOIS processing.
whoisCh chan netip.Addr
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
tlsCipherIDs []uint16
@@ -353,21 +355,43 @@ func initContextClients() (err error) {
arpdb = aghnet.NewARPDB()
}
Context.clients.Init(
err = Context.clients.Init(
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
arpdb,
config.DNS.DnsfilterConf,
)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
return nil
}
// setupBindOpts overrides bind host/port from the opts.
func setupBindOpts(opts options) (err error) {
bindAddr := opts.bindAddr
if bindAddr != (netip.AddrPort{}) {
config.HTTPConfig.Address = bindAddr
if config.HTTPConfig.Address.Port() != 0 {
err = checkPorts()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
}
return nil
}
if opts.bindPort != 0 {
config.BindPort = opts.bindPort
config.HTTPConfig.Address = netip.AddrPortFrom(
config.HTTPConfig.Address.Addr(),
uint16(opts.bindPort),
)
err = checkPorts()
if err != nil {
@@ -377,7 +401,10 @@ func setupBindOpts(opts options) (err error) {
}
if opts.bindHost.IsValid() {
config.BindHost = opts.bindHost
config.HTTPConfig.Address = netip.AddrPortFrom(
opts.bindHost,
config.HTTPConfig.Address.Port(),
)
}
return nil
@@ -461,7 +488,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
// checkPorts is a helper for ports validation in config.
func checkPorts() (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(config.BindPort))
addPorts(tcpPorts, tcpPort(config.HTTPConfig.Address.Port()))
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port))
@@ -501,8 +528,8 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
webConf := webConfig{
firstRun: Context.firstRun,
BindHost: config.BindHost,
BindPort: config.BindPort,
BindHost: config.HTTPConfig.Address.Addr(),
BindPort: int(config.HTTPConfig.Address.Port()),
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHdrTimeout,
@@ -638,8 +665,8 @@ func initUsers() (auth *Auth, err error) {
log.Info("authratelimiter is disabled")
}
sessionTTL := config.WebSessionTTLHours * 60 * 60
auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter)
sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter)
if auth == nil {
return nil, errors.Error("initializing auth module failed")
}
@@ -917,7 +944,7 @@ func printHTTPAddresses(proto string) {
Context.tls.WriteDiskConfig(&tlsConf)
}
port := config.BindPort
port := int(config.HTTPConfig.Address.Port())
if proto == aghhttp.SchemeHTTPS {
port = tlsConf.PortHTTPS
}
@@ -929,9 +956,9 @@ func printHTTPAddresses(proto string) {
return
}
bindhost := config.BindHost
if !bindhost.IsUnspecified() {
printWebAddrs(proto, bindhost.String(), port)
bindHost := config.HTTPConfig.Address.Addr()
if !bindHost.IsUnspecified() {
printWebAddrs(proto, bindHost.String(), port)
return
}
@@ -942,14 +969,14 @@ func printHTTPAddresses(proto string) {
// That's weird, but we'll ignore it.
//
// TODO(e.burkov): Find out when it happens.
printWebAddrs(proto, bindhost.String(), port)
printWebAddrs(proto, bindHost.String(), port)
return
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
printWebAddrs(proto, addr.String(), config.BindPort)
printWebAddrs(proto, addr.String(), port)
}
}
}

View File

@@ -35,11 +35,18 @@ type options struct {
serviceControlAction string
// bindHost is the address on which to serve the HTTP UI.
//
// Deprecated: Use bindAddr.
bindHost netip.Addr
// bindPort is the port on which to serve the HTTP UI.
//
// Deprecated: Use bindAddr.
bindPort int
// bindAddr is the address to serve the web UI on.
bindAddr netip.AddrPort
// checkConfig is true if the current invocation is only required to check
// the configuration file and exit.
checkConfig bool
@@ -147,9 +154,10 @@ var cmdLineOpts = []cmdLineOpt{{
return o.bindHost.String(), true
},
description: "Host address to bind HTTP server on.",
longName: "host",
shortName: "h",
description: "Deprecated. Host address to bind HTTP server on. Use --web-addr. " +
"The short -h will work as --help in the future.",
longName: "host",
shortName: "h",
}, {
updateWithValue: func(o options, v string) (options, error) {
var err error
@@ -174,9 +182,23 @@ var cmdLineOpts = []cmdLineOpt{{
return strconv.Itoa(o.bindPort), true
},
description: "Port to serve HTTP pages on.",
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
longName: "port",
shortName: "p",
}, {
updateWithValue: func(o options, v string) (oo options, err error) {
o.bindAddr, err = netip.ParseAddrPort(v)
return o, err
},
updateNoValue: nil,
effect: nil,
serialize: func(o options) (val string, ok bool) {
return o.bindAddr.String(), o.bindAddr.IsValid()
},
description: "Address to serve the web UI on, in the host:port format.",
longName: "web-addr",
shortName: "",
}, {
updateWithValue: func(o options, v string) (options, error) {
o.serviceControlAction = v

View File

@@ -82,6 +82,23 @@ func TestParseBindPort(t *testing.T) {
testParseErr(t, "port too high", "-p", "18446744073709551617") // 2^64 + 1
}
func TestParseBindAddr(t *testing.T) {
wantAddrPort := netip.MustParseAddrPort("1.2.3.4:8089")
assert.Zero(t, testParseOK(t).bindAddr, "empty is not web-addr")
assert.Equal(t, wantAddrPort, testParseOK(t, "--web-addr", "1.2.3.4:8089").bindAddr)
assert.Equal(t, netip.MustParseAddrPort("1.2.3.4:0"), testParseOK(t, "--web-addr", "1.2.3.4:0").bindAddr)
testParseParamMissing(t, "-web-addr")
testParseErr(t, "not an int", "--web-addr", "1.2.3.4:x")
testParseErr(t, "hex not supported", "--web-addr", "1.2.3.4:0x100")
testParseErr(t, "port negative", "--web-addr", "1.2.3.4:-1")
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:65536")
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:4294967297") // 2^32 + 1
testParseErr(t, "port too high", "--web-addr", "1.2.3.4:18446744073709551617") // 2^64 + 1
}
func TestParseLogfile(t *testing.T) {
assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
@@ -162,6 +179,10 @@ func TestOptsToArgs(t *testing.T) {
name: "bind_port",
args: []string{"-p", "666"},
opts: options{bindPort: 666},
}, {
name: "web-addr",
args: []string{"--web-addr", "1.2.3.4:8080"},
opts: options{bindAddr: netip.MustParseAddrPort("1.2.3.4:8080")},
}, {
name: "log_file",
args: []string{"-l", "path"},

View File

@@ -228,12 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
for _, tc := range testCases {
w.Reset()
cc := &clientsContainer{
list: map[string]*Client{},
idIndex: map[string]*Client{},
ipToRC: map[netip.Addr]*RuntimeClient{},
allTags: stringutil.NewSet(),
}
cc := newClientsContainer(t)
ch := make(chan netip.Addr)
rdns := &RDNS{
exchanger: &rDNSExchanger{

View File

@@ -320,7 +320,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
if setts.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
@@ -407,7 +407,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),

View File

@@ -3,6 +3,7 @@ package home
import (
"bytes"
"fmt"
"net/netip"
"net/url"
"os"
"path"
@@ -22,7 +23,7 @@ import (
)
// currentSchemaVersion is the current schema version.
const currentSchemaVersion = 20
const currentSchemaVersion = 23
// These aliases are provided for convenience.
type (
@@ -94,6 +95,9 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
upgradeSchema17to18,
upgradeSchema18to19,
upgradeSchema19to20,
upgradeSchema20to21,
upgradeSchema21to22,
upgradeSchema22to23,
}
n := 0
@@ -1128,6 +1132,199 @@ func upgradeSchema19to20(diskConf yobj) (err error) {
return nil
}
// upgradeSchema20to21 performs the following changes:
//
// # BEFORE:
// 'dns':
// 'blocked_services':
// - 'svc_name'
//
// # AFTER:
// 'dns':
// 'blocked_services':
// 'ids':
// - 'svc_name'
// 'schedule':
// 'time_zone': 'Local'
func upgradeSchema20to21(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 20 to 21")
diskConf["schema_version"] = 21
const field = "blocked_services"
dnsVal, ok := diskConf["dns"]
if !ok {
return nil
}
dns, ok := dnsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of dns: %T", dnsVal)
}
blockedVal, ok := dns[field]
if !ok {
return nil
}
services, ok := blockedVal.(yarr)
if !ok {
return fmt.Errorf("unexpected type of blocked: %T", blockedVal)
}
dns[field] = yobj{
"ids": services,
"schedule": yobj{
"time_zone": "Local",
},
}
return nil
}
// upgradeSchema21to22 performs the following changes:
//
// # BEFORE:
// 'persistent':
// - 'name': 'client_name'
// 'blocked_services':
// - 'svc_name'
//
// # AFTER:
// 'persistent':
// - 'name': 'client_name'
// 'blocked_services':
// 'ids':
// - 'svc_name'
// 'schedule':
// 'time_zone': 'Local'
func upgradeSchema21to22(diskConf yobj) (err error) {
log.Println("Upgrade yaml: 21 to 22")
diskConf["schema_version"] = 22
const field = "blocked_services"
clientsVal, ok := diskConf["clients"]
if !ok {
return nil
}
clients, ok := clientsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
}
persistentVal, ok := clients["persistent"]
if !ok {
return nil
}
persistent, ok := persistentVal.([]any)
if !ok {
return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal)
}
for i, val := range persistent {
var c yobj
c, ok = val.(yobj)
if !ok {
return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val)
}
var blockedVal any
blockedVal, ok = c[field]
if !ok {
continue
}
var services yarr
services, ok = blockedVal.(yarr)
if !ok {
return fmt.Errorf(
"persistent client at index %d: unexpected type of blocked services: %T",
i,
blockedVal,
)
}
c[field] = yobj{
"ids": services,
"schedule": yobj{
"time_zone": "Local",
},
}
}
return nil
}
// upgradeSchema22to23 performs the following changes:
//
// # BEFORE:
// 'bind_host': '1.2.3.4'
// 'bind_port': 8080
// 'web_session_ttl': 720
//
// # AFTER:
// 'http':
// 'address': '1.2.3.4:8080'
// 'session_ttl': '720h'
func upgradeSchema22to23(diskConf yobj) (err error) {
log.Printf("Upgrade yaml: 22 to 23")
diskConf["schema_version"] = 23
bindHostVal, ok := diskConf["bind_host"]
if !ok {
return nil
}
bindHost, ok := bindHostVal.(string)
if !ok {
return fmt.Errorf("unexpected type of bind_host: %T", bindHostVal)
}
bindHostAddr, err := netip.ParseAddr(bindHost)
if err != nil {
return fmt.Errorf("invalid bind_host value: %s", bindHost)
}
bindPortVal, ok := diskConf["bind_port"]
if !ok {
return nil
}
bindPort, ok := bindPortVal.(int)
if !ok {
return fmt.Errorf("unexpected type of bind_port: %T", bindPortVal)
}
sessionTTLVal, ok := diskConf["web_session_ttl"]
if !ok {
return nil
}
sessionTTL, ok := sessionTTLVal.(int)
if !ok {
return fmt.Errorf("unexpected type of web_session_ttl: %T", sessionTTLVal)
}
addr := netip.AddrPortFrom(bindHostAddr, uint16(bindPort))
if !addr.IsValid() {
return fmt.Errorf("invalid address: %s", addr)
}
diskConf["http"] = yobj{
"address": addr.String(),
"session_ttl": timeutil.Duration{Duration: time.Duration(sessionTTL) * time.Hour}.String(),
}
delete(diskConf, "bind_host")
delete(diskConf, "bind_port")
delete(diskConf, "web_session_ttl")
return nil
}
// TODO(a.garipov): Replace with log.Output when we port it to our logging
// package.
func funcName() string {

View File

@@ -1140,3 +1140,169 @@ func TestUpgradeSchema19to20(t *testing.T) {
assert.Equal(t, 24*time.Hour, ivlVal.Duration)
})
}
func TestUpgradeSchema20to21(t *testing.T) {
const newSchemaVer = 21
testCases := []struct {
in yobj
want yobj
name string
}{{
name: "nothing",
in: yobj{},
want: yobj{
"schema_version": newSchemaVer,
},
}, {
name: "no_clients",
in: yobj{
"dns": yobj{
"blocked_services": yarr{"ok"},
},
},
want: yobj{
"dns": yobj{
"blocked_services": yobj{
"ids": yarr{"ok"},
"schedule": yobj{
"time_zone": "Local",
},
},
},
"schema_version": newSchemaVer,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema20to21(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema21to22(t *testing.T) {
const newSchemaVer = 22
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{
"clients": yobj{},
},
want: yobj{
"clients": yobj{},
"schema_version": newSchemaVer,
},
name: "nothing",
}, {
in: yobj{
"clients": yobj{
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}},
},
},
want: yobj{
"clients": yobj{
"persistent": []any{yobj{
"name": "localhost",
"blocked_services": yobj{
"ids": yarr{},
"schedule": yobj{
"time_zone": "Local",
},
},
}},
},
"schema_version": newSchemaVer,
},
name: "no_services",
}, {
in: yobj{
"clients": yobj{
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}},
},
},
want: yobj{
"clients": yobj{
"persistent": []any{yobj{
"name": "localhost",
"blocked_services": yobj{
"ids": yarr{"ok"},
"schedule": yobj{
"time_zone": "Local",
},
},
}},
},
"schema_version": newSchemaVer,
},
name: "services",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema21to22(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema22to23(t *testing.T) {
const newSchemaVer = 23
testCases := []struct {
in yobj
want yobj
name string
}{{
name: "empty",
in: yobj{},
want: yobj{
"schema_version": newSchemaVer,
},
}, {
name: "ok",
in: yobj{
"bind_host": "1.2.3.4",
"bind_port": 8081,
"web_session_ttl": 720,
},
want: yobj{
"http": yobj{
"address": "1.2.3.4:8081",
"session_ttl": "720h",
},
"schema_version": newSchemaVer,
},
}, {
name: "v6_address",
in: yobj{
"bind_host": "2001:db8::1",
"bind_port": 8081,
"web_session_ttl": 720,
},
want: yobj{
"http": yobj{
"address": "[2001:db8::1]:8081",
"session_ttl": "720h",
},
"schema_version": newSchemaVer,
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema22to23(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}

View File

@@ -119,7 +119,9 @@ func webCheckPortAvailable(port int) (ok bool) {
return true
}
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
return aghnet.CheckPort("tcp", addrPort) == nil
}
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server

View File

@@ -1,259 +0,0 @@
package home
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
)
const (
defaultServer = "whois.arin.net"
defaultPort = "43"
maxValueLength = 250
whoisTTL = 1 * 60 * 60 // 1 hour
)
// WHOIS - module context
type WHOIS struct {
clients *clientsContainer
ipChan chan netip.Addr
// dialContext specifies the dial function for creating unencrypted TCP
// connections.
dialContext func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// Contains IP addresses of clients
// An active IP address is resolved once again after it expires.
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
ipAddrs cache.Cache
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
timeoutMsec uint
}
// initWHOIS creates the WHOIS module context.
func initWHOIS(clients *clientsContainer) *WHOIS {
w := WHOIS{
timeoutMsec: 5000,
clients: clients,
ipAddrs: cache.New(cache.Config{
EnableLRU: true,
MaxCount: 10000,
}),
dialContext: customDialContext,
ipChan: make(chan netip.Addr, 255),
}
go w.workerLoop()
return &w
}
// If the value is too large - cut it and append "..."
func trimValue(s string) string {
if len(s) <= maxValueLength {
return s
}
return s[:maxValueLength-3] + "..."
}
// isWHOISComment returns true if the string is empty or is a WHOIS comment.
func isWHOISComment(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' || s[0] == '%'
}
// strmap is an alias for convenience.
type strmap = map[string]string
// whoisParse parses a subset of plain-text data from the WHOIS response into
// a string map.
func whoisParse(data string) (m strmap) {
m = strmap{}
var orgname string
lines := strings.Split(data, "\n")
for _, l := range lines {
if isWHOISComment(l) {
continue
}
kv := strings.SplitN(l, ":", 2)
if len(kv) != 2 {
continue
}
k := strings.ToLower(strings.TrimSpace(kv[0]))
v := strings.TrimSpace(kv[1])
if v == "" {
continue
}
switch k {
case "orgname", "org-name":
k = "orgname"
v = trimValue(v)
orgname = v
case "city", "country":
v = trimValue(v)
case "descr", "netname":
k = "orgname"
v = stringutil.Coalesce(orgname, v)
orgname = v
case "whois":
k = "whois"
case "referralserver":
k = "whois"
v = strings.TrimPrefix(v, "whois://")
default:
continue
}
m[k] = v
}
return m
}
// MaxConnReadSize is an upper limit in bytes for reading from net.Conn.
const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response
func (w *WHOIS) query(ctx context.Context, target, serverAddr string) (data string, err error) {
addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" {
target = "n + " + target
}
conn, err := w.dialContext(ctx, "tcp", serverAddr)
if err != nil {
return "", err
}
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
r, err := aghio.LimitReader(conn, MaxConnReadSize)
if err != nil {
return "", err
}
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n"))
if err != nil {
return "", err
}
// This use of ReadAll is now safe, because we limited the conn Reader.
var whoisData []byte
whoisData, err = io.ReadAll(r)
if err != nil {
return "", err
}
return string(whoisData), nil
}
// Query WHOIS servers (handle redirects)
func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
server := net.JoinHostPort(defaultServer, defaultPort)
const maxRedirects = 5
for i := 0; i != maxRedirects; i++ {
resp, err := w.query(ctx, target, server)
if err != nil {
return "", err
}
log.Debug("whois: received response (%d bytes) from %s IP:%s", len(resp), server, target)
m := whoisParse(resp)
redir, ok := m["whois"]
if !ok {
return resp, nil
}
redir = strings.ToLower(redir)
_, _, err = net.SplitHostPort(redir)
if err != nil {
server = net.JoinHostPort(redir, defaultPort)
} else {
server = redir
}
log.Debug("whois: redirected to %s IP:%s", redir, target)
}
return "", fmt.Errorf("whois: redirect loop")
}
// Request WHOIS information
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
resp, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("whois: error: %s IP:%s", err, ip)
return nil
}
log.Debug("whois: IP:%s response: %d bytes", ip, len(resp))
m := whoisParse(resp)
wi = &RuntimeClientWHOISInfo{
City: m["city"],
Country: m["country"],
Orgname: m["orgname"],
}
// Don't return an empty struct so that the frontend doesn't get
// confused.
if *wi == (RuntimeClientWHOISInfo{}) {
return nil
}
return wi
}
// Begin - begin requesting WHOIS info
func (w *WHOIS) Begin(ip netip.Addr) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get(ipBytes)
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
return
}
}
expire = make([]byte, 8)
binary.BigEndian.PutUint64(expire, now+whoisTTL)
_ = w.ipAddrs.Set(ipBytes, expire)
log.Debug("whois: adding %s", ip)
select {
case w.ipChan <- ip:
default:
log.Debug("whois: queue is full")
}
}
// workerLoop processes the IP addresses it got from the channel and associates
// the retrieving WHOIS info with a client.
func (w *WHOIS) workerLoop() {
for ip := range w.ipChan {
info := w.process(context.Background(), ip)
if info == nil {
continue
}
w.clients.setWHOISInfo(ip, info)
}
}

View File

@@ -1,152 +0,0 @@
package home
import (
"context"
"io"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// fakeConn is a mock implementation of net.Conn to simplify testing.
//
// TODO(e.burkov): Search for other places in code where it may be used. Move
// into aghtest then.
type fakeConn struct {
// Conn is embedded here simply to make *fakeConn a net.Conn without
// actually implementing all methods.
net.Conn
data []byte
}
// Write implements net.Conn interface for *fakeConn. It always returns 0 and a
// nil error without mutating the slice.
func (c *fakeConn) Write(_ []byte) (n int, err error) {
return 0, nil
}
// Read implements net.Conn interface for *fakeConn. It puts the content of
// c.data field into b up to the b's capacity.
func (c *fakeConn) Read(b []byte) (n int, err error) {
return copy(b, c.data), io.EOF
}
// Close implements net.Conn interface for *fakeConn. It always returns nil.
func (c *fakeConn) Close() (err error) {
return nil
}
// SetReadDeadline implements net.Conn interface for *fakeConn. It always
// returns nil.
func (c *fakeConn) SetReadDeadline(_ time.Time) (err error) {
return nil
}
// fakeDial is a mock implementation of customDialContext to simplify testing.
func (c *fakeConn) fakeDial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return c, nil
}
func TestWHOIS(t *testing.T) {
const (
nl = "\n"
data = `OrgName: FakeOrg LLC` + nl +
`City: Nonreal` + nl +
`Country: Imagiland` + nl
)
fc := &fakeConn{
data: []byte(data),
}
w := WHOIS{
timeoutMsec: 5000,
dialContext: fc.fakeDial,
}
resp, err := w.queryAll(context.Background(), "1.2.3.4")
assert.NoError(t, err)
m := whoisParse(resp)
require.NotEmpty(t, m)
assert.Equal(t, "FakeOrg LLC", m["orgname"])
assert.Equal(t, "Imagiland", m["country"])
assert.Equal(t, "Nonreal", m["city"])
}
func TestWHOISParse(t *testing.T) {
const (
city = "Nonreal"
country = "Imagiland"
orgname = "FakeOrgLLC"
whois = "whois.example.net"
)
testCases := []struct {
want strmap
name string
in string
}{{
want: strmap{},
name: "empty",
in: ``,
}, {
want: strmap{},
name: "comments",
in: "%\n#",
}, {
want: strmap{},
name: "no_colon",
in: "city",
}, {
want: strmap{},
name: "no_value",
in: "city:",
}, {
want: strmap{"city": city},
name: "city",
in: `city: ` + city,
}, {
want: strmap{"country": country},
name: "country",
in: `country: ` + country,
}, {
want: strmap{"orgname": orgname},
name: "orgname",
in: `orgname: ` + orgname,
}, {
want: strmap{"orgname": orgname},
name: "orgname_hyphen",
in: `org-name: ` + orgname,
}, {
want: strmap{"orgname": orgname},
name: "orgname_descr",
in: `descr: ` + orgname,
}, {
want: strmap{"orgname": orgname},
name: "orgname_netname",
in: `netname: ` + orgname,
}, {
want: strmap{"whois": whois},
name: "whois",
in: `whois: ` + whois,
}, {
want: strmap{"whois": whois},
name: "referralserver",
in: `referralserver: whois://` + whois,
}, {
want: strmap{},
name: "other",
in: `other: value`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := whoisParse(tc.in)
assert.Equal(t, tc.want, got)
})
}
}