Pull request #2323: AGDNS-2598-clients-search
Merge in DNS/adguard-home from AGDNS-2598-clients-search to master Squashed commit of the following: commit 9df3c19acad16203ccaa7752902bce8bc835c8fb Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Dec 16 19:11:43 2024 +0300 home: imp code commit 7bf8f0a516b57fab6c19c24e4a156c87d9c4d03f Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Dec 16 18:34:06 2024 +0300 all: imp code commit 2dd1c941232ceeaef4c506717096a8b8e8555e6e Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Dec 13 17:35:11 2024 +0300 all: clients search
This commit is contained in:
@@ -424,6 +424,8 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
}
|
||||
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]*clientJSON{}
|
||||
@@ -433,19 +435,58 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
break
|
||||
}
|
||||
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.storage.Find(idStr)
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
data = append(data, map[string]*clientJSON{
|
||||
idStr: cj,
|
||||
idStr: clients.findClient(idStr),
|
||||
})
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
// findClient returns available information about a client by idStr from the
|
||||
// client's storage or access settings. cj is guaranteed to be non-nil.
|
||||
func (clients *clientsContainer) findClient(idStr string) (cj *clientJSON) {
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.storage.Find(idStr)
|
||||
if !ok {
|
||||
return clients.findRuntime(ip, idStr)
|
||||
}
|
||||
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
// searchQueryJSON is a request to the POST /control/clients/search HTTP API.
|
||||
//
|
||||
// TODO(s.chzhen): Add UIDs.
|
||||
type searchQueryJSON struct {
|
||||
Clients []searchClientJSON `json:"clients"`
|
||||
}
|
||||
|
||||
// searchClientJSON is a part of [searchQueryJSON] that contains a string
|
||||
// representation of the client's IP address, CIDR, MAC address, or ClientID.
|
||||
type searchClientJSON struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// handleSearchClient is the handler for the POST /control/clients/search HTTP API.
|
||||
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := searchQueryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&q)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
data := []map[string]*clientJSON{}
|
||||
for _, c := range q.Clients {
|
||||
idStr := c.ID
|
||||
data = append(data, map[string]*clientJSON{
|
||||
idStr: clients.findClient(idStr),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -493,5 +534,8 @@ func (clients *clientsContainer) registerWebHandlers() {
|
||||
httpRegister(http.MethodPost, "/control/clients/add", clients.handleAddClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/search", clients.handleSearchClient)
|
||||
|
||||
// Deprecated handler.
|
||||
httpRegister(http.MethodGet, "/control/clients/find", clients.handleFindClient)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -408,3 +409,145 @@ func TestClientsContainer_HandleFindClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleSearchClient(t *testing.T) {
|
||||
var (
|
||||
runtimeCli = "runtime_client1"
|
||||
|
||||
runtimeCliIP = "3.3.3.3"
|
||||
blockedCliIP = "4.4.4.4"
|
||||
nonExistentCliIP = "5.5.5.5"
|
||||
|
||||
allowed = false
|
||||
dissallowed = true
|
||||
|
||||
emptyRule = ""
|
||||
disallowedRule = "disallowed_rule"
|
||||
)
|
||||
|
||||
clients := newClientsContainer(t)
|
||||
clients.clientChecker = &testBlockedClientChecker{
|
||||
onIsBlockedClient: func(ip netip.Addr, _ string) (ok bool, rule string) {
|
||||
if ip == netip.MustParseAddr(blockedCliIP) {
|
||||
return true, disallowedRule
|
||||
}
|
||||
|
||||
return false, emptyRule
|
||||
},
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.storage.Add(ctx, clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.storage.Add(ctx, clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
clients.UpdateAddress(ctx, netip.MustParseAddr(runtimeCliIP), runtimeCli, nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query *searchQueryJSON
|
||||
wantPersistent []*client.Persistent
|
||||
wantRuntime *clientJSON
|
||||
}{{
|
||||
name: "single",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: testClientIP1,
|
||||
}},
|
||||
},
|
||||
wantPersistent: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "multiple",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: testClientIP1,
|
||||
}, {
|
||||
ID: testClientIP2,
|
||||
}},
|
||||
},
|
||||
wantPersistent: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "runtime",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: runtimeCliIP,
|
||||
}},
|
||||
},
|
||||
wantRuntime: &clientJSON{
|
||||
Name: runtimeCli,
|
||||
IDs: []string{runtimeCliIP},
|
||||
Disallowed: &allowed,
|
||||
DisallowedRule: &emptyRule,
|
||||
WHOIS: &whois.Info{},
|
||||
},
|
||||
}, {
|
||||
name: "blocked_access",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: blockedCliIP,
|
||||
}},
|
||||
},
|
||||
wantRuntime: &clientJSON{
|
||||
IDs: []string{blockedCliIP},
|
||||
Disallowed: &dissallowed,
|
||||
DisallowedRule: &disallowedRule,
|
||||
WHOIS: &whois.Info{},
|
||||
},
|
||||
}, {
|
||||
name: "non_existing_client",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: nonExistentCliIP,
|
||||
}},
|
||||
},
|
||||
wantRuntime: &clientJSON{
|
||||
IDs: []string{nonExistentCliIP},
|
||||
Disallowed: &allowed,
|
||||
DisallowedRule: &emptyRule,
|
||||
WHOIS: &whois.Info{},
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var body []byte
|
||||
body, err = json.Marshal(tc.query)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleSearchClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
body, err = io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientData := []map[string]*clientJSON{}
|
||||
err = json.Unmarshal(body, &clientData)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.wantPersistent != nil {
|
||||
assertPersistentClientsData(t, clients, clientData, tc.wantPersistent)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, clientData, 1)
|
||||
require.Len(t, clientData[0], 1)
|
||||
|
||||
rc := clientData[0][tc.wantRuntime.IDs[0]]
|
||||
assert.Equal(t, tc.wantRuntime, rc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user