diff --git a/internal/home/client.go b/internal/home/client.go index 5e56df19..1aee021e 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/stringutil" ) // Client contains information about persistent clients. @@ -37,6 +38,19 @@ type Client struct { IgnoreStatistics bool } +// ShallowClone returns a deep copy of the client, except upstreamConfig, +// safeSearchConf, SafeSearch fields, because it's difficult to copy them. +func (c *Client) ShallowClone() (sh *Client) { + clone := *c + + clone.IDs = stringutil.CloneSlice(c.IDs) + clone.Tags = stringutil.CloneSlice(c.Tags) + clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices) + clone.Upstreams = stringutil.CloneSlice(c.Upstreams) + + return &clone +} + // closeUpstreams closes the client-specific upstream config of c if any. func (c *Client) closeUpstreams() (err error) { if c.upstreamConfig != nil { diff --git a/internal/home/clients.go b/internal/home/clients.go index 1ea0247a..d2e4194b 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -378,6 +378,7 @@ func (clients *clientsContainer) clientOrArtificial( }, true } +// Find returns a shallow copy of the client if there is one found. func (clients *clientsContainer) Find(id string) (c *Client, ok bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -387,20 +388,18 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) { return nil, false } - c.IDs = stringutil.CloneSlice(c.IDs) - c.Tags = stringutil.CloneSlice(c.Tags) - c.BlockedServices = stringutil.CloneSlice(c.BlockedServices) - c.Upstreams = stringutil.CloneSlice(c.Upstreams) - - return c, true + return c.ShallowClone(), true } // shouldCountClient is a wrapper around Find to make it a valid client // information finder for the statistics. If no information about the client // is found, it returns true. func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) { + clients.lock.Lock() + defer clients.lock.Unlock() + for _, id := range ids { - client, ok := clients.Find(id) + client, ok := clients.findLocked(id) if ok { return !client.IgnoreStatistics } @@ -617,6 +616,15 @@ func (clients *clientsContainer) Add(c *Client) (ok bool, err error) { } } + clients.add(c) + + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list)) + + return true, nil +} + +// add c to the indexes. clients.lock is expected to be locked. +func (clients *clientsContainer) add(c *Client) { // update Name index clients.list[c.Name] = c @@ -624,10 +632,6 @@ func (clients *clientsContainer) Add(c *Client) (ok bool, err error) { for _, id := range c.IDs { clients.idIndex[id] = c } - - log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list)) - - return true, nil } // Del removes a client. ok is false if there is no such client. @@ -645,86 +649,53 @@ func (clients *clientsContainer) Del(name string) (ok bool) { log.Error("client container: removing client %s: %s", name, err) } + clients.del(c) + + return true +} + +// del removes c from the indexes. clients.lock is expected to be locked. +func (clients *clientsContainer) del(c *Client) { // update Name index - delete(clients.list, name) + delete(clients.list, c.Name) // update ID index for _, id := range c.IDs { delete(clients.idIndex, id) } - - return true } // Update updates a client by its name. -func (clients *clientsContainer) Update(name string, c *Client) (err error) { +func (clients *clientsContainer) Update(prev, c *Client) (err error) { err = clients.check(c) if err != nil { + // Don't wrap the error since it's informative enough as is. return err } clients.lock.Lock() defer clients.lock.Unlock() - prev, ok := clients.list[name] - if !ok { - return errors.Error("client not found") - } - - // First, check the name index. + // Check the name index. if prev.Name != c.Name { - _, ok = clients.list[c.Name] + _, ok := clients.list[c.Name] if ok { return errors.Error("client already exists") } } - // Second, update the ID index. - err = clients.updateIDIndex(prev, c.IDs) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } - - // Update name index. - if prev.Name != c.Name { - delete(clients.list, prev.Name) - clients.list[c.Name] = prev - } - - // Update upstreams cache. - err = c.closeUpstreams() - if err != nil { - return err - } - - *prev = *c - - return nil -} - -// updateIDIndex updates the ID index data for cli using the information from -// newIDs. -func (clients *clientsContainer) updateIDIndex(cli *Client, newIDs []string) (err error) { - if slices.Equal(cli.IDs, newIDs) { - return nil - } - - for _, id := range newIDs { - existing, ok := clients.idIndex[id] - if ok && existing != cli { - return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) + // Check the ID index. + if !slices.Equal(prev.IDs, c.IDs) { + for _, id := range c.IDs { + existing, ok := clients.idIndex[id] + if ok && existing != prev { + return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) + } } } - // Update the IDs in the index. - for _, id := range cli.IDs { - delete(clients.idIndex, id) - } - - for _, id := range newIDs { - clients.idIndex[id] = cli - } + clients.del(prev) + clients.add(c) return nil } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index ebf879ef..8361528a 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -98,22 +98,8 @@ func TestClients(t *testing.T) { assert.False(t, ok) }) - t.Run("update_fail_name", func(t *testing.T) { - err := clients.Update("client3", &Client{ - IDs: []string{"1.2.3.0"}, - Name: "client3", - }) - require.Error(t, err) - - err = clients.Update("client3", &Client{ - IDs: []string{"1.2.3.0"}, - Name: "client2", - }) - assert.Error(t, err) - }) - t.Run("update_fail_ip", func(t *testing.T) { - err := clients.Update("client1", &Client{ + err := clients.Update(&Client{Name: "client1"}, &Client{ IDs: []string{"2.2.2.2"}, Name: "client1", }) @@ -129,7 +115,10 @@ func TestClients(t *testing.T) { cliNewIP = netip.MustParseAddr(cliNew) ) - err := clients.Update("client1", &Client{ + prev, ok := clients.list["client1"] + require.True(t, ok) + + err := clients.Update(prev, &Client{ IDs: []string{cliNew}, Name: "client1", }) @@ -138,7 +127,10 @@ func TestClients(t *testing.T) { assert.Equal(t, clients.clientSource(cliOldIP), ClientSourceNone) assert.Equal(t, clients.clientSource(cliNewIP), ClientSourcePersistent) - err = clients.Update("client1", &Client{ + prev, ok = clients.list["client1"] + require.True(t, ok) + + err = clients.Update(prev, &Client{ IDs: []string{cliNew}, Name: "client1-renamed", UseOwnSettings: true, diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 82a16713..6425f941 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -289,7 +289,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - err = clients.Update(dj.Name, c) + err = clients.Update(prev, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)