Compare commits
35 Commits
v0.108.0-b
...
2499-rewri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1be2bab4d | ||
|
|
53cd9b7a1a | ||
|
|
d8d7a5c335 | ||
|
|
18a6066df5 | ||
|
|
18392943fa | ||
|
|
c2abedec70 | ||
|
|
bbdcc673a2 | ||
|
|
d3bf5fcb05 | ||
|
|
5a794411d9 | ||
|
|
8e058b8042 | ||
|
|
d76834f843 | ||
|
|
e7fc61a997 | ||
|
|
97af23b0af | ||
|
|
5480bed1f7 | ||
|
|
c5fb7e6b0d | ||
|
|
9efc381224 | ||
|
|
e481922d91 | ||
|
|
defde7d0fe | ||
|
|
0c03063c8a | ||
|
|
0ddd8e3dcc | ||
|
|
48cbc7bdf0 | ||
|
|
299371e0fd | ||
|
|
12f52f07c5 | ||
|
|
de08ef0077 | ||
|
|
990311c9e0 | ||
|
|
526c358697 | ||
|
|
e657899c32 | ||
|
|
fb3602853a | ||
|
|
2cf171f21e | ||
|
|
e56f465ad8 | ||
|
|
a8e80bc583 | ||
|
|
9a186d0a8a | ||
|
|
2d29455d7f | ||
|
|
55a0dec144 | ||
|
|
6b607e982b |
11
CHANGELOG.md
11
CHANGELOG.md
@@ -42,11 +42,22 @@ See also the [v0.107.21 GitHub milestone][ms-v0.107.21].
|
||||
|
||||
### Fixed
|
||||
|
||||
- `AdGuardHome --update` freezing when another instance of AdGuard Home is
|
||||
running ([#4223], [#5191]).
|
||||
- The `--update` flag performing an update even with the same version.
|
||||
- Failing HTTPS redirection on saving the encryption settings ([#4898]).
|
||||
- Zeroing rules counter of erroneusly edited filtering rule lists ([#5290]).
|
||||
- Filters updating strategy, which could sometimes lead to use of broken or
|
||||
incompletely downloaded lists ([#5258]).
|
||||
- Errors popping up during updates of settings, which could sometimes cause the
|
||||
server to stop responding ([#5251]).
|
||||
|
||||
[#4898]: https://github.com/AdguardTeam/AdGuardHome/issues/4898
|
||||
[#5191]: https://github.com/AdguardTeam/AdGuardHome/issues/5191
|
||||
[#5238]: https://github.com/AdguardTeam/AdGuardHome/issues/5238
|
||||
[#5251]: https://github.com/AdguardTeam/AdGuardHome/issues/5251
|
||||
[#5258]: https://github.com/AdguardTeam/AdGuardHome/issues/5258
|
||||
[#5290]: https://github.com/AdguardTeam/AdGuardHome/issues/5290
|
||||
|
||||
[ms-v0.107.21]: https://github.com/AdguardTeam/AdGuardHome/milestone/57?closed=1
|
||||
|
||||
|
||||
@@ -239,18 +239,12 @@
|
||||
;;
|
||||
esac
|
||||
|
||||
# Ignore errors from the Snapstore upload script, because it seems to
|
||||
# have a lot of issues recently.
|
||||
#
|
||||
# TODO(a.garipov): Stop ignoring those errors once they fix the issues.
|
||||
#
|
||||
# See https://forum.snapcraft.io/t/unable-to-upload-promote-snaps-to-edge/33120.
|
||||
env\
|
||||
SNAPCRAFT_CHANNEL="$snapchannel"\
|
||||
SNAPCRAFT_EMAIL="${bamboo.snapcraftEmail}"\
|
||||
SNAPCRAFT_MACAROON="${bamboo.snapcraftMacaroonPassword}"\
|
||||
SNAPCRAFT_UBUNTU_DISCHARGE="${bamboo.snapcraftUbuntuDischargePassword}"\
|
||||
../bamboo-deploy-publisher/deploy.sh adguard-home-snap || :
|
||||
../bamboo-deploy-publisher/deploy.sh adguard-home-snap
|
||||
'final-tasks':
|
||||
- 'clean'
|
||||
'requirements':
|
||||
|
||||
@@ -41,6 +41,12 @@ export const setTlsConfig = (config) => async (dispatch, getState) => {
|
||||
response.certificate_chain = atob(response.certificate_chain);
|
||||
response.private_key = atob(response.private_key);
|
||||
|
||||
if (values.enabled && values.force_https && window.location.protocol === 'http:') {
|
||||
window.location.reload();
|
||||
return;
|
||||
}
|
||||
redirectToCurrentProtocol(response, httpPort);
|
||||
|
||||
const dnsStatus = await apiClient.getGlobalStatus();
|
||||
if (dnsStatus) {
|
||||
dispatch(dnsStatusSuccess(dnsStatus));
|
||||
@@ -48,7 +54,6 @@ export const setTlsConfig = (config) => async (dispatch, getState) => {
|
||||
|
||||
dispatch(setTlsConfigSuccess(response));
|
||||
dispatch(addSuccessToast('encryption_config_saved'));
|
||||
redirectToCurrentProtocol(response, httpPort);
|
||||
} catch (error) {
|
||||
dispatch(addErrorToast({ error }));
|
||||
dispatch(setTlsConfigFailure());
|
||||
|
||||
@@ -155,7 +155,7 @@ const Form = (props) => {
|
||||
name={FORM_NAMES.search}
|
||||
component={renderFilterField}
|
||||
type="text"
|
||||
className={classNames('form-control--search form-control--transparent', className)}
|
||||
className={classNames('form-control form-control--search form-control--transparent', className)}
|
||||
placeholder={t('domain_or_client')}
|
||||
tooltip={t('query_log_strict_search')}
|
||||
onClearInputClick={onInputClear}
|
||||
|
||||
@@ -103,14 +103,12 @@
|
||||
}
|
||||
|
||||
.form-control--search {
|
||||
box-shadow: 0 1px 0 #ddd;
|
||||
padding: 0 2.5rem;
|
||||
height: 2.25rem;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
.form-control--transparent {
|
||||
border: 0 solid transparent !important;
|
||||
background-color: transparent !important;
|
||||
}
|
||||
|
||||
@@ -174,10 +172,8 @@
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
|
||||
--size: 2.5rem;
|
||||
width: var(--size);
|
||||
height: var(--size);
|
||||
width: 2.5rem;
|
||||
height: 2.5rem;
|
||||
padding: 0;
|
||||
margin-left: 0.9375rem;
|
||||
background-color: transparent;
|
||||
@@ -474,7 +470,7 @@
|
||||
|
||||
.filteringRules__filter {
|
||||
font-style: italic;
|
||||
font-weight: normal;
|
||||
font-weight: 400;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,12 +11,13 @@ import Select from 'react-select';
|
||||
import i18n from '../../../i18n';
|
||||
import Tabs from '../../ui/Tabs';
|
||||
import Examples from '../Dns/Upstream/Examples';
|
||||
import { toggleAllServices } from '../../../helpers/helpers';
|
||||
import { toggleAllServices, trimLinesAndRemoveEmpty } from '../../../helpers/helpers';
|
||||
import {
|
||||
renderInputField,
|
||||
renderGroupField,
|
||||
CheckboxField,
|
||||
renderServiceField,
|
||||
renderTextareaField,
|
||||
} from '../../../helpers/form';
|
||||
import { validateClientId, validateRequiredValue } from '../../../helpers/validators';
|
||||
import { CLIENT_ID_LINK, FORM_NAME } from '../../../helpers/constants';
|
||||
@@ -230,10 +231,11 @@ let Form = (props) => {
|
||||
<Field
|
||||
id="upstreams"
|
||||
name="upstreams"
|
||||
component="textarea"
|
||||
component={renderTextareaField}
|
||||
type="text"
|
||||
className="form-control form-control--textarea mb-5"
|
||||
placeholder={t('upstream_dns')}
|
||||
normalizeOnBlur={trimLinesAndRemoveEmpty}
|
||||
/>
|
||||
<Examples />
|
||||
</div>,
|
||||
|
||||
@@ -390,6 +390,7 @@ export const SPECIAL_FILTER_ID = {
|
||||
PARENTAL: -3,
|
||||
SAFE_BROWSING: -4,
|
||||
SAFE_SEARCH: -5,
|
||||
REWRITES: -6,
|
||||
};
|
||||
|
||||
export const BLOCK_ACTIONS = {
|
||||
|
||||
@@ -530,14 +530,14 @@ func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP)
|
||||
// prepareInternalProxy initializes the DNS proxy that is used for internal DNS
|
||||
// queries, such as public clients PTR resolving and updater hostname resolving.
|
||||
func (s *Server) prepareInternalProxy() (err error) {
|
||||
srvConf := s.conf
|
||||
conf := &proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: s.conf.UpstreamConfig,
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
MaxGoroutines: int(s.conf.MaxGoroutines),
|
||||
}
|
||||
|
||||
srvConf := s.conf
|
||||
setProxyUpstreamMode(
|
||||
conf,
|
||||
srvConf.AllServers,
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
@@ -67,7 +68,7 @@ func createTestServer(
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(filterConf, filters)
|
||||
f, err := filtering.New(filterConf, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
@@ -760,7 +761,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
@@ -880,21 +881,22 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
c := &filtering.Config{
|
||||
Rewrites: []*filtering.LegacyRewrite{{
|
||||
Rewrites: []*filtering.RewriteItem{{
|
||||
Domain: "test.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "alias.test.com",
|
||||
Answer: "test.com",
|
||||
Type: dns.TypeCNAME,
|
||||
}, {
|
||||
Domain: "my.alias.example.org",
|
||||
Answer: "example.org",
|
||||
Type: dns.TypeCNAME,
|
||||
}},
|
||||
}
|
||||
f, err := filtering.New(c, nil)
|
||||
|
||||
rewriteStorage, err := rewrite.NewDefaultStorage(c.Rewrites)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := filtering.New(c, nil, rewriteStorage)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
@@ -945,6 +947,12 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
@@ -952,8 +960,15 @@ func TestRewrite(t *testing.T) {
|
||||
require.Len(t, reply.Answer, 2)
|
||||
|
||||
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
|
||||
|
||||
req = createTestMessageWithType("alias.test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
|
||||
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
@@ -967,6 +982,12 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
|
||||
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
|
||||
|
||||
req = createTestMessageWithType("my.alias.test.com.", dns.TypeTXT)
|
||||
reply, eerr = dns.Exchange(req, addr.String())
|
||||
require.NoError(t, eerr)
|
||||
|
||||
assert.Empty(t, reply.Answer)
|
||||
}
|
||||
|
||||
for _, protect := range []bool{true, false} {
|
||||
@@ -1011,7 +1032,7 @@ var testDHCP = &dhcpd.MockInterface{
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||
flt, err := filtering.New(&filtering.Config{}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
@@ -1085,7 +1106,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
flt.SetEnabled(true)
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters, nil)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package filtering
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -97,14 +99,15 @@ func (d *DNSFilter) filterSetProperties(
|
||||
filt.URL,
|
||||
)
|
||||
|
||||
defer func(oldURL, oldName string, oldEnabled bool, oldUpdated time.Time) {
|
||||
defer func(oldURL, oldName string, oldEnabled bool, oldUpdated time.Time, oldRulesCount int) {
|
||||
if err != nil {
|
||||
filt.URL = oldURL
|
||||
filt.Name = oldName
|
||||
filt.Enabled = oldEnabled
|
||||
filt.LastUpdated = oldUpdated
|
||||
filt.RulesCount = oldRulesCount
|
||||
}
|
||||
}(filt.URL, filt.Name, filt.Enabled, filt.LastUpdated)
|
||||
}(filt.URL, filt.Name, filt.Enabled, filt.LastUpdated, filt.RulesCount)
|
||||
|
||||
filt.Name = newList.Name
|
||||
|
||||
@@ -134,8 +137,8 @@ func (d *DNSFilter) filterSetProperties(
|
||||
// TODO(e.burkov): The validation of the contents of the new URL is
|
||||
// currently skipped if the rule list is disabled. This makes it
|
||||
// possible to set a bad rules source, but the validation should still
|
||||
// kick in when the filter is enabled. Consider making changing this
|
||||
// behavior to be stricter.
|
||||
// kick in when the filter is enabled. Consider changing this behavior
|
||||
// to be stricter.
|
||||
filt.unload()
|
||||
}
|
||||
|
||||
@@ -269,10 +272,10 @@ func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||
// already going on.
|
||||
//
|
||||
// TODO(e.burkov): Get rid of the concurrency pattern which requires the
|
||||
// sync.Mutex.TryLock.
|
||||
// [sync.Mutex.TryLock].
|
||||
func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
|
||||
if ok = d.refreshLock.TryLock(); !ok {
|
||||
return 0, false, ok
|
||||
return 0, false, false
|
||||
}
|
||||
defer d.refreshLock.Unlock()
|
||||
|
||||
@@ -427,52 +430,124 @@ func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
|
||||
return updNum, false
|
||||
}
|
||||
|
||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||
func isPrintableText(data []byte, len int) bool {
|
||||
for i := 0; i < len; i++ {
|
||||
c := data[i]
|
||||
// isPrintableText returns true if data is printable UTF-8 text with CR, LF, TAB
|
||||
// characters.
|
||||
//
|
||||
// TODO(e.burkov): Investigate the purpose of this and improve the
|
||||
// implementation. Perhaps, use something from the unicode package.
|
||||
func isPrintableText(data string) (ok bool) {
|
||||
for _, c := range []byte(data) {
|
||||
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
r := bufio.NewReader(file)
|
||||
checksum := uint32(0)
|
||||
// scanLinesWithBreak is essentially a [bufio.ScanLines] which keeps trailing
|
||||
// line breaks.
|
||||
func scanLinesWithBreak(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
|
||||
if i := bytes.IndexByte(data, '\n'); i >= 0 {
|
||||
return i + 1, data[0 : i+1], nil
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
//
|
||||
} else if line[0] == '!' {
|
||||
m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
} else if line[0] == '#' {
|
||||
//
|
||||
} else {
|
||||
rulesCount++
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// parseFilter copies filter's content from src to dst and returns the number of
|
||||
// rules, name, number of bytes written, checksum, and title of the parsed list.
|
||||
// dst must not be nil.
|
||||
func (d *DNSFilter) parseFilter(
|
||||
src io.Reader,
|
||||
dst io.Writer,
|
||||
) (rulesNum, written int, checksum uint32, title string, err error) {
|
||||
scanner := bufio.NewScanner(src)
|
||||
scanner.Split(scanLinesWithBreak)
|
||||
|
||||
titleFound := false
|
||||
for n := 0; scanner.Scan(); written += n {
|
||||
line := scanner.Text()
|
||||
var isRule bool
|
||||
var likelyTitle string
|
||||
isRule, likelyTitle, err = d.parseFilterLine(line, !titleFound, written == 0)
|
||||
if err != nil {
|
||||
return 0, written, 0, "", err
|
||||
}
|
||||
|
||||
if isRule {
|
||||
rulesNum++
|
||||
} else if likelyTitle != "" {
|
||||
title, titleFound = likelyTitle, true
|
||||
}
|
||||
|
||||
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
|
||||
|
||||
n, err = dst.Write([]byte(line))
|
||||
if err != nil {
|
||||
break
|
||||
return 0, written, 0, "", fmt.Errorf("writing filter line: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rulesCount, checksum, name
|
||||
if err = scanner.Err(); err != nil {
|
||||
return 0, written, 0, "", fmt.Errorf("scanning filter contents: %w", err)
|
||||
}
|
||||
|
||||
return rulesNum, written, checksum, title, nil
|
||||
}
|
||||
|
||||
// parseFilterLine returns true if the passed line is a rule. line is
|
||||
// considered a rule if it's not a comment and contains no title.
|
||||
func (d *DNSFilter) parseFilterLine(
|
||||
line string,
|
||||
lookForTitle bool,
|
||||
testHTML bool,
|
||||
) (isRule bool, title string, err error) {
|
||||
if !isPrintableText(line) {
|
||||
return false, "", errors.Error("filter contains non-printable characters")
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || line[0] == '#' {
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
if testHTML && isHTML(line) {
|
||||
return false, "", errors.Error("data is HTML, not plain text")
|
||||
}
|
||||
|
||||
if line[0] == '!' && lookForTitle {
|
||||
match := d.filterTitleRegexp.FindStringSubmatch(line)
|
||||
if len(match) > 1 {
|
||||
title = match[1]
|
||||
}
|
||||
|
||||
return false, title, nil
|
||||
}
|
||||
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
// isHTML returns true if the line contains HTML tags instead of plain text.
|
||||
// line shouldn have no leading space symbols.
|
||||
//
|
||||
// TODO(ameshkov): It actually gives too much false-positives. Perhaps, just
|
||||
// check if trimmed string begins with angle bracket.
|
||||
func isHTML(line string) (ok bool) {
|
||||
line = strings.ToLower(line)
|
||||
|
||||
return strings.HasPrefix(line, "<html") || strings.HasPrefix(line, "<!doctype")
|
||||
}
|
||||
|
||||
// Perform upgrade on a filter and update LastUpdated value
|
||||
@@ -485,57 +560,10 @@ func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
}
|
||||
}
|
||||
|
||||
return b, err
|
||||
}
|
||||
|
||||
func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
|
||||
htmlTest := true
|
||||
firstChunk := make([]byte, 4*1024)
|
||||
firstChunkLen := 0
|
||||
buf := make([]byte, 64*1024)
|
||||
total := 0
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
total += n
|
||||
|
||||
if htmlTest {
|
||||
num := len(firstChunk) - firstChunkLen
|
||||
if n < num {
|
||||
num = n
|
||||
}
|
||||
copied := copy(firstChunk[firstChunkLen:], buf[:num])
|
||||
firstChunkLen += copied
|
||||
|
||||
if firstChunkLen == len(firstChunk) || err == io.EOF {
|
||||
if !isPrintableText(firstChunk, firstChunkLen) {
|
||||
return total, fmt.Errorf("data contains non-printable characters")
|
||||
}
|
||||
|
||||
s := strings.ToLower(string(firstChunk))
|
||||
if strings.Contains(s, "<html") || strings.Contains(s, "<!doctype") {
|
||||
return total, fmt.Errorf("data is HTML, not plain text")
|
||||
}
|
||||
|
||||
htmlTest = false
|
||||
firstChunk = nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err2 := tmpFile.Write(buf[:n])
|
||||
if err2 != nil {
|
||||
return total, err2
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return total, nil
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||
// according to updated. It also saves new values of flt's name, rules number
|
||||
// and checksum if sucсeeded.
|
||||
@@ -552,7 +580,8 @@ func (d *DNSFilter) finalizeUpdate(
|
||||
// Close the file before renaming it because it's required on Windows.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
||||
if err = file.Close(); err != nil {
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing temporary file: %w", err)
|
||||
}
|
||||
|
||||
@@ -564,38 +593,18 @@ func (d *DNSFilter) finalizeUpdate(
|
||||
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
|
||||
|
||||
if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
|
||||
// Don't use renamio or maybe packages, since those will require loading the
|
||||
// whole filter content to the memory on Windows.
|
||||
err = os.Rename(tmpFileName, flt.Path(d.DataDir))
|
||||
if err != nil {
|
||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||
}
|
||||
|
||||
flt.Name = stringutil.Coalesce(flt.Name, name)
|
||||
flt.checksum = cs
|
||||
flt.RulesCount = rnum
|
||||
flt.Name, flt.checksum, flt.RulesCount = aghalg.Coalesce(flt.Name, name), cs, rnum
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processUpdate copies filter's content from src to dst and returns the name,
|
||||
// rules number, and checksum for it. It also returns the number of bytes read
|
||||
// from src.
|
||||
func (d *DNSFilter) processUpdate(
|
||||
src io.Reader,
|
||||
dst *os.File,
|
||||
flt *FilterYAML,
|
||||
) (name string, rnum int, cs uint32, n int, err error) {
|
||||
if n, err = d.read(src, dst, flt); err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
if _, err = dst.Seek(0, io.SeekStart); err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
rnum, cs, name = d.parseFilterContents(dst)
|
||||
|
||||
return name, rnum, cs, n, nil
|
||||
}
|
||||
|
||||
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
||||
// the actual update has been performed.
|
||||
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
@@ -612,31 +621,21 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
}
|
||||
defer func() {
|
||||
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
ok = ok && err == nil
|
||||
if ok {
|
||||
if ok && err == nil {
|
||||
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
||||
}
|
||||
}()
|
||||
|
||||
// Change the default 0o600 permission to something more acceptable by
|
||||
// end users.
|
||||
// Change the default 0o600 permission to something more acceptable by end
|
||||
// users.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3198.
|
||||
if err = tmpFile.Chmod(0o644); err != nil {
|
||||
return false, fmt.Errorf("changing file mode: %w", err)
|
||||
}
|
||||
|
||||
var r io.Reader
|
||||
if filepath.IsAbs(flt.URL) {
|
||||
var file io.ReadCloser
|
||||
file, err = os.Open(flt.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, file.Close()) }()
|
||||
|
||||
r = file
|
||||
} else {
|
||||
var rc io.ReadCloser
|
||||
if !filepath.IsAbs(flt.URL) {
|
||||
var resp *http.Response
|
||||
resp, err = d.HTTPClient.Get(flt.URL)
|
||||
if err != nil {
|
||||
@@ -649,24 +648,30 @@ func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("got status code %d from %s, skip", resp.StatusCode, flt.URL)
|
||||
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
return false, fmt.Errorf("got status code %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
r = resp.Body
|
||||
rc = resp.Body
|
||||
} else {
|
||||
rc, err = os.Open(flt.URL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
|
||||
}
|
||||
|
||||
name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
|
||||
rnum, n, cs, name, err = d.parseFilter(rc, tmpFile)
|
||||
|
||||
return cs != flt.checksum, err
|
||||
return cs != flt.checksum && err == nil, err
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||
filterFilePath := filter.Path(d.DataDir)
|
||||
func (d *DNSFilter) load(flt *FilterYAML) (err error) {
|
||||
fileName := flt.Path(d.DataDir)
|
||||
|
||||
log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
|
||||
log.Debug("filtering: loading filter %d from %s", flt.ID, fileName)
|
||||
|
||||
file, err := os.Open(filterFilePath)
|
||||
file, err := os.Open(fileName)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// Do nothing, file doesn't exist.
|
||||
return nil
|
||||
@@ -680,13 +685,14 @@ func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||
return fmt.Errorf("getting filter file stat: %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
|
||||
log.Debug("filtering: file %s, id %d, length %d", fileName, flt.ID, st.Size())
|
||||
|
||||
rulesCount, checksum, _ := d.parseFilterContents(file)
|
||||
rulesCount, _, checksum, _, err := d.parseFilter(file, io.Discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing filter file: %w", err)
|
||||
}
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.checksum = checksum
|
||||
filter.LastUpdated = st.ModTime()
|
||||
flt.RulesCount, flt.checksum, flt.LastUpdated = rulesCount, checksum, st.ModTime()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,33 +4,23 @@ import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// serveFiltersLocally is a helper that concurrently listens on a free port to
|
||||
// respond with fltContent. It also gracefully closes the listener when the
|
||||
// test under t finishes.
|
||||
func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
|
||||
// serveHTTPLocally starts a new HTTP server, that handles its index with h. It
|
||||
// also gracefully closes the listener when the test under t finishes.
|
||||
func serveHTTPLocally(t *testing.T, h http.Handler) (urlStr string) {
|
||||
t.Helper()
|
||||
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
n, werr := w.Write(fltContent)
|
||||
require.NoError(pt, werr)
|
||||
require.Equal(pt, len(fltContent), n)
|
||||
})
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -38,9 +28,26 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
|
||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||
|
||||
addr := l.Addr()
|
||||
require.IsType(t, new(net.TCPAddr), addr)
|
||||
require.IsType(t, (*net.TCPAddr)(nil), addr)
|
||||
|
||||
return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
|
||||
return (&url.URL{
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Host: addr.String(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
// serveFiltersLocally is a helper that concurrently listens on a free port to
|
||||
// respond with fltContent.
|
||||
func serveFiltersLocally(t *testing.T, fltContent []byte) (urlStr string) {
|
||||
t.Helper()
|
||||
|
||||
return serveHTTPLocally(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
n, werr := w.Write(fltContent)
|
||||
require.NoError(pt, werr)
|
||||
require.Equal(pt, len(fltContent), n)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
@@ -61,14 +68,11 @@ func TestFilters(t *testing.T) {
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &FilterYAML{
|
||||
URL: (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
}).String(),
|
||||
URL: addr,
|
||||
}
|
||||
|
||||
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
|
||||
@@ -103,11 +107,7 @@ func TestFilters(t *testing.T) {
|
||||
anotherContent := []byte(`||example.com^`)
|
||||
oldURL := f.URL
|
||||
|
||||
ipp := serveFiltersLocally(t, anotherContent)
|
||||
f.URL = (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: ipp.String(),
|
||||
}).String()
|
||||
f.URL = serveFiltersLocally(t, anotherContent)
|
||||
t.Cleanup(func() { f.URL = oldURL })
|
||||
|
||||
updateAndAssert(t, require.True, 1)
|
||||
|
||||
@@ -33,7 +33,6 @@ import (
|
||||
// The IDs of built-in filter lists.
|
||||
//
|
||||
// Keep in sync with client/src/helpers/constants.js.
|
||||
// TODO(d.kolyshev): Add RewritesListID and don't forget to keep in sync.
|
||||
const (
|
||||
CustomListID = -iota
|
||||
SysHostsListID
|
||||
@@ -41,6 +40,7 @@ const (
|
||||
ParentalListID
|
||||
SafeBrowsingListID
|
||||
SafeSearchListID
|
||||
RewritesListID
|
||||
)
|
||||
|
||||
// ServiceEntry - blocked service array element
|
||||
@@ -90,7 +90,7 @@ type Config struct {
|
||||
ParentalCacheSize uint `yaml:"parental_cache_size"` // (in bytes)
|
||||
CacheTime uint `yaml:"cache_time"` // Element's TTL (in minutes)
|
||||
|
||||
Rewrites []*LegacyRewrite `yaml:"rewrites"`
|
||||
Rewrites []*RewriteItem `yaml:"rewrites"`
|
||||
|
||||
// Names of services to block (globally).
|
||||
// Per-client settings can override this configuration.
|
||||
@@ -190,8 +190,12 @@ type DNSFilter struct {
|
||||
|
||||
// filterTitleRegexp is the regular expression to retrieve a name of a
|
||||
// filter list.
|
||||
//
|
||||
// TODO(e.burkov): Don't use regexp for such a simple text processing task.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
rewriteStorage RewriteStorage
|
||||
|
||||
hostCheckers []hostChecker
|
||||
}
|
||||
|
||||
@@ -313,7 +317,7 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
*c = d.Config
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
c.Rewrites = slices.Clone(c.Rewrites)
|
||||
}()
|
||||
|
||||
d.filtersMu.RLock()
|
||||
@@ -324,16 +328,6 @@ func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
c.UserRules = slices.Clone(d.UserRules)
|
||||
}
|
||||
|
||||
// cloneRewrites returns a deep copy of entries.
|
||||
func cloneRewrites(entries []*LegacyRewrite) (clone []*LegacyRewrite) {
|
||||
clone = make([]*LegacyRewrite, len(entries))
|
||||
for i, rw := range entries {
|
||||
clone[i] = rw.clone()
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// SetFilters sets new filters, synchronously or asynchronously. When filters
|
||||
// are set asynchronously, the old filters continue working until the new
|
||||
// filters are ready.
|
||||
@@ -544,75 +538,52 @@ func (d *DNSFilter) matchSysHosts(
|
||||
// CNAME, breaking loops in the process.
|
||||
//
|
||||
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
|
||||
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
|
||||
// result is an exception.
|
||||
// accordingly.
|
||||
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
|
||||
if !matched {
|
||||
return Result{}
|
||||
if d.rewriteStorage == nil {
|
||||
return res
|
||||
}
|
||||
|
||||
res.Reason = Rewritten
|
||||
dnsr := d.rewriteStorage.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
DNSType: qtype,
|
||||
})
|
||||
|
||||
cnames := stringutil.NewSet()
|
||||
origHost := host
|
||||
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
|
||||
rw := rewrites[0]
|
||||
rwPat := rw.Domain
|
||||
rwAns := rw.Answer
|
||||
|
||||
log.Debug("rewrite: cname for %s is %s", host, rwAns)
|
||||
|
||||
if origHost == rwAns || rwPat == rwAns {
|
||||
// Either a request for the hostname itself or a rewrite of
|
||||
// a pattern onto itself, both of which are an exception rules.
|
||||
// Return a not filtered result.
|
||||
return Result{}
|
||||
} else if host == rwAns && isWildcard(rwPat) {
|
||||
// An "*.example.com → sub.example.com" rewrite matching in a loop.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
|
||||
|
||||
res.CanonName = host
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
host = rwAns
|
||||
if cnames.Has(host) {
|
||||
log.Info("rewrite: cname loop for %q on %q", origHost, host)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
cnames.Add(host)
|
||||
res.CanonName = host
|
||||
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
|
||||
}
|
||||
|
||||
setRewriteResult(&res, host, rewrites, qtype)
|
||||
setRewriteResult(&res, host, dnsr, qtype)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
|
||||
// be nil.
|
||||
func setRewriteResult(res *Result, host string, rewrites []*LegacyRewrite, qtype uint16) {
|
||||
for _, rw := range rewrites {
|
||||
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
if rw.IP == nil {
|
||||
// "A"/"AAAA" exception: allow getting from upstream.
|
||||
res.Reason = NotFilteredNotFound
|
||||
func setRewriteResult(res *Result, host string, dnsr []*rules.DNSRewrite, qtype uint16) {
|
||||
if len(dnsr) == 0 {
|
||||
res.Reason = NotFilteredNotFound
|
||||
|
||||
return
|
||||
return
|
||||
}
|
||||
|
||||
res.Reason = Rewritten
|
||||
|
||||
for _, dnsRewrite := range dnsr {
|
||||
if dnsRewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
ip, ok := dnsRewrite.Value.(net.IP)
|
||||
if !ok || ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
res.IPList = append(res.IPList, rw.IP)
|
||||
if qtype == dns.TypeA {
|
||||
ip = ip.To4()
|
||||
}
|
||||
|
||||
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
|
||||
res.IPList = append(res.IPList, ip)
|
||||
|
||||
log.Debug("rewrite: a/aaaa for %s is %s", host, ip)
|
||||
} else if dnsRewrite.NewCNAME != "" {
|
||||
res.CanonName = dnsRewrite.NewCNAME
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -925,7 +896,7 @@ func InitModule() {
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
func New(c *Config, blockFilters []Filter, rewriteStorage RewriteStorage) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
refreshLock: &sync.Mutex{},
|
||||
@@ -978,11 +949,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
|
||||
d.Config = *c
|
||||
d.filtersMu = &sync.RWMutex{}
|
||||
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rewrites: preparing: %s", err)
|
||||
}
|
||||
d.rewriteStorage = rewriteStorage
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.SafeBrowsingCacheSize = 10000
|
||||
c.ParentalCacheSize = 10000
|
||||
@@ -58,7 +59,8 @@ func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts
|
||||
// It must not be nil.
|
||||
c = &Config{}
|
||||
}
|
||||
f, err := New(c, filters)
|
||||
|
||||
f, err := New(c, filters, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
purgeCaches(f)
|
||||
@@ -417,274 +419,275 @@ func TestMatching(t *testing.T) {
|
||||
host string
|
||||
wantReason Reason
|
||||
wantIsFiltered bool
|
||||
wantDNSType uint16
|
||||
qtype uint16
|
||||
}{{
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "www.doubleclick.net",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "nodoubleclick.net",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "doubleclick.net.ru",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: sbBlocked,
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "blocking",
|
||||
rules: blockingRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "allowlist",
|
||||
rules: allowlistRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "important",
|
||||
rules: importantRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "test.test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "regex",
|
||||
rules: regexRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "test2.example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "example.com",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "exampleeee.com",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "onemoreexamsite.com",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "testexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "mask",
|
||||
rules: maskRules,
|
||||
host: "example.co.uk",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "onemoreexample.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "example.org",
|
||||
wantIsFiltered: true,
|
||||
wantReason: FilteredBlockList,
|
||||
wantDNSType: dns.TypeAAAA,
|
||||
qtype: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeA,
|
||||
qtype: dns.TypeA,
|
||||
}, {
|
||||
name: "dnstype",
|
||||
rules: dnstypeRules,
|
||||
host: "test.example.org",
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredAllowList,
|
||||
wantDNSType: dns.TypeAAAA,
|
||||
qtype: dns.TypeAAAA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
|
||||
res, err := d.CheckHost(tc.host, tc.qtype, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -30,11 +29,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
endpoint: &badRulesEndpoint,
|
||||
content: []byte(`<html></html>`),
|
||||
}} {
|
||||
ipp := serveFiltersLocally(t, rulesSource.content)
|
||||
*rulesSource.endpoint = (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: ipp.String(),
|
||||
}).String()
|
||||
*rulesSource.endpoint = serveFiltersLocally(t, rulesSource.content)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -110,7 +105,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
},
|
||||
ConfigModified: func() { confModifiedCalled = true },
|
||||
DataDir: filtersDir,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
|
||||
42
internal/filtering/rewrite.go
Normal file
42
internal/filtering/rewrite.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
)
|
||||
|
||||
// RewriteStorage is a storage for rewrite rules.
|
||||
type RewriteStorage interface {
|
||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||
|
||||
// Add adds item to the storage.
|
||||
Add(item *RewriteItem) (err error)
|
||||
|
||||
// Remove deletes item from the storage.
|
||||
Remove(item *RewriteItem) (err error)
|
||||
|
||||
// List returns all items from the storage.
|
||||
List() (items []*RewriteItem)
|
||||
}
|
||||
|
||||
// RewriteItem is a single DNS rewrite record.
|
||||
type RewriteItem struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain" json:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer" json:"answer"`
|
||||
}
|
||||
|
||||
// Equal returns true if rw is Equal to other.
|
||||
func (rw *RewriteItem) Equal(other *RewriteItem) (ok bool) {
|
||||
if rw == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *rw == *other
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Item is a single DNS rewrite record.
|
||||
type Item struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer"`
|
||||
}
|
||||
|
||||
// equal returns true if rw is equal to other.
|
||||
func (rw *Item) equal(other *Item) (ok bool) {
|
||||
if rw == nil {
|
||||
return other == nil
|
||||
} else if other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *rw == *other
|
||||
}
|
||||
|
||||
// toRule converts rw to a filter rule.
|
||||
func (rw *Item) toRule() (res string) {
|
||||
if rw == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
domain := strings.ToLower(rw.Domain)
|
||||
|
||||
dType, exception := rw.rewriteParams()
|
||||
dTypeKey := dns.TypeToString[dType]
|
||||
if exception {
|
||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||
}
|
||||
|
||||
// rewriteParams returns dns request type and exception flag for rw.
|
||||
func (rw *Item) rewriteParams() (dType uint16, exception bool) {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA, true
|
||||
case "A":
|
||||
return dns.TypeA, true
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(rw.Answer)
|
||||
if err != nil {
|
||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||
return dns.TypeCNAME, false
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
dType = dns.TypeA
|
||||
} else {
|
||||
dType = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return dType, false
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestItem_equal(t *testing.T) {
|
||||
const (
|
||||
testDomain = "example.org"
|
||||
testAnswer = "1.1.1.1"
|
||||
)
|
||||
|
||||
testItem := &Item{
|
||||
Domain: testDomain,
|
||||
Answer: testAnswer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left *Item
|
||||
right *Item
|
||||
want bool
|
||||
}{{
|
||||
name: "nil_left",
|
||||
left: nil,
|
||||
right: testItem,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nil_right",
|
||||
left: testItem,
|
||||
right: nil,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nils",
|
||||
left: nil,
|
||||
right: nil,
|
||||
want: true,
|
||||
}, {
|
||||
name: "equal",
|
||||
left: testItem,
|
||||
right: testItem,
|
||||
want: true,
|
||||
}, {
|
||||
name: "distinct",
|
||||
left: testItem,
|
||||
right: &Item{
|
||||
Domain: "other",
|
||||
Answer: "other",
|
||||
},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.left.equal(tc.right)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestItem_toRule(t *testing.T) {
|
||||
const testDomain = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
item *Item
|
||||
want string
|
||||
}{{
|
||||
name: "nil",
|
||||
item: nil,
|
||||
want: "",
|
||||
}, {
|
||||
name: "a_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "1.1.1.1",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||
}, {
|
||||
name: "aaaa_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "1:2:3::4",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||
}, {
|
||||
name: "cname_rule",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "wildcard_rule",
|
||||
item: &Item{
|
||||
Domain: "*.example.org",
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "A",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &Item{
|
||||
Domain: testDomain,
|
||||
Answer: "AAAA",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.item.toRule()
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,11 @@ package rewrite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -15,21 +17,6 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Storage is a storage for rewrite rules.
|
||||
type Storage interface {
|
||||
// MatchRequest returns matching dnsrewrites for the specified request.
|
||||
MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite)
|
||||
|
||||
// Add adds item to the storage.
|
||||
Add(item *Item) (err error)
|
||||
|
||||
// Remove deletes item from the storage.
|
||||
Remove(item *Item) (err error)
|
||||
|
||||
// List returns all items from the storage.
|
||||
List() (items []*Item)
|
||||
}
|
||||
|
||||
// DefaultStorage is the default storage for rewrite rules.
|
||||
type DefaultStorage struct {
|
||||
// mu protects items.
|
||||
@@ -42,7 +29,7 @@ type DefaultStorage struct {
|
||||
ruleList filterlist.RuleList
|
||||
|
||||
// rewrites stores the rewrite entries from configuration.
|
||||
rewrites []*Item
|
||||
rewrites []*filtering.RewriteItem
|
||||
|
||||
// urlFilterID is the synthetic integer identifier for the urlfilter engine.
|
||||
//
|
||||
@@ -53,16 +40,13 @@ type DefaultStorage struct {
|
||||
|
||||
// NewDefaultStorage returns new rewrites storage. listID is used as an
|
||||
// identifier of the underlying rules list. rewrites must not be nil.
|
||||
func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err error) {
|
||||
func NewDefaultStorage(rewrites []*filtering.RewriteItem) (s *DefaultStorage, err error) {
|
||||
s = &DefaultStorage{
|
||||
mu: &sync.RWMutex{},
|
||||
urlFilterID: listID,
|
||||
urlFilterID: filtering.RewritesListID,
|
||||
rewrites: rewrites,
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
err = s.resetRules()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -72,9 +56,9 @@ func NewDefaultStorage(listID int, rewrites []*Item) (s *DefaultStorage, err err
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Storage = (*DefaultStorage)(nil)
|
||||
var _ filtering.RewriteStorage = (*DefaultStorage)(nil)
|
||||
|
||||
// MatchRequest implements the [Storage] interface for *DefaultStorage.
|
||||
// MatchRequest implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.DNSRewrite) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -84,28 +68,32 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Check cnames for cycles on initialisation.
|
||||
// TODO(a.garipov): Check cnames for cycles on initialization.
|
||||
cnames := stringutil.NewSet()
|
||||
host := dReq.Hostname
|
||||
var lastCNAMERule *rules.NetworkRule
|
||||
for len(rrules) > 0 && rrules[0].DNSRewrite != nil && rrules[0].DNSRewrite.NewCNAME != "" {
|
||||
rule := rrules[0]
|
||||
rwAns := rule.DNSRewrite.NewCNAME
|
||||
lastCNAMERule = rrules[0]
|
||||
lastDNSRewrite := lastCNAMERule.DNSRewrite
|
||||
rwAns := lastDNSRewrite.NewCNAME
|
||||
|
||||
log.Debug("rewrite: cname for %s is %s", host, rwAns)
|
||||
|
||||
if dReq.Hostname == rwAns {
|
||||
// A request for the hostname itself is an exception rule.
|
||||
// A request for the hostname itself.
|
||||
// TODO(d.kolyshev): Check rewrite of a pattern onto itself.
|
||||
log.Debug("rewrite: request for hostname itself for %q", dReq.Hostname)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if host == rwAns && isWildcard(rule.RuleText) {
|
||||
if host == rwAns && isWildcard(lastCNAMERule.RuleText) {
|
||||
// An "*.example.com → sub.example.com" rewrite matching in a loop.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4016.
|
||||
log.Debug("rewrite: cname wildcard loop for %q on %q", dReq.Hostname, rwAns)
|
||||
|
||||
return []*rules.DNSRewrite{rule.DNSRewrite}
|
||||
return []*rules.DNSRewrite{lastDNSRewrite}
|
||||
}
|
||||
|
||||
if cnames.Has(rwAns) {
|
||||
@@ -120,21 +108,28 @@ func (s *DefaultStorage) MatchRequest(dReq *urlfilter.DNSRequest) (rws []*rules.
|
||||
Hostname: rwAns,
|
||||
DNSType: dReq.DNSType,
|
||||
})
|
||||
if drules != nil {
|
||||
rrules = drules
|
||||
|
||||
if drules == nil {
|
||||
break
|
||||
}
|
||||
|
||||
rrules = drules
|
||||
host = rwAns
|
||||
}
|
||||
|
||||
return s.collectDNSRewrites(rrules, dReq.DNSType)
|
||||
return s.collectDNSRewrites(rrules, lastCNAMERule, dReq.DNSType)
|
||||
}
|
||||
|
||||
// collectDNSRewrites filters DNSRewrite by question type.
|
||||
func (s *DefaultStorage) collectDNSRewrites(
|
||||
rewrites []*rules.NetworkRule,
|
||||
cnameRule *rules.NetworkRule,
|
||||
qtyp uint16,
|
||||
) (rws []*rules.DNSRewrite) {
|
||||
if cnameRule != nil {
|
||||
rewrites = append([]*rules.NetworkRule{cnameRule}, rewrites...)
|
||||
}
|
||||
|
||||
for _, rewrite := range rewrites {
|
||||
dnsRewrite := rewrite.DNSRewrite
|
||||
if matchesQType(dnsRewrite, qtyp) {
|
||||
@@ -152,8 +147,8 @@ func (s *DefaultStorage) rewriteRulesForReq(dReq *urlfilter.DNSRequest) (rules [
|
||||
return res.DNSRewrites()
|
||||
}
|
||||
|
||||
// Add implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Add(item *Item) (err error) {
|
||||
// Add implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Add(item *filtering.RewriteItem) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -163,16 +158,16 @@ func (s *DefaultStorage) Add(item *Item) (err error) {
|
||||
return s.resetRules()
|
||||
}
|
||||
|
||||
// Remove implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Remove(item *Item) (err error) {
|
||||
// Remove implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) Remove(item *filtering.RewriteItem) (err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
arr := []*Item{}
|
||||
arr := []*filtering.RewriteItem{}
|
||||
|
||||
// TODO(d.kolyshev): Use slices.IndexFunc + slices.Delete?
|
||||
for _, ent := range s.rewrites {
|
||||
if ent.equal(item) {
|
||||
if ent.Equal(item) {
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
@@ -185,8 +180,8 @@ func (s *DefaultStorage) Remove(item *Item) (err error) {
|
||||
return s.resetRules()
|
||||
}
|
||||
|
||||
// List implements the [Storage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) List() (items []*Item) {
|
||||
// List implements the [RewriteStorage] interface for *DefaultStorage.
|
||||
func (s *DefaultStorage) List() (items []*filtering.RewriteItem) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -198,7 +193,7 @@ func (s *DefaultStorage) resetRules() (err error) {
|
||||
// TODO(a.garipov): Use strings.Builder.
|
||||
var rulesText []string
|
||||
for _, rewrite := range s.rewrites {
|
||||
rulesText = append(rulesText, rewrite.toRule())
|
||||
rulesText = append(rulesText, toRule(rewrite))
|
||||
}
|
||||
|
||||
strList := &filterlist.StringRuleList{
|
||||
@@ -222,20 +217,60 @@ func (s *DefaultStorage) resetRules() (err error) {
|
||||
|
||||
// matchesQType returns true if dnsrewrite matches the question type qt.
|
||||
func matchesQType(dnsrr *rules.DNSRewrite, qt uint16) (ok bool) {
|
||||
// Add CNAMEs, since they match for all types requests.
|
||||
if dnsrr.RRType == dns.TypeCNAME {
|
||||
switch qt {
|
||||
case dns.TypeA:
|
||||
return dnsrr.RRType != dns.TypeAAAA
|
||||
case dns.TypeAAAA:
|
||||
return dnsrr.RRType != dns.TypeA
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject types other than A and AAAA.
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return false
|
||||
}
|
||||
|
||||
return dnsrr.RRType == qt
|
||||
}
|
||||
|
||||
// isWildcard returns true if pat is a wildcard domain pattern.
|
||||
func isWildcard(pat string) (res bool) {
|
||||
return strings.HasPrefix(pat, "|*.")
|
||||
}
|
||||
|
||||
// toRule converts rw to a filter rule.
|
||||
func toRule(rw *filtering.RewriteItem) (res string) {
|
||||
if rw == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
domain := strings.ToLower(rw.Domain)
|
||||
|
||||
dType, exception := rewriteParams(rw)
|
||||
dTypeKey := dns.TypeToString[dType]
|
||||
if exception {
|
||||
return fmt.Sprintf("@@||%s^$dnstype=%s,dnsrewrite", domain, dTypeKey)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("|%s^$dnsrewrite=NOERROR;%s;%s", domain, dTypeKey, rw.Answer)
|
||||
}
|
||||
|
||||
// RewriteParams returns dns request type and exception flag for rw.
|
||||
func rewriteParams(rw *filtering.RewriteItem) (dType uint16, exception bool) {
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA, true
|
||||
case "A":
|
||||
return dns.TypeA, true
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(rw.Answer)
|
||||
if err != nil {
|
||||
// TODO(d.kolyshev): Validate rw.Answer as a domain name.
|
||||
return dns.TypeCNAME, false
|
||||
}
|
||||
|
||||
if addr.Is4() {
|
||||
dType = dns.TypeA
|
||||
} else {
|
||||
dType = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return dType, false
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
@@ -12,32 +13,32 @@ import (
|
||||
)
|
||||
|
||||
func TestNewDefaultStorage(t *testing.T) {
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "example.com",
|
||||
Answer: "answer.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, s.List(), 1)
|
||||
}
|
||||
|
||||
func TestDefaultStorage_CRUD(t *testing.T) {
|
||||
var items []*Item
|
||||
var items []*filtering.RewriteItem
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, s.List(), 0)
|
||||
|
||||
item := &Item{Domain: "example.com", Answer: "answer.com"}
|
||||
item := &filtering.RewriteItem{Domain: "example.com", Answer: "answer.com"}
|
||||
|
||||
err = s.Add(item)
|
||||
require.NoError(t, err)
|
||||
|
||||
list := s.List()
|
||||
require.Len(t, list, 1)
|
||||
require.True(t, item.equal(list[0]))
|
||||
require.True(t, item.Equal(list[0]))
|
||||
|
||||
err = s.Remove(item)
|
||||
require.NoError(t, err)
|
||||
@@ -45,7 +46,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
@@ -101,7 +102,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -115,14 +116,39 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "not_filtered_qtype",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeMX,
|
||||
name: "other_qtype",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 5}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.ParseIP("1:2:3::4"),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
dtyp: dns.TypeMX,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -138,6 +164,11 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.ParseIP("1:2:3::4"),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -154,21 +185,30 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
//}, {
|
||||
// TODO(d.kolyshev): This is about matching in urlfilter.
|
||||
// name: "wildcard_override",
|
||||
// host: "a.host.com",
|
||||
// wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
// Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
// NewCNAME: "",
|
||||
// RCode: dns.RcodeSuccess,
|
||||
// RRType: dns.TypeA,
|
||||
// }},
|
||||
// dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 5}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 4}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -184,6 +224,11 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "somehost.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{0, 0, 0, 0}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -194,6 +239,11 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "x.host.com",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 2, 3, 5}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
@@ -221,10 +271,15 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeHTTPS,
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: net.IP{0, 0, 0, 0}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeHTTPS,
|
||||
}, {
|
||||
name: "issue4016",
|
||||
host: "www.issue4016.com",
|
||||
@@ -256,7 +311,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
|
||||
|
||||
func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
@@ -267,7 +322,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
Answer: "3.3.3.3",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -295,17 +350,21 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
//}, {
|
||||
// TODO(d.kolyshev): This is about matching in urlfilter.
|
||||
// name: "l3_match",
|
||||
// host: "my.sub.host.com",
|
||||
// wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
// Value: net.IP{3, 3, 3, 3}.To16(),
|
||||
// NewCNAME: "",
|
||||
// RCode: dns.RcodeSuccess,
|
||||
// RRType: dns.TypeA,
|
||||
// }},
|
||||
// dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "l3_match",
|
||||
host: "my.sub.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: net.IP{3, 3, 3, 3}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
Value: net.IP{2, 2, 2, 2}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -322,7 +381,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
|
||||
|
||||
func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
// Wildcard and exception for a sub-domain.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
@@ -330,10 +389,10 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
Answer: "sub.host.com",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
Answer: "sub.host.com",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -356,12 +415,79 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
host: "sub.host.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeA,
|
||||
//}, {
|
||||
// TODO(d.kolyshev): This is about matching in urlfilter.
|
||||
// name: "exception_wildcard",
|
||||
// host: "my.sub.host.com",
|
||||
// wantDNSRewrites: nil,
|
||||
// dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
wantDNSRewrites: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dnsRewrites := s.MatchRequest(&urlfilter.DNSRequest{
|
||||
Hostname: tc.host,
|
||||
DNSType: tc.dtyp,
|
||||
})
|
||||
|
||||
assert.Equal(t, tc.wantDNSRewrites, dnsRewrites)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultStorage_MatchRequest_CNAMEs(t *testing.T) {
|
||||
// Two cname rules for one subdomain
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "cname.org",
|
||||
Answer: "1.1.1.1",
|
||||
}, {
|
||||
Domain: "sub_cname.org",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
Domain: "*.host.com",
|
||||
Answer: "cname.org",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "sub_cname.org",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantDNSRewrites []*rules.DNSRewrite
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "match_my_domain",
|
||||
host: "my.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "cname.org",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 1, 1, 1}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_sub_my_domain",
|
||||
host: "my.sub.host.com",
|
||||
wantDNSRewrites: []*rules.DNSRewrite{{
|
||||
Value: nil,
|
||||
NewCNAME: "cname.org",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeNone,
|
||||
}, {
|
||||
Value: net.IP{1, 1, 1, 1}.To16(),
|
||||
NewCNAME: "",
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
}},
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -378,7 +504,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
|
||||
|
||||
func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
// Exception for AAAA record.
|
||||
items := []*Item{{
|
||||
items := []*filtering.RewriteItem{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
@@ -395,7 +521,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
Answer: "A",
|
||||
}}
|
||||
|
||||
s, err := NewDefaultStorage(-1, items)
|
||||
s, err := NewDefaultStorage(items)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
@@ -456,3 +582,66 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToRule(t *testing.T) {
|
||||
const testDomain = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
item *filtering.RewriteItem
|
||||
want string
|
||||
}{{
|
||||
name: "nil",
|
||||
item: nil,
|
||||
want: "",
|
||||
}, {
|
||||
name: "a_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "1.1.1.1",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;A;1.1.1.1",
|
||||
}, {
|
||||
name: "aaaa_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "1:2:3::4",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;AAAA;1:2:3::4",
|
||||
}, {
|
||||
name: "cname_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "wildcard_rule",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: "*.example.org",
|
||||
Answer: "other.org",
|
||||
},
|
||||
want: "|*.example.org^$dnsrewrite=NOERROR;CNAME;other.org",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "A",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=A,dnsrewrite",
|
||||
}, {
|
||||
name: "aaaa_exception",
|
||||
item: &filtering.RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: "AAAA",
|
||||
},
|
||||
want: "@@||example.org^$dnstype=AAAA,dnsrewrite",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := toRule(tc.item)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
61
internal/filtering/rewrite_test.go
Normal file
61
internal/filtering/rewrite_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestItem_equal(t *testing.T) {
|
||||
const (
|
||||
testDomain = "example.org"
|
||||
testAnswer = "1.1.1.1"
|
||||
)
|
||||
|
||||
testItem := &RewriteItem{
|
||||
Domain: testDomain,
|
||||
Answer: testAnswer,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
left *RewriteItem
|
||||
right *RewriteItem
|
||||
want bool
|
||||
}{{
|
||||
name: "nil_left",
|
||||
left: nil,
|
||||
right: testItem,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nil_right",
|
||||
left: testItem,
|
||||
right: nil,
|
||||
want: false,
|
||||
}, {
|
||||
name: "nils",
|
||||
left: nil,
|
||||
right: nil,
|
||||
want: true,
|
||||
}, {
|
||||
name: "equal",
|
||||
left: testItem,
|
||||
right: testItem,
|
||||
want: true,
|
||||
}, {
|
||||
name: "distinct",
|
||||
left: testItem,
|
||||
right: &RewriteItem{
|
||||
Domain: "other",
|
||||
Answer: "other",
|
||||
},
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := tc.left.Equal(tc.right)
|
||||
assert.Equal(t, tc.want, res)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,85 +8,57 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
||||
type rewriteEntryJSON struct {
|
||||
Domain string `json:"domain"`
|
||||
Answer string `json:"answer"`
|
||||
}
|
||||
|
||||
// handleRewriteList is the handler for the GET /control/rewrite/list HTTP API.
|
||||
func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
||||
arr := []*rewriteEntryJSON{}
|
||||
|
||||
d.confLock.Lock()
|
||||
for _, ent := range d.Config.Rewrites {
|
||||
jsent := rewriteEntryJSON{
|
||||
Domain: ent.Domain,
|
||||
Answer: ent.Answer,
|
||||
}
|
||||
arr = append(arr, &jsent)
|
||||
}
|
||||
d.confLock.Unlock()
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, arr)
|
||||
_ = aghhttp.WriteJSONResponse(w, r, d.rewriteStorage.List())
|
||||
}
|
||||
|
||||
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
rwJSON := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&rwJSON)
|
||||
rw := &RewriteItem{}
|
||||
err := json.NewDecoder(r.Body).Decode(rw)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rw := &LegacyRewrite{
|
||||
Domain: rwJSON.Domain,
|
||||
Answer: rwJSON.Answer,
|
||||
}
|
||||
|
||||
err = rw.normalize()
|
||||
err = d.rewriteStorage.Add(rw)
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "add rewrite: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("rewrite: added element: %s -> %s", rw.Domain, rw.Answer)
|
||||
|
||||
d.confLock.Lock()
|
||||
d.Config.Rewrites = append(d.Config.Rewrites, rw)
|
||||
d.Config.Rewrites = d.rewriteStorage.List()
|
||||
d.confLock.Unlock()
|
||||
log.Debug("rewrite: added element: %s -> %s [%d]", rw.Domain, rw.Answer, len(d.Config.Rewrites))
|
||||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
|
||||
// API.
|
||||
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
entDel := RewriteItem{}
|
||||
err := json.NewDecoder(r.Body).Decode(&entDel)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
entDel := &LegacyRewrite{
|
||||
Domain: jsent.Domain,
|
||||
Answer: jsent.Answer,
|
||||
err = d.rewriteStorage.Remove(&entDel)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "remove rewrite: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
arr := []*LegacyRewrite{}
|
||||
|
||||
d.confLock.Lock()
|
||||
for _, ent := range d.Config.Rewrites {
|
||||
if ent.equal(entDel) {
|
||||
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
arr = append(arr, ent)
|
||||
}
|
||||
d.Config.Rewrites = arr
|
||||
d.Config.Rewrites = d.rewriteStorage.List()
|
||||
d.confLock.Unlock()
|
||||
|
||||
d.Config.ConfigModified()
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
// DNS Rewrites
|
||||
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// LegacyRewrite is a single legacy DNS rewrite record.
|
||||
//
|
||||
// Instances of *LegacyRewrite must never be nil.
|
||||
type LegacyRewrite struct {
|
||||
// Domain is the domain pattern for which this rewrite should work.
|
||||
Domain string `yaml:"domain"`
|
||||
|
||||
// Answer is the IP address, canonical name, or one of the special
|
||||
// values: "A" or "AAAA".
|
||||
Answer string `yaml:"answer"`
|
||||
|
||||
// IP is the IP address that should be used in the response if Type is
|
||||
// dns.TypeA or dns.TypeAAAA.
|
||||
IP net.IP `yaml:"-"`
|
||||
|
||||
// Type is the DNS record type: A, AAAA, or CNAME.
|
||||
Type uint16 `yaml:"-"`
|
||||
}
|
||||
|
||||
// clone returns a deep clone of rw.
|
||||
func (rw *LegacyRewrite) clone() (cloneRW *LegacyRewrite) {
|
||||
return &LegacyRewrite{
|
||||
Domain: rw.Domain,
|
||||
Answer: rw.Answer,
|
||||
IP: slices.Clone(rw.IP),
|
||||
Type: rw.Type,
|
||||
}
|
||||
}
|
||||
|
||||
// equal returns true if the rw is equal to the other.
|
||||
func (rw *LegacyRewrite) equal(other *LegacyRewrite) (ok bool) {
|
||||
return rw.Domain == other.Domain && rw.Answer == other.Answer
|
||||
}
|
||||
|
||||
// matchesQType returns true if the entry matches the question type qt.
|
||||
func (rw *LegacyRewrite) matchesQType(qt uint16) (ok bool) {
|
||||
// Add CNAMEs, since they match for all types requests.
|
||||
if rw.Type == dns.TypeCNAME {
|
||||
return true
|
||||
}
|
||||
|
||||
// Reject types other than A and AAAA.
|
||||
if qt != dns.TypeA && qt != dns.TypeAAAA {
|
||||
return false
|
||||
}
|
||||
|
||||
// If the types match or the entry is set to allow only the other type,
|
||||
// include them.
|
||||
return rw.Type == qt || rw.IP == nil
|
||||
}
|
||||
|
||||
// normalize makes sure that the a new or decoded entry is normalized with
|
||||
// regards to domain name case, IP length, and so on.
|
||||
//
|
||||
// If rw is nil, it returns an errors.
|
||||
func (rw *LegacyRewrite) normalize() (err error) {
|
||||
if rw == nil {
|
||||
return errors.Error("nil rewrite entry")
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Write a case-agnostic version of strings.HasSuffix and
|
||||
// use it in matchDomainWildcard instead of using strings.ToLower
|
||||
// everywhere.
|
||||
rw.Domain = strings.ToLower(rw.Domain)
|
||||
|
||||
switch rw.Answer {
|
||||
case "AAAA":
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeAAAA
|
||||
|
||||
return nil
|
||||
case "A":
|
||||
rw.IP = nil
|
||||
rw.Type = dns.TypeA
|
||||
|
||||
return nil
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
ip := net.ParseIP(rw.Answer)
|
||||
if ip == nil {
|
||||
rw.Type = dns.TypeCNAME
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
rw.IP = ip4
|
||||
rw.Type = dns.TypeA
|
||||
} else {
|
||||
rw.IP = ip
|
||||
rw.Type = dns.TypeAAAA
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isWildcard returns true if pat is a wildcard domain pattern.
|
||||
func isWildcard(pat string) bool {
|
||||
return len(pat) > 1 && pat[0] == '*' && pat[1] == '.'
|
||||
}
|
||||
|
||||
// matchDomainWildcard returns true if host matches the wildcard pattern.
|
||||
func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
return isWildcard(wildcard) && strings.HasSuffix(host, wildcard[1:])
|
||||
}
|
||||
|
||||
// rewritesSorted is a slice of legacy rewrites for sorting.
|
||||
//
|
||||
// The sorting priority:
|
||||
//
|
||||
// 1. A and AAAA > CNAME
|
||||
// 2. wildcard > exact
|
||||
// 3. lower level wildcard > higher level wildcard
|
||||
//
|
||||
// TODO(a.garipov): Replace with slices.Sort.
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||
|
||||
// Swap implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
// Less implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||
ith, jth := a[i], a[j]
|
||||
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
|
||||
return true
|
||||
} else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
|
||||
return false
|
||||
}
|
||||
|
||||
if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
|
||||
return jw
|
||||
}
|
||||
|
||||
// Both are either wildcards or not.
|
||||
return len(ith.Domain) > len(jth.Domain)
|
||||
}
|
||||
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
func (d *DNSFilter) prepareRewrites() (err error) {
|
||||
for i, r := range d.Rewrites {
|
||||
err = r.normalize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findRewrites returns the list of matched rewrite entries. If rewrites are
|
||||
// empty, but matched is true, the domain is found among the rewrite rules but
|
||||
// not for this question type.
|
||||
//
|
||||
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
|
||||
// host is matched exactly, wildcard entries aren't returned. If the host
|
||||
// matched by wildcards, return the most specific for the question type.
|
||||
func findRewrites(
|
||||
entries []*LegacyRewrite,
|
||||
host string,
|
||||
qtype uint16,
|
||||
) (rewrites []*LegacyRewrite, matched bool) {
|
||||
for _, e := range entries {
|
||||
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
|
||||
continue
|
||||
}
|
||||
|
||||
matched = true
|
||||
if e.matchesQType(qtype) {
|
||||
rewrites = append(rewrites, e)
|
||||
}
|
||||
}
|
||||
|
||||
if len(rewrites) == 0 {
|
||||
return nil, matched
|
||||
}
|
||||
|
||||
sort.Sort(rewritesSorted(rewrites))
|
||||
|
||||
for i, r := range rewrites {
|
||||
if isWildcard(r.Domain) {
|
||||
// Don't use rewrites[:0], because we need to return at least one
|
||||
// item here.
|
||||
rewrites = rewrites[:max(1, i)]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rewrites, matched
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
@@ -1,371 +0,0 @@
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
// This one and below are about CNAME, A and AAAA.
|
||||
Domain: "somecname",
|
||||
Answer: "somehost.com",
|
||||
}, {
|
||||
Domain: "somehost.com",
|
||||
Answer: "0.0.0.0",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.5",
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "1:2:3::4",
|
||||
}, {
|
||||
Domain: "www.host.com",
|
||||
Answer: "host.com",
|
||||
}, {
|
||||
// This one is a wildcard.
|
||||
Domain: "*.host.com",
|
||||
Answer: "1.2.3.5",
|
||||
}, {
|
||||
// This one and below are about wildcard overriding.
|
||||
Domain: "a.host.com",
|
||||
Answer: "1.2.3.4",
|
||||
}, {
|
||||
// This one is about CNAME and wildcard interacting.
|
||||
Domain: "*.host2.com",
|
||||
Answer: "host.com",
|
||||
}, {
|
||||
// This one and below are about 2 level CNAME.
|
||||
Domain: "b.host.com",
|
||||
Answer: "somecname",
|
||||
}, {
|
||||
// This one and below are about 2 level CNAME and wildcard.
|
||||
Domain: "b.host3.com",
|
||||
Answer: "a.host3.com",
|
||||
}, {
|
||||
Domain: "a.host3.com",
|
||||
Answer: "x.host.com",
|
||||
}, {
|
||||
Domain: "*.hostboth.com",
|
||||
Answer: "1.2.3.6",
|
||||
}, {
|
||||
Domain: "*.hostboth.com",
|
||||
Answer: "1234::5678",
|
||||
}, {
|
||||
Domain: "BIGHOST.COM",
|
||||
Answer: "1.2.3.7",
|
||||
}, {
|
||||
Domain: "*.issue4016.com",
|
||||
Answer: "sub.issue4016.com",
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
wantCName string
|
||||
wantIPs []net.IP
|
||||
wantReason Reason
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "not_filtered_not_found",
|
||||
host: "hoost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_a",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "rewritten_aaaa",
|
||||
host: "www.host.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "wildcard_match",
|
||||
host: "abc.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_override",
|
||||
host: "a.host.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "wildcard_cname_interaction",
|
||||
host: "www.host2.com",
|
||||
wantCName: "host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames",
|
||||
host: "b.host.com",
|
||||
wantCName: "somehost.com",
|
||||
wantIPs: []net.IP{{0, 0, 0, 0}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "two_cnames_and_wildcard",
|
||||
host: "b.host3.com",
|
||||
wantCName: "x.host.com",
|
||||
wantIPs: []net.IP{{1, 2, 3, 5}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue3343",
|
||||
host: "www.hostboth.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{net.ParseIP("1234::5678")},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "issue3351",
|
||||
host: "bighost.com",
|
||||
wantCName: "",
|
||||
wantIPs: []net.IP{{1, 2, 3, 7}},
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4008",
|
||||
host: "somehost.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeHTTPS,
|
||||
}, {
|
||||
name: "issue4016",
|
||||
host: "www.issue4016.com",
|
||||
wantCName: "sub.issue4016.com",
|
||||
wantIPs: nil,
|
||||
wantReason: Rewritten,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "issue4016_self",
|
||||
host: "sub.issue4016.com",
|
||||
wantCName: "",
|
||||
wantIPs: nil,
|
||||
wantReason: NotFilteredNotFound,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
|
||||
|
||||
if tc.wantCName != "" {
|
||||
assert.Equal(t, tc.wantCName, r.CanonName)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.wantIPs, r.IPList)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.1.1.1",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "3.3.3.3",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "exact_match",
|
||||
host: "host.com",
|
||||
want: net.IP{1, 1, 1, 1},
|
||||
}, {
|
||||
name: "l2_match",
|
||||
host: "sub.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "l3_match",
|
||||
host: "my.sub.host.com",
|
||||
want: net.IP{3, 3, 3, 3},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "*.host.com",
|
||||
Answer: "2.2.2.2",
|
||||
}, {
|
||||
Domain: "sub.host.com",
|
||||
Answer: "sub.host.com",
|
||||
}, {
|
||||
Domain: "*.sub.host.com",
|
||||
Answer: "*.sub.host.com",
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want net.IP
|
||||
}{{
|
||||
name: "match_subdomain",
|
||||
host: "my.host.com",
|
||||
want: net.IP{2, 2, 2, 2},
|
||||
}, {
|
||||
name: "exception_cname",
|
||||
host: "sub.host.com",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "exception_wildcard",
|
||||
host: "my.sub.host.com",
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, dns.TypeA)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason, "got %s", r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
require.Len(t, r.IPList, 1)
|
||||
assert.True(t, tc.want.Equal(r.IPList[0]))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
Domain: "host.com",
|
||||
Answer: "1.2.3.4",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "host.com",
|
||||
Answer: "AAAA",
|
||||
Type: dns.TypeAAAA,
|
||||
}, {
|
||||
Domain: "host2.com",
|
||||
Answer: "::1",
|
||||
Type: dns.TypeAAAA,
|
||||
}, {
|
||||
Domain: "host2.com",
|
||||
Answer: "A",
|
||||
Type: dns.TypeA,
|
||||
}, {
|
||||
Domain: "host3.com",
|
||||
Answer: "A",
|
||||
Type: dns.TypeA,
|
||||
}}
|
||||
|
||||
require.NoError(t, d.prepareRewrites())
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
want []net.IP
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "match_A",
|
||||
host: "host.com",
|
||||
want: []net.IP{{1, 2, 3, 4}},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "exception_AAAA_host.com",
|
||||
host: "host.com",
|
||||
want: nil,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "exception_A_host2.com",
|
||||
host: "host2.com",
|
||||
want: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_AAAA_host2.com",
|
||||
host: "host2.com",
|
||||
want: []net.IP{net.ParseIP("::1")},
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "exception_A_host3.com",
|
||||
host: "host3.com",
|
||||
want: nil,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "match_AAAA_host3.com",
|
||||
host: "host3.com",
|
||||
want: []net.IP{},
|
||||
dtyp: dns.TypeAAAA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+"_"+tc.host, func(t *testing.T) {
|
||||
r := d.processRewrites(tc.host, tc.dtyp)
|
||||
if tc.want == nil {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
|
||||
|
||||
require.Len(t, r.IPList, len(tc.want))
|
||||
|
||||
for _, ip := range tc.want {
|
||||
assert.True(t, ip.Equal(r.IPList[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
err = Context.updater.Update(false)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
|
||||
@@ -9,9 +9,12 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rewrite"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -39,17 +42,13 @@ func onConfigModified() {
|
||||
}
|
||||
}
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
// Please note that we must do it even if we don't start it
|
||||
// so that we had access to the query log and the stats
|
||||
func initDNSServer() (err error) {
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized.
|
||||
func initDNS() (err error) {
|
||||
baseDir := Context.getDataDir()
|
||||
|
||||
var anonFunc aghnet.IPMutFunc
|
||||
if config.DNS.AnonymizeClientIP {
|
||||
anonFunc = querylog.AnonymizeIP
|
||||
}
|
||||
anonymizer := aghnet.NewIPMut(anonFunc)
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Filename: filepath.Join(baseDir, "stats.db"),
|
||||
@@ -76,40 +75,57 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
rewriteStorage, err := rewrite.NewDefaultStorage(config.DNS.DnsfilterConf.Rewrites)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rewrites: init: %w", err)
|
||||
}
|
||||
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil, rewriteStorage)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
case 0:
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
privateNets = netutil.SliceSubnetSet(nets)
|
||||
return initDNSServer(
|
||||
Context.filters,
|
||||
Context.stats,
|
||||
Context.queryLog,
|
||||
Context.dhcpServer,
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
)
|
||||
}
|
||||
|
||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||
// proxy, none of the arguments are required, but tlsConf still must not be nil,
|
||||
// in other cases all the arguments also must not be nil. It also must not be
|
||||
// called unless [config] and [Context] are initialized.
|
||||
func initDNSServer(
|
||||
filters *filtering.DNSFilter,
|
||||
sts stats.Interface,
|
||||
qlog querylog.QueryLog,
|
||||
dhcpSrv dhcpd.Interface,
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
) (err error) {
|
||||
privateNets, err := parseSubnetSet(config.DNS.PrivateNets)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.filters,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
QueryLog: qlog,
|
||||
PrivateNets: privateNets,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
DHCPServer: Context.dhcpServer,
|
||||
DHCPServer: dhcpSrv,
|
||||
}
|
||||
|
||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
||||
@@ -120,15 +136,15 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
|
||||
Context.clients.dnsServer = Context.dnsServer
|
||||
var dnsConfig dnsforward.ServerConfig
|
||||
dnsConfig, err = generateServerConfig()
|
||||
|
||||
dnsConf, err := generateServerConfig(tlsConf, httpReg)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
return fmt.Errorf("generateServerConfig: %w", err)
|
||||
}
|
||||
|
||||
err = Context.dnsServer.Prepare(&dnsConfig)
|
||||
err = Context.dnsServer.Prepare(&dnsConf)
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
|
||||
@@ -146,6 +162,32 @@ func initDNSServer() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSubnetSet parses a slice of subnets. If the slice is empty, it returns
|
||||
// a subnet set that matches all locally served networks, see
|
||||
// [netutil.IsLocallyServed].
|
||||
func parseSubnetSet(nets []string) (s netutil.SubnetSet, err error) {
|
||||
switch len(nets) {
|
||||
case 0:
|
||||
// Use an optimized function-based matcher.
|
||||
return netutil.SubnetSetFunc(netutil.IsLocallyServed), nil
|
||||
case 1:
|
||||
s, err = netutil.ParseSubnet(nets[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return netutil.SliceSubnetSet(nets), nil
|
||||
}
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
@@ -193,7 +235,10 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
return udpAddrs
|
||||
}
|
||||
|
||||
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
func generateServerConfig(
|
||||
tlsConf *tlsConfigSettings,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
) (newConf dnsforward.ServerConfig, err error) {
|
||||
dnsConf := config.DNS
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
newConf = dnsforward.ServerConfig{
|
||||
@@ -201,12 +246,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
FilteringConfig: dnsConf.FilteringConfig,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
HTTPRegister: httpReg,
|
||||
OnDNSRequest: onDNSRequest,
|
||||
}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||
@@ -224,7 +267,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSCrypt != 0 {
|
||||
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, tlsConf)
|
||||
newConf.DNSCryptConfig, err = newDNSCrypt(hosts, *tlsConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's already
|
||||
// wrapped by newDNSCrypt.
|
||||
@@ -413,7 +456,11 @@ func startDNSServer() error {
|
||||
|
||||
func reconfigureDNSServer() (err error) {
|
||||
var newConf dnsforward.ServerConfig
|
||||
newConf, err = generateServerConfig()
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
newConf, err = generateServerConfig(tlsConf, httpRegister)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||
}
|
||||
|
||||
@@ -455,6 +455,10 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
err = setupConfig(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(opts)
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config
|
||||
err = config.write()
|
||||
@@ -522,7 +526,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNSServer()
|
||||
err = initDNS()
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.start()
|
||||
@@ -543,20 +547,24 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(a.garipov): This could be made much earlier and could be done on
|
||||
// the first run as well, but to achieve this we need to bypass requests
|
||||
// over dnsforward resolver.
|
||||
cmdlineUpdate(opts)
|
||||
|
||||
Context.web.Start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
}
|
||||
|
||||
func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
||||
var anonFunc aghnet.IPMutFunc
|
||||
if c.DNS.AnonymizeClientIP {
|
||||
anonFunc = querylog.AnonymizeIP
|
||||
}
|
||||
|
||||
return aghnet.NewIPMut(anonFunc)
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
func startMods() error {
|
||||
err := initDNSServer()
|
||||
func startMods() (err error) {
|
||||
err = initDNS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -927,8 +935,8 @@ func getHTTPProxy(_ *http.Request) (*url.URL, error) {
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in .../dhcpd and
|
||||
// other packages after refactoring the web handler registering.
|
||||
// TODO(a.garipov): Merge together with the implementations in [dhcpd] and other
|
||||
// packages after refactoring the web handler registering.
|
||||
type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
@@ -940,30 +948,40 @@ func cmdlineUpdate(opts options) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("starting update")
|
||||
// Initialize the DNS server to use the internal resolver which the updater
|
||||
// needs to be able to resolve the update source hostname.
|
||||
//
|
||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||
// separately.
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{})
|
||||
fatalOnError(err)
|
||||
|
||||
if Context.firstRun {
|
||||
log.Info("update not allowed on first run")
|
||||
log.Info("cmdline update: performing update")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
_, err := Context.updater.VersionInfo(true)
|
||||
updater := Context.updater
|
||||
info, err := updater.VersionInfo(true)
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
vcu := updater.VersionCheckURL()
|
||||
log.Error("getting version info from %s: %s", vcu, err)
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if Context.updater.NewVersion() == "" {
|
||||
if info.NewVersion == version.Version() {
|
||||
log.Info("no updates available")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
err = updater.Update(Context.firstRun)
|
||||
fatalOnError(err)
|
||||
|
||||
err = restartService()
|
||||
if err != nil {
|
||||
log.Debug("restarting service: %s", err)
|
||||
log.Info("AdGuard Home was not installed as a service. " +
|
||||
"Please restart running instances of AdGuardHome manually.")
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
updateNoValue: func(o options) (options, error) { o.performUpdate = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.performUpdate },
|
||||
description: "Update application and exit.",
|
||||
description: "Update the current binary and restart the service in case it's installed.",
|
||||
longName: "update",
|
||||
shortName: "",
|
||||
}, {
|
||||
|
||||
@@ -159,6 +159,38 @@ func sendSigReload() {
|
||||
log.Debug("service: sent signal to pid %d", pid)
|
||||
}
|
||||
|
||||
// restartService restarts the service. It returns error if the service is not
|
||||
// running.
|
||||
func restartService() (err error) {
|
||||
// Call chooseSystem explicitly to introduce OpenBSD support for service
|
||||
// package. It's a noop for other GOOS values.
|
||||
chooseSystem()
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting current directory: %w", err)
|
||||
}
|
||||
|
||||
svcConfig := &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: serviceDisplayName,
|
||||
Description: serviceDescription,
|
||||
WorkingDirectory: pwd,
|
||||
}
|
||||
configureService(svcConfig)
|
||||
|
||||
var s service.Service
|
||||
if s, err = service.New(&program{}, svcConfig); err != nil {
|
||||
return fmt.Errorf("initializing service: %w", err)
|
||||
}
|
||||
|
||||
if err = svcAction(s, "restart"); err != nil {
|
||||
return fmt.Errorf("restarting service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleServiceControlAction one of the possible control actions:
|
||||
//
|
||||
// - install: Installs a service/daemon.
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// chooseSystem checks the current system detected and substitutes it with local
|
||||
// implementation if needed.
|
||||
func chooseSystem() {
|
||||
sys := service.ChosenSystem()
|
||||
// By default, package service uses the SysV system if it cannot detect
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
// sysVersion is the version of local service.System interface implementation.
|
||||
const sysVersion = "openbsd-runcom"
|
||||
|
||||
// chooseSystem checks the current system detected and substitutes it with local
|
||||
// implementation if needed.
|
||||
func chooseSystem() {
|
||||
service.ChooseSystem(openbsdSystem{})
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ func withRecovered(orig *error) {
|
||||
// type check
|
||||
var _ Interface = (*StatsCtx)(nil)
|
||||
|
||||
// Start implements the Interface interface for *StatsCtx.
|
||||
// Start implements the [Interface] interface for *StatsCtx.
|
||||
func (s *StatsCtx) Start() {
|
||||
s.initWeb()
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||
}
|
||||
|
||||
u.prevCheckTime = time.Now()
|
||||
u.prevCheckTime = now
|
||||
u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(body)
|
||||
|
||||
return u.prevCheckResult, u.prevCheckError
|
||||
@@ -92,7 +92,11 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
info.AnnouncementURL = versionJSON["announcement_url"]
|
||||
|
||||
packageURL, ok := u.downloadURL(versionJSON)
|
||||
info.CanAutoUpdate = aghalg.BoolToNullBool(ok && info.NewVersion != u.version)
|
||||
if !ok {
|
||||
return info, fmt.Errorf("version.json: packageURL not found")
|
||||
}
|
||||
|
||||
info.CanAutoUpdate = aghalg.BoolToNullBool(info.NewVersion != u.version)
|
||||
|
||||
u.newVersion = info.NewVersion
|
||||
u.packageURL = packageURL
|
||||
|
||||
@@ -104,49 +104,58 @@ func NewUpdater(conf *Config) *Updater {
|
||||
}
|
||||
}
|
||||
|
||||
// Update performs the auto-update.
|
||||
func (u *Updater) Update() (err error) {
|
||||
// Update performs the auto-update. It returns an error if the update failed.
|
||||
// If firstRun is true, it assumes the configuration file doesn't exist.
|
||||
func (u *Updater) Update(firstRun bool) (err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
log.Info("updater: updating")
|
||||
defer func() { log.Info("updater: finished; errors: %v", err) }()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Error("updater: failed: %v", err)
|
||||
} else {
|
||||
log.Info("updater: finished")
|
||||
}
|
||||
}()
|
||||
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("getting executable path: %w", err)
|
||||
}
|
||||
|
||||
err = u.prepare(execPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("preparing: %w", err)
|
||||
}
|
||||
|
||||
defer u.clean()
|
||||
|
||||
err = u.downloadPackageFile(u.packageURL, u.packageName)
|
||||
err = u.downloadPackageFile()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("downloading package file: %w", err)
|
||||
}
|
||||
|
||||
err = u.unpack()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("unpacking: %w", err)
|
||||
}
|
||||
|
||||
err = u.check()
|
||||
if err != nil {
|
||||
return err
|
||||
if !firstRun {
|
||||
err = u.check()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = u.backup()
|
||||
err = u.backup(firstRun)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("making backup: %w", err)
|
||||
}
|
||||
|
||||
err = u.replace()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("replacing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -174,7 +183,7 @@ func (u *Updater) prepare(exePath string) (err error) {
|
||||
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
if pkgNameOnly == "" {
|
||||
return fmt.Errorf("invalid PackageURL")
|
||||
return fmt.Errorf("invalid PackageURL: %q", u.packageURL)
|
||||
}
|
||||
|
||||
u.packageName = filepath.Join(u.updateDir, pkgNameOnly)
|
||||
@@ -204,6 +213,7 @@ func (u *Updater) prepare(exePath string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// unpack extracts the files from the downloaded archive.
|
||||
func (u *Updater) unpack() error {
|
||||
var err error
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
@@ -228,38 +238,48 @@ func (u *Updater) unpack() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check returns an error if the configuration file couldn't be used with the
|
||||
// version of AdGuard Home just downloaded.
|
||||
func (u *Updater) check() error {
|
||||
log.Debug("updater: checking configuration")
|
||||
|
||||
err := copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(u.updateExeName, "--check-config")
|
||||
err = cmd.Run()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Updater) backup() error {
|
||||
// backup makes a backup of the current configuration and supporting files. It
|
||||
// ignores the configuration file if firstRun is true.
|
||||
func (u *Updater) backup(firstRun bool) (err error) {
|
||||
log.Debug("updater: backing up current configuration")
|
||||
_ = os.Mkdir(u.backupDir, 0o755)
|
||||
err := copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %w", err)
|
||||
if !firstRun {
|
||||
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("copyFile() failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
wd := u.workDir
|
||||
err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s",
|
||||
wd, u.backupDir, err)
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", wd, u.backupDir, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// replace moves the current executable with the updated one and also copies the
|
||||
// supporting files.
|
||||
func (u *Updater) replace() error {
|
||||
err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.workDir)
|
||||
if err != nil {
|
||||
@@ -287,6 +307,7 @@ func (u *Updater) replace() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// clean removes the temporary directory itself and all it's contents.
|
||||
func (u *Updater) clean() {
|
||||
_ = os.RemoveAll(u.updateDir)
|
||||
}
|
||||
@@ -297,9 +318,9 @@ func (u *Updater) clean() {
|
||||
const MaxPackageFileSize = 32 * 1024 * 1024
|
||||
|
||||
// Download package file and save it to disk
|
||||
func (u *Updater) downloadPackageFile(url, filename string) (err error) {
|
||||
func (u *Updater) downloadPackageFile() (err error) {
|
||||
var resp *http.Response
|
||||
resp, err = u.client.Get(url)
|
||||
resp, err = u.client.Get(u.packageURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
@@ -321,7 +342,7 @@ func (u *Updater) downloadPackageFile(url, filename string) (err error) {
|
||||
_ = os.Mkdir(u.updateDir, 0o755)
|
||||
|
||||
log.Debug("updater: saving package to file")
|
||||
err = os.WriteFile(filename, body, 0o644)
|
||||
err = os.WriteFile(u.packageName, body, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.WriteFile() failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -136,10 +136,10 @@ func TestUpdate(t *testing.T) {
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare(exePath))
|
||||
require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName))
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
// require.NoError(t, u.check())
|
||||
require.NoError(t, u.backup())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
|
||||
u.clean()
|
||||
@@ -215,10 +215,10 @@ func TestUpdateWindows(t *testing.T) {
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare(exePath))
|
||||
require.NoError(t, u.downloadPackageFile(u.packageURL, u.packageName))
|
||||
require.NoError(t, u.downloadPackageFile())
|
||||
require.NoError(t, u.unpack())
|
||||
// assert.Nil(t, u.check())
|
||||
require.NoError(t, u.backup())
|
||||
require.NoError(t, u.backup(false))
|
||||
require.NoError(t, u.replace())
|
||||
|
||||
u.clean()
|
||||
|
||||
Reference in New Issue
Block a user