Merge branch 'master' into feature/1574
This commit is contained in:
@@ -11,11 +11,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
@@ -41,11 +42,12 @@ type Client struct {
|
||||
BlockedServices []string
|
||||
|
||||
Upstreams []string // list of upstream servers to be used for the client's requests
|
||||
// Upstream objects:
|
||||
|
||||
// Custom upstream config for this client
|
||||
// nil: not yet initialized
|
||||
// not nil, but empty: initialized, no good upstreams
|
||||
// not nil, not empty: Upstreams ready to be used
|
||||
upstreamObjects []upstream.Upstream
|
||||
upstreamConfig *proxy.UpstreamConfig
|
||||
}
|
||||
|
||||
type clientSource uint
|
||||
@@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||
return c, true
|
||||
}
|
||||
|
||||
func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream {
|
||||
a2 := make([]upstream.Upstream, len(a))
|
||||
copy(a2, a)
|
||||
return a2
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamObjects == nil {
|
||||
c.upstreamObjects = make([]upstream.Upstream, 0)
|
||||
for _, us := range c.Upstreams {
|
||||
u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
|
||||
if err != nil {
|
||||
log.Error("upstream.AddressToUpstream: %s: %s", us, err)
|
||||
continue
|
||||
}
|
||||
c.upstreamObjects = append(c.upstreamObjects, u)
|
||||
if len(c.Upstreams) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.upstreamConfig == nil {
|
||||
config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout)
|
||||
if err == nil {
|
||||
c.upstreamConfig = &config
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.upstreamObjects) == 0 {
|
||||
return nil
|
||||
}
|
||||
return upstreamArrayCopy(c.upstreamObjects)
|
||||
return c.upstreamConfig
|
||||
}
|
||||
|
||||
// Find searches for a client by IP (and does not lock anything)
|
||||
@@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
|
||||
}
|
||||
|
||||
// update upstreams cache
|
||||
c.upstreamObjects = nil
|
||||
c.upstreamConfig = nil
|
||||
|
||||
*old = c
|
||||
return nil
|
||||
|
||||
@@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestClientsCustomUpstream(t *testing.T) {
|
||||
clients := clientsContainer{}
|
||||
clients.testing = true
|
||||
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
// add client with upstreams
|
||||
client := Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
}
|
||||
ok, err := clients.Add(client)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
config := clients.FindUpstreams("1.2.3.4")
|
||||
assert.Nil(t, config)
|
||||
|
||||
config = clients.FindUpstreams("1.1.1.1")
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, 1, len(config.Upstreams))
|
||||
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/util"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
@@ -176,7 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||
newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newconfig.FilterHandler = applyAdditionalFiltering
|
||||
newconfig.GetUpstreamsByClient = getUpstreamsByClient
|
||||
newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||
return newconfig
|
||||
}
|
||||
|
||||
@@ -222,10 +221,6 @@ func getDNSAddresses() []string {
|
||||
return dnsAddresses
|
||||
}
|
||||
|
||||
func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
|
||||
return Context.clients.FindUpstreams(clientAddr)
|
||||
}
|
||||
|
||||
// If a client has his own settings, apply them
|
||||
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
|
||||
@@ -478,6 +478,8 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
}
|
||||
} else if line[0] == '#' {
|
||||
continue
|
||||
} else {
|
||||
rulesCount++
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
func testStartFilterListener() net.Listener {
|
||||
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
content := `||example.org^$third-party
|
||||
# Inline comment example
|
||||
||example.com^$third-party
|
||||
0.0.0.0 example.com
|
||||
`
|
||||
|
||||
@@ -336,6 +336,8 @@ func requireAdminRights() {
|
||||
admin, _ := util.HaveAdminRights()
|
||||
if //noinspection ALL
|
||||
admin || isdelve.Enabled {
|
||||
// Don't forget that for this to work you need to add "delve" tag explicitly
|
||||
// https://stackoverflow.com/questions/47879070/how-can-i-see-if-the-goland-debugger-is-running-in-the-program
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user