diff --git a/CHANGELOG.md b/CHANGELOG.md index c7ebcb85..388ff3dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,10 @@ See also the [v0.107.61 GitHub milestone][ms-v0.107.61]. **NOTE:** We thank [Xiang Li][mr-xiang-li] for reporting this security issue. It's strongly recommended to leave it enabled, otherwise AdGuard Home will be vulnerable to untrusted clients. +### Fixed + +- Searching for persistent clients using an exact match for CIDR in the `POST /clients/search HTTP API`. + [mr-xiang-li]: https://lixiang521.com/ [ms-v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/milestone/96?closed=1 diff --git a/internal/client/client.go b/internal/client/client.go index c24c846c..f72ba4de 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -11,8 +11,34 @@ import ( "slices" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" ) +// ClientID is a unique identifier for a persistent client used in +// DNS-over-HTTPS, DNS-over-TLS, and DNS-over-QUIC queries. +// +// TODO(s.chzhen): Use everywhere. +type ClientID string + +// ValidateClientID returns an error if id is not a valid ClientID. +// +// TODO(s.chzhen): Consider implementing [validate.Interface] for ClientID. +func ValidateClientID(id string) (err error) { + err = netutil.ValidateHostnameLabel(id) + if err != nil { + // Replace the domain name label wrapper with our own. + return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err)) + } + + return nil +} + +// isValidClientID returns false if id is not a valid ClientID. +func isValidClientID(id string) (ok bool) { + return netutil.IsValidHostnameLabel(id) +} + // Source represents the source from which the information about the client has // been obtained. type Source uint8 diff --git a/internal/client/index.go b/internal/client/index.go index d34e0e51..a900ab14 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -35,7 +35,7 @@ type index struct { nameToUID map[string]UID // clientIDToUID maps ClientID to UID. - clientIDToUID map[string]UID + clientIDToUID map[ClientID]UID // ipToUID maps IP address to UID. ipToUID map[netip.Addr]UID @@ -54,7 +54,7 @@ type index struct { func newIndex() (ci *index) { return &index{ nameToUID: map[string]UID{}, - clientIDToUID: map[string]UID{}, + clientIDToUID: map[ClientID]UID{}, ipToUID: map[netip.Addr]UID{}, subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), macToUID: map[macKey]UID{}, @@ -207,7 +207,7 @@ func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) // find finds persistent client by string representation of the ClientID, IP // address, or MAC. func (ci *index) find(id string) (c *Persistent, ok bool) { - c, ok = ci.findByClientID(id) + c, ok = ci.findByClientID(ClientID(id)) if ok { return c, true } @@ -230,7 +230,7 @@ func (ci *index) find(id string) (c *Persistent, ok bool) { } // findByClientID finds persistent client by ClientID. -func (ci *index) findByClientID(clientID string) (c *Persistent, ok bool) { +func (ci *index) findByClientID(clientID ClientID) (c *Persistent, ok bool) { uid, ok := ci.clientIDToUID[clientID] if ok { return ci.uidToClient[uid], true @@ -275,6 +275,26 @@ func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) { return nil, false } +// findByCIDR searches for a persistent client with the provided subnet as an +// identifier. Note that this function looks for an exact match of subnets, +// rather than checking if one subnet contains another. +func (ci *index) findByCIDR(subnet netip.Prefix) (c *Persistent, ok bool) { + var uid UID + for pref, id := range ci.subnetToUID.Range { + if subnet == pref { + uid, ok = id, true + + break + } + } + + if ok { + return ci.uidToClient[uid], true + } + + return nil, false +} + // findByMAC finds persistent client by MAC. func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { k := macToKey(mac) diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index f514b995..7fd4b0a2 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/AdguardTeam/golibs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,12 +59,12 @@ func TestClientIndex_Find(t *testing.T) { clientWithMAC = &Persistent{ Name: "client_with_mac", - MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, } clientWithID = &Persistent{ Name: "client_with_id", - ClientIDs: []string{cliID}, + ClientIDs: []ClientID{cliID}, } clientLinkLocal = &Persistent{ @@ -141,10 +142,10 @@ func TestClientIndex_Clashes(t *testing.T) { Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, }, { Name: "client_with_mac", - MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, }, { Name: "client_with_id", - ClientIDs: []string{cliID}, + ClientIDs: []ClientID{cliID}, }} ci := newIDIndex(clients) @@ -181,17 +182,6 @@ func TestClientIndex_Clashes(t *testing.T) { } } -// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an -// error. -func mustParseMAC(s string) (mac net.HardwareAddr) { - mac, err := net.ParseMAC(s) - if err != nil { - panic(err) - } - - return mac -} - func TestMACToKey(t *testing.T) { testCases := []struct { want any @@ -200,44 +190,44 @@ func TestMACToKey(t *testing.T) { }{{ name: "column6", in: "00:00:5e:00:53:01", - want: [6]byte(mustParseMAC("00:00:5e:00:53:01")), + want: [6]byte(errors.Must(net.ParseMAC("00:00:5e:00:53:01"))), }, { name: "column8", in: "02:00:5e:10:00:00:00:01", - want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")), + want: [8]byte(errors.Must(net.ParseMAC("02:00:5e:10:00:00:00:01"))), }, { name: "column20", in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", - want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")), + want: [20]byte(errors.Must(net.ParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01"))), }, { name: "hyphen6", in: "00-00-5e-00-53-01", - want: [6]byte(mustParseMAC("00-00-5e-00-53-01")), + want: [6]byte(errors.Must(net.ParseMAC("00-00-5e-00-53-01"))), }, { name: "hyphen8", in: "02-00-5e-10-00-00-00-01", - want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")), + want: [8]byte(errors.Must(net.ParseMAC("02-00-5e-10-00-00-00-01"))), }, { name: "hyphen20", in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", - want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")), + want: [20]byte(errors.Must(net.ParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01"))), }, { name: "dot6", in: "0000.5e00.5301", - want: [6]byte(mustParseMAC("0000.5e00.5301")), + want: [6]byte(errors.Must(net.ParseMAC("0000.5e00.5301"))), }, { name: "dot8", in: "0200.5e10.0000.0001", - want: [8]byte(mustParseMAC("0200.5e10.0000.0001")), + want: [8]byte(errors.Must(net.ParseMAC("0200.5e10.0000.0001"))), }, { name: "dot20", in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", - want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")), + want: [20]byte(errors.Must(net.ParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001"))), }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mac := mustParseMAC(tc.in) + mac := errors.Must(net.ParseMAC(tc.in)) key := macToKey(mac) assert.Equal(t, tc.want, key) @@ -302,19 +292,19 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) { func TestClientIndex_RangeByName(t *testing.T) { sortedClients := []*Persistent{{ Name: "clientA", - ClientIDs: []string{"A"}, + ClientIDs: []ClientID{"A"}, }, { Name: "clientB", - ClientIDs: []string{"B"}, + ClientIDs: []ClientID{"B"}, }, { Name: "clientC", - ClientIDs: []string{"C"}, + ClientIDs: []ClientID{"C"}, }, { Name: "clientD", - ClientIDs: []string{"D"}, + ClientIDs: []ClientID{"D"}, }, { Name: "clientE", - ClientIDs: []string{"E"}, + ClientIDs: []ClientID{"E"}, }} testCases := []struct { @@ -349,3 +339,115 @@ func TestClientIndex_RangeByName(t *testing.T) { }) } } + +func TestIndex_FindByName(t *testing.T) { + const ( + clientExistingName = "client_existing" + clientAnotherExistingName = "client_another_existing" + nonExistingClientName = "client_non_existing" + ) + + var ( + clientExisting = &Persistent{ + Name: clientExistingName, + IPs: []netip.Addr{netip.MustParseAddr("192.0.2.1")}, + } + + clientAnotherExisting = &Persistent{ + Name: clientAnotherExistingName, + IPs: []netip.Addr{netip.MustParseAddr("192.0.2.2")}, + } + ) + + clients := []*Persistent{ + clientExisting, + clientAnotherExisting, + } + ci := newIDIndex(clients) + + testCases := []struct { + want *Persistent + found assert.BoolAssertionFunc + name string + clientName string + }{{ + want: clientExisting, + found: assert.True, + name: "existing", + clientName: clientExistingName, + }, { + want: clientAnotherExisting, + found: assert.True, + name: "another_existing", + clientName: clientAnotherExistingName, + }, { + want: nil, + found: assert.False, + name: "non_existing", + clientName: nonExistingClientName, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, ok := ci.findByName(tc.clientName) + assert.Equal(t, tc.want, c) + tc.found(t, ok) + }) + } +} + +func TestIndex_FindByMAC(t *testing.T) { + var ( + cliMAC = errors.Must(net.ParseMAC("11:11:11:11:11:11")) + cliAnotherMAC = errors.Must(net.ParseMAC("22:22:22:22:22:22")) + nonExistingClientMAC = errors.Must(net.ParseMAC("33:33:33:33:33:33")) + ) + + var ( + clientExisting = &Persistent{ + Name: "client", + MACs: []net.HardwareAddr{cliMAC}, + } + + clientAnotherExisting = &Persistent{ + Name: "another_client", + MACs: []net.HardwareAddr{cliAnotherMAC}, + } + ) + + clients := []*Persistent{ + clientExisting, + clientAnotherExisting, + } + ci := newIDIndex(clients) + + testCases := []struct { + want *Persistent + found assert.BoolAssertionFunc + name string + clientMAC net.HardwareAddr + }{{ + want: clientExisting, + found: assert.True, + name: "existing", + clientMAC: cliMAC, + }, { + want: clientAnotherExisting, + found: assert.True, + name: "another_existing", + clientMAC: cliAnotherMAC, + }, { + want: nil, + found: assert.False, + name: "non_existing", + clientMAC: nonExistingClientMAC, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, ok := ci.findByMAC(tc.clientMAC) + assert.Equal(t, tc.want, c) + tc.found(t, ok) + }) + } +} diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 4ec3695e..56bd8f47 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -15,7 +15,6 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" - "github.com/AdguardTeam/golibs/netutil" "github.com/google/uuid" ) @@ -90,7 +89,7 @@ type Persistent struct { // ClientIDs identifying the client. The client must have at least one ID // (IP, subnet, MAC, or ClientID). - ClientIDs []string + ClientIDs []ClientID // UID is the unique identifier of the persistent client. UID UID @@ -134,7 +133,7 @@ func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []str switch { case c.Name == "": return errors.Error("empty name") - case c.IDsLen() == 0: + case c.idendifiersLen() == 0: return errors.Error("id required") case c.UID == UID{}: return errors.Error("uid required") @@ -237,28 +236,15 @@ func (c *Persistent) setID(id string) (err error) { return err } - c.ClientIDs = append(c.ClientIDs, strings.ToLower(id)) + c.ClientIDs = append(c.ClientIDs, ClientID(strings.ToLower(id))) return nil } -// ValidateClientID returns an error if id is not a valid ClientID. -// -// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to -// avoid the import cycle. Remove it. -func ValidateClientID(id string) (err error) { - err = netutil.ValidateHostnameLabel(id) - if err != nil { - // Replace the domain name label wrapper with our own. - return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err)) - } - - return nil -} - -// IDs returns a list of ClientIDs containing at least one element. -func (c *Persistent) IDs() (ids []string) { - ids = make([]string, 0, c.IDsLen()) +// Identifiers returns a list of client identifiers containing at least one +// element. +func (c *Persistent) Identifiers() (ids []string) { + ids = make([]string, 0, c.idendifiersLen()) for _, ip := range c.IPs { ids = append(ids, ip.String()) @@ -272,11 +258,15 @@ func (c *Persistent) IDs() (ids []string) { ids = append(ids, mac.String()) } - return append(ids, c.ClientIDs...) + for _, cid := range c.ClientIDs { + ids = append(ids, string(cid)) + } + + return ids } -// IDsLen returns a length of ClientIDs. -func (c *Persistent) IDsLen() (n int) { +// identifiersLen returns the number of client identifiers. +func (c *Persistent) idendifiersLen() (n int) { return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) } diff --git a/internal/client/storage.go b/internal/client/storage.go index 76696d79..c5382648 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "slices" + "strings" "sync" "time" @@ -18,6 +19,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" ) @@ -433,48 +435,186 @@ func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) { ctx, "client added", "name", p.Name, - "ids", p.IDs(), + "ids", p.Identifiers(), "clients_count", s.index.size(), ) return nil } -// FindByName finds persistent client by name. And returns its shallow copy. -func (s *Storage) FindByName(name string) (p *Persistent, ok bool) { - s.mu.Lock() - defer s.mu.Unlock() +// FindParams represents the parameters for searching a client. At least one +// field must be non-empty. +type FindParams struct { + // ClientID is a unique identifier for the client used in DoH, DoT, and DoQ + // DNS queries. + ClientID ClientID - p, ok = s.index.findByName(name) - if ok { - return p.ShallowClone(), ok - } + // RemoteIP is the IP address used as a client search parameter. + RemoteIP netip.Addr - return nil, false + // Subnet is the CIDR used as a client search parameter. + Subnet netip.Prefix + + // MAC is the physical hardware address used as a client search parameter. + MAC net.HardwareAddr + + // UID is the unique ID of persistent client used as a search parameter. + // + // TODO(s.chzhen): Use this. + UID UID } -// Find finds persistent client by string representation of the ClientID, IP -// address, or MAC. And returns its shallow copy. +// ErrBadIdentifier is returned by [FindParams.Set] when it cannot parse the +// provided client identifier. +const ErrBadIdentifier errors.Error = "bad client identifier" + +// Set clears the stored search parameters and parses the string representation +// of the search parameter into typed parameter, storing it. In some cases, it +// may result in storing both an IP address and a MAC address because they might +// have identical string representations. It returns [ErrBadIdentifier] if id +// cannot be parsed. // -// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain -// the parsed IP address, if any. -func (s *Storage) Find(id string) (p *Persistent, ok bool) { +// TODO(s.chzhen): Add support for UID. +func (p *FindParams) Set(id string) (err error) { + *p = FindParams{} + + isClientID := true + + if netutil.IsValidIPString(id) { + // It is safe to use [netip.MustParseAddr] because it has already been + // validated that id contains the string representation of the IP + // address. + p.RemoteIP = netip.MustParseAddr(id) + + // Even if id can be parsed as an IP address, it may be a MAC address. + // So do not return prematurely, continue parsing. + isClientID = false + } + + if canBeValidIPPrefixString(id) { + p.Subnet, err = netip.ParsePrefix(id) + if err == nil { + isClientID = false + } + } + + if canBeMACString(id) { + p.MAC, err = net.ParseMAC(id) + if err == nil { + isClientID = false + } + } + + if !isClientID { + return nil + } + + if !isValidClientID(id) { + return ErrBadIdentifier + } + + p.ClientID = ClientID(id) + + return nil +} + +// canBeValidIPPrefixString is a best-effort check to determine if s is a valid +// CIDR before using [netip.ParsePrefix], aimed at reducing allocations. +// +// TODO(s.chzhen): Replace this implementation with the more robust version +// from golibs. +func canBeValidIPPrefixString(s string) (ok bool) { + ipStr, bitStr, ok := strings.Cut(s, "/") + if !ok { + return false + } + + if bitStr == "" || len(bitStr) > 3 { + return false + } + + bits := 0 + for _, c := range bitStr { + if c < '0' || c > '9' { + return false + } + + bits = bits*10 + int(c-'0') + } + + if bits > 128 { + return false + } + + return netutil.IsValidIPString(ipStr) +} + +// canBeMACString is a best-effort check to determine if s is a valid MAC +// address before using [net.ParseMAC], aimed at reducing allocations. +// +// TODO(s.chzhen): Replace this implementation with the more robust version +// from golibs. +func canBeMACString(s string) (ok bool) { + switch len(s) { + case + len("0000.0000.0000"), + len("00:00:00:00:00:00"), + len("0000.0000.0000.0000"), + len("00:00:00:00:00:00:00:00"), + len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"), + len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"): + return true + default: + return false + } +} + +// Find represents the parameters for searching a client. params must not be +// nil and must have at least one non-empty field. +func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() - p, ok = s.index.find(id) + isClientID := params.ClientID != "" + isRemoteIP := params.RemoteIP != (netip.Addr{}) + isSubnet := params.Subnet != (netip.Prefix{}) + isMAC := params.MAC != nil + + for { + switch { + case isClientID: + isClientID = false + p, ok = s.index.findByClientID(params.ClientID) + case isRemoteIP: + isRemoteIP = false + p, ok = s.findByIP(params.RemoteIP) + case isSubnet: + isSubnet = false + p, ok = s.index.findByCIDR(params.Subnet) + case isMAC: + isMAC = false + p, ok = s.index.findByMAC(params.MAC) + default: + return nil, false + } + + if ok { + return p.ShallowClone(), true + } + } +} + +// findByIP finds persistent client by IP address. s.mu is expected to be +// locked. +func (s *Storage) findByIP(addr netip.Addr) (p *Persistent, ok bool) { + p, ok = s.index.findByIP(addr) if ok { - return p.ShallowClone(), ok + return p, true } - ip, err := netip.ParseAddr(id) - if err != nil { - return nil, false - } - - foundMAC := s.dhcp.MACByIP(ip) + foundMAC := s.dhcp.MACByIP(addr) if foundMAC != nil { - return s.FindByMAC(foundMAC) + return s.index.findByMAC(foundMAC) } return nil, false @@ -487,6 +627,8 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) { // // Note that multiple clients can have the same IP address with different zones. // Therefore, the result of this method is indeterminate. +// +// TODO(s.chzhen): Consider accepting [FindParams]. func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() @@ -498,7 +640,7 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { foundMAC := s.dhcp.MACByIP(ip) if foundMAC != nil { - return s.FindByMAC(foundMAC) + return s.index.findByMAC(foundMAC) } p = s.index.findByIPWithoutZone(ip) @@ -509,17 +651,6 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { return nil, false } -// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu -// is expected to be locked. -func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) { - p, ok = s.index.findByMAC(mac) - if ok { - return p.ShallowClone(), ok - } - - return nil, false -} - // RemoveByName removes persistent client information. ok is false if no such // client exists by that name. func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) { @@ -648,7 +779,7 @@ func (s *Storage) CustomUpstreamConfig( s.mu.Lock() defer s.mu.Unlock() - c, ok := s.index.findByClientID(id) + c, ok := s.index.findByClientID(ClientID(id)) if !ok { c, ok = s.index.findByIP(addr) } @@ -682,7 +813,7 @@ func (s *Storage) ClearUpstreamCache() { // ClientID or client IP address, and applies it to the filtering settings. // setts must not be nil. func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) { - c, ok := s.index.findByClientID(id) + c, ok := s.index.findByClientID(ClientID(id)) if !ok { c, ok = s.index.findByIP(addr) } @@ -690,7 +821,7 @@ func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filter if !ok { foundMAC := s.dhcp.MACByIP(addr) if foundMAC != nil { - c, ok = s.FindByMAC(foundMAC) + c, ok = s.index.findByMAC(foundMAC) } } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index d00f8350..7d6051ca 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -15,6 +15,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" @@ -350,15 +351,15 @@ func TestClientsDHCP(t *testing.T) { cliName1 = "one.dhcp" cliIP2 = netip.MustParseAddr("2.2.2.2") - cliMAC2 = mustParseMAC("22:22:22:22:22:22") + cliMAC2 = errors.Must(net.ParseMAC("22:22:22:22:22:22")) cliName2 = "two.dhcp" cliIP3 = netip.MustParseAddr("3.3.3.3") - cliMAC3 = mustParseMAC("33:33:33:33:33:33") + cliMAC3 = errors.Must(net.ParseMAC("33:33:33:33:33:33")) cliName3 = "three.dhcp" prsCliIP = netip.MustParseAddr("4.3.2.1") - prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") + prsCliMAC = errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA")) prsCliName = "persistent.dhcp" otherARPCliName = "other.arp" @@ -519,7 +520,11 @@ func TestClientsDHCP(t *testing.T) { }) require.NoError(t, err) - prsCli, ok := storage.Find(prsCliIP.String()) + params := &client.FindParams{} + err = params.Set(prsCliIP.String()) + require.NoError(t, err) + + prsCli, ok := storage.Find(params) require.True(t, ok) assert.Equal(t, prsCliName, prsCli.Name) @@ -663,17 +668,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { return s } -// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an -// error. -func mustParseMAC(s string) (mac net.HardwareAddr) { - mac, err := net.ParseMAC(s) - if err != nil { - panic(err) - } - - return mac -} - func TestStorage_Add(t *testing.T) { const ( existingName = "existing_name" @@ -693,7 +687,7 @@ func TestStorage_Add(t *testing.T) { Name: existingName, IPs: []netip.Addr{existingIP}, Subnets: []netip.Prefix{existingSubnet}, - ClientIDs: []string{existingClientID}, + ClientIDs: []client.ClientID{existingClientID}, UID: existingClientUID, } @@ -761,7 +755,7 @@ func TestStorage_Add(t *testing.T) { name: "duplicate_client_id", cli: &client.Persistent{ Name: "duplicate_client_id", - ClientIDs: []string{existingClientID}, + ClientIDs: []client.ClientID{existingClientID}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: another client "existing_name" ` + @@ -898,12 +892,12 @@ func TestStorage_Find(t *testing.T) { clientWithMAC = &client.Persistent{ Name: "client_with_mac", - MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, } clientWithID = &client.Persistent{ Name: "client_with_id", - ClientIDs: []string{cliID}, + ClientIDs: []client.ClientID{cliID}, } clientLinkLocal = &client.Persistent{ @@ -950,7 +944,11 @@ func TestStorage_Find(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { for _, id := range tc.ids { - c, ok := s.Find(id) + params := &client.FindParams{} + err := params.Set(id) + require.NoError(t, err) + + c, ok := s.Find(params) require.True(t, ok) assert.Equal(t, tc.want, c) @@ -959,7 +957,11 @@ func TestStorage_Find(t *testing.T) { } t.Run("not_found", func(t *testing.T) { - _, ok := s.Find(cliIPNone) + params := &client.FindParams{} + err := params.Set(cliIPNone) + require.NoError(t, err) + + _, ok := s.Find(params) assert.False(t, ok) }) } @@ -1025,127 +1027,6 @@ func TestStorage_FindLoose(t *testing.T) { } } -func TestStorage_FindByName(t *testing.T) { - const ( - cliIP1 = "1.1.1.1" - cliIP2 = "2.2.2.2" - ) - - const ( - clientExistingName = "client_existing" - clientAnotherExistingName = "client_another_existing" - nonExistingClientName = "client_non_existing" - ) - - var ( - clientExisting = &client.Persistent{ - Name: clientExistingName, - IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, - } - - clientAnotherExisting = &client.Persistent{ - Name: clientAnotherExistingName, - IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, - } - ) - - clients := []*client.Persistent{ - clientExisting, - clientAnotherExisting, - } - s := newStorage(t, clients) - - testCases := []struct { - want *client.Persistent - name string - clientName string - }{{ - name: "existing", - clientName: clientExistingName, - want: clientExisting, - }, { - name: "another_existing", - clientName: clientAnotherExistingName, - want: clientAnotherExisting, - }, { - name: "non_existing", - clientName: nonExistingClientName, - want: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c, ok := s.FindByName(tc.clientName) - if tc.want == nil { - assert.False(t, ok) - - return - } - - assert.True(t, ok) - assert.Equal(t, tc.want, c) - }) - } -} - -func TestStorage_FindByMAC(t *testing.T) { - var ( - cliMAC = mustParseMAC("11:11:11:11:11:11") - cliAnotherMAC = mustParseMAC("22:22:22:22:22:22") - nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33") - ) - - var ( - clientExisting = &client.Persistent{ - Name: "client", - MACs: []net.HardwareAddr{cliMAC}, - } - - clientAnotherExisting = &client.Persistent{ - Name: "another_client", - MACs: []net.HardwareAddr{cliAnotherMAC}, - } - ) - - clients := []*client.Persistent{ - clientExisting, - clientAnotherExisting, - } - s := newStorage(t, clients) - - testCases := []struct { - want *client.Persistent - name string - clientMAC net.HardwareAddr - }{{ - name: "existing", - clientMAC: cliMAC, - want: clientExisting, - }, { - name: "another_existing", - clientMAC: cliAnotherMAC, - want: clientAnotherExisting, - }, { - name: "non_existing", - clientMAC: nonExistingClientMAC, - want: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c, ok := s.FindByMAC(tc.clientMAC) - if tc.want == nil { - assert.False(t, ok) - - return - } - - assert.True(t, ok) - assert.Equal(t, tc.want, c) - }) - } -} - func TestStorage_Update(t *testing.T) { const ( clientName = "client_name" @@ -1162,7 +1043,7 @@ func TestStorage_Update(t *testing.T) { Name: obstructingName, IPs: []netip.Addr{obstructingIP}, Subnets: []netip.Prefix{obstructingSubnet}, - ClientIDs: []string{obstructingClientID}, + ClientIDs: []client.ClientID{obstructingClientID}, } clientToUpdate := &client.Persistent{ @@ -1211,7 +1092,7 @@ func TestStorage_Update(t *testing.T) { name: "duplicate_client_id", cli: &client.Persistent{ Name: "duplicate_client_id", - ClientIDs: []string{obstructingClientID}, + ClientIDs: []client.ClientID{obstructingClientID}, UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" ` + @@ -1238,19 +1119,19 @@ func TestStorage_Update(t *testing.T) { func TestStorage_RangeByName(t *testing.T) { sortedClients := []*client.Persistent{{ Name: "clientA", - ClientIDs: []string{"A"}, + ClientIDs: []client.ClientID{"A"}, }, { Name: "clientB", - ClientIDs: []string{"B"}, + ClientIDs: []client.ClientID{"B"}, }, { Name: "clientC", - ClientIDs: []string{"C"}, + ClientIDs: []client.ClientID{"C"}, }, { Name: "clientD", - ClientIDs: []string{"D"}, + ClientIDs: []client.ClientID{"D"}, }, { Name: "clientE", - ClientIDs: []string{"E"}, + ClientIDs: []client.ClientID{"E"}, }} testCases := []struct { @@ -1306,7 +1187,7 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) { existingClient := &client.Persistent{ Name: existingName, IPs: []netip.Addr{existingIP}, - ClientIDs: []string{existingClientID}, + ClientIDs: []client.ClientID{existingClientID}, UID: existingClientUID, Upstreams: []string{"192.0.2.0"}, } @@ -1381,3 +1262,182 @@ func TestStorage_CustomUpstreamConfig(t *testing.T) { assert.NotEqual(t, conf, updConf) }) } + +func BenchmarkFindParams_Set(b *testing.B) { + const ( + testIPStr = "192.0.2.1" + testCIDRStr = "192.0.2.0/24" + testMACStr = "02:00:00:00:00:00" + testClientID = "clientid" + ) + + benchCases := []struct { + wantErr error + params *client.FindParams + name string + id string + }{{ + wantErr: nil, + params: &client.FindParams{ + ClientID: testClientID, + }, + name: "client_id", + id: testClientID, + }, { + wantErr: nil, + params: &client.FindParams{ + RemoteIP: netip.MustParseAddr(testIPStr), + }, + name: "ip_address", + id: testIPStr, + }, { + wantErr: nil, + params: &client.FindParams{ + Subnet: netip.MustParsePrefix(testCIDRStr), + }, + name: "subnet", + id: testCIDRStr, + }, { + wantErr: nil, + params: &client.FindParams{ + MAC: errors.Must(net.ParseMAC(testMACStr)), + }, + name: "mac_address", + id: testMACStr, + }, { + wantErr: client.ErrBadIdentifier, + params: &client.FindParams{}, + name: "bad_id", + id: "!@#$%^&*()_+", + }} + + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + params := &client.FindParams{} + var err error + + b.ReportAllocs() + for b.Loop() { + err = params.Set(bc.id) + } + + assert.ErrorIs(b, err, bc.wantErr) + assert.Equal(b, bc.params, params) + }) + } + + // Most recent results: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardHome/internal/client + // cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz + // BenchmarkFindParams_Set/client_id-8 49463488 24.27 ns/op 0 B/op 0 allocs/op + // BenchmarkFindParams_Set/ip_address-8 18740977 62.22 ns/op 0 B/op 0 allocs/op + // BenchmarkFindParams_Set/subnet-8 10848192 110.0 ns/op 0 B/op 0 allocs/op + // BenchmarkFindParams_Set/mac_address-8 8148494 133.2 ns/op 8 B/op 1 allocs/op + // BenchmarkFindParams_Set/bad_id-8 73894278 16.29 ns/op 0 B/op 0 allocs/op +} + +func BenchmarkStorage_Find(b *testing.B) { + const ( + cliID = "cid" + cliMAC = "02:00:00:00:00:00" + ) + + const ( + cliNameWithID = "client_with_id" + cliNameWithIP = "client_with_ip" + cliNameWithCIDR = "client_with_cidr" + cliNameWithMAC = "client_with_mac" + ) + + var ( + cliIP = netip.MustParseAddr("192.0.2.1") + cliCIDR = netip.MustParsePrefix("192.0.2.0/24") + ) + + var ( + clientWithID = &client.Persistent{ + Name: cliNameWithID, + ClientIDs: []client.ClientID{cliID}, + } + clientWithIP = &client.Persistent{ + Name: cliNameWithIP, + IPs: []netip.Addr{cliIP}, + } + clientWithCIDR = &client.Persistent{ + Name: cliNameWithCIDR, + Subnets: []netip.Prefix{cliCIDR}, + } + clientWithMAC = &client.Persistent{ + Name: cliNameWithMAC, + MACs: []net.HardwareAddr{errors.Must(net.ParseMAC(cliMAC))}, + } + ) + + clients := []*client.Persistent{ + clientWithID, + clientWithIP, + clientWithCIDR, + clientWithMAC, + } + s := newStorage(b, clients) + + benchCases := []struct { + params *client.FindParams + name string + wantName string + }{{ + params: &client.FindParams{ + ClientID: cliID, + }, + name: "client_id", + wantName: cliNameWithID, + }, { + params: &client.FindParams{ + RemoteIP: cliIP, + }, + name: "ip_address", + wantName: cliNameWithIP, + }, { + params: &client.FindParams{ + Subnet: cliCIDR, + }, + name: "subnet", + wantName: cliNameWithCIDR, + }, { + params: &client.FindParams{ + MAC: errors.Must(net.ParseMAC(cliMAC)), + }, + name: "mac_address", + wantName: cliNameWithMAC, + }} + + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + var p *client.Persistent + var ok bool + + b.ReportAllocs() + for b.Loop() { + p, ok = s.Find(bc.params) + } + + assert.True(b, ok) + assert.NotNil(b, p) + assert.Equal(b, bc.wantName, p.Name) + }) + } + + // Most recent results: + // + // goos: linux + // goarch: amd64 + // pkg: github.com/AdguardTeam/AdGuardHome/internal/client + // cpu: Intel(R) Core(TM) i7-10510U CPU @ 1.80GHz + // BenchmarkStorage_Find/client_id-8 7070107 154.4 ns/op 240 B/op 2 allocs/op + // BenchmarkStorage_Find/ip_address-8 6831823 168.6 ns/op 248 B/op 2 allocs/op + // BenchmarkStorage_Find/subnet-8 7209050 167.5 ns/op 256 B/op 2 allocs/op + // BenchmarkStorage_Find/mac_address-8 5776131 199.7 ns/op 256 B/op 3 allocs/op +} diff --git a/internal/dhcpsvc/dhcpsvc_test.go b/internal/dhcpsvc/dhcpsvc_test.go index f8b993f6..016dc7cd 100644 --- a/internal/dhcpsvc/dhcpsvc_test.go +++ b/internal/dhcpsvc/dhcpsvc_test.go @@ -1,13 +1,11 @@ package dhcpsvc_test import ( - "net" "net/netip" "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/golibs/logutil/slogutil" - "github.com/stretchr/testify/require" ) // testLocalTLD is a common local TLD for tests. @@ -56,11 +54,3 @@ var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{ }, }, } - -// mustParseMAC parses a hardware address from s and requires no errors. -func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) { - mac, err := net.ParseMAC(s) - require.NoError(t, err) - - return mac -} diff --git a/internal/dhcpsvc/server_test.go b/internal/dhcpsvc/server_test.go index 94509e37..88123e17 100644 --- a/internal/dhcpsvc/server_test.go +++ b/internal/dhcpsvc/server_test.go @@ -2,6 +2,7 @@ package dhcpsvc_test import ( "io/fs" + "net" "net/netip" "os" "path" @@ -11,6 +12,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -176,9 +178,9 @@ func TestDHCPServer_AddLease(t *testing.T) { newIP = netip.MustParseAddr("192.168.0.3") newIPv6 = netip.MustParseAddr("2001:db8::2") - existMAC = mustParseMAC(t, "01:02:03:04:05:06") - newMAC = mustParseMAC(t, "06:05:04:03:02:01") - ipv6MAC = mustParseMAC(t, "02:03:04:05:06:07") + existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06")) + newMAC = errors.Must(net.ParseMAC("06:05:04:03:02:01")) + ipv6MAC = errors.Must(net.ParseMAC("02:03:04:05:06:07")) ) require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{ @@ -291,9 +293,9 @@ func TestDHCPServer_index(t *testing.T) { ip3 = netip.MustParseAddr("172.16.0.3") ip4 = netip.MustParseAddr("172.16.0.4") - mac1 = mustParseMAC(t, "01:02:03:04:05:06") - mac2 = mustParseMAC(t, "06:05:04:03:02:01") - mac3 = mustParseMAC(t, "02:03:04:05:06:07") + mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06")) + mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01")) + mac3 = errors.Must(net.ParseMAC("02:03:04:05:06:07")) ) t.Run("ip_idx", func(t *testing.T) { @@ -349,9 +351,9 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) { ip3 = netip.MustParseAddr("192.168.0.4") ip4 = netip.MustParseAddr("2001:db8::3") - mac1 = mustParseMAC(t, "01:02:03:04:05:06") - mac2 = mustParseMAC(t, "06:05:04:03:02:01") - mac3 = mustParseMAC(t, "06:05:04:03:02:02") + mac1 = errors.Must(net.ParseMAC("01:02:03:04:05:06")) + mac2 = errors.Must(net.ParseMAC("06:05:04:03:02:01")) + mac3 = errors.Must(net.ParseMAC("06:05:04:03:02:02")) ) testCases := []struct { @@ -452,9 +454,9 @@ func TestDHCPServer_RemoveLease(t *testing.T) { newIP = netip.MustParseAddr("192.168.0.3") newIPv6 = netip.MustParseAddr("2001:db8::2") - existMAC = mustParseMAC(t, "01:02:03:04:05:06") - newMAC = mustParseMAC(t, "02:03:04:05:06:07") - ipv6MAC = mustParseMAC(t, "06:05:04:03:02:01") + existMAC = errors.Must(net.ParseMAC("01:02:03:04:05:06")) + newMAC = errors.Must(net.ParseMAC("02:03:04:05:06:07")) + ipv6MAC = errors.Must(net.ParseMAC("06:05:04:03:02:01")) ) testCases := []struct { @@ -559,13 +561,13 @@ func TestServer_Leases(t *testing.T) { Expiry: expiry, IP: netip.MustParseAddr("192.168.0.3"), Hostname: "example.host", - HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"), + HWAddr: errors.Must(net.ParseMAC("AA:AA:AA:AA:AA:AA")), IsStatic: false, }, { Expiry: time.Time{}, IP: netip.MustParseAddr("192.168.0.4"), Hostname: "example.static.host", - HWAddr: mustParseMAC(t, "BB:BB:BB:BB:BB:BB"), + HWAddr: errors.Must(net.ParseMAC("BB:BB:BB:BB:BB:BB")), IsStatic: true, }} assert.ElementsMatch(t, wantLeases, srv.Leases()) diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index 42e4b758..c5535d30 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" @@ -51,7 +52,7 @@ func processAccessClients( } else if ipnet, err = netip.ParsePrefix(s); err == nil { *nets = append(*nets, ipnet) } else { - err = ValidateClientID(s) + err = client.ValidateClientID(s) if err != nil { return fmt.Errorf("value %q at index %d: bad ip, cidr, or clientid", s, i) } diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 9c18b342..2a2d3825 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -7,26 +7,13 @@ import ( "path" "strings" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/quic-go/quic-go" ) -// ValidateClientID returns an error if id is not a valid ClientID. -// -// Keep in sync with [client.ValidateClientID]. -func ValidateClientID(id string) (err error) { - err = netutil.ValidateHostnameLabel(id) - if err != nil { - // Replace the domain name label wrapper with our own. - return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err)) - } - - return nil -} - // clientIDFromClientServerName extracts and validates a ClientID. hostSrvName // is the server name of the host. cliSrvName is the server name as sent by the // client. When strict is true, and client and host server name don't match, @@ -53,7 +40,7 @@ func clientIDFromClientServerName( } clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] - err = ValidateClientID(clientID) + err = client.ValidateClientID(clientID) if err != nil { // Don't wrap the error, because it's informative enough as is. return "", err @@ -93,7 +80,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e return "", fmt.Errorf("clientid check: invalid path %q: extra parts", origPath) } - err = ValidateClientID(clientID) + err = client.ValidateClientID(clientID) if err != nil { return "", fmt.Errorf("clientid check: %w", err) } diff --git a/internal/home/clients.go b/internal/home/clients.go index 781e5e9d..208b326b 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -28,6 +28,10 @@ type clientsContainer struct { // filter. It must not be nil. baseLogger *slog.Logger + // logger is used for logging the operation of the client container. It + // must not be nil. + logger *slog.Logger + // storage stores information about persistent clients. storage *client.Storage @@ -58,6 +62,7 @@ type clientsContainer struct { // BlockedClientChecker checks if a client is blocked by the current access // settings. type BlockedClientChecker interface { + // TODO(s.chzhen): Accept [client.FindParams]. IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) } @@ -80,6 +85,7 @@ func (clients *clientsContainer) Init( } clients.baseLogger = baseLogger + clients.logger = baseLogger.With(slogutil.KeyPrefix, "client_container") clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) @@ -269,7 +275,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { BlockedServices: cli.BlockedServices.Clone(), - IDs: cli.IDs(), + IDs: cli.Identifiers(), Tags: slices.Clone(cli.Tags), Upstreams: slices.Clone(cli.Upstreams), @@ -356,15 +362,27 @@ func (clients *clientsContainer) clientOrArtificial( }, true } -// shouldCountClient is a wrapper around [clientsContainer.find] to make it a +// shouldCountClient is a wrapper around [client.Storage.Find] to make it a // valid client information finder for the statistics. If no information about -// the client is found, it returns true. +// the client is found, it returns true. Values of ids must be either a valid +// ClientID or a valid IP address. +// +// TODO(s.chzhen): Accept [client.FindParams]. func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { clients.lock.Lock() defer clients.lock.Unlock() + params := &client.FindParams{} for _, id := range ids { - client, ok := clients.storage.Find(id) + err := params.Set(id) + if err != nil { + // Should not happen. + clients.logger.Warn("parsing find params", slogutil.KeyError, err) + + continue + } + + client, ok := clients.storage.Find(params) if ok { return !client.IgnoreStatistics } diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 2971dfea..010df861 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -300,7 +300,7 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) { return &clientJSON{ Name: c.Name, - IDs: c.IDs(), + IDs: c.Identifiers(), Tags: c.Tags, UseGlobalSettings: !c.UseOwnSettings, FilteringEnabled: c.FilteringEnabled, @@ -428,32 +428,53 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht // Deprecated: Remove it when migration to the new API is over. func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() - data := []map[string]*clientJSON{} + data := make([]map[string]*clientJSON, 0, len(q)) + params := &client.FindParams{} + var err error + for i := range len(q) { idStr := q.Get(fmt.Sprintf("ip%d", i)) if idStr == "" { break } + err = params.Set(idStr) + if err != nil { + clients.logger.DebugContext( + r.Context(), + "finding client", + "id", idStr, + slogutil.KeyError, err, + ) + + continue + } + data = append(data, map[string]*clientJSON{ - idStr: clients.findClient(idStr), + idStr: clients.findClient(idStr, params), }) } aghhttp.WriteJSONResponseOK(w, r, data) } -// findClient returns available information about a client by idStr from the -// client's storage or access settings. cj is guaranteed to be non-nil. -func (clients *clientsContainer) findClient(idStr string) (cj *clientJSON) { - ip, _ := netip.ParseAddr(idStr) - c, ok := clients.storage.Find(idStr) +// findClient returns available information about a client by params from the +// client's storage or access settings. idStr is the string representation of +// typed params. params must not be nil. cj is guaranteed to be non-nil. +func (clients *clientsContainer) findClient( + idStr string, + params *client.FindParams, +) (cj *clientJSON) { + c, ok := clients.storage.Find(params) if !ok { - return clients.findRuntime(ip, idStr) + return clients.findRuntime(idStr, params) } cj = clientToJSON(c) - disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient( + params.RemoteIP, + string(params.ClientID), + ) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule return cj @@ -472,7 +493,8 @@ type searchClientJSON struct { ID string `json:"id"` } -// handleSearchClient is the handler for the POST /control/clients/search HTTP API. +// handleSearchClient is the handler for the POST /control/clients/search HTTP +// API. func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) { q := searchQueryJSON{} err := json.NewDecoder(r.Body).Decode(&q) @@ -482,11 +504,25 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht return } - data := []map[string]*clientJSON{} + data := make([]map[string]*clientJSON, 0, len(q.Clients)) + params := &client.FindParams{} + for _, c := range q.Clients { idStr := c.ID + err = params.Set(idStr) + if err != nil { + clients.logger.DebugContext( + r.Context(), + "searching client", + "id", idStr, + slogutil.KeyError, err, + ) + + continue + } + data = append(data, map[string]*clientJSON{ - idStr: clients.findClient(idStr), + idStr: clients.findClient(idStr, params), }) } @@ -494,38 +530,37 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht } // findRuntime looks up the IP in runtime and temporary storages, like -// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be -// non-nil. -func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { +// /etc/hosts tables, DHCP leases, or blocklists. params must not be nil. cj +// is guaranteed to be non-nil. +func (clients *clientsContainer) findRuntime( + idStr string, + params *client.FindParams, +) (cj *clientJSON) { + var host string + whois := &whois.Info{} + + ip := params.RemoteIP rc := clients.storage.ClientRuntime(ip) - if rc == nil { - // It is still possible that the IP used to be in the runtime clients - // list, but then the server was reloaded. So, check the DNS server's - // blocked IP list. - // - // See https://github.com/AdguardTeam/AdGuardHome/issues/2428. - disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) - cj = &clientJSON{ - IDs: []string{idStr}, - Disallowed: &disallowed, - DisallowedRule: &rule, - WHOIS: &whois.Info{}, - } - - return cj + if rc != nil { + _, host = rc.Info() + whois = whoisOrEmpty(rc) } - _, host := rc.Info() - cj = &clientJSON{ - Name: host, - IDs: []string{idStr}, - WHOIS: whoisOrEmpty(rc), + // Check the DNS server's blocked IP list regardless of whether a runtime + // client was found or not. This is because it's still possible that the + // runtime client associated with the IP address was stored previously, but + // then the server was reloaded. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/2428. + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, string(params.ClientID)) + + return &clientJSON{ + Name: host, + IDs: []string{idStr}, + WHOIS: whois, + Disallowed: &disallowed, + DisallowedRule: &rule, } - - disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) - cj.Disallowed, cj.DisallowedRule = &disallowed, &rule - - return cj } // RegisterClientsHandlers registers HTTP handlers diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go index c1c495f2..01197983 100644 --- a/internal/home/clientshttp_internal_test.go +++ b/internal/home/clientshttp_internal_test.go @@ -153,7 +153,7 @@ func TestClientsContainer_HandleAddClient(t *testing.T) { clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) clientEmptyID := newPersistentClient("empty_client_id") - clientEmptyID.ClientIDs = []string{""} + clientEmptyID.ClientIDs = []client.ClientID{""} testCases := []struct { name string @@ -278,7 +278,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) { clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) clientEmptyID := newPersistentClient("empty_client_id") - clientEmptyID.ClientIDs = []string{""} + clientEmptyID.ClientIDs = []client.ClientID{""} testCases := []struct { name string diff --git a/internal/home/mobileconfig.go b/internal/home/mobileconfig.go index f3c82278..1f4c3955 100644 --- a/internal/home/mobileconfig.go +++ b/internal/home/mobileconfig.go @@ -8,7 +8,7 @@ import ( "net/url" "path" - "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" @@ -151,7 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { clientID := q.Get("client_id") if clientID != "" { - err = dnsforward.ValidateClientID(clientID) + err = client.ValidateClientID(clientID) if err != nil { respondJSONError(w, http.StatusBadRequest, err.Error()) diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index d6c47ce2..77315d41 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -980,7 +980,8 @@ - 'clients' 'operationId': 'clientsSearch' 'summary': > - Get information about clients by their IP addresses, CIDRs, MAC addresses, or ClientIDs. + Retrieve information about clients by performing an exact match search + using IP addresses, CIDRs, MAC addresses, or ClientIDs. 'requestBody': 'content': 'application/json':