Merge branch 'master' into 2476-rwmutex
This commit is contained in:
@@ -5,14 +5,9 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitReadCloser(t *testing.T) {
|
||||
@@ -78,11 +79,11 @@ func TestLimitedReadCloser_Read(t *testing.T) {
|
||||
buf := make([]byte, tc.limit+1)
|
||||
|
||||
lreader, err := LimitReadCloser(readCloser, tc.limit)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
n, err := lreader.Read(buf)
|
||||
assert.Equal(t, n, tc.want)
|
||||
assert.Equal(t, tc.err, err)
|
||||
require.Equal(t, tc.err, err)
|
||||
assert.Equal(t, tc.want, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Package testutil contains utilities for testing.
|
||||
package testutil
|
||||
// Package aghtest contains utilities for testing.
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"io"
|
||||
45
internal/aghtest/os.go
Normal file
45
internal/aghtest/os.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// PrepareTestDir returns the full path to temporary created directory and
|
||||
// registers the appropriate cleanup for *t.
|
||||
func PrepareTestDir(t *testing.T) (dir string) {
|
||||
t.Helper()
|
||||
|
||||
wd, err := os.Getwd()
|
||||
require.Nil(t, err)
|
||||
|
||||
dir, err = ioutil.TempDir(wd, "agh-test")
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, dir)
|
||||
|
||||
t.Cleanup(func() {
|
||||
// TODO(e.burkov): Replace with t.TempDir methods after updating
|
||||
// go version to 1.15.
|
||||
start := time.Now()
|
||||
for {
|
||||
err := os.RemoveAll(dir)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if runtime.GOOS != "windows" || time.Since(start) >= 500*time.Millisecond {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
return dir
|
||||
}
|
||||
63
internal/aghtest/resolver.go
Normal file
63
internal/aghtest/resolver.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TestResolver is a Resolver for tests.
|
||||
type TestResolver struct {
|
||||
counter int
|
||||
counterLock sync.Mutex
|
||||
}
|
||||
|
||||
// HostToIPs generates IPv4 and IPv6 from host.
|
||||
//
|
||||
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
|
||||
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
|
||||
hash := sha256.Sum256([]byte(host))
|
||||
|
||||
return net.IP(hash[:4]), net.IP(hash[4:20])
|
||||
}
|
||||
|
||||
// LookupIPAddr implements Resolver interface for *testResolver. It returns the
|
||||
// slice of net.IPAddr with IPv4 and IPv6 instances.
|
||||
func (r *TestResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
|
||||
ipv4, ipv6 := r.HostToIPs(host)
|
||||
addrs := []net.IPAddr{{
|
||||
IP: ipv4,
|
||||
}, {
|
||||
IP: ipv6,
|
||||
}}
|
||||
|
||||
r.counterLock.Lock()
|
||||
defer r.counterLock.Unlock()
|
||||
r.counter++
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver interface for *testResolver. It returns the
|
||||
// slice of IPv4 and IPv6 instances converted to strings.
|
||||
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
|
||||
ipv4, ipv6 := r.HostToIPs(host)
|
||||
|
||||
r.counterLock.Lock()
|
||||
defer r.counterLock.Unlock()
|
||||
r.counter++
|
||||
|
||||
return []string{
|
||||
ipv4.String(),
|
||||
ipv6.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Counter returns the number of requests handled.
|
||||
func (r *TestResolver) Counter() int {
|
||||
r.counterLock.Lock()
|
||||
defer r.counterLock.Unlock()
|
||||
|
||||
return r.counter
|
||||
}
|
||||
175
internal/aghtest/upstream.go
Normal file
175
internal/aghtest/upstream.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package aghtest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// TestUpstream is a mock of real upstream.
|
||||
type TestUpstream struct {
|
||||
// Addr is the address for Address method.
|
||||
Addr string
|
||||
// CName is a map of hostname to canonical name.
|
||||
CName map[string]string
|
||||
// IPv4 is a map of hostname to IPv4.
|
||||
IPv4 map[string][]net.IP
|
||||
// IPv6 is a map of hostname to IPv6.
|
||||
IPv6 map[string][]net.IP
|
||||
// Reverse is a map of address to domain name.
|
||||
Reverse map[string][]string
|
||||
}
|
||||
|
||||
// Exchange implements upstream.Upstream interface for *TestUpstream.
|
||||
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetReply(m)
|
||||
|
||||
if len(m.Question) == 0 {
|
||||
return nil, fmt.Errorf("question should not be empty")
|
||||
}
|
||||
name := m.Question[0].Name
|
||||
|
||||
if cname, ok := u.CName[name]; ok {
|
||||
resp.Answer = append(resp.Answer, &dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeCNAME,
|
||||
},
|
||||
Target: cname,
|
||||
})
|
||||
}
|
||||
|
||||
var hasRec bool
|
||||
var rrType uint16
|
||||
var ips []net.IP
|
||||
switch m.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
rrType = dns.TypeA
|
||||
if ipv4addr, ok := u.IPv4[name]; ok {
|
||||
hasRec = true
|
||||
ips = ipv4addr
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
rrType = dns.TypeAAAA
|
||||
if ipv6addr, ok := u.IPv6[name]; ok {
|
||||
hasRec = true
|
||||
ips = ipv6addr
|
||||
}
|
||||
case dns.TypePTR:
|
||||
names, ok := u.Reverse[name]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
for _, n := range names {
|
||||
resp.Answer = append(resp.Answer, &dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: rrType,
|
||||
},
|
||||
Ptr: n,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
resp.Answer = append(resp.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: rrType,
|
||||
},
|
||||
A: ip,
|
||||
})
|
||||
}
|
||||
|
||||
if len(resp.Answer) == 0 {
|
||||
if hasRec {
|
||||
// Set no error RCode if there are some records for
|
||||
// given Qname but we didn't apply them.
|
||||
resp.SetRcode(m, dns.RcodeSuccess)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
// Set NXDomain RCode otherwise.
|
||||
resp.SetRcode(m, dns.RcodeNameError)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Address implements upstream.Upstream interface for *TestUpstream.
|
||||
func (u *TestUpstream) Address() string {
|
||||
return u.Addr
|
||||
}
|
||||
|
||||
// TestBlockUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestBlockUpstream struct {
|
||||
Hostname string
|
||||
Block bool
|
||||
requestsCount int
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
|
||||
// pair.
|
||||
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
u.requestsCount++
|
||||
|
||||
hash := sha256.Sum256([]byte(u.Hostname))
|
||||
hashToReturn := hex.EncodeToString(hash[:])
|
||||
if !u.Block {
|
||||
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
|
||||
}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.Answer = []dns.RR{
|
||||
&dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
},
|
||||
Txt: []string{
|
||||
hashToReturn,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestBlockUpstream) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// RequestsCount returns the number of handled requests. It's safe for
|
||||
// concurrent use.
|
||||
func (u *TestBlockUpstream) RequestsCount() int {
|
||||
u.lock.Lock()
|
||||
defer u.lock.Unlock()
|
||||
|
||||
return u.requestsCount
|
||||
}
|
||||
|
||||
// TestErrUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type TestErrUpstream struct{}
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, agherr.Error("bad")
|
||||
}
|
||||
|
||||
// Address always returns an empty string.
|
||||
func (u *TestErrUpstream) Address() string {
|
||||
return ""
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
@@ -17,7 +19,12 @@ import (
|
||||
|
||||
const (
|
||||
defaultDiscoverTime = time.Second * 3
|
||||
leaseExpireStatic = 1
|
||||
// leaseExpireStatic is used to define the Expiry field for static
|
||||
// leases.
|
||||
//
|
||||
// TODO(e.burkov): Remove it when static leases determining mechanism
|
||||
// will be improved.
|
||||
leaseExpireStatic = 1
|
||||
)
|
||||
|
||||
var webHandlersRegistered = false
|
||||
@@ -33,6 +40,51 @@ type Lease struct {
|
||||
Expiry time.Time `json:"expires"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for *Lease.
|
||||
func (l *Lease) MarshalJSON() ([]byte, error) {
|
||||
var expiryStr string
|
||||
if expiry := l.Expiry; expiry.Unix() != leaseExpireStatic {
|
||||
// The front-end is waiting for RFC 3999 format of the time
|
||||
// value. It also shouldn't got an Expiry field for static
|
||||
// leases.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
|
||||
expiryStr = expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
type lease Lease
|
||||
return json.Marshal(&struct {
|
||||
HWAddr string `json:"mac"`
|
||||
Expiry string `json:"expires,omitempty"`
|
||||
*lease
|
||||
}{
|
||||
HWAddr: l.HWAddr.String(),
|
||||
Expiry: expiryStr,
|
||||
lease: (*lease)(l),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for *Lease.
|
||||
func (l *Lease) UnmarshalJSON(data []byte) (err error) {
|
||||
type lease Lease
|
||||
aux := struct {
|
||||
HWAddr string `json:"mac"`
|
||||
*lease
|
||||
}{
|
||||
lease: (*lease)(l),
|
||||
}
|
||||
if err = json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.HWAddr, err = net.ParseMAC(aux.HWAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't parse MAC address: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServerConfig - DHCP server configuration
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type ServerConfig struct {
|
||||
@@ -82,14 +134,14 @@ type ServerInterface interface {
|
||||
}
|
||||
|
||||
// Create - create object
|
||||
func Create(config ServerConfig) *Server {
|
||||
func Create(conf ServerConfig) *Server {
|
||||
s := &Server{}
|
||||
|
||||
s.conf.Enabled = config.Enabled
|
||||
s.conf.InterfaceName = config.InterfaceName
|
||||
s.conf.HTTPRegister = config.HTTPRegister
|
||||
s.conf.ConfigModified = config.ConfigModified
|
||||
s.conf.DBFilePath = filepath.Join(config.WorkDir, dbFilename)
|
||||
s.conf.Enabled = conf.Enabled
|
||||
s.conf.InterfaceName = conf.InterfaceName
|
||||
s.conf.HTTPRegister = conf.HTTPRegister
|
||||
s.conf.ConfigModified = conf.ConfigModified
|
||||
s.conf.DBFilePath = filepath.Join(conf.WorkDir, dbFilename)
|
||||
|
||||
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -110,7 +162,7 @@ func Create(config ServerConfig) *Server {
|
||||
}
|
||||
|
||||
var err4, err6 error
|
||||
v4conf := config.Conf4
|
||||
v4conf := conf.Conf4
|
||||
v4conf.Enabled = s.conf.Enabled
|
||||
if len(v4conf.RangeStart) == 0 {
|
||||
v4conf.Enabled = false
|
||||
@@ -119,7 +171,7 @@ func Create(config ServerConfig) *Server {
|
||||
v4conf.notify = s.onNotify
|
||||
s.srv4, err4 = v4Create(v4conf)
|
||||
|
||||
v6conf := config.Conf6
|
||||
v6conf := conf.Conf6
|
||||
v6conf.Enabled = s.conf.Enabled
|
||||
if len(v6conf.RangeStart) == 0 {
|
||||
v6conf.Enabled = false
|
||||
@@ -137,6 +189,9 @@ func Create(config ServerConfig) *Server {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.conf.Conf4 = conf.Conf4
|
||||
s.conf.Conf6 = conf.Conf6
|
||||
|
||||
if s.conf.Enabled && !v4conf.Enabled && !v6conf.Enabled {
|
||||
log.Error("Can't enable DHCP server because neither DHCPv4 nor DHCPv6 servers are configured")
|
||||
return nil
|
||||
@@ -210,14 +265,10 @@ const (
|
||||
LeasesAll = LeasesDynamic | LeasesStatic
|
||||
)
|
||||
|
||||
// Leases returns the list of current DHCP leases (thread-safe)
|
||||
func (s *Server) Leases(flags int) []Lease {
|
||||
result := s.srv4.GetLeases(flags)
|
||||
|
||||
v6leases := s.srv6.GetLeases(flags)
|
||||
result = append(result, v6leases...)
|
||||
|
||||
return result
|
||||
// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for
|
||||
// concurrent use.
|
||||
func (s *Server) Leases(flags int) (leases []Lease) {
|
||||
return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...)
|
||||
}
|
||||
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
@@ -255,17 +306,22 @@ func parseOptionString(s string) (uint8, []byte) {
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
case "ip":
|
||||
ip := net.ParseIP(sval)
|
||||
if ip == nil {
|
||||
return 0, nil
|
||||
}
|
||||
val = ip
|
||||
if ip.To4() != nil {
|
||||
val = ip.To4()
|
||||
}
|
||||
|
||||
// Most DHCP options require IPv4, so do not put the 16-byte
|
||||
// version if we can. Otherwise, the clients will receive weird
|
||||
// data that looks like four IPv4 addresses.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2688.
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
val = ip4
|
||||
} else {
|
||||
val = ip
|
||||
}
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -3,128 +3,188 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func testNotify(flags uint32) {
|
||||
}
|
||||
|
||||
// Leases database store/load
|
||||
// Leases database store/load.
|
||||
func TestDB(t *testing.T) {
|
||||
var err error
|
||||
s := Server{}
|
||||
s.conf.DBFilePath = dbFilename
|
||||
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
notify: testNotify,
|
||||
s := Server{
|
||||
conf: ServerConfig{
|
||||
DBFilePath: dbFilename,
|
||||
},
|
||||
}
|
||||
s.srv4, err = v4Create(conf)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
s.srv4, err = v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: testNotify,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
s.srv6, err = v6Create(V6ServerConf{})
|
||||
assert.True(t, err == nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.100").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
exp1 := time.Now().Add(time.Hour)
|
||||
l.Expiry = exp1
|
||||
s.srv4.(*v4Server).addLease(&l)
|
||||
leases := []Lease{{
|
||||
IP: net.IP{192, 168, 10, 100},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
}, {
|
||||
IP: net.IP{192, 168, 10, 101},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBB},
|
||||
}}
|
||||
|
||||
l2 := Lease{}
|
||||
l2.IP = net.ParseIP("192.168.10.101").To4()
|
||||
l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb")
|
||||
s.srv4.AddStaticLease(l2)
|
||||
srv4, ok := s.srv4.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
srv4.addLease(&leases[0])
|
||||
require.Nil(t, s.srv4.AddStaticLease(leases[1]))
|
||||
|
||||
_ = os.Remove("leases.db")
|
||||
s.dbStore()
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, os.Remove(dbFilename))
|
||||
})
|
||||
s.srv4.ResetLeases(nil)
|
||||
|
||||
s.dbLoad()
|
||||
|
||||
ll := s.srv4.GetLeases(LeasesAll)
|
||||
require.Len(t, ll, len(leases))
|
||||
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
|
||||
assert.Equal(t, "192.168.10.101", ll[0].IP.String())
|
||||
assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix())
|
||||
assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr)
|
||||
assert.Equal(t, leases[1].IP, ll[0].IP)
|
||||
assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix())
|
||||
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
|
||||
assert.Equal(t, "192.168.10.100", ll[1].IP.String())
|
||||
assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix())
|
||||
|
||||
_ = os.Remove("leases.db")
|
||||
assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr)
|
||||
assert.Equal(t, leases[0].IP, ll[1].IP)
|
||||
assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix())
|
||||
}
|
||||
|
||||
func TestIsValidSubnetMask(t *testing.T) {
|
||||
assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0}))
|
||||
assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0}))
|
||||
assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0}))
|
||||
assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0}))
|
||||
assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1}))
|
||||
testCases := []struct {
|
||||
mask net.IP
|
||||
want bool
|
||||
}{{
|
||||
mask: net.IP{255, 255, 255, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 254, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 252, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 253, 0},
|
||||
}, {
|
||||
mask: net.IP{255, 255, 255, 1},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.mask.String(), func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, isValidSubnetMask(tc.mask))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeLeases(t *testing.T) {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
dynLeases := []*Lease{{
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 4},
|
||||
}, {
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 5},
|
||||
}}
|
||||
|
||||
lease := &Lease{}
|
||||
lease.HWAddr = []byte{1, 2, 3, 4}
|
||||
dynLeases = append(dynLeases, lease)
|
||||
lease = new(Lease)
|
||||
lease.HWAddr = []byte{1, 2, 3, 5}
|
||||
dynLeases = append(dynLeases, lease)
|
||||
|
||||
lease = new(Lease)
|
||||
lease.HWAddr = []byte{1, 2, 3, 4}
|
||||
lease.IP = []byte{0, 2, 3, 4}
|
||||
staticLeases = append(staticLeases, lease)
|
||||
lease = new(Lease)
|
||||
lease.HWAddr = []byte{2, 2, 3, 4}
|
||||
staticLeases = append(staticLeases, lease)
|
||||
staticLeases := []*Lease{{
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 4},
|
||||
IP: net.IP{0, 2, 3, 4},
|
||||
}, {
|
||||
HWAddr: net.HardwareAddr{2, 2, 3, 4},
|
||||
}}
|
||||
|
||||
leases := normalizeLeases(staticLeases, dynLeases)
|
||||
require.Len(t, leases, 3)
|
||||
|
||||
assert.True(t, len(leases) == 3)
|
||||
assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4}))
|
||||
assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4}))
|
||||
assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4}))
|
||||
assert.True(t, bytes.Equal(leases[2].HWAddr, []byte{1, 2, 3, 5}))
|
||||
assert.Equal(t, leases[0].HWAddr, dynLeases[0].HWAddr)
|
||||
assert.Equal(t, leases[0].IP, staticLeases[0].IP)
|
||||
assert.Equal(t, leases[1].HWAddr, staticLeases[1].HWAddr)
|
||||
assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr)
|
||||
}
|
||||
|
||||
func TestOptions(t *testing.T) {
|
||||
code, val := parseOptionString(" 12 hex abcdef ")
|
||||
assert.Equal(t, uint8(12), code)
|
||||
assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val))
|
||||
testCases := []struct {
|
||||
name string
|
||||
optStr string
|
||||
wantVal []byte
|
||||
wantCode uint8
|
||||
}{{
|
||||
name: "success_hex",
|
||||
optStr: "12 hex abcdef",
|
||||
wantVal: []byte{0xab, 0xcd, 0xef},
|
||||
wantCode: 12,
|
||||
}, {
|
||||
name: "bad_hex",
|
||||
optStr: "12 hex abcdefx",
|
||||
wantVal: nil,
|
||||
wantCode: 0,
|
||||
}, {
|
||||
name: "success_ip",
|
||||
optStr: "123 ip 1.2.3.4",
|
||||
wantVal: net.IP{1, 2, 3, 4},
|
||||
wantCode: 123,
|
||||
}, {
|
||||
name: "success_ipv6",
|
||||
optStr: "123 ip ::1234",
|
||||
wantVal: net.IP{
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
0, 0, 0x12, 0x34,
|
||||
},
|
||||
wantCode: 123,
|
||||
}, {
|
||||
name: "bad_code",
|
||||
optStr: "256 ip 1.1.1.1",
|
||||
wantVal: nil,
|
||||
wantCode: 0,
|
||||
}, {
|
||||
name: "negative_code",
|
||||
optStr: "-1 ip 1.1.1.1",
|
||||
wantVal: nil,
|
||||
wantCode: 0,
|
||||
}, {
|
||||
name: "bad_ip",
|
||||
optStr: "12 ip 1.1.1.1x",
|
||||
wantVal: nil,
|
||||
wantCode: 0,
|
||||
}, {
|
||||
name: "bad_mode",
|
||||
wantVal: nil,
|
||||
optStr: "12 x 1.1.1.1",
|
||||
wantCode: 0,
|
||||
}}
|
||||
|
||||
code, _ = parseOptionString(" 12 hex abcdef1 ")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
|
||||
code, val = parseOptionString("123 ip 1.2.3.4")
|
||||
assert.Equal(t, uint8(123), code)
|
||||
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
|
||||
|
||||
code, _ = parseOptionString("256 ip 1.1.1.1")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
code, _ = parseOptionString("-1 ip 1.1.1.1")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
code, _ = parseOptionString("12 ip 1.1.1.1x")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
code, _ = parseOptionString("12 x 1.1.1.1")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code, val := parseOptionString(tc.optStr)
|
||||
require.Equal(t, tc.wantCode, code)
|
||||
if tc.wantVal != nil {
|
||||
assert.Equal(t, tc.wantVal, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,15 +14,17 @@ func isTimeout(err error) bool {
|
||||
return operr.Timeout()
|
||||
}
|
||||
|
||||
func parseIPv4(text string) (net.IP, error) {
|
||||
result := net.ParseIP(text)
|
||||
if result == nil {
|
||||
return nil, fmt.Errorf("%s is not an IP address", text)
|
||||
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("%v is not an IP address", ip)
|
||||
}
|
||||
if result.To4() == nil {
|
||||
return nil, fmt.Errorf("%s is not an IPv4 address", text)
|
||||
|
||||
ip4 = ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
|
||||
}
|
||||
return result.To4(), nil
|
||||
|
||||
return ip4, nil
|
||||
}
|
||||
|
||||
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
|
||||
|
||||
@@ -2,17 +2,16 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/sysutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/jsonutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@@ -22,44 +21,19 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string,
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// []Lease -> JSON
|
||||
func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string {
|
||||
leases := []map[string]string{}
|
||||
for _, l := range inputLeases {
|
||||
lease := map[string]string{
|
||||
"mac": l.HWAddr.String(),
|
||||
"ip": l.IP.String(),
|
||||
"hostname": l.Hostname,
|
||||
}
|
||||
|
||||
if includeExpires {
|
||||
lease["expires"] = l.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
return leases
|
||||
}
|
||||
|
||||
type v4ServerConfJSON struct {
|
||||
GatewayIP string `json:"gateway_ip"`
|
||||
SubnetMask string `json:"subnet_mask"`
|
||||
RangeStart string `json:"range_start"`
|
||||
RangeEnd string `json:"range_end"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
SubnetMask net.IP `json:"subnet_mask"`
|
||||
RangeStart net.IP `json:"range_start"`
|
||||
RangeEnd net.IP `json:"range_end"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON {
|
||||
return v4ServerConfJSON{
|
||||
GatewayIP: c.GatewayIP,
|
||||
SubnetMask: c.SubnetMask,
|
||||
RangeStart: c.RangeStart,
|
||||
RangeEnd: c.RangeEnd,
|
||||
LeaseDuration: c.LeaseDuration,
|
||||
func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
|
||||
if j == nil {
|
||||
return V4ServerConf{}
|
||||
}
|
||||
}
|
||||
|
||||
func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
|
||||
return V4ServerConf{
|
||||
GatewayIP: j.GatewayIP,
|
||||
SubnetMask: j.SubnetMask,
|
||||
@@ -70,43 +44,45 @@ func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
|
||||
}
|
||||
|
||||
type v6ServerConfJSON struct {
|
||||
RangeStart string `json:"range_start"`
|
||||
RangeStart net.IP `json:"range_start"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v6ServerConfToJSON(c V6ServerConf) v6ServerConfJSON {
|
||||
return v6ServerConfJSON{
|
||||
RangeStart: c.RangeStart,
|
||||
LeaseDuration: c.LeaseDuration,
|
||||
func v6JSONToServerConf(j *v6ServerConfJSON) V6ServerConf {
|
||||
if j == nil {
|
||||
return V6ServerConf{}
|
||||
}
|
||||
}
|
||||
|
||||
func v6JSONToServerConf(j v6ServerConfJSON) V6ServerConf {
|
||||
return V6ServerConf{
|
||||
RangeStart: j.RangeStart,
|
||||
LeaseDuration: j.LeaseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// dhcpStatusResponse is the response for /control/dhcp/status endpoint.
|
||||
type dhcpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IfaceName string `json:"interface_name"`
|
||||
V4 V4ServerConf `json:"v4"`
|
||||
V6 V6ServerConf `json:"v6"`
|
||||
Leases []Lease `json:"leases"`
|
||||
StaticLeases []Lease `json:"static_leases"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
leases := convertLeases(s.Leases(LeasesDynamic), true)
|
||||
staticLeases := convertLeases(s.Leases(LeasesStatic), false)
|
||||
|
||||
v4conf := V4ServerConf{}
|
||||
s.srv4.WriteDiskConfig4(&v4conf)
|
||||
|
||||
v6conf := V6ServerConf{}
|
||||
s.srv6.WriteDiskConfig6(&v6conf)
|
||||
|
||||
status := map[string]interface{}{
|
||||
"enabled": s.conf.Enabled,
|
||||
"interface_name": s.conf.InterfaceName,
|
||||
"v4": v4ServerConfToJSON(v4conf),
|
||||
"v6": v6ServerConfToJSON(v6conf),
|
||||
"leases": leases,
|
||||
"static_leases": staticLeases,
|
||||
status := &dhcpStatusResponse{
|
||||
Enabled: s.conf.Enabled,
|
||||
IfaceName: s.conf.InterfaceName,
|
||||
V4: V4ServerConf{},
|
||||
V6: V6ServerConf{},
|
||||
}
|
||||
|
||||
s.srv4.WriteDiskConfig4(&status.V4)
|
||||
s.srv6.WriteDiskConfig6(&status.V6)
|
||||
|
||||
status.Leases = s.Leases(LeasesDynamic)
|
||||
status.StaticLeases = s.Leases(LeasesStatic)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(status)
|
||||
if err != nil {
|
||||
@@ -115,27 +91,72 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type staticLeaseJSON struct {
|
||||
HWAddr string `json:"mac"`
|
||||
IP string `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
func (s *Server) enableDHCP(ifaceName string) (code int, err error) {
|
||||
var hasStaticIP bool
|
||||
hasStaticIP, err = sysutil.IfaceHasStaticIP(ifaceName)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
// ErrPermission may happen here on Linux systems where
|
||||
// AdGuard Home is installed using Snap. That doesn't
|
||||
// necessarily mean that the machine doesn't have
|
||||
// a static IP, so we can assume that it has and go on.
|
||||
// If the machine doesn't, we'll get an error later.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2667.
|
||||
//
|
||||
// TODO(a.garipov): I was thinking about moving this
|
||||
// into IfaceHasStaticIP, but then we wouldn't be able
|
||||
// to log it. Think about it more.
|
||||
log.Info("error while checking static ip: %s; "+
|
||||
"assuming machine has static ip and going on", err)
|
||||
hasStaticIP = true
|
||||
} else if errors.Is(err, sysutil.ErrNoStaticIPInfo) {
|
||||
// Couldn't obtain a definitive answer. Assume static
|
||||
// IP an go on.
|
||||
log.Info("can't check for static ip; " +
|
||||
"assuming machine has static ip and going on")
|
||||
hasStaticIP = true
|
||||
} else {
|
||||
err = fmt.Errorf("checking static ip: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStaticIP {
|
||||
err = sysutil.IfaceSetStaticIP(ifaceName)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("setting static ip: %w", err)
|
||||
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
return http.StatusBadRequest, fmt.Errorf("starting dhcp server: %w", err)
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type dhcpServerConfigJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
InterfaceName string `json:"interface_name"`
|
||||
V4 v4ServerConfJSON `json:"v4"`
|
||||
V6 v6ServerConfJSON `json:"v6"`
|
||||
V4 *v4ServerConfJSON `json:"v4"`
|
||||
V6 *v6ServerConfJSON `json:"v6"`
|
||||
InterfaceName string `json:"interface_name"`
|
||||
Enabled nullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
newconfig := dhcpServerConfigJSON{}
|
||||
newconfig.Enabled = s.conf.Enabled
|
||||
newconfig.InterfaceName = s.conf.InterfaceName
|
||||
conf := dhcpServerConfigJSON{}
|
||||
conf.Enabled = boolToNullBool(s.conf.Enabled)
|
||||
conf.InterfaceName = s.conf.InterfaceName
|
||||
|
||||
js, err := jsonutil.DecodeObject(&newconfig, r.Body)
|
||||
err := json.NewDecoder(r.Body).Decode(&conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
|
||||
httpError(r, w, http.StatusBadRequest,
|
||||
"failed to parse new dhcp config json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -144,80 +165,91 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
v4Enabled := false
|
||||
v6Enabled := false
|
||||
|
||||
if js.Exists("v4") {
|
||||
v4conf := v4JSONToServerConf(newconfig.V4)
|
||||
v4conf.Enabled = newconfig.Enabled
|
||||
if len(v4conf.RangeStart) == 0 {
|
||||
v4conf.Enabled = false
|
||||
if conf.V4 != nil {
|
||||
v4Conf := v4JSONToServerConf(conf.V4)
|
||||
v4Conf.Enabled = conf.Enabled == nbTrue
|
||||
if len(v4Conf.RangeStart) == 0 {
|
||||
v4Conf.Enabled = false
|
||||
}
|
||||
v4Enabled = v4conf.Enabled
|
||||
v4conf.InterfaceName = newconfig.InterfaceName
|
||||
|
||||
v4Enabled = v4Conf.Enabled
|
||||
v4Conf.InterfaceName = conf.InterfaceName
|
||||
|
||||
c4 := V4ServerConf{}
|
||||
s.srv4.WriteDiskConfig4(&c4)
|
||||
v4conf.notify = c4.notify
|
||||
v4conf.ICMPTimeout = c4.ICMPTimeout
|
||||
v4Conf.notify = c4.notify
|
||||
v4Conf.ICMPTimeout = c4.ICMPTimeout
|
||||
|
||||
s4, err = v4Create(v4conf)
|
||||
s4, err = v4Create(v4Conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv4 configuration: %s", err)
|
||||
httpError(r, w, http.StatusBadRequest,
|
||||
"invalid dhcpv4 configuration: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if js.Exists("v6") {
|
||||
v6conf := v6JSONToServerConf(newconfig.V6)
|
||||
v6conf.Enabled = newconfig.Enabled
|
||||
if len(v6conf.RangeStart) == 0 {
|
||||
v6conf.Enabled = false
|
||||
if conf.V6 != nil {
|
||||
v6Conf := v6JSONToServerConf(conf.V6)
|
||||
v6Conf.Enabled = conf.Enabled == nbTrue
|
||||
if len(v6Conf.RangeStart) == 0 {
|
||||
v6Conf.Enabled = false
|
||||
}
|
||||
v6Enabled = v6conf.Enabled
|
||||
v6conf.InterfaceName = newconfig.InterfaceName
|
||||
v6conf.notify = s.onNotify
|
||||
s6, err = v6Create(v6conf)
|
||||
|
||||
// Don't overwrite the RA/SLAAC settings from the config file.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps include them into the request to
|
||||
// allow changing them from the HTTP API?
|
||||
v6Conf.RASLAACOnly = s.conf.Conf6.RASLAACOnly
|
||||
v6Conf.RAAllowSLAAC = s.conf.Conf6.RAAllowSLAAC
|
||||
|
||||
v6Enabled = v6Conf.Enabled
|
||||
v6Conf.InterfaceName = conf.InterfaceName
|
||||
v6Conf.notify = s.onNotify
|
||||
|
||||
s6, err = v6Create(v6Conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv6 configuration: %s", err)
|
||||
httpError(r, w, http.StatusBadRequest,
|
||||
"invalid dhcpv6 configuration: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if newconfig.Enabled && !v4Enabled && !v6Enabled {
|
||||
httpError(r, w, http.StatusBadRequest, "DHCPv4 or DHCPv6 configuration must be complete")
|
||||
if conf.Enabled == nbTrue && !v4Enabled && !v6Enabled {
|
||||
httpError(r, w, http.StatusBadRequest,
|
||||
"dhcpv4 or dhcpv6 configuration must be complete")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.Stop()
|
||||
|
||||
if js.Exists("enabled") {
|
||||
s.conf.Enabled = newconfig.Enabled
|
||||
if conf.Enabled != nbNull {
|
||||
s.conf.Enabled = conf.Enabled == nbTrue
|
||||
}
|
||||
|
||||
if js.Exists("interface_name") {
|
||||
s.conf.InterfaceName = newconfig.InterfaceName
|
||||
if conf.InterfaceName != "" {
|
||||
s.conf.InterfaceName = conf.InterfaceName
|
||||
}
|
||||
|
||||
if s4 != nil {
|
||||
s.srv4 = s4
|
||||
}
|
||||
|
||||
if s6 != nil {
|
||||
s.srv6 = s6
|
||||
}
|
||||
|
||||
s.conf.ConfigModified()
|
||||
s.dbLoad()
|
||||
|
||||
if s.conf.Enabled {
|
||||
staticIP, err := sysutil.IfaceHasStaticIP(newconfig.InterfaceName)
|
||||
if !staticIP && err == nil {
|
||||
err = sysutil.IfaceSetStaticIP(newconfig.InterfaceName)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to configure static IP: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Start()
|
||||
var code int
|
||||
code, err = s.enableDHCP(conf.InterfaceName)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
|
||||
httpError(r, w, code, "enabling dhcp: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -225,15 +257,15 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
GatewayIP string `json:"gateway_ip"`
|
||||
GatewayIP net.IP `json:"gateway_ip"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Addrs4 []string `json:"ipv4_addresses"`
|
||||
Addrs6 []string `json:"ipv6_addresses"`
|
||||
Addrs4 []net.IP `json:"ipv4_addresses"`
|
||||
Addrs6 []net.IP `json:"ipv6_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]interface{}{}
|
||||
response := map[string]netInterfaceJSON{}
|
||||
|
||||
ifaces, err := util.GetValidNetInterfaces()
|
||||
if err != nil {
|
||||
@@ -277,9 +309,9 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String())
|
||||
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP)
|
||||
} else {
|
||||
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String())
|
||||
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP)
|
||||
}
|
||||
}
|
||||
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
|
||||
@@ -295,6 +327,40 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// dhcpSearchOtherResult contains information about other DHCP server for
|
||||
// specific network interface.
|
||||
type dhcpSearchOtherResult struct {
|
||||
Found string `json:"found,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// dhcpStaticIPStatus contains information about static IP address for DHCP
|
||||
// server.
|
||||
type dhcpStaticIPStatus struct {
|
||||
Static string `json:"static"`
|
||||
IP string `json:"ip,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// dhcpSearchV4Result contains information about DHCPv4 server for specific
|
||||
// network interface.
|
||||
type dhcpSearchV4Result struct {
|
||||
OtherServer dhcpSearchOtherResult `json:"other_server"`
|
||||
StaticIP dhcpStaticIPStatus `json:"static_ip"`
|
||||
}
|
||||
|
||||
// dhcpSearchV6Result contains information about DHCPv6 server for specific
|
||||
// network interface.
|
||||
type dhcpSearchV6Result struct {
|
||||
OtherServer dhcpSearchOtherResult `json:"other_server"`
|
||||
}
|
||||
|
||||
// dhcpSearchResult is a response for /control/dhcp/find_active_dhcp endpoint.
|
||||
type dhcpSearchResult struct {
|
||||
V4 dhcpSearchV4Result `json:"v4"`
|
||||
V6 dhcpSearchV6Result `json:"v6"`
|
||||
}
|
||||
|
||||
// Perform the following tasks:
|
||||
// . Search for another DHCP server running
|
||||
// . Check if a static IP is configured for the network interface
|
||||
@@ -317,50 +383,42 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
result := dhcpSearchResult{
|
||||
V4: dhcpSearchV4Result{
|
||||
OtherServer: dhcpSearchOtherResult{},
|
||||
StaticIP: dhcpStaticIPStatus{},
|
||||
},
|
||||
V6: dhcpSearchV6Result{
|
||||
OtherServer: dhcpSearchOtherResult{},
|
||||
},
|
||||
}
|
||||
|
||||
found4, err4 := CheckIfOtherDHCPServersPresentV4(interfaceName)
|
||||
|
||||
staticIP := map[string]interface{}{}
|
||||
isStaticIP, err := sysutil.IfaceHasStaticIP(interfaceName)
|
||||
staticIPStatus := "yes"
|
||||
if err != nil {
|
||||
staticIPStatus = "error"
|
||||
staticIP["error"] = err.Error()
|
||||
result.V4.StaticIP.Static = "error"
|
||||
result.V4.StaticIP.Error = err.Error()
|
||||
} else if !isStaticIP {
|
||||
staticIPStatus = "no"
|
||||
staticIP["ip"] = util.GetSubnet(interfaceName)
|
||||
result.V4.StaticIP.Static = "no"
|
||||
result.V4.StaticIP.IP = util.GetSubnet(interfaceName).String()
|
||||
}
|
||||
staticIP["static"] = staticIPStatus
|
||||
|
||||
v4 := map[string]interface{}{}
|
||||
othSrv := map[string]interface{}{}
|
||||
foundVal := "no"
|
||||
if found4 {
|
||||
foundVal = "yes"
|
||||
result.V4.OtherServer.Found = "yes"
|
||||
} else if err4 != nil {
|
||||
foundVal = "error"
|
||||
othSrv["error"] = err4.Error()
|
||||
result.V4.OtherServer.Found = "error"
|
||||
result.V4.OtherServer.Error = err4.Error()
|
||||
}
|
||||
othSrv["found"] = foundVal
|
||||
v4["other_server"] = othSrv
|
||||
v4["static_ip"] = staticIP
|
||||
|
||||
found6, err6 := CheckIfOtherDHCPServersPresentV6(interfaceName)
|
||||
|
||||
v6 := map[string]interface{}{}
|
||||
othSrv = map[string]interface{}{}
|
||||
foundVal = "no"
|
||||
if found6 {
|
||||
foundVal = "yes"
|
||||
result.V6.OtherServer.Found = "yes"
|
||||
} else if err6 != nil {
|
||||
foundVal = "error"
|
||||
othSrv["error"] = err6.Error()
|
||||
result.V6.OtherServer.Found = "error"
|
||||
result.V6.OtherServer.Error = err6.Error()
|
||||
}
|
||||
othSrv["found"] = foundVal
|
||||
v6["other_server"] = othSrv
|
||||
|
||||
result := map[string]interface{}{}
|
||||
result["v4"] = v4
|
||||
result["v6"] = v6
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(result)
|
||||
@@ -371,103 +429,75 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
lj := staticLeaseJSON{}
|
||||
lj := Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(&lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP(lj.IP)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
mac, err := net.ParseMAC(lj.HWAddr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid MAC")
|
||||
return
|
||||
}
|
||||
if lj.IP == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = s.srv6.AddStaticLease(lease)
|
||||
ip4 := lj.IP.To4()
|
||||
|
||||
if ip4 == nil {
|
||||
lj.IP = lj.IP.To16()
|
||||
|
||||
err = s.srv6.AddStaticLease(lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip, _ = parseIPv4(lj.IP)
|
||||
if ip == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
return
|
||||
}
|
||||
|
||||
mac, err := net.ParseMAC(lj.HWAddr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid MAC")
|
||||
return
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = s.srv4.AddStaticLease(lease)
|
||||
lj.IP = ip4
|
||||
err = s.srv4.AddStaticLease(lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
lj := staticLeaseJSON{}
|
||||
lj := Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(&lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP(lj.IP)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
mac, err := net.ParseMAC(lj.HWAddr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid MAC")
|
||||
return
|
||||
}
|
||||
if lj.IP == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = s.srv6.RemoveStaticLease(lease)
|
||||
ip4 := lj.IP.To4()
|
||||
|
||||
if ip4 == nil {
|
||||
lj.IP = lj.IP.To16()
|
||||
|
||||
err = s.srv6.RemoveStaticLease(lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ip, _ = parseIPv4(lj.IP)
|
||||
if ip == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
return
|
||||
}
|
||||
|
||||
mac, _ := net.ParseMAC(lj.HWAddr)
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = s.srv4.RemoveStaticLease(lease)
|
||||
lj.IP = ip4
|
||||
err = s.srv4.RemoveStaticLease(lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_notImplemented(t *testing.T) {
|
||||
@@ -14,7 +15,7 @@ func TestServer_notImplemented(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
|
||||
assert.Nil(t, err)
|
||||
require.Nil(t, err)
|
||||
|
||||
h(w, r)
|
||||
assert.Equal(t, http.StatusNotImplemented, w.Code)
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/hugelgupf/socketpair"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
|
||||
58
internal/dhcpd/nullbool.go
Normal file
58
internal/dhcpd/nullbool.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// nullBool is a nullable boolean. Use these in JSON requests and responses
|
||||
// instead of pointers to bool.
|
||||
//
|
||||
// TODO(a.garipov): Inspect uses of *bool, move this type into some new package
|
||||
// if we need it somewhere else.
|
||||
type nullBool uint8
|
||||
|
||||
// nullBool values
|
||||
const (
|
||||
nbNull nullBool = iota
|
||||
nbTrue
|
||||
nbFalse
|
||||
)
|
||||
|
||||
// String implements the fmt.Stringer interface for nullBool.
|
||||
func (nb nullBool) String() (s string) {
|
||||
switch nb {
|
||||
case nbNull:
|
||||
return "null"
|
||||
case nbTrue:
|
||||
return "true"
|
||||
case nbFalse:
|
||||
return "false"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("!invalid nullBool %d", uint8(nb))
|
||||
}
|
||||
|
||||
// boolToNullBool converts a bool into a nullBool.
|
||||
func boolToNullBool(cond bool) (nb nullBool) {
|
||||
if cond {
|
||||
return nbTrue
|
||||
}
|
||||
|
||||
return nbFalse
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for *nullBool.
|
||||
func (nb *nullBool) UnmarshalJSON(b []byte) (err error) {
|
||||
if len(b) == 0 || bytes.Equal(b, []byte("null")) {
|
||||
*nb = nbNull
|
||||
} else if bytes.Equal(b, []byte("true")) {
|
||||
*nb = nbTrue
|
||||
} else if bytes.Equal(b, []byte("false")) {
|
||||
*nb = nbFalse
|
||||
} else {
|
||||
return fmt.Errorf("invalid nullBool value %q", b)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
69
internal/dhcpd/nullbool_test.go
Normal file
69
internal/dhcpd/nullbool_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNullBool_UnmarshalText(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErrMsg string
|
||||
want nullBool
|
||||
}{{
|
||||
name: "empty",
|
||||
data: []byte{},
|
||||
wantErrMsg: "",
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "null",
|
||||
data: []byte("null"),
|
||||
wantErrMsg: "",
|
||||
want: nbNull,
|
||||
}, {
|
||||
name: "true",
|
||||
data: []byte("true"),
|
||||
wantErrMsg: "",
|
||||
want: nbTrue,
|
||||
}, {
|
||||
name: "false",
|
||||
data: []byte("false"),
|
||||
wantErrMsg: "",
|
||||
want: nbFalse,
|
||||
}, {
|
||||
name: "invalid",
|
||||
data: []byte("flase"),
|
||||
wantErrMsg: `invalid nullBool value "flase"`,
|
||||
want: nbNull,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got nullBool
|
||||
err := got.UnmarshalJSON(tc.data)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.Nil(t, err)
|
||||
} else {
|
||||
require.NotNil(t, err)
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
want := nbTrue
|
||||
var got struct {
|
||||
A nullBool
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(`{"A":true}`), &got)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, want, got.A)
|
||||
})
|
||||
}
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type raCtx struct {
|
||||
raAllowSlaac bool // send RA packets without MO flags
|
||||
raSlaacOnly bool // send RA packets with MO flags
|
||||
raAllowSLAAC bool // send RA packets without MO flags
|
||||
raSLAACOnly bool // send RA packets with MO flags
|
||||
ipAddr net.IP // source IP address (link-local-unicast)
|
||||
dnsIPAddr net.IP // IP address for DNS Server option
|
||||
prefixIPAddr net.IP // IP address for Prefix option
|
||||
@@ -159,7 +159,7 @@ func createICMPv6RAPacket(params icmpv6RA) []byte {
|
||||
func (ra *raCtx) Init() error {
|
||||
ra.stop.Store(0)
|
||||
ra.conn = nil
|
||||
if !(ra.raAllowSlaac || ra.raSlaacOnly) {
|
||||
if !(ra.raAllowSLAAC || ra.raSLAACOnly) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -167,8 +167,8 @@ func (ra *raCtx) Init() error {
|
||||
ra.ipAddr, ra.dnsIPAddr)
|
||||
|
||||
params := icmpv6RA{
|
||||
managedAddressConfiguration: !ra.raSlaacOnly,
|
||||
otherConfiguration: !ra.raSlaacOnly,
|
||||
managedAddressConfiguration: !ra.raSLAACOnly,
|
||||
otherConfiguration: !ra.raSLAACOnly,
|
||||
mtu: uint32(ra.iface.MTU),
|
||||
prefixLen: 64,
|
||||
recursiveDNSServer: ra.dnsIPAddr,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -9,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestRA(t *testing.T) {
|
||||
ra := icmpv6RA{
|
||||
data := createICMPv6RAPacket(icmpv6RA{
|
||||
managedAddressConfiguration: false,
|
||||
otherConfiguration: true,
|
||||
mtu: 1500,
|
||||
@@ -17,8 +16,7 @@ func TestRA(t *testing.T) {
|
||||
prefixLen: 64,
|
||||
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
|
||||
sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
|
||||
}
|
||||
data := createICMPv6RAPacket(ra)
|
||||
})
|
||||
dataCorrect := []byte{
|
||||
0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
|
||||
@@ -27,5 +25,5 @@ func TestRA(t *testing.T) {
|
||||
0x19, 0x03, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x08, 0x00, 0x27, 0xff, 0xfe, 0x00, 0x00, 0x00,
|
||||
}
|
||||
assert.True(t, bytes.Equal(data, dataCorrect))
|
||||
assert.Equal(t, dataCorrect, data)
|
||||
}
|
||||
|
||||
@@ -33,22 +33,22 @@ type DHCPServer interface {
|
||||
|
||||
// V4ServerConf - server configuration
|
||||
type V4ServerConf struct {
|
||||
Enabled bool `yaml:"-"`
|
||||
InterfaceName string `yaml:"-"`
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
InterfaceName string `yaml:"-" json:"-"`
|
||||
|
||||
GatewayIP string `yaml:"gateway_ip"`
|
||||
SubnetMask string `yaml:"subnet_mask"`
|
||||
GatewayIP net.IP `yaml:"gateway_ip" json:"gateway_ip"`
|
||||
SubnetMask net.IP `yaml:"subnet_mask" json:"subnet_mask"`
|
||||
|
||||
// The first & the last IP address for dynamic leases
|
||||
// Bytes [0..2] of the last allowed IP address must match the first IP
|
||||
RangeStart string `yaml:"range_start"`
|
||||
RangeEnd string `yaml:"range_end"`
|
||||
RangeStart net.IP `yaml:"range_start" json:"range_start"`
|
||||
RangeEnd net.IP `yaml:"range_end" json:"range_end"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
|
||||
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
|
||||
|
||||
// IP conflict detector: time (ms) to wait for ICMP reply
|
||||
// 0: disable
|
||||
ICMPTimeout uint32 `yaml:"icmp_timeout_msec"`
|
||||
ICMPTimeout uint32 `yaml:"icmp_timeout_msec" json:"-"`
|
||||
|
||||
// Custom Options.
|
||||
//
|
||||
@@ -58,7 +58,7 @@ type V4ServerConf struct {
|
||||
//
|
||||
// Option with IP data (only 1 IP is supported):
|
||||
// DEC_CODE ip IP_ADDR
|
||||
Options []string `yaml:"options"`
|
||||
Options []string `yaml:"options" json:"-"`
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
ipEnd net.IP // ending IP address for dynamic leases
|
||||
@@ -74,17 +74,17 @@ type V4ServerConf struct {
|
||||
|
||||
// V6ServerConf - server configuration
|
||||
type V6ServerConf struct {
|
||||
Enabled bool `yaml:"-"`
|
||||
InterfaceName string `yaml:"-"`
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
InterfaceName string `yaml:"-" json:"-"`
|
||||
|
||||
// The first IP address for dynamic leases
|
||||
// The last allowed IP address ends with 0xff byte
|
||||
RangeStart string `yaml:"range_start"`
|
||||
RangeStart net.IP `yaml:"range_start" json:"range_start"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
|
||||
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds
|
||||
|
||||
RaSlaacOnly bool `yaml:"ra_slaac_only"` // send ICMPv6.RA packets without MO flags
|
||||
RaAllowSlaac bool `yaml:"ra_allow_slaac"` // send ICMPv6.RA packets with MO flags
|
||||
RASLAACOnly bool `yaml:"ra_slaac_only" json:"-"` // send ICMPv6.RA packets without MO flags
|
||||
RAAllowSLAAC bool `yaml:"ra_allow_slaac" json:"-"` // send ICMPv6.RA packets with MO flags
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
|
||||
@@ -23,7 +23,8 @@ type v4Server struct {
|
||||
srv *server4.Server
|
||||
leasesLock sync.Mutex
|
||||
leases []*Lease
|
||||
ipAddrs [256]byte
|
||||
// TODO(e.burkov): This field type should be a normal bitmap.
|
||||
ipAddrs [256]byte
|
||||
|
||||
conf V4ServerConf
|
||||
}
|
||||
@@ -77,7 +78,10 @@ func (s *v4Server) blacklisted(l *Lease) bool {
|
||||
|
||||
// GetLeases returns the list of current DHCP leases (thread-safe)
|
||||
func (s *v4Server) GetLeases(flags int) []Lease {
|
||||
var result []Lease
|
||||
// The function shouldn't return nil value because zero-length slice
|
||||
// behaves differently in cases like marshalling. Our front-end also
|
||||
// requires non-nil value in the response.
|
||||
result := []Lease{}
|
||||
now := time.Now().Unix()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
@@ -589,7 +593,7 @@ func (s *v4Server) Start() error {
|
||||
s.conf.dnsIPAddrs = dnsIPAddrs
|
||||
|
||||
laddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
IP: net.IP{0, 0, 0, 0},
|
||||
Port: dhcpv4.ServerPort,
|
||||
}
|
||||
s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger())
|
||||
@@ -632,19 +636,18 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
|
||||
}
|
||||
|
||||
var err error
|
||||
s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP)
|
||||
s.conf.routerIP, err = tryTo4(s.conf.GatewayIP)
|
||||
if err != nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
|
||||
subnet, err := parseIPv4(s.conf.SubnetMask)
|
||||
if err != nil || !isValidSubnetMask(subnet) {
|
||||
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %s", s.conf.SubnetMask)
|
||||
if s.conf.SubnetMask == nil {
|
||||
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask)
|
||||
}
|
||||
s.conf.subnetMask = make([]byte, 4)
|
||||
copy(s.conf.subnetMask, subnet)
|
||||
copy(s.conf.subnetMask, s.conf.SubnetMask.To4())
|
||||
|
||||
s.conf.ipStart, err = parseIPv4(conf.RangeStart)
|
||||
s.conf.ipStart, err = tryTo4(conf.RangeStart)
|
||||
if s.conf.ipStart == nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
@@ -652,7 +655,7 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
|
||||
return s, fmt.Errorf("dhcpv4: invalid range start IP")
|
||||
}
|
||||
|
||||
s.conf.ipEnd, err = parseIPv4(conf.RangeEnd)
|
||||
s.conf.ipEnd, err = tryTo4(conf.RangeEnd)
|
||||
if s.conf.ipEnd == nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeIface struct {
|
||||
@@ -79,8 +80,8 @@ func TestIfaceIPAddrs(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, gotErr := ifaceIPAddrs(tc.iface, tc.ipv)
|
||||
require.True(t, errors.Is(gotErr, tc.wantErr))
|
||||
assert.Equal(t, tc.want, got)
|
||||
assert.True(t, errors.Is(gotErr, tc.wantErr))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -140,12 +141,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
|
||||
want: nil,
|
||||
wantErr: errTest,
|
||||
}, {
|
||||
name: "ipv4_wait",
|
||||
iface: &waitingFakeIface{
|
||||
addrs: []net.Addr{addr4},
|
||||
err: nil,
|
||||
n: 1,
|
||||
},
|
||||
name: "ipv4_wait",
|
||||
iface: &waitingFakeIface{addrs: []net.Addr{addr4}, err: nil, n: 1},
|
||||
ipv: ipVersion4,
|
||||
want: []net.IP{ip4, ip4},
|
||||
wantErr: nil,
|
||||
@@ -168,12 +165,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
|
||||
want: nil,
|
||||
wantErr: errTest,
|
||||
}, {
|
||||
name: "ipv6_wait",
|
||||
iface: &waitingFakeIface{
|
||||
addrs: []net.Addr{addr6},
|
||||
err: nil,
|
||||
n: 1,
|
||||
},
|
||||
name: "ipv6_wait",
|
||||
iface: &waitingFakeIface{addrs: []net.Addr{addr6}, err: nil, n: 1},
|
||||
ipv: ipVersion6,
|
||||
want: []net.IP{ip6, ip6},
|
||||
wantErr: nil,
|
||||
@@ -182,8 +175,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, gotErr := ifaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
|
||||
require.True(t, errors.Is(gotErr, tc.wantErr))
|
||||
assert.Equal(t, tc.want, got)
|
||||
assert.True(t, errors.Is(gotErr, tc.wantErr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,231 +8,283 @@ import (
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func notify4(flags uint32) {
|
||||
}
|
||||
|
||||
func TestV4StaticLeaseAddRemove(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
func TestV4_AddRemove_static(t *testing.T) {
|
||||
s, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
}
|
||||
s, err := v4Create(conf)
|
||||
assert.True(t, err == nil)
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
assert.Empty(t, ls)
|
||||
|
||||
// add static lease
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
// Add static lease.
|
||||
l := Lease{
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
require.Nil(t, s.AddStaticLease(l))
|
||||
assert.NotNil(t, s.AddStaticLease(l))
|
||||
|
||||
// try to add the same static lease - fail
|
||||
assert.True(t, s.AddStaticLease(l) != nil)
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
require.Len(t, ls, 1)
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
|
||||
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
|
||||
|
||||
// try to remove static lease - fail
|
||||
l.IP = net.ParseIP("192.168.10.110").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) != nil)
|
||||
// Try to remove static lease.
|
||||
assert.NotNil(t, s.RemoveStaticLease(Lease{
|
||||
IP: net.IP{192, 168, 10, 110},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}))
|
||||
|
||||
// remove static lease
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) == nil)
|
||||
|
||||
// check
|
||||
// Remove static lease.
|
||||
require.Nil(t, s.RemoveStaticLease(l))
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
assert.Empty(t, ls)
|
||||
}
|
||||
|
||||
func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
func TestV4_AddReplace(t *testing.T) {
|
||||
sIface, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
}
|
||||
sIface, err := v4Create(conf)
|
||||
s := sIface.(*v4Server)
|
||||
assert.True(t, err == nil)
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
// add dynamic lease
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("192.168.10.150").To4()
|
||||
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
// add dynamic lease
|
||||
{
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("192.168.10.151").To4()
|
||||
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
dynLeases := []Lease{{
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.IP{192, 168, 10, 151},
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
for i := range dynLeases {
|
||||
s.addLease(&dynLeases[i])
|
||||
}
|
||||
|
||||
// add static lease with the same IP
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
stLeases := []Lease{{
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.IP{192, 168, 10, 152},
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
// add static lease with the same MAC
|
||||
l = Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.152").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
for _, l := range stLeases {
|
||||
require.Nil(t, s.AddStaticLease(l))
|
||||
}
|
||||
|
||||
// check
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 2, len(ls))
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
|
||||
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
assert.Equal(t, "192.168.10.152", ls[1].IP.String())
|
||||
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
|
||||
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestV4StaticLeaseGet(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
func TestV4StaticLease_Get(t *testing.T) {
|
||||
var err error
|
||||
sIface, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
|
||||
l := Lease{
|
||||
IP: net.IP{192, 168, 10, 150},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
sIface, err := v4Create(conf)
|
||||
s := sIface.(*v4Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
|
||||
require.Nil(t, s.AddStaticLease(l))
|
||||
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
var req, resp *dhcpv4.DHCPv4
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
// "Discover"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv4.NewDiscovery(mac)
|
||||
resp, _ := dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
t.Run("discover", func(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// check "Offer"
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
req, err = dhcpv4.NewDiscovery(mac)
|
||||
require.Nil(t, err)
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv4.NewRequestFromOffer(resp)
|
||||
resp, _ = dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
// check "Ack"
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
})
|
||||
|
||||
t.Run("request", func(t *testing.T) {
|
||||
req, err = dhcpv4.NewRequestFromOffer(resp)
|
||||
require.Nil(t, err)
|
||||
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("ack", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, l.IP.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
})
|
||||
|
||||
dnsAddrs := resp.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
|
||||
require.Len(t, dnsAddrs, 1)
|
||||
assert.True(t, s.conf.GatewayIP.Equal(dnsAddrs[0]))
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
t.Run("check_lease", func(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
assert.True(t, l.IP.Equal(ls[0].IP))
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestV4DynamicLeaseGet(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
func TestV4DynamicLease_Get(t *testing.T) {
|
||||
var err error
|
||||
sIface, err := v4Create(V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
GatewayIP: net.IP{192, 168, 10, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
notify: notify4,
|
||||
Options: []string{
|
||||
"81 hex 303132",
|
||||
"82 ip 1.2.3.4",
|
||||
},
|
||||
}
|
||||
sIface, err := v4Create(conf)
|
||||
s := sIface.(*v4Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
// "Discover"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv4.NewDiscovery(mac)
|
||||
resp, _ := dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
|
||||
// check "Offer"
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
|
||||
assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String())
|
||||
var req, resp *dhcpv4.DHCPv4
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv4.NewRequestFromOffer(resp)
|
||||
resp, _ = dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
t.Run("discover", func(t *testing.T) {
|
||||
req, err = dhcpv4.NewDiscovery(mac)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check "Ack"
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)])))
|
||||
})
|
||||
|
||||
t.Run("request", func(t *testing.T) {
|
||||
var err error
|
||||
|
||||
req, err = dhcpv4.NewRequestFromOffer(resp)
|
||||
require.Nil(t, err)
|
||||
|
||||
resp, err = dhcpv4.NewReplyFromRequest(req)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("ack", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
})
|
||||
|
||||
dnsAddrs := resp.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
|
||||
require.Len(t, dnsAddrs, 1)
|
||||
assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "192.168.10.100", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
|
||||
start := net.ParseIP("192.168.10.100").To4()
|
||||
stop := net.ParseIP("192.168.10.200").To4()
|
||||
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4()))
|
||||
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4()))
|
||||
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4()))
|
||||
assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4()))
|
||||
t.Run("check_lease", func(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
assert.Len(t, ls, 1)
|
||||
assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIP4InRange(t *testing.T) {
|
||||
start := net.IP{192, 168, 10, 100}
|
||||
stop := net.IP{192, 168, 10, 200}
|
||||
|
||||
testCases := []struct {
|
||||
ip net.IP
|
||||
want bool
|
||||
}{{
|
||||
ip: net.IP{192, 168, 10, 99},
|
||||
want: false,
|
||||
}, {
|
||||
ip: net.IP{192, 168, 11, 100},
|
||||
want: false,
|
||||
}, {
|
||||
ip: net.IP{192, 168, 11, 201},
|
||||
want: false,
|
||||
}, {
|
||||
ip: start,
|
||||
want: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.ip.String(), func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, ip4InRange(start, stop, tc.ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,6 @@ func (s *v6Server) WriteDiskConfig6(c *V6ServerConf) {
|
||||
}
|
||||
|
||||
// Return TRUE if IP address is within range [start..0xff]
|
||||
// nolint(staticcheck)
|
||||
func ip6InRange(start, ip net.IP) bool {
|
||||
if len(start) != 16 {
|
||||
return false
|
||||
@@ -72,7 +71,10 @@ func (s *v6Server) ResetLeases(ll []*Lease) {
|
||||
|
||||
// GetLeases - get current leases
|
||||
func (s *v6Server) GetLeases(flags int) []Lease {
|
||||
var result []Lease
|
||||
// The function shouldn't return nil value because zero-length slice
|
||||
// behaves differently in cases like marshalling. Our front-end also
|
||||
// requires non-nil value in the response.
|
||||
result := []Lease{}
|
||||
s.leasesLock.Lock()
|
||||
for _, lease := range s.leases {
|
||||
if lease.Expiry.Unix() == leaseExpireStatic {
|
||||
@@ -550,8 +552,8 @@ func (s *v6Server) initRA(iface *net.Interface) error {
|
||||
}
|
||||
}
|
||||
|
||||
s.ra.raAllowSlaac = s.conf.RaAllowSlaac
|
||||
s.ra.raSlaacOnly = s.conf.RaSlaacOnly
|
||||
s.ra.raAllowSLAAC = s.conf.RAAllowSLAAC
|
||||
s.ra.raSLAACOnly = s.conf.RASLAACOnly
|
||||
s.ra.dnsIPAddr = s.ra.ipAddr
|
||||
s.ra.prefixIPAddr = s.conf.ipStart
|
||||
s.ra.ifaceName = s.conf.InterfaceName
|
||||
@@ -592,7 +594,7 @@ func (s *v6Server) Start() error {
|
||||
}
|
||||
|
||||
// don't initialize DHCPv6 server if we must force the clients to use SLAAC
|
||||
if s.conf.RaSlaacOnly {
|
||||
if s.conf.RASLAACOnly {
|
||||
log.Debug("DHCPv6: not starting DHCPv6 server due to ra_slaac_only=true")
|
||||
return nil
|
||||
}
|
||||
@@ -657,7 +659,7 @@ func v6Create(conf V6ServerConf) (DHCPServer, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
s.conf.ipStart = net.ParseIP(conf.RangeStart)
|
||||
s.conf.ipStart = conf.RangeStart
|
||||
if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil {
|
||||
return s, fmt.Errorf("dhcpv6: invalid range-start IP: %s", conf.RangeStart)
|
||||
}
|
||||
|
||||
@@ -9,217 +9,283 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func notify6(flags uint32) {
|
||||
}
|
||||
|
||||
func TestV6StaticLeaseAddRemove(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
func TestV6_AddRemove_static(t *testing.T) {
|
||||
s, err := v6Create(V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::1",
|
||||
RangeStart: net.ParseIP("2001::1"),
|
||||
notify: notify6,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Empty(t, s.GetLeases(LeasesStatic))
|
||||
|
||||
// Add static lease.
|
||||
l := Lease{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
s, err := v6Create(conf)
|
||||
assert.True(t, err == nil)
|
||||
require.Nil(t, s.AddStaticLease(l))
|
||||
|
||||
// Try to add the same static lease.
|
||||
require.NotNil(t, s.AddStaticLease(l))
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
require.Len(t, ls, 1)
|
||||
assert.Equal(t, l.IP, ls[0].IP)
|
||||
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
|
||||
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
|
||||
|
||||
// add static lease
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
// Try to remove non-existent static lease.
|
||||
require.NotNil(t, s.RemoveStaticLease(Lease{
|
||||
IP: net.ParseIP("2001::2"),
|
||||
HWAddr: l.HWAddr,
|
||||
}))
|
||||
|
||||
// try to add static lease - fail
|
||||
assert.True(t, s.AddStaticLease(l) != nil)
|
||||
// Remove static lease.
|
||||
require.Nil(t, s.RemoveStaticLease(l))
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "2001::1", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
// try to remove static lease - fail
|
||||
l.IP = net.ParseIP("2001::2")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) != nil)
|
||||
|
||||
// remove static lease
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) == nil)
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
assert.Empty(t, s.GetLeases(LeasesStatic))
|
||||
}
|
||||
|
||||
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
func TestV6_AddReplace(t *testing.T) {
|
||||
sIface, err := v6Create(V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::1",
|
||||
RangeStart: net.ParseIP("2001::1"),
|
||||
notify: notify6,
|
||||
}
|
||||
sIface, err := v6Create(conf)
|
||||
s := sIface.(*v6Server)
|
||||
assert.True(t, err == nil)
|
||||
})
|
||||
require.Nil(t, err)
|
||||
s, ok := sIface.(*v6Server)
|
||||
require.True(t, ok)
|
||||
|
||||
// add dynamic lease
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("2001::1")
|
||||
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
// Add dynamic leases.
|
||||
dynLeases := []*Lease{{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0x11, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.ParseIP("2001::2"),
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
// add dynamic lease
|
||||
{
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("2001::2")
|
||||
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
for _, l := range dynLeases {
|
||||
s.addLease(l)
|
||||
}
|
||||
|
||||
// add static lease with the same IP
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
stLeases := []Lease{{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0x33, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}, {
|
||||
IP: net.ParseIP("2001::3"),
|
||||
HWAddr: net.HardwareAddr{0x22, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}}
|
||||
|
||||
// add static lease with the same MAC
|
||||
l = Lease{}
|
||||
l.IP = net.ParseIP("2001::3")
|
||||
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
for _, l := range stLeases {
|
||||
require.Nil(t, s.AddStaticLease(l))
|
||||
}
|
||||
|
||||
// check
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 2, len(ls))
|
||||
require.Len(t, ls, 2)
|
||||
|
||||
assert.Equal(t, "2001::1", ls[0].IP.String())
|
||||
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
assert.Equal(t, "2001::3", ls[1].IP.String())
|
||||
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
|
||||
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
|
||||
for i, l := range ls {
|
||||
assert.True(t, stLeases[i].IP.Equal(l.IP))
|
||||
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
|
||||
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestV6GetLease(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
var err error
|
||||
sIface, err := v6Create(V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::1",
|
||||
RangeStart: net.ParseIP("2001::1"),
|
||||
notify: notify6,
|
||||
}
|
||||
sIface, err := v6Create(conf)
|
||||
s := sIface.(*v6Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
|
||||
})
|
||||
require.Nil(t, err)
|
||||
s, ok := sIface.(*v6Server)
|
||||
require.True(t, ok)
|
||||
|
||||
dnsAddr := net.ParseIP("2000::1")
|
||||
s.conf.dnsIPAddrs = []net.IP{dnsAddr}
|
||||
s.sid = dhcpv6.Duid{
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
l := Lease{
|
||||
IP: net.ParseIP("2001::1"),
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
require.Nil(t, s.AddStaticLease(l))
|
||||
|
||||
// "Solicit"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv6.NewSolicit(mac)
|
||||
msg, _ := req.GetInnerMessage()
|
||||
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
var req, resp, msg *dhcpv6.Message
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
t.Run("solicit", func(t *testing.T) {
|
||||
req, err = dhcpv6.NewSolicit(mac)
|
||||
require.Nil(t, err)
|
||||
|
||||
msg, err = req.GetInnerMessage()
|
||||
require.Nil(t, err)
|
||||
|
||||
resp, err = dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
resp.AddOption(dhcpv6.OptServerID(s.sid))
|
||||
|
||||
// check "Advertise"
|
||||
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
|
||||
oia := resp.Options.OneIANA()
|
||||
oiaAddr := oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
var oia *dhcpv6.OptIANA
|
||||
var oiaAddr *dhcpv6.OptIAAddress
|
||||
t.Run("advertise", func(t *testing.T) {
|
||||
require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
|
||||
msg, _ = req.GetInnerMessage()
|
||||
resp, _ = dhcpv6.NewReplyFromMessage(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
})
|
||||
|
||||
// check "Reply"
|
||||
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
t.Run("request", func(t *testing.T) {
|
||||
req, err = dhcpv6.NewRequestFromAdvertise(resp)
|
||||
require.Nil(t, err)
|
||||
|
||||
msg, err = req.GetInnerMessage()
|
||||
require.Nil(t, err)
|
||||
|
||||
resp, err = dhcpv6.NewReplyFromMessage(msg)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("reply", func(t *testing.T) {
|
||||
require.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
|
||||
assert.Equal(t, l.IP, oiaAddr.IPv6Addr)
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
})
|
||||
|
||||
dnsAddrs := resp.Options.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "2000::1", dnsAddrs[0].String())
|
||||
require.Len(t, dnsAddrs, 1)
|
||||
assert.Equal(t, dnsAddr, dnsAddrs[0])
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "2001::1", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
t.Run("lease", func(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
require.Len(t, ls, 1)
|
||||
assert.Equal(t, l.IP, ls[0].IP)
|
||||
assert.Equal(t, l.HWAddr, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestV6GetDynamicLease(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
sIface, err := v6Create(V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::2",
|
||||
RangeStart: net.ParseIP("2001::2"),
|
||||
notify: notify6,
|
||||
}
|
||||
sIface, err := v6Create(conf)
|
||||
s := sIface.(*v6Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
|
||||
s.sid = dhcpv6.Duid{
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
}
|
||||
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
})
|
||||
require.Nil(t, err)
|
||||
s, ok := sIface.(*v6Server)
|
||||
require.True(t, ok)
|
||||
|
||||
// "Solicit"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv6.NewSolicit(mac)
|
||||
msg, _ := req.GetInnerMessage()
|
||||
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
dnsAddr := net.ParseIP("2000::1")
|
||||
s.conf.dnsIPAddrs = []net.IP{dnsAddr}
|
||||
s.sid = dhcpv6.Duid{
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
LinkLayerAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
}
|
||||
|
||||
var req, resp, msg *dhcpv6.Message
|
||||
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
|
||||
t.Run("solicit", func(t *testing.T) {
|
||||
req, err = dhcpv6.NewSolicit(mac)
|
||||
require.Nil(t, err)
|
||||
|
||||
msg, err = req.GetInnerMessage()
|
||||
require.Nil(t, err)
|
||||
|
||||
resp, err = dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
resp.AddOption(dhcpv6.OptServerID(s.sid))
|
||||
|
||||
// check "Advertise"
|
||||
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
|
||||
oia := resp.Options.OneIANA()
|
||||
oiaAddr := oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
|
||||
var oia *dhcpv6.OptIANA
|
||||
var oiaAddr *dhcpv6.OptIAAddress
|
||||
t.Run("advertise", func(t *testing.T) {
|
||||
require.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
|
||||
})
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
|
||||
msg, _ = req.GetInnerMessage()
|
||||
resp, _ = dhcpv6.NewReplyFromMessage(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
t.Run("request", func(t *testing.T) {
|
||||
req, err = dhcpv6.NewRequestFromAdvertise(resp)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check "Reply"
|
||||
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
|
||||
msg, err = req.GetInnerMessage()
|
||||
require.Nil(t, err)
|
||||
|
||||
resp, err = dhcpv6.NewReplyFromMessage(msg)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("reply", func(t *testing.T) {
|
||||
require.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
|
||||
})
|
||||
|
||||
dnsAddrs := resp.Options.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "2000::1", dnsAddrs[0].String())
|
||||
require.Len(t, dnsAddrs, 1)
|
||||
assert.Equal(t, dnsAddr, dnsAddrs[0])
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "2001::2", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
|
||||
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
|
||||
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
|
||||
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2")))
|
||||
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3")))
|
||||
t.Run("lease", func(t *testing.T) {
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
require.Len(t, ls, 1)
|
||||
assert.Equal(t, "2001::2", ls[0].IP.String())
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIP6InRange(t *testing.T) {
|
||||
start := net.ParseIP("2001::2")
|
||||
|
||||
testCases := []struct {
|
||||
ip net.IP
|
||||
want bool
|
||||
}{{
|
||||
ip: net.ParseIP("2001::1"),
|
||||
want: false,
|
||||
}, {
|
||||
ip: net.ParseIP("2002::2"),
|
||||
want: false,
|
||||
}, {
|
||||
ip: start,
|
||||
want: true,
|
||||
}, {
|
||||
ip: net.ParseIP("2001::3"),
|
||||
want: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.ip.String(), func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, ip6InRange(start, tc.ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +161,73 @@ var serviceRulesArray = []svc{
|
||||
"||douyin.com^",
|
||||
"||tiktokv.com^",
|
||||
}},
|
||||
{"qq", []string{"||qq.com^", "||qqzaixian.com^"}},
|
||||
{"vimeo", []string{
|
||||
"||vimeo.com^",
|
||||
"||vimeocdn.com^",
|
||||
"*vod-adaptive.akamaized.net^",
|
||||
}},
|
||||
{"pinterest", []string{
|
||||
"||pinterest.*^",
|
||||
"||pinimg.com^",
|
||||
}},
|
||||
{"imgur", []string{
|
||||
"||imgur.com^",
|
||||
}},
|
||||
{"dailymotion", []string{
|
||||
"||dailymotion.com^",
|
||||
"||dm-event.net^",
|
||||
"||dmcdn.net^",
|
||||
}},
|
||||
{"qq", []string{
|
||||
// block qq.com and subdomains excluding WeChat domains
|
||||
"^(?!weixin|wx)([^.]+\\.)?qq\\.com$",
|
||||
"||qqzaixian.com^",
|
||||
}},
|
||||
{"wechat", []string{
|
||||
"||wechat.com^",
|
||||
"||weixin.qq.com^",
|
||||
"||wx.qq.com^",
|
||||
}},
|
||||
{"viber", []string{
|
||||
"||viber.com^",
|
||||
}},
|
||||
{"weibo", []string{
|
||||
"||weibo.com^",
|
||||
}},
|
||||
{"9gag", []string{
|
||||
"||9cache.com^",
|
||||
"||gag.com^",
|
||||
}},
|
||||
{"telegram", []string{
|
||||
"||t.me^",
|
||||
"||telegram.me^",
|
||||
"||telegram.org^",
|
||||
}},
|
||||
{"disneyplus", []string{
|
||||
"||disney-plus.net^",
|
||||
"||disneyplus.com^",
|
||||
}},
|
||||
{"hulu", []string{
|
||||
"||hulu.com^",
|
||||
}},
|
||||
{"spotify", []string{
|
||||
"/_spotify-connect._tcp.local/",
|
||||
"||spotify.com^",
|
||||
"||scdn.co^",
|
||||
"||spotify.com.edgesuite.net^",
|
||||
"||spotify.map.fastly.net^",
|
||||
"||spotify.map.fastlylb.net^",
|
||||
"||spotifycdn.net^",
|
||||
"||audio-ak-spotify-com.akamaized.net^",
|
||||
"||audio4-ak-spotify-com.akamaized.net^",
|
||||
"||heads-ak-spotify-com.akamaized.net^",
|
||||
"||heads4-ak-spotify-com.akamaized.net^",
|
||||
}},
|
||||
{"tinder", []string{
|
||||
"||gotinder.com^",
|
||||
"||tinder.com^",
|
||||
"||tindersparks.com^",
|
||||
}},
|
||||
}
|
||||
|
||||
// convert array to map
|
||||
@@ -242,6 +308,6 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
// registerBlockedServicesHandlers - register HTTP handlers
|
||||
func (d *DNSFilter) registerBlockedServicesHandlers() {
|
||||
d.Config.HTTPRegister("GET", "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
d.Config.HTTPRegister("POST", "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
}
|
||||
|
||||
37
internal/dnsfilter/blocked_test.go
Normal file
37
internal/dnsfilter/blocked_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// +build ignore
|
||||
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// This is a simple tool that takes a list of services and prints them to the output.
|
||||
// It is supposed to be used to update:
|
||||
// client/src/helpers/constants.js
|
||||
// client/src/components/ui/Icons.js
|
||||
//
|
||||
// Usage:
|
||||
// 1. go run ./internal/dnsfilter/blocked_test.go
|
||||
// 2. Use the output to replace `SERVICES` array in "client/src/helpers/constants.js".
|
||||
// 3. You'll need to enter services names manually.
|
||||
// 4. Don't forget to add missing icons to "client/src/components/ui/Icons.js".
|
||||
//
|
||||
// TODO(ameshkov): Rework generator: have a JSON file with all the metadata we need
|
||||
// then use this JSON file to generate JS and Go code
|
||||
func TestGenServicesArray(t *testing.T) {
|
||||
services := make([]svc, len(serviceRulesArray))
|
||||
copy(services, serviceRulesArray)
|
||||
|
||||
sort.Slice(services, func(i, j int) bool {
|
||||
return services[i].name < services[j].name
|
||||
})
|
||||
|
||||
fmt.Println("export const SERVICES = [")
|
||||
for _, s := range services {
|
||||
fmt.Printf(" {\n id: '%s',\n name: '%s',\n },\n", s.name, s.name)
|
||||
}
|
||||
fmt.Println("];")
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
package dnsfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -36,12 +37,18 @@ type RequestFilteringSettings struct {
|
||||
ParentalEnabled bool
|
||||
|
||||
ClientName string
|
||||
ClientIP string
|
||||
ClientIP net.IP
|
||||
ClientTags []string
|
||||
|
||||
ServicesRules []ServiceEntry
|
||||
}
|
||||
|
||||
// Resolver is the interface for net.Resolver to simplify testing.
|
||||
type Resolver interface {
|
||||
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
|
||||
LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error)
|
||||
}
|
||||
|
||||
// Config allows you to configure DNS filtering with New() or just change variables directly.
|
||||
type Config struct {
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
@@ -68,6 +75,9 @@ type Config struct {
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
|
||||
// CustomResolver is the resolver used by DNSFilter.
|
||||
CustomResolver Resolver
|
||||
}
|
||||
|
||||
// LookupStats store stats collected during safebrowsing or parental checks
|
||||
@@ -110,6 +120,11 @@ type DNSFilter struct {
|
||||
// Channel for passing data to filters-initializer goroutine
|
||||
filtersInitializerChan chan filtersInitializerParams
|
||||
filtersInitializerLock sync.Mutex
|
||||
|
||||
// resolver only looks up the IP address of the host while safe search.
|
||||
//
|
||||
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
||||
resolver Resolver
|
||||
}
|
||||
|
||||
// Filter represents a filter list
|
||||
@@ -148,17 +163,21 @@ const (
|
||||
// FilteredBlockedService - the host is blocked by "blocked services" settings
|
||||
FilteredBlockedService
|
||||
|
||||
// ReasonRewrite is returned when there was a rewrite by
|
||||
// a legacy DNS Rewrite rule.
|
||||
ReasonRewrite
|
||||
// Rewritten is returned when there was a rewrite by a legacy DNS
|
||||
// rewrite rule.
|
||||
Rewritten
|
||||
|
||||
// RewriteAutoHosts is returned when there was a rewrite by
|
||||
// autohosts rules (/etc/hosts and so on).
|
||||
RewriteAutoHosts
|
||||
// RewrittenAutoHosts is returned when there was a rewrite by autohosts
|
||||
// rules (/etc/hosts and so on).
|
||||
RewrittenAutoHosts
|
||||
|
||||
// DNSRewriteRule is returned when a $dnsrewrite filter rule was
|
||||
// applied.
|
||||
DNSRewriteRule
|
||||
// RewrittenRule is returned when a $dnsrewrite filter rule was applied.
|
||||
//
|
||||
// TODO(a.garipov): Remove Rewritten and RewrittenAutoHosts by merging
|
||||
// their functionality into RewrittenRule.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2499.
|
||||
RewrittenRule
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Resync with actual code names or replace completely
|
||||
@@ -175,11 +194,9 @@ var reasonNames = []string{
|
||||
FilteredSafeSearch: "FilteredSafeSearch",
|
||||
FilteredBlockedService: "FilteredBlockedService",
|
||||
|
||||
ReasonRewrite: "Rewrite",
|
||||
|
||||
RewriteAutoHosts: "RewriteEtcHosts",
|
||||
|
||||
DNSRewriteRule: "DNSRewriteRule",
|
||||
Rewritten: "Rewrite",
|
||||
RewrittenAutoHosts: "RewriteEtcHosts",
|
||||
RewrittenRule: "RewriteRule",
|
||||
}
|
||||
|
||||
func (r Reason) String() string {
|
||||
@@ -331,15 +348,15 @@ type Result struct {
|
||||
Rules []*ResultRule `json:",omitempty"`
|
||||
|
||||
// ReverseHosts is the reverse lookup rewrite result. It is
|
||||
// empty unless Reason is set to RewriteAutoHosts.
|
||||
// empty unless Reason is set to RewrittenAutoHosts.
|
||||
ReverseHosts []string `json:",omitempty"`
|
||||
|
||||
// IPList is the lookup rewrite result. It is empty unless
|
||||
// Reason is set to RewriteAutoHosts or ReasonRewrite.
|
||||
// Reason is set to RewrittenAutoHosts or Rewritten.
|
||||
IPList []net.IP `json:",omitempty"`
|
||||
|
||||
// CanonName is the CNAME value from the lookup rewrite result.
|
||||
// It is empty unless Reason is set to ReasonRewrite.
|
||||
// It is empty unless Reason is set to Rewritten or RewrittenRule.
|
||||
CanonName string `json:",omitempty"`
|
||||
|
||||
// ServiceName is the name of the blocked service. It is empty
|
||||
@@ -379,7 +396,7 @@ func (d *DNSFilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
|
||||
|
||||
// first - check rewrites, they have the highest priority
|
||||
result = d.processRewrites(host, qtype)
|
||||
if result.Reason == ReasonRewrite {
|
||||
if result.Reason == Rewritten {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -453,7 +470,7 @@ func (d *DNSFilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
|
||||
func (d *DNSFilter) checkAutoHosts(host string, qtype uint16, result *Result) (matched bool) {
|
||||
ips := d.Config.AutoHosts.Process(host, qtype)
|
||||
if ips != nil {
|
||||
result.Reason = RewriteAutoHosts
|
||||
result.Reason = RewrittenAutoHosts
|
||||
result.IPList = ips
|
||||
|
||||
return true
|
||||
@@ -461,7 +478,7 @@ func (d *DNSFilter) checkAutoHosts(host string, qtype uint16, result *Result) (m
|
||||
|
||||
revHosts := d.Config.AutoHosts.ProcessReverse(host, qtype)
|
||||
if len(revHosts) != 0 {
|
||||
result.Reason = RewriteAutoHosts
|
||||
result.Reason = RewrittenAutoHosts
|
||||
|
||||
// TODO(a.garipov): Optimize this with a buffer.
|
||||
result.ReverseHosts = make([]string, len(revHosts))
|
||||
@@ -488,7 +505,7 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||
|
||||
rr := findRewrites(d.Rewrites, host)
|
||||
if len(rr) != 0 {
|
||||
res.Reason = ReasonRewrite
|
||||
res.Reason = Rewritten
|
||||
}
|
||||
|
||||
cnames := map[string]bool{}
|
||||
@@ -674,9 +691,10 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS
|
||||
ureq := urlfilter.DNSRequest{
|
||||
Hostname: host,
|
||||
SortedClientTags: setts.ClientTags,
|
||||
ClientIP: setts.ClientIP,
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
// TODO(e.burkov): Wait for urlfilter update to pass net.IP.
|
||||
ClientIP: setts.ClientIP.String(),
|
||||
ClientName: setts.ClientName,
|
||||
DNSType: qtype,
|
||||
}
|
||||
|
||||
if d.filteringEngineAllow != nil {
|
||||
@@ -696,7 +714,7 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS
|
||||
// awkward.
|
||||
if dnsr := dnsres.DNSRewrites(); len(dnsr) > 0 {
|
||||
res = d.processDNSRewrites(dnsr)
|
||||
if res.Reason == DNSRewriteRule && res.CanonName == host {
|
||||
if res.Reason == RewrittenRule && res.CanonName == host {
|
||||
// A rewrite of a host to itself. Go on and
|
||||
// try matching other things.
|
||||
} else {
|
||||
@@ -781,6 +799,7 @@ func InitModule() {
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used.
|
||||
func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
var resolver Resolver = net.DefaultResolver
|
||||
if c != nil {
|
||||
cacheConf := cache.Config{
|
||||
EnableLRU: true,
|
||||
@@ -800,9 +819,15 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||
cacheConf.MaxSize = c.ParentalCacheSize
|
||||
gctx.parentalCache = cache.New(cacheConf)
|
||||
}
|
||||
|
||||
if c.CustomResolver != nil {
|
||||
resolver = c.CustomResolver
|
||||
}
|
||||
}
|
||||
|
||||
d := new(DNSFilter)
|
||||
d := &DNSFilter{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
err := d.initSecurityServices()
|
||||
if err != nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,8 @@ import (
|
||||
|
||||
// DNSRewriteResult is the result of application of $dnsrewrite rules.
|
||||
type DNSRewriteResult struct {
|
||||
RCode rules.RCode `json:",omitempty"`
|
||||
Response DNSRewriteResultResponse `json:",omitempty"`
|
||||
RCode rules.RCode `json:",omitempty"`
|
||||
}
|
||||
|
||||
// DNSRewriteResultResponse is the collection of DNS response records
|
||||
@@ -33,13 +33,13 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
if dr.NewCNAME != "" {
|
||||
// NewCNAME rules have a higher priority than
|
||||
// the other rules.
|
||||
rules := []*ResultRule{{
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
Text: nr.RuleText,
|
||||
}}
|
||||
|
||||
return Result{
|
||||
Reason: DNSRewriteRule,
|
||||
Reason: RewrittenRule,
|
||||
Rules: rules,
|
||||
CanonName: dr.NewCNAME,
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
default:
|
||||
// RcodeRefused and other such codes have higher
|
||||
// priority. Return immediately.
|
||||
rules := []*ResultRule{{
|
||||
rules = []*ResultRule{{
|
||||
FilterListID: int64(nr.GetFilterListID()),
|
||||
Text: nr.RuleText,
|
||||
}}
|
||||
@@ -65,7 +65,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
}
|
||||
|
||||
return Result{
|
||||
Reason: DNSRewriteRule,
|
||||
Reason: RewrittenRule,
|
||||
Rules: rules,
|
||||
DNSRewriteResult: dnsrr,
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (d *DNSFilter) processDNSRewrites(dnsr []*rules.NetworkRule) (res Result) {
|
||||
}
|
||||
|
||||
return Result{
|
||||
Reason: DNSRewriteRule,
|
||||
Reason: RewrittenRule,
|
||||
Rules: rules,
|
||||
DNSRewriteResult: dnsrr,
|
||||
}
|
||||
|
||||
@@ -11,40 +11,41 @@ import (
|
||||
|
||||
func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
const text = `
|
||||
|cname^$dnsrewrite=new_cname
|
||||
|cname^$dnsrewrite=new-cname
|
||||
|
||||
|a_record^$dnsrewrite=127.0.0.1
|
||||
|a-record^$dnsrewrite=127.0.0.1
|
||||
|
||||
|aaaa_record^$dnsrewrite=::1
|
||||
|aaaa-record^$dnsrewrite=::1
|
||||
|
||||
|txt_record^$dnsrewrite=NOERROR;TXT;hello_world
|
||||
|txt-record^$dnsrewrite=NOERROR;TXT;hello-world
|
||||
|
||||
|refused^$dnsrewrite=REFUSED
|
||||
|
||||
|a_records^$dnsrewrite=127.0.0.1
|
||||
|a_records^$dnsrewrite=127.0.0.2
|
||||
|a-records^$dnsrewrite=127.0.0.1
|
||||
|a-records^$dnsrewrite=127.0.0.2
|
||||
|
||||
|aaaa_records^$dnsrewrite=::1
|
||||
|aaaa_records^$dnsrewrite=::2
|
||||
|aaaa-records^$dnsrewrite=::1
|
||||
|aaaa-records^$dnsrewrite=::2
|
||||
|
||||
|disable_one^$dnsrewrite=127.0.0.1
|
||||
|disable_one^$dnsrewrite=127.0.0.2
|
||||
@@||disable_one^$dnsrewrite=127.0.0.1
|
||||
|disable-one^$dnsrewrite=127.0.0.1
|
||||
|disable-one^$dnsrewrite=127.0.0.2
|
||||
@@||disable-one^$dnsrewrite=127.0.0.1
|
||||
|
||||
|disable_cname^$dnsrewrite=127.0.0.1
|
||||
|disable_cname^$dnsrewrite=new_cname
|
||||
@@||disable_cname^$dnsrewrite=new_cname
|
||||
|disable-cname^$dnsrewrite=127.0.0.1
|
||||
|disable-cname^$dnsrewrite=new-cname
|
||||
@@||disable-cname^$dnsrewrite=new-cname
|
||||
|
||||
|disable_cname_many^$dnsrewrite=127.0.0.1
|
||||
|disable_cname_many^$dnsrewrite=new_cname_1
|
||||
|disable_cname_many^$dnsrewrite=new_cname_2
|
||||
@@||disable_cname_many^$dnsrewrite=new_cname_1
|
||||
|disable-cname-many^$dnsrewrite=127.0.0.1
|
||||
|disable-cname-many^$dnsrewrite=new-cname-1
|
||||
|disable-cname-many^$dnsrewrite=new-cname-2
|
||||
@@||disable-cname-many^$dnsrewrite=new-cname-1
|
||||
|
||||
|disable_all^$dnsrewrite=127.0.0.1
|
||||
|disable_all^$dnsrewrite=127.0.0.2
|
||||
@@||disable_all^$dnsrewrite
|
||||
|disable-all^$dnsrewrite=127.0.0.1
|
||||
|disable-all^$dnsrewrite=127.0.0.2
|
||||
@@||disable-all^$dnsrewrite
|
||||
`
|
||||
f := NewForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
|
||||
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
setts := &RequestFilteringSettings{
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
@@ -60,10 +61,10 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "new_cname", res.CanonName)
|
||||
assert.Equal(t, "new-cname", res.CanonName)
|
||||
})
|
||||
|
||||
t.Run("a_record", func(t *testing.T) {
|
||||
t.Run("a-record", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
@@ -78,7 +79,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("aaaa_record", func(t *testing.T) {
|
||||
t.Run("aaaa-record", func(t *testing.T) {
|
||||
dtyp := dns.TypeAAAA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
@@ -93,7 +94,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("txt_record", func(t *testing.T) {
|
||||
t.Run("txt-record", func(t *testing.T) {
|
||||
dtyp := dns.TypeTXT
|
||||
host := path.Base(t.Name())
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
@@ -102,7 +103,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
if strVals := dnsrr.Response[dtyp]; assert.Len(t, strVals, 1) {
|
||||
assert.Equal(t, "hello_world", strVals[0])
|
||||
assert.Equal(t, "hello-world", strVals[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -117,7 +118,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a_records", func(t *testing.T) {
|
||||
t.Run("a-records", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
@@ -133,7 +134,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("aaaa_records", func(t *testing.T) {
|
||||
t.Run("aaaa-records", func(t *testing.T) {
|
||||
dtyp := dns.TypeAAAA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
@@ -149,7 +150,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable_one", func(t *testing.T) {
|
||||
t.Run("disable-one", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
@@ -164,13 +165,13 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable_cname", func(t *testing.T) {
|
||||
t.Run("disable-cname", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", res.CanonName)
|
||||
assert.Empty(t, res.CanonName)
|
||||
|
||||
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
|
||||
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
|
||||
@@ -180,23 +181,23 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disable_cname_many", func(t *testing.T) {
|
||||
t.Run("disable-cname-many", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "new_cname_2", res.CanonName)
|
||||
assert.Equal(t, "new-cname-2", res.CanonName)
|
||||
assert.Nil(t, res.DNSRewriteResult)
|
||||
})
|
||||
|
||||
t.Run("disable_all", func(t *testing.T) {
|
||||
t.Run("disable-all", func(t *testing.T) {
|
||||
dtyp := dns.TypeA
|
||||
host := path.Base(t.Name())
|
||||
|
||||
res, err := f.CheckHostRules(host, dtyp, setts)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", res.CanonName)
|
||||
assert.Len(t, res.Rules, 0)
|
||||
assert.Empty(t, res.CanonName)
|
||||
assert.Empty(t, res.Rules)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) registerRewritesHandlers() {
|
||||
d.Config.HTTPRegister("GET", "/control/rewrite/list", d.handleRewriteList)
|
||||
d.Config.HTTPRegister("POST", "/control/rewrite/add", d.handleRewriteAdd)
|
||||
d.Config.HTTPRegister("POST", "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
)
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d := DNSFilter{}
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// CNAME, A, AAAA
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"somecname", "somehost.com", 0, nil},
|
||||
@@ -25,16 +26,16 @@ func TestRewrites(t *testing.T) {
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
|
||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.Equal(t, 2, len(r.IPList))
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
|
||||
assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5")))
|
||||
assert.Len(t, r.IPList, 2)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
|
||||
|
||||
r = d.processRewrites("www.host.com", dns.TypeAAAA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
|
||||
|
||||
// wildcard
|
||||
@@ -44,12 +45,12 @@ func TestRewrites(t *testing.T) {
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5")))
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
|
||||
|
||||
r = d.processRewrites("www.host2.com", dns.TypeA)
|
||||
assert.Equal(t, NotFilteredNotFound, r.Reason)
|
||||
@@ -61,9 +62,9 @@ func TestRewrites(t *testing.T) {
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("a.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.True(t, len(r.IPList) == 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
// wildcard + CNAME
|
||||
d.Rewrites = []RewriteEntry{
|
||||
@@ -72,9 +73,9 @@ func TestRewrites(t *testing.T) {
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("www.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
// 2 CNAMEs
|
||||
d.Rewrites = []RewriteEntry{
|
||||
@@ -84,10 +85,10 @@ func TestRewrites(t *testing.T) {
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("b.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "host.com", r.CanonName)
|
||||
assert.True(t, len(r.IPList) == 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
|
||||
// 2 CNAMEs + wildcard
|
||||
d.Rewrites = []RewriteEntry{
|
||||
@@ -97,14 +98,15 @@ func TestRewrites(t *testing.T) {
|
||||
}
|
||||
d.prepareRewrites()
|
||||
r = d.processRewrites("b.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Equal(t, "x.somehost.com", r.CanonName)
|
||||
assert.True(t, len(r.IPList) == 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4")))
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d := DNSFilter{}
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// exact host, wildcard L2, wildcard L3
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"host.com", "1.1.1.1", 0, nil},
|
||||
@@ -115,25 +117,26 @@ func TestRewritesLevels(t *testing.T) {
|
||||
|
||||
// match exact
|
||||
r := d.processRewrites("host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, "1.1.1.1", r.IPList[0].String())
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
|
||||
|
||||
// match L2
|
||||
r = d.processRewrites("sub.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
||||
|
||||
// match L3
|
||||
r = d.processRewrites("my.sub.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, "3.3.3.3", r.IPList[0].String())
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d := DNSFilter{}
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// wildcard; exception for a sub-domain
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"*.host.com", "2.2.2.2", 0, nil},
|
||||
@@ -143,9 +146,9 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
|
||||
// match sub-domain
|
||||
r := d.processRewrites("my.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
||||
|
||||
// match sub-domain, but handle exception
|
||||
r = d.processRewrites("sub.host.com", dns.TypeA)
|
||||
@@ -153,7 +156,8 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionWC(t *testing.T) {
|
||||
d := DNSFilter{}
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// wildcard; exception for a sub-wildcard
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"*.host.com", "2.2.2.2", 0, nil},
|
||||
@@ -163,9 +167,9 @@ func TestRewritesExceptionWC(t *testing.T) {
|
||||
|
||||
// match sub-domain
|
||||
r := d.processRewrites("my.host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, "2.2.2.2", r.IPList[0].String())
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
|
||||
|
||||
// match sub-domain, but handle exception
|
||||
r = d.processRewrites("my.sub.host.com", dns.TypeA)
|
||||
@@ -173,7 +177,8 @@ func TestRewritesExceptionWC(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d := DNSFilter{}
|
||||
d := newForTest(nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// exception for AAAA record
|
||||
d.Rewrites = []RewriteEntry{
|
||||
{"host.com", "1.2.3.4", 0, nil},
|
||||
@@ -186,9 +191,9 @@ func TestRewritesExceptionIP(t *testing.T) {
|
||||
|
||||
// match domain
|
||||
r := d.processRewrites("host.com", dns.TypeA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, "1.2.3.4", r.IPList[0].String())
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
|
||||
|
||||
// match exception
|
||||
r = d.processRewrites("host.com", dns.TypeAAAA)
|
||||
@@ -200,8 +205,8 @@ func TestRewritesExceptionIP(t *testing.T) {
|
||||
|
||||
// match domain
|
||||
r = d.processRewrites("host2.com", dns.TypeAAAA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 1, len(r.IPList))
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Len(t, r.IPList, 1)
|
||||
assert.Equal(t, "::1", r.IPList[0].String())
|
||||
|
||||
// match exception
|
||||
@@ -210,6 +215,6 @@ func TestRewritesExceptionIP(t *testing.T) {
|
||||
|
||||
// match domain
|
||||
r = d.processRewrites("host3.com", dns.TypeAAAA)
|
||||
assert.Equal(t, ReasonRewrite, r.Reason)
|
||||
assert.Equal(t, 0, len(r.IPList))
|
||||
assert.Equal(t, Rewritten, r.Reason)
|
||||
assert.Empty(t, r.IPList)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,20 @@ const (
|
||||
pcTXTSuffix = `pc.dns.adguard.com.`
|
||||
)
|
||||
|
||||
// SetParentalUpstream sets the parental upstream for *DNSFilter.
|
||||
//
|
||||
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
|
||||
func (d *DNSFilter) SetParentalUpstream(u upstream.Upstream) {
|
||||
d.parentalUpstream = u
|
||||
}
|
||||
|
||||
// SetSafeBrowsingUpstream sets the safe browsing upstream for *DNSFilter.
|
||||
//
|
||||
// TODO(e.burkov): Remove this in v1 API to forbid the direct access.
|
||||
func (d *DNSFilter) SetSafeBrowsingUpstream(u upstream.Upstream) {
|
||||
d.safeBrowsingUpstream = u
|
||||
}
|
||||
|
||||
func (d *DNSFilter) initSecurityServices() error {
|
||||
var err error
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
@@ -37,22 +51,24 @@ func (d *DNSFilter) initSecurityServices() error {
|
||||
opts := upstream.Options{
|
||||
Timeout: dnsTimeout,
|
||||
ServerIPAddrs: []net.IP{
|
||||
net.ParseIP("94.140.14.15"),
|
||||
net.ParseIP("94.140.15.16"),
|
||||
{94, 140, 14, 15},
|
||||
{94, 140, 15, 16},
|
||||
net.ParseIP("2a10:50c0::bad1:ff"),
|
||||
net.ParseIP("2a10:50c0::bad2:ff"),
|
||||
},
|
||||
}
|
||||
|
||||
d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts)
|
||||
parUps, err := upstream.AddressToUpstream(d.parentalServer, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("converting parental server: %w", err)
|
||||
}
|
||||
d.SetParentalUpstream(parUps)
|
||||
|
||||
d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, opts)
|
||||
sbUps, err := upstream.AddressToUpstream(d.safeBrowsingServer, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("converting safe browsing server: %w", err)
|
||||
}
|
||||
d.SetSafeBrowsingUpstream(sbUps)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -200,7 +216,6 @@ func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
|
||||
log.Debug("%s: received hashes for %s: %v", c.svc, c.host, txt.Txt)
|
||||
|
||||
for _, t := range txt.Txt {
|
||||
|
||||
if len(t) != 32*2 {
|
||||
continue
|
||||
}
|
||||
@@ -228,7 +243,7 @@ func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) {
|
||||
|
||||
func (c *sbCtx) storeCache(hashes [][]byte) {
|
||||
sort.Slice(hashes, func(a, b int) bool {
|
||||
return bytes.Compare(hashes[a], hashes[b]) < 0
|
||||
return bytes.Compare(hashes[a], hashes[b]) == -1
|
||||
})
|
||||
|
||||
var curData []byte
|
||||
@@ -346,16 +361,12 @@ func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": d.Config.SafeBrowsingEnabled,
|
||||
}
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
err := json.NewEncoder(w).Encode(&struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}{
|
||||
Enabled: d.Config.SafeBrowsingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
@@ -373,17 +384,12 @@ func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": d.Config.ParentalEnabled,
|
||||
}
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
err := json.NewEncoder(w).Encode(&struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}{
|
||||
Enabled: d.Config.ParentalEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
@@ -391,15 +397,15 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) registerSecurityHandlers() {
|
||||
d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||
|
||||
d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable)
|
||||
d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable)
|
||||
d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
||||
|
||||
d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
}
|
||||
|
||||
@@ -5,16 +5,15 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSafeBrowsingHash(t *testing.T) {
|
||||
// test hostnameToHashes()
|
||||
hashes := hostnameToHashes("1.2.3.sub.host.com")
|
||||
assert.Equal(t, 3, len(hashes))
|
||||
assert.Len(t, hashes, 3)
|
||||
_, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))]
|
||||
assert.True(t, ok)
|
||||
_, ok = hashes[sha256.Sum256([]byte("sub.host.com"))]
|
||||
@@ -31,9 +30,9 @@ func TestSafeBrowsingHash(t *testing.T) {
|
||||
|
||||
q := c.getQuestion()
|
||||
|
||||
assert.True(t, strings.Contains(q, "7a1b."))
|
||||
assert.True(t, strings.Contains(q, "af5a."))
|
||||
assert.True(t, strings.Contains(q, "eb11."))
|
||||
assert.Contains(t, q, "7a1b.")
|
||||
assert.Contains(t, q, "af5a.")
|
||||
assert.Contains(t, q, "eb11.")
|
||||
assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com."))
|
||||
}
|
||||
|
||||
@@ -81,7 +80,7 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
hash = sha256.Sum256([]byte("nonexisting.com"))
|
||||
c.hashToHost[hash] = "nonexisting.com"
|
||||
assert.Equal(t, 0, c.getCached())
|
||||
assert.Empty(t, c.getCached())
|
||||
|
||||
hash = sha256.Sum256([]byte("sub.host.com"))
|
||||
_, ok := c.hashToHost[hash]
|
||||
@@ -103,30 +102,17 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
c.hashToHost[hash] = "sub.host.com"
|
||||
|
||||
c.cache.Set(hash[0:2], make([]byte, 32))
|
||||
assert.Equal(t, 0, c.getCached())
|
||||
}
|
||||
|
||||
// testErrUpstream implements upstream.Upstream interface for replacing real
|
||||
// upstream in tests.
|
||||
type testErrUpstream struct{}
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (teu *testErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
return nil, agherr.Error("bad")
|
||||
}
|
||||
|
||||
func (teu *testErrUpstream) Address() string {
|
||||
return ""
|
||||
assert.Empty(t, c.getCached())
|
||||
}
|
||||
|
||||
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
defer d.Close()
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := &testErrUpstream{}
|
||||
ups := &aghtest.TestErrUpstream{}
|
||||
|
||||
d.safeBrowsingUpstream = ups
|
||||
d.parentalUpstream = ups
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
_, err := d.checkSafeBrowsing("smthng.com")
|
||||
assert.NotNil(t, err)
|
||||
@@ -134,3 +120,87 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
_, err = d.checkParental("smthng.com")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestSBPC(t *testing.T) {
|
||||
d := newForTest(&Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const hostname = "example.org"
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
block bool
|
||||
testFunc func(string) (Result, error)
|
||||
testCache cache.Cache
|
||||
}{{
|
||||
name: "sb_no_block",
|
||||
block: false,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
testCache: gctx.safebrowsingCache,
|
||||
}, {
|
||||
name: "sb_block",
|
||||
block: true,
|
||||
testFunc: d.checkSafeBrowsing,
|
||||
testCache: gctx.safebrowsingCache,
|
||||
}, {
|
||||
name: "pc_no_block",
|
||||
block: false,
|
||||
testFunc: d.checkParental,
|
||||
testCache: gctx.parentalCache,
|
||||
}, {
|
||||
name: "pc_block",
|
||||
block: true,
|
||||
testFunc: d.checkParental,
|
||||
testCache: gctx.parentalCache,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Prepare the upstream.
|
||||
ups := &aghtest.TestBlockUpstream{
|
||||
Hostname: hostname,
|
||||
Block: tc.block,
|
||||
}
|
||||
d.SetSafeBrowsingUpstream(ups)
|
||||
d.SetParentalUpstream(ups)
|
||||
|
||||
// Firstly, check the request blocking.
|
||||
hits := 0
|
||||
res, err := tc.testFunc(hostname)
|
||||
assert.Nil(t, err)
|
||||
if tc.block {
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.Len(t, res.Rules, 1)
|
||||
hits++
|
||||
} else {
|
||||
assert.False(t, res.IsFiltered)
|
||||
}
|
||||
|
||||
// Check the cache state, check the response is now cached.
|
||||
assert.Equal(t, 1, tc.testCache.Stats().Count)
|
||||
assert.Equal(t, hits, tc.testCache.Stats().Hit)
|
||||
|
||||
// There was one request to an upstream.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
|
||||
// Now make the same request to check the cache was used.
|
||||
res, err = tc.testFunc(hostname)
|
||||
assert.Nil(t, err)
|
||||
if tc.block {
|
||||
assert.True(t, res.IsFiltered)
|
||||
assert.Len(t, res.Rules, 1)
|
||||
} else {
|
||||
assert.False(t, res.IsFiltered)
|
||||
}
|
||||
|
||||
// Check the cache state, it should've been used.
|
||||
assert.Equal(t, 1, tc.testCache.Stats().Count)
|
||||
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
|
||||
|
||||
// Check that there were no additional requests.
|
||||
assert.Equal(t, 1, ups.RequestsCount())
|
||||
|
||||
purgeCaches()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package dnsfilter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
@@ -101,15 +102,14 @@ func (d *DNSFilter) checkSafeSearch(host string) (Result, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// TODO this address should be resolved with upstream that was configured in dnsforward
|
||||
ips, err := net.LookupIP(safeHost)
|
||||
ipAddrs, err := d.resolver.LookupIPAddr(context.Background(), safeHost)
|
||||
if err != nil {
|
||||
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if ipv4 := ipAddr.IP.To4(); ipv4 != nil {
|
||||
res.Rules[0].IP = ipv4
|
||||
|
||||
l := d.setCacheResult(gctx.safeSearchCache, host, res)
|
||||
@@ -133,17 +133,12 @@ func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := map[string]interface{}{
|
||||
"enabled": d.Config.SafeSearchEnabled,
|
||||
}
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
err := json.NewEncoder(w).Encode(&struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}{
|
||||
Enabled: d.Config.SafeSearchEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
|
||||
@@ -83,20 +83,21 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
|
||||
// Returns the item from the "disallowedClients" list that lead to blocking IP.
|
||||
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
|
||||
// but the ip does not belong to it.
|
||||
func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
|
||||
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
ipStr := ip.String()
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 {
|
||||
_, ok := a.allowedClients[ip]
|
||||
_, ok := a.allowedClients[ipStr]
|
||||
if ok {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if len(a.allowedClientsIPNet) != 0 {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
for _, ipnet := range a.allowedClientsIPNet {
|
||||
if ipnet.Contains(ipAddr) {
|
||||
if ipnet.Contains(ip) {
|
||||
return false, ""
|
||||
}
|
||||
}
|
||||
@@ -105,15 +106,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
|
||||
return true, ""
|
||||
}
|
||||
|
||||
_, ok := a.disallowedClients[ip]
|
||||
_, ok := a.disallowedClients[ipStr]
|
||||
if ok {
|
||||
return true, ip
|
||||
return true, ipStr
|
||||
}
|
||||
|
||||
if len(a.disallowedClientsIPNet) != 0 {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
for _, ipnet := range a.disallowedClientsIPNet {
|
||||
if ipnet.Contains(ipAddr) {
|
||||
if ipnet.Contains(ip) {
|
||||
return true, ipnet.String()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -8,44 +9,44 @@ import (
|
||||
|
||||
func TestIsBlockedIPAllowed(t *testing.T) {
|
||||
a := &accessCtx{}
|
||||
assert.True(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil) == nil)
|
||||
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
|
||||
|
||||
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
|
||||
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
|
||||
assert.False(t, disallowed)
|
||||
assert.Equal(t, "", disallowedRule)
|
||||
assert.Empty(t, disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
|
||||
assert.True(t, disallowed)
|
||||
assert.Equal(t, "", disallowedRule)
|
||||
assert.Empty(t, disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
|
||||
assert.False(t, disallowed)
|
||||
assert.Equal(t, "", disallowedRule)
|
||||
assert.Empty(t, disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
|
||||
assert.True(t, disallowed)
|
||||
assert.Equal(t, "", disallowedRule)
|
||||
assert.Empty(t, disallowedRule)
|
||||
}
|
||||
|
||||
func TestIsBlockedIPDisallowed(t *testing.T) {
|
||||
a := &accessCtx{}
|
||||
assert.True(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil) == nil)
|
||||
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
|
||||
|
||||
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
|
||||
disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
|
||||
assert.True(t, disallowed)
|
||||
assert.Equal(t, "1.1.1.1", disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
|
||||
assert.False(t, disallowed)
|
||||
assert.Equal(t, "", disallowedRule)
|
||||
assert.Empty(t, disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
|
||||
assert.True(t, disallowed)
|
||||
assert.Equal(t, "2.2.0.0/16", disallowedRule)
|
||||
|
||||
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
|
||||
disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
|
||||
assert.False(t, disallowed)
|
||||
assert.Equal(t, "", disallowedRule)
|
||||
assert.Empty(t, disallowedRule)
|
||||
}
|
||||
|
||||
func TestIsBlockedIPBlockedDomain(t *testing.T) {
|
||||
@@ -60,13 +61,13 @@ func TestIsBlockedIPBlockedDomain(t *testing.T) {
|
||||
// match by "host2.com"
|
||||
assert.True(t, a.IsBlockedDomain("host1"))
|
||||
assert.True(t, a.IsBlockedDomain("host2"))
|
||||
assert.True(t, !a.IsBlockedDomain("host3"))
|
||||
assert.False(t, a.IsBlockedDomain("host3"))
|
||||
|
||||
// match by wildcard "*.host.com"
|
||||
assert.True(t, !a.IsBlockedDomain("host.com"))
|
||||
assert.False(t, a.IsBlockedDomain("host.com"))
|
||||
assert.True(t, a.IsBlockedDomain("asdf.host.com"))
|
||||
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com"))
|
||||
assert.True(t, !a.IsBlockedDomain("asdf.zhost.com"))
|
||||
assert.False(t, a.IsBlockedDomain("asdf.zhost.com"))
|
||||
|
||||
// match by wildcard "||host3.com^"
|
||||
assert.True(t, a.IsBlockedDomain("host3.com"))
|
||||
|
||||
165
internal/dnsforward/clientid.go
Normal file
165
internal/dnsforward/clientid.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
const maxDomainPartLen = 64
|
||||
|
||||
// ValidateClientID returns an error if clientID is not a valid client ID.
|
||||
func ValidateClientID(clientID string) (err error) {
|
||||
if len(clientID) > maxDomainPartLen {
|
||||
return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
|
||||
}
|
||||
|
||||
for i, r := range clientID {
|
||||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
|
||||
continue
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
|
||||
// is the server name of the host. cliSrvName is the server name as sent by the
|
||||
// client. When strict is true, and client and host server name don't match,
|
||||
// clientIDFromClientServerName will return an error.
|
||||
func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (clientID string, err error) {
|
||||
if hostSrvName == cliSrvName {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(cliSrvName, hostSrvName) {
|
||||
if !strict {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
|
||||
}
|
||||
|
||||
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
|
||||
err = ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid client id: %w", err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientIDHTTPS extracts the client's ID from the path of the
|
||||
// client's DNS-over-HTTPS request.
|
||||
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
pctx := ctx.proxyCtx
|
||||
r := pctx.HTTPRequest
|
||||
if r == nil {
|
||||
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
origPath := r.URL.Path
|
||||
parts := strings.Split(path.Clean(origPath), "/")
|
||||
if parts[0] == "" {
|
||||
parts = parts[1:]
|
||||
}
|
||||
|
||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
clientID := ""
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// Just /dns-query, no client ID.
|
||||
return resultCodeSuccess
|
||||
case 2:
|
||||
clientID = parts[1]
|
||||
default:
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
err := ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
ctx.clientID = clientID
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
type tlsConn interface {
|
||||
ConnectionState() (cs tls.ConnectionState)
|
||||
}
|
||||
|
||||
// quicSession is a narrow interface for quic.Session to simplify testing.
|
||||
type quicSession interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
// processClientID extracts the client's ID from the server name of the client's
|
||||
// DOT or DOQ request or the path of the client's DOH.
|
||||
func processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
return processClientIDHTTPS(dctx)
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
srvConf := dctx.srv.conf
|
||||
hostSrvName := srvConf.TLSConfig.ServerName
|
||||
if hostSrvName == "" {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
cliSrvName := ""
|
||||
if proto == proxy.ProtoTLS {
|
||||
conn := pctx.Conn
|
||||
tc, ok := conn.(tlsConn)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
} else if proto == proxy.ProtoQUIC {
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().ServerName
|
||||
}
|
||||
|
||||
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
|
||||
if err != nil {
|
||||
dctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
dctx.clientID = clientID
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
273
internal/dnsforward/clientid_test.go
Normal file
273
internal/dnsforward/clientid_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testTLSConn is a tlsConn for tests.
|
||||
type testTLSConn struct {
|
||||
// Conn is embedded here simply to make testTLSConn a net.Conn without
|
||||
// acctually implementing all methods.
|
||||
net.Conn
|
||||
|
||||
serverName string
|
||||
}
|
||||
|
||||
// ConnectionState implements the tlsConn interface for testTLSConn.
|
||||
func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
|
||||
cs.ServerName = c.serverName
|
||||
|
||||
return cs
|
||||
}
|
||||
|
||||
// testQUICSession is a quicSession for tests.
|
||||
type testQUICSession struct {
|
||||
// Session is embedded here simply to make testQUICSession
|
||||
// a quic.Session without acctually implementing all methods.
|
||||
quic.Session
|
||||
|
||||
serverName string
|
||||
}
|
||||
|
||||
// ConnectionState implements the quicSession interface for testQUICSession.
|
||||
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
cs.ServerName = c.serverName
|
||||
|
||||
return cs
|
||||
}
|
||||
|
||||
func TestProcessClientID(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
hostSrvName string
|
||||
cliSrvName string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
strictSNI bool
|
||||
}{{
|
||||
name: "udp",
|
||||
proto: proxy.ProtoUDP,
|
||||
hostSrvName: "",
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_no_client_id",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name_no_strict",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_client_id",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_hostname_error",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.net",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_invalid_client_id",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "!!!.example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id: invalid char '!' ` +
|
||||
`at index 0 in client id "!!!"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_too_long",
|
||||
proto: proxy.ProtoTLS,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: `abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789.example.com`,
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" ` +
|
||||
`is too long, max: 64`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_client_id",
|
||||
proto: proxy.ProtoQUIC,
|
||||
hostSrvName: "example.com",
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tlsConf := TLSConfig{
|
||||
ServerName: tc.hostSrvName,
|
||||
StrictSNICheck: tc.strictSNI,
|
||||
}
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
}
|
||||
|
||||
var conn net.Conn
|
||||
if tc.proto == proxy.ProtoTLS {
|
||||
conn = testTLSConn{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
|
||||
var qs quic.Session
|
||||
if tc.proto == proxy.ProtoQUIC {
|
||||
qs = testQUICSession{
|
||||
serverName: tc.cliSrvName,
|
||||
}
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
srv: srv,
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
},
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.Nil(t, dctx.err)
|
||||
} else {
|
||||
require.NotNil(t, dctx.err)
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClientID_https(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
}{{
|
||||
name: "no_client_id",
|
||||
path: "/dns-query",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "no_client_id_slash",
|
||||
path: "/dns-query/",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id",
|
||||
path: "/dns-query/cli",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id_slash",
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/foo"`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/cli/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "invalid_client_id",
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id: invalid char '!'` +
|
||||
` at index 0 in client id "!!!"`,
|
||||
wantRes: resultCodeError,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: tc.path,
|
||||
},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
},
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.Nil(t, dctx.err)
|
||||
} else {
|
||||
require.NotNil(t, dctx.err)
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -24,22 +24,22 @@ type FilteringConfig struct {
|
||||
// Callbacks for other modules
|
||||
// --
|
||||
|
||||
// Filtering callback function
|
||||
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||
// FilterHandler is an optional additional filtering callback.
|
||||
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
||||
|
||||
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
||||
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
||||
//
|
||||
// TODO(e.burkov): Replace argument type with net.IP.
|
||||
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
|
||||
|
||||
// Protection configuration
|
||||
// --
|
||||
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
|
||||
BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
|
||||
BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
|
||||
BlockingIPAddrv4 net.IP `yaml:"-"`
|
||||
BlockingIPAddrv6 net.IP `yaml:"-"`
|
||||
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
|
||||
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
|
||||
BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
|
||||
BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
|
||||
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
|
||||
|
||||
// IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing
|
||||
@@ -110,6 +110,10 @@ type TLSConfig struct {
|
||||
CertificateChainData []byte `yaml:"-" json:"-"`
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
|
||||
// ServerName is the hostname of the server. Currently, it is only
|
||||
// being used for client ID checking.
|
||||
ServerName string `yaml:"-" json:"-"`
|
||||
|
||||
cert tls.Certificate
|
||||
// DNS names from certificate (SAN) or CN value from Subject
|
||||
dnsNames []string
|
||||
@@ -278,7 +282,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
}
|
||||
|
||||
if len(upstreamConfig.Upstreams) == 0 {
|
||||
log.Info("Warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
uc, err := proxy.ParseUpstreamsConfig(defaultDNS, s.conf.BootstrapDNS, DefaultTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns: failed to parse default upstreams: %v", err)
|
||||
@@ -292,12 +296,13 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
|
||||
// prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries
|
||||
func (s *Server) prepareIntlProxy() {
|
||||
intlProxyConfig := proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: s.conf.UpstreamConfig,
|
||||
s.internalProxy = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
UpstreamConfig: s.conf.UpstreamConfig,
|
||||
},
|
||||
}
|
||||
s.internalProxy = &proxy.Proxy{Config: intlProxyConfig}
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
|
||||
@@ -15,36 +15,69 @@ import (
|
||||
|
||||
// To transfer information between modules
|
||||
type dnsContext struct {
|
||||
srv *Server
|
||||
proxyCtx *proxy.DNSContext
|
||||
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
|
||||
startTime time.Time
|
||||
result *dnsfilter.Result
|
||||
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
|
||||
origQuestion dns.Question // question received from client. Set when Rewrites are used.
|
||||
err error // error returned from the module
|
||||
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
|
||||
responseFromUpstream bool // response is received from upstream servers
|
||||
origReqDNSSEC bool // DNSSEC flag in the original request from user
|
||||
srv *Server
|
||||
proxyCtx *proxy.DNSContext
|
||||
// setts are the filtering settings for the client.
|
||||
setts *dnsfilter.RequestFilteringSettings
|
||||
startTime time.Time
|
||||
result *dnsfilter.Result
|
||||
// origResp is the response received from upstream. It is set when the
|
||||
// response is modified by filters.
|
||||
origResp *dns.Msg
|
||||
// err is the error returned from a processing function.
|
||||
err error
|
||||
// clientID is the clientID from DOH, DOQ, or DOT, if provided.
|
||||
clientID string
|
||||
// origQuestion is the question received from the client. It is set
|
||||
// when the request is modified by rewrites.
|
||||
origQuestion dns.Question
|
||||
// protectionEnabled shows if the filtering is enabled, and if the
|
||||
// server's DNS filter is ready.
|
||||
protectionEnabled bool
|
||||
// responseFromUpstream shows if the response is received from the
|
||||
// upstream servers.
|
||||
responseFromUpstream bool
|
||||
// origReqDNSSEC shows if the DNSSEC flag in the original request from
|
||||
// the client is set.
|
||||
origReqDNSSEC bool
|
||||
}
|
||||
|
||||
// resultCode is the result of a request processing function.
|
||||
type resultCode int
|
||||
|
||||
const (
|
||||
resultDone = iota // module has completed its job, continue
|
||||
resultFinish // module has completed its job, exit normally
|
||||
resultError // an error occurred, exit with an error
|
||||
// resultCodeSuccess is returned when a handler performed successfully,
|
||||
// and the next handler must be called.
|
||||
resultCodeSuccess resultCode = iota
|
||||
// resultCodeFinish is returned when a handler performed successfully,
|
||||
// and the processing of the request must be stopped.
|
||||
resultCodeFinish
|
||||
// resultCodeError is returned when a handler failed, and the processing
|
||||
// of the request must be stopped.
|
||||
resultCodeError
|
||||
)
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
ctx := &dnsContext{srv: s, proxyCtx: d}
|
||||
ctx.result = &dnsfilter.Result{}
|
||||
ctx.startTime = time.Now()
|
||||
ctx := &dnsContext{
|
||||
srv: s,
|
||||
proxyCtx: d,
|
||||
result: &dnsfilter.Result{},
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
type modProcessFunc func(ctx *dnsContext) int
|
||||
type modProcessFunc func(ctx *dnsContext) (rc resultCode)
|
||||
|
||||
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
|
||||
// proxy.(Config).RequestHandler, there is no need for additional index
|
||||
// out of range checking in any of the following functions, because the
|
||||
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
||||
// appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
processInitial,
|
||||
processInternalHosts,
|
||||
processInternalIPAddrs,
|
||||
processClientID,
|
||||
processFilteringBeforeRequest,
|
||||
processUpstream,
|
||||
processDNSSECAfterResponse,
|
||||
@@ -55,13 +88,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
for _, process := range mods {
|
||||
r := process(ctx)
|
||||
switch r {
|
||||
case resultDone:
|
||||
case resultCodeSuccess:
|
||||
// continue: call the next filter
|
||||
|
||||
case resultFinish:
|
||||
case resultCodeFinish:
|
||||
return nil
|
||||
|
||||
case resultError:
|
||||
case resultCodeError:
|
||||
return ctx.err
|
||||
}
|
||||
}
|
||||
@@ -73,12 +106,12 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
}
|
||||
|
||||
// Perform initial checks; process WHOIS & rDNS
|
||||
func processInitial(ctx *dnsContext) int {
|
||||
func processInitial(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
|
||||
_ = proxy.CheckDisabledAAAARequest(d, true)
|
||||
return resultFinish
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
if s.conf.OnDNSRequest != nil {
|
||||
@@ -90,10 +123,10 @@ func processInitial(ctx *dnsContext) int {
|
||||
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
|
||||
d.Req.Question[0].Name == "use-application-dns.net." {
|
||||
d.Res = s.genNXDomain(d.Req)
|
||||
return resultFinish
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Return TRUE if host names doesn't contain disallowed characters
|
||||
@@ -151,32 +184,32 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
|
||||
// Respond to A requests if the target host name is associated with a lease from our DHCP server
|
||||
func processInternalHosts(ctx *dnsContext) int {
|
||||
func processInternalHosts(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
req := ctx.proxyCtx.Req
|
||||
if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
host := req.Question[0].Name
|
||||
host = strings.ToLower(host)
|
||||
if !strings.HasSuffix(host, ".lan.") {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
host = strings.TrimSuffix(host, ".lan.")
|
||||
|
||||
s.tableHostToIPLock.Lock()
|
||||
if s.tableHostToIP == nil {
|
||||
s.tableHostToIPLock.Unlock()
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
ip, ok := s.tableHostToIP[host]
|
||||
s.tableHostToIPLock.Unlock()
|
||||
if !ok {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip.String())
|
||||
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip)
|
||||
|
||||
resp := s.makeResponse(req)
|
||||
|
||||
@@ -194,15 +227,15 @@ func processInternalHosts(ctx *dnsContext) int {
|
||||
}
|
||||
|
||||
ctx.proxyCtx.Res = resp
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Respond to PTR requests if the target IP address is leased by our DHCP server
|
||||
func processInternalIPAddrs(ctx *dnsContext) int {
|
||||
func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
req := ctx.proxyCtx.Req
|
||||
if req.Question[0].Qtype != dns.TypePTR {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
arpa := req.Question[0].Name
|
||||
@@ -210,18 +243,18 @@ func processInternalIPAddrs(ctx *dnsContext) int {
|
||||
arpa = strings.ToLower(arpa)
|
||||
ip := util.DNSUnreverseAddr(arpa)
|
||||
if ip == nil {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
s.tablePTRLock.Lock()
|
||||
if s.tablePTR == nil {
|
||||
s.tablePTRLock.Unlock()
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
host, ok := s.tablePTR[ip.String()]
|
||||
s.tablePTRLock.Unlock()
|
||||
if !ok {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host)
|
||||
@@ -237,16 +270,16 @@ func processInternalIPAddrs(ctx *dnsContext) int {
|
||||
ptr.Ptr = host + "."
|
||||
resp.Answer = append(resp.Answer, ptr)
|
||||
ctx.proxyCtx.Res = resp
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func processFilteringBeforeRequest(ctx *dnsContext) int {
|
||||
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
|
||||
if d.Res != nil {
|
||||
return resultDone // response is already set - nothing to do
|
||||
return resultCodeSuccess // response is already set - nothing to do
|
||||
}
|
||||
|
||||
s.RLock()
|
||||
@@ -260,28 +293,28 @@ func processFilteringBeforeRequest(ctx *dnsContext) int {
|
||||
var err error
|
||||
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||
if ctx.protectionEnabled {
|
||||
ctx.setts = s.getClientRequestFilteringSettings(d)
|
||||
ctx.setts = s.getClientRequestFilteringSettings(ctx)
|
||||
ctx.result, err = s.filterDNSRequest(ctx)
|
||||
}
|
||||
s.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
return resultError
|
||||
return resultCodeError
|
||||
}
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Pass request to upstream servers; process the response
|
||||
func processUpstream(ctx *dnsContext) int {
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func processUpstream(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
if d.Res != nil {
|
||||
return resultDone // response is already set - nothing to do
|
||||
return resultCodeSuccess // response is already set - nothing to do
|
||||
}
|
||||
|
||||
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
|
||||
clientIP := ipFromAddr(d.Addr)
|
||||
clientIP := IPStringFromAddr(d.Addr)
|
||||
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
|
||||
if upstreamsConf != nil {
|
||||
log.Debug("Using custom upstreams for %s", clientIP)
|
||||
@@ -305,26 +338,26 @@ func processUpstream(ctx *dnsContext) int {
|
||||
err := s.dnsProxy.Resolve(d)
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
return resultError
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
ctx.responseFromUpstream = true
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Process DNSSEC after response from upstream server
|
||||
func processDNSSECAfterResponse(ctx *dnsContext) int {
|
||||
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
|
||||
d := ctx.proxyCtx
|
||||
|
||||
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers
|
||||
!ctx.srv.conf.EnableDNSSEC {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
if !ctx.origReqDNSSEC {
|
||||
optResp := d.Res.IsEdns0()
|
||||
if optResp != nil && !optResp.Do() {
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Remove RRSIG records from response
|
||||
@@ -355,19 +388,19 @@ func processDNSSECAfterResponse(ctx *dnsContext) int {
|
||||
d.Res.Ns = answers
|
||||
}
|
||||
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func processFilteringAfterResponse(ctx *dnsContext) int {
|
||||
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
res := ctx.result
|
||||
var err error
|
||||
|
||||
switch res.Reason {
|
||||
case dnsfilter.ReasonRewrite,
|
||||
dnsfilter.DNSRewriteRule:
|
||||
case dnsfilter.Rewritten,
|
||||
dnsfilter.RewrittenRule:
|
||||
|
||||
if len(ctx.origQuestion.Name) == 0 {
|
||||
// origQuestion is set in case we get only CNAME without IP from rewrites table
|
||||
@@ -379,7 +412,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
|
||||
|
||||
if len(d.Res.Answer) != 0 {
|
||||
answer := []dns.RR{}
|
||||
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
|
||||
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
|
||||
answer = append(answer, d.Res.Answer...)
|
||||
d.Res.Answer = answer
|
||||
}
|
||||
@@ -396,7 +429,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
|
||||
ctx.result, err = s.filterDNSResponse(ctx)
|
||||
if err != nil {
|
||||
ctx.err = err
|
||||
return resultError
|
||||
return resultCodeError
|
||||
}
|
||||
if ctx.result != nil {
|
||||
ctx.origResp = origResp2 // matched by response
|
||||
@@ -405,5 +438,5 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
|
||||
}
|
||||
}
|
||||
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -83,10 +85,11 @@ type DNSCreateParams struct {
|
||||
// NewServer creates a new instance of the dnsforward.Server
|
||||
// Note: this function must be called only once
|
||||
func NewServer(p DNSCreateParams) *Server {
|
||||
s := &Server{}
|
||||
s.dnsFilter = p.DNSFilter
|
||||
s.stats = p.Stats
|
||||
s.queryLog = p.QueryLog
|
||||
s := &Server{
|
||||
dnsFilter: p.DNSFilter,
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
}
|
||||
|
||||
if p.DHCPServer != nil {
|
||||
s.dhcpServer = p.DHCPServer
|
||||
@@ -101,6 +104,16 @@ func NewServer(p DNSCreateParams) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
// NewCustomServer creates a new instance of *Server with custom internal proxy.
|
||||
func NewCustomServer(internalProxy *proxy.Proxy) *Server {
|
||||
s := &Server{}
|
||||
if internalProxy != nil {
|
||||
s.internalProxy = internalProxy
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Close - close object
|
||||
func (s *Server) Close() {
|
||||
s.Lock()
|
||||
@@ -108,6 +121,12 @@ func (s *Server) Close() {
|
||||
s.stats = nil
|
||||
s.queryLog = nil
|
||||
s.dnsProxy = nil
|
||||
|
||||
err := s.ipset.Close()
|
||||
if err != nil {
|
||||
log.Error("closing ipset: %s", err)
|
||||
}
|
||||
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
@@ -155,15 +174,15 @@ func (s *Server) Exchange(req *dns.Msg) (*dns.Msg, error) {
|
||||
return ctx.Res, nil
|
||||
}
|
||||
|
||||
// Start starts the DNS server
|
||||
// Start starts the DNS server.
|
||||
func (s *Server) Start() error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.startInternal()
|
||||
return s.startLocked()
|
||||
}
|
||||
|
||||
// startInternal starts without locking
|
||||
func (s *Server) startInternal() error {
|
||||
// startLocked starts the DNS server without locking. For internal use only.
|
||||
func (s *Server) startLocked() error {
|
||||
err := s.dnsProxy.Start()
|
||||
if err == nil {
|
||||
s.isRunning = true
|
||||
@@ -178,9 +197,7 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||
if config != nil {
|
||||
s.conf = *config
|
||||
if s.conf.BlockingMode == "custom_ip" {
|
||||
s.conf.BlockingIPAddrv4 = net.ParseIP(s.conf.BlockingIPv4)
|
||||
s.conf.BlockingIPAddrv6 = net.ParseIP(s.conf.BlockingIPv6)
|
||||
if s.conf.BlockingIPAddrv4 == nil || s.conf.BlockingIPAddrv6 == nil {
|
||||
if s.conf.BlockingIPv4 == nil || s.conf.BlockingIPv6 == nil {
|
||||
return fmt.Errorf("dns: invalid custom blocking IP address specified")
|
||||
}
|
||||
}
|
||||
@@ -192,11 +209,27 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||
|
||||
// Initialize IPSET configuration
|
||||
// --
|
||||
s.ipset.init(s.conf.IPSETList)
|
||||
err := s.ipset.init(s.conf.IPSETList)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrInvalid) && !errors.Is(err, os.ErrPermission) {
|
||||
return fmt.Errorf("cannot initialize ipset: %w", err)
|
||||
}
|
||||
|
||||
// ipset cannot currently be initialized if the server was
|
||||
// installed from Snap or when the user or the binary doesn't
|
||||
// have the required permissions, or when the kernel doesn't
|
||||
// support netfilter.
|
||||
//
|
||||
// Log and go on.
|
||||
//
|
||||
// TODO(a.garipov): The Snap problem can probably be solved if
|
||||
// we add the netlink-connector interface plug.
|
||||
log.Error("cannot initialize ipset: %s", err)
|
||||
}
|
||||
|
||||
// Prepare DNS servers settings
|
||||
// --
|
||||
err := s.prepareUpstreamSettings()
|
||||
err = s.prepareUpstreamSettings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -234,15 +267,15 @@ func (s *Server) Prepare(config *ServerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS server
|
||||
// Stop stops the DNS server.
|
||||
func (s *Server) Stop() error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.stopInternal()
|
||||
return s.stopLocked()
|
||||
}
|
||||
|
||||
// stopInternal stops without locking
|
||||
func (s *Server) stopInternal() error {
|
||||
// stopLocked stops the DNS server without locking. For internal use only.
|
||||
func (s *Server) stopLocked() error {
|
||||
if s.dnsProxy != nil {
|
||||
err := s.dnsProxy.Stop()
|
||||
if err != nil {
|
||||
@@ -267,7 +300,7 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
|
||||
defer s.Unlock()
|
||||
|
||||
log.Print("Start reconfiguring the server")
|
||||
err := s.stopInternal()
|
||||
err := s.stopLocked()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
}
|
||||
@@ -281,7 +314,7 @@ func (s *Server) Reconfigure(config *ServerConfig) error {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
}
|
||||
|
||||
err = s.startInternal()
|
||||
err = s.startLocked()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
}
|
||||
@@ -300,6 +333,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
func (s *Server) IsBlockedIP(ip string) (bool, string) {
|
||||
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
return s.access.IsBlockedIP(ip)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,27 +13,55 @@ import (
|
||||
)
|
||||
|
||||
// filterDNSRewriteResponse handles a single DNS rewrite response entry.
|
||||
// It returns the constructed answer resource record.
|
||||
// It returns the properly constructed answer resource record.
|
||||
func (s *Server) filterDNSRewriteResponse(req *dns.Msg, rr rules.RRType, v rules.RRValue) (ans dns.RR, err error) {
|
||||
// TODO(a.garipov): As more types are added, we will probably want to
|
||||
// use a handler-oriented approach here. So, think of a way to decouple
|
||||
// the answer generation logic from the Server.
|
||||
|
||||
switch rr {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
ip, ok := v.(net.IP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value has type %T, not net.IP", v)
|
||||
return nil, fmt.Errorf("value for rr type %d has type %T, not net.IP", rr, v)
|
||||
}
|
||||
|
||||
if rr == dns.TypeA {
|
||||
return s.genAAnswer(req, ip.To4()), nil
|
||||
return s.genAnswerA(req, ip.To4()), nil
|
||||
}
|
||||
|
||||
return s.genAAAAAnswer(req, ip), nil
|
||||
case dns.TypeTXT:
|
||||
return s.genAnswerAAAA(req, ip), nil
|
||||
case dns.TypePTR,
|
||||
dns.TypeTXT:
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value has type %T, not string", v)
|
||||
return nil, fmt.Errorf("value for rr type %d has type %T, not string", rr, v)
|
||||
}
|
||||
|
||||
return s.genTXTAnswer(req, []string{str}), nil
|
||||
if rr == dns.TypeTXT {
|
||||
return s.genAnswerTXT(req, []string{str}), nil
|
||||
}
|
||||
|
||||
return s.genAnswerPTR(req, str), nil
|
||||
case dns.TypeMX:
|
||||
mx, ok := v.(*rules.DNSMX)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSMX", rr, v)
|
||||
}
|
||||
|
||||
return s.genAnswerMX(req, mx), nil
|
||||
case dns.TypeHTTPS,
|
||||
dns.TypeSVCB:
|
||||
svcb, ok := v.(*rules.DNSSVCB)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("value for rr type %d has type %T, not *rules.DNSSVCB", rr, v)
|
||||
}
|
||||
|
||||
if rr == dns.TypeHTTPS {
|
||||
return s.genAnswerHTTPS(req, svcb), nil
|
||||
}
|
||||
|
||||
return s.genAnswerSVCB(req, svcb), nil
|
||||
default:
|
||||
log.Debug("don't know how to handle dns rr type %d, skipping", rr)
|
||||
|
||||
|
||||
178
internal/dnsforward/dnsrewrite_test.go
Normal file
178
internal/dnsforward/dnsrewrite_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
// Helper data.
|
||||
ip4 := net.IP{127, 0, 0, 1}
|
||||
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
mx := &rules.DNSMX{
|
||||
Exchange: "mail.example.com",
|
||||
Preference: 32,
|
||||
}
|
||||
svcb := &rules.DNSSVCB{
|
||||
Params: map[string]string{"alpn": "h3"},
|
||||
Target: "example.com",
|
||||
Priority: 32,
|
||||
}
|
||||
const domain = "example.com"
|
||||
|
||||
// Helper functions and entities.
|
||||
srv := &Server{}
|
||||
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
|
||||
return &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Qtype: qtype,
|
||||
}},
|
||||
}
|
||||
}
|
||||
makeRes := func(rcode rules.RCode, rr rules.RRType, v rules.RRValue) (res dnsfilter.Result) {
|
||||
resp := dnsfilter.DNSRewriteResultResponse{
|
||||
rr: []rules.RRValue{v},
|
||||
}
|
||||
return dnsfilter.Result{
|
||||
DNSRewriteResult: &dnsfilter.DNSRewriteResult{
|
||||
RCode: rcode,
|
||||
Response: resp,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Tests.
|
||||
t.Run("nxdomain", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeA)
|
||||
res := makeRes(dns.RcodeNameError, 0, nil)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
|
||||
})
|
||||
|
||||
t.Run("noerror_empty", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeA)
|
||||
res := makeRes(dns.RcodeSuccess, 0, nil)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
assert.Empty(t, d.Res.Answer)
|
||||
})
|
||||
|
||||
t.Run("noerror_a", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeA)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeA, ip4)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, ip4, d.Res.Answer[0].(*dns.A).A)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("noerror_aaaa", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeAAAA)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeAAAA, ip6)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, ip6, d.Res.Answer[0].(*dns.AAAA).AAAA)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("noerror_ptr", func(t *testing.T) {
|
||||
req := makeQ(dns.TypePTR)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypePTR, domain)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, domain, d.Res.Answer[0].(*dns.PTR).Ptr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("noerror_txt", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeTXT)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeTXT, domain)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
assert.Equal(t, []string{domain}, d.Res.Answer[0].(*dns.TXT).Txt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("noerror_mx", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeMX)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeMX, mx)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.MX)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, mx.Exchange, ans.Mx)
|
||||
assert.Equal(t, mx.Preference, ans.Preference)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("noerror_svcb", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeSVCB)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcb)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.SVCB)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||
assert.Equal(t, svcb.Target, ans.Target)
|
||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("noerror_https", func(t *testing.T) {
|
||||
req := makeQ(dns.TypeHTTPS)
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcb)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
if assert.Len(t, d.Res.Answer, 1) {
|
||||
ans, ok := d.Res.Answer[0].(*dns.HTTPS)
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, dns.SVCB_ALPN, ans.Value[0].Key())
|
||||
assert.Equal(t, svcb.Params["alpn"], ans.Value[0].String())
|
||||
assert.Equal(t, svcb.Target, ans.Target)
|
||||
assert.Equal(t, svcb.Priority, ans.Priority)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := ipFromAddr(d.Addr)
|
||||
ip := IPFromAddr(d.Addr)
|
||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||
if disallowed {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
@@ -30,15 +30,15 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// getClientRequestFilteringSettings lookups client filtering settings
|
||||
// using the client's IP address from the DNSContext
|
||||
func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings {
|
||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
||||
// the client's IP address and ID, if any, from ctx.
|
||||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
if s.conf.FilterHandler != nil {
|
||||
clientAddr := ipFromAddr(d.Addr)
|
||||
s.conf.FilterHandler(clientAddr, &setts)
|
||||
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||
}
|
||||
|
||||
return &setts
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||
} else if res.IsFiltered {
|
||||
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text)
|
||||
d.Res = s.genDNSFilterMessage(d, &res)
|
||||
} else if res.Reason.In(dnsfilter.ReasonRewrite, dnsfilter.DNSRewriteRule) &&
|
||||
} else if res.Reason.In(dnsfilter.Rewritten, dnsfilter.RewrittenRule) &&
|
||||
res.CanonName != "" &&
|
||||
len(res.IPList) == 0 {
|
||||
// Resolve the new canonical name, not the original host
|
||||
@@ -63,7 +63,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||
// processFilteringAfterResponse.
|
||||
ctx.origQuestion = d.Req.Question[0]
|
||||
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
} else if res.Reason == dnsfilter.RewriteAutoHosts && len(res.ReverseHosts) != 0 {
|
||||
} else if res.Reason == dnsfilter.RewrittenAutoHosts && len(res.ReverseHosts) != 0 {
|
||||
resp := s.makeResponse(req)
|
||||
for _, h := range res.ReverseHosts {
|
||||
hdr := dns.RR_Header{
|
||||
@@ -82,29 +82,29 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||
}
|
||||
|
||||
d.Res = resp
|
||||
} else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteAutoHosts {
|
||||
} else if res.Reason == dnsfilter.Rewritten || res.Reason == dnsfilter.RewrittenAutoHosts {
|
||||
resp := s.makeResponse(req)
|
||||
|
||||
name := host
|
||||
if len(res.CanonName) != 0 {
|
||||
resp.Answer = append(resp.Answer, s.genCNAMEAnswer(req, res.CanonName))
|
||||
resp.Answer = append(resp.Answer, s.genAnswerCNAME(req, res.CanonName))
|
||||
name = res.CanonName
|
||||
}
|
||||
|
||||
for _, ip := range res.IPList {
|
||||
if req.Question[0].Qtype == dns.TypeA {
|
||||
a := s.genAAnswer(req, ip.To4())
|
||||
a := s.genAnswerA(req, ip.To4())
|
||||
a.Hdr.Name = dns.Fqdn(name)
|
||||
resp.Answer = append(resp.Answer, a)
|
||||
} else if req.Question[0].Qtype == dns.TypeAAAA {
|
||||
a := s.genAAAAAnswer(req, ip)
|
||||
a := s.genAnswerAAAA(req, ip)
|
||||
a.Hdr.Name = dns.Fqdn(name)
|
||||
resp.Answer = append(resp.Answer, a)
|
||||
}
|
||||
}
|
||||
|
||||
d.Res = resp
|
||||
} else if res.Reason == dnsfilter.DNSRewriteRule {
|
||||
} else if res.Reason == dnsfilter.RewrittenRule {
|
||||
err = s.filterDNSRewrite(req, res, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
@@ -28,8 +29,8 @@ type dnsConfig struct {
|
||||
ProtectionEnabled *bool `json:"protection_enabled"`
|
||||
RateLimit *uint32 `json:"ratelimit"`
|
||||
BlockingMode *string `json:"blocking_mode"`
|
||||
BlockingIPv4 *string `json:"blocking_ipv4"`
|
||||
BlockingIPv6 *string `json:"blocking_ipv6"`
|
||||
BlockingIPv4 net.IP `json:"blocking_ipv4"`
|
||||
BlockingIPv6 net.IP `json:"blocking_ipv6"`
|
||||
EDNSCSEnabled *bool `json:"edns_cs_enabled"`
|
||||
DNSSECEnabled *bool `json:"dnssec_enabled"`
|
||||
DisableIPv6 *bool `json:"disable_ipv6"`
|
||||
@@ -68,8 +69,8 @@ func (s *Server) getDNSConfig() dnsConfig {
|
||||
Bootstraps: &bootstraps,
|
||||
ProtectionEnabled: &protectionEnabled,
|
||||
BlockingMode: &blockingMode,
|
||||
BlockingIPv4: &BlockingIPv4,
|
||||
BlockingIPv6: &BlockingIPv6,
|
||||
BlockingIPv4: BlockingIPv4,
|
||||
BlockingIPv6: BlockingIPv6,
|
||||
RateLimit: &Ratelimit,
|
||||
EDNSCSEnabled: &EnableEDNSClientSubnet,
|
||||
DNSSECEnabled: &EnableDNSSEC,
|
||||
@@ -100,17 +101,11 @@ func (req *dnsConfig) checkBlockingMode() bool {
|
||||
|
||||
bm := *req.BlockingMode
|
||||
if bm == "custom_ip" {
|
||||
if req.BlockingIPv4 == nil || req.BlockingIPv6 == nil {
|
||||
if req.BlockingIPv4.To4() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ip4 := net.ParseIP(*req.BlockingIPv4)
|
||||
if ip4 == nil || ip4.To4() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ip6 := net.ParseIP(*req.BlockingIPv6)
|
||||
return ip6 != nil
|
||||
return req.BlockingIPv6 != nil
|
||||
}
|
||||
|
||||
for _, valid := range []string{
|
||||
@@ -247,10 +242,8 @@ func (s *Server) setConfig(dc dnsConfig) (restart bool) {
|
||||
if dc.BlockingMode != nil {
|
||||
s.conf.BlockingMode = *dc.BlockingMode
|
||||
if *dc.BlockingMode == "custom_ip" {
|
||||
s.conf.BlockingIPv4 = *dc.BlockingIPv4
|
||||
s.conf.BlockingIPAddrv4 = net.ParseIP(*dc.BlockingIPv4)
|
||||
s.conf.BlockingIPv6 = *dc.BlockingIPv6
|
||||
s.conf.BlockingIPAddrv6 = net.ParseIP(*dc.BlockingIPv6)
|
||||
s.conf.BlockingIPv4 = dc.BlockingIPv4.To4()
|
||||
s.conf.BlockingIPv6 = dc.BlockingIPv6.To16()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,6 +315,11 @@ func ValidateUpstreams(upstreams []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := proxy.ParseUpstreamsConfig(upstreams, []string{}, DefaultTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultUpstreamFound bool
|
||||
for _, u := range upstreams {
|
||||
d, err := validateUpstream(u)
|
||||
@@ -530,12 +528,20 @@ func (s *Server) handleDOH(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister("GET", "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPRegister("POST", "/control/dns_config", s.handleSetConfig)
|
||||
s.conf.HTTPRegister("POST", "/control/test_upstream_dns", s.handleTestUpstreamDNS)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
|
||||
|
||||
s.conf.HTTPRegister("GET", "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPRegister("POST", "/control/access/set", s.handleAccessSet)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
|
||||
|
||||
// Register both versions, with and without the trailing slash, to
|
||||
// prevent a 301 Moved Permanently redirect when clients request the
|
||||
// path without the trailing slash. Those redirects break some clients.
|
||||
//
|
||||
// See go doc net/http.ServeMux.
|
||||
//
|
||||
// See also https://github.com/AdguardTeam/AdGuardHome/issues/2628.
|
||||
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
|
||||
s.conf.HTTPRegister("", "/dns-query/", s.handleDOH)
|
||||
}
|
||||
|
||||
@@ -2,16 +2,35 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
filterConf := &dnsfilter.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchEnabled: true,
|
||||
SafeSearchCacheSize: 1000,
|
||||
ParentalCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
}
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
err := s.Start()
|
||||
assert.Nil(t, err)
|
||||
defer assert.Nil(t, s.Stop())
|
||||
@@ -35,6 +54,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
conf: func() ServerConfig {
|
||||
conf := defaultConf
|
||||
conf.FastestAddr = true
|
||||
|
||||
return conf
|
||||
},
|
||||
want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"fastest_addr\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n",
|
||||
@@ -43,6 +63,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
conf: func() ServerConfig {
|
||||
conf := defaultConf
|
||||
conf.AllServers = true
|
||||
|
||||
return conf
|
||||
},
|
||||
want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"parallel\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n",
|
||||
@@ -61,7 +82,24 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
||||
s := createTestServer(t)
|
||||
filterConf := &dnsfilter.Config{
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchEnabled: true,
|
||||
SafeSearchCacheSize: 1000,
|
||||
ParentalCacheSize: 1000,
|
||||
CacheTime: 30,
|
||||
}
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{},
|
||||
TCPListenAddr: &net.TCPAddr{},
|
||||
FilteringConfig: FilteringConfig{
|
||||
ProtectionEnabled: true,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type ipsetCtx struct {
|
||||
ipsetList map[string][]string // domain -> []ipset_name
|
||||
ipsetCache map[[4]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
|
||||
ipsetMutex *sync.Mutex
|
||||
ipset6Cache map[[16]byte]bool // cache for IP[] to prevent duplicate calls to ipset program
|
||||
ipset6Mutex *sync.Mutex
|
||||
}
|
||||
|
||||
// Convert configuration settings to an internal map
|
||||
// DOMAIN[,DOMAIN].../IPSET1_NAME[,IPSET2_NAME]...
|
||||
func (c *ipsetCtx) init(ipsetConfig []string) {
|
||||
c.ipsetList = make(map[string][]string)
|
||||
c.ipsetCache = make(map[[4]byte]bool)
|
||||
c.ipsetMutex = &sync.Mutex{}
|
||||
c.ipset6Cache = make(map[[16]byte]bool)
|
||||
c.ipset6Mutex = &sync.Mutex{}
|
||||
|
||||
for _, it := range ipsetConfig {
|
||||
it = strings.TrimSpace(it)
|
||||
hostsAndNames := strings.Split(it, "/")
|
||||
if len(hostsAndNames) != 2 {
|
||||
log.Debug("IPSET: invalid value %q", it)
|
||||
continue
|
||||
}
|
||||
|
||||
ipsetNames := strings.Split(hostsAndNames[1], ",")
|
||||
if len(ipsetNames) == 0 {
|
||||
log.Debug("IPSET: invalid value %q", it)
|
||||
continue
|
||||
}
|
||||
bad := false
|
||||
for i := range ipsetNames {
|
||||
ipsetNames[i] = strings.TrimSpace(ipsetNames[i])
|
||||
if len(ipsetNames[i]) == 0 {
|
||||
bad = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if bad {
|
||||
log.Debug("IPSET: invalid value %q", it)
|
||||
continue
|
||||
}
|
||||
|
||||
hosts := strings.Split(hostsAndNames[0], ",")
|
||||
for _, host := range hosts {
|
||||
host = strings.TrimSpace(host)
|
||||
host = strings.ToLower(host)
|
||||
if len(host) == 0 {
|
||||
log.Debug("IPSET: invalid value %q", it)
|
||||
continue
|
||||
}
|
||||
c.ipsetList[host] = ipsetNames
|
||||
}
|
||||
}
|
||||
log.Debug("IPSET: added %d hosts", len(c.ipsetList))
|
||||
}
|
||||
|
||||
func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
|
||||
switch a := rr.(type) {
|
||||
case *dns.A:
|
||||
var ip4 [4]byte
|
||||
copy(ip4[:], a.A.To4())
|
||||
c.ipsetMutex.Lock()
|
||||
defer c.ipsetMutex.Unlock()
|
||||
_, found := c.ipsetCache[ip4]
|
||||
if found {
|
||||
return nil // this IP was added before
|
||||
}
|
||||
c.ipsetCache[ip4] = false
|
||||
return a.A
|
||||
|
||||
case *dns.AAAA:
|
||||
var ip6 [16]byte
|
||||
copy(ip6[:], a.AAAA)
|
||||
c.ipset6Mutex.Lock()
|
||||
defer c.ipset6Mutex.Unlock()
|
||||
_, found := c.ipset6Cache[ip6]
|
||||
if found {
|
||||
return nil // this IP was added before
|
||||
}
|
||||
c.ipset6Cache[ip6] = false
|
||||
return a.AAAA
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add IP addresses of the specified in configuration domain names to an ipset list
|
||||
func (c *ipsetCtx) process(ctx *dnsContext) int {
|
||||
req := ctx.proxyCtx.Req
|
||||
if !(req.Question[0].Qtype == dns.TypeA ||
|
||||
req.Question[0].Qtype == dns.TypeAAAA) ||
|
||||
!ctx.responseFromUpstream {
|
||||
return resultDone
|
||||
}
|
||||
|
||||
host := req.Question[0].Name
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
host = strings.ToLower(host)
|
||||
ipsetNames, found := c.ipsetList[host]
|
||||
if !found {
|
||||
return resultDone
|
||||
}
|
||||
|
||||
log.Debug("IPSET: found ipsets %v for host %s", ipsetNames, host)
|
||||
|
||||
for _, it := range ctx.proxyCtx.Res.Answer {
|
||||
ip := c.getIP(it)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
for _, name := range ipsetNames {
|
||||
code, out, err := util.RunCommand("ipset", "add", name, ipStr)
|
||||
if err != nil {
|
||||
log.Info("IPSET: %s(%s) -> %s: %s", host, ipStr, name, err)
|
||||
continue
|
||||
}
|
||||
if code != 0 {
|
||||
log.Info("IPSET: ipset add: code:%d output:%q", code, out)
|
||||
continue
|
||||
}
|
||||
log.Debug("IPSET: added %s(%s) -> %s", host, ipStr, name)
|
||||
}
|
||||
}
|
||||
|
||||
return resultDone
|
||||
}
|
||||
400
internal/dnsforward/ipset_linux.go
Normal file
400
internal/dnsforward/ipset_linux.go
Normal file
@@ -0,0 +1,400 @@
|
||||
// +build linux
|
||||
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/digineo/go-ipset/v2"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ti-mo/netfilter"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Cover with unit tests as well as document how to test it
|
||||
// manually. The original PR by @dsheets on Github contained an integration
|
||||
// test, but unfortunately I didn't have the time to properly refactor it and
|
||||
// check it in.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2611.
|
||||
|
||||
// ipsetProps contains one Linux Netfilter ipset properties.
|
||||
type ipsetProps struct {
|
||||
name string
|
||||
family netfilter.ProtoFamily
|
||||
}
|
||||
|
||||
// ipsetCtx is the Linux Netfilter ipset context.
|
||||
type ipsetCtx struct {
|
||||
// mu protects all properties below.
|
||||
mu *sync.Mutex
|
||||
|
||||
nameToIpset map[string]ipsetProps
|
||||
domainToIpsets map[string][]ipsetProps
|
||||
|
||||
// TODO(a.garipov): Currently, the ipset list is static, and we don't
|
||||
// read the IPs already in sets, so we can assume that all incoming IPs
|
||||
// are either added to all corresponding ipsets or not. When that stops
|
||||
// being the case, for example if we add dynamic reconfiguration of
|
||||
// ipsets, this map will need to become a per-ipset-name one.
|
||||
addedIPs map[[16]byte]struct{}
|
||||
|
||||
ipv4Conn *ipset.Conn
|
||||
ipv6Conn *ipset.Conn
|
||||
}
|
||||
|
||||
// dialNetfilter establishes connections to Linux's netfilter module.
|
||||
func (c *ipsetCtx) dialNetfilter(config *netlink.Config) (err error) {
|
||||
// The kernel API does not actually require two sockets but package
|
||||
// github.com/digineo/go-ipset does.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps we can ditch package ipset altogether and
|
||||
// just use packages netfilter and netlink.
|
||||
c.ipv4Conn, err = ipset.Dial(netfilter.ProtoIPv4, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing v4: %w", err)
|
||||
}
|
||||
|
||||
c.ipv6Conn, err = ipset.Dial(netfilter.ProtoIPv6, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing v6: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipsetProps returns the properties of an ipset with the given name.
|
||||
func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) {
|
||||
// The family doesn't seem to matter when we use a header query, so
|
||||
// query only the IPv4 one.
|
||||
//
|
||||
// TODO(a.garipov): Find out if this is a bug or a feature.
|
||||
res, err := c.ipv4Conn.Header(name)
|
||||
if err != nil {
|
||||
return set, err
|
||||
}
|
||||
|
||||
if res == nil || res.Family == nil {
|
||||
return set, agherr.Error("empty response or no family data")
|
||||
}
|
||||
|
||||
family := netfilter.ProtoFamily(res.Family.Value)
|
||||
if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 {
|
||||
return set, fmt.Errorf("unexpected ipset family %s", family)
|
||||
}
|
||||
|
||||
return ipsetProps{
|
||||
name: name,
|
||||
family: family,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ipsets returns currently known ipsets.
|
||||
func (c *ipsetCtx) ipsets(names []string) (sets []ipsetProps, err error) {
|
||||
for _, name := range names {
|
||||
set, ok := c.nameToIpset[name]
|
||||
if ok {
|
||||
sets = append(sets, set)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
set, err = c.ipsetProps(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying ipset %q: %w", name, err)
|
||||
}
|
||||
|
||||
c.nameToIpset[name] = set
|
||||
sets = append(sets, set)
|
||||
}
|
||||
|
||||
return sets, nil
|
||||
}
|
||||
|
||||
// parseIpsetConfig parses one ipset configuration string.
|
||||
func parseIpsetConfig(cfgStr string) (hosts, ipsetNames []string, err error) {
|
||||
cfgStr = strings.TrimSpace(cfgStr)
|
||||
hostsAndNames := strings.Split(cfgStr, "/")
|
||||
if len(hostsAndNames) != 2 {
|
||||
return nil, nil, fmt.Errorf("invalid value %q: expected one slash", cfgStr)
|
||||
}
|
||||
|
||||
hosts = strings.Split(hostsAndNames[0], ",")
|
||||
ipsetNames = strings.Split(hostsAndNames[1], ",")
|
||||
|
||||
if len(ipsetNames) == 0 {
|
||||
log.Info("ipset: resolutions for %q will not be stored", hosts)
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
for i := range ipsetNames {
|
||||
ipsetNames[i] = strings.TrimSpace(ipsetNames[i])
|
||||
if len(ipsetNames[i]) == 0 {
|
||||
return nil, nil, fmt.Errorf("invalid value %q: empty ipset name", cfgStr)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range hosts {
|
||||
hosts[i] = strings.TrimSpace(hosts[i])
|
||||
hosts[i] = strings.ToLower(hosts[i])
|
||||
if len(hosts[i]) == 0 {
|
||||
log.Info("ipset: root catchall in %q", ipsetNames)
|
||||
}
|
||||
}
|
||||
|
||||
return hosts, ipsetNames, nil
|
||||
}
|
||||
|
||||
// init initializes the ipset context. It is not safe for concurrent use.
|
||||
//
|
||||
// TODO(a.garipov): Rewrite into a simple constructor?
|
||||
func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
|
||||
c.mu = &sync.Mutex{}
|
||||
c.nameToIpset = make(map[string]ipsetProps)
|
||||
c.domainToIpsets = make(map[string][]ipsetProps)
|
||||
c.addedIPs = make(map[[16]byte]struct{})
|
||||
|
||||
err = c.dialNetfilter(&netlink.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("ipset: dialing netfilter: %w", err)
|
||||
}
|
||||
|
||||
for i, cfgStr := range ipsetConfig {
|
||||
var hosts, ipsetNames []string
|
||||
hosts, ipsetNames, err = parseIpsetConfig(cfgStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ipset: config line at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
var ipsets []ipsetProps
|
||||
ipsets, err = c.ipsets(ipsetNames)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ipset: getting ipsets config line at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
c.domainToIpsets[host] = append(c.domainToIpsets[host], ipsets...)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("ipset: added %d domains for %d ipsets", len(c.domainToIpsets), len(c.nameToIpset))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the Linux Netfilter connections.
|
||||
func (c *ipsetCtx) Close() (err error) {
|
||||
var errors []error
|
||||
if c.ipv4Conn != nil {
|
||||
err = c.ipv4Conn.Close()
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.ipv6Conn != nil {
|
||||
err = c.ipv6Conn.Close()
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) != 0 {
|
||||
return agherr.Many("closing ipsets", errors...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipFromRR returns an IP address from a DNS resource record.
|
||||
func ipFromRR(rr dns.RR) (ip net.IP) {
|
||||
switch a := rr.(type) {
|
||||
case *dns.A:
|
||||
return a.A
|
||||
case *dns.AAAA:
|
||||
return a.AAAA
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// lookupHost find the ipsets for the host, taking subdomain wildcards into
|
||||
// account.
|
||||
func (c *ipsetCtx) lookupHost(host string) (sets []ipsetProps) {
|
||||
// Search for matching ipset hosts starting with most specific
|
||||
// subdomain. We could use a trie here but the simple, inefficient
|
||||
// solution isn't that expensive. ~75 % for 10 subdomains vs 0, but
|
||||
// still sub-microsecond on a Core i7.
|
||||
//
|
||||
// TODO(a.garipov): Re-add benchmarks from the original PR.
|
||||
for i := 0; i != -1; i++ {
|
||||
host = host[i:]
|
||||
sets = c.domainToIpsets[host]
|
||||
if sets != nil {
|
||||
return sets
|
||||
}
|
||||
|
||||
i = strings.Index(host, ".")
|
||||
if i == -1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check the root catch-all one.
|
||||
return c.domainToIpsets[""]
|
||||
}
|
||||
|
||||
// addIPs adds the IP addresses for the host to the ipset. set must be same
|
||||
// family as set's family.
|
||||
func (c *ipsetCtx) addIPs(host string, set ipsetProps, ips []net.IP) (err error) {
|
||||
if len(ips) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
entries := make([]*ipset.Entry, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
|
||||
}
|
||||
|
||||
var conn *ipset.Conn
|
||||
switch set.family {
|
||||
case netfilter.ProtoIPv4:
|
||||
conn = c.ipv4Conn
|
||||
case netfilter.ProtoIPv6:
|
||||
conn = c.ipv6Conn
|
||||
default:
|
||||
return fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
|
||||
}
|
||||
|
||||
err = conn.Add(set.name, entries...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
|
||||
}
|
||||
|
||||
log.Debug("ipset: added %s%s to ipset %s", host, ips, set.name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// skipIpsetProcessing returns true when the ipset processing can be skipped for
|
||||
// this request.
|
||||
func (c *ipsetCtx) skipIpsetProcessing(ctx *dnsContext) (ok bool) {
|
||||
if len(c.domainToIpsets) == 0 || ctx == nil || !ctx.responseFromUpstream {
|
||||
return true
|
||||
}
|
||||
|
||||
req := ctx.proxyCtx.Req
|
||||
if req == nil || len(req.Question) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
qt := req.Question[0].Qtype
|
||||
return qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeANY
|
||||
}
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (c *ipsetCtx) process(ctx *dnsContext) (rc resultCode) {
|
||||
var err error
|
||||
|
||||
if c == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("ipset: starting processing")
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.skipIpsetProcessing(ctx) {
|
||||
log.Debug("ipset: skipped processing for request")
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
req := ctx.proxyCtx.Req
|
||||
host := req.Question[0].Name
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
host = strings.ToLower(host)
|
||||
sets := c.lookupHost(host)
|
||||
if len(sets) == 0 {
|
||||
log.Debug("ipset: no ipsets for host %s", host)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("ipset: found ipsets %+v for host %s", sets, host)
|
||||
|
||||
if ctx.proxyCtx.Res == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ans := ctx.proxyCtx.Res.Answer
|
||||
l := len(ans)
|
||||
v4s := make([]net.IP, 0, l)
|
||||
v6s := make([]net.IP, 0, l)
|
||||
for _, rr := range ans {
|
||||
ip := ipFromRR(rr)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var iparr [16]byte
|
||||
copy(iparr[:], ip.To16())
|
||||
if _, added := c.addedIPs[iparr]; added {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip.To4() == nil {
|
||||
v6s = append(v6s, ip)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
v4s = append(v4s, ip)
|
||||
}
|
||||
|
||||
setLoop:
|
||||
for _, set := range sets {
|
||||
switch set.family {
|
||||
case netfilter.ProtoIPv4:
|
||||
err = c.addIPs(host, set, v4s)
|
||||
if err != nil {
|
||||
break setLoop
|
||||
}
|
||||
case netfilter.ProtoIPv6:
|
||||
err = c.addIPs(host, set, v6s)
|
||||
if err != nil {
|
||||
break setLoop
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
|
||||
break setLoop
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("ipset: adding host ips: %s", err)
|
||||
} else {
|
||||
log.Debug("ipset: processed %d new ips", len(v4s)+len(v6s))
|
||||
}
|
||||
|
||||
for _, ip := range v4s {
|
||||
var iparr [16]byte
|
||||
copy(iparr[:], ip.To16())
|
||||
c.addedIPs[iparr] = struct{}{}
|
||||
}
|
||||
|
||||
for _, ip := range v6s {
|
||||
var iparr [16]byte
|
||||
copy(iparr[:], ip.To16())
|
||||
c.addedIPs[iparr] = struct{}{}
|
||||
}
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
26
internal/dnsforward/ipset_others.go
Normal file
26
internal/dnsforward/ipset_others.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// +build !linux
|
||||
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type ipsetCtx struct{}
|
||||
|
||||
// init initializes the ipset context.
|
||||
func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
|
||||
if len(ipsetConfig) != 0 {
|
||||
log.Info("ipset: only available on linux")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (c *ipsetCtx) process(_ *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Close closes the Linux Netfilter connections.
|
||||
func (c *ipsetCtx) Close() (_ error) { return nil }
|
||||
@@ -1,41 +0,0 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIPSET(t *testing.T) {
|
||||
s := Server{}
|
||||
s.conf.IPSETList = append(s.conf.IPSETList, "HOST.com/name")
|
||||
s.conf.IPSETList = append(s.conf.IPSETList, "host2.com,host3.com/name23")
|
||||
s.conf.IPSETList = append(s.conf.IPSETList, "host4.com/name4,name41")
|
||||
c := ipsetCtx{}
|
||||
c.init(s.conf.IPSETList)
|
||||
|
||||
assert.Equal(t, "name", c.ipsetList["host.com"][0])
|
||||
assert.Equal(t, "name23", c.ipsetList["host2.com"][0])
|
||||
assert.Equal(t, "name23", c.ipsetList["host3.com"][0])
|
||||
assert.Equal(t, "name4", c.ipsetList["host4.com"][0])
|
||||
assert.Equal(t, "name41", c.ipsetList["host4.com"][1])
|
||||
|
||||
_, ok := c.ipsetList["host0.com"]
|
||||
assert.False(t, ok)
|
||||
|
||||
ctx := &dnsContext{
|
||||
srv: &s,
|
||||
}
|
||||
ctx.proxyCtx = &proxy.DNSContext{}
|
||||
ctx.proxyCtx.Req = &dns.Msg{
|
||||
Question: []dns.Question{
|
||||
{
|
||||
Name: "host.com.",
|
||||
Qtype: dns.TypeA,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, resultDone, c.process(ctx))
|
||||
}
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Create a DNS response by DNS request and set necessary flags
|
||||
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
||||
// guarantees that req.Question will be not empty.
|
||||
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
@@ -58,9 +60,9 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
|
||||
|
||||
switch m.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
return s.genARecord(m, s.conf.BlockingIPAddrv4)
|
||||
return s.genARecord(m, s.conf.BlockingIPv4)
|
||||
case dns.TypeAAAA:
|
||||
return s.genAAAARecord(m, s.conf.BlockingIPAddrv6)
|
||||
return s.genAAAARecord(m, s.conf.BlockingIPv6)
|
||||
}
|
||||
} else if s.conf.BlockingMode == "nxdomain" {
|
||||
// means that we should return NXDOMAIN for any blocked request
|
||||
@@ -92,48 +94,64 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
|
||||
|
||||
func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp.Answer = append(resp.Answer, s.genAAnswer(request, ip))
|
||||
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg {
|
||||
resp := s.makeResponse(request)
|
||||
resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip))
|
||||
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A {
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) {
|
||||
return dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Rrtype: rrType,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
answer.A = ip
|
||||
return answer
|
||||
}
|
||||
|
||||
func (s *Server) genAAAAAnswer(req *dns.Msg, ip net.IP) *dns.AAAA {
|
||||
answer := new(dns.AAAA)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
func (s *Server) genAnswerA(req *dns.Msg, ip net.IP) (ans *dns.A) {
|
||||
return &dns.A{
|
||||
Hdr: s.hdr(req, dns.TypeA),
|
||||
A: ip,
|
||||
}
|
||||
answer.AAAA = ip
|
||||
return answer
|
||||
}
|
||||
|
||||
func (s *Server) genTXTAnswer(req *dns.Msg, strs []string) (answer *dns.TXT) {
|
||||
func (s *Server) genAnswerAAAA(req *dns.Msg, ip net.IP) (ans *dns.AAAA) {
|
||||
return &dns.AAAA{
|
||||
Hdr: s.hdr(req, dns.TypeAAAA),
|
||||
AAAA: ip,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genAnswerCNAME(req *dns.Msg, cname string) (ans *dns.CNAME) {
|
||||
return &dns.CNAME{
|
||||
Hdr: s.hdr(req, dns.TypeCNAME),
|
||||
Target: dns.Fqdn(cname),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genAnswerMX(req *dns.Msg, mx *rules.DNSMX) (ans *dns.MX) {
|
||||
return &dns.MX{
|
||||
Hdr: s.hdr(req, dns.TypePTR),
|
||||
Preference: mx.Preference,
|
||||
Mx: mx.Exchange,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genAnswerPTR(req *dns.Msg, ptr string) (ans *dns.PTR) {
|
||||
return &dns.PTR{
|
||||
Hdr: s.hdr(req, dns.TypePTR),
|
||||
Ptr: ptr,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
|
||||
return &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypeTXT,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Hdr: s.hdr(req, dns.TypeTXT),
|
||||
Txt: strs,
|
||||
}
|
||||
}
|
||||
@@ -198,19 +216,6 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
return resp
|
||||
}
|
||||
|
||||
// Make a CNAME response
|
||||
func (s *Server) genCNAMEAnswer(req *dns.Msg, cname string) *dns.CNAME {
|
||||
answer := new(dns.CNAME)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: req.Question[0].Name,
|
||||
Rrtype: dns.TypeCNAME,
|
||||
Ttl: s.conf.BlockedResponseTTL,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
answer.Target = dns.Fqdn(cname)
|
||||
return answer
|
||||
}
|
||||
|
||||
// Create REFUSED DNS response
|
||||
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg {
|
||||
resp := dns.Msg{}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -13,13 +12,13 @@ import (
|
||||
)
|
||||
|
||||
// Write Stats data and logs
|
||||
func processQueryLogsAndStats(ctx *dnsContext) int {
|
||||
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
||||
elapsed := time.Since(ctx.startTime)
|
||||
s := ctx.srv
|
||||
d := ctx.proxyCtx
|
||||
pctx := ctx.proxyCtx
|
||||
|
||||
shouldLog := true
|
||||
msg := d.Req
|
||||
msg := pctx.Req
|
||||
|
||||
// don't log ANY request if refuseAny is enabled
|
||||
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny {
|
||||
@@ -32,65 +31,67 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
|
||||
if shouldLog && s.queryLog != nil {
|
||||
p := querylog.AddParams{
|
||||
Question: msg,
|
||||
Answer: d.Res,
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: ctx.origResp,
|
||||
Result: ctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientIP: getIP(d.Addr),
|
||||
ClientIP: IPFromAddr(pctx.Addr),
|
||||
ClientID: ctx.clientID,
|
||||
}
|
||||
|
||||
switch d.Proto {
|
||||
switch pctx.Proto {
|
||||
case proxy.ProtoHTTPS:
|
||||
p.ClientProto = querylog.ClientProtoDOH
|
||||
case proxy.ProtoQUIC:
|
||||
p.ClientProto = querylog.ClientProtoDOQ
|
||||
case proxy.ProtoTLS:
|
||||
p.ClientProto = querylog.ClientProtoDOT
|
||||
case proxy.ProtoDNSCrypt:
|
||||
p.ClientProto = querylog.ClientProtoDNSCrypt
|
||||
default:
|
||||
// Consider this a plain DNS-over-UDP or DNS-over-TCL
|
||||
// Consider this a plain DNS-over-UDP or DNS-over-TCP
|
||||
// request.
|
||||
}
|
||||
|
||||
if d.Upstream != nil {
|
||||
p.Upstream = d.Upstream.Address()
|
||||
if pctx.Upstream != nil {
|
||||
p.Upstream = pctx.Upstream.Address()
|
||||
}
|
||||
|
||||
s.queryLog.Add(p)
|
||||
}
|
||||
|
||||
s.updateStats(d, elapsed, *ctx.result)
|
||||
s.updateStats(ctx, elapsed, *ctx.result)
|
||||
s.RUnlock()
|
||||
|
||||
return resultDone
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) {
|
||||
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) {
|
||||
if s.stats == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pctx := ctx.proxyCtx
|
||||
e := stats.Entry{}
|
||||
e.Domain = strings.ToLower(d.Req.Question[0].Name)
|
||||
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
|
||||
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
|
||||
switch addr := d.Addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
e.Client = addr.IP
|
||||
case *net.TCPAddr:
|
||||
e.Client = addr.IP
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
|
||||
e.Client = ip.String()
|
||||
}
|
||||
|
||||
e.Time = uint32(elapsed / 1000)
|
||||
e.Result = stats.RNotFiltered
|
||||
|
||||
switch res.Reason {
|
||||
|
||||
case dnsfilter.FilteredSafeBrowsing:
|
||||
e.Result = stats.RSafeBrowsing
|
||||
|
||||
case dnsfilter.FilteredParental:
|
||||
e.Result = stats.RParental
|
||||
|
||||
case dnsfilter.FilteredSafeSearch:
|
||||
e.Result = stats.RSafeSearch
|
||||
|
||||
case dnsfilter.FilteredBlockList:
|
||||
fallthrough
|
||||
case dnsfilter.FilteredInvalid:
|
||||
|
||||
198
internal/dnsforward/stats_test.go
Normal file
198
internal/dnsforward/stats_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// testQueryLog is a simple querylog.QueryLog implementation for tests.
|
||||
type testQueryLog struct {
|
||||
// QueryLog is embedded here simply to make testQueryLog
|
||||
// a querylog.QueryLog without acctually implementing all methods.
|
||||
querylog.QueryLog
|
||||
|
||||
lastParams querylog.AddParams
|
||||
}
|
||||
|
||||
// Add implements the querylog.QueryLog interface for *testQueryLog.
|
||||
func (l *testQueryLog) Add(p querylog.AddParams) {
|
||||
l.lastParams = p
|
||||
}
|
||||
|
||||
// testStats is a simple stats.Stats implementation for tests.
|
||||
type testStats struct {
|
||||
// Stats is embedded here simply to make testStats a stats.Stats without
|
||||
// acctually implementing all methods.
|
||||
stats.Stats
|
||||
|
||||
lastEntry stats.Entry
|
||||
}
|
||||
|
||||
// Update implements the stats.Stats interface for *testStats.
|
||||
func (l *testStats) Update(e stats.Entry) {
|
||||
l.lastEntry = e
|
||||
}
|
||||
|
||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
addr net.Addr
|
||||
clientID string
|
||||
wantLogProto querylog.ClientProto
|
||||
wantStatClient string
|
||||
wantCode resultCode
|
||||
reason dnsfilter.Reason
|
||||
wantStatResult stats.Result
|
||||
}{{
|
||||
name: "success_udp",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls_client_id",
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "cli42",
|
||||
wantLogProto: querylog.ClientProtoDOT,
|
||||
wantStatClient: "cli42",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_tls",
|
||||
proto: proxy.ProtoTLS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDOT,
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_quic",
|
||||
proto: proxy.ProtoQUIC,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDOQ,
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_https",
|
||||
proto: proxy.ProtoHTTPS,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDOH,
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_dnscrypt",
|
||||
proto: proxy.ProtoDNSCrypt,
|
||||
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: querylog.ClientProtoDNSCrypt,
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.NotFilteredNotFound,
|
||||
wantStatResult: stats.RNotFiltered,
|
||||
}, {
|
||||
name: "success_udp_filtered",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.FilteredBlockList,
|
||||
wantStatResult: stats.RFiltered,
|
||||
}, {
|
||||
name: "success_udp_sb",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.FilteredSafeBrowsing,
|
||||
wantStatResult: stats.RSafeBrowsing,
|
||||
}, {
|
||||
name: "success_udp_ss",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.FilteredSafeSearch,
|
||||
wantStatResult: stats.RSafeSearch,
|
||||
}, {
|
||||
name: "success_udp_pc",
|
||||
proto: proxy.ProtoUDP,
|
||||
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
|
||||
clientID: "",
|
||||
wantLogProto: "",
|
||||
wantStatClient: "1.2.3.4",
|
||||
wantCode: resultCodeSuccess,
|
||||
reason: dnsfilter.FilteredParental,
|
||||
wantStatResult: stats.RParental,
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "example.com.",
|
||||
}},
|
||||
}
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Req: req,
|
||||
Res: &dns.Msg{},
|
||||
Addr: tc.addr,
|
||||
Upstream: ups,
|
||||
}
|
||||
|
||||
ql := &testQueryLog{}
|
||||
st := &testStats{}
|
||||
dctx := &dnsContext{
|
||||
srv: &Server{
|
||||
queryLog: ql,
|
||||
stats: st,
|
||||
},
|
||||
proxyCtx: pctx,
|
||||
startTime: time.Now(),
|
||||
result: &dnsfilter.Result{
|
||||
Reason: tc.reason,
|
||||
},
|
||||
clientID: tc.clientID,
|
||||
}
|
||||
|
||||
code := processQueryLogsAndStats(dctx)
|
||||
assert.Equal(t, tc.wantCode, code)
|
||||
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
|
||||
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
|
||||
assert.Equal(t, tc.wantStatResult, st.lastEntry.Result)
|
||||
})
|
||||
}
|
||||
}
|
||||
168
internal/dnsforward/svcbmsg.go
Normal file
168
internal/dnsforward/svcbmsg.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// genAnswerHTTPS returns a properly initialized HTTPS resource record.
|
||||
//
|
||||
// See the comment on genAnswerSVCB for a list of current restrictions on
|
||||
// parameter values.
|
||||
func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) {
|
||||
ans = &dns.HTTPS{
|
||||
SVCB: *s.genAnswerSVCB(req, svcb),
|
||||
}
|
||||
|
||||
ans.Hdr.Rrtype = dns.TypeHTTPS
|
||||
|
||||
return ans
|
||||
}
|
||||
|
||||
// strToSVCBKey is the string-to-svcb-key mapping.
|
||||
//
|
||||
// See https://github.com/miekg/dns/blob/23c4faca9d32b0abbb6e179aa1aadc45ac53a916/svcb.go#L27.
|
||||
//
|
||||
// TODO(a.garipov): Propose exporting this API or something similar in the
|
||||
// github.com/miekg/dns module.
|
||||
var strToSVCBKey = map[string]dns.SVCBKey{
|
||||
"alpn": dns.SVCB_ALPN,
|
||||
"echconfig": dns.SVCB_ECHCONFIG,
|
||||
"ipv4hint": dns.SVCB_IPV4HINT,
|
||||
"ipv6hint": dns.SVCB_IPV6HINT,
|
||||
"mandatory": dns.SVCB_MANDATORY,
|
||||
"no-default-alpn": dns.SVCB_NO_DEFAULT_ALPN,
|
||||
"port": dns.SVCB_PORT,
|
||||
}
|
||||
|
||||
// svcbKeyHandler is a handler for one SVCB parameter key.
|
||||
type svcbKeyHandler func(valStr string) (val dns.SVCBKeyValue)
|
||||
|
||||
// svcbKeyHandlers are the supported SVCB parameters handlers.
|
||||
var svcbKeyHandlers = map[string]svcbKeyHandler{
|
||||
"alpn": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
return &dns.SVCBAlpn{
|
||||
Alpn: []string{valStr},
|
||||
}
|
||||
},
|
||||
|
||||
"echconfig": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
ech, err := base64.StdEncoding.DecodeString(valStr)
|
||||
if err != nil {
|
||||
log.Debug("can't parse svcb/https echconfig: %s; ignoring", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBECHConfig{
|
||||
ECH: ech,
|
||||
}
|
||||
},
|
||||
|
||||
"ipv4hint": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
ip := net.ParseIP(valStr)
|
||||
if ip4 := ip.To4(); ip == nil || ip4 == nil {
|
||||
log.Debug("can't parse svcb/https ipv4 hint %q; ignoring", valStr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBIPv4Hint{
|
||||
Hint: []net.IP{ip},
|
||||
}
|
||||
},
|
||||
|
||||
"ipv6hint": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
ip := net.ParseIP(valStr)
|
||||
if ip == nil {
|
||||
log.Debug("can't parse svcb/https ipv6 hint %q; ignoring", valStr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBIPv6Hint{
|
||||
Hint: []net.IP{ip},
|
||||
}
|
||||
},
|
||||
|
||||
"mandatory": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
code, ok := strToSVCBKey[valStr]
|
||||
if !ok {
|
||||
log.Debug("unknown svcb/https mandatory key %q, ignoring", valStr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBMandatory{
|
||||
Code: []dns.SVCBKey{code},
|
||||
}
|
||||
},
|
||||
|
||||
"no-default-alpn": func(_ string) (val dns.SVCBKeyValue) {
|
||||
return &dns.SVCBNoDefaultAlpn{}
|
||||
},
|
||||
|
||||
"port": func(valStr string) (val dns.SVCBKeyValue) {
|
||||
port64, err := strconv.ParseUint(valStr, 10, 16)
|
||||
if err != nil {
|
||||
log.Debug("can't parse svcb/https port: %s; ignoring", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return &dns.SVCBPort{
|
||||
Port: uint16(port64),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// genAnswerSVCB returns a properly initialized SVCB resource record.
|
||||
//
|
||||
// Currently, there are several restrictions on how the parameters are parsed.
|
||||
// Firstly, the parsing of non-contiguous values isn't supported. Secondly, the
|
||||
// parsing of value-lists is not supported either.
|
||||
//
|
||||
// ipv4hint=127.0.0.1 // Supported.
|
||||
// ipv4hint="127.0.0.1" // Unsupported.
|
||||
// ipv4hint=127.0.0.1,127.0.0.2 // Unsupported.
|
||||
// ipv4hint="127.0.0.1,127.0.0.2" // Unsupported.
|
||||
//
|
||||
// TODO(a.garipov): Support all of these.
|
||||
func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB) {
|
||||
ans = &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: svcb.Priority,
|
||||
Target: svcb.Target,
|
||||
}
|
||||
if len(svcb.Params) == 0 {
|
||||
return ans
|
||||
}
|
||||
|
||||
values := make([]dns.SVCBKeyValue, 0, len(svcb.Params))
|
||||
for k, valStr := range svcb.Params {
|
||||
handler, ok := svcbKeyHandlers[k]
|
||||
if !ok {
|
||||
log.Debug("unknown svcb/https key %q, ignoring", k)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
val := handler(valStr)
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
values = append(values, val)
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
ans.Value = values
|
||||
}
|
||||
|
||||
return ans
|
||||
}
|
||||
154
internal/dnsforward/svcbmsg_test.go
Normal file
154
internal/dnsforward/svcbmsg_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
// Preconditions.
|
||||
|
||||
s := &Server{
|
||||
conf: ServerConfig{
|
||||
FilteringConfig: FilteringConfig{
|
||||
BlockedResponseTTL: 3600,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: "abcd",
|
||||
}},
|
||||
}
|
||||
|
||||
// Constants and helper values.
|
||||
|
||||
const host = "example.com"
|
||||
const prio = 32
|
||||
|
||||
ip4 := net.IPv4(127, 0, 0, 1)
|
||||
ip6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
|
||||
// Helper functions.
|
||||
|
||||
dnssvcb := func(key, value string) (svcb *rules.DNSSVCB) {
|
||||
svcb = &rules.DNSSVCB{
|
||||
Target: host,
|
||||
Priority: prio,
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
return svcb
|
||||
}
|
||||
|
||||
svcb.Params = map[string]string{
|
||||
key: value,
|
||||
}
|
||||
|
||||
return svcb
|
||||
}
|
||||
|
||||
wantsvcb := func(kv dns.SVCBKeyValue) (want *dns.SVCB) {
|
||||
want = &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: prio,
|
||||
Target: host,
|
||||
}
|
||||
|
||||
if kv == nil {
|
||||
return want
|
||||
}
|
||||
|
||||
want.Value = []dns.SVCBKeyValue{kv}
|
||||
|
||||
return want
|
||||
}
|
||||
|
||||
// Tests.
|
||||
|
||||
testCases := []struct {
|
||||
svcb *rules.DNSSVCB
|
||||
want *dns.SVCB
|
||||
name string
|
||||
}{{
|
||||
svcb: dnssvcb("", ""),
|
||||
want: wantsvcb(nil),
|
||||
name: "no_params",
|
||||
}, {
|
||||
svcb: dnssvcb("foo", "bar"),
|
||||
want: wantsvcb(nil),
|
||||
name: "invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("alpn", "h3"),
|
||||
want: wantsvcb(&dns.SVCBAlpn{Alpn: []string{"h3"}}),
|
||||
name: "alpn",
|
||||
}, {
|
||||
svcb: dnssvcb("echconfig", "AAAA"),
|
||||
want: wantsvcb(&dns.SVCBECHConfig{ECH: []byte{0, 0, 0}}),
|
||||
name: "echconfig",
|
||||
}, {
|
||||
svcb: dnssvcb("echconfig", "%BAD%"),
|
||||
want: wantsvcb(nil),
|
||||
name: "echconfig_invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("ipv4hint", "127.0.0.1"),
|
||||
want: wantsvcb(&dns.SVCBIPv4Hint{Hint: []net.IP{ip4}}),
|
||||
name: "ipv4hint",
|
||||
}, {
|
||||
svcb: dnssvcb("ipv4hint", "127.0.01"),
|
||||
want: wantsvcb(nil),
|
||||
name: "ipv4hint_invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("ipv6hint", "::1"),
|
||||
want: wantsvcb(&dns.SVCBIPv6Hint{Hint: []net.IP{ip6}}),
|
||||
name: "ipv6hint",
|
||||
}, {
|
||||
svcb: dnssvcb("ipv6hint", ":::1"),
|
||||
want: wantsvcb(nil),
|
||||
name: "ipv6hint_invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("mandatory", "alpn"),
|
||||
want: wantsvcb(&dns.SVCBMandatory{Code: []dns.SVCBKey{dns.SVCB_ALPN}}),
|
||||
name: "mandatory",
|
||||
}, {
|
||||
svcb: dnssvcb("mandatory", "alpnn"),
|
||||
want: wantsvcb(nil),
|
||||
name: "mandatory_invalid",
|
||||
}, {
|
||||
svcb: dnssvcb("no-default-alpn", ""),
|
||||
want: wantsvcb(&dns.SVCBNoDefaultAlpn{}),
|
||||
name: "no-default-alpn",
|
||||
}, {
|
||||
svcb: dnssvcb("port", "8080"),
|
||||
want: wantsvcb(&dns.SVCBPort{Port: 8080}),
|
||||
name: "port",
|
||||
}, {
|
||||
svcb: dnssvcb("port", "1005008080"),
|
||||
want: wantsvcb(nil),
|
||||
name: "port",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("https", func(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
want := &dns.HTTPS{SVCB: *tc.want}
|
||||
want.Hdr.Rrtype = dns.TypeHTTPS
|
||||
|
||||
got := s.genAnswerHTTPS(req, tc.svcb)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("svcb", func(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := s.genAnswerSVCB(req, tc.svcb)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,38 +8,8 @@ import (
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
||||
// GetIPString is a helper function that extracts IP address from net.Addr
|
||||
func GetIPString(addr net.Addr) string {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP.String()
|
||||
case *net.TCPAddr:
|
||||
return addr.IP.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringArrayDup(a []string) []string {
|
||||
a2 := make([]string, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// Get IP address from net.Addr object
|
||||
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
|
||||
func ipFromAddr(a net.Addr) string {
|
||||
switch addr := a.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP.String()
|
||||
case *net.TCPAddr:
|
||||
return addr.IP.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get IP address from net.Addr
|
||||
func getIP(addr net.Addr) net.IP {
|
||||
// IPFromAddr gets IP address from addr.
|
||||
func IPFromAddr(addr net.Addr) (ip net.IP) {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
return addr.IP
|
||||
@@ -49,6 +19,23 @@ func getIP(addr net.Addr) net.IP {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPStringFromAddr extracts IP address from net.Addr.
|
||||
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
|
||||
func IPStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
if ip := IPFromAddr(addr); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringArrayDup(a []string) []string {
|
||||
a2 := make([]string, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// Find value in a sorted array
|
||||
func findSorted(ar []string, val string) int {
|
||||
i := sort.SearchStrings(ar, val)
|
||||
|
||||
@@ -2,13 +2,10 @@ package home
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -20,8 +17,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
cookieTTL = 365 * 24 // in hours
|
||||
// cookieTTL is given in hours.
|
||||
cookieTTL = 365 * 24
|
||||
sessionCookieName = "agh_session"
|
||||
|
||||
// sessionTokenSize is the length of session token in bytes.
|
||||
sessionTokenSize = 16
|
||||
)
|
||||
|
||||
type session struct {
|
||||
@@ -59,10 +60,10 @@ func (s *session) deserialize(data []byte) bool {
|
||||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
sessions map[string]*session // session name -> session data
|
||||
lock sync.Mutex
|
||||
sessions map[string]*session
|
||||
users []User
|
||||
sessionTTL uint32 // in seconds
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
}
|
||||
|
||||
// User object
|
||||
@@ -223,24 +224,35 @@ func (a *Auth) removeSession(sess []byte) {
|
||||
log.Debug("Auth: removed session from DB")
|
||||
}
|
||||
|
||||
// CheckSession - check if session is valid
|
||||
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
|
||||
func (a *Auth) CheckSession(sess string) int {
|
||||
// checkSessionResult is the result of checking a session.
|
||||
type checkSessionResult int
|
||||
|
||||
// checkSessionResult constants.
|
||||
const (
|
||||
checkSessionOK checkSessionResult = 0
|
||||
checkSessionNotFound checkSessionResult = -1
|
||||
checkSessionExpired checkSessionResult = 1
|
||||
)
|
||||
|
||||
// checkSession checks if the session is valid.
|
||||
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return -1
|
||||
return checkSessionNotFound
|
||||
}
|
||||
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSession(key)
|
||||
a.lock.Unlock()
|
||||
return 1
|
||||
|
||||
return checkSessionExpired
|
||||
}
|
||||
|
||||
newExpire := now + a.sessionTTL
|
||||
@@ -250,8 +262,6 @@ func (a *Auth) CheckSession(sess string) int {
|
||||
s.expire = newExpire
|
||||
}
|
||||
|
||||
a.lock.Unlock()
|
||||
|
||||
if update {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
if a.storeSession(key, s) {
|
||||
@@ -259,7 +269,7 @@ func (a *Auth) CheckSession(sess string) int {
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
return checkSessionOK
|
||||
}
|
||||
|
||||
// RemoveSession - remove session
|
||||
@@ -276,16 +286,29 @@ type loginJSON struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func getSession(u *User) ([]byte, error) {
|
||||
maxSalt := big.NewInt(math.MaxUint32)
|
||||
salt, err := rand.Int(rand.Reader, maxSalt)
|
||||
// newSessionToken returns cryptographically secure randomly generated slice of
|
||||
// bytes of sessionTokenSize length.
|
||||
//
|
||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||
func newSessionToken() (data []byte, err error) {
|
||||
randData := make([]byte, sessionTokenSize)
|
||||
|
||||
_, err = rand.Read(randData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := []byte(fmt.Sprintf("%s%s%s", salt, u.Name, u.PasswordHash))
|
||||
hash := sha256.Sum256(d)
|
||||
return hash[:], nil
|
||||
return randData, nil
|
||||
}
|
||||
|
||||
// cookieTimeFormat is the format to be used in (time.Time).Format for cookie's
|
||||
// expiry field.
|
||||
const cookieTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
|
||||
|
||||
// cookieExpiryFormat returns the formatted exp to be used in cookie string.
|
||||
// It's quite simple for now, but probably will be expanded in the future.
|
||||
func cookieExpiryFormat(exp time.Time) (formatted string) {
|
||||
return exp.Format(cookieTimeFormat)
|
||||
}
|
||||
|
||||
func (a *Auth) httpCookie(req loginJSON) (string, error) {
|
||||
@@ -294,24 +317,23 @@ func (a *Auth) httpCookie(req loginJSON) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
sess, err := getSession(&u)
|
||||
sess, err := newSessionToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expire := now.Add(cookieTTL * time.Hour)
|
||||
expstr := expire.Format(time.RFC1123)
|
||||
expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT"
|
||||
expstr += "GMT"
|
||||
|
||||
s := session{}
|
||||
s.userName = u.Name
|
||||
s.expire = uint32(now.Unix()) + a.sessionTTL
|
||||
a.addSession(sess, &s)
|
||||
a.addSession(sess, &session{
|
||||
userName: u.Name,
|
||||
expire: uint32(now.Unix()) + a.sessionTTL,
|
||||
})
|
||||
|
||||
return fmt.Sprintf("%s=%s; Path=/; HttpOnly; Expires=%s",
|
||||
sessionCookieName, hex.EncodeToString(sess), expstr), nil
|
||||
return fmt.Sprintf(
|
||||
"%s=%s; Path=/; HttpOnly; Expires=%s",
|
||||
sessionCookieName, hex.EncodeToString(sess),
|
||||
cookieExpiryFormat(now.Add(cookieTTL*time.Hour)),
|
||||
), nil
|
||||
}
|
||||
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -360,8 +382,8 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RegisterAuthHandlers - register handlers
|
||||
func RegisterAuthHandlers() {
|
||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler("POST", handleLogin)))
|
||||
httpRegister("GET", "/control/logout", handleLogout)
|
||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||
}
|
||||
|
||||
func parseCookie(cookie string) string {
|
||||
@@ -392,8 +414,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
|
||||
ok = true
|
||||
|
||||
} else if err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
ok = true
|
||||
} else if r < 0 {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
@@ -434,12 +456,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
|
||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
r := Context.auth.CheckSession(cookie.Value)
|
||||
if r == 0 {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
} else if r < 0 {
|
||||
} else if r == checkSessionNotFound {
|
||||
log.Debug("Auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
}
|
||||
@@ -503,32 +526,34 @@ func (a *Auth) UserFind(login, password string) User {
|
||||
return User{}
|
||||
}
|
||||
|
||||
// GetCurrentUser - get the current user
|
||||
func (a *Auth) GetCurrentUser(r *http.Request) User {
|
||||
// getCurrentUser returns the current user. It returns an empty User if the
|
||||
// user is not found.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) User {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// there's no Cookie, check Basic authentication
|
||||
// There's no Cookie, check Basic authentication.
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u := Context.auth.UserFind(user, pass)
|
||||
return u
|
||||
return Context.auth.UserFind(user, pass)
|
||||
}
|
||||
|
||||
return User{}
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return User{}
|
||||
}
|
||||
|
||||
for _, u := range a.users {
|
||||
if u.Name == s.userName {
|
||||
a.lock.Unlock()
|
||||
return u
|
||||
}
|
||||
}
|
||||
a.lock.Unlock()
|
||||
|
||||
return User{}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -9,12 +11,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func prepareTestDir() string {
|
||||
@@ -24,24 +27,44 @@ func prepareTestDir() string {
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestNewSessionToken(t *testing.T) {
|
||||
// Successful case.
|
||||
token, err := newSessionToken()
|
||||
require.Nil(t, err)
|
||||
assert.Len(t, token, sessionTokenSize)
|
||||
|
||||
// Break the rand.Reader.
|
||||
prevReader := rand.Reader
|
||||
t.Cleanup(func() {
|
||||
rand.Reader = prevReader
|
||||
})
|
||||
rand.Reader = &bytes.Buffer{}
|
||||
|
||||
// Unsuccessful case.
|
||||
token, err = newSessionToken()
|
||||
require.NotNil(t, err)
|
||||
assert.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
dir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(dir) }()
|
||||
t.Cleanup(func() { _ = os.RemoveAll(dir) })
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []User{
|
||||
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
users := []User{{
|
||||
Name: "name",
|
||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||
}}
|
||||
a := InitAuth(fn, nil, 60)
|
||||
s := session{}
|
||||
|
||||
user := User{Name: "name"}
|
||||
a.UserAdd(&user, "password")
|
||||
|
||||
assert.True(t, a.CheckSession("notfound") == -1)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.RemoveSession("notfound")
|
||||
|
||||
sess, err := getSession(&users[0])
|
||||
sess, err := newSessionToken()
|
||||
assert.Nil(t, err)
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
@@ -49,13 +72,13 @@ func TestAuth(t *testing.T) {
|
||||
// check expiration
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 1)
|
||||
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
s = session{}
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
|
||||
@@ -63,23 +86,22 @@ func TestAuth(t *testing.T) {
|
||||
a = InitAuth(fn, users, 60)
|
||||
|
||||
// the session is still alive
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
// reset our expiration time because CheckSession() has just updated it
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
// reset our expiration time because checkSession() has just updated it
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u := a.UserFind("name", "password")
|
||||
assert.True(t, len(u.Name) != 0)
|
||||
assert.NotEmpty(t, u.Name)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60)
|
||||
assert.True(t, a.CheckSession(sessStr) == -1)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
os.Remove(fn)
|
||||
}
|
||||
|
||||
// implements http.ResponseWriter
|
||||
@@ -111,7 +133,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
Context.auth = InitAuth(fn, users, 60)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||
handlerCalled = true
|
||||
}
|
||||
handler2 := optionalAuth(handler)
|
||||
@@ -119,15 +141,15 @@ func TestAuthHTTP(t *testing.T) {
|
||||
w.hdr = make(http.Header)
|
||||
r := http.Request{}
|
||||
r.Header = make(http.Header)
|
||||
r.Method = "GET"
|
||||
r.Method = http.MethodGet
|
||||
|
||||
// get / - we're redirected to login page
|
||||
r.URL = &url.URL{Path: "/"}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, w.statusCode == http.StatusFound)
|
||||
assert.True(t, w.hdr.Get("Location") != "")
|
||||
assert.True(t, !handlerCalled)
|
||||
assert.Equal(t, http.StatusFound, w.statusCode)
|
||||
assert.NotEmpty(t, w.hdr.Get("Location"))
|
||||
assert.False(t, handlerCalled)
|
||||
|
||||
// go to login page
|
||||
loginURL := w.hdr.Get("Location")
|
||||
@@ -139,7 +161,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
// perform login
|
||||
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, cookie != "")
|
||||
assert.NotEmpty(t, cookie)
|
||||
|
||||
// get /
|
||||
handler2 = optionalAuth(handler)
|
||||
@@ -168,8 +190,8 @@ func TestAuthHTTP(t *testing.T) {
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, w.hdr.Get("Location") != "")
|
||||
assert.True(t, !handlerCalled)
|
||||
assert.NotEmpty(t, w.hdr.Get("Location"))
|
||||
assert.False(t, handlerCalled)
|
||||
r.Header.Del("Cookie")
|
||||
|
||||
// get login page with an invalid cookie
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestAuthGL(t *testing.T) {
|
||||
binary.BigEndian.PutUint32(data, tval)
|
||||
}
|
||||
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
|
||||
r, _ := http.NewRequest("GET", "http://localhost/", nil)
|
||||
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
|
||||
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
|
||||
assert.True(t, glProcessCookie(r))
|
||||
GLMode = false
|
||||
|
||||
@@ -11,23 +11,21 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
clientsUpdatePeriod = 10 * time.Minute
|
||||
)
|
||||
const clientsUpdatePeriod = 10 * time.Minute
|
||||
|
||||
var webHandlersRegistered = false
|
||||
|
||||
// Client information
|
||||
// Client contains information about persistent clients.
|
||||
type Client struct {
|
||||
IDs []string
|
||||
Tags []string
|
||||
@@ -52,14 +50,13 @@ type Client struct {
|
||||
|
||||
type clientSource uint
|
||||
|
||||
// Client sources
|
||||
// Client sources. The order determines the priority.
|
||||
const (
|
||||
// Priority: etc/hosts > DHCP > ARP > rDNS > WHOIS
|
||||
ClientSourceWHOIS clientSource = iota // from WHOIS
|
||||
ClientSourceRDNS // from rDNS
|
||||
ClientSourceDHCP // from DHCP
|
||||
ClientSourceARP // from 'arp -a'
|
||||
ClientSourceHostsFile // from /etc/hosts
|
||||
ClientSourceWHOIS clientSource = iota
|
||||
ClientSourceRDNS
|
||||
ClientSourceDHCP
|
||||
ClientSourceARP
|
||||
ClientSourceHostsFile
|
||||
)
|
||||
|
||||
// ClientHost information
|
||||
@@ -70,8 +67,10 @@ type ClientHost struct {
|
||||
}
|
||||
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for
|
||||
// different types (string, net.IP, and so on).
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // IP -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
ipHost map[string]*ClientHost // IP -> Hostname
|
||||
lock sync.Mutex
|
||||
|
||||
@@ -156,7 +155,7 @@ func (clients *clientsContainer) tagKnown(tag string) bool {
|
||||
|
||||
func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||
for _, cy := range objects {
|
||||
cli := Client{
|
||||
cli := &Client{
|
||||
Name: cy.Name,
|
||||
IDs: cy.IDs,
|
||||
UseOwnSettings: !cy.UseGlobalSettings,
|
||||
@@ -172,7 +171,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||
|
||||
for _, s := range cy.BlockedServices {
|
||||
if !dnsfilter.BlockedSvcKnown(s) {
|
||||
log.Debug("Clients: skipping unknown blocked-service %q", s)
|
||||
log.Debug("clients: skipping unknown blocked-service %q", s)
|
||||
continue
|
||||
}
|
||||
cli.BlockedServices = append(cli.BlockedServices, s)
|
||||
@@ -180,7 +179,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) {
|
||||
|
||||
for _, t := range cy.Tags {
|
||||
if !clients.tagKnown(t) {
|
||||
log.Debug("Clients: skipping unknown tag %q", t)
|
||||
log.Debug("clients: skipping unknown tag %q", t)
|
||||
continue
|
||||
}
|
||||
cli.Tags = append(cli.Tags, t)
|
||||
@@ -208,10 +207,10 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) {
|
||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||
}
|
||||
|
||||
cy.Tags = stringArrayDup(cli.Tags)
|
||||
cy.IDs = stringArrayDup(cli.IDs)
|
||||
cy.BlockedServices = stringArrayDup(cli.BlockedServices)
|
||||
cy.Upstreams = stringArrayDup(cli.Upstreams)
|
||||
cy.Tags = copyStrings(cli.Tags)
|
||||
cy.IDs = copyStrings(cli.IDs)
|
||||
cy.BlockedServices = copyStrings(cli.BlockedServices)
|
||||
cy.Upstreams = copyStrings(cli.Upstreams)
|
||||
|
||||
*objects = append(*objects, cy)
|
||||
}
|
||||
@@ -238,45 +237,44 @@ func (clients *clientsContainer) onHostsChanged() {
|
||||
clients.addFromHostsFile()
|
||||
}
|
||||
|
||||
// Exists checks if client with this IP already exists
|
||||
func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
|
||||
// Exists checks if client with this ID already exists.
|
||||
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findByIP(ip)
|
||||
_, ok = clients.findLocked(id)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
ch, ok := clients.ipHost[ip]
|
||||
var ch *ClientHost
|
||||
ch, ok = clients.ipHost[id]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if source > ch.Source {
|
||||
return false // we're going to overwrite this client's info with a stronger source
|
||||
}
|
||||
return true
|
||||
|
||||
// Return false if the new source has higher priority.
|
||||
return source <= ch.Source
|
||||
}
|
||||
|
||||
func stringArrayDup(a []string) []string {
|
||||
a2 := make([]string, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
func copyStrings(a []string) (b []string) {
|
||||
return append(b, a...)
|
||||
}
|
||||
|
||||
// Find searches for a client by IP
|
||||
func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||
// Find searches for a client by its ID.
|
||||
func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findByIP(ip)
|
||||
c, ok = clients.findLocked(id)
|
||||
if !ok {
|
||||
return Client{}, false
|
||||
return nil, false
|
||||
}
|
||||
c.IDs = stringArrayDup(c.IDs)
|
||||
c.Tags = stringArrayDup(c.Tags)
|
||||
c.BlockedServices = stringArrayDup(c.BlockedServices)
|
||||
c.Upstreams = stringArrayDup(c.Upstreams)
|
||||
|
||||
c.IDs = copyStrings(c.IDs)
|
||||
c.Tags = copyStrings(c.Tags)
|
||||
c.BlockedServices = copyStrings(c.BlockedServices)
|
||||
c.Upstreams = copyStrings(c.Upstreams)
|
||||
return c, true
|
||||
}
|
||||
|
||||
@@ -287,7 +285,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findByIP(ip)
|
||||
c, ok := clients.findLocked(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@@ -306,16 +304,16 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
|
||||
return c.upstreamConfig
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if ipAddr == nil {
|
||||
return Client{}, false
|
||||
// findLocked searches for a client by its ID. For internal use only.
|
||||
func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
c, ok = clients.idIndex[id]
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
c, ok := clients.idIndex[ip]
|
||||
if ok {
|
||||
return *c, true
|
||||
ip := net.ParseIP(id)
|
||||
if ip == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
@@ -324,32 +322,36 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if ipnet.Contains(ipAddr) {
|
||||
return *c, true
|
||||
|
||||
if ipnet.Contains(ip) {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if clients.dhcpServer == nil {
|
||||
return Client{}, false
|
||||
return nil, false
|
||||
}
|
||||
macFound := clients.dhcpServer.FindMACbyIP(ipAddr)
|
||||
|
||||
macFound := clients.dhcpServer.FindMACbyIP(ip)
|
||||
if macFound == nil {
|
||||
return Client{}, false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
hwAddr, err := net.ParseMAC(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(hwAddr, macFound) {
|
||||
return *c, true
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Client{}, false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindAutoClient - search for an auto-client by IP
|
||||
@@ -369,44 +371,47 @@ func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) {
|
||||
return ClientHost{}, false
|
||||
}
|
||||
|
||||
// Check if Client object's fields are correct
|
||||
func (clients *clientsContainer) check(c *Client) error {
|
||||
if len(c.Name) == 0 {
|
||||
return fmt.Errorf("invalid Name")
|
||||
}
|
||||
|
||||
if len(c.IDs) == 0 {
|
||||
return fmt.Errorf("id required")
|
||||
// check validates the client.
|
||||
func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
return agherr.Error("client is nil")
|
||||
case c.Name == "":
|
||||
return agherr.Error("invalid name")
|
||||
case len(c.IDs) == 0:
|
||||
return agherr.Error("id required")
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
for i, id := range c.IDs {
|
||||
ip := net.ParseIP(id)
|
||||
if ip != nil {
|
||||
c.IDs[i] = ip.String() // normalize IP address
|
||||
continue
|
||||
// Normalize structured data.
|
||||
var ip net.IP
|
||||
var ipnet *net.IPNet
|
||||
var mac net.HardwareAddr
|
||||
if ip = net.ParseIP(id); ip != nil {
|
||||
c.IDs[i] = ip.String()
|
||||
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
|
||||
ipnet.IP = ip
|
||||
c.IDs[i] = ipnet.String()
|
||||
} else if mac, err = net.ParseMAC(id); err == nil {
|
||||
c.IDs[i] = mac.String()
|
||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||
c.IDs[i] = id
|
||||
} else {
|
||||
return fmt.Errorf("invalid client id at index %d: %q", i, id)
|
||||
}
|
||||
|
||||
_, _, err := net.ParseCIDR(id)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = net.ParseMAC(id)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid ID: %s", id)
|
||||
}
|
||||
|
||||
for _, t := range c.Tags {
|
||||
if !clients.tagKnown(t) {
|
||||
return fmt.Errorf("invalid tag: %s", t)
|
||||
return fmt.Errorf("invalid tag: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(c.Tags)
|
||||
|
||||
err := dnsforward.ValidateUpstreams(c.Upstreams)
|
||||
err = dnsforward.ValidateUpstreams(c.Upstreams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid upstream servers: %w", err)
|
||||
}
|
||||
@@ -414,49 +419,52 @@ func (clients *clientsContainer) check(c *Client) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a new client object
|
||||
// Return true: success; false: client exists.
|
||||
func (clients *clientsContainer) Add(c Client) (bool, error) {
|
||||
e := clients.check(&c)
|
||||
if e != nil {
|
||||
return false, e
|
||||
// Add adds a new client object. ok is false if such client already exists or
|
||||
// if an error occurred.
|
||||
func (clients *clientsContainer) Add(c *Client) (ok bool, err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// check Name index
|
||||
_, ok := clients.list[c.Name]
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// check ID index
|
||||
for _, id := range c.IDs {
|
||||
c2, ok := clients.idIndex[id]
|
||||
var c2 *Client
|
||||
c2, ok = clients.idIndex[id]
|
||||
if ok {
|
||||
return false, fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
|
||||
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// update Name index
|
||||
clients.list[c.Name] = &c
|
||||
clients.list[c.Name] = c
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.IDs {
|
||||
clients.idIndex[id] = &c
|
||||
clients.idIndex[id] = c
|
||||
}
|
||||
|
||||
log.Debug("Clients: added %q: ID:%v [%d]", c.Name, c.IDs, len(clients.list))
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list))
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Del removes a client
|
||||
func (clients *clientsContainer) Del(name string) bool {
|
||||
// Del removes a client. ok is false if there is no such client.
|
||||
func (clients *clientsContainer) Del(name string) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.list[name]
|
||||
var c *Client
|
||||
c, ok = clients.list[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -468,25 +476,28 @@ func (clients *clientsContainer) Del(name string) bool {
|
||||
for _, id := range c.IDs {
|
||||
delete(clients.idIndex, id)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Return TRUE if arrays are equal
|
||||
func arraysEqual(a, b []string) bool {
|
||||
// equalStringSlices returns true if the slices are equal.
|
||||
func equalStringSlices(a, b []string) (ok bool) {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i != len(a); i++ {
|
||||
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Update a client
|
||||
func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
err := clients.check(&c)
|
||||
// Update updates a client by its name.
|
||||
func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||
err = clients.check(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -494,65 +505,69 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
old, ok := clients.list[name]
|
||||
prev, ok := clients.list[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("client not found")
|
||||
return agherr.Error("client not found")
|
||||
}
|
||||
|
||||
// check Name index
|
||||
if old.Name != c.Name {
|
||||
if prev.Name != c.Name {
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return fmt.Errorf("client already exists")
|
||||
return agherr.Error("client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
// check IP index
|
||||
if !arraysEqual(old.IDs, c.IDs) {
|
||||
if !equalStringSlices(prev.IDs, c.IDs) {
|
||||
for _, id := range c.IDs {
|
||||
c2, ok := clients.idIndex[id]
|
||||
if ok && c2 != old {
|
||||
return fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name)
|
||||
if ok && c2 != prev {
|
||||
return fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// update ID index
|
||||
for _, id := range old.IDs {
|
||||
for _, id := range prev.IDs {
|
||||
delete(clients.idIndex, id)
|
||||
}
|
||||
for _, id := range c.IDs {
|
||||
clients.idIndex[id] = old
|
||||
clients.idIndex[id] = prev
|
||||
}
|
||||
}
|
||||
|
||||
// update Name index
|
||||
if old.Name != c.Name {
|
||||
delete(clients.list, old.Name)
|
||||
clients.list[c.Name] = old
|
||||
if prev.Name != c.Name {
|
||||
delete(clients.list, prev.Name)
|
||||
clients.list[c.Name] = prev
|
||||
}
|
||||
|
||||
// update upstreams cache
|
||||
c.upstreamConfig = nil
|
||||
|
||||
*old = c
|
||||
*prev = *c
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWhoisInfo - associate WHOIS information with a client
|
||||
// SetWhoisInfo sets the WHOIS information for a client.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps replace [][]string with map[string]string.
|
||||
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findByIP(ip)
|
||||
_, ok := clients.findLocked(ip)
|
||||
if ok {
|
||||
log.Debug("Clients: client for %s is already created, ignore WHOIS info", ip)
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
return
|
||||
}
|
||||
|
||||
ch, ok := clients.ipHost[ip]
|
||||
if ok {
|
||||
ch.WhoisInfo = info
|
||||
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
|
||||
log.Debug("clients: set whois info for auto-client %s: %q", ch.Host, info)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -562,31 +577,33 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
|
||||
}
|
||||
ch.WhoisInfo = info
|
||||
clients.ipHost[ip] = ch
|
||||
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
|
||||
log.Debug("clients: set whois info for auto-client with IP %s: %q", ip, info)
|
||||
}
|
||||
|
||||
// AddHost adds new IP -> Host pair
|
||||
// Use priority of the source (etc/hosts > ARP > rDNS)
|
||||
// so we overwrite existing entries with an equal or higher priority
|
||||
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
|
||||
clients.lock.Lock()
|
||||
b := clients.addHost(ip, host, source)
|
||||
ok = clients.addHostLocked(ip, host, src)
|
||||
clients.lock.Unlock()
|
||||
return b, nil
|
||||
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (clients *clientsContainer) addHost(ip, host string, source clientSource) (addedNew bool) {
|
||||
ch, ok := clients.ipHost[ip]
|
||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
|
||||
var ch *ClientHost
|
||||
ch, ok = clients.ipHost[ip]
|
||||
if ok {
|
||||
if ch.Source > source {
|
||||
if ch.Source > src {
|
||||
return false
|
||||
}
|
||||
|
||||
ch.Source = source
|
||||
ch.Source = src
|
||||
} else {
|
||||
ch = &ClientHost{
|
||||
Host: host,
|
||||
Source: source,
|
||||
Source: src,
|
||||
}
|
||||
|
||||
clients.ipHost[ip] = ch
|
||||
@@ -597,11 +614,11 @@ func (clients *clientsContainer) addHost(ip, host string, source clientSource) (
|
||||
return true
|
||||
}
|
||||
|
||||
// Remove all entries that match the specified source
|
||||
func (clients *clientsContainer) rmHosts(source clientSource) {
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
n := 0
|
||||
for k, v := range clients.ipHost {
|
||||
if v.Source == source {
|
||||
if v.Source == src {
|
||||
delete(clients.ipHost, k)
|
||||
n++
|
||||
}
|
||||
@@ -610,19 +627,20 @@ func (clients *clientsContainer) rmHosts(source clientSource) {
|
||||
log.Debug("clients: removed %d client aliases", n)
|
||||
}
|
||||
|
||||
// addFromHostsFile fills the clients hosts list from the system's hosts files.
|
||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||
// hosts files.
|
||||
func (clients *clientsContainer) addFromHostsFile() {
|
||||
hosts := clients.autoHosts.List()
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHosts(ClientSourceHostsFile)
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
for ip, names := range hosts {
|
||||
for _, name := range names {
|
||||
ok := clients.addHost(ip, name, ClientSourceHostsFile)
|
||||
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
@@ -632,31 +650,31 @@ func (clients *clientsContainer) addFromHostsFile() {
|
||||
log.Debug("Clients: added %d client aliases from system hosts-file", n)
|
||||
}
|
||||
|
||||
// Add IP -> Host pairs from the system's `arp -a` command output
|
||||
// The command's output is:
|
||||
// HOST (IP) at MAC on IFACE
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
// command.
|
||||
func (clients *clientsContainer) addFromSystemARP() {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("arp", "-a")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
log.Tracef("executing %q %q", cmd.Path, cmd.Args)
|
||||
data, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Debug("command %s has failed: %v code:%d",
|
||||
log.Debug("command %q has failed: %q code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
return
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
clients.rmHosts(ClientSourceARP)
|
||||
|
||||
clients.rmHostsBySrc(ClientSourceARP)
|
||||
|
||||
n := 0
|
||||
// TODO(a.garipov): Rewrite to use bufio.Scanner.
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, ln := range lines {
|
||||
|
||||
open := strings.Index(ln, " (")
|
||||
close := strings.Index(ln, ") ")
|
||||
if open == -1 || close == -1 || open >= close {
|
||||
@@ -669,16 +687,17 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHost(ip, host, ClientSourceARP)
|
||||
ok := clients.addHostLocked(ip, host, ClientSourceARP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Clients: added %d client aliases from 'arp -a' command output", n)
|
||||
log.Debug("clients: added %d client aliases from 'arp -a' command output", n)
|
||||
}
|
||||
|
||||
// Add clients from DHCP that have non-empty Hostname property
|
||||
// addFromDHCP adds the clients that have a non-empty hostname from the DHCP
|
||||
// server.
|
||||
func (clients *clientsContainer) addFromDHCP() {
|
||||
if clients.dhcpServer == nil {
|
||||
return
|
||||
@@ -687,18 +706,20 @@ func (clients *clientsContainer) addFromDHCP() {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
clients.rmHosts(ClientSourceDHCP)
|
||||
clients.rmHostsBySrc(ClientSourceDHCP)
|
||||
|
||||
leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
n := 0
|
||||
for _, l := range leases {
|
||||
if len(l.Hostname) == 0 {
|
||||
if l.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
ok := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
|
||||
|
||||
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
log.Debug("Clients: added %d client aliases from DHCP", n)
|
||||
|
||||
log.Debug("clients: added %d client aliases from dhcp", n)
|
||||
}
|
||||
|
||||
@@ -18,32 +18,35 @@ func TestClients(t *testing.T) {
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
c := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
}
|
||||
|
||||
b, err := clients.Add(c)
|
||||
assert.True(t, b)
|
||||
ok, err := clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client2",
|
||||
}
|
||||
|
||||
b, err = clients.Add(c)
|
||||
assert.True(t, b)
|
||||
ok, err = clients.Add(c)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
c, b = clients.Find("1.1.1.1")
|
||||
assert.True(t, b && c.Name == "client1")
|
||||
c, ok = clients.Find("1.1.1.1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, b = clients.Find("1:2:3::4")
|
||||
assert.True(t, b && c.Name == "client1")
|
||||
c, ok = clients.Find("1:2:3::4")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, b = clients.Find("2.2.2.2")
|
||||
assert.True(t, b && c.Name == "client2")
|
||||
c, ok = clients.Find("2.2.2.2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
@@ -51,29 +54,29 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
c := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"1.2.3.5"},
|
||||
Name: "client1",
|
||||
}
|
||||
|
||||
b, err := clients.Add(c)
|
||||
assert.False(t, b)
|
||||
ok, err := clients.Add(c)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
c := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
}
|
||||
|
||||
b, err := clients.Add(c)
|
||||
assert.False(t, b)
|
||||
ok, err := clients.Add(c)
|
||||
assert.False(t, ok)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_fail_name", func(t *testing.T) {
|
||||
c := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"1.2.3.0"},
|
||||
Name: "client3",
|
||||
}
|
||||
@@ -81,7 +84,7 @@ func TestClients(t *testing.T) {
|
||||
err := clients.Update("client3", c)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{"1.2.3.0"},
|
||||
Name: "client2",
|
||||
}
|
||||
@@ -91,7 +94,7 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
c := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client1",
|
||||
}
|
||||
@@ -101,7 +104,7 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update_success", func(t *testing.T) {
|
||||
c := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
}
|
||||
@@ -112,7 +115,7 @@ func TestClients(t *testing.T) {
|
||||
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1-renamed",
|
||||
UseOwnSettings: true,
|
||||
@@ -121,50 +124,52 @@ func TestClients(t *testing.T) {
|
||||
err = clients.Update("client1", c)
|
||||
assert.Nil(t, err)
|
||||
|
||||
c, b := clients.Find("1.1.1.2")
|
||||
assert.True(t, b)
|
||||
assert.True(t, c.Name == "client1-renamed")
|
||||
assert.True(t, c.IDs[0] == "1.1.1.2")
|
||||
c, ok := clients.Find("1.1.1.2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
assert.Nil(t, clients.list["client1"])
|
||||
if assert.Len(t, c.IDs, 1) {
|
||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("del_success", func(t *testing.T) {
|
||||
b := clients.Del("client1-renamed")
|
||||
assert.True(t, b)
|
||||
ok := clients.Del("client1-renamed")
|
||||
assert.True(t, ok)
|
||||
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
b := clients.Del("client3")
|
||||
assert.False(t, b)
|
||||
ok := clients.Del("client3")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
b, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
assert.True(t, b)
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
b, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
assert.True(t, b)
|
||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
b, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
assert.True(t, b)
|
||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
b, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
assert.False(t, b)
|
||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientsWhois(t *testing.T) {
|
||||
var c Client
|
||||
var c *Client
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
clients.Init(nil, nil, nil)
|
||||
@@ -172,26 +177,36 @@ func TestClientsWhois(t *testing.T) {
|
||||
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
|
||||
// set whois info on new client
|
||||
clients.SetWhoisInfo("1.1.1.255", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val")
|
||||
if assert.NotNil(t, clients.ipHost["1.1.1.255"]) {
|
||||
h := clients.ipHost["1.1.1.255"]
|
||||
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
|
||||
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
|
||||
}
|
||||
}
|
||||
|
||||
// set whois info on existing auto-client
|
||||
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val")
|
||||
if assert.NotNil(t, clients.ipHost["1.1.1.1"]) {
|
||||
h := clients.ipHost["1.1.1.1"]
|
||||
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
|
||||
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
|
||||
}
|
||||
}
|
||||
|
||||
// Check that we cannot set whois info on a manually-added client
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
}
|
||||
_, _ = clients.Add(c)
|
||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||
assert.True(t, clients.ipHost["1.1.1.2"] == nil)
|
||||
assert.Nil(t, clients.ipHost["1.1.1.2"])
|
||||
_ = clients.Del("client1")
|
||||
}
|
||||
|
||||
func TestClientsAddExisting(t *testing.T) {
|
||||
var c Client
|
||||
var c *Client
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
clients.Init(nil, nil, nil)
|
||||
@@ -201,7 +216,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
testIP := "1.2.3.4"
|
||||
|
||||
// add a client
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
}
|
||||
@@ -230,7 +245,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
// add a new client with the same IP as for a client with MAC
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{testIP},
|
||||
Name: "client2",
|
||||
}
|
||||
@@ -239,7 +254,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
// add a new client with the IP from the client1's IP range
|
||||
c = Client{
|
||||
c = &Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
}
|
||||
@@ -255,7 +270,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// add client with upstreams
|
||||
client := Client{
|
||||
c := &Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
Upstreams: []string{
|
||||
@@ -263,7 +278,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
}
|
||||
ok, err := clients.Add(client)
|
||||
ok, err := clients.Add(c)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ type clientJSON struct {
|
||||
|
||||
Upstreams []string `json:"upstreams"`
|
||||
|
||||
WhoisInfo map[string]interface{} `json:"whois_info"`
|
||||
WhoisInfo map[string]string `json:"whois_info"`
|
||||
|
||||
// Disallowed - if true -- client's IP is not disallowed
|
||||
// Otherwise, it is blocked.
|
||||
@@ -38,7 +39,7 @@ type clientHostJSON struct {
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
|
||||
WhoisInfo map[string]interface{} `json:"whois_info"`
|
||||
WhoisInfo map[string]string `json:"whois_info"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
@@ -74,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||
cj.Source = "WHOIS"
|
||||
}
|
||||
|
||||
cj.WhoisInfo = make(map[string]interface{})
|
||||
cj.WhoisInfo = map[string]string{}
|
||||
for _, wi := range ch.WhoisInfo {
|
||||
cj.WhoisInfo[wi[0]] = wi[1]
|
||||
}
|
||||
@@ -139,7 +140,7 @@ func clientHostToJSON(ip string, ch ClientHost) clientJSON {
|
||||
IDs: []string{ip},
|
||||
}
|
||||
|
||||
cj.WhoisInfo = make(map[string]interface{})
|
||||
cj.WhoisInfo = map[string]string{}
|
||||
for _, wi := range ch.WhoisInfo {
|
||||
cj.WhoisInfo[wi[0]] = wi[1]
|
||||
}
|
||||
@@ -157,7 +158,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
}
|
||||
|
||||
c := jsonToClient(cj)
|
||||
ok, err := clients.Add(*c)
|
||||
ok, err := clients.Add(c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
@@ -215,7 +216,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
}
|
||||
|
||||
c := jsonToClient(dj.Data)
|
||||
err = clients.Update(dj.Name, *c)
|
||||
err = clients.Update(dj.Name, c)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
@@ -227,51 +228,78 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
// Get the list of clients by IP address list
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]interface{}{}
|
||||
for i := 0; ; i++ {
|
||||
ip := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if len(ip) == 0 {
|
||||
data := []map[string]clientJSON{}
|
||||
for i := 0; i < len(q); i++ {
|
||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if idStr == "" {
|
||||
break
|
||||
}
|
||||
el := map[string]interface{}{}
|
||||
c, ok := clients.Find(ip)
|
||||
|
||||
ip := net.ParseIP(idStr)
|
||||
c, ok := clients.Find(idStr)
|
||||
var cj clientJSON
|
||||
if !ok {
|
||||
ch, ok := clients.FindAutoClient(ip)
|
||||
if !ok {
|
||||
continue // a client with this IP isn't found
|
||||
var found bool
|
||||
cj, found = clients.findTemporary(ip, idStr)
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
cj := clientHostToJSON(ip, ch)
|
||||
|
||||
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
|
||||
el[ip] = cj
|
||||
} else {
|
||||
cj := clientToJSON(&c)
|
||||
|
||||
cj = clientToJSON(c)
|
||||
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
|
||||
el[ip] = cj
|
||||
}
|
||||
|
||||
data = append(data, el)
|
||||
}
|
||||
|
||||
js, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
return
|
||||
data = append(data, map[string]clientJSON{
|
||||
idStr: cj,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(js)
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// findTemporary looks up the IP in temporary storages, like autohosts or
|
||||
// blocklists.
|
||||
func (clients *clientsContainer) findTemporary(ip net.IP, idStr string) (cj clientJSON, found bool) {
|
||||
if ip == nil {
|
||||
return cj, false
|
||||
}
|
||||
|
||||
ch, ok := clients.FindAutoClient(idStr)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime
|
||||
// clients list, but then the server was reloaded. So, check
|
||||
// the DNS server's blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
if rule == "" {
|
||||
return clientJSON{}, false
|
||||
}
|
||||
|
||||
cj = clientJSON{
|
||||
IDs: []string{idStr},
|
||||
Disallowed: disallowed,
|
||||
DisallowedRule: rule,
|
||||
}
|
||||
|
||||
return cj, true
|
||||
}
|
||||
|
||||
cj = clientHostToJSON(idStr, ch)
|
||||
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
|
||||
|
||||
return cj, true
|
||||
}
|
||||
|
||||
// RegisterClientsHandlers registers HTTP handlers
|
||||
func (clients *clientsContainer) registerWebHandlers() {
|
||||
httpRegister("GET", "/control/clients", clients.handleGetClients)
|
||||
httpRegister("POST", "/control/clients/add", clients.handleAddClient)
|
||||
httpRegister("POST", "/control/clients/delete", clients.handleDelClient)
|
||||
httpRegister("POST", "/control/clients/update", clients.handleUpdateClient)
|
||||
httpRegister("GET", "/control/clients/find", clients.handleFindClient)
|
||||
httpRegister(http.MethodGet, "/control/clients", clients.handleGetClients)
|
||||
httpRegister(http.MethodPost, "/control/clients/add", clients.handleAddClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
|
||||
httpRegister(http.MethodGet, "/control/clients/find", clients.handleFindClient)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
@@ -39,13 +42,14 @@ type configuration struct {
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
Users []User `yaml:"users"` // Users that can access HTTP server
|
||||
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
|
||||
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
|
||||
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
|
||||
Users []User `yaml:"users"` // Users that can access HTTP server
|
||||
ProxyURL string `yaml:"http_proxy"` // Proxy address for our HTTP client
|
||||
Language string `yaml:"language"` // two-letter ISO 639-1 language code
|
||||
RlimitNoFile uint `yaml:"rlimit_nofile"` // Maximum number of opened fd's per process (0: default)
|
||||
DebugPProf bool `yaml:"debug_pprof"` // Enable pprof HTTP server on port 6060
|
||||
|
||||
// TTL for a web session (in hours)
|
||||
// An active session is automatically refreshed once a day.
|
||||
@@ -72,7 +76,7 @@ type configuration struct {
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type dnsConfig struct {
|
||||
BindHost string `yaml:"bind_host"`
|
||||
BindHost net.IP `yaml:"bind_host"`
|
||||
Port int `yaml:"port"`
|
||||
|
||||
// time interval for statistics (in days)
|
||||
@@ -117,10 +121,11 @@ type tlsConfigSettings struct {
|
||||
|
||||
// initialize to default values, will be changed later when reading config or parsing command line
|
||||
var config = configuration{
|
||||
BindPort: 3000,
|
||||
BindHost: "0.0.0.0",
|
||||
BindPort: 3000,
|
||||
BetaBindPort: 0,
|
||||
BindHost: net.IP{0, 0, 0, 0},
|
||||
DNS: dnsConfig{
|
||||
BindHost: "0.0.0.0",
|
||||
BindHost: net.IP{0, 0, 0, 0},
|
||||
Port: 53,
|
||||
StatsInterval: 1,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
@@ -174,13 +179,17 @@ func initConfig() {
|
||||
config.DHCP.Conf4.LeaseDuration = 86400
|
||||
config.DHCP.Conf4.ICMPTimeout = 1000
|
||||
config.DHCP.Conf6.LeaseDuration = 86400
|
||||
|
||||
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
|
||||
config.BetaBindPort = 3001
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
func (c *configuration) getConfigFilename() string {
|
||||
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("unexpected error while config file path evaluation: %s", err)
|
||||
}
|
||||
configFile = Context.configFilename
|
||||
|
||||
@@ -2,6 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
)
|
||||
@@ -35,48 +37,52 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
||||
// ---------------
|
||||
// dns run control
|
||||
// ---------------
|
||||
func addDNSAddress(dnsAddresses *[]string, addr string) {
|
||||
func addDNSAddress(dnsAddresses *[]string, addr net.IP) {
|
||||
hostport := addr.String()
|
||||
if config.DNS.Port != 53 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port)
|
||||
hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
|
||||
}
|
||||
*dnsAddresses = append(*dnsAddresses, addr)
|
||||
*dnsAddresses = append(*dnsAddresses, hostport)
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
c := dnsforward.FilteringConfig{}
|
||||
// statusResponse is a response for /control/status endpoint.
|
||||
type statusResponse struct {
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
IsProtectionEnabled bool `json:"protection_enabled"`
|
||||
// TODO(e.burkov): Inspect if front-end doesn't requires this field as
|
||||
// openapi.yaml declares.
|
||||
IsDHCPAvailable bool `json:"dhcp_available"`
|
||||
IsRunning bool `json:"running"`
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := statusResponse{
|
||||
DNSAddrs: getDNSAddresses(),
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.BindPort,
|
||||
IsRunning: isRunning(),
|
||||
Version: version.Version(),
|
||||
Language: config.Language,
|
||||
}
|
||||
|
||||
var c *dnsforward.FilteringConfig
|
||||
if Context.dnsServer != nil {
|
||||
Context.dnsServer.WriteDiskConfig(&c)
|
||||
c = &dnsforward.FilteringConfig{}
|
||||
Context.dnsServer.WriteDiskConfig(c)
|
||||
resp.IsProtectionEnabled = c.ProtectionEnabled
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"dns_addresses": getDNSAddresses(),
|
||||
"http_port": config.BindPort,
|
||||
"dns_port": config.DNS.Port,
|
||||
"running": isRunning(),
|
||||
"version": versionString,
|
||||
"language": config.Language,
|
||||
|
||||
"protection_enabled": c.ProtectionEnabled,
|
||||
// IsDHCPAvailable field is now false by default for Windows.
|
||||
if runtime.GOOS != "windows" {
|
||||
resp.IsDHCPAvailable = Context.dhcpServer != nil
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// Set the DHCP to false explicitly, because Context.dhcpServer
|
||||
// is probably not nil, despite the fact that there is no
|
||||
// support for DHCP on Windows in AdGuardHome.
|
||||
//
|
||||
// See also the TODO in dhcpd.Create.
|
||||
data["dhcp_available"] = false
|
||||
} else {
|
||||
data["dhcp_available"] = (Context.dhcpServer != nil)
|
||||
}
|
||||
|
||||
jsonVal, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal status json: %s", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
return
|
||||
@@ -89,7 +95,7 @@ type profileJSON struct {
|
||||
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
pj := profileJSON{}
|
||||
u := Context.auth.GetCurrentUser(r)
|
||||
u := Context.auth.getCurrentUser(r)
|
||||
pj.Name = u.Name
|
||||
|
||||
data, err := json.Marshal(pj)
|
||||
@@ -118,7 +124,7 @@ func registerControlHandlers() {
|
||||
}
|
||||
|
||||
func httpRegister(method, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
if len(method) == 0 {
|
||||
if method == "" {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
Context.mux.HandleFunc(url, postInstall(handler))
|
||||
return
|
||||
@@ -139,7 +145,7 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
|
||||
return
|
||||
}
|
||||
|
||||
if method == "POST" || method == "PUT" || method == "DELETE" {
|
||||
if method == http.MethodPost || method == http.MethodPut || method == http.MethodDelete {
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
}
|
||||
@@ -149,11 +155,11 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
|
||||
}
|
||||
|
||||
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure("POST", handler)
|
||||
return ensure(http.MethodPost, handler)
|
||||
}
|
||||
|
||||
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure("GET", handler)
|
||||
return ensure(http.MethodGet, handler)
|
||||
}
|
||||
|
||||
// Bridge between http.Handler object and Go function
|
||||
@@ -197,37 +203,84 @@ func preInstallHandler(handler http.Handler) http.Handler {
|
||||
return &preInstallHandlerStruct{handler}
|
||||
}
|
||||
|
||||
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
|
||||
// it also enforces HTTPS if it is enabled and configured
|
||||
const defaultHTTPSPort = 443
|
||||
|
||||
// handleHTTPSRedirect redirects the request to HTTPS, if needed. If ok is
|
||||
// true, the middleware must continue handling the request.
|
||||
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
web := Context.web
|
||||
if web.httpsServer.server == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
// Check for the missing port error. If it is that error, just
|
||||
// use the host as is.
|
||||
//
|
||||
// See the source code for net.SplitHostPort.
|
||||
const missingPort = "missing port in address"
|
||||
|
||||
addrErr := &net.AddrError{}
|
||||
if !errors.As(err, &addrErr) || addrErr.Err != missingPort {
|
||||
httpError(w, http.StatusBadRequest, "bad host: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
if r.TLS == nil && web.forceHTTPS {
|
||||
hostPort := host
|
||||
if port := web.conf.PortHTTPS; port != defaultHTTPSPort {
|
||||
portStr := strconv.Itoa(port)
|
||||
hostPort = net.JoinHostPort(host, portStr)
|
||||
}
|
||||
|
||||
httpsURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: hostPort,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(w, r, httpsURL.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Allow the frontend from the HTTP origin to send requests to the HTTPS
|
||||
// server. This can happen when the user has just set up HTTPS with
|
||||
// redirects. Prevent cache-related errors by setting the Vary header.
|
||||
//
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
|
||||
originURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: r.Host,
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
|
||||
w.Header().Set("Vary", "Origin")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// postInstall lets the handler to run only if firstRun is false. Otherwise, it
|
||||
// redirects to /install.html. It also enforces HTTPS if it is enabled and
|
||||
// configured and sets appropriate access control headers.
|
||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if Context.firstRun &&
|
||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||
!strings.HasPrefix(r.URL.Path, "/assets/") {
|
||||
path := r.URL.Path
|
||||
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
|
||||
!strings.HasPrefix(path, "/assets/") {
|
||||
http.Redirect(w, r, "/install.html", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// enforce https?
|
||||
if r.TLS == nil && Context.web.forceHTTPS && Context.web.httpsServer.server != nil {
|
||||
// yes, and we want host from host:port
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
// no port in host
|
||||
host = r.Host
|
||||
}
|
||||
// construct new URL to redirect to
|
||||
newURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(host, strconv.Itoa(Context.web.portHTTPS)),
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
}
|
||||
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
||||
if !handleHTTPSRedirect(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
handler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,8 +68,8 @@ kXS9jgARhhiWXJrk
|
||||
data.KeyType == "RSA" &&
|
||||
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
||||
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
|
||||
data.NotBefore == notBefore &&
|
||||
data.NotAfter == notAfter &&
|
||||
data.NotBefore.Equal(notBefore) &&
|
||||
data.NotAfter.Equal(notAfter) &&
|
||||
// data.DNSNames[0] == &&
|
||||
data.ValidPair) {
|
||||
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
|
||||
|
||||
@@ -369,7 +369,7 @@ type checkHostResp struct {
|
||||
// for FilteredBlockedService:
|
||||
SvcName string `json:"service_name"`
|
||||
|
||||
// for ReasonRewrite:
|
||||
// for Rewrite:
|
||||
CanonName string `json:"cname"` // CNAME value
|
||||
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
|
||||
}
|
||||
@@ -417,14 +417,14 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RegisterFilteringHandlers - register handlers
|
||||
func (f *Filtering) RegisterFilteringHandlers() {
|
||||
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
|
||||
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
|
||||
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
|
||||
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
||||
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
|
||||
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
|
||||
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
|
||||
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
|
||||
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
|
||||
}
|
||||
|
||||
func checkFiltersUpdateIntervalHours(i uint32) bool {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
@@ -20,23 +22,16 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type firstRunData struct {
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces map[string]interface{} `json:"interfaces"`
|
||||
// getAddrsResponse is the response for /install/get_addresses endpoint.
|
||||
type getAddrsResponse struct {
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces map[string]*util.NetInterface `json:"interfaces"`
|
||||
}
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
MTU int `json:"mtu"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Addresses []string `json:"ip_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
// Get initial installation settings
|
||||
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
|
||||
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
data := firstRunData{}
|
||||
data := getAddrsResponse{}
|
||||
data.WebPort = 80
|
||||
data.DNSPort = 53
|
||||
|
||||
@@ -46,16 +41,9 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
data.Interfaces = make(map[string]interface{})
|
||||
data.Interfaces = make(map[string]*util.NetInterface)
|
||||
for _, iface := range ifaces {
|
||||
ifaceJSON := netInterfaceJSON{
|
||||
Name: iface.Name,
|
||||
MTU: iface.MTU,
|
||||
HardwareAddr: iface.HardwareAddr,
|
||||
Addresses: iface.Addresses,
|
||||
Flags: iface.Flags,
|
||||
}
|
||||
data.Interfaces[iface.Name] = ifaceJSON
|
||||
data.Interfaces[iface.Name] = iface
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -68,7 +56,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
|
||||
|
||||
type checkConfigReqEnt struct {
|
||||
Port int `json:"port"`
|
||||
IP string `json:"ip"`
|
||||
IP net.IP `json:"ip"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
@@ -105,10 +93,10 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
|
||||
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort && reqData.Web.Port != config.BetaBindPort {
|
||||
err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
|
||||
if err != nil {
|
||||
respData.Web.Status = fmt.Sprintf("%v", err)
|
||||
respData.Web.Status = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,8 +124,8 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
respData.DNS.Status = fmt.Sprintf("%v", err)
|
||||
} else if reqData.DNS.IP != "0.0.0.0" {
|
||||
respData.DNS.Status = err.Error()
|
||||
} else if !reqData.DNS.IP.IsUnspecified() {
|
||||
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
|
||||
}
|
||||
}
|
||||
@@ -153,7 +141,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
// handleStaticIP - handles static IP request
|
||||
// It either checks if we have a static IP
|
||||
// Or if set=true, it tries to set it
|
||||
func handleStaticIP(ip string, set bool) staticIPJSON {
|
||||
func handleStaticIP(ip net.IP, set bool) staticIPJSON {
|
||||
resp := staticIPJSON{}
|
||||
|
||||
interfaceName := util.GetInterfaceByIP(ip)
|
||||
@@ -185,7 +173,7 @@ func handleStaticIP(ip string, set bool) staticIPJSON {
|
||||
if isStaticIP {
|
||||
resp.Static = "yes"
|
||||
}
|
||||
resp.IP = util.GetSubnet(interfaceName)
|
||||
resp.IP = util.GetSubnet(interfaceName).String()
|
||||
}
|
||||
return resp
|
||||
}
|
||||
@@ -261,7 +249,7 @@ func disableDNSStubListener() error {
|
||||
}
|
||||
|
||||
type applyConfigReqEnt struct {
|
||||
IP string `json:"ip"`
|
||||
IP net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
@@ -276,10 +264,14 @@ type applyConfigReq struct {
|
||||
func copyInstallSettings(dst, src *configuration) {
|
||||
dst.BindHost = src.BindHost
|
||||
dst.BindPort = src.BindPort
|
||||
dst.BetaBindPort = src.BetaBindPort
|
||||
dst.DNS.BindHost = src.DNS.BindHost
|
||||
dst.DNS.Port = src.DNS.Port
|
||||
}
|
||||
|
||||
// shutdownTimeout is the timeout for shutting HTTP server down operation.
|
||||
const shutdownTimeout = 5 * time.Second
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
newSettings := applyConfigReq{}
|
||||
@@ -295,7 +287,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
restartHTTP := true
|
||||
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port {
|
||||
if config.BindHost.Equal(newSettings.Web.IP) && config.BindPort == newSettings.Web.Port {
|
||||
// no need to rebind
|
||||
restartHTTP = false
|
||||
}
|
||||
@@ -305,9 +297,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
|
||||
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
|
||||
net.JoinHostPort(newSettings.Web.IP.String(), strconv.Itoa(newSettings.Web.Port)), err)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||
@@ -331,6 +324,10 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
config.DNS.BindHost = newSettings.DNS.IP
|
||||
config.DNS.Port = newSettings.DNS.Port
|
||||
|
||||
// TODO(e.burkov): StartMods() should be put in a separate goroutine at
|
||||
// the moment we'll allow setting up TLS in the initial configuration or
|
||||
// the configuration itself will use HTTPS protocol, because the
|
||||
// underlying functions potentially restart the HTTPS server.
|
||||
err = StartMods()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
@@ -362,12 +359,23 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
// The Shutdown() method of (*http.Server) needs to be called in a
|
||||
// separate goroutine, because it waits until all requests are handled
|
||||
// and will be blocked by it's own caller.
|
||||
if restartHTTP {
|
||||
go func() {
|
||||
_ = Context.web.httpServer.Shutdown(context.TODO())
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
|
||||
shut := func(srv *http.Server) {
|
||||
defer cancel()
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Debug("error while shutting down HTTP server: %s", err)
|
||||
}
|
||||
}
|
||||
go shut(web.httpServer)
|
||||
if web.httpServerBeta != nil {
|
||||
go shut(web.httpServerBeta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,3 +384,186 @@ func (web *Web) registerInstallHandlers() {
|
||||
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
|
||||
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
|
||||
}
|
||||
|
||||
// checkConfigReqEntBeta is a struct representing new client's config check
|
||||
// request entry. It supports multiple IP values unlike the checkConfigReqEnt.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default checkConfigReqEnt.
|
||||
type checkConfigReqEntBeta struct {
|
||||
Port int `json:"port"`
|
||||
IP []net.IP `json:"ip"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
// checkConfigReqBeta is a struct representing new client's config check request
|
||||
// body. It uses checkConfigReqEntBeta instead of checkConfigReqEnt.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default checkConfigReq.
|
||||
type checkConfigReqBeta struct {
|
||||
Web checkConfigReqEntBeta `json:"web"`
|
||||
DNS checkConfigReqEntBeta `json:"dns"`
|
||||
SetStaticIP bool `json:"set_static_ip"`
|
||||
}
|
||||
|
||||
// handleInstallCheckConfigBeta is a substitution of /install/check_config
|
||||
// handler for new client.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default handleInstallCheckConfig.
|
||||
func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := checkConfigReqBeta{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
|
||||
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
nonBetaReqData := checkConfigReq{
|
||||
Web: checkConfigReqEnt{
|
||||
Port: reqData.Web.Port,
|
||||
IP: reqData.Web.IP[0],
|
||||
Autofix: reqData.Web.Autofix,
|
||||
},
|
||||
DNS: checkConfigReqEnt{
|
||||
Port: reqData.DNS.Port,
|
||||
IP: reqData.DNS.IP[0],
|
||||
Autofix: reqData.DNS.Autofix,
|
||||
},
|
||||
SetStaticIP: reqData.SetStaticIP,
|
||||
}
|
||||
|
||||
nonBetaReqBody := &strings.Builder{}
|
||||
|
||||
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
|
||||
return
|
||||
}
|
||||
body := nonBetaReqBody.String()
|
||||
r.Body = ioutil.NopCloser(strings.NewReader(body))
|
||||
r.ContentLength = int64(len(body))
|
||||
|
||||
web.handleInstallCheckConfig(w, r)
|
||||
}
|
||||
|
||||
// applyConfigReqEntBeta is a struct representing new client's config setting
|
||||
// request entry. It supports multiple IP values unlike the applyConfigReqEnt.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default applyConfigReqEnt.
|
||||
type applyConfigReqEntBeta struct {
|
||||
IP []net.IP `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// applyConfigReqBeta is a struct representing new client's config setting
|
||||
// request body. It uses applyConfigReqEntBeta instead of applyConfigReqEnt.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default applyConfigReq.
|
||||
type applyConfigReqBeta struct {
|
||||
Web applyConfigReqEntBeta `json:"web"`
|
||||
DNS applyConfigReqEntBeta `json:"dns"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// handleInstallConfigureBeta is a substitution of /install/configure handler
|
||||
// for new client.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default handleInstallConfigure.
|
||||
func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Request) {
|
||||
reqData := applyConfigReqBeta{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to parse 'check_config' JSON data: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(reqData.DNS.IP) == 0 || len(reqData.Web.IP) == 0 {
|
||||
httpError(w, http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
nonBetaReqData := applyConfigReq{
|
||||
Web: applyConfigReqEnt{
|
||||
IP: reqData.Web.IP[0],
|
||||
Port: reqData.Web.Port,
|
||||
},
|
||||
DNS: applyConfigReqEnt{
|
||||
IP: reqData.DNS.IP[0],
|
||||
Port: reqData.DNS.Port,
|
||||
},
|
||||
Username: reqData.Username,
|
||||
Password: reqData.Password,
|
||||
}
|
||||
|
||||
nonBetaReqBody := &strings.Builder{}
|
||||
|
||||
err = json.NewEncoder(nonBetaReqBody).Encode(nonBetaReqData)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Failed to encode 'check_config' JSON data: %s", err)
|
||||
return
|
||||
}
|
||||
body := nonBetaReqBody.String()
|
||||
r.Body = ioutil.NopCloser(strings.NewReader(body))
|
||||
r.ContentLength = int64(len(body))
|
||||
|
||||
web.handleInstallConfigure(w, r)
|
||||
}
|
||||
|
||||
// getAddrsResponseBeta is a struct representing new client's getting addresses
|
||||
// request body. It uses array of structs instead of map.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default firstRunData.
|
||||
type getAddrsResponseBeta struct {
|
||||
WebPort int `json:"web_port"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
Interfaces []*util.NetInterface `json:"interfaces"`
|
||||
}
|
||||
|
||||
// handleInstallConfigureBeta is a substitution of /install/get_addresses
|
||||
// handler for new client.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default handleInstallGetAddresses.
|
||||
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
|
||||
data := getAddrsResponseBeta{}
|
||||
data.WebPort = 80
|
||||
data.DNSPort = 53
|
||||
|
||||
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
data.Interfaces = ifaces
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Unable to marshal default addresses to json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// registerBetaInstallHandlers registers the install handlers for new client
|
||||
// with the structures it supports.
|
||||
//
|
||||
// TODO(e.burkov): This should removed with the API v1 when the appropriate
|
||||
// functionality will appear in default handlers.
|
||||
func (web *Web) registerBetaInstallHandlers() {
|
||||
Context.mux.HandleFunc("/control/install/get_addresses_beta", preInstall(ensureGET(web.handleInstallGetAddressesBeta)))
|
||||
Context.mux.HandleFunc("/control/install/check_config_beta", preInstall(ensurePOST(web.handleInstallCheckConfigBeta)))
|
||||
Context.mux.HandleFunc("/control/install/configure_beta", preInstall(ensurePOST(web.handleInstallConfigureBeta)))
|
||||
}
|
||||
|
||||
@@ -1,76 +1,104 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/sysutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/update"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
type getVersionJSONRequest struct {
|
||||
RecheckNow bool `json:"recheck_now"`
|
||||
// temporaryError is the interface for temporary errors from the Go standard
|
||||
// library.
|
||||
type temporaryError interface {
|
||||
error
|
||||
Temporary() (ok bool)
|
||||
}
|
||||
|
||||
// Get the latest available version from the Internet
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
resp := &versionResponse{}
|
||||
if Context.disableUpdate {
|
||||
resp := make(map[string]interface{})
|
||||
resp["disabled"] = true
|
||||
d, _ := json.Marshal(resp)
|
||||
_, _ = w.Write(d)
|
||||
// w.Header().Set("Content-Type", "application/json")
|
||||
resp.Disabled = true
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
// TODO(e.burkov): Add error handling and deal with headers.
|
||||
return
|
||||
}
|
||||
|
||||
req := getVersionJSONRequest{}
|
||||
req := &struct {
|
||||
Recheck bool `json:"recheck_now"`
|
||||
}{}
|
||||
|
||||
var err error
|
||||
if r.ContentLength != 0 {
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err = json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var info update.VersionInfo
|
||||
for i := 0; i != 3; i++ {
|
||||
Context.controlLock.Lock()
|
||||
info, err = Context.updater.GetVersionResponse(req.RecheckNow)
|
||||
Context.controlLock.Unlock()
|
||||
if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") {
|
||||
// This case may happen while we're restarting DNS server
|
||||
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/934
|
||||
continue
|
||||
func() {
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
|
||||
resp.VersionInfo, err = Context.updater.VersionInfo(req.Recheck)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
var terr temporaryError
|
||||
if errors.As(err, &terr) && terr.Temporary() {
|
||||
// Temporary network error. This case may happen while
|
||||
// we're restarting our DNS server. Log and sleep for
|
||||
// some time.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
|
||||
d := time.Duration(i) * time.Second
|
||||
log.Info("temp net error: %q; sleeping for %s and retrying", err, d)
|
||||
time.Sleep(d)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
// TODO(a.garipov): Figure out the purpose of %T verb.
|
||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", vcu, err, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.confirmAutoUpdate()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(getVersionResp(info))
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Perform an update procedure to the latest available version
|
||||
// handleUpdate performs an update to the latest available version procedure.
|
||||
func handleUpdate(w http.ResponseWriter, _ *http.Request) {
|
||||
if len(Context.updater.NewVersion) == 0 {
|
||||
if Context.updater.NewVersion() == "" {
|
||||
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||
return
|
||||
}
|
||||
|
||||
err := Context.updater.DoUpdate()
|
||||
err := Context.updater.Update()
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||
return
|
||||
@@ -81,24 +109,33 @@ func handleUpdate(w http.ResponseWriter, _ *http.Request) {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
go finishUpdate()
|
||||
// The background context is used because the underlying functions wrap
|
||||
// it with timeout and shut down the server, which handles current
|
||||
// request. It also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
go func() {
|
||||
finishUpdate(context.Background())
|
||||
}()
|
||||
}
|
||||
|
||||
// Convert version.json data to our JSON response
|
||||
func getVersionResp(info update.VersionInfo) []byte {
|
||||
ret := make(map[string]interface{})
|
||||
ret["can_autoupdate"] = false
|
||||
ret["new_version"] = info.NewVersion
|
||||
ret["announcement"] = info.Announcement
|
||||
ret["announcement_url"] = info.AnnouncementURL
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
type versionResponse struct {
|
||||
Disabled bool `json:"disabled"`
|
||||
updater.VersionInfo
|
||||
}
|
||||
|
||||
if info.CanAutoUpdate {
|
||||
// confirmAutoUpdate checks the real possibility of auto update.
|
||||
func (vr *versionResponse) confirmAutoUpdate() {
|
||||
if vr.CanAutoUpdate != nil && *vr.CanAutoUpdate {
|
||||
canUpdate := true
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
var tlsConf *tlsConfigSettings
|
||||
if runtime.GOOS != "windows" {
|
||||
tlsConf = &tlsConfigSettings{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
}
|
||||
|
||||
if runtime.GOOS != "windows" &&
|
||||
if tlsConf != nil &&
|
||||
((tlsConf.Enabled && (tlsConf.PortHTTPS < 1024 ||
|
||||
tlsConf.PortDNSOverTLS < 1024 ||
|
||||
tlsConf.PortDNSOverQUIC < 1024)) ||
|
||||
@@ -106,17 +143,14 @@ func getVersionResp(info update.VersionInfo) []byte {
|
||||
config.DNS.Port < 1024) {
|
||||
canUpdate, _ = sysutil.CanBindPrivilegedPorts()
|
||||
}
|
||||
ret["can_autoupdate"] = canUpdate
|
||||
vr.CanAutoUpdate = &canUpdate
|
||||
}
|
||||
|
||||
d, _ := json.Marshal(ret)
|
||||
return d
|
||||
}
|
||||
|
||||
// Complete an update procedure
|
||||
func finishUpdate() {
|
||||
// finishUpdate completes an update procedure.
|
||||
func finishUpdate(ctx context.Context) {
|
||||
log.Info("Stopping all tasks")
|
||||
cleanup()
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
|
||||
exeName := "AdGuardHome"
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
// +build ignore
|
||||
|
||||
package home
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDoUpdate(t *testing.T) {
|
||||
config.DNS.Port = 0
|
||||
Context.workDir = "..." // set absolute path
|
||||
newver := "v0.96"
|
||||
|
||||
data := `{
|
||||
"version": "v0.96",
|
||||
"announcement": "AdGuard Home v0.96 is now available!",
|
||||
"announcement_url": "",
|
||||
"download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip",
|
||||
"download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip",
|
||||
"download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip",
|
||||
"download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip",
|
||||
"download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz",
|
||||
"download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz",
|
||||
"download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz",
|
||||
"download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
"download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz",
|
||||
"download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz",
|
||||
"download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz",
|
||||
"download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz",
|
||||
"download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz",
|
||||
"download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz",
|
||||
"download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz",
|
||||
"download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz",
|
||||
"download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz",
|
||||
"download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz",
|
||||
"download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz",
|
||||
"download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz"
|
||||
}`
|
||||
uu, err := getUpdateInfo([]byte(data))
|
||||
if err != nil {
|
||||
t.Fatalf("getUpdateInfo: %s", err)
|
||||
}
|
||||
|
||||
u := updateInfo{
|
||||
pkgURL: "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz",
|
||||
pkgName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome_linux_amd64.tar.gz",
|
||||
newVer: newver,
|
||||
updateDir: Context.workDir + "/agh-update-" + newver,
|
||||
backupDir: Context.workDir + "/agh-backup",
|
||||
configName: Context.workDir + "/AdGuardHome.yaml",
|
||||
updateConfigName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome.yaml",
|
||||
curBinName: Context.workDir + "/AdGuardHome",
|
||||
bkpBinName: Context.workDir + "/agh-backup/AdGuardHome",
|
||||
newBinName: Context.workDir + "/agh-update-" + newver + "/AdGuardHome/internal/AdGuardHome",
|
||||
}
|
||||
|
||||
assert.Equal(t, uu.pkgURL, u.pkgURL)
|
||||
assert.Equal(t, uu.pkgName, u.pkgName)
|
||||
assert.Equal(t, uu.newVer, u.newVer)
|
||||
assert.Equal(t, uu.updateDir, u.updateDir)
|
||||
assert.Equal(t, uu.backupDir, u.backupDir)
|
||||
assert.Equal(t, uu.configName, u.configName)
|
||||
assert.Equal(t, uu.updateConfigName, u.updateConfigName)
|
||||
assert.Equal(t, uu.curBinName, u.curBinName)
|
||||
assert.Equal(t, uu.bkpBinName, u.bkpBinName)
|
||||
assert.Equal(t, uu.newBinName, u.newBinName)
|
||||
|
||||
e := doUpdate(&u)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
os.RemoveAll(u.backupDir)
|
||||
}
|
||||
|
||||
func TestTargzFileUnpack(t *testing.T) {
|
||||
fn := "../dist/AdGuardHome_linux_amd64.tar.gz"
|
||||
outdir := "../test-unpack"
|
||||
defer os.RemoveAll(outdir)
|
||||
_ = os.Mkdir(outdir, 0o755)
|
||||
files, e := targzFileUnpack(fn, outdir)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
t.Logf("%v", files)
|
||||
}
|
||||
|
||||
func TestZipFileUnpack(t *testing.T) {
|
||||
fn := "../dist/AdGuardHome_windows_amd64.zip"
|
||||
outdir := "../test-unpack"
|
||||
_ = os.Mkdir(outdir, 0o755)
|
||||
files, e := zipFileUnpack(fn, outdir)
|
||||
if e != nil {
|
||||
t.Fatalf("FAILED: %s", e)
|
||||
}
|
||||
t.Logf("%v", files)
|
||||
os.RemoveAll(outdir)
|
||||
}
|
||||
@@ -3,8 +3,10 @@ package home
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
@@ -55,10 +57,10 @@ func initDNSServer() error {
|
||||
|
||||
filterConf := config.DNS.DnsfilterConf
|
||||
bindhost := config.DNS.BindHost
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
bindhost = "127.0.0.1"
|
||||
if config.DNS.BindHost.IsUnspecified() {
|
||||
bindhost = net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
|
||||
filterConf.ResolverAddress = net.JoinHostPort(bindhost.String(), strconv.Itoa(config.DNS.Port))
|
||||
filterConf.AutoHosts = &Context.autoHosts
|
||||
filterConf.ConfigModified = onConfigModified
|
||||
filterConf.HTTPRegister = httpRegister
|
||||
@@ -98,26 +100,24 @@ func isRunning() bool {
|
||||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
ip := dnsforward.GetIPString(d.Addr)
|
||||
if ip == "" {
|
||||
ip := dnsforward.IPFromAddr(d.Addr)
|
||||
if ip == nil {
|
||||
// This would be quite weird if we get here
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if !ipAddr.IsLoopback() {
|
||||
if !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if !Context.ipDetector.detectSpecialNetwork(ipAddr) {
|
||||
if !Context.ipDetector.detectSpecialNetwork(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
|
||||
bindHost := net.ParseIP(config.DNS.BindHost)
|
||||
newconfig = dnsforward.ServerConfig{
|
||||
UDPListenAddr: &net.UDPAddr{IP: bindHost, Port: config.DNS.Port},
|
||||
TCPListenAddr: &net.TCPAddr{IP: bindHost, Port: config.DNS.Port},
|
||||
UDPListenAddr: &net.UDPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
|
||||
TCPListenAddr: &net.TCPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
|
||||
FilteringConfig: config.DNS.FilteringConfig,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
@@ -128,23 +128,24 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled {
|
||||
newconfig.TLSConfig = tlsConf.TLSConfig
|
||||
newconfig.TLSConfig.ServerName = tlsConf.ServerName
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
newconfig.TLSListenAddr = &net.TCPAddr{
|
||||
IP: bindHost,
|
||||
IP: config.DNS.BindHost,
|
||||
Port: tlsConf.PortDNSOverTLS,
|
||||
}
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverQUIC != 0 {
|
||||
newconfig.QUICListenAddr = &net.UDPAddr{
|
||||
IP: bindHost,
|
||||
IP: config.DNS.BindHost,
|
||||
Port: int(tlsConf.PortDNSOverQUIC),
|
||||
}
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSCrypt != 0 {
|
||||
newconfig.DNSCryptConfig, err = newDNSCrypt(bindHost, tlsConf)
|
||||
newconfig.DNSCryptConfig, err = newDNSCrypt(config.DNS.BindHost, tlsConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's already
|
||||
// wrapped by newDNSCrypt.
|
||||
@@ -209,43 +210,49 @@ type dnsEncryption struct {
|
||||
quic string
|
||||
}
|
||||
|
||||
func getDNSEncryption() dnsEncryption {
|
||||
dnsEncryption := dnsEncryption{}
|
||||
|
||||
func getDNSEncryption() (de dnsEncryption) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
|
||||
if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
|
||||
|
||||
hostname := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
addr := tlsConf.ServerName
|
||||
addr := hostname
|
||||
if tlsConf.PortHTTPS != 443 {
|
||||
addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS)
|
||||
addr = net.JoinHostPort(addr, strconv.Itoa(tlsConf.PortHTTPS))
|
||||
}
|
||||
addr = fmt.Sprintf("https://%s/dns-query", addr)
|
||||
dnsEncryption.https = addr
|
||||
|
||||
de.https = (&url.URL{
|
||||
Scheme: "https",
|
||||
Host: addr,
|
||||
Path: "/dns-query",
|
||||
}).String()
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS)
|
||||
dnsEncryption.tls = addr
|
||||
de.tls = (&url.URL{
|
||||
Scheme: "tls",
|
||||
Host: net.JoinHostPort(hostname, strconv.Itoa(tlsConf.PortDNSOverTLS)),
|
||||
}).String()
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverQUIC != 0 {
|
||||
addr := fmt.Sprintf("quic://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverQUIC)
|
||||
dnsEncryption.quic = addr
|
||||
de.quic = (&url.URL{
|
||||
Scheme: "quic",
|
||||
Host: net.JoinHostPort(hostname, strconv.Itoa(int(tlsConf.PortDNSOverQUIC))),
|
||||
}).String()
|
||||
}
|
||||
}
|
||||
|
||||
return dnsEncryption
|
||||
return de
|
||||
}
|
||||
|
||||
// Get the list of DNS addresses the server is listening on
|
||||
func getDNSAddresses() []string {
|
||||
dnsAddresses := []string{}
|
||||
|
||||
if config.DNS.BindHost == "0.0.0.0" {
|
||||
if config.DNS.BindHost.IsUnspecified() {
|
||||
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||
if e != nil {
|
||||
log.Error("Couldn't get network interfaces: %v", e)
|
||||
@@ -275,21 +282,26 @@ func getDNSAddresses() []string {
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
if len(clientAddr) == 0 {
|
||||
if clientAddr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
setts.ClientIP = clientAddr
|
||||
|
||||
c, ok := Context.clients.Find(clientAddr)
|
||||
c, ok := Context.clients.Find(clientID)
|
||||
if !ok {
|
||||
return
|
||||
c, ok = Context.clients.Find(clientAddr.String())
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Using settings for client %s with IP %s", c.Name, clientAddr)
|
||||
log.Debug("using settings for client %s with ip %s and id %q", c.Name, clientAddr, clientID)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
@@ -328,13 +340,11 @@ func startDNSServer() error {
|
||||
Context.queryLog.Start()
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
topClients := Context.stats.GetTopClientsIP(topClientsNumber)
|
||||
for _, ip := range topClients {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if !ipAddr.IsLoopback() {
|
||||
for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
|
||||
if !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if !Context.ipDetector.detectSpecialNetwork(ipAddr) {
|
||||
if !Context.ipDetector.detectSpecialNetwork(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,16 +50,17 @@ func TestFilters(t *testing.T) {
|
||||
|
||||
// download
|
||||
ok, err := Context.filters.update(&f)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 3, f.RulesCount)
|
||||
|
||||
// refresh
|
||||
ok, err = Context.filters.update(&f)
|
||||
assert.True(t, !ok && err == nil)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = Context.filters.load(&f)
|
||||
assert.True(t, err == nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
f.unload()
|
||||
_ = os.Remove(f.Path())
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -27,8 +28,9 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/sysutil"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/update"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
@@ -38,15 +40,6 @@ const (
|
||||
configSyslog = "syslog"
|
||||
)
|
||||
|
||||
// Update-related variables
|
||||
var (
|
||||
versionString = "dev"
|
||||
updateChannel = "none"
|
||||
versionCheckURL = ""
|
||||
ARMVersion = ""
|
||||
MIPSVersion = ""
|
||||
)
|
||||
|
||||
// Global context
|
||||
type homeContext struct {
|
||||
// Modules
|
||||
@@ -65,7 +58,7 @@ type homeContext struct {
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
|
||||
updater *update.Updater
|
||||
updater *updater.Updater
|
||||
|
||||
ipDetector *ipDetector
|
||||
|
||||
@@ -99,14 +92,7 @@ func (c *homeContext) getDataDir() string {
|
||||
var Context homeContext
|
||||
|
||||
// Main is the entry point
|
||||
func Main(version, channel, armVer, mipsVer string) {
|
||||
// Init update-related global variables
|
||||
versionString = version
|
||||
updateChannel = channel
|
||||
ARMVersion = armVer
|
||||
MIPSVersion = mipsVer
|
||||
versionCheckURL = "https://static.adguard.com/adguardhome/" + updateChannel + "/version.json"
|
||||
|
||||
func Main() {
|
||||
// config can be specified, which reads options from there, but other command line flags have to override config values
|
||||
// therefore, we must do it manually instead of using a lib
|
||||
args := loadOptions()
|
||||
@@ -123,7 +109,7 @@ func Main(version, channel, armVer, mipsVer string) {
|
||||
Context.tls.Reload()
|
||||
|
||||
default:
|
||||
cleanup()
|
||||
cleanup(context.Background())
|
||||
cleanupAlways()
|
||||
os.Exit(0)
|
||||
}
|
||||
@@ -139,23 +125,10 @@ func Main(version, channel, armVer, mipsVer string) {
|
||||
run(args)
|
||||
}
|
||||
|
||||
// version - returns the current version string
|
||||
func version() string {
|
||||
// TODO(a.garipov): I'm pretty sure we can extract some of this stuff
|
||||
// from the build info.
|
||||
msg := "AdGuard Home, version %s, channel %s, arch %s %s"
|
||||
if ARMVersion != "" {
|
||||
msg = msg + " v" + ARMVersion
|
||||
} else if MIPSVersion != "" {
|
||||
msg = msg + " " + MIPSVersion
|
||||
}
|
||||
|
||||
return fmt.Sprintf(msg, versionString, updateChannel, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
func setupContext(args options) {
|
||||
Context.runningAsService = args.runningAsService
|
||||
Context.disableUpdate = args.disableUpdate
|
||||
Context.disableUpdate = args.disableUpdate ||
|
||||
version.Channel() == version.ChannelDevelopment
|
||||
|
||||
Context.firstRun = detectFirstRun()
|
||||
if Context.firstRun {
|
||||
@@ -214,15 +187,16 @@ func setupConfig(args options) {
|
||||
|
||||
Context.autoHosts.Init("")
|
||||
|
||||
Context.updater = update.NewUpdater(update.Config{
|
||||
Client: Context.client,
|
||||
WorkDir: Context.workDir,
|
||||
VersionURL: versionCheckURL,
|
||||
VersionString: versionString,
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
ARMVersion: ARMVersion,
|
||||
ConfigName: config.getConfigFilename(),
|
||||
Context.updater = updater.NewUpdater(&updater.Config{
|
||||
Client: Context.client,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: Context.workDir,
|
||||
ConfName: config.getConfigFilename(),
|
||||
})
|
||||
|
||||
Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts)
|
||||
@@ -234,7 +208,7 @@ func setupConfig(args options) {
|
||||
}
|
||||
|
||||
// override bind host/port from the console
|
||||
if args.bindHost != "" {
|
||||
if args.bindHost != nil {
|
||||
config.BindHost = args.bindHost
|
||||
}
|
||||
if args.bindPort != 0 {
|
||||
@@ -260,7 +234,7 @@ func run(args options) {
|
||||
memoryUsage(args)
|
||||
|
||||
// print the first message after logger is configured
|
||||
log.Println(version())
|
||||
log.Println(version.Full())
|
||||
log.Debug("Current working directory is %s", Context.workDir)
|
||||
if args.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
@@ -316,19 +290,25 @@ func run(args options) {
|
||||
}
|
||||
|
||||
webConf := webConfig{
|
||||
firstRun: Context.firstRun,
|
||||
BindHost: config.BindHost,
|
||||
BindPort: config.BindPort,
|
||||
firstRun: Context.firstRun,
|
||||
BindHost: config.BindHost,
|
||||
BindPort: config.BindPort,
|
||||
BetaBindPort: config.BetaBindPort,
|
||||
|
||||
ReadTimeout: ReadTimeout,
|
||||
ReadHeaderTimeout: ReadHeaderTimeout,
|
||||
WriteTimeout: WriteTimeout,
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
}
|
||||
Context.web = CreateWeb(&webConf)
|
||||
if Context.web == nil {
|
||||
log.Fatalf("Can't initialize Web module")
|
||||
}
|
||||
|
||||
Context.ipDetector, err = newIPDetector()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if !Context.firstRun {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
@@ -340,6 +320,7 @@ func run(args options) {
|
||||
go func() {
|
||||
err := startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
@@ -349,18 +330,13 @@ func run(args options) {
|
||||
}
|
||||
}
|
||||
|
||||
Context.ipDetector, err = newIPDetector()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
Context.web.Start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
select {}
|
||||
}
|
||||
|
||||
// StartMods - initialize and start DNS after installation
|
||||
// StartMods initializes and starts the DNS server after installation.
|
||||
func StartMods() error {
|
||||
err := initDNSServer()
|
||||
if err != nil {
|
||||
@@ -460,6 +436,10 @@ func initWorkingDir(args options) {
|
||||
} else {
|
||||
Context.workDir = filepath.Dir(execPath)
|
||||
}
|
||||
|
||||
if workDir, err := filepath.EvalSymlinks(Context.workDir); err == nil {
|
||||
Context.workDir = workDir
|
||||
}
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output
|
||||
@@ -527,11 +507,12 @@ func configureLogger(args options) {
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
// cleanup stops and resets all the modules.
|
||||
func cleanup(ctx context.Context) {
|
||||
log.Info("Stopping AdGuard Home")
|
||||
|
||||
if Context.web != nil {
|
||||
Context.web.Close()
|
||||
Context.web.Close(ctx)
|
||||
Context.web = nil
|
||||
}
|
||||
if Context.auth != nil {
|
||||
@@ -592,8 +573,6 @@ func loadOptions() options {
|
||||
// prints IP addresses which user can use to open the admin interface
|
||||
// proto is either "http" or "https"
|
||||
func printHTTPAddresses(proto string) {
|
||||
var address string
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
if Context.tls != nil {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
@@ -604,31 +583,41 @@ func printHTTPAddresses(proto string) {
|
||||
port = strconv.Itoa(tlsConf.PortHTTPS)
|
||||
}
|
||||
|
||||
var hostStr string
|
||||
if proto == "https" && tlsConf.ServerName != "" {
|
||||
if tlsConf.PortHTTPS == 443 {
|
||||
log.Printf("Go to https://%s", tlsConf.ServerName)
|
||||
} else {
|
||||
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
|
||||
}
|
||||
} else if config.BindHost == "0.0.0.0" {
|
||||
} else if config.BindHost.IsUnspecified() {
|
||||
log.Println("AdGuard Home is available on the following addresses:")
|
||||
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||
if err != nil {
|
||||
// That's weird, but we'll ignore it
|
||||
address = net.JoinHostPort(config.BindHost, port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
hostStr = config.BindHost.String()
|
||||
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
|
||||
if config.BetaBindPort != 0 {
|
||||
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
for _, addr := range iface.Addresses {
|
||||
address = net.JoinHostPort(addr, strconv.Itoa(config.BindPort))
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
hostStr = addr.String()
|
||||
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BindPort)))
|
||||
if config.BetaBindPort != 0 {
|
||||
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
address = net.JoinHostPort(config.BindHost, port)
|
||||
log.Printf("Go to %s://%s", proto, address)
|
||||
hostStr = config.BindHost.String()
|
||||
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
|
||||
if config.BetaBindPort != 0 {
|
||||
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,7 +630,7 @@ func detectFirstRun() bool {
|
||||
configfile = filepath.Join(Context.workDir, Context.configFilename)
|
||||
}
|
||||
_, err := os.Stat(configfile)
|
||||
return os.IsNotExist(err)
|
||||
return errors.Is(err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
// Connect to a remote server resolving hostname using our own DNS server
|
||||
@@ -685,10 +674,11 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err
|
||||
return nil, agherr.Many(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
||||
}
|
||||
|
||||
func getHTTPProxy(req *http.Request) (*url.URL, error) {
|
||||
if len(config.ProxyURL) == 0 {
|
||||
func getHTTPProxy(_ *http.Request) (*url.URL, error) {
|
||||
if config.ProxyURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return url.Parse(config.ProxyURL)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// +build !race
|
||||
|
||||
// TODO(e.burkov): remove this weird buildtag.
|
||||
// TODO(e.burkov): Remove this weird buildtag.
|
||||
|
||||
package home
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestHome(t *testing.T) {
|
||||
fn := filepath.Join(dir, "AdGuardHome.yaml")
|
||||
|
||||
// Prepare the test config
|
||||
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644) == nil)
|
||||
assert.Nil(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644))
|
||||
fn, _ = filepath.Abs(fn)
|
||||
|
||||
config = configuration{} // the global variable is dirty because of the previous tests run
|
||||
@@ -133,16 +133,16 @@ func TestHome(t *testing.T) {
|
||||
h := http.Client{}
|
||||
for i := 0; i != 50; i++ {
|
||||
resp, err = h.Get("http://127.0.0.1:3000/")
|
||||
if err == nil && resp.StatusCode != 404 {
|
||||
if err == nil && resp.StatusCode != http.StatusNotFound {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
assert.Truef(t, err == nil, "%s", err)
|
||||
assert.Nilf(t, err, "%s", err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
resp, err = h.Get("http://127.0.0.1:3000/control/status")
|
||||
assert.Truef(t, err == nil, "%s", err)
|
||||
assert.Nilf(t, err, "%s", err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// test DNS over UDP
|
||||
@@ -159,16 +159,16 @@ func TestHome(t *testing.T) {
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
|
||||
buf, err := req.Pack()
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
assert.Nil(t, err)
|
||||
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
|
||||
resp, err = http.DefaultClient.Get(requestURL)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
assert.Nil(t, err)
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
assert.True(t, resp.StatusCode == http.StatusOK)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
response := dns.Msg{}
|
||||
err = response.Unpack(body)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
assert.Nil(t, err)
|
||||
addrs = nil
|
||||
proxyutil.AppendIPAddrs(&addrs, response.Answer)
|
||||
haveIP = len(addrs) != 0
|
||||
@@ -186,6 +186,6 @@ func TestHome(t *testing.T) {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
cleanup()
|
||||
cleanup(context.Background())
|
||||
cleanupAlways()
|
||||
}
|
||||
|
||||
@@ -5,16 +5,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPDetector_detectSpecialNetwork(t *testing.T) {
|
||||
var ipd *ipDetector
|
||||
var err error
|
||||
|
||||
t.Run("newIPDetector", func(t *testing.T) {
|
||||
var err error
|
||||
ipd, err = newIPDetector()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
ipd, err = newIPDetector()
|
||||
require.Nil(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -22,15 +22,43 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// RequestBodySizeLimit is maximum request body length in bytes.
|
||||
const RequestBodySizeLimit = 64 * 1024
|
||||
// defaultReqBodySzLim is the default maximum request body size.
|
||||
const defaultReqBodySzLim = 64 * 1024
|
||||
|
||||
// largerReqBodySzLim is the maximum request body size for APIs expecting larger
|
||||
// requests.
|
||||
const largerReqBodySzLim = 4 * 1024 * 1024
|
||||
|
||||
// expectsLargerRequests shows if this request should use a larger body size
|
||||
// limit. These are exceptions for poorly designed current APIs as well as APIs
|
||||
// that are designed to expect large files and requests. Remove once the new,
|
||||
// better APIs are up.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2666 and
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/2675.
|
||||
func expectsLargerRequests(r *http.Request) (ok bool) {
|
||||
m := r.Method
|
||||
if m != http.MethodPost {
|
||||
return false
|
||||
}
|
||||
|
||||
p := r.URL.Path
|
||||
return p == "/control/access/set" ||
|
||||
p == "/control/filtering/set_rules"
|
||||
}
|
||||
|
||||
// limitRequestBody wraps underlying handler h, making it's request's body Read
|
||||
// method limited.
|
||||
func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
|
||||
|
||||
var szLim int64 = defaultReqBodySzLim
|
||||
if expectsLargerRequests(r) {
|
||||
szLim = largerReqBodySzLim
|
||||
}
|
||||
|
||||
r.Body, err = aghio.LimitReadCloser(r.Body, szLim)
|
||||
if err != nil {
|
||||
log.Error("limitRequestBody: %s", err)
|
||||
|
||||
@@ -40,3 +68,18 @@ func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// wrapIndexBeta returns handler that deals with new client.
|
||||
func (web *Web) wrapIndexBeta(http.Handler) (wrapped http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h, pattern := Context.mux.Handler(r)
|
||||
switch pattern {
|
||||
case "/":
|
||||
web.handlerBeta.ServeHTTP(w, r)
|
||||
case "/install.html":
|
||||
web.installerBeta.ServeHTTP(w, r)
|
||||
default:
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitRequestBody(t *testing.T) {
|
||||
errReqLimitReached := &aghio.LimitReachedError{
|
||||
Limit: RequestBodySizeLimit,
|
||||
Limit: defaultReqBodySzLim,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -28,8 +29,8 @@ func TestLimitRequestBody(t *testing.T) {
|
||||
wantErr: nil,
|
||||
}, {
|
||||
name: "so_big",
|
||||
body: string(make([]byte, RequestBodySizeLimit+1)),
|
||||
want: make([]byte, RequestBodySizeLimit),
|
||||
body: string(make([]byte, defaultReqBodySzLim+1)),
|
||||
want: make([]byte, defaultReqBodySzLim),
|
||||
wantErr: errReqLimitReached,
|
||||
}, {
|
||||
name: "empty",
|
||||
@@ -42,7 +43,10 @@ func TestLimitRequestBody(t *testing.T) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var b []byte
|
||||
b, *err = ioutil.ReadAll(r.Body)
|
||||
w.Write(b)
|
||||
_, werr := w.Write(b)
|
||||
if werr != nil {
|
||||
panic(werr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,8 +61,8 @@ func TestLimitRequestBody(t *testing.T) {
|
||||
|
||||
lim.ServeHTTP(res, req)
|
||||
|
||||
require.Equal(t, tc.wantErr, err)
|
||||
assert.Equal(t, tc.want, res.Body.Bytes())
|
||||
assert.Equal(t, tc.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"howett.net/plist"
|
||||
@@ -14,6 +17,7 @@ type dnsSettings struct {
|
||||
DNSProtocol string
|
||||
ServerURL string `plist:",omitempty"`
|
||||
ServerName string `plist:",omitempty"`
|
||||
clientID string
|
||||
}
|
||||
|
||||
type payloadContent struct {
|
||||
@@ -23,19 +27,19 @@ type payloadContent struct {
|
||||
PayloadIdentifier string
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadVersion int
|
||||
DNSSettings dnsSettings
|
||||
PayloadVersion int
|
||||
}
|
||||
|
||||
type mobileConfig struct {
|
||||
PayloadContent []payloadContent
|
||||
PayloadDescription string
|
||||
PayloadDisplayName string
|
||||
PayloadIdentifier string
|
||||
PayloadRemovalDisallowed bool
|
||||
PayloadType string
|
||||
PayloadUUID string
|
||||
PayloadContent []payloadContent
|
||||
PayloadVersion int
|
||||
PayloadRemovalDisallowed bool
|
||||
}
|
||||
|
||||
func genUUIDv4() string {
|
||||
@@ -48,22 +52,35 @@ const (
|
||||
)
|
||||
|
||||
func getMobileConfig(d dnsSettings) ([]byte, error) {
|
||||
var name string
|
||||
var dspName string
|
||||
switch d.DNSProtocol {
|
||||
case dnsProtoHTTPS:
|
||||
name = fmt.Sprintf("%s DoH", d.ServerName)
|
||||
d.ServerURL = fmt.Sprintf("https://%s/dns-query", d.ServerName)
|
||||
dspName = fmt.Sprintf("%s DoH", d.ServerName)
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: d.ServerName,
|
||||
Path: "/dns-query",
|
||||
}
|
||||
if d.clientID != "" {
|
||||
u.Path = path.Join(u.Path, d.clientID)
|
||||
}
|
||||
|
||||
d.ServerURL = u.String()
|
||||
case dnsProtoTLS:
|
||||
name = fmt.Sprintf("%s DoT", d.ServerName)
|
||||
dspName = fmt.Sprintf("%s DoT", d.ServerName)
|
||||
if d.clientID != "" {
|
||||
d.ServerName = d.clientID + "." + d.ServerName
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol)
|
||||
}
|
||||
|
||||
data := mobileConfig{
|
||||
PayloadContent: []payloadContent{{
|
||||
Name: name,
|
||||
Name: dspName,
|
||||
PayloadDescription: "Configures device to use AdGuard Home",
|
||||
PayloadDisplayName: name,
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadIdentifier: fmt.Sprintf("com.apple.dnsSettings.managed.%s", genUUIDv4()),
|
||||
PayloadType: "com.apple.dnsSettings.managed",
|
||||
PayloadUUID: genUUIDv4(),
|
||||
@@ -71,7 +88,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
|
||||
DNSSettings: d,
|
||||
}},
|
||||
PayloadDescription: "Adds AdGuard Home to Big Sur and iOS 14 or newer systems",
|
||||
PayloadDisplayName: name,
|
||||
PayloadDisplayName: dspName,
|
||||
PayloadIdentifier: genUUIDv4(),
|
||||
PayloadRemovalDisallowed: false,
|
||||
PayloadType: "Configuration",
|
||||
@@ -83,7 +100,10 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
|
||||
}
|
||||
|
||||
func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||
host := r.URL.Query().Get("host")
|
||||
var err error
|
||||
|
||||
q := r.URL.Query()
|
||||
host := q.Get("host")
|
||||
if host == "" {
|
||||
host = Context.tls.conf.ServerName
|
||||
}
|
||||
@@ -92,7 +112,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
const msg = "no host in query parameters and no server_name"
|
||||
err := json.NewEncoder(w).Encode(&jsonError{
|
||||
err = json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: msg,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -102,9 +122,25 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||
return
|
||||
}
|
||||
|
||||
clientID := q.Get("client_id")
|
||||
err = dnsforward.ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
err = json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: err.Error(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug("writing 400 json response: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
d := dnsSettings{
|
||||
DNSProtocol: dnsp,
|
||||
ServerName: host,
|
||||
clientID: clientID,
|
||||
}
|
||||
|
||||
mobileconfig, err := getMobileConfig(d)
|
||||
@@ -115,6 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/xml")
|
||||
|
||||
_, _ = w.Write(mobileconfig)
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Equal(t, 1, len(mc.PayloadContent)) {
|
||||
if assert.Len(t, mc.PayloadContent, 1) {
|
||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
@@ -51,7 +51,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Equal(t, 1, len(mc.PayloadContent)) {
|
||||
if assert.Len(t, mc.PayloadContent, 1) {
|
||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
@@ -73,6 +73,27 @@ func TestHandleMobileConfigDOH(t *testing.T) {
|
||||
handleMobileConfigDOH(w, r)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
|
||||
t.Run("client_id", func(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handleMobileConfigDOH(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var mc mobileConfig
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Len(t, mc.PayloadContent, 1) {
|
||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleMobileConfigDOT(t *testing.T) {
|
||||
@@ -89,7 +110,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Equal(t, 1, len(mc.PayloadContent)) {
|
||||
if assert.Len(t, mc.PayloadContent, 1) {
|
||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
@@ -116,7 +137,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Equal(t, 1, len(mc.PayloadContent)) {
|
||||
if assert.Len(t, mc.PayloadContent, 1) {
|
||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
@@ -137,4 +158,24 @@ func TestHandleMobileConfigDOT(t *testing.T) {
|
||||
handleMobileConfigDOT(w, r)
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
})
|
||||
|
||||
t.Run("client_id", func(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handleMobileConfigDOT(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var mc mobileConfig
|
||||
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if assert.Len(t, mc.PayloadContent, 1) {
|
||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
|
||||
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
|
||||
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
)
|
||||
|
||||
// options passed from command-line arguments
|
||||
@@ -11,7 +14,7 @@ type options struct {
|
||||
verbose bool // is verbose logging enabled
|
||||
configFilename string // path to the config file
|
||||
workDir string // path to the working directory where we will store the filters data and the querylog
|
||||
bindHost string // host address to bind HTTP server on
|
||||
bindHost net.IP // host address to bind HTTP server on
|
||||
bindPort int // port to serve HTTP pages on
|
||||
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
|
||||
pidFile string // File name to save PID to
|
||||
@@ -52,10 +55,19 @@ type arg struct {
|
||||
// against its zero value and return nil if the parameter value is
|
||||
// zero otherwise they return a string slice of the parameter
|
||||
|
||||
func ipSliceOrNil(ip net.IP) []string {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []string{ip.String()}
|
||||
}
|
||||
|
||||
func stringSliceOrNil(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
@@ -63,6 +75,7 @@ func intSliceOrNil(i int) []string {
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []string{strconv.Itoa(i)}
|
||||
}
|
||||
|
||||
@@ -70,6 +83,7 @@ func boolSliceOrNil(b bool) []string {
|
||||
if b {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -94,8 +108,8 @@ var workDirArg = arg{
|
||||
var hostArg = arg{
|
||||
"Host address to bind HTTP server on",
|
||||
"host", "h",
|
||||
func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil,
|
||||
func(o options) []string { return stringSliceOrNil(o.bindHost) },
|
||||
func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil,
|
||||
func(o options) []string { return ipSliceOrNil(o.bindHost) },
|
||||
}
|
||||
|
||||
var portArg = arg{
|
||||
@@ -180,7 +194,7 @@ var versionArg = arg{
|
||||
"Show the version and exit",
|
||||
"version", "",
|
||||
nil, nil, func(o options, exec string) (effect, error) {
|
||||
return func() error { fmt.Println(version()); os.Exit(0); return nil }, nil
|
||||
return func() error { fmt.Println(version.Full()); os.Exit(0); return nil }, nil
|
||||
},
|
||||
func(o options) []string { return nil },
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -65,14 +66,14 @@ func TestParseWorkDir(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseBindHost(t *testing.T) {
|
||||
if testParseOk(t).bindHost != "" {
|
||||
if testParseOk(t).bindHost != nil {
|
||||
t.Fatal("empty is no host")
|
||||
}
|
||||
if testParseOk(t, "-h", "addr").bindHost != "addr" {
|
||||
if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
|
||||
t.Fatal("-h is host")
|
||||
}
|
||||
testParseParamMissing(t, "-h")
|
||||
if testParseOk(t, "--host", "addr").bindHost != "addr" {
|
||||
if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
|
||||
t.Fatal("--host is host")
|
||||
}
|
||||
testParseParamMissing(t, "--host")
|
||||
@@ -204,7 +205,7 @@ func TestSerializeWorkDir(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSerializeBindHost(t *testing.T) {
|
||||
testSerialize(t, options{bindHost: "addr"}, "-h", "addr")
|
||||
testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
|
||||
}
|
||||
|
||||
func TestSerializeBindPort(t *testing.T) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -15,7 +16,7 @@ import (
|
||||
type RDNS struct {
|
||||
dnsServer *dnsforward.Server
|
||||
clients *clientsContainer
|
||||
ipChannel chan string // pass data from DNS request handling thread to rDNS thread
|
||||
ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread
|
||||
|
||||
// Contains IP addresses of clients to be resolved by rDNS
|
||||
// If IP address is resolved, it stays here while it's inside Clients.
|
||||
@@ -26,24 +27,24 @@ type RDNS struct {
|
||||
|
||||
// InitRDNS - create module context
|
||||
func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
|
||||
r := RDNS{}
|
||||
r.dnsServer = dnsServer
|
||||
r.clients = clients
|
||||
r := &RDNS{
|
||||
dnsServer: dnsServer,
|
||||
clients: clients,
|
||||
ipAddrs: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: 10000,
|
||||
}),
|
||||
ipChannel: make(chan net.IP, 256),
|
||||
}
|
||||
|
||||
cconf := cache.Config{}
|
||||
cconf.EnableLRU = true
|
||||
cconf.MaxCount = 10000
|
||||
r.ipAddrs = cache.New(cconf)
|
||||
|
||||
r.ipChannel = make(chan string, 256)
|
||||
go r.workerLoop()
|
||||
return &r
|
||||
return r
|
||||
}
|
||||
|
||||
// Begin - add IP address to rDNS queue
|
||||
func (r *RDNS) Begin(ip string) {
|
||||
func (r *RDNS) Begin(ip net.IP) {
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := r.ipAddrs.Get([]byte(ip))
|
||||
expire := r.ipAddrs.Get(ip)
|
||||
if len(expire) != 0 {
|
||||
exp := binary.BigEndian.Uint64(expire)
|
||||
if exp > now {
|
||||
@@ -54,9 +55,10 @@ func (r *RDNS) Begin(ip string) {
|
||||
expire = make([]byte, 8)
|
||||
const ttl = 1 * 60 * 60
|
||||
binary.BigEndian.PutUint64(expire, now+ttl)
|
||||
_ = r.ipAddrs.Set([]byte(ip), expire)
|
||||
_ = r.ipAddrs.Set(ip, expire)
|
||||
|
||||
if r.clients.Exists(ip, ClientSourceRDNS) {
|
||||
id := ip.String()
|
||||
if r.clients.Exists(id, ClientSourceRDNS) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -70,26 +72,26 @@ func (r *RDNS) Begin(ip string) {
|
||||
}
|
||||
|
||||
// Use rDNS to get hostname by IP address
|
||||
func (r *RDNS) resolve(ip string) string {
|
||||
func (r *RDNS) resolve(ip net.IP) string {
|
||||
log.Tracef("Resolving host for %s", ip)
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
},
|
||||
}
|
||||
var err error
|
||||
req.Question[0].Name, err = dns.ReverseAddr(ip)
|
||||
name, err := dns.ReverseAddr(ip.String())
|
||||
if err != nil {
|
||||
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resp, err := r.dnsServer.Exchange(&req)
|
||||
resp, err := r.dnsServer.Exchange(&dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: name,
|
||||
Qtype: dns.TypePTR,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug("Error while making an rDNS lookup for %s: %s", ip, err)
|
||||
return ""
|
||||
@@ -123,6 +125,6 @@ func (r *RDNS) workerLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,32 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResolveRDNS(t *testing.T) {
|
||||
dns := &dnsforward.Server{}
|
||||
conf := &dnsforward.ServerConfig{}
|
||||
conf.UpstreamDNS = []string{"8.8.8.8"}
|
||||
err := dns.Prepare(conf)
|
||||
assert.True(t, err == nil, "%s", err)
|
||||
ups := &aghtest.TestUpstream{
|
||||
Reverse: map[string][]string{
|
||||
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
|
||||
},
|
||||
}
|
||||
dns := dnsforward.NewCustomServer(&proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: []upstream.Upstream{ups},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
clients := &clientsContainer{}
|
||||
rdns := InitRDNS(dns, clients)
|
||||
r := rdns.resolve("1.1.1.1")
|
||||
assert.True(t, r == "one.one.one.one", "%s", r)
|
||||
r := rdns.resolve(net.IP{1, 1, 1, 1})
|
||||
assert.Equal(t, "one.one.one.one", r, r)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Move shell templates into actual files. Either during the
|
||||
// v0.106.0 cycle using packr or during the following cycle using go:embed.
|
||||
|
||||
const (
|
||||
launchdStdoutPath = "/var/log/AdGuardHome.stdout.log"
|
||||
launchdStderrPath = "/var/log/AdGuardHome.stderr.log"
|
||||
@@ -504,6 +507,10 @@ status() {
|
||||
}
|
||||
`
|
||||
|
||||
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
|
||||
// guarantees that it will actually be the required directory.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
|
||||
const freeBSDScript = `#!/bin/sh
|
||||
# PROVIDE: {{.Name}}
|
||||
# REQUIRE: networking
|
||||
@@ -514,6 +521,6 @@ name="{{.Name}}"
|
||||
{{.Name}}_user="root"
|
||||
pidfile="/var/run/${name}.pid"
|
||||
command="/usr/sbin/daemon"
|
||||
command_args="-P ${pidfile} -r -f {{.WorkingDirectory}}/{{.Name}}"
|
||||
command_args="-P ${pidfile} -f -r {{.WorkingDirectory}}/{{.Name}}"
|
||||
run_rc_command "$1"
|
||||
`
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
@@ -92,7 +93,7 @@ func (t *TLSMod) setCertFileTime() {
|
||||
t.certLastMod = fi.ModTime().UTC()
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
// Start updates the configuration of TLSMod and starts it.
|
||||
func (t *TLSMod) Start() {
|
||||
if !tlsWebHandlersRegistered {
|
||||
tlsWebHandlersRegistered = true
|
||||
@@ -102,10 +103,14 @@ func (t *TLSMod) Start() {
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
t.confLock.Unlock()
|
||||
Context.web.TLSConfigChanged(tlsConf)
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps
|
||||
// context with timeout on its own and shuts down the server, which
|
||||
// handles current request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// Reload - reload certificate file
|
||||
// Reload updates the configuration of TLSMod and restarts it.
|
||||
func (t *TLSMod) Reload() {
|
||||
t.confLock.Lock()
|
||||
tlsConf := t.conf
|
||||
@@ -139,7 +144,10 @@ func (t *TLSMod) Reload() {
|
||||
t.confLock.Lock()
|
||||
tlsConf = t.conf
|
||||
t.confLock.Unlock()
|
||||
Context.web.TLSConfigChanged(tlsConf)
|
||||
// The background context is used because the TLSConfigChanged wraps
|
||||
// context with timeout on its own and shuts down the server, which
|
||||
// handles current request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
}
|
||||
|
||||
// Set certificate and private key data
|
||||
@@ -296,11 +304,13 @@ func (t *TLSMod) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
|
||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||
// The background context is used because the TLSConfigChanged wraps
|
||||
// context with timeout on its own and shuts down the server, which
|
||||
// handles current request. It is also should be done in a separate
|
||||
// goroutine due to the same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(data)
|
||||
Context.web.TLSConfigChanged(context.Background(), data)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -534,7 +544,7 @@ func marshalTLS(w http.ResponseWriter, data tlsConfig) {
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration
|
||||
func (t *TLSMod) registerWebHandlers() {
|
||||
httpRegister("GET", "/control/tls/status", t.handleTLSStatus)
|
||||
httpRegister("POST", "/control/tls/configure", t.handleTLSConfigure)
|
||||
httpRegister("POST", "/control/tls/validate", t.handleTLSValidate)
|
||||
httpRegister(http.MethodGet, "/control/tls/status", t.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", t.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", t.handleTLSValidate)
|
||||
}
|
||||
|
||||
@@ -3,183 +3,108 @@ package home
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpgrade1to2(t *testing.T) {
|
||||
// let's create test config for 1 schema version
|
||||
diskConfig := createTestDiskConfig(1)
|
||||
// any is a convenient alias for interface{}.
|
||||
type any = interface{}
|
||||
|
||||
// update config
|
||||
err := upgradeSchema1to2(&diskConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Can't upgrade schema version from 1 to 2")
|
||||
}
|
||||
// object is a convenient alias for map[string]interface{}.
|
||||
type object = map[string]any
|
||||
|
||||
// ensure that schema version was bumped
|
||||
compareSchemaVersion(t, diskConfig["schema_version"], 2)
|
||||
func TestUpgradeSchema1to2(t *testing.T) {
|
||||
diskConf := testDiskConf(1)
|
||||
|
||||
// old coredns entry should be removed
|
||||
_, ok := diskConfig["coredns"]
|
||||
if ok {
|
||||
t.Fatalf("Core DNS config was not removed after upgrade schema version from 1 to 2")
|
||||
}
|
||||
err := upgradeSchema1to2(&diskConf)
|
||||
require.Nil(t, err)
|
||||
|
||||
// pull out new dns config
|
||||
dnsMap, ok := diskConfig["dns"]
|
||||
if !ok {
|
||||
t.Fatalf("No DNS config after upgrade schema version from 1 to 2")
|
||||
}
|
||||
require.Equal(t, diskConf["schema_version"], 2)
|
||||
|
||||
// cast dns configurations to maps and compare them
|
||||
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(1))
|
||||
newDNSConfig := castInterfaceToMap(t, dnsMap)
|
||||
compareConfigs(t, &oldDNSConfig, &newDNSConfig)
|
||||
_, ok := diskConf["coredns"]
|
||||
require.False(t, ok)
|
||||
|
||||
dnsMap, ok := diskConf["dns"]
|
||||
require.True(t, ok)
|
||||
|
||||
oldDNSConf := convertToObject(t, testDNSConf(1))
|
||||
newDNSConf := convertToObject(t, dnsMap)
|
||||
assert.Equal(t, oldDNSConf, newDNSConf)
|
||||
|
||||
// exclude dns config and schema version from disk config comparison
|
||||
oldExcludedEntries := []string{"coredns", "schema_version"}
|
||||
newExcludedEntries := []string{"dns", "schema_version"}
|
||||
oldDiskConfig := createTestDiskConfig(1)
|
||||
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, oldExcludedEntries, newExcludedEntries)
|
||||
oldDiskConf := testDiskConf(1)
|
||||
assertEqualExcept(t, oldDiskConf, diskConf, oldExcludedEntries, newExcludedEntries)
|
||||
}
|
||||
|
||||
func TestUpgrade2to3(t *testing.T) {
|
||||
// let's create test config
|
||||
diskConfig := createTestDiskConfig(2)
|
||||
func TestUpgradeSchema2to3(t *testing.T) {
|
||||
diskConf := testDiskConf(2)
|
||||
|
||||
// upgrade schema from 2 to 3
|
||||
err := upgradeSchema2to3(&diskConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Can't update schema version from 2 to 3: %s", err)
|
||||
}
|
||||
err := upgradeSchema2to3(&diskConf)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check new schema version
|
||||
compareSchemaVersion(t, diskConfig["schema_version"], 3)
|
||||
require.Equal(t, diskConf["schema_version"], 3)
|
||||
|
||||
// pull out new dns configuration
|
||||
dnsMap, ok := diskConfig["dns"]
|
||||
if !ok {
|
||||
t.Fatalf("No dns config in new configuration")
|
||||
}
|
||||
dnsMap, ok := diskConf["dns"]
|
||||
require.True(t, ok)
|
||||
|
||||
// cast dns configuration to map
|
||||
newDNSConfig := castInterfaceToMap(t, dnsMap)
|
||||
|
||||
// check if bootstrap DNS becomes an array
|
||||
bootstrapDNS := newDNSConfig["bootstrap_dns"]
|
||||
newDNSConf := convertToObject(t, dnsMap)
|
||||
bootstrapDNS := newDNSConf["bootstrap_dns"]
|
||||
switch v := bootstrapDNS.(type) {
|
||||
case []string:
|
||||
if len(v) != 1 {
|
||||
t.Fatalf("Wrong count of bootsrap DNS servers: %d", len(v))
|
||||
}
|
||||
|
||||
if v[0] != "8.8.8.8:53" {
|
||||
t.Fatalf("Bootsrap DNS server is not 8.8.8.8:53 : %s", v[0])
|
||||
}
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "8.8.8.8:53", v[0])
|
||||
default:
|
||||
t.Fatalf("Wrong type for bootsrap DNS: %T", v)
|
||||
t.Fatalf("wrong type for bootsrap dns: %T", v)
|
||||
}
|
||||
|
||||
// exclude bootstrap DNS from DNS configs comparison
|
||||
excludedEntries := []string{"bootstrap_dns"}
|
||||
oldDNSConfig := castInterfaceToMap(t, createTestDNSConfig(2))
|
||||
compareConfigsWithoutEntries(t, &oldDNSConfig, &newDNSConfig, excludedEntries, excludedEntries)
|
||||
oldDNSConf := convertToObject(t, testDNSConf(2))
|
||||
assertEqualExcept(t, oldDNSConf, newDNSConf, excludedEntries, excludedEntries)
|
||||
|
||||
// excluded dns config and schema version from disk config comparison
|
||||
excludedEntries = []string{"dns", "schema_version"}
|
||||
oldDiskConfig := createTestDiskConfig(2)
|
||||
compareConfigsWithoutEntries(t, &oldDiskConfig, &diskConfig, excludedEntries, excludedEntries)
|
||||
oldDiskConf := testDiskConf(2)
|
||||
assertEqualExcept(t, oldDiskConf, diskConf, excludedEntries, excludedEntries)
|
||||
}
|
||||
|
||||
func castInterfaceToMap(t *testing.T, oldConfig interface{}) (newConfig map[string]interface{}) {
|
||||
newConfig = make(map[string]interface{})
|
||||
switch v := oldConfig.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
func convertToObject(t *testing.T, oldConf any) (newConf object) {
|
||||
t.Helper()
|
||||
|
||||
switch v := oldConf.(type) {
|
||||
case map[any]any:
|
||||
newConf = make(object, len(v))
|
||||
for key, value := range v {
|
||||
newConfig[fmt.Sprint(key)] = value
|
||||
newConf[fmt.Sprint(key)] = value
|
||||
}
|
||||
case map[string]interface{}:
|
||||
case object:
|
||||
newConf = make(object, len(v))
|
||||
for key, value := range v {
|
||||
newConfig[key] = value
|
||||
newConf[key] = value
|
||||
}
|
||||
default:
|
||||
t.Fatalf("DNS configuration is not a map")
|
||||
t.Fatalf("dns configuration is not a map, got %T", oldConf)
|
||||
}
|
||||
return
|
||||
|
||||
return newConf
|
||||
}
|
||||
|
||||
// compareConfigsWithoutEntry removes entries from configs and returns result of compareConfigs
|
||||
func compareConfigsWithoutEntries(t *testing.T, oldConfig, newConfig *map[string]interface{}, oldKey, newKey []string) {
|
||||
for _, k := range oldKey {
|
||||
delete(*oldConfig, k)
|
||||
// assertEqualExcept removes entries from configs and compares them.
|
||||
func assertEqualExcept(t *testing.T, oldConf, newConf object, oldKeys, newKeys []string) {
|
||||
t.Helper()
|
||||
|
||||
for _, k := range oldKeys {
|
||||
delete(oldConf, k)
|
||||
}
|
||||
for _, k := range newKey {
|
||||
delete(*newConfig, k)
|
||||
for _, k := range newKeys {
|
||||
delete(newConf, k)
|
||||
}
|
||||
compareConfigs(t, oldConfig, newConfig)
|
||||
|
||||
assert.Equal(t, oldConf, newConf)
|
||||
}
|
||||
|
||||
// compares configs before and after schema upgrade
|
||||
func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) {
|
||||
if len(*oldConfig) != len(*newConfig) {
|
||||
t.Fatalf("wrong config entries count! Before upgrade: %d; After upgrade: %d", len(*oldConfig), len(*oldConfig))
|
||||
}
|
||||
|
||||
// Check old and new entries
|
||||
for k, v := range *newConfig {
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
if value != (*oldConfig)[k] {
|
||||
t.Fatalf("wrong value for string %s. Before update: %s; After update: %s", k, (*oldConfig)[k], value)
|
||||
}
|
||||
case int:
|
||||
if value != (*oldConfig)[k] {
|
||||
t.Fatalf("wrong value for int %s. Before update: %d; After update: %d", k, (*oldConfig)[k], value)
|
||||
}
|
||||
case []string:
|
||||
for i, line := range value {
|
||||
if len((*oldConfig)[k].([]string)) != len(value) {
|
||||
t.Fatalf("wrong array length for %s. Before update: %d; After update: %d", k, len((*oldConfig)[k].([]string)), len(value))
|
||||
}
|
||||
if (*oldConfig)[k].([]string)[i] != line {
|
||||
t.Fatalf("wrong data for string array %s. Before update: %s; After update: %s", k, (*oldConfig)[k].([]string)[i], line)
|
||||
}
|
||||
}
|
||||
case bool:
|
||||
if v != (*oldConfig)[k].(bool) {
|
||||
t.Fatalf("wrong boolean value for %s", k)
|
||||
}
|
||||
case []filter:
|
||||
if len((*oldConfig)[k].([]filter)) != len(value) {
|
||||
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
|
||||
}
|
||||
for i, newFilter := range value {
|
||||
oldFilter := (*oldConfig)[k].([]filter)[i]
|
||||
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
|
||||
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
|
||||
}
|
||||
}
|
||||
default:
|
||||
t.Fatalf("uknown data type for %s: %T", k, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compareSchemaVersion check if newSchemaVersion equals schemaVersion
|
||||
func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVersion int) {
|
||||
switch v := newSchemaVersion.(type) {
|
||||
case int:
|
||||
if v != schemaVersion {
|
||||
t.Fatalf("Wrong schema version in new config file")
|
||||
}
|
||||
default:
|
||||
t.Fatalf("Schema version is not an integer after update")
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
|
||||
diskConfig = make(map[string]interface{})
|
||||
diskConfig["language"] = "en"
|
||||
diskConfig["filters"] = []filter{
|
||||
func testDiskConf(schemaVersion int) (diskConf object) {
|
||||
filters := []filter{
|
||||
{
|
||||
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
|
||||
Name: "Latvian filter",
|
||||
@@ -191,40 +116,51 @@ func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{})
|
||||
RulesCount: 200,
|
||||
},
|
||||
}
|
||||
diskConfig["user_rules"] = []string{}
|
||||
diskConfig["schema_version"] = schemaVersion
|
||||
diskConfig["bind_host"] = "0.0.0.0"
|
||||
diskConfig["bind_port"] = 80
|
||||
diskConfig["auth_name"] = "name"
|
||||
diskConfig["auth_pass"] = "pass"
|
||||
dnsConfig := createTestDNSConfig(schemaVersion)
|
||||
if schemaVersion > 1 {
|
||||
diskConfig["dns"] = dnsConfig
|
||||
} else {
|
||||
diskConfig["coredns"] = dnsConfig
|
||||
diskConf = object{
|
||||
"language": "en",
|
||||
"filters": filters,
|
||||
"user_rules": []string{},
|
||||
"schema_version": schemaVersion,
|
||||
"bind_host": "0.0.0.0",
|
||||
"bind_port": 80,
|
||||
"auth_name": "name",
|
||||
"auth_pass": "pass",
|
||||
}
|
||||
return diskConfig
|
||||
|
||||
dnsConf := testDNSConf(schemaVersion)
|
||||
if schemaVersion > 1 {
|
||||
diskConf["dns"] = dnsConf
|
||||
} else {
|
||||
diskConf["coredns"] = dnsConf
|
||||
}
|
||||
|
||||
return diskConf
|
||||
}
|
||||
|
||||
func createTestDNSConfig(schemaVersion int) map[interface{}]interface{} {
|
||||
dnsConfig := make(map[interface{}]interface{})
|
||||
dnsConfig["port"] = 53
|
||||
dnsConfig["blocked_response_ttl"] = 10
|
||||
dnsConfig["querylog_enabled"] = true
|
||||
dnsConfig["ratelimit"] = 20
|
||||
dnsConfig["bootstrap_dns"] = "8.8.8.8:53"
|
||||
if schemaVersion > 2 {
|
||||
dnsConfig["bootstrap_dns"] = []string{"8.8.8.8:53"}
|
||||
// testDNSConf creates a DNS config for test the way gopkg.in/yaml.v2 would
|
||||
// unmarshal it. In YAML, keys aren't guaranteed to always only be strings.
|
||||
func testDNSConf(schemaVersion int) (dnsConf map[any]any) {
|
||||
dnsConf = map[any]any{
|
||||
"port": 53,
|
||||
"blocked_response_ttl": 10,
|
||||
"querylog_enabled": true,
|
||||
"ratelimit": 20,
|
||||
"bootstrap_dns": "8.8.8.8:53",
|
||||
"parental_sensitivity": 13,
|
||||
"ratelimit_whitelist": []string{},
|
||||
"upstream_dns": []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"},
|
||||
"filtering_enabled": true,
|
||||
"refuse_any": true,
|
||||
"parental_enabled": true,
|
||||
"bind_host": "0.0.0.0",
|
||||
"protection_enabled": true,
|
||||
"safesearch_enabled": true,
|
||||
"safebrowsing_enabled": true,
|
||||
}
|
||||
dnsConfig["parental_sensitivity"] = 13
|
||||
dnsConfig["ratelimit_whitelist"] = []string{}
|
||||
dnsConfig["upstream_dns"] = []string{"tls://1.1.1.1", "tls://1.0.0.1", "8.8.8.8"}
|
||||
dnsConfig["filtering_enabled"] = true
|
||||
dnsConfig["refuse_any"] = true
|
||||
dnsConfig["parental_enabled"] = true
|
||||
dnsConfig["bind_host"] = "0.0.0.0"
|
||||
dnsConfig["protection_enabled"] = true
|
||||
dnsConfig["safesearch_enabled"] = true
|
||||
dnsConfig["safebrowsing_enabled"] = true
|
||||
return dnsConfig
|
||||
|
||||
if schemaVersion > 2 {
|
||||
dnsConf["bootstrap_dns"] = []string{"8.8.8.8:53"}
|
||||
}
|
||||
|
||||
return dnsConf
|
||||
}
|
||||
|
||||
@@ -16,24 +16,22 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// ReadTimeout is the maximum duration for reading the entire request,
|
||||
// readTimeout is the maximum duration for reading the entire request,
|
||||
// including the body.
|
||||
ReadTimeout = 10 * time.Second
|
||||
|
||||
// ReadHeaderTimeout is the amount of time allowed to read request
|
||||
// headers.
|
||||
ReadHeaderTimeout = 10 * time.Second
|
||||
|
||||
// WriteTimeout is the maximum duration before timing out writes of the
|
||||
readTimeout = 60 * time.Second
|
||||
// readHdrTimeout is the amount of time allowed to read request headers.
|
||||
readHdrTimeout = 60 * time.Second
|
||||
// writeTimeout is the maximum duration before timing out writes of the
|
||||
// response.
|
||||
WriteTimeout = 10 * time.Second
|
||||
writeTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type webConfig struct {
|
||||
firstRun bool
|
||||
BindHost string
|
||||
BindPort int
|
||||
PortHTTPS int
|
||||
firstRun bool
|
||||
BindHost net.IP
|
||||
BindPort int
|
||||
BetaBindPort int
|
||||
PortHTTPS int
|
||||
|
||||
// ReadTimeout is an option to pass to http.Server for setting an
|
||||
// appropriate field.
|
||||
@@ -62,9 +60,16 @@ type HTTPSServer struct {
|
||||
type Web struct {
|
||||
conf *webConfig
|
||||
forceHTTPS bool
|
||||
portHTTPS int
|
||||
httpServer *http.Server // HTTP module
|
||||
httpsServer HTTPSServer // HTTPS module
|
||||
|
||||
// handlerBeta is the handler for new client.
|
||||
handlerBeta http.Handler
|
||||
// installerBeta is the pre-install handler for new client.
|
||||
installerBeta http.Handler
|
||||
|
||||
// httpServerBeta is a server for new client.
|
||||
httpServerBeta *http.Server
|
||||
}
|
||||
|
||||
// CreateWeb - create module
|
||||
@@ -76,15 +81,20 @@ func CreateWeb(conf *webConfig) *Web {
|
||||
|
||||
// Initialize and run the admin Web interface
|
||||
box := packr.NewBox("../../build/static")
|
||||
boxBeta := packr.NewBox("../../build2/static")
|
||||
|
||||
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
|
||||
Context.mux.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
|
||||
Context.mux.Handle("/", withMiddlewares(http.FileServer(box), gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
|
||||
w.handlerBeta = withMiddlewares(http.FileServer(boxBeta), gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler)
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if conf.firstRun {
|
||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||
Context.mux.Handle("/install.html", preInstallHandler(http.FileServer(box)))
|
||||
w.installerBeta = preInstallHandler(http.FileServer(boxBeta))
|
||||
w.registerInstallHandlers()
|
||||
// This must be removed in API v1.
|
||||
w.registerBetaInstallHandlers()
|
||||
} else {
|
||||
registerControlHandlers()
|
||||
}
|
||||
@@ -109,12 +119,12 @@ func WebCheckPortAvailable(port int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TLSConfigChanged - called when TLS configuration has changed
|
||||
func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
|
||||
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// if necessary.
|
||||
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
log.Debug("Web: applying new TLS configuration")
|
||||
web.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||
web.portHTTPS = tlsConf.PortHTTPS
|
||||
|
||||
enabled := tlsConf.Enabled &&
|
||||
tlsConf.PortHTTPS != 0 &&
|
||||
@@ -131,7 +141,12 @@ func (web *Web) TLSConfigChanged(tlsConf tlsConfigSettings) {
|
||||
|
||||
web.httpsServer.cond.L.Lock()
|
||||
if web.httpsServer.server != nil {
|
||||
_ = web.httpsServer.server.Shutdown(context.TODO())
|
||||
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
|
||||
err = web.httpsServer.server.Shutdown(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Debug("error while shutting down HTTP server: %s", err)
|
||||
}
|
||||
}
|
||||
web.httpsServer.enabled = enabled
|
||||
web.httpsServer.cert = cert
|
||||
@@ -147,19 +162,40 @@ func (web *Web) Start() {
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.shutdown {
|
||||
printHTTPAddresses("http")
|
||||
errs := make(chan error, 2)
|
||||
|
||||
hostStr := web.conf.BindHost.String()
|
||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort))
|
||||
web.httpServer = &http.Server{
|
||||
ErrorLog: log.StdLog("web: http", log.DEBUG),
|
||||
Addr: address,
|
||||
Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BindPort)),
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
}
|
||||
go func() {
|
||||
errs <- web.httpServer.ListenAndServe()
|
||||
}()
|
||||
|
||||
err := web.httpServer.ListenAndServe()
|
||||
if web.conf.BetaBindPort != 0 {
|
||||
web.httpServerBeta = &http.Server{
|
||||
ErrorLog: log.StdLog("web: http", log.DEBUG),
|
||||
Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BetaBindPort)),
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta),
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
}
|
||||
go func() {
|
||||
betaErr := web.httpServerBeta.ListenAndServe()
|
||||
if betaErr != nil {
|
||||
log.Error("starting beta http server: %s", betaErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := <-errs
|
||||
if err != http.ErrServerClosed {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
@@ -168,19 +204,28 @@ func (web *Web) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// Close - stop HTTP server, possibly waiting for all active connections to be closed
|
||||
func (web *Web) Close() {
|
||||
// Close gracefully shuts down the HTTP servers.
|
||||
func (web *Web) Close(ctx context.Context) {
|
||||
log.Info("Stopping HTTP server...")
|
||||
web.httpsServer.cond.L.Lock()
|
||||
web.httpsServer.shutdown = true
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
if web.httpsServer.server != nil {
|
||||
_ = web.httpsServer.server.Shutdown(context.TODO())
|
||||
}
|
||||
if web.httpServer != nil {
|
||||
_ = web.httpServer.Shutdown(context.TODO())
|
||||
|
||||
shut := func(srv *http.Server) {
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Debug("error while shutting down HTTP server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
shut(web.httpsServer.server)
|
||||
shut(web.httpServer)
|
||||
shut(web.httpServerBeta)
|
||||
|
||||
log.Info("Stopped HTTP server")
|
||||
}
|
||||
|
||||
@@ -204,7 +249,7 @@ func (web *Web) tlsServerLoop() {
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
|
||||
// prepare HTTPS server
|
||||
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS))
|
||||
address := net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.PortHTTPS))
|
||||
web.httpsServer.server = &http.Server{
|
||||
ErrorLog: log.StdLog("web: https", log.DEBUG),
|
||||
Addr: address,
|
||||
@@ -214,7 +259,7 @@ func (web *Web) tlsServerLoop() {
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCiphers,
|
||||
},
|
||||
Handler: Context.mux,
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
@@ -25,29 +24,32 @@ const (
|
||||
|
||||
// Whois - module context
|
||||
type Whois struct {
|
||||
clients *clientsContainer
|
||||
ipChan chan string
|
||||
timeoutMsec uint
|
||||
clients *clientsContainer
|
||||
ipChan chan net.IP
|
||||
|
||||
// Contains IP addresses of clients
|
||||
// An active IP address is resolved once again after it expires.
|
||||
// If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP.
|
||||
ipAddrs cache.Cache
|
||||
|
||||
// TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why?
|
||||
timeoutMsec uint
|
||||
}
|
||||
|
||||
// Create module context
|
||||
// initWhois creates the Whois module context.
|
||||
func initWhois(clients *clientsContainer) *Whois {
|
||||
w := Whois{}
|
||||
w.timeoutMsec = 5000
|
||||
w.clients = clients
|
||||
w := Whois{
|
||||
timeoutMsec: 5000,
|
||||
clients: clients,
|
||||
ipAddrs: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: 10000,
|
||||
}),
|
||||
ipChan: make(chan net.IP, 255),
|
||||
}
|
||||
|
||||
cconf := cache.Config{}
|
||||
cconf.EnableLRU = true
|
||||
cconf.MaxCount = 10000
|
||||
w.ipAddrs = cache.New(cconf)
|
||||
|
||||
w.ipChan = make(chan string, 255)
|
||||
go w.workerLoop()
|
||||
|
||||
return &w
|
||||
}
|
||||
|
||||
@@ -81,23 +83,16 @@ func whoisParse(data string) map[string]string {
|
||||
switch k {
|
||||
case "org-name":
|
||||
m["orgname"] = trimValue(v)
|
||||
case "orgname":
|
||||
fallthrough
|
||||
case "city":
|
||||
fallthrough
|
||||
case "country":
|
||||
case "city", "country", "orgname":
|
||||
m[k] = trimValue(v)
|
||||
|
||||
case "descr":
|
||||
if len(descr) == 0 {
|
||||
descr = v
|
||||
}
|
||||
case "netname":
|
||||
netname = v
|
||||
|
||||
case "whois": // "whois: whois.arin.net"
|
||||
m["whois"] = v
|
||||
|
||||
case "referralserver": // "ReferralServer: whois://whois.ripe.net"
|
||||
if strings.HasPrefix(v, "whois://") {
|
||||
m["whois"] = v[len("whois://"):]
|
||||
@@ -105,12 +100,16 @@ func whoisParse(data string) map[string]string {
|
||||
}
|
||||
}
|
||||
|
||||
// descr or netname -> orgname
|
||||
_, ok := m["orgname"]
|
||||
if !ok && len(descr) != 0 {
|
||||
m["orgname"] = trimValue(descr)
|
||||
} else if !ok && len(netname) != 0 {
|
||||
m["orgname"] = trimValue(netname)
|
||||
if !ok {
|
||||
// Set orgname from either descr or netname for the frontent.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps don't do that in the V1 HTTP API?
|
||||
if descr != "" {
|
||||
m["orgname"] = trimValue(descr)
|
||||
} else if netname != "" {
|
||||
m["orgname"] = trimValue(netname)
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
@@ -120,12 +119,12 @@ func whoisParse(data string) map[string]string {
|
||||
const MaxConnReadSize = 64 * 1024
|
||||
|
||||
// Send request to a server and receive the response
|
||||
func (w *Whois) query(target, serverAddr string) (string, error) {
|
||||
func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
if addr == "whois.arin.net" {
|
||||
target = "n + " + target
|
||||
}
|
||||
conn, err := customDialContext(context.TODO(), "tcp", serverAddr)
|
||||
conn, err := customDialContext(ctx, "tcp", serverAddr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -153,11 +152,11 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
|
||||
}
|
||||
|
||||
// Query WHOIS servers (handle redirects)
|
||||
func (w *Whois) queryAll(target string) (string, error) {
|
||||
func (w *Whois) queryAll(ctx context.Context, target string) (string, error) {
|
||||
server := net.JoinHostPort(defaultServer, defaultPort)
|
||||
const maxRedirects = 5
|
||||
for i := 0; i != maxRedirects; i++ {
|
||||
resp, err := w.query(target, server)
|
||||
resp, err := w.query(ctx, target, server)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -183,9 +182,9 @@ func (w *Whois) queryAll(target string) (string, error) {
|
||||
}
|
||||
|
||||
// Request WHOIS information
|
||||
func (w *Whois) process(ip string) [][]string {
|
||||
func (w *Whois) process(ctx context.Context, ip net.IP) [][]string {
|
||||
data := [][]string{}
|
||||
resp, err := w.queryAll(ip)
|
||||
resp, err := w.queryAll(ctx, ip.String())
|
||||
if err != nil {
|
||||
log.Debug("Whois: error: %s IP:%s", err, ip)
|
||||
return data
|
||||
@@ -209,7 +208,7 @@ func (w *Whois) process(ip string) [][]string {
|
||||
}
|
||||
|
||||
// Begin - begin requesting WHOIS info
|
||||
func (w *Whois) Begin(ip string) {
|
||||
func (w *Whois) Begin(ip net.IP) {
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := w.ipAddrs.Get([]byte(ip))
|
||||
if len(expire) != 0 {
|
||||
@@ -232,16 +231,18 @@ func (w *Whois) Begin(ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Get IP address from channel; get WHOIS info; associate info with a client
|
||||
// workerLoop processes the IP addresses it got from the channel and associates
|
||||
// the retrieving WHOIS info with a client.
|
||||
func (w *Whois) workerLoop() {
|
||||
for {
|
||||
ip := <-w.ipChan
|
||||
|
||||
info := w.process(ip)
|
||||
info := w.process(context.Background(), ip)
|
||||
if len(info) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
w.clients.SetWhoisInfo(ip, info)
|
||||
id := ip.String()
|
||||
w.clients.SetWhoisInfo(id, info)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
@@ -12,14 +13,19 @@ func prepareTestDNSServer() error {
|
||||
Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{})
|
||||
conf := &dnsforward.ServerConfig{}
|
||||
conf.UpstreamDNS = []string{"8.8.8.8"}
|
||||
|
||||
return Context.dnsServer.Prepare(conf)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): It's kind of complicated to get rid of network access in this
|
||||
// test. The thing is that *Whois creates new *net.Dialer each time it requests
|
||||
// the server, so it becomes hard to simulate handling of request from test even
|
||||
// with substituted upstream. However, it must be done.
|
||||
func TestWhois(t *testing.T) {
|
||||
assert.Nil(t, prepareTestDNSServer())
|
||||
|
||||
w := Whois{timeoutMsec: 5000}
|
||||
resp, err := w.queryAll("8.8.8.8")
|
||||
resp, err := w.queryAll(context.Background(), "8.8.8.8")
|
||||
assert.Nil(t, err)
|
||||
m := whoisParse(resp)
|
||||
assert.Equal(t, "Google LLC", m["orgname"])
|
||||
|
||||
@@ -17,14 +17,26 @@ import (
|
||||
type logEntryHandler (func(t json.Token, ent *logEntry) error)
|
||||
|
||||
var logEntryHandlers = map[string]logEntryHandler{
|
||||
"CID": func(t json.Token, ent *logEntry) error {
|
||||
v, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ent.ClientID = v
|
||||
|
||||
return nil
|
||||
},
|
||||
"IP": func(t json.Token, ent *logEntry) error {
|
||||
v, ok := t.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if len(ent.IP) == 0 {
|
||||
ent.IP = v
|
||||
|
||||
if ent.IP == nil {
|
||||
ent.IP = net.ParseIP(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
"T": func(t json.Token, ent *logEntry) error {
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
@@ -19,12 +19,13 @@ import (
|
||||
func TestDecodeLogEntry(t *testing.T) {
|
||||
logOutput := &bytes.Buffer{}
|
||||
|
||||
testutil.ReplaceLogWriter(t, logOutput)
|
||||
testutil.ReplaceLogLevel(t, log.DEBUG)
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
|
||||
const data = `{"IP":"127.0.0.1",` +
|
||||
`"CID":"cli42",` +
|
||||
`"T":"2020-11-25T18:55:56.519796+03:00",` +
|
||||
`"QH":"an.yandex.ru",` +
|
||||
`"QT":"A",` +
|
||||
@@ -47,11 +48,12 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
want := &logEntry{
|
||||
IP: "127.0.0.1",
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Time: time.Date(2020, 11, 25, 15, 55, 56, 519796000, time.UTC),
|
||||
QHost: "an.yandex.ru",
|
||||
QType: "A",
|
||||
QClass: "IN",
|
||||
ClientID: "cli42",
|
||||
ClientProto: "",
|
||||
Answer: ans,
|
||||
Result: dnsfilter.Result{
|
||||
@@ -84,7 +86,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
decodeLogEntry(got, data)
|
||||
|
||||
s := logOutput.String()
|
||||
assert.Equal(t, "", s)
|
||||
assert.Empty(t, s)
|
||||
|
||||
// Correct for time zones.
|
||||
got.Time = got.Time.UTC()
|
||||
@@ -172,7 +174,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
||||
|
||||
s := logOutput.String()
|
||||
if tc.want == "" {
|
||||
assert.Equal(t, "", s)
|
||||
assert.Empty(t, s)
|
||||
} else {
|
||||
assert.True(t, strings.HasSuffix(s, tc.want),
|
||||
"got %q", s)
|
||||
|
||||
@@ -22,10 +22,10 @@ type qlogConfig struct {
|
||||
|
||||
// Register web handlers
|
||||
func (l *queryLog) initWeb() {
|
||||
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog)
|
||||
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
|
||||
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
|
||||
}
|
||||
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
@@ -14,22 +15,19 @@ import (
|
||||
// TODO(a.garipov): Use a proper structured approach here.
|
||||
|
||||
// Get Client IP address
|
||||
func (l *queryLog) getClientIP(clientIP string) string {
|
||||
if l.conf.AnonymizeClientIP {
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip != nil {
|
||||
ip4 := ip.To4()
|
||||
const AnonymizeClientIP4Mask = 16
|
||||
const AnonymizeClientIP6Mask = 112
|
||||
if ip4 != nil {
|
||||
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
|
||||
} else {
|
||||
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
|
||||
}
|
||||
func (l *queryLog) getClientIP(ip net.IP) (clientIP net.IP) {
|
||||
if l.conf.AnonymizeClientIP && ip != nil {
|
||||
const AnonymizeClientIPv4Mask = 16
|
||||
const AnonymizeClientIPv6Mask = 112
|
||||
|
||||
if ip.To4() != nil {
|
||||
return ip.Mask(net.CIDRMask(AnonymizeClientIPv4Mask, 32))
|
||||
}
|
||||
|
||||
return ip.Mask(net.CIDRMask(AnonymizeClientIPv6Mask, 128))
|
||||
}
|
||||
|
||||
return clientIP
|
||||
return ip
|
||||
}
|
||||
|
||||
// jobject is a JSON object alias.
|
||||
@@ -82,6 +80,10 @@ func (l *queryLog) logEntryToJSONEntry(entry *logEntry) (jsonEntry jobject) {
|
||||
},
|
||||
}
|
||||
|
||||
if entry.ClientID != "" {
|
||||
jsonEntry["client_id"] = entry.ClientID
|
||||
}
|
||||
|
||||
if msg != nil {
|
||||
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
|
||||
|
||||
@@ -138,48 +140,60 @@ func resultRulesToJSONRules(rules []*dnsfilter.ResultRule) (jsonRules []jobject)
|
||||
return jsonRules
|
||||
}
|
||||
|
||||
func answerToMap(a *dns.Msg) (answers []jobject) {
|
||||
type dnsAnswer struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
TTL uint32 `json:"ttl"`
|
||||
}
|
||||
|
||||
func answerToMap(a *dns.Msg) (answers []*dnsAnswer) {
|
||||
if a == nil || len(a.Answer) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
answers = []jobject{}
|
||||
answers = make([]*dnsAnswer, 0, len(a.Answer))
|
||||
for _, k := range a.Answer {
|
||||
header := k.Header()
|
||||
answer := jobject{
|
||||
"type": dns.TypeToString[header.Rrtype],
|
||||
"ttl": header.Ttl,
|
||||
answer := &dnsAnswer{
|
||||
Type: dns.TypeToString[header.Rrtype],
|
||||
TTL: header.Ttl,
|
||||
}
|
||||
// try most common record types
|
||||
|
||||
// Some special treatment for some well-known types.
|
||||
//
|
||||
// TODO(a.garipov): Consider just calling String() for everyone
|
||||
// instead.
|
||||
switch v := k.(type) {
|
||||
case nil:
|
||||
// Probably unlikely, but go on.
|
||||
case *dns.A:
|
||||
answer["value"] = v.A.String()
|
||||
answer.Value = v.A.String()
|
||||
case *dns.AAAA:
|
||||
answer["value"] = v.AAAA.String()
|
||||
answer.Value = v.AAAA.String()
|
||||
case *dns.MX:
|
||||
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
|
||||
answer.Value = fmt.Sprintf("%v %v", v.Preference, v.Mx)
|
||||
case *dns.CNAME:
|
||||
answer["value"] = v.Target
|
||||
answer.Value = v.Target
|
||||
case *dns.NS:
|
||||
answer["value"] = v.Ns
|
||||
answer.Value = v.Ns
|
||||
case *dns.SPF:
|
||||
answer["value"] = v.Txt
|
||||
answer.Value = strings.Join(v.Txt, "\n")
|
||||
case *dns.TXT:
|
||||
answer["value"] = v.Txt
|
||||
answer.Value = strings.Join(v.Txt, "\n")
|
||||
case *dns.PTR:
|
||||
answer["value"] = v.Ptr
|
||||
answer.Value = v.Ptr
|
||||
case *dns.SOA:
|
||||
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
|
||||
answer.Value = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
|
||||
case *dns.CAA:
|
||||
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
|
||||
answer.Value = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
|
||||
case *dns.HINFO:
|
||||
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
|
||||
answer.Value = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
|
||||
case *dns.RRSIG:
|
||||
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
|
||||
answer.Value = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
|
||||
default:
|
||||
// type unknown, marshall it as-is
|
||||
answer["value"] = v
|
||||
answer.Value = v.String()
|
||||
}
|
||||
|
||||
answers = append(answers, answer)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -36,10 +38,11 @@ type ClientProto string
|
||||
|
||||
// Client protocol names.
|
||||
const (
|
||||
ClientProtoDOH ClientProto = "doh"
|
||||
ClientProtoDOQ ClientProto = "doq"
|
||||
ClientProtoDOT ClientProto = "dot"
|
||||
ClientProtoPlain ClientProto = ""
|
||||
ClientProtoDOH ClientProto = "doh"
|
||||
ClientProtoDOQ ClientProto = "doq"
|
||||
ClientProtoDOT ClientProto = "dot"
|
||||
ClientProtoDNSCrypt ClientProto = "dnscrypt"
|
||||
ClientProtoPlain ClientProto = ""
|
||||
)
|
||||
|
||||
// NewClientProto validates that the client protocol name is valid and returns
|
||||
@@ -50,6 +53,7 @@ func NewClientProto(s string) (cp ClientProto, err error) {
|
||||
ClientProtoDOH,
|
||||
ClientProtoDOQ,
|
||||
ClientProtoDOT,
|
||||
ClientProtoDNSCrypt,
|
||||
ClientProtoPlain:
|
||||
|
||||
return cp, nil
|
||||
@@ -60,13 +64,14 @@ func NewClientProto(s string) (cp ClientProto, err error) {
|
||||
|
||||
// logEntry - represents a single log entry
|
||||
type logEntry struct {
|
||||
IP string `json:"IP"` // Client IP
|
||||
IP net.IP `json:"IP"` // Client IP
|
||||
Time time.Time `json:"T"`
|
||||
|
||||
QHost string `json:"QH"`
|
||||
QType string `json:"QT"`
|
||||
QClass string `json:"QC"`
|
||||
|
||||
ClientID string `json:"CID,omitempty"`
|
||||
ClientProto ClientProto `json:"CP"`
|
||||
|
||||
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
|
||||
@@ -118,14 +123,15 @@ func (l *queryLog) clear() {
|
||||
l.flushPending = false
|
||||
l.bufferLock.Unlock()
|
||||
|
||||
err := os.Remove(l.logFile + ".1")
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Error("file remove: %s: %s", l.logFile+".1", err)
|
||||
oldLogFile := l.logFile + ".1"
|
||||
err := os.Remove(oldLogFile)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("removing old log file %q: %s", oldLogFile, err)
|
||||
}
|
||||
|
||||
err = os.Remove(l.logFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Error("file remove: %s: %s", l.logFile, err)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("removing log file %q: %s", l.logFile, err)
|
||||
}
|
||||
|
||||
log.Debug("Query log: cleared")
|
||||
@@ -147,12 +153,13 @@ func (l *queryLog) Add(params AddParams) {
|
||||
|
||||
now := time.Now()
|
||||
entry := logEntry{
|
||||
IP: l.getClientIP(params.ClientIP.String()),
|
||||
IP: l.getClientIP(params.ClientIP),
|
||||
Time: now,
|
||||
|
||||
Result: *params.Result,
|
||||
Elapsed: params.Elapsed,
|
||||
Upstream: params.Upstream,
|
||||
ClientID: params.ClientID,
|
||||
ClientProto: params.ClientProto,
|
||||
}
|
||||
q := params.Question.Question[0]
|
||||
|
||||
@@ -1,242 +1,276 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxyutil"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
func prepareTestDir() string {
|
||||
const dir = "./agh-test"
|
||||
_ = os.RemoveAll(dir)
|
||||
_ = os.MkdirAll(dir, 0o755)
|
||||
return dir
|
||||
}
|
||||
|
||||
// Check adding and loading (with filtering) entries from disk and memory
|
||||
// TestQueryLog tests adding and loading (with filtering) entries from disk and
|
||||
// memory.
|
||||
func TestQueryLog(t *testing.T) {
|
||||
conf := Config{
|
||||
l := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
Interval: 1,
|
||||
MemSize: 100,
|
||||
BaseDir: aghtest.PrepareTestDir(t),
|
||||
})
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
// Write to disk (first file).
|
||||
require.Nil(t, l.flushLogBuffer(true))
|
||||
// Start writing to the second file.
|
||||
require.Nil(t, l.rotate())
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
|
||||
// Write to disk.
|
||||
require.Nil(t, l.flushLogBuffer(true))
|
||||
// Add memory entries.
|
||||
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
|
||||
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
|
||||
|
||||
type tcAssertion struct {
|
||||
num int
|
||||
host string
|
||||
answer, client net.IP
|
||||
}
|
||||
conf.BaseDir = prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
|
||||
l := newQueryLog(conf)
|
||||
|
||||
// add disk entries
|
||||
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
|
||||
// write to disk (first file)
|
||||
_ = l.flushLogBuffer(true)
|
||||
// start writing to the second file
|
||||
_ = l.rotate()
|
||||
// add disk entries
|
||||
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
|
||||
// write to disk
|
||||
_ = l.flushLogBuffer(true)
|
||||
// add memory entries
|
||||
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
|
||||
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
|
||||
testCases := []struct {
|
||||
name string
|
||||
sCr []searchCriteria
|
||||
want []tcAssertion
|
||||
}{{
|
||||
name: "all",
|
||||
sCr: []searchCriteria{},
|
||||
want: []tcAssertion{
|
||||
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
|
||||
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
},
|
||||
}, {
|
||||
name: "by_domain_strict",
|
||||
sCr: []searchCriteria{{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: true,
|
||||
value: "TEST.example.org",
|
||||
}},
|
||||
want: []tcAssertion{{
|
||||
num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3),
|
||||
}},
|
||||
}, {
|
||||
name: "by_domain_non-strict",
|
||||
sCr: []searchCriteria{{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: false,
|
||||
value: "example.ORG",
|
||||
}},
|
||||
want: []tcAssertion{
|
||||
{num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 1, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
},
|
||||
}, {
|
||||
name: "by_client_ip_strict",
|
||||
sCr: []searchCriteria{{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: true,
|
||||
value: "2.2.2.2",
|
||||
}},
|
||||
want: []tcAssertion{{
|
||||
num: 0, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2),
|
||||
}},
|
||||
}, {
|
||||
name: "by_client_ip_non-strict",
|
||||
sCr: []searchCriteria{{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: false,
|
||||
value: "2.2.2",
|
||||
}},
|
||||
want: []tcAssertion{
|
||||
{num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)},
|
||||
{num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)},
|
||||
{num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)},
|
||||
{num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)},
|
||||
},
|
||||
}}
|
||||
|
||||
// get all entries
|
||||
params := newSearchParams()
|
||||
entries, _ := l.search(params)
|
||||
assert.Equal(t, 4, len(entries))
|
||||
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
|
||||
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
|
||||
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
|
||||
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
params := newSearchParams()
|
||||
params.searchCriteria = tc.sCr
|
||||
|
||||
// search by domain (strict)
|
||||
params = newSearchParams()
|
||||
params.searchCriteria = append(params.searchCriteria, searchCriteria{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: true,
|
||||
value: "TEST.example.org",
|
||||
})
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 1, len(entries))
|
||||
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
|
||||
|
||||
// search by domain (not strict)
|
||||
params = newSearchParams()
|
||||
params.searchCriteria = append(params.searchCriteria, searchCriteria{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: false,
|
||||
value: "example.ORG",
|
||||
})
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 3, len(entries))
|
||||
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
|
||||
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
|
||||
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
|
||||
|
||||
// search by client IP (strict)
|
||||
params = newSearchParams()
|
||||
params.searchCriteria = append(params.searchCriteria, searchCriteria{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: true,
|
||||
value: "2.2.2.2",
|
||||
})
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 1, len(entries))
|
||||
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
|
||||
|
||||
// search by client IP (part of)
|
||||
params = newSearchParams()
|
||||
params.searchCriteria = append(params.searchCriteria, searchCriteria{
|
||||
criteriaType: ctDomainOrClient,
|
||||
strict: false,
|
||||
value: "2.2.2",
|
||||
})
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 4, len(entries))
|
||||
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
|
||||
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
|
||||
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
|
||||
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
|
||||
entries, _ := l.search(params)
|
||||
require.Len(t, entries, len(tc.want))
|
||||
for _, want := range tc.want {
|
||||
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
conf := Config{
|
||||
l := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
Interval: 1,
|
||||
MemSize: 100,
|
||||
}
|
||||
conf.BaseDir = prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
|
||||
l := newQueryLog(conf)
|
||||
BaseDir: aghtest.PrepareTestDir(t),
|
||||
})
|
||||
|
||||
// add 10 entries to the log
|
||||
for i := 0; i < 10; i++ {
|
||||
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1")
|
||||
const (
|
||||
entNum = 10
|
||||
firstPageDomain = "first.example.org"
|
||||
secondPageDomain = "second.example.org"
|
||||
)
|
||||
// Add entries to the log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// write them to disk (first file)
|
||||
_ = l.flushLogBuffer(true)
|
||||
// add 10 more entries to the log (memory)
|
||||
for i := 0; i < 10; i++ {
|
||||
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1")
|
||||
// Write them to the first file.
|
||||
require.Nil(t, l.flushLogBuffer(true))
|
||||
// Add more to the in-memory part of log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
|
||||
// First page
|
||||
params := newSearchParams()
|
||||
params.offset = 0
|
||||
params.limit = 10
|
||||
entries, _ := l.search(params)
|
||||
assert.Equal(t, 10, len(entries))
|
||||
assert.Equal(t, entries[0].QHost, "first.example.org")
|
||||
assert.Equal(t, entries[9].QHost, "first.example.org")
|
||||
|
||||
// Second page
|
||||
params.offset = 10
|
||||
params.limit = 10
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 10, len(entries))
|
||||
assert.Equal(t, entries[0].QHost, "second.example.org")
|
||||
assert.Equal(t, entries[9].QHost, "second.example.org")
|
||||
testCases := []struct {
|
||||
name string
|
||||
offset int
|
||||
limit int
|
||||
wantLen int
|
||||
want string
|
||||
}{{
|
||||
name: "page_1",
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
wantLen: 10,
|
||||
want: firstPageDomain,
|
||||
}, {
|
||||
name: "page_2",
|
||||
offset: 10,
|
||||
limit: 10,
|
||||
wantLen: 10,
|
||||
want: secondPageDomain,
|
||||
}, {
|
||||
name: "page_2.5",
|
||||
offset: 15,
|
||||
limit: 10,
|
||||
wantLen: 5,
|
||||
want: secondPageDomain,
|
||||
}, {
|
||||
name: "page_3",
|
||||
offset: 20,
|
||||
limit: 10,
|
||||
wantLen: 0,
|
||||
}}
|
||||
|
||||
// Second and a half page
|
||||
params.offset = 15
|
||||
params.limit = 10
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 5, len(entries))
|
||||
assert.Equal(t, entries[0].QHost, "second.example.org")
|
||||
assert.Equal(t, entries[4].QHost, "second.example.org")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
params.offset = tc.offset
|
||||
params.limit = tc.limit
|
||||
entries, _ := l.search(params)
|
||||
|
||||
// Third page
|
||||
params.offset = 20
|
||||
params.limit = 10
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 0, len(entries))
|
||||
require.Len(t, entries, tc.wantLen)
|
||||
|
||||
if tc.wantLen > 0 {
|
||||
assert.Equal(t, entries[0].QHost, tc.want)
|
||||
assert.Equal(t, entries[tc.wantLen-1].QHost, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
conf := Config{
|
||||
l := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
Interval: 1,
|
||||
MemSize: 100,
|
||||
}
|
||||
conf.BaseDir = prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
|
||||
l := newQueryLog(conf)
|
||||
BaseDir: aghtest.PrepareTestDir(t),
|
||||
})
|
||||
|
||||
// add 10 entries to the log
|
||||
for i := 0; i < 10; i++ {
|
||||
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
|
||||
const entNum = 10
|
||||
// Add entries to the log.
|
||||
for i := 0; i < entNum; i++ {
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// write them to disk (first file)
|
||||
_ = l.flushLogBuffer(true)
|
||||
// Write them to disk.
|
||||
require.Nil(t, l.flushLogBuffer(true))
|
||||
|
||||
params := newSearchParams()
|
||||
params.maxFileScanEntries = 5 // do not scan more than 5 records
|
||||
entries, _ := l.search(params)
|
||||
assert.Equal(t, 5, len(entries))
|
||||
|
||||
params.maxFileScanEntries = 0 // disable the limit
|
||||
entries, _ = l.search(params)
|
||||
assert.Equal(t, 10, len(entries))
|
||||
for _, maxFileScanEntries := range []int{5, 0} {
|
||||
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
|
||||
params.maxFileScanEntries = maxFileScanEntries
|
||||
entries, _ := l.search(params)
|
||||
assert.Len(t, entries, entNum-maxFileScanEntries)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryLogFileDisabled(t *testing.T) {
|
||||
conf := Config{
|
||||
l := newQueryLog(Config{
|
||||
Enabled: true,
|
||||
FileEnabled: false,
|
||||
Interval: 1,
|
||||
MemSize: 2,
|
||||
}
|
||||
conf.BaseDir = prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
|
||||
l := newQueryLog(conf)
|
||||
BaseDir: aghtest.PrepareTestDir(t),
|
||||
})
|
||||
|
||||
addEntry(l, "example1.org", "1.1.1.1", "2.2.2.1")
|
||||
addEntry(l, "example2.org", "1.1.1.1", "2.2.2.1")
|
||||
addEntry(l, "example3.org", "1.1.1.1", "2.2.2.1")
|
||||
// the oldest entry is now removed from mem buffer
|
||||
addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
// The oldest entry is going to be removed from memory buffer.
|
||||
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
|
||||
params := newSearchParams()
|
||||
ll, _ := l.search(params)
|
||||
assert.Equal(t, 2, len(ll))
|
||||
require.Len(t, ll, 2)
|
||||
assert.Equal(t, "example3.org", ll[0].QHost)
|
||||
assert.Equal(t, "example2.org", ll[1].QHost)
|
||||
}
|
||||
|
||||
func addEntry(l *queryLog, host, answerStr, client string) {
|
||||
q := dns.Msg{}
|
||||
q.Question = append(q.Question, dns.Question{
|
||||
Name: host + ".",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
})
|
||||
|
||||
a := dns.Msg{}
|
||||
a.Question = append(a.Question, q.Question[0])
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: q.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
func addEntry(l *queryLog, host string, answerStr, client net.IP) {
|
||||
q := dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
Name: host + ".",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
a := dns.Msg{
|
||||
Question: q.Question,
|
||||
Answer: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: q.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: answerStr,
|
||||
}},
|
||||
}
|
||||
answer.A = net.ParseIP(answerStr)
|
||||
a.Answer = append(a.Answer, answer)
|
||||
res := dnsfilter.Result{
|
||||
IsFiltered: true,
|
||||
Reason: dnsfilter.ReasonRewrite,
|
||||
Reason: dnsfilter.Rewritten,
|
||||
ServiceName: "SomeService",
|
||||
Rules: []*dnsfilter.ResultRule{{
|
||||
FilterListID: 1,
|
||||
@@ -248,25 +282,28 @@ func addEntry(l *queryLog, host, answerStr, client string) {
|
||||
Answer: &a,
|
||||
OrigAnswer: &a,
|
||||
Result: &res,
|
||||
ClientIP: net.ParseIP(client),
|
||||
ClientIP: client,
|
||||
Upstream: "upstream",
|
||||
}
|
||||
l.Add(params)
|
||||
}
|
||||
|
||||
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool {
|
||||
func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, entry)
|
||||
|
||||
assert.Equal(t, host, entry.QHost)
|
||||
assert.Equal(t, client, entry.IP)
|
||||
assert.Equal(t, "A", entry.QType)
|
||||
assert.Equal(t, "IN", entry.QClass)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
assert.Nil(t, msg.Unpack(entry.Answer))
|
||||
assert.Equal(t, 1, len(msg.Answer))
|
||||
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
|
||||
assert.NotNil(t, ip)
|
||||
assert.Equal(t, answer, ip.String())
|
||||
return true
|
||||
msg := &dns.Msg{}
|
||||
require.Nil(t, msg.Unpack(entry.Answer))
|
||||
require.Len(t, msg.Answer, 1)
|
||||
|
||||
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16()
|
||||
assert.Equal(t, answer, ip)
|
||||
}
|
||||
|
||||
func testEntries() (entries []*logEntry) {
|
||||
@@ -332,8 +369,8 @@ func TestLogEntriesByTime_sort(t *testing.T) {
|
||||
entries := testEntries()
|
||||
sort.Sort(logEntriesByTimeDesc(entries))
|
||||
|
||||
for i := 1; i < len(entries); i++ {
|
||||
assert.False(t, entries[i].Time.After(entries[i-1].Time),
|
||||
"%s %s", entries[i].Time, entries[i-1].Time)
|
||||
for i := range entries[1:] {
|
||||
assert.False(t, entries[i+1].Time.After(entries[i].Time),
|
||||
"%s %s", entries[i+1].Time, entries[i].Time)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ func (q *QLogFile) readNextLine(position int64) (string, int64, error) {
|
||||
// the goal is to read a chunk of file that includes the line with the specified position.
|
||||
func (q *QLogFile) initBuffer(position int64) error {
|
||||
q.bufferStart = int64(0)
|
||||
if (position - bufferSize) > 0 {
|
||||
if position > bufferSize {
|
||||
q.bufferStart = position - bufferSize
|
||||
}
|
||||
|
||||
@@ -264,12 +264,10 @@ func (q *QLogFile) initBuffer(position int64) error {
|
||||
if q.buffer == nil {
|
||||
q.buffer = make([]byte, bufferSize)
|
||||
}
|
||||
q.bufferLen, err = q.file.Read(q.buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
q.bufferLen, err = q.file.Read(q.buffer)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// readProbeLine reads a line that includes the specified position
|
||||
@@ -280,7 +278,7 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, int64, error) {
|
||||
// In order to do this, we'll define the boundaries
|
||||
seekPosition := int64(0)
|
||||
relativePos := position // position relative to the buffer we're going to read
|
||||
if (position - maxEntrySize) > 0 {
|
||||
if position > maxEntrySize {
|
||||
seekPosition = position - maxEntrySize
|
||||
relativePos = maxEntrySize
|
||||
}
|
||||
|
||||
@@ -2,347 +2,347 @@ package querylog
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQLogFileEmpty(t *testing.T) {
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFile := prepareTestFile(testDir, 0)
|
||||
// prepareTestFiles prepares several test query log files, each with the
|
||||
// specified lines count.
|
||||
func prepareTestFiles(t *testing.T, filesNum, linesNum int) []string {
|
||||
t.Helper()
|
||||
|
||||
// create the new QLogFile instance
|
||||
q, err := NewQLogFile(testFile)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, q)
|
||||
defer q.Close()
|
||||
|
||||
// seek to the start
|
||||
pos, err := q.SeekStart()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), pos)
|
||||
|
||||
// try reading anyway
|
||||
line, err := q.ReadNext()
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, "", line)
|
||||
}
|
||||
|
||||
func TestQLogFileLarge(t *testing.T) {
|
||||
// should be large enough
|
||||
count := 50000
|
||||
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFile := prepareTestFile(testDir, count)
|
||||
|
||||
// create the new QLogFile instance
|
||||
q, err := NewQLogFile(testFile)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, q)
|
||||
defer q.Close()
|
||||
|
||||
// seek to the start
|
||||
pos, err := q.SeekStart()
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, int64(0), pos)
|
||||
|
||||
read := 0
|
||||
var line string
|
||||
for err == nil {
|
||||
line, err = q.ReadNext()
|
||||
if err == nil {
|
||||
assert.True(t, len(line) > 0)
|
||||
read++
|
||||
}
|
||||
if filesNum == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
assert.Equal(t, count, read)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestQLogFileSeekLargeFile(t *testing.T) {
|
||||
// more or less big file
|
||||
count := 10000
|
||||
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFile := prepareTestFile(testDir, count)
|
||||
|
||||
// create the new QLogFile instance
|
||||
q, err := NewQLogFile(testFile)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, q)
|
||||
defer q.Close()
|
||||
|
||||
// CASE 1: NOT TOO OLD LINE
|
||||
testSeekLineQLogFile(t, q, 300)
|
||||
|
||||
// CASE 2: OLD LINE
|
||||
testSeekLineQLogFile(t, q, count-300)
|
||||
|
||||
// CASE 3: FIRST LINE
|
||||
testSeekLineQLogFile(t, q, 0)
|
||||
|
||||
// CASE 4: LAST LINE
|
||||
testSeekLineQLogFile(t, q, count)
|
||||
|
||||
// CASE 5: Seek non-existent (too low)
|
||||
_, _, err = q.SeekTS(123)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// CASE 6: Seek non-existent (too high)
|
||||
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
|
||||
_, _, err = q.SeekTS(ts.UnixNano())
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// CASE 7: "Almost" found
|
||||
line, err := getQLogFileLine(q, count/2)
|
||||
assert.Nil(t, err)
|
||||
// ALMOST the record we need
|
||||
timestamp := readQLogTimestamp(line) - 1
|
||||
assert.NotEqual(t, uint64(0), timestamp)
|
||||
_, depth, err := q.SeekTS(timestamp)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, depth <= int(math.Log2(float64(count))+3))
|
||||
}
|
||||
|
||||
func TestQLogFileSeekSmallFile(t *testing.T) {
|
||||
// more or less big file
|
||||
count := 10
|
||||
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFile := prepareTestFile(testDir, count)
|
||||
|
||||
// create the new QLogFile instance
|
||||
q, err := NewQLogFile(testFile)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, q)
|
||||
defer q.Close()
|
||||
|
||||
// CASE 1: NOT TOO OLD LINE
|
||||
testSeekLineQLogFile(t, q, 2)
|
||||
|
||||
// CASE 2: OLD LINE
|
||||
testSeekLineQLogFile(t, q, count-2)
|
||||
|
||||
// CASE 3: FIRST LINE
|
||||
testSeekLineQLogFile(t, q, 0)
|
||||
|
||||
// CASE 4: LAST LINE
|
||||
testSeekLineQLogFile(t, q, count)
|
||||
|
||||
// CASE 5: Seek non-existent (too low)
|
||||
_, _, err = q.SeekTS(123)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// CASE 6: Seek non-existent (too high)
|
||||
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
|
||||
_, _, err = q.SeekTS(ts.UnixNano())
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// CASE 7: "Almost" found
|
||||
line, err := getQLogFileLine(q, count/2)
|
||||
assert.Nil(t, err)
|
||||
// ALMOST the record we need
|
||||
timestamp := readQLogTimestamp(line) - 1
|
||||
assert.NotEqual(t, uint64(0), timestamp)
|
||||
_, depth, err := q.SeekTS(timestamp)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, depth <= int(math.Log2(float64(count))+3))
|
||||
}
|
||||
|
||||
func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) {
|
||||
line, err := getQLogFileLine(q, lineNumber)
|
||||
assert.Nil(t, err)
|
||||
ts := readQLogTimestamp(line)
|
||||
assert.NotEqual(t, uint64(0), ts)
|
||||
|
||||
// try seeking to that line now
|
||||
pos, _, err := q.SeekTS(ts)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, int64(0), pos)
|
||||
|
||||
testLine, err := q.ReadNext()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, line, testLine)
|
||||
}
|
||||
|
||||
func getQLogFileLine(q *QLogFile, lineNumber int) (string, error) {
|
||||
_, err := q.SeekStart()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i := 1; i < lineNumber; i++ {
|
||||
_, err := q.ReadNext()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return q.ReadNext()
|
||||
}
|
||||
|
||||
// Check adding and loading (with filtering) entries from disk and memory
|
||||
func TestQLogFile(t *testing.T) {
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFile := prepareTestFile(testDir, 2)
|
||||
|
||||
// create the new QLogFile instance
|
||||
q, err := NewQLogFile(testFile)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, q)
|
||||
defer q.Close()
|
||||
|
||||
// seek to the start
|
||||
pos, err := q.SeekStart()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, pos > 0)
|
||||
|
||||
// read first line
|
||||
line, err := q.ReadNext()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, strings.Contains(line, "0.0.0.2"), line)
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
assert.True(t, strings.HasSuffix(line, "}"), line)
|
||||
|
||||
// read second line
|
||||
line, err = q.ReadNext()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), q.position)
|
||||
assert.True(t, strings.Contains(line, "0.0.0.1"), line)
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
assert.True(t, strings.HasSuffix(line, "}"), line)
|
||||
|
||||
// try reading again (there's nothing to read anymore)
|
||||
line, err = q.ReadNext()
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, "", line)
|
||||
}
|
||||
|
||||
// prepareTestFile - prepares a test query log file with the specified number of lines
|
||||
func prepareTestFile(dir string, linesCount int) string {
|
||||
return prepareTestFiles(dir, 1, linesCount)[0]
|
||||
}
|
||||
|
||||
// prepareTestFiles - prepares several test query log files
|
||||
// each of them -- with the specified linesCount
|
||||
func prepareTestFiles(dir string, filesCount, linesCount int) []string {
|
||||
format := `{"IP":"${IP}","T":"${TIMESTAMP}","QH":"example.org","QT":"A","QC":"IN","Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=","Result":{},"Elapsed":0,"Upstream":"upstream"}`
|
||||
const strV = "\"%s\""
|
||||
const nl = "\n"
|
||||
const format = `{"IP":` + strV + `,"T":` + strV + `,` +
|
||||
`"QH":"example.org","QT":"A","QC":"IN",` +
|
||||
`"Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=",` +
|
||||
`"Result":{},"Elapsed":0,"Upstream":"upstream"}` + nl
|
||||
|
||||
lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00")
|
||||
lineIP := uint32(0)
|
||||
|
||||
files := make([]string, filesCount)
|
||||
for j := 0; j < filesCount; j++ {
|
||||
f, _ := ioutil.TempFile(dir, "*.txt")
|
||||
files[filesCount-j-1] = f.Name()
|
||||
dir := aghtest.PrepareTestDir(t)
|
||||
|
||||
for i := 0; i < linesCount; i++ {
|
||||
files := make([]string, filesNum)
|
||||
for j := range files {
|
||||
f, err := ioutil.TempFile(dir, "*.txt")
|
||||
require.Nil(t, err)
|
||||
files[filesNum-j-1] = f.Name()
|
||||
|
||||
for i := 0; i < linesNum; i++ {
|
||||
lineIP++
|
||||
lineTime = lineTime.Add(time.Second)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, lineIP)
|
||||
|
||||
line := format
|
||||
line = strings.ReplaceAll(line, "${IP}", ip.String())
|
||||
line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano))
|
||||
line := fmt.Sprintf(format, ip, lineTime.Format(time.RFC3339Nano))
|
||||
|
||||
_, _ = f.WriteString(line)
|
||||
_, _ = f.WriteString("\n")
|
||||
_, err = f.WriteString(line)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
func TestQLogSeek(t *testing.T) {
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
// prepareTestFile prepares a test query log file with the specified number of
|
||||
// lines.
|
||||
func prepareTestFile(t *testing.T, linesCount int) string {
|
||||
t.Helper()
|
||||
|
||||
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
|
||||
{"T":"2020-08-31T18:44:25.376690873+03:00"}
|
||||
{"T":"2020-08-31T18:44:25.382540454+03:00"}`
|
||||
f, _ := ioutil.TempFile(testDir, "*.txt")
|
||||
_, _ = f.WriteString(d)
|
||||
defer f.Close()
|
||||
|
||||
q, err := NewQLogFile(f.Name())
|
||||
assert.Nil(t, err)
|
||||
defer q.Close()
|
||||
|
||||
target, _ := time.Parse(time.RFC3339, "2020-08-31T18:44:25.376690873+03:00")
|
||||
|
||||
_, depth, err := q.SeekTS(target.UnixNano())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, depth)
|
||||
return prepareTestFiles(t, 1, linesCount)[0]
|
||||
}
|
||||
|
||||
func TestQLogSeek_ErrTSTooLate(t *testing.T) {
|
||||
testDir := prepareTestDir()
|
||||
// newTestQLogFile creates new *QLogFile for tests and registers the required
|
||||
// cleanup functions.
|
||||
func newTestQLogFile(t *testing.T, linesNum int) (file *QLogFile) {
|
||||
t.Helper()
|
||||
|
||||
testFile := prepareTestFile(t, linesNum)
|
||||
|
||||
// Create the new QLogFile instance.
|
||||
file, err := NewQLogFile(testFile)
|
||||
require.Nil(t, err)
|
||||
assert.NotNil(t, file)
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(testDir)
|
||||
assert.Nil(t, file.Close())
|
||||
})
|
||||
|
||||
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
|
||||
{"T":"2020-08-31T18:44:25.376690873+03:00"}
|
||||
{"T":"2020-08-31T18:44:25.382540454+03:00"}
|
||||
`
|
||||
f, err := ioutil.TempFile(testDir, "*.txt")
|
||||
assert.Nil(t, err)
|
||||
defer f.Close()
|
||||
|
||||
_, err = f.WriteString(d)
|
||||
assert.Nil(t, err)
|
||||
|
||||
q, err := NewQLogFile(f.Name())
|
||||
assert.Nil(t, err)
|
||||
defer q.Close()
|
||||
|
||||
target, err := time.Parse(time.RFC3339, "2020-08-31T18:44:25.382540454+03:00")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, depth, err := q.SeekTS(target.UnixNano() + int64(time.Second))
|
||||
assert.Equal(t, ErrTSTooLate, err)
|
||||
assert.Equal(t, 2, depth)
|
||||
return file
|
||||
}
|
||||
|
||||
func TestQLogSeek_ErrTSTooEarly(t *testing.T) {
|
||||
testDir := prepareTestDir()
|
||||
func TestQLogFile_ReadNext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
linesNum int
|
||||
}{{
|
||||
name: "empty",
|
||||
linesNum: 0,
|
||||
}, {
|
||||
name: "large",
|
||||
linesNum: 50000,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
q := newTestQLogFile(t, tc.linesNum)
|
||||
|
||||
// Calculate the expected position.
|
||||
fileInfo, err := q.file.Stat()
|
||||
require.Nil(t, err)
|
||||
var expPos int64
|
||||
if expPos = fileInfo.Size(); expPos > 0 {
|
||||
expPos--
|
||||
}
|
||||
|
||||
// Seek to the start.
|
||||
pos, err := q.SeekStart()
|
||||
require.Nil(t, err)
|
||||
require.EqualValues(t, expPos, pos)
|
||||
|
||||
var read int
|
||||
var line string
|
||||
for err == nil {
|
||||
line, err = q.ReadNext()
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, line)
|
||||
read++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, tc.linesNum, read)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQLogFile_SeekTS_good(t *testing.T) {
|
||||
linesCases := []struct {
|
||||
name string
|
||||
num int
|
||||
}{{
|
||||
name: "large",
|
||||
num: 10000,
|
||||
}, {
|
||||
name: "small",
|
||||
num: 10,
|
||||
}}
|
||||
|
||||
for _, l := range linesCases {
|
||||
testCases := []struct {
|
||||
name string
|
||||
linesNum int
|
||||
line int
|
||||
}{{
|
||||
name: "not_too_old",
|
||||
line: 2,
|
||||
}, {
|
||||
name: "old",
|
||||
line: l.num - 2,
|
||||
}, {
|
||||
name: "first",
|
||||
line: 0,
|
||||
}, {
|
||||
name: "last",
|
||||
line: l.num,
|
||||
}}
|
||||
|
||||
q := newTestQLogFile(t, l.num)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(l.name+"_"+tc.name, func(t *testing.T) {
|
||||
line, err := getQLogFileLine(q, tc.line)
|
||||
require.Nil(t, err)
|
||||
ts := readQLogTimestamp(line)
|
||||
assert.NotEqualValues(t, 0, ts)
|
||||
|
||||
// Try seeking to that line now.
|
||||
pos, _, err := q.SeekTS(ts)
|
||||
require.Nil(t, err)
|
||||
assert.NotEqualValues(t, 0, pos)
|
||||
|
||||
testLine, err := q.ReadNext()
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, line, testLine)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQLogFile_SeekTS_bad(t *testing.T) {
|
||||
linesCases := []struct {
|
||||
name string
|
||||
num int
|
||||
}{{
|
||||
name: "large",
|
||||
num: 10000,
|
||||
}, {
|
||||
name: "small",
|
||||
num: 10,
|
||||
}}
|
||||
|
||||
for _, l := range linesCases {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ts int64
|
||||
leq bool
|
||||
}{{
|
||||
name: "non-existent_long_ago",
|
||||
}, {
|
||||
name: "non-existent_far_ahead",
|
||||
}, {
|
||||
name: "almost",
|
||||
leq: true,
|
||||
}}
|
||||
|
||||
q := newTestQLogFile(t, l.num)
|
||||
testCases[0].ts = 123
|
||||
|
||||
lateTS, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
|
||||
testCases[1].ts = lateTS.UnixNano()
|
||||
|
||||
line, err := getQLogFileLine(q, l.num/2)
|
||||
require.Nil(t, err)
|
||||
testCases[2].ts = readQLogTimestamp(line) - 1
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.NotEqualValues(t, 0, tc.ts)
|
||||
|
||||
_, depth, err := q.SeekTS(tc.ts)
|
||||
assert.NotEmpty(t, l.num)
|
||||
require.NotNil(t, err)
|
||||
if tc.leq {
|
||||
assert.LessOrEqual(t, depth, int(math.Log2(float64(l.num))+3))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getQLogFileLine(q *QLogFile, lineNumber int) (line string, err error) {
|
||||
if _, err = q.SeekStart(); err != nil {
|
||||
return line, err
|
||||
}
|
||||
|
||||
for i := 1; i < lineNumber; i++ {
|
||||
if _, err = q.ReadNext(); err != nil {
|
||||
return line, err
|
||||
}
|
||||
}
|
||||
|
||||
return q.ReadNext()
|
||||
}
|
||||
|
||||
// Check adding and loading (with filtering) entries from disk and memory.
|
||||
func TestQLogFile(t *testing.T) {
|
||||
// Create the new QLogFile instance.
|
||||
q := newTestQLogFile(t, 2)
|
||||
|
||||
// Seek to the start.
|
||||
pos, err := q.SeekStart()
|
||||
require.Nil(t, err)
|
||||
assert.Greater(t, pos, int64(0))
|
||||
|
||||
// Read first line.
|
||||
line, err := q.ReadNext()
|
||||
require.Nil(t, err)
|
||||
assert.Contains(t, line, "0.0.0.2")
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
assert.True(t, strings.HasSuffix(line, "}"), line)
|
||||
|
||||
// Read second line.
|
||||
line, err = q.ReadNext()
|
||||
require.Nil(t, err)
|
||||
assert.EqualValues(t, 0, q.position)
|
||||
assert.Contains(t, line, "0.0.0.1")
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
assert.True(t, strings.HasSuffix(line, "}"), line)
|
||||
|
||||
// Try reading again (there's nothing to read anymore).
|
||||
line, err = q.ReadNext()
|
||||
require.Equal(t, io.EOF, err)
|
||||
assert.Empty(t, line)
|
||||
}
|
||||
|
||||
func NewTestQLogFileData(t *testing.T, data string) (file *QLogFile) {
|
||||
f, err := ioutil.TempFile(aghtest.PrepareTestDir(t), "*.txt")
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(testDir)
|
||||
assert.Nil(t, f.Close())
|
||||
})
|
||||
|
||||
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
|
||||
{"T":"2020-08-31T18:44:25.376690873+03:00"}
|
||||
{"T":"2020-08-31T18:44:25.382540454+03:00"}
|
||||
`
|
||||
f, err := ioutil.TempFile(testDir, "*.txt")
|
||||
assert.Nil(t, err)
|
||||
defer f.Close()
|
||||
_, err = f.WriteString(data)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = f.WriteString(d)
|
||||
assert.Nil(t, err)
|
||||
file, err = NewQLogFile(f.Name())
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, file.Close())
|
||||
})
|
||||
|
||||
q, err := NewQLogFile(f.Name())
|
||||
assert.Nil(t, err)
|
||||
defer q.Close()
|
||||
|
||||
target, err := time.Parse(time.RFC3339, "2020-08-31T18:44:23.911246629+03:00")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, depth, err := q.SeekTS(target.UnixNano() - int64(time.Second))
|
||||
assert.Equal(t, ErrTSTooEarly, err)
|
||||
assert.Equal(t, 1, depth)
|
||||
return file
|
||||
}
|
||||
|
||||
func TestQLog_Seek(t *testing.T) {
|
||||
const nl = "\n"
|
||||
const strV = "%s"
|
||||
const recs = `{"T":"` + strV + `","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}` + nl +
|
||||
`{"T":"` + strV + `"}` + nl +
|
||||
`{"T":"` + strV + `"}` + nl
|
||||
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
delta int
|
||||
wantErr error
|
||||
wantDepth int
|
||||
}{{
|
||||
name: "ok",
|
||||
delta: 0,
|
||||
wantErr: nil,
|
||||
wantDepth: 2,
|
||||
}, {
|
||||
name: "too_late",
|
||||
delta: 2,
|
||||
wantErr: ErrTSTooLate,
|
||||
wantDepth: 2,
|
||||
}, {
|
||||
name: "too_early",
|
||||
delta: -2,
|
||||
wantErr: ErrTSTooEarly,
|
||||
wantDepth: 1,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
data := fmt.Sprintf(recs,
|
||||
timestamp.Add(-time.Second).Format(time.RFC3339Nano),
|
||||
timestamp.Format(time.RFC3339Nano),
|
||||
timestamp.Add(time.Second).Format(time.RFC3339Nano),
|
||||
)
|
||||
|
||||
q := NewTestQLogFileData(t, data)
|
||||
|
||||
_, depth, err := q.SeekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano())
|
||||
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
|
||||
assert.Equal(t, tc.wantDepth, depth)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,110 +3,77 @@ package querylog
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQLogReaderEmpty(t *testing.T) {
|
||||
r, err := NewQLogReader([]string{})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
defer r.Close()
|
||||
// newTestQLogReader creates new *QLogReader for tests and registers the
|
||||
// required cleanup functions.
|
||||
func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *QLogReader) {
|
||||
t.Helper()
|
||||
|
||||
// seek to the start
|
||||
err = r.SeekStart()
|
||||
assert.Nil(t, err)
|
||||
testFiles := prepareTestFiles(t, filesNum, linesNum)
|
||||
|
||||
line, err := r.ReadNext()
|
||||
assert.Equal(t, "", line)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
// Create the new QLogReader instance.
|
||||
reader, err := NewQLogReader(testFiles)
|
||||
require.Nil(t, err)
|
||||
assert.NotNil(t, reader)
|
||||
t.Cleanup(func() {
|
||||
assert.Nil(t, reader.Close())
|
||||
})
|
||||
|
||||
return reader
|
||||
}
|
||||
|
||||
func TestQLogReaderOneFile(t *testing.T) {
|
||||
// let's do one small file
|
||||
count := 10
|
||||
filesCount := 1
|
||||
func TestQLogReader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filesNum int
|
||||
linesNum int
|
||||
}{{
|
||||
name: "empty",
|
||||
filesNum: 0,
|
||||
linesNum: 0,
|
||||
}, {
|
||||
name: "one_file",
|
||||
filesNum: 1,
|
||||
linesNum: 10,
|
||||
}, {
|
||||
name: "multiple_files",
|
||||
filesNum: 5,
|
||||
linesNum: 10000,
|
||||
}}
|
||||
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFiles := prepareTestFiles(testDir, filesCount, count)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := newTestQLogReader(t, tc.filesNum, tc.linesNum)
|
||||
|
||||
r, err := NewQLogReader(testFiles)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
defer r.Close()
|
||||
// Seek to the start.
|
||||
err := r.SeekStart()
|
||||
require.Nil(t, err)
|
||||
|
||||
// seek to the start
|
||||
err = r.SeekStart()
|
||||
assert.Nil(t, err)
|
||||
// Read everything.
|
||||
var read int
|
||||
var line string
|
||||
for err == nil {
|
||||
line, err = r.ReadNext()
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, line)
|
||||
read++
|
||||
}
|
||||
}
|
||||
|
||||
// read everything
|
||||
read := 0
|
||||
var line string
|
||||
for err == nil {
|
||||
line, err = r.ReadNext()
|
||||
if err == nil {
|
||||
assert.True(t, len(line) > 0)
|
||||
read++
|
||||
}
|
||||
require.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, tc.filesNum*tc.linesNum, read)
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, count*filesCount, read)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestQLogReaderMultipleFiles(t *testing.T) {
|
||||
// should be large enough
|
||||
count := 10000
|
||||
filesCount := 5
|
||||
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFiles := prepareTestFiles(testDir, filesCount, count)
|
||||
|
||||
r, err := NewQLogReader(testFiles)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
defer r.Close()
|
||||
|
||||
// seek to the start
|
||||
err = r.SeekStart()
|
||||
assert.Nil(t, err)
|
||||
|
||||
// read everything
|
||||
read := 0
|
||||
var line string
|
||||
for err == nil {
|
||||
line, err = r.ReadNext()
|
||||
if err == nil {
|
||||
assert.True(t, len(line) > 0)
|
||||
read++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, count*filesCount, read)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestQLogReader_Seek(t *testing.T) {
|
||||
count := 10000
|
||||
filesCount := 2
|
||||
|
||||
testDir := prepareTestDir()
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(testDir)
|
||||
})
|
||||
testFiles := prepareTestFiles(testDir, filesCount, count)
|
||||
|
||||
r, err := NewQLogReader(testFiles)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
t.Cleanup(func() {
|
||||
_ = r.Close()
|
||||
})
|
||||
r := newTestQLogReader(t, 2, 10000)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -114,7 +81,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
||||
want error
|
||||
}{{
|
||||
name: "not_too_old",
|
||||
time: "2020-02-19T04:04:56.920973+03:00",
|
||||
time: "2020-02-18T22:39:35.920973+03:00",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "old",
|
||||
@@ -122,7 +89,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
||||
want: nil,
|
||||
}, {
|
||||
name: "first",
|
||||
time: "2020-02-19T04:09:55.920973+03:00",
|
||||
time: "2020-02-18T22:36:36.920973+03:00",
|
||||
want: nil,
|
||||
}, {
|
||||
name: "last",
|
||||
@@ -147,28 +114,20 @@ func TestQLogReader_Seek(t *testing.T) {
|
||||
timestamp, err := time.Parse(time.RFC3339Nano, tc.time)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if tc.name == "first" {
|
||||
assert.True(t, true)
|
||||
}
|
||||
|
||||
err = r.SeekTS(timestamp.UnixNano())
|
||||
assert.True(t, errors.Is(err, tc.want), err)
|
||||
assert.True(t, errors.Is(err, tc.want))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQLogReader_ReadNext(t *testing.T) {
|
||||
count := 10
|
||||
filesCount := 1
|
||||
|
||||
testDir := prepareTestDir()
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(testDir)
|
||||
})
|
||||
testFiles := prepareTestFiles(testDir, filesCount, count)
|
||||
|
||||
r, err := NewQLogReader(testFiles)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
t.Cleanup(func() {
|
||||
_ = r.Close()
|
||||
})
|
||||
const linesNum = 10
|
||||
const filesNum = 1
|
||||
r := newTestQLogReader(t, filesNum, linesNum)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -180,7 +139,7 @@ func TestQLogReader_ReadNext(t *testing.T) {
|
||||
want: nil,
|
||||
}, {
|
||||
name: "too_big",
|
||||
start: count + 1,
|
||||
start: linesNum + 1,
|
||||
want: io.EOF,
|
||||
}}
|
||||
|
||||
@@ -199,70 +158,3 @@ func TestQLogReader_ReadNext(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Remove the tests below. Make tests above more compelling.
|
||||
func TestQLogReaderSeek(t *testing.T) {
|
||||
// more or less big file
|
||||
count := 10000
|
||||
filesCount := 2
|
||||
|
||||
testDir := prepareTestDir()
|
||||
defer func() { _ = os.RemoveAll(testDir) }()
|
||||
testFiles := prepareTestFiles(testDir, filesCount, count)
|
||||
|
||||
r, err := NewQLogReader(testFiles)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
defer r.Close()
|
||||
|
||||
// CASE 1: NOT TOO OLD LINE
|
||||
testSeekLineQLogReader(t, r, 300)
|
||||
|
||||
// CASE 2: OLD LINE
|
||||
testSeekLineQLogReader(t, r, count-300)
|
||||
|
||||
// CASE 3: FIRST LINE
|
||||
testSeekLineQLogReader(t, r, 0)
|
||||
|
||||
// CASE 4: LAST LINE
|
||||
testSeekLineQLogReader(t, r, count)
|
||||
|
||||
// CASE 5: Seek non-existent (too low)
|
||||
err = r.SeekTS(123)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// CASE 6: Seek non-existent (too high)
|
||||
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
|
||||
err = r.SeekTS(ts.UnixNano())
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) {
|
||||
line, err := getQLogReaderLine(r, lineNumber)
|
||||
assert.Nil(t, err)
|
||||
ts := readQLogTimestamp(line)
|
||||
assert.NotEqual(t, uint64(0), ts)
|
||||
|
||||
// try seeking to that line now
|
||||
err = r.SeekTS(ts)
|
||||
assert.Nil(t, err)
|
||||
|
||||
testLine, err := r.ReadNext()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, line, testLine)
|
||||
}
|
||||
|
||||
func getQLogReaderLine(r *QLogReader, lineNumber int) (string, error) {
|
||||
err := r.SeekStart()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i := 1; i < lineNumber; i++ {
|
||||
_, err := r.ReadNext()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return r.ReadNext()
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type AddParams struct {
|
||||
OrigAnswer *dns.Msg // The response from an upstream server (optional)
|
||||
Result *dnsfilter.Result // Filtering result (optional)
|
||||
Elapsed time.Duration // Time spent for processing the request
|
||||
ClientID string
|
||||
ClientIP net.IP
|
||||
Upstream string // Upstream server URL
|
||||
ClientProto ClientProto
|
||||
|
||||
@@ -3,6 +3,7 @@ package querylog
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -87,18 +88,19 @@ func (l *queryLog) rotate() error {
|
||||
from := l.logFile
|
||||
to := l.logFile + ".1"
|
||||
|
||||
if _, err := os.Stat(from); os.IsNotExist(err) {
|
||||
// do nothing, file doesn't exist
|
||||
return nil
|
||||
}
|
||||
|
||||
err := os.Rename(from, to)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Error("querylog: failed to rename file: %s", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("querylog: renamed %s -> %s", from, to)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,13 @@ import (
|
||||
type criteriaType int
|
||||
|
||||
const (
|
||||
ctDomainOrClient criteriaType = iota // domain name or client IP address
|
||||
ctFilteringStatus // filtering status
|
||||
// ctDomainOrClient is for searching by the domain name, the client's IP
|
||||
// address, or the clinet's ID.
|
||||
ctDomainOrClient criteriaType = iota
|
||||
// ctFilteringStatus is for searching by the filtering status.
|
||||
//
|
||||
// See (*searchCriteria).ctFilteringStatusCase for details.
|
||||
ctFilteringStatus
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,9 +43,9 @@ var filteringStatusValues = []string{
|
||||
// searchCriteria - every search request may contain a list of different search criteria
|
||||
// we use each of them to match the query
|
||||
type searchCriteria struct {
|
||||
value string // search criteria value
|
||||
criteriaType criteriaType // type of the criteria
|
||||
strict bool // should we strictly match (equality) or not (indexOf)
|
||||
value string // search criteria value
|
||||
}
|
||||
|
||||
// quickMatch - quickly checks if the log entry matches this search criteria
|
||||
@@ -51,7 +56,8 @@ func (c *searchCriteria) quickMatch(line string) bool {
|
||||
switch c.criteriaType {
|
||||
case ctDomainOrClient:
|
||||
return c.quickMatchJSONValue(line, "QH") ||
|
||||
c.quickMatchJSONValue(line, "IP")
|
||||
c.quickMatchJSONValue(line, "IP") ||
|
||||
c.quickMatchJSONValue(line, "CID")
|
||||
default:
|
||||
return true
|
||||
}
|
||||
@@ -89,21 +95,26 @@ func (c *searchCriteria) match(entry *logEntry) bool {
|
||||
}
|
||||
|
||||
func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool {
|
||||
clientID := strings.ToLower(entry.ClientID)
|
||||
qhost := strings.ToLower(entry.QHost)
|
||||
searchVal := strings.ToLower(c.value)
|
||||
if c.strict && qhost == searchVal {
|
||||
return true
|
||||
}
|
||||
if !c.strict && strings.Contains(qhost, searchVal) {
|
||||
if c.strict && (qhost == searchVal || clientID == searchVal) {
|
||||
return true
|
||||
}
|
||||
|
||||
if c.strict && entry.IP == c.value {
|
||||
if !c.strict && (strings.Contains(qhost, searchVal) || strings.Contains(clientID, searchVal)) {
|
||||
return true
|
||||
}
|
||||
if !c.strict && strings.Contains(entry.IP, c.value) {
|
||||
|
||||
ipStr := entry.IP.String()
|
||||
if c.strict && ipStr == c.value {
|
||||
return true
|
||||
}
|
||||
|
||||
if !c.strict && strings.Contains(ipStr, c.value) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -116,8 +127,9 @@ func (c *searchCriteria) ctFilteringStatusCase(res dnsfilter.Result) bool {
|
||||
return res.IsFiltered ||
|
||||
res.Reason.In(
|
||||
dnsfilter.NotFilteredAllowList,
|
||||
dnsfilter.ReasonRewrite,
|
||||
dnsfilter.RewriteAutoHosts,
|
||||
dnsfilter.Rewritten,
|
||||
dnsfilter.RewrittenAutoHosts,
|
||||
dnsfilter.RewrittenRule,
|
||||
)
|
||||
|
||||
case filteringStatusBlocked:
|
||||
@@ -137,7 +149,11 @@ func (c *searchCriteria) ctFilteringStatusCase(res dnsfilter.Result) bool {
|
||||
return res.Reason == dnsfilter.NotFilteredAllowList
|
||||
|
||||
case filteringStatusRewritten:
|
||||
return res.Reason.In(dnsfilter.ReasonRewrite, dnsfilter.RewriteAutoHosts)
|
||||
return res.Reason.In(
|
||||
dnsfilter.Rewritten,
|
||||
dnsfilter.RewrittenAutoHosts,
|
||||
dnsfilter.RewrittenRule,
|
||||
)
|
||||
|
||||
case filteringStatusSafeSearch:
|
||||
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user