all: sync with master
This commit is contained in:
@@ -1,19 +1,53 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Client contains information about persistent clients.
|
||||
type Client struct {
|
||||
// UID is the type for the unique IDs of persistent clients.
|
||||
type UID uuid.UUID
|
||||
|
||||
// NewUID returns a new persistent client UID. Any error returned is an error
|
||||
// from the cryptographic randomness reader.
|
||||
func NewUID() (uid UID, err error) {
|
||||
uuidv7, err := uuid.NewV7()
|
||||
|
||||
return UID(uuidv7), err
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ encoding.TextMarshaler = UID{}
|
||||
|
||||
// MarshalText implements the [encoding.TextMarshaler] for UID.
|
||||
func (uid UID) MarshalText() ([]byte, error) {
|
||||
return uuid.UUID(uid).MarshalText()
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ encoding.TextUnmarshaler = (*UID)(nil)
|
||||
|
||||
// UnmarshalText implements the [encoding.TextUnmarshaler] interface for UID.
|
||||
func (uid *UID) UnmarshalText(data []byte) error {
|
||||
return (*uuid.UUID)(uid).UnmarshalText(data)
|
||||
}
|
||||
|
||||
// persistentClient contains information about persistent clients.
|
||||
type persistentClient struct {
|
||||
// upstreamConfig is the custom upstream configuration for this client. If
|
||||
// it's nil, it has not been initialized yet. If it's non-nil and empty,
|
||||
// there are no valid upstreams. If it's non-nil and non-empty, these
|
||||
@@ -29,10 +63,18 @@ type Client struct {
|
||||
|
||||
Name string
|
||||
|
||||
IDs []string
|
||||
Tags []string
|
||||
Upstreams []string
|
||||
|
||||
IPs []netip.Addr
|
||||
// TODO(s.chzhen): Use netutil.Prefix.
|
||||
Subnets []netip.Prefix
|
||||
MACs []net.HardwareAddr
|
||||
ClientIDs []string
|
||||
|
||||
// UID is the unique identifier of the persistent client.
|
||||
UID UID
|
||||
|
||||
UpstreamsCacheSize uint32
|
||||
UpstreamsCacheEnabled bool
|
||||
|
||||
@@ -45,21 +87,153 @@ type Client struct {
|
||||
IgnoreStatistics bool
|
||||
}
|
||||
|
||||
// ShallowClone returns a deep copy of the client, except upstreamConfig,
|
||||
// setTags sets the tags if they are known, otherwise logs an unknown tag.
|
||||
func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
|
||||
for _, t := range tags {
|
||||
if !known.Has(t) {
|
||||
log.Info("skipping unknown tag %q", t)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
c.Tags = append(c.Tags, t)
|
||||
}
|
||||
|
||||
slices.Sort(c.Tags)
|
||||
}
|
||||
|
||||
// setIDs parses a list of strings into typed fields and returns an error if
|
||||
// there is one.
|
||||
func (c *persistentClient) setIDs(ids []string) (err error) {
|
||||
for _, id := range ids {
|
||||
err = c.setID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(c.IPs, netip.Addr.Compare)
|
||||
|
||||
// TODO(s.chzhen): Use netip.PrefixCompare in Go 1.23.
|
||||
slices.SortFunc(c.Subnets, subnetCompare)
|
||||
slices.SortFunc(c.MACs, slices.Compare[net.HardwareAddr])
|
||||
slices.Sort(c.ClientIDs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// subnetCompare is a comparison function for the two subnets. It returns -1 if
|
||||
// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting
|
||||
// position is the same.
|
||||
func subnetCompare(x, y netip.Prefix) (cmp int) {
|
||||
if x == y {
|
||||
return 0
|
||||
}
|
||||
|
||||
xAddr, xBits := x.Addr(), x.Bits()
|
||||
yAddr, yBits := y.Addr(), y.Bits()
|
||||
if xBits == yBits {
|
||||
return xAddr.Compare(yAddr)
|
||||
}
|
||||
|
||||
if xBits > yBits {
|
||||
return -1
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// setID parses id into typed field if there is no error.
|
||||
func (c *persistentClient) setID(id string) (err error) {
|
||||
if id == "" {
|
||||
return errors.Error("clientid is empty")
|
||||
}
|
||||
|
||||
var ip netip.Addr
|
||||
if ip, err = netip.ParseAddr(id); err == nil {
|
||||
c.IPs = append(c.IPs, ip)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var subnet netip.Prefix
|
||||
if subnet, err = netip.ParsePrefix(id); err == nil {
|
||||
c.Subnets = append(c.Subnets, subnet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var mac net.HardwareAddr
|
||||
if mac, err = net.ParseMAC(id); err == nil {
|
||||
c.MACs = append(c.MACs, mac)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = dnsforward.ValidateClientID(id)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ids returns a list of client ids containing at least one element.
|
||||
func (c *persistentClient) ids() (ids []string) {
|
||||
ids = make([]string, 0, c.idsLen())
|
||||
|
||||
for _, ip := range c.IPs {
|
||||
ids = append(ids, ip.String())
|
||||
}
|
||||
|
||||
for _, subnet := range c.Subnets {
|
||||
ids = append(ids, subnet.String())
|
||||
}
|
||||
|
||||
for _, mac := range c.MACs {
|
||||
ids = append(ids, mac.String())
|
||||
}
|
||||
|
||||
return append(ids, c.ClientIDs...)
|
||||
}
|
||||
|
||||
// idsLen returns a length of client ids.
|
||||
func (c *persistentClient) idsLen() (n int) {
|
||||
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
||||
}
|
||||
|
||||
// equalIDs returns true if the ids of the current and previous clients are the
|
||||
// same.
|
||||
func (c *persistentClient) equalIDs(prev *persistentClient) (equal bool) {
|
||||
return slices.Equal(c.IPs, prev.IPs) &&
|
||||
slices.Equal(c.Subnets, prev.Subnets) &&
|
||||
slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) &&
|
||||
slices.Equal(c.ClientIDs, prev.ClientIDs)
|
||||
}
|
||||
|
||||
// shallowClone returns a deep copy of the client, except upstreamConfig,
|
||||
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
|
||||
func (c *Client) ShallowClone() (sh *Client) {
|
||||
clone := *c
|
||||
func (c *persistentClient) shallowClone() (clone *persistentClient) {
|
||||
clone = &persistentClient{}
|
||||
*clone = *c
|
||||
|
||||
clone.BlockedServices = c.BlockedServices.Clone()
|
||||
clone.IDs = stringutil.CloneSlice(c.IDs)
|
||||
clone.Tags = stringutil.CloneSlice(c.Tags)
|
||||
clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
clone.Tags = slices.Clone(c.Tags)
|
||||
clone.Upstreams = slices.Clone(c.Upstreams)
|
||||
|
||||
return &clone
|
||||
clone.IPs = slices.Clone(c.IPs)
|
||||
clone.Subnets = slices.Clone(c.Subnets)
|
||||
clone.MACs = slices.Clone(c.MACs)
|
||||
clone.ClientIDs = slices.Clone(c.ClientIDs)
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// closeUpstreams closes the client-specific upstream config of c if any.
|
||||
func (c *Client) closeUpstreams() (err error) {
|
||||
func (c *persistentClient) closeUpstreams() (err error) {
|
||||
if c.upstreamConfig != nil {
|
||||
if err = c.upstreamConfig.Close(); err != nil {
|
||||
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
|
||||
@@ -70,7 +244,7 @@ func (c *Client) closeUpstreams() (err error) {
|
||||
}
|
||||
|
||||
// setSafeSearch initializes and sets the safe search filter for this client.
|
||||
func (c *Client) setSafeSearch(
|
||||
func (c *persistentClient) setSafeSearch(
|
||||
conf filtering.SafeSearchConfig,
|
||||
cacheSize uint,
|
||||
cacheTTL time.Duration,
|
||||
@@ -85,17 +259,3 @@ func (c *Client) setSafeSearch(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RuntimeClient is a client information about which has been obtained using the
|
||||
// source described in the Source field.
|
||||
type RuntimeClient struct {
|
||||
// WHOIS is the filtered WHOIS data of a client.
|
||||
WHOIS *whois.Info
|
||||
|
||||
// Host is the host name of a client.
|
||||
Host string
|
||||
|
||||
// Source is the source from which the information about the client has
|
||||
// been obtained.
|
||||
Source client.Source
|
||||
}
|
||||
|
||||
124
internal/home/client_internal_test.go
Normal file
124
internal/home/client_internal_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPersistentClient_EqualIDs(t *testing.T) {
|
||||
const (
|
||||
ip = "0.0.0.0"
|
||||
ip1 = "1.1.1.1"
|
||||
ip2 = "2.2.2.2"
|
||||
|
||||
cidr = "0.0.0.0/0"
|
||||
cidr1 = "1.1.1.1/11"
|
||||
cidr2 = "2.2.2.2/22"
|
||||
|
||||
mac = "00-00-00-00-00-00"
|
||||
mac1 = "11-11-11-11-11-11"
|
||||
mac2 = "22-22-22-22-22-22"
|
||||
|
||||
cli = "client0"
|
||||
cli1 = "client1"
|
||||
cli2 = "client2"
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ids []string
|
||||
prevIDs []string
|
||||
want assert.BoolAssertionFunc
|
||||
}{{
|
||||
name: "single_ip",
|
||||
ids: []string{ip1},
|
||||
prevIDs: []string{ip1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "single_ip_not_equal",
|
||||
ids: []string{ip1},
|
||||
prevIDs: []string{ip2},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "ips_not_equal",
|
||||
ids: []string{ip1, ip2},
|
||||
prevIDs: []string{ip1, ip},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "ips_mixed_equal",
|
||||
ids: []string{ip1, ip2},
|
||||
prevIDs: []string{ip2, ip1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "single_subnet",
|
||||
ids: []string{cidr1},
|
||||
prevIDs: []string{cidr1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "subnets_not_equal",
|
||||
ids: []string{ip1, ip2, cidr1, cidr2},
|
||||
prevIDs: []string{ip1, ip2, cidr1, cidr},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "subnets_mixed_equal",
|
||||
ids: []string{ip1, ip2, cidr1, cidr2},
|
||||
prevIDs: []string{cidr2, cidr1, ip2, ip1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "single_mac",
|
||||
ids: []string{mac1},
|
||||
prevIDs: []string{mac1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "single_mac_not_equal",
|
||||
ids: []string{mac1},
|
||||
prevIDs: []string{mac2},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "macs_not_equal",
|
||||
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
|
||||
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "macs_mixed_equal",
|
||||
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
|
||||
prevIDs: []string{mac2, mac1, cidr2, cidr1, ip2, ip1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "single_client_id",
|
||||
ids: []string{cli1},
|
||||
prevIDs: []string{cli1},
|
||||
want: assert.True,
|
||||
}, {
|
||||
name: "single_client_id_not_equal",
|
||||
ids: []string{cli1},
|
||||
prevIDs: []string{cli2},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "client_ids_not_equal",
|
||||
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
|
||||
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli},
|
||||
want: assert.False,
|
||||
}, {
|
||||
name: "client_ids_mixed_equal",
|
||||
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
|
||||
prevIDs: []string{cli2, cli1, mac2, mac1, cidr2, cidr1, ip2, ip1},
|
||||
want: assert.True,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := &persistentClient{}
|
||||
err := c.setIDs(tc.ids)
|
||||
require.NoError(t, err)
|
||||
|
||||
prev := &persistentClient{}
|
||||
err = prev.setIDs(tc.prevIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.want(t, c.equalIDs(prev))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -20,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -47,11 +47,11 @@ type DHCP interface {
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
// types (string, netip.Addr, and so on).
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
list map[string]*persistentClient // name -> client
|
||||
idIndex map[string]*persistentClient // ID -> client
|
||||
|
||||
// ipToRC is the IP address to *RuntimeClient map.
|
||||
ipToRC map[netip.Addr]*RuntimeClient
|
||||
// ipToRC maps IP addresses to runtime client information.
|
||||
ipToRC map[netip.Addr]*client.Runtime
|
||||
|
||||
allTags *stringutil.Set
|
||||
|
||||
@@ -102,9 +102,9 @@ func (clients *clientsContainer) Init(
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipToRC = map[netip.Addr]*RuntimeClient{}
|
||||
clients.list = map[string]*persistentClient{}
|
||||
clients.idIndex = map[string]*persistentClient{}
|
||||
clients.ipToRC = map[netip.Addr]*client.Runtime{}
|
||||
|
||||
clients.allTags = stringutil.NewSet(clientTags...)
|
||||
|
||||
@@ -139,6 +139,9 @@ func (clients *clientsContainer) Init(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleHostsUpdates receives the updates from the hosts container and adds
|
||||
// them to the clients container. It's used to be called in a separate
|
||||
// goroutine.
|
||||
func (clients *clientsContainer) handleHostsUpdates() {
|
||||
for upd := range clients.etcHosts.Upd() {
|
||||
clients.addFromHostsFile(upd)
|
||||
@@ -185,6 +188,9 @@ type clientObject struct {
|
||||
Tags []string `yaml:"tags"`
|
||||
Upstreams []string `yaml:"upstreams"`
|
||||
|
||||
// UID is the unique identifier of the persistent client.
|
||||
UID UID `yaml:"uid"`
|
||||
|
||||
// UpstreamsCacheSize is the DNS cache size (in bytes).
|
||||
//
|
||||
// TODO(d.kolyshev): Use [datasize.Bytesize].
|
||||
@@ -203,66 +209,83 @@ type clientObject struct {
|
||||
IgnoreStatistics bool `yaml:"ignore_statistics"`
|
||||
}
|
||||
|
||||
// toPersistent returns an initialized persistent client if there are no errors.
|
||||
func (o *clientObject) toPersistent(
|
||||
filteringConf *filtering.Config,
|
||||
allTags *stringutil.Set,
|
||||
) (cli *persistentClient, err error) {
|
||||
cli = &persistentClient{
|
||||
Name: o.Name,
|
||||
|
||||
Upstreams: o.Upstreams,
|
||||
|
||||
UID: o.UID,
|
||||
|
||||
UseOwnSettings: !o.UseGlobalSettings,
|
||||
FilteringEnabled: o.FilteringEnabled,
|
||||
ParentalEnabled: o.ParentalEnabled,
|
||||
safeSearchConf: o.SafeSearchConf,
|
||||
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
|
||||
IgnoreQueryLog: o.IgnoreQueryLog,
|
||||
IgnoreStatistics: o.IgnoreStatistics,
|
||||
UpstreamsCacheEnabled: o.UpstreamsCacheEnabled,
|
||||
UpstreamsCacheSize: o.UpstreamsCacheSize,
|
||||
}
|
||||
|
||||
err = cli.setIDs(o.IDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing ids: %w", err)
|
||||
}
|
||||
|
||||
if (cli.UID == UID{}) {
|
||||
cli.UID, err = NewUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating uid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err = cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init safesearch %q: %w", cli.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = o.BlockedServices.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init blocked services %q: %w", cli.Name, err)
|
||||
}
|
||||
|
||||
cli.BlockedServices = o.BlockedServices.Clone()
|
||||
|
||||
cli.setTags(o.Tags, allTags)
|
||||
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
// addFromConfig initializes the clients container with objects from the
|
||||
// configuration file.
|
||||
func (clients *clientsContainer) addFromConfig(
|
||||
objects []*clientObject,
|
||||
filteringConf *filtering.Config,
|
||||
) (err error) {
|
||||
for _, o := range objects {
|
||||
cli := &Client{
|
||||
Name: o.Name,
|
||||
|
||||
IDs: o.IDs,
|
||||
Upstreams: o.Upstreams,
|
||||
|
||||
UseOwnSettings: !o.UseGlobalSettings,
|
||||
FilteringEnabled: o.FilteringEnabled,
|
||||
ParentalEnabled: o.ParentalEnabled,
|
||||
safeSearchConf: o.SafeSearchConf,
|
||||
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
|
||||
IgnoreQueryLog: o.IgnoreQueryLog,
|
||||
IgnoreStatistics: o.IgnoreStatistics,
|
||||
UpstreamsCacheEnabled: o.UpstreamsCacheEnabled,
|
||||
UpstreamsCacheSize: o.UpstreamsCacheSize,
|
||||
}
|
||||
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
err = cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("clients: init client safesearch %q: %s", cli.Name, err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
err = o.BlockedServices.Validate()
|
||||
for i, o := range objects {
|
||||
var cli *persistentClient
|
||||
cli, err = o.toPersistent(filteringConf, clients.allTags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
|
||||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
cli.BlockedServices = o.BlockedServices.Clone()
|
||||
|
||||
for _, t := range o.Tags {
|
||||
if clients.allTags.Has(t) {
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
} else {
|
||||
log.Info("clients: skipping unknown tag %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
slices.Sort(cli.Tags)
|
||||
|
||||
_, err = clients.Add(cli)
|
||||
_, err = clients.add(cli)
|
||||
if err != nil {
|
||||
log.Error("clients: adding clients %s: %s", cli.Name, err)
|
||||
log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,10 +305,12 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||
|
||||
BlockedServices: cli.BlockedServices.Clone(),
|
||||
|
||||
IDs: stringutil.CloneSlice(cli.IDs),
|
||||
IDs: cli.ids(),
|
||||
Tags: stringutil.CloneSlice(cli.Tags),
|
||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||
|
||||
UID: cli.UID,
|
||||
|
||||
UseGlobalSettings: !cli.UseOwnSettings,
|
||||
FilteringEnabled: cli.FilteringEnabled,
|
||||
ParentalEnabled: cli.ParentalEnabled,
|
||||
@@ -338,7 +363,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
src = rc.Source
|
||||
src, _ = rc.Info()
|
||||
}
|
||||
|
||||
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
|
||||
@@ -348,10 +373,10 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
||||
return src
|
||||
}
|
||||
|
||||
// findMultiple is a wrapper around Find to make it a valid client finder for
|
||||
// the query log. c is never nil; if no information about the client is found,
|
||||
// it returns an artificial client record by only setting the blocking-related
|
||||
// fields. err is always nil.
|
||||
// findMultiple is a wrapper around [clientsContainer.find] to make it a valid
|
||||
// client finder for the query log. c is never nil; if no information about the
|
||||
// client is found, it returns an artificial client record by only setting the
|
||||
// blocking-related fields. err is always nil.
|
||||
func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client, err error) {
|
||||
var artClient *querylog.Client
|
||||
var art bool
|
||||
@@ -385,20 +410,22 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
}
|
||||
}()
|
||||
|
||||
client, ok := clients.Find(id)
|
||||
cli, ok := clients.find(id)
|
||||
if ok {
|
||||
return &querylog.Client{
|
||||
Name: client.Name,
|
||||
IgnoreQueryLog: client.IgnoreQueryLog,
|
||||
Name: cli.Name,
|
||||
IgnoreQueryLog: cli.IgnoreQueryLog,
|
||||
}, false
|
||||
}
|
||||
|
||||
var rc *RuntimeClient
|
||||
var rc *client.Runtime
|
||||
rc, ok = clients.findRuntimeClient(ip)
|
||||
if ok {
|
||||
_, host := rc.Info()
|
||||
|
||||
return &querylog.Client{
|
||||
Name: rc.Host,
|
||||
WHOIS: rc.WHOIS,
|
||||
Name: host,
|
||||
WHOIS: rc.WHOIS(),
|
||||
}, false
|
||||
}
|
||||
|
||||
@@ -407,8 +434,8 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
}, true
|
||||
}
|
||||
|
||||
// Find returns a shallow copy of the client if there is one found.
|
||||
func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
// find returns a shallow copy of the client if there is one found.
|
||||
func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -417,12 +444,12 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return c.ShallowClone(), true
|
||||
return c.shallowClone(), true
|
||||
}
|
||||
|
||||
// shouldCountClient is a wrapper around Find to make it a valid client
|
||||
// information finder for the statistics. If no information about the client
|
||||
// is found, it returns true.
|
||||
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
||||
// valid client information finder for the statistics. If no information about
|
||||
// the client is found, it returns true.
|
||||
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
@@ -490,7 +517,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||
|
||||
// findLocked searches for a client by its ID. clients.lock is expected to be
|
||||
// locked.
|
||||
func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
|
||||
c, ok = clients.idIndex[id]
|
||||
if ok {
|
||||
return c, true
|
||||
@@ -502,13 +529,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
var subnet netip.Prefix
|
||||
subnet, err = netip.ParsePrefix(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, subnet := range c.Subnets {
|
||||
if subnet.Contains(ip) {
|
||||
return c, true
|
||||
}
|
||||
@@ -521,22 +542,16 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
|
||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
||||
// there is such client. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
|
||||
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, ok bool) {
|
||||
foundMAC := clients.dhcp.MACByIP(ip)
|
||||
if foundMAC == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
mac, err := net.ParseMAC(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(mac, foundMAC) {
|
||||
return c, true
|
||||
}
|
||||
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
|
||||
if found {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,7 +560,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *Client, ok bool) {
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil, false
|
||||
}
|
||||
@@ -559,52 +574,43 @@ func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *RuntimeClient
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||
if rc, ok = clients.runtimeClient(ip); ok && rc.Source > client.SourceDHCP {
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
|
||||
rc, ok = clients.runtimeClient(ip)
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
if host == "" {
|
||||
return rc, ok
|
||||
|
||||
if host != "" {
|
||||
if !ok {
|
||||
rc = &client.Runtime{}
|
||||
}
|
||||
|
||||
rc.SetInfo(client.SourceDHCP, []string{host})
|
||||
|
||||
return rc, true
|
||||
}
|
||||
|
||||
return &RuntimeClient{
|
||||
Host: host,
|
||||
Source: client.SourceDHCP,
|
||||
WHOIS: &whois.Info{},
|
||||
}, true
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
// check validates the client.
|
||||
func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
// check validates the client. It also sorts the client tags.
|
||||
func (clients *clientsContainer) check(c *persistentClient) (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
return errors.Error("client is nil")
|
||||
case c.Name == "":
|
||||
return errors.Error("invalid name")
|
||||
case len(c.IDs) == 0:
|
||||
case c.idsLen() == 0:
|
||||
return errors.Error("id required")
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
for i, id := range c.IDs {
|
||||
var norm string
|
||||
norm, err = normalizeClientIdentifier(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
c.IDs[i] = norm
|
||||
}
|
||||
|
||||
for _, t := range c.Tags {
|
||||
if !clients.allTags.Has(t) {
|
||||
return fmt.Errorf("invalid tag: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Move to the constructor.
|
||||
slices.Sort(c.Tags)
|
||||
|
||||
err = dnsforward.ValidateUpstreams(c.Upstreams)
|
||||
@@ -615,38 +621,9 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeClientIdentifier returns a normalized version of idStr. If idStr
|
||||
// cannot be normalized, it returns an error.
|
||||
func normalizeClientIdentifier(idStr string) (norm string, err error) {
|
||||
if idStr == "" {
|
||||
return "", errors.Error("clientid is empty")
|
||||
}
|
||||
|
||||
var ip netip.Addr
|
||||
if ip, err = netip.ParseAddr(idStr); err == nil {
|
||||
return ip.String(), nil
|
||||
}
|
||||
|
||||
var subnet netip.Prefix
|
||||
if subnet, err = netip.ParsePrefix(idStr); err == nil {
|
||||
return subnet.String(), nil
|
||||
}
|
||||
|
||||
var mac net.HardwareAddr
|
||||
if mac, err = net.ParseMAC(idStr); err == nil {
|
||||
return mac.String(), nil
|
||||
}
|
||||
|
||||
if err = dnsforward.ValidateClientID(idStr); err == nil {
|
||||
return strings.ToLower(idStr), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("bad client identifier %q", idStr)
|
||||
}
|
||||
|
||||
// Add adds a new client object. ok is false if such client already exists or
|
||||
// add adds a new client object. ok is false if such client already exists or
|
||||
// if an error occurred.
|
||||
func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
|
||||
func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -662,50 +639,52 @@ func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
|
||||
}
|
||||
|
||||
// check ID index
|
||||
for _, id := range c.IDs {
|
||||
var c2 *Client
|
||||
ids := c.ids()
|
||||
for _, id := range ids {
|
||||
var c2 *persistentClient
|
||||
c2, ok = clients.idIndex[id]
|
||||
if ok {
|
||||
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
clients.add(c)
|
||||
clients.addLocked(c)
|
||||
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// add c to the indexes. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) add(c *Client) {
|
||||
// addLocked c to the indexes. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) addLocked(c *persistentClient) {
|
||||
// update Name index
|
||||
clients.list[c.Name] = c
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.IDs {
|
||||
for _, id := range c.ids() {
|
||||
clients.idIndex[id] = c
|
||||
}
|
||||
}
|
||||
|
||||
// Del removes a client. ok is false if there is no such client.
|
||||
func (clients *clientsContainer) Del(name string) (ok bool) {
|
||||
// remove removes a client. ok is false if there is no such client.
|
||||
func (clients *clientsContainer) remove(name string) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
var c *Client
|
||||
var c *persistentClient
|
||||
c, ok = clients.list[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
clients.del(c)
|
||||
clients.removeLocked(c)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// del removes c from the indexes. clients.lock is expected to be locked.
|
||||
func (clients *clientsContainer) del(c *Client) {
|
||||
// removeLocked removes c from the indexes. clients.lock is expected to be
|
||||
// locked.
|
||||
func (clients *clientsContainer) removeLocked(c *persistentClient) {
|
||||
if err := c.closeUpstreams(); err != nil {
|
||||
log.Error("client container: removing client %s: %s", c.Name, err)
|
||||
}
|
||||
@@ -714,13 +693,13 @@ func (clients *clientsContainer) del(c *Client) {
|
||||
delete(clients.list, c.Name)
|
||||
|
||||
// Update the ID index.
|
||||
for _, id := range c.IDs {
|
||||
for _, id := range c.ids() {
|
||||
delete(clients.idIndex, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Update updates a client by its name.
|
||||
func (clients *clientsContainer) Update(prev, c *Client) (err error) {
|
||||
// update updates a client by its name.
|
||||
func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
@@ -738,18 +717,23 @@ func (clients *clientsContainer) Update(prev, c *Client) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if c.equalIDs(prev) {
|
||||
clients.removeLocked(prev)
|
||||
clients.addLocked(c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the ID index.
|
||||
if !slices.Equal(prev.IDs, c.IDs) {
|
||||
for _, id := range c.IDs {
|
||||
existing, ok := clients.idIndex[id]
|
||||
if ok && existing != prev {
|
||||
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
|
||||
}
|
||||
for _, id := range c.ids() {
|
||||
existing, ok := clients.idIndex[id]
|
||||
if ok && existing != prev {
|
||||
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
|
||||
}
|
||||
}
|
||||
|
||||
clients.del(prev)
|
||||
clients.add(c)
|
||||
clients.removeLocked(prev)
|
||||
clients.addLocked(c)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -764,23 +748,20 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Consider storing WHOIS information separately and
|
||||
// potentially get rid of [RuntimeClient].
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = &RuntimeClient{
|
||||
Source: client.SourceWHOIS,
|
||||
}
|
||||
rc = &client.Runtime{}
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||
host, _ := rc.Info()
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
|
||||
}
|
||||
|
||||
rc.WHOIS = wi
|
||||
rc.SetWHOIS(wi)
|
||||
}
|
||||
|
||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
@@ -839,18 +820,13 @@ func (clients *clientsContainer) addHostLocked(
|
||||
}
|
||||
}
|
||||
|
||||
rc = &RuntimeClient{
|
||||
WHOIS: &whois.Info{},
|
||||
}
|
||||
rc = &client.Runtime{}
|
||||
clients.ipToRC[ip] = rc
|
||||
} else if src < rc.Source {
|
||||
return false
|
||||
}
|
||||
|
||||
rc.Host = host
|
||||
rc.Source = src
|
||||
rc.SetInfo(src, []string{host})
|
||||
|
||||
log.Debug("clients: added %s -> %q [%d]", ip, host, len(clients.ipToRC))
|
||||
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC))
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -859,7 +835,8 @@ func (clients *clientsContainer) addHostLocked(
|
||||
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
|
||||
n := 0
|
||||
for ip, rc := range clients.ipToRC {
|
||||
if rc.Source == src {
|
||||
rc.Unset(src)
|
||||
if rc.IsEmpty() {
|
||||
delete(clients.ipToRC, ip)
|
||||
n++
|
||||
}
|
||||
@@ -870,21 +847,24 @@ func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
|
||||
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts aghnet.Hosts) {
|
||||
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHostsBySrc(client.SourceHostsFile)
|
||||
|
||||
n := 0
|
||||
for addr, rec := range hosts {
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
clients.addHostLocked(addr, rec[0].Names[0], client.SourceHostsFile)
|
||||
n++
|
||||
}
|
||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
||||
n++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts file", n)
|
||||
}
|
||||
@@ -926,7 +906,7 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// the persistent clients.
|
||||
func (clients *clientsContainer) close() (err error) {
|
||||
persistent := maps.Values(clients.list)
|
||||
slices.SortFunc(persistent, func(a, b *Client) (res int) {
|
||||
slices.SortFunc(persistent, func(a, b *persistentClient) (res int) {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
|
||||
@@ -60,63 +60,65 @@ func TestClients(t *testing.T) {
|
||||
cli1 = "1.1.1.1"
|
||||
cli2 = "2.2.2.2"
|
||||
|
||||
cliNoneIP = netip.MustParseAddr(cliNone)
|
||||
cli1IP = netip.MustParseAddr(cli1)
|
||||
cli2IP = netip.MustParseAddr(cli2)
|
||||
cli1IP = netip.MustParseAddr(cli1)
|
||||
cli2IP = netip.MustParseAddr(cli2)
|
||||
|
||||
cliIPv6 = netip.MustParseAddr("1:2:3::4")
|
||||
)
|
||||
|
||||
c := &Client{
|
||||
IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
c := &persistentClient{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||
}
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
ok, err := clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &Client{
|
||||
IDs: []string{cli2},
|
||||
c = &persistentClient{
|
||||
Name: "client2",
|
||||
IPs: []netip.Addr{cli2IP},
|
||||
}
|
||||
|
||||
ok, err = clients.Add(c)
|
||||
ok, err = clients.add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.Find(cli1)
|
||||
c, ok = clients.find(cli1)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find("1:2:3::4")
|
||||
c, ok = clients.find("1:2:3::4")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find(cli2)
|
||||
c, ok = clients.find(cli2)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliNoneIP), client.SourceNone)
|
||||
_, ok = clients.find(cliNone)
|
||||
assert.False(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
|
||||
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.2.3.5"},
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client3",
|
||||
})
|
||||
require.Error(t, err)
|
||||
@@ -124,8 +126,7 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
err := clients.Update(&Client{Name: "client1"}, &Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{
|
||||
Name: "client1",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
@@ -136,33 +137,34 @@ func TestClients(t *testing.T) {
|
||||
cliOld = "1.1.1.1"
|
||||
cliNew = "1.1.1.2"
|
||||
|
||||
cliOldIP = netip.MustParseAddr(cliOld)
|
||||
cliNewIP = netip.MustParseAddr(cliNew)
|
||||
)
|
||||
|
||||
prev, ok := clients.list["client1"]
|
||||
require.True(t, ok)
|
||||
|
||||
err := clients.Update(prev, &Client{
|
||||
IDs: []string{cliNew},
|
||||
err := clients.update(prev, &persistentClient{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliOldIP), client.SourceNone)
|
||||
_, ok = clients.find(cliOld)
|
||||
assert.False(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
||||
|
||||
prev, ok = clients.list["client1"]
|
||||
require.True(t, ok)
|
||||
|
||||
err = clients.Update(prev, &Client{
|
||||
IDs: []string{cliNew},
|
||||
err = clients.update(prev, &persistentClient{
|
||||
Name: "client1-renamed",
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := clients.Find(cliNew)
|
||||
c, ok := clients.find(cliNew)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
@@ -173,20 +175,21 @@ func TestClients(t *testing.T) {
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
|
||||
require.Len(t, c.IDs, 1)
|
||||
require.Len(t, c.ids(), 1)
|
||||
|
||||
assert.Equal(t, cliNew, c.IDs[0])
|
||||
assert.Equal(t, cliNewIP, c.IPs[0])
|
||||
})
|
||||
|
||||
t.Run("del_success", func(t *testing.T) {
|
||||
ok := clients.Del("client1-renamed")
|
||||
ok := clients.remove("client1-renamed")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, clients.clientSource(netip.MustParseAddr("1.1.1.2")), client.SourceNone)
|
||||
_, ok = clients.find("1.1.1.2")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
ok := clients.Del("client3")
|
||||
ok := clients.remove("client3")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
@@ -215,10 +218,12 @@ func TestClients(t *testing.T) {
|
||||
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
t.Run("addhost_priority", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.addHost(ip, "host1", client.SourceRDNS)
|
||||
assert.False(t, ok)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -235,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, rc.WHOIS, whois)
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
@@ -247,15 +252,15 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, rc.WHOIS, whois)
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.2")
|
||||
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
@@ -264,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
rc := clients.ipToRC[ip]
|
||||
require.Nil(t, rc)
|
||||
|
||||
assert.True(t, clients.Del("client1"))
|
||||
assert.True(t, clients.remove("client1"))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -275,9 +280,11 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
@@ -323,17 +330,17 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{ip.String()},
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client2",
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
ok, err = clients.Add(&Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
ok, err = clients.add(&persistentClient{
|
||||
Name: "client3",
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
@@ -344,9 +351,9 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := newClientsContainer(t)
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
|
||||
@@ -61,6 +61,7 @@ type clientJSON struct {
|
||||
UpstreamsCacheEnabled aghalg.NullBool `json:"upstreams_cache_enabled"`
|
||||
}
|
||||
|
||||
// runtimeClientJSON is a JSON representation of the [client.Runtime].
|
||||
type runtimeClientJSON struct {
|
||||
WHOIS *whois.Info `json:"whois_info"`
|
||||
|
||||
@@ -69,12 +70,25 @@ type runtimeClientJSON struct {
|
||||
Source client.Source `json:"source"`
|
||||
}
|
||||
|
||||
// clientListJSON contains lists of persistent clients, runtime clients and also
|
||||
// supported tags.
|
||||
type clientListJSON struct {
|
||||
Clients []*clientJSON `json:"clients"`
|
||||
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
|
||||
Tags []string `json:"supported_tags"`
|
||||
}
|
||||
|
||||
// whoisOrEmpty returns a WHOIS client information or a pointer to an empty
|
||||
// struct. Frontend expects a non-nil value.
|
||||
func whoisOrEmpty(r *client.Runtime) (info *whois.Info) {
|
||||
info = r.WHOIS()
|
||||
if info != nil {
|
||||
return info
|
||||
}
|
||||
|
||||
return &whois.Info{}
|
||||
}
|
||||
|
||||
// handleGetClients is the handler for GET /control/clients HTTP API.
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
|
||||
data := clientListJSON{}
|
||||
@@ -88,11 +102,11 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
for ip, rc := range clients.ipToRC {
|
||||
src, host := rc.Info()
|
||||
cj := runtimeClientJSON{
|
||||
WHOIS: rc.WHOIS,
|
||||
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
Name: host,
|
||||
Source: src,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
@@ -115,32 +129,36 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
// jsonToClient converts JSON object to Client object.
|
||||
func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *Client, err error) {
|
||||
safeSearchConf := copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
|
||||
// initPrev initializes the persistent client with the default or previous
|
||||
// client properties.
|
||||
func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err error) {
|
||||
var (
|
||||
uid UID
|
||||
ignoreQueryLog bool
|
||||
ignoreStatistics bool
|
||||
upsCacheEnabled bool
|
||||
upsCacheSize uint32
|
||||
)
|
||||
|
||||
if prev != nil {
|
||||
uid = prev.UID
|
||||
ignoreQueryLog = prev.IgnoreQueryLog
|
||||
ignoreStatistics = prev.IgnoreStatistics
|
||||
upsCacheEnabled = prev.UpstreamsCacheEnabled
|
||||
upsCacheSize = prev.UpstreamsCacheSize
|
||||
}
|
||||
|
||||
var ignoreQueryLog bool
|
||||
if cj.IgnoreQueryLog != aghalg.NBNull {
|
||||
ignoreQueryLog = cj.IgnoreQueryLog == aghalg.NBTrue
|
||||
} else if prev != nil {
|
||||
ignoreQueryLog = prev.IgnoreQueryLog
|
||||
}
|
||||
|
||||
var ignoreStatistics bool
|
||||
if cj.IgnoreStatistics != aghalg.NBNull {
|
||||
ignoreStatistics = cj.IgnoreStatistics == aghalg.NBTrue
|
||||
} else if prev != nil {
|
||||
ignoreStatistics = prev.IgnoreStatistics
|
||||
}
|
||||
|
||||
var upsCacheEnabled bool
|
||||
var upsCacheSize uint32
|
||||
if cj.UpstreamsCacheEnabled != aghalg.NBNull {
|
||||
upsCacheEnabled = cj.UpstreamsCacheEnabled == aghalg.NBTrue
|
||||
upsCacheSize = cj.UpstreamsCacheSize
|
||||
} else if prev != nil {
|
||||
upsCacheEnabled = prev.UpstreamsCacheEnabled
|
||||
upsCacheSize = prev.UpstreamsCacheSize
|
||||
}
|
||||
|
||||
svcs, err := copyBlockedServices(cj.Schedule, cj.BlockedServices, prev)
|
||||
@@ -148,31 +166,54 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
|
||||
return nil, fmt.Errorf("invalid blocked services: %w", err)
|
||||
}
|
||||
|
||||
c = &Client{
|
||||
safeSearchConf: safeSearchConf,
|
||||
if (uid == UID{}) {
|
||||
uid, err = NewUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating uid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
Name: cj.Name,
|
||||
|
||||
BlockedServices: svcs,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
Upstreams: cj.Upstreams,
|
||||
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
return &persistentClient{
|
||||
BlockedServices: svcs,
|
||||
UID: uid,
|
||||
IgnoreQueryLog: ignoreQueryLog,
|
||||
IgnoreStatistics: ignoreStatistics,
|
||||
UpstreamsCacheEnabled: upsCacheEnabled,
|
||||
UpstreamsCacheSize: upsCacheSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jsonToClient converts JSON object to persistent client object if there are no
|
||||
// errors.
|
||||
func (clients *clientsContainer) jsonToClient(
|
||||
cj clientJSON,
|
||||
prev *persistentClient,
|
||||
) (c *persistentClient, err error) {
|
||||
c, err = initPrev(cj, prev)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if safeSearchConf.Enabled {
|
||||
err = c.setIDs(cj.IDs)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.safeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
|
||||
c.Name = cj.Name
|
||||
c.Tags = cj.Tags
|
||||
c.Upstreams = cj.Upstreams
|
||||
c.UseOwnSettings = !cj.UseGlobalSettings
|
||||
c.FilteringEnabled = cj.FilteringEnabled
|
||||
c.ParentalEnabled = cj.ParentalEnabled
|
||||
c.SafeBrowsingEnabled = cj.SafeBrowsingEnabled
|
||||
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
|
||||
|
||||
if c.safeSearchConf.Enabled {
|
||||
err = c.setSafeSearch(
|
||||
safeSearchConf,
|
||||
c.safeSearchConf,
|
||||
clients.safeSearchCacheSize,
|
||||
clients.safeSearchCacheTTL,
|
||||
)
|
||||
@@ -217,7 +258,7 @@ func copySafeSearch(
|
||||
func copyBlockedServices(
|
||||
sch *schedule.Weekly,
|
||||
svcStrs []string,
|
||||
prev *Client,
|
||||
prev *persistentClient,
|
||||
) (svcs *filtering.BlockedServices, err error) {
|
||||
var weekly *schedule.Weekly
|
||||
if sch != nil {
|
||||
@@ -241,8 +282,8 @@ func copyBlockedServices(
|
||||
return svcs, nil
|
||||
}
|
||||
|
||||
// clientToJSON converts Client object to JSON.
|
||||
func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
// clientToJSON converts persistent client object to JSON object.
|
||||
func clientToJSON(c *persistentClient) (cj *clientJSON) {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
cloneVal := c.safeSearchConf
|
||||
@@ -250,7 +291,7 @@ func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
|
||||
return &clientJSON{
|
||||
Name: c.Name,
|
||||
IDs: c.IDs,
|
||||
IDs: c.ids(),
|
||||
Tags: c.Tags,
|
||||
UseGlobalSettings: !c.UseOwnSettings,
|
||||
FilteringEnabled: c.FilteringEnabled,
|
||||
@@ -291,7 +332,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
ok, err := clients.add(c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -323,7 +364,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.Del(cj.Name) {
|
||||
if !clients.remove(cj.Name) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
|
||||
|
||||
return
|
||||
@@ -332,6 +373,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
// updateJSON contains the name and data of the updated persistent client.
|
||||
type updateJSON struct {
|
||||
Name string `json:"name"`
|
||||
Data clientJSON `json:"data"`
|
||||
@@ -355,7 +397,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
var prev *Client
|
||||
var prev *persistentClient
|
||||
var ok bool
|
||||
|
||||
func() {
|
||||
@@ -367,6 +409,8 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
|
||||
if !ok {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "client not found")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c, err := clients.jsonToClient(dj.Data, prev)
|
||||
@@ -376,7 +420,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
err = clients.Update(prev, c)
|
||||
err = clients.update(prev, c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -397,7 +441,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.Find(idStr)
|
||||
c, ok := clients.find(idStr)
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
@@ -437,10 +481,11 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
|
||||
return cj
|
||||
}
|
||||
|
||||
_, host := rc.Info()
|
||||
cj = &clientJSON{
|
||||
Name: rc.Host,
|
||||
Name: host,
|
||||
IDs: []string{idStr},
|
||||
WHOIS: rc.WHOIS,
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/confmigrate"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
@@ -149,7 +150,7 @@ type configuration struct {
|
||||
sync.RWMutex `yaml:"-"`
|
||||
|
||||
// SchemaVersion is the version of the configuration schema. See
|
||||
// [confmigrate.LastSchemaVersion].
|
||||
// [configmigrate.LastSchemaVersion].
|
||||
SchemaVersion uint `yaml:"schema_version"`
|
||||
}
|
||||
|
||||
@@ -200,7 +201,7 @@ type dnsConfig struct {
|
||||
|
||||
// PrivateNets is the set of IP networks for which the private reverse DNS
|
||||
// resolver should be used.
|
||||
PrivateNets []string `yaml:"private_networks"`
|
||||
PrivateNets []netutil.Prefix `yaml:"private_networks"`
|
||||
|
||||
// UsePrivateRDNS defines if the PTR requests for unknown addresses from
|
||||
// locally-served networks should be resolved via private PTR resolvers.
|
||||
@@ -267,7 +268,7 @@ type queryLogConfig struct {
|
||||
|
||||
// MemSize is the number of entries kept in memory before they are flushed
|
||||
// to disk.
|
||||
MemSize int `yaml:"size_memory"`
|
||||
MemSize uint `yaml:"size_memory"`
|
||||
|
||||
// Enabled defines if the query log is enabled.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
@@ -315,14 +316,18 @@ var config = &configuration{
|
||||
RatelimitSubnetLenIPv4: 24,
|
||||
RatelimitSubnetLenIPv6: 56,
|
||||
RefuseAny: true,
|
||||
AllServers: false,
|
||||
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
|
||||
HandleDDR: true,
|
||||
FastestTimeout: timeutil.Duration{
|
||||
Duration: fastip.DefaultPingWaitTimeout,
|
||||
},
|
||||
|
||||
TrustedProxies: []string{"127.0.0.0/8", "::1/128"},
|
||||
CacheSize: 4 * 1024 * 1024,
|
||||
TrustedProxies: []netutil.Prefix{{
|
||||
Prefix: netip.MustParsePrefix("127.0.0.0/8"),
|
||||
}, {
|
||||
Prefix: netip.MustParsePrefix("::1/128"),
|
||||
}},
|
||||
CacheSize: 4 * 1024 * 1024,
|
||||
|
||||
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{
|
||||
CustomIP: netip.Addr{},
|
||||
@@ -434,7 +439,7 @@ var config = &configuration{
|
||||
MaxAge: 3,
|
||||
},
|
||||
OSConfig: &osConfig{},
|
||||
SchemaVersion: confmigrate.LastSchemaVersion,
|
||||
SchemaVersion: configmigrate.LastSchemaVersion,
|
||||
Theme: ThemeAuto,
|
||||
}
|
||||
|
||||
@@ -479,14 +484,14 @@ func parseConfig() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
migrator := confmigrate.New(&confmigrate.Config{
|
||||
migrator := configmigrate.New(&configmigrate.Config{
|
||||
WorkingDir: Context.workDir,
|
||||
})
|
||||
|
||||
var upgraded bool
|
||||
config.fileData, upgraded, err = migrator.Migrate(
|
||||
config.fileData,
|
||||
confmigrate.LastSchemaVersion,
|
||||
configmigrate.LastSchemaVersion,
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
|
||||
@@ -127,16 +127,11 @@ func initDNSServer(
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
) (err error) {
|
||||
privateNets, err := parseSubnetSet(config.DNS.PrivateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
QueryLog: qlog,
|
||||
PrivateNets: privateNets,
|
||||
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
|
||||
Anonymizer: anonymizer,
|
||||
DHCPServer: dhcpSrv,
|
||||
EtcHosts: Context.etcHosts,
|
||||
@@ -169,26 +164,15 @@ func initDNSServer(
|
||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||
// a subnet set that matches all locally served networks, see
|
||||
// [netutil.IsLocallyServed].
|
||||
func parseSubnetSet(nets []string) (s netutil.SubnetSet, err error) {
|
||||
func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
|
||||
switch len(nets) {
|
||||
case 0:
|
||||
// Use an optimized function-based matcher.
|
||||
return netutil.SubnetSetFunc(netutil.IsLocallyServed), nil
|
||||
return netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
s, err = netutil.ParseSubnet(nets[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
return nets[0].Prefix
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return netutil.SliceSubnetSet(nets), nil
|
||||
return netutil.SliceSubnetSet(netutil.UnembedPrefixes(nets))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,9 +395,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||
|
||||
setts.ClientIP = clientIP
|
||||
|
||||
c, ok := Context.clients.Find(clientID)
|
||||
c, ok := Context.clients.find(clientID)
|
||||
if !ok {
|
||||
c, ok = Context.clients.Find(clientIP.String())
|
||||
c, ok = Context.clients.find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
Context.clients.idIndex = map[string]*persistentClient{
|
||||
"default": {
|
||||
UseOwnSettings: false,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||
@@ -108,7 +108,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*Client{
|
||||
Context.clients.idIndex = map[string]*persistentClient{
|
||||
"default": {
|
||||
UseOwnBlockedServices: false,
|
||||
},
|
||||
|
||||
@@ -35,8 +35,10 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@@ -231,11 +233,12 @@ func setupHostsContainer() (err error) {
|
||||
return fmt.Errorf("initing hosts watcher: %w", err)
|
||||
}
|
||||
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(
|
||||
aghos.RootDirFS(),
|
||||
hostsWatcher,
|
||||
aghnet.DefaultHostsPaths()...,
|
||||
)
|
||||
paths, err := hostsfile.DefaultHostsPaths()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting default system hosts paths: %w", err)
|
||||
}
|
||||
|
||||
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
|
||||
if err != nil {
|
||||
closeErr := hostsWatcher.Close()
|
||||
if errors.Is(err, aghnet.ErrNoHostsPaths) {
|
||||
@@ -357,6 +360,11 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||
)
|
||||
|
||||
conf.EtcHosts = Context.etcHosts
|
||||
// TODO(s.chzhen): Use empty interface.
|
||||
if Context.etcHosts == nil {
|
||||
conf.EtcHosts = nil
|
||||
}
|
||||
|
||||
conf.ConfigModified = onConfigModified
|
||||
conf.HTTPRegister = httpRegister
|
||||
conf.DataDir = Context.getDataDir()
|
||||
@@ -605,7 +613,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
Context.auth, err = initUsers()
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls, err = newTLSManager(config.TLS)
|
||||
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
|
||||
if err != nil {
|
||||
log.Error("initializing tls: %s", err)
|
||||
onConfigModified()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -308,7 +309,7 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
effect: func(o options, exec string) (effect, error) {
|
||||
return func() error {
|
||||
if o.verbose {
|
||||
fmt.Println(version.Verbose())
|
||||
fmt.Print(version.Verbose(configmigrate.LastSchemaVersion))
|
||||
} else {
|
||||
fmt.Println(version.Full())
|
||||
}
|
||||
|
||||
@@ -38,15 +38,19 @@ type tlsManager struct {
|
||||
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
|
||||
// servePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||
servePlainDNS bool
|
||||
}
|
||||
|
||||
// newTLSManager initializes the manager of TLS configuration. m is always
|
||||
// non-nil while any returned error indicates that the TLS configuration isn't
|
||||
// valid. Thus TLS may be initialized later, e.g. via the web UI.
|
||||
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
||||
func newTLSManager(conf tlsConfigSettings, servePlainDNS bool) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
status: &tlsConfigStatus{},
|
||||
conf: conf,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: conf,
|
||||
servePlainDNS: servePlainDNS,
|
||||
}
|
||||
|
||||
if m.conf.Enabled {
|
||||
@@ -283,21 +287,29 @@ type tlsConfig struct {
|
||||
tlsConfigSettingsExt `json:",inline"`
|
||||
}
|
||||
|
||||
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
|
||||
// ensure that clients don't send and receive previously saved private keys.
|
||||
// tlsConfigSettingsExt is used to (un)marshal PrivateKeySaved field and
|
||||
// ServePlainDNS field.
|
||||
type tlsConfigSettingsExt struct {
|
||||
tlsConfigSettings `json:",inline"`
|
||||
|
||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||
// key from answer.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
|
||||
// key from answer. It is used to ensure that clients don't send and
|
||||
// receive previously saved private keys.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
|
||||
|
||||
// ServePlainDNS defines if plain DNS is allowed for incoming requests. It
|
||||
// is an [aghalg.NullBool] to be able to tell when it's set without using
|
||||
// pointers.
|
||||
ServePlainDNS aghalg.NullBool `yaml:"-" json:"serve_plain_dns"`
|
||||
}
|
||||
|
||||
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
m.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: m.conf,
|
||||
ServePlainDNS: aghalg.BoolToNullBool(m.servePlainDNS),
|
||||
},
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
@@ -306,6 +318,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
marshalTLS(w, r, data)
|
||||
}
|
||||
|
||||
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
@@ -318,30 +331,8 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !webCheckPortAvailable(setts.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable HTTPS on it",
|
||||
setts.PortHTTPS,
|
||||
)
|
||||
if err = validateTLSSettings(setts); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -358,7 +349,12 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
marshalTLS(w, r, resp)
|
||||
}
|
||||
|
||||
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
|
||||
// setConfig updates manager conf with the given one.
|
||||
func (m *tlsManager) setConfig(
|
||||
newConf tlsConfigSettings,
|
||||
status *tlsConfigStatus,
|
||||
servePlain aghalg.NullBool,
|
||||
) (restartHTTPS bool) {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
|
||||
@@ -390,9 +386,15 @@ func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatu
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
m.status = status
|
||||
|
||||
if servePlain != aghalg.NBNull {
|
||||
m.servePlainDNS = servePlain == aghalg.NBTrue
|
||||
}
|
||||
|
||||
return restartHTTPS
|
||||
}
|
||||
|
||||
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
|
||||
// API.
|
||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
@@ -405,31 +407,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(req.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Investigate and perhaps check other ports.
|
||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable https on it",
|
||||
req.PortHTTPS,
|
||||
)
|
||||
if err = validateTLSSettings(req); err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -447,8 +426,18 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
|
||||
restartHTTPS := m.setConfig(req.tlsConfigSettings, status, req.ServePlainDNS)
|
||||
m.setCertFileTime()
|
||||
|
||||
if req.ServePlainDNS != aghalg.NBNull {
|
||||
func() {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
|
||||
config.DNS.ServePlainDNS = req.ServePlainDNS == aghalg.NBTrue
|
||||
}()
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
@@ -479,6 +468,33 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
}
|
||||
|
||||
// validateTLSSettings returns error if the setts are not valid.
|
||||
func validateTLSSettings(setts tlsConfigSettingsExt) (err error) {
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.HTTPConfig.Address.Port()),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
} else if setts.ServePlainDNS == aghalg.NBFalse {
|
||||
// TODO(a.garipov): Support full disabling of all DNS.
|
||||
return errors.Error("plain DNS is required in case encryption protocols are disabled")
|
||||
}
|
||||
|
||||
if !webCheckPortAvailable(setts.PortHTTPS) {
|
||||
return fmt.Errorf("port %d is not available, cannot enable HTTPS on it", setts.PortHTTPS)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
// DNS protocols.
|
||||
func validatePorts(
|
||||
|
||||
Reference in New Issue
Block a user