Compare commits
6 Commits
v0.108.0-b
...
4927-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c7d56dca3 | ||
|
|
08282dc4d9 | ||
|
|
f36efa26a4 | ||
|
|
a8850059db | ||
|
|
93882d6860 | ||
|
|
167b112511 |
@@ -25,6 +25,13 @@ See also the [v0.107.19 GitHub milestone][ms-v0.107.19].
|
||||
[ms-v0.107.19]: https://github.com/AdguardTeam/AdGuardHome/milestone/55?closed=1
|
||||
-->
|
||||
|
||||
### Added
|
||||
|
||||
- The new `--update` command-line option, which allows updating AdGuard Home
|
||||
silently ([#4223]).
|
||||
|
||||
[#4223]: https://github.com/AdguardTeam/AdGuardHome/issues/4223
|
||||
|
||||
|
||||
|
||||
## [v0.107.18] - 2022-11-08
|
||||
|
||||
@@ -393,6 +393,7 @@
|
||||
"encryption_issuer": "Issuer",
|
||||
"encryption_hostnames": "Hostnames",
|
||||
"encryption_reset": "Are you sure you want to reset encryption settings?",
|
||||
"encryption_warning": "Warning",
|
||||
"topline_expiring_certificate": "Your SSL certificate is about to expire. Update <0>Encryption settings</0>.",
|
||||
"topline_expired_certificate": "Your SSL certificate is expired. Update <0>Encryption settings</0>.",
|
||||
"form_error_port_range": "Enter port number in the range of 80-65535",
|
||||
|
||||
@@ -56,6 +56,26 @@ const clearFields = (change, setTlsConfig, t) => {
|
||||
}
|
||||
};
|
||||
|
||||
const validationMessage = (warningValidation, isWarning) => {
|
||||
if (!warningValidation) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isWarning) {
|
||||
return (
|
||||
<div className="col-12">
|
||||
<p><Trans>encryption_warning</Trans>: {warningValidation}</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="col-12">
|
||||
<p className="text-danger">{warningValidation}</p>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
let Form = (props) => {
|
||||
const {
|
||||
t,
|
||||
@@ -95,6 +115,8 @@ let Form = (props) => {
|
||||
|| !valid_cert
|
||||
|| !valid_pair;
|
||||
|
||||
const isWarning = valid_key && valid_cert && valid_pair;
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<div className="row">
|
||||
@@ -382,11 +404,7 @@ let Form = (props) => {
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{warning_validation && (
|
||||
<div className="col-12">
|
||||
<p className="text-danger">{warning_validation}</p>
|
||||
</div>
|
||||
)}
|
||||
{validationMessage(warning_validation, isWarning)}
|
||||
</div>
|
||||
|
||||
<div className="btn-list mt-2">
|
||||
|
||||
@@ -31,12 +31,6 @@ var (
|
||||
// the IP being static is available.
|
||||
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
|
||||
|
||||
// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4].
|
||||
func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
|
||||
|
||||
// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6].
|
||||
func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) }
|
||||
|
||||
// IfaceHasStaticIP checks if interface is configured to have static IP address.
|
||||
// If it can't give a definitive answer, it returns false and an error for which
|
||||
// errors.Is(err, ErrNoStaticIPInfo) is true.
|
||||
|
||||
@@ -188,7 +188,7 @@ func TestBroadcastFromIPNet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckPort(t *testing.T) {
|
||||
laddr := netip.AddrPortFrom(IPv4Localhost(), 0)
|
||||
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
||||
|
||||
t.Run("tcp_bound", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", laddr.String())
|
||||
|
||||
@@ -23,16 +23,6 @@ func ValidateClientID(id string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's
|
||||
// a helper function to prevent unnecessary allocations in code like:
|
||||
//
|
||||
// if strings.HasSuffix(s, "." + suffix) { /* … */ }
|
||||
//
|
||||
// s must be longer than suffix.
|
||||
func hasLabelSuffix(s, suffix string) (ok bool) {
|
||||
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
|
||||
}
|
||||
|
||||
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
|
||||
// is the server name of the host. cliSrvName is the server name as sent by the
|
||||
// client. When strict is true, and client and host server name don't match,
|
||||
@@ -46,7 +36,7 @@ func clientIDFromClientServerName(
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if !hasLabelSuffix(cliSrvName, hostSrvName) {
|
||||
if !netutil.IsImmediateSubdomain(cliSrvName, hostSrvName) {
|
||||
if !strict {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -145,7 +145,8 @@ type FilteringConfig struct {
|
||||
IpsetListFileName string `yaml:"ipset_file"`
|
||||
}
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, DNS-over-TLS,
|
||||
// and DNS-over-QUIC.
|
||||
type TLSConfig struct {
|
||||
cert tls.Certificate
|
||||
|
||||
|
||||
@@ -246,6 +246,7 @@ type RDNSExchanger interface {
|
||||
// Exchange tries to resolve the ip in a suitable way, e.g. either as
|
||||
// local or as external.
|
||||
Exchange(ip net.IP) (host string, err error)
|
||||
|
||||
// ResolvesPrivatePTR returns true if the RDNSExchanger is able to
|
||||
// resolve PTR requests for locally-served addresses.
|
||||
ResolvesPrivatePTR() (ok bool)
|
||||
@@ -261,6 +262,9 @@ const (
|
||||
rDNSNotPTRErr errors.Error = "the response is not a ptr"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ RDNSExchanger = (*Server)(nil)
|
||||
|
||||
// Exchange implements the RDNSExchanger interface for *Server.
|
||||
func (s *Server) Exchange(ip net.IP) (host string, err error) {
|
||||
s.serverLock.RLock()
|
||||
@@ -675,21 +679,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// IsBlockedClient returns true if the client is blocked by the current access
|
||||
// settings.
|
||||
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
|
||||
func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
blockedByIP := false
|
||||
if ip != nil {
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: bad client ip %v: %s", ip, err)
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
blockedByIP, rule = s.access.isBlockedIP(ipAddr)
|
||||
if ip != (netip.Addr{}) {
|
||||
blockedByIP, rule = s.access.isBlockedIP(ip)
|
||||
}
|
||||
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
|
||||
@@ -19,13 +19,13 @@ func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(ip, clientID)
|
||||
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
|
||||
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
|
||||
if blocked {
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -40,7 +40,7 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
|
||||
addr := l.Addr()
|
||||
require.IsType(t, new(net.TCPAddr), addr)
|
||||
|
||||
return netip.AddrPortFrom(aghnet.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
|
||||
return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
|
||||
}
|
||||
|
||||
func TestFilters(t *testing.T) {
|
||||
|
||||
@@ -129,7 +129,7 @@ type RuntimeClientWHOISInfo struct {
|
||||
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for
|
||||
// different types (string, net.IP, and so on).
|
||||
// different types (string, netip.Addr, and so on).
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
|
||||
@@ -333,7 +333,7 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
|
||||
// exists checks if client with this IP address already exists.
|
||||
func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) {
|
||||
func (clients *clientsContainer) exists(ip netip.Addr, source clientSource) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -342,7 +342,7 @@ func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool
|
||||
return true
|
||||
}
|
||||
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -371,7 +371,8 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||
var artClient *querylog.Client
|
||||
var art bool
|
||||
for _, id := range ids {
|
||||
c, art = clients.clientOrArtificial(net.ParseIP(id), id)
|
||||
ip, _ := netip.ParseAddr(id)
|
||||
c, art = clients.clientOrArtificial(ip, id)
|
||||
if art {
|
||||
artClient = c
|
||||
|
||||
@@ -389,7 +390,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||
// records about this client besides maybe whether or not it is blocked. c is
|
||||
// never nil.
|
||||
func (clients *clientsContainer) clientOrArtificial(
|
||||
ip net.IP,
|
||||
ip netip.Addr,
|
||||
id string,
|
||||
) (c *querylog.Client, art bool) {
|
||||
defer func() {
|
||||
@@ -406,13 +407,6 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||
}, false
|
||||
}
|
||||
|
||||
if ip == nil {
|
||||
// Technically should never happen, but still.
|
||||
return &querylog.Client{
|
||||
Name: "",
|
||||
}, true
|
||||
}
|
||||
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.findRuntimeClient(ip)
|
||||
if ok {
|
||||
@@ -492,19 +486,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
return c, true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(id)
|
||||
if ip == nil {
|
||||
ip, err := netip.ParseAddr(id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
_, ipnet, err := net.ParseCIDR(id)
|
||||
var n netip.Prefix
|
||||
n, err = netip.ParsePrefix(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipnet.Contains(ip) {
|
||||
if n.Contains(ip) {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
@@ -514,19 +509,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
macFound := clients.dhcpServer.FindMACbyIP(ip)
|
||||
macFound := clients.dhcpServer.FindMACbyIP(ip.AsSlice())
|
||||
if macFound == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, id := range c.IDs {
|
||||
hwAddr, err := net.ParseMAC(id)
|
||||
var mac net.HardwareAddr
|
||||
mac, err = net.ParseMAC(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(hwAddr, macFound) {
|
||||
if bytes.Equal(mac, macFound) {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
@@ -535,32 +531,18 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||
// internal use only.
|
||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
log.Error("clients: bad client ip %v: %s", ip, err)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
rc, ok = clients.ipToRC[ipAddr]
|
||||
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
if ip == nil {
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
return clients.findRuntimeClientLocked(ip)
|
||||
rc, ok = clients.ipToRC[ip]
|
||||
|
||||
return rc, ok
|
||||
}
|
||||
|
||||
// check validates the client.
|
||||
@@ -578,14 +560,16 @@ func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
|
||||
for i, id := range c.IDs {
|
||||
// Normalize structured data.
|
||||
var ip net.IP
|
||||
var ipnet *net.IPNet
|
||||
var mac net.HardwareAddr
|
||||
if ip = net.ParseIP(id); ip != nil {
|
||||
var (
|
||||
ip netip.Addr
|
||||
n netip.Prefix
|
||||
mac net.HardwareAddr
|
||||
)
|
||||
|
||||
if ip, err = netip.ParseAddr(id); err == nil {
|
||||
c.IDs[i] = ip.String()
|
||||
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
|
||||
ipnet.IP = ip
|
||||
c.IDs[i] = ipnet.String()
|
||||
} else if n, err = netip.ParsePrefix(id); err == nil {
|
||||
c.IDs[i] = n.String()
|
||||
} else if mac, err = net.ParseMAC(id); err == nil {
|
||||
c.IDs[i] = mac.String()
|
||||
} else if err = dnsforward.ValidateClientID(id); err == nil {
|
||||
@@ -750,7 +734,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||
}
|
||||
|
||||
// setWHOISInfo sets the WHOIS information for a client.
|
||||
func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
|
||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@@ -760,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||
return
|
||||
}
|
||||
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
rc.WHOISInfo = wi
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||
@@ -776,32 +760,22 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
|
||||
|
||||
rc.WHOISInfo = wi
|
||||
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
log.Error("clients: bad client ip %v: %s", ip, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
clients.ipToRC[ipAddr] = rc
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||
func (clients *clientsContainer) AddHost(
|
||||
ip netip.Addr,
|
||||
host string,
|
||||
src clientSource,
|
||||
) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("adding host: %w", err)
|
||||
}
|
||||
|
||||
return clients.addHostLocked(ipAddr, host, src), nil
|
||||
return clients.addHostLocked(ip, host, src)
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
||||
|
||||
@@ -22,8 +22,18 @@ func TestClients(t *testing.T) {
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
|
||||
t.Run("add_success", func(t *testing.T) {
|
||||
var (
|
||||
cliNone = "1.2.3.4"
|
||||
cli1 = "1.1.1.1"
|
||||
cli2 = "2.2.2.2"
|
||||
|
||||
cliNoneIP = netip.MustParseAddr(cliNone)
|
||||
cli1IP = netip.MustParseAddr(cli1)
|
||||
cli2IP = netip.MustParseAddr(cli2)
|
||||
)
|
||||
|
||||
c := &Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
|
||||
Name: "client1",
|
||||
}
|
||||
|
||||
@@ -33,7 +43,7 @@ func TestClients(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &Client{
|
||||
IDs: []string{"2.2.2.2"},
|
||||
IDs: []string{cli2},
|
||||
Name: "client2",
|
||||
}
|
||||
|
||||
@@ -42,7 +52,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.Find("1.1.1.1")
|
||||
c, ok = clients.Find(cli1)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
@@ -52,14 +62,14 @@ func TestClients(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find("2.2.2.2")
|
||||
c, ok = clients.Find(cli2)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
assert.False(t, clients.exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
|
||||
assert.False(t, clients.exists(cliNoneIP, ClientSourceHostsFile))
|
||||
assert.True(t, clients.exists(cli1IP, ClientSourceHostsFile))
|
||||
assert.True(t, clients.exists(cli2IP, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
@@ -103,23 +113,31 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update_success", func(t *testing.T) {
|
||||
var (
|
||||
cliOld = "1.1.1.1"
|
||||
cliNew = "1.1.1.2"
|
||||
|
||||
cliOldIP = netip.MustParseAddr(cliOld)
|
||||
cliNewIP = netip.MustParseAddr(cliNew)
|
||||
)
|
||||
|
||||
err := clients.Update("client1", &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
IDs: []string{cliNew},
|
||||
Name: "client1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||
assert.False(t, clients.exists(cliOldIP, ClientSourceHostsFile))
|
||||
assert.True(t, clients.exists(cliNewIP, ClientSourceHostsFile))
|
||||
|
||||
err = clients.Update("client1", &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
IDs: []string{cliNew},
|
||||
Name: "client1-renamed",
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
c, ok := clients.Find("1.1.1.2")
|
||||
c, ok := clients.Find(cliNew)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
@@ -132,14 +150,14 @@ func TestClients(t *testing.T) {
|
||||
|
||||
require.Len(t, c.IDs, 1)
|
||||
|
||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
||||
assert.Equal(t, cliNew, c.IDs[0])
|
||||
})
|
||||
|
||||
t.Run("del_success", func(t *testing.T) {
|
||||
ok := clients.Del("client1-renamed")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.False(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||
assert.False(t, clients.exists(netip.MustParseAddr("1.1.1.2"), ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
@@ -148,45 +166,33 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
|
||||
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.AddHost(ip, "host", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.exists(ip, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, clients.exists(ip, ClientSourceARP))
|
||||
|
||||
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, clients.exists(ip, ClientSourceDHCP))
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -203,7 +209,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.255")
|
||||
clients.setWHOISInfo(ip.AsSlice(), whois)
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
@@ -212,12 +218,10 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
ok, err := clients.AddHost(ip.AsSlice(), "host", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip.AsSlice(), whois)
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
require.NotNil(t, rc)
|
||||
|
||||
@@ -234,7 +238,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip.AsSlice(), whois)
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.ipToRC[ip]
|
||||
require.Nil(t, rc)
|
||||
|
||||
@@ -249,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
clients.Init(nil, nil, nil, nil)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.Add(&Client{
|
||||
@@ -260,8 +264,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package home
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
)
|
||||
@@ -47,8 +47,8 @@ type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
Name string `json:"name"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Source clientSource `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
@@ -75,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
|
||||
Name: rc.Host,
|
||||
Source: rc.Source,
|
||||
IP: ip.AsSlice(),
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
@@ -218,7 +218,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
break
|
||||
}
|
||||
|
||||
ip := net.ParseIP(idStr)
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.Find(idStr)
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
@@ -240,7 +240,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||
// non-nil.
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.findRuntimeClient(ip)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime clients
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/exp/slices"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -113,8 +114,8 @@ type configuration struct {
|
||||
// An active session is automatically refreshed once a day.
|
||||
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
|
||||
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfiguration `yaml:"tls"`
|
||||
|
||||
// Filters reflects the filters from [filtering.Config]. It's cloned to the
|
||||
// config used in the filtering module at the startup. Afterwards it's
|
||||
@@ -199,7 +200,8 @@ type dnsConfig struct {
|
||||
UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"`
|
||||
}
|
||||
|
||||
type tlsConfigSettings struct {
|
||||
// tlsConfiguration is the on-disk TLS configuration.
|
||||
type tlsConfiguration struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
@@ -223,6 +225,29 @@ type tlsConfigSettings struct {
|
||||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||
}
|
||||
|
||||
// cloneForEncoding returns a clone of c with all top-level fields of c and all
|
||||
// exported and YAML-encoded fields of c.TLSConfig cloned.
|
||||
//
|
||||
// TODO(a.garipov): This is better than races, but still not good enough.
|
||||
func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := *c
|
||||
cloned = &v
|
||||
cloned.TLSConfig = dnsforward.TLSConfig{
|
||||
CertificateChain: c.CertificateChain,
|
||||
PrivateKey: c.PrivateKey,
|
||||
CertificatePath: c.CertificatePath,
|
||||
PrivateKeyPath: c.PrivateKeyPath,
|
||||
OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers),
|
||||
StrictSNICheck: c.StrictSNICheck,
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
// config is the global configuration structure.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
||||
@@ -273,7 +298,7 @@ var config = &configuration{
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
TLS: tlsConfigSettings{
|
||||
TLS: tlsConfiguration{
|
||||
PortHTTPS: defaultPortHTTPS,
|
||||
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
||||
PortDNSOverQUIC: defaultPortQUIC,
|
||||
@@ -442,7 +467,7 @@ func (c *configuration) write() (err error) {
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
tlsConf := tlsConfiguration{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
config.TLS = tlsConf
|
||||
}
|
||||
|
||||
@@ -71,9 +71,7 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
|
||||
// on, including the addresses on all interfaces in cases of unspecified IPs.
|
||||
func collectDNSAddresses() (addrs []string, err error) {
|
||||
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
|
||||
addr := aghnet.IPv4Localhost()
|
||||
|
||||
addrs = appendDNSAddrs(addrs, addr)
|
||||
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
|
||||
} else {
|
||||
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
|
||||
if err != nil {
|
||||
|
||||
@@ -154,7 +154,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
tlsConf := &tlsConfiguration{}
|
||||
Context.tls.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
@@ -172,7 +172,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
|
||||
|
||||
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
|
||||
// indicates that privileged ports are used.
|
||||
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
func tlsConfUsesPrivilegedPorts(c *tlsConfiguration) (ok bool) {
|
||||
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@@ -150,8 +151,8 @@ func isRunning() bool {
|
||||
}
|
||||
|
||||
func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
|
||||
if ip == nil {
|
||||
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
|
||||
if ip == (netip.Addr{}) {
|
||||
// This would be quite weird if we get here.
|
||||
return
|
||||
}
|
||||
@@ -160,7 +161,8 @@ func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
||||
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
@@ -193,11 +195,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
|
||||
|
||||
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
dnsConf := config.DNS
|
||||
hosts := dnsConf.BindHosts
|
||||
if len(hosts) == 0 {
|
||||
hosts = []netip.Addr{aghnet.IPv4Localhost()}
|
||||
}
|
||||
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
newConf = dnsforward.ServerConfig{
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
@@ -207,7 +205,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
OnDNSRequest: onDNSRequest,
|
||||
}
|
||||
|
||||
tlsConf := tlsConfigSettings{}
|
||||
tlsConf := tlsConfiguration{}
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
if tlsConf.Enabled {
|
||||
newConf.TLSConfig = tlsConf.TLSConfig
|
||||
@@ -252,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
return newConf, nil
|
||||
}
|
||||
|
||||
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
|
||||
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfiguration) (dnscc dnsforward.DNSCryptConfig, err error) {
|
||||
if tlsConf.DNSCryptConfigFile == "" {
|
||||
return dnscc, errors.Error("no dnscrypt_config_file")
|
||||
}
|
||||
@@ -290,7 +288,7 @@ type dnsEncryption struct {
|
||||
}
|
||||
|
||||
func getDNSEncryption() (de dnsEncryption) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
tlsConf := tlsConfiguration{}
|
||||
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
|
||||
@@ -400,15 +398,12 @@ func startDNSServer() error {
|
||||
|
||||
const topClientsNumber = 100 // the number of clients to get
|
||||
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcs := config.Clients.Sources
|
||||
if srcs.RDNS && !ip.IsLoopback() {
|
||||
Context.rdns.Begin(ip)
|
||||
}
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
|
||||
|
||||
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
|
||||
Context.whois.Begin(ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,7 +512,7 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
}
|
||||
config.Users = nil
|
||||
|
||||
Context.tls, err = newTLSManager(config.TLS)
|
||||
Context.tls, err = newTLSManager(&config.TLS)
|
||||
if err != nil {
|
||||
log.Fatalf("initializing tls: %s", err)
|
||||
}
|
||||
@@ -542,6 +542,11 @@ func run(opts options, clientBuildFS fs.FS) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(a.garipov): This could be made much earlier and could be done on
|
||||
// the first run as well, but to achieve this we need to bypass requests
|
||||
// over dnsforward resolver.
|
||||
cmdlineUpdate(opts)
|
||||
|
||||
Context.web.Start()
|
||||
|
||||
// wait indefinitely for other go-routines to complete their job
|
||||
@@ -576,7 +581,7 @@ func checkPermissions() {
|
||||
}
|
||||
|
||||
// We should check if AdGuard Home is able to bind to port 53
|
||||
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS))
|
||||
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(netutil.IPv4Localhost(), defaultPortDNS))
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
log.Fatal(`Permission check failed.
|
||||
@@ -812,7 +817,7 @@ func printWebAddrs(proto, addr string, port, betaPort int) {
|
||||
// printHTTPAddresses prints the IP addresses which user can use to access the
|
||||
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
||||
func printHTTPAddresses(proto string) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
tlsConf := tlsConfiguration{}
|
||||
if Context.tls != nil {
|
||||
Context.tls.WriteDiskConfig(&tlsConf)
|
||||
}
|
||||
@@ -927,3 +932,37 @@ type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// cmdlineUpdate updates current application and exits.
|
||||
func cmdlineUpdate(opts options) {
|
||||
if !opts.performUpdate {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("starting update")
|
||||
|
||||
if Context.firstRun {
|
||||
log.Info("update not allowed on first run")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
_, err := Context.updater.VersionInfo(true)
|
||||
if err != nil {
|
||||
vcu := Context.updater.VersionCheckURL()
|
||||
log.Error("getting version info from %s: %s", vcu, err)
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if Context.updater.NewVersion() == "" {
|
||||
log.Info("no updates available")
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
err = Context.updater.Update()
|
||||
fatalOnError(err)
|
||||
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -32,7 +32,11 @@ func setupDNSIPs(t testing.TB) {
|
||||
},
|
||||
}
|
||||
|
||||
Context.tls = &tlsManager{}
|
||||
var err error
|
||||
Context.tls, err = newTLSManager(&tlsConfiguration{
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHandleMobileConfigDoH(t *testing.T) {
|
||||
@@ -65,7 +69,11 @@ func TestHandleMobileConfigDoH(t *testing.T) {
|
||||
oldTLSConf := Context.tls
|
||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||
|
||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||
var err error
|
||||
Context.tls, err = newTLSManager(&tlsConfiguration{
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -137,7 +145,11 @@ func TestHandleMobileConfigDoT(t *testing.T) {
|
||||
oldTLSConf := Context.tls
|
||||
t.Cleanup(func() { Context.tls = oldTLSConf })
|
||||
|
||||
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
|
||||
var err error
|
||||
Context.tls, err = newTLSManager(&tlsConfiguration{
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -47,6 +47,9 @@ type options struct {
|
||||
// disableUpdate, if set, makes AdGuard Home not check for updates.
|
||||
disableUpdate bool
|
||||
|
||||
// performUpdate, if set, updates AdGuard Home without GUI and exits.
|
||||
performUpdate bool
|
||||
|
||||
// verbose shows if verbose logging is enabled.
|
||||
verbose bool
|
||||
|
||||
@@ -221,6 +224,14 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
description: "Don't check for updates.",
|
||||
longName: "no-check-update",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.performUpdate = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.performUpdate },
|
||||
description: "Update application and exit.",
|
||||
longName: "update",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
|
||||
@@ -103,6 +103,11 @@ func TestParseDisableUpdate(t *testing.T) {
|
||||
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
|
||||
}
|
||||
|
||||
func TestParsePerformUpdate(t *testing.T) {
|
||||
assert.False(t, testParseOK(t).performUpdate, "empty is not perform update")
|
||||
assert.True(t, testParseOK(t, "--update").performUpdate, "--update is perform update")
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Remove after v0.108.0.
|
||||
func TestParseDisableMemoryOptimization(t *testing.T) {
|
||||
o, eff, err := parseCmdOpts("", []string{"--no-mem-optimization"})
|
||||
@@ -169,6 +174,10 @@ func TestOptsToArgs(t *testing.T) {
|
||||
name: "disable_update",
|
||||
args: []string{"--no-check-update"},
|
||||
opts: options{disableUpdate: true},
|
||||
}, {
|
||||
name: "perform_update",
|
||||
args: []string{"--update"},
|
||||
opts: options{performUpdate: true},
|
||||
}, {
|
||||
name: "control_action",
|
||||
args: []string{"-s", "run"},
|
||||
|
||||
@@ -2,7 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +21,7 @@ type RDNS struct {
|
||||
usePrivate uint32
|
||||
|
||||
// ipCh used to pass client's IP to rDNS workerLoop.
|
||||
ipCh chan net.IP
|
||||
ipCh chan netip.Addr
|
||||
|
||||
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
|
||||
// address stays here while it's inside clients. After leaving clients the
|
||||
@@ -50,7 +50,7 @@ func NewRDNS(
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultRDNSCacheSize,
|
||||
}),
|
||||
ipCh: make(chan net.IP, defaultRDNSIPChSize),
|
||||
ipCh: make(chan netip.Addr, defaultRDNSIPChSize),
|
||||
}
|
||||
if usePrivate {
|
||||
rDNS.usePrivate = 1
|
||||
@@ -80,9 +80,10 @@ func (r *RDNS) ensurePrivateCache() {
|
||||
|
||||
// isCached returns true if ip is already cached and not expired yet. It also
|
||||
// caches it otherwise.
|
||||
func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
||||
func (r *RDNS) isCached(ip netip.Addr) (ok bool) {
|
||||
ipBytes := ip.AsSlice()
|
||||
now := uint64(time.Now().Unix())
|
||||
if expire := r.ipCache.Get(ip); len(expire) != 0 {
|
||||
if expire := r.ipCache.Get(ipBytes); len(expire) != 0 {
|
||||
if binary.BigEndian.Uint64(expire) > now {
|
||||
return true
|
||||
}
|
||||
@@ -91,14 +92,14 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
||||
// The cache entry either expired or doesn't exist.
|
||||
ttl := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL)
|
||||
r.ipCache.Set(ip, ttl)
|
||||
r.ipCache.Set(ipBytes, ttl)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Begin adds the ip to the resolving queue if it is not cached or already
|
||||
// resolved.
|
||||
func (r *RDNS) Begin(ip net.IP) {
|
||||
func (r *RDNS) Begin(ip netip.Addr) {
|
||||
r.ensurePrivateCache()
|
||||
|
||||
if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) {
|
||||
@@ -107,9 +108,9 @@ func (r *RDNS) Begin(ip net.IP) {
|
||||
|
||||
select {
|
||||
case r.ipCh <- ip:
|
||||
log.Tracef("rdns: %q added to queue", ip)
|
||||
log.Debug("rdns: %q added to queue", ip)
|
||||
default:
|
||||
log.Tracef("rdns: queue is full")
|
||||
log.Debug("rdns: queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +120,7 @@ func (r *RDNS) workerLoop() {
|
||||
defer log.OnPanic("rdns")
|
||||
|
||||
for ip := range r.ipCh {
|
||||
host, err := r.exchanger.Exchange(ip)
|
||||
host, err := r.exchanger.Exchange(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||
|
||||
@@ -128,8 +129,6 @@ func (r *RDNS) workerLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't handle any errors since AddHost doesn't return non-nil errors
|
||||
// for now.
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
_ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,14 +27,14 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}
|
||||
ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5")
|
||||
|
||||
testCases := []struct {
|
||||
cliIDIndex map[string]*Client
|
||||
customChan chan net.IP
|
||||
customChan chan netip.Addr
|
||||
name string
|
||||
wantLog string
|
||||
req net.IP
|
||||
ip netip.Addr
|
||||
wantCacheHit int
|
||||
wantCacheMiss int
|
||||
}{{
|
||||
@@ -42,7 +42,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
customChan: nil,
|
||||
name: "cached",
|
||||
wantLog: "",
|
||||
req: ip1234,
|
||||
ip: ip1234,
|
||||
wantCacheHit: 1,
|
||||
wantCacheMiss: 0,
|
||||
}, {
|
||||
@@ -50,7 +50,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
customChan: nil,
|
||||
name: "not_cached",
|
||||
wantLog: "rdns: queue is full",
|
||||
req: ip1235,
|
||||
ip: ip1235,
|
||||
wantCacheHit: 0,
|
||||
wantCacheMiss: 1,
|
||||
}, {
|
||||
@@ -58,15 +58,15 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
customChan: nil,
|
||||
name: "already_in_clients",
|
||||
wantLog: "",
|
||||
req: ip1235,
|
||||
ip: ip1235,
|
||||
wantCacheHit: 0,
|
||||
wantCacheMiss: 1,
|
||||
}, {
|
||||
cliIDIndex: map[string]*Client{},
|
||||
customChan: make(chan net.IP, 1),
|
||||
customChan: make(chan netip.Addr, 1),
|
||||
name: "add_to_queue",
|
||||
wantLog: `rdns: "1.2.3.5" added to queue`,
|
||||
req: ip1235,
|
||||
ip: ip1235,
|
||||
wantCacheHit: 0,
|
||||
wantCacheMiss: 1,
|
||||
}}
|
||||
@@ -102,7 +102,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rdns.Begin(tc.req)
|
||||
rdns.Begin(tc.ip)
|
||||
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
|
||||
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
|
||||
assert.Contains(t, w.String(), tc.wantLog)
|
||||
@@ -179,8 +179,8 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
w := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, w)
|
||||
|
||||
localIP := net.IP{192, 168, 1, 1}
|
||||
revIPv4, err := netutil.IPToReversedAddr(localIP)
|
||||
localIP := netip.MustParseAddr("192.168.1.1")
|
||||
revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
|
||||
@@ -201,24 +201,24 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
ups upstream.Upstream
|
||||
cliIP netip.Addr
|
||||
wantLog string
|
||||
name string
|
||||
cliIP net.IP
|
||||
}{{
|
||||
ups: locUpstream,
|
||||
cliIP: localIP,
|
||||
wantLog: "",
|
||||
name: "all_good",
|
||||
cliIP: localIP,
|
||||
}, {
|
||||
ups: errUpstream,
|
||||
cliIP: netip.MustParseAddr("192.168.1.2"),
|
||||
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
|
||||
name: "resolve_error",
|
||||
cliIP: net.IP{192, 168, 1, 2},
|
||||
}, {
|
||||
ups: locUpstream,
|
||||
cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"),
|
||||
wantLog: "",
|
||||
name: "ipv6_good",
|
||||
cliIP: net.ParseIP("2a00:1450:400c:c06::93"),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -230,7 +230,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
ipToRC: map[netip.Addr]*RuntimeClient{},
|
||||
allTags: stringutil.NewSet(),
|
||||
}
|
||||
ch := make(chan net.IP)
|
||||
ch := make(chan netip.Addr)
|
||||
rdns := &RDNS{
|
||||
exchanger: &rDNSExchanger{
|
||||
ex: tc.ups,
|
||||
|
||||
@@ -8,42 +8,39 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
||||
// encryption.
|
||||
type tlsManager struct {
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
// mu protects all fields.
|
||||
mu *sync.RWMutex
|
||||
|
||||
// certLastMod is the last modification time of the certificate file.
|
||||
certLastMod time.Time
|
||||
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
|
||||
// conf is the current TLS configuration.
|
||||
conf *tlsConfiguration
|
||||
}
|
||||
|
||||
// newTLSManager initializes the TLS configuration.
|
||||
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
||||
func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
status: &tlsConfigStatus{},
|
||||
mu: &sync.RWMutex{},
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
@@ -59,9 +56,19 @@ func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// confForEncoding returns a partial clone of the current TLS configuration. It
|
||||
// is safe for concurrent use.
|
||||
func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.conf.cloneForEncoding()
|
||||
}
|
||||
|
||||
// load reloads the TLS configuration from files or data from the config file.
|
||||
// m.mu is expected to be locked for writing.
|
||||
func (m *tlsManager) load() (err error) {
|
||||
err = loadTLSConf(&m.conf, m.status)
|
||||
err = loadTLSConf(m.conf, m.status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
@@ -70,14 +77,12 @@ func (m *tlsManager) load() (err error) {
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write config
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
m.confLock.Lock()
|
||||
*conf = m.conf
|
||||
m.confLock.Unlock()
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
|
||||
*conf = *m.confForEncoding()
|
||||
}
|
||||
|
||||
// setCertFileTime sets t.certLastMod from the certificate. If there are
|
||||
// errors, setCertFileTime logs them.
|
||||
// errors, setCertFileTime logs them. mu is expected to be locked for writing.
|
||||
func (m *tlsManager) setCertFileTime() {
|
||||
if len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
@@ -97,27 +102,22 @@ func (m *tlsManager) setCertFileTime() {
|
||||
func (m *tlsManager) start() {
|
||||
m.registerWebHandlers()
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
Context.web.TLSConfigChanged(context.Background(), m.confForEncoding())
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts t.
|
||||
// reload updates the configuration and restarts m.
|
||||
func (m *tlsManager) reload() {
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||
if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fi, err := os.Stat(tlsConf.CertificatePath)
|
||||
fi, err := os.Stat(m.conf.CertificatePath)
|
||||
if err != nil {
|
||||
log.Error("tls: %s", err)
|
||||
|
||||
@@ -132,9 +132,7 @@ func (m *tlsManager) reload() {
|
||||
|
||||
log.Debug("tls: certificate file is modified")
|
||||
|
||||
m.confLock.Lock()
|
||||
err = m.load()
|
||||
m.confLock.Unlock()
|
||||
if err != nil {
|
||||
log.Error("tls: reloading: %s", err)
|
||||
|
||||
@@ -145,19 +143,15 @@ func (m *tlsManager) reload() {
|
||||
|
||||
_ = reconfigureDNSServer()
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf = m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
Context.web.TLSConfigChanged(context.Background(), tlsConf)
|
||||
Context.web.TLSConfigChanged(context.Background(), m.conf)
|
||||
}
|
||||
|
||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||
// also set in status.WarningValidation.
|
||||
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
|
||||
func loadTLSConf(tlsConf *tlsConfiguration, status *tlsConfigStatus) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
status.WarningValidation = err.Error()
|
||||
@@ -172,13 +166,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
||||
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
|
||||
|
||||
if tlsConf.CertificatePath != "" {
|
||||
if tlsConf.CertificateChain != "" {
|
||||
return errors.Error("certificate data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||
err = loadCert(tlsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading cert file: %w", err)
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// Set status.ValidCert to true to signal the frontend that the
|
||||
@@ -187,13 +178,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
||||
}
|
||||
|
||||
if tlsConf.PrivateKeyPath != "" {
|
||||
if tlsConf.PrivateKey != "" {
|
||||
return errors.Error("private key data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||
err = loadPKey(tlsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading key file: %w", err)
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
status.ValidKey = true
|
||||
@@ -212,278 +200,29 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
|
||||
return nil
|
||||
}
|
||||
|
||||
// tlsConfigStatus contains the status of a certificate chain and key pair.
|
||||
type tlsConfigStatus struct {
|
||||
// Subject is the subject of the first certificate in the chain.
|
||||
Subject string `json:"subject,omitempty"`
|
||||
|
||||
// Issuer is the issuer of the first certificate in the chain.
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
|
||||
// KeyType is the type of the private key.
|
||||
KeyType string `json:"key_type,omitempty"`
|
||||
|
||||
// NotBefore is the NotBefore field of the first certificate in the chain.
|
||||
NotBefore time.Time `json:"not_before,omitempty"`
|
||||
|
||||
// NotAfter is the NotAfter field of the first certificate in the chain.
|
||||
NotAfter time.Time `json:"not_after,omitempty"`
|
||||
|
||||
// WarningValidation is a validation warning message with the issue
|
||||
// description.
|
||||
WarningValidation string `json:"warning_validation,omitempty"`
|
||||
|
||||
// DNSNames is the value of SubjectAltNames field of the first certificate
|
||||
// in the chain.
|
||||
DNSNames []string `json:"dns_names"`
|
||||
|
||||
// ValidCert is true if the specified certificate chain is a valid chain of
|
||||
// X509 certificates.
|
||||
ValidCert bool `json:"valid_cert"`
|
||||
|
||||
// ValidChain is true if the specified certificate chain is verified and
|
||||
// issued by a known CA.
|
||||
ValidChain bool `json:"valid_chain"`
|
||||
|
||||
// ValidKey is true if the key is a valid private key.
|
||||
ValidKey bool `json:"valid_key"`
|
||||
|
||||
// ValidPair is true if both certificate and private key are correct for
|
||||
// each other.
|
||||
ValidPair bool `json:"valid_pair"`
|
||||
}
|
||||
|
||||
// tlsConfig is the TLS configuration and status response.
|
||||
type tlsConfig struct {
|
||||
*tlsConfigStatus `json:",inline"`
|
||||
tlsConfigSettingsExt `json:",inline"`
|
||||
}
|
||||
|
||||
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
|
||||
// ensure that clients don't send and receive previously saved private keys.
|
||||
type tlsConfigSettingsExt struct {
|
||||
tlsConfigSettings `json:",inline"`
|
||||
|
||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||
// key from answer.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
|
||||
}
|
||||
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
m.confLock.Lock()
|
||||
data := tlsConfig{
|
||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: m.conf,
|
||||
},
|
||||
tlsConfigStatus: m.status,
|
||||
// loadCert loads the certificate from file, if necessary.
|
||||
func loadCert(tlsConf *tlsConfiguration) (err error) {
|
||||
if tlsConf.CertificateChain != "" {
|
||||
return errors.Error("certificate data and file can't be set together")
|
||||
}
|
||||
m.confLock.Unlock()
|
||||
|
||||
marshalTLS(w, r, data)
|
||||
}
|
||||
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
setts, err := unmarshalTLS(r)
|
||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
return fmt.Errorf("reading cert file: %w", err)
|
||||
}
|
||||
|
||||
if setts.PrivateKeySaved {
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if setts.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(setts.PortHTTPS),
|
||||
tcpPort(setts.PortDNSOverTLS),
|
||||
tcpPort(setts.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(setts.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !webCheckPortAvailable(setts.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable HTTPS on it",
|
||||
setts.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// status.WarningValidation.
|
||||
status := &tlsConfigStatus{}
|
||||
_ = loadTLSConf(&setts.tlsConfigSettings, status)
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: setts,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
|
||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||
// accept these from the frontend.
|
||||
//
|
||||
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
||||
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
||||
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
||||
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||
log.Info("tls config has changed, restarting https server")
|
||||
restartHTTPS = true
|
||||
} else {
|
||||
log.Info("tls: config has not changed")
|
||||
// loadPKey loads the private key from file, if necessary.
|
||||
func loadPKey(tlsConf *tlsConfiguration) (err error) {
|
||||
if tlsConf.PrivateKey != "" {
|
||||
return errors.Error("private key data and file cannot be set together")
|
||||
}
|
||||
|
||||
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
|
||||
m.conf.Enabled = newConf.Enabled
|
||||
m.conf.ServerName = newConf.ServerName
|
||||
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||
m.conf.PortHTTPS = newConf.PortHTTPS
|
||||
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||
m.conf.CertificateChain = newConf.CertificateChain
|
||||
m.conf.CertificatePath = newConf.CertificatePath
|
||||
m.conf.CertificateChainData = newConf.CertificateChainData
|
||||
m.conf.PrivateKey = newConf.PrivateKey
|
||||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
m.status = status
|
||||
|
||||
return restartHTTPS
|
||||
}
|
||||
|
||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(req.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Investigate and perhaps check other ports.
|
||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable https on it",
|
||||
req.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
status := &tlsConfigStatus{}
|
||||
err = loadTLSConf(&req.tlsConfigSettings, status)
|
||||
if err != nil {
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
|
||||
m.setCertFileTime()
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request. It is also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
// DNS protocols.
|
||||
func validatePorts(
|
||||
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
||||
dnsPort, doqPort udpPort,
|
||||
) (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(bindPort),
|
||||
tcpPort(betaBindPort),
|
||||
tcpPort(dohPort),
|
||||
tcpPort(dotPort),
|
||||
tcpPort(dnscryptTCPPort),
|
||||
)
|
||||
|
||||
err = tcpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
}
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||
|
||||
err = udpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
return fmt.Errorf("reading key file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -513,6 +252,11 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// errNoIPInCert is the error that is returned from [parseCertChain] if the leaf
|
||||
// certificate doesn't contain IPs.
|
||||
const errNoIPInCert errors.Error = `certificates has no IP addresses; ` +
|
||||
`DNS-over-TLS won't be advertised via DDR`
|
||||
|
||||
// parseCertChain parses the certificate chain from raw data, and returns it.
|
||||
// If ok is true, the returned error, if any, is not critical.
|
||||
func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err error) {
|
||||
@@ -535,8 +279,7 @@ func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err
|
||||
log.Info("tls: number of certs: %d", len(parsedCerts))
|
||||
|
||||
if !aghtls.CertificateHasIP(parsedCerts[0]) {
|
||||
err = errors.Error(`certificate has no IP addresses` +
|
||||
`, this may cause issues with DNS-over-TLS clients`)
|
||||
err = errNoIPInCert
|
||||
}
|
||||
|
||||
return parsedCerts, true, err
|
||||
@@ -696,61 +439,3 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error)
|
||||
|
||||
return nil, "", errors.Error("tls: failed to parse private key")
|
||||
}
|
||||
|
||||
// unmarshalTLS handles base64-encoded certificates transparently
|
||||
func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
|
||||
data := tlsConfigSettingsExt{}
|
||||
err := json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to parse new TLS config json: %w", err)
|
||||
}
|
||||
|
||||
if data.CertificateChain != "" {
|
||||
var cert []byte
|
||||
cert, err = base64.StdEncoding.DecodeString(data.CertificateChain)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
|
||||
}
|
||||
|
||||
data.CertificateChain = string(cert)
|
||||
if data.CertificatePath != "" {
|
||||
return data, fmt.Errorf("certificate data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
var key []byte
|
||||
key, err = base64.StdEncoding.DecodeString(data.PrivateKey)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
|
||||
}
|
||||
|
||||
data.PrivateKey = string(key)
|
||||
if data.PrivateKeyPath != "" {
|
||||
return data, fmt.Errorf("private key data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
||||
if data.CertificateChain != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
|
||||
data.CertificateChain = encoded
|
||||
}
|
||||
|
||||
if data.PrivateKey != "" {
|
||||
data.PrivateKeySaved = true
|
||||
data.PrivateKey = ""
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||
func (m *tlsManager) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||
}
|
||||
|
||||
362
internal/home/tlshttp.go
Normal file
362
internal/home/tlshttp.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// Encryption Settings HTTP API
|
||||
|
||||
// tlsConfigStatus contains the status of a certificate chain and key pair.
|
||||
type tlsConfigStatus struct {
|
||||
// Subject is the subject of the first certificate in the chain.
|
||||
Subject string `json:"subject,omitempty"`
|
||||
|
||||
// Issuer is the issuer of the first certificate in the chain.
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
|
||||
// KeyType is the type of the private key.
|
||||
KeyType string `json:"key_type,omitempty"`
|
||||
|
||||
// NotBefore is the NotBefore field of the first certificate in the chain.
|
||||
NotBefore time.Time `json:"not_before,omitempty"`
|
||||
|
||||
// NotAfter is the NotAfter field of the first certificate in the chain.
|
||||
NotAfter time.Time `json:"not_after,omitempty"`
|
||||
|
||||
// WarningValidation is a validation warning message with the issue
|
||||
// description.
|
||||
WarningValidation string `json:"warning_validation,omitempty"`
|
||||
|
||||
// DNSNames is the value of SubjectAltNames field of the first certificate
|
||||
// in the chain.
|
||||
DNSNames []string `json:"dns_names"`
|
||||
|
||||
// ValidCert is true if the specified certificate chain is a valid chain of
|
||||
// X509 certificates.
|
||||
ValidCert bool `json:"valid_cert"`
|
||||
|
||||
// ValidChain is true if the specified certificate chain is verified and
|
||||
// issued by a known CA.
|
||||
ValidChain bool `json:"valid_chain"`
|
||||
|
||||
// ValidKey is true if the key is a valid private key.
|
||||
ValidKey bool `json:"valid_key"`
|
||||
|
||||
// ValidPair is true if both certificate and private key are correct for
|
||||
// each other.
|
||||
ValidPair bool `json:"valid_pair"`
|
||||
}
|
||||
|
||||
// tlsConfigResp is the TLS configuration and status response.
|
||||
type tlsConfigResp struct {
|
||||
*tlsConfigStatus
|
||||
*tlsConfiguration
|
||||
|
||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||
// key from answer.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
|
||||
}
|
||||
|
||||
// tlsConfigReq is the TLS configuration request.
|
||||
type tlsConfigReq struct {
|
||||
tlsConfiguration
|
||||
|
||||
// PrivateKeySaved is true if the private key is saved as a string and omit
|
||||
// key from answer.
|
||||
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
|
||||
}
|
||||
|
||||
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
var resp *tlsConfigResp
|
||||
func() {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
resp = &tlsConfigResp{
|
||||
tlsConfigStatus: m.status,
|
||||
tlsConfiguration: m.conf.cloneForEncoding(),
|
||||
}
|
||||
}()
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
}
|
||||
|
||||
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
|
||||
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.confForEncoding().PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(req.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable HTTPS on it",
|
||||
req.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := &tlsConfigResp{
|
||||
tlsConfigStatus: &tlsConfigStatus{},
|
||||
tlsConfiguration: &req.tlsConfiguration,
|
||||
}
|
||||
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// resl.tlsConfigStatus.WarningValidation.
|
||||
_ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
}
|
||||
|
||||
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
|
||||
// DNS protocols.
|
||||
func validatePorts(
|
||||
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
|
||||
dnsPort, doqPort udpPort,
|
||||
) (err error) {
|
||||
tcpPorts := aghalg.UniqChecker[tcpPort]{}
|
||||
addPorts(
|
||||
tcpPorts,
|
||||
tcpPort(bindPort),
|
||||
tcpPort(betaBindPort),
|
||||
tcpPort(dohPort),
|
||||
tcpPort(dotPort),
|
||||
tcpPort(dnscryptTCPPort),
|
||||
)
|
||||
|
||||
err = tcpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating tcp ports: %w", err)
|
||||
}
|
||||
|
||||
udpPorts := aghalg.UniqChecker[udpPort]{}
|
||||
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
|
||||
|
||||
err = udpPorts.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
|
||||
// API.
|
||||
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := unmarshalTLS(r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.confForEncoding().PrivateKey
|
||||
}
|
||||
|
||||
if req.Enabled {
|
||||
err = validatePorts(
|
||||
tcpPort(config.BindPort),
|
||||
tcpPort(config.BetaBindPort),
|
||||
tcpPort(req.PortHTTPS),
|
||||
tcpPort(req.PortDNSOverTLS),
|
||||
tcpPort(req.PortDNSCrypt),
|
||||
udpPort(config.DNS.Port),
|
||||
udpPort(req.PortDNSOverQUIC),
|
||||
)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Investigate and perhaps check other ports.
|
||||
if !webCheckPortAvailable(req.PortHTTPS) {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"port %d is not available, cannot enable https on it",
|
||||
req.PortHTTPS,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp := &tlsConfigResp{
|
||||
tlsConfigStatus: &tlsConfigStatus{},
|
||||
tlsConfiguration: &req.tlsConfiguration,
|
||||
}
|
||||
err = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
|
||||
if err != nil {
|
||||
marshalTLS(w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
restartRequired := m.setConf(resp)
|
||||
onConfigModified()
|
||||
|
||||
err = reconfigureDNSServer()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.tlsConfiguration = m.confForEncoding()
|
||||
marshalTLS(w, r, resp)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request. It is also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
if restartRequired {
|
||||
go func() {
|
||||
Context.web.TLSConfigChanged(context.Background(), resp.tlsConfiguration)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// setConf sets the necessary values from the new configuration.
|
||||
func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||
// accept these from the frontend.
|
||||
//
|
||||
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
||||
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
||||
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
||||
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||
log.Info("tls: config has changed, restarting https server")
|
||||
restartRequired = true
|
||||
} else {
|
||||
log.Info("tls: config has not changed")
|
||||
}
|
||||
|
||||
// Do not just write "m.conf = *newConf.tlsConfiguration", because all other
|
||||
// members of m.conf must be preserved.
|
||||
m.conf.Enabled = newConf.Enabled
|
||||
m.conf.ServerName = newConf.ServerName
|
||||
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||
m.conf.PortHTTPS = newConf.PortHTTPS
|
||||
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||
|
||||
m.conf.CertificateChain = newConf.CertificateChain
|
||||
m.conf.CertificatePath = newConf.CertificatePath
|
||||
m.conf.CertificateChainData = newConf.CertificateChainData
|
||||
m.conf.PrivateKey = newConf.PrivateKey
|
||||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
|
||||
m.setCertFileTime()
|
||||
|
||||
m.status = newConf.tlsConfigStatus
|
||||
|
||||
return restartRequired
|
||||
}
|
||||
|
||||
// marshalTLS handles Base64-encoded certificates transparently.
|
||||
func marshalTLS(w http.ResponseWriter, r *http.Request, conf *tlsConfigResp) {
|
||||
if conf.CertificateChain != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(conf.CertificateChain))
|
||||
conf.CertificateChain = encoded
|
||||
}
|
||||
|
||||
if conf.PrivateKey != "" {
|
||||
conf.PrivateKeySaved = true
|
||||
conf.PrivateKey = ""
|
||||
}
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, conf)
|
||||
}
|
||||
|
||||
// unmarshalTLS handles Base64-encoded certificates transparently.
|
||||
func unmarshalTLS(r *http.Request) (req *tlsConfigReq, err error) {
|
||||
req = &tlsConfigReq{}
|
||||
err = json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing tls config: %w", err)
|
||||
}
|
||||
|
||||
if req.CertificateChain != "" {
|
||||
var cert []byte
|
||||
cert, err = base64.StdEncoding.DecodeString(req.CertificateChain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
|
||||
}
|
||||
|
||||
req.CertificateChain = string(cert)
|
||||
if req.CertificatePath != "" {
|
||||
return nil, fmt.Errorf("certificate data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
if req.PrivateKey != "" {
|
||||
var key []byte
|
||||
key, err = base64.StdEncoding.DecodeString(req.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to base64-decode private key: %w", err)
|
||||
}
|
||||
|
||||
req.PrivateKey = string(key)
|
||||
if req.PrivateKeyPath != "" {
|
||||
return nil, fmt.Errorf("private key data and file can't be set together")
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||
func (m *tlsManager) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func webCheckPortAvailable(port int) (ok bool) {
|
||||
|
||||
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// if necessary.
|
||||
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf *tlsConfiguration) {
|
||||
log.Debug("web: applying new tls configuration")
|
||||
web.conf.PortHTTPS = tlsConf.PortHTTPS
|
||||
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -26,7 +27,7 @@ const (
|
||||
// WHOIS - module context
|
||||
type WHOIS struct {
|
||||
clients *clientsContainer
|
||||
ipChan chan net.IP
|
||||
ipChan chan netip.Addr
|
||||
|
||||
// dialContext specifies the dial function for creating unencrypted TCP
|
||||
// connections.
|
||||
@@ -51,7 +52,7 @@ func initWHOIS(clients *clientsContainer) *WHOIS {
|
||||
MaxCount: 10000,
|
||||
}),
|
||||
dialContext: customDialContext,
|
||||
ipChan: make(chan net.IP, 255),
|
||||
ipChan: make(chan netip.Addr, 255),
|
||||
}
|
||||
|
||||
go w.workerLoop()
|
||||
@@ -192,7 +193,7 @@ func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
|
||||
}
|
||||
|
||||
// Request WHOIS information
|
||||
func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISInfo) {
|
||||
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
|
||||
resp, err := w.queryAll(ctx, ip.String())
|
||||
if err != nil {
|
||||
log.Debug("whois: error: %s IP:%s", err, ip)
|
||||
@@ -220,24 +221,25 @@ func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISI
|
||||
}
|
||||
|
||||
// Begin - begin requesting WHOIS info
|
||||
func (w *WHOIS) Begin(ip net.IP) {
|
||||
func (w *WHOIS) Begin(ip netip.Addr) {
|
||||
ipBytes := ip.AsSlice()
|
||||
now := uint64(time.Now().Unix())
|
||||
expire := w.ipAddrs.Get([]byte(ip))
|
||||
expire := w.ipAddrs.Get(ipBytes)
|
||||
if len(expire) != 0 {
|
||||
exp := binary.BigEndian.Uint64(expire)
|
||||
if exp > now {
|
||||
return
|
||||
}
|
||||
// TTL expired
|
||||
}
|
||||
|
||||
expire = make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expire, now+whoisTTL)
|
||||
_ = w.ipAddrs.Set([]byte(ip), expire)
|
||||
_ = w.ipAddrs.Set(ipBytes, expire)
|
||||
|
||||
log.Debug("whois: adding %s", ip)
|
||||
|
||||
select {
|
||||
case w.ipChan <- ip:
|
||||
//
|
||||
default:
|
||||
log.Debug("whois: queue is full")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -64,7 +65,7 @@ type Interface interface {
|
||||
|
||||
// GetTopClientIP returns at most limit IP addresses corresponding to the
|
||||
// clients with the most number of requests.
|
||||
TopClientsIP(limit uint) []net.IP
|
||||
TopClientsIP(limit uint) []netip.Addr
|
||||
|
||||
// WriteDiskConfig puts the Interface's configuration to the dc.
|
||||
WriteDiskConfig(dc *DiskConfig)
|
||||
@@ -107,8 +108,6 @@ type StatsCtx struct {
|
||||
filename string
|
||||
}
|
||||
|
||||
var _ Interface = &StatsCtx{}
|
||||
|
||||
// New creates s from conf and properly initializes it. Don't use s before
|
||||
// calling it's Start method.
|
||||
func New(conf Config) (s *StatsCtx, err error) {
|
||||
@@ -178,6 +177,9 @@ func withRecovered(orig *error) {
|
||||
*orig = errors.WithDeferred(*orig, err)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Interface = (*StatsCtx)(nil)
|
||||
|
||||
// Start implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) Start() {
|
||||
s.initWeb()
|
||||
@@ -250,8 +252,8 @@ func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) {
|
||||
dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
|
||||
}
|
||||
|
||||
// TopClientsIP implements the Interface interface for *StatsCtx.
|
||||
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
|
||||
// TopClientsIP implements the [Interface] interface for *StatsCtx.
|
||||
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
|
||||
limit := atomic.LoadUint32(&s.limitHours)
|
||||
if limit == 0 {
|
||||
return nil
|
||||
@@ -271,10 +273,10 @@ func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
|
||||
}
|
||||
|
||||
a := convertMapToSlice(m, int(maxCount))
|
||||
ips = []net.IP{}
|
||||
ips = []netip.Addr{}
|
||||
for _, it := range a {
|
||||
ip := net.ParseIP(it.Name)
|
||||
if ip != nil {
|
||||
ip, err := netip.ParseAddr(it.Name)
|
||||
if err == nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -45,7 +46,7 @@ func assertSuccessAndUnmarshal(t *testing.T, to any, handler http.Handler, req *
|
||||
}
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
cliIP := net.IP{127, 0, 0, 1}
|
||||
cliIP := netutil.IPv4Localhost()
|
||||
cliIPStr := cliIP.String()
|
||||
|
||||
handlers := map[string]http.Handler{}
|
||||
@@ -123,7 +124,7 @@ func TestStats(t *testing.T) {
|
||||
topClients := s.TopClientsIP(2)
|
||||
require.NotEmpty(t, topClients)
|
||||
|
||||
assert.True(t, cliIP.Equal(topClients[0]))
|
||||
assert.Equal(t, cliIP, topClients[0])
|
||||
})
|
||||
|
||||
t.Run("reset", func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user