Pull request: all: fix client upstreams, imp code
Updates #3186. Squashed commit of the following: commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri May 28 12:54:07 2021 +0300 all: imp code, names commit 98f86c21ae23b665095075feb4a59dcfcc622bc7 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu May 27 21:11:37 2021 +0300 all: fix client upstreams, imp code
This commit is contained in:
@@ -335,37 +335,44 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
|
||||
return c, true
|
||||
}
|
||||
|
||||
// FindUpstreams looks for upstreams configured for the client
|
||||
// If no client found for this IP, or if no custom upstreams are configured,
|
||||
// this method returns nil
|
||||
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
|
||||
// findUpstreams returns upstreams configured for the client, identified either
|
||||
// by its IP address or its ClientID. upsConf is nil if the client isn't found
|
||||
// or if the client has no custom upstreams.
|
||||
func (clients *clientsContainer) findUpstreams(
|
||||
id string,
|
||||
) (upsConf *proxy.UpstreamConfig, err error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
c, ok := clients.findLocked(ip)
|
||||
c, ok := clients.findLocked(id)
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
|
||||
if len(upstreams) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if c.upstreamConfig == nil {
|
||||
conf, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: dnsforward.DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
c.upstreamConfig = &conf
|
||||
}
|
||||
if c.upstreamConfig != nil {
|
||||
return c.upstreamConfig, nil
|
||||
}
|
||||
|
||||
return c.upstreamConfig
|
||||
var conf proxy.UpstreamConfig
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: dnsforward.DefaultTimeout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.upstreamConfig = &conf
|
||||
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
// findLocked searches for a client by its ID. For internal use only.
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestClients(t *testing.T) {
|
||||
}
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &Client{
|
||||
@@ -34,7 +34,7 @@ func TestClients(t *testing.T) {
|
||||
}
|
||||
|
||||
ok, err = clients.Add(c)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.Find("1.1.1.1")
|
||||
@@ -59,7 +59,7 @@ func TestClients(t *testing.T) {
|
||||
IDs: []string{"1.2.3.5"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
@@ -68,7 +68,7 @@ func TestClients(t *testing.T) {
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
})
|
||||
require.NotNil(t, err)
|
||||
require.Error(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
@@ -77,13 +77,13 @@ func TestClients(t *testing.T) {
|
||||
IDs: []string{"1.2.3.0"},
|
||||
Name: "client3",
|
||||
})
|
||||
require.NotNil(t, err)
|
||||
require.Error(t, err)
|
||||
|
||||
err = clients.Update("client3", &Client{
|
||||
IDs: []string{"1.2.3.0"},
|
||||
Name: "client2",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
@@ -91,7 +91,7 @@ func TestClients(t *testing.T) {
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client1",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("update_success", func(t *testing.T) {
|
||||
@@ -99,7 +99,7 @@ func TestClients(t *testing.T) {
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
@@ -109,7 +109,7 @@ func TestClients(t *testing.T) {
|
||||
Name: "client1-renamed",
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := clients.Find("1.1.1.2")
|
||||
require.True(t, ok)
|
||||
@@ -137,15 +137,15 @@ func TestClients(t *testing.T) {
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
@@ -153,7 +153,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestClientsWhois(t *testing.T) {
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWhoisInfo("1.1.1.1", whois)
|
||||
@@ -198,7 +198,7 @@ func TestClientsWhois(t *testing.T) {
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWhoisInfo("1.1.1.2", whois)
|
||||
@@ -219,12 +219,12 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
@@ -253,14 +253,14 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{testIP.String()},
|
||||
Name: "client2",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add a new client with the IP from the first client's IP
|
||||
@@ -269,7 +269,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
IDs: []string{"2.2.2.2"},
|
||||
Name: "client3",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -289,14 +289,16 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
"[/example.org/]8.8.8.8",
|
||||
},
|
||||
})
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
config := clients.FindUpstreams("1.2.3.4")
|
||||
config, err := clients.findUpstreams("1.2.3.4")
|
||||
assert.Nil(t, config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config = clients.FindUpstreams("1.1.1.1")
|
||||
config, err = clients.findUpstreams("1.1.1.1")
|
||||
require.NotNil(t, config)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, config.Upstreams, 1)
|
||||
assert.Len(t, config.DomainReservedUpstreams, 1)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
@@ -106,7 +107,7 @@ func isRunning() bool {
|
||||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
ip := dnsforward.IPFromAddr(d.Addr)
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
if ip == nil {
|
||||
// This would be quite weird if we get here.
|
||||
return
|
||||
@@ -197,7 +198,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
|
||||
|
||||
newConf.FilterHandler = applyAdditionalFiltering
|
||||
newConf.GetCustomUpstreamByClient = Context.clients.FindUpstreams
|
||||
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
|
||||
|
||||
newConf.ResolveClients = dnsConf.ResolveClients
|
||||
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -66,19 +67,6 @@ func trimValue(s string) string {
|
||||
return s[:maxValueLength-3] + "..."
|
||||
}
|
||||
|
||||
// coalesceStr returns the first non-empty string.
|
||||
//
|
||||
// TODO(a.garipov): Move to aghstrings?
|
||||
func coalesceStr(strs ...string) (res string) {
|
||||
for _, s := range strs {
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isWhoisComment returns true if the string is empty or is a WHOIS comment.
|
||||
func isWhoisComment(s string) (ok bool) {
|
||||
return len(s) == 0 || s[0] == '#' || s[0] == '%'
|
||||
@@ -119,7 +107,7 @@ func whoisParse(data string) (m strmap) {
|
||||
v = trimValue(v)
|
||||
case "descr", "netname":
|
||||
k = "orgname"
|
||||
v = coalesceStr(orgname, v)
|
||||
v = aghstrings.Coalesce(orgname, v)
|
||||
orgname = v
|
||||
case "whois":
|
||||
k = "whois"
|
||||
|
||||
Reference in New Issue
Block a user