Pull request: all: mv some utilities to netutil
Merge in DNS/adguard-home from mv-netutil to master Squashed commit of the following: commit 5698fceed656dca7f8644e7dbd7e1a7fc57a68ce Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon Aug 9 15:44:17 2021 +0300 dnsforward: add todos commit 122fb6e3de658b296931e0f608cf24ef85547666 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon Aug 9 14:27:46 2021 +0300 all: mv some utilities to netutil
This commit is contained in:
@@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
// accessCtx controls IP and client blocking that takes place before all other
|
||||
// processing. An accessCtx is safe for concurrent use.
|
||||
type accessCtx struct {
|
||||
allowedIPs *aghnet.IPMap
|
||||
blockedIPs *aghnet.IPMap
|
||||
allowedIPs *netutil.IPMap
|
||||
blockedIPs *netutil.IPMap
|
||||
|
||||
allowedClientIDs *stringutil.Set
|
||||
blockedClientIDs *stringutil.Set
|
||||
@@ -26,7 +26,7 @@ type accessCtx struct {
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// aghnet.IPNetSet?
|
||||
// netutil.IPNetSet?
|
||||
allowedNets []*net.IPNet
|
||||
blockedNets []*net.IPNet
|
||||
}
|
||||
@@ -38,7 +38,7 @@ type unit = struct{}
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips *aghnet.IPMap,
|
||||
ips *netutil.IPMap,
|
||||
nets *[]*net.IPNet,
|
||||
clientIDs *stringutil.Set,
|
||||
) (err error) {
|
||||
@@ -68,8 +68,8 @@ func processAccessClients(
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedIPs: aghnet.NewIPMap(0),
|
||||
blockedIPs: aghnet.NewIPMap(0),
|
||||
allowedIPs: netutil.NewIPMap(0),
|
||||
blockedIPs: netutil.NewIPMap(0),
|
||||
|
||||
allowedClientIDs: stringutil.NewSet(),
|
||||
blockedClientIDs: stringutil.NewSet(),
|
||||
|
||||
@@ -7,15 +7,15 @@ import (
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// ValidateClientID returns an error if clientID is not a valid client ID.
|
||||
func ValidateClientID(clientID string) (err error) {
|
||||
err = aghnet.ValidateDomainNameLabel(clientID)
|
||||
err = netutil.ValidateDomainNameLabel(clientID)
|
||||
if err != nil {
|
||||
// Replace the domain name label wrapper with our own.
|
||||
return fmt.Errorf("invalid client id %q: %w", clientID, errors.Unwrap(err))
|
||||
|
||||
@@ -46,6 +46,8 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
}
|
||||
|
||||
func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
// TODO(a.garipov): Consider moving away from the text-based error
|
||||
// checks and onto a more structured approach.
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto proxy.Proto
|
||||
@@ -111,7 +113,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
cliSrvName: "!!!.example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
`bad domain name label rune '!'`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_too_long",
|
||||
@@ -122,7 +124,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||
`label is too long, max: 63`,
|
||||
`domain name label is too long: got 72, max 63`,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_client_id",
|
||||
@@ -220,7 +222,7 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
`bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
)
|
||||
@@ -451,7 +451,7 @@ func matchesDomainWildcard(host, pat string) (ok bool) {
|
||||
// anyNameMatches returns true if sni, the client's SNI value, matches any of
|
||||
// the DNS names and patterns from certificate. dnsNames must be sorted.
|
||||
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||
if aghnet.ValidateDomainName(sni) != nil {
|
||||
if netutil.ValidateDomainName(sni) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIP = t
|
||||
}
|
||||
|
||||
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
|
||||
func (s *Server) setTableIPToHost(t *netutil.IPMap) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
@@ -188,18 +188,18 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
|
||||
var hostToIP hostToIPTable
|
||||
var ipToHost *aghnet.IPMap
|
||||
var ipToHost *netutil.IPMap
|
||||
if add {
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
|
||||
hostToIP = make(hostToIPTable, len(ll))
|
||||
ipToHost = aghnet.NewIPMap(len(ll))
|
||||
ipToHost = netutil.NewIPMap(len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished
|
||||
// with the client hostname validations in the DHCP
|
||||
// server code.
|
||||
err = aghnet.ValidateDomainName(l.Hostname)
|
||||
err = netutil.ValidateDomainName(l.Hostname)
|
||||
if err != nil {
|
||||
log.Debug(
|
||||
"dns: skipping invalid hostname %q from dhcp: %s",
|
||||
@@ -230,7 +230,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||
rc = resultCodeSuccess
|
||||
|
||||
var ip net.IP
|
||||
if ip = aghnet.IPFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
if ip, _ = netutil.IPAndPortFromAddr(dctx.proxyCtx.Addr); ip == nil {
|
||||
return rc
|
||||
}
|
||||
|
||||
@@ -331,12 +331,11 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
ip := aghnet.UnreverseAddr(q.Name)
|
||||
if ip == nil {
|
||||
// That's weird.
|
||||
//
|
||||
// TODO(e.burkov): Research the cases when it could happen.
|
||||
return resultCodeSuccess
|
||||
ip, err := netutil.IPFromReversedAddr(q.Name)
|
||||
if err != nil {
|
||||
log.Debug("dns: reversed addr: %s", err)
|
||||
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
// Restrict an access to local addresses for external clients. We also
|
||||
@@ -502,7 +501,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// ipStringFromAddr extracts an IP address string from net.Addr.
|
||||
func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
if ip := aghnet.IPFromAddr(addr); ip != nil {
|
||||
if ip, _ := netutil.IPAndPortFromAddr(addr); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -81,7 +82,7 @@ type Server struct {
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost *aghnet.IPMap
|
||||
tableIPToHost *netutil.IPMap
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for clientIDs that were
|
||||
@@ -141,7 +142,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
if p.LocalDomain == "" {
|
||||
localDomainSuffix = defaultLocalDomainSuffix
|
||||
} else {
|
||||
err = aghnet.ValidateDomainName(p.LocalDomain)
|
||||
err = netutil.ValidateDomainName(p.LocalDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("local domain: %w", err)
|
||||
}
|
||||
@@ -281,7 +282,12 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
arpa := dns.Fqdn(aghnet.ReverseAddr(ip))
|
||||
arpa, err := netutil.IPToReversedAddr(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reversing ip: %w", err)
|
||||
}
|
||||
|
||||
arpa = dns.Fqdn(arpa)
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
|
||||
@@ -1119,6 +1119,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
// TODO(a.garipov): Consider moving away from the text-based error
|
||||
// checks and onto a more structured approach.
|
||||
testCases := []struct {
|
||||
name string
|
||||
in DNSCreateParams
|
||||
@@ -1144,9 +1146,8 @@ func TestNewServer(t *testing.T) {
|
||||
in: DNSCreateParams{
|
||||
LocalDomain: "!!!",
|
||||
},
|
||||
wantErrMsg: `local domain: validating domain name "!!!": ` +
|
||||
`invalid domain name label at index 0: ` +
|
||||
`validating label "!!!": invalid char '!' at index 0`,
|
||||
wantErrMsg: `local domain: bad domain name "!!!": ` +
|
||||
`bad domain name label "!!!": bad domain name label rune '!'`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -5,10 +5,10 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
@@ -53,7 +53,8 @@ func (s *Server) beforeRequestHandler(
|
||||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
|
||||
setts := s.dnsFilter.GetConfig()
|
||||
if s.conf.FilterHandler != nil {
|
||||
s.conf.FilterHandler(aghnet.IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
|
||||
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
|
||||
s.conf.FilterHandler(ip, ctx.clientID, &setts)
|
||||
}
|
||||
|
||||
return &setts
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -443,7 +443,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
|
||||
continue
|
||||
}
|
||||
|
||||
err = aghnet.ValidateDomainName(host)
|
||||
err = netutil.ValidateDomainName(host)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -77,7 +77,7 @@ func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDet
|
||||
|
||||
// msgToSignature converts msg into it's signature represented in bytes.
|
||||
func msgToSignature(msg dns.Msg) (sig []byte) {
|
||||
sig = make([]byte, uint16sz*2+aghnet.MaxDomainNameLen)
|
||||
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
|
||||
// The binary.BigEndian byte order is used everywhere except when the
|
||||
// real machine's endianess is needed.
|
||||
byteOrder := binary.BigEndian
|
||||
@@ -95,7 +95,7 @@ func msgToSignature(msg dns.Msg) (sig []byte) {
|
||||
// See BenchmarkMsgToSignature.
|
||||
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
|
||||
type msgSignature struct {
|
||||
name [aghnet.MaxDomainNameLen]byte
|
||||
name [netutil.MaxDomainNameLen]byte
|
||||
id uint16
|
||||
qtype uint16
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -32,13 +32,14 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
|
||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
|
||||
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
|
||||
if shouldLog && s.queryLog != nil {
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
p := querylog.AddParams{
|
||||
Question: msg,
|
||||
Answer: pctx.Res,
|
||||
OrigAnswer: ctx.origResp,
|
||||
Result: ctx.result,
|
||||
Elapsed: elapsed,
|
||||
ClientIP: aghnet.IPFromAddr(pctx.Addr),
|
||||
ClientIP: ip,
|
||||
ClientID: ctx.clientID,
|
||||
}
|
||||
|
||||
@@ -80,7 +81,7 @@ func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res filteri
|
||||
|
||||
if clientID := ctx.clientID; clientID != "" {
|
||||
e.Client = clientID
|
||||
} else if ip := aghnet.IPFromAddr(pctx.Addr); ip != nil {
|
||||
} else if ip, _ := netutil.IPAndPortFromAddr(pctx.Addr); ip != nil {
|
||||
e.Client = ip.String()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user