Merge branch 'master' into 2476-rwmutex

This commit is contained in:
Ainar Garipov
2021-03-03 16:10:29 +03:00
485 changed files with 35236 additions and 6572 deletions

View File

@@ -5,14 +5,9 @@ import (
"fmt"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
func TestError_Error(t *testing.T) {
testCases := []struct {
name string

View File

@@ -8,6 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitReadCloser(t *testing.T) {
@@ -78,11 +79,11 @@ func TestLimitedReadCloser_Read(t *testing.T) {
buf := make([]byte, tc.limit+1)
lreader, err := LimitReadCloser(readCloser, tc.limit)
assert.Nil(t, err)
require.Nil(t, err)
n, err := lreader.Read(buf)
assert.Equal(t, n, tc.want)
assert.Equal(t, tc.err, err)
require.Equal(t, tc.err, err)
assert.Equal(t, tc.want, n)
})
}
}

View File

@@ -1,5 +1,5 @@
// Package testutil contains utilities for testing.
package testutil
// Package aghtest contains utilities for testing.
package aghtest
import (
"io"

45
internal/aghtest/os.go Normal file
View File

@@ -0,0 +1,45 @@
package aghtest
import (
"io/ioutil"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// PrepareTestDir returns the full path to temporary created directory and
// registers the appropriate cleanup for *t.
func PrepareTestDir(t *testing.T) (dir string) {
t.Helper()
wd, err := os.Getwd()
require.Nil(t, err)
dir, err = ioutil.TempDir(wd, "agh-test")
require.Nil(t, err)
require.NotEmpty(t, dir)
t.Cleanup(func() {
// TODO(e.burkov): Replace with t.TempDir methods after updating
// go version to 1.15.
start := time.Now()
for {
err := os.RemoveAll(dir)
if err == nil {
break
}
if runtime.GOOS != "windows" || time.Since(start) >= 500*time.Millisecond {
break
}
time.Sleep(5 * time.Millisecond)
}
assert.Nil(t, err)
})
return dir
}

View File

@@ -0,0 +1,63 @@
package aghtest
import (
"context"
"crypto/sha256"
"net"
"sync"
)
// TestResolver is a Resolver for tests.
type TestResolver struct {
counter int
counterLock sync.Mutex
}
// HostToIPs generates IPv4 and IPv6 from host.
//
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}
// LookupIPAddr implements Resolver interface for *testResolver. It returns the
// slice of net.IPAddr with IPv4 and IPv6 instances.
func (r *TestResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
ipv4, ipv6 := r.HostToIPs(host)
addrs := []net.IPAddr{{
IP: ipv4,
}, {
IP: ipv6,
}}
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return addrs, nil
}
// LookupHost implements Resolver interface for *testResolver. It returns the
// slice of IPv4 and IPv6 instances converted to strings.
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
ipv4, ipv6 := r.HostToIPs(host)
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return []string{
ipv4.String(),
ipv6.String(),
}, nil
}
// Counter returns the number of requests handled.
func (r *TestResolver) Counter() int {
r.counterLock.Lock()
defer r.counterLock.Unlock()
return r.counter
}

View File

@@ -0,0 +1,175 @@
package aghtest
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/miekg/dns"
)
// TestUpstream is a mock of real upstream.
type TestUpstream struct {
// Addr is the address for Address method.
Addr string
// CName is a map of hostname to canonical name.
CName map[string]string
// IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
IPv6 map[string][]net.IP
// Reverse is a map of address to domain name.
Reverse map[string][]string
}
// Exchange implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)
if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty")
}
name := m.Question[0].Name
if cname, ok := u.CName[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
})
}
var hasRec bool
var rrType uint16
var ips []net.IP
switch m.Question[0].Qtype {
case dns.TypeA:
rrType = dns.TypeA
if ipv4addr, ok := u.IPv4[name]; ok {
hasRec = true
ips = ipv4addr
}
case dns.TypeAAAA:
rrType = dns.TypeAAAA
if ipv6addr, ok := u.IPv6[name]; ok {
hasRec = true
ips = ipv6addr
}
case dns.TypePTR:
names, ok := u.Reverse[name]
if !ok {
break
}
for _, n := range names {
resp.Answer = append(resp.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
Ptr: n,
})
}
}
for _, ip := range ips {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
A: ip,
})
}
if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError)
}
return resp, nil
}
// Address implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Address() string {
return u.Addr
}
// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
Block bool
requestsCount int
lock sync.RWMutex
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.requestsCount++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}
m := &dns.Msg{}
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}
return m, nil
}
// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}
// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.requestsCount
}
// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct{}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad")
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
}

View File

@@ -3,6 +3,8 @@ package dhcpd
import (
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"path/filepath"
@@ -17,7 +19,12 @@ import (
const (
defaultDiscoverTime = time.Second * 3
leaseExpireStatic = 1
// leaseExpireStatic is used to define the Expiry field for static
// leases.
//
// TODO(e.burkov): Remove it when static leases determining mechanism
// will be improved.
leaseExpireStatic = 1
)
var webHandlersRegistered = false
@@ -33,6 +40,51 @@ type Lease struct {
Expiry time.Time `json:"expires"`
}
// MarshalJSON implements the json.Marshaler interface for *Lease.
func (l *Lease) MarshalJSON() ([]byte, error) {
var expiryStr string
if expiry := l.Expiry; expiry.Unix() != leaseExpireStatic {
// The front-end is waiting for RFC 3999 format of the time
// value. It also shouldn't got an Expiry field for static
// leases.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
expiryStr = expiry.Format(time.RFC3339)
}
type lease Lease
return json.Marshal(&struct {
HWAddr string `json:"mac"`
Expiry string `json:"expires,omitempty"`
*lease
}{
HWAddr: l.HWAddr.String(),
Expiry: expiryStr,
lease: (*lease)(l),
})
}
// UnmarshalJSON implements the json.Unmarshaler interface for *Lease.
func (l *Lease) UnmarshalJSON(data []byte) (err error) {
type lease Lease
aux := struct {
HWAddr string `json:"mac"`
*lease
}{
lease: (*lease)(l),
}
if err = json.Unmarshal(data, &aux); err != nil {
return err
}
l.HWAddr, err = net.ParseMAC(aux.HWAddr)
if err != nil {
return fmt.Errorf("couldn't parse MAC address: %w", err)
}
return nil
}
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here
type ServerConfig struct {
@@ -82,14 +134,14 @@ type ServerInterface interface {
}
// Create - create object
func Create(config ServerConfig) *Server {
func Create(conf ServerConfig) *Server {
s := &Server{}
s.conf.Enabled = config.Enabled
s.conf.InterfaceName = config.InterfaceName
s.conf.HTTPRegister = config.HTTPRegister
s.conf.ConfigModified = config.ConfigModified
s.conf.DBFilePath = filepath.Join(config.WorkDir, dbFilename)
s.conf.Enabled = conf.Enabled
s.conf.InterfaceName = conf.InterfaceName
s.conf.HTTPRegister = conf.HTTPRegister
s.conf.ConfigModified = conf.ConfigModified
s.conf.DBFilePath = filepath.Join(conf.WorkDir, dbFilename)
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
if runtime.GOOS == "windows" {
@@ -110,7 +162,7 @@ func Create(config ServerConfig) *Server {
}
var err4, err6 error
v4conf := config.Conf4
v4conf := conf.Conf4
v4conf.Enabled = s.conf.Enabled
if len(v4conf.RangeStart) == 0 {
v4conf.Enabled = false
@@ -119,7 +171,7 @@ func Create(config ServerConfig) *Server {
v4conf.notify = s.onNotify
s.srv4, err4 = v4Create(v4conf)
v6conf := config.Conf6
v6conf := conf.Conf6
v6conf.Enabled = s.conf.Enabled
if len(v6conf.RangeStart) == 0 {
v6conf.Enabled = false
@@ -137,6 +189,9 @@ func Create(config ServerConfig) *Server {
return nil
}
s.conf.Conf4 = conf.Conf4
s.conf.Conf6 = conf.Conf6
if s.conf.Enabled && !v4conf.Enabled && !v6conf.Enabled {
log.Error("Can't enable DHCP server because neither DHCPv4 nor DHCPv6 servers are configured")
return nil
@@ -210,14 +265,10 @@ const (
LeasesAll = LeasesDynamic | LeasesStatic
)
// Leases returns the list of current DHCP leases (thread-safe)
func (s *Server) Leases(flags int) []Lease {
result := s.srv4.GetLeases(flags)
v6leases := s.srv6.GetLeases(flags)
result = append(result, v6leases...)
return result
// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for
// concurrent use.
func (s *Server) Leases(flags int) (leases []Lease) {
return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...)
}
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
@@ -255,17 +306,22 @@ func parseOptionString(s string) (uint8, []byte) {
if err != nil {
return 0, nil
}
case "ip":
ip := net.ParseIP(sval)
if ip == nil {
return 0, nil
}
val = ip
if ip.To4() != nil {
val = ip.To4()
}
// Most DHCP options require IPv4, so do not put the 16-byte
// version if we can. Otherwise, the clients will receive weird
// data that looks like four IPv4 addresses.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2688.
if ip4 := ip.To4(); ip4 != nil {
val = ip4
} else {
val = ip
}
default:
return 0, nil
}

View File

@@ -3,128 +3,188 @@
package dhcpd
import (
"bytes"
"net"
"os"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}
func testNotify(flags uint32) {
}
// Leases database store/load
// Leases database store/load.
func TestDB(t *testing.T) {
var err error
s := Server{}
s.conf.DBFilePath = dbFilename
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
notify: testNotify,
s := Server{
conf: ServerConfig{
DBFilePath: dbFilename,
},
}
s.srv4, err = v4Create(conf)
assert.True(t, err == nil)
s.srv4, err = v4Create(V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: testNotify,
})
require.Nil(t, err)
s.srv6, err = v6Create(V6ServerConf{})
assert.True(t, err == nil)
require.Nil(t, err)
l := Lease{}
l.IP = net.ParseIP("192.168.10.100").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
exp1 := time.Now().Add(time.Hour)
l.Expiry = exp1
s.srv4.(*v4Server).addLease(&l)
leases := []Lease{{
IP: net.IP{192, 168, 10, 100},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Expiry: time.Now().Add(time.Hour),
}, {
IP: net.IP{192, 168, 10, 101},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBB},
}}
l2 := Lease{}
l2.IP = net.ParseIP("192.168.10.101").To4()
l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb")
s.srv4.AddStaticLease(l2)
srv4, ok := s.srv4.(*v4Server)
require.True(t, ok)
srv4.addLease(&leases[0])
require.Nil(t, s.srv4.AddStaticLease(leases[1]))
_ = os.Remove("leases.db")
s.dbStore()
t.Cleanup(func() {
assert.Nil(t, os.Remove(dbFilename))
})
s.srv4.ResetLeases(nil)
s.dbLoad()
ll := s.srv4.GetLeases(LeasesAll)
require.Len(t, ll, len(leases))
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
assert.Equal(t, "192.168.10.101", ll[0].IP.String())
assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix())
assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr)
assert.Equal(t, leases[1].IP, ll[0].IP)
assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
assert.Equal(t, "192.168.10.100", ll[1].IP.String())
assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix())
_ = os.Remove("leases.db")
assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr)
assert.Equal(t, leases[0].IP, ll[1].IP)
assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix())
}
func TestIsValidSubnetMask(t *testing.T) {
assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1}))
testCases := []struct {
mask net.IP
want bool
}{{
mask: net.IP{255, 255, 255, 0},
want: true,
}, {
mask: net.IP{255, 255, 254, 0},
want: true,
}, {
mask: net.IP{255, 255, 252, 0},
want: true,
}, {
mask: net.IP{255, 255, 253, 0},
}, {
mask: net.IP{255, 255, 255, 1},
}}
for _, tc := range testCases {
t.Run(tc.mask.String(), func(t *testing.T) {
assert.Equal(t, tc.want, isValidSubnetMask(tc.mask))
})
}
}
func TestNormalizeLeases(t *testing.T) {
dynLeases := []*Lease{}
staticLeases := []*Lease{}
dynLeases := []*Lease{{
HWAddr: net.HardwareAddr{1, 2, 3, 4},
}, {
HWAddr: net.HardwareAddr{1, 2, 3, 5},
}}
lease := &Lease{}
lease.HWAddr = []byte{1, 2, 3, 4}
dynLeases = append(dynLeases, lease)
lease = new(Lease)
lease.HWAddr = []byte{1, 2, 3, 5}
dynLeases = append(dynLeases, lease)
lease = new(Lease)
lease.HWAddr = []byte{1, 2, 3, 4}
lease.IP = []byte{0, 2, 3, 4}
staticLeases = append(staticLeases, lease)
lease = new(Lease)
lease.HWAddr = []byte{2, 2, 3, 4}
staticLeases = append(staticLeases, lease)
staticLeases := []*Lease{{
HWAddr: net.HardwareAddr{1, 2, 3, 4},
IP: net.IP{0, 2, 3, 4},
}, {
HWAddr: net.HardwareAddr{2, 2, 3, 4},
}}
leases := normalizeLeases(staticLeases, dynLeases)
require.Len(t, leases, 3)
assert.True(t, len(leases) == 3)
assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[2].HWAddr, []byte{1, 2, 3, 5}))
assert.Equal(t, leases[0].HWAddr, dynLeases[0].HWAddr)
assert.Equal(t, leases[0].IP, staticLeases[0].IP)
assert.Equal(t, leases[1].HWAddr, staticLeases[1].HWAddr)
assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr)
}
func TestOptions(t *testing.T) {
code, val := parseOptionString(" 12 hex abcdef ")
assert.Equal(t, uint8(12), code)
assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val))
testCases := []struct {
name string
optStr string
wantVal []byte
wantCode uint8
}{{
name: "success_hex",
optStr: "12 hex abcdef",
wantVal: []byte{0xab, 0xcd, 0xef},
wantCode: 12,
}, {
name: "bad_hex",
optStr: "12 hex abcdefx",
wantVal: nil,
wantCode: 0,
}, {
name: "success_ip",
optStr: "123 ip 1.2.3.4",
wantVal: net.IP{1, 2, 3, 4},
wantCode: 123,
}, {
name: "success_ipv6",
optStr: "123 ip ::1234",
wantVal: net.IP{
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0x12, 0x34,
},
wantCode: 123,
}, {
name: "bad_code",
optStr: "256 ip 1.1.1.1",
wantVal: nil,
wantCode: 0,
}, {
name: "negative_code",
optStr: "-1 ip 1.1.1.1",
wantVal: nil,
wantCode: 0,
}, {
name: "bad_ip",
optStr: "12 ip 1.1.1.1x",
wantVal: nil,
wantCode: 0,
}, {
name: "bad_mode",
wantVal: nil,
optStr: "12 x 1.1.1.1",
wantCode: 0,
}}
code, _ = parseOptionString(" 12 hex abcdef1 ")
assert.Equal(t, uint8(0), code)
code, val = parseOptionString("123 ip 1.2.3.4")
assert.Equal(t, uint8(123), code)
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
code, _ = parseOptionString("256 ip 1.1.1.1")
assert.Equal(t, uint8(0), code)
code, _ = parseOptionString("-1 ip 1.1.1.1")
assert.Equal(t, uint8(0), code)
code, _ = parseOptionString("12 ip 1.1.1.1x")
assert.Equal(t, uint8(0), code)
code, _ = parseOptionString("12 x 1.1.1.1")
assert.Equal(t, uint8(0), code)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
code, val := parseOptionString(tc.optStr)
require.Equal(t, tc.wantCode, code)
if tc.wantVal != nil {
assert.Equal(t, tc.wantVal, val)
}
})
}
}

View File

@@ -14,15 +14,17 @@ func isTimeout(err error) bool {
return operr.Timeout()
}
func parseIPv4(text string) (net.IP, error) {
result := net.ParseIP(text)
if result == nil {
return nil, fmt.Errorf("%s is not an IP address", text)
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
if ip == nil {
return nil, fmt.Errorf("%v is not an IP address", ip)
}
if result.To4() == nil {
return nil, fmt.Errorf("%s is not an IPv4 address", text)
ip4 = ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
}
return result.To4(), nil
return ip4, nil
}
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)

View File

@@ -2,17 +2,16 @@ package dhcpd
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/sysutil"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
@@ -22,44 +21,19 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string,
http.Error(w, text, code)
}
// []Lease -> JSON
func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string {
leases := []map[string]string{}
for _, l := range inputLeases {
lease := map[string]string{
"mac": l.HWAddr.String(),
"ip": l.IP.String(),
"hostname": l.Hostname,
}
if includeExpires {
lease["expires"] = l.Expiry.Format(time.RFC3339)
}
leases = append(leases, lease)
}
return leases
}
type v4ServerConfJSON struct {
GatewayIP string `json:"gateway_ip"`
SubnetMask string `json:"subnet_mask"`
RangeStart string `json:"range_start"`
RangeEnd string `json:"range_end"`
GatewayIP net.IP `json:"gateway_ip"`
SubnetMask net.IP `json:"subnet_mask"`
RangeStart net.IP `json:"range_start"`
RangeEnd net.IP `json:"range_end"`
LeaseDuration uint32 `json:"lease_duration"`
}
func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON {
return v4ServerConfJSON{
GatewayIP: c.GatewayIP,
SubnetMask: c.SubnetMask,
RangeStart: c.RangeStart,
RangeEnd: c.RangeEnd,
LeaseDuration: c.LeaseDuration,
func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
if j == nil {
return V4ServerConf{}
}
}
func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
return V4ServerConf{
GatewayIP: j.GatewayIP,
SubnetMask: j.SubnetMask,
@@ -70,43 +44,45 @@ func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
}
type v6ServerConfJSON struct {
RangeStart string `json:"range_start"`
RangeStart net.IP `json:"range_start"`
LeaseDuration uint32 `json:"lease_duration"`
}
func v6ServerConfToJSON(c V6ServerConf) v6ServerConfJSON {
return v6ServerConfJSON{
RangeStart: c.RangeStart,
LeaseDuration: c.LeaseDuration,
func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
if j == nil {
return V6ServerConf{}
}
}
func v6JSONToServerConf(j v6ServerConfJSON) V6ServerConf {
return V6ServerConf{
RangeStart: j.RangeStart,
LeaseDuration: j.LeaseDuration,
}
}
// dhcpStatusResponse is the response for /control/dhcp/status endpoint.
type dhcpStatusResponse struct {
Enabled bool `json:"enabled"`
IfaceName string `json:"interface_name"`
V4 V4ServerConf `json:"v4"`
V6 V6ServerConf `json:"v6"`
Leases []Lease `json:"leases"`
StaticLeases []Lease `json:"static_leases"`
}
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
leases := convertLeases(s.Leases(LeasesDynamic), true)
staticLeases := convertLeases(s.Leases(LeasesStatic), false)
v4conf := V4ServerConf{}
s.srv4.WriteDiskConfig4(&v4conf)
v6conf := V6ServerConf{}
s.srv6.WriteDiskConfig6(&v6conf)
status := map[string]interface{}{
"enabled": s.conf.Enabled,
"interface_name": s.conf.InterfaceName,
"v4": v4ServerConfToJSON(v4conf),
"v6": v6ServerConfToJSON(v6conf),
"leases": leases,
"static_leases": staticLeases,
status := &dhcpStatusResponse{
Enabled: s.conf.Enabled,
IfaceName: s.conf.InterfaceName,
V4: V4ServerConf{},
V6: V6ServerConf{},
}
s.srv4.WriteDiskConfig4(&status.V4)
s.srv6.WriteDiskConfig6(&status.V6)
status.Leases = s.Leases(LeasesDynamic)
status.StaticLeases = s.Leases(LeasesStatic)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status)
if err != nil {
@@ -115,27 +91,72 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
}
}
type staticLeaseJSON struct {
HWAddr string `json:"mac"`
IP string `json:"ip"`
Hostname string `json:"hostname"`
func (s *Server) enableDHCP(ifaceName string) (code int, err error) {
var hasStaticIP bool
hasStaticIP, err = sysutil.IfaceHasStaticIP(ifaceName)
if err != nil {
if errors.Is(err, os.ErrPermission) {
// ErrPermission may happen here on Linux systems where
// AdGuard Home is installed using Snap. That doesn't
// necessarily mean that the machine doesn't have
// a static IP, so we can assume that it has and go on.
// If the machine doesn't, we'll get an error later.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2667.
//
// TODO(a.garipov): I was thinking about moving this
// into IfaceHasStaticIP, but then we wouldn't be able
// to log it. Think about it more.
log.Info("error while checking static ip: %s; "+
"assuming machine has static ip and going on", err)
hasStaticIP = true
} else if errors.Is(err, sysutil.ErrNoStaticIPInfo) {
// Couldn't obtain a definitive answer. Assume static
// IP an go on.
log.Info("can't check for static ip; " +
"assuming machine has static ip and going on")
hasStaticIP = true
} else {
err = fmt.Errorf("checking static ip: %w", err)
return http.StatusInternalServerError, err
}
}
if !hasStaticIP {
err = sysutil.IfaceSetStaticIP(ifaceName)
if err != nil {
err = fmt.Errorf("setting static ip: %w", err)
return http.StatusInternalServerError, err
}
}
err = s.Start()
if err != nil {
return http.StatusBadRequest, fmt.Errorf("starting dhcp server: %w", err)
}
return 0, nil
}
type dhcpServerConfigJSON struct {
Enabled bool `json:"enabled"`
InterfaceName string `json:"interface_name"`
V4 v4ServerConfJSON `json:"v4"`
V6 v6ServerConfJSON `json:"v6"`
V4 *v4ServerConfJSON `json:"v4"`
V6 *v6ServerConfJSON `json:"v6"`
InterfaceName string `json:"interface_name"`
Enabled nullBool `json:"enabled"`
}
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
newconfig := dhcpServerConfigJSON{}
newconfig.Enabled = s.conf.Enabled
newconfig.InterfaceName = s.conf.InterfaceName
conf := dhcpServerConfigJSON{}
conf.Enabled = boolToNullBool(s.conf.Enabled)
conf.InterfaceName = s.conf.InterfaceName
js, err := jsonutil.DecodeObject(&newconfig, r.Body)
err := json.NewDecoder(r.Body).Decode(&conf)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
httpError(r, w, http.StatusBadRequest,
"failed to parse new dhcp config json: %s", err)
return
}
@@ -144,80 +165,91 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
v4Enabled := false
v6Enabled := false
if js.Exists("v4") {
v4conf := v4JSONToServerConf(newconfig.V4)
v4conf.Enabled = newconfig.Enabled
if len(v4conf.RangeStart) == 0 {
v4conf.Enabled = false
if conf.V4 != nil {
v4Conf := v4JSONToServerConf(conf.V4)
v4Conf.Enabled = conf.Enabled == nbTrue
if len(v4Conf.RangeStart) == 0 {
v4Conf.Enabled = false
}
v4Enabled = v4conf.Enabled
v4conf.InterfaceName = newconfig.InterfaceName
v4Enabled = v4Conf.Enabled
v4Conf.InterfaceName = conf.InterfaceName
c4 := V4ServerConf{}
s.srv4.WriteDiskConfig4(&c4)
v4conf.notify = c4.notify
v4conf.ICMPTimeout = c4.ICMPTimeout
v4Conf.notify = c4.notify
v4Conf.ICMPTimeout = c4.ICMPTimeout
s4, err = v4Create(v4conf)
s4, err = v4Create(v4Conf)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv4 configuration: %s", err)
httpError(r, w, http.StatusBadRequest,
"invalid dhcpv4 configuration: %s", err)
return
}
}
if js.Exists("v6") {
v6conf := v6JSONToServerConf(newconfig.V6)
v6conf.Enabled = newconfig.Enabled
if len(v6conf.RangeStart) == 0 {
v6conf.Enabled = false
if conf.V6 != nil {
v6Conf := v6JSONToServerConf(conf.V6)
v6Conf.Enabled = conf.Enabled == nbTrue
if len(v6Conf.RangeStart) == 0 {
v6Conf.Enabled = false
}
v6Enabled = v6conf.Enabled
v6conf.InterfaceName = newconfig.InterfaceName
v6conf.notify = s.onNotify
s6, err = v6Create(v6conf)
// Don't overwrite the RA/SLAAC settings from the config file.
//
// TODO(a.garipov): Perhaps include them into the request to
// allow changing them from the HTTP API?
v6Conf.RASLAACOnly = s.conf.Conf6.RASLAACOnly
v6Conf.RAAllowSLAAC = s.conf.Conf6.RAAllowSLAAC
v6Enabled = v6Conf.Enabled
v6Conf.InterfaceName = conf.InterfaceName
v6Conf.notify = s.onNotify
s6, err = v6Create(v6Conf)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv6 configuration: %s", err)
httpError(r, w, http.StatusBadRequest,
"invalid dhcpv6 configuration: %s", err)
return
}
}
if newconfig.Enabled && !v4Enabled && !v6Enabled {
httpError(r, w, http.StatusBadRequest, "DHCPv4 or DHCPv6 configuration must be complete")
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
httpError(r, w, http.StatusBadRequest,
"dhcpv4 or dhcpv6 configuration must be complete")
return
}
s.Stop()
if js.Exists("enabled") {
s.conf.Enabled = newconfig.Enabled
if conf.Enabled != nbNull {
s.conf.Enabled = conf.Enabled == nbTrue
}
if js.Exists("interface_name") {
s.conf.InterfaceName = newconfig.InterfaceName
if conf.InterfaceName != "" {
s.conf.InterfaceName = conf.InterfaceName
}
if s4 != nil {
s.srv4 = s4
}
if s6 != nil {
s.srv6 = s6
}
s.conf.ConfigModified()
s.dbLoad()
if s.conf.Enabled {
staticIP, err := sysutil.IfaceHasStaticIP(newconfig.InterfaceName)
if !staticIP && err == nil {
err = sysutil.IfaceSetStaticIP(newconfig.InterfaceName)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to configure static IP: %s", err)
return
}
}
err = s.Start()
var code int
code, err = s.enableDHCP(conf.InterfaceName)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
httpError(r, w, code, "enabling dhcp: %s", err)
return
}
}
@@ -225,15 +257,15 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
type netInterfaceJSON struct {
Name string `json:"name"`
GatewayIP string `json:"gateway_ip"`
GatewayIP net.IP `json:"gateway_ip"`
HardwareAddr string `json:"hardware_address"`
Addrs4 []string `json:"ipv4_addresses"`
Addrs6 []string `json:"ipv6_addresses"`
Addrs4 []net.IP `json:"ipv4_addresses"`
Addrs6 []net.IP `json:"ipv6_addresses"`
Flags string `json:"flags"`
}
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{}
response := map[string]netInterfaceJSON{}
ifaces, err := util.GetValidNetInterfaces()
if err != nil {
@@ -277,9 +309,9 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
continue
}
if ipnet.IP.To4() != nil {
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String())
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP)
} else {
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String())
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP)
}
}
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
@@ -295,6 +327,40 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
}
}
// dhcpSearchOtherResult contains information about other DHCP server for
// specific network interface.
type dhcpSearchOtherResult struct {
Found string `json:"found,omitempty"`
Error string `json:"error,omitempty"`
}
// dhcpStaticIPStatus contains information about static IP address for DHCP
// server.
type dhcpStaticIPStatus struct {
Static string `json:"static"`
IP string `json:"ip,omitempty"`
Error string `json:"error,omitempty"`
}
// dhcpSearchV4Result contains information about DHCPv4 server for specific
// network interface.
type dhcpSearchV4Result struct {
OtherServer dhcpSearchOtherResult `json:"other_server"`
StaticIP dhcpStaticIPStatus `json:"static_ip"`
}
// dhcpSearchV6Result contains information about DHCPv6 server for specific
// network interface.
type dhcpSearchV6Result struct {
OtherServer dhcpSearchOtherResult `json:"other_server"`
}
// dhcpSearchResult is a response for /control/dhcp/find_active_dhcp endpoint.
type dhcpSearchResult struct {
V4 dhcpSearchV4Result `json:"v4"`
V6 dhcpSearchV6Result `json:"v6"`
}
// Perform the following tasks:
// . Search for another DHCP server running
// . Check if a static IP is configured for the network interface
@@ -317,50 +383,42 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
return
}
result := dhcpSearchResult{
V4: dhcpSearchV4Result{
OtherServer: dhcpSearchOtherResult{},
StaticIP: dhcpStaticIPStatus{},
},
V6: dhcpSearchV6Result{
OtherServer: dhcpSearchOtherResult{},
},
}
found4, err4 := CheckIfOtherDHCPServersPresentV4(interfaceName)
staticIP := map[string]interface{}{}
isStaticIP, err := sysutil.IfaceHasStaticIP(interfaceName)
staticIPStatus := "yes"
if err != nil {
staticIPStatus = "error"
staticIP["error"] = err.Error()
result.V4.StaticIP.Static = "error"
result.V4.StaticIP.Error = err.Error()
} else if !isStaticIP {
staticIPStatus = "no"
staticIP["ip"] = util.GetSubnet(interfaceName)
result.V4.StaticIP.Static = "no"
result.V4.StaticIP.IP = util.GetSubnet(interfaceName).String()
}
staticIP["static"] = staticIPStatus
v4 := map[string]interface{}{}
othSrv := map[string]interface{}{}
foundVal := "no"
if found4 {
foundVal = "yes"
result.V4.OtherServer.Found = "yes"
} else if err4 != nil {
foundVal = "error"
othSrv["error"] = err4.Error()
result.V4.OtherServer.Found = "error"
result.V4.OtherServer.Error = err4.Error()
}
othSrv["found"] = foundVal
v4["other_server"] = othSrv
v4["static_ip"] = staticIP
found6, err6 := CheckIfOtherDHCPServersPresentV6(interfaceName)
v6 := map[string]interface{}{}
othSrv = map[string]interface{}{}
foundVal = "no"
if found6 {
foundVal = "yes"
result.V6.OtherServer.Found = "yes"
} else if err6 != nil {
foundVal = "error"
othSrv["error"] = err6.Error()
result.V6.OtherServer.Found = "error"
result.V6.OtherServer.Error = err6.Error()
}
othSrv["found"] = foundVal
v6["other_server"] = othSrv
result := map[string]interface{}{}
result["v4"] = v4
result["v6"] = v6
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result)
@@ -371,103 +429,75 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
}
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
lj := staticLeaseJSON{}
lj := Lease{}
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := net.ParseIP(lj.IP)
if ip != nil && ip.To4() == nil {
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
if lj.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
lease := Lease{
IP: ip,
HWAddr: mac,
}
return
}
err = s.srv6.AddStaticLease(lease)
ip4 := lj.IP.To4()
if ip4 == nil {
lj.IP = lj.IP.To16()
err = s.srv6.AddStaticLease(lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
return
}
ip, _ = parseIPv4(lj.IP)
if ip == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = s.srv4.AddStaticLease(lease)
lj.IP = ip4
err = s.srv4.AddStaticLease(lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
}
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
lj := staticLeaseJSON{}
lj := Lease{}
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := net.ParseIP(lj.IP)
if ip != nil && ip.To4() == nil {
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
if lj.IP == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
lease := Lease{
IP: ip,
HWAddr: mac,
}
return
}
err = s.srv6.RemoveStaticLease(lease)
ip4 := lj.IP.To4()
if ip4 == nil {
lj.IP = lj.IP.To16()
err = s.srv6.RemoveStaticLease(lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
return
}
ip, _ = parseIPv4(lj.IP)
if ip == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, _ := net.ParseMAC(lj.HWAddr)
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = s.srv4.RemoveStaticLease(lease)
lj.IP = ip4
err = s.srv4.RemoveStaticLease(lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
}

View File

@@ -6,6 +6,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_notImplemented(t *testing.T) {
@@ -14,7 +15,7 @@ func TestServer_notImplemented(t *testing.T) {
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
assert.Nil(t, err)
require.Nil(t, err)
h(w, r)
assert.Equal(t, http.StatusNotImplemented, w.Code)

View File

@@ -17,14 +17,14 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/hugelgupf/socketpair"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}
type handler struct {

View File

@@ -0,0 +1,58 @@
package dhcpd
import (
"bytes"
"fmt"
)
// nullBool is a nullable boolean. Use these in JSON requests and responses
// instead of pointers to bool.
//
// TODO(a.garipov): Inspect uses of *bool, move this type into some new package
// if we need it somewhere else.
type nullBool uint8
// nullBool values
const (
nbNull nullBool = iota
nbTrue
nbFalse
)
// String implements the fmt.Stringer interface for nullBool.
func (nb nullBool) String() (s string) {
switch nb {
case nbNull:
return "null"
case nbTrue:
return "true"
case nbFalse:
return "false"
}
return fmt.Sprintf("!invalid nullBool %d", uint8(nb))
}
// boolToNullBool converts a bool into a nullBool.
func boolToNullBool(cond bool) (nb nullBool) {
if cond {
return nbTrue
}
return nbFalse
}
// UnmarshalJSON implements the json.Unmarshaler interface for *nullBool.
func (nb *nullBool) UnmarshalJSON(b []byte) (err error) {
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
*nb = nbNull
} else if bytes.Equal(b, []byte("true")) {
*nb = nbTrue
} else if bytes.Equal(b, []byte("false")) {
*nb = nbFalse
} else {
return fmt.Errorf("invalid nullBool value %q", b)
}
return nil
}

View File

@@ -0,0 +1,69 @@
package dhcpd
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNullBool_UnmarshalText(t *testing.T) {
testCases := []struct {
name string
data []byte
wantErrMsg string
want nullBool
}{{
name: "empty",
data: []byte{},
wantErrMsg: "",
want: nbNull,
}, {
name: "null",
data: []byte("null"),
wantErrMsg: "",
want: nbNull,
}, {
name: "true",
data: []byte("true"),
wantErrMsg: "",
want: nbTrue,
}, {
name: "false",
data: []byte("false"),
wantErrMsg: "",
want: nbFalse,
}, {
name: "invalid",
data: []byte("flase"),
wantErrMsg: `invalid nullBool value "flase"`,
want: nbNull,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got nullBool
err := got.UnmarshalJSON(tc.data)
if tc.wantErrMsg == "" {
assert.Nil(t, err)
} else {
require.NotNil(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error())
}
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
want := nbTrue
var got struct {
A nullBool
}
err := json.Unmarshal([]byte(`{"A":true}`), &got)
require.Nil(t, err)
assert.Equal(t, want, got.A)
})
}

View File

@@ -13,8 +13,8 @@ import (
)
type raCtx struct {
raAllowSlaac bool // send RA packets without MO flags
raSlaacOnly bool // send RA packets with MO flags
raAllowSLAAC bool // send RA packets without MO flags
raSLAACOnly bool // send RA packets with MO flags
ipAddr net.IP // source IP address (link-local-unicast)
dnsIPAddr net.IP // IP address for DNS Server option
prefixIPAddr net.IP // IP address for Prefix option
@@ -159,7 +159,7 @@ func createICMPv6RAPacket(params icmpv6RA) []byte {
func (ra *raCtx) Init() error {
ra.stop.Store(0)
ra.conn = nil
if !(ra.raAllowSlaac || ra.raSlaacOnly) {
if !(ra.raAllowSLAAC || ra.raSLAACOnly) {
return nil
}
@@ -167,8 +167,8 @@ func (ra *raCtx) Init() error {
ra.ipAddr, ra.dnsIPAddr)
params := icmpv6RA{
managedAddressConfiguration: !ra.raSlaacOnly,
otherConfiguration: !ra.raSlaacOnly,
managedAddressConfiguration: !ra.raSLAACOnly,
otherConfiguration: !ra.raSLAACOnly,
mtu: uint32(ra.iface.MTU),
prefixLen: 64,
recursiveDNSServer: ra.dnsIPAddr,

View File

@@ -1,7 +1,6 @@
package dhcpd
import (
"bytes"
"net"
"testing"
@@ -9,7 +8,7 @@ import (
)
func TestRA(t *testing.T) {
ra := icmpv6RA{
data := createICMPv6RAPacket(icmpv6RA{
managedAddressConfiguration: false,
otherConfiguration: true,
mtu: 1500,
@@ -17,8 +16,7 @@ func TestRA(t *testing.T) {
prefixLen: 64,
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
}
data := createICMPv6RAPacket(ra)
})
dataCorrect := []byte{
0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
@@ -27,5 +25,5 @@ func TestRA(t *testing.T) {
0x19, 0x03, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x27, 0xff, 0xfe, 0x00, 0x00, 0x00,
}
assert.True(t, bytes.Equal(data, dataCorrect))
assert.Equal(t, dataCorrect, data)
}

View File

@@ -33,22 +33,22 @@ type DHCPServer interface {
// V4ServerConf - server configuration
type V4ServerConf struct {
Enabled bool `yaml:"-"`
InterfaceName string `yaml:"-"`
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
GatewayIP string `yaml:"gateway_ip"`
SubnetMask string `yaml:"subnet_mask"`
GatewayIP net.IP `yaml:"gateway_ip" json:"gateway_ip"`
SubnetMask net.IP `yaml:"subnet_mask" json:"subnet_mask"`
// The first & the last IP address for dynamic leases
// Bytes [0..2] of the last allowed IP address must match the first IP
RangeStart string `yaml:"range_start"`
RangeEnd string `yaml:"range_end"`
RangeStart net.IP `yaml:"range_start" json:"range_start"`
RangeEnd net.IP `yaml:"range_end" json:"range_end"`
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
// IP conflict detector: time (ms) to wait for ICMP reply
// 0: disable
ICMPTimeout uint32 `yaml:"icmp_timeout_msec"`
ICMPTimeout uint32 `yaml:"icmp_timeout_msec" json:"-"`
// Custom Options.
//
@@ -58,7 +58,7 @@ type V4ServerConf struct {
//
// Option with IP data (only 1 IP is supported):
// DEC_CODE ip IP_ADDR
Options []string `yaml:"options"`
Options []string `yaml:"options" json:"-"`
ipStart net.IP // starting IP address for dynamic leases
ipEnd net.IP // ending IP address for dynamic leases
@@ -74,17 +74,17 @@ type V4ServerConf struct {
// V6ServerConf - server configuration
type V6ServerConf struct {
Enabled bool `yaml:"-"`
InterfaceName string `yaml:"-"`
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
// The first IP address for dynamic leases
// The last allowed IP address ends with 0xff byte
RangeStart string `yaml:"range_start"`
RangeStart net.IP `yaml:"range_start" json:"range_start"`
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
RaSlaacOnly bool `yaml:"ra_slaac_only"` // send ICMPv6.RA packets without MO flags
RaAllowSlaac bool `yaml:"ra_allow_slaac"` // send ICMPv6.RA packets with MO flags
RASLAACOnly bool `yaml:"ra_slaac_only" json:"-"` // send ICMPv6.RA packets without MO flags
RAAllowSLAAC bool `yaml:"ra_allow_slaac" json:"-"` // send ICMPv6.RA packets with MO flags
ipStart net.IP // starting IP address for dynamic leases
leaseTime time.Duration // the time during which a dynamic lease is considered valid

View File

@@ -23,7 +23,8 @@ type v4Server struct {
srv *server4.Server
leasesLock sync.Mutex
leases []*Lease
ipAddrs [256]byte
// TODO(e.burkov): This field type should be a normal bitmap.
ipAddrs [256]byte
conf V4ServerConf
}
@@ -77,7 +78,10 @@ func (s *v4Server) blacklisted(l *Lease) bool {
// GetLeases returns the list of current DHCP leases (thread-safe)
func (s *v4Server) GetLeases(flags int) []Lease {
var result []Lease
// The function shouldn't return nil value because zero-length slice
// behaves differently in cases like marshalling. Our front-end also
// requires non-nil value in the response.
result := []Lease{}
now := time.Now().Unix()
s.leasesLock.Lock()
@@ -589,7 +593,7 @@ func (s *v4Server) Start() error {
s.conf.dnsIPAddrs = dnsIPAddrs
laddr := &net.UDPAddr{
IP: net.ParseIP("0.0.0.0"),
IP: net.IP{0, 0, 0, 0},
Port: dhcpv4.ServerPort,
}
s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger())
@@ -632,19 +636,18 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
}
var err error
s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP)
s.conf.routerIP, err = tryTo4(s.conf.GatewayIP)
if err != nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}
subnet, err := parseIPv4(s.conf.SubnetMask)
if err != nil || !isValidSubnetMask(subnet) {
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %s", s.conf.SubnetMask)
if s.conf.SubnetMask == nil {
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask)
}
s.conf.subnetMask = make([]byte, 4)
copy(s.conf.subnetMask, subnet)
copy(s.conf.subnetMask, s.conf.SubnetMask.To4())
s.conf.ipStart, err = parseIPv4(conf.RangeStart)
s.conf.ipStart, err = tryTo4(conf.RangeStart)
if s.conf.ipStart == nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}
@@ -652,7 +655,7 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
return s, fmt.Errorf("dhcpv4: invalid range start IP")
}
s.conf.ipEnd, err = parseIPv4(conf.RangeEnd)
s.conf.ipEnd, err = tryTo4(conf.RangeEnd)
if s.conf.ipEnd == nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeIface struct {
@@ -79,8 +80,8 @@ func TestIfaceIPAddrs(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, gotErr := ifaceIPAddrs(tc.iface, tc.ipv)
require.True(t, errors.Is(gotErr, tc.wantErr))
assert.Equal(t, tc.want, got)
assert.True(t, errors.Is(gotErr, tc.wantErr))
})
}
}
@@ -140,12 +141,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
want: nil,
wantErr: errTest,
}, {
name: "ipv4_wait",
iface: &waitingFakeIface{
addrs: []net.Addr{addr4},
err: nil,
n: 1,
},
name: "ipv4_wait",
iface: &waitingFakeIface{addrs: []net.Addr{addr4}, err: nil, n: 1},
ipv: ipVersion4,
want: []net.IP{ip4, ip4},
wantErr: nil,
@@ -168,12 +165,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
want: nil,
wantErr: errTest,
}, {
name: "ipv6_wait",
iface: &waitingFakeIface{
addrs: []net.Addr{addr6},
err: nil,
n: 1,
},
name: "ipv6_wait",
iface: &waitingFakeIface{addrs: []net.Addr{addr6}, err: nil, n: 1},
ipv: ipVersion6,
want: []net.IP{ip6, ip6},
wantErr: nil,
@@ -182,8 +175,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, gotErr := ifaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
require.True(t, errors.Is(gotErr, tc.wantErr))
assert.Equal(t, tc.want, got)
assert.True(t, errors.Is(gotErr, tc.wantErr))
})
}
}

View File

@@ -8,231 +8,283 @@ import (
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func notify4(flags uint32) {
}
func TestV4StaticLeaseAddRemove(t *testing.T) {
conf := V4ServerConf{
func TestV4_AddRemove_static(t *testing.T) {
s, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
}
s, err := v4Create(conf)
assert.True(t, err == nil)
})
require.Nil(t, err)
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, ls)
// add static lease
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// Add static lease.
l := Lease{
IP: net.IP{192, 168, 10, 150},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
require.Nil(t, s.AddStaticLease(l))
assert.NotNil(t, s.AddStaticLease(l))
// try to add the same static lease - fail
assert.True(t, s.AddStaticLease(l) != nil)
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
// try to remove static lease - fail
l.IP = net.ParseIP("192.168.10.110").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil)
// Try to remove static lease.
assert.NotNil(t, s.RemoveStaticLease(Lease{
IP: net.IP{192, 168, 10, 110},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}))
// remove static lease
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil)
// check
// Remove static lease.
require.Nil(t, s.RemoveStaticLease(l))
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, ls)
}
func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V4ServerConf{
func TestV4_AddReplace(t *testing.T) {
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
})
require.Nil(t, err)
// add dynamic lease
ld := Lease{}
ld.IP = net.ParseIP("192.168.10.150").To4()
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
s.addLease(&ld)
s, ok := sIface.(*v4Server)
require.True(t, ok)
// add dynamic lease
{
ld := Lease{}
ld.IP = net.ParseIP("192.168.10.151").To4()
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
s.addLease(&ld)
dynLeases := []Lease{{
IP: net.IP{192, 168, 10, 150},
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}, {
IP: net.IP{192, 168, 10, 151},
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}}
for i := range dynLeases {
s.addLease(&dynLeases[i])
}
// add static lease with the same IP
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
stLeases := []Lease{{
IP: net.IP{192, 168, 10, 150},
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}, {
IP: net.IP{192, 168, 10, 152},
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}}
// add static lease with the same MAC
l = Lease{}
l.IP = net.ParseIP("192.168.10.152").To4()
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
for _, l := range stLeases {
require.Nil(t, s.AddStaticLease(l))
}
// check
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls))
require.Len(t, ls, 2)
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.Equal(t, "192.168.10.152", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
for i, l := range ls {
assert.True(t, stLeases[i].IP.Equal(l.IP))
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
}
}
func TestV4StaticLeaseGet(t *testing.T) {
conf := V4ServerConf{
func TestV4StaticLease_Get(t *testing.T) {
var err error
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
})
require.Nil(t, err)
s, ok := sIface.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
l := Lease{
IP: net.IP{192, 168, 10, 150},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
require.Nil(t, s.AddStaticLease(l))
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
var req, resp *dhcpv4.DHCPv4
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
// "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv4.NewDiscovery(mac)
resp, _ := dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
t.Run("discover", func(t *testing.T) {
var err error
// check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
req, err = dhcpv4.NewDiscovery(mac)
require.Nil(t, err)
// "Request"
req, _ = dhcpv4.NewRequestFromOffer(resp)
resp, _ = dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
resp, err = dhcpv4.NewReplyFromRequest(req)
require.Nil(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
require.Nil(t, err)
// check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
t.Run("offer", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, l.IP.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
})
t.Run("request", func(t *testing.T) {
req, err = dhcpv4.NewRequestFromOffer(resp)
require.Nil(t, err)
resp, err = dhcpv4.NewReplyFromRequest(req)
require.Nil(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
require.Nil(t, err)
t.Run("ack", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, l.IP.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
})
dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
require.Len(t, dnsAddrs, 1)
assert.True(t, s.conf.GatewayIP.Equal(dnsAddrs[0]))
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
t.Run("check_lease", func(t *testing.T) {
ls := s.GetLeases(LeasesStatic)
require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, mac, ls[0].HWAddr)
})
}
func TestV4DynamicLeaseGet(t *testing.T) {
conf := V4ServerConf{
func TestV4DynamicLease_Get(t *testing.T) {
var err error
sIface, err := v4Create(V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
Options: []string{
"81 hex 303132",
"82 ip 1.2.3.4",
},
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
})
require.Nil(t, err)
// "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv4.NewDiscovery(mac)
resp, _ := dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
s, ok := sIface.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
// check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String())
var req, resp *dhcpv4.DHCPv4
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
// "Request"
req, _ = dhcpv4.NewRequestFromOffer(resp)
resp, _ = dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
t.Run("discover", func(t *testing.T) {
req, err = dhcpv4.NewDiscovery(mac)
require.Nil(t, err)
// check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
resp, err = dhcpv4.NewReplyFromRequest(req)
require.Nil(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
require.Nil(t, err)
t.Run("offer", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)])))
})
t.Run("request", func(t *testing.T) {
var err error
req, err = dhcpv4.NewRequestFromOffer(resp)
require.Nil(t, err)
resp, err = dhcpv4.NewReplyFromRequest(req)
require.Nil(t, err)
assert.Equal(t, 1, s.process(req, resp))
})
require.Nil(t, err)
t.Run("ack", func(t *testing.T) {
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, mac, resp.ClientHWAddr)
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
})
dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
require.Len(t, dnsAddrs, 1)
assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "192.168.10.100", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
start := net.ParseIP("192.168.10.100").To4()
stop := net.ParseIP("192.168.10.200").To4()
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4()))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4()))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4()))
assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4()))
t.Run("check_lease", func(t *testing.T) {
ls := s.GetLeases(LeasesDynamic)
assert.Len(t, ls, 1)
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
assert.Equal(t, mac, ls[0].HWAddr)
})
}
func TestIP4InRange(t *testing.T) {
start := net.IP{192, 168, 10, 100}
stop := net.IP{192, 168, 10, 200}
testCases := []struct {
ip net.IP
want bool
}{{
ip: net.IP{192, 168, 10, 99},
want: false,
}, {
ip: net.IP{192, 168, 11, 100},
want: false,
}, {
ip: net.IP{192, 168, 11, 201},
want: false,
}, {
ip: start,
want: true,
}}
for _, tc := range testCases {
t.Run(tc.ip.String(), func(t *testing.T) {
assert.Equal(t, tc.want, ip4InRange(start, stop, tc.ip))
})
}
}

View File

@@ -42,7 +42,6 @@ func (s *v6Server) WriteDiskConfig6(c *V6ServerConf) {
}
// Return TRUE if IP address is within range [start..0xff]
// nolint(staticcheck)
func ip6InRange(start, ip net.IP) bool {
if len(start) != 16 {
return false
@@ -72,7 +71,10 @@ func (s *v6Server) ResetLeases(ll []*Lease) {
// GetLeases - get current leases
func (s *v6Server) GetLeases(flags int) []Lease {
var result []Lease
// The function shouldn't return nil value because zero-length slice
// behaves differently in cases like marshalling. Our front-end also
// requires non-nil value in the response.
result := []Lease{}
s.leasesLock.Lock()
for _, lease := range s.leases {
if lease.Expiry.Unix() == leaseExpireStatic {
@@ -550,8 +552,8 @@ func (s *v6Server) initRA(iface *net.Interface) error {
}
}
s.ra.raAllowSlaac = s.conf.RaAllowSlaac
s.ra.raSlaacOnly = s.conf.RaSlaacOnly
s.ra.raAllowSLAAC = s.conf.RAAllowSLAAC
s.ra.raSLAACOnly = s.conf.RASLAACOnly
s.ra.dnsIPAddr = s.ra.ipAddr
s.ra.prefixIPAddr = s.conf.ipStart
s.ra.ifaceName = s.conf.InterfaceName
@@ -592,7 +594,7 @@ func (s *v6Server) Start() error {
}
// don't initialize DHCPv6 server if we must force the clients to use SLAAC
if s.conf.RaSlaacOnly {
if s.conf.RASLAACOnly {
log.Debug("DHCPv6: not starting DHCPv6 server due to ra_slaac_only=true")
return nil
}
@@ -657,7 +659,7 @@ func v6Create(conf V6ServerConf) (DHCPServer, error) {
return s, nil
}
s.conf.ipStart = net.ParseIP(conf.RangeStart)
s.conf.ipStart = conf.RangeStart
if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil {
return s, fmt.Errorf("dhcpv6: invalid range-start IP: %s", conf.RangeStart)
}

View File

@@ -9,217 +9,283 @@ import (
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/iana"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func notify6(flags uint32) {
}
func TestV6StaticLeaseAddRemove(t *testing.T) {
conf := V6ServerConf{
func TestV6_AddRemove_static(t *testing.T) {
s, err := v6Create(V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
RangeStart: net.ParseIP("2001::1"),
notify: notify6,
})
require.Nil(t, err)
require.Empty(t, s.GetLeases(LeasesStatic))
// Add static lease.
l := Lease{
IP: net.ParseIP("2001::1"),
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
s, err := v6Create(conf)
assert.True(t, err == nil)
require.Nil(t, s.AddStaticLease(l))
// Try to add the same static lease.
require.NotNil(t, s.AddStaticLease(l))
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
require.Len(t, ls, 1)
assert.Equal(t, l.IP, ls[0].IP)
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
// add static lease
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// Try to remove non-existent static lease.
require.NotNil(t, s.RemoveStaticLease(Lease{
IP: net.ParseIP("2001::2"),
HWAddr: l.HWAddr,
}))
// try to add static lease - fail
assert.True(t, s.AddStaticLease(l) != nil)
// Remove static lease.
require.Nil(t, s.RemoveStaticLease(l))
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
// try to remove static lease - fail
l.IP = net.ParseIP("2001::2")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil)
// remove static lease
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil)
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
assert.Empty(t, s.GetLeases(LeasesStatic))
}
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V6ServerConf{
func TestV6_AddReplace(t *testing.T) {
sIface, err := v6Create(V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
RangeStart: net.ParseIP("2001::1"),
notify: notify6,
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
})
require.Nil(t, err)
s, ok := sIface.(*v6Server)
require.True(t, ok)
// add dynamic lease
ld := Lease{}
ld.IP = net.ParseIP("2001::1")
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
s.addLease(&ld)
// Add dynamic leases.
dynLeases := []*Lease{{
IP: net.ParseIP("2001::1"),
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}, {
IP: net.ParseIP("2001::2"),
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}}
// add dynamic lease
{
ld := Lease{}
ld.IP = net.ParseIP("2001::2")
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
s.addLease(&ld)
for _, l := range dynLeases {
s.addLease(l)
}
// add static lease with the same IP
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
stLeases := []Lease{{
IP: net.ParseIP("2001::1"),
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}, {
IP: net.ParseIP("2001::3"),
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}}
// add static lease with the same MAC
l = Lease{}
l.IP = net.ParseIP("2001::3")
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
for _, l := range stLeases {
require.Nil(t, s.AddStaticLease(l))
}
// check
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls))
require.Len(t, ls, 2)
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.Equal(t, "2001::3", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
for i, l := range ls {
assert.True(t, stLeases[i].IP.Equal(l.IP))
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
}
}
func TestV6GetLease(t *testing.T) {
conf := V6ServerConf{
var err error
sIface, err := v6Create(V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
RangeStart: net.ParseIP("2001::1"),
notify: notify6,
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
})
require.Nil(t, err)
s, ok := sIface.(*v6Server)
require.True(t, ok)
dnsAddr := net.ParseIP("2000::1")
s.conf.dnsIPAddrs = []net.IP{dnsAddr}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
l := Lease{
IP: net.ParseIP("2001::1"),
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
require.Nil(t, s.AddStaticLease(l))
// "Solicit"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv6.NewSolicit(mac)
msg, _ := req.GetInnerMessage()
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
assert.True(t, s.process(msg, req, resp))
var req, resp, msg *dhcpv6.Message
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
t.Run("solicit", func(t *testing.T) {
req, err = dhcpv6.NewSolicit(mac)
require.Nil(t, err)
msg, err = req.GetInnerMessage()
require.Nil(t, err)
resp, err = dhcpv6.NewAdvertiseFromSolicit(msg)
require.Nil(t, err)
assert.True(t, s.process(msg, req, resp))
})
require.Nil(t, err)
resp.AddOption(dhcpv6.OptServerID(s.sid))
// check "Advertise"
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
oia := resp.Options.OneIANA()
oiaAddr := oia.Options.OneAddress()
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
var oia *dhcpv6.OptIANA
var oiaAddr *dhcpv6.OptIAAddress
t.Run("advertise", func(t *testing.T) {
require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
// "Request"
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
msg, _ = req.GetInnerMessage()
resp, _ = dhcpv6.NewReplyFromMessage(msg)
assert.True(t, s.process(msg, req, resp))
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
})
// check "Reply"
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
t.Run("request", func(t *testing.T) {
req, err = dhcpv6.NewRequestFromAdvertise(resp)
require.Nil(t, err)
msg, err = req.GetInnerMessage()
require.Nil(t, err)
resp, err = dhcpv6.NewReplyFromMessage(msg)
require.Nil(t, err)
assert.True(t, s.process(msg, req, resp))
})
require.Nil(t, err)
t.Run("reply", func(t *testing.T) {
require.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
})
dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "2000::1", dnsAddrs[0].String())
require.Len(t, dnsAddrs, 1)
assert.Equal(t, dnsAddr, dnsAddrs[0])
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
t.Run("lease", func(t *testing.T) {
ls := s.GetLeases(LeasesStatic)
require.Len(t, ls, 1)
assert.Equal(t, l.IP, ls[0].IP)
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
})
}
func TestV6GetDynamicLease(t *testing.T) {
conf := V6ServerConf{
sIface, err := v6Create(V6ServerConf{
Enabled: true,
RangeStart: "2001::2",
RangeStart: net.ParseIP("2001::2"),
notify: notify6,
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
}
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
})
require.Nil(t, err)
s, ok := sIface.(*v6Server)
require.True(t, ok)
// "Solicit"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv6.NewSolicit(mac)
msg, _ := req.GetInnerMessage()
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
assert.True(t, s.process(msg, req, resp))
dnsAddr := net.ParseIP("2000::1")
s.conf.dnsIPAddrs = []net.IP{dnsAddr}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
}
var req, resp, msg *dhcpv6.Message
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
t.Run("solicit", func(t *testing.T) {
req, err = dhcpv6.NewSolicit(mac)
require.Nil(t, err)
msg, err = req.GetInnerMessage()
require.Nil(t, err)
resp, err = dhcpv6.NewAdvertiseFromSolicit(msg)
require.Nil(t, err)
assert.True(t, s.process(msg, req, resp))
})
require.Nil(t, err)
resp.AddOption(dhcpv6.OptServerID(s.sid))
// check "Advertise"
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
oia := resp.Options.OneIANA()
oiaAddr := oia.Options.OneAddress()
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
var oia *dhcpv6.OptIANA
var oiaAddr *dhcpv6.OptIAAddress
t.Run("advertise", func(t *testing.T) {
require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
})
// "Request"
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
msg, _ = req.GetInnerMessage()
resp, _ = dhcpv6.NewReplyFromMessage(msg)
assert.True(t, s.process(msg, req, resp))
t.Run("request", func(t *testing.T) {
req, err = dhcpv6.NewRequestFromAdvertise(resp)
require.Nil(t, err)
// check "Reply"
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
msg, err = req.GetInnerMessage()
require.Nil(t, err)
resp, err = dhcpv6.NewReplyFromMessage(msg)
require.Nil(t, err)
assert.True(t, s.process(msg, req, resp))
})
require.Nil(t, err)
t.Run("reply", func(t *testing.T) {
require.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
})
dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "2000::1", dnsAddrs[0].String())
require.Len(t, dnsAddrs, 1)
assert.Equal(t, dnsAddr, dnsAddrs[0])
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "2001::2", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3")))
t.Run("lease", func(t *testing.T) {
ls := s.GetLeases(LeasesDynamic)
require.Len(t, ls, 1)
assert.Equal(t, "2001::2", ls[0].IP.String())
assert.Equal(t, mac, ls[0].HWAddr)
})
}
func TestIP6InRange(t *testing.T) {
start := net.ParseIP("2001::2")
testCases := []struct {
ip net.IP
want bool
}{{
ip: net.ParseIP("2001::1"),
want: false,
}, {
ip: net.ParseIP("2002::2"),
want: false,
}, {
ip: start,
want: true,
}, {
ip: net.ParseIP("2001::3"),
want: true,
}}
for _, tc := range testCases {
t.Run(tc.ip.String(), func(t *testing.T) {
assert.Equal(t, tc.want, ip6InRange(start, tc.ip))
})
}
}

View File

@@ -161,7 +161,73 @@ var serviceRulesArray = []svc{
"||douyin.com^",
"||tiktokv.com^",
}},
{"qq", []string{"||qq.com^", "||qqzaixian.com^"}},
{"vimeo", []string{
"||vimeo.com^",
"||vimeocdn.com^",
"*vod-adaptive.akamaized.net^",
}},
{"pinterest", []string{
"||pinterest.*^",
"||pinimg.com^",
}},
{"imgur", []string{
"||imgur.com^",
}},
{"dailymotion", []string{
"||dailymotion.com^",
"||dm-event.net^",
"||dmcdn.net^",
}},
{"qq", []string{
// block qq.com and subdomains excluding WeChat domains
"^(?!weixin|wx)([^.]+\\.)?qq\\.com$",
"||qqzaixian.com^",
}},
{"wechat", []string{
"||wechat.com^",
"||weixin.qq.com^",
"||wx.qq.com^",
}},
{"viber", []string{
"||viber.com^",
}},
{"weibo", []string{
"||weibo.com^",
}},
{"9gag", []string{
"||9cache.com^",
"||gag.com^",
}},
{"telegram", []string{
"||t.me^",
"||telegram.me^",
"||telegram.org^",
}},
{"disneyplus", []string{
"||disney-plus.net^",
"||disneyplus.com^",
}},
{"hulu", []string{
"||hulu.com^",
}},
{"spotify", []string{
"/_spotify-connect._tcp.local/",
"||spotify.com^",
"||scdn.co^",
"||spotify.com.edgesuite.net^",
"||spotify.map.fastly.net^",
"||spotify.map.fastlylb.net^",
"||spotifycdn.net^",
"||audio-ak-spotify-com.akamaized.net^",
"||audio4-ak-spotify-com.akamaized.net^",
"||heads-ak-spotify-com.akamaized.net^",
"||heads4-ak-spotify-com.akamaized.net^",
}},
{"tinder", []string{
"||gotinder.com^",
"||tinder.com^",
"||tindersparks.com^",
}},
}
// convert array to map
@@ -242,6 +308,6 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
// registerBlockedServicesHandlers - register HTTP handlers
func (d *DNSFilter) registerBlockedServicesHandlers() {
d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
}

View File

@@ -0,0 +1,37 @@
// +build ignore
package dnsfilter
import (
"fmt"
"sort"
"testing"
)
// This is a simple tool that takes a list of services and prints them to the output.
// It is supposed to be used to update:
// client/src/helpers/constants.js
// client/src/components/ui/Icons.js
//
// Usage:
// 1. go run ./internal/dnsfilter/blocked_test.go
// 2. Use the output to replace `SERVICES` array in "client/src/helpers/constants.js".
// 3. You'll need to enter services names manually.
// 4. Don't forget to add missing icons to "client/src/components/ui/Icons.js".
//
// TODO(ameshkov): Rework generator: have a JSON file with all the metadata we need
// then use this JSON file to generate JS and Go code
func TestGenServicesArray(t *testing.T) {
services := make([]svc, len(serviceRulesArray))
copy(services, serviceRulesArray)
sort.Slice(services, func(i, j int) bool {
return services[i].name < services[j].name
})
fmt.Println("export const SERVICES = [")
for _, s := range services {
fmt.Printf(" {\n id: '%s',\n name: '%s',\n },\n", s.name, s.name)
}
fmt.Println("];")
}

View File

@@ -2,6 +2,7 @@
package dnsfilter
import (
"context"
"fmt"
"io/ioutil"
"net"
@@ -36,12 +37,18 @@ type RequestFilteringSettings struct {
ParentalEnabled bool
ClientName string
ClientIP string
ClientIP net.IP
ClientTags []string
ServicesRules []ServiceEntry
}
// Resolver is the interface for net.Resolver to simplify testing.
type Resolver interface {
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error)
}
// Config allows you to configure DNS filtering with New() or just change variables directly.
type Config struct {
ParentalEnabled bool `yaml:"parental_enabled"`
@@ -68,6 +75,9 @@ type Config struct {
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver
}
// LookupStats store stats collected during safebrowsing or parental checks
@@ -110,6 +120,11 @@ type DNSFilter struct {
// Channel for passing data to filters-initializer goroutine
filtersInitializerChan chan filtersInitializerParams
filtersInitializerLock sync.Mutex
// resolver only looks up the IP address of the host while safe search.
//
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
resolver Resolver
}
// Filter represents a filter list
@@ -148,17 +163,21 @@ const (
// FilteredBlockedService - the host is blocked by "blocked services" settings
FilteredBlockedService
// ReasonRewrite is returned when there was a rewrite by
// a legacy DNS Rewrite rule.
ReasonRewrite
// Rewritten is returned when there was a rewrite by a legacy DNS
// rewrite rule.
Rewritten
// RewriteAutoHosts is returned when there was a rewrite by
// autohosts rules (/etc/hosts and so on).
RewriteAutoHosts
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
// rules (/etc/hosts and so on).
RewrittenAutoHosts
// DNSRewriteRule is returned when a $dnsrewrite filter rule was
// applied.
DNSRewriteRule
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
//
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
// their functionality into RewrittenRule.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
RewrittenRule
)
// TODO(a.garipov): Resync with actual code names or replace completely
@@ -175,11 +194,9 @@ var reasonNames = []string{
FilteredSafeSearch: "FilteredSafeSearch",
FilteredBlockedService: "FilteredBlockedService",
ReasonRewrite: "Rewrite",
RewriteAutoHosts: "RewriteEtcHosts",
DNSRewriteRule: "DNSRewriteRule",
Rewritten: "Rewrite",
RewrittenAutoHosts: "RewriteEtcHosts",
RewrittenRule: "RewriteRule",
}
func (r Reason) String() string {
@@ -331,15 +348,15 @@ type Result struct {
Rules []*ResultRule `json:",omitempty"`
// ReverseHosts is the reverse lookup rewrite result. It is
// empty unless Reason is set to RewriteAutoHosts.
// empty unless Reason is set to RewrittenAutoHosts.
ReverseHosts []string `json:",omitempty"`
// IPList is the lookup rewrite result. It is empty unless
// Reason is set to RewriteAutoHosts or ReasonRewrite.
// Reason is set to RewrittenAutoHosts or Rewritten.
IPList []net.IP `json:",omitempty"`
// CanonName is the CNAME value from the lookup rewrite result.
// It is empty unless Reason is set to ReasonRewrite.
// It is empty unless Reason is set to Rewritten or RewrittenRule.
CanonName string `json:",omitempty"`
// ServiceName is the name of the blocked service. It is empty
@@ -379,7 +396,7 @@ func (d *DNSFilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
// first - check rewrites, they have the highest priority
result = d.processRewrites(host, qtype)
if result.Reason == ReasonRewrite {
if result.Reason == Rewritten {
return result, nil
}
@@ -453,7 +470,7 @@ func (d *DNSFilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
func (d *DNSFilter) checkAutoHosts(host string, qtype uint16, result *Result) (matched bool) {
ips := d.Config.AutoHosts.Process(host, qtype)
if ips != nil {
result.Reason = RewriteAutoHosts
result.Reason = RewrittenAutoHosts
result.IPList = ips
return true
@@ -461,7 +478,7 @@ func (d *DNSFilter) checkAutoHosts(host string, qtype uint16, result *Result) (m
revHosts := d.Config.AutoHosts.ProcessReverse(host, qtype)
if len(revHosts) != 0 {
result.Reason = RewriteAutoHosts
result.Reason = RewrittenAutoHosts
// TODO(a.garipov): Optimize this with a buffer.
result.ReverseHosts = make([]string, len(revHosts))
@@ -488,7 +505,7 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
rr := findRewrites(d.Rewrites, host)
if len(rr) != 0 {
res.Reason = ReasonRewrite
res.Reason = Rewritten
}
cnames := map[string]bool{}
@@ -674,9 +691,10 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS
ureq := urlfilter.DNSRequest{
Hostname: host,
SortedClientTags: setts.ClientTags,
ClientIP: setts.ClientIP,
ClientName: setts.ClientName,
DNSType: qtype,
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientIP: setts.ClientIP.String(),
ClientName: setts.ClientName,
DNSType: qtype,
}
if d.filteringEngineAllow != nil {
@@ -696,7 +714,7 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS
// awkward.
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
res = d.processDNSRewrites(dnsr)
if res.Reason == DNSRewriteRule && res.CanonName == host {
if res.Reason == RewrittenRule && res.CanonName == host {
// A rewrite of a host to itself. Go on and
// try matching other things.
} else {
@@ -781,6 +799,7 @@ func InitModule() {
// New creates properly initialized DNS Filter that is ready to be used.
func New(c *Config, blockFilters []Filter) *DNSFilter {
var resolver Resolver = net.DefaultResolver
if c != nil {
cacheConf := cache.Config{
EnableLRU: true,
@@ -800,9 +819,15 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
cacheConf.MaxSize = c.ParentalCacheSize
gctx.parentalCache = cache.New(cacheConf)
}
if c.CustomResolver != nil {
resolver = c.CustomResolver
}
}
d := new(DNSFilter)
d := &DNSFilter{
resolver: resolver,
}
err := d.initSecurityServices()
if err != nil {

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,8 @@ import (
// DNSRewriteResult is the result of application of $dnsrewrite rules.
type DNSRewriteResult struct {
RCode rules.RCode `json:",omitempty"`
Response DNSRewriteResultResponse `json:",omitempty"`
RCode rules.RCode `json:",omitempty"`
}
// DNSRewriteResultResponse is the collection of DNS response records
@@ -33,13 +33,13 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
if dr.NewCNAME != "" {
// NewCNAME rules have a higher priority than
// the other rules.
rules := []*ResultRule{{
rules = []*ResultRule{{
FilterListID: int64(nr.GetFilterListID()),
Text: nr.RuleText,
}}
return Result{
Reason: DNSRewriteRule,
Reason: RewrittenRule,
Rules: rules,
CanonName: dr.NewCNAME,
}
@@ -56,7 +56,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
default:
// RcodeRefused and other such codes have higher
// priority. Return immediately.
rules := []*ResultRule{{
rules = []*ResultRule{{
FilterListID: int64(nr.GetFilterListID()),
Text: nr.RuleText,
}}
@@ -65,7 +65,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
}
return Result{
Reason: DNSRewriteRule,
Reason: RewrittenRule,
Rules: rules,
DNSRewriteResult: dnsrr,
}
@@ -73,7 +73,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
}
return Result{
Reason: DNSRewriteRule,
Reason: RewrittenRule,
Rules: rules,
DNSRewriteResult: dnsrr,
}

View File

@@ -11,40 +11,41 @@ import (
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
const text = `
|cname^$dnsrewrite=new_cname
|cname^$dnsrewrite=new-cname
|a_record^$dnsrewrite=127.0.0.1
|a-record^$dnsrewrite=127.0.0.1
|aaaa_record^$dnsrewrite=::1
|aaaa-record^$dnsrewrite=::1
|txt_record^$dnsrewrite=NOERROR;TXT;hello_world
|txt-record^$dnsrewrite=NOERROR;TXT;hello-world
|refused^$dnsrewrite=REFUSED
|a_records^$dnsrewrite=127.0.0.1
|a_records^$dnsrewrite=127.0.0.2
|a-records^$dnsrewrite=127.0.0.1
|a-records^$dnsrewrite=127.0.0.2
|aaaa_records^$dnsrewrite=::1
|aaaa_records^$dnsrewrite=::2
|aaaa-records^$dnsrewrite=::1
|aaaa-records^$dnsrewrite=::2
|disable_one^$dnsrewrite=127.0.0.1
|disable_one^$dnsrewrite=127.0.0.2
@@||disable_one^$dnsrewrite=127.0.0.1
|disable-one^$dnsrewrite=127.0.0.1
|disable-one^$dnsrewrite=127.0.0.2
@@||disable-one^$dnsrewrite=127.0.0.1
|disable_cname^$dnsrewrite=127.0.0.1
|disable_cname^$dnsrewrite=new_cname
@@||disable_cname^$dnsrewrite=new_cname
|disable-cname^$dnsrewrite=127.0.0.1
|disable-cname^$dnsrewrite=new-cname
@@||disable-cname^$dnsrewrite=new-cname
|disable_cname_many^$dnsrewrite=127.0.0.1
|disable_cname_many^$dnsrewrite=new_cname_1
|disable_cname_many^$dnsrewrite=new_cname_2
@@||disable_cname_many^$dnsrewrite=new_cname_1
|disable-cname-many^$dnsrewrite=127.0.0.1
|disable-cname-many^$dnsrewrite=new-cname-1
|disable-cname-many^$dnsrewrite=new-cname-2
@@||disable-cname-many^$dnsrewrite=new-cname-1
|disable_all^$dnsrewrite=127.0.0.1
|disable_all^$dnsrewrite=127.0.0.2
@@||disable_all^$dnsrewrite
|disable-all^$dnsrewrite=127.0.0.1
|disable-all^$dnsrewrite=127.0.0.2
@@||disable-all^$dnsrewrite
`
f := NewForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &RequestFilteringSettings{
FilteringEnabled: true,
}
@@ -60,10 +61,10 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Equal(t, "new_cname", res.CanonName)
assert.Equal(t, "new-cname", res.CanonName)
})
t.Run("a_record", func(t *testing.T) {
t.Run("a-record", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
@@ -78,7 +79,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("aaaa_record", func(t *testing.T) {
t.Run("aaaa-record", func(t *testing.T) {
dtyp := dns.TypeAAAA
host := path.Base(t.Name())
@@ -93,7 +94,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("txt_record", func(t *testing.T) {
t.Run("txt-record", func(t *testing.T) {
dtyp := dns.TypeTXT
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
@@ -102,7 +103,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) {
assert.Equal(t, "hello_world", strVals[0])
assert.Equal(t, "hello-world", strVals[0])
}
}
})
@@ -117,7 +118,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("a_records", func(t *testing.T) {
t.Run("a-records", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
@@ -133,7 +134,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("aaaa_records", func(t *testing.T) {
t.Run("aaaa-records", func(t *testing.T) {
dtyp := dns.TypeAAAA
host := path.Base(t.Name())
@@ -149,7 +150,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("disable_one", func(t *testing.T) {
t.Run("disable-one", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
@@ -164,13 +165,13 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("disable_cname", func(t *testing.T) {
t.Run("disable-cname", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Equal(t, "", res.CanonName)
assert.Empty(t, res.CanonName)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
@@ -180,23 +181,23 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
}
})
t.Run("disable_cname_many", func(t *testing.T) {
t.Run("disable-cname-many", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Equal(t, "new_cname_2", res.CanonName)
assert.Equal(t, "new-cname-2", res.CanonName)
assert.Nil(t, res.DNSRewriteResult)
})
t.Run("disable_all", func(t *testing.T) {
t.Run("disable-all", func(t *testing.T) {
dtyp := dns.TypeA
host := path.Base(t.Name())
res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err)
assert.Equal(t, "", res.CanonName)
assert.Len(t, res.Rules, 0)
assert.Empty(t, res.CanonName)
assert.Empty(t, res.Rules)
})
}

View File

@@ -219,7 +219,7 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
}
func (d *DNSFilter) registerRewritesHandlers() {
d.Config.HTTPRegister("GET", "/control/rewrite/list", d.handleRewriteList)
d.Config.HTTPRegister("POST", "/control/rewrite/add", d.handleRewriteAdd)
d.Config.HTTPRegister("POST", "/control/rewrite/delete", d.handleRewriteDelete)
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
}

View File

@@ -9,7 +9,8 @@ import (
)
func TestRewrites(t *testing.T) {
d := DNSFilter{}
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// CNAME, A, AAAA
d.Rewrites = []RewriteEntry{
{"somecname", "somehost.com", 0, nil},
@@ -25,16 +26,16 @@ func TestRewrites(t *testing.T) {
assert.Equal(t, NotFilteredNotFound, r.Reason)
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Equal(t, 2, len(r.IPList))
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5")))
assert.Len(t, r.IPList, 2)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host.com", dns.TypeAAAA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.Equal(t, 1, len(r.IPList))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
// wildcard
@@ -44,12 +45,12 @@ func TestRewrites(t *testing.T) {
}
d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5")))
assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason)
@@ -61,9 +62,9 @@ func TestRewrites(t *testing.T) {
}
d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// wildcard + CNAME
d.Rewrites = []RewriteEntry{
@@ -72,9 +73,9 @@ func TestRewrites(t *testing.T) {
}
d.prepareRewrites()
r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs
d.Rewrites = []RewriteEntry{
@@ -84,10 +85,10 @@ func TestRewrites(t *testing.T) {
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs + wildcard
d.Rewrites = []RewriteEntry{
@@ -97,14 +98,15 @@ func TestRewrites(t *testing.T) {
}
d.prepareRewrites()
r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName)
assert.True(t, len(r.IPList) == 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
}
func TestRewritesLevels(t *testing.T) {
d := DNSFilter{}
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// exact host, wildcard L2, wildcard L3
d.Rewrites = []RewriteEntry{
{"host.com", "1.1.1.1", 0, nil},
@@ -115,25 +117,26 @@ func TestRewritesLevels(t *testing.T) {
// match exact
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, "1.1.1.1", r.IPList[0].String())
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
// match L2
r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match L3
r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, "3.3.3.3", r.IPList[0].String())
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
}
func TestRewritesExceptionCNAME(t *testing.T) {
d := DNSFilter{}
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// wildcard; exception for a sub-domain
d.Rewrites = []RewriteEntry{
{"*.host.com", "2.2.2.2", 0, nil},
@@ -143,9 +146,9 @@ func TestRewritesExceptionCNAME(t *testing.T) {
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception
r = d.processRewrites("sub.host.com", dns.TypeA)
@@ -153,7 +156,8 @@ func TestRewritesExceptionCNAME(t *testing.T) {
}
func TestRewritesExceptionWC(t *testing.T) {
d := DNSFilter{}
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// wildcard; exception for a sub-wildcard
d.Rewrites = []RewriteEntry{
{"*.host.com", "2.2.2.2", 0, nil},
@@ -163,9 +167,9 @@ func TestRewritesExceptionWC(t *testing.T) {
// match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception
r = d.processRewrites("my.sub.host.com", dns.TypeA)
@@ -173,7 +177,8 @@ func TestRewritesExceptionWC(t *testing.T) {
}
func TestRewritesExceptionIP(t *testing.T) {
d := DNSFilter{}
d := newForTest(nil, nil)
t.Cleanup(d.Close)
// exception for AAAA record
d.Rewrites = []RewriteEntry{
{"host.com", "1.2.3.4", 0, nil},
@@ -186,9 +191,9 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain
r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, "1.2.3.4", r.IPList[0].String())
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
// match exception
r = d.processRewrites("host.com", dns.TypeAAAA)
@@ -200,8 +205,8 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain
r = d.processRewrites("host2.com", dns.TypeAAAA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 1, len(r.IPList))
assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1)
assert.Equal(t, "::1", r.IPList[0].String())
// match exception
@@ -210,6 +215,6 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain
r = d.processRewrites("host3.com", dns.TypeAAAA)
assert.Equal(t, ReasonRewrite, r.Reason)
assert.Equal(t, 0, len(r.IPList))
assert.Equal(t, Rewritten, r.Reason)
assert.Empty(t, r.IPList)
}

View File

@@ -30,6 +30,20 @@ const (
pcTXTSuffix = `pc.dns.adguard.com.`
)
// SetParentalUpstream sets the parental upstream for *DNSFilter.
//
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
func (d *DNSFilter) SetParentalUpstream(u upstream.Upstream) {
d.parentalUpstream = u
}
// SetSafeBrowsingUpstream sets the safe browsing upstream for *DNSFilter.
//
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
func (d *DNSFilter) SetSafeBrowsingUpstream(u upstream.Upstream) {
d.safeBrowsingUpstream = u
}
func (d *DNSFilter) initSecurityServices() error {
var err error
d.safeBrowsingServer = defaultSafebrowsingServer
@@ -37,22 +51,24 @@ func (d *DNSFilter) initSecurityServices() error {
opts := upstream.Options{
Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{
net.ParseIP("94.140.14.15"),
net.ParseIP("94.140.15.16"),
{94, 140, 14, 15},
{94, 140, 15, 16},
net.ParseIP("2a10:50c0::bad1:ff"),
net.ParseIP("2a10:50c0::bad2:ff"),
},
}
d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts)
parUps, err := upstream.AddressToUpstream(d.parentalServer, opts)
if err != nil {
return err
return fmt.Errorf("converting parental server: %w", err)
}
d.SetParentalUpstream(parUps)
d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, opts)
sbUps, err := upstream.AddressToUpstream(d.safeBrowsingServer, opts)
if err != nil {
return err
return fmt.Errorf("converting safe browsing server: %w", err)
}
d.SetSafeBrowsingUpstream(sbUps)
return nil
}
@@ -200,7 +216,6 @@ func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
log.Debug("%s: received hashes for %s: %v", c.svc, c.host, txt.Txt)
for _, t := range txt.Txt {
if len(t) != 32*2 {
continue
}
@@ -228,7 +243,7 @@ func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
func (c *sbCtx) storeCache(hashes [][]byte) {
sort.Slice(hashes, func(a, b int) bool {
return bytes.Compare(hashes[a], hashes[b]) < 0
return bytes.Compare(hashes[a], hashes[b]) == -1
})
var curData []byte
@@ -346,16 +361,12 @@ func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Req
}
func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.SafeBrowsingEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
err := json.NewEncoder(w).Encode(&struct {
Enabled bool `json:"enabled"`
}{
Enabled: d.Config.SafeBrowsingEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
@@ -373,17 +384,12 @@ func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request
}
func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.ParentalEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
err := json.NewEncoder(w).Encode(&struct {
Enabled bool `json:"enabled"`
}{
Enabled: d.Config.ParentalEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
@@ -391,15 +397,15 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
}
func (d *DNSFilter) registerSecurityHandlers() {
d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable)
d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable)
d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable)
d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable)
d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
}

View File

@@ -5,16 +5,15 @@ import (
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestSafeBrowsingHash(t *testing.T) {
// test hostnameToHashes()
hashes := hostnameToHashes("1.2.3.sub.host.com")
assert.Equal(t, 3, len(hashes))
assert.Len(t, hashes, 3)
_, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))]
assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("sub.host.com"))]
@@ -31,9 +30,9 @@ func TestSafeBrowsingHash(t *testing.T) {
q := c.getQuestion()
assert.True(t, strings.Contains(q, "7a1b."))
assert.True(t, strings.Contains(q, "af5a."))
assert.True(t, strings.Contains(q, "eb11."))
assert.Contains(t, q, "7a1b.")
assert.Contains(t, q, "af5a.")
assert.Contains(t, q, "eb11.")
assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com."))
}
@@ -81,7 +80,7 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com"))
c.hashToHost[hash] = "nonexisting.com"
assert.Equal(t, 0, c.getCached())
assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash]
@@ -103,30 +102,17 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com"
c.cache.Set(hash[0:2], make([]byte, 32))
assert.Equal(t, 0, c.getCached())
}
// testErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type testErrUpstream struct{}
// Exchange always returns nil Msg and non-nil error.
func (teu *testErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad")
}
func (teu *testErrUpstream) Address() string {
return ""
assert.Empty(t, c.getCached())
}
func TestSBPC_checkErrorUpstream(t *testing.T) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Close()
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
ups := &testErrUpstream{}
ups := &aghtest.TestErrUpstream{}
d.safeBrowsingUpstream = ups
d.parentalUpstream = ups
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
_, err := d.checkSafeBrowsing("smthng.com")
assert.NotNil(t, err)
@@ -134,3 +120,87 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
_, err = d.checkParental("smthng.com")
assert.NotNil(t, err)
}
func TestSBPC(t *testing.T) {
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const hostname = "example.org"
testCases := []struct {
name string
block bool
testFunc func(string) (Result, error)
testCache cache.Cache
}{{
name: "sb_no_block",
block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
name: "sb_block",
block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, {
name: "pc_no_block",
block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, {
name: "pc_block",
block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Prepare the upstream.
ups := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: tc.block,
}
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
// Firstly, check the request blocking.
hits := 0
res, err := tc.testFunc(hostname)
assert.Nil(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
hits++
} else {
assert.False(t, res.IsFiltered)
}
// Check the cache state, check the response is now cached.
assert.Equal(t, 1, tc.testCache.Stats().Count)
assert.Equal(t, hits, tc.testCache.Stats().Hit)
// There was one request to an upstream.
assert.Equal(t, 1, ups.RequestsCount())
// Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname)
assert.Nil(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
assert.Len(t, res.Rules, 1)
} else {
assert.False(t, res.IsFiltered)
}
// Check the cache state, it should've been used.
assert.Equal(t, 1, tc.testCache.Stats().Count)
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
// Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount())
purgeCaches()
})
}
}

View File

@@ -2,6 +2,7 @@ package dnsfilter
import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"encoding/json"
@@ -101,15 +102,14 @@ func (d *DNSFilter) checkSafeSearch(host string) (Result, error) {
return res, nil
}
// TODO this address should be resolved with upstream that was configured in dnsforward
ips, err := net.LookupIP(safeHost)
ipAddrs, err := d.resolver.LookupIPAddr(context.Background(), safeHost)
if err != nil {
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
return Result{}, err
}
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
for _, ipAddr := range ipAddrs {
if ipv4 := ipAddr.IP.To4(); ipv4 != nil {
res.Rules[0].IP = ipv4
l := d.setCacheResult(gctx.safeSearchCache, host, res)
@@ -133,17 +133,12 @@ func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Reque
}
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"enabled": d.Config.SafeSearchEnabled,
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
err := json.NewEncoder(w).Encode(&struct {
Enabled bool `json:"enabled"`
}{
Enabled: d.Config.SafeSearchEnabled,
})
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return

View File

@@ -83,20 +83,21 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
// Returns the item from the "disallowedClients" list that lead to blocking IP.
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
// but the ip does not belong to it.
func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
ipStr := ip.String()
a.lock.Lock()
defer a.lock.Unlock()
if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 {
_, ok := a.allowedClients[ip]
_, ok := a.allowedClients[ipStr]
if ok {
return false, ""
}
if len(a.allowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.allowedClientsIPNet {
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return false, ""
}
}
@@ -105,15 +106,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
return true, ""
}
_, ok := a.disallowedClients[ip]
_, ok := a.disallowedClients[ipStr]
if ok {
return true, ip
return true, ipStr
}
if len(a.disallowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.disallowedClientsIPNet {
if ipnet.Contains(ipAddr) {
if ipnet.Contains(ip) {
return true, ipnet.String()
}
}

View File

@@ -1,6 +1,7 @@
package dnsforward
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
@@ -8,44 +9,44 @@ import (
func TestIsBlockedIPAllowed(t *testing.T) {
a := &accessCtx{}
assert.True(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil) == nil)
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.True(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
}
func TestIsBlockedIPDisallowed(t *testing.T) {
a := &accessCtx{}
assert.True(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil) == nil)
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "1.1.1.1", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.True(t, disallowed)
assert.Equal(t, "2.2.0.0/16", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule)
assert.Empty(t, disallowedRule)
}
func TestIsBlockedIPBlockedDomain(t *testing.T) {
@@ -60,13 +61,13 @@ func TestIsBlockedIPBlockedDomain(t *testing.T) {
// match by "host2.com"
assert.True(t, a.IsBlockedDomain("host1"))
assert.True(t, a.IsBlockedDomain("host2"))
assert.True(t, !a.IsBlockedDomain("host3"))
assert.False(t, a.IsBlockedDomain("host3"))
// match by wildcard "*.host.com"
assert.True(t, !a.IsBlockedDomain("host.com"))
assert.False(t, a.IsBlockedDomain("host.com"))
assert.True(t, a.IsBlockedDomain("asdf.host.com"))
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com"))
assert.True(t, !a.IsBlockedDomain("asdf.zhost.com"))
assert.False(t, a.IsBlockedDomain("asdf.zhost.com"))
// match by wildcard "||host3.com^"
assert.True(t, a.IsBlockedDomain("host3.com"))

View File

@@ -0,0 +1,165 @@
package dnsforward
import (
"crypto/tls"
"fmt"
"path"
"strings"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/lucas-clemente/quic-go"
)
const maxDomainPartLen = 64
// ValidateClientID returns an error if clientID is not a valid client ID.
func ValidateClientID(clientID string) (err error) {
if len(clientID) > maxDomainPartLen {
return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
}
for i, r := range clientID {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
continue
}
return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
}
return nil
}
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match,
// clientIDFromClientServerName will return an error.
func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) {
if hostSrvName == cliSrvName {
return "", nil
}
if !strings.HasSuffix(cliSrvName, hostSrvName) {
if !strict {
return "", nil
}
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
}
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
err = ValidateClientID(clientID)
if err != nil {
return "", fmt.Errorf("invalid client id: %w", err)
}
return clientID, nil
}
// processClientIDHTTPS extracts the client's ID from the path of the
// client's DNS-over-HTTPS request.
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
r := pctx.HTTPRequest
if r == nil {
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
return resultCodeError
}
origPath := r.URL.Path
parts := strings.Split(path.Clean(origPath), "/")
if parts[0] == "" {
parts = parts[1:]
}
if len(parts) == 0 || parts[0] != "dns-query" {
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
return resultCodeError
}
clientID := ""
switch len(parts) {
case 1:
// Just /dns-query, no client ID.
return resultCodeSuccess
case 2:
clientID = parts[1]
default:
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return resultCodeError
}
err := ValidateClientID(clientID)
if err != nil {
ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
return resultCodeError
}
ctx.clientID = clientID
return resultCodeSuccess
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
type tlsConn interface {
ConnectionState() (cs tls.ConnectionState)
}
// quicSession is a narrow interface for quic.Session to simplify testing.
type quicSession interface {
ConnectionState() (cs quic.ConnectionState)
}
// processClientID extracts the client's ID from the server name of the client's
// DOT or DOQ request or the path of the client's DOH.
func processClientID(dctx *dnsContext) (rc resultCode) {
pctx := dctx.proxyCtx
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
return processClientIDHTTPS(dctx)
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return resultCodeSuccess
}
srvConf := dctx.srv.conf
hostSrvName := srvConf.TLSConfig.ServerName
if hostSrvName == "" {
return resultCodeSuccess
}
cliSrvName := ""
if proto == proxy.ProtoTLS {
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
return resultCodeError
}
cliSrvName = tc.ConnectionState().ServerName
} else if proto == proxy.ProtoQUIC {
qs, ok := pctx.QUICSession.(quicSession)
if !ok {
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
return resultCodeError
}
cliSrvName = qs.ConnectionState().ServerName
}
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
if err != nil {
dctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
}
dctx.clientID = clientID
return resultCodeSuccess
}

View File

@@ -0,0 +1,273 @@
package dnsforward
import (
"crypto/tls"
"net"
"net/http"
"net/url"
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/lucas-clemente/quic-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTLSConn is a tlsConn for tests.
type testTLSConn struct {
// Conn is embedded here simply to make testTLSConn a net.Conn without
// acctually implementing all methods.
net.Conn
serverName string
}
// ConnectionState implements the tlsConn interface for testTLSConn.
func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
cs.ServerName = c.serverName
return cs
}
// testQUICSession is a quicSession for tests.
type testQUICSession struct {
// Session is embedded here simply to make testQUICSession
// a quic.Session without acctually implementing all methods.
quic.Session
serverName string
}
// ConnectionState implements the quicSession interface for testQUICSession.
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
cs.ServerName = c.serverName
return cs
}
func TestProcessClientID(t *testing.T) {
testCases := []struct {
name string
proto string
hostSrvName string
cliSrvName string
wantClientID string
wantErrMsg string
wantRes resultCode
strictSNI bool
}{{
name: "udp",
proto: proxy.ProtoUDP,
hostSrvName: "",
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false,
}, {
name: "tls_no_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}, {
name: "tls_no_client_server_name",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "",
wantClientID: "",
wantErrMsg: `client id check: client server name "" ` +
`doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_no_client_server_name_no_strict",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: false,
}, {
name: "tls_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}, {
name: "tls_client_id_hostname_error",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.net",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" ` +
`doesn't match host server name "example.com"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_invalid_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "!!!.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!' ` +
`at index 0 in client id "!!!"`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "tls_client_id_too_long",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` +
`pqrstuvwxyz0123456789.example.com`,
wantClientID: "",
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmno` +
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" ` +
`is too long, max: 64`,
wantRes: resultCodeError,
strictSNI: true,
}, {
name: "quic_client_id",
proto: proxy.ProtoQUIC,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
strictSNI: true,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tlsConf := TLSConfig{
ServerName: tc.hostSrvName,
StrictSNICheck: tc.strictSNI,
}
srv := &Server{
conf: ServerConfig{TLSConfig: tlsConf},
}
var conn net.Conn
if tc.proto == proxy.ProtoTLS {
conn = testTLSConn{
serverName: tc.cliSrvName,
}
}
var qs quic.Session
if tc.proto == proxy.ProtoQUIC {
qs = testQUICSession{
serverName: tc.cliSrvName,
}
}
dctx := &dnsContext{
srv: srv,
proxyCtx: &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
},
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg == "" {
assert.Nil(t, dctx.err)
} else {
require.NotNil(t, dctx.err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
}
})
}
}
func TestProcessClientID_https(t *testing.T) {
testCases := []struct {
name string
path string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "no_client_id",
path: "/dns-query",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "no_client_id_slash",
path: "/dns-query/",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id",
path: "/dns-query/cli",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id_slash",
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "bad_url",
path: "/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`,
wantRes: resultCodeError,
}, {
name: "extra",
path: "/dns-query/cli/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantRes: resultCodeError,
}, {
name: "invalid_client_id",
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!'` +
` at index 0 in client id "!!!"`,
wantRes: resultCodeError,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := &http.Request{
URL: &url.URL{
Path: tc.path,
},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
},
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg == "" {
assert.Nil(t, dctx.err)
} else {
require.NotNil(t, dctx.err)
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
}
})
}
}

View File

@@ -24,22 +24,22 @@ type FilteringConfig struct {
// Callbacks for other modules
// --
// Filtering callback function
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
//
// TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration
// --
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
BlockingIPAddrv4 net.IP `yaml:"-"`
BlockingIPAddrv6 net.IP `yaml:"-"`
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
// IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing
@@ -110,6 +110,10 @@ type TLSConfig struct {
CertificateChainData []byte `yaml:"-" json:"-"`
PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only
// being used for client ID checking.
ServerName string `yaml:"-" json:"-"`
cert tls.Certificate
// DNS names from certificate (SAN) or CN value from Subject
dnsNames []string
@@ -278,7 +282,7 @@ func (s *Server) prepareUpstreamSettings() error {
}
if len(upstreamConfig.Upstreams) == 0 {
log.Info("Warning: no default upstream servers specified, using %v", defaultDNS)
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
uc, err := proxy.ParseUpstreamsConfig(defaultDNS, s.conf.BootstrapDNS, DefaultTimeout)
if err != nil {
return fmt.Errorf("dns: failed to parse default upstreams: %v", err)
@@ -292,12 +296,13 @@ func (s *Server) prepareUpstreamSettings() error {
// prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries
func (s *Server) prepareIntlProxy() {
intlProxyConfig := proxy.Config{
CacheEnabled: true,
CacheSizeBytes: 4096,
UpstreamConfig: s.conf.UpstreamConfig,
s.internalProxy = &proxy.Proxy{
Config: proxy.Config{
CacheEnabled: true,
CacheSizeBytes: 4096,
UpstreamConfig: s.conf.UpstreamConfig,
},
}
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
}
// prepareTLS - prepares TLS configuration for the DNS proxy

View File

@@ -15,36 +15,69 @@ import (
// To transfer information between modules
type dnsContext struct {
srv *Server
proxyCtx *proxy.DNSContext
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
startTime time.Time
result *dnsfilter.Result
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
origQuestion dns.Question // question received from client. Set when Rewrites are used.
err error // error returned from the module
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
responseFromUpstream bool // response is received from upstream servers
origReqDNSSEC bool // DNSSEC flag in the original request from user
srv *Server
proxyCtx *proxy.DNSContext
// setts are the filtering settings for the client.
setts *dnsfilter.RequestFilteringSettings
startTime time.Time
result *dnsfilter.Result
// origResp is the response received from upstream. It is set when the
// response is modified by filters.
origResp *dns.Msg
// err is the error returned from a processing function.
err error
// clientID is the clientID from DOH, DOQ, or DOT, if provided.
clientID string
// origQuestion is the question received from the client. It is set
// when the request is modified by rewrites.
origQuestion dns.Question
// protectionEnabled shows if the filtering is enabled, and if the
// server's DNS filter is ready.
protectionEnabled bool
// responseFromUpstream shows if the response is received from the
// upstream servers.
responseFromUpstream bool
// origReqDNSSEC shows if the DNSSEC flag in the original request from
// the client is set.
origReqDNSSEC bool
}
// resultCode is the result of a request processing function.
type resultCode int
const (
resultDone = iota // module has completed its job, continue
resultFinish // module has completed its job, exit normally
resultError // an error occurred, exit with an error
// resultCodeSuccess is returned when a handler performed successfully,
// and the next handler must be called.
resultCodeSuccess resultCode = iota
// resultCodeFinish is returned when a handler performed successfully,
// and the processing of the request must be stopped.
resultCodeFinish
// resultCodeError is returned when a handler failed, and the processing
// of the request must be stopped.
resultCodeError
)
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{srv: s, proxyCtx: d}
ctx.result = &dnsfilter.Result{}
ctx.startTime = time.Now()
ctx := &dnsContext{
srv: s,
proxyCtx: d,
result: &dnsfilter.Result{},
startTime: time.Now(),
}
type modProcessFunc func(ctx *dnsContext) int
type modProcessFunc func(ctx *dnsContext) (rc resultCode)
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
// proxy.(Config).RequestHandler, there is no need for additional index
// out of range checking in any of the following functions, because the
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler.
mods := []modProcessFunc{
processInitial,
processInternalHosts,
processInternalIPAddrs,
processClientID,
processFilteringBeforeRequest,
processUpstream,
processDNSSECAfterResponse,
@@ -55,13 +88,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
for _, process := range mods {
r := process(ctx)
switch r {
case resultDone:
case resultCodeSuccess:
// continue: call the next filter
case resultFinish:
case resultCodeFinish:
return nil
case resultError:
case resultCodeError:
return ctx.err
}
}
@@ -73,12 +106,12 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
}
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) int {
func processInitial(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true)
return resultFinish
return resultCodeFinish
}
if s.conf.OnDNSRequest != nil {
@@ -90,10 +123,10 @@ func processInitial(ctx *dnsContext) int {
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
d.Req.Question[0].Name == "use-application-dns.net." {
d.Res = s.genNXDomain(d.Req)
return resultFinish
return resultCodeFinish
}
return resultDone
return resultCodeSuccess
}
// Return TRUE if host names doesn't contain disallowed characters
@@ -151,32 +184,32 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
}
// Respond to A requests if the target host name is associated with a lease from our DHCP server
func processInternalHosts(ctx *dnsContext) int {
func processInternalHosts(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) {
return resultDone
return resultCodeSuccess
}
host := req.Question[0].Name
host = strings.ToLower(host)
if !strings.HasSuffix(host, ".lan.") {
return resultDone
return resultCodeSuccess
}
host = strings.TrimSuffix(host, ".lan.")
s.tableHostToIPLock.Lock()
if s.tableHostToIP == nil {
s.tableHostToIPLock.Unlock()
return resultDone
return resultCodeSuccess
}
ip, ok := s.tableHostToIP[host]
s.tableHostToIPLock.Unlock()
if !ok {
return resultDone
return resultCodeSuccess
}
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip.String())
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip)
resp := s.makeResponse(req)
@@ -194,15 +227,15 @@ func processInternalHosts(ctx *dnsContext) int {
}
ctx.proxyCtx.Res = resp
return resultDone
return resultCodeSuccess
}
// Respond to PTR requests if the target IP address is leased by our DHCP server
func processInternalIPAddrs(ctx *dnsContext) int {
func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
req := ctx.proxyCtx.Req
if req.Question[0].Qtype != dns.TypePTR {
return resultDone
return resultCodeSuccess
}
arpa := req.Question[0].Name
@@ -210,18 +243,18 @@ func processInternalIPAddrs(ctx *dnsContext) int {
arpa = strings.ToLower(arpa)
ip := util.DNSUnreverseAddr(arpa)
if ip == nil {
return resultDone
return resultCodeSuccess
}
s.tablePTRLock.Lock()
if s.tablePTR == nil {
s.tablePTRLock.Unlock()
return resultDone
return resultCodeSuccess
}
host, ok := s.tablePTR[ip.String()]
s.tablePTRLock.Unlock()
if !ok {
return resultDone
return resultCodeSuccess
}
log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host)
@@ -237,16 +270,16 @@ func processInternalIPAddrs(ctx *dnsContext) int {
ptr.Ptr = host + "."
resp.Answer = append(resp.Answer, ptr)
ctx.proxyCtx.Res = resp
return resultDone
return resultCodeSuccess
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) int {
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
return resultCodeSuccess // response is already set - nothing to do
}
s.RLock()
@@ -260,28 +293,28 @@ func processFilteringBeforeRequest(ctx *dnsContext) int {
var err error
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if ctx.protectionEnabled {
ctx.setts = s.getClientRequestFilteringSettings(d)
ctx.setts = s.getClientRequestFilteringSettings(ctx)
ctx.result, err = s.filterDNSRequest(ctx)
}
s.RUnlock()
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
return resultDone
return resultCodeSuccess
}
// Pass request to upstream servers; process the response
func processUpstream(ctx *dnsContext) int {
// processUpstream passes request to upstream servers and handles the response.
func processUpstream(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
return resultCodeSuccess // response is already set - nothing to do
}
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := ipFromAddr(d.Addr)
clientIP := IPStringFromAddr(d.Addr)
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
if upstreamsConf != nil {
log.Debug("Using custom upstreams for %s", clientIP)
@@ -305,26 +338,26 @@ func processUpstream(ctx *dnsContext) int {
err := s.dnsProxy.Resolve(d)
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
ctx.responseFromUpstream = true
return resultDone
return resultCodeSuccess
}
// Process DNSSEC after response from upstream server
func processDNSSECAfterResponse(ctx *dnsContext) int {
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers
!ctx.srv.conf.EnableDNSSEC {
return resultDone
return resultCodeSuccess
}
if !ctx.origReqDNSSEC {
optResp := d.Res.IsEdns0()
if optResp != nil && !optResp.Do() {
return resultDone
return resultCodeSuccess
}
// Remove RRSIG records from response
@@ -355,19 +388,19 @@ func processDNSSECAfterResponse(ctx *dnsContext) int {
d.Res.Ns = answers
}
return resultDone
return resultCodeSuccess
}
// Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) int {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
var err error
switch res.Reason {
case dnsfilter.ReasonRewrite,
dnsfilter.DNSRewriteRule:
case dnsfilter.Rewritten,
dnsfilter.RewrittenRule:
if len(ctx.origQuestion.Name) == 0 {
// origQuestion is set in case we get only CNAME without IP from rewrites table
@@ -379,7 +412,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...)
d.Res.Answer = answer
}
@@ -396,7 +429,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
ctx.result, err = s.filterDNSResponse(ctx)
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
@@ -405,5 +438,5 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
}
}
return resultDone
return resultCodeSuccess
}

View File

@@ -2,9 +2,11 @@
package dnsforward
import (
"errors"
"fmt"
"net"
"net/http"
"os"
"runtime"
"sync"
"time"
@@ -83,10 +85,11 @@ type DNSCreateParams struct {
// NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once
func NewServer(p DNSCreateParams) *Server {
s := &Server{}
s.dnsFilter = p.DNSFilter
s.stats = p.Stats
s.queryLog = p.QueryLog
s := &Server{
dnsFilter: p.DNSFilter,
stats: p.Stats,
queryLog: p.QueryLog,
}
if p.DHCPServer != nil {
s.dhcpServer = p.DHCPServer
@@ -101,6 +104,16 @@ func NewServer(p DNSCreateParams) *Server {
return s
}
// NewCustomServer creates a new instance of *Server with custom internal proxy.
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
s := &Server{}
if internalProxy != nil {
s.internalProxy = internalProxy
}
return s
}
// Close - close object
func (s *Server) Close() {
s.Lock()
@@ -108,6 +121,12 @@ func (s *Server) Close() {
s.stats = nil
s.queryLog = nil
s.dnsProxy = nil
err := s.ipset.Close()
if err != nil {
log.Error("closing ipset: %s", err)
}
s.Unlock()
}
@@ -155,15 +174,15 @@ func (s *Server) Exchange(req *dns.Msg) (*dns.Msg, error) {
return ctx.Res, nil
}
// Start starts the DNS server
// Start starts the DNS server.
func (s *Server) Start() error {
s.Lock()
defer s.Unlock()
return s.startInternal()
return s.startLocked()
}
// startInternal starts without locking
func (s *Server) startInternal() error {
// startLocked starts the DNS server without locking. For internal use only.
func (s *Server) startLocked() error {
err := s.dnsProxy.Start()
if err == nil {
s.isRunning = true
@@ -178,9 +197,7 @@ func (s *Server) Prepare(config *ServerConfig) error {
if config != nil {
s.conf = *config
if s.conf.BlockingMode == "custom_ip" {
s.conf.BlockingIPAddrv4 = net.ParseIP(s.conf.BlockingIPv4)
s.conf.BlockingIPAddrv6 = net.ParseIP(s.conf.BlockingIPv6)
if s.conf.BlockingIPAddrv4 == nil || s.conf.BlockingIPAddrv6 == nil {
if s.conf.BlockingIPv4 == nil || s.conf.BlockingIPv6 == nil {
return fmt.Errorf("dns: invalid custom blocking IP address specified")
}
}
@@ -192,11 +209,27 @@ func (s *Server) Prepare(config *ServerConfig) error {
// Initialize IPSET configuration
// --
s.ipset.init(s.conf.IPSETList)
err := s.ipset.init(s.conf.IPSETList)
if err != nil {
if !errors.Is(err, os.ErrInvalid) && !errors.Is(err, os.ErrPermission) {
return fmt.Errorf("cannot initialize ipset: %w", err)
}
// ipset cannot currently be initialized if the server was
// installed from Snap or when the user or the binary doesn't
// have the required permissions, or when the kernel doesn't
// support netfilter.
//
// Log and go on.
//
// TODO(a.garipov): The Snap problem can probably be solved if
// we add the netlink-connector interface plug.
log.Error("cannot initialize ipset: %s", err)
}
// Prepare DNS servers settings
// --
err := s.prepareUpstreamSettings()
err = s.prepareUpstreamSettings()
if err != nil {
return err
}
@@ -234,15 +267,15 @@ func (s *Server) Prepare(config *ServerConfig) error {
return nil
}
// Stop stops the DNS server
// Stop stops the DNS server.
func (s *Server) Stop() error {
s.Lock()
defer s.Unlock()
return s.stopInternal()
return s.stopLocked()
}
// stopInternal stops without locking
func (s *Server) stopInternal() error {
// stopLocked stops the DNS server without locking. For internal use only.
func (s *Server) stopLocked() error {
if s.dnsProxy != nil {
err := s.dnsProxy.Stop()
if err != nil {
@@ -267,7 +300,7 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
defer s.Unlock()
log.Print("Start reconfiguring the server")
err := s.stopInternal()
err := s.stopLocked()
if err != nil {
return fmt.Errorf("could not reconfigure the server: %w", err)
}
@@ -281,7 +314,7 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
return fmt.Errorf("could not reconfigure the server: %w", err)
}
err = s.startInternal()
err = s.startLocked()
if err != nil {
return fmt.Errorf("could not reconfigure the server: %w", err)
}
@@ -300,6 +333,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// IsBlockedIP - return TRUE if this client should be blocked
func (s *Server) IsBlockedIP(ip string) (bool, string) {
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
return s.access.IsBlockedIP(ip)
}

File diff suppressed because it is too large Load Diff

View File

@@ -13,27 +13,55 @@ import (
)
// filterDNSRewriteResponse handles a single DNS rewrite response entry.
// It returns the constructed answer resource record.
// It returns the properly constructed answer resource record.
func (s *Server) filterDNSRewriteResponse(req *dns.Msg, rr rules.RRType, v rules.RRValue) (ans dns.RR, err error) {
// TODO(a.garipov): As more types are added, we will probably want to
// use a handler-oriented approach here. So, think of a way to decouple
// the answer generation logic from the Server.
switch rr {
case dns.TypeA, dns.TypeAAAA:
ip, ok := v.(net.IP)
if !ok {
return nil, fmt.Errorf("value has type %T, not net.IP", v)
return nil, fmt.Errorf("value for rr type %d has type %T, not net.IP", rr, v)
}
if rr == dns.TypeA {
return s.genAAnswer(req, ip.To4()), nil
return s.genAnswerA(req, ip.To4()), nil
}
return s.genAAAAAnswer(req, ip), nil
case dns.TypeTXT:
return s.genAnswerAAAA(req, ip), nil
case dns.TypePTR,
dns.TypeTXT:
str, ok := v.(string)
if !ok {
return nil, fmt.Errorf("value has type %T, not string", v)
return nil, fmt.Errorf("value for rr type %d has type %T, not string", rr, v)
}
return s.genTXTAnswer(req, []string{str}), nil
if rr == dns.TypeTXT {
return s.genAnswerTXT(req, []string{str}), nil
}
return s.genAnswerPTR(req, str), nil
case dns.TypeMX:
mx, ok := v.(*rules.DNSMX)
if !ok {
return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSMX", rr, v)
}
return s.genAnswerMX(req, mx), nil
case dns.TypeHTTPS,
dns.TypeSVCB:
svcb, ok := v.(*rules.DNSSVCB)
if !ok {
return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSSVCB", rr, v)
}
if rr == dns.TypeHTTPS {
return s.genAnswerHTTPS(req, svcb), nil
}
return s.genAnswerSVCB(req, svcb), nil
default:
log.Debug("don't know how to handle dns rr type %d, skipping", rr)

View File

@@ -0,0 +1,178 @@
package dnsforward
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestServer_FilterDNSRewrite(t *testing.T) {
// Helper data.
ip4 := net.IP{127, 0, 0, 1}
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
mx := &rules.DNSMX{
Exchange: "mail.example.com",
Preference: 32,
}
svcb := &rules.DNSSVCB{
Params: map[string]string{"alpn": "h3"},
Target: "example.com",
Priority: 32,
}
const domain = "example.com"
// Helper functions and entities.
srv := &Server{}
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
return &dns.Msg{
Question: []dns.Question{{
Qtype: qtype,
}},
}
}
makeRes := func(rcode rules.RCode, rr rules.RRType, v rules.RRValue) (res dnsfilter.Result) {
resp := dnsfilter.DNSRewriteResultResponse{
rr: []rules.RRValue{v},
}
return dnsfilter.Result{
DNSRewriteResult: &dnsfilter.DNSRewriteResult{
RCode: rcode,
Response: resp,
},
}
}
// Tests.
t.Run("nxdomain", func(t *testing.T) {
req := makeQ(dns.TypeA)
res := makeRes(dns.RcodeNameError, 0, nil)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
})
t.Run("noerror_empty", func(t *testing.T) {
req := makeQ(dns.TypeA)
res := makeRes(dns.RcodeSuccess, 0, nil)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
assert.Empty(t, d.Res.Answer)
})
t.Run("noerror_a", func(t *testing.T) {
req := makeQ(dns.TypeA)
res := makeRes(dns.RcodeSuccess, dns.TypeA, ip4)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
}
})
t.Run("noerror_aaaa", func(t *testing.T) {
req := makeQ(dns.TypeAAAA)
res := makeRes(dns.RcodeSuccess, dns.TypeAAAA, ip6)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
}
})
t.Run("noerror_ptr", func(t *testing.T) {
req := makeQ(dns.TypePTR)
res := makeRes(dns.RcodeSuccess, dns.TypePTR, domain)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
}
})
t.Run("noerror_txt", func(t *testing.T) {
req := makeQ(dns.TypeTXT)
res := makeRes(dns.RcodeSuccess, dns.TypeTXT, domain)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
}
})
t.Run("noerror_mx", func(t *testing.T) {
req := makeQ(dns.TypeMX)
res := makeRes(dns.RcodeSuccess, dns.TypeMX, mx)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
ans, ok := d.Res.Answer[0].(*dns.MX)
if assert.True(t, ok) {
assert.Equal(t, mx.Exchange, ans.Mx)
assert.Equal(t, mx.Preference, ans.Preference)
}
}
})
t.Run("noerror_svcb", func(t *testing.T) {
req := makeQ(dns.TypeSVCB)
res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcb)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
ans, ok := d.Res.Answer[0].(*dns.SVCB)
if assert.True(t, ok) {
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcb.Target, ans.Target)
assert.Equal(t, svcb.Priority, ans.Priority)
}
}
})
t.Run("noerror_https", func(t *testing.T) {
req := makeQ(dns.TypeHTTPS)
res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcb)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
if assert.Len(t, d.Res.Answer, 1) {
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
if assert.True(t, ok) {
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
assert.Equal(t, svcb.Target, ans.Target)
assert.Equal(t, svcb.Priority, ans.Priority)
}
}
})
}

View File

@@ -12,7 +12,7 @@ import (
)
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := ipFromAddr(d.Addr)
ip := IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip)
@@ -30,15 +30,15 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
return true, nil
}
// getClientRequestFilteringSettings lookups client filtering settings
// using the client's IP address from the DNSContext
func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings {
// getClientRequestFilteringSettings looks up client filtering settings using
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings {
setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true
if s.conf.FilterHandler != nil {
clientAddr := ipFromAddr(d.Addr)
s.conf.FilterHandler(clientAddr, &setts)
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
}
return &setts
}
@@ -55,7 +55,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
} else if res.IsFiltered {
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text)
d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason.In(dnsfilter.ReasonRewrite, dnsfilter.DNSRewriteRule) &&
} else if res.Reason.In(dnsfilter.Rewritten, dnsfilter.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0 {
// Resolve the new canonical name, not the original host
@@ -63,7 +63,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
// processFilteringAfterResponse.
ctx.origQuestion = d.Req.Question[0]
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == dnsfilter.RewriteAutoHosts && len(res.ReverseHosts) != 0 {
} else if res.Reason == dnsfilter.RewrittenAutoHosts && len(res.ReverseHosts) != 0 {
resp := s.makeResponse(req)
for _, h := range res.ReverseHosts {
hdr := dns.RR_Header{
@@ -82,29 +82,29 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
}
d.Res = resp
} else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteAutoHosts {
} else if res.Reason == dnsfilter.Rewritten || res.Reason == dnsfilter.RewrittenAutoHosts {
resp := s.makeResponse(req)
name := host
if len(res.CanonName) != 0 {
resp.Answer = append(resp.Answer, s.genCNAMEAnswer(req, res.CanonName))
resp.Answer = append(resp.Answer, s.genAnswerCNAME(req, res.CanonName))
name = res.CanonName
}
for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA {
a := s.genAAnswer(req, ip.To4())
a := s.genAnswerA(req, ip.To4())
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA {
a := s.genAAAAAnswer(req, ip)
a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a)
}
}
d.Res = resp
} else if res.Reason == dnsfilter.DNSRewriteRule {
} else if res.Reason == dnsfilter.RewrittenRule {
err = s.filterDNSRewrite(req, res, d)
if err != nil {
return nil, err

View File

@@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
@@ -28,8 +29,8 @@ type dnsConfig struct {
ProtectionEnabled *bool `json:"protection_enabled"`
RateLimit *uint32 `json:"ratelimit"`
BlockingMode *string `json:"blocking_mode"`
BlockingIPv4 *string `json:"blocking_ipv4"`
BlockingIPv6 *string `json:"blocking_ipv6"`
BlockingIPv4 net.IP `json:"blocking_ipv4"`
BlockingIPv6 net.IP `json:"blocking_ipv6"`
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
DNSSECEnabled *bool `json:"dnssec_enabled"`
DisableIPv6 *bool `json:"disable_ipv6"`
@@ -68,8 +69,8 @@ func (s *Server) getDNSConfig() dnsConfig {
Bootstraps: &bootstraps,
ProtectionEnabled: &protectionEnabled,
BlockingMode: &blockingMode,
BlockingIPv4: &BlockingIPv4,
BlockingIPv6: &BlockingIPv6,
BlockingIPv4: BlockingIPv4,
BlockingIPv6: BlockingIPv6,
RateLimit: &Ratelimit,
EDNSCSEnabled: &EnableEDNSClientSubnet,
DNSSECEnabled: &EnableDNSSEC,
@@ -100,17 +101,11 @@ func (req *dnsConfig) checkBlockingMode() bool {
bm := *req.BlockingMode
if bm == "custom_ip" {
if req.BlockingIPv4 == nil || req.BlockingIPv6 == nil {
if req.BlockingIPv4.To4() == nil {
return false
}
ip4 := net.ParseIP(*req.BlockingIPv4)
if ip4 == nil || ip4.To4() == nil {
return false
}
ip6 := net.ParseIP(*req.BlockingIPv6)
return ip6 != nil
return req.BlockingIPv6 != nil
}
for _, valid := range []string{
@@ -247,10 +242,8 @@ func (s *Server) setConfig(dc dnsConfig) (restart bool) {
if dc.BlockingMode != nil {
s.conf.BlockingMode = *dc.BlockingMode
if *dc.BlockingMode == "custom_ip" {
s.conf.BlockingIPv4 = *dc.BlockingIPv4
s.conf.BlockingIPAddrv4 = net.ParseIP(*dc.BlockingIPv4)
s.conf.BlockingIPv6 = *dc.BlockingIPv6
s.conf.BlockingIPAddrv6 = net.ParseIP(*dc.BlockingIPv6)
s.conf.BlockingIPv4 = dc.BlockingIPv4.To4()
s.conf.BlockingIPv6 = dc.BlockingIPv6.To16()
}
}
@@ -322,6 +315,11 @@ func ValidateUpstreams(upstreams []string) error {
return nil
}
_, err := proxy.ParseUpstreamsConfig(upstreams, []string{}, DefaultTimeout)
if err != nil {
return err
}
var defaultUpstreamFound bool
for _, u := range upstreams {
d, err := validateUpstream(u)
@@ -530,12 +528,20 @@ func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) registerHandlers() {
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
s.conf.HTTPRegister("POST", "/control/test_upstream_dns", s.handleTestUpstreamDNS)
s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
// Register both versions, with and without the trailing slash, to
// prevent a 301 Moved Permanently redirect when clients request the
// path without the trailing slash. Those redirects break some clients.
//
// See go doc net/http.ServeMux.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/2628.
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
s.conf.HTTPRegister("", "/dns-query/", s.handleDOH)
}

View File

@@ -2,16 +2,35 @@ package dnsforward
import (
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/stretchr/testify/assert"
)
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
s := createTestServer(t)
filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf)
err := s.Start()
assert.Nil(t, err)
defer assert.Nil(t, s.Stop())
@@ -35,6 +54,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
conf: func() ServerConfig {
conf := defaultConf
conf.FastestAddr = true
return conf
},
want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"fastest_addr\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n",
@@ -43,6 +63,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
conf: func() ServerConfig {
conf := defaultConf
conf.AllServers = true
return conf
},
want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"parallel\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n",
@@ -61,7 +82,24 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
}
func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
s := createTestServer(t)
filterConf := &dnsfilter.Config{
SafeBrowsingEnabled: true,
SafeBrowsingCacheSize: 1000,
SafeSearchEnabled: true,
SafeSearchCacheSize: 1000,
ParentalCacheSize: 1000,
CacheTime: 30,
}
forwardConf := ServerConfig{
UDPListenAddr: &net.UDPAddr{},
TCPListenAddr: &net.TCPAddr{},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
},
ConfigModified: func() {},
}
s := createTestServer(t, filterConf, forwardConf)
defaultConf := s.conf

View File

@@ -1,142 +0,0 @@
package dnsforward
import (
"net"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
type ipsetCtx struct {
ipsetList map[string][]string // domain -> []ipset_name
ipsetCache map[[4]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
ipsetMutex *sync.Mutex
ipset6Cache map[[16]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
ipset6Mutex *sync.Mutex
}
// Convert configuration settings to an internal map
// DOMAIN[,DOMAIN].../IPSET1_NAME[,IPSET2_NAME]...
func (c *ipsetCtx) init(ipsetConfig []string) {
c.ipsetList = make(map[string][]string)
c.ipsetCache = make(map[[4]byte]bool)
c.ipsetMutex = &sync.Mutex{}
c.ipset6Cache = make(map[[16]byte]bool)
c.ipset6Mutex = &sync.Mutex{}
for _, it := range ipsetConfig {
it = strings.TrimSpace(it)
hostsAndNames := strings.Split(it, "/")
if len(hostsAndNames) != 2 {
log.Debug("IPSET: invalid value %q", it)
continue
}
ipsetNames := strings.Split(hostsAndNames[1], ",")
if len(ipsetNames) == 0 {
log.Debug("IPSET: invalid value %q", it)
continue
}
bad := false
for i := range ipsetNames {
ipsetNames[i] = strings.TrimSpace(ipsetNames[i])
if len(ipsetNames[i]) == 0 {
bad = true
break
}
}
if bad {
log.Debug("IPSET: invalid value %q", it)
continue
}
hosts := strings.Split(hostsAndNames[0], ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
host = strings.ToLower(host)
if len(host) == 0 {
log.Debug("IPSET: invalid value %q", it)
continue
}
c.ipsetList[host] = ipsetNames
}
}
log.Debug("IPSET: added %d hosts", len(c.ipsetList))
}
func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
switch a := rr.(type) {
case *dns.A:
var ip4 [4]byte
copy(ip4[:], a.A.To4())
c.ipsetMutex.Lock()
defer c.ipsetMutex.Unlock()
_, found := c.ipsetCache[ip4]
if found {
return nil // this IP was added before
}
c.ipsetCache[ip4] = false
return a.A
case *dns.AAAA:
var ip6 [16]byte
copy(ip6[:], a.AAAA)
c.ipset6Mutex.Lock()
defer c.ipset6Mutex.Unlock()
_, found := c.ipset6Cache[ip6]
if found {
return nil // this IP was added before
}
c.ipset6Cache[ip6] = false
return a.AAAA
default:
return nil
}
}
// Add IP addresses of the specified in configuration domain names to an ipset list
func (c *ipsetCtx) process(ctx *dnsContext) int {
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA ||
req.Question[0].Qtype == dns.TypeAAAA) ||
!ctx.responseFromUpstream {
return resultDone
}
host := req.Question[0].Name
host = strings.TrimSuffix(host, ".")
host = strings.ToLower(host)
ipsetNames, found := c.ipsetList[host]
if !found {
return resultDone
}
log.Debug("IPSET: found ipsets %v for host %s", ipsetNames, host)
for _, it := range ctx.proxyCtx.Res.Answer {
ip := c.getIP(it)
if ip == nil {
continue
}
ipStr := ip.String()
for _, name := range ipsetNames {
code, out, err := util.RunCommand("ipset", "add", name, ipStr)
if err != nil {
log.Info("IPSET: %s(%s) -> %s: %s", host, ipStr, name, err)
continue
}
if code != 0 {
log.Info("IPSET: ipset add: code:%d output:%q", code, out)
continue
}
log.Debug("IPSET: added %s(%s) -> %s", host, ipStr, name)
}
}
return resultDone
}

View File

@@ -0,0 +1,400 @@
// +build linux
package dnsforward
import (
"fmt"
"net"
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/log"
"github.com/digineo/go-ipset/v2"
"github.com/mdlayher/netlink"
"github.com/miekg/dns"
"github.com/ti-mo/netfilter"
)
// TODO(a.garipov): Cover with unit tests as well as document how to test it
// manually. The original PR by @dsheets on Github contained an integration
// test, but unfortunately I didn't have the time to properly refactor it and
// check it in.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2611.
// ipsetProps contains one Linux Netfilter ipset properties.
type ipsetProps struct {
name string
family netfilter.ProtoFamily
}
// ipsetCtx is the Linux Netfilter ipset context.
type ipsetCtx struct {
// mu protects all properties below.
mu *sync.Mutex
nameToIpset map[string]ipsetProps
domainToIpsets map[string][]ipsetProps
// TODO(a.garipov): Currently, the ipset list is static, and we don't
// read the IPs already in sets, so we can assume that all incoming IPs
// are either added to all corresponding ipsets or not. When that stops
// being the case, for example if we add dynamic reconfiguration of
// ipsets, this map will need to become a per-ipset-name one.
addedIPs map[[16]byte]struct{}
ipv4Conn *ipset.Conn
ipv6Conn *ipset.Conn
}
// dialNetfilter establishes connections to Linux's netfilter module.
func (c *ipsetCtx) dialNetfilter(config *netlink.Config) (err error) {
// The kernel API does not actually require two sockets but package
// github.com/digineo/go-ipset does.
//
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and
// just use packages netfilter and netlink.
c.ipv4Conn, err = ipset.Dial(netfilter.ProtoIPv4, config)
if err != nil {
return fmt.Errorf("dialing v4: %w", err)
}
c.ipv6Conn, err = ipset.Dial(netfilter.ProtoIPv6, config)
if err != nil {
return fmt.Errorf("dialing v6: %w", err)
}
return nil
}
// ipsetProps returns the properties of an ipset with the given name.
func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) {
// The family doesn't seem to matter when we use a header query, so
// query only the IPv4 one.
//
// TODO(a.garipov): Find out if this is a bug or a feature.
res, err := c.ipv4Conn.Header(name)
if err != nil {
return set, err
}
if res == nil || res.Family == nil {
return set, agherr.Error("empty response or no family data")
}
family := netfilter.ProtoFamily(res.Family.Value)
if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 {
return set, fmt.Errorf("unexpected ipset family %s", family)
}
return ipsetProps{
name: name,
family: family,
}, nil
}
// ipsets returns currently known ipsets.
func (c *ipsetCtx) ipsets(names []string) (sets []ipsetProps, err error) {
for _, name := range names {
set, ok := c.nameToIpset[name]
if ok {
sets = append(sets, set)
continue
}
var err error
set, err = c.ipsetProps(name)
if err != nil {
return nil, fmt.Errorf("querying ipset %q: %w", name, err)
}
c.nameToIpset[name] = set
sets = append(sets, set)
}
return sets, nil
}
// parseIpsetConfig parses one ipset configuration string.
func parseIpsetConfig(cfgStr string) (hosts, ipsetNames []string, err error) {
cfgStr = strings.TrimSpace(cfgStr)
hostsAndNames := strings.Split(cfgStr, "/")
if len(hostsAndNames) != 2 {
return nil, nil, fmt.Errorf("invalid value %q: expected one slash", cfgStr)
}
hosts = strings.Split(hostsAndNames[0], ",")
ipsetNames = strings.Split(hostsAndNames[1], ",")
if len(ipsetNames) == 0 {
log.Info("ipset: resolutions for %q will not be stored", hosts)
return nil, nil, nil
}
for i := range ipsetNames {
ipsetNames[i] = strings.TrimSpace(ipsetNames[i])
if len(ipsetNames[i]) == 0 {
return nil, nil, fmt.Errorf("invalid value %q: empty ipset name", cfgStr)
}
}
for i := range hosts {
hosts[i] = strings.TrimSpace(hosts[i])
hosts[i] = strings.ToLower(hosts[i])
if len(hosts[i]) == 0 {
log.Info("ipset: root catchall in %q", ipsetNames)
}
}
return hosts, ipsetNames, nil
}
// init initializes the ipset context. It is not safe for concurrent use.
//
// TODO(a.garipov): Rewrite into a simple constructor?
func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
c.mu = &sync.Mutex{}
c.nameToIpset = make(map[string]ipsetProps)
c.domainToIpsets = make(map[string][]ipsetProps)
c.addedIPs = make(map[[16]byte]struct{})
err = c.dialNetfilter(&netlink.Config{})
if err != nil {
return fmt.Errorf("ipset: dialing netfilter: %w", err)
}
for i, cfgStr := range ipsetConfig {
var hosts, ipsetNames []string
hosts, ipsetNames, err = parseIpsetConfig(cfgStr)
if err != nil {
return fmt.Errorf("ipset: config line at index %d: %w", i, err)
}
var ipsets []ipsetProps
ipsets, err = c.ipsets(ipsetNames)
if err != nil {
return fmt.Errorf("ipset: getting ipsets config line at index %d: %w", i, err)
}
for _, host := range hosts {
c.domainToIpsets[host] = append(c.domainToIpsets[host], ipsets...)
}
}
log.Debug("ipset: added %d domains for %d ipsets", len(c.domainToIpsets), len(c.nameToIpset))
return nil
}
// Close closes the Linux Netfilter connections.
func (c *ipsetCtx) Close() (err error) {
var errors []error
if c.ipv4Conn != nil {
err = c.ipv4Conn.Close()
if err != nil {
errors = append(errors, err)
}
}
if c.ipv6Conn != nil {
err = c.ipv6Conn.Close()
if err != nil {
errors = append(errors, err)
}
}
if len(errors) != 0 {
return agherr.Many("closing ipsets", errors...)
}
return nil
}
// ipFromRR returns an IP address from a DNS resource record.
func ipFromRR(rr dns.RR) (ip net.IP) {
switch a := rr.(type) {
case *dns.A:
return a.A
case *dns.AAAA:
return a.AAAA
default:
return nil
}
}
// lookupHost find the ipsets for the host, taking subdomain wildcards into
// account.
func (c *ipsetCtx) lookupHost(host string) (sets []ipsetProps) {
// Search for matching ipset hosts starting with most specific
// subdomain. We could use a trie here but the simple, inefficient
// solution isn't that expensive. ~75 % for 10 subdomains vs 0, but
// still sub-microsecond on a Core i7.
//
// TODO(a.garipov): Re-add benchmarks from the original PR.
for i := 0; i != -1; i++ {
host = host[i:]
sets = c.domainToIpsets[host]
if sets != nil {
return sets
}
i = strings.Index(host, ".")
if i == -1 {
break
}
}
// Check the root catch-all one.
return c.domainToIpsets[""]
}
// addIPs adds the IP addresses for the host to the ipset. set must be same
// family as set's family.
func (c *ipsetCtx) addIPs(host string, set ipsetProps, ips []net.IP) (err error) {
if len(ips) == 0 {
return
}
entries := make([]*ipset.Entry, 0, len(ips))
for _, ip := range ips {
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
}
var conn *ipset.Conn
switch set.family {
case netfilter.ProtoIPv4:
conn = c.ipv4Conn
case netfilter.ProtoIPv6:
conn = c.ipv6Conn
default:
return fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
}
err = conn.Add(set.name, entries...)
if err != nil {
return fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
}
log.Debug("ipset: added %s%s to ipset %s", host, ips, set.name)
return nil
}
// skipIpsetProcessing returns true when the ipset processing can be skipped for
// this request.
func (c *ipsetCtx) skipIpsetProcessing(ctx *dnsContext) (ok bool) {
if len(c.domainToIpsets) == 0 || ctx == nil || !ctx.responseFromUpstream {
return true
}
req := ctx.proxyCtx.Req
if req == nil || len(req.Question) == 0 {
return true
}
qt := req.Question[0].Qtype
return qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeANY
}
// process adds the resolved IP addresses to the domain's ipsets, if any.
func (c *ipsetCtx) process(ctx *dnsContext) (rc resultCode) {
var err error
if c == nil {
return resultCodeSuccess
}
log.Debug("ipset: starting processing")
c.mu.Lock()
defer c.mu.Unlock()
if c.skipIpsetProcessing(ctx) {
log.Debug("ipset: skipped processing for request")
return resultCodeSuccess
}
req := ctx.proxyCtx.Req
host := req.Question[0].Name
host = strings.TrimSuffix(host, ".")
host = strings.ToLower(host)
sets := c.lookupHost(host)
if len(sets) == 0 {
log.Debug("ipset: no ipsets for host %s", host)
return resultCodeSuccess
}
log.Debug("ipset: found ipsets %+v for host %s", sets, host)
if ctx.proxyCtx.Res == nil {
return resultCodeSuccess
}
ans := ctx.proxyCtx.Res.Answer
l := len(ans)
v4s := make([]net.IP, 0, l)
v6s := make([]net.IP, 0, l)
for _, rr := range ans {
ip := ipFromRR(rr)
if ip == nil {
continue
}
var iparr [16]byte
copy(iparr[:], ip.To16())
if _, added := c.addedIPs[iparr]; added {
continue
}
if ip.To4() == nil {
v6s = append(v6s, ip)
continue
}
v4s = append(v4s, ip)
}
setLoop:
for _, set := range sets {
switch set.family {
case netfilter.ProtoIPv4:
err = c.addIPs(host, set, v4s)
if err != nil {
break setLoop
}
case netfilter.ProtoIPv6:
err = c.addIPs(host, set, v6s)
if err != nil {
break setLoop
}
default:
err = fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
break setLoop
}
}
if err != nil {
log.Error("ipset: adding host ips: %s", err)
} else {
log.Debug("ipset: processed %d new ips", len(v4s)+len(v6s))
}
for _, ip := range v4s {
var iparr [16]byte
copy(iparr[:], ip.To16())
c.addedIPs[iparr] = struct{}{}
}
for _, ip := range v6s {
var iparr [16]byte
copy(iparr[:], ip.To16())
c.addedIPs[iparr] = struct{}{}
}
return resultCodeSuccess
}

View File

@@ -0,0 +1,26 @@
// +build !linux
package dnsforward
import (
"github.com/AdguardTeam/golibs/log"
)
type ipsetCtx struct{}
// init initializes the ipset context.
func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
if len(ipsetConfig) != 0 {
log.Info("ipset: only available on linux")
}
return nil
}
// process adds the resolved IP addresses to the domain's ipsets, if any.
func (c *ipsetCtx) process(_ *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
// Close closes the Linux Netfilter connections.
func (c *ipsetCtx) Close() (_ error) { return nil }

View File

@@ -1,41 +0,0 @@
package dnsforward
import (
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestIPSET(t *testing.T) {
s := Server{}
s.conf.IPSETList = append(s.conf.IPSETList, "HOST.com/name")
s.conf.IPSETList = append(s.conf.IPSETList, "host2.com,host3.com/name23")
s.conf.IPSETList = append(s.conf.IPSETList, "host4.com/name4,name41")
c := ipsetCtx{}
c.init(s.conf.IPSETList)
assert.Equal(t, "name", c.ipsetList["host.com"][0])
assert.Equal(t, "name23", c.ipsetList["host2.com"][0])
assert.Equal(t, "name23", c.ipsetList["host3.com"][0])
assert.Equal(t, "name4", c.ipsetList["host4.com"][0])
assert.Equal(t, "name41", c.ipsetList["host4.com"][1])
_, ok := c.ipsetList["host0.com"]
assert.False(t, ok)
ctx := &dnsContext{
srv: &s,
}
ctx.proxyCtx = &proxy.DNSContext{}
ctx.proxyCtx.Req = &dns.Msg{
Question: []dns.Question{
{
Name: "host.com.",
Qtype: dns.TypeA,
},
},
}
assert.Equal(t, resultDone, c.process(ctx))
}

View File

@@ -7,10 +7,12 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
// Create a DNS response by DNS request and set necessary flags
// makeResponse creates a DNS response by req and sets necessary flags. It also
// guarantees that req.Question will be not empty.
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
@@ -58,9 +60,9 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
switch m.Question[0].Qtype {
case dns.TypeA:
return s.genARecord(m, s.conf.BlockingIPAddrv4)
return s.genARecord(m, s.conf.BlockingIPv4)
case dns.TypeAAAA:
return s.genAAAARecord(m, s.conf.BlockingIPAddrv6)
return s.genAAAARecord(m, s.conf.BlockingIPv6)
}
} else if s.conf.BlockingMode == "nxdomain" {
// means that we should return NXDOMAIN for any blocked request
@@ -92,48 +94,64 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
resp := s.makeResponse(request)
resp.Answer = append(resp.Answer, s.genAAnswer(request, ip))
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp
}
func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg {
resp := s.makeResponse(request)
resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip))
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp
}
func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) {
return dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeA,
Rrtype: rrType,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
answer.A = ip
return answer
}
func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA {
answer := new(dns.AAAA)
answer.Hdr = dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeAAAA,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
func (s *Server) genAnswerA(req *dns.Msg, ip net.IP) (ans *dns.A) {
return &dns.A{
Hdr: s.hdr(req, dns.TypeA),
A: ip,
}
answer.AAAA = ip
return answer
}
func (s *Server) genTXTAnswer(req *dns.Msg, strs []string) (answer *dns.TXT) {
func (s *Server) genAnswerAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA) {
return &dns.AAAA{
Hdr: s.hdr(req, dns.TypeAAAA),
AAAA: ip,
}
}
func (s *Server) genAnswerCNAME(req *dns.Msg, cname string) (ans *dns.CNAME) {
return &dns.CNAME{
Hdr: s.hdr(req, dns.TypeCNAME),
Target: dns.Fqdn(cname),
}
}
func (s *Server) genAnswerMX(req *dns.Msg, mx *rules.DNSMX) (ans *dns.MX) {
return &dns.MX{
Hdr: s.hdr(req, dns.TypePTR),
Preference: mx.Preference,
Mx: mx.Exchange,
}
}
func (s *Server) genAnswerPTR(req *dns.Msg, ptr string) (ans *dns.PTR) {
return &dns.PTR{
Hdr: s.hdr(req, dns.TypePTR),
Ptr: ptr,
}
}
func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
return &dns.TXT{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeTXT,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
},
Hdr: s.hdr(req, dns.TypeTXT),
Txt: strs,
}
}
@@ -198,19 +216,6 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return resp
}
// Make a CNAME response
func (s *Server) genCNAMEAnswer(req *dns.Msg, cname string) *dns.CNAME {
answer := new(dns.CNAME)
answer.Hdr = dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeCNAME,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
answer.Target = dns.Fqdn(cname)
return answer
}
// Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}

View File

@@ -1,7 +1,6 @@
package dnsforward
import (
"net"
"strings"
"time"
@@ -13,13 +12,13 @@ import (
)
// Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) int {
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
d := ctx.proxyCtx
pctx := ctx.proxyCtx
shouldLog := true
msg := d.Req
msg := pctx.Req
// don't log ANY request if refuseAny is enabled
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny {
@@ -32,65 +31,67 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
if shouldLog && s.queryLog != nil {
p := querylog.AddParams{
Question: msg,
Answer: d.Res,
Answer: pctx.Res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: getIP(d.Addr),
ClientIP: IPFromAddr(pctx.Addr),
ClientID: ctx.clientID,
}
switch d.Proto {
switch pctx.Proto {
case proxy.ProtoHTTPS:
p.ClientProto = querylog.ClientProtoDOH
case proxy.ProtoQUIC:
p.ClientProto = querylog.ClientProtoDOQ
case proxy.ProtoTLS:
p.ClientProto = querylog.ClientProtoDOT
case proxy.ProtoDNSCrypt:
p.ClientProto = querylog.ClientProtoDNSCrypt
default:
// Consider this a plain DNS-over-UDP or DNS-over-TCL
// Consider this a plain DNS-over-UDP or DNS-over-TCP
// request.
}
if d.Upstream != nil {
p.Upstream = d.Upstream.Address()
if pctx.Upstream != nil {
p.Upstream = pctx.Upstream.Address()
}
s.queryLog.Add(p)
}
s.updateStats(d, elapsed, *ctx.result)
s.updateStats(ctx, elapsed, *ctx.result)
s.RUnlock()
return resultDone
return resultCodeSuccess
}
func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) {
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) {
if s.stats == nil {
return
}
pctx := ctx.proxyCtx
e := stats.Entry{}
e.Domain = strings.ToLower(d.Req.Question[0].Name)
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
switch addr := d.Addr.(type) {
case *net.UDPAddr:
e.Client = addr.IP
case *net.TCPAddr:
e.Client = addr.IP
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
}
e.Time = uint32(elapsed / 1000)
e.Result = stats.RNotFiltered
switch res.Reason {
case dnsfilter.FilteredSafeBrowsing:
e.Result = stats.RSafeBrowsing
case dnsfilter.FilteredParental:
e.Result = stats.RParental
case dnsfilter.FilteredSafeSearch:
e.Result = stats.RSafeSearch
case dnsfilter.FilteredBlockList:
fallthrough
case dnsfilter.FilteredInvalid:

View File

@@ -0,0 +1,198 @@
package dnsforward
import (
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
// testQueryLog is a simple querylog.QueryLog implementation for tests.
type testQueryLog struct {
// QueryLog is embedded here simply to make testQueryLog
// a querylog.QueryLog without acctually implementing all methods.
querylog.QueryLog
lastParams querylog.AddParams
}
// Add implements the querylog.QueryLog interface for *testQueryLog.
func (l *testQueryLog) Add(p querylog.AddParams) {
l.lastParams = p
}
// testStats is a simple stats.Stats implementation for tests.
type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without
// acctually implementing all methods.
stats.Stats
lastEntry stats.Entry
}
// Update implements the stats.Stats interface for *testStats.
func (l *testStats) Update(e stats.Entry) {
l.lastEntry = e
}
func TestProcessQueryLogsAndStats(t *testing.T) {
testCases := []struct {
name string
proto string
addr net.Addr
clientID string
wantLogProto querylog.ClientProto
wantStatClient string
wantCode resultCode
reason dnsfilter.Reason
wantStatResult stats.Result
}{{
name: "success_udp",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls_client_id",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "cli42",
wantLogProto: querylog.ClientProtoDOT,
wantStatClient: "cli42",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOT,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_quic",
proto: proxy.ProtoQUIC,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOQ,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_https",
proto: proxy.ProtoHTTPS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOH,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_dnscrypt",
proto: proxy.ProtoDNSCrypt,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDNSCrypt,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_udp_filtered",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredBlockList,
wantStatResult: stats.RFiltered,
}, {
name: "success_udp_sb",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredSafeBrowsing,
wantStatResult: stats.RSafeBrowsing,
}, {
name: "success_udp_ss",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredSafeSearch,
wantStatResult: stats.RSafeSearch,
}, {
name: "success_udp_pc",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredParental,
wantStatResult: stats.RParental,
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
assert.Nil(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := &dns.Msg{
Question: []dns.Question{{
Name: "example.com.",
}},
}
pctx := &proxy.DNSContext{
Proto: tc.proto,
Req: req,
Res: &dns.Msg{},
Addr: tc.addr,
Upstream: ups,
}
ql := &testQueryLog{}
st := &testStats{}
dctx := &dnsContext{
srv: &Server{
queryLog: ql,
stats: st,
},
proxyCtx: pctx,
startTime: time.Now(),
result: &dnsfilter.Result{
Reason: tc.reason,
},
clientID: tc.clientID,
}
code := processQueryLogsAndStats(dctx)
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
assert.Equal(t, tc.wantStatResult, st.lastEntry.Result)
})
}
}

View File

@@ -0,0 +1,168 @@
package dnsforward
import (
"encoding/base64"
"net"
"strconv"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
// genAnswerHTTPS returns a properly initialized HTTPS resource record.
//
// See the comment on genAnswerSVCB for a list of current restrictions on
// parameter values.
func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) {
ans = &dns.HTTPS{
SVCB: *s.genAnswerSVCB(req, svcb),
}
ans.Hdr.Rrtype = dns.TypeHTTPS
return ans
}
// strToSVCBKey is the string-to-svcb-key mapping.
//
// See https://github.com/miekg/dns/blob/23c4faca9d32b0abbb6e179aa1aadc45ac53a916/svcb.go#L27.
//
// TODO(a.garipov): Propose exporting this API or something similar in the
// github.com/miekg/dns module.
var strToSVCBKey = map[string]dns.SVCBKey{
"alpn": dns.SVCB_ALPN,
"echconfig": dns.SVCB_ECHCONFIG,
"ipv4hint": dns.SVCB_IPV4HINT,
"ipv6hint": dns.SVCB_IPV6HINT,
"mandatory": dns.SVCB_MANDATORY,
"no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN,
"port": dns.SVCB_PORT,
}
// svcbKeyHandler is a handler for one SVCB parameter key.
type svcbKeyHandler func(valStr string) (val dns.SVCBKeyValue)
// svcbKeyHandlers are the supported SVCB parameters handlers.
var svcbKeyHandlers = map[string]svcbKeyHandler{
"alpn": func(valStr string) (val dns.SVCBKeyValue) {
return &dns.SVCBAlpn{
Alpn: []string{valStr},
}
},
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
ech, err := base64.StdEncoding.DecodeString(valStr)
if err != nil {
log.Debug("can't parse svcb/https echconfig: %s; ignoring", err)
return nil
}
return &dns.SVCBECHConfig{
ECH: ech,
}
},
"ipv4hint": func(valStr string) (val dns.SVCBKeyValue) {
ip := net.ParseIP(valStr)
if ip4 := ip.To4(); ip == nil || ip4 == nil {
log.Debug("can't parse svcb/https ipv4 hint %q; ignoring", valStr)
return nil
}
return &dns.SVCBIPv4Hint{
Hint: []net.IP{ip},
}
},
"ipv6hint": func(valStr string) (val dns.SVCBKeyValue) {
ip := net.ParseIP(valStr)
if ip == nil {
log.Debug("can't parse svcb/https ipv6 hint %q; ignoring", valStr)
return nil
}
return &dns.SVCBIPv6Hint{
Hint: []net.IP{ip},
}
},
"mandatory": func(valStr string) (val dns.SVCBKeyValue) {
code, ok := strToSVCBKey[valStr]
if !ok {
log.Debug("unknown svcb/https mandatory key %q, ignoring", valStr)
return nil
}
return &dns.SVCBMandatory{
Code: []dns.SVCBKey{code},
}
},
"no-default-alpn": func(_ string) (val dns.SVCBKeyValue) {
return &dns.SVCBNoDefaultAlpn{}
},
"port": func(valStr string) (val dns.SVCBKeyValue) {
port64, err := strconv.ParseUint(valStr, 10, 16)
if err != nil {
log.Debug("can't parse svcb/https port: %s; ignoring", err)
return nil
}
return &dns.SVCBPort{
Port: uint16(port64),
}
},
}
// genAnswerSVCB returns a properly initialized SVCB resource record.
//
// Currently, there are several restrictions on how the parameters are parsed.
// Firstly, the parsing of non-contiguous values isn't supported. Secondly, the
// parsing of value-lists is not supported either.
//
// ipv4hint=127.0.0.1 // Supported.
// ipv4hint="127.0.0.1" // Unsupported.
// ipv4hint=127.0.0.1,127.0.0.2 // Unsupported.
// ipv4hint="127.0.0.1,127.0.0.2" // Unsupported.
//
// TODO(a.garipov): Support all of these.
func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB) {
ans = &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: svcb.Priority,
Target: svcb.Target,
}
if len(svcb.Params) == 0 {
return ans
}
values := make([]dns.SVCBKeyValue, 0, len(svcb.Params))
for k, valStr := range svcb.Params {
handler, ok := svcbKeyHandlers[k]
if !ok {
log.Debug("unknown svcb/https key %q, ignoring", k)
continue
}
val := handler(valStr)
if val == nil {
continue
}
values = append(values, val)
}
if len(values) > 0 {
ans.Value = values
}
return ans
}

View File

@@ -0,0 +1,154 @@
package dnsforward
import (
"net"
"testing"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
// Preconditions.
s := &Server{
conf: ServerConfig{
FilteringConfig: FilteringConfig{
BlockedResponseTTL: 3600,
},
},
}
req := &dns.Msg{
Question: []dns.Question{{
Name: "abcd",
}},
}
// Constants and helper values.
const host = "example.com"
const prio = 32
ip4 := net.IPv4(127, 0, 0, 1)
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
// Helper functions.
dnssvcb := func(key, value string) (svcb *rules.DNSSVCB) {
svcb = &rules.DNSSVCB{
Target: host,
Priority: prio,
}
if key == "" {
return svcb
}
svcb.Params = map[string]string{
key: value,
}
return svcb
}
wantsvcb := func(kv dns.SVCBKeyValue) (want *dns.SVCB) {
want = &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: prio,
Target: host,
}
if kv == nil {
return want
}
want.Value = []dns.SVCBKeyValue{kv}
return want
}
// Tests.
testCases := []struct {
svcb *rules.DNSSVCB
want *dns.SVCB
name string
}{{
svcb: dnssvcb("", ""),
want: wantsvcb(nil),
name: "no_params",
}, {
svcb: dnssvcb("foo", "bar"),
want: wantsvcb(nil),
name: "invalid",
}, {
svcb: dnssvcb("alpn", "h3"),
want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}),
name: "alpn",
}, {
svcb: dnssvcb("echconfig", "AAAA"),
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
name: "echconfig",
}, {
svcb: dnssvcb("echconfig", "%BAD%"),
want: wantsvcb(nil),
name: "echconfig_invalid",
}, {
svcb: dnssvcb("ipv4hint", "127.0.0.1"),
want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}),
name: "ipv4hint",
}, {
svcb: dnssvcb("ipv4hint", "127.0.01"),
want: wantsvcb(nil),
name: "ipv4hint_invalid",
}, {
svcb: dnssvcb("ipv6hint", "::1"),
want: wantsvcb(&dns.SVCBIPv6Hint{Hint: []net.IP{ip6}}),
name: "ipv6hint",
}, {
svcb: dnssvcb("ipv6hint", ":::1"),
want: wantsvcb(nil),
name: "ipv6hint_invalid",
}, {
svcb: dnssvcb("mandatory", "alpn"),
want: wantsvcb(&dns.SVCBMandatory{Code: []dns.SVCBKey{dns.SVCB_ALPN}}),
name: "mandatory",
}, {
svcb: dnssvcb("mandatory", "alpnn"),
want: wantsvcb(nil),
name: "mandatory_invalid",
}, {
svcb: dnssvcb("no-default-alpn", ""),
want: wantsvcb(&dns.SVCBNoDefaultAlpn{}),
name: "no-default-alpn",
}, {
svcb: dnssvcb("port", "8080"),
want: wantsvcb(&dns.SVCBPort{Port: 8080}),
name: "port",
}, {
svcb: dnssvcb("port", "1005008080"),
want: wantsvcb(nil),
name: "port",
}}
for _, tc := range testCases {
t.Run("https", func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
want := &dns.HTTPS{SVCB: *tc.want}
want.Hdr.Rrtype = dns.TypeHTTPS
got := s.genAnswerHTTPS(req, tc.svcb)
assert.Equal(t, want, got)
})
})
t.Run("svcb", func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
got := s.genAnswerSVCB(req, tc.svcb)
assert.Equal(t, tc.want, got)
})
})
}
}

View File

@@ -8,38 +8,8 @@ import (
"github.com/AdguardTeam/golibs/utils"
)
// GetIPString is a helper function that extracts IP address from net.Addr
func GetIPString(addr net.Addr) string {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Get IP address from net.Addr object
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func ipFromAddr(a net.Addr) string {
switch addr := a.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
// Get IP address from net.Addr
func getIP(addr net.Addr) net.IP {
// IPFromAddr gets IP address from addr.
func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP
@@ -49,6 +19,23 @@ func getIP(addr net.Addr) net.IP {
return nil
}
// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipStr string) {
if ip := IPFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Find value in a sorted array
func findSorted(ar []string, val string) int {
i := sort.SearchStrings(ar, val)

View File

@@ -2,13 +2,10 @@ package home
import (
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math"
"math/big"
"net/http"
"strings"
"sync"
@@ -20,8 +17,12 @@ import (
)
const (
cookieTTL = 365 * 24 // in hours
// cookieTTL is given in hours.
cookieTTL = 365 * 24
sessionCookieName = "agh_session"
// sessionTokenSize is the length of session token in bytes.
sessionTokenSize = 16
)
type session struct {
@@ -59,10 +60,10 @@ func (s *session) deserialize(data []byte) bool {
// Auth - global object
type Auth struct {
db *bbolt.DB
sessions map[string]*session // session name -> session data
lock sync.Mutex
sessions map[string]*session
users []User
sessionTTL uint32 // in seconds
lock sync.Mutex
sessionTTL uint32
}
// User object
@@ -223,24 +224,35 @@ func (a *Auth) removeSession(sess []byte) {
log.Debug("Auth: removed session from DB")
}
// CheckSession - check if session is valid
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
func (a *Auth) CheckSession(sess string) int {
// checkSessionResult is the result of checking a session.
type checkSessionResult int
// checkSessionResult constants.
const (
checkSessionOK checkSessionResult = 0
checkSessionNotFound checkSessionResult = -1
checkSessionExpired checkSessionResult = 1
)
// checkSession checks if the session is valid.
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
now := uint32(time.Now().UTC().Unix())
update := false
a.lock.Lock()
defer a.lock.Unlock()
s, ok := a.sessions[sess]
if !ok {
a.lock.Unlock()
return -1
return checkSessionNotFound
}
if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSession(key)
a.lock.Unlock()
return 1
return checkSessionExpired
}
newExpire := now + a.sessionTTL
@@ -250,8 +262,6 @@ func (a *Auth) CheckSession(sess string) int {
s.expire = newExpire
}
a.lock.Unlock()
if update {
key, _ := hex.DecodeString(sess)
if a.storeSession(key, s) {
@@ -259,7 +269,7 @@ func (a *Auth) CheckSession(sess string) int {
}
}
return 0
return checkSessionOK
}
// RemoveSession - remove session
@@ -276,16 +286,29 @@ type loginJSON struct {
Password string `json:"password"`
}
func getSession(u *User) ([]byte, error) {
maxSalt := big.NewInt(math.MaxUint32)
salt, err := rand.Int(rand.Reader, maxSalt)
// newSessionToken returns cryptographically secure randomly generated slice of
// bytes of sessionTokenSize length.
//
// TODO(e.burkov): Think about using byte array instead of byte slice.
func newSessionToken() (data []byte, err error) {
randData := make([]byte, sessionTokenSize)
_, err = rand.Read(randData)
if err != nil {
return nil, err
}
d := []byte(fmt.Sprintf("%s%s%s", salt, u.Name, u.PasswordHash))
hash := sha256.Sum256(d)
return hash[:], nil
return randData, nil
}
// cookieTimeFormat is the format to be used in (time.Time).Format for cookie's
// expiry field.
const cookieTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// cookieExpiryFormat returns the formatted exp to be used in cookie string.
// It's quite simple for now, but probably will be expanded in the future.
func cookieExpiryFormat(exp time.Time) (formatted string) {
return exp.Format(cookieTimeFormat)
}
func (a *Auth) httpCookie(req loginJSON) (string, error) {
@@ -294,24 +317,23 @@ func (a *Auth) httpCookie(req loginJSON) (string, error) {
return "", nil
}
sess, err := getSession(&u)
sess, err := newSessionToken()
if err != nil {
return "", err
}
now := time.Now().UTC()
expire := now.Add(cookieTTL * time.Hour)
expstr := expire.Format(time.RFC1123)
expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT"
expstr += "GMT"
s := session{}
s.userName = u.Name
s.expire = uint32(now.Unix()) + a.sessionTTL
a.addSession(sess, &s)
a.addSession(sess, &session{
userName: u.Name,
expire: uint32(now.Unix()) + a.sessionTTL,
})
return fmt.Sprintf("%s=%s; Path=/; HttpOnly; Expires=%s",
sessionCookieName, hex.EncodeToString(sess), expstr), nil
return fmt.Sprintf(
"%s=%s; Path=/; HttpOnly; Expires=%s",
sessionCookieName, hex.EncodeToString(sess),
cookieExpiryFormat(now.Add(cookieTTL*time.Hour)),
), nil
}
func handleLogin(w http.ResponseWriter, r *http.Request) {
@@ -360,8 +382,8 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
// RegisterAuthHandlers - register handlers
func RegisterAuthHandlers() {
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler("POST", handleLogin)))
httpRegister("GET", "/control/logout", handleLogout)
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
httpRegister(http.MethodGet, "/control/logout", handleLogout)
}
func parseCookie(cookie string) string {
@@ -392,8 +414,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
ok = true
} else if err == nil {
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
r := Context.auth.checkSession(cookie.Value)
if r == checkSessionOK {
ok = true
} else if r < 0 {
log.Debug("Auth: invalid cookie value: %s", cookie)
@@ -434,12 +456,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
authRequired := Context.auth != nil && Context.auth.AuthRequired()
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
r := Context.auth.checkSession(cookie.Value)
if r == checkSessionOK {
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)
return
} else if r < 0 {
} else if r == checkSessionNotFound {
log.Debug("Auth: invalid cookie value: %s", cookie)
}
}
@@ -503,32 +526,34 @@ func (a *Auth) UserFind(login, password string) User {
return User{}
}
// GetCurrentUser - get the current user
func (a *Auth) GetCurrentUser(r *http.Request) User {
// getCurrentUser returns the current user. It returns an empty User if the
// user is not found.
func (a *Auth) getCurrentUser(r *http.Request) User {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// there's no Cookie, check Basic authentication
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
u := Context.auth.UserFind(user, pass)
return u
return Context.auth.UserFind(user, pass)
}
return User{}
}
a.lock.Lock()
defer a.lock.Unlock()
s, ok := a.sessions[cookie.Value]
if !ok {
a.lock.Unlock()
return User{}
}
for _, u := range a.users {
if u.Name == s.userName {
a.lock.Unlock()
return u
}
}
a.lock.Unlock()
return User{}
}

View File

@@ -1,6 +1,8 @@
package home
import (
"bytes"
"crypto/rand"
"encoding/hex"
"net/http"
"net/url"
@@ -9,12 +11,13 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}
func prepareTestDir() string {
@@ -24,24 +27,44 @@ func prepareTestDir() string {
return dir
}
func TestNewSessionToken(t *testing.T) {
// Successful case.
token, err := newSessionToken()
require.Nil(t, err)
assert.Len(t, token, sessionTokenSize)
// Break the rand.Reader.
prevReader := rand.Reader
t.Cleanup(func() {
rand.Reader = prevReader
})
rand.Reader = &bytes.Buffer{}
// Unsuccessful case.
token, err = newSessionToken()
require.NotNil(t, err)
assert.Empty(t, token)
}
func TestAuth(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
t.Cleanup(func() { _ = os.RemoveAll(dir) })
fn := filepath.Join(dir, "sessions.db")
users := []User{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
users := []User{{
Name: "name",
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
}}
a := InitAuth(fn, nil, 60)
s := session{}
user := User{Name: "name"}
a.UserAdd(&user, "password")
assert.True(t, a.CheckSession("notfound") == -1)
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.RemoveSession("notfound")
sess, err := getSession(&users[0])
sess, err := newSessionToken()
assert.Nil(t, err)
sessStr := hex.EncodeToString(sess)
@@ -49,13 +72,13 @@ func TestAuth(t *testing.T) {
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 1)
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 0)
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
a.Close()
@@ -63,23 +86,22 @@ func TestAuth(t *testing.T) {
a = InitAuth(fn, users, 60)
// the session is still alive
assert.True(t, a.CheckSession(sessStr) == 0)
// reset our expiration time because CheckSession() has just updated it
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()
u := a.UserFind("name", "password")
assert.True(t, len(u.Name) != 0)
assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second)
// load and remove expired sessions
a = InitAuth(fn, users, 60)
assert.True(t, a.CheckSession(sessStr) == -1)
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
a.Close()
os.Remove(fn)
}
// implements http.ResponseWriter
@@ -111,7 +133,7 @@ func TestAuthHTTP(t *testing.T) {
Context.auth = InitAuth(fn, users, 60)
handlerCalled := false
handler := func(w http.ResponseWriter, r *http.Request) {
handler := func(_ http.ResponseWriter, _ *http.Request) {
handlerCalled = true
}
handler2 := optionalAuth(handler)
@@ -119,15 +141,15 @@ func TestAuthHTTP(t *testing.T) {
w.hdr = make(http.Header)
r := http.Request{}
r.Header = make(http.Header)
r.Method = "GET"
r.Method = http.MethodGet
// get / - we're redirected to login page
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, w.statusCode == http.StatusFound)
assert.True(t, w.hdr.Get("Location") != "")
assert.True(t, !handlerCalled)
assert.Equal(t, http.StatusFound, w.statusCode)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.False(t, handlerCalled)
// go to login page
loginURL := w.hdr.Get("Location")
@@ -139,7 +161,7 @@ func TestAuthHTTP(t *testing.T) {
// perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
assert.Nil(t, err)
assert.True(t, cookie != "")
assert.NotEmpty(t, cookie)
// get /
handler2 = optionalAuth(handler)
@@ -168,8 +190,8 @@ func TestAuthHTTP(t *testing.T) {
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)
assert.True(t, w.hdr.Get("Location") != "")
assert.True(t, !handlerCalled)
assert.NotEmpty(t, w.hdr.Get("Location"))
assert.False(t, handlerCalled)
r.Header.Del("Cookie")
// get login page with an invalid cookie

View File

@@ -36,7 +36,7 @@ func TestAuthGL(t *testing.T) {
binary.BigEndian.PutUint32(data, tval)
}
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
r, _ := http.NewRequest("GET", "http://localhost/", nil)
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
assert.True(t, glProcessCookie(r))
GLMode = false

View File

@@ -11,23 +11,21 @@ import (
"sync"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/utils"
)
const (
clientsUpdatePeriod = 10 * time.Minute
)
const clientsUpdatePeriod = 10 * time.Minute
var webHandlersRegistered = false
// Client information
// Client contains information about persistent clients.
type Client struct {
IDs []string
Tags []string
@@ -52,14 +50,13 @@ type Client struct {
type clientSource uint
// Client sources
// Client sources. The order determines the priority.
const (
// Priority: etc/hosts > DHCP > ARP > rDNS > WHOIS
ClientSourceWHOIS clientSource = iota // from WHOIS
ClientSourceRDNS // from rDNS
ClientSourceDHCP // from DHCP
ClientSourceARP // from 'arp -a'
ClientSourceHostsFile // from /etc/hosts
ClientSourceWHOIS clientSource = iota
ClientSourceRDNS
ClientSourceDHCP
ClientSourceARP
ClientSourceHostsFile
)
// ClientHost information
@@ -70,8 +67,10 @@ type ClientHost struct {
}
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for
// different types (string, net.IP, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client
idIndex map[string]*Client // ID -> client
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
@@ -156,7 +155,7 @@ func (clients *clientsContainer) tagKnown(tag string) bool {
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, cy := range objects {
cli := Client{
cli := &Client{
Name: cy.Name,
IDs: cy.IDs,
UseOwnSettings: !cy.UseGlobalSettings,
@@ -172,7 +171,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, s := range cy.BlockedServices {
if !dnsfilter.BlockedSvcKnown(s) {
log.Debug("Clients: skipping unknown blocked-service %q", s)
log.Debug("clients: skipping unknown blocked-service %q", s)
continue
}
cli.BlockedServices = append(cli.BlockedServices, s)
@@ -180,7 +179,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
for _, t := range cy.Tags {
if !clients.tagKnown(t) {
log.Debug("Clients: skipping unknown tag %q", t)
log.Debug("clients: skipping unknown tag %q", t)
continue
}
cli.Tags = append(cli.Tags, t)
@@ -208,10 +207,10 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
}
cy.Tags = stringArrayDup(cli.Tags)
cy.IDs = stringArrayDup(cli.IDs)
cy.BlockedServices = stringArrayDup(cli.BlockedServices)
cy.Upstreams = stringArrayDup(cli.Upstreams)
cy.Tags = copyStrings(cli.Tags)
cy.IDs = copyStrings(cli.IDs)
cy.BlockedServices = copyStrings(cli.BlockedServices)
cy.Upstreams = copyStrings(cli.Upstreams)
*objects = append(*objects, cy)
}
@@ -238,45 +237,44 @@ func (clients *clientsContainer) onHostsChanged() {
clients.addFromHostsFile()
}
// Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
// Exists checks if client with this ID already exists.
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findByIP(ip)
_, ok = clients.findLocked(id)
if ok {
return true
}
ch, ok := clients.ipHost[ip]
var ch *ClientHost
ch, ok = clients.ipHost[id]
if !ok {
return false
}
if source > ch.Source {
return false // we're going to overwrite this client's info with a stronger source
}
return true
// Return false if the new source has higher priority.
return source <= ch.Source
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
func copyStrings(a []string) (b []string) {
return append(b, a...)
}
// Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) {
// Find searches for a client by its ID.
func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
c, ok = clients.findLocked(id)
if !ok {
return Client{}, false
return nil, false
}
c.IDs = stringArrayDup(c.IDs)
c.Tags = stringArrayDup(c.Tags)
c.BlockedServices = stringArrayDup(c.BlockedServices)
c.Upstreams = stringArrayDup(c.Upstreams)
c.IDs = copyStrings(c.IDs)
c.Tags = copyStrings(c.Tags)
c.BlockedServices = copyStrings(c.BlockedServices)
c.Upstreams = copyStrings(c.Upstreams)
return c, true
}
@@ -287,7 +285,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findByIP(ip)
c, ok := clients.findLocked(ip)
if !ok {
return nil
}
@@ -306,16 +304,16 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
return c.upstreamConfig
}
// Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return Client{}, false
// findLocked searches for a client by its ID. For internal use only.
func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
c, ok = clients.idIndex[id]
if ok {
return c, true
}
c, ok := clients.idIndex[ip]
if ok {
return *c, true
ip := net.ParseIP(id)
if ip == nil {
return nil, false
}
for _, c = range clients.list {
@@ -324,32 +322,36 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if err != nil {
continue
}
if ipnet.Contains(ipAddr) {
return *c, true
if ipnet.Contains(ip) {
return c, true
}
}
}
if clients.dhcpServer == nil {
return Client{}, false
return nil, false
}
macFound := clients.dhcpServer.FindMACbyIP(ipAddr)
macFound := clients.dhcpServer.FindMACbyIP(ip)
if macFound == nil {
return Client{}, false
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
hwAddr, err := net.ParseMAC(id)
if err != nil {
continue
}
if bytes.Equal(hwAddr, macFound) {
return *c, true
return c, true
}
}
}
return Client{}, false
return nil, false
}
// FindAutoClient - search for an auto-client by IP
@@ -369,44 +371,47 @@ func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
return ClientHost{}, false
}
// Check if Client object's fields are correct
func (clients *clientsContainer) check(c *Client) error {
if len(c.Name) == 0 {
return fmt.Errorf("invalid Name")
}
if len(c.IDs) == 0 {
return fmt.Errorf("id required")
// check validates the client.
func (clients *clientsContainer) check(c *Client) (err error) {
switch {
case c == nil:
return agherr.Error("client is nil")
case c.Name == "":
return agherr.Error("invalid name")
case len(c.IDs) == 0:
return agherr.Error("id required")
default:
// Go on.
}
for i, id := range c.IDs {
ip := net.ParseIP(id)
if ip != nil {
c.IDs[i] = ip.String() // normalize IP address
continue
// Normalize structured data.
var ip net.IP
var ipnet *net.IPNet
var mac net.HardwareAddr
if ip = net.ParseIP(id); ip != nil {
c.IDs[i] = ip.String()
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
ipnet.IP = ip
c.IDs[i] = ipnet.String()
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
c.IDs[i] = id
} else {
return fmt.Errorf("invalid client id at index %d: %q", i, id)
}
_, _, err := net.ParseCIDR(id)
if err == nil {
continue
}
_, err = net.ParseMAC(id)
if err == nil {
continue
}
return fmt.Errorf("invalid ID: %s", id)
}
for _, t := range c.Tags {
if !clients.tagKnown(t) {
return fmt.Errorf("invalid tag: %s", t)
return fmt.Errorf("invalid tag: %q", t)
}
}
sort.Strings(c.Tags)
err := dnsforward.ValidateUpstreams(c.Upstreams)
err = dnsforward.ValidateUpstreams(c.Upstreams)
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
@@ -414,49 +419,52 @@ func (clients *clientsContainer) check(c *Client) error {
return nil
}
// Add a new client object
// Return true: success; false: client exists.
func (clients *clientsContainer) Add(c Client) (bool, error) {
e := clients.check(&c)
if e != nil {
return false, e
// Add adds a new client object. ok is false if such client already exists or
// if an error occurred.
func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
err = clients.check(c)
if err != nil {
return false, err
}
clients.lock.Lock()
defer clients.lock.Unlock()
// check Name index
_, ok := clients.list[c.Name]
_, ok = clients.list[c.Name]
if ok {
return false, nil
}
// check ID index
for _, id := range c.IDs {
c2, ok := clients.idIndex[id]
var c2 *Client
c2, ok = clients.idIndex[id]
if ok {
return false, fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
}
// update Name index
clients.list[c.Name] = &c
clients.list[c.Name] = c
// update ID index
for _, id := range c.IDs {
clients.idIndex[id] = &c
clients.idIndex[id] = c
}
log.Debug("Clients: added %q: ID:%v [%d]", c.Name, c.IDs, len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
return true, nil
}
// Del removes a client
func (clients *clientsContainer) Del(name string) bool {
// Del removes a client. ok is false if there is no such client.
func (clients *clientsContainer) Del(name string) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.list[name]
var c *Client
c, ok = clients.list[name]
if !ok {
return false
}
@@ -468,25 +476,28 @@ func (clients *clientsContainer) Del(name string) bool {
for _, id := range c.IDs {
delete(clients.idIndex, id)
}
return true
}
// Return TRUE if arrays are equal
func arraysEqual(a, b []string) bool {
// equalStringSlices returns true if the slices are equal.
func equalStringSlices(a, b []string) (ok bool) {
if len(a) != len(b) {
return false
}
for i := 0; i != len(a); i++ {
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// Update a client
func (clients *clientsContainer) Update(name string, c Client) error {
err := clients.check(&c)
// Update updates a client by its name.
func (clients *clientsContainer) Update(name string, c *Client) (err error) {
err = clients.check(c)
if err != nil {
return err
}
@@ -494,65 +505,69 @@ func (clients *clientsContainer) Update(name string, c Client) error {
clients.lock.Lock()
defer clients.lock.Unlock()
old, ok := clients.list[name]
prev, ok := clients.list[name]
if !ok {
return fmt.Errorf("client not found")
return agherr.Error("client not found")
}
// check Name index
if old.Name != c.Name {
if prev.Name != c.Name {
_, ok = clients.list[c.Name]
if ok {
return fmt.Errorf("client already exists")
return agherr.Error("client already exists")
}
}
// check IP index
if !arraysEqual(old.IDs, c.IDs) {
if !equalStringSlices(prev.IDs, c.IDs) {
for _, id := range c.IDs {
c2, ok := clients.idIndex[id]
if ok && c2 != old {
return fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
if ok && c2 != prev {
return fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
}
// update ID index
for _, id := range old.IDs {
for _, id := range prev.IDs {
delete(clients.idIndex, id)
}
for _, id := range c.IDs {
clients.idIndex[id] = old
clients.idIndex[id] = prev
}
}
// update Name index
if old.Name != c.Name {
delete(clients.list, old.Name)
clients.list[c.Name] = old
if prev.Name != c.Name {
delete(clients.list, prev.Name)
clients.list[c.Name] = prev
}
// update upstreams cache
c.upstreamConfig = nil
*old = c
*prev = *c
return nil
}
// SetWhoisInfo - associate WHOIS information with a client
// SetWhoisInfo sets the WHOIS information for a client.
//
// TODO(a.garipov): Perhaps replace [][]string with map[string]string.
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
clients.lock.Lock()
defer clients.lock.Unlock()
_, ok := clients.findByIP(ip)
_, ok := clients.findLocked(ip)
if ok {
log.Debug("Clients: client for %s is already created, ignore WHOIS info", ip)
log.Debug("clients: client for %s is already created, ignore whois info", ip)
return
}
ch, ok := clients.ipHost[ip]
if ok {
ch.WhoisInfo = info
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
log.Debug("clients: set whois info for auto-client %s: %q", ch.Host, info)
return
}
@@ -562,31 +577,33 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
}
ch.WhoisInfo = info
clients.ipHost[ip] = ch
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
log.Debug("clients: set whois info for auto-client with IP %s: %q", ip, info)
}
// AddHost adds new IP -> Host pair
// Use priority of the source (etc/hosts > ARP > rDNS)
// so we overwrite existing entries with an equal or higher priority
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
clients.lock.Lock()
b := clients.addHost(ip, host, source)
ok = clients.addHostLocked(ip, host, src)
clients.lock.Unlock()
return b, nil
return ok, nil
}
func (clients *clientsContainer) addHost(ip, host string, source clientSource) (addedNew bool) {
ch, ok := clients.ipHost[ip]
// addHostLocked adds a new IP-hostname pairing. For internal use only.
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
var ch *ClientHost
ch, ok = clients.ipHost[ip]
if ok {
if ch.Source > source {
if ch.Source > src {
return false
}
ch.Source = source
ch.Source = src
} else {
ch = &ClientHost{
Host: host,
Source: source,
Source: src,
}
clients.ipHost[ip] = ch
@@ -597,11 +614,11 @@ func (clients *clientsContainer) addHost(ip, host string, source clientSource) (
return true
}
// Remove all entries that match the specified source
func (clients *clientsContainer) rmHosts(source clientSource) {
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
n := 0
for k, v := range clients.ipHost {
if v.Source == source {
if v.Source == src {
delete(clients.ipHost, k)
n++
}
@@ -610,19 +627,20 @@ func (clients *clientsContainer) rmHosts(source clientSource) {
log.Debug("clients: removed %d client aliases", n)
}
// addFromHostsFile fills the clients hosts list from the system's hosts files.
// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile() {
hosts := clients.autoHosts.List()
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHosts(ClientSourceHostsFile)
clients.rmHostsBySrc(ClientSourceHostsFile)
n := 0
for ip, names := range hosts {
for _, name := range names {
ok := clients.addHost(ip, name, ClientSourceHostsFile)
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
if ok {
n++
}
@@ -632,31 +650,31 @@ func (clients *clientsContainer) addFromHostsFile() {
log.Debug("Clients: added %d client aliases from system hosts-file", n)
}
// Add IP -> Host pairs from the system's `arp -a` command output
// The command's output is:
// HOST (IP) at MAC on IFACE
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
// command.
func (clients *clientsContainer) addFromSystemARP() {
if runtime.GOOS == "windows" {
return
}
cmd := exec.Command("arp", "-a")
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
log.Tracef("executing %q %q", cmd.Path, cmd.Args)
data, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
log.Debug("command %s has failed: %v code:%d",
log.Debug("command %q has failed: %q code:%d",
cmd.Path, err, cmd.ProcessState.ExitCode())
return
}
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHosts(ClientSourceARP)
clients.rmHostsBySrc(ClientSourceARP)
n := 0
// TODO(a.garipov): Rewrite to use bufio.Scanner.
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
open := strings.Index(ln, " (")
close := strings.Index(ln, ") ")
if open == -1 || close == -1 || open >= close {
@@ -669,16 +687,17 @@ func (clients *clientsContainer) addFromSystemARP() {
continue
}
ok := clients.addHost(ip, host, ClientSourceARP)
ok := clients.addHostLocked(ip, host, ClientSourceARP)
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from 'arp -a' command output", n)
log.Debug("clients: added %d client aliases from 'arp -a' command output", n)
}
// Add clients from DHCP that have non-empty Hostname property
// addFromDHCP adds the clients that have a non-empty hostname from the DHCP
// server.
func (clients *clientsContainer) addFromDHCP() {
if clients.dhcpServer == nil {
return
@@ -687,18 +706,20 @@ func (clients *clientsContainer) addFromDHCP() {
clients.lock.Lock()
defer clients.lock.Unlock()
clients.rmHosts(ClientSourceDHCP)
clients.rmHostsBySrc(ClientSourceDHCP)
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
n := 0
for _, l := range leases {
if len(l.Hostname) == 0 {
if l.Hostname == "" {
continue
}
ok := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
if ok {
n++
}
}
log.Debug("Clients: added %d client aliases from DHCP", n)
log.Debug("clients: added %d client aliases from dhcp", n)
}

View File

@@ -18,32 +18,35 @@ func TestClients(t *testing.T) {
clients.Init(nil, nil, nil)
t.Run("add_success", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
}
b, err := clients.Add(c)
assert.True(t, b)
ok, err := clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
c = Client{
c = &Client{
IDs: []string{"2.2.2.2"},
Name: "client2",
}
b, err = clients.Add(c)
assert.True(t, b)
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
c, b = clients.Find("1.1.1.1")
assert.True(t, b && c.Name == "client1")
c, ok = clients.Find("1.1.1.1")
assert.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, b = clients.Find("1:2:3::4")
assert.True(t, b && c.Name == "client1")
c, ok = clients.Find("1:2:3::4")
assert.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, b = clients.Find("2.2.2.2")
assert.True(t, b && c.Name == "client2")
c, ok = clients.Find("2.2.2.2")
assert.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
@@ -51,29 +54,29 @@ func TestClients(t *testing.T) {
})
t.Run("add_fail_name", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.2.3.5"},
Name: "client1",
}
b, err := clients.Add(c)
assert.False(t, b)
ok, err := clients.Add(c)
assert.False(t, ok)
assert.Nil(t, err)
})
t.Run("add_fail_ip", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
b, err := clients.Add(c)
assert.False(t, b)
ok, err := clients.Add(c)
assert.False(t, ok)
assert.NotNil(t, err)
})
t.Run("update_fail_name", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.2.3.0"},
Name: "client3",
}
@@ -81,7 +84,7 @@ func TestClients(t *testing.T) {
err := clients.Update("client3", c)
assert.NotNil(t, err)
c = Client{
c = &Client{
IDs: []string{"1.2.3.0"},
Name: "client2",
}
@@ -91,7 +94,7 @@ func TestClients(t *testing.T) {
})
t.Run("update_fail_ip", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"2.2.2.2"},
Name: "client1",
}
@@ -101,7 +104,7 @@ func TestClients(t *testing.T) {
})
t.Run("update_success", func(t *testing.T) {
c := Client{
c := &Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
@@ -112,7 +115,7 @@ func TestClients(t *testing.T) {
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
c = Client{
c = &Client{
IDs: []string{"1.1.1.2"},
Name: "client1-renamed",
UseOwnSettings: true,
@@ -121,50 +124,52 @@ func TestClients(t *testing.T) {
err = clients.Update("client1", c)
assert.Nil(t, err)
c, b := clients.Find("1.1.1.2")
assert.True(t, b)
assert.True(t, c.Name == "client1-renamed")
assert.True(t, c.IDs[0] == "1.1.1.2")
c, ok := clients.Find("1.1.1.2")
assert.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
assert.Nil(t, clients.list["client1"])
if assert.Len(t, c.IDs, 1) {
assert.Equal(t, "1.1.1.2", c.IDs[0])
}
})
t.Run("del_success", func(t *testing.T) {
b := clients.Del("client1-renamed")
assert.True(t, b)
ok := clients.Del("client1-renamed")
assert.True(t, ok)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
})
t.Run("del_fail", func(t *testing.T) {
b := clients.Del("client3")
assert.False(t, b)
ok := clients.Del("client3")
assert.False(t, ok)
})
t.Run("addhost_success", func(t *testing.T) {
b, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
assert.True(t, b)
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
assert.True(t, ok)
assert.Nil(t, err)
b, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
assert.True(t, b)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
assert.True(t, ok)
assert.Nil(t, err)
b, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
assert.True(t, b)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
assert.True(t, ok)
assert.Nil(t, err)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
})
t.Run("addhost_fail", func(t *testing.T) {
b, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
assert.False(t, b)
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
assert.False(t, ok)
assert.Nil(t, err)
})
}
func TestClientsWhois(t *testing.T) {
var c Client
var c *Client
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
@@ -172,26 +177,36 @@ func TestClientsWhois(t *testing.T) {
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois)
assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val")
if assert.NotNil(t, clients.ipHost["1.1.1.255"]) {
h := clients.ipHost["1.1.1.255"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois)
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
if assert.NotNil(t, clients.ipHost["1.1.1.1"]) {
h := clients.ipHost["1.1.1.1"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// Check that we cannot set whois info on a manually-added client
c = Client{
c = &Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1")
}
func TestClientsAddExisting(t *testing.T) {
var c Client
var c *Client
clients := clientsContainer{}
clients.testing = true
clients.Init(nil, nil, nil)
@@ -201,7 +216,7 @@ func TestClientsAddExisting(t *testing.T) {
testIP := "1.2.3.4"
// add a client
c = Client{
c = &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
}
@@ -230,7 +245,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.Nil(t, err)
// add a new client with the same IP as for a client with MAC
c = Client{
c = &Client{
IDs: []string{testIP},
Name: "client2",
}
@@ -239,7 +254,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.Nil(t, err)
// add a new client with the IP from the client1's IP range
c = Client{
c = &Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
@@ -255,7 +270,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients.Init(nil, nil, nil)
// add client with upstreams
client := Client{
c := &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
@@ -263,7 +278,7 @@ func TestClientsCustomUpstream(t *testing.T) {
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(client)
ok, err := clients.Add(c)
assert.Nil(t, err)
assert.True(t, ok)

View File

@@ -3,6 +3,7 @@ package home
import (
"encoding/json"
"fmt"
"net"
"net/http"
)
@@ -21,7 +22,7 @@ type clientJSON struct {
Upstreams []string `json:"upstreams"`
WhoisInfo map[string]interface{} `json:"whois_info"`
WhoisInfo map[string]string `json:"whois_info"`
// Disallowed - if true -- client's IP is not disallowed
// Otherwise, it is blocked.
@@ -38,7 +39,7 @@ type clientHostJSON struct {
Name string `json:"name"`
Source string `json:"source"`
WhoisInfo map[string]interface{} `json:"whois_info"`
WhoisInfo map[string]string `json:"whois_info"`
}
type clientListJSON struct {
@@ -74,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
cj.Source = "WHOIS"
}
cj.WhoisInfo = make(map[string]interface{})
cj.WhoisInfo = map[string]string{}
for _, wi := range ch.WhoisInfo {
cj.WhoisInfo[wi[0]] = wi[1]
}
@@ -139,7 +140,7 @@ func clientHostToJSON(ip string, ch ClientHost) clientJSON {
IDs: []string{ip},
}
cj.WhoisInfo = make(map[string]interface{})
cj.WhoisInfo = map[string]string{}
for _, wi := range ch.WhoisInfo {
cj.WhoisInfo[wi[0]] = wi[1]
}
@@ -157,7 +158,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
}
c := jsonToClient(cj)
ok, err := clients.Add(*c)
ok, err := clients.Add(c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@@ -215,7 +216,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
}
c := jsonToClient(dj.Data)
err = clients.Update(dj.Name, *c)
err = clients.Update(dj.Name, c)
if err != nil {
httpError(w, http.StatusBadRequest, "%s", err)
return
@@ -227,51 +228,78 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
// Get the list of clients by IP address list
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
data := []map[string]interface{}{}
for i := 0; ; i++ {
ip := q.Get(fmt.Sprintf("ip%d", i))
if len(ip) == 0 {
data := []map[string]clientJSON{}
for i := 0; i < len(q); i++ {
idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" {
break
}
el := map[string]interface{}{}
c, ok := clients.Find(ip)
ip := net.ParseIP(idStr)
c, ok := clients.Find(idStr)
var cj clientJSON
if !ok {
ch, ok := clients.FindAutoClient(ip)
if !ok {
continue // a client with this IP isn't found
var found bool
cj, found = clients.findTemporary(ip, idStr)
if !found {
continue
}
cj := clientHostToJSON(ip, ch)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
el[ip] = cj
} else {
cj := clientToJSON(&c)
cj = clientToJSON(c)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
el[ip] = cj
}
data = append(data, el)
}
js, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
return
data = append(data, map[string]clientJSON{
idStr: cj,
})
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(js)
err := json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err)
}
}
// findTemporary looks up the IP in temporary storages, like autohosts or
// blocklists.
func (clients *clientsContainer) findTemporary(ip net.IP, idStr string) (cj clientJSON, found bool) {
if ip == nil {
return cj, false
}
ch, ok := clients.FindAutoClient(idStr)
if !ok {
// It is still possible that the IP used to be in the runtime
// clients list, but then the server was reloaded. So, check
// the DNS server's blocked IP list.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
if rule == "" {
return clientJSON{}, false
}
cj = clientJSON{
IDs: []string{idStr},
Disallowed: disallowed,
DisallowedRule: rule,
}
return cj, true
}
cj = clientHostToJSON(idStr, ch)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
return cj, true
}
// RegisterClientsHandlers registers HTTP handlers
func (clients *clientsContainer) registerWebHandlers() {
httpRegister("GET", "/control/clients", clients.handleGetClients)
httpRegister("POST", "/control/clients/add", clients.handleAddClient)
httpRegister("POST", "/control/clients/delete", clients.handleDelClient)
httpRegister("POST", "/control/clients/update", clients.handleUpdateClient)
httpRegister("GET", "/control/clients/find", clients.handleFindClient)
httpRegister(http.MethodGet, "/control/clients", clients.handleGetClients)
httpRegister(http.MethodPost, "/control/clients/add", clients.handleAddClient)
httpRegister(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
httpRegister(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
httpRegister(http.MethodGet, "/control/clients/find", clients.handleFindClient)
}

View File

@@ -1,7 +1,9 @@
package home
import (
"errors"
"io/ioutil"
"net"
"os"
"path/filepath"
"sync"
@@ -11,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
yaml "gopkg.in/yaml.v2"
@@ -39,13 +42,14 @@ type configuration struct {
// It's reset after config is parsed
fileData []byte
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
Users []User `yaml:"users"` // Users that can access HTTP server
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
Language string `yaml:"language"` // two-letter ISO 639-1 language code
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
Users []User `yaml:"users"` // Users that can access HTTP server
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
Language string `yaml:"language"` // two-letter ISO 639-1 language code
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
// TTL for a web session (in hours)
// An active session is automatically refreshed once a day.
@@ -72,7 +76,7 @@ type configuration struct {
// field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct {
BindHost string `yaml:"bind_host"`
BindHost net.IP `yaml:"bind_host"`
Port int `yaml:"port"`
// time interval for statistics (in days)
@@ -117,10 +121,11 @@ type tlsConfigSettings struct {
// initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{
BindPort: 3000,
BindHost: "0.0.0.0",
BindPort: 3000,
BetaBindPort: 0,
BindHost: net.IP{0, 0, 0, 0},
DNS: dnsConfig{
BindHost: "0.0.0.0",
BindHost: net.IP{0, 0, 0, 0},
Port: 53,
StatsInterval: 1,
FilteringConfig: dnsforward.FilteringConfig{
@@ -174,13 +179,17 @@ func initConfig() {
config.DHCP.Conf4.LeaseDuration = 86400
config.DHCP.Conf4.ICMPTimeout = 1000
config.DHCP.Conf6.LeaseDuration = 86400
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
config.BetaBindPort = 3001
}
}
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename)
if err != nil {
if !os.IsNotExist(err) {
if !errors.Is(err, os.ErrNotExist) {
log.Error("unexpected error while config file path evaluation: %s", err)
}
configFile = Context.configFilename

View File

@@ -2,6 +2,7 @@ package home
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
@@ -11,6 +12,7 @@ import (
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
)
@@ -35,48 +37,52 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
// ---------------
// dns run control
// ---------------
func addDNSAddress(dnsAddresses *[]string, addr string) {
func addDNSAddress(dnsAddresses *[]string, addr net.IP) {
hostport := addr.String()
if config.DNS.Port != 53 {
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
}
*dnsAddresses = append(*dnsAddresses, addr)
*dnsAddresses = append(*dnsAddresses, hostport)
}
func handleStatus(w http.ResponseWriter, r *http.Request) {
c := dnsforward.FilteringConfig{}
// statusResponse is a response for /control/status endpoint.
type statusResponse struct {
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"`
IsProtectionEnabled bool `json:"protection_enabled"`
// TODO(e.burkov): Inspect if front-end doesn't requires this field as
// openapi.yaml declares.
IsDHCPAvailable bool `json:"dhcp_available"`
IsRunning bool `json:"running"`
Version string `json:"version"`
Language string `json:"language"`
}
func handleStatus(w http.ResponseWriter, _ *http.Request) {
resp := statusResponse{
DNSAddrs: getDNSAddresses(),
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
IsRunning: isRunning(),
Version: version.Version(),
Language: config.Language,
}
var c *dnsforward.FilteringConfig
if Context.dnsServer != nil {
Context.dnsServer.WriteDiskConfig(&c)
c = &dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(c)
resp.IsProtectionEnabled = c.ProtectionEnabled
}
data := map[string]interface{}{
"dns_addresses": getDNSAddresses(),
"http_port": config.BindPort,
"dns_port": config.DNS.Port,
"running": isRunning(),
"version": versionString,
"language": config.Language,
"protection_enabled": c.ProtectionEnabled,
// IsDHCPAvailable field is now false by default for Windows.
if runtime.GOOS != "windows" {
resp.IsDHCPAvailable = Context.dhcpServer != nil
}
if runtime.GOOS == "windows" {
// Set the DHCP to false explicitly, because Context.dhcpServer
// is probably not nil, despite the fact that there is no
// support for DHCP on Windows in AdGuardHome.
//
// See also the TODO in dhcpd.Create.
data["dhcp_available"] = false
} else {
data["dhcp_available"] = (Context.dhcpServer != nil)
}
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
err := json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
@@ -89,7 +95,7 @@ type profileJSON struct {
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{}
u := Context.auth.GetCurrentUser(r)
u := Context.auth.getCurrentUser(r)
pj.Name = u.Name
data, err := json.Marshal(pj)
@@ -118,7 +124,7 @@ func registerControlHandlers() {
}
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
if len(method) == 0 {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler))
return
@@ -139,7 +145,7 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
return
}
if method == "POST" || method == "PUT" || method == "DELETE" {
if method == http.MethodPost || method == http.MethodPut || method == http.MethodDelete {
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
}
@@ -149,11 +155,11 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("POST", handler)
return ensure(http.MethodPost, handler)
}
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure("GET", handler)
return ensure(http.MethodGet, handler)
}
// Bridge between http.Handler object and Go function
@@ -197,37 +203,84 @@ func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
}
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
// it also enforces HTTPS if it is enabled and configured
const defaultHTTPSPort = 443
// handleHTTPSRedirect redirects the request to HTTPS, if needed. If ok is
// true, the middleware must continue handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
web := Context.web
if web.httpsServer.server == nil {
return true
}
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// Check for the missing port error. If it is that error, just
// use the host as is.
//
// See the source code for net.SplitHostPort.
const missingPort = "missing port in address"
addrErr := &net.AddrError{}
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
httpError(w, http.StatusBadRequest, "bad host: %s", err)
return false
}
host = r.Host
}
if r.TLS == nil && web.forceHTTPS {
hostPort := host
if port := web.conf.PortHTTPS; port != defaultHTTPSPort {
portStr := strconv.Itoa(port)
hostPort = net.JoinHostPort(host, portStr)
}
httpsURL := &url.URL{
Scheme: "https",
Host: hostPort,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, httpsURL.String(), http.StatusTemporaryRedirect)
return false
}
// Allow the frontend from the HTTP origin to send requests to the HTTPS
// server. This can happen when the user has just set up HTTPS with
// redirects. Prevent cache-related errors by setting the Vary header.
//
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
originURL := &url.URL{
Scheme: "http",
Host: r.Host,
}
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
w.Header().Set("Vary", "Origin")
return true
}
// postInstall lets the handler to run only if firstRun is false. Otherwise, it
// redirects to /install.html. It also enforces HTTPS if it is enabled and
// configured and sets appropriate access control headers.
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") &&
!strings.HasPrefix(r.URL.Path, "/assets/") {
path := r.URL.Path
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "/install.html", http.StatusFound)
return
}
// enforce https?
if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
// yes, and we want host from host:port
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// no port in host
host = r.Host
}
// construct new URL to redirect to
newURL := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
}
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
if !handleHTTPSRedirect(w, r) {
return
}
w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r)
}
}

View File

@@ -68,8 +68,8 @@ kXS9jgARhhiWXJrk
data.KeyType == "RSA" &&
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.NotBefore == notBefore &&
data.NotAfter == notAfter &&
data.NotBefore.Equal(notBefore) &&
data.NotAfter.Equal(notAfter) &&
// data.DNSNames[0] == &&
data.ValidPair) {
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)

View File

@@ -369,7 +369,7 @@ type checkHostResp struct {
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for ReasonRewrite:
// for Rewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
}
@@ -417,14 +417,14 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
// RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() {
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
}
func checkFiltersUpdateIntervalHours(i uint32) bool {

View File

@@ -12,6 +12,8 @@ import (
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
@@ -20,23 +22,16 @@ import (
"github.com/AdguardTeam/golibs/log"
)
type firstRunData struct {
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces map[string]interface{} `json:"interfaces"`
// getAddrsResponse is the response for /install/get_addresses endpoint.
type getAddrsResponse struct {
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces map[string]*util.NetInterface `json:"interfaces"`
}
type netInterfaceJSON struct {
Name string `json:"name"`
MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"`
Flags string `json:"flags"`
}
// Get initial installation settings
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := firstRunData{}
data := getAddrsResponse{}
data.WebPort = 80
data.DNSPort = 53
@@ -46,16 +41,9 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
return
}
data.Interfaces = make(map[string]interface{})
data.Interfaces = make(map[string]*util.NetInterface)
for _, iface := range ifaces {
ifaceJSON := netInterfaceJSON{
Name: iface.Name,
MTU: iface.MTU,
HardwareAddr: iface.HardwareAddr,
Addresses: iface.Addresses,
Flags: iface.Flags,
}
data.Interfaces[iface.Name] = ifaceJSON
data.Interfaces[iface.Name] = iface
}
w.Header().Set("Content-Type", "application/json")
@@ -68,7 +56,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
type checkConfigReqEnt struct {
Port int `json:"port"`
IP string `json:"ip"`
IP net.IP `json:"ip"`
Autofix bool `json:"autofix"`
}
@@ -105,10 +93,10 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
return
}
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort {
err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
if err != nil {
respData.Web.Status = fmt.Sprintf("%v", err)
respData.Web.Status = err.Error()
}
}
@@ -136,8 +124,8 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
}
if err != nil {
respData.DNS.Status = fmt.Sprintf("%v", err)
} else if reqData.DNS.IP != "0.0.0.0" {
respData.DNS.Status = err.Error()
} else if !reqData.DNS.IP.IsUnspecified() {
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
}
}
@@ -153,7 +141,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
// handleStaticIP - handles static IP request
// It either checks if we have a static IP
// Or if set=true, it tries to set it
func handleStaticIP(ip string, set bool) staticIPJSON {
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
resp := staticIPJSON{}
interfaceName := util.GetInterfaceByIP(ip)
@@ -185,7 +173,7 @@ func handleStaticIP(ip string, set bool) staticIPJSON {
if isStaticIP {
resp.Static = "yes"
}
resp.IP = util.GetSubnet(interfaceName)
resp.IP = util.GetSubnet(interfaceName).String()
}
return resp
}
@@ -261,7 +249,7 @@ func disableDNSStubListener() error {
}
type applyConfigReqEnt struct {
IP string `json:"ip"`
IP net.IP `json:"ip"`
Port int `json:"port"`
}
@@ -276,10 +264,14 @@ type applyConfigReq struct {
func copyInstallSettings(dst, src *configuration) {
dst.BindHost = src.BindHost
dst.BindPort = src.BindPort
dst.BetaBindPort = src.BetaBindPort
dst.DNS.BindHost = src.DNS.BindHost
dst.DNS.Port = src.DNS.Port
}
// shutdownTimeout is the timeout for shutting HTTP server down operation.
const shutdownTimeout = 5 * time.Second
// Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
newSettings := applyConfigReq{}
@@ -295,7 +287,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
}
restartHTTP := true
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
if config.BindHost.Equal(newSettings.Web.IP) && config.BindPort == newSettings.Web.Port {
// no need to rebind
restartHTTP = false
}
@@ -305,9 +297,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
if err != nil {
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
net.JoinHostPort(newSettings.Web.IP.String(), strconv.Itoa(newSettings.Web.Port)), err)
return
}
}
err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
@@ -331,6 +324,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
config.DNS.BindHost = newSettings.DNS.IP
config.DNS.Port = newSettings.DNS.Port
// TODO(e.burkov): StartMods() should be put in a separate goroutine at
// the moment we'll allow setting up TLS in the initial configuration or
// the configuration itself will use HTTPS protocol, because the
// underlying functions potentially restart the HTTPS server.
err = StartMods()
if err != nil {
Context.firstRun = true
@@ -362,12 +359,23 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
f.Flush()
}
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
// The Shutdown() method of (*http.Server) needs to be called in a
// separate goroutine, because it waits until all requests are handled
// and will be blocked by it's own caller.
if restartHTTP {
go func() {
_ = Context.web.httpServer.Shutdown(context.TODO())
}()
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
shut := func(srv *http.Server) {
defer cancel()
err := srv.Shutdown(ctx)
if err != nil {
log.Debug("error while shutting down HTTP server: %s", err)
}
}
go shut(web.httpServer)
if web.httpServerBeta != nil {
go shut(web.httpServerBeta)
}
}
}
@@ -376,3 +384,186 @@ func (web *Web) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
}
// checkConfigReqEntBeta is a struct representing new client's config check
// request entry. It supports multiple IP values unlike the checkConfigReqEnt.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReqEnt.
type checkConfigReqEntBeta struct {
Port int `json:"port"`
IP []net.IP `json:"ip"`
Autofix bool `json:"autofix"`
}
// checkConfigReqBeta is a struct representing new client's config check request
// body. It uses checkConfigReqEntBeta instead of checkConfigReqEnt.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReq.
type checkConfigReqBeta struct {
Web checkConfigReqEntBeta `json:"web"`
DNS checkConfigReqEntBeta `json:"dns"`
SetStaticIP bool `json:"set_static_ip"`
}
// handleInstallCheckConfigBeta is a substitution of /install/check_config
// handler for new client.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallCheckConfig.
func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReqBeta{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return
}
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return
}
nonBetaReqData := checkConfigReq{
Web: checkConfigReqEnt{
Port: reqData.Web.Port,
IP: reqData.Web.IP[0],
Autofix: reqData.Web.Autofix,
},
DNS: checkConfigReqEnt{
Port: reqData.DNS.Port,
IP: reqData.DNS.IP[0],
Autofix: reqData.DNS.Autofix,
},
SetStaticIP: reqData.SetStaticIP,
}
nonBetaReqBody := &strings.Builder{}
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
return
}
body := nonBetaReqBody.String()
r.Body = ioutil.NopCloser(strings.NewReader(body))
r.ContentLength = int64(len(body))
web.handleInstallCheckConfig(w, r)
}
// applyConfigReqEntBeta is a struct representing new client's config setting
// request entry. It supports multiple IP values unlike the applyConfigReqEnt.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReqEnt.
type applyConfigReqEntBeta struct {
IP []net.IP `json:"ip"`
Port int `json:"port"`
}
// applyConfigReqBeta is a struct representing new client's config setting
// request body. It uses applyConfigReqEntBeta instead of applyConfigReqEnt.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReq.
type applyConfigReqBeta struct {
Web applyConfigReqEntBeta `json:"web"`
DNS applyConfigReqEntBeta `json:"dns"`
Username string `json:"username"`
Password string `json:"password"`
}
// handleInstallConfigureBeta is a substitution of /install/configure handler
// for new client.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallConfigure.
func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Request) {
reqData := applyConfigReqBeta{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
return
}
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
return
}
nonBetaReqData := applyConfigReq{
Web: applyConfigReqEnt{
IP: reqData.Web.IP[0],
Port: reqData.Web.Port,
},
DNS: applyConfigReqEnt{
IP: reqData.DNS.IP[0],
Port: reqData.DNS.Port,
},
Username: reqData.Username,
Password: reqData.Password,
}
nonBetaReqBody := &strings.Builder{}
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
return
}
body := nonBetaReqBody.String()
r.Body = ioutil.NopCloser(strings.NewReader(body))
r.ContentLength = int64(len(body))
web.handleInstallConfigure(w, r)
}
// getAddrsResponseBeta is a struct representing new client's getting addresses
// request body. It uses array of structs instead of map.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default firstRunData.
type getAddrsResponseBeta struct {
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces []*util.NetInterface `json:"interfaces"`
}
// handleInstallConfigureBeta is a substitution of /install/get_addresses
// handler for new client.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallGetAddresses.
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponseBeta{}
data.WebPort = 80
data.DNSPort = 53
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
data.Interfaces = ifaces
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
return
}
}
// registerBetaInstallHandlers registers the install handlers for new client
// with the structures it supports.
//
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handlers.
func (web *Web) registerBetaInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses_beta", preInstall(ensureGET(web.handleInstallGetAddressesBeta)))
Context.mux.HandleFunc("/control/install/check_config_beta", preInstall(ensurePOST(web.handleInstallCheckConfigBeta)))
Context.mux.HandleFunc("/control/install/configure_beta", preInstall(ensurePOST(web.handleInstallConfigureBeta)))
}

View File

@@ -1,76 +1,104 @@
package home
import (
"context"
"encoding/json"
"errors"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/sysutil"
"github.com/AdguardTeam/AdGuardHome/internal/update"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/log"
)
type getVersionJSONRequest struct {
RecheckNow bool `json:"recheck_now"`
// temporaryError is the interface for temporary errors from the Go standard
// library.
type temporaryError interface {
error
Temporary() (ok bool)
}
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
resp := &versionResponse{}
if Context.disableUpdate {
resp := make(map[string]interface{})
resp["disabled"] = true
d, _ := json.Marshal(resp)
_, _ = w.Write(d)
// w.Header().Set("Content-Type", "application/json")
resp.Disabled = true
_ = json.NewEncoder(w).Encode(resp)
// TODO(e.burkov): Add error handling and deal with headers.
return
}
req := getVersionJSONRequest{}
req := &struct {
Recheck bool `json:"recheck_now"`
}{}
var err error
if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
return
}
}
var info update.VersionInfo
for i := 0; i != 3; i++ {
Context.controlLock.Lock()
info, err = Context.updater.GetVersionResponse(req.RecheckNow)
Context.controlLock.Unlock()
if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") {
// This case may happen while we're restarting DNS server
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/934
continue
func() {
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
}()
if err != nil {
var terr temporaryError
if errors.As(err, &terr) && terr.Temporary() {
// Temporary network error. This case may happen while
// we're restarting our DNS server. Log and sleep for
// some time.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
d := time.Duration(i) * time.Second
log.Info("temp net error: %q; sleeping for %s and retrying", err, d)
time.Sleep(d)
continue
}
}
break
}
if err != nil {
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
vcu := Context.updater.VersionCheckURL()
// TODO(a.garipov): Figure out the purpose of %T verb.
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", vcu, err, err)
return
}
resp.confirmAutoUpdate()
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(getVersionResp(info))
err = json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
// Perform an update procedure to the latest available version
// handleUpdate performs an update to the latest available version procedure.
func handleUpdate(w http.ResponseWriter, _ *http.Request) {
if len(Context.updater.NewVersion) == 0 {
if Context.updater.NewVersion() == "" {
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
return
}
err := Context.updater.DoUpdate()
err := Context.updater.Update()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
@@ -81,24 +109,33 @@ func handleUpdate(w http.ResponseWriter, _ *http.Request) {
f.Flush()
}
go finishUpdate()
// The background context is used because the underlying functions wrap
// it with timeout and shut down the server, which handles current
// request. It also should be done in a separate goroutine due to the
// same reason.
go func() {
finishUpdate(context.Background())
}()
}
// Convert version.json data to our JSON response
func getVersionResp(info update.VersionInfo) []byte {
ret := make(map[string]interface{})
ret["can_autoupdate"] = false
ret["new_version"] = info.NewVersion
ret["announcement"] = info.Announcement
ret["announcement_url"] = info.AnnouncementURL
// versionResponse is the response for /control/version.json endpoint.
type versionResponse struct {
Disabled bool `json:"disabled"`
updater.VersionInfo
}
if info.CanAutoUpdate {
// confirmAutoUpdate checks the real possibility of auto update.
func (vr *versionResponse) confirmAutoUpdate() {
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate {
canUpdate := true
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
var tlsConf *tlsConfigSettings
if runtime.GOOS != "windows" {
tlsConf = &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
}
if runtime.GOOS != "windows" &&
if tlsConf != nil &&
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
tlsConf.PortDNSOverTLS < 1024 ||
tlsConf.PortDNSOverQUIC < 1024)) ||
@@ -106,17 +143,14 @@ func getVersionResp(info update.VersionInfo) []byte {
config.DNS.Port < 1024) {
canUpdate, _ = sysutil.CanBindPrivilegedPorts()
}
ret["can_autoupdate"] = canUpdate
vr.CanAutoUpdate = &canUpdate
}
d, _ := json.Marshal(ret)
return d
}
// Complete an update procedure
func finishUpdate() {
// finishUpdate completes an update procedure.
func finishUpdate(ctx context.Context) {
log.Info("Stopping all tasks")
cleanup()
cleanup(ctx)
cleanupAlways()
exeName := "AdGuardHome"

View File

@@ -1,102 +0,0 @@
// +build ignore
package home
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDoUpdate(t *testing.T) {
config.DNS.Port = 0
Context.workDir = "..." // set absolute path
newver := "v0.96"
data := `{
"version": "v0.96",
"announcement": "AdGuard Home v0.96 is now available!",
"announcement_url": "",
"download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip",
"download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip",
"download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
"download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip",
"download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
"download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
"download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
"download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
"download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
"download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
"download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
"download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
"download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
"download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
"download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
"download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
"download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
"download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
"download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
"download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
}`
uu, err := getUpdateInfo([]byte(data))
if err != nil {
t.Fatalf("getUpdateInfo: %s", err)
}
u := updateInfo{
pkgURL: "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
newVer: newver,
updateDir: Context.workDir + "/agh-update-" + newver,
backupDir: Context.workDir + "/agh-backup",
configName: Context.workDir + "/AdGuardHome.yaml",
updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome.yaml",
curBinName: Context.workDir + "/AdGuardHome",
bkpBinName: Context.workDir + "/agh-backup/AdGuardHome",
newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome",
}
assert.Equal(t, uu.pkgURL, u.pkgURL)
assert.Equal(t, uu.pkgName, u.pkgName)
assert.Equal(t, uu.newVer, u.newVer)
assert.Equal(t, uu.updateDir, u.updateDir)
assert.Equal(t, uu.backupDir, u.backupDir)
assert.Equal(t, uu.configName, u.configName)
assert.Equal(t, uu.updateConfigName, u.updateConfigName)
assert.Equal(t, uu.curBinName, u.curBinName)
assert.Equal(t, uu.bkpBinName, u.bkpBinName)
assert.Equal(t, uu.newBinName, u.newBinName)
e := doUpdate(&u)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
os.RemoveAll(u.backupDir)
}
func TestTargzFileUnpack(t *testing.T) {
fn := "../dist/AdGuardHome_linux_amd64.tar.gz"
outdir := "../test-unpack"
defer os.RemoveAll(outdir)
_ = os.Mkdir(outdir, 0o755)
files, e := targzFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
t.Logf("%v", files)
}
func TestZipFileUnpack(t *testing.T) {
fn := "../dist/AdGuardHome_windows_amd64.zip"
outdir := "../test-unpack"
_ = os.Mkdir(outdir, 0o755)
files, e := zipFileUnpack(fn, outdir)
if e != nil {
t.Fatalf("FAILED: %s", e)
}
t.Logf("%v", files)
os.RemoveAll(outdir)
}

View File

@@ -3,8 +3,10 @@ package home
import (
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
@@ -55,10 +57,10 @@ func initDNSServer() error {
filterConf := config.DNS.DnsfilterConf
bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" {
bindhost = "127.0.0.1"
if config.DNS.BindHost.IsUnspecified() {
bindhost = net.IPv4(127, 0, 0, 1)
}
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.ResolverAddress = net.JoinHostPort(bindhost.String(), strconv.Itoa(config.DNS.Port))
filterConf.AutoHosts = &Context.autoHosts
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
@@ -98,26 +100,24 @@ func isRunning() bool {
}
func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.GetIPString(d.Addr)
if ip == "" {
ip := dnsforward.IPFromAddr(d.Addr)
if ip == nil {
// This would be quite weird if we get here
return
}
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
if !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.ipDetector.detectSpecialNetwork(ipAddr) {
if !Context.ipDetector.detectSpecialNetwork(ip) {
Context.whois.Begin(ip)
}
}
func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
bindHost := net.ParseIP(config.DNS.BindHost)
newconfig = dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: bindHost, Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: bindHost, Port: config.DNS.Port},
UDPListenAddr: &net.UDPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
@@ -128,23 +128,24 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled {
newconfig.TLSConfig = tlsConf.TLSConfig
newconfig.TLSConfig.ServerName = tlsConf.ServerName
if tlsConf.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{
IP: bindHost,
IP: config.DNS.BindHost,
Port: tlsConf.PortDNSOverTLS,
}
}
if tlsConf.PortDNSOverQUIC != 0 {
newconfig.QUICListenAddr = &net.UDPAddr{
IP: bindHost,
IP: config.DNS.BindHost,
Port: int(tlsConf.PortDNSOverQUIC),
}
}
if tlsConf.PortDNSCrypt != 0 {
newconfig.DNSCryptConfig, err = newDNSCrypt(bindHost, tlsConf)
newconfig.DNSCryptConfig, err = newDNSCrypt(config.DNS.BindHost, tlsConf)
if err != nil {
// Don't wrap the error, because it's already
// wrapped by newDNSCrypt.
@@ -209,43 +210,49 @@ type dnsEncryption struct {
quic string
}
func getDNSEncryption() dnsEncryption {
dnsEncryption := dnsEncryption{}
func getDNSEncryption() (de dnsEncryption) {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
hostname := tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
addr := tlsConf.ServerName
addr := hostname
if tlsConf.PortHTTPS != 443 {
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
addr = net.JoinHostPort(addr, strconv.Itoa(tlsConf.PortHTTPS))
}
addr = fmt.Sprintf("https://%s/dns-query", addr)
dnsEncryption.https = addr
de.https = (&url.URL{
Scheme: "https",
Host: addr,
Path: "/dns-query",
}).String()
}
if tlsConf.PortDNSOverTLS != 0 {
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
dnsEncryption.tls = addr
de.tls = (&url.URL{
Scheme: "tls",
Host: net.JoinHostPort(hostname, strconv.Itoa(tlsConf.PortDNSOverTLS)),
}).String()
}
if tlsConf.PortDNSOverQUIC != 0 {
addr := fmt.Sprintf("quic://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverQUIC)
dnsEncryption.quic = addr
de.quic = (&url.URL{
Scheme: "quic",
Host: net.JoinHostPort(hostname, strconv.Itoa(int(tlsConf.PortDNSOverQUIC))),
}).String()
}
}
return dnsEncryption
return de
}
// Get the list of DNS addresses the server is listening on
func getDNSAddresses() []string {
dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" {
if config.DNS.BindHost.IsUnspecified() {
ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
@@ -275,21 +282,26 @@ func getDNSAddresses() []string {
return dnsAddresses
}
// If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if len(clientAddr) == 0 {
if clientAddr == nil {
return
}
setts.ClientIP = clientAddr
c, ok := Context.clients.Find(clientAddr)
c, ok := Context.clients.Find(clientID)
if !ok {
return
c, ok = Context.clients.Find(clientAddr.String())
if !ok {
return
}
}
log.Debug("Using settings for client %s with IP %s", c.Name, clientAddr)
log.Debug("using settings for client %s with ip %s and id %q", c.Name, clientAddr, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
@@ -328,13 +340,11 @@ func startDNSServer() error {
Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
for _, ip := range topClients {
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
if !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if !Context.ipDetector.detectSpecialNetwork(ipAddr) {
if !Context.ipDetector.detectSpecialNetwork(ip) {
Context.whois.Begin(ip)
}
}

View File

@@ -50,16 +50,17 @@ func TestFilters(t *testing.T) {
// download
ok, err := Context.filters.update(&f)
assert.Equal(t, nil, err)
assert.Nil(t, err)
assert.True(t, ok)
assert.Equal(t, 3, f.RulesCount)
// refresh
ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil)
assert.False(t, ok)
assert.Nil(t, err)
err = Context.filters.load(&f)
assert.True(t, err == nil)
assert.Nil(t, err)
f.unload()
_ = os.Remove(f.Path())

View File

@@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
@@ -27,8 +28,9 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/sysutil"
"github.com/AdguardTeam/AdGuardHome/internal/update"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
"gopkg.in/natefinch/lumberjack.v2"
)
@@ -38,15 +40,6 @@ const (
configSyslog = "syslog"
)
// Update-related variables
var (
versionString = "dev"
updateChannel = "none"
versionCheckURL = ""
ARMVersion = ""
MIPSVersion = ""
)
// Global context
type homeContext struct {
// Modules
@@ -65,7 +58,7 @@ type homeContext struct {
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
updater *update.Updater
updater *updater.Updater
ipDetector *ipDetector
@@ -99,14 +92,7 @@ func (c *homeContext) getDataDir() string {
var Context homeContext
// Main is the entry point
func Main(version, channel, armVer, mipsVer string) {
// Init update-related global variables
versionString = version
updateChannel = channel
ARMVersion = armVer
MIPSVersion = mipsVer
versionCheckURL = "https://static.adguard.com/adguardhome/" + updateChannel + "/version.json"
func Main() {
// config can be specified, which reads options from there, but other command line flags have to override config values
// therefore, we must do it manually instead of using a lib
args := loadOptions()
@@ -123,7 +109,7 @@ func Main(version, channel, armVer, mipsVer string) {
Context.tls.Reload()
default:
cleanup()
cleanup(context.Background())
cleanupAlways()
os.Exit(0)
}
@@ -139,23 +125,10 @@ func Main(version, channel, armVer, mipsVer string) {
run(args)
}
// version - returns the current version string
func version() string {
// TODO(a.garipov): I'm pretty sure we can extract some of this stuff
// from the build info.
msg := "AdGuard Home, version %s, channel %s, arch %s %s"
if ARMVersion != "" {
msg = msg + " v" + ARMVersion
} else if MIPSVersion != "" {
msg = msg + " " + MIPSVersion
}
return fmt.Sprintf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH)
}
func setupContext(args options) {
Context.runningAsService = args.runningAsService
Context.disableUpdate = args.disableUpdate
Context.disableUpdate = args.disableUpdate ||
version.Channel() == version.ChannelDevelopment
Context.firstRun = detectFirstRun()
if Context.firstRun {
@@ -214,15 +187,16 @@ func setupConfig(args options) {
Context.autoHosts.Init("")
Context.updater = update.NewUpdater(update.Config{
Client: Context.client,
WorkDir: Context.workDir,
VersionURL: versionCheckURL,
VersionString: versionString,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
ARMVersion: ARMVersion,
ConfigName: config.getConfigFilename(),
Context.updater = updater.NewUpdater(&updater.Config{
Client: Context.client,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: Context.workDir,
ConfName: config.getConfigFilename(),
})
Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts)
@@ -234,7 +208,7 @@ func setupConfig(args options) {
}
// override bind host/port from the console
if args.bindHost != "" {
if args.bindHost != nil {
config.BindHost = args.bindHost
}
if args.bindPort != 0 {
@@ -260,7 +234,7 @@ func run(args options) {
memoryUsage(args)
// print the first message after logger is configured
log.Println(version())
log.Println(version.Full())
log.Debug("Current working directory is %s", Context.workDir)
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
@@ -316,19 +290,25 @@ func run(args options) {
}
webConf := webConfig{
firstRun: Context.firstRun,
BindHost: config.BindHost,
BindPort: config.BindPort,
firstRun: Context.firstRun,
BindHost: config.BindHost,
BindPort: config.BindPort,
BetaBindPort: config.BetaBindPort,
ReadTimeout: ReadTimeout,
ReadHeaderTimeout: ReadHeaderTimeout,
WriteTimeout: WriteTimeout,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHdrTimeout,
WriteTimeout: writeTimeout,
}
Context.web = CreateWeb(&webConf)
if Context.web == nil {
log.Fatalf("Can't initialize Web module")
}
Context.ipDetector, err = newIPDetector()
if err != nil {
log.Fatal(err)
}
if !Context.firstRun {
err := initDNSServer()
if err != nil {
@@ -340,6 +320,7 @@ func run(args options) {
go func() {
err := startDNSServer()
if err != nil {
closeDNSServer()
log.Fatal(err)
}
}()
@@ -349,18 +330,13 @@ func run(args options) {
}
}
Context.ipDetector, err = newIPDetector()
if err != nil {
log.Fatal(err)
}
Context.web.Start()
// wait indefinitely for other go-routines to complete their job
select {}
}
// StartMods - initialize and start DNS after installation
// StartMods initializes and starts the DNS server after installation.
func StartMods() error {
err := initDNSServer()
if err != nil {
@@ -460,6 +436,10 @@ func initWorkingDir(args options) {
} else {
Context.workDir = filepath.Dir(execPath)
}
if workDir, err := filepath.EvalSymlinks(Context.workDir); err == nil {
Context.workDir = workDir
}
}
// configureLogger configures logger level and output
@@ -527,11 +507,12 @@ func configureLogger(args options) {
}
}
func cleanup() {
// cleanup stops and resets all the modules.
func cleanup(ctx context.Context) {
log.Info("Stopping AdGuard Home")
if Context.web != nil {
Context.web.Close()
Context.web.Close(ctx)
Context.web = nil
}
if Context.auth != nil {
@@ -592,8 +573,6 @@ func loadOptions() options {
// prints IP addresses which user can use to open the admin interface
// proto is either "http" or "https"
func printHTTPAddresses(proto string) {
var address string
tlsConf := tlsConfigSettings{}
if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf)
@@ -604,31 +583,41 @@ func printHTTPAddresses(proto string) {
port = strconv.Itoa(tlsConf.PortHTTPS)
}
var hostStr string
if proto == "https" && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName)
} else {
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
}
} else if config.BindHost == "0.0.0.0" {
} else if config.BindHost.IsUnspecified() {
log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil {
// That's weird, but we'll ignore it
address = net.JoinHostPort(config.BindHost, port)
log.Printf("Go to %s://%s", proto, address)
hostStr = config.BindHost.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
}
return
}
for _, iface := range ifaces {
for _, addr := range iface.Addresses {
address = net.JoinHostPort(addr, strconv.Itoa(config.BindPort))
log.Printf("Go to %s://%s", proto, address)
hostStr = addr.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BindPort)))
if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
}
}
}
} else {
address = net.JoinHostPort(config.BindHost, port)
log.Printf("Go to %s://%s", proto, address)
hostStr = config.BindHost.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
}
}
}
@@ -641,7 +630,7 @@ func detectFirstRun() bool {
configfile = filepath.Join(Context.workDir, Context.configFilename)
}
_, err := os.Stat(configfile)
return os.IsNotExist(err)
return errors.Is(err, os.ErrNotExist)
}
// Connect to a remote server resolving hostname using our own DNS server
@@ -685,10 +674,11 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err
return nil, agherr.Many(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}
func getHTTPProxy(req *http.Request) (*url.URL, error) {
if len(config.ProxyURL) == 0 {
func getHTTPProxy(_ *http.Request) (*url.URL, error) {
if config.ProxyURL == "" {
return nil, nil
}
return url.Parse(config.ProxyURL)
}

View File

@@ -1,6 +1,6 @@
// +build !race
// TODO(e.burkov): remove this weird buildtag.
// TODO(e.burkov): Remove this weird buildtag.
package home
@@ -119,7 +119,7 @@ func TestHome(t *testing.T) {
fn := filepath.Join(dir, "AdGuardHome.yaml")
// Prepare the test config
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644) == nil)
assert.Nil(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644))
fn, _ = filepath.Abs(fn)
config = configuration{} // the global variable is dirty because of the previous tests run
@@ -133,16 +133,16 @@ func TestHome(t *testing.T) {
h := http.Client{}
for i := 0; i != 50; i++ {
resp, err = h.Get("http://127.0.0.1:3000/")
if err == nil && resp.StatusCode != 404 {
if err == nil && resp.StatusCode != http.StatusNotFound {
break
}
time.Sleep(100 * time.Millisecond)
}
assert.Truef(t, err == nil, "%s", err)
assert.Nilf(t, err, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
resp, err = h.Get("http://127.0.0.1:3000/control/status")
assert.Truef(t, err == nil, "%s", err)
assert.Nilf(t, err, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// test DNS over UDP
@@ -159,16 +159,16 @@ func TestHome(t *testing.T) {
req.RecursionDesired = true
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
buf, err := req.Pack()
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
resp, err = http.DefaultClient.Get(requestURL)
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
body, err := ioutil.ReadAll(resp.Body)
assert.True(t, err == nil, "%s", err)
assert.True(t, resp.StatusCode == http.StatusOK)
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
response := dns.Msg{}
err = response.Unpack(body)
assert.True(t, err == nil, "%s", err)
assert.Nil(t, err)
addrs = nil
proxyutil.AppendIPAddrs(&addrs, response.Answer)
haveIP = len(addrs) != 0
@@ -186,6 +186,6 @@ func TestHome(t *testing.T) {
time.Sleep(1 * time.Second)
}
cleanup()
cleanup(context.Background())
cleanupAlways()
}

View File

@@ -5,16 +5,15 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIPDetector_detectSpecialNetwork(t *testing.T) {
var ipd *ipDetector
var err error
t.Run("newIPDetector", func(t *testing.T) {
var err error
ipd, err = newIPDetector()
assert.Nil(t, err)
})
ipd, err = newIPDetector()
require.Nil(t, err)
testCases := []struct {
name string

View File

@@ -22,15 +22,43 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha
return wrapped
}
// RequestBodySizeLimit is maximum request body length in bytes.
const RequestBodySizeLimit = 64 * 1024
// defaultReqBodySzLim is the default maximum request body size.
const defaultReqBodySzLim = 64 * 1024
// largerReqBodySzLim is the maximum request body size for APIs expecting larger
// requests.
const largerReqBodySzLim = 4 * 1024 * 1024
// expectsLargerRequests shows if this request should use a larger body size
// limit. These are exceptions for poorly designed current APIs as well as APIs
// that are designed to expect large files and requests. Remove once the new,
// better APIs are up.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2666 and
// https://github.com/AdguardTeam/AdGuardHome/issues/2675.
func expectsLargerRequests(r *http.Request) (ok bool) {
m := r.Method
if m != http.MethodPost {
return false
}
p := r.URL.Path
return p == "/control/access/set" ||
p == "/control/filtering/set_rules"
}
// limitRequestBody wraps underlying handler h, making it's request's body Read
// method limited.
func limitRequestBody(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
var szLim int64 = defaultReqBodySzLim
if expectsLargerRequests(r) {
szLim = largerReqBodySzLim
}
r.Body, err = aghio.LimitReadCloser(r.Body, szLim)
if err != nil {
log.Error("limitRequestBody: %s", err)
@@ -40,3 +68,18 @@ func limitRequestBody(h http.Handler) (limited http.Handler) {
h.ServeHTTP(w, r)
})
}
// wrapIndexBeta returns handler that deals with new client.
func (web *Web) wrapIndexBeta(http.Handler) (wrapped http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h, pattern := Context.mux.Handler(r)
switch pattern {
case "/":
web.handlerBeta.ServeHTTP(w, r)
case "/install.html":
web.installerBeta.ServeHTTP(w, r)
default:
h.ServeHTTP(w, r)
}
})
}

View File

@@ -9,11 +9,12 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLimitRequestBody(t *testing.T) {
errReqLimitReached := &aghio.LimitReachedError{
Limit: RequestBodySizeLimit,
Limit: defaultReqBodySzLim,
}
testCases := []struct {
@@ -28,8 +29,8 @@ func TestLimitRequestBody(t *testing.T) {
wantErr: nil,
}, {
name: "so_big",
body: string(make([]byte, RequestBodySizeLimit+1)),
want: make([]byte, RequestBodySizeLimit),
body: string(make([]byte, defaultReqBodySzLim+1)),
want: make([]byte, defaultReqBodySzLim),
wantErr: errReqLimitReached,
}, {
name: "empty",
@@ -42,7 +43,10 @@ func TestLimitRequestBody(t *testing.T) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte
b, *err = ioutil.ReadAll(r.Body)
w.Write(b)
_, werr := w.Write(b)
if werr != nil {
panic(werr)
}
})
}
@@ -57,8 +61,8 @@ func TestLimitRequestBody(t *testing.T) {
lim.ServeHTTP(res, req)
require.Equal(t, tc.wantErr, err)
assert.Equal(t, tc.want, res.Body.Bytes())
assert.Equal(t, tc.wantErr, err)
})
}
}

View File

@@ -4,7 +4,10 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"path"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/log"
uuid "github.com/satori/go.uuid"
"howett.net/plist"
@@ -14,6 +17,7 @@ type dnsSettings struct {
DNSProtocol string
ServerURL string `plist:",omitempty"`
ServerName string `plist:",omitempty"`
clientID string
}
type payloadContent struct {
@@ -23,19 +27,19 @@ type payloadContent struct {
PayloadIdentifier string
PayloadType string
PayloadUUID string
PayloadVersion int
DNSSettings dnsSettings
PayloadVersion int
}
type mobileConfig struct {
PayloadContent []payloadContent
PayloadDescription string
PayloadDisplayName string
PayloadIdentifier string
PayloadRemovalDisallowed bool
PayloadType string
PayloadUUID string
PayloadContent []payloadContent
PayloadVersion int
PayloadRemovalDisallowed bool
}
func genUUIDv4() string {
@@ -48,22 +52,35 @@ const (
)
func getMobileConfig(d dnsSettings) ([]byte, error) {
var name string
var dspName string
switch d.DNSProtocol {
case dnsProtoHTTPS:
name = fmt.Sprintf("%s DoH", d.ServerName)
d.ServerURL = fmt.Sprintf("https://%s/dns-query", d.ServerName)
dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{
Scheme: "https",
Host: d.ServerName,
Path: "/dns-query",
}
if d.clientID != "" {
u.Path = path.Join(u.Path, d.clientID)
}
d.ServerURL = u.String()
case dnsProtoTLS:
name = fmt.Sprintf("%s DoT", d.ServerName)
dspName = fmt.Sprintf("%s DoT", d.ServerName)
if d.clientID != "" {
d.ServerName = d.clientID + "." + d.ServerName
}
default:
return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol)
}
data := mobileConfig{
PayloadContent: []payloadContent{{
Name: name,
Name: dspName,
PayloadDescription: "Configures device to use AdGuard Home",
PayloadDisplayName: name,
PayloadDisplayName: dspName,
PayloadIdentifier: fmt.Sprintf("com.apple.dnsSettings.managed.%s", genUUIDv4()),
PayloadType: "com.apple.dnsSettings.managed",
PayloadUUID: genUUIDv4(),
@@ -71,7 +88,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
DNSSettings: d,
}},
PayloadDescription: "Adds AdGuard Home to Big Sur and iOS 14 or newer systems",
PayloadDisplayName: name,
PayloadDisplayName: dspName,
PayloadIdentifier: genUUIDv4(),
PayloadRemovalDisallowed: false,
PayloadType: "Configuration",
@@ -83,7 +100,10 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
}
func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
host := r.URL.Query().Get("host")
var err error
q := r.URL.Query()
host := q.Get("host")
if host == "" {
host = Context.tls.conf.ServerName
}
@@ -92,7 +112,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
w.WriteHeader(http.StatusInternalServerError)
const msg = "no host in query parameters and no server_name"
err := json.NewEncoder(w).Encode(&jsonError{
err = json.NewEncoder(w).Encode(&jsonError{
Message: msg,
})
if err != nil {
@@ -102,9 +122,25 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
return
}
clientID := q.Get("client_id")
err = dnsforward.ValidateClientID(clientID)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
err = json.NewEncoder(w).Encode(&jsonError{
Message: err.Error(),
})
if err != nil {
log.Debug("writing 400 json response: %s", err)
}
return
}
d := dnsSettings{
DNSProtocol: dnsp,
ServerName: host,
clientID: clientID,
}
mobileconfig, err := getMobileConfig(d)
@@ -115,6 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
}
w.Header().Set("Content-Type", "application/xml")
_, _ = w.Write(mobileconfig)
}

View File

@@ -23,7 +23,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@@ -51,7 +51,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@@ -73,6 +73,27 @@ func TestHandleMobileConfigDOH(t *testing.T) {
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusInternalServerError, w.Code)
})
t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
}
})
}
func TestHandleMobileConfigDOT(t *testing.T) {
@@ -89,7 +110,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@@ -116,7 +137,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) {
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@@ -137,4 +158,24 @@ func TestHandleMobileConfigDOT(t *testing.T) {
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusInternalServerError, w.Code)
})
t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
})
}

View File

@@ -2,8 +2,11 @@ package home
import (
"fmt"
"net"
"os"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/version"
)
// options passed from command-line arguments
@@ -11,7 +14,7 @@ type options struct {
verbose bool // is verbose logging enabled
configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog
bindHost string // host address to bind HTTP server on
bindHost net.IP // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to
@@ -52,10 +55,19 @@ type arg struct {
// against its zero value and return nil if the parameter value is
// zero otherwise they return a string slice of the parameter
func ipSliceOrNil(ip net.IP) []string {
if ip == nil {
return nil
}
return []string{ip.String()}
}
func stringSliceOrNil(s string) []string {
if s == "" {
return nil
}
return []string{s}
}
@@ -63,6 +75,7 @@ func intSliceOrNil(i int) []string {
if i == 0 {
return nil
}
return []string{strconv.Itoa(i)}
}
@@ -70,6 +83,7 @@ func boolSliceOrNil(b bool) []string {
if b {
return []string{}
}
return nil
}
@@ -94,8 +108,8 @@ var workDirArg = arg{
var hostArg = arg{
"Host address to bind HTTP server on",
"host", "h",
func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.bindHost) },
func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil,
func(o options) []string { return ipSliceOrNil(o.bindHost) },
}
var portArg = arg{
@@ -180,7 +194,7 @@ var versionArg = arg{
"Show the version and exit",
"version", "",
nil, nil, func(o options, exec string) (effect, error) {
return func() error { fmt.Println(version()); os.Exit(0); return nil }, nil
return func() error { fmt.Println(version.Full()); os.Exit(0); return nil }, nil
},
func(o options) []string { return nil },
}

View File

@@ -2,6 +2,7 @@ package home
import (
"fmt"
"net"
"testing"
)
@@ -65,14 +66,14 @@ func TestParseWorkDir(t *testing.T) {
}
func TestParseBindHost(t *testing.T) {
if testParseOk(t).bindHost != "" {
if testParseOk(t).bindHost != nil {
t.Fatal("empty is no host")
}
if testParseOk(t, "-h", "addr").bindHost != "addr" {
if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("-h is host")
}
testParseParamMissing(t, "-h")
if testParseOk(t, "--host", "addr").bindHost != "addr" {
if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("--host is host")
}
testParseParamMissing(t, "--host")
@@ -204,7 +205,7 @@ func TestSerializeWorkDir(t *testing.T) {
}
func TestSerializeBindHost(t *testing.T) {
testSerialize(t, options{bindHost: "addr"}, "-h", "addr")
testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
}
func TestSerializeBindPort(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package home
import (
"encoding/binary"
"net"
"strings"
"time"
@@ -15,7 +16,7 @@ import (
type RDNS struct {
dnsServer *dnsforward.Server
clients *clientsContainer
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread
// Contains IP addresses of clients to be resolved by rDNS
// If IP address is resolved, it stays here while it's inside Clients.
@@ -26,24 +27,24 @@ type RDNS struct {
// InitRDNS - create module context
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
r := RDNS{}
r.dnsServer = dnsServer
r.clients = clients
r := &RDNS{
dnsServer: dnsServer,
clients: clients,
ipAddrs: cache.New(cache.Config{
EnableLRU: true,
MaxCount: 10000,
}),
ipChannel: make(chan net.IP, 256),
}
cconf := cache.Config{}
cconf.EnableLRU = true
cconf.MaxCount = 10000
r.ipAddrs = cache.New(cconf)
r.ipChannel = make(chan string, 256)
go r.workerLoop()
return &r
return r
}
// Begin - add IP address to rDNS queue
func (r *RDNS) Begin(ip string) {
func (r *RDNS) Begin(ip net.IP) {
now := uint64(time.Now().Unix())
expire := r.ipAddrs.Get([]byte(ip))
expire := r.ipAddrs.Get(ip)
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
@@ -54,9 +55,10 @@ func (r *RDNS) Begin(ip string) {
expire = make([]byte, 8)
const ttl = 1 * 60 * 60
binary.BigEndian.PutUint64(expire, now+ttl)
_ = r.ipAddrs.Set([]byte(ip), expire)
_ = r.ipAddrs.Set(ip, expire)
if r.clients.Exists(ip, ClientSourceRDNS) {
id := ip.String()
if r.clients.Exists(id, ClientSourceRDNS) {
return
}
@@ -70,26 +72,26 @@ func (r *RDNS) Begin(ip string) {
}
// Use rDNS to get hostname by IP address
func (r *RDNS) resolve(ip string) string {
func (r *RDNS) resolve(ip net.IP) string {
log.Tracef("Resolving host for %s", ip)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
},
}
var err error
req.Question[0].Name, err = dns.ReverseAddr(ip)
name, err := dns.ReverseAddr(ip.String())
if err != nil {
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
return ""
}
resp, err := r.dnsServer.Exchange(&req)
resp, err := r.dnsServer.Exchange(&dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypePTR,
Qclass: dns.ClassINET,
}},
})
if err != nil {
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
return ""
@@ -123,6 +125,6 @@ func (r *RDNS) workerLoop() {
continue
}
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
}
}

View File

@@ -1,21 +1,32 @@
package home
import (
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/stretchr/testify/assert"
)
func TestResolveRDNS(t *testing.T) {
dns := &dnsforward.Server{}
conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"}
err := dns.Prepare(conf)
assert.True(t, err == nil, "%s", err)
ups := &aghtest.TestUpstream{
Reverse: map[string][]string{
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
},
}
dns := dnsforward.NewCustomServer(&proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
},
})
clients := &clientsContainer{}
rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1")
assert.True(t, r == "one.one.one.one", "%s", r)
r := rdns.resolve(net.IP{1, 1, 1, 1})
assert.Equal(t, "one.one.one.one", r, r)
}

View File

@@ -15,6 +15,9 @@ import (
"github.com/kardianos/service"
)
// TODO(a.garipov): Move shell templates into actual files. Either during the
// v0.106.0 cycle using packr or during the following cycle using go:embed.
const (
launchdStdoutPath = "/var/log/AdGuardHome.stdout.log"
launchdStderrPath = "/var/log/AdGuardHome.stderr.log"
@@ -504,6 +507,10 @@ status() {
}
`
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
// guarantees that it will actually be the required directory.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
const freeBSDScript = `#!/bin/sh
# PROVIDE: {{.Name}}
# REQUIRE: networking
@@ -514,6 +521,6 @@ name="{{.Name}}"
{{.Name}}_user="root"
pidfile="/var/run/${name}.pid"
command="/usr/sbin/daemon"
command_args="-P ${pidfile} -r -f {{.WorkingDirectory}}/{{.Name}}"
command_args="-P ${pidfile} -f -r {{.WorkingDirectory}}/{{.Name}}"
run_rc_command "$1"
`

View File

@@ -1,6 +1,7 @@
package home
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
@@ -92,7 +93,7 @@ func (t *TLSMod) setCertFileTime() {
t.certLastMod = fi.ModTime().UTC()
}
// Start - start the module
// Start updates the configuration of TLSMod and starts it.
func (t *TLSMod) Start() {
if !tlsWebHandlersRegistered {
tlsWebHandlersRegistered = true
@@ -102,10 +103,14 @@ func (t *TLSMod) Start() {
t.confLock.Lock()
tlsConf := t.conf
t.confLock.Unlock()
Context.web.TLSConfigChanged(tlsConf)
// The background context is used because the TLSConfigChanged wraps
// context with timeout on its own and shuts down the server, which
// handles current request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
}
// Reload - reload certificate file
// Reload updates the configuration of TLSMod and restarts it.
func (t *TLSMod) Reload() {
t.confLock.Lock()
tlsConf := t.conf
@@ -139,7 +144,10 @@ func (t *TLSMod) Reload() {
t.confLock.Lock()
tlsConf = t.conf
t.confLock.Unlock()
Context.web.TLSConfigChanged(tlsConf)
// The background context is used because the TLSConfigChanged wraps
// context with timeout on its own and shuts down the server, which
// handles current request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
}
// Set certificate and private key data
@@ -296,11 +304,13 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
f.Flush()
}
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
// The background context is used because the TLSConfigChanged wraps
// context with timeout on its own and shuts down the server, which
// handles current request. It is also should be done in a separate
// goroutine due to the same reason.
if restartHTTPS {
go func() {
Context.web.TLSConfigChanged(data)
Context.web.TLSConfigChanged(context.Background(), data)
}()
}
}
@@ -534,7 +544,7 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) {
// registerWebHandlers registers HTTP handlers for TLS configuration
func (t *TLSMod) registerWebHandlers() {
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
httpRegister(http.MethodGet, "/control/tls/status", t.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate)
}

View File

@@ -3,183 +3,108 @@ package home
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUpgrade1to2(t *testing.T) {
// let's create test config for 1 schema version
diskConfig := createTestDiskConfig(1)
// any is a convenient alias for interface{}.
type any = interface{}
// update config
err := upgradeSchema1to2(&diskConfig)
if err != nil {
t.Fatalf("Can't upgrade schema version from 1 to 2")
}
// object is a convenient alias for map[string]interface{}.
type object = map[string]any
// ensure that schema version was bumped
compareSchemaVersion(t, diskConfig["schema_version"], 2)
func TestUpgradeSchema1to2(t *testing.T) {
diskConf := testDiskConf(1)
// old coredns entry should be removed
_, ok := diskConfig["coredns"]
if ok {
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
}
err := upgradeSchema1to2(&diskConf)
require.Nil(t, err)
// pull out new dns config
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
}
require.Equal(t, diskConf["schema_version"], 2)
// cast dns configurations to maps and compare them
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
newDNSConfig := castInterfaceToMap(t, dnsMap)
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
_, ok := diskConf["coredns"]
require.False(t, ok)
dnsMap, ok := diskConf["dns"]
require.True(t, ok)
oldDNSConf := convertToObject(t, testDNSConf(1))
newDNSConf := convertToObject(t, dnsMap)
assert.Equal(t, oldDNSConf, newDNSConf)
// exclude dns config and schema version from disk config comparison
oldExcludedEntries := []string{"coredns", "schema_version"}
newExcludedEntries := []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(1)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
oldDiskConf := testDiskConf(1)
assertEqualExcept(t, oldDiskConf, diskConf, oldExcludedEntries, newExcludedEntries)
}
func TestUpgrade2to3(t *testing.T) {
// let's create test config
diskConfig := createTestDiskConfig(2)
func TestUpgradeSchema2to3(t *testing.T) {
diskConf := testDiskConf(2)
// upgrade schema from 2 to 3
err := upgradeSchema2to3(&diskConfig)
if err != nil {
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
}
err := upgradeSchema2to3(&diskConf)
require.Nil(t, err)
// check new schema version
compareSchemaVersion(t, diskConfig["schema_version"], 3)
require.Equal(t, diskConf["schema_version"], 3)
// pull out new dns configuration
dnsMap, ok := diskConfig["dns"]
if !ok {
t.Fatalf("No dns config in new configuration")
}
dnsMap, ok := diskConf["dns"]
require.True(t, ok)
// cast dns configuration to map
newDNSConfig := castInterfaceToMap(t, dnsMap)
// check if bootstrap DNS becomes an array
bootstrapDNS := newDNSConfig["bootstrap_dns"]
newDNSConf := convertToObject(t, dnsMap)
bootstrapDNS := newDNSConf["bootstrap_dns"]
switch v := bootstrapDNS.(type) {
case []string:
if len(v) != 1 {
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
}
if v[0] != "8.8.8.8:53" {
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
}
require.Len(t, v, 1)
require.Equal(t, "8.8.8.8:53", v[0])
default:
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
t.Fatalf("wrong type for bootsrap dns: %T", v)
}
// exclude bootstrap DNS from DNS configs comparison
excludedEntries := []string{"bootstrap_dns"}
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
oldDNSConf := convertToObject(t, testDNSConf(2))
assertEqualExcept(t, oldDNSConf, newDNSConf, excludedEntries, excludedEntries)
// excluded dns config and schema version from disk config comparison
excludedEntries = []string{"dns", "schema_version"}
oldDiskConfig := createTestDiskConfig(2)
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
oldDiskConf := testDiskConf(2)
assertEqualExcept(t, oldDiskConf, diskConf, excludedEntries, excludedEntries)
}
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
newConfig = make(map[string]interface{})
switch v := oldConfig.(type) {
case map[interface{}]interface{}:
func convertToObject(t *testing.T, oldConf any) (newConf object) {
t.Helper()
switch v := oldConf.(type) {
case map[any]any:
newConf = make(object, len(v))
for key, value := range v {
newConfig[fmt.Sprint(key)] = value
newConf[fmt.Sprint(key)] = value
}
case map[string]interface{}:
case object:
newConf = make(object, len(v))
for key, value := range v {
newConfig[key] = value
newConf[key] = value
}
default:
t.Fatalf("DNS configuration is not a map")
t.Fatalf("dns configuration is not a map, got %T", oldConf)
}
return
return newConf
}
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
for _, k := range oldKey {
delete(*oldConfig, k)
// assertEqualExcept removes entries from configs and compares them.
func assertEqualExcept(t *testing.T, oldConf, newConf object, oldKeys, newKeys []string) {
t.Helper()
for _, k := range oldKeys {
delete(oldConf, k)
}
for _, k := range newKey {
delete(*newConfig, k)
for _, k := range newKeys {
delete(newConf, k)
}
compareConfigs(t, oldConfig, newConfig)
assert.Equal(t, oldConf, newConf)
}
// compares configs before and after schema upgrade
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
if len(*oldConfig) != len(*newConfig) {
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
}
// Check old and new entries
for k, v := range *newConfig {
switch value := v.(type) {
case string:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
}
case int:
if value != (*oldConfig)[k] {
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
}
case []string:
for i, line := range value {
if len((*oldConfig)[k].([]string)) != len(value) {
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
}
if (*oldConfig)[k].([]string)[i] != line {
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
}
}
case bool:
if v != (*oldConfig)[k].(bool) {
t.Fatalf("wrong boolean value for %s", k)
}
case []filter:
if len((*oldConfig)[k].([]filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
}
for i, newFilter := range value {
oldFilter := (*oldConfig)[k].([]filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
}
}
default:
t.Fatalf("uknown data type for %s: %T", k, value)
}
}
}
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
switch v := newSchemaVersion.(type) {
case int:
if v != schemaVersion {
t.Fatalf("Wrong schema version in new config file")
}
default:
t.Fatalf("Schema version is not an integer after update")
}
}
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
diskConfig = make(map[string]interface{})
diskConfig["language"] = "en"
diskConfig["filters"] = []filter{
func testDiskConf(schemaVersion int) (diskConf object) {
filters := []filter{
{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
@@ -191,40 +116,51 @@ func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{})
RulesCount: 200,
},
}
diskConfig["user_rules"] = []string{}
diskConfig["schema_version"] = schemaVersion
diskConfig["bind_host"] = "0.0.0.0"
diskConfig["bind_port"] = 80
diskConfig["auth_name"] = "name"
diskConfig["auth_pass"] = "pass"
dnsConfig := createTestDNSConfig(schemaVersion)
if schemaVersion > 1 {
diskConfig["dns"] = dnsConfig
} else {
diskConfig["coredns"] = dnsConfig
diskConf = object{
"language": "en",
"filters": filters,
"user_rules": []string{},
"schema_version": schemaVersion,
"bind_host": "0.0.0.0",
"bind_port": 80,
"auth_name": "name",
"auth_pass": "pass",
}
return diskConfig
dnsConf := testDNSConf(schemaVersion)
if schemaVersion > 1 {
diskConf["dns"] = dnsConf
} else {
diskConf["coredns"] = dnsConf
}
return diskConf
}
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
dnsConfig := make(map[interface{}]interface{})
dnsConfig["port"] = 53
dnsConfig["blocked_response_ttl"] = 10
dnsConfig["querylog_enabled"] = true
dnsConfig["ratelimit"] = 20
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
if schemaVersion > 2 {
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v2 would
// unmarshal it. In YAML, keys aren't guaranteed to always only be strings.
func testDNSConf(schemaVersion int) (dnsConf map[any]any) {
dnsConf = map[any]any{
"port": 53,
"blocked_response_ttl": 10,
"querylog_enabled": true,
"ratelimit": 20,
"bootstrap_dns": "8.8.8.8:53",
"parental_sensitivity": 13,
"ratelimit_whitelist": []string{},
"upstream_dns": []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"},
"filtering_enabled": true,
"refuse_any": true,
"parental_enabled": true,
"bind_host": "0.0.0.0",
"protection_enabled": true,
"safesearch_enabled": true,
"safebrowsing_enabled": true,
}
dnsConfig["parental_sensitivity"] = 13
dnsConfig["ratelimit_whitelist"] = []string{}
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
dnsConfig["filtering_enabled"] = true
dnsConfig["refuse_any"] = true
dnsConfig["parental_enabled"] = true
dnsConfig["bind_host"] = "0.0.0.0"
dnsConfig["protection_enabled"] = true
dnsConfig["safesearch_enabled"] = true
dnsConfig["safebrowsing_enabled"] = true
return dnsConfig
if schemaVersion > 2 {
dnsConf["bootstrap_dns"] = []string{"8.8.8.8:53"}
}
return dnsConf
}

View File

@@ -16,24 +16,22 @@ import (
)
const (
// ReadTimeout is the maximum duration for reading the entire request,
// readTimeout is the maximum duration for reading the entire request,
// including the body.
ReadTimeout = 10 * time.Second
// ReadHeaderTimeout is the amount of time allowed to read request
// headers.
ReadHeaderTimeout = 10 * time.Second
// WriteTimeout is the maximum duration before timing out writes of the
readTimeout = 60 * time.Second
// readHdrTimeout is the amount of time allowed to read request headers.
readHdrTimeout = 60 * time.Second
// writeTimeout is the maximum duration before timing out writes of the
// response.
WriteTimeout = 10 * time.Second
writeTimeout = 60 * time.Second
)
type webConfig struct {
firstRun bool
BindHost string
BindPort int
PortHTTPS int
firstRun bool
BindHost net.IP
BindPort int
BetaBindPort int
PortHTTPS int
// ReadTimeout is an option to pass to http.Server for setting an
// appropriate field.
@@ -62,9 +60,16 @@ type HTTPSServer struct {
type Web struct {
conf *webConfig
forceHTTPS bool
portHTTPS int
httpServer *http.Server // HTTP module
httpsServer HTTPSServer // HTTPS module
// handlerBeta is the handler for new client.
handlerBeta http.Handler
// installerBeta is the pre-install handler for new client.
installerBeta http.Handler
// httpServerBeta is a server for new client.
httpServerBeta *http.Server
}
// CreateWeb - create module
@@ -76,15 +81,20 @@ func CreateWeb(conf *webConfig) *Web {
// Initialize and run the admin Web interface
box := packr.NewBox("../../build/static")
boxBeta := packr.NewBox("../../build2/static")
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
Context.mux.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
Context.mux.Handle("/", withMiddlewares(http.FileServer(box), gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
w.handlerBeta = withMiddlewares(http.FileServer(boxBeta), gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler)
// add handlers for /install paths, we only need them when we're not configured yet
if conf.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
Context.mux.Handle("/install.html", preInstallHandler(http.FileServer(box)))
w.installerBeta = preInstallHandler(http.FileServer(boxBeta))
w.registerInstallHandlers()
// This must be removed in API v1.
w.registerBetaInstallHandlers()
} else {
registerControlHandlers()
}
@@ -109,12 +119,12 @@ func WebCheckPortAvailable(port int) bool {
return true
}
// TLSConfigChanged - called when TLS configuration has changed
func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
log.Debug("Web: applying new TLS configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
web.portHTTPS = tlsConf.PortHTTPS
enabled := tlsConf.Enabled &&
tlsConf.PortHTTPS != 0 &&
@@ -131,7 +141,12 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
web.httpsServer.cond.L.Lock()
if web.httpsServer.server != nil {
_ = web.httpsServer.server.Shutdown(context.TODO())
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
err = web.httpsServer.server.Shutdown(ctx)
cancel()
if err != nil {
log.Debug("error while shutting down HTTP server: %s", err)
}
}
web.httpsServer.enabled = enabled
web.httpsServer.cert = cert
@@ -147,19 +162,40 @@ func (web *Web) Start() {
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown {
printHTTPAddresses("http")
errs := make(chan error, 2)
hostStr := web.conf.BindHost.String()
// we need to have new instance, because after Shutdown() the Server is not usable
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort))
web.httpServer = &http.Server{
ErrorLog: log.StdLog("web: http", log.DEBUG),
Addr: address,
Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BindPort)),
Handler: withMiddlewares(Context.mux, limitRequestBody),
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout,
}
go func() {
errs <- web.httpServer.ListenAndServe()
}()
err := web.httpServer.ListenAndServe()
if web.conf.BetaBindPort != 0 {
web.httpServerBeta = &http.Server{
ErrorLog: log.StdLog("web: http", log.DEBUG),
Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BetaBindPort)),
Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta),
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout,
}
go func() {
betaErr := web.httpServerBeta.ListenAndServe()
if betaErr != nil {
log.Error("starting beta http server: %s", betaErr)
}
}()
}
err := <-errs
if err != http.ErrServerClosed {
cleanupAlways()
log.Fatal(err)
@@ -168,19 +204,28 @@ func (web *Web) Start() {
}
}
// Close - stop HTTP server, possibly waiting for all active connections to be closed
func (web *Web) Close() {
// Close gracefully shuts down the HTTP servers.
func (web *Web) Close(ctx context.Context) {
log.Info("Stopping HTTP server...")
web.httpsServer.cond.L.Lock()
web.httpsServer.shutdown = true
web.httpsServer.cond.L.Unlock()
if web.httpsServer.server != nil {
_ = web.httpsServer.server.Shutdown(context.TODO())
}
if web.httpServer != nil {
_ = web.httpServer.Shutdown(context.TODO())
shut := func(srv *http.Server) {
if srv == nil {
return
}
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Debug("error while shutting down HTTP server: %s", err)
}
}
shut(web.httpsServer.server)
shut(web.httpServer)
shut(web.httpServerBeta)
log.Info("Stopped HTTP server")
}
@@ -204,7 +249,7 @@ func (web *Web) tlsServerLoop() {
web.httpsServer.cond.L.Unlock()
// prepare HTTPS server
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
address := net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.PortHTTPS))
web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: address,
@@ -214,7 +259,7 @@ func (web *Web) tlsServerLoop() {
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCiphers,
},
Handler: Context.mux,
Handler: withMiddlewares(Context.mux, limitRequestBody),
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout,

View File

@@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
)
@@ -25,29 +24,32 @@ const (
// Whois - module context
type Whois struct {
clients *clientsContainer
ipChan chan string
timeoutMsec uint
clients *clientsContainer
ipChan chan net.IP
// Contains IP addresses of clients
// An active IP address is resolved once again after it expires.
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
ipAddrs cache.Cache
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
timeoutMsec uint
}
// Create module context
// initWhois creates the Whois module context.
func initWhois(clients *clientsContainer) *Whois {
w := Whois{}
w.timeoutMsec = 5000
w.clients = clients
w := Whois{
timeoutMsec: 5000,
clients: clients,
ipAddrs: cache.New(cache.Config{
EnableLRU: true,
MaxCount: 10000,
}),
ipChan: make(chan net.IP, 255),
}
cconf := cache.Config{}
cconf.EnableLRU = true
cconf.MaxCount = 10000
w.ipAddrs = cache.New(cconf)
w.ipChan = make(chan string, 255)
go w.workerLoop()
return &w
}
@@ -81,23 +83,16 @@ func whoisParse(data string) map[string]string {
switch k {
case "org-name":
m["orgname"] = trimValue(v)
case "orgname":
fallthrough
case "city":
fallthrough
case "country":
case "city", "country", "orgname":
m[k] = trimValue(v)
case "descr":
if len(descr) == 0 {
descr = v
}
case "netname":
netname = v
case "whois": // "whois: whois.arin.net"
m["whois"] = v
case "referralserver": // "ReferralServer: whois://whois.ripe.net"
if strings.HasPrefix(v, "whois://") {
m["whois"] = v[len("whois://"):]
@@ -105,12 +100,16 @@ func whoisParse(data string) map[string]string {
}
}
// descr or netname -> orgname
_, ok := m["orgname"]
if !ok && len(descr) != 0 {
m["orgname"] = trimValue(descr)
} else if !ok && len(netname) != 0 {
m["orgname"] = trimValue(netname)
if !ok {
// Set orgname from either descr or netname for the frontent.
//
// TODO(a.garipov): Perhaps don't do that in the V1 HTTP API?
if descr != "" {
m["orgname"] = trimValue(descr)
} else if netname != "" {
m["orgname"] = trimValue(netname)
}
}
return m
@@ -120,12 +119,12 @@ func whoisParse(data string) map[string]string {
const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response
func (w *Whois) query(target, serverAddr string) (string, error) {
func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) {
addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" {
target = "n + " + target
}
conn, err := customDialContext(context.TODO(), "tcp", serverAddr)
conn, err := customDialContext(ctx, "tcp", serverAddr)
if err != nil {
return "", err
}
@@ -153,11 +152,11 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
}
// Query WHOIS servers (handle redirects)
func (w *Whois) queryAll(target string) (string, error) {
func (w *Whois) queryAll(ctx context.Context, target string) (string, error) {
server := net.JoinHostPort(defaultServer, defaultPort)
const maxRedirects = 5
for i := 0; i != maxRedirects; i++ {
resp, err := w.query(target, server)
resp, err := w.query(ctx, target, server)
if err != nil {
return "", err
}
@@ -183,9 +182,9 @@ func (w *Whois) queryAll(target string) (string, error) {
}
// Request WHOIS information
func (w *Whois) process(ip string) [][]string {
func (w *Whois) process(ctx context.Context, ip net.IP) [][]string {
data := [][]string{}
resp, err := w.queryAll(ip)
resp, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("Whois: error: %s IP:%s", err, ip)
return data
@@ -209,7 +208,7 @@ func (w *Whois) process(ip string) [][]string {
}
// Begin - begin requesting WHOIS info
func (w *Whois) Begin(ip string) {
func (w *Whois) Begin(ip net.IP) {
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get([]byte(ip))
if len(expire) != 0 {
@@ -232,16 +231,18 @@ func (w *Whois) Begin(ip string) {
}
}
// Get IP address from channel; get WHOIS info; associate info with a client
// workerLoop processes the IP addresses it got from the channel and associates
// the retrieving WHOIS info with a client.
func (w *Whois) workerLoop() {
for {
ip := <-w.ipChan
info := w.process(ip)
info := w.process(context.Background(), ip)
if len(info) == 0 {
continue
}
w.clients.SetWhoisInfo(ip, info)
id := ip.String()
w.clients.SetWhoisInfo(id, info)
}
}

View File

@@ -1,6 +1,7 @@
package home
import (
"context"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
@@ -12,14 +13,19 @@ func prepareTestDNSServer() error {
Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{})
conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"}
return Context.dnsServer.Prepare(conf)
}
// TODO(e.burkov): It's kind of complicated to get rid of network access in this
// test. The thing is that *Whois creates new *net.Dialer each time it requests
// the server, so it becomes hard to simulate handling of request from test even
// with substituted upstream. However, it must be done.
func TestWhois(t *testing.T) {
assert.Nil(t, prepareTestDNSServer())
w := Whois{timeoutMsec: 5000}
resp, err := w.queryAll("8.8.8.8")
resp, err := w.queryAll(context.Background(), "8.8.8.8")
assert.Nil(t, err)
m := whoisParse(resp)
assert.Equal(t, "Google LLC", m["orgname"])

View File

@@ -17,14 +17,26 @@ import (
type logEntryHandler (func(t json.Token, ent *logEntry) error)
var logEntryHandlers = map[string]logEntryHandler{
"CID": func(t json.Token, ent *logEntry) error {
v, ok := t.(string)
if !ok {
return nil
}
ent.ClientID = v
return nil
},
"IP": func(t json.Token, ent *logEntry) error {
v, ok := t.(string)
if !ok {
return nil
}
if len(ent.IP) == 0 {
ent.IP = v
if ent.IP == nil {
ent.IP = net.ParseIP(v)
}
return nil
},
"T": func(t json.Token, ent *logEntry) error {

View File

@@ -8,8 +8,8 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@@ -19,12 +19,13 @@ import (
func TestDecodeLogEntry(t *testing.T) {
logOutput := &bytes.Buffer{}
testutil.ReplaceLogWriter(t, logOutput)
testutil.ReplaceLogLevel(t, log.DEBUG)
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
t.Run("success", func(t *testing.T) {
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
const data = `{"IP":"127.0.0.1",` +
`"CID":"cli42",` +
`"T":"2020-11-25T18:55:56.519796+03:00",` +
`"QH":"an.yandex.ru",` +
`"QT":"A",` +
@@ -47,11 +48,12 @@ func TestDecodeLogEntry(t *testing.T) {
assert.Nil(t, err)
want := &logEntry{
IP: "127.0.0.1",
IP: net.IPv4(127, 0, 0, 1),
Time: time.Date(2020, 11, 25, 15, 55, 56, 519796000, time.UTC),
QHost: "an.yandex.ru",
QType: "A",
QClass: "IN",
ClientID: "cli42",
ClientProto: "",
Answer: ans,
Result: dnsfilter.Result{
@@ -84,7 +86,7 @@ func TestDecodeLogEntry(t *testing.T) {
decodeLogEntry(got, data)
s := logOutput.String()
assert.Equal(t, "", s)
assert.Empty(t, s)
// Correct for time zones.
got.Time = got.Time.UTC()
@@ -172,7 +174,7 @@ func TestDecodeLogEntry(t *testing.T) {
s := logOutput.String()
if tc.want == "" {
assert.Equal(t, "", s)
assert.Empty(t, s)
} else {
assert.True(t, strings.HasSuffix(s, tc.want),
"got %q", s)

View File

@@ -22,10 +22,10 @@ type qlogConfig struct {
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
@@ -14,22 +15,19 @@ import (
// TODO(a.garipov): Use a proper structured approach here.
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 16
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
func (l *queryLog) getClientIP(ip net.IP) (clientIP net.IP) {
if l.conf.AnonymizeClientIP && ip != nil {
const AnonymizeClientIPv4Mask = 16
const AnonymizeClientIPv6Mask = 112
if ip.To4() != nil {
return ip.Mask(net.CIDRMask(AnonymizeClientIPv4Mask, 32))
}
return ip.Mask(net.CIDRMask(AnonymizeClientIPv6Mask, 128))
}
return clientIP
return ip
}
// jobject is a JSON object alias.
@@ -82,6 +80,10 @@ func (l *queryLog) logEntryToJSONEntry(entry *logEntry) (jsonEntry jobject) {
},
}
if entry.ClientID != "" {
jsonEntry["client_id"] = entry.ClientID
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
@@ -138,48 +140,60 @@ func resultRulesToJSONRules(rules []*dnsfilter.ResultRule) (jsonRules []jobject)
return jsonRules
}
func answerToMap(a *dns.Msg) (answers []jobject) {
type dnsAnswer struct {
Type string `json:"type"`
Value string `json:"value"`
TTL uint32 `json:"ttl"`
}
func answerToMap(a *dns.Msg) (answers []*dnsAnswer) {
if a == nil || len(a.Answer) == 0 {
return nil
}
answers = []jobject{}
answers = make([]*dnsAnswer, 0, len(a.Answer))
for _, k := range a.Answer {
header := k.Header()
answer := jobject{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
answer := &dnsAnswer{
Type: dns.TypeToString[header.Rrtype],
TTL: header.Ttl,
}
// try most common record types
// Some special treatment for some well-known types.
//
// TODO(a.garipov): Consider just calling String() for everyone
// instead.
switch v := k.(type) {
case nil:
// Probably unlikely, but go on.
case *dns.A:
answer["value"] = v.A.String()
answer.Value = v.A.String()
case *dns.AAAA:
answer["value"] = v.AAAA.String()
answer.Value = v.AAAA.String()
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
answer.Value = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
answer.Value = v.Target
case *dns.NS:
answer["value"] = v.Ns
answer.Value = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
answer.Value = strings.Join(v.Txt, "\n")
case *dns.TXT:
answer["value"] = v.Txt
answer.Value = strings.Join(v.Txt, "\n")
case *dns.PTR:
answer["value"] = v.Ptr
answer.Value = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
answer.Value = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
answer.Value = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
answer.Value = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
answer.Value = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
answer.Value = v.String()
}
answers = append(answers, answer)
}

View File

@@ -2,7 +2,9 @@
package querylog
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
@@ -36,10 +38,11 @@ type ClientProto string
// Client protocol names.
const (
ClientProtoDOH ClientProto = "doh"
ClientProtoDOQ ClientProto = "doq"
ClientProtoDOT ClientProto = "dot"
ClientProtoPlain ClientProto = ""
ClientProtoDOH ClientProto = "doh"
ClientProtoDOQ ClientProto = "doq"
ClientProtoDOT ClientProto = "dot"
ClientProtoDNSCrypt ClientProto = "dnscrypt"
ClientProtoPlain ClientProto = ""
)
// NewClientProto validates that the client protocol name is valid and returns
@@ -50,6 +53,7 @@ func NewClientProto(s string) (cp ClientProto, err error) {
ClientProtoDOH,
ClientProtoDOQ,
ClientProtoDOT,
ClientProtoDNSCrypt,
ClientProtoPlain:
return cp, nil
@@ -60,13 +64,14 @@ func NewClientProto(s string) (cp ClientProto, err error) {
// logEntry - represents a single log entry
type logEntry struct {
IP string `json:"IP"` // Client IP
IP net.IP `json:"IP"` // Client IP
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
ClientID string `json:"CID,omitempty"`
ClientProto ClientProto `json:"CP"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
@@ -118,14 +123,15 @@ func (l *queryLog) clear() {
l.flushPending = false
l.bufferLock.Unlock()
err := os.Remove(l.logFile + ".1")
if err != nil && !os.IsNotExist(err) {
log.Error("file remove: %s: %s", l.logFile+".1", err)
oldLogFile := l.logFile + ".1"
err := os.Remove(oldLogFile)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing old log file %q: %s", oldLogFile, err)
}
err = os.Remove(l.logFile)
if err != nil && !os.IsNotExist(err) {
log.Error("file remove: %s: %s", l.logFile, err)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing log file %q: %s", l.logFile, err)
}
log.Debug("Query log: cleared")
@@ -147,12 +153,13 @@ func (l *queryLog) Add(params AddParams) {
now := time.Now()
entry := logEntry{
IP: l.getClientIP(params.ClientIP.String()),
IP: l.getClientIP(params.ClientIP),
Time: now,
Result: *params.Result,
Elapsed: params.Elapsed,
Upstream: params.Upstream,
ClientID: params.ClientID,
ClientProto: params.ClientProto,
}
q := params.Question.Question[0]

View File

@@ -1,242 +1,276 @@
package querylog
import (
"fmt"
"math/rand"
"net"
"os"
"sort"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0o755)
return dir
}
// Check adding and loading (with filtering) entries from disk and memory
// TestQueryLog tests adding and loading (with filtering) entries from disk and
// memory.
func TestQueryLog(t *testing.T) {
conf := Config{
l := newQueryLog(Config{
Enabled: true,
FileEnabled: true,
Interval: 1,
MemSize: 100,
BaseDir: aghtest.PrepareTestDir(t),
})
// Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file).
require.Nil(t, l.flushLogBuffer(true))
// Start writing to the second file.
require.Nil(t, l.rotate())
// Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk.
require.Nil(t, l.flushLogBuffer(true))
// Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
type tcAssertion struct {
num int
host string
answer, client net.IP
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
testCases := []struct {
name string
sCr []searchCriteria
want []tcAssertion
}{{
name: "all",
sCr: []searchCriteria{},
want: []tcAssertion{
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
},
}, {
name: "by_domain_strict",
sCr: []searchCriteria{{
criteriaType: ctDomainOrClient,
strict: true,
value: "TEST.example.org",
}},
want: []tcAssertion{{
num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3),
}},
}, {
name: "by_domain_non-strict",
sCr: []searchCriteria{{
criteriaType: ctDomainOrClient,
strict: false,
value: "example.ORG",
}},
want: []tcAssertion{
{num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 1, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
},
}, {
name: "by_client_ip_strict",
sCr: []searchCriteria{{
criteriaType: ctDomainOrClient,
strict: true,
value: "2.2.2.2",
}},
want: []tcAssertion{{
num: 0, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2),
}},
}, {
name: "by_client_ip_non-strict",
sCr: []searchCriteria{{
criteriaType: ctDomainOrClient,
strict: false,
value: "2.2.2",
}},
want: []tcAssertion{
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
},
}}
// get all entries
params := newSearchParams()
entries, _ := l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
params := newSearchParams()
params.searchCriteria = tc.sCr
// search by domain (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: true,
value: "TEST.example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
// search by domain (not strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: false,
value: "example.ORG",
})
entries, _ = l.search(params)
assert.Equal(t, 3, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
// search by client IP (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: true,
value: "2.2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
// search by client IP (part of)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: false,
value: "2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
entries, _ := l.search(params)
require.Len(t, entries, len(tc.want))
for _, want := range tc.want {
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
}
})
}
}
func TestQueryLogOffsetLimit(t *testing.T) {
conf := Config{
l := newQueryLog(Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
BaseDir: aghtest.PrepareTestDir(t),
})
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1")
const (
entNum = 10
firstPageDomain = "first.example.org"
secondPageDomain = "second.example.org"
)
// Add entries to the log.
for i := 0; i < entNum; i++ {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
// add 10 more entries to the log (memory)
for i := 0; i < 10; i++ {
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1")
// Write them to the first file.
require.Nil(t, l.flushLogBuffer(true))
// Add more to the in-memory part of log.
for i := 0; i < entNum; i++ {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// First page
params := newSearchParams()
params.offset = 0
params.limit = 10
entries, _ := l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "first.example.org")
assert.Equal(t, entries[9].QHost, "first.example.org")
// Second page
params.offset = 10
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[9].QHost, "second.example.org")
testCases := []struct {
name string
offset int
limit int
wantLen int
want string
}{{
name: "page_1",
offset: 0,
limit: 10,
wantLen: 10,
want: firstPageDomain,
}, {
name: "page_2",
offset: 10,
limit: 10,
wantLen: 10,
want: secondPageDomain,
}, {
name: "page_2.5",
offset: 15,
limit: 10,
wantLen: 5,
want: secondPageDomain,
}, {
name: "page_3",
offset: 20,
limit: 10,
wantLen: 0,
}}
// Second and a half page
params.offset = 15
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 5, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[4].QHost, "second.example.org")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
params.offset = tc.offset
params.limit = tc.limit
entries, _ := l.search(params)
// Third page
params.offset = 20
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 0, len(entries))
require.Len(t, entries, tc.wantLen)
if tc.wantLen > 0 {
assert.Equal(t, entries[0].QHost, tc.want)
assert.Equal(t, entries[tc.wantLen-1].QHost, tc.want)
}
})
}
}
func TestQueryLogMaxFileScanEntries(t *testing.T) {
conf := Config{
l := newQueryLog(Config{
Enabled: true,
FileEnabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
BaseDir: aghtest.PrepareTestDir(t),
})
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
const entNum = 10
// Add entries to the log.
for i := 0; i < entNum; i++ {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
// Write them to disk.
require.Nil(t, l.flushLogBuffer(true))
params := newSearchParams()
params.maxFileScanEntries = 5 // do not scan more than 5 records
entries, _ := l.search(params)
assert.Equal(t, 5, len(entries))
params.maxFileScanEntries = 0 // disable the limit
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
for _, maxFileScanEntries := range []int{5, 0} {
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
params.maxFileScanEntries = maxFileScanEntries
entries, _ := l.search(params)
assert.Len(t, entries, entNum-maxFileScanEntries)
})
}
}
func TestQueryLogFileDisabled(t *testing.T) {
conf := Config{
l := newQueryLog(Config{
Enabled: true,
FileEnabled: false,
Interval: 1,
MemSize: 2,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
BaseDir: aghtest.PrepareTestDir(t),
})
addEntry(l, "example1.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example2.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example3.org", "1.1.1.1", "2.2.2.1")
// the oldest entry is now removed from mem buffer
addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// The oldest entry is going to be removed from memory buffer.
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
params := newSearchParams()
ll, _ := l.search(params)
assert.Equal(t, 2, len(ll))
require.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost)
}
func addEntry(l *queryLog, host, answerStr, client string) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
a := dns.Msg{}
a.Question = append(a.Question, q.Question[0])
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
func addEntry(l *queryLog, host string, answerStr, client net.IP) {
q := dns.Msg{
Question: []dns.Question{{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
a := dns.Msg{
Question: q.Question,
Answer: []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
},
A: answerStr,
}},
}
answer.A = net.ParseIP(answerStr)
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{
IsFiltered: true,
Reason: dnsfilter.ReasonRewrite,
Reason: dnsfilter.Rewritten,
ServiceName: "SomeService",
Rules: []*dnsfilter.ResultRule{{
FilterListID: 1,
@@ -248,25 +282,28 @@ func addEntry(l *queryLog, host, answerStr, client string) {
Answer: &a,
OrigAnswer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
ClientIP: client,
Upstream: "upstream",
}
l.Add(params)
}
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool {
func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) {
t.Helper()
require.NotNil(t, entry)
assert.Equal(t, host, entry.QHost)
assert.Equal(t, client, entry.IP)
assert.Equal(t, "A", entry.QType)
assert.Equal(t, "IN", entry.QClass)
msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer))
assert.Equal(t, 1, len(msg.Answer))
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String())
return true
msg := &dns.Msg{}
require.Nil(t, msg.Unpack(entry.Answer))
require.Len(t, msg.Answer, 1)
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16()
assert.Equal(t, answer, ip)
}
func testEntries() (entries []*logEntry) {
@@ -332,8 +369,8 @@ func TestLogEntriesByTime_sort(t *testing.T) {
entries := testEntries()
sort.Sort(logEntriesByTimeDesc(entries))
for i := 1; i < len(entries); i++ {
assert.False(t, entries[i].Time.After(entries[i-1].Time),
"%s %s", entries[i].Time, entries[i-1].Time)
for i := range entries[1:] {
assert.False(t, entries[i+1].Time.After(entries[i].Time),
"%s %s", entries[i+1].Time, entries[i].Time)
}
}

View File

@@ -251,7 +251,7 @@ func (q *QLogFile) readNextLine(position int64) (string, int64, error) {
// the goal is to read a chunk of file that includes the line with the specified position.
func (q *QLogFile) initBuffer(position int64) error {
q.bufferStart = int64(0)
if (position - bufferSize) > 0 {
if position > bufferSize {
q.bufferStart = position - bufferSize
}
@@ -264,12 +264,10 @@ func (q *QLogFile) initBuffer(position int64) error {
if q.buffer == nil {
q.buffer = make([]byte, bufferSize)
}
q.bufferLen, err = q.file.Read(q.buffer)
if err != nil {
return err
}
return nil
q.bufferLen, err = q.file.Read(q.buffer)
return err
}
// readProbeLine reads a line that includes the specified position
@@ -280,7 +278,7 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, int64, error) {
// In order to do this, we'll define the boundaries
seekPosition := int64(0)
relativePos := position // position relative to the buffer we're going to read
if (position - maxEntrySize) > 0 {
if position > maxEntrySize {
seekPosition = position - maxEntrySize
relativePos = maxEntrySize
}

View File

@@ -2,347 +2,347 @@ package querylog
import (
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"os"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestQLogFileEmpty(t *testing.T) {
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, 0)
// prepareTestFiles prepares several test query log files, each with the
// specified lines count.
func prepareTestFiles(t *testing.T, filesNum, linesNum int) []string {
t.Helper()
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.Equal(t, int64(0), pos)
// try reading anyway
line, err := q.ReadNext()
assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line)
}
func TestQLogFileLarge(t *testing.T) {
// should be large enough
count := 50000
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, count)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos)
read := 0
var line string
for err == nil {
line, err = q.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
read++
}
if filesNum == 0 {
return []string{}
}
assert.Equal(t, count, read)
assert.Equal(t, io.EOF, err)
}
func TestQLogFileSeekLargeFile(t *testing.T) {
// more or less big file
count := 10000
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, count)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// CASE 1: NOT TOO OLD LINE
testSeekLineQLogFile(t, q, 300)
// CASE 2: OLD LINE
testSeekLineQLogFile(t, q, count-300)
// CASE 3: FIRST LINE
testSeekLineQLogFile(t, q, 0)
// CASE 4: LAST LINE
testSeekLineQLogFile(t, q, count)
// CASE 5: Seek non-existent (too low)
_, _, err = q.SeekTS(123)
assert.NotNil(t, err)
// CASE 6: Seek non-existent (too high)
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
_, _, err = q.SeekTS(ts.UnixNano())
assert.NotNil(t, err)
// CASE 7: "Almost" found
line, err := getQLogFileLine(q, count/2)
assert.Nil(t, err)
// ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp)
_, depth, err := q.SeekTS(timestamp)
assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3))
}
func TestQLogFileSeekSmallFile(t *testing.T) {
// more or less big file
count := 10
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, count)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// CASE 1: NOT TOO OLD LINE
testSeekLineQLogFile(t, q, 2)
// CASE 2: OLD LINE
testSeekLineQLogFile(t, q, count-2)
// CASE 3: FIRST LINE
testSeekLineQLogFile(t, q, 0)
// CASE 4: LAST LINE
testSeekLineQLogFile(t, q, count)
// CASE 5: Seek non-existent (too low)
_, _, err = q.SeekTS(123)
assert.NotNil(t, err)
// CASE 6: Seek non-existent (too high)
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
_, _, err = q.SeekTS(ts.UnixNano())
assert.NotNil(t, err)
// CASE 7: "Almost" found
line, err := getQLogFileLine(q, count/2)
assert.Nil(t, err)
// ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp)
_, depth, err := q.SeekTS(timestamp)
assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3))
}
func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) {
line, err := getQLogFileLine(q, lineNumber)
assert.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts)
// try seeking to that line now
pos, _, err := q.SeekTS(ts)
assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos)
testLine, err := q.ReadNext()
assert.Nil(t, err)
assert.Equal(t, line, testLine)
}
func getQLogFileLine(q *QLogFile, lineNumber int) (string, error) {
_, err := q.SeekStart()
if err != nil {
return "", err
}
for i := 1; i < lineNumber; i++ {
_, err := q.ReadNext()
if err != nil {
return "", err
}
}
return q.ReadNext()
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQLogFile(t *testing.T) {
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, 2)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.True(t, pos > 0)
// read first line
line, err := q.ReadNext()
assert.Nil(t, err)
assert.True(t, strings.Contains(line, "0.0.0.2"), line)
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// read second line
line, err = q.ReadNext()
assert.Nil(t, err)
assert.Equal(t, int64(0), q.position)
assert.True(t, strings.Contains(line, "0.0.0.1"), line)
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// try reading again (there's nothing to read anymore)
line, err = q.ReadNext()
assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line)
}
// prepareTestFile - prepares a test query log file with the specified number of lines
func prepareTestFile(dir string, linesCount int) string {
return prepareTestFiles(dir, 1, linesCount)[0]
}
// prepareTestFiles - prepares several test query log files
// each of them -- with the specified linesCount
func prepareTestFiles(dir string, filesCount, linesCount int) []string {
format := `{"IP":"${IP}","T":"${TIMESTAMP}","QH":"example.org","QT":"A","QC":"IN","Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=","Result":{},"Elapsed":0,"Upstream":"upstream"}`
const strV = "\"%s\""
const nl = "\n"
const format = `{"IP":` + strV + `,"T":` + strV + `,` +
`"QH":"example.org","QT":"A","QC":"IN",` +
`"Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=",` +
`"Result":{},"Elapsed":0,"Upstream":"upstream"}` + nl
lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00")
lineIP := uint32(0)
files := make([]string, filesCount)
for j := 0; j < filesCount; j++ {
f, _ := ioutil.TempFile(dir, "*.txt")
files[filesCount-j-1] = f.Name()
dir := aghtest.PrepareTestDir(t)
for i := 0; i < linesCount; i++ {
files := make([]string, filesNum)
for j := range files {
f, err := ioutil.TempFile(dir, "*.txt")
require.Nil(t, err)
files[filesNum-j-1] = f.Name()
for i := 0; i < linesNum; i++ {
lineIP++
lineTime = lineTime.Add(time.Second)
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, lineIP)
line := format
line = strings.ReplaceAll(line, "${IP}", ip.String())
line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano))
line := fmt.Sprintf(format, ip, lineTime.Format(time.RFC3339Nano))
_, _ = f.WriteString(line)
_, _ = f.WriteString("\n")
_, err = f.WriteString(line)
require.Nil(t, err)
}
}
return files
}
func TestQLogSeek(t *testing.T) {
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
// prepareTestFile prepares a test query log file with the specified number of
// lines.
func prepareTestFile(t *testing.T, linesCount int) string {
t.Helper()
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
{"T":"2020-08-31T18:44:25.376690873+03:00"}
{"T":"2020-08-31T18:44:25.382540454+03:00"}`
f, _ := ioutil.TempFile(testDir, "*.txt")
_, _ = f.WriteString(d)
defer f.Close()
q, err := NewQLogFile(f.Name())
assert.Nil(t, err)
defer q.Close()
target, _ := time.Parse(time.RFC3339, "2020-08-31T18:44:25.376690873+03:00")
_, depth, err := q.SeekTS(target.UnixNano())
assert.Nil(t, err)
assert.Equal(t, 1, depth)
return prepareTestFiles(t, 1, linesCount)[0]
}
func TestQLogSeek_ErrTSTooLate(t *testing.T) {
testDir := prepareTestDir()
// newTestQLogFile creates new *QLogFile for tests and registers the required
// cleanup functions.
func newTestQLogFile(t *testing.T, linesNum int) (file *QLogFile) {
t.Helper()
testFile := prepareTestFile(t, linesNum)
// Create the new QLogFile instance.
file, err := NewQLogFile(testFile)
require.Nil(t, err)
assert.NotNil(t, file)
t.Cleanup(func() {
_ = os.RemoveAll(testDir)
assert.Nil(t, file.Close())
})
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
{"T":"2020-08-31T18:44:25.376690873+03:00"}
{"T":"2020-08-31T18:44:25.382540454+03:00"}
`
f, err := ioutil.TempFile(testDir, "*.txt")
assert.Nil(t, err)
defer f.Close()
_, err = f.WriteString(d)
assert.Nil(t, err)
q, err := NewQLogFile(f.Name())
assert.Nil(t, err)
defer q.Close()
target, err := time.Parse(time.RFC3339, "2020-08-31T18:44:25.382540454+03:00")
assert.Nil(t, err)
_, depth, err := q.SeekTS(target.UnixNano() + int64(time.Second))
assert.Equal(t, ErrTSTooLate, err)
assert.Equal(t, 2, depth)
return file
}
func TestQLogSeek_ErrTSTooEarly(t *testing.T) {
testDir := prepareTestDir()
func TestQLogFile_ReadNext(t *testing.T) {
testCases := []struct {
name string
linesNum int
}{{
name: "empty",
linesNum: 0,
}, {
name: "large",
linesNum: 50000,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q := newTestQLogFile(t, tc.linesNum)
// Calculate the expected position.
fileInfo, err := q.file.Stat()
require.Nil(t, err)
var expPos int64
if expPos = fileInfo.Size(); expPos > 0 {
expPos--
}
// Seek to the start.
pos, err := q.SeekStart()
require.Nil(t, err)
require.EqualValues(t, expPos, pos)
var read int
var line string
for err == nil {
line, err = q.ReadNext()
if err == nil {
assert.NotEmpty(t, line)
read++
}
}
require.Equal(t, io.EOF, err)
assert.Equal(t, tc.linesNum, read)
})
}
}
func TestQLogFile_SeekTS_good(t *testing.T) {
linesCases := []struct {
name string
num int
}{{
name: "large",
num: 10000,
}, {
name: "small",
num: 10,
}}
for _, l := range linesCases {
testCases := []struct {
name string
linesNum int
line int
}{{
name: "not_too_old",
line: 2,
}, {
name: "old",
line: l.num - 2,
}, {
name: "first",
line: 0,
}, {
name: "last",
line: l.num,
}}
q := newTestQLogFile(t, l.num)
for _, tc := range testCases {
t.Run(l.name+"_"+tc.name, func(t *testing.T) {
line, err := getQLogFileLine(q, tc.line)
require.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqualValues(t, 0, ts)
// Try seeking to that line now.
pos, _, err := q.SeekTS(ts)
require.Nil(t, err)
assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext()
require.Nil(t, err)
assert.Equal(t, line, testLine)
})
}
}
}
func TestQLogFile_SeekTS_bad(t *testing.T) {
linesCases := []struct {
name string
num int
}{{
name: "large",
num: 10000,
}, {
name: "small",
num: 10,
}}
for _, l := range linesCases {
testCases := []struct {
name string
ts int64
leq bool
}{{
name: "non-existent_long_ago",
}, {
name: "non-existent_far_ahead",
}, {
name: "almost",
leq: true,
}}
q := newTestQLogFile(t, l.num)
testCases[0].ts = 123
lateTS, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
testCases[1].ts = lateTS.UnixNano()
line, err := getQLogFileLine(q, l.num/2)
require.Nil(t, err)
testCases[2].ts = readQLogTimestamp(line) - 1
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.NotEqualValues(t, 0, tc.ts)
_, depth, err := q.SeekTS(tc.ts)
assert.NotEmpty(t, l.num)
require.NotNil(t, err)
if tc.leq {
assert.LessOrEqual(t, depth, int(math.Log2(float64(l.num))+3))
}
})
}
}
}
func getQLogFileLine(q *QLogFile, lineNumber int) (line string, err error) {
if _, err = q.SeekStart(); err != nil {
return line, err
}
for i := 1; i < lineNumber; i++ {
if _, err = q.ReadNext(); err != nil {
return line, err
}
}
return q.ReadNext()
}
// Check adding and loading (with filtering) entries from disk and memory.
func TestQLogFile(t *testing.T) {
// Create the new QLogFile instance.
q := newTestQLogFile(t, 2)
// Seek to the start.
pos, err := q.SeekStart()
require.Nil(t, err)
assert.Greater(t, pos, int64(0))
// Read first line.
line, err := q.ReadNext()
require.Nil(t, err)
assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// Read second line.
line, err = q.ReadNext()
require.Nil(t, err)
assert.EqualValues(t, 0, q.position)
assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// Try reading again (there's nothing to read anymore).
line, err = q.ReadNext()
require.Equal(t, io.EOF, err)
assert.Empty(t, line)
}
func NewTestQLogFileData(t *testing.T, data string) (file *QLogFile) {
f, err := ioutil.TempFile(aghtest.PrepareTestDir(t), "*.txt")
require.Nil(t, err)
t.Cleanup(func() {
_ = os.RemoveAll(testDir)
assert.Nil(t, f.Close())
})
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
{"T":"2020-08-31T18:44:25.376690873+03:00"}
{"T":"2020-08-31T18:44:25.382540454+03:00"}
`
f, err := ioutil.TempFile(testDir, "*.txt")
assert.Nil(t, err)
defer f.Close()
_, err = f.WriteString(data)
require.Nil(t, err)
_, err = f.WriteString(d)
assert.Nil(t, err)
file, err = NewQLogFile(f.Name())
require.Nil(t, err)
t.Cleanup(func() {
assert.Nil(t, file.Close())
})
q, err := NewQLogFile(f.Name())
assert.Nil(t, err)
defer q.Close()
target, err := time.Parse(time.RFC3339, "2020-08-31T18:44:23.911246629+03:00")
assert.Nil(t, err)
_, depth, err := q.SeekTS(target.UnixNano() - int64(time.Second))
assert.Equal(t, ErrTSTooEarly, err)
assert.Equal(t, 1, depth)
return file
}
func TestQLog_Seek(t *testing.T) {
const nl = "\n"
const strV = "%s"
const recs = `{"T":"` + strV + `","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}` + nl +
`{"T":"` + strV + `"}` + nl +
`{"T":"` + strV + `"}` + nl
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
testCases := []struct {
name string
delta int
wantErr error
wantDepth int
}{{
name: "ok",
delta: 0,
wantErr: nil,
wantDepth: 2,
}, {
name: "too_late",
delta: 2,
wantErr: ErrTSTooLate,
wantDepth: 2,
}, {
name: "too_early",
delta: -2,
wantErr: ErrTSTooEarly,
wantDepth: 1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
data := fmt.Sprintf(recs,
timestamp.Add(-time.Second).Format(time.RFC3339Nano),
timestamp.Format(time.RFC3339Nano),
timestamp.Add(time.Second).Format(time.RFC3339Nano),
)
q := NewTestQLogFileData(t, data)
_, depth, err := q.SeekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano())
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
assert.Equal(t, tc.wantDepth, depth)
})
}
}

View File

@@ -3,110 +3,77 @@ package querylog
import (
"errors"
"io"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestQLogReaderEmpty(t *testing.T) {
r, err := NewQLogReader([]string{})
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// newTestQLogReader creates new *QLogReader for tests and registers the
// required cleanup functions.
func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader) {
t.Helper()
// seek to the start
err = r.SeekStart()
assert.Nil(t, err)
testFiles := prepareTestFiles(t, filesNum, linesNum)
line, err := r.ReadNext()
assert.Equal(t, "", line)
assert.Equal(t, io.EOF, err)
// Create the new QLogReader instance.
reader, err := NewQLogReader(testFiles)
require.Nil(t, err)
assert.NotNil(t, reader)
t.Cleanup(func() {
assert.Nil(t, reader.Close())
})
return reader
}
func TestQLogReaderOneFile(t *testing.T) {
// let's do one small file
count := 10
filesCount := 1
func TestQLogReader(t *testing.T) {
testCases := []struct {
name string
filesNum int
linesNum int
}{{
name: "empty",
filesNum: 0,
linesNum: 0,
}, {
name: "one_file",
filesNum: 1,
linesNum: 10,
}, {
name: "multiple_files",
filesNum: 5,
linesNum: 10000,
}}
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFiles := prepareTestFiles(testDir, filesCount, count)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := newTestQLogReader(t, tc.filesNum, tc.linesNum)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// Seek to the start.
err := r.SeekStart()
require.Nil(t, err)
// seek to the start
err = r.SeekStart()
assert.Nil(t, err)
// Read everything.
var read int
var line string
for err == nil {
line, err = r.ReadNext()
if err == nil {
assert.NotEmpty(t, line)
read++
}
}
// read everything
read := 0
var line string
for err == nil {
line, err = r.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
read++
}
require.Equal(t, io.EOF, err)
assert.Equal(t, tc.filesNum*tc.linesNum, read)
})
}
assert.Equal(t, count*filesCount, read)
assert.Equal(t, io.EOF, err)
}
func TestQLogReaderMultipleFiles(t *testing.T) {
// should be large enough
count := 10000
filesCount := 5
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// seek to the start
err = r.SeekStart()
assert.Nil(t, err)
// read everything
read := 0
var line string
for err == nil {
line, err = r.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
read++
}
}
assert.Equal(t, count*filesCount, read)
assert.Equal(t, io.EOF, err)
}
func TestQLogReader_Seek(t *testing.T) {
count := 10000
filesCount := 2
testDir := prepareTestDir()
t.Cleanup(func() {
_ = os.RemoveAll(testDir)
})
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
t.Cleanup(func() {
_ = r.Close()
})
r := newTestQLogReader(t, 2, 10000)
testCases := []struct {
name string
@@ -114,7 +81,7 @@ func TestQLogReader_Seek(t *testing.T) {
want error
}{{
name: "not_too_old",
time: "2020-02-19T04:04:56.920973+03:00",
time: "2020-02-18T22:39:35.920973+03:00",
want: nil,
}, {
name: "old",
@@ -122,7 +89,7 @@ func TestQLogReader_Seek(t *testing.T) {
want: nil,
}, {
name: "first",
time: "2020-02-19T04:09:55.920973+03:00",
time: "2020-02-18T22:36:36.920973+03:00",
want: nil,
}, {
name: "last",
@@ -147,28 +114,20 @@ func TestQLogReader_Seek(t *testing.T) {
timestamp, err := time.Parse(time.RFC3339Nano, tc.time)
assert.Nil(t, err)
if tc.name == "first" {
assert.True(t, true)
}
err = r.SeekTS(timestamp.UnixNano())
assert.True(t, errors.Is(err, tc.want), err)
assert.True(t, errors.Is(err, tc.want))
})
}
}
func TestQLogReader_ReadNext(t *testing.T) {
count := 10
filesCount := 1
testDir := prepareTestDir()
t.Cleanup(func() {
_ = os.RemoveAll(testDir)
})
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
t.Cleanup(func() {
_ = r.Close()
})
const linesNum = 10
const filesNum = 1
r := newTestQLogReader(t, filesNum, linesNum)
testCases := []struct {
name string
@@ -180,7 +139,7 @@ func TestQLogReader_ReadNext(t *testing.T) {
want: nil,
}, {
name: "too_big",
start: count + 1,
start: linesNum + 1,
want: io.EOF,
}}
@@ -199,70 +158,3 @@ func TestQLogReader_ReadNext(t *testing.T) {
})
}
}
// TODO(e.burkov): Remove the tests below. Make tests above more compelling.
func TestQLogReaderSeek(t *testing.T) {
// more or less big file
count := 10000
filesCount := 2
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// CASE 1: NOT TOO OLD LINE
testSeekLineQLogReader(t, r, 300)
// CASE 2: OLD LINE
testSeekLineQLogReader(t, r, count-300)
// CASE 3: FIRST LINE
testSeekLineQLogReader(t, r, 0)
// CASE 4: LAST LINE
testSeekLineQLogReader(t, r, count)
// CASE 5: Seek non-existent (too low)
err = r.SeekTS(123)
assert.NotNil(t, err)
// CASE 6: Seek non-existent (too high)
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
err = r.SeekTS(ts.UnixNano())
assert.NotNil(t, err)
}
func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) {
line, err := getQLogReaderLine(r, lineNumber)
assert.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts)
// try seeking to that line now
err = r.SeekTS(ts)
assert.Nil(t, err)
testLine, err := r.ReadNext()
assert.Nil(t, err)
assert.Equal(t, line, testLine)
}
func getQLogReaderLine(r *QLogReader, lineNumber int) (string, error) {
err := r.SeekStart()
if err != nil {
return "", err
}
for i := 1; i < lineNumber; i++ {
_, err := r.ReadNext()
if err != nil {
return "", err
}
}
return r.ReadNext()
}

View File

@@ -46,6 +46,7 @@ type AddParams struct {
OrigAnswer *dns.Msg // The response from an upstream server (optional)
Result *dnsfilter.Result // Filtering result (optional)
Elapsed time.Duration // Time spent for processing the request
ClientID string
ClientIP net.IP
Upstream string // Upstream server URL
ClientProto ClientProto

View File

@@ -3,6 +3,7 @@ package querylog
import (
"bytes"
"encoding/json"
"errors"
"os"
"time"
@@ -87,18 +88,19 @@ func (l *queryLog) rotate() error {
from := l.logFile
to := l.logFile + ".1"
if _, err := os.Stat(from); os.IsNotExist(err) {
// do nothing, file doesn't exist
return nil
}
err := os.Rename(from, to)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
log.Error("querylog: failed to rename file: %s", err)
return err
}
log.Debug("querylog: renamed %s -> %s", from, to)
return nil
}

View File

@@ -9,8 +9,13 @@ import (
type criteriaType int
const (
ctDomainOrClient criteriaType = iota // domain name or client IP address
ctFilteringStatus // filtering status
// ctDomainOrClient is for searching by the domain name, the client's IP
// address, or the clinet's ID.
ctDomainOrClient criteriaType = iota
// ctFilteringStatus is for searching by the filtering status.
//
// See (*searchCriteria).ctFilteringStatusCase for details.
ctFilteringStatus
)
const (
@@ -38,9 +43,9 @@ var filteringStatusValues = []string{
// searchCriteria - every search request may contain a list of different search criteria
// we use each of them to match the query
type searchCriteria struct {
value string // search criteria value
criteriaType criteriaType // type of the criteria
strict bool // should we strictly match (equality) or not (indexOf)
value string // search criteria value
}
// quickMatch - quickly checks if the log entry matches this search criteria
@@ -51,7 +56,8 @@ func (c *searchCriteria) quickMatch(line string) bool {
switch c.criteriaType {
case ctDomainOrClient:
return c.quickMatchJSONValue(line, "QH") ||
c.quickMatchJSONValue(line, "IP")
c.quickMatchJSONValue(line, "IP") ||
c.quickMatchJSONValue(line, "CID")
default:
return true
}
@@ -89,21 +95,26 @@ func (c *searchCriteria) match(entry *logEntry) bool {
}
func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool {
clientID := strings.ToLower(entry.ClientID)
qhost := strings.ToLower(entry.QHost)
searchVal := strings.ToLower(c.value)
if c.strict && qhost == searchVal {
return true
}
if !c.strict && strings.Contains(qhost, searchVal) {
if c.strict && (qhost == searchVal || clientID == searchVal) {
return true
}
if c.strict && entry.IP == c.value {
if !c.strict && (strings.Contains(qhost, searchVal) || strings.Contains(clientID, searchVal)) {
return true
}
if !c.strict && strings.Contains(entry.IP, c.value) {
ipStr := entry.IP.String()
if c.strict && ipStr == c.value {
return true
}
if !c.strict && strings.Contains(ipStr, c.value) {
return true
}
return false
}
@@ -116,8 +127,9 @@ func (c *searchCriteria) ctFilteringStatusCase(res dnsfilter.Result) bool {
return res.IsFiltered ||
res.Reason.In(
dnsfilter.NotFilteredAllowList,
dnsfilter.ReasonRewrite,
dnsfilter.RewriteAutoHosts,
dnsfilter.Rewritten,
dnsfilter.RewrittenAutoHosts,
dnsfilter.RewrittenRule,
)
case filteringStatusBlocked:
@@ -137,7 +149,11 @@ func (c *searchCriteria) ctFilteringStatusCase(res dnsfilter.Result) bool {
return res.Reason == dnsfilter.NotFilteredAllowList
case filteringStatusRewritten:
return res.Reason.In(dnsfilter.ReasonRewrite, dnsfilter.RewriteAutoHosts)
return res.Reason.In(
dnsfilter.Rewritten,
dnsfilter.RewrittenAutoHosts,
dnsfilter.RewrittenRule,
)
case filteringStatusSafeSearch:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch

Some files were not shown because too many files have changed in this diff Show More