Merge branch 'master' into 4535-list-services
This commit is contained in:
@@ -1,32 +1,45 @@
|
||||
// Package aghalg contains common generic algorithms and data structures.
|
||||
//
|
||||
// TODO(a.garipov): Update to use type parameters in Go 1.18.
|
||||
// TODO(a.garipov): Move parts of this into golibs.
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// comparable is an alias for interface{}. Values passed as arguments of this
|
||||
// type alias must be comparable.
|
||||
//
|
||||
// TODO(a.garipov): Remove in Go 1.18.
|
||||
type comparable = interface{}
|
||||
// Coalesce returns the first non-zero value. It is named after the function
|
||||
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
|
||||
// value.
|
||||
func Coalesce[T comparable](values ...T) (res T) {
|
||||
var zero T
|
||||
for _, v := range values {
|
||||
if v != zero {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return zero
|
||||
}
|
||||
|
||||
// UniqChecker allows validating uniqueness of comparable items.
|
||||
type UniqChecker map[comparable]int64
|
||||
//
|
||||
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate.
|
||||
// Consider ways of making this constraint comparable instead.
|
||||
type UniqChecker[T constraints.Ordered] map[T]int64
|
||||
|
||||
// Add adds a value to the validator. v must not be nil.
|
||||
func (uc UniqChecker) Add(elems ...comparable) {
|
||||
func (uc UniqChecker[T]) Add(elems ...T) {
|
||||
for _, e := range elems {
|
||||
uc[e]++
|
||||
}
|
||||
}
|
||||
|
||||
// Merge returns a checker containing data from both uc and other.
|
||||
func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
merged = make(UniqChecker, len(uc)+len(other))
|
||||
func (uc UniqChecker[T]) Merge(other UniqChecker[T]) (merged UniqChecker[T]) {
|
||||
merged = make(UniqChecker[T], len(uc)+len(other))
|
||||
for elem, num := range uc {
|
||||
merged[elem] += num
|
||||
}
|
||||
@@ -39,10 +52,8 @@ func (uc UniqChecker) Merge(other UniqChecker) (merged UniqChecker) {
|
||||
}
|
||||
|
||||
// Validate returns an error enumerating all elements that aren't unique.
|
||||
// isBefore is an optional sorting function to make the error message
|
||||
// deterministic.
|
||||
func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err error) {
|
||||
var dup []comparable
|
||||
func (uc UniqChecker[T]) Validate() (err error) {
|
||||
var dup []T
|
||||
for elem, num := range uc {
|
||||
if num > 1 {
|
||||
dup = append(dup, elem)
|
||||
@@ -53,23 +64,7 @@ func (uc UniqChecker) Validate(isBefore func(a, b comparable) (less bool)) (err
|
||||
return nil
|
||||
}
|
||||
|
||||
if isBefore != nil {
|
||||
sort.Slice(dup, func(i, j int) (less bool) {
|
||||
return isBefore(dup[i], dup[j])
|
||||
})
|
||||
}
|
||||
slices.Sort(dup)
|
||||
|
||||
return fmt.Errorf("duplicated values: %v", dup)
|
||||
}
|
||||
|
||||
// IntIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type int.
|
||||
func IntIsBefore(a, b comparable) (less bool) {
|
||||
return a.(int) < b.(int)
|
||||
}
|
||||
|
||||
// StringIsBefore is a helper sort function for UniqChecker.Validate.
|
||||
// a and b must be of type string.
|
||||
func StringIsBefore(a, b comparable) (less bool) {
|
||||
return a.(string) < b.(string)
|
||||
}
|
||||
|
||||
67
internal/aghalg/nullbool.go
Normal file
67
internal/aghalg/nullbool.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NullBool is a nullable boolean. Use these in JSON requests and responses
|
||||
// instead of pointers to bool.
|
||||
type NullBool uint8
|
||||
|
||||
// NullBool values
|
||||
const (
|
||||
NBNull NullBool = iota
|
||||
NBTrue
|
||||
NBFalse
|
||||
)
|
||||
|
||||
// String implements the fmt.Stringer interface for NullBool.
|
||||
func (nb NullBool) String() (s string) {
|
||||
switch nb {
|
||||
case NBNull:
|
||||
return "null"
|
||||
case NBTrue:
|
||||
return "true"
|
||||
case NBFalse:
|
||||
return "false"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("!invalid NullBool %d", uint8(nb))
|
||||
}
|
||||
|
||||
// BoolToNullBool converts a bool into a NullBool.
|
||||
func BoolToNullBool(cond bool) (nb NullBool) {
|
||||
if cond {
|
||||
return NBTrue
|
||||
}
|
||||
|
||||
return NBFalse
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ json.Marshaler = NBNull
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for NullBool.
|
||||
func (nb NullBool) MarshalJSON() (b []byte, err error) {
|
||||
return []byte(nb.String()), nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ json.Unmarshaler = (*NullBool)(nil)
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for *NullBool.
|
||||
func (nb *NullBool) UnmarshalJSON(b []byte) (err error) {
|
||||
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
|
||||
*nb = NBNull
|
||||
} else if bytes.Equal(b, []byte("true")) {
|
||||
*nb = NBTrue
|
||||
} else if bytes.Equal(b, []byte("false")) {
|
||||
*nb = NBFalse
|
||||
} else {
|
||||
return fmt.Errorf("unmarshalling json data into aghalg.NullBool: bad value %q", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
113
internal/aghalg/nullbool_test.go
Normal file
113
internal/aghalg/nullbool_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package aghalg_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullBool_MarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
want []byte
|
||||
in aghalg.NullBool
|
||||
}{{
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
want: []byte("null"),
|
||||
in: aghalg.NBNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
want: []byte("true"),
|
||||
in: aghalg.NBTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
want: []byte("false"),
|
||||
in: aghalg.NBFalse,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := tc.in.MarshalJSON()
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
in := &struct {
|
||||
A aghalg.NullBool
|
||||
}{
|
||||
A: aghalg.NBTrue,
|
||||
}
|
||||
|
||||
got, err := json.Marshal(in)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte(`{"A":true}`), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
data []byte
|
||||
want aghalg.NullBool
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
data: []byte{},
|
||||
want: aghalg.NBNull,
|
||||
}, {
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
data: []byte("null"),
|
||||
want: aghalg.NBNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
data: []byte("true"),
|
||||
want: aghalg.NBTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
data: []byte("false"),
|
||||
want: aghalg.NBFalse,
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErrMsg: `unmarshalling json data into aghalg.NullBool: bad value "invalid"`,
|
||||
data: []byte("invalid"),
|
||||
want: aghalg.NBNull,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got aghalg.NullBool
|
||||
err := got.UnmarshalJSON(tc.data)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
want := aghalg.NBTrue
|
||||
var got struct {
|
||||
A aghalg.NullBool
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(`{"A":true}`), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, got.A)
|
||||
})
|
||||
}
|
||||
@@ -9,6 +9,12 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||
// method.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Get rid of it.
|
||||
type RegisterFunc func(method, url string, handler http.HandlerFunc)
|
||||
|
||||
// OK responds with word OK.
|
||||
func OK(w http.ResponseWriter) {
|
||||
if _, err := io.WriteString(w, "OK\n"); err != nil {
|
||||
@@ -17,7 +23,7 @@ func OK(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
// Error writes formatted message to w and also logs it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
|
||||
@@ -198,7 +198,7 @@ func (hc *HostsContainer) Close() (err error) {
|
||||
}
|
||||
|
||||
// Upd returns the channel into which the updates are sent. The receivable
|
||||
// map's values are guaranteed to be of type of *stringutil.Set.
|
||||
// map's values are guaranteed to be of type of *HostsRecord.
|
||||
func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) {
|
||||
return hc.updates
|
||||
}
|
||||
@@ -290,7 +290,7 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err
|
||||
continue
|
||||
}
|
||||
|
||||
hp.addPairs(ip, hosts)
|
||||
hp.addRecord(ip, hosts)
|
||||
}
|
||||
|
||||
return nil, true, s.Err()
|
||||
@@ -335,39 +335,66 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) {
|
||||
return ip, hosts
|
||||
}
|
||||
|
||||
// addPair puts the pair of ip and host to the rules builder if needed. For
|
||||
// each ip the first member of hosts will become the main one.
|
||||
func (hp *hostsParser) addPairs(ip net.IP, hosts []string) {
|
||||
// HostsRecord represents a single hosts file record.
|
||||
type HostsRecord struct {
|
||||
Aliases *stringutil.Set
|
||||
Canonical string
|
||||
}
|
||||
|
||||
// Equal returns true if all fields of rec are equal to field in other or they
|
||||
// both are nil.
|
||||
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
|
||||
if rec == nil {
|
||||
return other == nil
|
||||
}
|
||||
|
||||
return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases)
|
||||
}
|
||||
|
||||
// addRecord puts the record for the IP address to the rules builder if needed.
|
||||
// The first host is considered to be the canonical name for the IP address.
|
||||
// hosts must have at least one name.
|
||||
func (hp *hostsParser) addRecord(ip net.IP, hosts []string) {
|
||||
line := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
|
||||
var rec *HostsRecord
|
||||
v, ok := hp.table.Get(ip)
|
||||
if !ok {
|
||||
// This ip is added at the first time.
|
||||
v = stringutil.NewSet()
|
||||
hp.table.Set(ip, v)
|
||||
rec = &HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
}
|
||||
|
||||
rec.Canonical, hosts = hosts[0], hosts[1:]
|
||||
hp.addRules(ip, rec.Canonical, line)
|
||||
hp.table.Set(ip, rec)
|
||||
} else {
|
||||
rec, ok = v.(*HostsRecord)
|
||||
if !ok {
|
||||
log.Error("%s: adding pairs: unexpected type %T", hostsContainerPref, v)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
if !ok {
|
||||
log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
processed := strings.Join(append([]string{ip.String()}, hosts...), " ")
|
||||
for _, h := range hosts {
|
||||
if set.Has(h) {
|
||||
for _, host := range hosts {
|
||||
if rec.Canonical == host || rec.Aliases.Has(host) {
|
||||
continue
|
||||
}
|
||||
|
||||
set.Add(h)
|
||||
rec.Aliases.Add(host)
|
||||
|
||||
rule, rulePtr := hp.writeRules(h, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = processed, processed
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h)
|
||||
hp.addRules(ip, host, line)
|
||||
}
|
||||
}
|
||||
|
||||
// addRules adds rules and rule translations for the line.
|
||||
func (hp *hostsParser) addRules(ip net.IP, host, line string) {
|
||||
rule, rulePtr := hp.writeRules(host, ip)
|
||||
hp.translations[rule], hp.translations[rulePtr] = line, line
|
||||
|
||||
log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host)
|
||||
}
|
||||
|
||||
// writeRules writes the actual rule for the qtype and the PTR for the host-ip
|
||||
// pair into internal builders.
|
||||
func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) {
|
||||
@@ -417,6 +444,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string)
|
||||
}
|
||||
|
||||
// equalSet returns true if the internal hosts table just parsed equals target.
|
||||
// target's values must be of type *HostsRecord.
|
||||
func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
|
||||
if target == nil {
|
||||
// hp.table shouldn't appear nil since it's initialized on each refresh.
|
||||
@@ -427,22 +455,35 @@ func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
hp.table.Range(func(ip net.IP, b interface{}) (cont bool) {
|
||||
// ok is set to true if the target doesn't contain ip or if the
|
||||
// appropriate hosts set isn't equal to the checked one.
|
||||
if a, hasIP := target.Get(ip); !hasIP {
|
||||
ok = true
|
||||
} else if hosts, aok := a.(*stringutil.Set); aok {
|
||||
ok = !hosts.Equal(b.(*stringutil.Set))
|
||||
hp.table.Range(func(ip net.IP, recVal any) (cont bool) {
|
||||
var targetVal any
|
||||
targetVal, ok = target.Get(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Continue only if maps has no discrepancies.
|
||||
return !ok
|
||||
var rec *HostsRecord
|
||||
rec, ok = recVal.(*HostsRecord)
|
||||
if !ok {
|
||||
log.Error("%s: comparing: unexpected type %T", hostsContainerPref, recVal)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var targetRec *HostsRecord
|
||||
targetRec, ok = targetVal.(*HostsRecord)
|
||||
if !ok {
|
||||
log.Error("%s: comparing: target: unexpected type %T", hostsContainerPref, targetVal)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
ok = rec.Equal(targetRec)
|
||||
|
||||
return ok
|
||||
})
|
||||
|
||||
// Return true if every value from the IP map has no discrepancies with the
|
||||
// appropriate one from the target.
|
||||
return !ok
|
||||
return ok
|
||||
}
|
||||
|
||||
// sendUpd tries to send the parsed data to the ch.
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -159,31 +160,47 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||
|
||||
checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) {
|
||||
upd, ok := <-hc.Upd()
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
checkRefresh := func(t *testing.T, want *HostsRecord) {
|
||||
t.Helper()
|
||||
|
||||
var ok bool
|
||||
var upd *netutil.IPMap
|
||||
select {
|
||||
case upd, ok = <-hc.Upd():
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("did not receive after 1s")
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, upd.Len())
|
||||
|
||||
v, ok := upd.Get(ip)
|
||||
require.True(t, ok)
|
||||
|
||||
var set *stringutil.Set
|
||||
set, ok = v.(*stringutil.Set)
|
||||
require.True(t, ok)
|
||||
require.IsType(t, (*HostsRecord)(nil), v)
|
||||
|
||||
assert.True(t, set.Equal(wantHosts))
|
||||
rec, _ := v.(*HostsRecord)
|
||||
require.NotNil(t, rec)
|
||||
|
||||
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
|
||||
}
|
||||
|
||||
t.Run("initial_refresh", func(t *testing.T) {
|
||||
checkRefresh(t, stringutil.NewSet("hostname"))
|
||||
checkRefresh(t, &HostsRecord{
|
||||
Aliases: stringutil.NewSet(),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("second_refresh", func(t *testing.T) {
|
||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||
eventsCh <- event{}
|
||||
checkRefresh(t, stringutil.NewSet("hostname", "alias"))
|
||||
|
||||
checkRefresh(t, &HostsRecord{
|
||||
Aliases: stringutil.NewSet("alias"),
|
||||
Canonical: "hostname",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("double_refresh", func(t *testing.T) {
|
||||
@@ -363,10 +380,15 @@ func TestHostsContainer(t *testing.T) {
|
||||
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
||||
|
||||
testCases := []struct {
|
||||
want []*rules.DNSRewrite
|
||||
name string
|
||||
req *urlfilter.DNSRequest
|
||||
name string
|
||||
want []*rules.DNSRewrite
|
||||
}{{
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "simple",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 1),
|
||||
@@ -376,27 +398,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
Value: net.ParseIP("::1"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "simple",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "simplehost",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "hello_alias",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "hello_alias",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
@@ -406,26 +413,41 @@ func TestHostsContainer(t *testing.T) {
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
name: "other_line_alias",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "hello.world.again",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "other_line_alias",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.IPv4(1, 0, 0, 0),
|
||||
RRType: dns.TypeA,
|
||||
}, {
|
||||
RCode: dns.RcodeSuccess,
|
||||
Value: net.ParseIP("::"),
|
||||
RRType: dns.TypeAAAA,
|
||||
}},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_subdomain",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "say.hello",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "hello_subdomain",
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "hello_alias_subdomain",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "say.hello.world",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "hello_alias_subdomain",
|
||||
want: []*rules.DNSRewrite{},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "lots_of_aliases",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -435,37 +457,37 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::2"),
|
||||
}},
|
||||
name: "lots_of_aliases",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "for.testing",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
name: "reverse",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypePTR,
|
||||
Value: "simplehost.",
|
||||
}},
|
||||
name: "reverse",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypePTR,
|
||||
},
|
||||
}, {
|
||||
want: []*rules.DNSRewrite{},
|
||||
name: "non-existing",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "nonexisting",
|
||||
Hostname: "nonexistent.example",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "non-existing",
|
||||
want: []*rules.DNSRewrite{},
|
||||
}, {
|
||||
want: nil,
|
||||
name: "bad_type",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "1.0.0.1.in-addr.arpa",
|
||||
DNSType: dns.TypeSRV,
|
||||
},
|
||||
name: "bad_type",
|
||||
want: nil,
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
name: "issue_4216_4_6",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -475,12 +497,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::42"),
|
||||
}},
|
||||
name: "issue_4216_4_6",
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain",
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
name: "issue_4216_4",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeA,
|
||||
@@ -490,12 +512,12 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeA,
|
||||
Value: net.IPv4(1, 3, 5, 7),
|
||||
}},
|
||||
name: "issue_4216_4",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain4",
|
||||
DNSType: dns.TypeA,
|
||||
},
|
||||
}, {
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
name: "issue_4216_6",
|
||||
want: []*rules.DNSRewrite{{
|
||||
RCode: dns.RcodeSuccess,
|
||||
RRType: dns.TypeAAAA,
|
||||
@@ -505,11 +527,6 @@ func TestHostsContainer(t *testing.T) {
|
||||
RRType: dns.TypeAAAA,
|
||||
Value: net.ParseIP("::31"),
|
||||
}},
|
||||
name: "issue_4216_6",
|
||||
req: &urlfilter.DNSRequest{
|
||||
Hostname: "domain6",
|
||||
DNSType: dns.TypeAAAA,
|
||||
},
|
||||
}}
|
||||
|
||||
stubWatcher := aghtest.FSWatcher{
|
||||
|
||||
@@ -154,10 +154,13 @@ func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
|
||||
return netIfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceByIP returns the name of interface containing provided ip.
|
||||
// InterfaceByIP returns the name of the interface bound to ip.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
|
||||
func GetInterfaceByIP(ip net.IP) string {
|
||||
// TODO(a.garipov, e.burkov): This function is technically incorrect, since one
|
||||
// IP address can be shared by multiple interfaces in some configurations.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
|
||||
func InterfaceByIP(ip net.IP) (ifaceName string) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -177,7 +180,7 @@ func GetInterfaceByIP(ip net.IP) string {
|
||||
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if
|
||||
// the search fails.
|
||||
//
|
||||
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb.
|
||||
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
|
||||
func GetSubnet(ifaceName string) *net.IPNet {
|
||||
netIfaces, err := GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -22,17 +23,27 @@ import (
|
||||
const dhcpcdConf = "etc/dhcpcd.conf"
|
||||
|
||||
func canBindPrivilegedPorts() (can bool, err error) {
|
||||
cnbs, err := unix.PrctlRetInt(
|
||||
res, err := unix.PrctlRetInt(
|
||||
unix.PR_CAP_AMBIENT,
|
||||
unix.PR_CAP_AMBIENT_IS_SET,
|
||||
unix.CAP_NET_BIND_SERVICE,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINVAL) {
|
||||
// Older versions of Linux kernel do not support this. Print a
|
||||
// warning and check admin rights.
|
||||
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
|
||||
} else {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Don't check the error because it's always nil on Linux.
|
||||
adm, _ := aghos.HaveAdminRights()
|
||||
|
||||
return cnbs == 1 || adm, err
|
||||
return res == 1 || adm, nil
|
||||
}
|
||||
|
||||
// dhcpcdStaticConfig checks if interface is configured by /etc/dhcpcd.conf to
|
||||
|
||||
@@ -132,7 +132,7 @@ func TestGatewayIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInterfaceByIP(t *testing.T) {
|
||||
func TestInterfaceByIP(t *testing.T) {
|
||||
ifaces, err := GetValidNetInterfacesForWeb()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ifaces)
|
||||
@@ -142,7 +142,7 @@ func TestGetInterfaceByIP(t *testing.T) {
|
||||
require.NotEmpty(t, iface.Addresses)
|
||||
|
||||
for _, ip := range iface.Addresses {
|
||||
ifaceName := GetInterfaceByIP(ip)
|
||||
ifaceName := InterfaceByIP(ip)
|
||||
require.Equal(t, iface.Name, ifaceName)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -19,7 +19,7 @@ type SystemResolvers interface {
|
||||
}
|
||||
|
||||
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||
// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If
|
||||
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
|
||||
// nil is passed for hostGenFunc, the default generator will be used.
|
||||
func NewSystemResolvers(
|
||||
hostGenFunc HostGenFunc,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
57
internal/aghos/filewalker_internal_test.go
Normal file
57
internal/aghos/filewalker_internal_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package aghos
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// errFS is an fs.FS implementation, method Open of which always returns
|
||||
// errFSOpen.
|
||||
type errFS struct{}
|
||||
|
||||
// errFSOpen is returned from errGlobFS.Open.
|
||||
const errFSOpen errors.Error = "test open error"
|
||||
|
||||
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and
|
||||
// err is always errFSOpen.
|
||||
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
|
||||
return nil, errFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -1,13 +1,13 @@
|
||||
package aghos
|
||||
package aghos_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func TestFileWalker_Walk(t *testing.T) {
|
||||
const attribute = `000`
|
||||
|
||||
makeFileWalker := func(_ string) (fw FileWalker) {
|
||||
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
|
||||
return func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
@@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
f := fstest.MapFS{
|
||||
filename: &fstest.MapFile{Data: []byte("[]")},
|
||||
}
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
patterns = append(patterns, s.Text())
|
||||
@@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
|
||||
}
|
||||
|
||||
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
|
||||
return nil, true, rerr
|
||||
}).Walk(f, "*")
|
||||
require.ErrorIs(t, err, rerr)
|
||||
@@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
type errFS struct {
|
||||
fs.GlobFS
|
||||
}
|
||||
|
||||
const errErrFSOpen errors.Error = "this error is always returned"
|
||||
|
||||
func (efs *errFS) Open(name string) (fs.File, error) {
|
||||
return nil, errErrFSOpen
|
||||
}
|
||||
|
||||
func TestWalkerFunc_CheckFile(t *testing.T) {
|
||||
emptyFS := fstest.MapFS{}
|
||||
|
||||
t.Run("non-existing", func(t *testing.T) {
|
||||
_, ok, err := checkFile(emptyFS, nil, "lol")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("invalid_argument", func(t *testing.T) {
|
||||
_, ok, err := checkFile(&errFS{}, nil, "")
|
||||
require.ErrorIs(t, err, errErrFSOpen)
|
||||
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("ignore_dirs", func(t *testing.T) {
|
||||
const dirName = "dir"
|
||||
|
||||
testFS := fstest.MapFS{
|
||||
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
|
||||
}
|
||||
|
||||
patterns, ok, err := checkFile(testFS, nil, dirName)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, patterns)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Exchanger is a mock aghnet.Exchanger implementation for tests.
|
||||
type Exchanger struct {
|
||||
Ups upstream.Upstream
|
||||
}
|
||||
|
||||
// Exchange implements aghnet.Exchanger interface for *Exchanger.
|
||||
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
if e.Ups == nil {
|
||||
e.Ups = &TestErrUpstream{}
|
||||
}
|
||||
|
||||
return e.Ups.Exchange(req)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the aghos.FSWatcher interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
135
internal/aghtest/interface.go
Normal file
135
internal/aghtest/interface.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Interface Mocks
|
||||
//
|
||||
// Keep entities in this file in alphabetic order.
|
||||
|
||||
// Standard Library
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock [fs.FS] implementation for tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the [fs.FS] interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock [fs.GlobFS] implementation for tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the [fs.GlobFS] interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock [fs.StatFS] implementation for tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the [fs.StatFS] interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ net.Listener = (*Listener)(nil)
|
||||
|
||||
// Listener is a mock [net.Listener] implementation for tests.
|
||||
type Listener struct {
|
||||
OnAccept func() (conn net.Conn, err error)
|
||||
OnAddr func() (addr net.Addr)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Accept implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
return l.OnAccept()
|
||||
}
|
||||
|
||||
// Addr implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Addr() (addr net.Addr) {
|
||||
return l.OnAddr()
|
||||
}
|
||||
|
||||
// Close implements the [net.Listener] interface for *Listener.
|
||||
func (l *Listener) Close() (err error) {
|
||||
return l.OnClose()
|
||||
}
|
||||
|
||||
// Module dnsproxy
|
||||
|
||||
// type check
|
||||
var _ upstream.Upstream = (*UpstreamMock)(nil)
|
||||
|
||||
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
|
||||
//
|
||||
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
|
||||
// rename it to just Upstream.
|
||||
type UpstreamMock struct {
|
||||
OnAddress func() (addr string)
|
||||
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
|
||||
}
|
||||
|
||||
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Address() (addr string) {
|
||||
return u.OnAddress()
|
||||
}
|
||||
|
||||
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
|
||||
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return u.OnExchange(req)
|
||||
}
|
||||
|
||||
// Module AdGuardHome
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*FSWatcher)(nil)
|
||||
|
||||
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
|
||||
type FSWatcher struct {
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnAdd func(name string) (err error)
|
||||
OnClose func() (err error)
|
||||
}
|
||||
|
||||
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Close() (err error) {
|
||||
return w.OnClose()
|
||||
}
|
||||
9
internal/aghtest/interface_test.go
Normal file
9
internal/aghtest/interface_test.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package aghtest_test
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)
|
||||
@@ -1,46 +0,0 @@
|
||||
package aghtest
|
||||
|
||||
import "io/fs"
|
||||
|
||||
// type check
|
||||
var _ fs.FS = &FS{}
|
||||
|
||||
// FS is a mock fs.FS implementation to use in tests.
|
||||
type FS struct {
|
||||
OnOpen func(name string) (fs.File, error)
|
||||
}
|
||||
|
||||
// Open implements the fs.FS interface for *FS.
|
||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
||||
return fsys.OnOpen(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.StatFS = &StatFS{}
|
||||
|
||||
// StatFS is a mock fs.StatFS implementation to use in tests.
|
||||
type StatFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnStat func(name string) (fs.FileInfo, error)
|
||||
}
|
||||
|
||||
// Stat implements the fs.StatFS interface for *StatFS.
|
||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
||||
return fsys.OnStat(name)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ fs.GlobFS = &GlobFS{}
|
||||
|
||||
// GlobFS is a mock fs.GlobFS implementation to use in tests.
|
||||
type GlobFS struct {
|
||||
// FS is embedded here to avoid implementing all it's methods.
|
||||
FS
|
||||
OnGlob func(pattern string) ([]string, error)
|
||||
}
|
||||
|
||||
// Glob implements the fs.GlobFS interface for *GlobFS.
|
||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
||||
return fsys.OnGlob(pattern)
|
||||
}
|
||||
@@ -6,12 +6,18 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Additional Upstream Testing Utilities
|
||||
|
||||
// Upstream is a mock implementation of upstream.Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
|
||||
type Upstream struct {
|
||||
// CName is a map of hostname to canonical name.
|
||||
CName map[string][]string
|
||||
@@ -25,6 +31,43 @@ type Upstream struct {
|
||||
Addr string
|
||||
}
|
||||
|
||||
// RespondTo returns a response with answer if req has class cl, question type
|
||||
// qt, and target targ.
|
||||
func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, req)
|
||||
require.Len(t, req.Question, 1)
|
||||
|
||||
q := req.Question[0]
|
||||
targ = dns.Fqdn(targ)
|
||||
if q.Qclass != cl || q.Qtype != qt || q.Name != targ {
|
||||
return nil
|
||||
}
|
||||
|
||||
respHdr := dns.RR_Header{
|
||||
Name: targ,
|
||||
Rrtype: qt,
|
||||
Class: cl,
|
||||
Ttl: 60,
|
||||
}
|
||||
|
||||
resp = new(dns.Msg).SetReply(req)
|
||||
switch qt {
|
||||
case dns.TypePTR:
|
||||
resp.Answer = []dns.RR{
|
||||
&dns.PTR{
|
||||
Hdr: respHdr,
|
||||
Ptr: answer,
|
||||
},
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unsupported question type: %s", dns.Type(qt))
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// Exchange implements the upstream.Upstream interface for *Upstream.
|
||||
//
|
||||
// TODO(a.garipov): Split further into handlers.
|
||||
@@ -76,74 +119,57 @@ func (u *Upstream) Address() string {
|
||||
return u.Addr
|
||||
}
|
||||
|
||||
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestBlockUpstream struct {
|
||||
Hostname string
|
||||
|
||||
// lock protects reqNum.
|
||||
lock sync.RWMutex
|
||||
reqNum int
|
||||
|
||||
Block bool
|
||||
}
|
||||
|
||||
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
||||
// pair.
|
||||
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
u.reqNum++
|
||||
|
||||
hash := sha256.Sum256([]byte(u.Hostname))
|
||||
hashToReturn := hex.EncodeToString(hash[:])
|
||||
if !u.Block {
|
||||
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
|
||||
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
|
||||
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
|
||||
// a different hash.
|
||||
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
|
||||
hash := sha256.Sum256([]byte(hostname))
|
||||
hashStr := hex.EncodeToString(hash[:])
|
||||
if !shouldBlock {
|
||||
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
m.Answer = []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
},
|
||||
Txt: []string{
|
||||
hashToReturn,
|
||||
},
|
||||
ans := &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "",
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
Txt: []string{hashStr},
|
||||
}
|
||||
respTmpl := &dns.Msg{
|
||||
Answer: []dns.RR{ans},
|
||||
}
|
||||
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return "sbpc.upstream.example"
|
||||
},
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = respTmpl.Copy()
|
||||
resp.SetReply(req)
|
||||
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestBlockUpstream) Address() string {
|
||||
return ""
|
||||
}
|
||||
// ErrUpstream is the error returned from the [*UpstreamMock] created by
|
||||
// [NewErrorUpstream].
|
||||
const ErrUpstream errors.Error = "test upstream error"
|
||||
|
||||
// RequestsCount returns the number of handled requests. It's safe for
|
||||
// concurrent use.
|
||||
func (u *TestBlockUpstream) RequestsCount() int {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
return u.reqNum
|
||||
}
|
||||
|
||||
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestErrUpstream struct {
|
||||
// The error returned by Exchange may be unwrapped to the Err.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, fmt.Errorf("errupstream: %w", u.Err)
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestErrUpstream) Address() string {
|
||||
return ""
|
||||
// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
|
||||
// its Exchange method.
|
||||
func NewErrorUpstream() (u *UpstreamMock) {
|
||||
return &UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return "error.upstream.example"
|
||||
},
|
||||
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
|
||||
return nil, errors.Error("test upstream error")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/mdlayher/ethernet"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
@@ -49,16 +51,15 @@ type dhcpConn struct {
|
||||
}
|
||||
|
||||
// newDHCPConn creates the special connection for DHCP server.
|
||||
func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error) {
|
||||
// Create the raw connection.
|
||||
func (s *v4Server) newDHCPConn(iface *net.Interface) (c net.PacketConn, err error) {
|
||||
var ucast net.PacketConn
|
||||
if ucast, err = raw.ListenPacket(ifi, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
if ucast, err = raw.ListenPacket(iface, uint16(ethernet.EtherTypeIPv4), nil); err != nil {
|
||||
return nil, fmt.Errorf("creating raw udp connection: %w", err)
|
||||
}
|
||||
|
||||
// Create the UDP connection.
|
||||
var bcast net.PacketConn
|
||||
bcast, err = server4.NewIPv4UDPConn(ifi.Name, &net.UDPAddr{
|
||||
bcast, err = server4.NewIPv4UDPConn(iface.Name, &net.UDPAddr{
|
||||
// TODO(e.burkov): Listening on zeroes makes the server handle
|
||||
// requests from all the interfaces. Inspect the ways to
|
||||
// specify the interface-specific listening addresses.
|
||||
@@ -75,7 +76,7 @@ func (s *v4Server) newDHCPConn(ifi *net.Interface) (c net.PacketConn, err error)
|
||||
udpConn: bcast,
|
||||
bcastIP: s.conf.broadcastIP,
|
||||
rawConn: ucast,
|
||||
srcMAC: ifi.HardwareAddr,
|
||||
srcMAC: iface.HardwareAddr,
|
||||
srcIP: s.conf.dnsIPAddrs[0],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
func TestDHCPConn_WriteTo_common(t *testing.T) {
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
@@ -126,7 +126,7 @@ type ServerConfig struct {
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -145,7 +146,7 @@ type dhcpServerConfigJSON struct {
|
||||
V4 *v4ServerConfJSON `json:"v4"`
|
||||
V6 *v6ServerConfJSON `json:"v6"`
|
||||
InterfaceName string `json:"interface_name"`
|
||||
Enabled nullBool `json:"enabled"`
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfigV4(
|
||||
@@ -156,7 +157,7 @@ func (s *Server) handleDHCPSetConfigV4(
|
||||
}
|
||||
|
||||
v4Conf := v4JSONToServerConf(conf.V4)
|
||||
v4Conf.Enabled = conf.Enabled == nbTrue
|
||||
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
if len(v4Conf.RangeStart) == 0 {
|
||||
v4Conf.Enabled = false
|
||||
}
|
||||
@@ -183,7 +184,7 @@ func (s *Server) handleDHCPSetConfigV6(
|
||||
}
|
||||
|
||||
v6Conf := v6JSONToServerConf(conf.V6)
|
||||
v6Conf.Enabled = conf.Enabled == nbTrue
|
||||
v6Conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
if len(v6Conf.RangeStart) == 0 {
|
||||
v6Conf.Enabled = false
|
||||
}
|
||||
@@ -206,7 +207,7 @@ func (s *Server) handleDHCPSetConfigV6(
|
||||
|
||||
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
conf := &dhcpServerConfigJSON{}
|
||||
conf.Enabled = boolToNullBool(s.conf.Enabled)
|
||||
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
|
||||
conf.InterfaceName = s.conf.InterfaceName
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(conf)
|
||||
@@ -230,7 +231,7 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
|
||||
if conf.Enabled == aghalg.NBTrue && !v4Enabled && !v6Enabled {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "dhcpv4 or dhcpv6 configuration must be complete")
|
||||
|
||||
return
|
||||
@@ -243,8 +244,8 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if conf.Enabled != nbNull {
|
||||
s.conf.Enabled = conf.Enabled == nbTrue
|
||||
if conf.Enabled != aghalg.NBNull {
|
||||
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
}
|
||||
|
||||
if conf.InterfaceName != "" {
|
||||
@@ -279,11 +280,11 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Flags string `json:"flags"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
Addrs4 []net.IP `json:"ipv4_addresses"`
|
||||
Addrs6 []net.IP `json:"ipv6_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -497,7 +498,6 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
ip4 := l.IP.To4()
|
||||
|
||||
if ip4 == nil {
|
||||
l.IP = l.IP.To16()
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// nullBool is a nullable boolean. Use these in JSON requests and responses
|
||||
// instead of pointers to bool.
|
||||
//
|
||||
// TODO(a.garipov): Inspect uses of *bool, move this type into some new package
|
||||
// if we need it somewhere else.
|
||||
type nullBool uint8
|
||||
|
||||
// nullBool values
|
||||
const (
|
||||
nbNull nullBool = iota
|
||||
nbTrue
|
||||
nbFalse
|
||||
)
|
||||
|
||||
// String implements the fmt.Stringer interface for nullBool.
|
||||
func (nb nullBool) String() (s string) {
|
||||
switch nb {
|
||||
case nbNull:
|
||||
return "null"
|
||||
case nbTrue:
|
||||
return "true"
|
||||
case nbFalse:
|
||||
return "false"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("!invalid nullBool %d", uint8(nb))
|
||||
}
|
||||
|
||||
// boolToNullBool converts a bool into a nullBool.
|
||||
func boolToNullBool(cond bool) (nb nullBool) {
|
||||
if cond {
|
||||
return nbTrue
|
||||
}
|
||||
|
||||
return nbFalse
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for *nullBool.
|
||||
func (nb *nullBool) UnmarshalJSON(b []byte) (err error) {
|
||||
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
|
||||
*nb = nbNull
|
||||
} else if bytes.Equal(b, []byte("true")) {
|
||||
*nb = nbTrue
|
||||
} else if bytes.Equal(b, []byte("false")) {
|
||||
*nb = nbFalse
|
||||
} else {
|
||||
return fmt.Errorf("invalid nullBool value %q", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullBool_UnmarshalJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
data []byte
|
||||
want nullBool
|
||||
}{{
|
||||
name: "empty",
|
||||
wantErrMsg: "",
|
||||
data: []byte{},
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "null",
|
||||
wantErrMsg: "",
|
||||
data: []byte("null"),
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "true",
|
||||
wantErrMsg: "",
|
||||
data: []byte("true"),
|
||||
want: nbTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
wantErrMsg: "",
|
||||
data: []byte("false"),
|
||||
want: nbFalse,
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErrMsg: `invalid nullBool value "invalid"`,
|
||||
data: []byte("invalid"),
|
||||
want: nbNull,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got nullBool
|
||||
err := got.UnmarshalJSON(tc.data)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
want := nbTrue
|
||||
var got struct {
|
||||
A nullBool
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(`{"A":true}`), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want, got.A)
|
||||
})
|
||||
}
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
@@ -91,6 +94,9 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna
|
||||
|
||||
if hostname == "" {
|
||||
hostname = aghnet.GenerateHostname(ip)
|
||||
} else if s.leaseHosts.Has(hostname) {
|
||||
log.Info("dhcpv4: hostname %q already exists", hostname)
|
||||
hostname = aghnet.GenerateHostname(ip)
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(hostname)
|
||||
@@ -250,11 +256,11 @@ func (s *v4Server) rmLeaseByIndex(i int) {
|
||||
// Remove a dynamic lease with the same properties
|
||||
// Return error if a static lease is found
|
||||
func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
for i := 0; i < len(s.leases); i++ {
|
||||
l := s.leases[i]
|
||||
for i, l := range s.leases {
|
||||
isStatic := l.IsStatic()
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) {
|
||||
if l.IsStatic() {
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) || l.IP.Equal(lease.IP) {
|
||||
if isStatic {
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
|
||||
@@ -266,20 +272,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if l.IP.Equal(lease.IP) {
|
||||
if l.IsStatic() {
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
|
||||
s.rmLeaseByIndex(i)
|
||||
if i == len(s.leases) {
|
||||
break
|
||||
}
|
||||
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if l.Hostname == lease.Hostname {
|
||||
if !isStatic && l.Hostname == lease.Hostname {
|
||||
l.Hostname = ""
|
||||
}
|
||||
}
|
||||
@@ -287,6 +280,10 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrDupHostname is returned by addLease when the added lease has a not empty
|
||||
// non-unique hostname.
|
||||
const ErrDupHostname = errors.Error("hostname is not unique")
|
||||
|
||||
// addLease adds a dynamic or static lease.
|
||||
func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
r := s.conf.ipRange
|
||||
@@ -302,13 +299,17 @@ func (s *v4Server) addLease(l *Lease) (err error) {
|
||||
return fmt.Errorf("lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
s.leases = append(s.leases, l)
|
||||
s.leasedOffsets.set(offset, true)
|
||||
|
||||
if l.Hostname != "" {
|
||||
if s.leaseHosts.Has(l.Hostname) {
|
||||
return ErrDupHostname
|
||||
}
|
||||
|
||||
s.leaseHosts.Add(l.Hostname)
|
||||
}
|
||||
|
||||
s.leases = append(s.leases, l)
|
||||
s.leasedOffsets.set(offset, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -333,12 +334,16 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
|
||||
return errors.Error("lease not found")
|
||||
}
|
||||
|
||||
// AddStaticLease adds a static lease. It is safe for concurrent use.
|
||||
// AddStaticLease implements the DHCPServer interface for *v4Server. It is safe
|
||||
// for concurrent use.
|
||||
func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
|
||||
|
||||
if ip4 := l.IP.To4(); ip4 == nil {
|
||||
ip := l.IP.To4()
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
|
||||
} else if gwIP := s.conf.GatewayIP; gwIP.Equal(ip) {
|
||||
return fmt.Errorf("can't assign the gateway IP %s to the lease", gwIP)
|
||||
}
|
||||
|
||||
l.Expiry = time.Unix(leaseExpireStatic, 0)
|
||||
@@ -359,10 +364,11 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
return fmt.Errorf("validating hostname: %w", err)
|
||||
}
|
||||
|
||||
// Don't check for hostname uniqueness, since we try to emulate
|
||||
// dnsmasq here, which means that rmDynamicLease below will
|
||||
// simply empty the hostname of the dynamic lease if there even
|
||||
// is one.
|
||||
// Don't check for hostname uniqueness, since we try to emulate dnsmasq
|
||||
// here, which means that rmDynamicLease below will simply empty the
|
||||
// hostname of the dynamic lease if there even is one. In case a static
|
||||
// lease with the same name already exists, addLease will return an
|
||||
// error and the lease won't be added.
|
||||
|
||||
l.Hostname = hostname
|
||||
}
|
||||
@@ -377,7 +383,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf(
|
||||
"removing dynamic leases for %s (%s): %w",
|
||||
l.IP,
|
||||
ip,
|
||||
l.HWAddr,
|
||||
err,
|
||||
)
|
||||
@@ -387,7 +393,7 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
|
||||
err = fmt.Errorf("adding static lease for %s (%s): %w", ip, l.HWAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -517,11 +523,7 @@ func (s *v4Server) findExpiredLease() int {
|
||||
// reserveLease reserves a lease for a client by its MAC-address. It returns
|
||||
// nil if it couldn't allocate a new lease.
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
|
||||
l = &Lease{
|
||||
HWAddr: make([]byte, len(mac)),
|
||||
}
|
||||
|
||||
copy(l.HWAddr, mac)
|
||||
l = &Lease{HWAddr: slices.Clone(mac)}
|
||||
|
||||
l.IP = s.nextIP()
|
||||
if l.IP == nil {
|
||||
@@ -614,33 +616,25 @@ func (s *v4Server) processDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err erro
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type optFQDN struct {
|
||||
name string
|
||||
}
|
||||
// OptionFQDN returns a DHCPv4 option for sending the FQDN to the client
|
||||
// requested another hostname.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4702.
|
||||
func OptionFQDN(fqdn string) (opt dhcpv4.Option) {
|
||||
optData := []byte{
|
||||
// Set only S and O DHCP client FQDN option flags.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4702#section-2.1.
|
||||
1<<0 | 1<<1,
|
||||
// The RCODE fields should be set to 0xFF in the server responses.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4702#section-2.2.
|
||||
0xFF,
|
||||
0xFF,
|
||||
}
|
||||
optData = append(optData, fqdn...)
|
||||
|
||||
func (o *optFQDN) String() string {
|
||||
return "optFQDN"
|
||||
}
|
||||
|
||||
// flags[1]
|
||||
// A-RR[1]
|
||||
// PTR-RR[1]
|
||||
// name[]
|
||||
func (o *optFQDN) ToBytes() []byte {
|
||||
b := make([]byte, 3+len(o.name))
|
||||
i := 0
|
||||
|
||||
b[i] = 0x03 // f_server_overrides | f_server
|
||||
i++
|
||||
|
||||
b[i] = 255 // A-RR
|
||||
i++
|
||||
|
||||
b[i] = 255 // PTR-RR
|
||||
i++
|
||||
|
||||
copy(b[i:], []byte(o.name))
|
||||
return b
|
||||
return dhcpv4.OptGeneric(dhcpv4.OptionFQDN, optData)
|
||||
}
|
||||
|
||||
// checkLease checks if the pair of mac and ip is already leased. The mismatch
|
||||
@@ -673,6 +667,8 @@ func (s *v4Server) checkLease(mac net.HardwareAddr, ip net.IP) (lease *Lease, mi
|
||||
// processRequest is the handler for the DHCP Request request.
|
||||
func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needsReply bool) {
|
||||
mac := req.ClientHWAddr
|
||||
// TODO(e.burkov): The IP address can only be requested in DHCPDISCOVER
|
||||
// message.
|
||||
reqIP := req.RequestedIPAddress()
|
||||
if reqIP == nil {
|
||||
reqIP = req.ClientIPAddr
|
||||
@@ -705,24 +701,17 @@ func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, needs
|
||||
if !lease.IsStatic() {
|
||||
cliHostname := req.HostName()
|
||||
hostname := s.validHostnameForClient(cliHostname, reqIP)
|
||||
if hostname != lease.Hostname && s.leaseHosts.Has(hostname) {
|
||||
log.Info("dhcpv4: hostname %q already exists", hostname)
|
||||
lease.Hostname = ""
|
||||
} else {
|
||||
if lease.Hostname != hostname {
|
||||
lease.Hostname = hostname
|
||||
resp.UpdateOption(dhcpv4.OptHostName(hostname))
|
||||
}
|
||||
|
||||
s.commitLease(lease)
|
||||
} else if lease.Hostname != "" {
|
||||
o := &optFQDN{
|
||||
name: lease.Hostname,
|
||||
}
|
||||
fqdn := dhcpv4.Option{
|
||||
Code: dhcpv4.OptionFQDN,
|
||||
Value: o,
|
||||
}
|
||||
|
||||
resp.UpdateOption(fqdn)
|
||||
// TODO(e.burkov): This option is used to update the server's DNS
|
||||
// mapping. The option should only be answered when it has been
|
||||
// requested.
|
||||
resp.UpdateOption(OptionFQDN(lease.Hostname))
|
||||
}
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
|
||||
@@ -845,7 +834,7 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
|
||||
|
||||
// TODO(a.garipov): Refactor this into handlers.
|
||||
var l *Lease
|
||||
switch req.MessageType() {
|
||||
switch mt := req.MessageType(); mt {
|
||||
case dhcpv4.MessageTypeDiscover:
|
||||
l, err = s.processDiscover(req, resp)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,16 +4,29 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
//lint:ignore SA1019 See the TODO in go.mod.
|
||||
"github.com/mdlayher/raw"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultRangeStart = net.IP{192, 168, 10, 100}
|
||||
DefaultRangeEnd = net.IP{192, 168, 10, 200}
|
||||
DefaultGatewayIP = net.IP{192, 168, 10, 1}
|
||||
DefaultSelfIP = net.IP{192, 168, 10, 2}
|
||||
DefaultSubnetMask = net.IP{255, 255, 255, 0}
|
||||
)
|
||||
|
||||
func notify4(flags uint32) {
|
||||
@@ -24,11 +37,12 @@ func notify4(flags uint32) {
|
||||
func defaultV4ServerConf() (conf V4ServerConf) {
|
||||
return V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
RangeStart: DefaultRangeStart,
|
||||
RangeEnd: DefaultRangeEnd,
|
||||
GatewayIP: DefaultGatewayIP,
|
||||
SubnetMask: DefaultSubnetMask,
|
||||
notify: notify4,
|
||||
dnsIPAddrs: []net.IP{DefaultSelfIP},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,44 +58,228 @@ func defaultSrv(t *testing.T) (s DHCPServer) {
|
||||
return s
|
||||
}
|
||||
|
||||
func TestV4_AddRemove_static(t *testing.T) {
|
||||
func TestV4Server_leasing(t *testing.T) {
|
||||
const (
|
||||
staticName = "static-client"
|
||||
anotherName = "another-client"
|
||||
)
|
||||
|
||||
staticIP := net.IP{192, 168, 10, 10}
|
||||
anotherIP := DefaultRangeStart
|
||||
staticMAC := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
anotherMAC := net.HardwareAddr{0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}
|
||||
|
||||
s := defaultSrv(t)
|
||||
|
||||
t.Run("add_static", func(t *testing.T) {
|
||||
err := s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: staticMAC,
|
||||
IP: staticIP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("same_name", func(t *testing.T) {
|
||||
err = s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: staticName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: anotherIP,
|
||||
})
|
||||
assert.ErrorIs(t, err, ErrDupHostname)
|
||||
})
|
||||
|
||||
t.Run("same_mac", func(t *testing.T) {
|
||||
wantErrMsg := "dhcpv4: adding static lease: removing " +
|
||||
"dynamic leases for " + anotherIP.String() +
|
||||
" (" + staticMAC.String() + "): static lease already exists"
|
||||
|
||||
err = s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: staticMAC,
|
||||
IP: anotherIP,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
|
||||
t.Run("same_ip", func(t *testing.T) {
|
||||
wantErrMsg := "dhcpv4: adding static lease: removing " +
|
||||
"dynamic leases for " + staticIP.String() +
|
||||
" (" + anotherMAC.String() + "): static lease already exists"
|
||||
|
||||
err = s.AddStaticLease(&Lease{
|
||||
Expiry: time.Unix(leaseExpireStatic, 0),
|
||||
Hostname: anotherName,
|
||||
HWAddr: anotherMAC,
|
||||
IP: staticIP,
|
||||
})
|
||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("add_dynamic", func(t *testing.T) {
|
||||
s4, ok := s.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
discoverAnOffer := func(
|
||||
t *testing.T,
|
||||
name string,
|
||||
ip net.IP,
|
||||
mac net.HardwareAddr,
|
||||
) (resp *dhcpv4.DHCPv4) {
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.ResetLeases(s.GetLeases(LeasesStatic))
|
||||
})
|
||||
|
||||
req, err := dhcpv4.NewDiscovery(
|
||||
mac,
|
||||
dhcpv4.WithOption(dhcpv4.OptHostName(name)),
|
||||
dhcpv4.WithOption(dhcpv4.OptRequestedIPAddress(ip)),
|
||||
dhcpv4.WithOption(dhcpv4.OptClientIdentifier([]byte{1, 2, 3, 4, 5, 6, 8})),
|
||||
dhcpv4.WithGatewayIP(DefaultGatewayIP),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp = &dhcpv4.DHCPv4{}
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
require.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
|
||||
resp.ClientHWAddr = mac
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
t.Run("same_name", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, staticName, anotherIP, anotherMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(staticName),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
assert.Equal(t, aghnet.GenerateHostname(resp.YourIPAddr), resp.HostName())
|
||||
})
|
||||
|
||||
t.Run("same_mac", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, anotherName, anotherIP, staticMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(anotherName),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
fqdnOptData := resp.Options.Get(dhcpv4.OptionFQDN)
|
||||
require.Len(t, fqdnOptData, 3+len(staticName))
|
||||
assert.Equal(t, []uint8(staticName), fqdnOptData[3:])
|
||||
|
||||
assert.Equal(t, staticIP, resp.YourIPAddr)
|
||||
})
|
||||
|
||||
t.Run("same_ip", func(t *testing.T) {
|
||||
resp := discoverAnOffer(t, anotherName, staticIP, anotherMAC)
|
||||
|
||||
req, err := dhcpv4.NewRequestFromOffer(resp, dhcpv4.WithOption(
|
||||
dhcpv4.OptHostName(anotherName),
|
||||
))
|
||||
require.NoError(t, err)
|
||||
|
||||
res := s4.process(req, resp)
|
||||
require.Positive(t, res)
|
||||
|
||||
assert.NotEqual(t, staticIP, resp.YourIPAddr)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4Server_AddRemove_static(t *testing.T) {
|
||||
s := defaultSrv(t)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
require.Empty(t, ls)
|
||||
|
||||
// Add static lease.
|
||||
l := &Lease{
|
||||
Hostname: "static-1.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
testCases := []struct {
|
||||
lease *Lease
|
||||
name string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
lease: &Lease{
|
||||
Hostname: "success.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
},
|
||||
name: "success",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "probably-router.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: DefaultGatewayIP,
|
||||
},
|
||||
name: "with_gateway_ip",
|
||||
wantErrMsg: "dhcpv4: adding static lease: " +
|
||||
"can't assign the gateway IP 192.168.10.1 to the lease",
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "ip6.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.ParseIP("ffff::1"),
|
||||
},
|
||||
name: "ipv6",
|
||||
wantErrMsg: `dhcpv4: adding static lease: ` +
|
||||
`invalid ip "ffff::1", only ipv4 is supported`,
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-mac.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
},
|
||||
name: "bad_mac",
|
||||
wantErrMsg: `dhcpv4: adding static lease: bad mac address "aa:aa": ` +
|
||||
`bad mac address length 2, allowed: [6 8 20]`,
|
||||
}, {
|
||||
lease: &Lease{
|
||||
Hostname: "bad-lbl-.local",
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
},
|
||||
name: "bad_hostname",
|
||||
wantErrMsg: `dhcpv4: adding static lease: validating hostname: ` +
|
||||
`bad domain name "bad-lbl-.local": ` +
|
||||
`bad domain name label "bad-lbl-": bad domain name label rune '-'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := s.AddStaticLease(tc.lease)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
if tc.wantErrMsg != "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: tc.lease.IP,
|
||||
HWAddr: tc.lease.HWAddr,
|
||||
})
|
||||
diffErrMsg := fmt.Sprintf("dhcpv4: lease for ip %s is different: %+v", tc.lease.IP, tc.lease)
|
||||
testutil.AssertErrorMsg(t, diffErrMsg, err)
|
||||
|
||||
// Remove static lease.
|
||||
err = s.RemoveStaticLease(tc.lease)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
require.Emptyf(t, ls, "after %s", tc.name)
|
||||
}
|
||||
|
||||
err := s.AddStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.AddStaticLease(l)
|
||||
assert.Error(t, err)
|
||||
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
|
||||
assert.True(t, ls[0].IsStatic())
|
||||
|
||||
// Try to remove static lease.
|
||||
err = s.RemoveStaticLease(&Lease{
|
||||
IP: net.IP{192, 168, 10, 110},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Remove static lease.
|
||||
err = s.RemoveStaticLease(l)
|
||||
require.NoError(t, err)
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Empty(t, ls)
|
||||
}
|
||||
|
||||
func TestV4_AddReplace(t *testing.T) {
|
||||
|
||||
@@ -214,7 +214,7 @@ func validateAccessSet(list *accessListJSON) (err error) {
|
||||
}
|
||||
|
||||
merged := allowed.Merge(disallowed)
|
||||
err = merged.Validate(aghalg.StringIsBefore)
|
||||
err = merged.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("items in allowed and disallowed clients intersect: %w", err)
|
||||
}
|
||||
@@ -223,13 +223,13 @@ func validateAccessSet(list *accessListJSON) (err error) {
|
||||
}
|
||||
|
||||
// validateStrUniq returns an informative error if clients are not unique.
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker, err error) {
|
||||
uc = make(aghalg.UniqChecker, len(clients))
|
||||
func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error) {
|
||||
uc = make(aghalg.UniqChecker[string], len(clients))
|
||||
for _, c := range clients {
|
||||
uc.Add(c)
|
||||
}
|
||||
|
||||
return uc, uc.Validate(aghalg.StringIsBefore)
|
||||
return uc, uc.Validate()
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -65,7 +65,7 @@ func clientIDFromClientServerName(
|
||||
return "", err
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
@@ -104,7 +104,7 @@ func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err e
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return strings.ToLower(clientID), nil
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
@@ -112,8 +112,8 @@ type tlsConn interface {
|
||||
ConnectionState() (cs tls.ConnectionState)
|
||||
}
|
||||
|
||||
// quicSession is a narrow interface for quic.Session to simplify testing.
|
||||
type quicSession interface {
|
||||
// quicConnection is a narrow interface for quic.Connection to simplify testing.
|
||||
type quicConnection interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
@@ -148,16 +148,16 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
case proxy.ProtoQUIC:
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
conn, ok := pctx.QUICConnection.(quicConnection)
|
||||
if !ok {
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
|
||||
proto,
|
||||
pctx.QUICSession,
|
||||
pctx.QUICConnection,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||
cliSrvName = conn.ConnectionState().TLS.ServerName
|
||||
}
|
||||
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
|
||||
@@ -29,17 +29,18 @@ func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
|
||||
return cs
|
||||
}
|
||||
|
||||
// testQUICSession is a quicSession for tests.
|
||||
type testQUICSession struct {
|
||||
// Session is embedded here simply to make testQUICSession a quic.Session
|
||||
// without actually implementing all methods.
|
||||
quic.Session
|
||||
// testQUICConnection is a quicConnection for tests.
|
||||
type testQUICConnection struct {
|
||||
// Connection is embedded here simply to make testQUICConnection a
|
||||
// quic.Connection without actually implementing all methods.
|
||||
quic.Connection
|
||||
|
||||
serverName string
|
||||
}
|
||||
|
||||
// ConnectionState implements the quicSession interface for testQUICSession.
|
||||
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
// ConnectionState implements the quicConnection interface for
|
||||
// testQUICConnection.
|
||||
func (c testQUICConnection) ConnectionState() (cs quic.ConnectionState) {
|
||||
cs.TLS.ServerName = c.serverName
|
||||
|
||||
return cs
|
||||
@@ -143,6 +144,22 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantErrMsg: `clientid check: client server name "cli.myexample.com" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_case",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "InSeNsItIvE.example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_case",
|
||||
proto: proxy.ProtoQUIC,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "InSeNsItIvE.example.com",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -163,17 +180,17 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var qs quic.Session
|
||||
var qconn quic.Connection
|
||||
if tc.proto == proxy.ProtoQUIC {
|
||||
qs = testQUICSession{
|
||||
qconn = testQUICConnection{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICConnection: qconn,
|
||||
}
|
||||
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
@@ -210,6 +227,11 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "clientid_case",
|
||||
path: "/dns-query/InSeNsItIvE",
|
||||
wantClientID: "insensitive",
|
||||
wantErrMsg: ``,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
|
||||
@@ -5,12 +5,12 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
@@ -134,8 +134,9 @@ type FilteringConfig struct {
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
type TLSConfig struct {
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
|
||||
// Reject connection if the client uses server name (in SNI) that doesn't match the certificate
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
@@ -192,7 +193,7 @@ type ServerConfig struct {
|
||||
ConfigModified func()
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
// ResolveClients signals if the RDNS should resolve clients' addresses.
|
||||
ResolveClients bool
|
||||
@@ -277,6 +278,11 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
return proxyConfig, nil
|
||||
}
|
||||
|
||||
const (
|
||||
defaultSafeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
defaultParentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
// initDefaultSettings initializes default settings if nothing
|
||||
// is configured
|
||||
func (s *Server) initDefaultSettings() {
|
||||
@@ -288,12 +294,12 @@ func (s *Server) initDefaultSettings() {
|
||||
s.conf.BootstrapDNS = defaultBootstrap
|
||||
}
|
||||
|
||||
if len(s.conf.ParentalBlockHost) == 0 {
|
||||
s.conf.ParentalBlockHost = parentalBlockHost
|
||||
if s.conf.ParentalBlockHost == "" {
|
||||
s.conf.ParentalBlockHost = defaultParentalBlockHost
|
||||
}
|
||||
|
||||
if len(s.conf.SafeBrowsingBlockHost) == 0 {
|
||||
s.conf.SafeBrowsingBlockHost = safeBrowsingBlockHost
|
||||
if s.conf.SafeBrowsingBlockHost == "" {
|
||||
s.conf.SafeBrowsingBlockHost = defaultSafeBrowsingBlockHost
|
||||
}
|
||||
|
||||
if s.conf.UDPListenAddrs == nil {
|
||||
|
||||
@@ -100,9 +100,9 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDetermineLocal,
|
||||
s.processInternalHosts,
|
||||
s.processDHCPHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
@@ -230,12 +230,10 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
)
|
||||
}
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname)
|
||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||
ip := netutil.CloneIP(l.IP)
|
||||
|
||||
ipToHost.Set(l.IP, lowhost)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
copy(ip, l.IP.To4())
|
||||
ipToHost.Set(ip, lowhost)
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
@@ -260,9 +258,8 @@ func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if question.Name == ddrHostFQDN {
|
||||
// TODO(a.garipov): Check DoQ support in next RFC drafts.
|
||||
if s.dnsProxy.TLSListenAddr == nil && s.dnsProxy.HTTPSListenAddr == nil ||
|
||||
question.Qtype != dns.TypeSVCB {
|
||||
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
|
||||
s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB {
|
||||
d.Res = s.makeResponse(d.Req)
|
||||
|
||||
return resultCodeFinish
|
||||
@@ -276,12 +273,18 @@ func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// makeDDRResponse creates DDR answer according to server configuration.
|
||||
// makeDDRResponse creates DDR answer according to server configuration. The
|
||||
// contructed SVCB resource records have the priority of 1 for each entry,
|
||||
// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
//
|
||||
// TODO(a.meshkov): Consider setting the priority values based on the protocol.
|
||||
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.makeResponse(req)
|
||||
domainName := s.conf.ServerName
|
||||
// TODO(e.burkov): Think about storing the FQDN version of the server's
|
||||
// name somewhere.
|
||||
domainName := dns.Fqdn(s.conf.ServerName)
|
||||
|
||||
for _, addr := range s.dnsProxy.HTTPSListenAddr {
|
||||
for _, addr := range s.conf.HTTPSListenAddrs {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
@@ -306,7 +309,23 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 2,
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
for _, addr := range s.dnsProxy.QUICListenAddr {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"doq"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: 1,
|
||||
Target: domainName,
|
||||
Value: values,
|
||||
}
|
||||
@@ -355,11 +374,11 @@ func (s *Server) hostToIP(host string) (ip net.IP, ok bool) {
|
||||
return ip, true
|
||||
}
|
||||
|
||||
// processInternalHosts respond to A requests if the target hostname is known to
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server.
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
if !s.dhcpServer.Enabled() {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
@@ -374,11 +393,10 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
reqHost := strings.ToLower(q.Name)
|
||||
reqHost := strings.ToLower(q.Name[:len(q.Name)-1])
|
||||
// TODO(a.garipov): Move everything related to DHCP local domain to the DHCP
|
||||
// server.
|
||||
host := strings.TrimSuffix(reqHost, s.localDomainSuffix)
|
||||
if host == reqHost {
|
||||
if !strings.HasSuffix(reqHost, s.localDomainSuffix) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -391,7 +409,7 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
ip, ok := s.hostToIP(host)
|
||||
ip, ok := s.hostToIP(reqHost)
|
||||
if !ok {
|
||||
// TODO(e.burkov): Inspect special cases when user want to apply some
|
||||
// rules handled by other processors to the hosts with TLD.
|
||||
@@ -448,7 +466,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
// assume that all the DHCP leases we give are locally-served or at least
|
||||
// don't need to be inaccessible externally.
|
||||
// don't need to be accessible externally.
|
||||
if !s.privateNets.Contains(ip) {
|
||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
||||
|
||||
@@ -488,7 +506,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var v interface{}
|
||||
var v any
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
if !ok {
|
||||
return "", false
|
||||
@@ -505,7 +523,7 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
|
||||
// Respond to PTR requests if the target IP is leased by our DHCP server and the
|
||||
// requestor is inside the local network.
|
||||
func (s *Server) processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
func (s *Server) processDHCPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
return resultCodeSuccess
|
||||
|
||||
@@ -14,12 +14,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const ddrTestDomainName = "dns.example.net"
|
||||
const (
|
||||
ddrTestDomainName = "dns.example.net"
|
||||
ddrTestFQDN = ddrTestDomainName + "."
|
||||
)
|
||||
|
||||
func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
dohSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestDomainName,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: 8044},
|
||||
@@ -28,14 +31,23 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
}
|
||||
|
||||
dotSVCB := &dns.SVCB{
|
||||
Priority: 2,
|
||||
Target: ddrTestDomainName,
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"dot"}},
|
||||
&dns.SVCBPort{Port: 8043},
|
||||
},
|
||||
}
|
||||
|
||||
doqSVCB := &dns.SVCB{
|
||||
Priority: 1,
|
||||
Target: ddrTestFQDN,
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"doq"}},
|
||||
&dns.SVCBPort{Port: 8042},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
@@ -43,6 +55,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
wantRes resultCode
|
||||
portDoH int
|
||||
portDoT int
|
||||
portDoQ int
|
||||
qtype uint16
|
||||
ddrEnabled bool
|
||||
}{{
|
||||
@@ -88,6 +101,14 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoH: 8044,
|
||||
}, {
|
||||
name: "doq",
|
||||
wantRes: resultCodeFinish,
|
||||
want: []*dns.SVCB{doqSVCB},
|
||||
host: ddrHostFQDN,
|
||||
qtype: dns.TypeSVCB,
|
||||
ddrEnabled: true,
|
||||
portDoQ: 8042,
|
||||
}, {
|
||||
name: "dot_doh",
|
||||
wantRes: resultCodeFinish,
|
||||
@@ -101,7 +122,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.ddrEnabled)
|
||||
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
|
||||
|
||||
req := createTestMessageWithType(tc.host, tc.qtype)
|
||||
|
||||
@@ -130,19 +151,19 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT int, ddrEnabled bool) (s *Server) {
|
||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
||||
t.Helper()
|
||||
|
||||
proxyConf := proxy.Config{}
|
||||
|
||||
if portDoH > 0 {
|
||||
proxyConf.HTTPSListenAddr = []*net.TCPAddr{{Port: portDoH}}
|
||||
}
|
||||
|
||||
if portDoT > 0 {
|
||||
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
||||
}
|
||||
|
||||
if portDoQ > 0 {
|
||||
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
||||
}
|
||||
|
||||
s = &Server{
|
||||
dnsProxy: &proxy.Proxy{
|
||||
Config: proxyConf,
|
||||
@@ -157,6 +178,10 @@ func prepareTestServer(t *testing.T, portDoH, portDoT int, ddrEnabled bool) (s *
|
||||
},
|
||||
}
|
||||
|
||||
if portDoH > 0 {
|
||||
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -204,7 +229,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -245,7 +270,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: defaultLocalDomainSuffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example": knownIP,
|
||||
"example." + defaultLocalDomainSuffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -267,7 +292,7 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
isLocalClient: tc.isLocalCli,
|
||||
}
|
||||
|
||||
res := s.processInternalHosts(dctx)
|
||||
res := s.processDHCPHosts(dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
pctx := dctx.proxyCtx
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
@@ -293,10 +318,10 @@ func TestServer_ProcessInternalHosts_localRestriction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
const (
|
||||
examplecom = "example.com"
|
||||
examplelan = "example.lan"
|
||||
examplelan = "example." + defaultLocalDomainSuffix
|
||||
)
|
||||
|
||||
knownIP := net.IP{1, 2, 3, 4}
|
||||
@@ -345,41 +370,41 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
|
||||
}, {
|
||||
name: "success_custom_suffix",
|
||||
host: "example.custom",
|
||||
suffix: ".custom.",
|
||||
suffix: "custom",
|
||||
wantIP: knownIP,
|
||||
wantRes: resultCodeSuccess,
|
||||
qtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + tc.suffix: knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtyp,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example": knownIP,
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: 1234,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: dns.Fqdn(tc.host),
|
||||
Qtype: tc.qtyp,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Req: req,
|
||||
},
|
||||
isLocalClient: true,
|
||||
}
|
||||
|
||||
res := s.processInternalHosts(dctx)
|
||||
res := s.processDHCPHosts(dctx)
|
||||
pctx := dctx.proxyCtx
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
if tc.wantRes == resultCodeFinish {
|
||||
|
||||
@@ -33,11 +33,6 @@ const DefaultTimeout = 10 * time.Second
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
)
|
||||
|
||||
var defaultDNS = []string{
|
||||
"https://dns10.quad9.net/dns-query",
|
||||
}
|
||||
@@ -66,7 +61,7 @@ type Server struct {
|
||||
dnsFilter *filtering.DNSFilter // DNS filter instance
|
||||
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Stats
|
||||
stats stats.Interface
|
||||
access *accessCtx
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
@@ -107,12 +102,12 @@ type Server struct {
|
||||
// when no suffix is provided.
|
||||
//
|
||||
// See the documentation for Server.localDomainSuffix.
|
||||
const defaultLocalDomainSuffix = ".lan."
|
||||
const defaultLocalDomainSuffix = "lan"
|
||||
|
||||
// DNSCreateParams are parameters to create a new server.
|
||||
type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Stats
|
||||
Stats stats.Interface
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
PrivateNets netutil.SubnetSet
|
||||
@@ -120,17 +115,6 @@ type DNSCreateParams struct {
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
// domainNameToSuffix converts a domain name into a local domain suffix.
|
||||
func domainNameToSuffix(tld string) (suffix string) {
|
||||
l := len(tld) + 2
|
||||
b := make([]byte, l)
|
||||
b[0] = '.'
|
||||
copy(b[1:], tld)
|
||||
b[l-1] = '.'
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
const (
|
||||
// recursionTTL is the time recursive request is cached for.
|
||||
recursionTTL = 1 * time.Second
|
||||
@@ -151,7 +135,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
return nil, fmt.Errorf("local domain: %w", err)
|
||||
}
|
||||
|
||||
localDomainSuffix = domainNameToSuffix(p.LocalDomain)
|
||||
localDomainSuffix = p.LocalDomain
|
||||
}
|
||||
|
||||
if p.Anonymizer == nil {
|
||||
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -853,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
const hostname = "wmconvirus.narod.ru"
|
||||
|
||||
sbUps := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: true,
|
||||
}
|
||||
sbUps := aghtest.NewBlockUpstream(hostname, true)
|
||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
||||
|
||||
filterConf := &filtering.Config{
|
||||
@@ -988,7 +985,7 @@ func TestRewrite(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) interface{} {
|
||||
func publicKey(priv any) any {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
@@ -1016,31 +1013,33 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
LocalDomain: localDomain,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.FilteringConfig.ProtectionEnabled = true
|
||||
s.conf.ProtectionEnabled = true
|
||||
|
||||
err = s.Prepare(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
|
||||
|
||||
resp, err := dns.Exchange(req, addr.String())
|
||||
require.NoError(t, err)
|
||||
require.NoErrorf(t, err, "%s", addr)
|
||||
|
||||
require.Len(t, resp.Answer, 1)
|
||||
|
||||
@@ -1049,7 +1048,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
|
||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "myhost.", ptr.Ptr)
|
||||
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
|
||||
}
|
||||
|
||||
func TestPTRResponseFromHosts(t *testing.T) {
|
||||
@@ -1175,25 +1174,48 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServer_Exchange(t *testing.T) {
|
||||
extUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
const (
|
||||
onesHost = "one.one.one.one"
|
||||
localDomainHost = "local.domain"
|
||||
)
|
||||
|
||||
var (
|
||||
onesIP = net.IP{1, 1, 1, 1}
|
||||
localIP = net.IP{192, 168, 1, 1}
|
||||
)
|
||||
|
||||
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
extUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.168.192.in-addr.arpa.": {"local.domain"},
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
|
||||
revLocIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
upstreamErr := errors.Error("upstream error")
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: upstreamErr,
|
||||
}
|
||||
nonPtrUpstream := &aghtest.TestBlockUpstream{
|
||||
Hostname: "some-host",
|
||||
Block: true,
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
|
||||
|
||||
srv := NewCustomServer(&proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
@@ -1207,7 +1229,6 @@ func TestServer_Exchange(t *testing.T) {
|
||||
|
||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
testCases := []struct {
|
||||
name string
|
||||
want string
|
||||
@@ -1216,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
|
||||
req net.IP
|
||||
}{{
|
||||
name: "external_good",
|
||||
want: "one.one.one.one",
|
||||
want: onesHost,
|
||||
wantErr: nil,
|
||||
locUpstream: nil,
|
||||
req: net.IP{1, 1, 1, 1},
|
||||
req: onesIP,
|
||||
}, {
|
||||
name: "local_good",
|
||||
want: "local.domain",
|
||||
want: localDomainHost,
|
||||
wantErr: nil,
|
||||
locUpstream: locUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
name: "upstream_error",
|
||||
want: "",
|
||||
wantErr: upstreamErr,
|
||||
wantErr: aghtest.ErrUpstream,
|
||||
locUpstream: errUpstream,
|
||||
req: localIP,
|
||||
}, {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +16,8 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type dnsConfig struct {
|
||||
@@ -363,6 +364,21 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
var ups string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = validateUpstream(ups, domains)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
&upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout},
|
||||
@@ -373,13 +389,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
}
|
||||
|
||||
for _, u := range upstreams {
|
||||
_, err = validateUpstream(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@@ -393,20 +402,6 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// stringKeysSorted returns the sorted slice of string keys of m.
|
||||
//
|
||||
// TODO(e.burkov): Use generics in Go 1.18. Move into golibs.
|
||||
func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) {
|
||||
sorted = make([]string, 0, len(m))
|
||||
for s := range m {
|
||||
sorted = append(sorted, s)
|
||||
}
|
||||
|
||||
sort.Strings(sorted)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified. It also
|
||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||
@@ -421,9 +416,11 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
keys := maps.Keys(conf.DomainReservedUpstreams)
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) {
|
||||
var errs []error
|
||||
for _, domain := range keys {
|
||||
var subnet *net.IPNet
|
||||
subnet, err = netutil.SubnetFromReversedAddr(domain)
|
||||
if err != nil {
|
||||
@@ -449,16 +446,14 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
||||
|
||||
var protocols = []string{"udp://", "tcp://", "tls://", "https://", "sdns://", "quic://"}
|
||||
|
||||
func validateUpstream(u string) (useDefault bool, err error) {
|
||||
// Check if the user tries to specify upstream for domain.
|
||||
var isDomainSpec bool
|
||||
u, isDomainSpec, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
return !isDomainSpec, err
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
@@ -485,12 +480,14 @@ func validateUpstream(u string) (useDefault bool, err error) {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstream without the specified domains.
|
||||
// isDomainSpec is true when the upstream is domains-specific.
|
||||
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
|
||||
// separateUpstream returns the upstream and the specified domains. domains is
|
||||
// nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, false, nil
|
||||
return upstreamStr, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
@@ -498,39 +495,46 @@ func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, e
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return "", false, errors.Error("missing separator")
|
||||
return "", nil, errors.Error("missing separator")
|
||||
default:
|
||||
return "", true, errors.Error("duplicated separator")
|
||||
return "", []string{}, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
var domains string
|
||||
domains, upstream = parts[0], parts[1]
|
||||
for i, host := range strings.Split(domains, "/") {
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err = netutil.ValidateDomainName(host)
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return upstream, true, nil
|
||||
return parts[1], domains, nil
|
||||
}
|
||||
|
||||
// excFunc is a signature of function to check if upstream exchanges correctly.
|
||||
type excFunc func(u upstream.Upstream) (err error)
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6761#section-6.2.
|
||||
const testTLD = "test."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "google-public-dns-a.google.com.",
|
||||
Name: testTLD,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
@@ -540,12 +544,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
reply, err = u.Exchange(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||
}
|
||||
|
||||
if len(reply.Answer) != 1 {
|
||||
return fmt.Errorf("wrong response")
|
||||
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
|
||||
return fmt.Errorf("wrong response")
|
||||
} else if len(reply.Answer) != 0 {
|
||||
return errors.Error("wrong response")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -553,14 +553,22 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc1035#section-3.5.
|
||||
const inAddrArpaTLD = "in-addr.arpa."
|
||||
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "1.0.0.127.in-addr.arpa.",
|
||||
Name: inAddrArpaTLD,
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
@@ -573,46 +581,66 @@ func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFunc) (err error) {
|
||||
if IsCommentOrEmpty(input) {
|
||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||
// the tested upstream domain-specific and therefore consider its errors
|
||||
// non-critical.
|
||||
//
|
||||
// TODO(a.garipov): Some common mechanism of distinguishing between errors and
|
||||
// warnings (non-critical errors) is desired.
|
||||
type domainSpecificTestError struct {
|
||||
error
|
||||
}
|
||||
|
||||
// checkDNS checks the upstream server defined by upstreamConfigStr using
|
||||
// healthCheck for actually exchange messages. It uses bootstrap to resolve the
|
||||
// upstream's address.
|
||||
func checkDNS(
|
||||
upstreamConfigStr string,
|
||||
bootstrap []string,
|
||||
timeout time.Duration,
|
||||
healthCheck healthCheckFunc,
|
||||
) (err error) {
|
||||
if IsCommentOrEmpty(upstreamConfigStr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Separate upstream from domains list.
|
||||
var useDefault bool
|
||||
if useDefault, err = validateUpstream(input); err != nil {
|
||||
upstreamAddr, domains, err := separateUpstream(upstreamConfigStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
// No need to check this DNS server.
|
||||
if !useDefault {
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
if input, _, err = separateUpstream(input); err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
if len(bootstrap) == 0 {
|
||||
bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
log.Debug("checking if upstream %s works", input)
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", input, err)
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
|
||||
if err = ef(u); err != nil {
|
||||
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
|
||||
if err = healthCheck(u); err != nil {
|
||||
err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err)
|
||||
if domains != nil {
|
||||
return domainSpecificTestError{error: err}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("upstream %s is ok", input)
|
||||
log.Debug("dnsforward: upstream %q is ok", upstreamAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -635,6 +663,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.Info("%v", err)
|
||||
result[host] = err.Error()
|
||||
if _, ok := err.(domainSpecificTestError); ok {
|
||||
result[host] = fmt.Sprintf("WARNING: %s", result[host])
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -650,6 +681,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
// above, we rewriting the error for it. These cases should be
|
||||
// handled properly instead.
|
||||
result[host] = err.Error()
|
||||
if _, ok := err.(domainSpecificTestError); ok {
|
||||
result[host] = fmt.Sprintf("WARNING: %s", result[host])
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func (fsr *fakeSystemResolvers) Get() (rs []string) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadTestData(t *testing.T, casesFileName string, cases interface{}) {
|
||||
func loadTestData(t *testing.T, casesFileName string, cases any) {
|
||||
t.Helper()
|
||||
|
||||
var f *os.File
|
||||
@@ -185,7 +185,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating upstream servers: bad ipport address "!!!": ` +
|
||||
wantSet: `validating upstream servers: ` +
|
||||
`validating upstream "!!!": bad ipport address "!!!": ` +
|
||||
`address !!!: missing port in address`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
@@ -256,112 +257,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
wantDef assert.BoolAssertionFunc
|
||||
name string
|
||||
upstream string
|
||||
wantErr string
|
||||
}{{
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "1.2.3.4.5",
|
||||
wantErr: `bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "123.3.7m",
|
||||
wantErr: `bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "htttps://google.com/dns-query",
|
||||
wantErr: `wrong protocol`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[/host.com]tls://dns.adguard.com",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "invalid",
|
||||
upstream: "[host.ru]#",
|
||||
wantErr: `bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "tls://1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "https://dns.adguard.com/dns-query",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "valid_default",
|
||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_host",
|
||||
upstream: "udp://dns.google",
|
||||
}, {
|
||||
wantDef: assert.True,
|
||||
name: "default_udp_ip",
|
||||
upstream: "udp://8.8.8.8",
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/]1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[//]tls://1.1.1.1",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/www.host.com/]#",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "valid",
|
||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "idna",
|
||||
upstream: "[/пример.рф/]8.8.8.8",
|
||||
wantErr: ``,
|
||||
}, {
|
||||
wantDef: assert.False,
|
||||
name: "bad_domain",
|
||||
upstream: "[/!/]8.8.8.8",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
||||
tc.wantDef(t, defaultUpstream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreams(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -376,7 +271,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
wantErr: ``,
|
||||
set: []string{"# comment"},
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
name: "no_default",
|
||||
wantErr: `no default upstreams specified`,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
@@ -386,7 +281,7 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
name: "with_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
@@ -398,8 +293,46 @@ func TestValidateUpstreams(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `cannot prepare the upstream dhcp://fake.dns ([]): unsupported url scheme: dhcp`,
|
||||
wantErr: `validating upstream "dhcp://fake.dns": wrong protocol`,
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "1.2.3.4.5": bad ipport address "1.2.3.4.5": address 1.2.3.4.5: missing port in address`,
|
||||
set: []string{"1.2.3.4.5"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "123.3.7m": bad ipport address "123.3.7m": address 123.3.7m: missing port in address`,
|
||||
set: []string{"123.3.7m"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `bad upstream for domain "[/host.com]tls://dns.adguard.com": missing separator`,
|
||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
||||
}, {
|
||||
name: "invalid",
|
||||
wantErr: `validating upstream "[host.ru]#": bad ipport address "[host.ru]#": address [host.ru]#: missing port in address`,
|
||||
set: []string{"[host.ru]#"},
|
||||
}, {
|
||||
name: "valid_default",
|
||||
wantErr: ``,
|
||||
set: []string{
|
||||
"1.1.1.1",
|
||||
"tls://1.1.1.1",
|
||||
"https://dns.adguard.com/dns-query",
|
||||
"sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"udp://dns.google",
|
||||
"udp://8.8.8.8",
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"[/пример.рф/]8.8.8.8",
|
||||
},
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
wantErr: `bad upstream for domain "[/!/]8.8.8.8": domain at index 0: ` +
|
||||
`bad domain name "!": bad domain name label "!": bad domain name label rune '!'`,
|
||||
set: []string{"[/!/]8.8.8.8"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -64,9 +64,9 @@ func (s *Server) logQuery(
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: dctx.origResp,
|
||||
Result: dctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientID: dctx.clientID,
|
||||
ClientIP: ip,
|
||||
Elapsed: elapsed,
|
||||
AuthenticatedData: dctx.responseAD,
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func (l *testQueryLog) Add(p *querylog.AddParams) {
|
||||
type testStats struct {
|
||||
// Stats is embedded here simply to make testStats a stats.Stats without
|
||||
// actually implementing all methods.
|
||||
stats.Stats
|
||||
stats.Interface
|
||||
|
||||
lastEntry stats.Entry
|
||||
}
|
||||
|
||||
@@ -20,8 +20,12 @@ type svc struct {
|
||||
// client/src/helpers/constants.js
|
||||
// client/src/components/ui/Icons.js
|
||||
var serviceRulesArray = []svc{{
|
||||
name: "whatsapp",
|
||||
rules: []string{"||whatsapp.net^", "||whatsapp.com^"},
|
||||
name: "whatsapp",
|
||||
rules: []string{
|
||||
"||wa.me^",
|
||||
"||whatsapp.com^",
|
||||
"||whatsapp.net^",
|
||||
},
|
||||
}, {
|
||||
name: "facebook",
|
||||
rules: []string{
|
||||
@@ -31,29 +35,43 @@ var serviceRulesArray = []svc{{
|
||||
"||accountkit.com^",
|
||||
"||fb.me^",
|
||||
"||fb.com^",
|
||||
"||fb.gg^",
|
||||
"||fbsbx.com^",
|
||||
"||fbwat.ch^",
|
||||
"||messenger.com^",
|
||||
"||facebookcorewwwi.onion^",
|
||||
"||fbcdn.com^",
|
||||
"||fb.watch^",
|
||||
},
|
||||
}, {
|
||||
name: "twitter",
|
||||
rules: []string{"||twitter.com^", "||twttr.com^", "||t.co^", "||twimg.com^"},
|
||||
name: "twitter",
|
||||
rules: []string{
|
||||
"||t.co^",
|
||||
"||twimg.com^",
|
||||
"||twitter.com^",
|
||||
"||twttr.com^",
|
||||
},
|
||||
}, {
|
||||
name: "youtube",
|
||||
rules: []string{
|
||||
"||youtube.com^",
|
||||
"||ytimg.com^",
|
||||
"||youtu.be^",
|
||||
"||googlevideo.com^",
|
||||
"||youtubei.googleapis.com^",
|
||||
"||youtube-nocookie.com^",
|
||||
"||wide-youtube.l.google.com^",
|
||||
"||youtu.be^",
|
||||
"||youtube",
|
||||
"||youtube-nocookie.com^",
|
||||
"||youtube.com^",
|
||||
"||youtubei.googleapis.com^",
|
||||
"||youtubekids.com^",
|
||||
"||ytimg.com^",
|
||||
},
|
||||
}, {
|
||||
name: "twitch",
|
||||
rules: []string{"||twitch.tv^", "||ttvnw.net^", "||jtvnw.net^", "||twitchcdn.net^"},
|
||||
name: "twitch",
|
||||
rules: []string{
|
||||
"||jtvnw.net^",
|
||||
"||ttvnw.net^",
|
||||
"||twitch.tv^",
|
||||
"||twitchcdn.net^",
|
||||
},
|
||||
}, {
|
||||
name: "netflix",
|
||||
rules: []string{
|
||||
@@ -83,20 +101,36 @@ var serviceRulesArray = []svc{{
|
||||
"||discordapp.net^",
|
||||
"||discordapp.com^",
|
||||
"||discord.com^",
|
||||
"||discord.gift",
|
||||
"||discord.media^",
|
||||
},
|
||||
}, {
|
||||
name: "ok",
|
||||
rules: []string{"||ok.ru^"},
|
||||
}, {
|
||||
name: "skype",
|
||||
rules: []string{"||skype.com^", "||skypeassets.com^"},
|
||||
name: "skype",
|
||||
rules: []string{
|
||||
"||edge-skype-com.s-0001.s-msedge.net^",
|
||||
"||skype-edf.akadns.net^",
|
||||
"||skype.com^",
|
||||
"||skypeassets.com^",
|
||||
"||skypedata.akadns.net^",
|
||||
},
|
||||
}, {
|
||||
name: "vk",
|
||||
rules: []string{"||vk.com^", "||userapi.com^", "||vk-cdn.net^", "||vkuservideo.net^"},
|
||||
name: "vk",
|
||||
rules: []string{
|
||||
"||userapi.com^",
|
||||
"||vk-cdn.net^",
|
||||
"||vk.com^",
|
||||
"||vkuservideo.net^",
|
||||
},
|
||||
}, {
|
||||
name: "origin",
|
||||
rules: []string{"||origin.com^", "||signin.ea.com^", "||accounts.ea.com^"},
|
||||
name: "origin",
|
||||
rules: []string{
|
||||
"||accounts.ea.com^",
|
||||
"||origin.com^",
|
||||
"||signin.ea.com^",
|
||||
},
|
||||
}, {
|
||||
name: "steam",
|
||||
rules: []string{
|
||||
@@ -160,6 +194,7 @@ var serviceRulesArray = []svc{{
|
||||
"||amazon.com.br^",
|
||||
"||amazon.co.jp^",
|
||||
"||amazon.com.mx^",
|
||||
"||amazon.com.tr^",
|
||||
"||amazon.co.uk^",
|
||||
"||createspace.com^",
|
||||
"||aws",
|
||||
@@ -209,47 +244,81 @@ var serviceRulesArray = []svc{{
|
||||
"||toutiaocloud.net^",
|
||||
"||bdurl.com^",
|
||||
"||bytecdn.cn^",
|
||||
"||bytedapm.com^",
|
||||
"||byteimg.com^",
|
||||
"||byteoversea.com^",
|
||||
"||ixigua.com^",
|
||||
"||muscdn.com^",
|
||||
"||bytedance.map.fastly.net^",
|
||||
"||douyin.com^",
|
||||
"||tiktokv.com^",
|
||||
"||toutiaovod.com^",
|
||||
"||douyincdn.com^",
|
||||
},
|
||||
}, {
|
||||
name: "vimeo",
|
||||
rules: []string{"||vimeo.com^", "||vimeocdn.com^", "*vod-adaptive.akamaized.net^"},
|
||||
name: "vimeo",
|
||||
rules: []string{
|
||||
"*vod-adaptive.akamaized.net^",
|
||||
"||vimeo.com^",
|
||||
"||vimeocdn.com^",
|
||||
},
|
||||
}, {
|
||||
name: "pinterest",
|
||||
rules: []string{"||pinterest.*^", "||pinimg.com^"},
|
||||
name: "pinterest",
|
||||
rules: []string{
|
||||
"||pinimg.com^",
|
||||
"||pinterest.*^",
|
||||
},
|
||||
}, {
|
||||
name: "imgur",
|
||||
rules: []string{"||imgur.com^"},
|
||||
}, {
|
||||
name: "dailymotion",
|
||||
rules: []string{"||dailymotion.com^", "||dm-event.net^", "||dmcdn.net^"},
|
||||
name: "dailymotion",
|
||||
rules: []string{
|
||||
"||dailymotion.com^",
|
||||
"||dm-event.net^",
|
||||
"||dmcdn.net^",
|
||||
},
|
||||
}, {
|
||||
name: "qq",
|
||||
rules: []string{
|
||||
// Block qq.com and subdomains excluding WeChat's domains.
|
||||
"||qq.com^$denyallow=wx.qq.com|weixin.qq.com",
|
||||
"||qqzaixian.com^",
|
||||
"||qq-video.cdn-go.cn^",
|
||||
"||url.cn^",
|
||||
},
|
||||
}, {
|
||||
name: "wechat",
|
||||
rules: []string{"||wechat.com^", "||weixin.qq.com^", "||wx.qq.com^"},
|
||||
name: "wechat",
|
||||
rules: []string{
|
||||
"||wechat.com^",
|
||||
"||weixin.qq.com.cn^",
|
||||
"||weixin.qq.com^",
|
||||
"||weixinbridge.com^",
|
||||
"||wx.qq.com^",
|
||||
},
|
||||
}, {
|
||||
name: "viber",
|
||||
rules: []string{"||viber.com^"},
|
||||
}, {
|
||||
name: "weibo",
|
||||
rules: []string{"||weibo.com^"},
|
||||
name: "weibo",
|
||||
rules: []string{
|
||||
"||weibo.cn^",
|
||||
"||weibo.com^",
|
||||
"||weibocdn.com^",
|
||||
},
|
||||
}, {
|
||||
name: "9gag",
|
||||
rules: []string{"||9cache.com^", "||9gag.com^"},
|
||||
name: "9gag",
|
||||
rules: []string{
|
||||
"||9cache.com^",
|
||||
"||9gag.com^",
|
||||
},
|
||||
}, {
|
||||
name: "telegram",
|
||||
rules: []string{"||t.me^", "||telegram.me^", "||telegram.org^"},
|
||||
name: "telegram",
|
||||
rules: []string{
|
||||
"||t.me^",
|
||||
"||telegram.me^",
|
||||
"||telegram.org^",
|
||||
},
|
||||
}, {
|
||||
name: "disneyplus",
|
||||
rules: []string{
|
||||
@@ -283,6 +352,17 @@ var serviceRulesArray = []svc{{
|
||||
"||tinder.com^",
|
||||
"||tindersparks.com^",
|
||||
},
|
||||
}, {
|
||||
name: "bilibili",
|
||||
rules: []string{
|
||||
"||biliapi.net^",
|
||||
"||bilibili.com^",
|
||||
"||biligame.com^",
|
||||
"||bilivideo.cn^",
|
||||
"||bilivideo.com^",
|
||||
"||dreamcast.hk^",
|
||||
"||hdslb.com^",
|
||||
},
|
||||
}}
|
||||
|
||||
// convert array to map
|
||||
|
||||
@@ -61,22 +61,22 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
|
||||
testCasesA := []struct {
|
||||
name string
|
||||
want []interface{}
|
||||
want []any
|
||||
rcode int
|
||||
dtyp uint16
|
||||
}{{
|
||||
name: "a-record",
|
||||
rcode: dns.RcodeSuccess,
|
||||
want: []interface{}{ipv4p1},
|
||||
want: []any{ipv4p1},
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "aaaa-record",
|
||||
want: []interface{}{ipv6p1},
|
||||
want: []any{ipv6p1},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "txt-record",
|
||||
want: []interface{}{"hello-world"},
|
||||
want: []any{"hello-world"},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeTXT,
|
||||
}, {
|
||||
@@ -86,22 +86,22 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
dtyp: 0,
|
||||
}, {
|
||||
name: "a-records",
|
||||
want: []interface{}{ipv4p1, ipv4p2},
|
||||
want: []any{ipv4p1, ipv4p2},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "aaaa-records",
|
||||
want: []interface{}{ipv6p1, ipv6p2},
|
||||
want: []any{ipv6p1, ipv6p2},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeAAAA,
|
||||
}, {
|
||||
name: "disable-one",
|
||||
want: []interface{}{ipv4p2},
|
||||
want: []any{ipv4p2},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeA,
|
||||
}, {
|
||||
name: "disable-cname",
|
||||
want: []interface{}{ipv4p1},
|
||||
want: []any{ipv4p1},
|
||||
rcode: dns.RcodeSuccess,
|
||||
dtyp: dns.TypeA,
|
||||
}}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
@@ -94,7 +94,7 @@ type Config struct {
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
// CustomResolver is the resolver used by DNSFilter.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
|
||||
@@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
const (
|
||||
sbBlocked = "wmconvirus.narod.ru"
|
||||
pcBlocked = "pornhub.com"
|
||||
)
|
||||
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
@@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
d.checkMatch(t, matching)
|
||||
|
||||
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
d.checkMatch(t, sbBlocked)
|
||||
|
||||
d.checkMatch(t, "test."+matching)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
||||
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
|
||||
// Cached result.
|
||||
d.safeBrowsingServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
t.Run("group", func(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatch(t, "test."+matching)
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "pornhub.com")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {
|
||||
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const matching = "pornhub.com"
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: matching,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.checkMatch(t, matching)
|
||||
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.checkMatch(t, pcBlocked)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
||||
|
||||
d.checkMatch(t, "www."+matching)
|
||||
d.checkMatch(t, "www."+pcBlocked)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "api.jquery.com")
|
||||
|
||||
// Test cached result.
|
||||
d.parentalServer = "127.0.0.1"
|
||||
d.checkMatch(t, matching)
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
}
|
||||
|
||||
@@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
|
||||
}, {
|
||||
name: "sanity",
|
||||
rules: "||doubleclick.net^",
|
||||
host: "wmconvirus.narod.ru",
|
||||
host: sbBlocked,
|
||||
wantIsFiltered: false,
|
||||
wantReason: NotFilteredNotFound,
|
||||
wantDNSType: dns.TypeA,
|
||||
@@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
|
||||
}},
|
||||
)
|
||||
t.Cleanup(d.Close)
|
||||
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: "pornhub.com",
|
||||
Block: true,
|
||||
})
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: "wmconvirus.narod.ru",
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
|
||||
wantReason: FilteredBlockList,
|
||||
}, {
|
||||
name: "parental",
|
||||
host: "pornhub.com",
|
||||
host: pcBlocked,
|
||||
before: true,
|
||||
wantReason: FilteredParental,
|
||||
}, {
|
||||
name: "safebrowsing",
|
||||
host: "wmconvirus.narod.ru",
|
||||
host: sbBlocked,
|
||||
before: false,
|
||||
wantReason: FilteredSafeBrowsing,
|
||||
}, {
|
||||
@@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: blocked,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
blocked := "wmconvirus.narod.ru"
|
||||
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
|
||||
Hostname: blocked,
|
||||
Block: true,
|
||||
})
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,10 +24,11 @@ import (
|
||||
|
||||
// Safe browsing and parental control methods.
|
||||
|
||||
// TODO(a.garipov): Make configurable.
|
||||
const (
|
||||
dnsTimeout = 3 * time.Second
|
||||
defaultSafebrowsingServer = `https://dns-family.adguard.com/dns-query`
|
||||
defaultParentalServer = `https://dns-family.adguard.com/dns-query`
|
||||
defaultSafebrowsingServer = `https://family.adguard-dns.com/dns-query`
|
||||
defaultParentalServer = `https://family.adguard-dns.com/dns-query`
|
||||
sbTXTSuffix = `sb.dns.adguard.com.`
|
||||
pcTXTSuffix = `pc.dns.adguard.com.`
|
||||
)
|
||||
@@ -313,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(
|
||||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
|
||||
defer timer.LogElapsed("safebrowsing lookup for %q", host)
|
||||
}
|
||||
|
||||
sctx := &sbCtx{
|
||||
@@ -347,7 +348,7 @@ func (d *DNSFilter) checkParental(
|
||||
|
||||
if log.GetLevel() >= log.DEBUG {
|
||||
timer := log.StartTimer()
|
||||
defer timer.LogElapsed("Parental lookup for %s", host)
|
||||
defer timer.LogElapsed("parental lookup for %q", host)
|
||||
}
|
||||
|
||||
sctx := &sbCtx{
|
||||
|
||||
@@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
assert.Equal(t, -1, c.getCached())
|
||||
|
||||
// match "sub.host.com" from cache,
|
||||
// but another hash for "nonexisting.com" is not in cache
|
||||
// which means that we must get data from server for it
|
||||
// Match "sub.host.com" from cache. Another hash for "host.example" is not
|
||||
// in the cache, so get data for it from the server.
|
||||
c.hashToHost = make(map[[32]byte]string)
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
c.hashToHost[hash] = "nonexisting.com"
|
||||
hash = sha256.Sum256([]byte("host.example"))
|
||||
c.hashToHost[hash] = "host.example"
|
||||
assert.Empty(t, c.getCached())
|
||||
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
_, ok := c.hashToHost[hash]
|
||||
assert.False(t, ok)
|
||||
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
hash = sha256.Sum256([]byte("host.example"))
|
||||
_, ok = c.hashToHost[hash]
|
||||
assert.True(t, ok)
|
||||
|
||||
@@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := &aghtest.TestErrUpstream{}
|
||||
|
||||
ups := aghtest.NewErrorUpstream()
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
@@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
// Prepare the upstream.
|
||||
ups := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: tc.block,
|
||||
ups := aghtest.NewBlockUpstream(hostname, tc.block)
|
||||
|
||||
var numReq int
|
||||
onExchange := ups.OnExchange
|
||||
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
numReq++
|
||||
|
||||
return onExchange(req)
|
||||
}
|
||||
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
@@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) {
|
||||
assert.Equal(t, hits, tc.testCache.Stats().Hit)
|
||||
|
||||
// There was one request to an upstream.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
assert.Equal(t, 1, numReq)
|
||||
|
||||
// Now make the same request to check the cache was used.
|
||||
res, err = tc.testFunc(hostname, dns.TypeA, setts)
|
||||
@@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) {
|
||||
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
|
||||
|
||||
// Check that there were no additional requests.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
assert.Equal(t, 1, numReq)
|
||||
})
|
||||
|
||||
purgeCaches(d)
|
||||
|
||||
@@ -2,9 +2,11 @@ package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -59,6 +61,33 @@ const (
|
||||
ClientSourceHostsFile
|
||||
)
|
||||
|
||||
var _ fmt.Stringer = clientSource(0)
|
||||
|
||||
// String returns a human-readable name of cs.
|
||||
func (cs clientSource) String() (s string) {
|
||||
switch cs {
|
||||
case ClientSourceWHOIS:
|
||||
return "WHOIS"
|
||||
case ClientSourceARP:
|
||||
return "ARP"
|
||||
case ClientSourceRDNS:
|
||||
return "rDNS"
|
||||
case ClientSourceDHCP:
|
||||
return "DHCP"
|
||||
case ClientSourceHostsFile:
|
||||
return "etc/hosts"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var _ encoding.TextMarshaler = clientSource(0)
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler for the clientSource.
|
||||
func (cs clientSource) MarshalText() (text []byte, err error) {
|
||||
return []byte(cs.String()), nil
|
||||
}
|
||||
|
||||
// clientSourceConf is used to configure where the runtime clients will be
|
||||
// obtained from.
|
||||
type clientSourcesConf struct {
|
||||
@@ -396,6 +425,7 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
c.Tags = stringutil.CloneSlice(c.Tags)
|
||||
c.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
|
||||
c.Upstreams = stringutil.CloneSlice(c.Upstreams)
|
||||
|
||||
return c, true
|
||||
}
|
||||
|
||||
@@ -492,7 +522,7 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||
// internal use only.
|
||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
var v interface{}
|
||||
var v any
|
||||
v, ok = clients.ipToRC.Get(ip)
|
||||
if !ok {
|
||||
return nil, false
|
||||
@@ -546,7 +576,7 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
} else if mac, err = net.ParseMAC(id); err == nil {
|
||||
c.IDs[i] = mac.String()
|
||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||
c.IDs[i] = id
|
||||
c.IDs[i] = strings.ToLower(id)
|
||||
} else {
|
||||
return fmt.Errorf("invalid clientid at index %d: %q", i, id)
|
||||
}
|
||||
@@ -742,8 +772,7 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.findRuntimeClientLocked(ip)
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
@@ -769,7 +798,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
n := 0
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
@@ -797,26 +826,21 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) {
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
hosts, ok := v.(*stringutil.Set)
|
||||
hosts.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rec, ok := v.(*aghnet.HostsRecord)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
hosts.Range(func(name string) (cont bool) {
|
||||
if clients.addHostLocked(ip, name, ClientSourceHostsFile) {
|
||||
n++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile)
|
||||
n++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts-file", n)
|
||||
log.Debug("clients: added %d client aliases from system hosts file", n)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
|
||||
@@ -47,9 +47,9 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source clientSource `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
@@ -70,7 +70,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
clients.ipToRC.Range(func(ip net.IP, v any) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
@@ -81,20 +81,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
cj := runtimeClientJSON{
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
|
||||
Name: rc.Host,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
cj.Source = "etc/hosts"
|
||||
switch rc.Source {
|
||||
case ClientSourceDHCP:
|
||||
cj.Source = "DHCP"
|
||||
case ClientSourceRDNS:
|
||||
cj.Source = "rDNS"
|
||||
case ClientSourceARP:
|
||||
cj.Source = "ARP"
|
||||
case ClientSourceWHOIS:
|
||||
cj.Source = "WHOIS"
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
@@ -107,13 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to encode to json: %v",
|
||||
e,
|
||||
)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -279,9 +262,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime
|
||||
// clients list, but then the server was reloaded. So, check
|
||||
// the DNS server's blocked IP list.
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
// list, but then the server was reloaded. So, check the DNS server's
|
||||
// blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,15 +28,36 @@ const (
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
|
||||
// logSettings
|
||||
// logSettings are the logging settings part of the configuration file.
|
||||
//
|
||||
// TODO(a.garipov): Put them into a separate object.
|
||||
type logSettings struct {
|
||||
LogCompress bool `yaml:"log_compress"` // Compress determines if the rotated log files should be compressed using gzip (default: false)
|
||||
LogLocalTime bool `yaml:"log_localtime"` // If the time used for formatting the timestamps in is the computer's local time (default: false [UTC])
|
||||
LogMaxBackups int `yaml:"log_max_backups"` // Maximum number of old log files to retain (MaxAge may still cause them to get deleted)
|
||||
LogMaxSize int `yaml:"log_max_size"` // Maximum size in megabytes of the log file before it gets rotated (default 100 MB)
|
||||
LogMaxAge int `yaml:"log_max_age"` // MaxAge is the maximum number of days to retain old log files
|
||||
LogFile string `yaml:"log_file"` // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
Verbose bool `yaml:"verbose"` // If true, verbose logging is enabled
|
||||
// File is the path to the log file. If empty, logs are written to stdout.
|
||||
// If "syslog", logs are written to syslog.
|
||||
File string `yaml:"log_file"`
|
||||
|
||||
// MaxBackups is the maximum number of old log files to retain.
|
||||
//
|
||||
// NOTE: MaxAge may still cause them to get deleted.
|
||||
MaxBackups int `yaml:"log_max_backups"`
|
||||
|
||||
// MaxSize is the maximum size of the log file before it gets rotated, in
|
||||
// megabytes. The default value is 100 MB.
|
||||
MaxSize int `yaml:"log_max_size"`
|
||||
|
||||
// MaxAge is the maximum duration for retaining old log files, in days.
|
||||
MaxAge int `yaml:"log_max_age"`
|
||||
|
||||
// Compress determines, if the rotated log files should be compressed using
|
||||
// gzip.
|
||||
Compress bool `yaml:"log_compress"`
|
||||
|
||||
// LocalTime determines, if the time used for formatting the timestamps in
|
||||
// is the computer's local time.
|
||||
LocalTime bool `yaml:"log_localtime"`
|
||||
|
||||
// Verbose determines, if verbose (aka debug) logging is enabled.
|
||||
Verbose bool `yaml:"verbose"`
|
||||
}
|
||||
|
||||
// osConfig contains OS-related configuration.
|
||||
@@ -223,11 +245,11 @@ var config = &configuration{
|
||||
},
|
||||
},
|
||||
logSettings: logSettings{
|
||||
LogCompress: false,
|
||||
LogLocalTime: false,
|
||||
LogMaxBackups: 0,
|
||||
LogMaxSize: 100,
|
||||
LogMaxAge: 3,
|
||||
Compress: false,
|
||||
LocalTime: false,
|
||||
MaxBackups: 0,
|
||||
MaxSize: 100,
|
||||
MaxAge: 3,
|
||||
},
|
||||
OSConfig: &osConfig{},
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
@@ -302,27 +324,28 @@ func parseConfig() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
udpPort(config.DNS.Port),
|
||||
)
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(config.BindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
// TODO(e.burkov): Consider adding a udpPort with the same value if
|
||||
// we ever support the HTTP/3 for web admin interface.
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
udpPort(config.TLS.PortDNSOverQUIC),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
// TODO(e.burkov): Consider adding a udpPort with the same value when
|
||||
// we add support for HTTP/3 for web admin interface.
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
@@ -342,23 +365,11 @@ type udpPort int
|
||||
// tcpPort is the port number for TCP protocol.
|
||||
type tcpPort int
|
||||
|
||||
// addPorts is a helper for ports validation. It skips zero ports. Each of
|
||||
// ports should be either a udpPort or a tcpPort.
|
||||
func addPorts(uc aghalg.UniqChecker, ports ...interface{}) {
|
||||
// addPorts is a helper for ports validation that skips zero ports.
|
||||
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
|
||||
for _, p := range ports {
|
||||
// Use separate cases for tcpPort and udpPort so that the untyped
|
||||
// constant zero is converted to the appropriate type.
|
||||
switch p := p.(type) {
|
||||
case tcpPort:
|
||||
if p != 0 {
|
||||
uc.Add(p)
|
||||
}
|
||||
case udpPort:
|
||||
if p != 0 {
|
||||
uc.Add(p)
|
||||
}
|
||||
default:
|
||||
// Go on.
|
||||
if p != 0 {
|
||||
uc.Add(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -377,13 +388,14 @@ func readConfigFile() (fileData []byte, err error) {
|
||||
}
|
||||
|
||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||
func (c *configuration) write() error {
|
||||
func (c *configuration) write() (err error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
@@ -429,19 +441,20 @@ func (c *configuration) write() error {
|
||||
config.Clients.Persistent = Context.clients.forConfig()
|
||||
|
||||
configFile := config.getConfigFilename()
|
||||
log.Debug("Writing YAML file: %s", configFile)
|
||||
yamlText, err := yaml.Marshal(&config)
|
||||
if err != nil {
|
||||
log.Error("Couldn't generate YAML file: %s", err)
|
||||
log.Debug("writing config file %q", configFile)
|
||||
|
||||
return err
|
||||
buf := &bytes.Buffer{}
|
||||
enc := yaml.NewEncoder(buf)
|
||||
enc.SetIndent(2)
|
||||
|
||||
err = enc.Encode(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating config file: %w", err)
|
||||
}
|
||||
|
||||
err = maybe.WriteFile(configFile, yamlText, 0o644)
|
||||
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
|
||||
if err != nil {
|
||||
log.Error("Couldn't save YAML config: %s", err)
|
||||
|
||||
return err
|
||||
return fmt.Errorf("writing config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -189,7 +189,7 @@ func registerControlHandlers() {
|
||||
RegisterAuthHandlers()
|
||||
}
|
||||
|
||||
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||
if method == "" {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
Context.mux.HandleFunc(url, postInstall(handler))
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -57,8 +58,8 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
err = validateFilterURL(fj.URL)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("invalid url: %s", err)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
err = fmt.Errorf("invalid url: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -178,16 +179,16 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
}
|
||||
|
||||
type filterURLJSON struct {
|
||||
type filterURLReqData struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filterURLReq struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
Data filterURLJSON `json:"data"`
|
||||
Data *filterURLReqData `json:"data"`
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -199,10 +200,17 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if fj.Data == nil {
|
||||
err = errors.Error("data cannot be null")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = validateFilterURL(fj.Data.URL)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("invalid url: %s", err)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
err = fmt.Errorf("invalid url: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -223,11 +231,8 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
restart := false
|
||||
if (status & statusEnabledChanged) != 0 {
|
||||
// we must add or remove filter rules
|
||||
restart = true
|
||||
}
|
||||
|
||||
restart := (status & statusEnabledChanged) != 0
|
||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||
// download new filter and apply its rules
|
||||
flags := filterRefreshBlocklists
|
||||
@@ -242,6 +247,7 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
restart = true
|
||||
}
|
||||
}
|
||||
|
||||
if restart {
|
||||
enableFilters(true)
|
||||
}
|
||||
@@ -311,20 +317,20 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
ID int64 `json:"id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
LastUpdated string `json:"last_updated,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
RulesCount uint32 `json:"rules_count"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type filteringConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
Filters []filterJSON `json:"filters"`
|
||||
WhitelistFilters []filterJSON `json:"whitelist_filters"`
|
||||
UserRules []string `json:"user_rules"`
|
||||
Interval uint32 `json:"interval"` // in hours
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func filterToJSON(f filter) filterJSON {
|
||||
@@ -402,16 +408,12 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
type checkHostRespRule struct {
|
||||
FilterListID int64 `json:"filter_list_id"`
|
||||
Text string `json:"text"`
|
||||
FilterListID int64 `json:"filter_list_id"`
|
||||
}
|
||||
|
||||
type checkHostResp struct {
|
||||
Reason string `json:"reason"`
|
||||
// FilterID is the ID of the rule's filter list.
|
||||
//
|
||||
// Deprecated: Use Rules[*].FilterListID.
|
||||
FilterID int64 `json:"filter_id"`
|
||||
|
||||
// Rule is the text of the matched rule.
|
||||
//
|
||||
@@ -426,6 +428,11 @@ type checkHostResp struct {
|
||||
// for Rewrite:
|
||||
CanonName string `json:"cname"` // CNAME value
|
||||
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
|
||||
|
||||
// FilterID is the ID of the rule's filter list.
|
||||
//
|
||||
// Deprecated: Use Rules[*].FilterListID.
|
||||
FilterID int64 `json:"filter_id"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -105,19 +105,22 @@ type checkConfResp struct {
|
||||
|
||||
// validateWeb returns error is the web part if the initial configuration can't
|
||||
// be set.
|
||||
func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.Web.Port
|
||||
addPorts(uc, tcpPort(config.BetaBindPort), tcpPort(port))
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
// Avoid duplicating the error into the status of DNS.
|
||||
uc[port] = 1
|
||||
portInt := req.Web.Port
|
||||
port := tcpPort(portInt)
|
||||
addPorts(tcpPorts, tcpPort(config.BetaBindPort), port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
// Reset the value for the port to 1 to make sure that validateDNS
|
||||
// doesn't throw the same error, unless the same TCP port is set there
|
||||
// as well.
|
||||
tcpPorts[port] = 1
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
switch port {
|
||||
switch portInt {
|
||||
case 0, config.BindPort:
|
||||
return nil
|
||||
default:
|
||||
@@ -125,21 +128,18 @@ func (req *checkConfReq) validateWeb(uc aghalg.UniqChecker) (err error) {
|
||||
// unbound after install.
|
||||
}
|
||||
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, port)
|
||||
return aghnet.CheckPort("tcp", req.Web.IP, portInt)
|
||||
}
|
||||
|
||||
// validateDNS returns error if the DNS part of the initial configuration can't
|
||||
// be set. canAutofix is true if the port can be unbound by AdGuard Home
|
||||
// automatically.
|
||||
func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, err error) {
|
||||
func (req *checkConfReq) validateDNS(
|
||||
tcpPorts aghalg.UniqChecker[tcpPort],
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := req.DNS.Port
|
||||
addPorts(uc, udpPort(port))
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
@@ -148,6 +148,11 @@ func (req *checkConfReq) validateDNS(uc aghalg.UniqChecker) (canAutofix bool, er
|
||||
// by AdGuard Home for web interface.
|
||||
default:
|
||||
// Check TCP as well.
|
||||
addPorts(tcpPorts, tcpPort(port))
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", req.DNS.IP, port)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -185,13 +190,12 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
resp := &checkConfResp{}
|
||||
uc := aghalg.UniqChecker{}
|
||||
|
||||
if err = req.validateWeb(uc); err != nil {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
if err = req.validateWeb(tcpPorts); err != nil {
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(uc); err != nil {
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
@@ -212,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
|
||||
resp := staticIPJSON{}
|
||||
|
||||
interfaceName := aghnet.GetInterfaceByIP(ip)
|
||||
interfaceName := aghnet.InterfaceByIP(ip)
|
||||
resp.Static = "no"
|
||||
|
||||
if len(interfaceName) == 0 {
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
@@ -117,7 +117,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := Context.updater.Update()
|
||||
// Retain the current absolute path of the executable, since the updater is
|
||||
// likely to change the position current one to the backup directory.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4735.
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "getting path: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
@@ -129,13 +140,10 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// The background context is used because the underlying functions wrap
|
||||
// it with timeout and shut down the server, which handles current
|
||||
// request. It also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
go func() {
|
||||
finishUpdate(context.Background())
|
||||
}()
|
||||
// The background context is used because the underlying functions wrap it
|
||||
// with timeout and shut down the server, which handles current request. It
|
||||
// also should be done in a separate goroutine for the same reason.
|
||||
go finishUpdate(context.Background(), execPath)
|
||||
}
|
||||
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
@@ -147,8 +155,8 @@ type versionResponse struct {
|
||||
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
|
||||
// allowed to perform an automatic update by the OS.
|
||||
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
if vr.CanAutoUpdate == nil || !*vr.CanAutoUpdate {
|
||||
return
|
||||
if vr.CanAutoUpdate != aghalg.NBTrue {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
@@ -162,7 +170,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
vr.CanAutoUpdate = &canUpdate
|
||||
vr.CanAutoUpdate = aghalg.BoolToNullBool(canUpdate)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -174,46 +182,46 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
}
|
||||
|
||||
// finishUpdate completes an update procedure.
|
||||
func finishUpdate(ctx context.Context) {
|
||||
log.Info("Stopping all tasks")
|
||||
func finishUpdate(ctx context.Context, execPath string) {
|
||||
var err error
|
||||
|
||||
log.Info("stopping all tasks")
|
||||
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
|
||||
exeName := "AdGuardHome"
|
||||
if runtime.GOOS == "windows" {
|
||||
exeName = "AdGuardHome.exe"
|
||||
}
|
||||
curBinName := filepath.Join(Context.workDir, exeName)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if Context.runningAsService {
|
||||
// Note:
|
||||
// we can't restart the service via "kardianos/service" package - it kills the process first
|
||||
// we can't start a new instance - Windows doesn't allow it
|
||||
// NOTE: We can't restart the service via "kardianos/service"
|
||||
// package, because it kills the process first we can't start a new
|
||||
// instance, because Windows doesn't allow it.
|
||||
//
|
||||
// TODO(a.garipov): Recheck the claim above.
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
log.Fatalf("restarting: stopping: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
cmd := exec.Command(curBinName, os.Args[1:]...)
|
||||
log.Info("Restarting: %v", cmd.Args)
|
||||
cmd := exec.Command(execPath, os.Args[1:]...)
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("exec.Command() failed: %s", err)
|
||||
log.Fatalf("restarting:: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
} else {
|
||||
log.Info("Restarting: %v", os.Args)
|
||||
err := syscall.Exec(curBinName, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("syscall.Exec() failed: %s", err)
|
||||
}
|
||||
// Unreachable code
|
||||
}
|
||||
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
err = syscall.Exec(execPath, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("restarting: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Default ports.
|
||||
@@ -58,6 +58,7 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
|
||||
conf := querylog.Config{
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
FindClient: Context.clients.findMultiple,
|
||||
@@ -67,7 +68,6 @@ func initDNSServer() (err error) {
|
||||
Enabled: config.DNS.QueryLogEnabled,
|
||||
FileEnabled: config.DNS.QueryLogFileEnabled,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
Anonymizer: anonymizer,
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
@@ -221,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
newConf.TLSConfig.ServerName = tlsConf.ServerName
|
||||
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS)
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS)
|
||||
}
|
||||
@@ -392,7 +396,7 @@ func startDNSServer() error {
|
||||
Context.queryLog.Start()
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
|
||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
@@ -451,7 +455,12 @@ func closeDNSServer() {
|
||||
}
|
||||
|
||||
if Context.stats != nil {
|
||||
Context.stats.Close()
|
||||
err := Context.stats.Close()
|
||||
if err != nil {
|
||||
log.Debug("closing stats: %s", err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Find out if it's safe.
|
||||
Context.stats = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ type homeContext struct {
|
||||
// --
|
||||
|
||||
clients clientsContainer // per-client-settings module
|
||||
stats stats.Stats // statistics module
|
||||
stats stats.Interface // statistics module
|
||||
queryLog querylog.QueryLog // query log module
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
@@ -298,24 +298,27 @@ func setupConfig(args options) (err error) {
|
||||
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
|
||||
|
||||
if args.bindPort != 0 {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
tcpPort(args.bindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
udpPort(config.DNS.Port),
|
||||
)
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort))
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(config.DNS.Port))
|
||||
|
||||
if config.TLS.Enabled {
|
||||
addPorts(
|
||||
uc,
|
||||
tcpPorts,
|
||||
tcpPort(config.TLS.PortHTTPS),
|
||||
tcpPort(config.TLS.PortDNSOverTLS),
|
||||
udpPort(config.TLS.PortDNSOverQUIC),
|
||||
tcpPort(config.TLS.PortDNSCrypt),
|
||||
)
|
||||
|
||||
addPorts(udpPorts, udpPort(config.TLS.PortDNSOverQUIC))
|
||||
}
|
||||
if err = uc.Validate(aghalg.IntIsBefore); err != nil {
|
||||
return fmt.Errorf("validating ports: %w", err)
|
||||
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
} else if err = udpPorts.Validate(); err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
config.BindPort = args.bindPort
|
||||
@@ -599,17 +602,17 @@ func configureLogger(args options) {
|
||||
ls.Verbose = true
|
||||
}
|
||||
if args.logFile != "" {
|
||||
ls.LogFile = args.logFile
|
||||
} else if config.LogFile != "" {
|
||||
ls.LogFile = config.LogFile
|
||||
ls.File = args.logFile
|
||||
} else if config.File != "" {
|
||||
ls.File = config.File
|
||||
}
|
||||
|
||||
// Handle default log settings overrides
|
||||
ls.LogCompress = config.LogCompress
|
||||
ls.LogLocalTime = config.LogLocalTime
|
||||
ls.LogMaxBackups = config.LogMaxBackups
|
||||
ls.LogMaxSize = config.LogMaxSize
|
||||
ls.LogMaxAge = config.LogMaxAge
|
||||
ls.Compress = config.Compress
|
||||
ls.LocalTime = config.LocalTime
|
||||
ls.MaxBackups = config.MaxBackups
|
||||
ls.MaxSize = config.MaxSize
|
||||
ls.MaxAge = config.MaxAge
|
||||
|
||||
// log.SetLevel(log.INFO) - default
|
||||
if ls.Verbose {
|
||||
@@ -620,27 +623,27 @@ func configureLogger(args options) {
|
||||
// happen pretty quickly.
|
||||
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
|
||||
|
||||
if args.runningAsService && ls.LogFile == "" && runtime.GOOS == "windows" {
|
||||
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||
// When running as a Windows service, use eventlog by default if nothing
|
||||
// else is configured. Otherwise, we'll simply lose the log output.
|
||||
ls.LogFile = configSyslog
|
||||
ls.File = configSyslog
|
||||
}
|
||||
|
||||
// logs are written to stdout (default)
|
||||
if ls.LogFile == "" {
|
||||
if ls.File == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if ls.LogFile == configSyslog {
|
||||
if ls.File == configSyslog {
|
||||
// Use syslog where it is possible and eventlog on Windows
|
||||
err := aghos.ConfigureSyslog(serviceName)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot initialize syslog: %s", err)
|
||||
}
|
||||
} else {
|
||||
logFilePath := filepath.Join(Context.workDir, ls.LogFile)
|
||||
if filepath.IsAbs(ls.LogFile) {
|
||||
logFilePath = ls.LogFile
|
||||
logFilePath := filepath.Join(Context.workDir, ls.File)
|
||||
if filepath.IsAbs(ls.File) {
|
||||
logFilePath = ls.File
|
||||
}
|
||||
|
||||
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
@@ -650,11 +653,11 @@ func configureLogger(args options) {
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
Filename: logFilePath,
|
||||
Compress: ls.LogCompress, // disabled by default
|
||||
LocalTime: ls.LogLocalTime,
|
||||
MaxBackups: ls.LogMaxBackups,
|
||||
MaxSize: ls.LogMaxSize, // megabytes
|
||||
MaxAge: ls.LogMaxAge, // days
|
||||
Compress: ls.Compress, // disabled by default
|
||||
LocalTime: ls.LocalTime,
|
||||
MaxBackups: ls.MaxBackups,
|
||||
MaxSize: ls.MaxSize, // megabytes
|
||||
MaxAge: ls.MaxAge, // days
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
|
||||
var allowedLanguages = stringutil.NewSet(
|
||||
"ar",
|
||||
"be",
|
||||
"bg",
|
||||
"cs",
|
||||
@@ -50,7 +51,7 @@ var allowedLanguages = stringutil.NewSet(
|
||||
"zh-tw",
|
||||
)
|
||||
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
log.Printf("config.Language is %s", config.Language)
|
||||
_, err := fmt.Fprintf(w, "%s\n", config.Language)
|
||||
@@ -58,6 +59,7 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
msg := fmt.Sprintf("Unable to write response json: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -69,6 +71,7 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/google/uuid"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
@@ -47,9 +47,9 @@ type payloadContent struct {
|
||||
|
||||
PayloadType string
|
||||
PayloadIdentifier string
|
||||
PayloadUUID string
|
||||
PayloadDisplayName string
|
||||
PayloadDescription string
|
||||
PayloadUUID uuid.UUID
|
||||
PayloadVersion int
|
||||
}
|
||||
|
||||
@@ -63,18 +63,14 @@ const dnsSettingsPayloadType = "com.apple.dnsSettings.managed"
|
||||
type mobileConfig struct {
|
||||
PayloadDescription string
|
||||
PayloadDisplayName string
|
||||
PayloadIdentifier string
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadContent []*payloadContent
|
||||
PayloadIdentifier uuid.UUID
|
||||
PayloadUUID uuid.UUID
|
||||
PayloadVersion int
|
||||
PayloadRemovalDisallowed bool
|
||||
}
|
||||
|
||||
func genUUIDv4() string {
|
||||
return uuid.NewV4().String()
|
||||
}
|
||||
|
||||
const (
|
||||
dnsProtoHTTPS = "HTTPS"
|
||||
dnsProtoTLS = "TLS"
|
||||
@@ -104,23 +100,23 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
||||
return nil, fmt.Errorf("bad dns protocol %q", proto)
|
||||
}
|
||||
|
||||
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, genUUIDv4())
|
||||
payloadID := fmt.Sprintf("%s.%s", dnsSettingsPayloadType, uuid.New())
|
||||
data := &mobileConfig{
|
||||
PayloadDescription: "Adds AdGuard Home to macOS Big Sur " +
|
||||
"and iOS 14 or newer systems",
|
||||
PayloadDescription: "Adds AdGuard Home to macOS Big Sur and iOS 14 or newer systems",
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadIdentifier: genUUIDv4(),
|
||||
PayloadType: "Configuration",
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadContent: []*payloadContent{{
|
||||
DNSSettings: d,
|
||||
|
||||
PayloadType: dnsSettingsPayloadType,
|
||||
PayloadIdentifier: payloadID,
|
||||
PayloadUUID: genUUIDv4(),
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadDescription: "Configures device to use AdGuard Home",
|
||||
PayloadUUID: uuid.New(),
|
||||
PayloadVersion: 1,
|
||||
DNSSettings: d,
|
||||
}},
|
||||
PayloadIdentifier: uuid.New(),
|
||||
PayloadUUID: uuid.New(),
|
||||
PayloadVersion: 1,
|
||||
PayloadRemovalDisallowed: false,
|
||||
}
|
||||
|
||||
@@ -3,15 +3,16 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{},
|
||||
ipCache: ipCache,
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
},
|
||||
clients: &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: tc.cliIDIndex,
|
||||
@@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
|
||||
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
|
||||
type rDNSExchanger struct {
|
||||
ex aghtest.Exchanger
|
||||
ex upstream.Upstream
|
||||
usePrivate bool
|
||||
}
|
||||
|
||||
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
|
||||
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
|
||||
rev, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: ip.String(),
|
||||
Qtype: dns.TypePTR,
|
||||
Name: dns.Fqdn(rev),
|
||||
Qclass: dns.ClassINET,
|
||||
Qtype: dns.TypePTR,
|
||||
}},
|
||||
}
|
||||
|
||||
@@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) {
|
||||
MaxCount: defaultRDNSCacheSize,
|
||||
})
|
||||
|
||||
ex := &rDNSExchanger{}
|
||||
ex := &rDNSExchanger{
|
||||
ex: aghtest.NewErrorUpstream(),
|
||||
}
|
||||
|
||||
rdns := &RDNS{
|
||||
ipCache: ipCache,
|
||||
@@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
locUpstream := &aghtest.Upstream{
|
||||
Reverse: map[string][]string{
|
||||
"192.168.1.1": {"local.domain"},
|
||||
"2a00:1450:400c:c06::93": {"ipv6.domain"},
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
revIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
||||
require.NoError(t, err)
|
||||
|
||||
locUpstream := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = aghalg.Coalesce(
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"),
|
||||
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"),
|
||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: errors.Error("1234"),
|
||||
}
|
||||
|
||||
errUpstream := aghtest.NewErrorUpstream()
|
||||
|
||||
testCases := []struct {
|
||||
ups upstream.Upstream
|
||||
@@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
ups: locUpstream,
|
||||
wantLog: "",
|
||||
name: "all_good",
|
||||
cliIP: net.IP{192, 168, 1, 1},
|
||||
cliIP: localIP,
|
||||
}, {
|
||||
ups: errUpstream,
|
||||
wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`,
|
||||
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
||||
name: "resolve_error",
|
||||
cliIP: net.IP{192, 168, 1, 2},
|
||||
}, {
|
||||
@@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
ch := make(chan net.IP)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: aghtest.Exchanger{
|
||||
Ups: tc.ups,
|
||||
},
|
||||
ex: tc.ups,
|
||||
},
|
||||
clients: cc,
|
||||
ipCh: ch,
|
||||
|
||||
@@ -433,8 +433,11 @@ EnvironmentFile=-/etc/sysconfig/{{.Name}}
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
|
||||
// Note: we should keep it in sync with the template from service_sysv_linux.go file
|
||||
// Use "ps | grep -v grep | grep $(get_pid)" because "ps PID" may not work on OpenWrt
|
||||
// sysvScript is the source of the daemon script for SysV-based Linux systems.
|
||||
// Keep as close as possible to the https://github.com/kardianos/service/blob/29f8c79c511bc18422bb99992779f96e6bc33921/service_sysv_linux.go#L187.
|
||||
//
|
||||
// Use ps command instead of reading the procfs since it's a more
|
||||
// implementation-independent approach.
|
||||
const sysvScript = `#!/bin/sh
|
||||
# For RedHat and cousins:
|
||||
# chkconfig: - 99 01
|
||||
@@ -465,7 +468,7 @@ get_pid() {
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps | grep -v grep | grep $(get_pid) > /dev/null 2>&1
|
||||
[ -f "$pid_file" ] && ps -p "$(get_pid)" > /dev/null 2>&1
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
@@ -609,7 +612,7 @@ command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirector
|
||||
run_rc_command "$1"
|
||||
`
|
||||
|
||||
const openBSDScript = `#!/bin/sh
|
||||
const openBSDScript = `#!/bin/ksh
|
||||
#
|
||||
# $OpenBSD: {{ .SvcInfo }}
|
||||
|
||||
|
||||
83
internal/home/service_linux.go
Normal file
83
internal/home/service_linux.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func chooseSystem() {
|
||||
sys := service.ChosenSystem()
|
||||
// By default, package service uses the SysV system if it cannot detect
|
||||
// anything other, but the update-rc.d fix should not be applied on OpenWrt,
|
||||
// so exclude it explicitly.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4480 and
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/4677.
|
||||
if sys.String() == "unix-systemv" && !aghos.IsOpenWrt() {
|
||||
service.ChooseSystem(sysvSystem{System: sys})
|
||||
}
|
||||
}
|
||||
|
||||
// sysvSystem is a wrapper for service.System that wraps the service.Service
|
||||
// while creating a new one.
|
||||
//
|
||||
// TODO(e.burkov): File a PR to github.com/kardianos/service.
|
||||
type sysvSystem struct {
|
||||
// System is expected to have an unexported type
|
||||
// *service.linuxSystemService.
|
||||
service.System
|
||||
}
|
||||
|
||||
// New returns a wrapped service.Service.
|
||||
func (sys sysvSystem) New(i service.Interface, c *service.Config) (s service.Service, err error) {
|
||||
s, err = sys.System.New(i, c)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return sysvService{
|
||||
Service: s,
|
||||
name: c.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sysvService is a wrapper for a service.Service that also calls update-rc.d in
|
||||
// a proper way on installing and uninstalling.
|
||||
type sysvService struct {
|
||||
// Service is expected to have an unexported type *service.sysv.
|
||||
service.Service
|
||||
// name stores the name of the service to call updating script with it.
|
||||
name string
|
||||
}
|
||||
|
||||
// Install wraps service.Service.Install call with calling the updating script.
|
||||
func (svc sysvService) Install() (err error) {
|
||||
err = svc.Service.Install()
|
||||
if err != nil {
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults")
|
||||
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// Uninstall wraps service.Service.Uninstall call with calling the updating
|
||||
// script.
|
||||
func (svc sysvService) Uninstall() (err error) {
|
||||
err = svc.Service.Uninstall()
|
||||
if err != nil {
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove")
|
||||
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
@@ -160,7 +160,7 @@ rc_cmd $1
|
||||
|
||||
// template returns the script template to put into rc.d.
|
||||
func (s *openbsdRunComService) template() (t *template.Template) {
|
||||
tf := map[string]interface{}{
|
||||
tf := map[string]any{
|
||||
"args": func(sl []string) string {
|
||||
return `"` + strings.Join(sl, " ") + `"`
|
||||
},
|
||||
@@ -390,42 +390,42 @@ func newSysLogger(_ string, _ chan<- error) (service.Logger, error) {
|
||||
type sysLogger struct{}
|
||||
|
||||
// Error implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Error(v ...interface{}) error {
|
||||
func (sysLogger) Error(v ...any) error {
|
||||
log.Error(fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Warning implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Warning(v ...interface{}) error {
|
||||
func (sysLogger) Warning(v ...any) error {
|
||||
log.Info("warning: %s", fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Info implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Info(v ...interface{}) error {
|
||||
func (sysLogger) Info(v ...any) error {
|
||||
log.Info(fmt.Sprint(v...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Errorf implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Errorf(format string, a ...interface{}) error {
|
||||
func (sysLogger) Errorf(format string, a ...any) error {
|
||||
log.Error(format, a...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Warningf implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Warningf(format string, a ...interface{}) error {
|
||||
func (sysLogger) Warningf(format string, a ...any) error {
|
||||
log.Info("warning: %s", fmt.Sprintf(format, a...))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Infof implements service.Logger interface for sysLogger.
|
||||
func (sysLogger) Infof(format string, a ...interface{}) error {
|
||||
func (sysLogger) Infof(format string, a ...any) error {
|
||||
log.Info(format, a...)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//go:build !openbsd
|
||||
// +build !openbsd
|
||||
//go:build !(openbsd || linux)
|
||||
// +build !openbsd,!linux
|
||||
|
||||
package home
|
||||
|
||||
// chooseSystem checks the current system detected and substitutes it with local
|
||||
// implementation if needed.
|
||||
func chooseSystem() {}
|
||||
|
||||
@@ -250,21 +250,17 @@ func (t *TLSMod) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if setts.Enabled {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
udpPort(config.DNS.Port),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
|
||||
err = uc.Validate(aghalg.IntIsBefore)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "validating ports: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -343,19 +339,15 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if data.Enabled {
|
||||
uc := aghalg.UniqChecker{}
|
||||
addPorts(
|
||||
uc,
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
udpPort(config.DNS.Port),
|
||||
tcpPort(data.PortHTTPS),
|
||||
tcpPort(data.PortDNSOverTLS),
|
||||
udpPort(data.PortDNSOverQUIC),
|
||||
tcpPort(data.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(data.PortDNSOverQUIC),
|
||||
)
|
||||
|
||||
err = uc.Validate(aghalg.IntIsBefore)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -421,6 +413,38 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
// DNS protocols.
|
||||
func validatePorts(
|
||||
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
||||
dnsPort, doqPort udpPort,
|
||||
) (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(bindPort),
|
||||
tcpPort(betaBindPort),
|
||||
tcpPort(dohPort),
|
||||
tcpPort(dotPort),
|
||||
tcpPort(dnscryptTCPPort),
|
||||
)
|
||||
|
||||
err = tcpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
}
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||
|
||||
err = udpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error {
|
||||
log.Tracef("TLS: got certificate: %d bytes", len(certChain))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -17,19 +18,16 @@ import (
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// currentSchemaVersion is the current schema version.
|
||||
const currentSchemaVersion = 14
|
||||
|
||||
// These aliases are provided for convenience.
|
||||
//
|
||||
// TODO(e.burkov): Remove any after updating to Go 1.18.
|
||||
type (
|
||||
any = interface{}
|
||||
yarr = []any
|
||||
yobj = map[any]any
|
||||
yobj = map[string]any
|
||||
)
|
||||
|
||||
// Performs necessary upgrade operations if needed
|
||||
@@ -107,16 +105,20 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
|
||||
return fmt.Errorf("unknown configuration schema version %d", oldVersion)
|
||||
}
|
||||
|
||||
body, err := yaml.Marshal(diskConf)
|
||||
buf := &bytes.Buffer{}
|
||||
enc := yaml.NewEncoder(buf)
|
||||
enc.SetIndent(2)
|
||||
|
||||
err = enc.Encode(diskConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating new config: %w", err)
|
||||
}
|
||||
|
||||
config.fileData = body
|
||||
config.fileData = buf.Bytes()
|
||||
confFile := config.getConfigFilename()
|
||||
err = maybe.WriteFile(confFile, body, 0o644)
|
||||
err = maybe.WriteFile(confFile, config.fileData, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("saving new config: %w", err)
|
||||
return fmt.Errorf("writing new config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -176,16 +178,16 @@ func upgradeSchema2to3(diskConf yobj) error {
|
||||
return fmt.Errorf("no DNS configuration in config file")
|
||||
}
|
||||
|
||||
// Convert interface{} to yobj
|
||||
// Convert any to yobj
|
||||
newDNSConfig := make(yobj)
|
||||
|
||||
switch v := dnsConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
case yobj:
|
||||
for k, v := range v {
|
||||
newDNSConfig[fmt.Sprint(k)] = v
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("dns configuration is not a map")
|
||||
return fmt.Errorf("unexpected type of dns: %T", dnsConfig)
|
||||
}
|
||||
|
||||
// Replace bootstrap_dns value filed with new array contains old bootstrap_dns inside
|
||||
@@ -216,12 +218,12 @@ func upgradeSchema3to4(diskConf yobj) error {
|
||||
}
|
||||
|
||||
switch arr := clients.(type) {
|
||||
case []interface{}:
|
||||
case []any:
|
||||
|
||||
for i := range arr {
|
||||
switch c := arr[i].(type) {
|
||||
|
||||
case map[interface{}]interface{}:
|
||||
case map[any]any:
|
||||
c["use_global_blocked_services"] = true
|
||||
|
||||
default:
|
||||
@@ -307,11 +309,11 @@ func upgradeSchema5to6(diskConf yobj) error {
|
||||
}
|
||||
|
||||
switch arr := clients.(type) {
|
||||
case []interface{}:
|
||||
case []any:
|
||||
for i := range arr {
|
||||
switch c := arr[i].(type) {
|
||||
case map[interface{}]interface{}:
|
||||
var ipVal interface{}
|
||||
case map[any]any:
|
||||
var ipVal any
|
||||
ipVal, ok = c["ip"]
|
||||
ids := []string{}
|
||||
if ok {
|
||||
@@ -326,7 +328,7 @@ func upgradeSchema5to6(diskConf yobj) error {
|
||||
}
|
||||
}
|
||||
|
||||
var macVal interface{}
|
||||
var macVal any
|
||||
macVal, ok = c["mac"]
|
||||
if ok {
|
||||
var mac string
|
||||
@@ -377,7 +379,7 @@ func upgradeSchema6to7(diskConf yobj) error {
|
||||
}
|
||||
|
||||
switch dhcp := dhcpVal.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
case map[any]any:
|
||||
var str string
|
||||
str, ok = dhcp["gateway_ip"].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -190,7 +190,7 @@ func testDiskConf(schemaVersion int) (diskConf yobj) {
|
||||
return diskConf
|
||||
}
|
||||
|
||||
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v2 would
|
||||
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v3 would
|
||||
// unmarshal it. In YAML, keys aren't guaranteed to always only be strings.
|
||||
func testDNSConf(schemaVersion int) (dnsConf yobj) {
|
||||
dnsConf = yobj{
|
||||
@@ -500,7 +500,7 @@ func TestUpgradeSchema11to12(t *testing.T) {
|
||||
dnsVal, ok = dns.(yobj)
|
||||
require.True(t, ok)
|
||||
|
||||
var ivl interface{}
|
||||
var ivl any
|
||||
ivl, ok = dnsVal["querylog_interval"]
|
||||
require.True(t, ok)
|
||||
|
||||
|
||||
@@ -19,10 +19,10 @@ import (
|
||||
)
|
||||
|
||||
type qlogConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
// Use float64 here to support fractional numbers and not mess the API
|
||||
// users by changing the units.
|
||||
Interval float64 `json:"interval"`
|
||||
Enabled bool `json:"enabled"`
|
||||
AnonymizeClientIP bool `json:"anonymize_client_ip"`
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// TODO(a.garipov): Use a proper structured approach here.
|
||||
|
||||
// jobject is a JSON object alias.
|
||||
type jobject = map[string]interface{}
|
||||
type jobject = map[string]any
|
||||
|
||||
// entriesToJSON converts query log entries to JSON.
|
||||
func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) (res jobject) {
|
||||
|
||||
@@ -149,7 +149,7 @@ func (l *queryLog) clear() {
|
||||
log.Error("removing log file %q: %s", l.logFile, err)
|
||||
}
|
||||
|
||||
log.Debug("Query log: cleared")
|
||||
log.Debug("querylog: cleared")
|
||||
}
|
||||
|
||||
func (l *queryLog) Add(params *AddParams) {
|
||||
|
||||
@@ -285,8 +285,8 @@ func addEntry(l *queryLog, host string, answerStr, client net.IP) {
|
||||
Answer: &a,
|
||||
OrigAnswer: &a,
|
||||
Result: &res,
|
||||
ClientIP: client,
|
||||
Upstream: "upstream",
|
||||
ClientIP: client,
|
||||
}
|
||||
|
||||
l.Add(params)
|
||||
|
||||
@@ -303,7 +303,7 @@ func NewTestQLogFileData(t *testing.T, data string) (file *QLogFile) {
|
||||
func TestQLog_Seek(t *testing.T) {
|
||||
const nl = "\n"
|
||||
const strV = "%s"
|
||||
const recs = `{"T":"` + strV + `","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}` + nl +
|
||||
const recs = `{"T":"` + strV + `","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://unfiltered.adguard-dns.com:853"}` + nl +
|
||||
`{"T":"` + strV + `"}` + nl +
|
||||
`{"T":"` + strV + `"}` + nl
|
||||
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
|
||||
|
||||
@@ -2,10 +2,10 @@ package querylog
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -28,14 +28,17 @@ type QueryLog interface {
|
||||
WriteDiskConfig(c *Config)
|
||||
}
|
||||
|
||||
// Config - configuration object
|
||||
// Config is the query log configuration structure.
|
||||
type Config struct {
|
||||
// Anonymizer processes the IP addresses to anonymize those if needed.
|
||||
Anonymizer *aghnet.IPMut
|
||||
|
||||
// ConfigModified is called when the configuration is changed, for
|
||||
// example by HTTP requests.
|
||||
ConfigModified func()
|
||||
|
||||
// HTTPRegister registers an HTTP handler.
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
// FindClient returns client information by their IDs.
|
||||
FindClient func(ids []string) (c *Client, err error)
|
||||
@@ -68,9 +71,6 @@ type Config struct {
|
||||
// AnonymizeClientIP tells if the query log should anonymize clients' IP
|
||||
// addresses.
|
||||
AnonymizeClientIP bool
|
||||
|
||||
// Anonymizer processes the IP addresses to anonymize those if needed.
|
||||
Anonymizer *aghnet.IPMut
|
||||
}
|
||||
|
||||
// AddParams is the parameters for adding an entry.
|
||||
@@ -91,18 +91,18 @@ type AddParams struct {
|
||||
// Result is the filtering result (optional).
|
||||
Result *filtering.Result
|
||||
|
||||
// Elapsed is the time spent for processing the request.
|
||||
Elapsed time.Duration
|
||||
|
||||
ClientID string
|
||||
|
||||
ClientIP net.IP
|
||||
|
||||
// Upstream is the URL of the upstream DNS server.
|
||||
Upstream string
|
||||
|
||||
ClientProto ClientProto
|
||||
|
||||
ClientIP net.IP
|
||||
|
||||
// Elapsed is the time spent for processing the request.
|
||||
Elapsed time.Duration
|
||||
|
||||
// Cached indicates if the response is served from cache.
|
||||
Cached bool
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
||||
|
||||
// search - searches log entries in the query log using specified parameters
|
||||
// returns the list of entries found + time of the oldest entry
|
||||
func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) {
|
||||
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) {
|
||||
now := time.Now()
|
||||
|
||||
if params.limit == 0 {
|
||||
@@ -88,7 +88,7 @@ func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) {
|
||||
totalLimit := params.offset + params.limit
|
||||
|
||||
// now let's get a unified collection
|
||||
entries := append(memoryEntries, fileEntries...)
|
||||
entries = append(memoryEntries, fileEntries...)
|
||||
if len(entries) > totalLimit {
|
||||
// remove extra records
|
||||
entries = entries[:totalLimit]
|
||||
@@ -111,13 +111,18 @@ func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(entries) > 0 && len(entries) <= totalLimit {
|
||||
if len(entries) > 0 {
|
||||
// Update oldest after merging in the memory buffer.
|
||||
oldest = entries[len(entries)-1].Time
|
||||
}
|
||||
|
||||
log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s",
|
||||
len(entries), total, params.olderThan, time.Since(now))
|
||||
log.Debug(
|
||||
"querylog: prepared data (%d/%d) older than %s in %s",
|
||||
len(entries),
|
||||
total,
|
||||
params.olderThan,
|
||||
time.Since(now),
|
||||
)
|
||||
|
||||
return entries, oldest
|
||||
}
|
||||
@@ -180,6 +185,8 @@ func (l *queryLog) searchFiles(
|
||||
e, ts, err = l.readNextEntry(r, params, cache)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
oldestNano = 0
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ package stats
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -15,18 +16,10 @@ import (
|
||||
// The key is either a client's address or a requested address.
|
||||
type topAddrs = map[string]uint64
|
||||
|
||||
// statsResponse is a response for getting statistics.
|
||||
type statsResponse struct {
|
||||
// StatsResp is a response to the GET /control/stats.
|
||||
type StatsResp struct {
|
||||
TimeUnits string `json:"time_units"`
|
||||
|
||||
NumDNSQueries uint64 `json:"num_dns_queries"`
|
||||
NumBlockedFiltering uint64 `json:"num_blocked_filtering"`
|
||||
NumReplacedSafebrowsing uint64 `json:"num_replaced_safebrowsing"`
|
||||
NumReplacedSafesearch uint64 `json:"num_replaced_safesearch"`
|
||||
NumReplacedParental uint64 `json:"num_replaced_parental"`
|
||||
|
||||
AvgProcessingTime float64 `json:"avg_processing_time"`
|
||||
|
||||
TopQueried []topAddrs `json:"top_queried_domains"`
|
||||
TopClients []topAddrs `json:"top_clients"`
|
||||
TopBlocked []topAddrs `json:"top_blocked_domains"`
|
||||
@@ -36,37 +29,30 @@ type statsResponse struct {
|
||||
BlockedFiltering []uint64 `json:"blocked_filtering"`
|
||||
ReplacedSafebrowsing []uint64 `json:"replaced_safebrowsing"`
|
||||
ReplacedParental []uint64 `json:"replaced_parental"`
|
||||
|
||||
NumDNSQueries uint64 `json:"num_dns_queries"`
|
||||
NumBlockedFiltering uint64 `json:"num_blocked_filtering"`
|
||||
NumReplacedSafebrowsing uint64 `json:"num_replaced_safebrowsing"`
|
||||
NumReplacedSafesearch uint64 `json:"num_replaced_safesearch"`
|
||||
NumReplacedParental uint64 `json:"num_replaced_parental"`
|
||||
|
||||
AvgProcessingTime float64 `json:"avg_processing_time"`
|
||||
}
|
||||
|
||||
// handleStats is a handler for getting statistics.
|
||||
func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
// handleStats handles requests to the GET /control/stats endpoint.
|
||||
func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
limit := atomic.LoadUint32(&s.limitHours)
|
||||
|
||||
start := time.Now()
|
||||
resp, ok := s.getData(limit)
|
||||
log.Debug("stats: prepared data in %v", time.Since(start))
|
||||
|
||||
var resp statsResponse
|
||||
if s.conf.limit == 0 {
|
||||
resp = statsResponse{
|
||||
TimeUnits: "days",
|
||||
if !ok {
|
||||
// Don't bring the message to the lower case since it's a part of UI
|
||||
// text for the moment.
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
|
||||
|
||||
TopBlocked: []topAddrs{},
|
||||
TopClients: []topAddrs{},
|
||||
TopQueried: []topAddrs{},
|
||||
|
||||
BlockedFiltering: []uint64{},
|
||||
DNSQueries: []uint64{},
|
||||
ReplacedParental: []uint64{},
|
||||
ReplacedSafebrowsing: []uint64{},
|
||||
}
|
||||
} else {
|
||||
var ok bool
|
||||
resp, ok = s.getData()
|
||||
|
||||
log.Debug("stats: prepared data in %v", time.Since(start))
|
||||
|
||||
if !ok {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
|
||||
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -74,36 +60,30 @@ func (s *statsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type config struct {
|
||||
// configResp is the response to the GET /control/stats_info.
|
||||
type configResp struct {
|
||||
IntervalDays uint32 `json:"interval"`
|
||||
}
|
||||
|
||||
// Get configuration
|
||||
func (s *statsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||
resp := config{}
|
||||
resp.IntervalDays = s.conf.limit / 24
|
||||
// handleStatsInfo handles requests to the GET /control/stats_info endpoint.
|
||||
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||
resp := configResp{IntervalDays: atomic.LoadUint32(&s.limitHours) / 24}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(data)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set configuration
|
||||
func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := config{}
|
||||
// handleStatsConfig handles requests to the POST /control/stats_config
|
||||
// endpoint.
|
||||
func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := configResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
@@ -118,22 +98,25 @@ func (s *statsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
s.setLimit(int(reqData.IntervalDays))
|
||||
s.conf.ConfigModified()
|
||||
s.configModified()
|
||||
}
|
||||
|
||||
// Reset data
|
||||
func (s *statsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||
s.clear()
|
||||
// handleStatsReset handles requests to the POST /control/stats_reset endpoint.
|
||||
func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.clear()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Register web handlers
|
||||
func (s *statsCtx) initWeb() {
|
||||
if s.conf.HTTPRegister == nil {
|
||||
// initWeb registers the handlers for web endpoints of statistics module.
|
||||
func (s *StatsCtx) initWeb() {
|
||||
if s.httpRegister == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/stats", s.handleStats)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
|
||||
s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
|
||||
}
|
||||
|
||||
@@ -3,86 +3,545 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
type unitIDCallback func() uint32
|
||||
|
||||
// DiskConfig - configuration settings that are stored on disk
|
||||
// DiskConfig is the configuration structure that is stored in file.
|
||||
type DiskConfig struct {
|
||||
Interval uint32 `yaml:"statistics_interval"` // time interval for statistics (in days)
|
||||
// Interval is the number of days for which the statistics are collected
|
||||
// before flushing to the database.
|
||||
Interval uint32 `yaml:"statistics_interval"`
|
||||
}
|
||||
|
||||
// Config - module configuration
|
||||
type Config struct {
|
||||
Filename string // database file name
|
||||
LimitDays uint32 // time limit (in days)
|
||||
UnitID unitIDCallback // user function to get the current unit ID. If nil, the current time hour is used.
|
||||
// checkInterval returns true if days is valid to be used as statistics
|
||||
// retention interval. The valid values are 0, 1, 7, 30 and 90.
|
||||
func checkInterval(days uint32) (ok bool) {
|
||||
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
|
||||
}
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
// Config is the configuration structure for the statistics collecting.
|
||||
type Config struct {
|
||||
// UnitID is the function to generate the identifier for current unit. If
|
||||
// nil, the default function is used, see newUnitID.
|
||||
UnitID UnitIDGenFunc
|
||||
|
||||
// ConfigModified will be called each time the configuration changed via web
|
||||
// interface.
|
||||
ConfigModified func()
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
|
||||
// HTTPRegister is the function that registers handlers for the stats
|
||||
// endpoints.
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
limit uint32 // maximum time we need to keep data for (in hours)
|
||||
// Filename is the name of the database file.
|
||||
Filename string
|
||||
|
||||
// LimitDays is the maximum number of days to collect statistics into the
|
||||
// current unit.
|
||||
LimitDays uint32
|
||||
}
|
||||
|
||||
// New - create object
|
||||
func New(conf Config) (Stats, error) {
|
||||
return createObject(conf)
|
||||
}
|
||||
|
||||
// Stats - main interface
|
||||
type Stats interface {
|
||||
// Interface is the statistics interface to be used by other packages.
|
||||
type Interface interface {
|
||||
// Start begins the statistics collecting.
|
||||
Start()
|
||||
|
||||
// Close object.
|
||||
// This function is not thread safe
|
||||
// (can't be called in parallel with any other function of this interface).
|
||||
Close()
|
||||
io.Closer
|
||||
|
||||
// Update counters
|
||||
// Update collects the incoming statistics data.
|
||||
Update(e Entry)
|
||||
|
||||
// Get IP addresses of the clients with the most number of requests
|
||||
GetTopClientsIP(limit uint) []net.IP
|
||||
// GetTopClientIP returns at most limit IP addresses corresponding to the
|
||||
// clients with the most number of requests.
|
||||
TopClientsIP(limit uint) []net.IP
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
// WriteDiskConfig puts the Interface's configuration to the dc.
|
||||
WriteDiskConfig(dc *DiskConfig)
|
||||
}
|
||||
|
||||
// TimeUnit - time unit
|
||||
type TimeUnit int
|
||||
|
||||
// Supported time units
|
||||
const (
|
||||
Hours TimeUnit = iota
|
||||
Days
|
||||
)
|
||||
|
||||
// Result of DNS request processing
|
||||
type Result int
|
||||
|
||||
// Supported result values
|
||||
const (
|
||||
RNotFiltered Result = iota + 1
|
||||
RFiltered
|
||||
RSafeBrowsing
|
||||
RSafeSearch
|
||||
RParental
|
||||
rLast
|
||||
)
|
||||
|
||||
// Entry is a statistics data entry.
|
||||
type Entry struct {
|
||||
// Clients is the client's primary ID.
|
||||
// StatsCtx collects the statistics and flushes it to the database. Its default
|
||||
// flushing interval is one hour.
|
||||
//
|
||||
// TODO(e.burkov): Use atomic.Pointer for accessing db in go1.19.
|
||||
type StatsCtx struct {
|
||||
// limitHours is the maximum number of hours to collect statistics into the
|
||||
// current unit.
|
||||
//
|
||||
// TODO(a.garipov): Make this a {net.IP, string} enum?
|
||||
Client string
|
||||
// It is of type uint32 to be accessed by atomic. It's arranged at the
|
||||
// beginning of the structure to keep 64-bit alignment.
|
||||
limitHours uint32
|
||||
|
||||
Domain string
|
||||
Result Result
|
||||
Time uint32 // processing time (msec)
|
||||
// currMu protects curr.
|
||||
currMu *sync.RWMutex
|
||||
// curr is the actual statistics collection result.
|
||||
curr *unit
|
||||
|
||||
// dbMu protects db.
|
||||
dbMu *sync.Mutex
|
||||
// db is the opened statistics database, if any.
|
||||
db *bbolt.DB
|
||||
|
||||
// unitIDGen is the function that generates an identifier for the current
|
||||
// unit. It's here for only testing purposes.
|
||||
unitIDGen UnitIDGenFunc
|
||||
|
||||
// httpRegister is used to set HTTP handlers.
|
||||
httpRegister aghhttp.RegisterFunc
|
||||
|
||||
// configModified is called whenever the configuration is modified via web
|
||||
// interface.
|
||||
configModified func()
|
||||
|
||||
// filename is the name of database file.
|
||||
filename string
|
||||
}
|
||||
|
||||
var _ Interface = &StatsCtx{}
|
||||
|
||||
// New creates s from conf and properly initializes it. Don't use s before
|
||||
// calling it's Start method.
|
||||
func New(conf Config) (s *StatsCtx, err error) {
|
||||
defer withRecovered(&err)
|
||||
|
||||
s = &StatsCtx{
|
||||
currMu: &sync.RWMutex{},
|
||||
dbMu: &sync.Mutex{},
|
||||
filename: conf.Filename,
|
||||
configModified: conf.ConfigModified,
|
||||
httpRegister: conf.HTTPRegister,
|
||||
}
|
||||
if s.limitHours = conf.LimitDays * 24; !checkInterval(conf.LimitDays) {
|
||||
s.limitHours = 24
|
||||
}
|
||||
if s.unitIDGen = newUnitID; conf.UnitID != nil {
|
||||
s.unitIDGen = conf.UnitID
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Move the code below to the Start method.
|
||||
|
||||
err = s.openDB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening database: %w", err)
|
||||
}
|
||||
|
||||
var udb *unitDB
|
||||
id := s.unitIDGen()
|
||||
|
||||
tx, err := s.db.Begin(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stats: opening a transaction: %w", err)
|
||||
}
|
||||
|
||||
deleted := deleteOldUnits(tx, id-s.limitHours-1)
|
||||
udb = loadUnitFromDB(tx, id)
|
||||
|
||||
err = finishTxn(tx, deleted > 0)
|
||||
if err != nil {
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
|
||||
s.curr = newUnit(id)
|
||||
s.curr.deserialize(udb)
|
||||
|
||||
log.Debug("stats: initialized")
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// withRecovered turns the value recovered from panic if any into an error and
|
||||
// combines it with the one pointed by orig. orig must be non-nil.
|
||||
func withRecovered(orig *error) {
|
||||
p := recover()
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
switch p := p.(type) {
|
||||
case error:
|
||||
err = fmt.Errorf("panic: %w", p)
|
||||
default:
|
||||
err = fmt.Errorf("panic: recovered value of type %[1]T: %[1]v", p)
|
||||
}
|
||||
|
||||
*orig = errors.WithDeferred(*orig, err)
|
||||
}
|
||||
|
||||
// Start implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) Start() {
|
||||
s.initWeb()
|
||||
|
||||
go s.periodicFlush()
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface for *StatsCtx.
|
||||
func (s *StatsCtx) Close() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "stats: closing: %w") }()
|
||||
|
||||
db := s.swapDatabase(nil)
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
cerr := db.Close()
|
||||
if cerr == nil {
|
||||
log.Debug("stats: database closed")
|
||||
}
|
||||
|
||||
err = errors.WithDeferred(err, cerr)
|
||||
}()
|
||||
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening transaction: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, finishTxn(tx, err == nil)) }()
|
||||
|
||||
s.currMu.RLock()
|
||||
defer s.currMu.RUnlock()
|
||||
|
||||
udb := s.curr.serialize()
|
||||
|
||||
return udb.flushUnitToDB(tx, s.curr.id)
|
||||
}
|
||||
|
||||
// Update implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) Update(e Entry) {
|
||||
if atomic.LoadUint32(&s.limitHours) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if e.Result == 0 || e.Result >= resultLast || e.Domain == "" || e.Client == "" {
|
||||
log.Debug("stats: malformed entry")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.currMu.Lock()
|
||||
defer s.currMu.Unlock()
|
||||
|
||||
if s.curr == nil {
|
||||
log.Error("stats: current unit is nil")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
clientID := e.Client
|
||||
if ip := net.ParseIP(clientID); ip != nil {
|
||||
clientID = ip.String()
|
||||
}
|
||||
|
||||
s.curr.add(e.Result, e.Domain, clientID, uint64(e.Time))
|
||||
}
|
||||
|
||||
// WriteDiskConfig implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) {
|
||||
dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
|
||||
}
|
||||
|
||||
// TopClientsIP implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
|
||||
limit := atomic.LoadUint32(&s.limitHours)
|
||||
if limit == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
units, _ := s.loadUnits(limit)
|
||||
if units == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collect data for all the clients to sort and crop it afterwards.
|
||||
m := map[string]uint64{}
|
||||
for _, u := range units {
|
||||
for _, it := range u.Clients {
|
||||
m[it.Name] += it.Count
|
||||
}
|
||||
}
|
||||
|
||||
a := convertMapToSlice(m, int(maxCount))
|
||||
ips = []net.IP{}
|
||||
for _, it := range a {
|
||||
ip := net.ParseIP(it.Name)
|
||||
if ip != nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
||||
// database returns the database if it's opened. It's safe for concurrent use.
|
||||
func (s *StatsCtx) database() (db *bbolt.DB) {
|
||||
s.dbMu.Lock()
|
||||
defer s.dbMu.Unlock()
|
||||
|
||||
return s.db
|
||||
}
|
||||
|
||||
// swapDatabase swaps the database with another one and returns it. It's safe
|
||||
// for concurrent use.
|
||||
func (s *StatsCtx) swapDatabase(with *bbolt.DB) (old *bbolt.DB) {
|
||||
s.dbMu.Lock()
|
||||
defer s.dbMu.Unlock()
|
||||
|
||||
old, s.db = s.db, with
|
||||
|
||||
return old
|
||||
}
|
||||
|
||||
// deleteOldUnits walks the buckets available to tx and deletes old units. It
|
||||
// returns the number of deletions performed.
|
||||
func deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) {
|
||||
log.Debug("stats: deleting old units until id %d", firstID)
|
||||
|
||||
// TODO(a.garipov): See if this is actually necessary. Looks like a rather
|
||||
// bizarre solution.
|
||||
const errStop errors.Error = "stop iteration"
|
||||
|
||||
walk := func(name []byte, _ *bbolt.Bucket) (err error) {
|
||||
nameID, ok := unitNameToID(name)
|
||||
if ok && nameID >= firstID {
|
||||
return errStop
|
||||
}
|
||||
|
||||
err = tx.DeleteBucket(name)
|
||||
if err != nil {
|
||||
log.Debug("stats: deleting bucket: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("stats: deleted unit %d (name %x)", nameID, name)
|
||||
|
||||
deleted++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := tx.ForEach(walk)
|
||||
if err != nil && !errors.Is(err, errStop) {
|
||||
log.Debug("stats: deleting units: %s", err)
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
// openDB returns an error if the database can't be opened from the specified
|
||||
// file. It's safe for concurrent use.
|
||||
func (s *StatsCtx) openDB() (err error) {
|
||||
log.Debug("stats: opening database")
|
||||
|
||||
var db *bbolt.DB
|
||||
db, err = bbolt.Open(s.filename, 0o644, nil)
|
||||
if err != nil {
|
||||
if err.Error() == "invalid argument" {
|
||||
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Use defer to unlock the mutex as soon as possible.
|
||||
defer log.Debug("stats: database opened")
|
||||
|
||||
s.dbMu.Lock()
|
||||
defer s.dbMu.Unlock()
|
||||
|
||||
s.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsCtx) flush() (cont bool, sleepFor time.Duration) {
|
||||
id := s.unitIDGen()
|
||||
|
||||
s.currMu.Lock()
|
||||
defer s.currMu.Unlock()
|
||||
|
||||
ptr := s.curr
|
||||
if ptr == nil {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
limit := atomic.LoadUint32(&s.limitHours)
|
||||
if limit == 0 || ptr.id == id {
|
||||
return true, time.Second
|
||||
}
|
||||
|
||||
db := s.database()
|
||||
if db == nil {
|
||||
return true, 0
|
||||
}
|
||||
|
||||
isCommitable := true
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("stats: opening transaction: %s", err)
|
||||
|
||||
return true, 0
|
||||
}
|
||||
defer func() {
|
||||
if err = finishTxn(tx, isCommitable); err != nil {
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
s.curr = newUnit(id)
|
||||
|
||||
flushErr := ptr.serialize().flushUnitToDB(tx, ptr.id)
|
||||
if flushErr != nil {
|
||||
log.Error("stats: flushing unit: %s", flushErr)
|
||||
isCommitable = false
|
||||
}
|
||||
|
||||
delErr := tx.DeleteBucket(idToUnitName(id - limit))
|
||||
if delErr != nil {
|
||||
// TODO(e.burkov): Improve the algorithm of deleting the oldest bucket
|
||||
// to avoid the error.
|
||||
if errors.Is(delErr, bbolt.ErrBucketNotFound) {
|
||||
log.Debug("stats: warning: deleting unit: %s", delErr)
|
||||
} else {
|
||||
isCommitable = false
|
||||
log.Error("stats: deleting unit: %s", delErr)
|
||||
}
|
||||
}
|
||||
|
||||
return true, 0
|
||||
}
|
||||
|
||||
// periodicFlush checks and flushes the unit to the database if the freshly
|
||||
// generated unit ID differs from the current's ID. Flushing process includes:
|
||||
// - swapping the current unit with the new empty one;
|
||||
// - writing the current unit to the database;
|
||||
// - removing the stale unit from the database.
|
||||
func (s *StatsCtx) periodicFlush() {
|
||||
for cont, sleepFor := true, time.Duration(0); cont; time.Sleep(sleepFor) {
|
||||
cont, sleepFor = s.flush()
|
||||
}
|
||||
|
||||
log.Debug("periodic flushing finished")
|
||||
}
|
||||
|
||||
func (s *StatsCtx) setLimit(limitDays int) {
|
||||
atomic.StoreUint32(&s.limitHours, uint32(24*limitDays))
|
||||
if limitDays == 0 {
|
||||
if err := s.clear(); err != nil {
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("stats: set limit: %d days", limitDays)
|
||||
}
|
||||
|
||||
// Reset counters and clear database
|
||||
func (s *StatsCtx) clear() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "clearing: %w") }()
|
||||
|
||||
db := s.swapDatabase(nil)
|
||||
if db != nil {
|
||||
var tx *bbolt.Tx
|
||||
tx, err = db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("stats: opening a transaction: %s", err)
|
||||
} else if err = finishTxn(tx, false); err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// Active transactions will continue using database, but new ones won't
|
||||
// be created.
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing database: %w", err)
|
||||
}
|
||||
|
||||
// All active transactions are now closed.
|
||||
log.Debug("stats: database closed")
|
||||
}
|
||||
|
||||
err = os.Remove(s.filename)
|
||||
if err != nil {
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
|
||||
err = s.openDB()
|
||||
if err != nil {
|
||||
log.Error("stats: opening database: %s", err)
|
||||
}
|
||||
|
||||
// Use defer to unlock the mutex as soon as possible.
|
||||
defer log.Debug("stats: cleared")
|
||||
|
||||
s.currMu.Lock()
|
||||
defer s.currMu.Unlock()
|
||||
|
||||
s.curr = newUnit(s.unitIDGen())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, firstID uint32) {
|
||||
db := s.database()
|
||||
if db == nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// Use writable transaction to ensure any ongoing writable transaction is
|
||||
// taken into account.
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("stats: opening transaction: %s", err)
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
s.currMu.RLock()
|
||||
defer s.currMu.RUnlock()
|
||||
|
||||
cur := s.curr
|
||||
|
||||
var curID uint32
|
||||
if cur != nil {
|
||||
curID = cur.id
|
||||
} else {
|
||||
curID = s.unitIDGen()
|
||||
}
|
||||
|
||||
// Per-hour units.
|
||||
units = make([]*unitDB, 0, limit)
|
||||
firstID = curID - limit + 1
|
||||
for i := firstID; i != curID; i++ {
|
||||
u := loadUnitFromDB(tx, i)
|
||||
if u == nil {
|
||||
u = &unitDB{NResult: make([]uint64, resultLast)}
|
||||
}
|
||||
units = append(units, u)
|
||||
}
|
||||
|
||||
err = finishTxn(tx, false)
|
||||
if err != nil {
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
|
||||
if cur != nil {
|
||||
units = append(units, cur.serialize())
|
||||
}
|
||||
|
||||
if unitsLen := len(units); unitsLen != int(limit) {
|
||||
log.Fatalf("loaded %d units whilst the desired number is %d", unitsLen, limit)
|
||||
}
|
||||
|
||||
return units, firstID
|
||||
}
|
||||
|
||||
26
internal/stats/stats_internal_test.go
Normal file
26
internal/stats/stats_internal_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(e.burkov): Use more realistic data.
|
||||
func TestStatsCollector(t *testing.T) {
|
||||
ng := func(_ *unitDB) uint64 { return 0 }
|
||||
units := make([]*unitDB, 720)
|
||||
|
||||
t.Run("hours", func(t *testing.T) {
|
||||
statsData := statsCollector(units, 0, Hours, ng)
|
||||
assert.Len(t, statsData, 720)
|
||||
})
|
||||
|
||||
t.Run("days", func(t *testing.T) {
|
||||
for i := 0; i != 25; i++ {
|
||||
statsData := statsCollector(units, uint32(i), Days, ng)
|
||||
require.Lenf(t, statsData, 30, "i=%d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,13 +1,17 @@
|
||||
package stats
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -17,147 +21,176 @@ func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func UIntArrayEquals(a, b []uint64) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
// constUnitID is the UnitIDGenFunc which always return 0.
|
||||
func constUnitID() (id uint32) { return 0 }
|
||||
|
||||
func assertSuccessAndUnmarshal(t *testing.T, to any, handler http.Handler, req *http.Request) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, handler)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rw, req)
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
data := rw.Body.Bytes()
|
||||
if to == nil {
|
||||
assert.Empty(t, data)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
err := json.Unmarshal(data, to)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
conf := Config{
|
||||
Filename: "./stats.db",
|
||||
cliIP := net.IP{127, 0, 0, 1}
|
||||
cliIPStr := cliIP.String()
|
||||
|
||||
handlers := map[string]http.Handler{}
|
||||
conf := stats.Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
LimitDays: 1,
|
||||
UnitID: constUnitID,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
}
|
||||
|
||||
s, err := createObject(conf)
|
||||
s, err := stats.New(conf)
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
s.clear()
|
||||
s.Close()
|
||||
|
||||
return os.Remove(conf.Filename)
|
||||
s.Start()
|
||||
testutil.CleanupAndRequireSuccess(t, s.Close)
|
||||
|
||||
t.Run("data", func(t *testing.T) {
|
||||
const reqDomain = "domain"
|
||||
|
||||
entries := []stats.Entry{{
|
||||
Domain: reqDomain,
|
||||
Client: cliIPStr,
|
||||
Result: stats.RFiltered,
|
||||
Time: 123456,
|
||||
}, {
|
||||
Domain: reqDomain,
|
||||
Client: cliIPStr,
|
||||
Result: stats.RNotFiltered,
|
||||
Time: 123456,
|
||||
}}
|
||||
|
||||
wantData := &stats.StatsResp{
|
||||
TimeUnits: "hours",
|
||||
TopQueried: []map[string]uint64{0: {reqDomain: 1}},
|
||||
TopClients: []map[string]uint64{0: {cliIPStr: 2}},
|
||||
TopBlocked: []map[string]uint64{0: {reqDomain: 1}},
|
||||
DNSQueries: []uint64{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
|
||||
},
|
||||
BlockedFiltering: []uint64{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||
},
|
||||
ReplacedSafebrowsing: []uint64{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
},
|
||||
ReplacedParental: []uint64{
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
},
|
||||
NumDNSQueries: 2,
|
||||
NumBlockedFiltering: 1,
|
||||
NumReplacedSafebrowsing: 0,
|
||||
NumReplacedSafesearch: 0,
|
||||
NumReplacedParental: 0,
|
||||
AvgProcessingTime: 0.123456,
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
s.Update(e)
|
||||
}
|
||||
|
||||
data := &stats.StatsResp{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/control/stats", nil)
|
||||
assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req)
|
||||
|
||||
assert.Equal(t, wantData, data)
|
||||
})
|
||||
|
||||
s.Update(Entry{
|
||||
Domain: "domain",
|
||||
Client: "127.0.0.1",
|
||||
Result: RFiltered,
|
||||
Time: 123456,
|
||||
})
|
||||
s.Update(Entry{
|
||||
Domain: "domain",
|
||||
Client: "127.0.0.1",
|
||||
Result: RNotFiltered,
|
||||
Time: 123456,
|
||||
t.Run("tops", func(t *testing.T) {
|
||||
topClients := s.TopClientsIP(2)
|
||||
require.NotEmpty(t, topClients)
|
||||
|
||||
assert.True(t, cliIP.Equal(topClients[0]))
|
||||
})
|
||||
|
||||
d, ok := s.getData()
|
||||
require.True(t, ok)
|
||||
t.Run("reset", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/control/stats_reset", nil)
|
||||
assertSuccessAndUnmarshal(t, nil, handlers["/control/stats_reset"], req)
|
||||
|
||||
a := []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
assert.True(t, UIntArrayEquals(d.DNSQueries, a))
|
||||
_24zeroes := [24]uint64{}
|
||||
emptyData := &stats.StatsResp{
|
||||
TimeUnits: "hours",
|
||||
TopQueried: []map[string]uint64{},
|
||||
TopClients: []map[string]uint64{},
|
||||
TopBlocked: []map[string]uint64{},
|
||||
DNSQueries: _24zeroes[:],
|
||||
BlockedFiltering: _24zeroes[:],
|
||||
ReplacedSafebrowsing: _24zeroes[:],
|
||||
ReplacedParental: _24zeroes[:],
|
||||
}
|
||||
|
||||
a = []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
assert.True(t, UIntArrayEquals(d.BlockedFiltering, a))
|
||||
req = httptest.NewRequest(http.MethodGet, "/control/stats", nil)
|
||||
data := &stats.StatsResp{}
|
||||
|
||||
a = []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
assert.True(t, UIntArrayEquals(d.ReplacedSafebrowsing, a))
|
||||
|
||||
a = []uint64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
assert.True(t, UIntArrayEquals(d.ReplacedParental, a))
|
||||
|
||||
m := d.TopQueried
|
||||
require.NotEmpty(t, m)
|
||||
assert.EqualValues(t, 1, m[0]["domain"])
|
||||
|
||||
m = d.TopBlocked
|
||||
require.NotEmpty(t, m)
|
||||
assert.EqualValues(t, 1, m[0]["domain"])
|
||||
|
||||
m = d.TopClients
|
||||
require.NotEmpty(t, m)
|
||||
assert.EqualValues(t, 2, m[0]["127.0.0.1"])
|
||||
|
||||
assert.EqualValues(t, 2, d.NumDNSQueries)
|
||||
assert.EqualValues(t, 1, d.NumBlockedFiltering)
|
||||
assert.EqualValues(t, 0, d.NumReplacedSafebrowsing)
|
||||
assert.EqualValues(t, 0, d.NumReplacedSafesearch)
|
||||
assert.EqualValues(t, 0, d.NumReplacedParental)
|
||||
assert.EqualValues(t, 0.123456, d.AvgProcessingTime)
|
||||
|
||||
topClients := s.GetTopClientsIP(2)
|
||||
require.NotEmpty(t, topClients)
|
||||
assert.True(t, net.IP{127, 0, 0, 1}.Equal(topClients[0]))
|
||||
assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req)
|
||||
assert.Equal(t, emptyData, data)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLargeNumbers(t *testing.T) {
|
||||
var hour int32 = 0
|
||||
newID := func() uint32 {
|
||||
// Use "atomic" to make go race detector happy.
|
||||
return uint32(atomic.LoadInt32(&hour))
|
||||
var curHour uint32 = 1
|
||||
handlers := map[string]http.Handler{}
|
||||
|
||||
conf := stats.Config{
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
LimitDays: 1,
|
||||
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
|
||||
}
|
||||
|
||||
conf := Config{
|
||||
Filename: "./stats.db",
|
||||
LimitDays: 1,
|
||||
UnitID: newID,
|
||||
}
|
||||
s, err := createObject(conf)
|
||||
s, err := stats.New(conf)
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
s.Close()
|
||||
|
||||
return os.Remove(conf.Filename)
|
||||
})
|
||||
s.Start()
|
||||
testutil.CleanupAndRequireSuccess(t, s.Close)
|
||||
|
||||
// Number of distinct clients and domains every hour.
|
||||
const n = 1000
|
||||
const (
|
||||
hoursNum = 12
|
||||
cliNumPerHour = 1000
|
||||
)
|
||||
|
||||
for h := 0; h < 12; h++ {
|
||||
atomic.AddInt32(&hour, 1)
|
||||
for i := 0; i < n; i++ {
|
||||
s.Update(Entry{
|
||||
Domain: fmt.Sprintf("domain%d", i),
|
||||
Client: net.IP{
|
||||
127,
|
||||
0,
|
||||
byte((i & 0xff00) >> 8),
|
||||
byte(i & 0xff),
|
||||
}.String(),
|
||||
Result: RNotFiltered,
|
||||
req := httptest.NewRequest(http.MethodGet, "/control/stats", nil)
|
||||
|
||||
for h := 0; h < hoursNum; h++ {
|
||||
atomic.AddUint32(&curHour, 1)
|
||||
|
||||
for i := 0; i < cliNumPerHour; i++ {
|
||||
ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)}
|
||||
e := stats.Entry{
|
||||
Domain: fmt.Sprintf("domain%d.hour%d", i, h),
|
||||
Client: ip.String(),
|
||||
Result: stats.RNotFiltered,
|
||||
Time: 123456,
|
||||
})
|
||||
}
|
||||
s.Update(e)
|
||||
}
|
||||
}
|
||||
|
||||
d, ok := s.getData()
|
||||
require.True(t, ok)
|
||||
assert.EqualValues(t, hour*n, d.NumDNSQueries)
|
||||
}
|
||||
|
||||
func TestStatsCollector(t *testing.T) {
|
||||
ng := func(_ *unitDB) uint64 {
|
||||
return 0
|
||||
}
|
||||
units := make([]*unitDB, 720)
|
||||
|
||||
t.Run("hours", func(t *testing.T) {
|
||||
statsData := statsCollector(units, 0, Hours, ng)
|
||||
assert.Len(t, statsData, 720)
|
||||
})
|
||||
|
||||
t.Run("days", func(t *testing.T) {
|
||||
for i := 0; i != 25; i++ {
|
||||
statsData := statsCollector(units, uint32(i), Days, ng)
|
||||
require.Lenf(t, statsData, 30, "i=%d", i)
|
||||
}
|
||||
})
|
||||
data := &stats.StatsResp{}
|
||||
assertSuccessAndUnmarshal(t, data, handlers["/control/stats"], req)
|
||||
assert.Equal(t, hoursNum*cliNumPerHour, int(data.NumDNSQueries))
|
||||
}
|
||||
|
||||
@@ -5,253 +5,148 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Rewrite all of this. Add proper error handling and
|
||||
// inspection. Improve logging. Decrease complexity.
|
||||
|
||||
const (
|
||||
maxDomains = 100 // max number of top domains to store in file or return via Get()
|
||||
maxClients = 100 // max number of top clients to store in file or return via Get()
|
||||
// maxDomains is the max number of top domains to return.
|
||||
maxDomains = 100
|
||||
// maxClients is the max number of top clients to return.
|
||||
maxClients = 100
|
||||
)
|
||||
|
||||
// statsCtx - global context
|
||||
type statsCtx struct {
|
||||
// mu protects unit.
|
||||
mu *sync.Mutex
|
||||
// current is the actual statistics collection result.
|
||||
current *unit
|
||||
// UnitIDGenFunc is the signature of a function that generates a unique ID for
|
||||
// the statistics unit.
|
||||
type UnitIDGenFunc func() (id uint32)
|
||||
|
||||
db *bolt.DB
|
||||
conf *Config
|
||||
// TimeUnit is the unit of measuring time while aggregating the statistics.
|
||||
type TimeUnit int
|
||||
|
||||
// Supported TimeUnit values.
|
||||
const (
|
||||
Hours TimeUnit = iota
|
||||
Days
|
||||
)
|
||||
|
||||
// Result is the resulting code of processing the DNS request.
|
||||
type Result int
|
||||
|
||||
// Supported Result values.
|
||||
//
|
||||
// TODO(e.burkov): Think about better naming.
|
||||
const (
|
||||
RNotFiltered Result = iota + 1
|
||||
RFiltered
|
||||
RSafeBrowsing
|
||||
RSafeSearch
|
||||
RParental
|
||||
|
||||
resultLast = RParental + 1
|
||||
)
|
||||
|
||||
// Entry is a statistics data entry.
|
||||
type Entry struct {
|
||||
// Clients is the client's primary ID.
|
||||
//
|
||||
// TODO(a.garipov): Make this a {net.IP, string} enum?
|
||||
Client string
|
||||
|
||||
// Domain is the domain name requested.
|
||||
Domain string
|
||||
|
||||
// Result is the result of processing the request.
|
||||
Result Result
|
||||
|
||||
// Time is the duration of the request processing in milliseconds.
|
||||
Time uint32
|
||||
}
|
||||
|
||||
// data for 1 time unit
|
||||
// unit collects the statistics data for a specific period of time.
|
||||
type unit struct {
|
||||
id uint32 // unit ID. Default: absolute hour since Jan 1, 1970
|
||||
// id is the unique unit's identifier. It's set to an absolute hour number
|
||||
// since the beginning of UNIX time by the default ID generating function.
|
||||
//
|
||||
// Must not be rewritten after creating to be accessed concurrently without
|
||||
// using mu.
|
||||
id uint32
|
||||
|
||||
nTotal uint64 // total requests
|
||||
nResult []uint64 // number of requests per one result
|
||||
timeSum uint64 // sum of processing time of all requests (usec)
|
||||
// nTotal stores the total number of requests.
|
||||
nTotal uint64
|
||||
// nResult stores the number of requests grouped by it's result.
|
||||
nResult []uint64
|
||||
// timeSum stores the sum of processing time in milliseconds of each request
|
||||
// written by the unit.
|
||||
timeSum uint64
|
||||
|
||||
// top:
|
||||
domains map[string]uint64 // number of requests per domain
|
||||
blockedDomains map[string]uint64 // number of blocked requests per domain
|
||||
clients map[string]uint64 // number of requests per client
|
||||
// domains stores the number of requests for each domain.
|
||||
domains map[string]uint64
|
||||
// blockedDomains stores the number of requests for each domain that has
|
||||
// been blocked.
|
||||
blockedDomains map[string]uint64
|
||||
// clients stores the number of requests from each client.
|
||||
clients map[string]uint64
|
||||
}
|
||||
|
||||
// name-count pair
|
||||
// newUnit allocates the new *unit.
|
||||
func newUnit(id uint32) (u *unit) {
|
||||
return &unit{
|
||||
id: id,
|
||||
nResult: make([]uint64, resultLast),
|
||||
domains: make(map[string]uint64),
|
||||
blockedDomains: make(map[string]uint64),
|
||||
clients: make(map[string]uint64),
|
||||
}
|
||||
}
|
||||
|
||||
// countPair is a single name-number pair for deserializing statistics data into
|
||||
// the database.
|
||||
type countPair struct {
|
||||
Name string
|
||||
Count uint64
|
||||
}
|
||||
|
||||
// structure for storing data in file
|
||||
// unitDB is the structure for serializing statistics data into the database.
|
||||
type unitDB struct {
|
||||
NTotal uint64
|
||||
// NTotal is the total number of requests.
|
||||
NTotal uint64
|
||||
// NResult is the number of requests by the result's kind.
|
||||
NResult []uint64
|
||||
|
||||
Domains []countPair
|
||||
// Domains is the number of requests for each domain name.
|
||||
Domains []countPair
|
||||
// BlockedDomains is the number of requests blocked for each domain name.
|
||||
BlockedDomains []countPair
|
||||
Clients []countPair
|
||||
// Clients is the number of requests from each client.
|
||||
Clients []countPair
|
||||
|
||||
TimeAvg uint32 // usec
|
||||
// TimeAvg is the average of processing times in milliseconds of all the
|
||||
// requests in the unit.
|
||||
TimeAvg uint32
|
||||
}
|
||||
|
||||
// withRecovered turns the value recovered from panic if any into an error and
|
||||
// combines it with the one pointed by orig. orig must be non-nil.
|
||||
func withRecovered(orig *error) {
|
||||
p := recover()
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
// newUnitID is the default UnitIDGenFunc that generates the unique id hourly.
|
||||
func newUnitID() (id uint32) {
|
||||
const secsInHour = int64(time.Hour / time.Second)
|
||||
|
||||
var err error
|
||||
switch p := p.(type) {
|
||||
case error:
|
||||
err = fmt.Errorf("panic: %w", p)
|
||||
default:
|
||||
err = fmt.Errorf("panic: recovered value of type %[1]T: %[1]v", p)
|
||||
}
|
||||
|
||||
*orig = errors.WithDeferred(*orig, err)
|
||||
return uint32(time.Now().Unix() / secsInHour)
|
||||
}
|
||||
|
||||
// createObject creates s from conf and properly initializes it.
|
||||
func createObject(conf Config) (s *statsCtx, err error) {
|
||||
defer withRecovered(&err)
|
||||
|
||||
s = &statsCtx{
|
||||
mu: &sync.Mutex{},
|
||||
}
|
||||
if !checkInterval(conf.LimitDays) {
|
||||
conf.LimitDays = 1
|
||||
func finishTxn(tx *bbolt.Tx, commit bool) (err error) {
|
||||
if commit {
|
||||
err = errors.Annotate(tx.Commit(), "committing: %w")
|
||||
} else {
|
||||
err = errors.Annotate(tx.Rollback(), "rolling back: %w")
|
||||
}
|
||||
|
||||
s.conf = &Config{}
|
||||
*s.conf = conf
|
||||
s.conf.limit = conf.LimitDays * 24
|
||||
if conf.UnitID == nil {
|
||||
s.conf.UnitID = newUnitID
|
||||
}
|
||||
|
||||
if !s.dbOpen() {
|
||||
return nil, fmt.Errorf("open database")
|
||||
}
|
||||
|
||||
id := s.conf.UnitID()
|
||||
tx := s.beginTxn(true)
|
||||
var udb *unitDB
|
||||
if tx != nil {
|
||||
log.Tracef("Deleting old units...")
|
||||
firstID := id - s.conf.limit - 1
|
||||
unitDel := 0
|
||||
|
||||
err = tx.ForEach(newBucketWalker(tx, &unitDel, firstID))
|
||||
if err != nil && !errors.Is(err, errStop) {
|
||||
log.Debug("stats: deleting units: %s", err)
|
||||
}
|
||||
|
||||
udb = s.loadUnitFromDB(tx, id)
|
||||
|
||||
if unitDel != 0 {
|
||||
s.commitTxn(tx)
|
||||
} else {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Debug("rolling back: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
u := unit{}
|
||||
s.initUnit(&u, id)
|
||||
if udb != nil {
|
||||
deserialize(&u, udb)
|
||||
}
|
||||
s.current = &u
|
||||
|
||||
log.Debug("stats: initialized")
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// TODO(a.garipov): See if this is actually necessary. Looks like a rather
|
||||
// bizarre solution.
|
||||
const errStop errors.Error = "stop iteration"
|
||||
|
||||
// newBucketWalker returns a new bucket walker that deletes old units. The
|
||||
// integer that unitDelPtr points to is incremented for every successful
|
||||
// deletion. If the bucket isn't deleted, f returns errStop.
|
||||
func newBucketWalker(
|
||||
tx *bolt.Tx,
|
||||
unitDelPtr *int,
|
||||
firstID uint32,
|
||||
) (f func(name []byte, b *bolt.Bucket) (err error)) {
|
||||
return func(name []byte, _ *bolt.Bucket) (err error) {
|
||||
nameID, ok := unitNameToID(name)
|
||||
if !ok || nameID < firstID {
|
||||
err = tx.DeleteBucket(name)
|
||||
if err != nil {
|
||||
log.Debug("stats: tx.DeleteBucket: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("stats: deleted unit %d (name %x)", nameID, name)
|
||||
|
||||
*unitDelPtr++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return errStop
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statsCtx) Start() {
|
||||
s.initWeb()
|
||||
go s.periodicFlush()
|
||||
}
|
||||
|
||||
func checkInterval(days uint32) bool {
|
||||
return days == 0 || days == 1 || days == 7 || days == 30 || days == 90
|
||||
}
|
||||
|
||||
func (s *statsCtx) dbOpen() bool {
|
||||
var err error
|
||||
log.Tracef("db.Open...")
|
||||
s.db, err = bolt.Open(s.conf.Filename, 0o644, nil)
|
||||
if err != nil {
|
||||
log.Error("stats: open DB: %s: %s", s.conf.Filename, err)
|
||||
if err.Error() == "invalid argument" {
|
||||
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
|
||||
}
|
||||
return false
|
||||
}
|
||||
log.Tracef("db.Open")
|
||||
return true
|
||||
}
|
||||
|
||||
// Atomically swap the currently active unit with a new value
|
||||
// Return old value
|
||||
func (s *statsCtx) swapUnit(new *unit) (u *unit) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
u = s.current
|
||||
s.current = new
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// Get unit ID for the current hour
|
||||
func newUnitID() uint32 {
|
||||
return uint32(time.Now().Unix() / (60 * 60))
|
||||
}
|
||||
|
||||
// Initialize a unit
|
||||
func (s *statsCtx) initUnit(u *unit, id uint32) {
|
||||
u.id = id
|
||||
u.nResult = make([]uint64, rLast)
|
||||
u.domains = make(map[string]uint64)
|
||||
u.blockedDomains = make(map[string]uint64)
|
||||
u.clients = make(map[string]uint64)
|
||||
}
|
||||
|
||||
// Open a DB transaction
|
||||
func (s *statsCtx) beginTxn(wr bool) *bolt.Tx {
|
||||
db := s.db
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Tracef("db.Begin...")
|
||||
tx, err := db.Begin(wr)
|
||||
if err != nil {
|
||||
log.Error("db.Begin: %s", err)
|
||||
return nil
|
||||
}
|
||||
log.Tracef("db.Begin")
|
||||
return tx
|
||||
}
|
||||
|
||||
func (s *statsCtx) commitTxn(tx *bolt.Tx) {
|
||||
err := tx.Commit()
|
||||
if err != nil {
|
||||
log.Debug("tx.Commit: %s", err)
|
||||
return
|
||||
}
|
||||
log.Tracef("tx.Commit")
|
||||
return err
|
||||
}
|
||||
|
||||
// bucketNameLen is the length of a bucket, a 64-bit unsigned integer.
|
||||
@@ -262,10 +157,10 @@ const bucketNameLen = 8
|
||||
|
||||
// idToUnitName converts a numerical ID into a database unit name.
|
||||
func idToUnitName(id uint32) (name []byte) {
|
||||
name = make([]byte, bucketNameLen)
|
||||
binary.BigEndian.PutUint64(name, uint64(id))
|
||||
n := [bucketNameLen]byte{}
|
||||
binary.BigEndian.PutUint64(n[:], uint64(id))
|
||||
|
||||
return name
|
||||
return n[:]
|
||||
}
|
||||
|
||||
// unitNameToID converts a database unit name into a numerical ID. ok is false
|
||||
@@ -278,316 +173,131 @@ func unitNameToID(name []byte) (id uint32, ok bool) {
|
||||
return uint32(binary.BigEndian.Uint64(name)), true
|
||||
}
|
||||
|
||||
func (s *statsCtx) ongoing() (u *unit) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.current
|
||||
}
|
||||
|
||||
// Flush the current unit to DB and delete an old unit when a new hour is started
|
||||
// If a unit must be flushed:
|
||||
// . lock DB
|
||||
// . atomically set a new empty unit as the current one and get the old unit
|
||||
// This is important to do it inside DB lock, so the reader won't get inconsistent results.
|
||||
// . write the unit to DB
|
||||
// . remove the stale unit from DB
|
||||
// . unlock DB
|
||||
func (s *statsCtx) periodicFlush() {
|
||||
for {
|
||||
ptr := s.ongoing()
|
||||
if ptr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
id := s.conf.UnitID()
|
||||
if ptr.id == id || s.conf.limit == 0 {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
tx := s.beginTxn(true)
|
||||
|
||||
nu := unit{}
|
||||
s.initUnit(&nu, id)
|
||||
u := s.swapUnit(&nu)
|
||||
udb := serialize(u)
|
||||
|
||||
if tx == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ok1 := s.flushUnitToDB(tx, u.id, udb)
|
||||
ok2 := s.deleteUnit(tx, id-s.conf.limit)
|
||||
if ok1 || ok2 {
|
||||
s.commitTxn(tx)
|
||||
} else {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracef("periodicFlush() exited")
|
||||
}
|
||||
|
||||
// Delete unit's data from file
|
||||
func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool {
|
||||
err := tx.DeleteBucket(idToUnitName(id))
|
||||
if err != nil {
|
||||
log.Tracef("stats: bolt DeleteBucket: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debug("stats: deleted unit %d", id)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func convertMapToSlice(m map[string]uint64, max int) []countPair {
|
||||
a := []countPair{}
|
||||
func convertMapToSlice(m map[string]uint64, max int) (s []countPair) {
|
||||
s = make([]countPair, 0, len(m))
|
||||
for k, v := range m {
|
||||
pair := countPair{}
|
||||
pair.Name = k
|
||||
pair.Count = v
|
||||
a = append(a, pair)
|
||||
s = append(s, countPair{Name: k, Count: v})
|
||||
}
|
||||
less := func(i, j int) bool {
|
||||
return a[j].Count < a[i].Count
|
||||
|
||||
sort.Slice(s, func(i, j int) bool {
|
||||
return s[j].Count < s[i].Count
|
||||
})
|
||||
if max > len(s) {
|
||||
max = len(s)
|
||||
}
|
||||
sort.Slice(a, less)
|
||||
if max > len(a) {
|
||||
max = len(a)
|
||||
}
|
||||
return a[:max]
|
||||
|
||||
return s[:max]
|
||||
}
|
||||
|
||||
func convertSliceToMap(a []countPair) map[string]uint64 {
|
||||
m := map[string]uint64{}
|
||||
func convertSliceToMap(a []countPair) (m map[string]uint64) {
|
||||
m = map[string]uint64{}
|
||||
for _, it := range a {
|
||||
m[it.Name] = it.Count
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func serialize(u *unit) *unitDB {
|
||||
udb := unitDB{}
|
||||
udb.NTotal = u.nTotal
|
||||
|
||||
udb.NResult = append(udb.NResult, u.nResult...)
|
||||
|
||||
// serialize converts u to the *unitDB. It's safe for concurrent use. u must
|
||||
// not be nil.
|
||||
func (u *unit) serialize() (udb *unitDB) {
|
||||
var timeAvg uint32 = 0
|
||||
if u.nTotal != 0 {
|
||||
udb.TimeAvg = uint32(u.timeSum / u.nTotal)
|
||||
timeAvg = uint32(u.timeSum / u.nTotal)
|
||||
}
|
||||
|
||||
udb.Domains = convertMapToSlice(u.domains, maxDomains)
|
||||
udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains)
|
||||
udb.Clients = convertMapToSlice(u.clients, maxClients)
|
||||
|
||||
return &udb
|
||||
return &unitDB{
|
||||
NTotal: u.nTotal,
|
||||
NResult: append([]uint64{}, u.nResult...),
|
||||
Domains: convertMapToSlice(u.domains, maxDomains),
|
||||
BlockedDomains: convertMapToSlice(u.blockedDomains, maxDomains),
|
||||
Clients: convertMapToSlice(u.clients, maxClients),
|
||||
TimeAvg: timeAvg,
|
||||
}
|
||||
}
|
||||
|
||||
func deserialize(u *unit, udb *unitDB) {
|
||||
u.nTotal = udb.NTotal
|
||||
|
||||
n := len(udb.NResult)
|
||||
if n < len(u.nResult) {
|
||||
n = len(u.nResult) // n = min(len(udb.NResult), len(u.nResult))
|
||||
}
|
||||
for i := 1; i < n; i++ {
|
||||
u.nResult[i] = udb.NResult[i]
|
||||
}
|
||||
|
||||
u.domains = convertSliceToMap(udb.Domains)
|
||||
u.blockedDomains = convertSliceToMap(udb.BlockedDomains)
|
||||
u.clients = convertSliceToMap(udb.Clients)
|
||||
u.timeSum = uint64(udb.TimeAvg) * u.nTotal
|
||||
}
|
||||
|
||||
func (s *statsCtx) flushUnitToDB(tx *bolt.Tx, id uint32, udb *unitDB) bool {
|
||||
log.Tracef("Flushing unit %d", id)
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id))
|
||||
if err != nil {
|
||||
log.Error("tx.CreateBucketIfNotExists: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err = enc.Encode(udb)
|
||||
if err != nil {
|
||||
log.Error("gob.Encode: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = bkt.Put([]byte{0}, buf.Bytes())
|
||||
if err != nil {
|
||||
log.Error("bkt.Put: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
|
||||
func loadUnitFromDB(tx *bbolt.Tx, id uint32) (udb *unitDB) {
|
||||
bkt := tx.Bucket(idToUnitName(id))
|
||||
if bkt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// log.Tracef("Loading unit %d", id)
|
||||
log.Tracef("Loading unit %d", id)
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write(bkt.Get([]byte{0}))
|
||||
dec := gob.NewDecoder(&buf)
|
||||
udb := unitDB{}
|
||||
err := dec.Decode(&udb)
|
||||
udb = &unitDB{}
|
||||
|
||||
err := gob.NewDecoder(&buf).Decode(udb)
|
||||
if err != nil {
|
||||
log.Error("gob Decode: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &udb
|
||||
return udb
|
||||
}
|
||||
|
||||
func convertTopSlice(a []countPair) []map[string]uint64 {
|
||||
m := []map[string]uint64{}
|
||||
for _, it := range a {
|
||||
ent := map[string]uint64{}
|
||||
ent[it.Name] = it.Count
|
||||
m = append(m, ent)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *statsCtx) setLimit(limitDays int) {
|
||||
s.conf.limit = uint32(limitDays) * 24
|
||||
if limitDays == 0 {
|
||||
s.clear()
|
||||
}
|
||||
|
||||
log.Debug("stats: set limit: %d", limitDays)
|
||||
}
|
||||
|
||||
func (s *statsCtx) WriteDiskConfig(dc *DiskConfig) {
|
||||
dc.Interval = s.conf.limit / 24
|
||||
}
|
||||
|
||||
func (s *statsCtx) Close() {
|
||||
u := s.swapUnit(nil)
|
||||
udb := serialize(u)
|
||||
tx := s.beginTxn(true)
|
||||
if tx != nil {
|
||||
if s.flushUnitToDB(tx, u.id, udb) {
|
||||
s.commitTxn(tx)
|
||||
} else {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}
|
||||
|
||||
if s.db != nil {
|
||||
log.Tracef("db.Close...")
|
||||
_ = s.db.Close()
|
||||
log.Tracef("db.Close")
|
||||
}
|
||||
|
||||
log.Debug("stats: closed")
|
||||
}
|
||||
|
||||
// Reset counters and clear database
|
||||
func (s *statsCtx) clear() {
|
||||
tx := s.beginTxn(true)
|
||||
if tx != nil {
|
||||
db := s.db
|
||||
s.db = nil
|
||||
_ = tx.Rollback()
|
||||
// the active transactions can continue using database,
|
||||
// but no new transactions will be opened
|
||||
_ = db.Close()
|
||||
log.Tracef("db.Close")
|
||||
// all active transactions are now closed
|
||||
}
|
||||
|
||||
u := unit{}
|
||||
s.initUnit(&u, s.conf.UnitID())
|
||||
_ = s.swapUnit(&u)
|
||||
|
||||
err := os.Remove(s.conf.Filename)
|
||||
if err != nil {
|
||||
log.Error("os.Remove: %s", err)
|
||||
}
|
||||
|
||||
_ = s.dbOpen()
|
||||
|
||||
log.Debug("stats: cleared")
|
||||
}
|
||||
|
||||
func (s *statsCtx) Update(e Entry) {
|
||||
if s.conf.limit == 0 {
|
||||
// deserealize assigns the appropriate values from udb to u. u must not be nil.
|
||||
// It's safe for concurrent use.
|
||||
func (u *unit) deserialize(udb *unitDB) {
|
||||
if udb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if e.Result == 0 ||
|
||||
e.Result >= rLast ||
|
||||
e.Domain == "" ||
|
||||
e.Client == "" {
|
||||
return
|
||||
}
|
||||
u.nTotal = udb.NTotal
|
||||
u.nResult = make([]uint64, resultLast)
|
||||
copy(u.nResult, udb.NResult)
|
||||
u.domains = convertSliceToMap(udb.Domains)
|
||||
u.blockedDomains = convertSliceToMap(udb.BlockedDomains)
|
||||
u.clients = convertSliceToMap(udb.Clients)
|
||||
u.timeSum = uint64(udb.TimeAvg) * udb.NTotal
|
||||
}
|
||||
|
||||
clientID := e.Client
|
||||
if ip := net.ParseIP(clientID); ip != nil {
|
||||
clientID = ip.String()
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
u := s.current
|
||||
|
||||
u.nResult[e.Result]++
|
||||
|
||||
if e.Result == RNotFiltered {
|
||||
u.domains[e.Domain]++
|
||||
// add adds new data to u. It's safe for concurrent use.
|
||||
func (u *unit) add(res Result, domain, cli string, dur uint64) {
|
||||
u.nResult[res]++
|
||||
if res == RNotFiltered {
|
||||
u.domains[domain]++
|
||||
} else {
|
||||
u.blockedDomains[e.Domain]++
|
||||
u.blockedDomains[domain]++
|
||||
}
|
||||
|
||||
u.clients[clientID]++
|
||||
u.timeSum += uint64(e.Time)
|
||||
u.clients[cli]++
|
||||
u.timeSum += dur
|
||||
u.nTotal++
|
||||
}
|
||||
|
||||
func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
|
||||
tx := s.beginTxn(false)
|
||||
if tx == nil {
|
||||
return nil, 0
|
||||
// flushUnitToDB puts udb to the database at id.
|
||||
func (udb *unitDB) flushUnitToDB(tx *bbolt.Tx, id uint32) (err error) {
|
||||
log.Debug("stats: flushing unit with id %d and total of %d", id, udb.NTotal)
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists(idToUnitName(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating bucket: %w", err)
|
||||
}
|
||||
|
||||
cur := s.ongoing()
|
||||
curID := cur.id
|
||||
|
||||
// Per-hour units.
|
||||
units := []*unitDB{}
|
||||
firstID := curID - limit + 1
|
||||
for i := firstID; i != curID; i++ {
|
||||
u := s.loadUnitFromDB(tx, i)
|
||||
if u == nil {
|
||||
u = &unitDB{}
|
||||
u.NResult = make([]uint64, rLast)
|
||||
}
|
||||
units = append(units, u)
|
||||
buf := &bytes.Buffer{}
|
||||
err = gob.NewEncoder(buf).Encode(udb)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding unit: %w", err)
|
||||
}
|
||||
|
||||
_ = tx.Rollback()
|
||||
|
||||
units = append(units, serialize(cur))
|
||||
|
||||
if len(units) != int(limit) {
|
||||
log.Fatalf("len(units) != limit: %d %d", len(units), limit)
|
||||
err = bkt.Put([]byte{0}, buf.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("putting unit to database: %w", err)
|
||||
}
|
||||
|
||||
return units, firstID
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertTopSlice(a []countPair) (m []map[string]uint64) {
|
||||
m = make([]map[string]uint64, 0, len(a))
|
||||
for _, it := range a {
|
||||
m = append(m, map[string]uint64{it.Name: it.Count})
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// numsGetter is a signature for statsCollector argument.
|
||||
@@ -597,6 +307,7 @@ type numsGetter func(u *unitDB) (num uint64)
|
||||
// timeUnit using ng to retrieve data.
|
||||
func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsGetter) (nums []uint64) {
|
||||
if timeUnit == Hours {
|
||||
nums = make([]uint64, 0, len(units))
|
||||
for _, u := range units {
|
||||
nums = append(nums, ng(u))
|
||||
}
|
||||
@@ -628,16 +339,17 @@ func statsCollector(units []*unitDB, firstID uint32, timeUnit TimeUnit, ng numsG
|
||||
// pairsGetter is a signature for topsCollector argument.
|
||||
type pairsGetter func(u *unitDB) (pairs []countPair)
|
||||
|
||||
// topsCollector collects statistics about highest values fro the given *unitDB
|
||||
// topsCollector collects statistics about highest values from the given *unitDB
|
||||
// slice using pg to retrieve data.
|
||||
func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64 {
|
||||
m := map[string]uint64{}
|
||||
for _, u := range units {
|
||||
for _, it := range pg(u) {
|
||||
m[it.Name] += it.Count
|
||||
for _, cp := range pg(u) {
|
||||
m[cp.Name] += cp.Count
|
||||
}
|
||||
}
|
||||
a2 := convertMapToSlice(m, max)
|
||||
|
||||
return convertTopSlice(a2)
|
||||
}
|
||||
|
||||
@@ -668,8 +380,21 @@ func topsCollector(units []*unitDB, max int, pg pairsGetter) []map[string]uint64
|
||||
* parental-blocked
|
||||
These values are just the sum of data for all units.
|
||||
*/
|
||||
func (s *statsCtx) getData() (statsResponse, bool) {
|
||||
limit := s.conf.limit
|
||||
func (s *StatsCtx) getData(limit uint32) (StatsResp, bool) {
|
||||
if limit == 0 {
|
||||
return StatsResp{
|
||||
TimeUnits: "days",
|
||||
|
||||
TopBlocked: []topAddrs{},
|
||||
TopClients: []topAddrs{},
|
||||
TopQueried: []topAddrs{},
|
||||
|
||||
BlockedFiltering: []uint64{},
|
||||
DNSQueries: []uint64{},
|
||||
ReplacedParental: []uint64{},
|
||||
ReplacedSafebrowsing: []uint64{},
|
||||
}, true
|
||||
}
|
||||
|
||||
timeUnit := Hours
|
||||
if limit/24 > 7 {
|
||||
@@ -678,7 +403,7 @@ func (s *statsCtx) getData() (statsResponse, bool) {
|
||||
|
||||
units, firstID := s.loadUnits(limit)
|
||||
if units == nil {
|
||||
return statsResponse{}, false
|
||||
return StatsResp{}, false
|
||||
}
|
||||
|
||||
dnsQueries := statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NTotal })
|
||||
@@ -686,7 +411,7 @@ func (s *statsCtx) getData() (statsResponse, bool) {
|
||||
log.Fatalf("len(dnsQueries) != limit: %d %d", len(dnsQueries), limit)
|
||||
}
|
||||
|
||||
data := statsResponse{
|
||||
data := StatsResp{
|
||||
DNSQueries: dnsQueries,
|
||||
BlockedFiltering: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RFiltered] }),
|
||||
ReplacedSafebrowsing: statsCollector(units, firstID, timeUnit, func(u *unitDB) (num uint64) { return u.NResult[RSafeBrowsing] }),
|
||||
@@ -698,7 +423,7 @@ func (s *statsCtx) getData() (statsResponse, bool) {
|
||||
|
||||
// Total counters:
|
||||
sum := unitDB{
|
||||
NResult: make([]uint64, rLast),
|
||||
NResult: make([]uint64, resultLast),
|
||||
}
|
||||
timeN := 0
|
||||
for _, u := range units {
|
||||
@@ -730,31 +455,3 @@ func (s *statsCtx) getData() (statsResponse, bool) {
|
||||
|
||||
return data, true
|
||||
}
|
||||
|
||||
func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
|
||||
if s.conf.limit == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
units, _ := s.loadUnits(s.conf.limit)
|
||||
if units == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// top clients
|
||||
m := map[string]uint64{}
|
||||
for _, u := range units {
|
||||
for _, it := range u.Clients {
|
||||
m[it.Name] += it.Count
|
||||
}
|
||||
}
|
||||
a := convertMapToSlice(m, int(maxCount))
|
||||
d := []net.IP{}
|
||||
for _, it := range a {
|
||||
ip := net.ParseIP(it.Name)
|
||||
if ip != nil {
|
||||
d = append(d, ip)
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -1,34 +1,32 @@
|
||||
module github.com/AdguardTeam/AdGuardHome/internal/tools
|
||||
|
||||
go 1.17
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/fzipp/gocyclo v0.5.1
|
||||
github.com/fzipp/gocyclo v0.6.0
|
||||
github.com/golangci/misspell v0.3.5
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8
|
||||
github.com/kisielk/errcheck v1.6.0
|
||||
github.com/kisielk/errcheck v1.6.2
|
||||
github.com/kyoh86/looppointer v0.1.7
|
||||
github.com/securego/gosec/v2 v2.11.0
|
||||
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616
|
||||
golang.org/x/tools v0.1.11-0.20220316014157-77aa08bb151a
|
||||
honnef.co/go/tools v0.3.1
|
||||
github.com/securego/gosec/v2 v2.12.0
|
||||
golang.org/x/tools v0.1.12
|
||||
honnef.co/go/tools v0.3.3
|
||||
mvdan.cc/gofumpt v0.3.1
|
||||
mvdan.cc/unparam v0.0.0-20220316160445-06cc5682983b
|
||||
mvdan.cc/unparam v0.0.0-20220706161116-678bad134442
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.1.0 // indirect
|
||||
github.com/BurntSushi/toml v1.2.0 // indirect
|
||||
github.com/client9/misspell v0.3.4 // indirect
|
||||
github.com/google/go-cmp v0.5.8 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/gookit/color v1.5.0 // indirect
|
||||
github.com/gookit/color v1.5.1 // indirect
|
||||
github.com/kyoh86/nolint v0.0.1 // indirect
|
||||
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20220426173459-3bcf042a4bf5 // indirect
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
|
||||
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20220722155223-a9213eeb770e // indirect
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect
|
||||
golang.org/x/sys v0.0.0-20220804214406-8e32c043e418 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
@@ -34,9 +34,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX
|
||||
contrib.go.opencensus.io/exporter/stackdriver v0.13.4/go.mod h1:aXENhDJ1Y4lIg4EUaVTwzvYETVNZk10Pu26tevFKLUc=
|
||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v0.4.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I=
|
||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0=
|
||||
github.com/BurntSushi/toml v1.2.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/Masterminds/goutils v1.1.0/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||
github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
||||
@@ -78,7 +77,6 @@ github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwc
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -94,12 +92,11 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
|
||||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
||||
github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM=
|
||||
github.com/frankban/quicktest v1.14.2 h1:SPb1KFFmM+ybpEjPUhCCkZOM5xlovT5UbrMvWnXyBns=
|
||||
github.com/frankban/quicktest v1.14.2/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fullstorydev/grpcurl v1.6.0/go.mod h1:ZQ+ayqbKMJNhzLmbpCiurTVlaK2M/3nqZCxaQ2Ze/sM=
|
||||
github.com/fzipp/gocyclo v0.5.1 h1:L66amyuYogbxl0j2U+vGqJXusPF2IkduvXLnYD5TFgw=
|
||||
github.com/fzipp/gocyclo v0.5.1/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
|
||||
github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo=
|
||||
github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
@@ -157,7 +154,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -179,8 +175,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/gookit/color v1.5.0 h1:1Opow3+BWDwqor78DcJkJCIwnkviFi+rrOANki9BUFw=
|
||||
github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo=
|
||||
github.com/gookit/color v1.5.1 h1:Vjg2VEcdHpwq+oY63s/ksHrgJYCTo0bwWvmmYWdE9fQ=
|
||||
github.com/gookit/color v1.5.1/go.mod h1:wZFzea4X8qN6vHOSP2apMb4/+w/orMznEzYsIHPaqKM=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8 h1:PVRE9d4AQKmbelZ7emNig1+NT27DUmKZn5qXxfio54U=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8/go.mod h1:Qcp2HIAYhR7mNUVSIxZww3Guk4it82ghYcEXIAk+QT0=
|
||||
@@ -222,19 +218,17 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V
|
||||
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k=
|
||||
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
|
||||
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
||||
github.com/kisielk/errcheck v1.6.0 h1:YTDO4pNy7AUN/021p+JGHycQyYNIyMoenM1YDVK6RlY=
|
||||
github.com/kisielk/errcheck v1.6.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/errcheck v1.6.2 h1:uGQ9xI8/pgc9iOoCe7kWQgRE6SBTrCGmTSf0LrEtY7c=
|
||||
github.com/kisielk/errcheck v1.6.2/go.mod h1:nXw/i/MfnvRHqXa7XXmQMUB0oNFGuBrNI8d8NLy0LPw=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/kyoh86/looppointer v0.1.7 h1:q5sZOhFvmvQ6ZoZxvPB/Mjj2croWX7L49BBuI4XQWCM=
|
||||
github.com/kyoh86/looppointer v0.1.7/go.mod h1:l0cRF49N6xDPx8IuBGC/imZo8Yn1BBLJY0vzI+4fepc=
|
||||
@@ -244,7 +238,7 @@ github.com/kyoh86/nolint v0.0.1/go.mod h1:1ZiZZ7qqrZ9dZegU96phwVcdQOMKIqRzFJL3ew
|
||||
github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag=
|
||||
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
@@ -288,19 +282,18 @@ github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+
|
||||
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
|
||||
github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc=
|
||||
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
|
||||
github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c=
|
||||
github.com/onsi/ginkgo/v2 v2.1.3 h1:e/3Cwtogj0HA+25nMP1jCMDIf8RtRYbGwGGuBIFztkc=
|
||||
github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c=
|
||||
github.com/onsi/ginkgo/v2 v2.1.4 h1:GNapqRSid3zijZ9H77KrgVG4/8KqiyRsxcSxe+7ApXY=
|
||||
github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/onsi/gomega v1.19.0 h1:4ieX6qQjPP/BfC3mpsAtIGGlxTWPeA3Inl/7DtXw1tw=
|
||||
github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro=
|
||||
github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -324,14 +317,12 @@ github.com/pseudomuto/protokit v0.2.0/go.mod h1:2PdH30hxVHsup8KpBTOXTBeMVhJZVio3
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/securego/gosec/v2 v2.11.0 h1:+PDkpzR41OI2jrw1q6AdXZCbsNGNGT7pQjal0H0cArI=
|
||||
github.com/securego/gosec/v2 v2.11.0/go.mod h1:SX8bptShuG8reGC0XS09+a4H2BoWSJi+fscA+Pulbpo=
|
||||
github.com/securego/gosec/v2 v2.12.0 h1:CQWdW7ATFpvLSohMVsajscfyHJ5rsGmEXmsNcsDNmAg=
|
||||
github.com/securego/gosec/v2 v2.12.0/go.mod h1:iTpT+eKTw59bSgklBHlSnH5O2tNygHMDxfvMubA4i7I=
|
||||
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
||||
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
@@ -356,8 +347,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s=
|
||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20200427203606-3cfed13b9966/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
@@ -407,7 +398,7 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -418,11 +409,9 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
||||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5 h1:FR+oGxGfbQu1d+jglI3rCkjAjUnhRSZcUxr+DqlDLNo=
|
||||
golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw=
|
||||
golang.org/x/exp/typeparams v0.0.0-20220218215828-6cf2b201936e/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/exp/typeparams v0.0.0-20220426173459-3bcf042a4bf5 h1:pKfHvPtBtqS0+V/V9Y0cZQa2h8HJV/qSRJiGgYu+LQA=
|
||||
golang.org/x/exp/typeparams v0.0.0-20220426173459-3bcf042a4bf5/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/exp/typeparams v0.0.0-20220722155223-a9213eeb770e h1:7Xs2YCOpMlNqSQSmrrnhlzBXIE/bpMecZplbLePTJvE=
|
||||
golang.org/x/exp/typeparams v0.0.0-20220722155223-a9213eeb770e/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -435,7 +424,6 @@ golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHl
|
||||
golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
|
||||
golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
|
||||
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
||||
@@ -446,9 +434,9 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -488,8 +476,9 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -505,8 +494,9 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -558,12 +548,12 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc=
|
||||
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220804214406-8e32c043e418 h1:9vYwv7OjYaky/tlAeD7C4oC9EsPTlaFl1H2jS++V+ME=
|
||||
golang.org/x/sys v0.0.0-20220804214406-8e32c043e418/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -620,7 +610,6 @@ golang.org/x/tools v0.0.0-20200426102838-f3a5411a4c3b/go.mod h1:EkVYQZoAsY45+roY
|
||||
golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200626171337-aa94e735be7f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200630154851-b2d8b0336632/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20200706234117-b22de6825cf7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
@@ -628,16 +617,14 @@ golang.org/x/tools v0.0.0-20200710042808-f1c4188a97a1/go.mod h1:njjCfa9FT2d7l9Bc
|
||||
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
|
||||
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU=
|
||||
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
||||
golang.org/x/tools v0.1.11-0.20220316014157-77aa08bb151a h1:ofrrl6c6NG5/IOSx/R1cyiQxxjqlur0h/TvbUhkH0II=
|
||||
golang.org/x/tools v0.1.11-0.20220316014157-77aa08bb151a/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
|
||||
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
|
||||
golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f h1:GGU+dLjvlC3qDwqYgL6UgRmHXhOOgns0bZu2Ty5mm6U=
|
||||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
|
||||
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
|
||||
@@ -743,8 +730,9 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
@@ -752,12 +740,12 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
|
||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
||||
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
honnef.co/go/tools v0.3.1 h1:1kJlrWJLkaGXgcaeosRXViwviqjI7nkBvU2+sZW0AYc=
|
||||
honnef.co/go/tools v0.3.1/go.mod h1:vlRD9XErLMGT+mDuofSr0mMMquscM/1nQqtRSsh6m70=
|
||||
honnef.co/go/tools v0.3.3 h1:oDx7VAwstgpYpb3wv0oxiZlxY+foCpRAwY7Vk6XpAgA=
|
||||
honnef.co/go/tools v0.3.3/go.mod h1:jzwdWgg7Jdq75wlfblQxO4neNaFFSvgc1tD5Wv8U0Yw=
|
||||
mvdan.cc/gofumpt v0.3.1 h1:avhhrOmv0IuvQVK7fvwV91oFSGAk5/6Po8GXTzICeu8=
|
||||
mvdan.cc/gofumpt v0.3.1/go.mod h1:w3ymliuxvzVx8DAutBnVyDqYb1Niy/yCJt/lk821YCE=
|
||||
mvdan.cc/unparam v0.0.0-20220316160445-06cc5682983b h1:C8Pi6noat8BcrL9WnSRYeQ63fpkJk3hKVHtF5731kIw=
|
||||
mvdan.cc/unparam v0.0.0-20220316160445-06cc5682983b/go.mod h1:WqFWCt8MGPoFSYGsQSiIORRlYVhkJsIk+n2MY6rhNbA=
|
||||
mvdan.cc/unparam v0.0.0-20220706161116-678bad134442 h1:seuXWbRB1qPrS3NQnHmFKLJLtskWyueeIzmLXghMGgk=
|
||||
mvdan.cc/unparam v0.0.0-20220706161116-678bad134442/go.mod h1:F/Cxw/6mVrNKqrR2YjFf5CaW0Bw4RL8RfbEf4GRggJk=
|
||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
||||
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
_ "github.com/kisielk/errcheck"
|
||||
_ "github.com/kyoh86/looppointer"
|
||||
_ "github.com/securego/gosec/v2/cmd/gosec"
|
||||
_ "golang.org/x/lint/golint"
|
||||
_ "golang.org/x/tools/go/analysis/passes/nilness/cmd/nilness"
|
||||
_ "golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow"
|
||||
_ "honnef.co/go/tools/cmd/staticcheck"
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
@@ -17,11 +17,12 @@ const versionCheckPeriod = 8 * time.Hour
|
||||
|
||||
// VersionInfo contains information about a new version.
|
||||
type VersionInfo struct {
|
||||
CanAutoUpdate *bool `json:"can_autoupdate,omitempty"`
|
||||
NewVersion string `json:"new_version,omitempty"`
|
||||
Announcement string `json:"announcement,omitempty"`
|
||||
AnnouncementURL string `json:"announcement_url,omitempty"`
|
||||
SelfUpdateMinVersion string `json:"-"`
|
||||
NewVersion string `json:"new_version,omitempty"`
|
||||
Announcement string `json:"announcement,omitempty"`
|
||||
AnnouncementURL string `json:"announcement_url,omitempty"`
|
||||
// TODO(a.garipov): See if the frontend actually still cares about
|
||||
// nullability.
|
||||
CanAutoUpdate aghalg.NullBool `json:"can_autoupdate,omitempty"`
|
||||
}
|
||||
|
||||
// MaxResponseSize is responses on server's requests maximum length in bytes.
|
||||
@@ -67,15 +68,13 @@ func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
}
|
||||
|
||||
func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
var canAutoUpdate bool
|
||||
info := VersionInfo{
|
||||
CanAutoUpdate: &canAutoUpdate,
|
||||
CanAutoUpdate: aghalg.NBFalse,
|
||||
}
|
||||
versionJSON := map[string]string{
|
||||
"version": "",
|
||||
"announcement": "",
|
||||
"announcement_url": "",
|
||||
"selfupdate_min_version": "",
|
||||
"version": "",
|
||||
"announcement": "",
|
||||
"announcement_url": "",
|
||||
}
|
||||
err := json.Unmarshal(data, &versionJSON)
|
||||
if err != nil {
|
||||
@@ -91,14 +90,9 @@ func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||
info.NewVersion = versionJSON["version"]
|
||||
info.Announcement = versionJSON["announcement"]
|
||||
info.AnnouncementURL = versionJSON["announcement_url"]
|
||||
info.SelfUpdateMinVersion = versionJSON["selfupdate_min_version"]
|
||||
|
||||
packageURL, ok := u.downloadURL(versionJSON)
|
||||
if ok &&
|
||||
info.NewVersion != u.version &&
|
||||
strings.TrimPrefix(u.version, "v") >= strings.TrimPrefix(info.SelfUpdateMinVersion, "v") {
|
||||
canAutoUpdate = true
|
||||
}
|
||||
info.CanAutoUpdate = aghalg.BoolToNullBool(ok && info.NewVersion != u.version)
|
||||
|
||||
u.newVersion = info.NewVersion
|
||||
u.packageURL = packageURL
|
||||
|
||||
@@ -82,8 +82,9 @@ type Config struct {
|
||||
func NewUpdater(conf *Config) *Updater {
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "static.adguard.com",
|
||||
Path: path.Join("adguardhome", conf.Channel, "version.json"),
|
||||
// TODO(a.garipov): Make configurable.
|
||||
Host: "static.adtidy.org",
|
||||
Path: path.Join("adguardhome", conf.Channel, "version.json"),
|
||||
}
|
||||
return &Updater{
|
||||
client: conf.Client,
|
||||
@@ -104,11 +105,19 @@ func NewUpdater(conf *Config) *Updater {
|
||||
}
|
||||
|
||||
// Update performs the auto-update.
|
||||
func (u *Updater) Update() error {
|
||||
func (u *Updater) Update() (err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
err := u.prepare()
|
||||
log.Info("updater: updating")
|
||||
defer func() { log.Info("updater: finished; errors: %v", err) }()
|
||||
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = u.prepare(filepath.Base(execPath))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -159,7 +168,8 @@ func (u *Updater) VersionCheckURL() (vcu string) {
|
||||
return u.versionCheckURL
|
||||
}
|
||||
|
||||
func (u *Updater) prepare() (err error) {
|
||||
// prepare fills all necessary fields in Updater object.
|
||||
func (u *Updater) prepare(exeName string) (err error) {
|
||||
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
||||
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
@@ -170,17 +180,21 @@ func (u *Updater) prepare() (err error) {
|
||||
u.packageName = filepath.Join(u.updateDir, pkgNameOnly)
|
||||
u.backupDir = filepath.Join(u.workDir, "agh-backup")
|
||||
|
||||
exeName := "AdGuardHome"
|
||||
updateExeName := "AdGuardHome"
|
||||
if u.goos == "windows" {
|
||||
exeName = "AdGuardHome.exe"
|
||||
updateExeName = "AdGuardHome.exe"
|
||||
}
|
||||
|
||||
u.backupExeName = filepath.Join(u.backupDir, exeName)
|
||||
u.updateExeName = filepath.Join(u.updateDir, exeName)
|
||||
u.updateExeName = filepath.Join(u.updateDir, updateExeName)
|
||||
|
||||
log.Info("Updating from %s to %s. URL:%s", version.Version(), u.newVersion, u.packageURL)
|
||||
log.Debug(
|
||||
"updater: updating from %s to %s using url: %s",
|
||||
version.Version(),
|
||||
u.newVersion,
|
||||
u.packageURL,
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Use os.Args[0] instead?
|
||||
u.currentExeName = filepath.Join(u.workDir, exeName)
|
||||
_, err = os.Stat(u.currentExeName)
|
||||
if err != nil {
|
||||
@@ -194,7 +208,7 @@ func (u *Updater) unpack() error {
|
||||
var err error
|
||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||
|
||||
log.Debug("updater: unpacking the package")
|
||||
log.Debug("updater: unpacking package")
|
||||
if strings.HasSuffix(pkgNameOnly, ".zip") {
|
||||
u.unpackedFiles, err = zipFileUnpack(u.packageName, u.updateDir)
|
||||
if err != nil {
|
||||
@@ -229,7 +243,7 @@ func (u *Updater) check() error {
|
||||
}
|
||||
|
||||
func (u *Updater) backup() error {
|
||||
log.Debug("updater: backing up the current configuration")
|
||||
log.Debug("updater: backing up current configuration")
|
||||
_ = os.Mkdir(u.backupDir, 0o755)
|
||||
err := copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
|
||||
if err != nil {
|
||||
@@ -252,7 +266,7 @@ func (u *Updater) replace() error {
|
||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", u.updateDir, u.workDir, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: renaming: %s -> %s", u.currentExeName, u.backupExeName)
|
||||
log.Debug("updater: renaming: %s to %s", u.currentExeName, u.backupExeName)
|
||||
err = os.Rename(u.currentExeName, u.backupExeName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -268,7 +282,7 @@ func (u *Updater) replace() error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("updater: renamed: %s -> %s", u.updateExeName, u.currentExeName)
|
||||
log.Debug("updater: renamed: %s to %s", u.updateExeName, u.currentExeName)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -297,7 +311,7 @@ func (u *Updater) downloadPackageFile(url, filename string) (err error) {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("updater: reading HTTP body")
|
||||
log.Debug("updater: reading http body")
|
||||
// This use of ReadAll is now safe, because we limited body's Reader.
|
||||
body, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
@@ -343,7 +357,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
}
|
||||
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
log.Debug("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
|
||||
log.Info("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -364,7 +378,7 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("updater: created file %s", outputName)
|
||||
log.Debug("updater: created file %q", outputName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
@@ -440,7 +454,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
|
||||
}
|
||||
|
||||
log.Tracef("created directory %q", outputName)
|
||||
log.Debug("updater: created directory %q", outputName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -457,7 +471,7 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("created file %s", outputName)
|
||||
log.Debug("updater: created file %q", outputName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
@@ -516,7 +530,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("updater: copied: %q -> %q", src, dst)
|
||||
log.Debug("updater: copied: %q to %q", src, dst)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@@ -44,28 +45,28 @@ func TestUpdateGetVersion(t *testing.T) {
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip",
|
||||
"download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip",
|
||||
"download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
|
||||
"download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip",
|
||||
"download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
|
||||
"download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
|
||||
"download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
|
||||
"download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
|
||||
"download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
|
||||
"download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
|
||||
"download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
|
||||
"download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
|
||||
"download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
|
||||
"download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
|
||||
"download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
|
||||
"download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
|
||||
"download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
|
||||
"download_windows_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_amd64.zip",
|
||||
"download_windows_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_windows_386.zip",
|
||||
"download_darwin_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
|
||||
"download_darwin_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_darwin_386.zip",
|
||||
"download_linux_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
|
||||
"download_linux_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
|
||||
"download_linux_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
|
||||
"download_linux_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
|
||||
"download_linux_mips": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
|
||||
"download_linux_mipsle": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
|
||||
"download_linux_mips64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
|
||||
"download_linux_mips64le": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
|
||||
"download_freebsd_386": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
|
||||
"download_freebsd_amd64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
|
||||
"download_freebsd_arm": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv5": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
|
||||
"download_freebsd_armv6": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
|
||||
"download_freebsd_arm64": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
|
||||
}`
|
||||
|
||||
l, lport := startHTTPServer(jsonData)
|
||||
@@ -92,10 +93,7 @@ func TestUpdateGetVersion(t *testing.T) {
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, "v0.0", info.SelfUpdateMinVersion)
|
||||
if assert.NotNil(t, info.CanAutoUpdate) {
|
||||
assert.True(t, *info.CanAutoUpdate)
|
||||
}
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
|
||||
// check cached
|
||||
_, err = u.VersionInfo(false)
|
||||
@@ -133,7 +131,7 @@ func TestUpdate(t *testing.T) {
|
||||
u.newVersion = "v0.103.1"
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare())
|
||||
require.NoError(t, u.prepare("AdGuardHome"))
|
||||
|
||||
u.currentExeName = filepath.Join(wd, "AdGuardHome")
|
||||
|
||||
@@ -211,7 +209,7 @@ func TestUpdateWindows(t *testing.T) {
|
||||
u.newVersion = "v0.103.1"
|
||||
u.packageURL = fakeURL.String()
|
||||
|
||||
require.NoError(t, u.prepare())
|
||||
require.NoError(t, u.prepare("AdGuardHome.exe"))
|
||||
|
||||
u.currentExeName = filepath.Join(wd, "AdGuardHome.exe")
|
||||
|
||||
@@ -262,7 +260,7 @@ func TestUpdater_VersionInto_ARM(t *testing.T) {
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz"
|
||||
"download_linux_armv7": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz"
|
||||
}`
|
||||
|
||||
l, lport := startHTTPServer(jsonData)
|
||||
@@ -290,10 +288,7 @@ func TestUpdater_VersionInto_ARM(t *testing.T) {
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, "v0.0", info.SelfUpdateMinVersion)
|
||||
if assert.NotNil(t, info.CanAutoUpdate) {
|
||||
assert.True(t, *info.CanAutoUpdate)
|
||||
}
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
}
|
||||
|
||||
func TestUpdater_VersionInto_MIPS(t *testing.T) {
|
||||
@@ -302,7 +297,7 @@ func TestUpdater_VersionInto_MIPS(t *testing.T) {
|
||||
"announcement": "AdGuard Home v0.103.0-beta.2 is now available!",
|
||||
"announcement_url": "https://github.com/AdguardTeam/AdGuardHome/internal/releases",
|
||||
"selfupdate_min_version": "v0.0",
|
||||
"download_linux_mips_softfloat": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz"
|
||||
"download_linux_mips_softfloat": "https://static.adtidy.org/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz"
|
||||
}`
|
||||
|
||||
l, lport := startHTTPServer(jsonData)
|
||||
@@ -330,8 +325,5 @@ func TestUpdater_VersionInto_MIPS(t *testing.T) {
|
||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||
assert.Equal(t, "AdGuard Home v0.103.0-beta.2 is now available!", info.Announcement)
|
||||
assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/internal/releases", info.AnnouncementURL)
|
||||
assert.Equal(t, "v0.0", info.SelfUpdateMinVersion)
|
||||
if assert.NotNil(t, info.CanAutoUpdate) {
|
||||
assert.True(t, *info.CanAutoUpdate)
|
||||
}
|
||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||
}
|
||||
|
||||
@@ -8,19 +8,19 @@ import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// Main is the entry point of application.
|
||||
func Main(clientBuildFS fs.FS) {
|
||||
// # Initial Configuration
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
start := time.Now()
|
||||
rand.Seed(start.UnixNano())
|
||||
|
||||
// TODO(a.garipov): Set up logging.
|
||||
|
||||
@@ -31,11 +31,9 @@ func Main(clientBuildFS fs.FS) {
|
||||
|
||||
// TODO(a.garipov): Make configurable.
|
||||
web := websvc.New(&websvc.Config{
|
||||
Addresses: []*netutil.IPPort{{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Port: 3001,
|
||||
}},
|
||||
Timeout: 60 * time.Second,
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")},
|
||||
Start: start,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
|
||||
err := web.Start()
|
||||
|
||||
@@ -19,10 +19,10 @@ type signalHandler struct {
|
||||
|
||||
// handle processes OS signals.
|
||||
func (h *signalHandler) handle() {
|
||||
defer log.OnPanic("signalProcessor.handle")
|
||||
defer log.OnPanic("signalHandler.handle")
|
||||
|
||||
for sig := range h.signal {
|
||||
log.Info("sigproc: received signal %q", sig)
|
||||
log.Info("sighdlr: received signal %q", sig)
|
||||
|
||||
if aghos.IsShutdownSignal(sig) {
|
||||
h.shutdown()
|
||||
@@ -43,16 +43,16 @@ func (h *signalHandler) shutdown() {
|
||||
|
||||
status := statusSuccess
|
||||
|
||||
log.Info("sigproc: shutting down services")
|
||||
log.Info("sighdlr: shutting down services")
|
||||
for i, service := range h.services {
|
||||
err := service.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Error("sigproc: shutting down service at index %d: %s", i, err)
|
||||
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
|
||||
status = statusError
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("sigproc: shutting down adguard home")
|
||||
log.Info("sighdlr: shutting down adguard home")
|
||||
|
||||
os.Exit(status)
|
||||
}
|
||||
|
||||
193
internal/v1/dnssvc/dnssvc.go
Normal file
193
internal/v1/dnssvc/dnssvc.go
Normal file
@@ -0,0 +1,193 @@
|
||||
// Package dnssvc contains the AdGuard Home DNS service.
|
||||
//
|
||||
// TODO(a.garipov): Define, if all methods of a *Service should work with a nil
|
||||
// receiver.
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
|
||||
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
|
||||
// and replacement of module dnsproxy.
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
)
|
||||
|
||||
// Config is the AdGuard Home DNS service configuration structure.
|
||||
//
|
||||
// TODO(a.garipov): Add timeout for incoming requests.
|
||||
type Config struct {
|
||||
// Addresses are the addresses on which to serve plain DNS queries.
|
||||
Addresses []netip.AddrPort
|
||||
|
||||
// Upstreams are the DNS upstreams to use. If not set, upstreams are
|
||||
// created using data from BootstrapServers, UpstreamServers, and
|
||||
// UpstreamTimeout.
|
||||
//
|
||||
// TODO(a.garipov): Think of a better scheme. Those other three parameters
|
||||
// are here only to make Config work properly.
|
||||
Upstreams []upstream.Upstream
|
||||
|
||||
// BootstrapServers are the addresses for bootstrapping the upstream DNS
|
||||
// server addresses.
|
||||
BootstrapServers []string
|
||||
|
||||
// UpstreamServers are the upstream DNS server addresses to use.
|
||||
UpstreamServers []string
|
||||
|
||||
// UpstreamTimeout is the timeout for upstream requests.
|
||||
UpstreamTimeout time.Duration
|
||||
}
|
||||
|
||||
// Service is the AdGuard Home DNS service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
proxy *proxy.Proxy
|
||||
bootstraps []string
|
||||
upstreams []string
|
||||
upsTimeout time.Duration
|
||||
}
|
||||
|
||||
// New returns a new properly initialized *Service. If c is nil, svc is a nil
|
||||
// *Service that does nothing. The fields of c must not be modified after
|
||||
// calling New.
|
||||
func New(c *Config) (svc *Service, err error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
bootstraps: c.BootstrapServers,
|
||||
upstreams: c.UpstreamServers,
|
||||
upsTimeout: c.UpstreamTimeout,
|
||||
}
|
||||
|
||||
var upstreams []upstream.Upstream
|
||||
if len(c.Upstreams) > 0 {
|
||||
upstreams = c.Upstreams
|
||||
} else {
|
||||
upstreams, err = addressesToUpstreams(
|
||||
c.UpstreamServers,
|
||||
c.BootstrapServers,
|
||||
c.UpstreamTimeout,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
svc.proxy = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UDPListenAddr: udpAddrs(c.Addresses),
|
||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: upstreams,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = svc.proxy.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: %w", err)
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||
// accepts a slice of addresses and other upstream parameters, and returns a
|
||||
// slice of upstreams.
|
||||
func addressesToUpstreams(
|
||||
upsStrs []string,
|
||||
bootstraps []string,
|
||||
timeout time.Duration,
|
||||
) (upstreams []upstream.Upstream, err error) {
|
||||
upstreams = make([]upstream.Upstream, len(upsStrs))
|
||||
for i, upsStr := range upsStrs {
|
||||
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: timeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
// tcpAddrs converts []netip.AddrPort into []*net.TCPAddr.
|
||||
func tcpAddrs(addrPorts []netip.AddrPort) (tcpAddrs []*net.TCPAddr) {
|
||||
if addrPorts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tcpAddrs = make([]*net.TCPAddr, len(addrPorts))
|
||||
for i, a := range addrPorts {
|
||||
tcpAddrs[i] = net.TCPAddrFromAddrPort(a)
|
||||
}
|
||||
|
||||
return tcpAddrs
|
||||
}
|
||||
|
||||
// udpAddrs converts []netip.AddrPort into []*net.UDPAddr.
|
||||
func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
|
||||
if addrPorts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddrs = make([]*net.UDPAddr, len(addrPorts))
|
||||
for i, a := range addrPorts {
|
||||
udpAddrs[i] = net.UDPAddrFromAddrPort(a)
|
||||
}
|
||||
|
||||
return udpAddrs
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all DNS servers have tried to start, but there is no
|
||||
// guarantee that they did. Errors from the servers are written to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return svc.proxy.Start()
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
// nil.
|
||||
func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return svc.proxy.Stop()
|
||||
}
|
||||
|
||||
// Config returns the current configuration of the web service.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
// TODO(a.garipov): Do we need to get the TCP addresses separately?
|
||||
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
|
||||
addrs := make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.(*net.UDPAddr).AddrPort()
|
||||
}
|
||||
|
||||
c = &Config{
|
||||
Addresses: addrs,
|
||||
BootstrapServers: svc.bootstraps,
|
||||
UpstreamServers: svc.upstreams,
|
||||
UpstreamTimeout: svc.upsTimeout,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
89
internal/v1/dnssvc/dnssvc_test.go
Normal file
89
internal/v1/dnssvc/dnssvc_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dnssvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 100 * time.Millisecond
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
const (
|
||||
bootstrapAddr = "bootstrap.example"
|
||||
upstreamAddr = "upstream.example"
|
||||
)
|
||||
|
||||
ups := &aghtest.UpstreamMock{
|
||||
OnAddress: func() (addr string) {
|
||||
return upstreamAddr
|
||||
},
|
||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = (&dns.Msg{}).SetReply(req)
|
||||
|
||||
return resp, nil
|
||||
},
|
||||
}
|
||||
|
||||
c := &dnssvc.Config{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
BootstrapServers: []string{bootstrapAddr},
|
||||
UpstreamServers: []string{upstreamAddr},
|
||||
UpstreamTimeout: testTimeout,
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
gotConf := svc.Config()
|
||||
require.NotNil(t, gotConf)
|
||||
require.Len(t, gotConf.Addresses, 1)
|
||||
|
||||
addr := gotConf.Addresses[0]
|
||||
|
||||
t.Run("dns", func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "example.com.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
cli := &dns.Client{}
|
||||
resp, _, excErr := cli.ExchangeContext(ctx, req, addr.String())
|
||||
require.NoError(t, excErr)
|
||||
|
||||
assert.NotNil(t, resp)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
err = svc.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
61
internal/v1/websvc/json.go
Normal file
61
internal/v1/websvc/json.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// JSON Utilities
|
||||
|
||||
// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON
|
||||
// according to our API conventions.
|
||||
type jsonTime time.Time
|
||||
|
||||
// type check
|
||||
var _ json.Marshaler = jsonTime{}
|
||||
|
||||
// nsecPerMsec is the number of nanoseconds in a millisecond.
|
||||
const nsecPerMsec = float64(time.Millisecond / time.Nanosecond)
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for jsonTime. err is
|
||||
// always nil.
|
||||
func (t jsonTime) MarshalJSON() (b []byte, err error) {
|
||||
msec := float64(time.Time(t).UnixNano()) / nsecPerMsec
|
||||
b = strconv.AppendFloat(nil, msec, 'f', 3, 64)
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ json.Unmarshaler = (*jsonTime)(nil)
|
||||
|
||||
// UnmarshalJSON implements the json.Marshaler interface for *jsonTime.
|
||||
func (t *jsonTime) UnmarshalJSON(b []byte) (err error) {
|
||||
if t == nil {
|
||||
return fmt.Errorf("json time is nil")
|
||||
}
|
||||
|
||||
msec, err := strconv.ParseFloat(string(b), 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing json time: %w", err)
|
||||
}
|
||||
|
||||
*t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeJSONResponse encodes v into w and logs any errors it encounters. r is
|
||||
// used to get additional information from the request.
|
||||
func writeJSONResponse(w io.Writer, r *http.Request, v any) {
|
||||
err := json.NewEncoder(w).Encode(v)
|
||||
if err != nil {
|
||||
log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err)
|
||||
}
|
||||
}
|
||||
16
internal/v1/websvc/middleware.go
Normal file
16
internal/v1/websvc/middleware.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package websvc
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Middlewares
|
||||
|
||||
// jsonMw sets the content type of the response to application/json.
|
||||
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(f)
|
||||
}
|
||||
8
internal/v1/websvc/path.go
Normal file
8
internal/v1/websvc/path.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package websvc
|
||||
|
||||
// Path constants
|
||||
const (
|
||||
PathHealthCheck = "/health-check"
|
||||
|
||||
PathV1SystemInfo = "/api/v1/system/info"
|
||||
)
|
||||
35
internal/v1/websvc/system.go
Normal file
35
internal/v1/websvc/system.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
)
|
||||
|
||||
// System Handlers
|
||||
|
||||
// RespGetV1SystemInfo describes the response of the GET /api/v1/system/info
|
||||
// HTTP API.
|
||||
type RespGetV1SystemInfo struct {
|
||||
Arch string `json:"arch"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
NewVersion string `json:"new_version,omitempty"`
|
||||
Start jsonTime `json:"start"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
|
||||
// API.
|
||||
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONResponse(w, r, &RespGetV1SystemInfo{
|
||||
Arch: runtime.GOARCH,
|
||||
Channel: version.Channel(),
|
||||
OS: runtime.GOOS,
|
||||
// TODO(a.garipov): Fill this when we have an updater.
|
||||
NewVersion: "",
|
||||
Start: jsonTime(svc.start),
|
||||
Version: version.Version(),
|
||||
})
|
||||
}
|
||||
36
internal/v1/websvc/system_test.go
Normal file
36
internal/v1/websvc/system_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_handleGetV1SystemInfo(t *testing.T) {
|
||||
_, addr := newTestServer(t)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr,
|
||||
Path: websvc.PathV1SystemInfo,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
resp := &websvc.RespGetV1SystemInfo{}
|
||||
err := json.Unmarshal(body, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO(a.garipov): Consider making version.Channel and version.Version
|
||||
// testable and test these better.
|
||||
assert.NotEmpty(t, resp.Channel)
|
||||
|
||||
assert.Equal(t, resp.Arch, runtime.GOARCH)
|
||||
assert.Equal(t, resp.OS, runtime.GOOS)
|
||||
assert.Equal(t, testStart, time.Time(resp.Start))
|
||||
}
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
httptreemux "github.com/dimfeld/httptreemux/v5"
|
||||
)
|
||||
|
||||
// Config is the AdGuard Home web service configuration structure.
|
||||
@@ -26,21 +27,25 @@ type Config struct {
|
||||
TLS *tls.Config
|
||||
|
||||
// Addresses are the addresses on which to serve the plain HTTP API.
|
||||
Addresses []*netutil.IPPort
|
||||
Addresses []netip.AddrPort
|
||||
|
||||
// SecureAddresses are the addresses on which to serve the HTTPS API. If
|
||||
// SecureAddresses is not empty, TLS must not be nil.
|
||||
SecureAddresses []*netutil.IPPort
|
||||
SecureAddresses []netip.AddrPort
|
||||
|
||||
// Start is the time of start of AdGuard Home.
|
||||
Start time.Time
|
||||
|
||||
// Timeout is the timeout for all server operations.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Service is the AdGuard Home web service. A nil *Service is a valid service
|
||||
// that does nothing.
|
||||
// Service is the AdGuard Home web service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
tls *tls.Config
|
||||
servers []*http.Server
|
||||
start time.Time
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
@@ -53,11 +58,11 @@ func New(c *Config) (svc *Service) {
|
||||
|
||||
svc = &Service{
|
||||
tls: c.TLS,
|
||||
start: c.Start,
|
||||
timeout: c.Timeout,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health-check", svc.handleGetHealthCheck)
|
||||
mux := newMux(svc)
|
||||
|
||||
for _, a := range c.Addresses {
|
||||
addr := a.String()
|
||||
@@ -91,6 +96,43 @@ func New(c *Config) (svc *Service) {
|
||||
return svc
|
||||
}
|
||||
|
||||
// newMux returns a new HTTP request multiplexor for the AdGuard Home web
|
||||
// service.
|
||||
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
|
||||
mux = httptreemux.NewContextMux()
|
||||
|
||||
routes := []struct {
|
||||
handler http.HandlerFunc
|
||||
method string
|
||||
path string
|
||||
isJSON bool
|
||||
}{{
|
||||
handler: svc.handleGetHealthCheck,
|
||||
method: http.MethodGet,
|
||||
path: PathHealthCheck,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: svc.handleGetV1SystemInfo,
|
||||
method: http.MethodGet,
|
||||
path: PathV1SystemInfo,
|
||||
isJSON: true,
|
||||
}}
|
||||
|
||||
for _, r := range routes {
|
||||
var h http.HandlerFunc
|
||||
if r.isJSON {
|
||||
// TODO(a.garipov): Consider using httptreemux's MiddlewareFunc.
|
||||
h = jsonMw(r.handler)
|
||||
} else {
|
||||
h = r.handler
|
||||
}
|
||||
|
||||
mux.Handle(r.method, r.path, h)
|
||||
}
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// Addrs returns all addresses on which this server serves the HTTP API. Addrs
|
||||
// must not be called until Start returns.
|
||||
func (svc *Service) Addrs() (addrs []string) {
|
||||
@@ -113,7 +155,7 @@ type unit = struct{}
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
|
||||
// Start implements the agh.Service interface for *Service. svc may be nil.
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all HTTP servers have tried to start, possibly failing and
|
||||
// writing error messages to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
@@ -163,7 +205,8 @@ func serve(srv *http.Server, wg *sync.WaitGroup) {
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown implements the agh.Service interface for *Service. svc may be nil.
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
// nil.
|
||||
func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
|
||||
@@ -3,14 +3,13 @@ package websvc_test
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -18,18 +17,26 @@ import (
|
||||
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
func TestService_Start_getHealthCheck(t *testing.T) {
|
||||
// testStart is the server start value for tests.
|
||||
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// newTestServer creates and starts a new web service instance as well as its
|
||||
// sole address. It also registers a cleanup procedure, which shuts the
|
||||
// instance down.
|
||||
//
|
||||
// TODO(a.garipov): Use svc or remove it.
|
||||
func newTestServer(t testing.TB) (svc *websvc.Service, addr string) {
|
||||
t.Helper()
|
||||
|
||||
c := &websvc.Config{
|
||||
TLS: nil,
|
||||
Addresses: []*netutil.IPPort{{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Port: 0,
|
||||
}},
|
||||
TLS: nil,
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
SecureAddresses: nil,
|
||||
Timeout: testTimeout,
|
||||
Start: testStart,
|
||||
}
|
||||
|
||||
svc := websvc.New(c)
|
||||
svc = websvc.New(c)
|
||||
|
||||
err := svc.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -44,26 +51,43 @@ func TestService_Start_getHealthCheck(t *testing.T) {
|
||||
addrs := svc.Addrs()
|
||||
require.Len(t, addrs, 1)
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addrs[0],
|
||||
Path: "/health-check",
|
||||
}
|
||||
return svc, addrs[0]
|
||||
}
|
||||
|
||||
// httpGet is a helper that performs an HTTP GET request and returns the body of
|
||||
// the response as well as checks that the status code is correct.
|
||||
//
|
||||
// TODO(a.garipov): Add helpers for other methods.
|
||||
func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
|
||||
t.Helper()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
require.NoError(t, err)
|
||||
require.NoErrorf(t, err, "creating req")
|
||||
|
||||
httpCli := &http.Client{
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
resp, err := httpCli.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.NoErrorf(t, err, "performing req")
|
||||
require.Equal(t, wantCode, resp.StatusCode)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
require.NoErrorf(t, err, "reading body")
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
return body
|
||||
}
|
||||
|
||||
func TestService_Start_getHealthCheck(t *testing.T) {
|
||||
_, addr := newTestServer(t)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr,
|
||||
Path: websvc.PathHealthCheck,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
|
||||
assert.Equal(t, []byte("OK"), body)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user