Pull request: all: fix client upstreams, imp code

Updates #3186.

Squashed commit of the following:

commit a8dd0e2cda3039839d069fe71a5bd0f9635ec064
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri May 28 12:54:07 2021 +0300

    all: imp code, names

commit 98f86c21ae23b665095075feb4a59dcfcc622bc7
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu May 27 21:11:37 2021 +0300

    all: fix client upstreams, imp code
This commit is contained in:
Ainar Garipov
2021-05-28 13:02:59 +03:00
parent 48b8579703
commit 3be783bd34
18 changed files with 249 additions and 270 deletions

View File

@@ -335,37 +335,44 @@ func (clients *clientsContainer) Find(id string) (c *Client, ok bool) {
return c, true
}
// FindUpstreams looks for upstreams configured for the client
// If no client found for this IP, or if no custom upstreams are configured,
// this method returns nil
func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig {
// findUpstreams returns upstreams configured for the client, identified either
// by its IP address or its ClientID. upsConf is nil if the client isn't found
// or if the client has no custom upstreams.
func (clients *clientsContainer) findUpstreams(
id string,
) (upsConf *proxy.UpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.findLocked(ip)
c, ok := clients.findLocked(id)
if !ok {
return nil
return nil, nil
}
upstreams := aghstrings.FilterOut(c.Upstreams, aghstrings.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil
return nil, nil
}
if c.upstreamConfig == nil {
conf, err := proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: dnsforward.DefaultTimeout,
},
)
if err == nil {
c.upstreamConfig = &conf
}
if c.upstreamConfig != nil {
return c.upstreamConfig, nil
}
return c.upstreamConfig
var conf proxy.UpstreamConfig
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: dnsforward.DefaultTimeout,
},
)
if err != nil {
return nil, err
}
c.upstreamConfig = &conf
return &conf, nil
}
// findLocked searches for a client by its ID. For internal use only.

View File

@@ -25,7 +25,7 @@ func TestClients(t *testing.T) {
}
ok, err := clients.Add(c)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
c = &Client{
@@ -34,7 +34,7 @@ func TestClients(t *testing.T) {
}
ok, err = clients.Add(c)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
c, ok = clients.Find("1.1.1.1")
@@ -59,7 +59,7 @@ func TestClients(t *testing.T) {
IDs: []string{"1.2.3.5"},
Name: "client1",
})
require.Nil(t, err)
require.NoError(t, err)
assert.False(t, ok)
})
@@ -68,7 +68,7 @@ func TestClients(t *testing.T) {
IDs: []string{"2.2.2.2"},
Name: "client3",
})
require.NotNil(t, err)
require.Error(t, err)
assert.False(t, ok)
})
@@ -77,13 +77,13 @@ func TestClients(t *testing.T) {
IDs: []string{"1.2.3.0"},
Name: "client3",
})
require.NotNil(t, err)
require.Error(t, err)
err = clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"},
Name: "client2",
})
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("update_fail_ip", func(t *testing.T) {
@@ -91,7 +91,7 @@ func TestClients(t *testing.T) {
IDs: []string{"2.2.2.2"},
Name: "client1",
})
assert.NotNil(t, err)
assert.Error(t, err)
})
t.Run("update_success", func(t *testing.T) {
@@ -99,7 +99,7 @@ func TestClients(t *testing.T) {
IDs: []string{"1.1.1.2"},
Name: "client1",
})
require.Nil(t, err)
require.NoError(t, err)
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
@@ -109,7 +109,7 @@ func TestClients(t *testing.T) {
Name: "client1-renamed",
UseOwnSettings: true,
})
require.Nil(t, err)
require.NoError(t, err)
c, ok := clients.Find("1.1.1.2")
require.True(t, ok)
@@ -137,15 +137,15 @@ func TestClients(t *testing.T) {
t.Run("addhost_success", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
@@ -153,7 +153,7 @@ func TestClients(t *testing.T) {
t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
require.Nil(t, err)
require.NoError(t, err)
assert.False(t, ok)
})
}
@@ -181,7 +181,7 @@ func TestClientsWhois(t *testing.T) {
t.Run("existing_auto-client", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.1", whois)
@@ -198,7 +198,7 @@ func TestClientsWhois(t *testing.T) {
IDs: []string{"1.1.1.2"},
Name: "client1",
})
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.2", whois)
@@ -219,12 +219,12 @@ func TestClientsAddExisting(t *testing.T) {
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
})
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
})
@@ -253,14 +253,14 @@ func TestClientsAddExisting(t *testing.T) {
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
require.Nil(t, err)
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{
IDs: []string{testIP.String()},
Name: "client2",
})
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP
@@ -269,7 +269,7 @@ func TestClientsAddExisting(t *testing.T) {
IDs: []string{"2.2.2.2"},
Name: "client3",
})
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
})
}
@@ -289,14 +289,16 @@ func TestClientsCustomUpstream(t *testing.T) {
"[/example.org/]8.8.8.8",
},
})
require.Nil(t, err)
require.NoError(t, err)
assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4")
config, err := clients.findUpstreams("1.2.3.4")
assert.Nil(t, config)
assert.NoError(t, err)
config = clients.FindUpstreams("1.1.1.1")
config, err = clients.findUpstreams("1.1.1.1")
require.NotNil(t, config)
assert.NoError(t, err)
assert.Len(t, config.Upstreams, 1)
assert.Len(t, config.DomainReservedUpstreams, 1)
}

View File

@@ -8,6 +8,7 @@ import (
"path/filepath"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
@@ -106,7 +107,7 @@ func isRunning() bool {
}
func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.IPFromAddr(d.Addr)
ip := aghnet.IPFromAddr(d.Addr)
if ip == nil {
// This would be quite weird if we get here.
return
@@ -197,7 +198,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH
newConf.FilterHandler = applyAdditionalFiltering
newConf.GetCustomUpstreamByClient = Context.clients.FindUpstreams
newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams
newConf.ResolveClients = dnsConf.ResolveClients
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -66,19 +67,6 @@ func trimValue(s string) string {
return s[:maxValueLength-3] + "..."
}
// coalesceStr returns the first non-empty string.
//
// TODO(a.garipov): Move to aghstrings?
func coalesceStr(strs ...string) (res string) {
for _, s := range strs {
if s != "" {
return s
}
}
return ""
}
// isWhoisComment returns true if the string is empty or is a WHOIS comment.
func isWhoisComment(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' || s[0] == '%'
@@ -119,7 +107,7 @@ func whoisParse(data string) (m strmap) {
v = trimValue(v)
case "descr", "netname":
k = "orgname"
v = coalesceStr(orgname, v)
v = aghstrings.Coalesce(orgname, v)
orgname = v
case "whois":
k = "whois"