Compare commits

..

1 Commits

Author SHA1 Message Date
Stanislav Chzhen
53cb84efc0 all: session storage usage 2025-04-22 15:42:12 +03:00
36 changed files with 674 additions and 1694 deletions

View File

@@ -9,41 +9,26 @@ The format is based on [*Keep a Changelog*](https://keepachangelog.com/en/1.0.0/
<!-- <!--
## [v0.108.0] TBA ## [v0.108.0] TBA
## [v0.107.62] - 2025-04-30 (APPROX.) ## [v0.107.61] - 2025-04-22 (APPROX.)
See also the [v0.107.62 GitHub milestone][ms-v0.107.62]. See also the [v0.107.61 GitHub milestone][ms-v0.107.61].
[ms-v0.107.62]: https://github.com/AdguardTeam/AdGuardHome/milestone/97?closed=1 [ms-v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/milestone/96?closed=1
NOTE: Add new changes BELOW THIS COMMENT. NOTE: Add new changes BELOW THIS COMMENT.
--> -->
### Fixed
- DNS cache not working for custom upstream configurations.
- Validation process for the DNS-over-TLS, DNS-over-QUIC, and HTTPS ports on the *Encryption Settings* page.
<!--
NOTE: Add new changes ABOVE THIS COMMENT.
-->
## [v0.107.61] - 2025-04-22
See also the [v0.107.61 GitHub milestone][ms-v0.107.61].
### Security ### Security
- Any simultaneous requests that are considered duplicates will now only result in a single request to upstreams, reducing the chance of a cache poisoning attack succeeding. This is controlled by the new configuration object `pending_requests`, which has a single `enabled` property, set to `true` by default. - Any simultaneous requests that are considered duplicates will now only result in a single request to upstreams, reducing the chance of a cache poisoning attack succeeding. This is controlled by the new configuration object `pending_requests`, which has a single `enabled` property, set to `true` by default.
**NOTE:** We thank [Xiang Li][mr-xiang-li] for reporting this security issue. It's strongly recommended to leave it enabled, otherwise AdGuard Home will be vulnerable to untrusted clients. **NOTE:** We thank [Xiang Li][mr-xiang-li] for reporting this security issue. It's strongly recommended to leave it enabled, otherwise AdGuard Home will be vulnerable to untrusted clients.
### Fixed
- Searching for persistent clients using an exact match for CIDR in the `POST /clients/search HTTP API`.
[mr-xiang-li]: https://lixiang521.com/ [mr-xiang-li]: https://lixiang521.com/
[ms-v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/milestone/96?closed=1
<!--
NOTE: Add new changes ABOVE THIS COMMENT.
-->
## [v0.107.60] - 2025-04-14 ## [v0.107.60] - 2025-04-14
@@ -86,6 +71,10 @@ See also the [v0.107.60 GitHub milestone][ms-v0.107.60].
See also the [v0.107.59 GitHub milestone][ms-v0.107.59]. See also the [v0.107.59 GitHub milestone][ms-v0.107.59].
### Fixed
- Validation process for the DNS-over-TLS, DNS-over-QUIC, and HTTPS ports on the *Encryption Settings* page.
- Rules with the `client` modifier not working ([#7708]). - Rules with the `client` modifier not working ([#7708]).
- The search form not working in the query log ([#7704]). - The search form not working in the query log ([#7704]).
@@ -3126,12 +3115,11 @@ See also the [v0.104.2 GitHub milestone][ms-v0.104.2].
[ms-v0.104.2]: https://github.com/AdguardTeam/AdGuardHome/milestone/28?closed=1 [ms-v0.104.2]: https://github.com/AdguardTeam/AdGuardHome/milestone/28?closed=1
<!-- <!--
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.62...HEAD
[v0.107.62]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.61...v0.107.62
-->
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.61...HEAD [Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.61...HEAD
[v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.60...v0.107.61 [v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.60...v0.107.61
-->
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.60...HEAD
[v0.107.60]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.59...v0.107.60 [v0.107.60]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.59...v0.107.60
[v0.107.59]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.58...v0.107.59 [v0.107.59]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.58...v0.107.59
[v0.107.58]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.57...v0.107.58 [v0.107.58]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.57...v0.107.58

View File

@@ -656,7 +656,7 @@
"blocklist": "Blocklist", "blocklist": "Blocklist",
"milliseconds_abbreviation": "ms", "milliseconds_abbreviation": "ms",
"cache_size": "Cache size", "cache_size": "Cache size",
"cache_size_desc": "DNS cache size (in bytes). To disable caching, set to 0.", "cache_size_desc": "DNS cache size (in bytes). To disable caching, leave empty.",
"cache_ttl_min_override": "Override minimum TTL", "cache_ttl_min_override": "Override minimum TTL",
"cache_ttl_max_override": "Override maximum TTL", "cache_ttl_max_override": "Override maximum TTL",
"enter_cache_size": "Enter cache size (bytes)", "enter_cache_size": "Enter cache size (bytes)",

View File

@@ -355,8 +355,12 @@ func (ds *DefaultSessionStorage) store(s *Session) (err error) {
return nil return nil
} }
// FindByToken implements the [SessionStorage] interface for *DefaultSessionStorage. // FindByToken implements the [SessionStorage] interface for
func (ds *DefaultSessionStorage) FindByToken(ctx context.Context, t SessionToken) (s *Session, err error) { // *DefaultSessionStorage.
func (ds *DefaultSessionStorage) FindByToken(
ctx context.Context,
t SessionToken,
) (s *Session, err error) {
ds.mu.Lock() ds.mu.Lock()
defer ds.mu.Unlock() defer ds.mu.Unlock()

View File

@@ -11,34 +11,8 @@ import (
"slices" "slices"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
) )
// ClientID is a unique identifier for a persistent client used in
// DNS-over-HTTPS, DNS-over-TLS, and DNS-over-QUIC queries.
//
// TODO(s.chzhen): Use everywhere.
type ClientID string
// ValidateClientID returns an error if id is not a valid ClientID.
//
// TODO(s.chzhen): Consider implementing [validate.Interface] for ClientID.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// isValidClientID returns false if id is not a valid ClientID.
func isValidClientID(id string) (ok bool) {
return netutil.IsValidHostnameLabel(id)
}
// Source represents the source from which the information about the client has // Source represents the source from which the information about the client has
// been obtained. // been obtained.
type Source uint8 type Source uint8

View File

@@ -35,7 +35,7 @@ type index struct {
nameToUID map[string]UID nameToUID map[string]UID
// clientIDToUID maps ClientID to UID. // clientIDToUID maps ClientID to UID.
clientIDToUID map[ClientID]UID clientIDToUID map[string]UID
// ipToUID maps IP address to UID. // ipToUID maps IP address to UID.
ipToUID map[netip.Addr]UID ipToUID map[netip.Addr]UID
@@ -54,7 +54,7 @@ type index struct {
func newIndex() (ci *index) { func newIndex() (ci *index) {
return &index{ return &index{
nameToUID: map[string]UID{}, nameToUID: map[string]UID{},
clientIDToUID: map[ClientID]UID{}, clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{}, ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{}, macToUID: map[macKey]UID{},
@@ -207,7 +207,7 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
// find finds persistent client by string representation of the ClientID, IP // find finds persistent client by string representation of the ClientID, IP
// address, or MAC. // address, or MAC.
func (ci *index) find(id string) (c *Persistent, ok bool) { func (ci *index) find(id string) (c *Persistent, ok bool) {
c, ok = ci.findByClientID(ClientID(id)) c, ok = ci.findByClientID(id)
if ok { if ok {
return c, true return c, true
} }
@@ -230,7 +230,7 @@ func (ci *index) find(id string) (c *Persistent, ok bool) {
} }
// findByClientID finds persistent client by ClientID. // findByClientID finds persistent client by ClientID.
func (ci *index) findByClientID(clientID ClientID) (c *Persistent, ok bool) { func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) {
uid, ok := ci.clientIDToUID[clientID] uid, ok := ci.clientIDToUID[clientID]
if ok { if ok {
return ci.uidToClient[uid], true return ci.uidToClient[uid], true
@@ -275,26 +275,6 @@ func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false return nil, false
} }
// findByCIDR searches for a persistent client with the provided subnet as an
// identifier. Note that this function looks for an exact match of subnets,
// rather than checking if one subnet contains another.
func (ci *index) findByCIDR(subnet netip.Prefix) (c *Persistent, ok bool) {
var uid UID
for pref, id := range ci.subnetToUID.Range {
if subnet == pref {
uid, ok = id, true
break
}
}
if ok {
return ci.uidToClient[uid], true
}
return nil, false
}
// findByMAC finds persistent client by MAC. // findByMAC finds persistent client by MAC.
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac) k := macToKey(mac)

View File

@@ -5,7 +5,6 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -59,12 +58,12 @@ func TestClientIndex_Find(t *testing.T) {
clientWithMAC = &Persistent{ clientWithMAC = &Persistent{
Name: "client_with_mac", Name: "client_with_mac",
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
} }
clientWithID = &Persistent{ clientWithID = &Persistent{
Name: "client_with_id", Name: "client_with_id",
ClientIDs: []ClientID{cliID}, ClientIDs: []string{cliID},
} }
clientLinkLocal = &Persistent{ clientLinkLocal = &Persistent{
@@ -142,10 +141,10 @@ func TestClientIndex_Clashes(t *testing.T) {
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, { }, {
Name: "client_with_mac", Name: "client_with_mac",
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, { }, {
Name: "client_with_id", Name: "client_with_id",
ClientIDs: []ClientID{cliID}, ClientIDs: []string{cliID},
}} }}
ci := newIDIndex(clients) ci := newIDIndex(clients)
@@ -182,6 +181,17 @@ func TestClientIndex_Clashes(t *testing.T) {
} }
} }
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestMACToKey(t *testing.T) { func TestMACToKey(t *testing.T) {
testCases := []struct { testCases := []struct {
want any want any
@@ -190,44 +200,44 @@ func TestMACToKey(t *testing.T) {
}{{ }{{
name: "column6", name: "column6",
in: "00:00:5e:00:53:01", in: "00:00:5e:00:53:01",
want: [6]byte(errors.Must(net.ParseMAC("00:00:5e:00:53:01"))), want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
}, { }, {
name: "column8", name: "column8",
in: "02:00:5e:10:00:00:00:01", in: "02:00:5e:10:00:00:00:01",
want: [8]byte(errors.Must(net.ParseMAC("02:00:5e:10:00:00:00:01"))), want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
}, { }, {
name: "column20", name: "column20",
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
want: [20]byte(errors.Must(net.ParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01"))), want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
}, { }, {
name: "hyphen6", name: "hyphen6",
in: "00-00-5e-00-53-01", in: "00-00-5e-00-53-01",
want: [6]byte(errors.Must(net.ParseMAC("00-00-5e-00-53-01"))), want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
}, { }, {
name: "hyphen8", name: "hyphen8",
in: "02-00-5e-10-00-00-00-01", in: "02-00-5e-10-00-00-00-01",
want: [8]byte(errors.Must(net.ParseMAC("02-00-5e-10-00-00-00-01"))), want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
}, { }, {
name: "hyphen20", name: "hyphen20",
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
want: [20]byte(errors.Must(net.ParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01"))), want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
}, { }, {
name: "dot6", name: "dot6",
in: "0000.5e00.5301", in: "0000.5e00.5301",
want: [6]byte(errors.Must(net.ParseMAC("0000.5e00.5301"))), want: [6]byte(mustParseMAC("0000.5e00.5301")),
}, { }, {
name: "dot8", name: "dot8",
in: "0200.5e10.0000.0001", in: "0200.5e10.0000.0001",
want: [8]byte(errors.Must(net.ParseMAC("0200.5e10.0000.0001"))), want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
}, { }, {
name: "dot20", name: "dot20",
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
want: [20]byte(errors.Must(net.ParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001"))), want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
mac := errors.Must(net.ParseMAC(tc.in)) mac := mustParseMAC(tc.in)
key := macToKey(mac) key := macToKey(mac)
assert.Equal(t, tc.want, key) assert.Equal(t, tc.want, key)
@@ -292,19 +302,19 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {
func TestClientIndex_RangeByName(t *testing.T) { func TestClientIndex_RangeByName(t *testing.T) {
sortedClients := []*Persistent{{ sortedClients := []*Persistent{{
Name: "clientA", Name: "clientA",
ClientIDs: []ClientID{"A"}, ClientIDs: []string{"A"},
}, { }, {
Name: "clientB", Name: "clientB",
ClientIDs: []ClientID{"B"}, ClientIDs: []string{"B"},
}, { }, {
Name: "clientC", Name: "clientC",
ClientIDs: []ClientID{"C"}, ClientIDs: []string{"C"},
}, { }, {
Name: "clientD", Name: "clientD",
ClientIDs: []ClientID{"D"}, ClientIDs: []string{"D"},
}, { }, {
Name: "clientE", Name: "clientE",
ClientIDs: []ClientID{"E"}, ClientIDs: []string{"E"},
}} }}
testCases := []struct { testCases := []struct {
@@ -339,115 +349,3 @@ func TestClientIndex_RangeByName(t *testing.T) {
}) })
} }
} }
func TestIndex_FindByName(t *testing.T) {
const (
clientExistingName = "client_existing"
clientAnotherExistingName = "client_another_existing"
nonExistingClientName = "client_non_existing"
)
var (
clientExisting = &Persistent{
Name: clientExistingName,
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.1")},
}
clientAnotherExisting = &Persistent{
Name: clientAnotherExistingName,
IPs: []netip.Addr{netip.MustParseAddr("192.0.2.2")},
}
)
clients := []*Persistent{
clientExisting,
clientAnotherExisting,
}
ci := newIDIndex(clients)
testCases := []struct {
want *Persistent
found assert.BoolAssertionFunc
name string
clientName string
}{{
want: clientExisting,
found: assert.True,
name: "existing",
clientName: clientExistingName,
}, {
want: clientAnotherExisting,
found: assert.True,
name: "another_existing",
clientName: clientAnotherExistingName,
}, {
want: nil,
found: assert.False,
name: "non_existing",
clientName: nonExistingClientName,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := ci.findByName(tc.clientName)
assert.Equal(t, tc.want, c)
tc.found(t, ok)
})
}
}
func TestIndex_FindByMAC(t *testing.T) {
var (
cliMAC = errors.Must(net.ParseMAC("11:11:11:11:11:11"))
cliAnotherMAC = errors.Must(net.ParseMAC("22:22:22:22:22:22"))
nonExistingClientMAC = errors.Must(net.ParseMAC("33:33:33:33:33:33"))
)
var (
clientExisting = &Persistent{
Name: "client",
MACs: []net.HardwareAddr{cliMAC},
}
clientAnotherExisting = &Persistent{
Name: "another_client",
MACs: []net.HardwareAddr{cliAnotherMAC},
}
)
clients := []*Persistent{
clientExisting,
clientAnotherExisting,
}
ci := newIDIndex(clients)
testCases := []struct {
want *Persistent
found assert.BoolAssertionFunc
name string
clientMAC net.HardwareAddr
}{{
want: clientExisting,
found: assert.True,
name: "existing",
clientMAC: cliMAC,
}, {
want: clientAnotherExisting,
found: assert.True,
name: "another_existing",
clientMAC: cliAnotherMAC,
}, {
want: nil,
found: assert.False,
name: "non_existing",
clientMAC: nonExistingClientMAC,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := ci.findByMAC(tc.clientMAC)
assert.Equal(t, tc.want, c)
tc.found(t, ok)
})
}
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -70,9 +71,7 @@ type Persistent struct {
// Tags is a list of client tags that categorize the client. // Tags is a list of client tags that categorize the client.
Tags []string Tags []string
// Upstreams is a list of custom upstream DNS servers for the client. If // Upstreams is a list of custom upstream DNS servers for the client.
// it's empty, the custom upstream cache is disabled, regardless of the
// value of UpstreamsCacheEnabled.
Upstreams []string Upstreams []string
// IPs is a list of IP addresses that identify the client. The client must // IPs is a list of IP addresses that identify the client. The client must
@@ -91,16 +90,15 @@ type Persistent struct {
// ClientIDs identifying the client. The client must have at least one ID // ClientIDs identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID). // (IP, subnet, MAC, or ClientID).
ClientIDs []ClientID ClientIDs []string
// UID is the unique identifier of the persistent client. // UID is the unique identifier of the persistent client.
UID UID UID UID
// UpstreamsCacheSize defines the size of the custom upstream cache. // UpstreamsCacheSize is the cache size for custom upstreams.
UpstreamsCacheSize uint32 UpstreamsCacheSize uint32
// UpstreamsCacheEnabled specifies whether the custom upstream cache is // UpstreamsCacheEnabled specifies whether custom upstreams are used.
// used. If true, the list of Upstreams should not be empty.
UpstreamsCacheEnabled bool UpstreamsCacheEnabled bool
// UseOwnSettings specifies whether custom filtering settings are used. // UseOwnSettings specifies whether custom filtering settings are used.
@@ -136,7 +134,7 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str
switch { switch {
case c.Name == "": case c.Name == "":
return errors.Error("empty name") return errors.Error("empty name")
case c.idendifiersLen() == 0: case c.IDsLen() == 0:
return errors.Error("id required") return errors.Error("id required")
case c.UID == UID{}: case c.UID == UID{}:
return errors.Error("uid required") return errors.Error("uid required")
@@ -239,15 +237,28 @@ func (c *Persistent) setID(id string) (err error) {
return err return err
} }
c.ClientIDs = append(c.ClientIDs, ClientID(strings.ToLower(id))) c.ClientIDs = append(c.ClientIDs, strings.ToLower(id))
return nil return nil
} }
// Identifiers returns a list of client identifiers containing at least one // ValidateClientID returns an error if id is not a valid ClientID.
// element. //
func (c *Persistent) Identifiers() (ids []string) { // TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
ids = make([]string, 0, c.idendifiersLen()) // avoid the import cycle. Remove it.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// IDs returns a list of ClientIDs containing at least one element.
func (c *Persistent) IDs() (ids []string) {
ids = make([]string, 0, c.IDsLen())
for _, ip := range c.IPs { for _, ip := range c.IPs {
ids = append(ids, ip.String()) ids = append(ids, ip.String())
@@ -261,15 +272,11 @@ func (c *Persistent) Identifiers() (ids []string) {
ids = append(ids, mac.String()) ids = append(ids, mac.String())
} }
for _, cid := range c.ClientIDs { return append(ids, c.ClientIDs...)
ids = append(ids, string(cid))
}
return ids
} }
// identifiersLen returns the number of client identifiers. // IDsLen returns a length of ClientIDs.
func (c *Persistent) idendifiersLen() (n int) { func (c *Persistent) IDsLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
} }

View File

@@ -7,7 +7,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
"strings"
"sync" "sync"
"time" "time"
@@ -19,7 +18,6 @@ import (
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
) )
@@ -435,186 +433,48 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) {
ctx, ctx,
"client added", "client added",
"name", p.Name, "name", p.Name,
"ids", p.Identifiers(), "ids", p.IDs(),
"clients_count", s.index.size(), "clients_count", s.index.size(),
) )
return nil return nil
} }
// FindParams represents the parameters for searching a client. At least one // FindByName finds persistent client by name. And returns its shallow copy.
// field must be non-empty. func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
type FindParams struct {
// ClientID is a unique identifier for the client used in DoH, DoT, and DoQ
// DNS queries.
ClientID ClientID
// RemoteIP is the IP address used as a client search parameter.
RemoteIP netip.Addr
// Subnet is the CIDR used as a client search parameter.
Subnet netip.Prefix
// MAC is the physical hardware address used as a client search parameter.
MAC net.HardwareAddr
// UID is the unique ID of persistent client used as a search parameter.
//
// TODO(s.chzhen): Use this.
UID UID
}
// ErrBadIdentifier is returned by [FindParams.Set] when it cannot parse the
// provided client identifier.
const ErrBadIdentifier errors.Error = "bad client identifier"
// Set clears the stored search parameters and parses the string representation
// of the search parameter into typed parameter, storing it. In some cases, it
// may result in storing both an IP address and a MAC address because they might
// have identical string representations. It returns [ErrBadIdentifier] if id
// cannot be parsed.
//
// TODO(s.chzhen): Add support for UID.
func (p *FindParams) Set(id string) (err error) {
*p = FindParams{}
isClientID := true
if netutil.IsValidIPString(id) {
// It is safe to use [netip.MustParseAddr] because it has already been
// validated that id contains the string representation of the IP
// address.
p.RemoteIP = netip.MustParseAddr(id)
// Even if id can be parsed as an IP address, it may be a MAC address.
// So do not return prematurely, continue parsing.
isClientID = false
}
if canBeValidIPPrefixString(id) {
p.Subnet, err = netip.ParsePrefix(id)
if err == nil {
isClientID = false
}
}
if canBeMACString(id) {
p.MAC, err = net.ParseMAC(id)
if err == nil {
isClientID = false
}
}
if !isClientID {
return nil
}
if !isValidClientID(id) {
return ErrBadIdentifier
}
p.ClientID = ClientID(id)
return nil
}
// canBeValidIPPrefixString is a best-effort check to determine if s is a valid
// CIDR before using [netip.ParsePrefix], aimed at reducing allocations.
//
// TODO(s.chzhen): Replace this implementation with the more robust version
// from golibs.
func canBeValidIPPrefixString(s string) (ok bool) {
ipStr, bitStr, ok := strings.Cut(s, "/")
if !ok {
return false
}
if bitStr == "" || len(bitStr) > 3 {
return false
}
bits := 0
for _, c := range bitStr {
if c < '0' || c > '9' {
return false
}
bits = bits*10 + int(c-'0')
}
if bits > 128 {
return false
}
return netutil.IsValidIPString(ipStr)
}
// canBeMACString is a best-effort check to determine if s is a valid MAC
// address before using [net.ParseMAC], aimed at reducing allocations.
//
// TODO(s.chzhen): Replace this implementation with the more robust version
// from golibs.
func canBeMACString(s string) (ok bool) {
switch len(s) {
case
len("0000.0000.0000"),
len("00:00:00:00:00:00"),
len("0000.0000.0000.0000"),
len("00:00:00:00:00:00:00:00"),
len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"),
len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"):
return true
default:
return false
}
}
// Find represents the parameters for searching a client. params must not be
// nil and must have at least one non-empty field.
func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
isClientID := params.ClientID != "" p, ok = s.index.findByName(name)
isRemoteIP := params.RemoteIP != (netip.Addr{}) if ok {
isSubnet := params.Subnet != (netip.Prefix{}) return p.ShallowClone(), ok
isMAC := params.MAC != nil
for {
switch {
case isClientID:
isClientID = false
p, ok = s.index.findByClientID(params.ClientID)
case isRemoteIP:
isRemoteIP = false
p, ok = s.findByIP(params.RemoteIP)
case isSubnet:
isSubnet = false
p, ok = s.index.findByCIDR(params.Subnet)
case isMAC:
isMAC = false
p, ok = s.index.findByMAC(params.MAC)
default:
return nil, false
}
if ok {
return p.ShallowClone(), true
}
} }
return nil, false
} }
// findByIP finds persistent client by IP address. s.mu is expected to be // Find finds persistent client by string representation of the ClientID, IP
// locked. // address, or MAC. And returns its shallow copy.
func (s *Storage) findByIP(addr netip.Addr) (p *Persistent, ok bool) { //
p, ok = s.index.findByIP(addr) // TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
// the parsed IP address, if any.
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
p, ok = s.index.find(id)
if ok { if ok {
return p, true return p.ShallowClone(), ok
} }
foundMAC := s.dhcp.MACByIP(addr) ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}
foundMAC := s.dhcp.MACByIP(ip)
if foundMAC != nil { if foundMAC != nil {
return s.index.findByMAC(foundMAC) return s.FindByMAC(foundMAC)
} }
return nil, false return nil, false
@@ -627,8 +487,6 @@ func (s *Storage) findByIP(addr netip.Addr) (p *Persistent, ok bool) {
// //
// Note that multiple clients can have the same IP address with different zones. // Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate. // Therefore, the result of this method is indeterminate.
//
// TODO(s.chzhen): Consider accepting [FindParams].
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -640,7 +498,7 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
foundMAC := s.dhcp.MACByIP(ip) foundMAC := s.dhcp.MACByIP(ip)
if foundMAC != nil { if foundMAC != nil {
return s.index.findByMAC(foundMAC) return s.FindByMAC(foundMAC)
} }
p = s.index.findByIPWithoutZone(ip) p = s.index.findByIPWithoutZone(ip)
@@ -651,6 +509,17 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
return nil, false return nil, false
} }
// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
// is expected to be locked.
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
p, ok = s.index.findByMAC(mac)
if ok {
return p.ShallowClone(), ok
}
return nil, false
}
// RemoveByName removes persistent client information. ok is false if no such // RemoveByName removes persistent client information. ok is false if no such
// client exists by that name. // client exists by that name.
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) { func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
@@ -779,9 +648,9 @@ func (s *Storage) CustomUpstreamConfig(
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
c, ok := s.index.findByClientID(ClientID(id)) c, ok := s.index.findByClientID(id)
if !ok { if !ok {
c, ok = s.findByIP(addr) c, ok = s.index.findByIP(addr)
} }
if !ok { if !ok {
@@ -813,7 +682,7 @@ func (s *Storage) ClearUpstreamCache() {
// ClientID or client IP address, and applies it to the filtering settings. // ClientID or client IP address, and applies it to the filtering settings.
// setts must not be nil. // setts must not be nil.
func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) { func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) {
c, ok := s.index.findByClientID(ClientID(id)) c, ok := s.index.findByClientID(id)
if !ok { if !ok {
c, ok = s.index.findByIP(addr) c, ok = s.index.findByIP(addr)
} }
@@ -821,7 +690,7 @@ func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filter
if !ok { if !ok {
foundMAC := s.dhcp.MACByIP(addr) foundMAC := s.dhcp.MACByIP(addr)
if foundMAC != nil { if foundMAC != nil {
c, ok = s.index.findByMAC(foundMAC) c, ok = s.FindByMAC(foundMAC)
} }
} }

View File

@@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
@@ -351,15 +350,15 @@ func TestClientsDHCP(t *testing.T) {
cliName1 = "one.dhcp" cliName1 = "one.dhcp"
cliIP2 = netip.MustParseAddr("2.2.2.2") cliIP2 = netip.MustParseAddr("2.2.2.2")
cliMAC2 = errors.Must(net.ParseMAC("22:22:22:22:22:22")) cliMAC2 = mustParseMAC("22:22:22:22:22:22")
cliName2 = "two.dhcp" cliName2 = "two.dhcp"
cliIP3 = netip.MustParseAddr("3.3.3.3") cliIP3 = netip.MustParseAddr("3.3.3.3")
cliMAC3 = errors.Must(net.ParseMAC("33:33:33:33:33:33")) cliMAC3 = mustParseMAC("33:33:33:33:33:33")
cliName3 = "three.dhcp" cliName3 = "three.dhcp"
prsCliIP = netip.MustParseAddr("4.3.2.1") prsCliIP = netip.MustParseAddr("4.3.2.1")
prsCliMAC = errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA")) prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
prsCliName = "persistent.dhcp" prsCliName = "persistent.dhcp"
otherARPCliName = "other.arp" otherARPCliName = "other.arp"
@@ -520,11 +519,7 @@ func TestClientsDHCP(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
params := &client.FindParams{} prsCli, ok := storage.Find(prsCliIP.String())
err = params.Set(prsCliIP.String())
require.NoError(t, err)
prsCli, ok := storage.Find(params)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, prsCliName, prsCli.Name) assert.Equal(t, prsCliName, prsCli.Name)
@@ -668,6 +663,17 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
return s return s
} }
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestStorage_Add(t *testing.T) { func TestStorage_Add(t *testing.T) {
const ( const (
existingName = "existing_name" existingName = "existing_name"
@@ -687,7 +693,7 @@ func TestStorage_Add(t *testing.T) {
Name: existingName, Name: existingName,
IPs: []netip.Addr{existingIP}, IPs: []netip.Addr{existingIP},
Subnets: []netip.Prefix{existingSubnet}, Subnets: []netip.Prefix{existingSubnet},
ClientIDs: []client.ClientID{existingClientID}, ClientIDs: []string{existingClientID},
UID: existingClientUID, UID: existingClientUID,
} }
@@ -755,7 +761,7 @@ func TestStorage_Add(t *testing.T) {
name: "duplicate_client_id", name: "duplicate_client_id",
cli: &client.Persistent{ cli: &client.Persistent{
Name: "duplicate_client_id", Name: "duplicate_client_id",
ClientIDs: []client.ClientID{existingClientID}, ClientIDs: []string{existingClientID},
UID: client.MustNewUID(), UID: client.MustNewUID(),
}, },
wantErrMsg: `adding client: another client "existing_name" ` + wantErrMsg: `adding client: another client "existing_name" ` +
@@ -892,12 +898,12 @@ func TestStorage_Find(t *testing.T) {
clientWithMAC = &client.Persistent{ clientWithMAC = &client.Persistent{
Name: "client_with_mac", Name: "client_with_mac",
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
} }
clientWithID = &client.Persistent{ clientWithID = &client.Persistent{
Name: "client_with_id", Name: "client_with_id",
ClientIDs: []client.ClientID{cliID}, ClientIDs: []string{cliID},
} }
clientLinkLocal = &client.Persistent{ clientLinkLocal = &client.Persistent{
@@ -944,11 +950,7 @@ func TestStorage_Find(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids { for _, id := range tc.ids {
params := &client.FindParams{} c, ok := s.Find(id)
err := params.Set(id)
require.NoError(t, err)
c, ok := s.Find(params)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, tc.want, c) assert.Equal(t, tc.want, c)
@@ -957,11 +959,7 @@ func TestStorage_Find(t *testing.T) {
} }
t.Run("not_found", func(t *testing.T) { t.Run("not_found", func(t *testing.T) {
params := &client.FindParams{} _, ok := s.Find(cliIPNone)
err := params.Set(cliIPNone)
require.NoError(t, err)
_, ok := s.Find(params)
assert.False(t, ok) assert.False(t, ok)
}) })
} }
@@ -1027,6 +1025,127 @@ func TestStorage_FindLoose(t *testing.T) {
} }
} }
func TestStorage_FindByName(t *testing.T) {
const (
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
)
const (
clientExistingName = "client_existing"
clientAnotherExistingName = "client_another_existing"
nonExistingClientName = "client_non_existing"
)
var (
clientExisting = &client.Persistent{
Name: clientExistingName,
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}
clientAnotherExisting = &client.Persistent{
Name: clientAnotherExistingName,
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
}
)
clients := []*client.Persistent{
clientExisting,
clientAnotherExisting,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
clientName string
}{{
name: "existing",
clientName: clientExistingName,
want: clientExisting,
}, {
name: "another_existing",
clientName: clientAnotherExistingName,
want: clientAnotherExisting,
}, {
name: "non_existing",
clientName: nonExistingClientName,
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindByName(tc.clientName)
if tc.want == nil {
assert.False(t, ok)
return
}
assert.True(t, ok)
assert.Equal(t, tc.want, c)
})
}
}
func TestStorage_FindByMAC(t *testing.T) {
var (
cliMAC = mustParseMAC("11:11:11:11:11:11")
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
)
var (
clientExisting = &client.Persistent{
Name: "client",
MACs: []net.HardwareAddr{cliMAC},
}
clientAnotherExisting = &client.Persistent{
Name: "another_client",
MACs: []net.HardwareAddr{cliAnotherMAC},
}
)
clients := []*client.Persistent{
clientExisting,
clientAnotherExisting,
}
s := newStorage(t, clients)
testCases := []struct {
want *client.Persistent
name string
clientMAC net.HardwareAddr
}{{
name: "existing",
clientMAC: cliMAC,
want: clientExisting,
}, {
name: "another_existing",
clientMAC: cliAnotherMAC,
want: clientAnotherExisting,
}, {
name: "non_existing",
clientMAC: nonExistingClientMAC,
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c, ok := s.FindByMAC(tc.clientMAC)
if tc.want == nil {
assert.False(t, ok)
return
}
assert.True(t, ok)
assert.Equal(t, tc.want, c)
})
}
}
func TestStorage_Update(t *testing.T) { func TestStorage_Update(t *testing.T) {
const ( const (
clientName = "client_name" clientName = "client_name"
@@ -1043,7 +1162,7 @@ func TestStorage_Update(t *testing.T) {
Name: obstructingName, Name: obstructingName,
IPs: []netip.Addr{obstructingIP}, IPs: []netip.Addr{obstructingIP},
Subnets: []netip.Prefix{obstructingSubnet}, Subnets: []netip.Prefix{obstructingSubnet},
ClientIDs: []client.ClientID{obstructingClientID}, ClientIDs: []string{obstructingClientID},
} }
clientToUpdate := &client.Persistent{ clientToUpdate := &client.Persistent{
@@ -1092,7 +1211,7 @@ func TestStorage_Update(t *testing.T) {
name: "duplicate_client_id", name: "duplicate_client_id",
cli: &client.Persistent{ cli: &client.Persistent{
Name: "duplicate_client_id", Name: "duplicate_client_id",
ClientIDs: []client.ClientID{obstructingClientID}, ClientIDs: []string{obstructingClientID},
UID: client.MustNewUID(), UID: client.MustNewUID(),
}, },
wantErrMsg: `updating client: another client "obstructing_name" ` + wantErrMsg: `updating client: another client "obstructing_name" ` +
@@ -1119,19 +1238,19 @@ func TestStorage_Update(t *testing.T) {
func TestStorage_RangeByName(t *testing.T) { func TestStorage_RangeByName(t *testing.T) {
sortedClients := []*client.Persistent{{ sortedClients := []*client.Persistent{{
Name: "clientA", Name: "clientA",
ClientIDs: []client.ClientID{"A"}, ClientIDs: []string{"A"},
}, { }, {
Name: "clientB", Name: "clientB",
ClientIDs: []client.ClientID{"B"}, ClientIDs: []string{"B"},
}, { }, {
Name: "clientC", Name: "clientC",
ClientIDs: []client.ClientID{"C"}, ClientIDs: []string{"C"},
}, { }, {
Name: "clientD", Name: "clientD",
ClientIDs: []client.ClientID{"D"}, ClientIDs: []string{"D"},
}, { }, {
Name: "clientE", Name: "clientE",
ClientIDs: []client.ClientID{"E"}, ClientIDs: []string{"E"},
}} }}
testCases := []struct { testCases := []struct {
@@ -1169,20 +1288,29 @@ func TestStorage_RangeByName(t *testing.T) {
func TestStorage_CustomUpstreamConfig(t *testing.T) { func TestStorage_CustomUpstreamConfig(t *testing.T) {
const ( const (
existingClientID = "existing_client_id" existingName = "existing_name"
existingClientID = "existing_client_id"
nonExistingClientID = "non_existing_client_id" nonExistingClientID = "non_existing_client_id"
) )
var ( var (
existingIP = netip.MustParseAddr("192.0.2.1") existingClientUID = client.MustNewUID()
nonExistingIP = netip.MustParseAddr("192.0.2.255") existingIP = netip.MustParseAddr("192.0.2.1")
dhcpCliIP = netip.MustParseAddr("192.0.2.2") nonExistingIP = netip.MustParseAddr("192.0.2.255")
dhcpCliMAC = errors.Must(net.ParseMAC("02:00:00:00:00:00"))
testUpstreamTimeout = time.Second testUpstreamTimeout = time.Second
) )
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
ClientIDs: []string{existingClientID},
UID: existingClientUID,
Upstreams: []string{"192.0.2.0"},
}
date := time.Now() date := time.Now()
clock := &faketime.Clock{ clock := &faketime.Clock{
OnNow: func() (now time.Time) { OnNow: func() (now time.Time) {
@@ -1192,30 +1320,7 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
}, },
} }
ipToMAC := map[netip.Addr]net.HardwareAddr{ s := newTestStorage(t, clock)
dhcpCliIP: dhcpCliMAC,
}
dhcp := &testDHCP{
OnLeases: func() (ls []*dhcpsvc.Lease) {
panic("not implemented")
},
OnHostBy: func(ip netip.Addr) (host string) {
panic("not implemented")
},
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) {
return ipToMAC[ip]
},
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := client.NewStorage(ctx, &client.StorageConfig{
Logger: slogutil.NewDiscardLogger(),
Clock: clock,
DHCP: dhcp,
})
require.NoError(t, err)
s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{ s.UpdateCommonUpstreamConfig(&client.CommonUpstreamConfig{
UpstreamTimeout: testUpstreamTimeout, UpstreamTimeout: testUpstreamTimeout,
}) })
@@ -1224,21 +1329,8 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) return s.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
}) })
err = s.Add(ctx, &client.Persistent{ ctx := testutil.ContextWithTimeout(t, testTimeout)
Name: "client_first", err := s.Add(ctx, existingClient)
IPs: []netip.Addr{existingIP},
ClientIDs: []client.ClientID{existingClientID},
UID: client.MustNewUID(),
Upstreams: []string{"192.0.2.0"},
})
require.NoError(t, err)
err = s.Add(ctx, &client.Persistent{
Name: "client_second",
MACs: []net.HardwareAddr{dhcpCliMAC},
UID: client.MustNewUID(),
Upstreams: []string{"192.0.2.0"},
})
require.NoError(t, err) require.NoError(t, err)
testCases := []struct { testCases := []struct {
@@ -1256,11 +1348,6 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
cliID: "", cliID: "",
cliAddr: existingIP, cliAddr: existingIP,
wantNilConf: assert.NotNil, wantNilConf: assert.NotNil,
}, {
name: "client_dhcp",
cliID: "",
cliAddr: dhcpCliIP,
wantNilConf: assert.NotNil,
}, { }, {
name: "non_existing_client_id", name: "non_existing_client_id",
cliID: nonExistingClientID, cliID: nonExistingClientID,
@@ -1293,193 +1380,4 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) {
assert.NotEqual(t, conf, updConf) assert.NotEqual(t, conf, updConf)
}) })
t.Run("same_custom_config", func(t *testing.T) {
firstConf := s.CustomUpstreamConfig(existingClientID, existingIP)
require.NotNil(t, firstConf)
secondConf := s.CustomUpstreamConfig(existingClientID, existingIP)
require.NotNil(t, secondConf)
assert.Same(t, firstConf, secondConf)
})
}
func BenchmarkFindParams_Set(b *testing.B) {
const (
testIPStr = "192.0.2.1"
testCIDRStr = "192.0.2.0/24"
testMACStr = "02:00:00:00:00:00"
testClientID = "clientid"
)
benchCases := []struct {
wantErr error
params *client.FindParams
name string
id string
}{{
wantErr: nil,
params: &client.FindParams{
ClientID: testClientID,
},
name: "client_id",
id: testClientID,
}, {
wantErr: nil,
params: &client.FindParams{
RemoteIP: netip.MustParseAddr(testIPStr),
},
name: "ip_address",
id: testIPStr,
}, {
wantErr: nil,
params: &client.FindParams{
Subnet: netip.MustParsePrefix(testCIDRStr),
},
name: "subnet",
id: testCIDRStr,
}, {
wantErr: nil,
params: &client.FindParams{
MAC: errors.Must(net.ParseMAC(testMACStr)),
},
name: "mac_address",
id: testMACStr,
}, {
wantErr: client.ErrBadIdentifier,
params: &client.FindParams{},
name: "bad_id",
id: "!@#$%^&*()_+",
}}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
params := &client.FindParams{}
var err error
b.ReportAllocs()
for b.Loop() {
err = params.Set(bc.id)
}
assert.ErrorIs(b, err, bc.wantErr)
assert.Equal(b, bc.params, params)
})
}
// Most recent results:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
// BenchmarkFindParams_Set/client_id-8 49463488 24.27 ns/op 0 B/op 0 allocs/op
// BenchmarkFindParams_Set/ip_address-8 18740977 62.22 ns/op 0 B/op 0 allocs/op
// BenchmarkFindParams_Set/subnet-8 10848192 110.0 ns/op 0 B/op 0 allocs/op
// BenchmarkFindParams_Set/mac_address-8 8148494 133.2 ns/op 8 B/op 1 allocs/op
// BenchmarkFindParams_Set/bad_id-8 73894278 16.29 ns/op 0 B/op 0 allocs/op
}
func BenchmarkStorage_Find(b *testing.B) {
const (
cliID = "cid"
cliMAC = "02:00:00:00:00:00"
)
const (
cliNameWithID = "client_with_id"
cliNameWithIP = "client_with_ip"
cliNameWithCIDR = "client_with_cidr"
cliNameWithMAC = "client_with_mac"
)
var (
cliIP = netip.MustParseAddr("192.0.2.1")
cliCIDR = netip.MustParsePrefix("192.0.2.0/24")
)
var (
clientWithID = &client.Persistent{
Name: cliNameWithID,
ClientIDs: []client.ClientID{cliID},
}
clientWithIP = &client.Persistent{
Name: cliNameWithIP,
IPs: []netip.Addr{cliIP},
}
clientWithCIDR = &client.Persistent{
Name: cliNameWithCIDR,
Subnets: []netip.Prefix{cliCIDR},
}
clientWithMAC = &client.Persistent{
Name: cliNameWithMAC,
MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))},
}
)
clients := []*client.Persistent{
clientWithID,
clientWithIP,
clientWithCIDR,
clientWithMAC,
}
s := newStorage(b, clients)
benchCases := []struct {
params *client.FindParams
name string
wantName string
}{{
params: &client.FindParams{
ClientID: cliID,
},
name: "client_id",
wantName: cliNameWithID,
}, {
params: &client.FindParams{
RemoteIP: cliIP,
},
name: "ip_address",
wantName: cliNameWithIP,
}, {
params: &client.FindParams{
Subnet: cliCIDR,
},
name: "subnet",
wantName: cliNameWithCIDR,
}, {
params: &client.FindParams{
MAC: errors.Must(net.ParseMAC(cliMAC)),
},
name: "mac_address",
wantName: cliNameWithMAC,
}}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
var p *client.Persistent
var ok bool
b.ReportAllocs()
for b.Loop() {
p, ok = s.Find(bc.params)
}
assert.True(b, ok)
assert.NotNil(b, p)
assert.Equal(b, bc.wantName, p.Name)
})
}
// Most recent results:
//
// goos: linux
// goarch: amd64
// pkg: github.com/AdguardTeam/AdGuardHome/internal/client
// cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz
// BenchmarkStorage_Find/client_id-8 7070107 154.4 ns/op 240 B/op 2 allocs/op
// BenchmarkStorage_Find/ip_address-8 6831823 168.6 ns/op 248 B/op 2 allocs/op
// BenchmarkStorage_Find/subnet-8 7209050 167.5 ns/op 256 B/op 2 allocs/op
// BenchmarkStorage_Find/mac_address-8 5776131 199.7 ns/op 256 B/op 3 allocs/op
} }

View File

@@ -138,7 +138,6 @@ func (m *upstreamManager) customUpstreamConfig(uid UID) (proxyConf *proxy.Custom
proxyConf = newCustomUpstreamConfig(cliConf, m.commonConf) proxyConf = newCustomUpstreamConfig(cliConf, m.commonConf)
cliConf.proxyConf = proxyConf cliConf.proxyConf = proxyConf
cliConf.commonConfUpdate = m.confUpdate
cliConf.isChanged = false cliConf.isChanged = false
return proxyConf return proxyConf

View File

@@ -1,11 +1,13 @@
package dhcpsvc_test package dhcpsvc_test
import ( import (
"net"
"net/netip" "net/netip"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/require"
) )
// testLocalTLD is a common local TLD for tests. // testLocalTLD is a common local TLD for tests.
@@ -54,3 +56,11 @@ var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
}, },
}, },
} }
// mustParseMAC parses a hardware address from s and requires no errors.
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
require.NoError(t, err)
return mac
}

View File

@@ -1,127 +0,0 @@
package dhcpsvc
import (
"context"
"fmt"
"net/netip"
"github.com/AdguardTeam/golibs/errors"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// serveV4 handles the ethernet packet of IPv4 type.
func (srv *DHCPServer) serveV4(
ctx context.Context,
rw responseWriter4,
pkt gopacket.Packet,
) (err error) {
defer func() { err = errors.Annotate(err, "serving dhcpv4: %w") }()
req, ok := pkt.Layer(layers.LayerTypeDHCPv4).(*layers.DHCPv4)
if !ok {
srv.logger.DebugContext(ctx, "skipping non-dhcpv4 packet")
return nil
}
// TODO(e.burkov): Handle duplicate Xid.
if req.Operation != layers.DHCPOpRequest {
srv.logger.DebugContext(ctx, "skipping non-request dhcpv4 packet")
return nil
}
typ, ok := msg4Type(req)
if !ok {
// The "DHCP message type" option - must be included in every DHCP
// message.
//
// See https://datatracker.ietf.org/doc/html/rfc2131#section-3.
return fmt.Errorf("dhcpv4: message type: %w", errors.ErrNoValue)
}
return srv.handleDHCPv4(ctx, rw, typ, req)
}
// handleDHCPv4 handles the DHCPv4 message of the given type.
func (srv *DHCPServer) handleDHCPv4(
ctx context.Context,
rw responseWriter4,
typ layers.DHCPMsgType,
req *layers.DHCPv4,
) (err error) {
// Each interface should handle the DISCOVER and REQUEST messages offer and
// allocate the available leases. The RELEASE and DECLINE messages should
// be handled by the server itself as it should remove the lease.
switch typ {
case layers.DHCPMsgTypeDiscover:
srv.handleDiscover(ctx, rw, req)
case layers.DHCPMsgTypeRequest:
srv.handleRequest(ctx, rw, req)
case layers.DHCPMsgTypeRelease:
// TODO(e.burkov): !! Remove the lease, either allocated or offered.
case layers.DHCPMsgTypeDecline:
// TODO(e.burkov): !! Remove the allocated lease. RFC tells it only
// possible if the client found the address already in use.
default:
// TODO(e.burkov): Handle DHCPINFORM.
return fmt.Errorf("dhcpv4: request type: %w: %v", errors.ErrBadEnumValue, typ)
}
return nil
}
// handleDiscover handles the DHCPv4 message of discover type.
func (srv *DHCPServer) handleDiscover(ctx context.Context, rw responseWriter4, req *layers.DHCPv4) {
// TODO(e.burkov): Check existing leases, either allocated or offered.
for _, iface := range srv.interfaces4 {
go iface.handleDiscover(ctx, rw, req)
}
}
// handleRequest handles the DHCPv4 message of request type.
func (srv *DHCPServer) handleRequest(ctx context.Context, rw responseWriter4, req *layers.DHCPv4) {
srvID, hasSrvID := serverID4(req)
reqIP, hasReqIP := requestedIPv4(req)
switch {
case hasSrvID && !srvID.IsUnspecified():
// If the DHCPREQUEST message contains a server identifier option, the
// message is in response to a DHCPOFFER message. Otherwise, the
// message is a request to verify or extend an existing lease.
iface, hasIface := srv.interfaces4.findInterface(srvID)
if !hasIface {
srv.logger.DebugContext(ctx, "skipping selecting request", "serverid", srvID)
return
}
iface.handleSelecting(ctx, rw, req, reqIP)
case hasReqIP && !reqIP.IsUnspecified():
// Requested IP address option MUST be filled in with client's notion of
// its previously assigned address.
iface, hasIface := srv.interfaces4.findInterface(reqIP)
if !hasIface {
srv.logger.DebugContext(ctx, "skipping init-reboot request", "requestedip", reqIP)
return
}
iface.handleInitReboot(ctx, rw, req, reqIP)
default:
// Server identifier MUST NOT be filled in, requested IP address option
// MUST NOT be filled in.
ip, _ := netip.AddrFromSlice(req.ClientIP.To4())
iface, hasIface := srv.interfaces4.findInterface(ip)
if !hasIface {
srv.logger.DebugContext(ctx, "skipping init-reboot request", "clientip", ip)
return
}
iface.handleRenew(ctx, rw, req)
}
}

View File

@@ -1,57 +0,0 @@
package dhcpsvc
import (
"context"
"fmt"
"github.com/AdguardTeam/golibs/errors"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
// serveV6 handles the ethernet packet of IPv6 type.
func (srv *DHCPServer) serveV6(
ctx context.Context,
rw responseWriter4,
pkt gopacket.Packet,
) (err error) {
defer func() { err = errors.Annotate(err, "serving dhcpv6: %w") }()
msg, ok := pkt.Layer(layers.LayerTypeDHCPv6).(*layers.DHCPv6)
if !ok {
srv.logger.DebugContext(ctx, "skipping non-dhcpv6 packet")
return nil
}
// TODO(e.burkov): Handle duplicate TransactionID.
typ := msg.MsgType
return srv.handleDHCPv6(ctx, rw, typ, msg)
}
// handleDHCPv6 handles the DHCPv6 message of the given type.
func (srv *DHCPServer) handleDHCPv6(
_ context.Context,
_ responseWriter4,
typ layers.DHCPv6MsgType,
_ *layers.DHCPv6,
) (err error) {
switch typ {
case
layers.DHCPv6MsgTypeSolicit,
layers.DHCPv6MsgTypeRequest,
layers.DHCPv6MsgTypeConfirm,
layers.DHCPv6MsgTypeRenew,
layers.DHCPv6MsgTypeRebind,
layers.DHCPv6MsgTypeInformationRequest,
layers.DHCPv6MsgTypeRelease,
layers.DHCPv6MsgTypeDecline:
// TODO(e.burkov): Handle messages.
default:
return fmt.Errorf("dhcpv6: request type: %w: %v", errors.ErrBadEnumValue, typ)
}
return nil
}

View File

@@ -45,6 +45,17 @@ type netInterface struct {
leaseTTL time.Duration leaseTTL time.Duration
} }
// newNetInterface creates a new netInterface with the given name, leaseTTL, and
// logger.
func newNetInterface(name string, l *slog.Logger, leaseTTL time.Duration) (iface *netInterface) {
return &netInterface{
logger: l,
leases: map[macKey]*Lease{},
name: name,
leaseTTL: leaseTTL,
}
}
// reset clears all the slices in iface for reuse. // reset clears all the slices in iface for reuse.
func (iface *netInterface) reset() { func (iface *netInterface) reset() {
clear(iface.leases) clear(iface.leases)

View File

@@ -1,50 +0,0 @@
package dhcpsvc
import (
"context"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/google/gopacket/layers"
)
// responseWriter4 writes DHCPv4 response to the client.
type responseWriter4 interface {
// write writes the DHCPv4 response to the client.
write(ctx context.Context, pkt *layers.DHCPv4) (err error)
}
// serve handles the incoming packets and dispatches them to the appropriate
// handler based on the Ethernet type. It's used to run in a separate goroutine
// as it blocks until packets channel is closed.
func (srv *DHCPServer) serve(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, srv.logger)
for pkt := range srv.packetSource.Packets() {
etherLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet)
if !ok {
srv.logger.DebugContext(ctx, "skipping non-ethernet packet")
continue
}
var err error
// TODO(e.burkov): Set the response writer.
var rw responseWriter4
switch typ := etherLayer.EthernetType; typ {
case layers.EthernetTypeIPv4:
err = srv.serveV4(ctx, rw, pkt)
case layers.EthernetTypeIPv6:
err = srv.serveV6(ctx, rw, pkt)
default:
srv.logger.DebugContext(ctx, "skipping ethernet packet", "type", typ)
continue
}
if err != nil {
srv.logger.ErrorContext(ctx, "serving", slogutil.KeyError, err)
}
}
}

View File

@@ -13,13 +13,9 @@ import (
"time" "time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/gopacket"
) )
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families. // DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
//
// TODO(e.burkov): Rename to Default.
type DHCPServer struct { type DHCPServer struct {
// enabled indicates whether the DHCP server is enabled and can provide // enabled indicates whether the DHCP server is enabled and can provide
// information about its clients. // information about its clients.
@@ -28,9 +24,6 @@ type DHCPServer struct {
// logger logs common DHCP events. // logger logs common DHCP events.
logger *slog.Logger logger *slog.Logger
// TODO(e.burkov): Implement and set.
packetSource gopacket.PacketSource
// localTLD is the top-level domain name to use for resolving DHCP clients' // localTLD is the top-level domain name to use for resolving DHCP clients'
// hostnames. // hostnames.
localTLD string localTLD string
@@ -105,7 +98,7 @@ func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) {
// their configurations. // their configurations.
func newInterfaces( func newInterfaces(
ctx context.Context, ctx context.Context,
baseLogger *slog.Logger, l *slog.Logger,
ifaces map[string]*InterfaceConfig, ifaces map[string]*InterfaceConfig,
) (v4 dhcpInterfacesV4, v6 dhcpInterfacesV6, err error) { ) (v4 dhcpInterfacesV4, v6 dhcpInterfacesV6, err error) {
defer func() { err = errors.Annotate(err, "creating interfaces: %w") }() defer func() { err = errors.Annotate(err, "creating interfaces: %w") }()
@@ -117,27 +110,18 @@ func newInterfaces(
var errs []error var errs []error
for _, name := range slices.Sorted(maps.Keys(ifaces)) { for _, name := range slices.Sorted(maps.Keys(ifaces)) {
iface := ifaces[name] iface := ifaces[name]
var i4 *dhcpInterfaceV4
iface4, v4Err := newDHCPInterfaceV4( i4, err = newDHCPInterfaceV4(ctx, l, name, iface.IPv4)
ctx, if err != nil {
baseLogger.With(keyInterface, name, keyFamily, netutil.AddrFamilyIPv4), errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err))
name, } else if i4 != nil {
iface.IPv4, v4 = append(v4, i4)
)
if v4Err != nil {
v4Err = fmt.Errorf("interface %q: %s: %w", name, netutil.AddrFamilyIPv4, v4Err)
errs = append(errs, v4Err)
} else {
v4 = append(v4, iface4)
} }
iface6 := newDHCPInterfaceV6( i6 := newDHCPInterfaceV6(ctx, l, name, iface.IPv6)
ctx, if i6 != nil {
baseLogger.With(keyInterface, name, keyFamily, netutil.AddrFamilyIPv6), v6 = append(v6, i6)
name, }
iface.IPv6,
)
v6 = append(v6, iface6)
} }
if err = errors.Join(errs...); err != nil { if err = errors.Join(errs...); err != nil {
@@ -152,25 +136,6 @@ func newInterfaces(
// TODO(e.burkov): Uncomment when the [Interface] interface is implemented. // TODO(e.burkov): Uncomment when the [Interface] interface is implemented.
// var _ Interface = (*DHCPServer)(nil) // var _ Interface = (*DHCPServer)(nil)
// Start implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Start(ctx context.Context) (err error) {
srv.logger.DebugContext(ctx, "starting dhcp server")
// TODO(e.burkov): Listen to configured interfaces.
go srv.serve(context.Background())
return nil
}
func (srv *DHCPServer) Shutdown(ctx context.Context) (err error) {
srv.logger.DebugContext(ctx, "shutting down dhcp server")
// TODO(e.burkov): Close the packet source.
return nil
}
// Enabled implements the [Interface] interface for *DHCPServer. // Enabled implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Enabled() (ok bool) { func (srv *DHCPServer) Enabled() (ok bool) {
return srv.enabled.Load() return srv.enabled.Load()
@@ -370,50 +335,6 @@ func (srv *DHCPServer) RemoveLease(ctx context.Context, l *Lease) (err error) {
return nil return nil
} }
// removeLeaseByAddr removes the lease with the given IP address from the
// server. It returns an error if the lease can't be removed.
//
// TODO(e.burkov): !! Use.
func (srv *DHCPServer) removeLeaseByAddr(ctx context.Context, addr netip.Addr) (err error) {
defer func() { err = errors.Annotate(err, "removing lease by address: %w") }()
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
l, ok := srv.leases.leaseByAddr(addr)
if !ok {
return fmt.Errorf("no lease for ip %s", addr)
}
err = srv.leases.remove(l, iface)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
err = srv.dbStore(ctx)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return err
}
iface.logger.DebugContext(
ctx, "removed lease",
"hostname", l.Hostname,
"ip", l.IP,
"mac", l.HWAddr,
"static", l.IsStatic,
)
return nil
}
// ifaceForAddr returns the handled network interface for the given IP address, // ifaceForAddr returns the handled network interface for the given IP address,
// or an error if no such interface exists. // or an error if no such interface exists.
func (srv *DHCPServer) ifaceForAddr(addr netip.Addr) (iface *netInterface, err error) { func (srv *DHCPServer) ifaceForAddr(addr netip.Addr) (iface *netInterface, err error) {

View File

@@ -2,7 +2,6 @@ package dhcpsvc_test
import ( import (
"io/fs" "io/fs"
"net"
"net/netip" "net/netip"
"os" "os"
"path" "path"
@@ -12,7 +11,6 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -178,9 +176,9 @@ func TestDHCPServer_AddLease(t *testing.T) {
newIP = netip.MustParseAddr("192.168.0.3") newIP = netip.MustParseAddr("192.168.0.3")
newIPv6 = netip.MustParseAddr("2001:db8::2") newIPv6 = netip.MustParseAddr("2001:db8::2")
existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06")) existMAC = mustParseMAC(t, "01:02:03:04:05:06")
newMAC = errors.Must(net.ParseMAC("06:05:04:03:02:01")) newMAC = mustParseMAC(t, "06:05:04:03:02:01")
ipv6MAC = errors.Must(net.ParseMAC("02:03:04:05:06:07")) ipv6MAC = mustParseMAC(t, "02:03:04:05:06:07")
) )
require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{ require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{
@@ -293,9 +291,9 @@ func TestDHCPServer_index(t *testing.T) {
ip3 = netip.MustParseAddr("172.16.0.3") ip3 = netip.MustParseAddr("172.16.0.3")
ip4 = netip.MustParseAddr("172.16.0.4") ip4 = netip.MustParseAddr("172.16.0.4")
mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06")) mac1 = mustParseMAC(t, "01:02:03:04:05:06")
mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01")) mac2 = mustParseMAC(t, "06:05:04:03:02:01")
mac3 = errors.Must(net.ParseMAC("02:03:04:05:06:07")) mac3 = mustParseMAC(t, "02:03:04:05:06:07")
) )
t.Run("ip_idx", func(t *testing.T) { t.Run("ip_idx", func(t *testing.T) {
@@ -351,9 +349,9 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
ip3 = netip.MustParseAddr("192.168.0.4") ip3 = netip.MustParseAddr("192.168.0.4")
ip4 = netip.MustParseAddr("2001:db8::3") ip4 = netip.MustParseAddr("2001:db8::3")
mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06")) mac1 = mustParseMAC(t, "01:02:03:04:05:06")
mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01")) mac2 = mustParseMAC(t, "06:05:04:03:02:01")
mac3 = errors.Must(net.ParseMAC("06:05:04:03:02:02")) mac3 = mustParseMAC(t, "06:05:04:03:02:02")
) )
testCases := []struct { testCases := []struct {
@@ -454,9 +452,9 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
newIP = netip.MustParseAddr("192.168.0.3") newIP = netip.MustParseAddr("192.168.0.3")
newIPv6 = netip.MustParseAddr("2001:db8::2") newIPv6 = netip.MustParseAddr("2001:db8::2")
existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06")) existMAC = mustParseMAC(t, "01:02:03:04:05:06")
newMAC = errors.Must(net.ParseMAC("02:03:04:05:06:07")) newMAC = mustParseMAC(t, "02:03:04:05:06:07")
ipv6MAC = errors.Must(net.ParseMAC("06:05:04:03:02:01")) ipv6MAC = mustParseMAC(t, "06:05:04:03:02:01")
) )
testCases := []struct { testCases := []struct {
@@ -561,13 +559,13 @@ func TestServer_Leases(t *testing.T) {
Expiry: expiry, Expiry: expiry,
IP: netip.MustParseAddr("192.168.0.3"), IP: netip.MustParseAddr("192.168.0.3"),
Hostname: "example.host", Hostname: "example.host",
HWAddr: errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA")), HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"),
IsStatic: false, IsStatic: false,
}, { }, {
Expiry: time.Time{}, Expiry: time.Time{},
IP: netip.MustParseAddr("192.168.0.4"), IP: netip.MustParseAddr("192.168.0.4"),
Hostname: "example.static.host", Hostname: "example.static.host",
HWAddr: errors.Must(net.ParseMAC("BB:BB:BB:BB:BB:BB")), HWAddr: mustParseMAC(t, "BB:BB:BB:BB:BB:BB"),
IsStatic: true, IsStatic: true,
}} }}
assert.ElementsMatch(t, wantLeases, srv.Leases()) assert.ElementsMatch(t, wantLeases, srv.Leases())

View File

@@ -91,7 +91,7 @@ type dhcpInterfaceV4 struct {
// gateway is the IP address of the network gateway. // gateway is the IP address of the network gateway.
gateway netip.Addr gateway netip.Addr
// subnet is the network subnet of the interface. // subnet is the network subnet.
subnet netip.Prefix subnet netip.Prefix
// addrSpace is the IPv4 address space allocated for leasing. // addrSpace is the IPv4 address space allocated for leasing.
@@ -115,7 +115,12 @@ func newDHCPInterfaceV4(
l *slog.Logger, l *slog.Logger,
name string, name string,
conf *IPv4Config, conf *IPv4Config,
) (iface *dhcpInterfaceV4, err error) { ) (i *dhcpInterfaceV4, err error) {
l = l.With(
keyInterface, name,
keyFamily, netutil.AddrFamilyIPv4,
)
if !conf.Enabled { if !conf.Enabled {
l.DebugContext(ctx, "disabled") l.DebugContext(ctx, "disabled")
@@ -139,20 +144,31 @@ func newDHCPInterfaceV4(
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace) return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
} }
iface = &dhcpInterfaceV4{ i = &dhcpInterfaceV4{
gateway: conf.GatewayIP, gateway: conf.GatewayIP,
subnet: subnet, subnet: subnet,
addrSpace: addrSpace, addrSpace: addrSpace,
common: &netInterface{ common: newNetInterface(name, l, conf.LeaseDuration),
logger: l,
leases: map[macKey]*Lease{},
name: name,
leaseTTL: conf.LeaseDuration,
},
} }
iface.implicitOpts, iface.explicitOpts = conf.options(ctx, l) i.implicitOpts, i.explicitOpts = conf.options(ctx, l)
return iface, nil return i, nil
}
// dhcpInterfacesV4 is a slice of network interfaces of IPv4 address family.
type dhcpInterfacesV4 []*dhcpInterfaceV4
// find returns the first network interface within ifaces containing ip. It
// returns false if there is no such interface.
func (ifaces dhcpInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV4) (contains bool) {
return iface.subnet.Contains(ip)
})
if i < 0 {
return nil, false
}
return ifaces[i].common, true
} }
// options returns the implicit and explicit options for the interface. The two // options returns the implicit and explicit options for the interface. The two
@@ -345,104 +361,3 @@ func (c *IPv4Config) options(ctx context.Context, l *slog.Logger) (imp, exp laye
func compareV4OptionCodes(a, b layers.DHCPOption) (res int) { func compareV4OptionCodes(a, b layers.DHCPOption) (res int) {
return int(a.Type) - int(b.Type) return int(a.Type) - int(b.Type)
} }
// msg4Type returns the message type of msg, if it's present within the options.
func msg4Type(msg *layers.DHCPv4) (typ layers.DHCPMsgType, ok bool) {
for _, opt := range msg.Options {
if opt.Type == layers.DHCPOptMessageType && len(opt.Data) > 0 {
return layers.DHCPMsgType(opt.Data[0]), true
}
}
return 0, false
}
// requestedIPv4 returns the IPv4 address, requested by client in the DHCP
// message, if any.
func requestedIPv4(msg *layers.DHCPv4) (ip netip.Addr, ok bool) {
for _, opt := range msg.Options {
if opt.Type == layers.DHCPOptRequestIP && len(opt.Data) == net.IPv4len {
return netip.AddrFromSlice(opt.Data)
}
}
return netip.Addr{}, false
}
// serverID4 returns the server ID of the DHCP message, if any.
func serverID4(msg *layers.DHCPv4) (ip netip.Addr, ok bool) {
for _, opt := range msg.Options {
if opt.Type == layers.DHCPOptServerID && len(opt.Data) == net.IPv4len {
return netip.AddrFromSlice(opt.Data)
}
}
return netip.Addr{}, false
}
// handleDiscover handles messages of type discover.
func (iface *dhcpInterfaceV4) handleDiscover(
ctx context.Context,
rw responseWriter4,
msg *layers.DHCPv4,
) {
// TODO(e.burkov): !! Implement.
}
// handleSelecting handles messages of type request in SELECTING state.
func (iface *dhcpInterfaceV4) handleSelecting(
ctx context.Context,
rw responseWriter4,
msg *layers.DHCPv4,
reqIP netip.Addr,
) {
// TODO(e.burkov): !! Implement.
}
// handleSelecting handles messages of type request in INIT-REBOOT state.
func (iface *dhcpInterfaceV4) handleInitReboot(
ctx context.Context,
rw responseWriter4,
msg *layers.DHCPv4,
reqIP netip.Addr,
) {
// TODO(e.burkov): !! Implement.
}
// handleRenew handles messages of type request in RENEWING or REBINDING state.
func (iface *dhcpInterfaceV4) handleRenew(
ctx context.Context,
rw responseWriter4,
req *layers.DHCPv4,
) {
// TODO(e.burkov): !! Implement.
}
// dhcpInterfacesV4 is a slice of network interfaces of IPv4 address family.
type dhcpInterfacesV4 []*dhcpInterfaceV4
// find returns the first network interface within ifaces containing ip. It
// returns false if there is no such interface.
func (ifaces dhcpInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV4) (contains bool) {
return iface.subnet.Contains(ip)
})
if i < 0 {
return nil, false
}
return ifaces[i].common, true
}
// findInterface returns the first DHCPv4 interface within ifaces containing
// ip. It returns false if there is no such interface.
func (ifaces dhcpInterfacesV4) findInterface(ip netip.Addr) (iface *dhcpInterfaceV4, ok bool) {
i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV4) (contains bool) {
return iface.subnet.Contains(ip)
})
if i < 0 {
return nil, false
}
return ifaces[i], true
}

View File

@@ -97,27 +97,23 @@ func newDHCPInterfaceV6(
l *slog.Logger, l *slog.Logger,
name string, name string,
conf *IPv6Config, conf *IPv6Config,
) (iface *dhcpInterfaceV6) { ) (i *dhcpInterfaceV6) {
l = l.With(keyInterface, name, keyFamily, netutil.AddrFamilyIPv6)
if !conf.Enabled { if !conf.Enabled {
l.DebugContext(ctx, "disabled") l.DebugContext(ctx, "disabled")
return nil return nil
} }
iface = &dhcpInterfaceV6{ i = &dhcpInterfaceV6{
rangeStart: conf.RangeStart, rangeStart: conf.RangeStart,
common: &netInterface{ common: newNetInterface(name, l, conf.LeaseDuration),
logger: l,
leases: map[macKey]*Lease{},
name: name,
leaseTTL: conf.LeaseDuration,
},
raSLAACOnly: conf.RASLAACOnly, raSLAACOnly: conf.RASLAACOnly,
raAllowSLAAC: conf.RAAllowSLAAC, raAllowSLAAC: conf.RAAllowSLAAC,
} }
iface.implicitOpts, iface.explicitOpts = conf.options(ctx, l) i.implicitOpts, i.explicitOpts = conf.options(ctx, l)
return iface return i
} }
// dhcpInterfacesV6 is a slice of network interfaces of IPv6 address family. // dhcpInterfacesV6 is a slice of network interfaces of IPv6 address family.

View File

@@ -10,7 +10,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
@@ -52,7 +51,7 @@ func processAccessClients(
} else if ipnet, err = netip.ParsePrefix(s); err == nil { } else if ipnet, err = netip.ParsePrefix(s); err == nil {
*nets = append(*nets, ipnet) *nets = append(*nets, ipnet)
} else { } else {
err = client.ValidateClientID(s) err = ValidateClientID(s)
if err != nil { if err != nil {
return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i) return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i)
} }

View File

@@ -7,13 +7,26 @@ import (
"path" "path"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
// ValidateClientID returns an error if id is not a valid ClientID.
//
// Keep in sync with [client.ValidateClientID].
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName // clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the // is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match, // client. When strict is true, and client and host server name don't match,
@@ -40,7 +53,7 @@ func clientIDFromClientServerName(
} }
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
err = client.ValidateClientID(clientID) err = ValidateClientID(clientID)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return "", err return "", err
@@ -80,7 +93,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath) return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath)
} }
err = client.ValidateClientID(clientID) err = ValidateClientID(clientID)
if err != nil { if err != nil {
return "", fmt.Errorf("clientid check: %w", err) return "", fmt.Errorf("clientid check: %w", err)
} }

View File

@@ -1,317 +1,131 @@
package home package home
import ( import (
"crypto/rand" "context"
"encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghuser"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"go.etcd.io/bbolt" "github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// sessionTokenSize is the length of session token in bytes. // webUser represents a user of the Web UI.
const sessionTokenSize = 16 type webUser struct {
// Name represents the login name of the web user.
Name string `yaml:"name"`
type session struct { // PasswordHash is the hashed representation of the web user password.
userName string PasswordHash string `yaml:"password"`
// expire is the expiration time, in seconds.
expire uint32 // UserID is the unique identifier of the web user.
//
// TODO(s.chzhen): !! Use this.
UserID aghuser.UserID `yaml:"-"`
} }
func (s *session) serialize() []byte { // toUser returns the new properly initialized *aghuser.User using stored
const ( // properties. It panics if there is an error generating the user ID.
expireLen = 4 func (wu *webUser) toUser() (u *aghuser.User) {
nameLen = 2 uid := wu.UserID
) if uid == (aghuser.UserID{}) {
data := make([]byte, expireLen+nameLen+len(s.userName)) uid = aghuser.MustNewUserID()
binary.BigEndian.PutUint32(data[0:4], s.expire)
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
copy(data[6:], []byte(s.userName))
return data
}
func (s *session) deserialize(data []byte) bool {
if len(data) < 4+2 {
return false
} }
s.expire = binary.BigEndian.Uint32(data[0:4])
nameLen := binary.BigEndian.Uint16(data[4:6])
data = data[6:]
if len(data) < int(nameLen) { return &aghuser.User{
return false Password: aghuser.NewDefaultPassword(wu.PasswordHash),
Login: aghuser.Login(wu.Name),
ID: uid,
} }
s.userName = string(data)
return true
} }
// Auth is the global authentication object. // Auth is the global authentication object.
type Auth struct { type Auth struct {
trustedProxies netutil.SubnetSet logger *slog.Logger
db *bbolt.DB
rateLimiter *authRateLimiter rateLimiter *authRateLimiter
sessions map[string]*session sessions aghuser.SessionStorage
users []webUser trustedProxies netutil.SubnetSet
lock sync.Mutex users aghuser.DB
sessionTTL uint32
} }
// webUser represents a user of the Web UI. // InitAuth initializes the global authentication object. baseLogger,
// // rateLimiter, trustedProxies must not be nil. dbFilename and sessionTTL
// TODO(s.chzhen): Improve naming. // should not be empty.
type webUser struct {
Name string `yaml:"name"`
PasswordHash string `yaml:"password"`
}
// InitAuth initializes the global authentication object.
func InitAuth( func InitAuth(
ctx context.Context,
baseLogger *slog.Logger,
dbFilename string, dbFilename string,
users []webUser, users []webUser,
sessionTTL uint32, sessionTTL time.Duration,
rateLimiter *authRateLimiter, rateLimiter *authRateLimiter,
trustedProxies netutil.SubnetSet, trustedProxies netutil.SubnetSet,
) (a *Auth) { ) (a *Auth, err error) {
log.Info("Initializing auth module: %s", dbFilename) userDB := aghuser.NewDefaultDB()
for i, u := range users {
a = &Auth{ err = userDB.Create(ctx, u.toUser())
sessionTTL: sessionTTL, if err != nil {
rateLimiter: rateLimiter, return nil, fmt.Errorf("users: at index %d: %w", i, err)
sessions: make(map[string]*session),
users: users,
trustedProxies: trustedProxies,
}
var err error
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
if err != nil {
log.Error("auth: open DB: %s: %s", dbFilename, err)
if err.Error() == "invalid argument" {
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
} }
return nil
} }
a.loadSessions()
log.Info("auth: initialized. users:%d sessions:%d", len(a.users), len(a.sessions))
return a s, err := aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
Logger: baseLogger.With(slogutil.KeyPrefix, "session_storage"),
Clock: timeutil.SystemClock{},
UserDB: aghuser.NewDefaultDB(),
DBPath: dbFilename,
SessionTTL: sessionTTL,
})
if err != nil {
return nil, fmt.Errorf("creating session storage: %w", err)
}
return &Auth{
logger: baseLogger.With(slogutil.KeyPrefix, "auth"),
rateLimiter: rateLimiter,
trustedProxies: trustedProxies,
sessions: s,
users: userDB,
}, nil
} }
// Close closes the authentication database. // Close closes the authentication database.
func (a *Auth) Close() { func (a *Auth) Close(ctx context.Context) {
_ = a.db.Close() err := a.sessions.Close()
}
func bucketName() []byte {
return []byte("sessions-2")
}
// loadSessions loads sessions from the database file and removes expired
// sessions.
func (a *Auth) loadSessions() {
tx, err := a.db.Begin(true)
if err != nil { if err != nil {
log.Error("auth: bbolt.Begin: %s", err) a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err)
return
}
defer func() {
_ = tx.Rollback()
}()
bkt := tx.Bucket(bucketName())
if bkt == nil {
return
}
removed := 0
if tx.Bucket([]byte("sessions")) != nil {
_ = tx.DeleteBucket([]byte("sessions"))
removed = 1
}
now := uint32(time.Now().UTC().Unix())
forEach := func(k, v []byte) error {
s := session{}
if !s.deserialize(v) || s.expire <= now {
err = bkt.Delete(k)
if err != nil {
log.Error("auth: bbolt.Delete: %s", err)
} else {
removed++
}
return nil
}
a.sessions[hex.EncodeToString(k)] = &s
return nil
}
_ = bkt.ForEach(forEach)
if removed != 0 {
err = tx.Commit()
if err != nil {
log.Error("bolt.Commit(): %s", err)
}
}
log.Debug("auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
}
// addSession adds a new session to the list of sessions and saves it in the
// database file.
func (a *Auth) addSession(data []byte, s *session) {
name := hex.EncodeToString(data)
a.lock.Lock()
a.sessions[name] = s
a.lock.Unlock()
if a.storeSession(data, s) {
log.Debug("auth: created session %s: expire=%d", name, s.expire)
} }
} }
// storeSession saves a session in the database file. // isValidSession returns true if the session is valid.
func (a *Auth) storeSession(data []byte, s *session) bool { func (a *Auth) isValidSession(ctx context.Context, cookieSess string) (ok bool) {
tx, err := a.db.Begin(true) sess, err := hex.DecodeString(cookieSess)
if err != nil { if err != nil {
log.Error("auth: bbolt.Begin: %s", err) a.logger.ErrorContext(ctx, "checking session: decoding cookie", slogutil.KeyError, err)
return false
}
defer func() {
_ = tx.Rollback()
}()
bkt, err := tx.CreateBucketIfNotExists(bucketName())
if err != nil {
log.Error("auth: bbolt.CreateBucketIfNotExists: %s", err)
return false return false
} }
err = bkt.Put(data, s.serialize()) var t aghuser.SessionToken
copy(t[:], sess)
s, err := a.sessions.FindByToken(ctx, t)
if err != nil { if err != nil {
log.Error("auth: bbolt.Put: %s", err) a.logger.ErrorContext(ctx, "checking session", slogutil.KeyError, err)
return false return false
} }
err = tx.Commit() return s != nil
if err != nil {
log.Error("auth: bbolt.Commit: %s", err)
return false
}
return true
} }
// removeSessionFromFile removes a stored session from the DB file on disk. // addUser adds a new user with the given password. u must not be nil.
func (a *Auth) removeSessionFromFile(sess []byte) { func (a *Auth) addUser(ctx context.Context, u *webUser, password string) (err error) {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("auth: bbolt.Begin: %s", err)
return
}
defer func() {
_ = tx.Rollback()
}()
bkt := tx.Bucket(bucketName())
if bkt == nil {
log.Error("auth: bbolt.Bucket")
return
}
err = bkt.Delete(sess)
if err != nil {
log.Error("auth: bbolt.Put: %s", err)
return
}
err = tx.Commit()
if err != nil {
log.Error("auth: bbolt.Commit: %s", err)
return
}
log.Debug("auth: removed session from DB")
}
// checkSessionResult is the result of checking a session.
type checkSessionResult int
// checkSessionResult constants.
const (
checkSessionOK checkSessionResult = 0
checkSessionNotFound checkSessionResult = -1
checkSessionExpired checkSessionResult = 1
)
// checkSession checks if the session is valid.
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
now := uint32(time.Now().UTC().Unix())
update := false
a.lock.Lock()
defer a.lock.Unlock()
s, ok := a.sessions[sess]
if !ok {
return checkSessionNotFound
}
if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSessionFromFile(key)
return checkSessionExpired
}
newExpire := now + a.sessionTTL
if s.expire/(24*60*60) != newExpire/(24*60*60) {
// update expiration time once a day
update = true
s.expire = newExpire
}
if update {
key, _ := hex.DecodeString(sess)
if a.storeSession(key, s) {
log.Debug("auth: updated session %s: expire=%d", sess, s.expire)
}
}
return checkSessionOK
}
// removeSession removes the session from the active sessions and the disk.
func (a *Auth) removeSession(sess string) {
key, _ := hex.DecodeString(sess)
a.lock.Lock()
delete(a.sessions, sess)
a.lock.Unlock()
a.removeSessionFromFile(key)
}
// addUser adds a new user with the given password.
func (a *Auth) addUser(u *webUser, password string) (err error) {
if len(password) == 0 { if len(password) == 0 {
return errors.Error("empty password") return errors.Error("empty password")
} }
@@ -323,97 +137,129 @@ func (a *Auth) addUser(u *webUser, password string) (err error) {
u.PasswordHash = string(hash) u.PasswordHash = string(hash)
a.lock.Lock() err = a.users.Create(ctx, u.toUser())
defer a.lock.Unlock() if err != nil {
// Should not happen.
panic(err)
}
a.users = append(a.users, *u) a.logger.DebugContext(ctx, "added user", "login", u.Name)
log.Debug("auth: added user with login %q", u.Name)
return nil return nil
} }
// findUser returns a user if there is one. // findUser returns a user if one exists with the provided login and the
func (a *Auth) findUser(login, password string) (u webUser, ok bool) { // password matches.
a.lock.Lock() func (a *Auth) findUser(ctx context.Context, login, password string) (user *aghuser.User) {
defer a.lock.Unlock() user, err := a.users.ByLogin(ctx, aghuser.Login(login))
if err != nil {
for _, u = range a.users { return nil
if u.Name == login &&
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
return u, true
}
} }
return webUser{}, false ok := user.Password.Authenticate(ctx, password)
if !ok {
return nil
}
return user
} }
// getCurrentUser returns the current user. It returns an empty User if the // getCurrentUser searches for a user using a cookie or credentials from basic
// user is not found. // authentication.
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) { func (a *Auth) getCurrentUser(r *http.Request) (user *aghuser.User) {
ctx := r.Context()
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if err != nil { if err != nil {
// There's no Cookie, check Basic authentication. // There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth() user, pass, ok := r.BasicAuth()
if ok { if ok {
u, _ = globalContext.auth.findUser(user, pass) return a.findUser(ctx, user, pass)
return u
} }
return webUser{} return nil
} }
a.lock.Lock() sess, err := hex.DecodeString(cookie.Value)
defer a.lock.Unlock() if err != nil {
a.logger.ErrorContext(
ctx,
"searching for user: decoding cookie value",
slogutil.KeyError, err,
)
s, ok := a.sessions[cookie.Value] return nil
if !ok {
return webUser{}
} }
for _, u = range a.users { var t aghuser.SessionToken
if u.Name == s.userName { copy(t[:], sess)
return u
} s, err := a.sessions.FindByToken(ctx, t)
if err != nil {
a.logger.ErrorContext(ctx, "searching for user", slogutil.KeyError, err)
return nil
} }
return webUser{} if s == nil {
return nil
}
return &aghuser.User{
Login: s.UserLogin,
ID: s.UserID,
}
}
// removeSession deletes the session from the active sessions and the disk. It
// also logs any occurring errors.
func (a *Auth) removeSession(ctx context.Context, cookieSess string) {
sess, err := hex.DecodeString(cookieSess)
if err != nil {
a.logger.ErrorContext(ctx, "removing session: decoding cookie", slogutil.KeyError, err)
return
}
var t aghuser.SessionToken
copy(t[:], sess)
err = a.sessions.DeleteByToken(ctx, t)
if err != nil {
a.logger.ErrorContext(ctx, "removing session by token", slogutil.KeyError, err)
}
} }
// usersList returns a copy of a users list. // usersList returns a copy of a users list.
func (a *Auth) usersList() (users []webUser) { func (a *Auth) usersList(ctx context.Context) (webUsers []webUser) {
a.lock.Lock() users, err := a.users.All(ctx)
defer a.lock.Unlock() if err != nil {
// Should not happen.
panic(err)
}
users = make([]webUser, len(a.users)) webUsers = make([]webUser, 0, len(users))
copy(users, a.users) for _, u := range users {
webUsers = append(webUsers, webUser{
Name: string(u.Login),
PasswordHash: string(u.Password.Hash()),
UserID: u.ID,
})
}
return users return webUsers
} }
// authRequired returns true if a authentication is required. // authRequired returns true if a authentication is required.
func (a *Auth) authRequired() bool { func (a *Auth) authRequired(ctx context.Context) (ok bool) {
if GLMode { if GLMode {
return true return true
} }
a.lock.Lock() users, err := a.users.All(ctx)
defer a.lock.Unlock() if err != nil {
// Should not happen.
panic(err)
}
return len(a.users) != 0 return len(users) != 0
}
// newSessionToken returns cryptographically secure randomly generated slice of
// bytes of sessionTokenSize length.
//
// TODO(e.burkov): Think about using byte array instead of byte slice.
func newSessionToken() (data []byte) {
randData := make([]byte, sessionTokenSize)
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
// unrecoverably instead.
_, _ = rand.Read(randData)
return randData
} }

View File

@@ -1,69 +0,0 @@
package home
import (
"encoding/hex"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []webUser{{
Name: "name",
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
}}
a := InitAuth(fn, nil, 60, nil, nil)
s := session{}
user := webUser{Name: "name"}
err := a.addUser(&user, "password")
require.NoError(t, err)
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.removeSession("notfound")
sess := newSessionToken()
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
a.Close()
// load saved session
a = InitAuth(fn, users, 60, nil, nil)
// the session is still alive
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()
u, ok := a.findUser("name", "password")
assert.True(t, ok)
assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second)
// load and remove expired sessions
a = InitAuth(fn, users, 60, nil, nil)
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
a.Close()
}

View File

@@ -1,6 +1,7 @@
package home package home
import ( import (
"context"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -32,10 +33,14 @@ type loginJSON struct {
} }
// newCookie creates a new authentication cookie. // newCookie creates a new authentication cookie.
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) { func (a *Auth) newCookie(
ctx context.Context,
req loginJSON,
addr string,
) (c *http.Cookie, err error) {
rateLimiter := a.rateLimiter rateLimiter := a.rateLimiter
u, ok := a.findUser(req.Name, req.Password) u := a.findUser(ctx, req.Name, req.Password)
if !ok { if u == nil {
if rateLimiter != nil { if rateLimiter != nil {
rateLimiter.inc(addr) rateLimiter.inc(addr)
} }
@@ -47,19 +52,16 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
rateLimiter.remove(addr) rateLimiter.remove(addr)
} }
sess := newSessionToken() s, err := a.sessions.New(ctx, u)
now := time.Now().UTC() if err != nil {
return nil, fmt.Errorf("creating session: %w", err)
a.addSession(sess, &session{ }
userName: u.Name,
expire: uint32(now.Unix()) + a.sessionTTL,
})
return &http.Cookie{ return &http.Cookie{
Name: sessionCookieName, Name: sessionCookieName,
Value: hex.EncodeToString(sess), Value: hex.EncodeToString(s.Token[:]),
Path: "/", Path: "/",
Expires: now.Add(cookieTTL), Expires: time.Now().Add(cookieTTL),
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -172,7 +174,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err) log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
} }
cookie, err := globalContext.auth.newCookie(req, remoteIP) cookie, err := globalContext.auth.newCookie(r.Context(), req, remoteIP)
if err != nil { if err != nil {
logIP := remoteIP logIP := remoteIP
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) { if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
@@ -209,7 +211,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
return return
} }
globalContext.auth.removeSession(c.Value) globalContext.auth.removeSession(r.Context(), c.Value)
c = &http.Cookie{ c = &http.Cookie{
Name: sessionCookieName, Name: sessionCookieName,
@@ -242,28 +244,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
return false return false
} }
// redirect to login page if not authenticated if u := globalContext.auth.getCurrentUser(r); u != nil {
isAuthenticated := false
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
// Check Basic authentication.
user, pass, hasBasic := r.BasicAuth()
if hasBasic {
_, isAuthenticated = globalContext.auth.findUser(user, pass)
if !isAuthenticated {
log.Info("%s: invalid basic authorization value", pref)
}
}
} else {
res := globalContext.auth.checkSession(cookie.Value)
isAuthenticated = res == checkSessionOK
if !isAuthenticated {
log.Debug("%s: invalid cookie value: %q", pref, cookie)
}
}
if isAuthenticated {
return false return false
} }
@@ -289,14 +270,14 @@ func optionalAuth(
h func(http.ResponseWriter, *http.Request), h func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) { ) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
p := r.URL.Path p := r.URL.Path
authRequired := globalContext.auth != nil && globalContext.auth.authRequired() authRequired := globalContext.auth != nil && globalContext.auth.authRequired(ctx)
if p == "/login.html" { if p == "/login.html" {
cookie, err := r.Cookie(sessionCookieName) cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil { if authRequired && err == nil {
// Redirect to the dashboard if already authenticated. // Redirect to the dashboard if already authenticated.
res := globalContext.auth.checkSession(cookie.Value) if globalContext.auth.isValidSession(ctx, cookie.Value) {
if res == checkSessionOK {
http.Redirect(w, r, "", http.StatusFound) http.Redirect(w, r, "", http.StatusFound)
return return

View File

@@ -7,8 +7,10 @@ import (
"net/url" "net/url"
"path/filepath" "path/filepath"
"testing" "testing"
"time"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -33,13 +35,20 @@ func (w *testResponseWriter) WriteHeader(statusCode int) {
} }
func TestAuthHTTP(t *testing.T) { func TestAuthHTTP(t *testing.T) {
var (
ctx = testutil.ContextWithTimeout(t, testTimeout)
logger = slogutil.NewDiscardLogger()
err error
)
dir := t.TempDir() dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db") fn := filepath.Join(dir, "sessions.db")
users := []webUser{ users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"}, {Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
} }
globalContext.auth = InitAuth(fn, users, 60, nil, nil) globalContext.auth, err = InitAuth(ctx, logger, fn, users, time.Minute, nil, nil)
require.NoError(t, err)
handlerCalled := false handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) { handler := func(_ http.ResponseWriter, _ *http.Request) {
@@ -68,7 +77,11 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled) assert.True(t, handlerCalled)
// perform login // perform login
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "") cookie, err := globalContext.auth.newCookie(
ctx,
loginJSON{Name: "name", Password: "password"},
"",
)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, cookie) require.NotNil(t, cookie)
@@ -114,7 +127,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled) assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie) r.Header.Del(httphdr.Cookie)
globalContext.auth.Close() globalContext.auth.Close(ctx)
} }
func TestRealIP(t *testing.T) { func TestRealIP(t *testing.T) {

View File

@@ -28,10 +28,6 @@ type clientsContainer struct {
// filter. It must not be nil. // filter. It must not be nil.
baseLogger *slog.Logger baseLogger *slog.Logger
// logger is used for logging the operation of the client container. It
// must not be nil.
logger *slog.Logger
// storage stores information about persistent clients. // storage stores information about persistent clients.
storage *client.Storage storage *client.Storage
@@ -62,7 +58,6 @@ type clientsContainer struct {
// BlockedClientChecker checks if a client is blocked by the current access // BlockedClientChecker checks if a client is blocked by the current access
// settings. // settings.
type BlockedClientChecker interface { type BlockedClientChecker interface {
// TODO(s.chzhen): Accept [client.FindParams].
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
} }
@@ -85,7 +80,6 @@ func (clients *clientsContainer) Init(
} }
clients.baseLogger = baseLogger clients.baseLogger = baseLogger
clients.logger = baseLogger.With(slogutil.KeyPrefix, "client_container")
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
@@ -275,7 +269,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
BlockedServices: cli.BlockedServices.Clone(), BlockedServices: cli.BlockedServices.Clone(),
IDs: cli.Identifiers(), IDs: cli.IDs(),
Tags: slices.Clone(cli.Tags), Tags: slices.Clone(cli.Tags),
Upstreams: slices.Clone(cli.Upstreams), Upstreams: slices.Clone(cli.Upstreams),
@@ -362,27 +356,15 @@ func (clients *clientsContainer) clientOrArtificial(
}, true }, true
} }
// shouldCountClient is a wrapper around [client.Storage.Find] to make it a // shouldCountClient is a wrapper around [clientsContainer.find] to make it a
// valid client information finder for the statistics. If no information about // valid client information finder for the statistics. If no information about
// the client is found, it returns true. Values of ids must be either a valid // the client is found, it returns true.
// ClientID or a valid IP address.
//
// TODO(s.chzhen): Accept [client.FindParams].
func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
params := &client.FindParams{}
for _, id := range ids { for _, id := range ids {
err := params.Set(id) client, ok := clients.storage.Find(id)
if err != nil {
// Should not happen.
clients.logger.Warn("parsing find params", slogutil.KeyError, err)
continue
}
client, ok := clients.storage.Find(params)
if ok { if ok {
return !client.IgnoreStatistics return !client.IgnoreStatistics
} }

View File

@@ -300,7 +300,7 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) {
return &clientJSON{ return &clientJSON{
Name: c.Name, Name: c.Name,
IDs: c.Identifiers(), IDs: c.IDs(),
Tags: c.Tags, Tags: c.Tags,
UseGlobalSettings: !c.UseOwnSettings, UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled, FilteringEnabled: c.FilteringEnabled,
@@ -428,53 +428,32 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
// Deprecated: Remove it when migration to the new API is over. // Deprecated: Remove it when migration to the new API is over.
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() q := r.URL.Query()
data := make([]map[string]*clientJSON, 0, len(q)) data := []map[string]*clientJSON{}
params := &client.FindParams{}
var err error
for i := range len(q) { for i := range len(q) {
idStr := q.Get(fmt.Sprintf("ip%d", i)) idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" { if idStr == "" {
break break
} }
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
r.Context(),
"finding client",
"id", idStr,
slogutil.KeyError, err,
)
continue
}
data = append(data, map[string]*clientJSON{ data = append(data, map[string]*clientJSON{
idStr: clients.findClient(idStr, params), idStr: clients.findClient(idStr),
}) })
} }
aghhttp.WriteJSONResponseOK(w, r, data) aghhttp.WriteJSONResponseOK(w, r, data)
} }
// findClient returns available information about a client by params from the // findClient returns available information about a client by idStr from the
// client's storage or access settings. idStr is the string representation of // client's storage or access settings. cj is guaranteed to be non-nil.
// typed params. params must not be nil. cj is guaranteed to be non-nil. func (clients *clientsContainer) findClient(idStr string) (cj *clientJSON) {
func (clients *clientsContainer) findClient( ip, _ := netip.ParseAddr(idStr)
idStr string, c, ok := clients.storage.Find(idStr)
params *client.FindParams,
) (cj *clientJSON) {
c, ok := clients.storage.Find(params)
if !ok { if !ok {
return clients.findRuntime(idStr, params) return clients.findRuntime(ip, idStr)
} }
cj = clientToJSON(c) cj = clientToJSON(c)
disallowed, rule := clients.clientChecker.IsBlockedClient( disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
params.RemoteIP,
string(params.ClientID),
)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj return cj
@@ -493,8 +472,7 @@ type searchClientJSON struct {
ID string `json:"id"` ID string `json:"id"`
} }
// handleSearchClient is the handler for the POST /control/clients/search HTTP // handleSearchClient is the handler for the POST /control/clients/search HTTP API.
// API.
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
q := searchQueryJSON{} q := searchQueryJSON{}
err := json.NewDecoder(r.Body).Decode(&q) err := json.NewDecoder(r.Body).Decode(&q)
@@ -504,25 +482,11 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
return return
} }
data := make([]map[string]*clientJSON, 0, len(q.Clients)) data := []map[string]*clientJSON{}
params := &client.FindParams{}
for _, c := range q.Clients { for _, c := range q.Clients {
idStr := c.ID idStr := c.ID
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
r.Context(),
"searching client",
"id", idStr,
slogutil.KeyError, err,
)
continue
}
data = append(data, map[string]*clientJSON{ data = append(data, map[string]*clientJSON{
idStr: clients.findClient(idStr, params), idStr: clients.findClient(idStr),
}) })
} }
@@ -530,37 +494,38 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
} }
// findRuntime looks up the IP in runtime and temporary storages, like // findRuntime looks up the IP in runtime and temporary storages, like
// /etc/hosts tables, DHCP leases, or blocklists. params must not be nil. cj // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// is guaranteed to be non-nil. // non-nil.
func (clients *clientsContainer) findRuntime( func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
idStr string,
params *client.FindParams,
) (cj *clientJSON) {
var host string
whois := &whois.Info{}
ip := params.RemoteIP
rc := clients.storage.ClientRuntime(ip) rc := clients.storage.ClientRuntime(ip)
if rc != nil { if rc == nil {
_, host = rc.Info() // It is still possible that the IP used to be in the runtime clients
whois = whoisOrEmpty(rc) // list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj = &clientJSON{
IDs: []string{idStr},
Disallowed: &disallowed,
DisallowedRule: &rule,
WHOIS: &whois.Info{},
}
return cj
} }
// Check the DNS server's blocked IP list regardless of whether a runtime _, host := rc.Info()
// client was found or not. This is because it's still possible that the cj = &clientJSON{
// runtime client associated with the IP address was stored previously, but Name: host,
// then the server was reloaded. IDs: []string{idStr},
// WHOIS: whoisOrEmpty(rc),
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, string(params.ClientID))
return &clientJSON{
Name: host,
IDs: []string{idStr},
WHOIS: whois,
Disallowed: &disallowed,
DisallowedRule: &rule,
} }
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj
} }
// RegisterClientsHandlers registers HTTP handlers // RegisterClientsHandlers registers HTTP handlers

View File

@@ -153,7 +153,7 @@ func TestClientsContainer_HandleAddClient(t *testing.T) {
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id") clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []client.ClientID{""} clientEmptyID.ClientIDs = []string{""}
testCases := []struct { testCases := []struct {
name string name string
@@ -278,7 +278,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) {
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id") clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []client.ClientID{""} clientEmptyID.ClientIDs = []string{""}
testCases := []struct { testCases := []struct {
name string name string

View File

@@ -2,6 +2,7 @@ package home
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/netip" "net/netip"
"os" "os"
@@ -748,7 +749,8 @@ func (c *configuration) write(tlsMgr *tlsManager) (err error) {
defer c.Unlock() defer c.Unlock()
if globalContext.auth != nil { if globalContext.auth != nil {
config.Users = globalContext.auth.usersList() // TODO(s.chzhen): Pass context.
config.Users = globalContext.auth.usersList(context.TODO())
} }
if tlsMgr != nil { if tlsMgr != nil {

View File

@@ -392,6 +392,8 @@ const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server // Apply new configuration, start DNS server, restart Web server
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req, restartHTTP, err := decodeApplyConfigReq(r.Body) req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -439,7 +441,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
u := &webUser{ u := &webUser{
Name: req.Username, Name: req.Username,
} }
err = globalContext.auth.addUser(u, req.Password) err = globalContext.auth.addUser(ctx, u, req.Password)
if err != nil { if err != nil {
globalContext.firstRun = true globalContext.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
@@ -452,7 +454,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// moment we'll allow setting up TLS in the initial configuration or the // moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying // configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server. // functions potentially restart the HTTPS server.
err = startMods(r.Context(), web.baseLogger, web.tlsManager) err = startMods(ctx, web.baseLogger, web.tlsManager)
if err != nil { if err != nil {
globalContext.firstRun = true globalContext.firstRun = true
copyInstallSettings(config, curConfig) copyInstallSettings(config, curConfig)
@@ -488,11 +490,11 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// and with its own context, because it waits until all requests are handled // and with its own context, because it waits until all requests are handled
// and will be blocked by it's own caller. // and will be blocked by it's own caller.
go func(timeout time.Duration) { go func(timeout time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer slogutil.RecoverAndLog(ctx, web.logger) defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
defer cancel() defer cancel()
shutdownSrv(ctx, web.logger, web.httpServer) shutdownSrv(shutdownCtx, web.logger, web.httpServer)
}(shutdownTimeout) }(shutdownTimeout)
} }

View File

@@ -347,6 +347,13 @@ func newDNSTLSConfig(
return nil, fmt.Errorf(format, err) return nil, fmt.Errorf(format, err)
} }
// Unencrypted DoH is managed by AdGuard Home itself, not by dnsproxy.
// Therefore, avoid setting the certificate property to prevent dnsproxy
// from starting encrypted listeners. See [dnsforward.Server.prepareTLS].
if conf.AllowUnencryptedDoH {
return dnsConf, nil
}
dnsConf.Cert = &cert dnsConf.Cert = &cert
return dnsConf, nil return dnsConf, nil

View File

@@ -668,7 +668,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
GLMode = opts.glinetMode GLMode = opts.glinetMode
// Init auth module. // Init auth module.
globalContext.auth, err = initUsers() globalContext.auth, err = initUsers(ctx, slogLogger)
fatalOnError(err) fatalOnError(err)
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL) web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
@@ -786,7 +786,8 @@ func checkPermissions(
} }
// initUsers initializes context auth module. Clears config users field. // initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) { // baseLogger must not be nil.
func initUsers(ctx context.Context, baseLogger *slog.Logger) (auth *Auth, err error) {
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db") sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter var rateLimiter *authRateLimiter
@@ -799,10 +800,17 @@ func initUsers() (auth *Auth, err error) {
trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies)) trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies))
sessionTTL := time.Duration(config.HTTPConfig.SessionTTL).Seconds() auth, err = InitAuth(
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter, trustedProxies) ctx,
if auth == nil { baseLogger,
return nil, errors.Error("initializing auth module failed") sessFilename,
config.Users,
time.Duration(config.HTTPConfig.SessionTTL),
rateLimiter,
trustedProxies,
)
if err != nil {
return nil, fmt.Errorf("initializing auth module: %w", err)
} }
config.Users = nil config.Users = nil
@@ -916,7 +924,7 @@ func cleanup(ctx context.Context) {
globalContext.web = nil globalContext.web = nil
} }
if globalContext.auth != nil { if globalContext.auth != nil {
globalContext.auth.Close() globalContext.auth.Close(ctx)
globalContext.auth = nil globalContext.auth = nil
} }

View File

@@ -8,7 +8,7 @@ import (
"net/url" "net/url"
"path" "path"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@@ -151,7 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
clientID := q.Get("client_id") clientID := q.Get("client_id")
if clientID != "" { if clientID != "" {
err = client.ValidateClientID(clientID) err = dnsforward.ValidateClientID(clientID)
if err != nil { if err != nil {
respondJSONError(w, http.StatusBadRequest, err.Error()) respondJSONError(w, http.StatusBadRequest, err.Error())

View File

@@ -47,7 +47,11 @@ type profileJSON struct {
// handleGetProfile is the handler for GET /control/profile endpoint. // handleGetProfile is the handler for GET /control/profile endpoint.
func handleGetProfile(w http.ResponseWriter, r *http.Request) { func handleGetProfile(w http.ResponseWriter, r *http.Request) {
name := ""
u := globalContext.auth.getCurrentUser(r) u := globalContext.auth.getCurrentUser(r)
if u != nil {
name = string(u.Login)
}
var resp profileJSON var resp profileJSON
func() { func() {
@@ -55,7 +59,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
defer config.RUnlock() defer config.RUnlock()
resp = profileJSON{ resp = profileJSON{
Name: u.Name, Name: name,
Language: config.Language, Language: config.Language,
Theme: config.Theme, Theme: config.Theme,
} }

View File

@@ -204,8 +204,6 @@ func assertCertSerialNumber(tb testing.TB, conf *tlsConfigSettings, wantSN int64
func TestTLSManager_Reload(t *testing.T) { func TestTLSManager_Reload(t *testing.T) {
storeGlobals(t) storeGlobals(t)
config.DNS.Port = 0
var ( var (
logger = slogutil.NewDiscardLogger() logger = slogutil.NewDiscardLogger()
ctx = testutil.ContextWithTimeout(t, testTimeout) ctx = testutil.ContextWithTimeout(t, testTimeout)
@@ -262,10 +260,6 @@ func TestTLSManager_Reload(t *testing.T) {
m.reload(ctx) m.reload(ctx)
// The [tlsManager.reload] method will start the DNS server and it should be
// stopped after the test ends.
testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop)
conf = m.config() conf = m.config()
assertCertSerialNumber(t, conf, snAfter) assertCertSerialNumber(t, conf, snAfter)
} }

View File

@@ -980,8 +980,7 @@
- 'clients' - 'clients'
'operationId': 'clientsSearch' 'operationId': 'clientsSearch'
'summary': > 'summary': >
Retrieve information about clients by performing an exact match search Get information about clients by their IP addresses, CIDRs, MAC addresses, or ClientIDs.
using IP addresses, CIDRs, MAC addresses, or ClientIDs.
'requestBody': 'requestBody':
'content': 'content':
'application/json': 'application/json':