Pull request: all: client id support

Merge in DNS/adguard-home from 1383-client-id to master

Updates #1383.

Squashed commit of the following:

commit ebe2678bfa9bf651a2cb1e64499b38edcf19a7ad
Author: Ildar Kamalov <ik@adguard.com>
Date:   Wed Jan 27 17:51:59 2021 +0300

    - client: check if IP is valid

commit 0c330585a170ea149ee75e43dfa65211e057299c
Author: Ildar Kamalov <ik@adguard.com>
Date:   Wed Jan 27 17:07:50 2021 +0300

    - client: find clients by client_id

commit 71c9593ee35d996846f061e114b7867c3aa3c978
Merge: 9104f161 3e9edd9e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jan 27 16:09:45 2021 +0300

    Merge branch 'master' into 1383-client-id

commit 9104f1615d2d462606c52017df25a422df872cea
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jan 27 13:28:50 2021 +0300

    dnsforward: imp tests

commit ed47f26e611ade625a2cc2c2f71a291b796bbf8f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Jan 27 12:39:52 2021 +0300

    dnsforward: fix address

commit 98b222ba69a5d265f620c180c960d01c84a1fb3b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 19:50:31 2021 +0300

    home: imp code

commit 4f3966548a2d8437d0b68207dd108dd1a6cb7d20
Merge: 199fdc05 c215b820
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 19:45:13 2021 +0300

    Merge branch 'master' into 1383-client-id

commit 199fdc056f8a8be5500584f3aaee32865188aedc
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 19:20:37 2021 +0300

    all: imp tests, logging, etc

commit 35ff14f4d534251aecb2ea60baba225f3eed8a3e
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jan 26 18:55:19 2021 +0300

    + client: remove block button from clients with client_id

commit 32991a0b4c56583a02fb5e00bba95d96000bce20
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Jan 26 18:54:25 2021 +0300

    + client: add requests count for client_id

commit 2d68df4d2eac4a296d7469923e601dad4575c1a1
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 15:49:50 2021 +0300

    stats: handle client ids

commit 4e14ab3590328f93a8cd6e9cbe1665baf74f220b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 13:45:25 2021 +0300

    openapi: fix example

commit ca9cf3f744fe197cace2c28ddc5bc68f71dad1f3
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 13:37:10 2021 +0300

    openapi: improve clients find api docs

commit f79876e550c424558b704bc316a4cd04f25db011
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Jan 26 13:18:52 2021 +0300

    home: accept ids in clients find

commit 5b72595122aa0bd64debadfd753ed8a0e0840629
Merge: 607e241f abf8f65f
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Jan 25 18:34:56 2021 +0300

    Merge branch 'master' into 1383-client-id

commit 607e241f1c339dd6397218f70b8301e3de6a1ee0
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Jan 25 18:30:39 2021 +0300

    dnsforward: fix quic

commit f046352fef93e46234c2bbe8ae316d21034260e5
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Jan 25 16:53:09 2021 +0300

    all: remove wildcard requirement

commit 3b679489bae82c54177372be453fe184d8f0bab6
Author: Andrey Meshkov <am@adguard.com>
Date:   Mon Jan 25 16:02:28 2021 +0300

    workDir now supports symlinks

commit 0647ab4f113de2223f6949df001f42ecab05c995
Author: Ildar Kamalov <ik@adguard.com>
Date:   Mon Jan 25 14:59:46 2021 +0300

    - client: remove wildcard from domain validation

commit b1aec04a4ecadc9d65648ed6d284188fecce01c3
Author: Ildar Kamalov <ik@adguard.com>
Date:   Mon Jan 25 14:55:39 2021 +0300

    + client: add form to download mobileconfig

... and 12 more commits
This commit is contained in:
Ainar Garipov
2021-01-27 18:32:13 +03:00
parent 3e9edd9eac
commit fc9ddcf941
56 changed files with 1735 additions and 623 deletions

View File

@@ -24,11 +24,12 @@ type FilteringConfig struct {
// Callbacks for other modules
// --
// Filtering callback function
FilterHandler func(clientAddr net.IP, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client
//
// TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
@@ -109,6 +110,10 @@ type TLSConfig struct {
CertificateChainData []byte `yaml:"-" json:"-"`
PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only
// being used for client ID checking.
ServerName string `yaml:"-" json:"-"`
cert tls.Certificate
// DNS names from certificate (SAN) or CN value from Subject
dnsNames []string

View File

@@ -1,7 +1,10 @@
package dnsforward
import (
"crypto/tls"
"fmt"
"net"
"path"
"strings"
"time"
@@ -10,37 +13,64 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/lucas-clemente/quic-go"
"github.com/miekg/dns"
)
// To transfer information between modules
type dnsContext struct {
srv *Server
proxyCtx *proxy.DNSContext
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
startTime time.Time
result *dnsfilter.Result
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
origQuestion dns.Question // question received from client. Set when Rewrites are used.
err error // error returned from the module
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
responseFromUpstream bool // response is received from upstream servers
origReqDNSSEC bool // DNSSEC flag in the original request from user
srv *Server
proxyCtx *proxy.DNSContext
// setts are the filtering settings for the client.
setts *dnsfilter.RequestFilteringSettings
startTime time.Time
result *dnsfilter.Result
// origResp is the response received from upstream. It is set when the
// response is modified by filters.
origResp *dns.Msg
// err is the error returned from a processing function.
err error
// clientID is the clientID from DOH, DOQ, or DOT, if provided.
clientID string
// origQuestion is the question received from the client. It is set
// when the request is modified by rewrites.
origQuestion dns.Question
// protectionEnabled shows if the filtering is enabled, and if the
// server's DNS filter is ready.
protectionEnabled bool
// responseFromUpstream shows if the response is received from the
// upstream servers.
responseFromUpstream bool
// origReqDNSSEC shows if the DNSSEC flag in the original request from
// the client is set.
origReqDNSSEC bool
}
// resultCode is the result of a request processing function.
type resultCode int
const (
resultDone = iota // module has completed its job, continue
resultFinish // module has completed its job, exit normally
resultError // an error occurred, exit with an error
// resultCodeSuccess is returned when a handler performed successfully,
// and the next handler must be called.
resultCodeSuccess resultCode = iota
// resultCodeFinish is returned when a handler performed successfully,
// and the processing of the request must be stopped.
resultCodeFinish
// resultCodeError is returned when a handler failed, and the processing
// of the request must be stopped.
resultCodeError
)
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{srv: s, proxyCtx: d}
ctx.result = &dnsfilter.Result{}
ctx.startTime = time.Now()
ctx := &dnsContext{
srv: s,
proxyCtx: d,
result: &dnsfilter.Result{},
startTime: time.Now(),
}
type modProcessFunc func(ctx *dnsContext) int
type modProcessFunc func(ctx *dnsContext) (rc resultCode)
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
// proxy.(Config).RequestHandler, there is no need for additional index
@@ -51,6 +81,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
processInitial,
processInternalHosts,
processInternalIPAddrs,
processClientID,
processFilteringBeforeRequest,
processUpstream,
processDNSSECAfterResponse,
@@ -61,13 +92,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
for _, process := range mods {
r := process(ctx)
switch r {
case resultDone:
case resultCodeSuccess:
// continue: call the next filter
case resultFinish:
case resultCodeFinish:
return nil
case resultError:
case resultCodeError:
return ctx.err
}
}
@@ -79,12 +110,12 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
}
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) int {
func processInitial(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true)
return resultFinish
return resultCodeFinish
}
if s.conf.OnDNSRequest != nil {
@@ -96,10 +127,10 @@ func processInitial(ctx *dnsContext) int {
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
d.Req.Question[0].Name == "use-application-dns.net." {
d.Res = s.genNXDomain(d.Req)
return resultFinish
return resultCodeFinish
}
return resultDone
return resultCodeSuccess
}
// Return TRUE if host names doesn't contain disallowed characters
@@ -157,29 +188,29 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
}
// Respond to A requests if the target host name is associated with a lease from our DHCP server
func processInternalHosts(ctx *dnsContext) int {
func processInternalHosts(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) {
return resultDone
return resultCodeSuccess
}
host := req.Question[0].Name
host = strings.ToLower(host)
if !strings.HasSuffix(host, ".lan.") {
return resultDone
return resultCodeSuccess
}
host = strings.TrimSuffix(host, ".lan.")
s.tableHostToIPLock.Lock()
if s.tableHostToIP == nil {
s.tableHostToIPLock.Unlock()
return resultDone
return resultCodeSuccess
}
ip, ok := s.tableHostToIP[host]
s.tableHostToIPLock.Unlock()
if !ok {
return resultDone
return resultCodeSuccess
}
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip)
@@ -200,15 +231,163 @@ func processInternalHosts(ctx *dnsContext) int {
}
ctx.proxyCtx.Res = resp
return resultDone
return resultCodeSuccess
}
const maxDomainPartLen = 64
// ValidateClientID returns an error if clientID is not a valid client ID.
func ValidateClientID(clientID string) (err error) {
if len(clientID) > maxDomainPartLen {
return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen)
}
for i, r := range clientID {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
continue
}
return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID)
}
return nil
}
// clientIDFromClientServerName extracts and validates a client ID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client.
func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) {
if hostSrvName == cliSrvName {
return "", nil
}
if !strings.HasSuffix(cliSrvName, hostSrvName) {
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
}
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
err = ValidateClientID(clientID)
if err != nil {
return "", fmt.Errorf("invalid client id: %w", err)
}
return clientID, nil
}
// processClientIDHTTPS extracts the client's ID from the path of the
// client's DNS-over-HTTPS request.
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
r := pctx.HTTPRequest
if r == nil {
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
return resultCodeError
}
origPath := r.URL.Path
parts := strings.Split(path.Clean(origPath), "/")
if parts[0] == "" {
parts = parts[1:]
}
if len(parts) == 0 || parts[0] != "dns-query" {
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
return resultCodeError
}
clientID := ""
switch len(parts) {
case 1:
// Just /dns-query, no client ID.
return resultCodeSuccess
case 2:
clientID = parts[1]
default:
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
return resultCodeError
}
err := ValidateClientID(clientID)
if err != nil {
ctx.err = fmt.Errorf("client id check: invalid client id: %w", err)
return resultCodeError
}
ctx.clientID = clientID
return resultCodeSuccess
}
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
type tlsConn interface {
ConnectionState() (cs tls.ConnectionState)
}
// quicSession is a narrow interface for quic.Session to simplify testing.
type quicSession interface {
ConnectionState() (cs quic.ConnectionState)
}
// processClientID extracts the client's ID from the server name of the client's
// DOT or DOQ request or the path of the client's DOH.
func processClientID(ctx *dnsContext) (rc resultCode) {
pctx := ctx.proxyCtx
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
return processClientIDHTTPS(ctx)
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return resultCodeSuccess
}
hostSrvName := ctx.srv.conf.TLSConfig.ServerName
if hostSrvName == "" {
return resultCodeSuccess
}
cliSrvName := ""
if proto == proxy.ProtoTLS {
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
return resultCodeError
}
cliSrvName = tc.ConnectionState().ServerName
} else if proto == proxy.ProtoQUIC {
qs, ok := pctx.QUICSession.(quicSession)
if !ok {
ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
return resultCodeError
}
cliSrvName = qs.ConnectionState().ServerName
}
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName)
if err != nil {
ctx.err = fmt.Errorf("client id check: %w", err)
return resultCodeError
}
ctx.clientID = clientID
return resultCodeSuccess
}
// Respond to PTR requests if the target IP address is leased by our DHCP server
func processInternalIPAddrs(ctx *dnsContext) int {
func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
req := ctx.proxyCtx.Req
if req.Question[0].Qtype != dns.TypePTR {
return resultDone
return resultCodeSuccess
}
arpa := req.Question[0].Name
@@ -216,18 +395,18 @@ func processInternalIPAddrs(ctx *dnsContext) int {
arpa = strings.ToLower(arpa)
ip := util.DNSUnreverseAddr(arpa)
if ip == nil {
return resultDone
return resultCodeSuccess
}
s.tablePTRLock.Lock()
if s.tablePTR == nil {
s.tablePTRLock.Unlock()
return resultDone
return resultCodeSuccess
}
host, ok := s.tablePTR[ip.String()]
s.tablePTRLock.Unlock()
if !ok {
return resultDone
return resultCodeSuccess
}
log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host)
@@ -243,16 +422,16 @@ func processInternalIPAddrs(ctx *dnsContext) int {
ptr.Ptr = host + "."
resp.Answer = append(resp.Answer, ptr)
ctx.proxyCtx.Res = resp
return resultDone
return resultCodeSuccess
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) int {
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
return resultCodeSuccess // response is already set - nothing to do
}
s.RLock()
@@ -266,24 +445,24 @@ func processFilteringBeforeRequest(ctx *dnsContext) int {
var err error
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if ctx.protectionEnabled {
ctx.setts = s.getClientRequestFilteringSettings(d)
ctx.setts = s.getClientRequestFilteringSettings(ctx)
ctx.result, err = s.filterDNSRequest(ctx)
}
s.RUnlock()
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
return resultDone
return resultCodeSuccess
}
// processUpstream passes request to upstream servers and handles the response.
func processUpstream(ctx *dnsContext) int {
func processUpstream(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
return resultCodeSuccess // response is already set - nothing to do
}
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
@@ -311,26 +490,26 @@ func processUpstream(ctx *dnsContext) int {
err := s.dnsProxy.Resolve(d)
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
ctx.responseFromUpstream = true
return resultDone
return resultCodeSuccess
}
// Process DNSSEC after response from upstream server
func processDNSSECAfterResponse(ctx *dnsContext) int {
func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers
!ctx.srv.conf.EnableDNSSEC {
return resultDone
return resultCodeSuccess
}
if !ctx.origReqDNSSEC {
optResp := d.Res.IsEdns0()
if optResp != nil && !optResp.Do() {
return resultDone
return resultCodeSuccess
}
// Remove RRSIG records from response
@@ -361,11 +540,11 @@ func processDNSSECAfterResponse(ctx *dnsContext) int {
d.Res.Ns = answers
}
return resultDone
return resultCodeSuccess
}
// Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) int {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
@@ -402,7 +581,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
ctx.result, err = s.filterDNSResponse(ctx)
if err != nil {
ctx.err = err
return resultError
return resultCodeError
}
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
@@ -411,5 +590,5 @@ func processFilteringAfterResponse(ctx *dnsContext) int {
}
}
return resultDone
return resultCodeSuccess
}

View File

@@ -0,0 +1,235 @@
package dnsforward
import (
"crypto/tls"
"net"
"net/http"
"net/url"
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/lucas-clemente/quic-go"
"github.com/stretchr/testify/assert"
)
// testTLSConn is a tlsConn for tests.
type testTLSConn struct {
// Conn is embedded here simply to make testTLSConn a net.Conn without
// acctually implementing all methods.
net.Conn
serverName string
}
// ConnectionState implements the tlsConn interface for testTLSConn.
func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) {
cs.ServerName = c.serverName
return cs
}
// testQUICSession is a quicSession for tests.
type testQUICSession struct {
// Session is embedded here simply to make testQUICSession
// a quic.Session without acctually implementing all methods.
quic.Session
serverName string
}
// ConnectionState implements the quicSession interface for testQUICSession.
func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
cs.ServerName = c.serverName
return cs
}
func TestProcessClientID(t *testing.T) {
testCases := []struct {
name string
proto string
hostSrvName string
cliSrvName string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "udp",
proto: proxy.ProtoUDP,
hostSrvName: "",
cliSrvName: "",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_no_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_client_id_hostname_error",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.net",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`,
wantRes: resultCodeError,
}, {
name: "tls_invalid_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "!!!.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
wantRes: resultCodeError,
}, {
name: "tls_client_id_too_long",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`,
wantRes: resultCodeError,
}, {
name: "quic_client_id",
proto: proxy.ProtoQUIC,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
srv := &Server{
conf: ServerConfig{
TLSConfig: TLSConfig{ServerName: tc.hostSrvName},
},
}
var conn net.Conn
if tc.proto == proxy.ProtoTLS {
conn = testTLSConn{
serverName: tc.cliSrvName,
}
}
var qs quic.Session
if tc.proto == proxy.ProtoQUIC {
qs = testQUICSession{
serverName: tc.cliSrvName,
}
}
dctx := &dnsContext{
srv: srv,
proxyCtx: &proxy.DNSContext{
Proto: tc.proto,
Conn: conn,
QUICSession: qs,
},
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
} else {
assert.Nil(t, dctx.err)
}
})
}
}
func TestProcessClientID_https(t *testing.T) {
testCases := []struct {
name string
path string
wantClientID string
wantErrMsg string
wantRes resultCode
}{{
name: "no_client_id",
path: "/dns-query",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "no_client_id_slash",
path: "/dns-query/",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id",
path: "/dns-query/cli",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "client_id_slash",
path: "/dns-query/cli/",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "bad_url",
path: "/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/foo"`,
wantRes: resultCodeError,
}, {
name: "extra",
path: "/dns-query/cli/foo",
wantClientID: "",
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
wantRes: resultCodeError,
}, {
name: "invalid_client_id",
path: "/dns-query/!!!",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
wantRes: resultCodeError,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := &http.Request{
URL: &url.URL{
Path: tc.path,
},
}
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Proto: proxy.ProtoHTTPS,
HTTPRequest: r,
},
}
res := processClientID(dctx)
assert.Equal(t, tc.wantRes, res)
assert.Equal(t, tc.wantClientID, dctx.clientID)
if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) {
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
} else {
assert.Nil(t, dctx.err)
}
})
}
}

View File

@@ -473,7 +473,7 @@ func TestBlockCNAME(t *testing.T) {
func TestClientRulesForCNAMEMatching(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.FilterHandler = func(_ net.IP, settings *dnsfilter.RequestFilteringSettings) {
s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) {
settings.FilteringEnabled = false
}
err := s.startWithUpstream(testUpstm)
@@ -1033,8 +1033,7 @@ func TestMatchDNSName(t *testing.T) {
assert.False(t, matchDNSName(dnsNames, "*.host2"))
}
type testDHCP struct {
}
type testDHCP struct{}
func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
l := dhcpd.Lease{}

View File

@@ -30,14 +30,15 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
return true, nil
}
// getClientRequestFilteringSettings lookups client filtering settings
// using the client's IP address from the DNSContext
func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings {
// getClientRequestFilteringSettings looks up client filtering settings using
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings {
setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true
if s.conf.FilterHandler != nil {
s.conf.FilterHandler(IPFromAddr(d.Addr), &setts)
s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts)
}
return &setts
}

View File

@@ -529,5 +529,5 @@ func (s *Server) registerHandlers() {
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
s.conf.HTTPRegister("", "/dns-query", s.handleDOH)
s.conf.HTTPRegister("", "/dns-query/", s.handleDOH)
}

View File

@@ -99,12 +99,12 @@ func (c *ipsetCtx) getIP(rr dns.RR) net.IP {
}
// Add IP addresses of the specified in configuration domain names to an ipset list
func (c *ipsetCtx) process(ctx *dnsContext) int {
func (c *ipsetCtx) process(ctx *dnsContext) (rc resultCode) {
req := ctx.proxyCtx.Req
if !(req.Question[0].Qtype == dns.TypeA ||
req.Question[0].Qtype == dns.TypeAAAA) ||
!ctx.responseFromUpstream {
return resultDone
return resultCodeSuccess
}
host := req.Question[0].Name
@@ -112,7 +112,7 @@ func (c *ipsetCtx) process(ctx *dnsContext) int {
host = strings.ToLower(host)
ipsetNames, found := c.ipsetList[host]
if !found {
return resultDone
return resultCodeSuccess
}
log.Debug("IPSET: found ipsets %v for host %s", ipsetNames, host)
@@ -138,5 +138,5 @@ func (c *ipsetCtx) process(ctx *dnsContext) int {
}
}
return resultDone
return resultCodeSuccess
}

View File

@@ -37,5 +37,5 @@ func TestIPSET(t *testing.T) {
},
},
}
assert.Equal(t, resultDone, c.process(ctx))
assert.Equal(t, resultCodeSuccess, c.process(ctx))
}

View File

@@ -1,7 +1,6 @@
package dnsforward
import (
"net"
"strings"
"time"
@@ -13,13 +12,13 @@ import (
)
// Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) int {
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
d := ctx.proxyCtx
pctx := ctx.proxyCtx
shouldLog := true
msg := d.Req
msg := pctx.Req
// don't log ANY request if refuseAny is enabled
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny {
@@ -32,65 +31,67 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
if shouldLog && s.queryLog != nil {
p := querylog.AddParams{
Question: msg,
Answer: d.Res,
Answer: pctx.Res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: IPFromAddr(d.Addr),
ClientIP: IPFromAddr(pctx.Addr),
ClientID: ctx.clientID,
}
switch d.Proto {
switch pctx.Proto {
case proxy.ProtoHTTPS:
p.ClientProto = querylog.ClientProtoDOH
case proxy.ProtoQUIC:
p.ClientProto = querylog.ClientProtoDOQ
case proxy.ProtoTLS:
p.ClientProto = querylog.ClientProtoDOT
case proxy.ProtoDNSCrypt:
p.ClientProto = querylog.ClientProtoDNSCrypt
default:
// Consider this a plain DNS-over-UDP or DNS-over-TCL
// Consider this a plain DNS-over-UDP or DNS-over-TCP
// request.
}
if d.Upstream != nil {
p.Upstream = d.Upstream.Address()
if pctx.Upstream != nil {
p.Upstream = pctx.Upstream.Address()
}
s.queryLog.Add(p)
}
s.updateStats(d, elapsed, *ctx.result)
s.updateStats(ctx, elapsed, *ctx.result)
s.RUnlock()
return resultDone
return resultCodeSuccess
}
func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) {
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) {
if s.stats == nil {
return
}
pctx := ctx.proxyCtx
e := stats.Entry{}
e.Domain = strings.ToLower(d.Req.Question[0].Name)
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
switch addr := d.Addr.(type) {
case *net.UDPAddr:
e.Client = addr.IP
case *net.TCPAddr:
e.Client = addr.IP
if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if ip := IPFromAddr(pctx.Addr); ip != nil {
e.Client = ip.String()
}
e.Time = uint32(elapsed / 1000)
e.Result = stats.RNotFiltered
switch res.Reason {
case dnsfilter.FilteredSafeBrowsing:
e.Result = stats.RSafeBrowsing
case dnsfilter.FilteredParental:
e.Result = stats.RParental
case dnsfilter.FilteredSafeSearch:
e.Result = stats.RSafeSearch
case dnsfilter.FilteredBlockList:
fallthrough
case dnsfilter.FilteredInvalid:

View File

@@ -0,0 +1,198 @@
package dnsforward
import (
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
// testQueryLog is a simple querylog.QueryLog implementation for tests.
type testQueryLog struct {
// QueryLog is embedded here simply to make testQueryLog
// a querylog.QueryLog without acctually implementing all methods.
querylog.QueryLog
lastParams querylog.AddParams
}
// Add implements the querylog.QueryLog interface for *testQueryLog.
func (l *testQueryLog) Add(p querylog.AddParams) {
l.lastParams = p
}
// testStats is a simple stats.Stats implementation for tests.
type testStats struct {
// Stats is embedded here simply to make testStats a stats.Stats without
// acctually implementing all methods.
stats.Stats
lastEntry stats.Entry
}
// Update implements the stats.Stats interface for *testStats.
func (l *testStats) Update(e stats.Entry) {
l.lastEntry = e
}
func TestProcessQueryLogsAndStats(t *testing.T) {
testCases := []struct {
name string
proto string
addr net.Addr
clientID string
wantLogProto querylog.ClientProto
wantStatClient string
wantCode resultCode
reason dnsfilter.Reason
wantStatResult stats.Result
}{{
name: "success_udp",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls_client_id",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "cli42",
wantLogProto: querylog.ClientProtoDOT,
wantStatClient: "cli42",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_tls",
proto: proxy.ProtoTLS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOT,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_quic",
proto: proxy.ProtoQUIC,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOQ,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_https",
proto: proxy.ProtoHTTPS,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDOH,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_dnscrypt",
proto: proxy.ProtoDNSCrypt,
addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: querylog.ClientProtoDNSCrypt,
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.NotFilteredNotFound,
wantStatResult: stats.RNotFiltered,
}, {
name: "success_udp_filtered",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredBlockList,
wantStatResult: stats.RFiltered,
}, {
name: "success_udp_sb",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredSafeBrowsing,
wantStatResult: stats.RSafeBrowsing,
}, {
name: "success_udp_ss",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredSafeSearch,
wantStatResult: stats.RSafeSearch,
}, {
name: "success_udp_pc",
proto: proxy.ProtoUDP,
addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234},
clientID: "",
wantLogProto: "",
wantStatClient: "1.2.3.4",
wantCode: resultCodeSuccess,
reason: dnsfilter.FilteredParental,
wantStatResult: stats.RParental,
}}
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
assert.Nil(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := &dns.Msg{
Question: []dns.Question{{
Name: "example.com.",
}},
}
pctx := &proxy.DNSContext{
Proto: tc.proto,
Req: req,
Res: &dns.Msg{},
Addr: tc.addr,
Upstream: ups,
}
ql := &testQueryLog{}
st := &testStats{}
dctx := &dnsContext{
srv: &Server{
queryLog: ql,
stats: st,
},
proxyCtx: pctx,
startTime: time.Now(),
result: &dnsfilter.Result{
Reason: tc.reason,
},
clientID: tc.clientID,
}
code := processQueryLogsAndStats(dctx)
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
assert.Equal(t, tc.wantStatResult, st.lastEntry.Result)
})
}
}