Compare commits

..

9 Commits

Author SHA1 Message Date
Stanislav Chzhen
b34a8c169c filtering: imp tests 2023-05-15 10:33:52 +03:00
Stanislav Chzhen
1f2ba07eae filtering: wildcard interference 2023-05-12 12:25:33 +03:00
Dimitry Kolyshev
c77b2a0ce5 Pull request: home: imp code
Merge in DNS/adguard-home from home-imp-code to master

Squashed commit of the following:

commit 459297e189c55393bf0340dd51ec9608d3475e55
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 10 11:42:34 2023 +0300

    home: imp code

commit ab38e1e80fed7b24fe57d4afdc57b70608f65d73
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 10 11:01:23 2023 +0300

    all: lint script

commit 7df68b128bf32172ef2e3bf7116f4f72a97baa2b
Merge: bcb482714 db52f7a3a
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 10 10:59:40 2023 +0300

    Merge remote-tracking branch 'origin/master' into home-imp-code

commit bcb482714780da882e69c261be08511ea4f36f3b
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Thu May 4 13:48:27 2023 +0300

    all: lint script

commit 1c017f27715202ec1f40881f069a96f11f9822e8
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Thu May 4 13:45:25 2023 +0300

    all: lint script

commit ee3d427a7d6ee7e377e67c5eb99eebc7fb1e6acc
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Thu May 4 13:44:53 2023 +0300

    home: imp code

commit bc50430469123415216e60e178bd8e30fc229300
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Thu May 4 13:12:10 2023 +0300

    home: imp code

commit fc07e416aeab2612e68cf0e3f933aaed95931115
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Thu May 4 11:42:32 2023 +0300

    aghos: service precheck

commit a68480fd9c4cd6f3c89210bee6917c53074f7a82
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Thu May 4 11:07:05 2023 +0300

    home: imp code

commit 61b743a340ac1564c48212452c7a9acd1808d352
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 17:17:21 2023 +0300

    all: lint script

commit c6fe620510c4af5b65456e90cb3424831334e004
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 17:16:37 2023 +0300

    home: imp code

commit 4b2fb47ea9c932054ccc72b1fd1d11793c93e39c
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 16:55:44 2023 +0300

    home: imp code

commit 63df3e2ab58482920a074cfd5f4188e49a0f8448
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 16:25:38 2023 +0300

    home: imp code

commit c7f1502f976482c2891e0c64426218b549585e83
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 15:54:30 2023 +0300

    home: imp code

commit c64cdaf1c82495bb70d9cdcaf7be9eeee9a7c773
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 14:35:04 2023 +0300

    home: imp code

commit a50436e040b3a064ef51d5f936b879fe8de72d41
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 14:24:02 2023 +0300

    home: imp code

commit 2b66464f472df732ea27cbbe5ac5c673a13bc14b
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 14:11:53 2023 +0300

    home: imp code

commit 713ce2963c210887faa0a06e41e01e4ebbf96894
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed May 3 14:10:54 2023 +0300

    home: imp code
2023-05-10 16:30:03 +03:00
Stanislav Chzhen
db52f7a3ac Pull request 1841: AG-21462-safebrowsing-parental-http-tests
Merge in DNS/adguard-home from AG-21462-safebrowsing-parental-http-tests to master

Squashed commit of the following:

commit 22a83ebad08a27939a443530137a7c195f512ee4
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed May 3 17:39:46 2023 +0300

    filtering: fix test

commit c3ca8b4987245cdd552f6f09759804e716bcae80
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed May 3 16:43:35 2023 +0300

    filtering: imp tests even more

commit 7643bfae350373b5b6dfb61b64e57da66c6ab952
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed May 3 16:17:42 2023 +0300

    filtering: imp tests more

commit 399c05ee4d479a727b61378b7a07158a568d0181
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed May 3 14:45:41 2023 +0300

    filtering: imp tests

commit f361df39e784ec9c5191666736a6c64b332928e8
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue May 2 12:49:26 2023 +0300

    filtering: add tests
2023-05-03 19:52:06 +03:00
Stanislav Chzhen
381f2f651d Pull request 1837: AG-21462-imp-safebrowsing-parental
Merge in DNS/adguard-home from AG-21462-imp-safebrowsing-parental to master

Squashed commit of the following:

commit 85016d4f1105e21a407efade0bd45b8362808061
Merge: 0e61edade 620b51e3e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 27 16:36:30 2023 +0300

    Merge branch 'master' into AG-21462-imp-safebrowsing-parental

commit 0e61edadeff34f6305e941c1db94575c82f238d9
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 27 14:51:37 2023 +0300

    filtering: imp tests

commit 994255514cc0f67dfe33d5a0892432e8924d1e36
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 27 11:13:19 2023 +0300

    filtering: fix typo

commit 96d1069573171538333330d6af94ef0f4208a9c4
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 27 11:00:18 2023 +0300

    filtering: imp code more

commit c2a5620b04c4a529eea69983f1520cd2bc82ea9b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 26 19:13:26 2023 +0300

    all: add todo

commit e5dcc2e9701f8bccfde6ef8c01a4a2e7eb31599e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 26 14:36:08 2023 +0300

    all: imp code more

commit b6e734ccbeda82669023f6578481260b7c1f7161
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Apr 25 15:01:56 2023 +0300

    filtering: imp code

commit 530648dadf836c1a4bd9917e0d3b47256fa8ff52
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 24 20:06:36 2023 +0300

    all: imp code

commit 49fa6e587052a40bb431fea457701ee860493527
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Apr 24 14:57:19 2023 +0300

    all: rm safe browsing ctx

commit bbcb66cb03e18fa875e3c33cf16295892739e507
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Apr 21 17:54:18 2023 +0300

    filtering: add cache item

commit cb7c9fffe8c4ff5e7a21ca912c223c799f61385f
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 20 18:43:02 2023 +0300

    filtering: fix hashes

commit 153fec46270212af03f3631bfb42c5d680c4e142
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 20 16:15:15 2023 +0300

    filtering: add test cases

commit 09372f92bbb1fc082f1b1283594ee589100209c5
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Apr 20 15:38:05 2023 +0300

    filtering: imp code

commit 466bc26d524ea6d1c3efb33692a7785d39e491ca
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 19 18:38:40 2023 +0300

    filtering: add tests

commit 24365ecf8c60512fdac65833ee603c80864ae018
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Apr 19 11:38:57 2023 +0300

    filtering: add hashprefix
2023-04-27 16:39:35 +03:00
Eugene Burkov
620b51e3ea Pull request 1840: 5752-unspec-ipv6
Merge in DNS/adguard-home from 5752-unspec-ipv6 to master

Closes #5752.

Squashed commit of the following:

commit 654b808d17c6d2374b6be919515113b361fc5ff7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Apr 21 18:11:34 2023 +0300

    home: imp docs

commit 28b4c36df790f1eaa05b11a1f0a7b986894d37dc
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Apr 21 16:50:16 2023 +0300

    all: fix empty bind host
2023-04-21 18:57:53 +03:00
Ainar Garipov
757ddb06f8 Pull request 1839: 5716-write-json
Updates #5716.

Squashed commit of the following:

commit 8cf7c4f404fffb646c9df8643924eb8dc1d8f49d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 19 18:09:35 2023 +0300

    aghhttp: write json properly
2023-04-19 18:15:17 +03:00
Ainar Garipov
d6043e2352 Pull request 1838: 5716-content-type
Updates #5716.

Squashed commit of the following:

commit 584e6771c82b92857e3c13232e942cad5c183682
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 19 14:54:43 2023 +0300

    all: fix content types
2023-04-19 14:58:56 +03:00
Eugene Burkov
aeec9a86e2 Pull request 1836: 5714-handle-zeroes-health
Merge in DNS/adguard-home from 5714-handle-zeroes-health to master

Updates #5714.

Squashed commit of the following:

commit 24faab01faf723e313050294b3a35e249c3cd3e3
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 19 13:10:24 2023 +0300

    docker: add curly brackets

commit 67365d02856200685551a79aa23cf59df4a3484b
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 18 20:16:12 2023 +0300

    docker: imp zeroes check
2023-04-19 13:48:59 +03:00
30 changed files with 1636 additions and 978 deletions

View File

@@ -25,9 +25,18 @@ NOTE: Add new changes BELOW THIS COMMENT.
### Fixed ### Fixed
- Unquoted IPv6 bind hosts with trailing colons erroneously considered
unspecified addresses are now properly validated ([#5752]).
**NOTE:** the Docker healthcheck script now also doesn't interpret the `""`
value as unspecified address.
- Incorrect `Content-Type` header value in `POST /control/version.json` and `GET
/control/dhcp/interfaces` HTTP APIs ([#5716]).
- Provided bootstrap servers are now used to resolve the hostnames of plain - Provided bootstrap servers are now used to resolve the hostnames of plain
UDP/TCP upstream servers. UDP/TCP upstream servers.
[#5716]: https://github.com/AdguardTeam/AdGuardHome/issues/5716
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.
--> -->
@@ -60,6 +69,7 @@ See also the [v0.107.29 GitHub milestone][ms-v0.107.29].
[#5712]: https://github.com/AdguardTeam/AdGuardHome/issues/5712 [#5712]: https://github.com/AdguardTeam/AdGuardHome/issues/5712
[#5721]: https://github.com/AdguardTeam/AdGuardHome/issues/5721 [#5721]: https://github.com/AdguardTeam/AdGuardHome/issues/5721
[#5725]: https://github.com/AdguardTeam/AdGuardHome/issues/5725 [#5725]: https://github.com/AdguardTeam/AdGuardHome/issues/5725
[#5752]: https://github.com/AdguardTeam/AdGuardHome/issues/5752
[ms-v0.107.29]: https://github.com/AdguardTeam/AdGuardHome/milestone/65?closed=1 [ms-v0.107.29]: https://github.com/AdguardTeam/AdGuardHome/milestone/65?closed=1

View File

@@ -7,11 +7,10 @@
addrs[$2] = true addrs[$2] = true
prev_line = FNR prev_line = FNR
if ($2 == "0.0.0.0" || $2 == "::") { if ($2 == "0.0.0.0" || $2 == "'::'") {
delete addrs
addrs["localhost"] = true
# Drop all the other addresses. # Drop all the other addresses.
delete addrs
addrs[""] = true
prev_line = -1 prev_line = -1
} }
} }

View File

@@ -61,8 +61,11 @@ then
error_exit "no DNS bindings could be retrieved from $filename" error_exit "no DNS bindings could be retrieved from $filename"
fi fi
first_dns="$( echo "$dns_hosts" | head -n 1 )"
readonly first_dns
# TODO(e.burkov): Deal with 0 port. # TODO(e.burkov): Deal with 0 port.
case "$( echo "$dns_hosts" | head -n 1 )" case "$first_dns"
in in
(*':0') (*':0')
error_exit '0 in DNS port is not supported by healthcheck' error_exit '0 in DNS port is not supported by healthcheck'
@@ -82,8 +85,23 @@ esac
# See https://github.com/AdguardTeam/AdGuardHome/issues/5642. # See https://github.com/AdguardTeam/AdGuardHome/issues/5642.
wget --no-check-certificate "$web_url" -O /dev/null -q || exit 1 wget --no-check-certificate "$web_url" -O /dev/null -q || exit 1
echo "$dns_hosts" | while read -r host test_fqdn="healthcheck.adguardhome.test."
do readonly test_fqdn
nslookup -type=a healthcheck.adguardhome.test. "$host" > /dev/null ||\
# The awk script currently returns only port prefixed with colon in case of
# unspecified address.
case "$first_dns"
in
(':'*)
nslookup -type=a "$test_fqdn" "127.0.0.1${first_dns}" > /dev/null ||\
nslookup -type=a "$test_fqdn" "[::1]${first_dns}" > /dev/null ||\
error_exit "nslookup failed for $host" error_exit "nslookup failed for $host"
done ;;
(*)
echo "$dns_hosts" | while read -r host
do
nslookup -type=a "$test_fqdn" "$host" > /dev/null ||\
error_exit "nslookup failed for $host"
done
;;
esac

View File

@@ -72,8 +72,8 @@ func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err er
// WriteJSONResponseCode is like [WriteJSONResponse] but adds the ability to // WriteJSONResponseCode is like [WriteJSONResponse] but adds the ability to
// redefine the status code. // redefine the status code.
func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) { func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) {
w.WriteHeader(code)
w.Header().Set(httphdr.ContentType, HdrValApplicationJSON) w.Header().Set(httphdr.ContentType, HdrValApplicationJSON)
w.WriteHeader(code)
err = json.NewEncoder(w).Encode(resp) err = json.NewEncoder(w).Encode(resp)
if err != nil { if err != nil {
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err) Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)

View File

@@ -0,0 +1,6 @@
package aghos
// PreCheckActionStart performs the service start action pre-check.
func PreCheckActionStart() (err error) {
return preCheckActionStart()
}

View File

@@ -0,0 +1,32 @@
//go:build darwin
package aghos
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/AdguardTeam/golibs/log"
)
// preCheckActionStart performs the service start action pre-check. It warns
// user that the service should be installed into Applications directory.
func preCheckActionStart() (err error) {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("getting executable path: %v", err)
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return fmt.Errorf("evaluating executable symlinks: %v", err)
}
if !strings.HasPrefix(exe, "/Applications/") {
log.Info("warning: service must be started from within the /Applications directory")
}
return err
}

View File

@@ -0,0 +1,8 @@
//go:build !darwin
package aghos
// preCheckActionStart performs the service start action pre-check.
func preCheckActionStart() (err error) {
return nil
}

View File

@@ -350,8 +350,10 @@ type netInterfaceJSON struct {
Addrs6 []netip.Addr `json:"ipv6_addresses"` Addrs6 []netip.Addr `json:"ipv6_addresses"`
} }
// handleDHCPInterfaces is the handler for the GET /control/dhcp/interfaces HTTP
// API.
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
response := map[string]netInterfaceJSON{} resp := map[string]netInterfaceJSON{}
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
@@ -424,20 +426,11 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
} }
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 { if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
jsonIface.GatewayIP = aghnet.GatewayIP(iface.Name) jsonIface.GatewayIP = aghnet.GatewayIP(iface.Name)
response[iface.Name] = jsonIface resp[iface.Name] = jsonIface
} }
} }
err = json.NewEncoder(w).Encode(response) _ = aghhttp.WriteJSONResponse(w, r, resp)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to marshal json with available interfaces: %s",
err,
)
}
} }
// dhcpSearchOtherResult contains information about other DHCP server for // dhcpSearchOtherResult contains information about other DHCP server for

View File

@@ -23,6 +23,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@@ -915,13 +916,23 @@ func TestBlockedByHosts(t *testing.T) {
} }
func TestBlockedBySafeBrowsing(t *testing.T) { func TestBlockedBySafeBrowsing(t *testing.T) {
const hostname = "wmconvirus.narod.ru" const (
hostname = "wmconvirus.narod.ru"
cacheTime = 10 * time.Minute
cacheSize = 10000
)
sbChecker := hashprefix.New(&hashprefix.Config{
CacheTime: cacheTime,
CacheSize: cacheSize,
Upstream: aghtest.NewBlockUpstream(hostname, true),
})
sbUps := aghtest.NewBlockUpstream(hostname, true)
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
filterConf := &filtering.Config{ filterConf := &filtering.Config{
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
SafeBrowsingChecker: sbChecker,
} }
forwardConf := ServerConfig{ forwardConf := ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}}, UDPListenAddrs: []*net.UDPAddr{{}},
@@ -935,7 +946,6 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
}, },
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf, nil)
s.dnsFilter.SetSafeBrowsingUpstream(sbUps)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)

View File

@@ -3,6 +3,7 @@ package filtering
import ( import (
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/exp/slices"
) )
// DNSRewriteResult is the result of application of $dnsrewrite rules. // DNSRewriteResult is the result of application of $dnsrewrite rules.
@@ -24,7 +25,13 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
Response: DNSRewriteResultResponse{}, Response: DNSRewriteResultResponse{},
} }
for _, nr := range dnsr { slices.SortFunc(dnsr, rewriteSortsBefore)
for i, nr := range dnsr {
if i > 0 && containsWildcard(nr) {
break
}
dr := nr.DNSRewrite dr := nr.DNSRewrite
if dr.NewCNAME != "" { if dr.NewCNAME != "" {
// NewCNAME rules have a higher priority than other rules. // NewCNAME rules have a higher priority than other rules.
@@ -73,3 +80,19 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
Reason: RewrittenRule, Reason: RewrittenRule,
} }
} }
func rewriteSortsBefore(a, b *rules.NetworkRule) (sortsBefore bool) {
return len(a.Shortcut) > len(b.Shortcut)
}
func containsWildcard(r *rules.NetworkRule) (ok bool) {
for _, c := range r.RuleText {
if c == '*' {
return true
} else if c == '^' {
break
}
}
return false
}

View File

@@ -5,6 +5,7 @@ import (
"path" "path"
"testing" "testing"
"github.com/AdguardTeam/urlfilter"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -202,3 +203,32 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
assert.Equal(t, "new-ptr-with-dot.", ptr) assert.Equal(t, "new-ptr-with-dot.", ptr)
}) })
} }
func TestDNSFilter_ProcessDNSRewrites(t *testing.T) {
const text = `
|www.example.com^$dnsrewrite=127.0.0.1
|*.example.com^$dnsrewrite=127.0.0.2
`
host := "www.example.com"
rrtype := dns.TypeA
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{
FilteringEnabled: true,
}
ufReq := &urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: rrtype,
}
dres, matched := f.filteringEngine.MatchRequest(ufReq)
require.False(t, matched)
res := f.processDNSResultRewrites(dres, host)
assert.Len(t, res.Rules, 1)
}

View File

@@ -18,8 +18,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil" "github.com/AdguardTeam/golibs/mathutil"
@@ -75,6 +73,12 @@ type Resolver interface {
// Config allows you to configure DNS filtering with New() or just change variables directly. // Config allows you to configure DNS filtering with New() or just change variables directly.
type Config struct { type Config struct {
// SafeBrowsingChecker is the safe browsing hash-prefix checker.
SafeBrowsingChecker Checker `yaml:"-"`
// ParentControl is the parental control hash-prefix checker.
ParentalControlChecker Checker `yaml:"-"`
// enabled is used to be returned within Settings. // enabled is used to be returned within Settings.
// //
// It is of type uint32 to be accessed by atomic. // It is of type uint32 to be accessed by atomic.
@@ -158,8 +162,22 @@ type hostChecker struct {
name string name string
} }
// Checker is used for safe browsing or parental control hash-prefix filtering.
type Checker interface {
// Check returns true if request for the host should be blocked.
Check(host string) (block bool, err error)
}
// DNSFilter matches hostnames and DNS requests against filtering rules. // DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct { type DNSFilter struct {
safeSearch SafeSearch
// safeBrowsingChecker is the safe browsing hash-prefix checker.
safeBrowsingChecker Checker
// parentalControl is the parental control hash-prefix checker.
parentalControlChecker Checker
rulesStorage *filterlist.RuleStorage rulesStorage *filterlist.RuleStorage
filteringEngine *urlfilter.DNSEngine filteringEngine *urlfilter.DNSEngine
@@ -168,14 +186,6 @@ type DNSFilter struct {
engineLock sync.RWMutex engineLock sync.RWMutex
parentalServer string // access via methods
safeBrowsingServer string // access via methods
parentalUpstream upstream.Upstream
safeBrowsingUpstream upstream.Upstream
safebrowsingCache cache.Cache
parentalCache cache.Cache
Config // for direct access by library users, even a = assignment Config // for direct access by library users, even a = assignment
// confLock protects Config. // confLock protects Config.
confLock sync.RWMutex confLock sync.RWMutex
@@ -192,7 +202,6 @@ type DNSFilter struct {
// TODO(e.burkov): Don't use regexp for such a simple text processing task. // TODO(e.burkov): Don't use regexp for such a simple text processing task.
filterTitleRegexp *regexp.Regexp filterTitleRegexp *regexp.Regexp
safeSearch SafeSearch
hostCheckers []hostChecker hostCheckers []hostChecker
} }
@@ -940,19 +949,12 @@ func InitModule() {
// be non-nil. // be non-nil.
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d = &DNSFilter{ d = &DNSFilter{
refreshLock: &sync.Mutex{}, refreshLock: &sync.Mutex{},
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`), filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
safeBrowsingChecker: c.SafeBrowsingChecker,
parentalControlChecker: c.ParentalControlChecker,
} }
d.safebrowsingCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeBrowsingCacheSize,
})
d.parentalCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.ParentalCacheSize,
})
d.safeSearch = c.SafeSearch d.safeSearch = c.SafeSearch
d.hostCheckers = []hostChecker{{ d.hostCheckers = []hostChecker{{
@@ -977,11 +979,6 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
defer func() { err = errors.Annotate(err, "filtering: %w") }() defer func() { err = errors.Annotate(err, "filtering: %w") }()
err = d.initSecurityServices()
if err != nil {
return nil, fmt.Errorf("initializing services: %s", err)
}
d.Config = *c d.Config = *c
d.filtersMu = &sync.RWMutex{} d.filtersMu = &sync.RWMutex{}
@@ -1038,3 +1035,69 @@ func (d *DNSFilter) Start() {
// So for now we just start this periodic task from here. // So for now we just start this periodic task from here.
go d.periodicallyRefreshFilters() go d.periodicallyRefreshFilters()
} }
// Safe browsing and parental control methods.
// TODO(a.garipov): Unify with checkParental.
func (d *DNSFilter) checkSafeBrowsing(
host string,
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("safebrowsing lookup for %q", host)
}
res = Result{
Rules: []*ResultRule{{
Text: "adguard-malware-shavar",
FilterListID: SafeBrowsingListID,
}},
Reason: FilteredSafeBrowsing,
IsFiltered: true,
}
block, err := d.safeBrowsingChecker.Check(host)
if !block || err != nil {
return Result{}, err
}
return res, nil
}
// TODO(a.garipov): Unify with checkSafeBrowsing.
func (d *DNSFilter) checkParental(
host string,
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("parental lookup for %q", host)
}
res = Result{
Rules: []*ResultRule{{
Text: "parental CATEGORY_BLACKLISTED",
FilterListID: ParentalListID,
}},
Reason: FilteredParental,
IsFiltered: true,
}
block, err := d.parentalControlChecker.Check(host)
if !block || err != nil {
return Result{}, err
}
return res, nil
}

View File

@@ -7,7 +7,7 @@ import (
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@@ -27,17 +27,6 @@ const (
// Helpers. // Helpers.
func purgeCaches(d *DNSFilter) {
for _, c := range []cache.Cache{
d.safebrowsingCache,
d.parentalCache,
} {
if c != nil {
c.Clear()
}
}
}
func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) { func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) {
setts = &Settings{ setts = &Settings{
ProtectionEnabled: true, ProtectionEnabled: true,
@@ -58,11 +47,17 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
f, err := New(c, filters) f, err := New(c, filters)
require.NoError(t, err) require.NoError(t, err)
purgeCaches(f)
return f, setts return f, setts
} }
func newChecker(host string) Checker {
return hashprefix.New(&hashprefix.Config{
CacheTime: 10,
CacheSize: 100000,
Upstream: aghtest.NewBlockUpstream(host, true),
})
}
func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) { func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) {
t.Helper() t.Helper()
@@ -175,10 +170,14 @@ func TestSafeBrowsing(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) sbChecker := newChecker(sbBlocked)
d, setts := newForTest(t, &Config{
SafeBrowsingEnabled: true,
SafeBrowsingChecker: sbChecker,
}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked, setts) d.checkMatch(t, sbBlocked, setts)
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked)) require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
@@ -188,18 +187,17 @@ func TestSafeBrowsing(t *testing.T) {
d.checkMatchEmpty(t, pcBlocked, setts) d.checkMatchEmpty(t, pcBlocked, setts)
// Cached result. // Cached result.
d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, sbBlocked, setts) d.checkMatch(t, sbBlocked, setts)
d.checkMatchEmpty(t, pcBlocked, setts) d.checkMatchEmpty(t, pcBlocked, setts)
d.safeBrowsingServer = defaultSafebrowsingServer
} }
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(t, &Config{
SafeBrowsingEnabled: true,
SafeBrowsingChecker: newChecker(sbBlocked),
}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
t.Run("group", func(t *testing.T) { t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
@@ -220,10 +218,12 @@ func TestParentalControl(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil) d, setts := newForTest(t, &Config{
ParentalEnabled: true,
ParentalControlChecker: newChecker(pcBlocked),
}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.checkMatch(t, pcBlocked, setts) d.checkMatch(t, pcBlocked, setts)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked)) require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
@@ -233,7 +233,6 @@ func TestParentalControl(t *testing.T) {
d.checkMatchEmpty(t, "api.jquery.com", setts) d.checkMatchEmpty(t, "api.jquery.com", setts)
// Test cached result. // Test cached result.
d.parentalServer = "127.0.0.1"
d.checkMatch(t, pcBlocked, setts) d.checkMatch(t, pcBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru", setts) d.checkMatchEmpty(t, "yandex.ru", setts)
} }
@@ -593,8 +592,10 @@ func applyClientSettings(setts *Settings) {
func TestClientSettings(t *testing.T) { func TestClientSettings(t *testing.T) {
d, setts := newForTest(t, d, setts := newForTest(t,
&Config{ &Config{
ParentalEnabled: true, ParentalEnabled: true,
SafeBrowsingEnabled: false, SafeBrowsingEnabled: false,
SafeBrowsingChecker: newChecker(sbBlocked),
ParentalControlChecker: newChecker(pcBlocked),
}, },
[]Filter{{ []Filter{{
ID: 0, Data: []byte("||example.org^\n"), ID: 0, Data: []byte("||example.org^\n"),
@@ -602,9 +603,6 @@ func TestClientSettings(t *testing.T) {
) )
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
type testCase struct { type testCase struct {
name string name string
host string host string
@@ -665,11 +663,12 @@ func TestClientSettings(t *testing.T) {
// Benchmarks. // Benchmarks.
func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) {
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(b, &Config{
SafeBrowsingEnabled: true,
SafeBrowsingChecker: newChecker(sbBlocked),
}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err) require.NoError(b, err)
@@ -679,11 +678,12 @@ func BenchmarkSafeBrowsing(b *testing.B) {
} }
func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) {
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(b, &Config{
SafeBrowsingEnabled: true,
SafeBrowsingChecker: newChecker(sbBlocked),
}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)

View File

@@ -0,0 +1,130 @@
package hashprefix
import (
"encoding/binary"
"time"
"github.com/AdguardTeam/golibs/log"
)
// expirySize is the size of expiry in cacheItem.
const expirySize = 8
// cacheItem represents an item that we will store in the cache.
type cacheItem struct {
// expiry is the time when cacheItem will expire.
expiry time.Time
// hashes is the hashed hostnames.
hashes []hostnameHash
}
// toCacheItem decodes cacheItem from data. data must be at least equal to
// expiry size.
func toCacheItem(data []byte) *cacheItem {
t := time.Unix(int64(binary.BigEndian.Uint64(data)), 0)
data = data[expirySize:]
hashes := make([]hostnameHash, len(data)/hashSize)
for i := 0; i < len(data); i += hashSize {
var hash hostnameHash
copy(hash[:], data[i:i+hashSize])
hashes = append(hashes, hash)
}
return &cacheItem{
expiry: t,
hashes: hashes,
}
}
// fromCacheItem encodes cacheItem into data.
func fromCacheItem(item *cacheItem) (data []byte) {
data = make([]byte, len(item.hashes)*hashSize+expirySize)
expiry := item.expiry.Unix()
binary.BigEndian.PutUint64(data[:expirySize], uint64(expiry))
for _, v := range item.hashes {
// nolint:looppointer // The subsilce is used for a copy.
data = append(data, v[:]...)
}
return data
}
// findInCache finds hashes in the cache. If nothing found returns list of
// hashes, prefixes of which will be sent to upstream.
func (c *Checker) findInCache(
hashes []hostnameHash,
) (found, blocked bool, hashesToRequest []hostnameHash) {
now := time.Now()
i := 0
for _, hash := range hashes {
// nolint:looppointer // The subsilce is used for a safe cache lookup.
data := c.cache.Get(hash[:prefixLen])
if data == nil {
hashes[i] = hash
i++
continue
}
item := toCacheItem(data)
if now.After(item.expiry) {
hashes[i] = hash
i++
continue
}
if ok := findMatch(hashes, item.hashes); ok {
return true, true, nil
}
}
if i == 0 {
return true, false, nil
}
return false, false, hashes[:i]
}
// storeInCache caches hashes.
func (c *Checker) storeInCache(hashesToRequest, respHashes []hostnameHash) {
hashToStore := make(map[prefix][]hostnameHash)
for _, hash := range respHashes {
var pref prefix
// nolint:looppointer // The subsilce is used for a copy.
copy(pref[:], hash[:])
hashToStore[pref] = append(hashToStore[pref], hash)
}
for pref, hash := range hashToStore {
// nolint:looppointer // The subsilce is used for a safe cache lookup.
c.setCache(pref[:], hash)
}
for _, hash := range hashesToRequest {
// nolint:looppointer // The subsilce is used for a safe cache lookup.
pref := hash[:prefixLen]
val := c.cache.Get(pref)
if val == nil {
c.setCache(pref, nil)
}
}
}
// setCache stores hash in cache.
func (c *Checker) setCache(pref []byte, hashes []hostnameHash) {
item := &cacheItem{
expiry: time.Now().Add(c.cacheTime),
hashes: hashes,
}
c.cache.Set(pref, fromCacheItem(item))
log.Debug("%s: stored in cache: %v", c.svc, pref)
}

View File

@@ -0,0 +1,245 @@
// Package hashprefix used for safe browsing and parent control.
package hashprefix
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
"golang.org/x/net/publicsuffix"
)
const (
// prefixLen is the length of the hash prefix of the filtered hostname.
prefixLen = 2
// hashSize is the size of hashed hostname.
hashSize = sha256.Size
// hexSize is the size of hexadecimal representation of hashed hostname.
hexSize = hashSize * 2
)
// prefix is the type of the SHA256 hash prefix used to match against the
// domain-name database.
type prefix [prefixLen]byte
// hostnameHash is the hashed hostname.
//
// TODO(s.chzhen): Split into prefix and suffix.
type hostnameHash [hashSize]byte
// findMatch returns true if one of the a hostnames matches one of the b.
func findMatch(a, b []hostnameHash) (matched bool) {
for _, hash := range a {
if slices.Contains(b, hash) {
return true
}
}
return false
}
// Config is the configuration structure for safe browsing and parental
// control.
type Config struct {
// Upstream is the upstream DNS server.
Upstream upstream.Upstream
// ServiceName is the name of the service.
ServiceName string
// TXTSuffix is the TXT suffix for DNS request.
TXTSuffix string
// CacheTime is the time period to store hash.
CacheTime time.Duration
// CacheSize is the maximum size of the cache. If it's zero, cache size is
// unlimited.
CacheSize uint
}
type Checker struct {
// upstream is the upstream DNS server.
upstream upstream.Upstream
// cache stores hostname hashes.
cache cache.Cache
// svc is the name of the service.
svc string
// txtSuffix is the TXT suffix for DNS request.
txtSuffix string
// cacheTime is the time period to store hash.
cacheTime time.Duration
}
// New returns Checker.
func New(conf *Config) (c *Checker) {
return &Checker{
upstream: conf.Upstream,
cache: cache.New(cache.Config{
EnableLRU: true,
MaxSize: conf.CacheSize,
}),
svc: conf.ServiceName,
txtSuffix: conf.TXTSuffix,
cacheTime: conf.CacheTime,
}
}
// Check returns true if request for the host should be blocked.
func (c *Checker) Check(host string) (ok bool, err error) {
hashes := hostnameToHashes(host)
found, blocked, hashesToRequest := c.findInCache(hashes)
if found {
log.Debug("%s: found %q in cache, blocked: %t", c.svc, host, blocked)
return blocked, nil
}
question := c.getQuestion(hashesToRequest)
log.Debug("%s: checking %s: %s", c.svc, host, question)
req := (&dns.Msg{}).SetQuestion(question, dns.TypeTXT)
resp, err := c.upstream.Exchange(req)
if err != nil {
return false, fmt.Errorf("getting hashes: %w", err)
}
matched, receivedHashes := c.processAnswer(hashesToRequest, resp, host)
c.storeInCache(hashesToRequest, receivedHashes)
return matched, nil
}
// hostnameToHashes returns hashes that should be checked by the hash prefix
// filter.
func hostnameToHashes(host string) (hashes []hostnameHash) {
// subDomainNum defines how many labels should be hashed to match against a
// hash prefix filter.
const subDomainNum = 4
pubSuf, icann := publicsuffix.PublicSuffix(host)
if !icann {
// Check the full private domain space.
pubSuf = ""
}
nDots := 0
i := strings.LastIndexFunc(host, func(r rune) (ok bool) {
if r == '.' {
nDots++
}
return nDots == subDomainNum
})
if i != -1 {
host = host[i+1:]
}
sub := netutil.Subdomains(host)
for _, s := range sub {
if s == pubSuf {
break
}
sum := sha256.Sum256([]byte(s))
hashes = append(hashes, sum)
}
return hashes
}
// getQuestion combines hexadecimal encoded prefixes of hashed hostnames into
// string.
func (c *Checker) getQuestion(hashes []hostnameHash) (q string) {
b := &strings.Builder{}
for _, hash := range hashes {
// nolint:looppointer // The subsilce is used for safe hex encoding.
stringutil.WriteToBuilder(b, hex.EncodeToString(hash[:prefixLen]), ".")
}
stringutil.WriteToBuilder(b, c.txtSuffix)
return b.String()
}
// processAnswer returns true if DNS response matches the hash, and received
// hashed hostnames from the upstream.
func (c *Checker) processAnswer(
hashesToRequest []hostnameHash,
resp *dns.Msg,
host string,
) (matched bool, receivedHashes []hostnameHash) {
txtCount := 0
for _, a := range resp.Answer {
txt, ok := a.(*dns.TXT)
if !ok {
continue
}
txtCount++
receivedHashes = c.appendHashesFromTXT(receivedHashes, txt, host)
}
log.Debug("%s: received answer for %s with %d TXT count", c.svc, host, txtCount)
matched = findMatch(hashesToRequest, receivedHashes)
if matched {
log.Debug("%s: matched %s", c.svc, host)
return true, receivedHashes
}
return false, receivedHashes
}
// appendHashesFromTXT appends received hashed hostnames.
func (c *Checker) appendHashesFromTXT(
hashes []hostnameHash,
txt *dns.TXT,
host string,
) (receivedHashes []hostnameHash) {
log.Debug("%s: received hashes for %s: %v", c.svc, host, txt.Txt)
for _, t := range txt.Txt {
if len(t) != hexSize {
log.Debug("%s: wrong hex size %d for %s %s", c.svc, len(t), host, t)
continue
}
buf, err := hex.DecodeString(t)
if err != nil {
log.Debug("%s: decoding hex string %s: %s", c.svc, t, err)
continue
}
var hash hostnameHash
copy(hash[:], buf)
hashes = append(hashes, hash)
}
return hashes
}

View File

@@ -0,0 +1,248 @@
package hashprefix
import (
"crypto/sha256"
"encoding/hex"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
const (
cacheTime = 10 * time.Minute
cacheSize = 10000
)
func TestChcker_getQuestion(t *testing.T) {
const suf = "sb.dns.adguard.com."
// test hostnameToHashes()
hashes := hostnameToHashes("1.2.3.sub.host.com")
assert.Len(t, hashes, 3)
hash := sha256.Sum256([]byte("3.sub.host.com"))
hexPref1 := hex.EncodeToString(hash[:prefixLen])
assert.True(t, slices.Contains(hashes, hash))
hash = sha256.Sum256([]byte("sub.host.com"))
hexPref2 := hex.EncodeToString(hash[:prefixLen])
assert.True(t, slices.Contains(hashes, hash))
hash = sha256.Sum256([]byte("host.com"))
hexPref3 := hex.EncodeToString(hash[:prefixLen])
assert.True(t, slices.Contains(hashes, hash))
hash = sha256.Sum256([]byte("com"))
assert.False(t, slices.Contains(hashes, hash))
c := &Checker{
svc: "SafeBrowsing",
txtSuffix: suf,
}
q := c.getQuestion(hashes)
assert.Contains(t, q, hexPref1)
assert.Contains(t, q, hexPref2)
assert.Contains(t, q, hexPref3)
assert.True(t, strings.HasSuffix(q, suf))
}
func TestHostnameToHashes(t *testing.T) {
testCases := []struct {
name string
host string
wantLen int
}{{
name: "basic",
host: "example.com",
wantLen: 1,
}, {
name: "sub_basic",
host: "www.example.com",
wantLen: 2,
}, {
name: "private_domain",
host: "foo.co.uk",
wantLen: 1,
}, {
name: "sub_private_domain",
host: "bar.foo.co.uk",
wantLen: 2,
}, {
name: "private_domain_v2",
host: "foo.blogspot.co.uk",
wantLen: 4,
}, {
name: "sub_private_domain_v2",
host: "bar.foo.blogspot.co.uk",
wantLen: 4,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hashes := hostnameToHashes(tc.host)
assert.Len(t, hashes, tc.wantLen)
})
}
}
func TestChecker_storeInCache(t *testing.T) {
c := &Checker{
svc: "SafeBrowsing",
cacheTime: cacheTime,
}
conf := cache.Config{}
c.cache = cache.New(conf)
// store in cache hashes for "3.sub.host.com" and "host.com"
// and empty data for hash-prefix for "sub.host.com"
hashes := []hostnameHash{}
hash := sha256.Sum256([]byte("sub.host.com"))
hashes = append(hashes, hash)
var hashesArray []hostnameHash
hash4 := sha256.Sum256([]byte("3.sub.host.com"))
hashesArray = append(hashesArray, hash4)
hash2 := sha256.Sum256([]byte("host.com"))
hashesArray = append(hashesArray, hash2)
c.storeInCache(hashes, hashesArray)
// match "3.sub.host.com" or "host.com" from cache
hashes = []hostnameHash{}
hash = sha256.Sum256([]byte("3.sub.host.com"))
hashes = append(hashes, hash)
hash = sha256.Sum256([]byte("sub.host.com"))
hashes = append(hashes, hash)
hash = sha256.Sum256([]byte("host.com"))
hashes = append(hashes, hash)
found, blocked, _ := c.findInCache(hashes)
assert.True(t, found)
assert.True(t, blocked)
// match "sub.host.com" from cache
hashes = []hostnameHash{}
hash = sha256.Sum256([]byte("sub.host.com"))
hashes = append(hashes, hash)
found, blocked, _ = c.findInCache(hashes)
assert.True(t, found)
assert.False(t, blocked)
// Match "sub.host.com" from cache. Another hash for "host.example" is not
// in the cache, so get data for it from the server.
hashes = []hostnameHash{}
hash = sha256.Sum256([]byte("sub.host.com"))
hashes = append(hashes, hash)
hash = sha256.Sum256([]byte("host.example"))
hashes = append(hashes, hash)
found, _, hashesToRequest := c.findInCache(hashes)
assert.False(t, found)
hash = sha256.Sum256([]byte("sub.host.com"))
ok := slices.Contains(hashesToRequest, hash)
assert.False(t, ok)
hash = sha256.Sum256([]byte("host.example"))
ok = slices.Contains(hashesToRequest, hash)
assert.True(t, ok)
c = &Checker{
svc: "SafeBrowsing",
cacheTime: cacheTime,
}
c.cache = cache.New(cache.Config{})
hashes = []hostnameHash{}
hash = sha256.Sum256([]byte("sub.host.com"))
hashes = append(hashes, hash)
c.cache.Set(hash[:prefixLen], make([]byte, expirySize+hashSize))
found, _, _ = c.findInCache(hashes)
assert.False(t, found)
}
func TestChecker_Check(t *testing.T) {
const hostname = "example.org"
testCases := []struct {
name string
wantBlock bool
}{{
name: "sb_no_block",
wantBlock: false,
}, {
name: "sb_block",
wantBlock: true,
}, {
name: "pc_no_block",
wantBlock: false,
}, {
name: "pc_block",
wantBlock: true,
}}
for _, tc := range testCases {
c := New(&Config{
CacheTime: cacheTime,
CacheSize: cacheSize,
})
// Prepare the upstream.
ups := aghtest.NewBlockUpstream(hostname, tc.wantBlock)
var numReq int
onExchange := ups.OnExchange
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
numReq++
return onExchange(req)
}
c.upstream = ups
t.Run(tc.name, func(t *testing.T) {
// Firstly, check the request blocking.
hits := 0
res := false
res, err := c.Check(hostname)
require.NoError(t, err)
if tc.wantBlock {
assert.True(t, res)
hits++
} else {
require.False(t, res)
}
// Check the cache state, check the response is now cached.
assert.Equal(t, 1, c.cache.Stats().Count)
assert.Equal(t, hits, c.cache.Stats().Hit)
// There was one request to an upstream.
assert.Equal(t, 1, numReq)
// Now make the same request to check the cache was used.
res, err = c.Check(hostname)
require.NoError(t, err)
if tc.wantBlock {
assert.True(t, res)
} else {
require.False(t, res)
}
// Check the cache state, it should've been used.
assert.Equal(t, 1, c.cache.Stats().Count)
assert.Equal(t, hits+1, c.cache.Stats().Hit)
// Check that there were no additional requests.
assert.Equal(t, 1, numReq)
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -458,6 +459,80 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
_ = aghhttp.WriteJSONResponse(w, r, resp) _ = aghhttp.WriteJSONResponse(w, r, resp)
} }
// setProtectedBool sets the value of a boolean pointer under a lock. l must
// protect the value under ptr.
//
// TODO(e.burkov): Make it generic?
func setProtectedBool(mu *sync.RWMutex, ptr *bool, val bool) {
mu.Lock()
defer mu.Unlock()
*ptr = val
}
// protectedBool gets the value of a boolean pointer under a read lock. l must
// protect the value under ptr.
//
// TODO(e.burkov): Make it generic?
func protectedBool(mu *sync.RWMutex, ptr *bool) (val bool) {
mu.RLock()
defer mu.RUnlock()
return *ptr
}
// handleSafeBrowsingEnable is the handler for the POST
// /control/safebrowsing/enable HTTP API.
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, true)
d.Config.ConfigModified()
}
// handleSafeBrowsingDisable is the handler for the POST
// /control/safebrowsing/disable HTTP API.
func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, false)
d.Config.ConfigModified()
}
// handleSafeBrowsingStatus is the handler for the GET
// /control/safebrowsing/status HTTP API.
func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
resp := &struct {
Enabled bool `json:"enabled"`
}{
Enabled: protectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled),
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// handleParentalEnable is the handler for the POST /control/parental/enable
// HTTP API.
func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, true)
d.Config.ConfigModified()
}
// handleParentalDisable is the handler for the POST /control/parental/disable
// HTTP API.
func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, false)
d.Config.ConfigModified()
}
// handleParentalStatus is the handler for the GET /control/parental/status
// HTTP API.
func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
resp := &struct {
Enabled bool `json:"enabled"`
}{
Enabled: protectedBool(&d.confLock, &d.Config.ParentalEnabled),
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// RegisterFilteringHandlers - register handlers // RegisterFilteringHandlers - register handlers
func (d *DNSFilter) RegisterFilteringHandlers() { func (d *DNSFilter) RegisterFilteringHandlers() {
registerHTTP := d.HTTPRegister registerHTTP := d.HTTPRegister

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -136,3 +137,171 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
}) })
} }
} }
func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) {
const (
testTimeout = time.Second
statusURL = "/control/safebrowsing/status"
)
confModCh := make(chan struct{})
filtersDir := t.TempDir()
testCases := []struct {
name string
url string
enabled bool
wantStatus assert.BoolAssertionFunc
}{{
name: "enable_off",
url: "/control/safebrowsing/enable",
enabled: false,
wantStatus: assert.True,
}, {
name: "enable_on",
url: "/control/safebrowsing/enable",
enabled: true,
wantStatus: assert.True,
}, {
name: "disable_on",
url: "/control/safebrowsing/disable",
enabled: true,
wantStatus: assert.False,
}, {
name: "disable_off",
url: "/control/safebrowsing/disable",
enabled: false,
wantStatus: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handlers := make(map[string]http.Handler)
d, err := New(&Config{
ConfigModified: func() {
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
},
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
SafeBrowsingEnabled: tc.enabled,
}, nil)
require.NoError(t, err)
t.Cleanup(d.Close)
d.RegisterFilteringHandlers()
require.NotEmpty(t, handlers)
require.Contains(t, handlers, statusURL)
r := httptest.NewRequest(http.MethodPost, tc.url, nil)
w := httptest.NewRecorder()
go handlers[tc.url].ServeHTTP(w, r)
testutil.RequireReceive(t, confModCh, testTimeout)
r = httptest.NewRequest(http.MethodGet, statusURL, nil)
w = httptest.NewRecorder()
handlers[statusURL].ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
status := struct {
Enabled bool `json:"enabled"`
}{
Enabled: false,
}
err = json.NewDecoder(w.Body).Decode(&status)
require.NoError(t, err)
tc.wantStatus(t, status.Enabled)
})
}
}
func TestDNSFilter_handleParentalStatus(t *testing.T) {
const (
testTimeout = time.Second
statusURL = "/control/parental/status"
)
confModCh := make(chan struct{})
filtersDir := t.TempDir()
testCases := []struct {
name string
url string
enabled bool
wantStatus assert.BoolAssertionFunc
}{{
name: "enable_off",
url: "/control/parental/enable",
enabled: false,
wantStatus: assert.True,
}, {
name: "enable_on",
url: "/control/parental/enable",
enabled: true,
wantStatus: assert.True,
}, {
name: "disable_on",
url: "/control/parental/disable",
enabled: true,
wantStatus: assert.False,
}, {
name: "disable_off",
url: "/control/parental/disable",
enabled: false,
wantStatus: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handlers := make(map[string]http.Handler)
d, err := New(&Config{
ConfigModified: func() {
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
},
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
ParentalEnabled: tc.enabled,
}, nil)
require.NoError(t, err)
t.Cleanup(d.Close)
d.RegisterFilteringHandlers()
require.NotEmpty(t, handlers)
require.Contains(t, handlers, statusURL)
r := httptest.NewRequest(http.MethodPost, tc.url, nil)
w := httptest.NewRecorder()
go handlers[tc.url].ServeHTTP(w, r)
testutil.RequireReceive(t, confModCh, testTimeout)
r = httptest.NewRequest(http.MethodGet, statusURL, nil)
w = httptest.NewRecorder()
handlers[statusURL].ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
status := struct {
Enabled bool `json:"enabled"`
}{
Enabled: false,
}
err = json.NewDecoder(w.Body).Decode(&status)
require.NoError(t, err)
tc.wantStatus(t, status.Enabled)
})
}
}

View File

@@ -1,433 +0,0 @@
package filtering
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
"golang.org/x/net/publicsuffix"
)
// Safe browsing and parental control methods.
// TODO(a.garipov): Make configurable.
const (
dnsTimeout = 3 * time.Second
defaultSafebrowsingServer = `https://family.adguard-dns.com/dns-query`
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
sbTXTSuffix = `sb.dns.adguard.com.`
pcTXTSuffix = `pc.dns.adguard.com.`
)
// SetParentalUpstream sets the parental upstream for *DNSFilter.
//
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
func (d *DNSFilter) SetParentalUpstream(u upstream.Upstream) {
d.parentalUpstream = u
}
// SetSafeBrowsingUpstream sets the safe browsing upstream for *DNSFilter.
//
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
func (d *DNSFilter) SetSafeBrowsingUpstream(u upstream.Upstream) {
d.safeBrowsingUpstream = u
}
func (d *DNSFilter) initSecurityServices() error {
var err error
d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer
opts := &upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
{94, 140, 14, 15},
{94, 140, 15, 16},
net.ParseIP("2a10:50c0::bad1:ff"),
net.ParseIP("2a10:50c0::bad2:ff"),
},
}
parUps, err := upstream.AddressToUpstream(d.parentalServer, opts)
if err != nil {
return fmt.Errorf("converting parental server: %w", err)
}
d.SetParentalUpstream(parUps)
sbUps, err := upstream.AddressToUpstream(d.safeBrowsingServer, opts)
if err != nil {
return fmt.Errorf("converting safe browsing server: %w", err)
}
d.SetSafeBrowsingUpstream(sbUps)
return nil
}
/*
expire byte[4]
hash byte[32]
...
*/
func (c *sbCtx) setCache(prefix, hashes []byte) {
d := make([]byte, 4+len(hashes))
expire := uint(time.Now().Unix()) + c.cacheTime*60
binary.BigEndian.PutUint32(d[:4], uint32(expire))
copy(d[4:], hashes)
c.cache.Set(prefix, d)
log.Debug("%s: stored in cache: %v", c.svc, prefix)
}
// findInHash returns 32-byte hash if it's found in hashToHost.
func (c *sbCtx) findInHash(val []byte) (hash32 [32]byte, found bool) {
for i := 4; i < len(val); i += 32 {
hash := val[i : i+32]
copy(hash32[:], hash[0:32])
_, found = c.hashToHost[hash32]
if found {
return hash32, found
}
}
return [32]byte{}, false
}
func (c *sbCtx) getCached() int {
now := time.Now().Unix()
hashesToRequest := map[[32]byte]string{}
for k, v := range c.hashToHost {
// nolint:looppointer // The subsilce is used for a safe cache lookup.
val := c.cache.Get(k[0:2])
if val == nil || now >= int64(binary.BigEndian.Uint32(val)) {
hashesToRequest[k] = v
continue
}
if hash32, found := c.findInHash(val); found {
log.Debug("%s: found in cache: %s: blocked by %v", c.svc, c.host, hash32)
return 1
}
}
if len(hashesToRequest) == 0 {
log.Debug("%s: found in cache: %s: not blocked", c.svc, c.host)
return -1
}
c.hashToHost = hashesToRequest
return 0
}
type sbCtx struct {
host string
svc string
hashToHost map[[32]byte]string
cache cache.Cache
cacheTime uint
}
func hostnameToHashes(host string) map[[32]byte]string {
hashes := map[[32]byte]string{}
tld, icann := publicsuffix.PublicSuffix(host)
if !icann {
// private suffixes like cloudfront.net
tld = ""
}
curhost := host
nDots := 0
for i := len(curhost) - 1; i >= 0; i-- {
if curhost[i] == '.' {
nDots++
if nDots == 4 {
curhost = curhost[i+1:] // "xxx.a.b.c.d" -> "a.b.c.d"
break
}
}
}
for {
if curhost == "" {
// we've reached end of string
break
}
if tld != "" && curhost == tld {
// we've reached the TLD, don't hash it
break
}
sum := sha256.Sum256([]byte(curhost))
hashes[sum] = curhost
pos := strings.IndexByte(curhost, byte('.'))
if pos < 0 {
break
}
curhost = curhost[pos+1:]
}
return hashes
}
// convert hash array to string
func (c *sbCtx) getQuestion() string {
b := &strings.Builder{}
for hash := range c.hashToHost {
// nolint:looppointer // The subsilce is used for safe hex encoding.
stringutil.WriteToBuilder(b, hex.EncodeToString(hash[0:2]), ".")
}
if c.svc == "SafeBrowsing" {
stringutil.WriteToBuilder(b, sbTXTSuffix)
return b.String()
}
stringutil.WriteToBuilder(b, pcTXTSuffix)
return b.String()
}
// Find the target hash in TXT response
func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
matched := false
hashes := [][]byte{}
for _, a := range resp.Answer {
txt, ok := a.(*dns.TXT)
if !ok {
continue
}
log.Debug("%s: received hashes for %s: %v", c.svc, c.host, txt.Txt)
for _, t := range txt.Txt {
if len(t) != 32*2 {
continue
}
hash, err := hex.DecodeString(t)
if err != nil {
continue
}
hashes = append(hashes, hash)
if !matched {
var hash32 [32]byte
copy(hash32[:], hash)
var hashHost string
hashHost, ok = c.hashToHost[hash32]
if ok {
log.Debug("%s: matched %s by %s/%s", c.svc, c.host, hashHost, t)
matched = true
}
}
}
}
return matched, hashes
}
func (c *sbCtx) storeCache(hashes [][]byte) {
slices.SortFunc(hashes, func(a, b []byte) (sortsBefore bool) {
return bytes.Compare(a, b) == -1
})
var curData []byte
var prevPrefix []byte
for i, hash := range hashes {
// nolint:looppointer // The subsilce is used for a safe comparison.
if !bytes.Equal(hash[0:2], prevPrefix) {
if i != 0 {
c.setCache(prevPrefix, curData)
curData = nil
}
prevPrefix = hashes[i][0:2]
}
curData = append(curData, hash...)
}
if len(prevPrefix) != 0 {
c.setCache(prevPrefix, curData)
}
for hash := range c.hashToHost {
// nolint:looppointer // The subsilce is used for a safe cache lookup.
prefix := hash[0:2]
val := c.cache.Get(prefix)
if val == nil {
c.setCache(prefix, nil)
}
}
}
func check(c *sbCtx, r Result, u upstream.Upstream) (Result, error) {
c.hashToHost = hostnameToHashes(c.host)
switch c.getCached() {
case -1:
return Result{}, nil
case 1:
return r, nil
}
question := c.getQuestion()
log.Tracef("%s: checking %s: %s", c.svc, c.host, question)
req := (&dns.Msg{}).SetQuestion(question, dns.TypeTXT)
resp, err := u.Exchange(req)
if err != nil {
return Result{}, err
}
matched, receivedHashes := c.processTXT(resp)
c.storeCache(receivedHashes)
if matched {
return r, nil
}
return Result{}, nil
}
// TODO(a.garipov): Unify with checkParental.
func (d *DNSFilter) checkSafeBrowsing(
host string,
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("safebrowsing lookup for %q", host)
}
sctx := &sbCtx{
host: host,
svc: "SafeBrowsing",
cache: d.safebrowsingCache,
cacheTime: d.Config.CacheTime,
}
res = Result{
Rules: []*ResultRule{{
Text: "adguard-malware-shavar",
FilterListID: SafeBrowsingListID,
}},
Reason: FilteredSafeBrowsing,
IsFiltered: true,
}
return check(sctx, res, d.safeBrowsingUpstream)
}
// TODO(a.garipov): Unify with checkSafeBrowsing.
func (d *DNSFilter) checkParental(
host string,
_ uint16,
setts *Settings,
) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("parental lookup for %q", host)
}
sctx := &sbCtx{
host: host,
svc: "Parental",
cache: d.parentalCache,
cacheTime: d.Config.CacheTime,
}
res = Result{
Rules: []*ResultRule{{
Text: "parental CATEGORY_BLACKLISTED",
FilterListID: ParentalListID,
}},
Reason: FilteredParental,
IsFiltered: true,
}
return check(sctx, res, d.parentalUpstream)
}
// setProtectedBool sets the value of a boolean pointer under a lock. l must
// protect the value under ptr.
//
// TODO(e.burkov): Make it generic?
func setProtectedBool(mu *sync.RWMutex, ptr *bool, val bool) {
mu.Lock()
defer mu.Unlock()
*ptr = val
}
// protectedBool gets the value of a boolean pointer under a read lock. l must
// protect the value under ptr.
//
// TODO(e.burkov): Make it generic?
func protectedBool(mu *sync.RWMutex, ptr *bool) (val bool) {
mu.RLock()
defer mu.RUnlock()
return *ptr
}
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, true)
d.Config.ConfigModified()
}
func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled, false)
d.Config.ConfigModified()
}
func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
resp := &struct {
Enabled bool `json:"enabled"`
}{
Enabled: protectedBool(&d.confLock, &d.Config.SafeBrowsingEnabled),
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, true)
d.Config.ConfigModified()
}
func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) {
setProtectedBool(&d.confLock, &d.Config.ParentalEnabled, false)
d.Config.ConfigModified()
}
func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
resp := &struct {
Enabled bool `json:"enabled"`
}{
Enabled: protectedBool(&d.confLock, &d.Config.ParentalEnabled),
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}

View File

@@ -1,226 +0,0 @@
package filtering
import (
"crypto/sha256"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSafeBrowsingHash(t *testing.T) {
// test hostnameToHashes()
hashes := hostnameToHashes("1.2.3.sub.host.com")
assert.Len(t, hashes, 3)
_, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("sub.host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("com"))]
assert.False(t, ok)
c := &sbCtx{
svc: "SafeBrowsing",
hashToHost: hashes,
}
q := c.getQuestion()
assert.Contains(t, q, "7a1b.")
assert.Contains(t, q, "af5a.")
assert.Contains(t, q, "eb11.")
assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com."))
}
func TestSafeBrowsingCache(t *testing.T) {
c := &sbCtx{
svc: "SafeBrowsing",
cacheTime: 100,
}
conf := cache.Config{}
c.cache = cache.New(conf)
// store in cache hashes for "3.sub.host.com" and "host.com"
// and empty data for hash-prefix for "sub.host.com"
hash := sha256.Sum256([]byte("sub.host.com"))
c.hashToHost = make(map[[32]byte]string)
c.hashToHost[hash] = "sub.host.com"
var hashesArray [][]byte
hash4 := sha256.Sum256([]byte("3.sub.host.com"))
hashesArray = append(hashesArray, hash4[:])
hash2 := sha256.Sum256([]byte("host.com"))
hashesArray = append(hashesArray, hash2[:])
c.storeCache(hashesArray)
// match "3.sub.host.com" or "host.com" from cache
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("3.sub.host.com"))
c.hashToHost[hash] = "3.sub.host.com"
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("host.com"))
c.hashToHost[hash] = "host.com"
assert.Equal(t, 1, c.getCached())
// match "sub.host.com" from cache
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
assert.Equal(t, -1, c.getCached())
// Match "sub.host.com" from cache. Another hash for "host.example" is not
// in the cache, so get data for it from the server.
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("host.example"))
c.hashToHost[hash] = "host.example"
assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash]
assert.False(t, ok)
hash = sha256.Sum256([]byte("host.example"))
_, ok = c.hashToHost[hash]
assert.True(t, ok)
c = &sbCtx{
svc: "SafeBrowsing",
cacheTime: 100,
}
conf = cache.Config{}
c.cache = cache.New(conf)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost = make(map[[32]byte]string)
c.hashToHost[hash] = "sub.host.com"
c.cache.Set(hash[0:2], make([]byte, 32))
assert.Empty(t, c.getCached())
}
func TestSBPC_checkErrorUpstream(t *testing.T) {
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
ups := aghtest.NewErrorUpstream()
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
_, err := d.checkSafeBrowsing("smthng.com", dns.TypeA, setts)
assert.Error(t, err)
_, err = d.checkParental("smthng.com", dns.TypeA, setts)
assert.Error(t, err)
}
func TestSBPC(t *testing.T) {
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const hostname = "example.org"
setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
testCases := []struct {
testCache cache.Cache
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
name string
block bool
}{{
testCache: d.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_no_block",
block: false,
}, {
testCache: d.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_block",
block: true,
}, {
testCache: d.parentalCache,
testFunc: d.checkParental,
name: "pc_no_block",
block: false,
}, {
testCache: d.parentalCache,
testFunc: d.checkParental,
name: "pc_block",
block: true,
}}
for _, tc := range testCases {
// Prepare the upstream.
ups := aghtest.NewBlockUpstream(hostname, tc.block)
var numReq int
onExchange := ups.OnExchange
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
numReq++
return onExchange(req)
}
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
t.Run(tc.name, func(t *testing.T) {
// Firstly, check the request blocking.
hits := 0
res, err := tc.testFunc(hostname, dns.TypeA, setts)
require.NoError(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
hits++
} else {
require.False(t, res.IsFiltered)
}
// Check the cache state, check the response is now cached.
assert.Equal(t, 1, tc.testCache.Stats().Count)
assert.Equal(t, hits, tc.testCache.Stats().Hit)
// There was one request to an upstream.
assert.Equal(t, 1, numReq)
// Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname, dns.TypeA, setts)
require.NoError(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
} else {
require.False(t, res.IsFiltered)
}
// Check the cache state, it should've been used.
assert.Equal(t, 1, tc.testCache.Stats().Count)
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
// Check that there were no additional requests.
assert.Equal(t, 1, numReq)
})
purgeCaches(d)
}
}

View File

@@ -399,19 +399,39 @@ func (c *configuration) getConfigFilename() string {
return configFile return configFile
} }
// getLogSettings reads logging settings from the config file. // readLogSettings reads logging settings from the config file. We do it in a
// we do it in a separate method in order to configure logger before the actual configuration is parsed and applied. // separate method in order to configure logger before the actual configuration
func getLogSettings() logSettings { // is parsed and applied.
l := logSettings{} func readLogSettings() (ls *logSettings) {
ls = &logSettings{}
yamlFile, err := readConfigFile() yamlFile, err := readConfigFile()
if err != nil { if err != nil {
return l return ls
} }
err = yaml.Unmarshal(yamlFile, &l)
err = yaml.Unmarshal(yamlFile, ls)
if err != nil { if err != nil {
log.Error("Couldn't get logging settings from the configuration: %s", err) log.Error("Couldn't get logging settings from the configuration: %s", err)
} }
return l
return ls
}
// validateBindHosts returns error if any of binding hosts from configuration is
// not a valid IP address.
func validateBindHosts(conf *configuration) (err error) {
if !conf.BindHost.IsValid() {
return errors.Error("bind_host is not a valid ip address")
}
for i, addr := range conf.DNS.BindHosts {
if !addr.IsValid() {
return fmt.Errorf("dns.bind_hosts at index %d is not a valid ip address", i)
}
}
return nil
} }
// parseConfig loads configuration from the YAML file // parseConfig loads configuration from the YAML file
@@ -425,6 +445,13 @@ func parseConfig() (err error) {
config.fileData = nil config.fileData = nil
err = yaml.Unmarshal(fileData, &config) err = yaml.Unmarshal(fileData, &config)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
err = validateBindHosts(config)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err return err
} }

View File

@@ -180,7 +180,7 @@ func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/status", handleStatus) httpRegister(http.MethodGet, "/control/status", handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage) httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) Context.mux.HandleFunc("/control/version.json", postInstall(optionalAuth(handleVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate) httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile) httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile) httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)

View File

@@ -26,15 +26,14 @@ type temporaryError interface {
Temporary() (ok bool) Temporary() (ok bool)
} }
// Get the latest available version from the Internet // handleVersionJSON is the handler for the POST /control/version.json HTTP API.
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { //
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
func handleVersionJSON(w http.ResponseWriter, r *http.Request) {
resp := &versionResponse{} resp := &versionResponse{}
if Context.disableUpdate { if Context.disableUpdate {
resp.Disabled = true resp.Disabled = true
err := json.NewEncoder(w).Encode(resp) _ = aghhttp.WriteJSONResponse(w, r, resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
return return
} }

View File

@@ -27,14 +27,17 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
@@ -143,7 +146,9 @@ func Main(clientBuildFS fs.FS) {
run(opts, clientBuildFS) run(opts, clientBuildFS)
} }
func setupContext(opts options) { // setupContext initializes [Context] fields. It also reads and upgrades
// config file if necessary.
func setupContext(opts options) (err error) {
setupContextFlags(opts) setupContextFlags(opts)
Context.tlsRoots = aghtls.SystemRootCAs() Context.tlsRoots = aghtls.SystemRootCAs()
@@ -160,10 +165,15 @@ func setupContext(opts options) {
}, },
} }
Context.mux = http.NewServeMux()
if !Context.firstRun { if !Context.firstRun {
// Do the upgrade if necessary. // Do the upgrade if necessary.
err := upgradeConfig() err = upgradeConfig()
fatalOnError(err) if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
if err = parseConfig(); err != nil { if err = parseConfig(); err != nil {
log.Error("parsing configuration file: %s", err) log.Error("parsing configuration file: %s", err)
@@ -179,11 +189,14 @@ func setupContext(opts options) {
if !opts.noEtcHosts && config.Clients.Sources.HostsFile { if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer() err = setupHostsContainer()
fatalOnError(err) if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
} }
} }
Context.mux = http.NewServeMux() return nil
} }
// setupContextFlags sets global flags and prints their status to the log. // setupContextFlags sets global flags and prints their status to the log.
@@ -285,25 +298,27 @@ func setupHostsContainer() (err error) {
return nil return nil
} }
func setupConfig(opts options) (err error) { // setupOpts sets up command-line options.
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts func setupOpts(opts options) (err error) {
config.DNS.DnsfilterConf.ConfigModified = onConfigModified err = setupBindOpts(opts)
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules)
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
config.DNS.DnsfilterConf.SafeSearchConf,
"default",
config.DNS.DnsfilterConf.SafeSearchCacheSize,
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
)
if err != nil { if err != nil {
return fmt.Errorf("initializing safesearch: %w", err) // Don't wrap the error, because it's informative enough as is.
return err
}
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile
}
return nil
}
// initContextClients initializes Context clients and related fields.
func initContextClients() (err error) {
err = setupDNSFilteringConf(config.DNS.DnsfilterConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
} }
//lint:ignore SA1019 Migration is not over. //lint:ignore SA1019 Migration is not over.
@@ -338,8 +353,19 @@ func setupConfig(opts options) (err error) {
arpdb = aghnet.NewARPDB() arpdb = aghnet.NewARPDB()
} }
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf) Context.clients.Init(
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
arpdb,
config.DNS.DnsfilterConf,
)
return nil
}
// setupBindOpts overrides bind host/port from the opts.
func setupBindOpts(opts options) (err error) {
if opts.bindPort != 0 { if opts.bindPort != 0 {
config.BindPort = opts.bindPort config.BindPort = opts.bindPort
@@ -350,12 +376,83 @@ func setupConfig(opts options) (err error) {
} }
} }
// override bind host/port from the console
if opts.bindHost.IsValid() { if opts.bindHost.IsValid() {
config.BindHost = opts.bindHost config.BindHost = opts.bindHost
} }
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile return nil
}
// setupDNSFilteringConf sets up DNS filtering configuration settings.
func setupDNSFilteringConf(conf *filtering.Config) (err error) {
const (
dnsTimeout = 3 * time.Second
sbService = "safe browsing"
defaultSafeBrowsingServer = `https://family.adguard-dns.com/dns-query`
sbTXTSuffix = `sb.dns.adguard.com.`
pcService = "parental control"
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
pcTXTSuffix = `pc.dns.adguard.com.`
)
conf.EtcHosts = Context.etcHosts
conf.ConfigModified = onConfigModified
conf.HTTPRegister = httpRegister
conf.DataDir = Context.getDataDir()
conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules)
conf.HTTPClient = Context.client
cacheTime := time.Duration(conf.CacheTime) * time.Minute
upsOpts := &upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
{94, 140, 14, 15},
{94, 140, 15, 16},
net.ParseIP("2a10:50c0::bad1:ff"),
net.ParseIP("2a10:50c0::bad2:ff"),
},
}
sbUps, err := upstream.AddressToUpstream(defaultSafeBrowsingServer, upsOpts)
if err != nil {
return fmt.Errorf("converting safe browsing server: %w", err)
}
conf.SafeBrowsingChecker = hashprefix.New(&hashprefix.Config{
Upstream: sbUps,
ServiceName: sbService,
TXTSuffix: sbTXTSuffix,
CacheTime: cacheTime,
CacheSize: conf.SafeBrowsingCacheSize,
})
parUps, err := upstream.AddressToUpstream(defaultParentalServer, upsOpts)
if err != nil {
return fmt.Errorf("converting parental server: %w", err)
}
conf.ParentalControlChecker = hashprefix.New(&hashprefix.Config{
Upstream: parUps,
ServiceName: pcService,
TXTSuffix: pcTXTSuffix,
CacheTime: cacheTime,
CacheSize: conf.SafeBrowsingCacheSize,
})
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf,
"default",
conf.SafeSearchCacheSize,
cacheTime,
)
if err != nil {
return fmt.Errorf("initializing safesearch: %w", err)
} }
return nil return nil
@@ -432,14 +529,16 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home. // run configures and starts AdGuard Home.
func run(opts options, clientBuildFS fs.FS) { func run(opts options, clientBuildFS fs.FS) {
// configure config filename // Configure config filename.
initConfigFilename(opts) initConfigFilename(opts)
// configure working dir and config path // Configure working dir and config path.
initWorkingDir(opts) err := initWorkingDir(opts)
fatalOnError(err)
// configure log level and output // Configure log level and output.
configureLogger(opts) err = configureLogger(opts)
fatalOnError(err)
// Print the first message after logger is configured. // Print the first message after logger is configured.
log.Info(version.Full()) log.Info(version.Full())
@@ -448,25 +547,29 @@ func run(opts options, clientBuildFS fs.FS) {
log.Info("AdGuard Home is running as a service") log.Info("AdGuard Home is running as a service")
} }
setupContext(opts) err = setupContext(opts)
err := configureOS(config)
fatalOnError(err) fatalOnError(err)
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()), err = configureOS(config)
// so we have to initialize filtering's static data first, fatalOnError(err)
// but also avoid relying on automatic Go init() function
// Clients package uses filtering package's static data
// (filtering.BlockedSvcKnown()), so we have to initialize filtering static
// data first, but also to avoid relying on automatic Go init() function.
filtering.InitModule() filtering.InitModule()
err = setupConfig(opts) err = initContextClients()
fatalOnError(err) fatalOnError(err)
// TODO(e.burkov): This could be made earlier, probably as the option's err = setupOpts(opts)
fatalOnError(err)
// TODO(e.burkov): This could be made earlier, probably as the option's
// effect. // effect.
cmdlineUpdate(opts) cmdlineUpdate(opts)
if !Context.firstRun { if !Context.firstRun {
// Save the updated config // Save the updated config.
err = config.write() err = config.write()
fatalOnError(err) fatalOnError(err)
@@ -476,33 +579,15 @@ func run(opts options, clientBuildFS fs.FS) {
} }
} }
err = os.MkdirAll(Context.getDataDir(), 0o755) dir := Context.getDataDir()
if err != nil { err = os.MkdirAll(dir, 0o755)
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err) fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dir))
}
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = opts.glinetMode GLMode = opts.glinetMode
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
rateLimiter = newAuthRateLimiter(
time.Duration(config.AuthBlockMin)*time.Minute,
config.AuthAttempts,
)
} else {
log.Info("authratelimiter is disabled")
}
Context.auth = InitAuth( // Init auth module.
sessFilename, Context.auth, err = initUsers()
config.Users, fatalOnError(err)
config.WebSessionTTLHours*60*60,
rateLimiter,
)
if Context.auth == nil {
log.Fatalf("Couldn't initialize Auth module")
}
config.Users = nil
Context.tls, err = newTLSManager(config.TLS) Context.tls, err = newTLSManager(config.TLS)
if err != nil { if err != nil {
@@ -520,10 +605,10 @@ func run(opts options, clientBuildFS fs.FS) {
Context.tls.start() Context.tls.start()
go func() { go func() {
serr := startDNSServer() sErr := startDNSServer()
if serr != nil { if sErr != nil {
closeDNSServer() closeDNSServer()
fatalOnError(serr) fatalOnError(sErr)
} }
}() }()
@@ -537,10 +622,33 @@ func run(opts options, clientBuildFS fs.FS) {
Context.web.start() Context.web.start()
// wait indefinitely for other go-routines to complete their job // Wait indefinitely for other goroutines to complete their job.
select {} select {}
} }
// initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
blockDur := time.Duration(config.AuthBlockMin) * time.Minute
rateLimiter = newAuthRateLimiter(blockDur, config.AuthAttempts)
} else {
log.Info("authratelimiter is disabled")
}
sessionTTL := config.WebSessionTTLHours * 60 * 60
auth = InitAuth(sessFilename, config.Users, sessionTTL, rateLimiter)
if auth == nil {
return nil, errors.Error("initializing auth module failed")
}
config.Users = nil
return auth, nil
}
func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) { func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
var anonFunc aghnet.IPMutFunc var anonFunc aghnet.IPMutFunc
if c.DNS.AnonymizeClientIP { if c.DNS.AnonymizeClientIP {
@@ -613,22 +721,19 @@ func writePIDFile(fn string) bool {
return true return true
} }
// initConfigFilename sets up context config file path. This file path can be
// overridden by command-line arguments, or is set to default.
func initConfigFilename(opts options) { func initConfigFilename(opts options) {
// config file path can be overridden by command-line arguments: Context.configFilename = stringutil.Coalesce(opts.confFilename, "AdGuardHome.yaml")
if opts.confFilename != "" {
Context.configFilename = opts.confFilename
} else {
// Default config file name
Context.configFilename = "AdGuardHome.yaml"
}
} }
// initWorkingDir initializes the workDir // initWorkingDir initializes the workDir. If no command-line arguments are
// if no command-line arguments specified, we use the directory where our binary file is located // specified, the directory with the binary file is used.
func initWorkingDir(opts options) { func initWorkingDir(opts options) (err error) {
execPath, err := os.Executable() execPath, err := os.Executable()
if err != nil { if err != nil {
panic(err) // Don't wrap the error, because it's informative enough as is.
return err
} }
if opts.workDir != "" { if opts.workDir != "" {
@@ -640,34 +745,20 @@ func initWorkingDir(opts options) {
workDir, err := filepath.EvalSymlinks(Context.workDir) workDir, err := filepath.EvalSymlinks(Context.workDir)
if err != nil { if err != nil {
panic(err) // Don't wrap the error, because it's informative enough as is.
return err
} }
Context.workDir = workDir Context.workDir = workDir
return nil
} }
// configureLogger configures logger level and output // configureLogger configures logger level and output.
func configureLogger(opts options) { func configureLogger(opts options) (err error) {
ls := getLogSettings() ls := getLogSettings(opts)
// command-line arguments can override config settings // Configure logger level.
if opts.verbose || config.Verbose {
ls.Verbose = true
}
if opts.logFile != "" {
ls.File = opts.logFile
} else if config.File != "" {
ls.File = config.File
}
// Handle default log settings overrides
ls.Compress = config.Compress
ls.LocalTime = config.LocalTime
ls.MaxBackups = config.MaxBackups
ls.MaxSize = config.MaxSize
ls.MaxAge = config.MaxAge
// log.SetLevel(log.INFO) - default
if ls.Verbose { if ls.Verbose {
log.SetLevel(log.DEBUG) log.SetLevel(log.DEBUG)
} }
@@ -676,38 +767,63 @@ func configureLogger(opts options) {
// happen pretty quickly. // happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds) log.SetFlags(log.LstdFlags | log.Lmicroseconds)
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" { // Write logs to stdout by default.
// When running as a Windows service, use eventlog by default if nothing
// else is configured. Otherwise, we'll simply lose the log output.
ls.File = configSyslog
}
// logs are written to stdout (default)
if ls.File == "" { if ls.File == "" {
return return nil
} }
if ls.File == configSyslog { if ls.File == configSyslog {
// Use syslog where it is possible and eventlog on Windows // Use syslog where it is possible and eventlog on Windows.
err := aghos.ConfigureSyslog(serviceName) err = aghos.ConfigureSyslog(serviceName)
if err != nil { if err != nil {
log.Fatalf("cannot initialize syslog: %s", err) return fmt.Errorf("cannot initialize syslog: %w", err)
}
} else {
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
} }
log.SetOutput(&lumberjack.Logger{ return nil
Filename: logFilePath,
Compress: ls.Compress, // disabled by default
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize, // megabytes
MaxAge: ls.MaxAge, // days
})
} }
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{
Filename: logFilePath,
Compress: ls.Compress,
LocalTime: ls.LocalTime,
MaxBackups: ls.MaxBackups,
MaxSize: ls.MaxSize,
MaxAge: ls.MaxAge,
})
return nil
}
// getLogSettings returns a log settings object properly initialized from opts.
func getLogSettings(opts options) (ls *logSettings) {
ls = readLogSettings()
// Command-line arguments can override config settings.
if opts.verbose || config.Verbose {
ls.Verbose = true
}
ls.File = stringutil.Coalesce(opts.logFile, config.File, ls.File)
// Handle default log settings overrides.
ls.Compress = config.Compress
ls.LocalTime = config.LocalTime
ls.MaxBackups = config.MaxBackups
ls.MaxSize = config.MaxSize
ls.MaxAge = config.MaxAge
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if
// nothing else is configured. Otherwise, we'll lose the log output.
ls.File = configSyslog
}
return ls
} }
// cleanup stops and resets all the modules. // cleanup stops and resets all the modules.

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@@ -84,14 +83,9 @@ func svcStatus(s service.Service) (status service.Status, err error) {
// On OpenWrt, the service utility may not exist. We use our service script // On OpenWrt, the service utility may not exist. We use our service script
// directly in this case. // directly in this case.
func svcAction(s service.Service, action string) (err error) { func svcAction(s service.Service, action string) (err error) {
if runtime.GOOS == "darwin" && action == "start" { if action == "start" {
var exe string if err = aghos.PreCheckActionStart(); err != nil {
if exe, err = os.Executable(); err != nil { log.Error("starting service: %s", err)
log.Error("starting service: getting executable path: %s", err)
} else if exe, err = filepath.EvalSymlinks(exe); err != nil {
log.Error("starting service: evaluating executable symlinks: %s", err)
} else if !strings.HasPrefix(exe, "/Applications/") {
log.Info("warning: service must be started from within the /Applications directory")
} }
} }
@@ -99,8 +93,6 @@ func svcAction(s service.Service, action string) (err error) {
if err != nil && service.Platform() == "unix-systemv" && if err != nil && service.Platform() == "unix-systemv" &&
(action == "start" || action == "stop" || action == "restart") { (action == "start" || action == "stop" || action == "restart") {
_, err = runInitdCommand(action) _, err = runInitdCommand(action)
return err
} }
return err return err
@@ -224,6 +216,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
runOpts := opts runOpts := opts
runOpts.serviceControlAction = "run" runOpts.serviceControlAction = "run"
svcConfig := &service.Config{ svcConfig := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: serviceDisplayName, DisplayName: serviceDisplayName,
@@ -233,35 +226,48 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
} }
configureService(svcConfig) configureService(svcConfig)
prg := &program{ s, err := service.New(&program{clientBuildFS: clientBuildFS, opts: runOpts}, svcConfig)
clientBuildFS: clientBuildFS, if err != nil {
opts: runOpts,
}
var s service.Service
if s, err = service.New(prg, svcConfig); err != nil {
log.Fatalf("service: initializing service: %s", err) log.Fatalf("service: initializing service: %s", err)
} }
err = handleServiceCommand(s, action, opts)
if err != nil {
log.Fatalf("service: %s", err)
}
log.Printf(
"service: action %s has been done successfully on %s",
action,
service.ChosenSystem(),
)
}
// handleServiceCommand handles service command.
func handleServiceCommand(s service.Service, action string, opts options) (err error) {
switch action { switch action {
case "status": case "status":
handleServiceStatusCommand(s) handleServiceStatusCommand(s)
case "run": case "run":
if err = s.Run(); err != nil { if err = s.Run(); err != nil {
log.Fatalf("service: failed to run service: %s", err) return fmt.Errorf("failed to run service: %w", err)
} }
case "install": case "install":
initConfigFilename(opts) initConfigFilename(opts)
initWorkingDir(opts) if err = initWorkingDir(opts); err != nil {
return fmt.Errorf("failed to init working dir: %w", err)
}
handleServiceInstallCommand(s) handleServiceInstallCommand(s)
case "uninstall": case "uninstall":
handleServiceUninstallCommand(s) handleServiceUninstallCommand(s)
default: default:
if err = svcAction(s, action); err != nil { if err = svcAction(s, action); err != nil {
log.Fatalf("service: executing action %q: %s", action, err) return fmt.Errorf("executing action %q: %w", action, err)
} }
} }
log.Printf("service: action %s has been done successfully on %s", action, service.ChosenSystem()) return nil
} }
// handleServiceStatusCommand handles service "status" command. // handleServiceStatusCommand handles service "status" command.

View File

@@ -172,9 +172,32 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
} }
}() }()
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain) err = loadCertificateChainData(tlsConf, status)
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey) if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = loadPrivateKeyData(tlsConf, status)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = validateCertificates(
status,
tlsConf.CertificateChainData,
tlsConf.PrivateKeyData,
tlsConf.ServerName,
)
return errors.Annotate(err, "validating certificate pair: %w")
}
// loadCertificateChainData loads PEM-encoded certificates chain data to the
// TLS configuration.
func loadCertificateChainData(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
if tlsConf.CertificatePath != "" { if tlsConf.CertificatePath != "" {
if tlsConf.CertificateChain != "" { if tlsConf.CertificateChain != "" {
return errors.Error("certificate data and file can't be set together") return errors.Error("certificate data and file can't be set together")
@@ -190,6 +213,13 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
status.ValidCert = true status.ValidCert = true
} }
return nil
}
// loadPrivateKeyData loads PEM-encoded private key data to the TLS
// configuration.
func loadPrivateKeyData(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
if tlsConf.PrivateKeyPath != "" { if tlsConf.PrivateKeyPath != "" {
if tlsConf.PrivateKey != "" { if tlsConf.PrivateKey != "" {
return errors.Error("private key data and file can't be set together") return errors.Error("private key data and file can't be set together")
@@ -203,16 +233,6 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
status.ValidKey = true status.ValidKey = true
} }
err = validateCertificates(
status,
tlsConf.CertificateChainData,
tlsConf.PrivateKeyData,
tlsConf.ServerName,
)
if err != nil {
return fmt.Errorf("validating certificate pair: %w", err)
}
return nil return nil
} }

View File

@@ -41,7 +41,8 @@ func upgradeConfig() error {
err = yaml.Unmarshal(body, &diskConf) err = yaml.Unmarshal(body, &diskConf)
if err != nil { if err != nil {
log.Printf("Couldn't parse config file: %s", err) log.Printf("parsing config file for upgrade: %s", err)
return err return err
} }
@@ -293,71 +294,61 @@ func upgradeSchema4to5(diskConf yobj) error {
return nil return nil
} }
// clients: // upgradeSchema5to6 performs the following changes:
// ...
// //
// ip: 127.0.0.1 // # BEFORE:
// mac: ... // 'clients':
// ...
// 'ip': 127.0.0.1
// 'mac': ...
// //
// -> // # AFTER:
// // 'clients':
// clients: // ...
// ... // 'ids':
// // - 127.0.0.1
// ids: // - ...
// - 127.0.0.1
// - ...
func upgradeSchema5to6(diskConf yobj) error { func upgradeSchema5to6(diskConf yobj) error {
log.Printf("%s(): called", funcName()) log.Printf("Upgrade yaml: 5 to 6")
diskConf["schema_version"] = 6 diskConf["schema_version"] = 6
clients, ok := diskConf["clients"] clientsVal, ok := diskConf["clients"]
if !ok { if !ok {
return nil return nil
} }
switch arr := clients.(type) { clients, ok := clientsVal.([]yobj)
case []any: if !ok {
for i := range arr { return fmt.Errorf("unexpected type of clients: %T", clientsVal)
switch c := arr[i].(type) { }
case map[any]any:
var ipVal any
ipVal, ok = c["ip"]
ids := []string{}
if ok {
var ip string
ip, ok = ipVal.(string)
if !ok {
log.Fatalf("client.ip is not a string: %v", ipVal)
return nil
}
if len(ip) != 0 {
ids = append(ids, ip)
}
}
var macVal any for i := range clients {
macVal, ok = c["mac"] c := clients[i]
if ok { var ids []string
var mac string
mac, ok = macVal.(string)
if !ok {
log.Fatalf("client.mac is not a string: %v", macVal)
return nil
}
if len(mac) != 0 {
ids = append(ids, mac)
}
}
c["ids"] = ids if ipVal, hasIP := c["ip"]; hasIP {
default: var ip string
continue if ip, ok = ipVal.(string); !ok {
return fmt.Errorf("client.ip is not a string: %v", ipVal)
}
if ip != "" {
ids = append(ids, ip)
} }
} }
default:
return nil if macVal, hasMac := c["mac"]; hasMac {
var mac string
if mac, ok = macVal.(string); !ok {
return fmt.Errorf("client.mac is not a string: %v", macVal)
}
if mac != "" {
ids = append(ids, mac)
}
}
c["ids"] = ids
} }
return nil return nil

View File

@@ -68,6 +68,95 @@ func TestUpgradeSchema2to3(t *testing.T) {
assertEqualExcept(t, oldDiskConf, diskConf, excludedEntries, excludedEntries) assertEqualExcept(t, oldDiskConf, diskConf, excludedEntries, excludedEntries)
} }
func TestUpgradeSchema5to6(t *testing.T) {
const newSchemaVer = 6
testCases := []struct {
in yobj
want yobj
wantErr string
name string
}{{
in: yobj{
"clients": []yobj{},
},
want: yobj{
"clients": []yobj{},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "no_clients",
}, {
in: yobj{
"clients": []yobj{{"ip": "127.0.0.1"}},
},
want: yobj{
"clients": []yobj{{
"ids": []string{"127.0.0.1"},
"ip": "127.0.0.1",
}},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "client_ip",
}, {
in: yobj{
"clients": []yobj{{"mac": "mac"}},
},
want: yobj{
"clients": []yobj{{
"ids": []string{"mac"},
"mac": "mac",
}},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "client_mac",
}, {
in: yobj{
"clients": []yobj{{"ip": "127.0.0.1", "mac": "mac"}},
},
want: yobj{
"clients": []yobj{{
"ids": []string{"127.0.0.1", "mac"},
"ip": "127.0.0.1",
"mac": "mac",
}},
"schema_version": newSchemaVer,
},
wantErr: "",
name: "client_ip_mac",
}, {
in: yobj{
"clients": []yobj{{"ip": 1, "mac": "mac"}},
},
want: yobj{
"clients": []yobj{{"ip": 1, "mac": "mac"}},
"schema_version": newSchemaVer,
},
wantErr: "client.ip is not a string: 1",
name: "inv_client_ip",
}, {
in: yobj{
"clients": []yobj{{"ip": "127.0.0.1", "mac": 1}},
},
want: yobj{
"clients": []yobj{{"ip": "127.0.0.1", "mac": 1}},
"schema_version": newSchemaVer,
},
wantErr: "client.mac is not a string: 1",
name: "inv_client_mac",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema5to6(tc.in)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, tc.in)
})
}
}
func TestUpgradeSchema7to8(t *testing.T) { func TestUpgradeSchema7to8(t *testing.T) {
const host = "1.2.3.4" const host = "1.2.3.4"
oldConf := yobj{ oldConf := yobj{

View File

@@ -4,6 +4,16 @@
## v0.108.0: API changes ## v0.108.0: API changes
## v0.107.30: API changes
### `POST /control/version.json` and `GET /control/dhcp/interfaces` content type
* The value of the `Content-Type` header in the `POST /control/version.json` and
`GET /control/dhcp/interfaces` HTTP APIs is now correctly set to
`application/json` as opposed to `text/plain`.
## v0.107.29: API changes ## v0.107.29: API changes
### `GET /control/clients` And `GET /control/clients/find` ### `GET /control/clients` And `GET /control/clients/find`
@@ -16,6 +26,8 @@
set AdGuard Home will use default value (false). It can be changed in the set AdGuard Home will use default value (false). It can be changed in the
future versions. future versions.
## v0.107.27: API changes ## v0.107.27: API changes
### The new optional fields `"edns_cs_use_custom"` and `"edns_cs_custom_ip"` in `DNSConfig` ### The new optional fields `"edns_cs_use_custom"` and `"edns_cs_custom_ip"` in `DNSConfig`

View File

@@ -161,11 +161,8 @@ run_linter "$GO" vet ./...
run_linter govulncheck ./... run_linter govulncheck ./...
# Apply more lax standards to the code we haven't properly refactored yet. # Apply more lax standards to the code we haven't properly refactored yet.
run_linter gocyclo --over 13\ run_linter gocyclo --over 13 ./internal/querylog
./internal/dhcpd\ run_linter gocyclo --over 12 ./internal/dhcpd
./internal/home/\
./internal/querylog/\
;
# Apply the normal standards to new or somewhat refactored code. # Apply the normal standards to new or somewhat refactored code.
run_linter gocyclo --over 10\ run_linter gocyclo --over 10\
@@ -175,6 +172,7 @@ run_linter gocyclo --over 10\
./internal/aghtest/\ ./internal/aghtest/\
./internal/dnsforward/\ ./internal/dnsforward/\
./internal/filtering/\ ./internal/filtering/\
./internal/home/\
./internal/stats/\ ./internal/stats/\
./internal/tools/\ ./internal/tools/\
./internal/updater/\ ./internal/updater/\