Pull request 1803: 5685-fix-safe-search
Updates #5685. Squashed commit of the following: commit 5312147abfa0914c896acbf1e88f8c8f1af90f2b Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Apr 6 14:09:44 2023 +0300 safesearch: imp tests, logs commit 298b5d24ce292c5f83ebe33d1e92329e4b3c1acc Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 5 20:36:16 2023 +0300 safesearch: fix filters, logging commit 63d6ca5d694d45705473f2f0410e9e0b49cf7346 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 5 20:24:47 2023 +0300 all: dry; fix logs commit fdbf2f364fd0484b47b3161bf6f4581856fdf47b Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 5 20:01:08 2023 +0300 all: fix safe search update
This commit is contained in:
@@ -3,8 +3,10 @@ package home
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
)
|
||||
|
||||
@@ -45,6 +47,23 @@ func (c *Client) closeUpstreams() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setSafeSearch initializes and sets the safe search filter for this client.
|
||||
func (c *Client) setSafeSearch(
|
||||
conf filtering.SafeSearchConfig,
|
||||
cacheSize uint,
|
||||
cacheTTL time.Duration,
|
||||
) (err error) {
|
||||
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
c.SafeSearch = ss
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientSource represents the source from which the information about the
|
||||
// client has been obtained.
|
||||
type clientSource uint
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@@ -55,6 +54,14 @@ type clientsContainer struct {
|
||||
// more detail.
|
||||
lock sync.Mutex
|
||||
|
||||
// safeSearchCacheSize is the size of the safe search cache to use for
|
||||
// persistent clients.
|
||||
safeSearchCacheSize uint
|
||||
|
||||
// safeSearchCacheTTL is the TTL of the safe search cache to use for
|
||||
// persistent clients.
|
||||
safeSearchCacheTTL time.Duration
|
||||
|
||||
// testing is a flag that disables some features for internal tests.
|
||||
//
|
||||
// TODO(a.garipov): Awful. Remove.
|
||||
@@ -74,6 +81,7 @@ func (clients *clientsContainer) Init(
|
||||
if clients.list != nil {
|
||||
log.Fatal("clients.list != nil")
|
||||
}
|
||||
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipToRC = map[netip.Addr]*RuntimeClient{}
|
||||
@@ -85,6 +93,9 @@ func (clients *clientsContainer) Init(
|
||||
clients.arpdb = arpdb
|
||||
clients.addFromConfig(objects, filteringConf)
|
||||
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
|
||||
if clients.testing {
|
||||
return
|
||||
}
|
||||
@@ -171,18 +182,16 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
|
||||
if o.SafeSearchConf.Enabled {
|
||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
|
||||
ss, err := safesearch.NewDefaultSafeSearch(
|
||||
err := cli.setSafeSearch(
|
||||
o.SafeSearchConf,
|
||||
filteringConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("clients: init client safesearch %s: %s", cli.Name, err)
|
||||
log.Error("clients: init client safesearch %q: %s", cli.Name, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
cli.SafeSearch = ss
|
||||
}
|
||||
|
||||
for _, s := range o.BlockedServices {
|
||||
|
||||
@@ -9,17 +9,27 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
// newClientsContainer is a helper that creates a new clients container for
|
||||
// tests.
|
||||
func newClientsContainer() (c *clientsContainer) {
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
|
||||
clients.Init(nil, nil, nil, nil, nil)
|
||||
c.Init(nil, nil, nil, nil, &filtering.Config{})
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClients(t *testing.T) {
|
||||
clients := newClientsContainer()
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
@@ -198,10 +208,7 @@ func TestClients(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsWHOIS(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil, nil, nil)
|
||||
clients := newClientsContainer()
|
||||
whois := &RuntimeClientWHOISInfo{
|
||||
Country: "AU",
|
||||
Orgname: "Example Org",
|
||||
@@ -247,10 +254,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil, nil, nil)
|
||||
clients := newClientsContainer()
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
@@ -325,10 +329,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
clients.Init(nil, nil, nil, nil, nil)
|
||||
clients := newClientsContainer()
|
||||
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.Add(&Client{
|
||||
|
||||
@@ -49,8 +49,8 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
Name string `json:"name"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source clientSource `json:"source"`
|
||||
}
|
||||
|
||||
@@ -90,14 +90,16 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
// jsonToClient converts JSON object to Client object.
|
||||
func jsonToClient(cj clientJSON) (c *Client) {
|
||||
func (clients *clientsContainer) jsonToClient(cj clientJSON) (c *Client, err error) {
|
||||
var safeSearchConf filtering.SafeSearchConfig
|
||||
if cj.SafeSearchConf != nil {
|
||||
safeSearchConf = *cj.SafeSearchConf
|
||||
} else {
|
||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||
// [clientJSON.SafeSearchEnabled] field.
|
||||
safeSearchConf = filtering.SafeSearchConfig{Enabled: cj.SafeSearchEnabled}
|
||||
safeSearchConf = filtering.SafeSearchConfig{
|
||||
Enabled: cj.SafeSearchEnabled,
|
||||
}
|
||||
|
||||
// Set default service flags for enabled safesearch.
|
||||
if safeSearchConf.Enabled {
|
||||
@@ -110,20 +112,35 @@ func jsonToClient(cj clientJSON) (c *Client) {
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
Name: cj.Name,
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
c = &Client{
|
||||
safeSearchConf: safeSearchConf,
|
||||
|
||||
Name: cj.Name,
|
||||
|
||||
IDs: cj.IDs,
|
||||
Tags: cj.Tags,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
Upstreams: cj.Upstreams,
|
||||
|
||||
UseOwnSettings: !cj.UseGlobalSettings,
|
||||
FilteringEnabled: cj.FilteringEnabled,
|
||||
ParentalEnabled: cj.ParentalEnabled,
|
||||
SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
|
||||
safeSearchConf: safeSearchConf,
|
||||
UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
|
||||
BlockedServices: cj.BlockedServices,
|
||||
|
||||
Upstreams: cj.Upstreams,
|
||||
}
|
||||
|
||||
if safeSearchConf.Enabled {
|
||||
err = c.setSafeSearch(
|
||||
safeSearchConf,
|
||||
clients.safeSearchCacheSize,
|
||||
clients.safeSearchCacheTTL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// clientToJSON converts Client object to JSON.
|
||||
@@ -161,7 +178,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
c := jsonToClient(cj)
|
||||
c, err := clients.jsonToClient(cj)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
@@ -224,7 +247,13 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
c := jsonToClient(dj.Data)
|
||||
c, err := clients.jsonToClient(dj.Data)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = clients.Update(dj.Name, c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -545,6 +545,8 @@ var _ filtering.Resolver = safeSearchResolver{}
|
||||
|
||||
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
|
||||
// It returns the slice of net.IP with IPv4 and IPv6 instances.
|
||||
//
|
||||
// TODO(a.garipov): Support network.
|
||||
func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||
addrs, err := Context.dnsServer.Resolve(host)
|
||||
if err != nil {
|
||||
|
||||
@@ -297,8 +297,9 @@ func setupConfig(opts options) (err error) {
|
||||
config.DNS.DnsfilterConf.HTTPClient = Context.client
|
||||
|
||||
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefaultSafeSearch(
|
||||
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
|
||||
config.DNS.DnsfilterConf.SafeSearchConf,
|
||||
"default",
|
||||
config.DNS.DnsfilterConf.SafeSearchCacheSize,
|
||||
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
|
||||
)
|
||||
@@ -869,8 +870,10 @@ func detectFirstRun() bool {
|
||||
// Connect to a remote server resolving hostname using our own DNS server.
|
||||
//
|
||||
// TODO(e.burkov): This messy logic should be decomposed and clarified.
|
||||
//
|
||||
// TODO(a.garipov): Support network.
|
||||
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Tracef("network:%v addr:%v", network, addr)
|
||||
log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user