all: sync with master

This commit is contained in:
Ainar Garipov
2024-01-30 18:43:51 +03:00
parent f6ad64bf69
commit b01c10b73e
196 changed files with 3190 additions and 1790 deletions

View File

@@ -1,19 +1,53 @@
package home
import (
"encoding"
"fmt"
"net"
"net/netip"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/uuid"
"golang.org/x/exp/slices"
)
// Client contains information about persistent clients.
type Client struct {
// UID is the type for the unique IDs of persistent clients.
type UID uuid.UUID
// NewUID returns a new persistent client UID. Any error returned is an error
// from the cryptographic randomness reader.
func NewUID() (uid UID, err error) {
uuidv7, err := uuid.NewV7()
return UID(uuidv7), err
}
// type check
var _ encoding.TextMarshaler = UID{}
// MarshalText implements the [encoding.TextMarshaler] for UID.
func (uid UID) MarshalText() ([]byte, error) {
return uuid.UUID(uid).MarshalText()
}
// type check
var _ encoding.TextUnmarshaler = (*UID)(nil)
// UnmarshalText implements the [encoding.TextUnmarshaler] interface for UID.
func (uid *UID) UnmarshalText(data []byte) error {
return (*uuid.UUID)(uid).UnmarshalText(data)
}
// persistentClient contains information about persistent clients.
type persistentClient struct {
// upstreamConfig is the custom upstream configuration for this client. If
// it's nil, it has not been initialized yet. If it's non-nil and empty,
// there are no valid upstreams. If it's non-nil and non-empty, these
@@ -29,10 +63,18 @@ type Client struct {
Name string
IDs []string
Tags []string
Upstreams []string
IPs []netip.Addr
// TODO(s.chzhen): Use netutil.Prefix.
Subnets []netip.Prefix
MACs []net.HardwareAddr
ClientIDs []string
// UID is the unique identifier of the persistent client.
UID UID
UpstreamsCacheSize uint32
UpstreamsCacheEnabled bool
@@ -45,21 +87,153 @@ type Client struct {
IgnoreStatistics bool
}
// ShallowClone returns a deep copy of the client, except upstreamConfig,
// setTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)
continue
}
c.Tags = append(c.Tags, t)
}
slices.Sort(c.Tags)
}
// setIDs parses a list of strings into typed fields and returns an error if
// there is one.
func (c *persistentClient) setIDs(ids []string) (err error) {
for _, id := range ids {
err = c.setID(id)
if err != nil {
return err
}
}
slices.SortFunc(c.IPs, netip.Addr.Compare)
// TODO(s.chzhen): Use netip.PrefixCompare in Go 1.23.
slices.SortFunc(c.Subnets, subnetCompare)
slices.SortFunc(c.MACs, slices.Compare[net.HardwareAddr])
slices.Sort(c.ClientIDs)
return nil
}
// subnetCompare is a comparison function for the two subnets. It returns -1 if
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
// position is the same.
func subnetCompare(x, y netip.Prefix) (cmp int) {
if x == y {
return 0
}
xAddr, xBits := x.Addr(), x.Bits()
yAddr, yBits := y.Addr(), y.Bits()
if xBits == yBits {
return xAddr.Compare(yAddr)
}
if xBits > yBits {
return -1
} else {
return 1
}
}
// setID parses id into typed field if there is no error.
func (c *persistentClient) setID(id string) (err error) {
if id == "" {
return errors.Error("clientid is empty")
}
var ip netip.Addr
if ip, err = netip.ParseAddr(id); err == nil {
c.IPs = append(c.IPs, ip)
return nil
}
var subnet netip.Prefix
if subnet, err = netip.ParsePrefix(id); err == nil {
c.Subnets = append(c.Subnets, subnet)
return nil
}
var mac net.HardwareAddr
if mac, err = net.ParseMAC(id); err == nil {
c.MACs = append(c.MACs, mac)
return nil
}
err = dnsforward.ValidateClientID(id)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
return nil
}
// ids returns a list of client ids containing at least one element.
func (c *persistentClient) ids() (ids []string) {
ids = make([]string, 0, c.idsLen())
for _, ip := range c.IPs {
ids = append(ids, ip.String())
}
for _, subnet := range c.Subnets {
ids = append(ids, subnet.String())
}
for _, mac := range c.MACs {
ids = append(ids, mac.String())
}
return append(ids, c.ClientIDs...)
}
// idsLen returns a length of client ids.
func (c *persistentClient) idsLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
}
// equalIDs returns true if the ids of the current and previous clients are the
// same.
func (c *persistentClient) equalIDs(prev *persistentClient) (equal bool) {
return slices.Equal(c.IPs, prev.IPs) &&
slices.Equal(c.Subnets, prev.Subnets) &&
slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) &&
slices.Equal(c.ClientIDs, prev.ClientIDs)
}
// shallowClone returns a deep copy of the client, except upstreamConfig,
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
func (c *Client) ShallowClone() (sh *Client) {
clone := *c
func (c *persistentClient) shallowClone() (clone *persistentClient) {
clone = &persistentClient{}
*clone = *c
clone.BlockedServices = c.BlockedServices.Clone()
clone.IDs = stringutil.CloneSlice(c.IDs)
clone.Tags = stringutil.CloneSlice(c.Tags)
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
clone.Tags = slices.Clone(c.Tags)
clone.Upstreams = slices.Clone(c.Upstreams)
return &clone
clone.IPs = slices.Clone(c.IPs)
clone.Subnets = slices.Clone(c.Subnets)
clone.MACs = slices.Clone(c.MACs)
clone.ClientIDs = slices.Clone(c.ClientIDs)
return clone
}
// closeUpstreams closes the client-specific upstream config of c if any.
func (c *Client) closeUpstreams() (err error) {
func (c *persistentClient) closeUpstreams() (err error) {
if c.upstreamConfig != nil {
if err = c.upstreamConfig.Close(); err != nil {
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
@@ -70,7 +244,7 @@ func (c *Client) closeUpstreams() (err error) {
}
// setSafeSearch initializes and sets the safe search filter for this client.
func (c *Client) setSafeSearch(
func (c *persistentClient) setSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
@@ -85,17 +259,3 @@ func (c *Client) setSafeSearch(
return nil
}
// RuntimeClient is a client information about which has been obtained using the
// source described in the Source field.
type RuntimeClient struct {
// WHOIS is the filtered WHOIS data of a client.
WHOIS *whois.Info
// Host is the host name of a client.
Host string
// Source is the source from which the information about the client has
// been obtained.
Source client.Source
}

View File

@@ -0,0 +1,124 @@
package home
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPersistentClient_EqualIDs(t *testing.T) {
const (
ip = "0.0.0.0"
ip1 = "1.1.1.1"
ip2 = "2.2.2.2"
cidr = "0.0.0.0/0"
cidr1 = "1.1.1.1/11"
cidr2 = "2.2.2.2/22"
mac = "00-00-00-00-00-00"
mac1 = "11-11-11-11-11-11"
mac2 = "22-22-22-22-22-22"
cli = "client0"
cli1 = "client1"
cli2 = "client2"
)
testCases := []struct {
name string
ids []string
prevIDs []string
want assert.BoolAssertionFunc
}{{
name: "single_ip",
ids: []string{ip1},
prevIDs: []string{ip1},
want: assert.True,
}, {
name: "single_ip_not_equal",
ids: []string{ip1},
prevIDs: []string{ip2},
want: assert.False,
}, {
name: "ips_not_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip1, ip},
want: assert.False,
}, {
name: "ips_mixed_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip2, ip1},
want: assert.True,
}, {
name: "single_subnet",
ids: []string{cidr1},
prevIDs: []string{cidr1},
want: assert.True,
}, {
name: "subnets_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{ip1, ip2, cidr1, cidr},
want: assert.False,
}, {
name: "subnets_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{cidr2, cidr1, ip2, ip1},
want: assert.True,
}, {
name: "single_mac",
ids: []string{mac1},
prevIDs: []string{mac1},
want: assert.True,
}, {
name: "single_mac_not_equal",
ids: []string{mac1},
prevIDs: []string{mac2},
want: assert.False,
}, {
name: "macs_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac},
want: assert.False,
}, {
name: "macs_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{mac2, mac1, cidr2, cidr1, ip2, ip1},
want: assert.True,
}, {
name: "single_client_id",
ids: []string{cli1},
prevIDs: []string{cli1},
want: assert.True,
}, {
name: "single_client_id_not_equal",
ids: []string{cli1},
prevIDs: []string{cli2},
want: assert.False,
}, {
name: "client_ids_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli},
want: assert.False,
}, {
name: "client_ids_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{cli2, cli1, mac2, mac1, cidr2, cidr1, ip2, ip1},
want: assert.True,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := &persistentClient{}
err := c.setIDs(tc.ids)
require.NoError(t, err)
prev := &persistentClient{}
err = prev.setIDs(tc.prevIDs)
require.NoError(t, err)
tc.want(t, c.equalIDs(prev))
})
}
}

View File

@@ -1,7 +1,6 @@
package home
import (
"bytes"
"fmt"
"net"
"net/netip"
@@ -20,6 +19,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
@@ -47,11 +47,11 @@ type DHCP interface {
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different
// types (string, netip.Addr, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client
list map[string]*persistentClient // name -> client
idIndex map[string]*persistentClient // ID -> client
// ipToRC is the IP address to *RuntimeClient map.
ipToRC map[netip.Addr]*RuntimeClient
// ipToRC maps IP addresses to runtime client information.
ipToRC map[netip.Addr]*client.Runtime
allTags *stringutil.Set
@@ -102,9 +102,9 @@ func (clients *clientsContainer) Init(
log.Fatal("clients.list != nil")
}
clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client)
clients.ipToRC = map[netip.Addr]*RuntimeClient{}
clients.list = map[string]*persistentClient{}
clients.idIndex = map[string]*persistentClient{}
clients.ipToRC = map[netip.Addr]*client.Runtime{}
clients.allTags = stringutil.NewSet(clientTags...)
@@ -139,6 +139,9 @@ func (clients *clientsContainer) Init(
return nil
}
// handleHostsUpdates receives the updates from the hosts container and adds
// them to the clients container. It's used to be called in a separate
// goroutine.
func (clients *clientsContainer) handleHostsUpdates() {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
@@ -185,6 +188,9 @@ type clientObject struct {
Tags []string `yaml:"tags"`
Upstreams []string `yaml:"upstreams"`
// UID is the unique identifier of the persistent client.
UID UID `yaml:"uid"`
// UpstreamsCacheSize is the DNS cache size (in bytes).
//
// TODO(d.kolyshev): Use [datasize.Bytesize].
@@ -203,66 +209,83 @@ type clientObject struct {
IgnoreStatistics bool `yaml:"ignore_statistics"`
}
// toPersistent returns an initialized persistent client if there are no errors.
func (o *clientObject) toPersistent(
filteringConf *filtering.Config,
allTags *stringutil.Set,
) (cli *persistentClient, err error) {
cli = &persistentClient{
Name: o.Name,
Upstreams: o.Upstreams,
UID: o.UID,
UseOwnSettings: !o.UseGlobalSettings,
FilteringEnabled: o.FilteringEnabled,
ParentalEnabled: o.ParentalEnabled,
safeSearchConf: o.SafeSearchConf,
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
IgnoreQueryLog: o.IgnoreQueryLog,
IgnoreStatistics: o.IgnoreStatistics,
UpstreamsCacheEnabled: o.UpstreamsCacheEnabled,
UpstreamsCacheSize: o.UpstreamsCacheSize,
}
err = cli.setIDs(o.IDs)
if err != nil {
return nil, fmt.Errorf("parsing ids: %w", err)
}
if (cli.UID == UID{}) {
cli.UID, err = NewUID()
if err != nil {
return nil, fmt.Errorf("generating uid: %w", err)
}
}
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err = cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
)
if err != nil {
return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err)
}
}
err = o.BlockedServices.Validate()
if err != nil {
return nil, fmt.Errorf("init blocked services %q: %w", cli.Name, err)
}
cli.BlockedServices = o.BlockedServices.Clone()
cli.setTags(o.Tags, allTags)
return cli, nil
}
// addFromConfig initializes the clients container with objects from the
// configuration file.
func (clients *clientsContainer) addFromConfig(
objects []*clientObject,
filteringConf *filtering.Config,
) (err error) {
for _, o := range objects {
cli := &Client{
Name: o.Name,
IDs: o.IDs,
Upstreams: o.Upstreams,
UseOwnSettings: !o.UseGlobalSettings,
FilteringEnabled: o.FilteringEnabled,
ParentalEnabled: o.ParentalEnabled,
safeSearchConf: o.SafeSearchConf,
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
IgnoreQueryLog: o.IgnoreQueryLog,
IgnoreStatistics: o.IgnoreStatistics,
UpstreamsCacheEnabled: o.UpstreamsCacheEnabled,
UpstreamsCacheSize: o.UpstreamsCacheSize,
}
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err = cli.setSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
)
if err != nil {
log.Error("clients: init client safesearch %q: %s", cli.Name, err)
continue
}
}
err = o.BlockedServices.Validate()
for i, o := range objects {
var cli *persistentClient
cli, err = o.toPersistent(filteringConf, clients.allTags)
if err != nil {
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
}
cli.BlockedServices = o.BlockedServices.Clone()
for _, t := range o.Tags {
if clients.allTags.Has(t) {
cli.Tags = append(cli.Tags, t)
} else {
log.Info("clients: skipping unknown tag %q", t)
}
}
slices.Sort(cli.Tags)
_, err = clients.Add(cli)
_, err = clients.add(cli)
if err != nil {
log.Error("clients: adding clients %s: %s", cli.Name, err)
log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err)
}
}
@@ -282,10 +305,12 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
BlockedServices: cli.BlockedServices.Clone(),
IDs: stringutil.CloneSlice(cli.IDs),
IDs: cli.ids(),
Tags: stringutil.CloneSlice(cli.Tags),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
UID: cli.UID,
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
@@ -338,7 +363,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
rc, ok := clients.ipToRC[ip]
if ok {
src = rc.Source
src, _ = rc.Info()
}
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
@@ -348,10 +373,10 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return src
}
// findMultiple is a wrapper around Find to make it a valid client finder for
// the query log. c is never nil; if no information about the client is found,
// it returns an artificial client record by only setting the blocking-related
// fields. err is always nil.
// findMultiple is a wrapper around [clientsContainer.find] to make it a valid
// client finder for the query log. c is never nil; if no information about the
// client is found, it returns an artificial client record by only setting the
// blocking-related fields. err is always nil.
func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client, err error) {
var artClient *querylog.Client
var art bool
@@ -385,20 +410,22 @@ func (clients *clientsContainer) clientOrArtificial(
}
}()
client, ok := clients.Find(id)
cli, ok := clients.find(id)
if ok {
return &querylog.Client{
Name: client.Name,
IgnoreQueryLog: client.IgnoreQueryLog,
Name: cli.Name,
IgnoreQueryLog: cli.IgnoreQueryLog,
}, false
}
var rc *RuntimeClient
var rc *client.Runtime
rc, ok = clients.findRuntimeClient(ip)
if ok {
_, host := rc.Info()
return &querylog.Client{
Name: rc.Host,
WHOIS: rc.WHOIS,
Name: host,
WHOIS: rc.WHOIS(),
}, false
}
@@ -407,8 +434,8 @@ func (clients *clientsContainer) clientOrArtificial(
}, true
}
// Find returns a shallow copy of the client if there is one found.
func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
// find returns a shallow copy of the client if there is one found.
func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -417,12 +444,12 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
return nil, false
}
return c.ShallowClone(), true
return c.shallowClone(), true
}
// shouldCountClient is a wrapper around Find to make it a valid client
// information finder for the statistics. If no information about the client
// is found, it returns true.
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
// valid client information finder for the statistics. If no information about
// the client is found, it returns true.
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -490,7 +517,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
// findLocked searches for a client by its ID. clients.lock is expected to be
// locked.
func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
c, ok = clients.idIndex[id]
if ok {
return c, true
@@ -502,13 +529,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
}
for _, c = range clients.list {
for _, id := range c.IDs {
var subnet netip.Prefix
subnet, err = netip.ParsePrefix(id)
if err != nil {
continue
}
for _, subnet := range c.Subnets {
if subnet.Contains(ip) {
return c, true
}
@@ -521,22 +542,16 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
// findDHCP searches for a client by its MAC, if the DHCP server is active and
// there is such client. clients.lock is expected to be locked.
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, ok bool) {
foundMAC := clients.dhcp.MACByIP(ip)
if foundMAC == nil {
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
mac, err := net.ParseMAC(id)
if err != nil {
continue
}
if bytes.Equal(mac, foundMAC) {
return c, true
}
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
if found {
return c, true
}
}
@@ -545,7 +560,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
// runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
if ip == (netip.Addr{}) {
return nil, false
}
@@ -559,52 +574,43 @@ func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
if rc, ok = clients.runtimeClient(ip); ok && rc.Source > client.SourceDHCP {
return rc, ok
}
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
rc, ok = clients.runtimeClient(ip)
host := clients.dhcp.HostByIP(ip)
if host == "" {
return rc, ok
if host != "" {
if !ok {
rc = &client.Runtime{}
}
rc.SetInfo(client.SourceDHCP, []string{host})
return rc, true
}
return &RuntimeClient{
Host: host,
Source: client.SourceDHCP,
WHOIS: &whois.Info{},
}, true
return rc, ok
}
// check validates the client.
func (clients *clientsContainer) check(c *Client) (err error) {
// check validates the client. It also sorts the client tags.
func (clients *clientsContainer) check(c *persistentClient) (err error) {
switch {
case c == nil:
return errors.Error("client is nil")
case c.Name == "":
return errors.Error("invalid name")
case len(c.IDs) == 0:
case c.idsLen() == 0:
return errors.Error("id required")
default:
// Go on.
}
for i, id := range c.IDs {
var norm string
norm, err = normalizeClientIdentifier(id)
if err != nil {
return fmt.Errorf("client at index %d: %w", i, err)
}
c.IDs[i] = norm
}
for _, t := range c.Tags {
if !clients.allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
err = dnsforward.ValidateUpstreams(c.Upstreams)
@@ -615,38 +621,9 @@ func (clients *clientsContainer) check(c *Client) (err error) {
return nil
}
// normalizeClientIdentifier returns a normalized version of idStr. If idStr
// cannot be normalized, it returns an error.
func normalizeClientIdentifier(idStr string) (norm string, err error) {
if idStr == "" {
return "", errors.Error("clientid is empty")
}
var ip netip.Addr
if ip, err = netip.ParseAddr(idStr); err == nil {
return ip.String(), nil
}
var subnet netip.Prefix
if subnet, err = netip.ParsePrefix(idStr); err == nil {
return subnet.String(), nil
}
var mac net.HardwareAddr
if mac, err = net.ParseMAC(idStr); err == nil {
return mac.String(), nil
}
if err = dnsforward.ValidateClientID(idStr); err == nil {
return strings.ToLower(idStr), nil
}
return "", fmt.Errorf("bad client identifier %q", idStr)
}
// Add adds a new client object. ok is false if such client already exists or
// add adds a new client object. ok is false if such client already exists or
// if an error occurred.
func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
err = clients.check(c)
if err != nil {
return false, err
@@ -662,50 +639,52 @@ func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
}
// check ID index
for _, id := range c.IDs {
var c2 *Client
ids := c.ids()
for _, id := range ids {
var c2 *persistentClient
c2, ok = clients.idIndex[id]
if ok {
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
}
clients.add(c)
clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
return true, nil
}
// add c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) add(c *Client) {
// addLocked c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) addLocked(c *persistentClient) {
// update Name index
clients.list[c.Name] = c
// update ID index
for _, id := range c.IDs {
for _, id := range c.ids() {
clients.idIndex[id] = c
}
}
// Del removes a client. ok is false if there is no such client.
func (clients *clientsContainer) Del(name string) (ok bool) {
// remove removes a client. ok is false if there is no such client.
func (clients *clientsContainer) remove(name string) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
var c *Client
var c *persistentClient
c, ok = clients.list[name]
if !ok {
return false
}
clients.del(c)
clients.removeLocked(c)
return true
}
// del removes c from the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) del(c *Client) {
// removeLocked removes c from the indexes. clients.lock is expected to be
// locked.
func (clients *clientsContainer) removeLocked(c *persistentClient) {
if err := c.closeUpstreams(); err != nil {
log.Error("client container: removing client %s: %s", c.Name, err)
}
@@ -714,13 +693,13 @@ func (clients *clientsContainer) del(c *Client) {
delete(clients.list, c.Name)
// Update the ID index.
for _, id := range c.IDs {
for _, id := range c.ids() {
delete(clients.idIndex, id)
}
}
// Update updates a client by its name.
func (clients *clientsContainer) Update(prev, c *Client) (err error) {
// update updates a client by its name.
func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
err = clients.check(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
@@ -738,18 +717,23 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
}
}
if c.equalIDs(prev) {
clients.removeLocked(prev)
clients.addLocked(c)
return nil
}
// Check the ID index.
if !slices.Equal(prev.IDs, c.IDs) {
for _, id := range c.IDs {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
for _, id := range c.ids() {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
}
clients.del(prev)
clients.add(c)
clients.removeLocked(prev)
clients.addLocked(c)
return nil
}
@@ -764,23 +748,20 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return
}
// TODO(e.burkov): Consider storing WHOIS information separately and
// potentially get rid of [RuntimeClient].
rc, ok := clients.ipToRC[ip]
if !ok {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &RuntimeClient{
Source: client.SourceWHOIS,
}
rc = &client.Runtime{}
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
host, _ := rc.Info()
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
}
rc.WHOIS = wi
rc.SetWHOIS(wi)
}
// addHost adds a new IP-hostname pairing. The priorities of the sources are
@@ -839,18 +820,13 @@ func (clients *clientsContainer) addHostLocked(
}
}
rc = &RuntimeClient{
WHOIS: &whois.Info{},
}
rc = &client.Runtime{}
clients.ipToRC[ip] = rc
} else if src < rc.Source {
return false
}
rc.Host = host
rc.Source = src
rc.SetInfo(src, []string{host})
log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC))
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC))
return true
}
@@ -859,7 +835,8 @@ func (clients *clientsContainer) addHostLocked(
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
n := 0
for ip, rc := range clients.ipToRC {
if rc.Source == src {
rc.Unset(src)
if rc.IsEmpty() {
delete(clients.ipToRC, ip)
n++
}
@@ -870,21 +847,24 @@ func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile(hosts aghnet.Hosts) {
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHostsBySrc(client.SourceHostsFile)
n := 0
for addr, rec := range hosts {
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
clients.addHostLocked(addr, rec[0].Names[0], client.SourceHostsFile)
n++
}
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
n++
}
return true
})
log.Debug("clients: added %d client aliases from system hosts file", n)
}
@@ -926,7 +906,7 @@ func (clients *clientsContainer) addFromSystemARP() {
// the persistent clients.
func (clients *clientsContainer) close() (err error) {
persistent := maps.Values(clients.list)
slices.SortFunc(persistent, func(a, b *Client) (res int) {
slices.SortFunc(persistent, func(a, b *persistentClient) (res int) {
return strings.Compare(a.Name, b.Name)
})

View File

@@ -60,63 +60,65 @@ func TestClients(t *testing.T) {
cli1 = "1.1.1.1"
cli2 = "2.2.2.2"
cliNoneIP = netip.MustParseAddr(cliNone)
cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)
cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)
cliIPv6 = netip.MustParseAddr("1:2:3::4")
)
c := &Client{
IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
c := &persistentClient{
Name: "client1",
IPs: []netip.Addr{cli1IP, cliIPv6},
}
ok, err := clients.Add(c)
ok, err := clients.add(c)
require.NoError(t, err)
assert.True(t, ok)
c = &Client{
IDs: []string{cli2},
c = &persistentClient{
Name: "client2",
IPs: []netip.Addr{cli2IP},
}
ok, err = clients.Add(c)
ok, err = clients.add(c)
require.NoError(t, err)
assert.True(t, ok)
c, ok = clients.Find(cli1)
c, ok = clients.find(cli1)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("1:2:3::4")
c, ok = clients.find("1:2:3::4")
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find(cli2)
c, ok = clients.find(cli2)
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.Equal(t, clients.clientSource(cliNoneIP), client.SourceNone)
_, ok = clients.find(cliNone)
assert.False(t, ok)
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
})
t.Run("add_fail_name", func(t *testing.T) {
ok, err := clients.Add(&Client{
IDs: []string{"1.2.3.5"},
ok, err := clients.add(&persistentClient{
Name: "client1",
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
})
require.NoError(t, err)
assert.False(t, ok)
})
t.Run("add_fail_ip", func(t *testing.T) {
ok, err := clients.Add(&Client{
IDs: []string{"2.2.2.2"},
ok, err := clients.add(&persistentClient{
Name: "client3",
})
require.Error(t, err)
@@ -124,8 +126,7 @@ func TestClients(t *testing.T) {
})
t.Run("update_fail_ip", func(t *testing.T) {
err := clients.Update(&Client{Name: "client1"}, &Client{
IDs: []string{"2.2.2.2"},
err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{
Name: "client1",
})
assert.Error(t, err)
@@ -136,33 +137,34 @@ func TestClients(t *testing.T) {
cliOld = "1.1.1.1"
cliNew = "1.1.1.2"
cliOldIP = netip.MustParseAddr(cliOld)
cliNewIP = netip.MustParseAddr(cliNew)
)
prev, ok := clients.list["client1"]
require.True(t, ok)
err := clients.Update(prev, &Client{
IDs: []string{cliNew},
err := clients.update(prev, &persistentClient{
Name: "client1",
IPs: []netip.Addr{cliNewIP},
})
require.NoError(t, err)
assert.Equal(t, clients.clientSource(cliOldIP), client.SourceNone)
_, ok = clients.find(cliOld)
assert.False(t, ok)
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.list["client1"]
require.True(t, ok)
err = clients.Update(prev, &Client{
IDs: []string{cliNew},
err = clients.update(prev, &persistentClient{
Name: "client1-renamed",
IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true,
})
require.NoError(t, err)
c, ok := clients.Find(cliNew)
c, ok := clients.find(cliNew)
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
@@ -173,20 +175,21 @@ func TestClients(t *testing.T) {
assert.Nil(t, nilCli)
require.Len(t, c.IDs, 1)
require.Len(t, c.ids(), 1)
assert.Equal(t, cliNew, c.IDs[0])
assert.Equal(t, cliNewIP, c.IPs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed")
ok := clients.remove("client1-renamed")
require.True(t, ok)
assert.Equal(t, clients.clientSource(netip.MustParseAddr("1.1.1.2")), client.SourceNone)
_, ok = clients.find("1.1.1.2")
assert.False(t, ok)
})
t.Run("del_fail", func(t *testing.T) {
ok := clients.Del("client3")
ok := clients.remove("client3")
assert.False(t, ok)
})
@@ -215,10 +218,12 @@ func TestClients(t *testing.T) {
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
})
t.Run("addhost_fail", func(t *testing.T) {
t.Run("addhost_priority", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.addHost(ip, "host1", client.SourceRDNS)
assert.False(t, ok)
assert.True(t, ok)
assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip))
})
}
@@ -235,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
assert.Equal(t, rc.WHOIS, whois)
assert.Equal(t, whois, rc.WHOIS())
})
t.Run("existing_auto-client", func(t *testing.T) {
@@ -247,15 +252,15 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
assert.Equal(t, rc.WHOIS, whois)
assert.Equal(t, whois, rc.WHOIS())
})
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2")
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.2"},
ok, err := clients.add(&persistentClient{
Name: "client1",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
})
require.NoError(t, err)
assert.True(t, ok)
@@ -264,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
rc := clients.ipToRC[ip]
require.Nil(t, rc)
assert.True(t, clients.Del("client1"))
assert.True(t, clients.remove("client1"))
})
}
@@ -275,9 +280,11 @@ func TestClientsAddExisting(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
ok, err := clients.Add(&Client{
IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
ok, err := clients.add(&persistentClient{
Name: "client1",
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
})
require.NoError(t, err)
assert.True(t, ok)
@@ -323,17 +330,17 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{
IDs: []string{ip.String()},
ok, err := clients.add(&persistentClient{
Name: "client2",
IPs: []netip.Addr{ip},
})
require.NoError(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP range.
ok, err = clients.Add(&Client{
IDs: []string{"2.2.2.2"},
ok, err = clients.add(&persistentClient{
Name: "client3",
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
})
require.NoError(t, err)
assert.True(t, ok)
@@ -344,9 +351,9 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
// Add client with upstreams.
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
ok, err := clients.add(&persistentClient{
Name: "client1",
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",

View File

@@ -61,6 +61,7 @@ type clientJSON struct {
UpstreamsCacheEnabled aghalg.NullBool `json:"upstreams_cache_enabled"`
}
// runtimeClientJSON is a JSON representation of the [client.Runtime].
type runtimeClientJSON struct {
WHOIS *whois.Info `json:"whois_info"`
@@ -69,12 +70,25 @@ type runtimeClientJSON struct {
Source client.Source `json:"source"`
}
// clientListJSON contains lists of persistent clients, runtime clients and also
// supported tags.
type clientListJSON struct {
Clients []*clientJSON `json:"clients"`
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
Tags []string `json:"supported_tags"`
}
// whoisOrEmpty returns a WHOIS client information or a pointer to an empty
// struct. Frontend expects a non-nil value.
func whoisOrEmpty(r *client.Runtime) (info *whois.Info) {
info = r.WHOIS()
if info != nil {
return info
}
return &whois.Info{}
}
// handleGetClients is the handler for GET /control/clients HTTP API.
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
data := clientListJSON{}
@@ -88,11 +102,11 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
}
for ip, rc := range clients.ipToRC {
src, host := rc.Info()
cj := runtimeClientJSON{
WHOIS: rc.WHOIS,
Name: rc.Host,
Source: rc.Source,
WHOIS: whoisOrEmpty(rc),
Name: host,
Source: src,
IP: ip,
}
@@ -115,32 +129,36 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
aghhttp.WriteJSONResponseOK(w, r, data)
}
// jsonToClient converts JSON object to Client object.
func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *Client, err error) {
safeSearchConf := copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
// initPrev initializes the persistent client with the default or previous
// client properties.
func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err error) {
var (
uid UID
ignoreQueryLog bool
ignoreStatistics bool
upsCacheEnabled bool
upsCacheSize uint32
)
if prev != nil {
uid = prev.UID
ignoreQueryLog = prev.IgnoreQueryLog
ignoreStatistics = prev.IgnoreStatistics
upsCacheEnabled = prev.UpstreamsCacheEnabled
upsCacheSize = prev.UpstreamsCacheSize
}
var ignoreQueryLog bool
if cj.IgnoreQueryLog != aghalg.NBNull {
ignoreQueryLog = cj.IgnoreQueryLog == aghalg.NBTrue
} else if prev != nil {
ignoreQueryLog = prev.IgnoreQueryLog
}
var ignoreStatistics bool
if cj.IgnoreStatistics != aghalg.NBNull {
ignoreStatistics = cj.IgnoreStatistics == aghalg.NBTrue
} else if prev != nil {
ignoreStatistics = prev.IgnoreStatistics
}
var upsCacheEnabled bool
var upsCacheSize uint32
if cj.UpstreamsCacheEnabled != aghalg.NBNull {
upsCacheEnabled = cj.UpstreamsCacheEnabled == aghalg.NBTrue
upsCacheSize = cj.UpstreamsCacheSize
} else if prev != nil {
upsCacheEnabled = prev.UpstreamsCacheEnabled
upsCacheSize = prev.UpstreamsCacheSize
}
svcs, err := copyBlockedServices(cj.Schedule, cj.BlockedServices, prev)
@@ -148,31 +166,54 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
return nil, fmt.Errorf("invalid blocked services: %w", err)
}
c = &Client{
safeSearchConf: safeSearchConf,
if (uid == UID{}) {
uid, err = NewUID()
if err != nil {
return nil, fmt.Errorf("generating uid: %w", err)
}
}
Name: cj.Name,
BlockedServices: svcs,
IDs: cj.IDs,
Tags: cj.Tags,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
return &persistentClient{
BlockedServices: svcs,
UID: uid,
IgnoreQueryLog: ignoreQueryLog,
IgnoreStatistics: ignoreStatistics,
UpstreamsCacheEnabled: upsCacheEnabled,
UpstreamsCacheSize: upsCacheSize,
}, nil
}
// jsonToClient converts JSON object to persistent client object if there are no
// errors.
func (clients *clientsContainer) jsonToClient(
cj clientJSON,
prev *persistentClient,
) (c *persistentClient, err error) {
c, err = initPrev(cj, prev)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
if safeSearchConf.Enabled {
err = c.setIDs(cj.IDs)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
c.safeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
c.Name = cj.Name
c.Tags = cj.Tags
c.Upstreams = cj.Upstreams
c.UseOwnSettings = !cj.UseGlobalSettings
c.FilteringEnabled = cj.FilteringEnabled
c.ParentalEnabled = cj.ParentalEnabled
c.SafeBrowsingEnabled = cj.SafeBrowsingEnabled
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
if c.safeSearchConf.Enabled {
err = c.setSafeSearch(
safeSearchConf,
c.safeSearchConf,
clients.safeSearchCacheSize,
clients.safeSearchCacheTTL,
)
@@ -217,7 +258,7 @@ func copySafeSearch(
func copyBlockedServices(
sch *schedule.Weekly,
svcStrs []string,
prev *Client,
prev *persistentClient,
) (svcs *filtering.BlockedServices, err error) {
var weekly *schedule.Weekly
if sch != nil {
@@ -241,8 +282,8 @@ func copyBlockedServices(
return svcs, nil
}
// clientToJSON converts Client object to JSON.
func clientToJSON(c *Client) (cj *clientJSON) {
// clientToJSON converts persistent client object to JSON object.
func clientToJSON(c *persistentClient) (cj *clientJSON) {
// TODO(d.kolyshev): Remove after cleaning the deprecated
// [clientJSON.SafeSearchEnabled] field.
cloneVal := c.safeSearchConf
@@ -250,7 +291,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
return &clientJSON{
Name: c.Name,
IDs: c.IDs,
IDs: c.ids(),
Tags: c.Tags,
UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled,
@@ -291,7 +332,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return
}
ok, err := clients.Add(c)
ok, err := clients.add(c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -323,7 +364,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
return
}
if !clients.Del(cj.Name) {
if !clients.remove(cj.Name) {
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
return
@@ -332,6 +373,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
onConfigModified()
}
// updateJSON contains the name and data of the updated persistent client.
type updateJSON struct {
Name string `json:"name"`
Data clientJSON `json:"data"`
@@ -355,7 +397,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
var prev *Client
var prev *persistentClient
var ok bool
func() {
@@ -367,6 +409,8 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
if !ok {
aghhttp.Error(r, w, http.StatusBadRequest, "client not found")
return
}
c, err := clients.jsonToClient(dj.Data, prev)
@@ -376,7 +420,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
err = clients.Update(prev, c)
err = clients.update(prev, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -397,7 +441,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
}
ip, _ := netip.ParseAddr(idStr)
c, ok := clients.Find(idStr)
c, ok := clients.find(idStr)
var cj *clientJSON
if !ok {
cj = clients.findRuntime(ip, idStr)
@@ -437,10 +481,11 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
return cj
}
_, host := rc.Info()
cj = &clientJSON{
Name: rc.Host,
Name: host,
IDs: []string{idStr},
WHOIS: rc.WHOIS,
WHOIS: whoisOrEmpty(rc),
}
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)

View File

@@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/confmigrate"
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/v2/maybe"
yaml "gopkg.in/yaml.v3"
@@ -149,7 +150,7 @@ type configuration struct {
sync.RWMutex `yaml:"-"`
// SchemaVersion is the version of the configuration schema. See
// [confmigrate.LastSchemaVersion].
// [configmigrate.LastSchemaVersion].
SchemaVersion uint `yaml:"schema_version"`
}
@@ -200,7 +201,7 @@ type dnsConfig struct {
// PrivateNets is the set of IP networks for which the private reverse DNS
// resolver should be used.
PrivateNets []string `yaml:"private_networks"`
PrivateNets []netutil.Prefix `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
// locally-served networks should be resolved via private PTR resolvers.
@@ -267,7 +268,7 @@ type queryLogConfig struct {
// MemSize is the number of entries kept in memory before they are flushed
// to disk.
MemSize int `yaml:"size_memory"`
MemSize uint `yaml:"size_memory"`
// Enabled defines if the query log is enabled.
Enabled bool `yaml:"enabled"`
@@ -315,14 +316,18 @@ var config = &configuration{
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 56,
RefuseAny: true,
AllServers: false,
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
HandleDDR: true,
FastestTimeout: timeutil.Duration{
Duration: fastip.DefaultPingWaitTimeout,
},
TrustedProxies: []string{"127.0.0.0/8", "::1/128"},
CacheSize: 4 * 1024 * 1024,
TrustedProxies: []netutil.Prefix{{
Prefix: netip.MustParsePrefix("127.0.0.0/8"),
}, {
Prefix: netip.MustParsePrefix("::1/128"),
}},
CacheSize: 4 * 1024 * 1024,
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{
CustomIP: netip.Addr{},
@@ -434,7 +439,7 @@ var config = &configuration{
MaxAge: 3,
},
OSConfig: &osConfig{},
SchemaVersion: confmigrate.LastSchemaVersion,
SchemaVersion: configmigrate.LastSchemaVersion,
Theme: ThemeAuto,
}
@@ -479,14 +484,14 @@ func parseConfig() (err error) {
return err
}
migrator := confmigrate.New(&confmigrate.Config{
migrator := configmigrate.New(&configmigrate.Config{
WorkingDir: Context.workDir,
})
var upgraded bool
config.fileData, upgraded, err = migrator.Migrate(
config.fileData,
confmigrate.LastSchemaVersion,
configmigrate.LastSchemaVersion,
)
if err != nil {
// Don't wrap the error, because it's informative enough as is.

View File

@@ -127,16 +127,11 @@ func initDNSServer(
httpReg aghhttp.RegisterFunc,
tlsConf *tlsConfigSettings,
) (err error) {
privateNets, err := parseSubnetSet(config.DNS.PrivateNets)
if err != nil {
return fmt.Errorf("preparing set of private subnets: %w", err)
}
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
DNSFilter: filters,
Stats: sts,
QueryLog: qlog,
PrivateNets: privateNets,
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
Anonymizer: anonymizer,
DHCPServer: dhcpSrv,
EtcHosts: Context.etcHosts,
@@ -169,26 +164,15 @@ func initDNSServer(
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
// a subnet set that matches all locally served networks, see
// [netutil.IsLocallyServed].
func parseSubnetSet(nets []string) (s netutil.SubnetSet, err error) {
func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
switch len(nets) {
case 0:
// Use an optimized function-based matcher.
return netutil.SubnetSetFunc(netutil.IsLocallyServed), nil
return netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
s, err = netutil.ParseSubnet(nets[0])
if err != nil {
return nil, err
}
return s, nil
return nets[0].Prefix
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
if err != nil {
return nil, err
}
return netutil.SliceSubnetSet(nets), nil
return netutil.SliceSubnetSet(netutil.UnembedPrefixes(nets))
}
}
@@ -411,9 +395,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ClientIP = clientIP
c, ok := Context.clients.Find(clientID)
c, ok := Context.clients.find(clientID)
if !ok {
c, ok = Context.clients.Find(clientIP.String())
c, ok = Context.clients.find(clientIP.String())
if !ok {
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)

View File

@@ -22,7 +22,7 @@ func TestApplyAdditionalFiltering(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
Context.clients.idIndex = map[string]*persistentClient{
"default": {
UseOwnSettings: false,
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
@@ -108,7 +108,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
Context.clients.idIndex = map[string]*persistentClient{
"default": {
UseOwnBlockedServices: false,
},

View File

@@ -35,8 +35,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
)
@@ -231,11 +233,12 @@ func setupHostsContainer() (err error) {
return fmt.Errorf("initing hosts watcher: %w", err)
}
Context.etcHosts, err = aghnet.NewHostsContainer(
aghos.RootDirFS(),
hostsWatcher,
aghnet.DefaultHostsPaths()...,
)
paths, err := hostsfile.DefaultHostsPaths()
if err != nil {
return fmt.Errorf("getting default system hosts paths: %w", err)
}
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
if err != nil {
closeErr := hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) {
@@ -357,6 +360,11 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
)
conf.EtcHosts = Context.etcHosts
// TODO(s.chzhen): Use empty interface.
if Context.etcHosts == nil {
conf.EtcHosts = nil
}
conf.ConfigModified = onConfigModified
conf.HTTPRegister = httpRegister
conf.DataDir = Context.getDataDir()
@@ -605,7 +613,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
Context.auth, err = initUsers()
fatalOnError(err)
Context.tls, err = newTLSManager(config.TLS)
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
if err != nil {
log.Error("initializing tls: %s", err)
onConfigModified()

View File

@@ -7,6 +7,7 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
@@ -308,7 +309,7 @@ var cmdLineOpts = []cmdLineOpt{{
effect: func(o options, exec string) (effect, error) {
return func() error {
if o.verbose {
fmt.Println(version.Verbose())
fmt.Print(version.Verbose(configmigrate.LastSchemaVersion))
} else {
fmt.Println(version.Full())
}

View File

@@ -38,15 +38,19 @@ type tlsManager struct {
confLock sync.Mutex
conf tlsConfigSettings
// servePlainDNS defines if plain DNS is allowed for incoming requests.
servePlainDNS bool
}
// newTLSManager initializes the manager of TLS configuration. m is always
// non-nil while any returned error indicates that the TLS configuration isn't
// valid. Thus TLS may be initialized later, e.g. via the web UI.
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
func newTLSManager(conf tlsConfigSettings, servePlainDNS bool) (m *tlsManager, err error) {
m = &tlsManager{
status: &tlsConfigStatus{},
conf: conf,
status: &tlsConfigStatus{},
conf: conf,
servePlainDNS: servePlainDNS,
}
if m.conf.Enabled {
@@ -283,21 +287,29 @@ type tlsConfig struct {
tlsConfigSettingsExt `json:",inline"`
}
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
// ensure that clients don't send and receive previously saved private keys.
// tlsConfigSettingsExt is used to (un)marshal PrivateKeySaved field and
// ServePlainDNS field.
type tlsConfigSettingsExt struct {
tlsConfigSettings `json:",inline"`
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
// key from answer. It is used to ensure that clients don't send and
// receive previously saved private keys.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
// ServePlainDNS defines if plain DNS is allowed for incoming requests. It
// is an [aghalg.NullBool] to be able to tell when it's set without using
// pointers.
ServePlainDNS aghalg.NullBool `yaml:"-" json:"serve_plain_dns"`
}
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
m.confLock.Lock()
data := tlsConfig{
tlsConfigSettingsExt: tlsConfigSettingsExt{
tlsConfigSettings: m.conf,
ServePlainDNS: aghalg.BoolToNullBool(m.servePlainDNS),
},
tlsConfigStatus: m.status,
}
@@ -306,6 +318,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
marshalTLS(w, r, data)
}
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r)
if err != nil {
@@ -318,30 +331,8 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts.PrivateKey = m.conf.PrivateKey
}
if setts.Enabled {
err = validatePorts(
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(setts.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
if !webCheckPortAvailable(setts.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
setts.PortHTTPS,
)
if err = validateTLSSettings(setts); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -358,7 +349,12 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
marshalTLS(w, r, resp)
}
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
// setConfig updates manager conf with the given one.
func (m *tlsManager) setConfig(
newConf tlsConfigSettings,
status *tlsConfigStatus,
servePlain aghalg.NullBool,
) (restartHTTPS bool) {
m.confLock.Lock()
defer m.confLock.Unlock()
@@ -390,9 +386,15 @@ func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatu
m.conf.PrivateKeyData = newConf.PrivateKeyData
m.status = status
if servePlain != aghalg.NBNull {
m.servePlainDNS = servePlain == aghalg.NBTrue
}
return restartHTTPS
}
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
// API.
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
@@ -405,31 +407,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
req.PrivateKey = m.conf.PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
// TODO(e.burkov): Investigate and perhaps check other ports.
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable https on it",
req.PortHTTPS,
)
if err = validateTLSSettings(req); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
@@ -447,8 +426,18 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
return
}
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
restartHTTPS := m.setConfig(req.tlsConfigSettings, status, req.ServePlainDNS)
m.setCertFileTime()
if req.ServePlainDNS != aghalg.NBNull {
func() {
m.confLock.Lock()
defer m.confLock.Unlock()
config.DNS.ServePlainDNS = req.ServePlainDNS == aghalg.NBTrue
}()
}
onConfigModified()
err = reconfigureDNSServer()
@@ -479,6 +468,33 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
}
}
// validateTLSSettings returns error if the setts are not valid.
func validateTLSSettings(setts tlsConfigSettingsExt) (err error) {
if setts.Enabled {
err = validatePorts(
tcpPort(config.HTTPConfig.Address.Port()),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(setts.PortDNSOverQUIC),
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
} else if setts.ServePlainDNS == aghalg.NBFalse {
// TODO(a.garipov): Support full disabling of all DNS.
return errors.Error("plain DNS is required in case encryption protocols are disabled")
}
if !webCheckPortAvailable(setts.PortHTTPS) {
return fmt.Errorf("port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
}
return nil
}
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
// DNS protocols.
func validatePorts(