Pull request 2021: upd golibs
Merge in DNS/adguard-home from upd-golibs to master
Squashed commit of the following:
commit 266b002c5450329761dee21d918c80d08e5d8ab9
Merge: 99eb7745d e305bd8e4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 5 14:21:51 2023 +0300
Merge branch 'master' into upd-golibs
commit 99eb7745d0bee190399f9b16cb7151f34a591b54
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 5 14:14:28 2023 +0300
home: imp alignment
commit 556cde56720ce449aec17b500825681fc8c084bf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 5 13:35:35 2023 +0300
dnsforward: imp naming, docs
commit 1ee99655a3318263db1edbcb9e4eeb33bfe441c8
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Thu Oct 5 13:28:39 2023 +0300
home: make ports uint16
commit b228032ea1f5902ab9bac7b5d55d84aaf6354616
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 4 18:56:59 2023 +0300
all: rm system resolvers
commit 4b5becbed5890db80612e53861f000aaf4c869ff
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Wed Oct 4 17:30:16 2023 +0300
all: upd golibs
This commit is contained in:
@@ -12,12 +12,12 @@ import (
|
||||
// listenPacketReusable announces on the local network address additionally
|
||||
// configuring the socket to have a reusable binding.
|
||||
func listenPacketReusable(ifaceName, network, address string) (c net.PacketConn, err error) {
|
||||
var port int
|
||||
var port uint16
|
||||
_, port, err = netutil.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect nclient4.NewRawUDPConn and implement here.
|
||||
return nclient4.NewRawUDPConn(ifaceName, port)
|
||||
return nclient4.NewRawUDPConn(ifaceName, int(port))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package aghnet_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -64,7 +65,7 @@ ParseAddr("256.256.256.256"): IPv4 field has value >255`,
|
||||
want: netip.AddrPort{},
|
||||
}, {
|
||||
name: "error_invalid_port",
|
||||
input: netutil.JoinHostPort(v4addr.String(), -5),
|
||||
input: net.JoinHostPort(v4addr.String(), "-5"),
|
||||
wantErrMsg: `invalid port "-5" parsing "1.2.3.4:-5"
|
||||
ParseAddr("1.2.3.4:-5"): unexpected character (at ":-5")`,
|
||||
want: netip.AddrPort{},
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
// DefaultRefreshIvl is the default period of time between refreshing cached
|
||||
// addresses.
|
||||
// const DefaultRefreshIvl = 5 * time.Minute
|
||||
|
||||
// HostGenFunc is the signature for functions generating fake hostnames. The
|
||||
// implementation must be safe for concurrent use.
|
||||
type HostGenFunc func() (host string)
|
||||
|
||||
// SystemResolvers helps to work with local resolvers' addresses provided by OS.
|
||||
type SystemResolvers interface {
|
||||
// Get returns the slice of local resolvers' addresses. It must be safe for
|
||||
// concurrent use.
|
||||
Get() (rs []string)
|
||||
// refresh refreshes the local resolvers' addresses cache. It must be safe
|
||||
// for concurrent use.
|
||||
refresh() (err error)
|
||||
}
|
||||
|
||||
// NewSystemResolvers returns a SystemResolvers with the cache refresh rate
|
||||
// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If
|
||||
// nil is passed for hostGenFunc, the default generator will be used.
|
||||
func NewSystemResolvers(
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers, err error) {
|
||||
sr = newSystemResolvers(hostGenFunc)
|
||||
|
||||
// Fill cache.
|
||||
err = sr.refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
// defaultHostGen is the default method of generating host for Refresh.
|
||||
func defaultHostGen() (host string) {
|
||||
// TODO(e.burkov): Use strings.Builder.
|
||||
return fmt.Sprintf("test%d.org", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// systemResolvers is a default implementation of SystemResolvers interface.
|
||||
type systemResolvers struct {
|
||||
// addrsLock protects addrs.
|
||||
addrsLock sync.RWMutex
|
||||
// addrs is the set that contains cached local resolvers' addresses.
|
||||
addrs *stringutil.Set
|
||||
|
||||
// resolver is used to fetch the resolvers' addresses.
|
||||
resolver *net.Resolver
|
||||
// hostGenFunc generates hosts to resolve.
|
||||
hostGenFunc HostGenFunc
|
||||
}
|
||||
|
||||
const (
|
||||
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
|
||||
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
|
||||
|
||||
// errFakeDial is an error which dialFunc is expected to return.
|
||||
errFakeDial errors.Error = "this error signals the successful dialFunc work"
|
||||
|
||||
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
|
||||
// more than one percent sign.
|
||||
errUnexpectedHostFormat errors.Error = "unexpected host format"
|
||||
)
|
||||
|
||||
// refresh implements the SystemResolvers interface for *systemResolvers.
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
|
||||
dnserr := &net.DNSError{}
|
||||
if errors.As(err, &dnserr) && dnserr.Err == errFakeDial.Error() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newSystemResolvers(hostGenFunc HostGenFunc) (sr SystemResolvers) {
|
||||
if hostGenFunc == nil {
|
||||
hostGenFunc = defaultHostGen
|
||||
}
|
||||
s := &systemResolvers{
|
||||
resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
hostGenFunc: hostGenFunc,
|
||||
addrs: stringutil.NewSet(),
|
||||
}
|
||||
s.resolver.Dial = s.dialFunc
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// validateDialedHost validated the host used by resolvers in dialFunc.
|
||||
func validateDialedHost(host string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
|
||||
|
||||
parts := strings.Split(host, "%")
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// host
|
||||
case 2:
|
||||
// Remove the zone and check the IP address part.
|
||||
host = parts[0]
|
||||
default:
|
||||
return errUnexpectedHostFormat
|
||||
}
|
||||
|
||||
if _, err = netutil.ParseIP(host); err != nil {
|
||||
return errBadAddrPassed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dockerEmbeddedDNS is the address of Docker's embedded DNS server.
|
||||
//
|
||||
// See
|
||||
// https://github.com/moby/moby/blob/v1.12.0/docs/userguide/networking/dockernetworks.md.
|
||||
const dockerEmbeddedDNS = "127.0.0.11"
|
||||
|
||||
// dialFunc gets the resolver's address and puts it into internal cache.
|
||||
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
|
||||
// Just validate the passed address is a valid IP.
|
||||
var host string
|
||||
host, err = netutil.SplitHost(address)
|
||||
if err != nil {
|
||||
// TODO(e.burkov): Maybe use a structured errBadAddrPassed to
|
||||
// allow unwrapping of the real error.
|
||||
return nil, fmt.Errorf("%s: %w", err, errBadAddrPassed)
|
||||
}
|
||||
|
||||
// Exclude Docker's embedded DNS server, as it may cause recursion if
|
||||
// the container is set as the host system's default DNS server.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3064.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps only do this when we are in the container?
|
||||
// Maybe use an environment variable?
|
||||
if host == dockerEmbeddedDNS {
|
||||
return nil, errFakeDial
|
||||
}
|
||||
|
||||
err = validateDialedHost(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating dialed host: %w", err)
|
||||
}
|
||||
|
||||
sr.addrsLock.Lock()
|
||||
defer sr.addrsLock.Unlock()
|
||||
|
||||
sr.addrs.Add(host)
|
||||
|
||||
return nil, errFakeDial
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Get() (rs []string) {
|
||||
sr.addrsLock.RLock()
|
||||
defer sr.addrsLock.RUnlock()
|
||||
|
||||
return sr.addrs.Values()
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolversImpl(
|
||||
t *testing.T,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (imp *systemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
sr := createTestSystemResolvers(t, hostGenFunc)
|
||||
|
||||
return testutil.RequireTypeAssert[*systemResolvers](t, sr)
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Refresh(t *testing.T) {
|
||||
t.Run("expected_error", func(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, nil)
|
||||
|
||||
assert.NoError(t, sr.refresh())
|
||||
})
|
||||
|
||||
t.Run("unexpected_error", func(t *testing.T) {
|
||||
_, err := NewSystemResolvers(func() string {
|
||||
return "127.0.0.1::123"
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSystemResolvers_DialFunc(t *testing.T) {
|
||||
imp := createTestSystemResolversImpl(t, nil)
|
||||
|
||||
testCases := []struct {
|
||||
want error
|
||||
name string
|
||||
address string
|
||||
}{{
|
||||
want: errFakeDial,
|
||||
name: "valid_ipv4",
|
||||
address: "127.0.0.1",
|
||||
}, {
|
||||
want: errFakeDial,
|
||||
name: "valid_ipv6_port",
|
||||
address: "[::1]:53",
|
||||
}, {
|
||||
want: errFakeDial,
|
||||
name: "valid_ipv6_zone_port",
|
||||
address: "[::1%lo0]:53",
|
||||
}, {
|
||||
want: errBadAddrPassed,
|
||||
name: "invalid_split_host",
|
||||
address: "127.0.0.1::123",
|
||||
}, {
|
||||
want: errUnexpectedHostFormat,
|
||||
name: "invalid_ipv6_zone_port",
|
||||
address: "[::1%%lo0]:53",
|
||||
}, {
|
||||
want: errBadAddrPassed,
|
||||
name: "invalid_parse_ip",
|
||||
address: "not-ip",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
conn, err := imp.dialFunc(context.Background(), "", tc.address)
|
||||
require.Nil(t, conn)
|
||||
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func createTestSystemResolvers(
|
||||
t *testing.T,
|
||||
hostGenFunc HostGenFunc,
|
||||
) (sr SystemResolvers) {
|
||||
t.Helper()
|
||||
|
||||
var err error
|
||||
sr, err = NewSystemResolvers(hostGenFunc)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sr)
|
||||
|
||||
return sr
|
||||
}
|
||||
|
||||
func TestSystemResolvers_Get(t *testing.T) {
|
||||
sr := createTestSystemResolvers(t, nil)
|
||||
|
||||
var rs []string
|
||||
require.NotPanics(t, func() {
|
||||
rs = sr.Get()
|
||||
})
|
||||
|
||||
assert.NotEmpty(t, rs)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Write tests for refreshWithTicker.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
||||
@@ -1,163 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// systemResolvers implementation differs for Windows since Go's resolver
|
||||
// doesn't work there.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/33097.
|
||||
type systemResolvers struct {
|
||||
// addrs is the slice of cached local resolvers' addresses.
|
||||
addrs []string
|
||||
addrsLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSystemResolvers(_ HostGenFunc) (sr SystemResolvers) {
|
||||
return &systemResolvers{}
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) Get() (rs []string) {
|
||||
sr.addrsLock.RLock()
|
||||
defer sr.addrsLock.RUnlock()
|
||||
|
||||
addrs := sr.addrs
|
||||
rs = make([]string, len(addrs))
|
||||
copy(rs, addrs)
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
// writeExit writes "exit" to w and closes it. It is supposed to be run in
|
||||
// a goroutine.
|
||||
func writeExit(w io.WriteCloser) {
|
||||
defer log.OnPanic("systemResolvers: writeExit")
|
||||
|
||||
defer func() {
|
||||
derr := w.Close()
|
||||
if derr != nil {
|
||||
log.Error("systemResolvers: writeExit: closing: %s", derr)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := io.WriteString(w, "exit")
|
||||
if err != nil {
|
||||
log.Error("systemResolvers: writeExit: writing: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// scanAddrs scans the DNS addresses from nslookup's output. The expected
|
||||
// output of nslookup looks like this:
|
||||
//
|
||||
// Default Server: 192-168-1-1.qualified.domain.ru
|
||||
// Address: 192.168.1.1
|
||||
func scanAddrs(s *bufio.Scanner) (addrs []string) {
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 || fields[0] != "Address:" {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the address contains port then it is separated with '#'.
|
||||
ipPort := strings.Split(fields[1], "#")
|
||||
if len(ipPort) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
addr := ipPort[0]
|
||||
if net.ParseIP(addr) == nil {
|
||||
log.Debug("systemResolvers: %q is not a valid ip", addr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
// getAddrs gets local resolvers' addresses from OS in a special Windows way.
|
||||
//
|
||||
// TODO(e.burkov): This whole function needs more detailed research on getting
|
||||
// local resolvers addresses on Windows. We execute the external command for
|
||||
// now that is not the most accurate way.
|
||||
func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||
var cmdPath string
|
||||
cmdPath, err = exec.LookPath("nslookup.exe")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("looking up cmd path: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(cmdPath)
|
||||
|
||||
var stdin io.WriteCloser
|
||||
stdin, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting the command's stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
var stdout io.ReadCloser
|
||||
stdout, err = cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
go writeExit(stdin)
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start command executing: %w", err)
|
||||
}
|
||||
|
||||
s := bufio.NewScanner(stdout)
|
||||
addrs = scanAddrs(s)
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("executing the command: %w", err)
|
||||
}
|
||||
|
||||
err = s.Err()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning output: %w", err)
|
||||
}
|
||||
|
||||
// Don't close StdoutPipe since Wait do it for us in ¿most? cases.
|
||||
//
|
||||
// See [exec.Cmd.StdoutPipe].
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
got, err := sr.getAddrs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't get addresses: %w", err)
|
||||
}
|
||||
if len(got) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sr.addrsLock.Lock()
|
||||
defer sr.addrsLock.Unlock()
|
||||
|
||||
sr.addrs = got
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
// TODO(e.burkov): Write tests for Windows implementation.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2846.
|
||||
@@ -436,18 +436,18 @@ func (conf *ServerConfig) collectDNSAddrs() (
|
||||
// defaultPlainDNSPort is the default port for plain DNS.
|
||||
const defaultPlainDNSPort uint16 = 53
|
||||
|
||||
// upstreamMatcher is a function that matches address of an upstream.
|
||||
type upstreamMatcher func(addr netip.AddrPort) (ok bool)
|
||||
// addrPortMatcher is a function that matches an IP address with port.
|
||||
type addrPortMatcher func(addr netip.AddrPort) (ok bool)
|
||||
|
||||
// filterOut filters out all the upstreams that match um. It returns all the
|
||||
// closing errors joined.
|
||||
func (um upstreamMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
func (m addrPortMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
var errs []error
|
||||
delFunc := func(u upstream.Upstream) (ok bool) {
|
||||
// TODO(e.burkov): We should probably consider the protocol of u to
|
||||
// only filter out the listening addresses of the same protocol.
|
||||
addr, parseErr := aghnet.ParseAddrPort(u.Address(), defaultPlainDNSPort)
|
||||
if parseErr != nil || !um(addr) {
|
||||
if parseErr != nil || !m(addr) {
|
||||
// Don't filter out the upstream if it either cannot be parsed, or
|
||||
// does not match um.
|
||||
return false
|
||||
@@ -469,23 +469,21 @@ func (um upstreamMatcher) filterOut(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// filterOurAddrs filters out all the upstreams that pointing to the local
|
||||
// listening addresses to avoid recursive queries. upsConf may appear empty
|
||||
// after the filtering. All the filtered upstreams are closed and these
|
||||
// closings errors are joined.
|
||||
func (conf *ServerConfig) filterOurAddrs(upsConf *proxy.UpstreamConfig) (err error) {
|
||||
// ourAddrsMatcher returns a matcher that matches all the configured listening
|
||||
// addresses.
|
||||
func (conf *ServerConfig) ourAddrsMatcher() (m addrPortMatcher, err error) {
|
||||
addrs, unspecPorts := conf.collectDNSAddrs()
|
||||
if len(addrs) == 0 {
|
||||
log.Debug("dnsforward: no listen addresses")
|
||||
|
||||
return nil
|
||||
// Match no addresses.
|
||||
return func(_ netip.AddrPort) (ok bool) { return false }, nil
|
||||
}
|
||||
|
||||
var matcher upstreamMatcher
|
||||
if len(unspecPorts) == 0 {
|
||||
log.Debug("dnsforward: filtering out addresses %s", addrs)
|
||||
|
||||
matcher = func(a netip.AddrPort) (ok bool) {
|
||||
m = func(a netip.AddrPort) (ok bool) {
|
||||
_, ok = addrs[a]
|
||||
|
||||
return ok
|
||||
@@ -495,12 +493,12 @@ func (conf *ServerConfig) filterOurAddrs(upsConf *proxy.UpstreamConfig) (err err
|
||||
ifaceAddrs, err = aghnet.CollectAllIfacesAddrs()
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: filtering out addresses %s on ports %d", ifaceAddrs, unspecPorts)
|
||||
|
||||
matcher = func(a netip.AddrPort) (ok bool) {
|
||||
m = func(a netip.AddrPort) (ok bool) {
|
||||
if _, ok = unspecPorts[a.Port()]; ok {
|
||||
return slices.Contains(ifaceAddrs, a.Addr())
|
||||
}
|
||||
@@ -509,7 +507,7 @@ func (conf *ServerConfig) filterOurAddrs(upsConf *proxy.UpstreamConfig) (err err
|
||||
}
|
||||
}
|
||||
|
||||
return matcher.filterOut(upsConf)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
|
||||
@@ -25,8 +25,10 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/sysresolv"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
@@ -72,6 +74,11 @@ type DHCP interface {
|
||||
Enabled() (ok bool)
|
||||
}
|
||||
|
||||
type SystemResolvers interface {
|
||||
// Addrs returns the list of system resolvers' addresses.
|
||||
Addrs() (addrs []netip.AddrPort)
|
||||
}
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@@ -126,7 +133,7 @@ type Server struct {
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers aghnet.SystemResolvers
|
||||
sysResolvers SystemResolvers
|
||||
|
||||
// recDetector is a cache for recursive requests. It is used to detect
|
||||
// and prevent recursive requests only for private upstreams.
|
||||
@@ -225,9 +232,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
anonymizer: p.Anonymizer,
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
// passes the public testing.
|
||||
s.sysResolvers, err = aghnet.NewSystemResolvers(nil)
|
||||
s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing system resolvers: %w", err)
|
||||
}
|
||||
@@ -442,19 +447,30 @@ const defaultLocalTimeout = 1 * time.Second
|
||||
// setupLocalResolvers initializes the resolvers for local addresses. For
|
||||
// internal use only.
|
||||
func (s *Server) setupLocalResolvers() (err error) {
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
bootstraps := s.conf.BootstrapDNS
|
||||
resolvers := s.conf.LocalPTRResolvers
|
||||
filterConfig := false
|
||||
|
||||
if len(resolvers) == 0 {
|
||||
resolvers = s.sysResolvers.Get()
|
||||
bootstraps = nil
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
resolvers = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
resolvers = append(resolvers, r.String())
|
||||
}
|
||||
} else {
|
||||
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
|
||||
filterConfig = true
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
|
||||
|
||||
uc, err := s.prepareLocalUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
uc, err := s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's certificates?
|
||||
@@ -464,6 +480,12 @@ func (s *Server) setupLocalResolvers() (err error) {
|
||||
return fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
if filterConfig {
|
||||
if err = matcher.filterOut(uc); err != nil {
|
||||
return fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: uc,
|
||||
|
||||
@@ -179,17 +179,16 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
// defaultLocalPTRUpstreams returns the list of default local PTR resolvers
|
||||
// filtered of AdGuard Home's own DNS server addresses. It may appear empty.
|
||||
func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
uc, err := s.prepareLocalUpstreamConfig(s.sysResolvers.Get(), nil, nil)
|
||||
matcher, err := s.conf.ourAddrsMatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting system upstream config: %w", err)
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
defer func() { err = errors.Join(err, uc.Close()) }()
|
||||
|
||||
for _, u := range uc.Upstreams {
|
||||
ups = append(ups, u.Address())
|
||||
sysResolvers := slices.DeleteFunc(s.sysResolvers.Addrs(), matcher)
|
||||
ups = make([]string, 0, len(sysResolvers))
|
||||
for _, r := range sysResolvers {
|
||||
ups = append(ups, r.String())
|
||||
}
|
||||
|
||||
return ups, nil
|
||||
@@ -228,7 +227,7 @@ func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||
return errors.Error("empty")
|
||||
}
|
||||
|
||||
if _, err = upstream.NewResolver(b, nil); err != nil {
|
||||
if _, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,17 +28,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fakeSystemResolvers is a mock aghnet.SystemResolvers implementation for
|
||||
// tests.
|
||||
type fakeSystemResolvers struct {
|
||||
// SystemResolvers is embedded here simply to make *fakeSystemResolvers
|
||||
// an aghnet.SystemResolvers without actually implementing all methods.
|
||||
aghnet.SystemResolvers
|
||||
}
|
||||
// emptySysResolvers is an empty [SystemResolvers] implementation that always
|
||||
// returns nil.
|
||||
type emptySysResolvers struct{}
|
||||
|
||||
// Get implements the aghnet.SystemResolvers interface for *fakeSystemResolvers.
|
||||
// It always returns nil.
|
||||
func (fsr *fakeSystemResolvers) Get() (rs []string) {
|
||||
// Addrs implements the aghnet.SystemResolvers interface for emptySysResolvers.
|
||||
func (emptySysResolvers) Addrs() (addrs []netip.AddrPort) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -79,7 +74,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &fakeSystemResolvers{}
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
require.NoError(t, s.Start())
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
@@ -156,7 +151,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
ConfigModified: func() {},
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
||||
s.sysResolvers = &fakeSystemResolvers{}
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
@@ -485,7 +480,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
hostsListener := newLocalUpstreamListener(t, 0, goodHandler)
|
||||
hostsUps := (&url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: netutil.JoinHostPort(upstreamHost, int(hostsListener.Port())),
|
||||
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
|
||||
}).String()
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(
|
||||
|
||||
@@ -103,27 +103,6 @@ func (s *Server) prepareUpstreamConfig(
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// prepareLocalUpstreamConfig returns the upstream configuration for private
|
||||
// upstreams based on upstreams and configuration of s. It also filters out
|
||||
// the own listening addresses from the upstreams, so it may appear empty.
|
||||
func (s *Server) prepareLocalUpstreamConfig(
|
||||
upstreams []string,
|
||||
defaultUpstreams []string,
|
||||
opts *upstream.Options,
|
||||
) (uc *proxy.UpstreamConfig, err error) {
|
||||
uc, err = s.prepareUpstreamConfig(upstreams, defaultUpstreams, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("preparing private upstreams: %w", err)
|
||||
}
|
||||
|
||||
err = s.conf.filterOurAddrs(uc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filtering private upstreams: %w", err)
|
||||
}
|
||||
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
// replaceUpstreamsWithHosts replaces unique upstreams with their resolved
|
||||
// versions based on the system hosts file.
|
||||
//
|
||||
|
||||
@@ -182,7 +182,7 @@ type httpPprofConfig struct {
|
||||
// not absolutely necessary.
|
||||
type dnsConfig struct {
|
||||
BindHosts []netip.Addr `yaml:"bind_hosts"`
|
||||
Port int `yaml:"port"`
|
||||
Port uint16 `yaml:"port"`
|
||||
|
||||
// AnonymizeClientIP defines if clients' IP addresses should be anonymized
|
||||
// in query log and statistics.
|
||||
@@ -232,13 +232,13 @@ type tlsConfigSettings struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
PortHTTPS int `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS int `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
|
||||
PortDNSOverQUIC int `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
PortHTTPS uint16 `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS uint16 `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
|
||||
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
|
||||
// PortDNSCrypt is the port for DNSCrypt requests. If it's zero,
|
||||
// DNSCrypt is disabled.
|
||||
PortDNSCrypt int `yaml:"port_dnscrypt" json:"port_dnscrypt"`
|
||||
PortDNSCrypt uint16 `yaml:"port_dnscrypt" json:"port_dnscrypt"`
|
||||
// DNSCryptConfigFile is the path to the DNSCrypt config file. Must be
|
||||
// set if PortDNSCrypt is not zero.
|
||||
//
|
||||
@@ -554,10 +554,10 @@ func validateConfig() (err error) {
|
||||
}
|
||||
|
||||
// udpPort is the port number for UDP protocol.
|
||||
type udpPort int
|
||||
type udpPort uint16
|
||||
|
||||
// tcpPort is the port number for TCP protocol.
|
||||
type tcpPort int
|
||||
type tcpPort uint16
|
||||
|
||||
// addPorts is a helper for ports validation that skips zero ports.
|
||||
func addPorts[T tcpPort | udpPort](uc aghalg.UniqChecker[T], ports ...T) {
|
||||
|
||||
@@ -24,11 +24,9 @@ import (
|
||||
// addresses to a slice of strings.
|
||||
func appendDNSAddrs(dst []string, addrs ...netip.Addr) (res []string) {
|
||||
for _, addr := range addrs {
|
||||
var hostport string
|
||||
if config.DNS.Port != defaultPortDNS {
|
||||
hostport = netip.AddrPortFrom(addr, uint16(config.DNS.Port)).String()
|
||||
} else {
|
||||
hostport = addr.String()
|
||||
hostport := addr.String()
|
||||
if p := config.DNS.Port; p != defaultPortDNS {
|
||||
hostport = netutil.JoinHostPort(hostport, p)
|
||||
}
|
||||
|
||||
dst = append(dst, hostport)
|
||||
@@ -102,7 +100,7 @@ type statusResponse struct {
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
DNSPort uint16 `json:"dns_port"`
|
||||
HTTPPort uint16 `json:"http_port"`
|
||||
|
||||
// ProtectionDisabledDuration is the duration of the protection pause in
|
||||
@@ -340,7 +338,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
var (
|
||||
forceHTTPS bool
|
||||
serveHTTP3 bool
|
||||
portHTTPS int
|
||||
portHTTPS uint16
|
||||
)
|
||||
func() {
|
||||
config.RLock()
|
||||
@@ -394,7 +392,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
|
||||
// httpsURL returns a copy of u for redirection to the HTTPS version, taking the
|
||||
// hostname and the HTTPS port into account.
|
||||
func httpsURL(u *url.URL, host string, portHTTPS int) (redirectURL *url.URL) {
|
||||
func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL) {
|
||||
hostPort := host
|
||||
if portHTTPS != defaultPortHTTPS {
|
||||
hostPort = netutil.JoinHostPort(host, portHTTPS)
|
||||
|
||||
@@ -43,8 +43,8 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
data := getAddrsResponse{
|
||||
Version: version.Version(),
|
||||
|
||||
WebPort: defaultPortHTTP,
|
||||
DNSPort: defaultPortDNS,
|
||||
WebPort: int(defaultPortHTTP),
|
||||
DNSPort: int(defaultPortDNS),
|
||||
}
|
||||
|
||||
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
|
||||
@@ -64,7 +64,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
type checkConfReqEnt struct {
|
||||
IP netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Port uint16 `json:"port"`
|
||||
Autofix bool `json:"autofix"`
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
// TODO(a.garipov): Declare all port variables anywhere as uint16.
|
||||
reqPort := uint16(req.Web.Port)
|
||||
reqPort := req.Web.Port
|
||||
port := tcpPort(reqPort)
|
||||
addPorts(tcpPorts, port)
|
||||
if err = tcpPorts.Validate(); err != nil {
|
||||
@@ -128,7 +128,7 @@ func (req *checkConfReq) validateDNS(
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
port := uint16(req.DNS.Port)
|
||||
port := req.DNS.Port
|
||||
switch port {
|
||||
case 0:
|
||||
return false, nil
|
||||
@@ -142,13 +142,13 @@ func (req *checkConfReq) validateDNS(
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
if !aghnet.IsAddrInUse(err) {
|
||||
return false, err
|
||||
}
|
||||
@@ -160,7 +160,7 @@ func (req *checkConfReq) validateDNS(
|
||||
log.Error("disabling DNSStubListener: %s", err)
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
canAutofix = false
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ func disableDNSStubListener() error {
|
||||
|
||||
type applyConfigReqEnt struct {
|
||||
IP netip.Addr `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Port uint16 `json:"port"`
|
||||
}
|
||||
|
||||
type applyConfigReq struct {
|
||||
@@ -395,14 +395,14 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -413,7 +413,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
copyInstallSettings(curConfig, config)
|
||||
|
||||
Context.firstRun = false
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port))
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
||||
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
|
||||
config.DNS.Port = req.DNS.Port
|
||||
|
||||
@@ -445,8 +445,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
web.conf.firstRun = false
|
||||
web.conf.BindHost = req.Web.IP
|
||||
web.conf.BindPort = req.Web.Port
|
||||
web.conf.BindAddr = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
|
||||
|
||||
registerControlHandlers(web)
|
||||
|
||||
@@ -487,9 +486,9 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
}
|
||||
|
||||
addrPort := config.HTTPConfig.Address
|
||||
restartHTTP = addrPort.Addr() != req.Web.IP || int(addrPort.Port()) != req.Web.Port
|
||||
restartHTTP = addrPort.Addr() != req.Web.IP || addrPort.Port() != req.Web.Port
|
||||
if restartHTTP {
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, req.Web.Port))
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf(
|
||||
"checking address %s:%d: %w",
|
||||
|
||||
@@ -27,11 +27,11 @@ import (
|
||||
|
||||
// Default listening ports.
|
||||
const (
|
||||
defaultPortDNS = 53
|
||||
defaultPortHTTP = 80
|
||||
defaultPortHTTPS = 443
|
||||
defaultPortQUIC = 853
|
||||
defaultPortTLS = 853
|
||||
defaultPortDNS uint16 = 53
|
||||
defaultPortHTTP uint16 = 80
|
||||
defaultPortHTTPS uint16 = 443
|
||||
defaultPortQUIC uint16 = 853
|
||||
defaultPortTLS uint16 = 853
|
||||
)
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
@@ -196,27 +196,27 @@ func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
|
||||
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tcpAddrs = make([]*net.TCPAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
|
||||
tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, port)))
|
||||
}
|
||||
|
||||
return tcpAddrs
|
||||
}
|
||||
|
||||
func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
|
||||
if ips == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddrs = make([]*net.UDPAddr, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
|
||||
udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, port)))
|
||||
}
|
||||
|
||||
return udpAddrs
|
||||
@@ -346,8 +346,8 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
hostname := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
addr := hostname
|
||||
if tlsConf.PortHTTPS != defaultPortHTTPS {
|
||||
addr = netutil.JoinHostPort(addr, tlsConf.PortHTTPS)
|
||||
if p := tlsConf.PortHTTPS; p != defaultPortHTTPS {
|
||||
addr = netutil.JoinHostPort(addr, p)
|
||||
}
|
||||
|
||||
de.https = (&url.URL{
|
||||
@@ -357,17 +357,17 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
}).String()
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverTLS != 0 {
|
||||
if p := tlsConf.PortDNSOverTLS; p != 0 {
|
||||
de.tls = (&url.URL{
|
||||
Scheme: "tls",
|
||||
Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverTLS),
|
||||
Host: netutil.JoinHostPort(hostname, p),
|
||||
}).String()
|
||||
}
|
||||
|
||||
if tlsConf.PortDNSOverQUIC != 0 {
|
||||
if p := tlsConf.PortDNSOverQUIC; p != 0 {
|
||||
de.quic = (&url.URL{
|
||||
Scheme: "quic",
|
||||
Host: netutil.JoinHostPort(hostname, tlsConf.PortDNSOverQUIC),
|
||||
Host: netutil.JoinHostPort(hostname, p),
|
||||
}).String()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,7 +329,7 @@ func setupBindOpts(opts options) (err error) {
|
||||
if opts.bindPort != 0 {
|
||||
config.HTTPConfig.Address = netip.AddrPortFrom(
|
||||
config.HTTPConfig.Address.Addr(),
|
||||
uint16(opts.bindPort),
|
||||
opts.bindPort,
|
||||
)
|
||||
|
||||
err = checkPorts()
|
||||
@@ -497,8 +497,7 @@ func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webA
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
BindHost: config.HTTPConfig.Address.Addr(),
|
||||
BindPort: int(config.HTTPConfig.Address.Port()),
|
||||
BindAddr: config.HTTPConfig.Address,
|
||||
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHdrTimeout,
|
||||
@@ -853,7 +852,7 @@ func loadCmdLineOpts() (opts options) {
|
||||
// example:
|
||||
//
|
||||
// go to http://127.0.0.1:80
|
||||
func printWebAddrs(proto, addr string, port int) {
|
||||
func printWebAddrs(proto, addr string, port uint16) {
|
||||
log.Printf("go to %s://%s", proto, netutil.JoinHostPort(addr, port))
|
||||
}
|
||||
|
||||
@@ -865,7 +864,7 @@ func printHTTPAddresses(proto string) {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
|
||||
port := int(config.HTTPConfig.Address.Port())
|
||||
port := config.HTTPConfig.Address.Port()
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ type options struct {
|
||||
// bindPort is the port on which to serve the HTTP UI.
|
||||
//
|
||||
// Deprecated: Use bindAddr.
|
||||
bindPort int
|
||||
bindPort uint16
|
||||
|
||||
// bindAddr is the address to serve the web UI on.
|
||||
bindAddr netip.AddrPort
|
||||
@@ -160,15 +160,11 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
shortName: "h",
|
||||
}, {
|
||||
updateWithValue: func(o options, v string) (options, error) {
|
||||
var err error
|
||||
var p int
|
||||
minPort, maxPort := 0, 1<<16-1
|
||||
if p, err = strconv.Atoi(v); err != nil {
|
||||
err = fmt.Errorf("port %q is not a number", v)
|
||||
} else if p < minPort || p > maxPort {
|
||||
err = fmt.Errorf("port %d not in range %d - %d", p, minPort, maxPort)
|
||||
p, err := strconv.ParseUint(v, 10, 16)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("parsing port: %w", err)
|
||||
} else {
|
||||
o.bindPort = p
|
||||
o.bindPort = uint16(p)
|
||||
}
|
||||
|
||||
return o, err
|
||||
@@ -180,7 +176,7 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
return "", false
|
||||
}
|
||||
|
||||
return strconv.Itoa(o.bindPort), true
|
||||
return strconv.Itoa(int(o.bindPort)), true
|
||||
},
|
||||
description: "Deprecated. Port to serve HTTP pages on. Use --web-addr.",
|
||||
longName: "port",
|
||||
|
||||
@@ -67,11 +67,11 @@ func TestParseBindHost(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseBindPort(t *testing.T) {
|
||||
assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0")
|
||||
assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port")
|
||||
assert.Equal(t, uint16(0), testParseOK(t).bindPort, "empty is port 0")
|
||||
assert.Equal(t, uint16(65535), testParseOK(t, "-p", "65535").bindPort, "-p is port")
|
||||
testParseParamMissing(t, "-p")
|
||||
|
||||
assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port")
|
||||
assert.Equal(t, uint16(65535), testParseOK(t, "--port", "65535").bindPort, "--port is port")
|
||||
testParseParamMissing(t, "--port")
|
||||
|
||||
testParseErr(t, "not an int", "-p", "x")
|
||||
|
||||
@@ -39,8 +39,8 @@ type webConfig struct {
|
||||
|
||||
clientFS fs.FS
|
||||
|
||||
BindHost netip.Addr
|
||||
BindPort int
|
||||
// BindAddr is the binding address with port for plain HTTP web interface.
|
||||
BindAddr netip.AddrPort
|
||||
|
||||
// ReadTimeout is an option to pass to http.Server for setting an
|
||||
// appropriate field.
|
||||
@@ -125,12 +125,12 @@ func newWebAPI(conf *webConfig) (w *webAPI) {
|
||||
// available, unless the HTTPS server isn't active.
|
||||
//
|
||||
// TODO(a.garipov): Adapt for HTTP/3.
|
||||
func webCheckPortAvailable(port int) (ok bool) {
|
||||
func webCheckPortAvailable(port uint16) (ok bool) {
|
||||
if Context.web.httpsServer.server != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), uint16(port))
|
||||
addrPort := netip.AddrPortFrom(config.HTTPConfig.Address.Addr(), port)
|
||||
|
||||
return aghnet.CheckPort("tcp", addrPort) == nil
|
||||
}
|
||||
@@ -185,10 +185,9 @@ func (web *webAPI) start() {
|
||||
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
|
||||
|
||||
// Create a new instance, because the Web is not usable after Shutdown.
|
||||
hostStr := web.conf.BindHost.String()
|
||||
web.httpServer = &http.Server{
|
||||
ErrorLog: log.StdLog("web: plain", log.DEBUG),
|
||||
Addr: netutil.JoinHostPort(hostStr, web.conf.BindPort),
|
||||
Addr: web.conf.BindAddr.String(),
|
||||
Handler: hdlr,
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
@@ -249,7 +248,7 @@ func (web *webAPI) tlsServerLoop() {
|
||||
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
|
||||
var portHTTPS int
|
||||
var portHTTPS uint16
|
||||
func() {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
@@ -257,7 +256,7 @@ func (web *webAPI) tlsServerLoop() {
|
||||
portHTTPS = config.TLS.PortHTTPS
|
||||
}()
|
||||
|
||||
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
|
||||
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
|
||||
web.httpsServer.server = &http.Server{
|
||||
ErrorLog: log.StdLog("web: https", log.DEBUG),
|
||||
Addr: addr,
|
||||
|
||||
Reference in New Issue
Block a user