Pull request: all: mv some utilities to netutil

Merge in DNS/adguard-home from mv-netutil to master

Squashed commit of the following:

commit 5698fceed656dca7f8644e7dbd7e1a7fc57a68ce
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Aug 9 15:44:17 2021 +0300

    dnsforward: add todos

commit 122fb6e3de658b296931e0f608cf24ef85547666
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Aug 9 14:27:46 2021 +0300

    all: mv some utilities to netutil
This commit is contained in:
Ainar Garipov
2021-08-09 16:03:37 +03:00
parent 133a1c57b1
commit 9bd7294fad
37 changed files with 242 additions and 1033 deletions

View File

@@ -7,8 +7,8 @@ import (
"net/http"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
@@ -17,8 +17,8 @@ import (
// accessCtx controls IP and client blocking that takes place before all other
// processing. An accessCtx is safe for concurrent use.
type accessCtx struct {
allowedIPs *aghnet.IPMap
blockedIPs *aghnet.IPMap
allowedIPs *netutil.IPMap
blockedIPs *netutil.IPMap
allowedClientIDs *stringutil.Set
blockedClientIDs *stringutil.Set
@@ -26,7 +26,7 @@ type accessCtx struct {
blockedHostsEng *urlfilter.DNSEngine
// TODO(a.garipov): Create a type for a set of IP networks.
// aghnet.IPNetSet?
// netutil.IPNetSet?
allowedNets []*net.IPNet
blockedNets []*net.IPNet
}
@@ -38,7 +38,7 @@ type unit = struct{}
// which may be an IP address, a CIDR, or a ClientID.
func processAccessClients(
clientStrs []string,
ips *aghnet.IPMap,
ips *netutil.IPMap,
nets *[]*net.IPNet,
clientIDs *stringutil.Set,
) (err error) {
@@ -68,8 +68,8 @@ func processAccessClients(
// newAccessCtx creates a new accessCtx.
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
a = &accessCtx{
allowedIPs: aghnet.NewIPMap(0),
blockedIPs: aghnet.NewIPMap(0),
allowedIPs: netutil.NewIPMap(0),
blockedIPs: netutil.NewIPMap(0),
allowedClientIDs: stringutil.NewSet(),
blockedClientIDs: stringutil.NewSet(),

View File

@@ -7,15 +7,15 @@ import (
"path"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/lucas-clemente/quic-go"
)
// ValidateClientID returns an error if clientID is not a valid client ID.
func ValidateClientID(clientID string) (err error) {
err = aghnet.ValidateDomainNameLabel(clientID)
err = netutil.ValidateDomainNameLabel(clientID)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid client id %q: %w", clientID, errors.Unwrap(err))

View File

@@ -46,6 +46,8 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
}
func TestServer_clientIDFromDNSContext(t *testing.T) {
// TODO(a.garipov): Consider moving away from the text-based error
// checks and onto a more structured approach.
testCases := []struct {
name string
proto proxy.Proto
@@ -111,7 +113,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
cliSrvName: "!!!.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`,
`bad domain name label rune '!'`,
strictSNI: true,
}, {
name: "tls_client_id_too_long",
@@ -122,7 +124,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantClientID: "",
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
`label is too long, max: 63`,
`domain name label is too long: got 72, max 63`,
strictSNI: true,
}, {
name: "quic_client_id",
@@ -220,7 +222,7 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id "!!!": ` +
`invalid char '!' at index 0`,
`bad domain name label rune '!'`,
}}
for _, tc := range testCases {

View File

@@ -11,12 +11,12 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/ameshkov/dnscrypt/v2"
)
@@ -451,7 +451,7 @@ func matchesDomainWildcard(host, pat string) (ok bool) {
// anyNameMatches returns true if sni, the client's SNI value, matches any of
// the DNS names and patterns from certificate. dnsNames must be sorted.
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
if aghnet.ValidateDomainName(sni) != nil {
if netutil.ValidateDomainName(sni) != nil {
return false
}

View File

@@ -5,11 +5,11 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
)
@@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
s.tableHostToIP = t
}
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
func (s *Server) setTableIPToHost(t *netutil.IPMap) {
s.tableIPToHostLock.Lock()
defer s.tableIPToHostLock.Unlock()
@@ -188,18 +188,18 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
}
var hostToIP hostToIPTable
var ipToHost *aghnet.IPMap
var ipToHost *netutil.IPMap
if add {
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
hostToIP = make(hostToIPTable, len(ll))
ipToHost = aghnet.NewIPMap(len(ll))
ipToHost = netutil.NewIPMap(len(ll))
for _, l := range ll {
// TODO(a.garipov): Remove this after we're finished
// with the client hostname validations in the DHCP
// server code.
err = aghnet.ValidateDomainName(l.Hostname)
err = netutil.ValidateDomainName(l.Hostname)
if err != nil {
log.Debug(
"dns: skipping invalid hostname %q from dhcp: %s",
@@ -230,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
rc = resultCodeSuccess
var ip net.IP
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
if ip, _ = netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr); ip == nil {
return rc
}
@@ -331,12 +331,11 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
return resultCodeSuccess
}
ip := aghnet.UnreverseAddr(q.Name)
if ip == nil {
// That's weird.
//
// TODO(e.burkov): Research the cases when it could happen.
return resultCodeSuccess
ip, err := netutil.IPFromReversedAddr(q.Name)
if err != nil {
log.Debug("dns: reversed addr: %s", err)
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
@@ -502,7 +501,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
// ipStringFromAddr extracts an IP address string from net.Addr.
func ipStringFromAddr(addr net.Addr) (ipStr string) {
if ip := aghnet.IPFromAddr(addr); ip != nil {
if ip, _ := netutil.IPAndPortFromAddr(addr); ip != nil {
return ip.String()
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
)
@@ -81,7 +82,7 @@ type Server struct {
tableHostToIP hostToIPTable
tableHostToIPLock sync.Mutex
tableIPToHost *aghnet.IPMap
tableIPToHost *netutil.IPMap
tableIPToHostLock sync.Mutex
// clientIDCache is a temporary storage for clientIDs that were
@@ -141,7 +142,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
if p.LocalDomain == "" {
localDomainSuffix = defaultLocalDomainSuffix
} else {
err = aghnet.ValidateDomainName(p.LocalDomain)
err = netutil.ValidateDomainName(p.LocalDomain)
if err != nil {
return nil, fmt.Errorf("local domain: %w", err)
}
@@ -281,7 +282,12 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
return "", nil
}
arpa := dns.Fqdn(aghnet.ReverseAddr(ip))
arpa, err := netutil.IPToReversedAddr(ip)
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
}
arpa = dns.Fqdn(arpa)
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),

View File

@@ -1119,6 +1119,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
}
func TestNewServer(t *testing.T) {
// TODO(a.garipov): Consider moving away from the text-based error
// checks and onto a more structured approach.
testCases := []struct {
name string
in DNSCreateParams
@@ -1144,9 +1146,8 @@ func TestNewServer(t *testing.T) {
in: DNSCreateParams{
LocalDomain: "!!!",
},
wantErrMsg: `local domain: validating domain name "!!!": ` +
`invalid domain name label at index 0: ` +
`validating label "!!!": invalid char '!' at index 0`,
wantErrMsg: `local domain: bad domain name "!!!": ` +
`bad domain name label "!!!": bad domain name label rune '!'`,
}}
for _, tc := range testCases {

View File

@@ -5,10 +5,10 @@ import (
"fmt"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
@@ -19,7 +19,7 @@ func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
ip := aghnet.IPFromAddr(pctx.Addr)
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
@@ -53,7 +53,8 @@ func (s *Server) beforeRequestHandler(
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig()
if s.conf.FilterHandler != nil {
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, ctx.clientID, &setts)
}
return &setts

View File

@@ -9,11 +9,11 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
)
@@ -443,7 +443,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
continue
}
err = aghnet.ValidateDomainName(host)
err = netutil.ValidateDomainName(host)
if err != nil {
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
}

View File

@@ -5,9 +5,9 @@ import (
"encoding/binary"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
@@ -77,7 +77,7 @@ func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDet
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen)
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the
// real machine's endianess is needed.
byteOrder := binary.BigEndian
@@ -95,7 +95,7 @@ func msgToSignature(msg dns.Msg) (sig []byte) {
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
type msgSignature struct {
name [aghnet.MaxDomainNameLen]byte
name [netutil.MaxDomainNameLen]byte
id uint16
qtype uint16
}

View File

@@ -4,11 +4,11 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
@@ -32,13 +32,14 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
if shouldLog && s.queryLog != nil {
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
p := querylog.AddParams{
Question: msg,
Answer: pctx.Res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: aghnet.IPFromAddr(pctx.Addr),
ClientIP: ip,
ClientID: ctx.clientID,
}
@@ -80,7 +81,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
} else if ip, _ := netutil.IPAndPortFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
}