Compare commits

...

6 Commits

Author SHA1 Message Date
Ainar Garipov
0c7d56dca3 Merge branch 'master' into 4927-refactor-tls 2022-11-22 17:10:40 +03:00
Ainar Garipov
08282dc4d9 Pull request: 4927-imp-ui
Updates #4927.

Squashed commit of the following:

commit 510143325805133e379ebc207cdc6bff59c94ade
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Nov 22 15:00:13 2022 +0300

    home: imp err

commit fd65a9914494b6dccdee7c0f0aa08bce80ce0945
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Nov 21 18:53:39 2022 +0300

    client: imp validation ui
2022-11-22 17:07:49 +03:00
Ainar Garipov
f36efa26a4 home: refactor more 2022-11-21 19:45:18 +03:00
Ainar Garipov
a8850059db home: refactor tls 2022-11-21 19:05:49 +03:00
Dimitry Kolyshev
93882d6860 Pull request: 4223 home: cmd update
Merge in DNS/adguard-home from 4223-cmd-update to master

Squashed commit of the following:

commit ffda71246f37eaba0cb190840f1370ba65099d7c
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Tue Nov 15 16:32:10 2022 +0200

    home: cmd update

commit 9c4e1c33da78952a2b1477ac380a0cf042a8990f
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Tue Nov 15 13:51:33 2022 +0200

    home: cmd update

commit 6a564dc30771b3675e8861ca3befaaee15d83026
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Mon Nov 14 11:05:06 2022 +0200

    all: docs

commit a546bdbdb6f3f78c40908bc1864f2a1ae1c9071f
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Mon Nov 14 10:55:16 2022 +0200

    home: cmd update

commit cbbb594980d3d163fe0489494b0ddca5f679d6e6
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Mon Nov 14 10:16:09 2022 +0200

    home: imp code

commit 677f8a7ca0f47da0ac636e5bab9db24506cf5041
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Sun Nov 13 14:12:48 2022 +0200

    home: cmd update
2022-11-15 17:44:50 +03:00
Ainar Garipov
167b112511 Pull request: 5035-more-clients-netip-addr
Updates #5035.

Squashed commit of the following:

commit 1934ea14299921760e9fcf6dd9053bd3155cb40e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Nov 9 14:19:54 2022 +0300

    all: move more client code to netip.Addr
2022-11-09 14:37:07 +03:00
29 changed files with 727 additions and 603 deletions

View File

@@ -25,6 +25,13 @@ See also the [v0.107.19 GitHub milestone][ms-v0.107.19].
[ms-v0.107.19]: https://github.com/AdguardTeam/AdGuardHome/milestone/55?closed=1
-->
### Added
- The new `--update` command-line option, which allows updating AdGuard Home
silently ([#4223]).
[#4223]: https://github.com/AdguardTeam/AdGuardHome/issues/4223
## [v0.107.18] - 2022-11-08

View File

@@ -393,6 +393,7 @@
"encryption_issuer": "Issuer",
"encryption_hostnames": "Hostnames",
"encryption_reset": "Are you sure you want to reset encryption settings?",
"encryption_warning": "Warning",
"topline_expiring_certificate": "Your SSL certificate is about to expire. Update <0>Encryption settings</0>.",
"topline_expired_certificate": "Your SSL certificate is expired. Update <0>Encryption settings</0>.",
"form_error_port_range": "Enter port number in the range of 80-65535",

View File

@@ -56,6 +56,26 @@ const clearFields = (change, setTlsConfig, t) => {
}
};
const validationMessage = (warningValidation, isWarning) => {
if (!warningValidation) {
return null;
}
if (isWarning) {
return (
<div className="col-12">
<p><Trans>encryption_warning</Trans>: {warningValidation}</p>
</div>
);
}
return (
<div className="col-12">
<p className="text-danger">{warningValidation}</p>
</div>
);
};
let Form = (props) => {
const {
t,
@@ -95,6 +115,8 @@ let Form = (props) => {
|| !valid_cert
|| !valid_pair;
const isWarning = valid_key && valid_cert && valid_pair;
return (
<form onSubmit={handleSubmit}>
<div className="row">
@@ -382,11 +404,7 @@ let Form = (props) => {
)}
</div>
</div>
{warning_validation && (
<div className="col-12">
<p className="text-danger">{warning_validation}</p>
</div>
)}
{validationMessage(warning_validation, isWarning)}
</div>
<div className="btn-list mt-2">

View File

@@ -31,12 +31,6 @@ var (
// the IP being static is available.
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4].
func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6].
func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) }
// IfaceHasStaticIP checks if interface is configured to have static IP address.
// If it can't give a definitive answer, it returns false and an error for which
// errors.Is(err, ErrNoStaticIPInfo) is true.

View File

@@ -188,7 +188,7 @@ func TestBroadcastFromIPNet(t *testing.T) {
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(IPv4Localhost(), 0)
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())

View File

@@ -23,16 +23,6 @@ func ValidateClientID(id string) (err error) {
return nil
}
// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's
// a helper function to prevent unnecessary allocations in code like:
//
// if strings.HasSuffix(s, "." + suffix) { /* … */ }
//
// s must be longer than suffix.
func hasLabelSuffix(s, suffix string) (ok bool) {
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
}
// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match,
@@ -46,7 +36,7 @@ func clientIDFromClientServerName(
return "", nil
}
if !hasLabelSuffix(cliSrvName, hostSrvName) {
if !netutil.IsImmediateSubdomain(cliSrvName, hostSrvName) {
if !strict {
return "", nil
}

View File

@@ -145,7 +145,8 @@ type FilteringConfig struct {
IpsetListFileName string `yaml:"ipset_file"`
}
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, DNS-over-TLS,
// and DNS-over-QUIC.
type TLSConfig struct {
cert tls.Certificate

View File

@@ -246,6 +246,7 @@ type RDNSExchanger interface {
// Exchange tries to resolve the ip in a suitable way, e.g. either as
// local or as external.
Exchange(ip net.IP) (host string, err error)
// ResolvesPrivatePTR returns true if the RDNSExchanger is able to
// resolve PTR requests for locally-served addresses.
ResolvesPrivatePTR() (ok bool)
@@ -261,6 +262,9 @@ const (
rDNSNotPTRErr errors.Error = "the response is not a ptr"
)
// type check
var _ RDNSExchanger = (*Server)(nil)
// Exchange implements the RDNSExchanger interface for *Server.
func (s *Server) Exchange(ip net.IP) (host string, err error) {
s.serverLock.RLock()
@@ -675,21 +679,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// IsBlockedClient returns true if the client is blocked by the current access
// settings.
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
blockedByIP := false
if ip != nil {
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("dnsforward: bad client ip %v: %s", ip, err)
return false, ""
}
blockedByIP, rule = s.access.isBlockedIP(ipAddr)
if ip != (netip.Addr{}) {
blockedByIP, rule = s.access.isBlockedIP(ip)
}
allowlistMode := s.access.allowlistMode()

View File

@@ -19,13 +19,13 @@ func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(ip, clientID)
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}

View File

@@ -11,7 +11,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -40,7 +40,7 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
addr := l.Addr()
require.IsType(t, new(net.TCPAddr), addr)
return netip.AddrPortFrom(aghnet.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
}
func TestFilters(t *testing.T) {

View File

@@ -129,7 +129,7 @@ type RuntimeClientWHOISInfo struct {
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for
// different types (string, net.IP, and so on).
// different types (string, netip.Addr, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client
@@ -333,7 +333,7 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
}
// exists checks if client with this IP address already exists.
func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) {
func (clients *clientsContainer) exists(ip netip.Addr, source clientSource) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -342,7 +342,7 @@ func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool
return true
}
rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok := clients.ipToRC[ip]
if !ok {
return false
}
@@ -371,7 +371,8 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
var artClient *querylog.Client
var art bool
for _, id := range ids {
c, art = clients.clientOrArtificial(net.ParseIP(id), id)
ip, _ := netip.ParseAddr(id)
c, art = clients.clientOrArtificial(ip, id)
if art {
artClient = c
@@ -389,7 +390,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
// records about this client besides maybe whether or not it is blocked. c is
// never nil.
func (clients *clientsContainer) clientOrArtificial(
ip net.IP,
ip netip.Addr,
id string,
) (c *querylog.Client, art bool) {
defer func() {
@@ -406,13 +407,6 @@ func (clients *clientsContainer) clientOrArtificial(
}, false
}
if ip == nil {
// Technically should never happen, but still.
return &querylog.Client{
Name: "",
}, true
}
var rc *RuntimeClient
rc, ok = clients.findRuntimeClient(ip)
if ok {
@@ -492,19 +486,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return c, true
}
ip := net.ParseIP(id)
if ip == nil {
ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
_, ipnet, err := net.ParseCIDR(id)
var n netip.Prefix
n, err = netip.ParsePrefix(id)
if err != nil {
continue
}
if ipnet.Contains(ip) {
if n.Contains(ip) {
return c, true
}
}
@@ -514,19 +509,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}
macFound := clients.dhcpServer.FindMACbyIP(ip)
macFound := clients.dhcpServer.FindMACbyIP(ip.AsSlice())
if macFound == nil {
return nil, false
}
for _, c = range clients.list {
for _, id := range c.IDs {
hwAddr, err := net.ParseMAC(id)
var mac net.HardwareAddr
mac, err = net.ParseMAC(id)
if err != nil {
continue
}
if bytes.Equal(hwAddr, macFound) {
if bytes.Equal(mac, macFound) {
return c, true
}
}
@@ -535,32 +531,18 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}
// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("clients: bad client ip %v: %s", ip, err)
return nil, false
}
rc, ok = clients.ipToRC[ipAddr]
return rc, ok
}
// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
if ip == nil {
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
if ip == (netip.Addr{}) {
return nil, false
}
clients.lock.Lock()
defer clients.lock.Unlock()
return clients.findRuntimeClientLocked(ip)
rc, ok = clients.ipToRC[ip]
return rc, ok
}
// check validates the client.
@@ -578,14 +560,16 @@ func (clients *clientsContainer) check(c *Client) (err error) {
for i, id := range c.IDs {
// Normalize structured data.
var ip net.IP
var ipnet *net.IPNet
var mac net.HardwareAddr
if ip = net.ParseIP(id); ip != nil {
var (
ip netip.Addr
n netip.Prefix
mac net.HardwareAddr
)
if ip, err = netip.ParseAddr(id); err == nil {
c.IDs[i] = ip.String()
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
ipnet.IP = ip
c.IDs[i] = ipnet.String()
} else if n, err = netip.ParsePrefix(id); err == nil {
c.IDs[i] = n.String()
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
@@ -750,7 +734,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
}
// setWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -760,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
return
}
rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok := clients.ipToRC[ip]
if ok {
rc.WHOISInfo = wi
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
@@ -776,32 +760,22 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
rc.WHOISInfo = wi
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("clients: bad client ip %v: %s", ip, err)
return
}
clients.ipToRC[ipAddr] = rc
clients.ipToRC[ip] = rc
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}
// AddHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
func (clients *clientsContainer) AddHost(
ip netip.Addr,
host string,
src clientSource,
) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
return false, fmt.Errorf("adding host: %w", err)
}
return clients.addHostLocked(ipAddr, host, src), nil
return clients.addHostLocked(ip, host, src)
}
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be

View File

@@ -22,8 +22,18 @@ func TestClients(t *testing.T) {
clients.Init(nil, nil, nil, nil)
t.Run("add_success", func(t *testing.T) {
var (
cliNone = "1.2.3.4"
cli1 = "1.1.1.1"
cli2 = "2.2.2.2"
cliNoneIP = netip.MustParseAddr(cliNone)
cli1IP = netip.MustParseAddr(cli1)
cli2IP = netip.MustParseAddr(cli2)
)
c := &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
}
@@ -33,7 +43,7 @@ func TestClients(t *testing.T) {
assert.True(t, ok)
c = &Client{
IDs: []string{"2.2.2.2"},
IDs: []string{cli2},
Name: "client2",
}
@@ -42,7 +52,7 @@ func TestClients(t *testing.T) {
assert.True(t, ok)
c, ok = clients.Find("1.1.1.1")
c, ok = clients.Find(cli1)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
@@ -52,14 +62,14 @@ func TestClients(t *testing.T) {
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("2.2.2.2")
c, ok = clients.Find(cli2)
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.False(t, clients.exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
assert.True(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
assert.False(t, clients.exists(cliNoneIP, ClientSourceHostsFile))
assert.True(t, clients.exists(cli1IP, ClientSourceHostsFile))
assert.True(t, clients.exists(cli2IP, ClientSourceHostsFile))
})
t.Run("add_fail_name", func(t *testing.T) {
@@ -103,23 +113,31 @@ func TestClients(t *testing.T) {
})
t.Run("update_success", func(t *testing.T) {
var (
cliOld = "1.1.1.1"
cliNew = "1.1.1.2"
cliOldIP = netip.MustParseAddr(cliOld)
cliNewIP = netip.MustParseAddr(cliNew)
)
err := clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
IDs: []string{cliNew},
Name: "client1",
})
require.NoError(t, err)
assert.False(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
assert.True(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
assert.False(t, clients.exists(cliOldIP, ClientSourceHostsFile))
assert.True(t, clients.exists(cliNewIP, ClientSourceHostsFile))
err = clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
IDs: []string{cliNew},
Name: "client1-renamed",
UseOwnSettings: true,
})
require.NoError(t, err)
c, ok := clients.Find("1.1.1.2")
c, ok := clients.Find(cliNew)
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
@@ -132,14 +150,14 @@ func TestClients(t *testing.T) {
require.Len(t, c.IDs, 1)
assert.Equal(t, "1.1.1.2", c.IDs[0])
assert.Equal(t, cliNew, c.IDs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed")
require.True(t, ok)
assert.False(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
assert.False(t, clients.exists(netip.MustParseAddr("1.1.1.2"), ClientSourceHostsFile))
})
t.Run("del_fail", func(t *testing.T) {
@@ -148,45 +166,33 @@ func TestClients(t *testing.T) {
})
t.Run("addhost_success", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
require.NoError(t, err)
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host", ClientSourceARP)
assert.True(t, ok)
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
require.NoError(t, err)
ok = clients.AddHost(ip, "host2", ClientSourceARP)
assert.True(t, ok)
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
require.NoError(t, err)
ok = clients.AddHost(ip, "host3", ClientSourceHostsFile)
assert.True(t, ok)
assert.True(t, clients.exists(ip, ClientSourceHostsFile))
})
t.Run("dhcp_replaces_arp", func(t *testing.T) {
ip := net.IP{1, 2, 3, 4}
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
require.NoError(t, err)
ip := netip.MustParseAddr("1.2.3.4")
ok := clients.AddHost(ip, "from_arp", ClientSourceARP)
assert.True(t, ok)
assert.True(t, clients.exists(ip, ClientSourceARP))
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
require.NoError(t, err)
ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
assert.True(t, ok)
assert.True(t, clients.exists(ip, ClientSourceDHCP))
})
t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
require.NoError(t, err)
ip := netip.MustParseAddr("1.1.1.1")
ok := clients.AddHost(ip, "host1", ClientSourceRDNS)
assert.False(t, ok)
})
}
@@ -203,7 +209,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip.AsSlice(), whois)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
@@ -212,12 +218,10 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("existing_auto-client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
ok, err := clients.AddHost(ip.AsSlice(), "host", ClientSourceRDNS)
require.NoError(t, err)
ok := clients.AddHost(ip, "host", ClientSourceRDNS)
assert.True(t, ok)
clients.setWHOISInfo(ip.AsSlice(), whois)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
require.NotNil(t, rc)
@@ -234,7 +238,7 @@ func TestClientsWHOIS(t *testing.T) {
require.NoError(t, err)
assert.True(t, ok)
clients.setWHOISInfo(ip.AsSlice(), whois)
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
require.Nil(t, rc)
@@ -249,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) {
clients.Init(nil, nil, nil, nil)
t.Run("simple", func(t *testing.T) {
ip := net.IP{1, 1, 1, 1}
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
ok, err := clients.Add(&Client{
@@ -260,8 +264,7 @@ func TestClientsAddExisting(t *testing.T) {
assert.True(t, ok)
// Now add an auto-client with the same IP.
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
require.NoError(t, err)
ok = clients.AddHost(ip, "test", ClientSourceRDNS)
assert.True(t, ok)
})

View File

@@ -3,8 +3,8 @@ package home
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
)
@@ -47,8 +47,8 @@ type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
IP netip.Addr `json:"ip"`
Source clientSource `json:"source"`
IP net.IP `json:"ip"`
}
type clientListJSON struct {
@@ -75,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
Name: rc.Host,
Source: rc.Source,
IP: ip.AsSlice(),
IP: ip,
}
data.RuntimeClients = append(data.RuntimeClients, cj)
@@ -218,7 +218,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
break
}
ip := net.ParseIP(idStr)
ip, _ := netip.ParseAddr(idStr)
c, ok := clients.Find(idStr)
var cj *clientJSON
if !ok {
@@ -240,7 +240,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// findRuntime looks up the IP in runtime and temporary storages, like
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil.
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
rc, ok := clients.findRuntimeClient(ip)
if !ok {
// It is still possible that the IP used to be in the runtime clients

View File

@@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/maybe"
"golang.org/x/exp/slices"
yaml "gopkg.in/yaml.v3"
)
@@ -113,8 +114,8 @@ type configuration struct {
// An active session is automatically refreshed once a day.
WebSessionTTLHours uint32 `yaml:"web_session_ttl"`
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
DNS dnsConfig `yaml:"dns"`
TLS tlsConfiguration `yaml:"tls"`
// Filters reflects the filters from [filtering.Config]. It's cloned to the
// config used in the filtering module at the startup. Afterwards it's
@@ -199,7 +200,8 @@ type dnsConfig struct {
UseHTTP3Upstreams bool `yaml:"use_http3_upstreams"`
}
type tlsConfigSettings struct {
// tlsConfiguration is the on-disk TLS configuration.
type tlsConfiguration struct {
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
@@ -223,6 +225,29 @@ type tlsConfigSettings struct {
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
}
// cloneForEncoding returns a clone of c with all top-level fields of c and all
// exported and YAML-encoded fields of c.TLSConfig cloned.
//
// TODO(a.garipov): This is better than races, but still not good enough.
func (c *tlsConfiguration) cloneForEncoding() (cloned *tlsConfiguration) {
if c == nil {
return nil
}
v := *c
cloned = &v
cloned.TLSConfig = dnsforward.TLSConfig{
CertificateChain: c.CertificateChain,
PrivateKey: c.PrivateKey,
CertificatePath: c.CertificatePath,
PrivateKeyPath: c.PrivateKeyPath,
OverrideTLSCiphers: slices.Clone(c.OverrideTLSCiphers),
StrictSNICheck: c.StrictSNICheck,
}
return cloned
}
// config is the global configuration structure.
//
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
@@ -273,7 +298,7 @@ var config = &configuration{
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
},
TLS: tlsConfigSettings{
TLS: tlsConfiguration{
PortHTTPS: defaultPortHTTPS,
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC,
@@ -442,7 +467,7 @@ func (c *configuration) write() (err error) {
}
if Context.tls != nil {
tlsConf := tlsConfigSettings{}
tlsConf := tlsConfiguration{}
Context.tls.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf
}

View File

@@ -71,9 +71,7 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
// on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addr := aghnet.IPv4Localhost()
addrs = appendDNSAddrs(addrs, addr)
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
} else {
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
if err != nil {

View File

@@ -154,7 +154,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
return nil
}
tlsConf := &tlsConfigSettings{}
tlsConf := &tlsConfiguration{}
Context.tls.WriteDiskConfig(tlsConf)
canUpdate := true
@@ -172,7 +172,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used.
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
func tlsConfUsesPrivilegedPorts(c *tlsConfiguration) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
}

View File

@@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -150,8 +151,8 @@ func isRunning() bool {
}
func onDNSRequest(pctx *proxy.DNSContext) {
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
if ip == nil {
ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr()
if ip == (netip.Addr{}) {
// This would be quite weird if we get here.
return
}
@@ -160,7 +161,8 @@ func onDNSRequest(pctx *proxy.DNSContext) {
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
}
@@ -193,11 +195,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
dnsConf := config.DNS
hosts := dnsConf.BindHosts
if len(hosts) == 0 {
hosts = []netip.Addr{aghnet.IPv4Localhost()}
}
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
newConf = dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
@@ -207,7 +205,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
OnDNSRequest: onDNSRequest,
}
tlsConf := tlsConfigSettings{}
tlsConf := tlsConfiguration{}
Context.tls.WriteDiskConfig(&tlsConf)
if tlsConf.Enabled {
newConf.TLSConfig = tlsConf.TLSConfig
@@ -252,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
return newConf, nil
}
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfiguration) (dnscc dnsforward.DNSCryptConfig, err error) {
if tlsConf.DNSCryptConfigFile == "" {
return dnscc, errors.Error("no dnscrypt_config_file")
}
@@ -290,7 +288,7 @@ type dnsEncryption struct {
}
func getDNSEncryption() (de dnsEncryption) {
tlsConf := tlsConfigSettings{}
tlsConf := tlsConfiguration{}
Context.tls.WriteDiskConfig(&tlsConf)
@@ -400,15 +398,12 @@ func startDNSServer() error {
const topClientsNumber = 100 // the number of clients to get
for _, ip := range Context.stats.TopClientsIP(topClientsNumber) {
if ip == nil {
continue
}
srcs := config.Clients.Sources
if srcs.RDNS && !ip.IsLoopback() {
Context.rdns.Begin(ip)
}
if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) {
if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) {
Context.whois.Begin(ip)
}
}

View File

@@ -512,7 +512,7 @@ func run(opts options, clientBuildFS fs.FS) {
}
config.Users = nil
Context.tls, err = newTLSManager(config.TLS)
Context.tls, err = newTLSManager(&config.TLS)
if err != nil {
log.Fatalf("initializing tls: %s", err)
}
@@ -542,6 +542,11 @@ func run(opts options, clientBuildFS fs.FS) {
}
}
// TODO(a.garipov): This could be made much earlier and could be done on
// the first run as well, but to achieve this we need to bypass requests
// over dnsforward resolver.
cmdlineUpdate(opts)
Context.web.Start()
// wait indefinitely for other go-routines to complete their job
@@ -576,7 +581,7 @@ func checkPermissions() {
}
// We should check if AdGuard Home is able to bind to port 53
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS))
err := aghnet.CheckPort("tcp", netip.AddrPortFrom(netutil.IPv4Localhost(), defaultPortDNS))
if err != nil {
if errors.Is(err, os.ErrPermission) {
log.Fatal(`Permission check failed.
@@ -812,7 +817,7 @@ func printWebAddrs(proto, addr string, port, betaPort int) {
// printHTTPAddresses prints the IP addresses which user can use to access the
// admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) {
tlsConf := tlsConfigSettings{}
tlsConf := tlsConfiguration{}
if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf)
}
@@ -927,3 +932,37 @@ type jsonError struct {
// Message is the error message, an opaque string.
Message string `json:"message"`
}
// cmdlineUpdate updates current application and exits.
func cmdlineUpdate(opts options) {
if !opts.performUpdate {
return
}
log.Info("starting update")
if Context.firstRun {
log.Info("update not allowed on first run")
os.Exit(0)
}
_, err := Context.updater.VersionInfo(true)
if err != nil {
vcu := Context.updater.VersionCheckURL()
log.Error("getting version info from %s: %s", vcu, err)
os.Exit(0)
}
if Context.updater.NewVersion() == "" {
log.Info("no updates available")
os.Exit(0)
}
err = Context.updater.Update()
fatalOnError(err)
os.Exit(0)
}

View File

@@ -32,7 +32,11 @@ func setupDNSIPs(t testing.TB) {
},
}
Context.tls = &tlsManager{}
var err error
Context.tls, err = newTLSManager(&tlsConfiguration{
Enabled: true,
})
require.NoError(t, err)
}
func TestHandleMobileConfigDoH(t *testing.T) {
@@ -65,7 +69,11 @@ func TestHandleMobileConfigDoH(t *testing.T) {
oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
var err error
Context.tls, err = newTLSManager(&tlsConfiguration{
Enabled: true,
})
require.NoError(t, err)
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
require.NoError(t, err)
@@ -137,7 +145,11 @@ func TestHandleMobileConfigDoT(t *testing.T) {
oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
var err error
Context.tls, err = newTLSManager(&tlsConfiguration{
Enabled: true,
})
require.NoError(t, err)
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
require.NoError(t, err)

View File

@@ -47,6 +47,9 @@ type options struct {
// disableUpdate, if set, makes AdGuard Home not check for updates.
disableUpdate bool
// performUpdate, if set, updates AdGuard Home without GUI and exits.
performUpdate bool
// verbose shows if verbose logging is enabled.
verbose bool
@@ -221,6 +224,14 @@ var cmdLineOpts = []cmdLineOpt{{
description: "Don't check for updates.",
longName: "no-check-update",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.performUpdate = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.performUpdate },
description: "Update application and exit.",
longName: "update",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: nil,

View File

@@ -103,6 +103,11 @@ func TestParseDisableUpdate(t *testing.T) {
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
}
func TestParsePerformUpdate(t *testing.T) {
assert.False(t, testParseOK(t).performUpdate, "empty is not perform update")
assert.True(t, testParseOK(t, "--update").performUpdate, "--update is perform update")
}
// TODO(e.burkov): Remove after v0.108.0.
func TestParseDisableMemoryOptimization(t *testing.T) {
o, eff, err := parseCmdOpts("", []string{"--no-mem-optimization"})
@@ -169,6 +174,10 @@ func TestOptsToArgs(t *testing.T) {
name: "disable_update",
args: []string{"--no-check-update"},
opts: options{disableUpdate: true},
}, {
name: "perform_update",
args: []string{"--update"},
opts: options{performUpdate: true},
}, {
name: "control_action",
args: []string{"-s", "run"},

View File

@@ -2,7 +2,7 @@ package home
import (
"encoding/binary"
"net"
"net/netip"
"sync/atomic"
"time"
@@ -21,7 +21,7 @@ type RDNS struct {
usePrivate uint32
// ipCh used to pass client's IP to rDNS workerLoop.
ipCh chan net.IP
ipCh chan netip.Addr
// ipCache caches the IP addresses to be resolved by rDNS. The resolved
// address stays here while it's inside clients. After leaving clients the
@@ -50,7 +50,7 @@ func NewRDNS(
EnableLRU: true,
MaxCount: defaultRDNSCacheSize,
}),
ipCh: make(chan net.IP, defaultRDNSIPChSize),
ipCh: make(chan netip.Addr, defaultRDNSIPChSize),
}
if usePrivate {
rDNS.usePrivate = 1
@@ -80,9 +80,10 @@ func (r *RDNS) ensurePrivateCache() {
// isCached returns true if ip is already cached and not expired yet. It also
// caches it otherwise.
func (r *RDNS) isCached(ip net.IP) (ok bool) {
func (r *RDNS) isCached(ip netip.Addr) (ok bool) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
if expire := r.ipCache.Get(ip); len(expire) != 0 {
if expire := r.ipCache.Get(ipBytes); len(expire) != 0 {
if binary.BigEndian.Uint64(expire) > now {
return true
}
@@ -91,14 +92,14 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
// The cache entry either expired or doesn't exist.
ttl := make([]byte, 8)
binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL)
r.ipCache.Set(ip, ttl)
r.ipCache.Set(ipBytes, ttl)
return false
}
// Begin adds the ip to the resolving queue if it is not cached or already
// resolved.
func (r *RDNS) Begin(ip net.IP) {
func (r *RDNS) Begin(ip netip.Addr) {
r.ensurePrivateCache()
if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) {
@@ -107,9 +108,9 @@ func (r *RDNS) Begin(ip net.IP) {
select {
case r.ipCh <- ip:
log.Tracef("rdns: %q added to queue", ip)
log.Debug("rdns: %q added to queue", ip)
default:
log.Tracef("rdns: queue is full")
log.Debug("rdns: queue is full")
}
}
@@ -119,7 +120,7 @@ func (r *RDNS) workerLoop() {
defer log.OnPanic("rdns")
for ip := range r.ipCh {
host, err := r.exchanger.Exchange(ip)
host, err := r.exchanger.Exchange(ip.AsSlice())
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
@@ -128,8 +129,6 @@ func (r *RDNS) workerLoop() {
continue
}
// Don't handle any errors since AddHost doesn't return non-nil errors
// for now.
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
_ = r.clients.AddHost(ip, host, ClientSourceRDNS)
}
}

View File

@@ -27,14 +27,14 @@ func TestRDNS_Begin(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5}
ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5")
testCases := []struct {
cliIDIndex map[string]*Client
customChan chan net.IP
customChan chan netip.Addr
name string
wantLog string
req net.IP
ip netip.Addr
wantCacheHit int
wantCacheMiss int
}{{
@@ -42,7 +42,7 @@ func TestRDNS_Begin(t *testing.T) {
customChan: nil,
name: "cached",
wantLog: "",
req: ip1234,
ip: ip1234,
wantCacheHit: 1,
wantCacheMiss: 0,
}, {
@@ -50,7 +50,7 @@ func TestRDNS_Begin(t *testing.T) {
customChan: nil,
name: "not_cached",
wantLog: "rdns: queue is full",
req: ip1235,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}, {
@@ -58,15 +58,15 @@ func TestRDNS_Begin(t *testing.T) {
customChan: nil,
name: "already_in_clients",
wantLog: "",
req: ip1235,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}, {
cliIDIndex: map[string]*Client{},
customChan: make(chan net.IP, 1),
customChan: make(chan netip.Addr, 1),
name: "add_to_queue",
wantLog: `rdns: "1.2.3.5" added to queue`,
req: ip1235,
ip: ip1235,
wantCacheHit: 0,
wantCacheMiss: 1,
}}
@@ -102,7 +102,7 @@ func TestRDNS_Begin(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
rdns.Begin(tc.req)
rdns.Begin(tc.ip)
assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit)
assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss)
assert.Contains(t, w.String(), tc.wantLog)
@@ -179,8 +179,8 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)
localIP := net.IP{192, 168, 1, 1}
revIPv4, err := netutil.IPToReversedAddr(localIP)
localIP := netip.MustParseAddr("192.168.1.1")
revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err)
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
@@ -201,24 +201,24 @@ func TestRDNS_WorkerLoop(t *testing.T) {
testCases := []struct {
ups upstream.Upstream
cliIP netip.Addr
wantLog string
name string
cliIP net.IP
}{{
ups: locUpstream,
cliIP: localIP,
wantLog: "",
name: "all_good",
cliIP: localIP,
}, {
ups: errUpstream,
cliIP: netip.MustParseAddr("192.168.1.2"),
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
name: "resolve_error",
cliIP: net.IP{192, 168, 1, 2},
}, {
ups: locUpstream,
cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"),
wantLog: "",
name: "ipv6_good",
cliIP: net.ParseIP("2a00:1450:400c:c06::93"),
}}
for _, tc := range testCases {
@@ -230,7 +230,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ipToRC: map[netip.Addr]*RuntimeClient{},
allTags: stringutil.NewSet(),
}
ch := make(chan net.IP)
ch := make(chan netip.Addr)
rdns := &RDNS{
exchanger: &rDNSExchanger{
ex: tc.ups,

View File

@@ -8,42 +8,39 @@ import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/go-cmp/cmp"
)
// tlsManager contains the current configuration and state of AdGuard Home TLS
// encryption.
type tlsManager struct {
// status is the current status of the configuration. It is never nil.
status *tlsConfigStatus
// mu protects all fields.
mu *sync.RWMutex
// certLastMod is the last modification time of the certificate file.
certLastMod time.Time
confLock sync.Mutex
conf tlsConfigSettings
// status is the current status of the configuration. It is never nil.
status *tlsConfigStatus
// conf is the current TLS configuration.
conf *tlsConfiguration
}
// newTLSManager initializes the TLS configuration.
func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
func newTLSManager(conf *tlsConfiguration) (m *tlsManager, err error) {
m = &tlsManager{
status: &tlsConfigStatus{},
mu: &sync.RWMutex{},
conf: conf,
}
@@ -59,9 +56,19 @@ func newTLSManager(conf tlsConfigSettings) (m *tlsManager, err error) {
return m, nil
}
// confForEncoding returns a partial clone of the current TLS configuration. It
// is safe for concurrent use.
func (m *tlsManager) confForEncoding() (conf *tlsConfiguration) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.conf.cloneForEncoding()
}
// load reloads the TLS configuration from files or data from the config file.
// m.mu is expected to be locked for writing.
func (m *tlsManager) load() (err error) {
err = loadTLSConf(&m.conf, m.status)
err = loadTLSConf(m.conf, m.status)
if err != nil {
return fmt.Errorf("loading config: %w", err)
}
@@ -70,14 +77,12 @@ func (m *tlsManager) load() (err error) {
}
// WriteDiskConfig - write config
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
m.confLock.Lock()
*conf = m.conf
m.confLock.Unlock()
func (m *tlsManager) WriteDiskConfig(conf *tlsConfiguration) {
*conf = *m.confForEncoding()
}
// setCertFileTime sets t.certLastMod from the certificate. If there are
// errors, setCertFileTime logs them.
// errors, setCertFileTime logs them. mu is expected to be locked for writing.
func (m *tlsManager) setCertFileTime() {
if len(m.conf.CertificatePath) == 0 {
return
@@ -97,27 +102,22 @@ func (m *tlsManager) setCertFileTime() {
func (m *tlsManager) start() {
m.registerWebHandlers()
m.confLock.Lock()
tlsConf := m.conf
m.confLock.Unlock()
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
Context.web.TLSConfigChanged(context.Background(), m.confForEncoding())
}
// reload updates the configuration and restarts t.
// reload updates the configuration and restarts m.
func (m *tlsManager) reload() {
m.confLock.Lock()
tlsConf := m.conf
m.confLock.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
if !m.conf.Enabled || len(m.conf.CertificatePath) == 0 {
return
}
fi, err := os.Stat(tlsConf.CertificatePath)
fi, err := os.Stat(m.conf.CertificatePath)
if err != nil {
log.Error("tls: %s", err)
@@ -132,9 +132,7 @@ func (m *tlsManager) reload() {
log.Debug("tls: certificate file is modified")
m.confLock.Lock()
err = m.load()
m.confLock.Unlock()
if err != nil {
log.Error("tls: reloading: %s", err)
@@ -145,19 +143,15 @@ func (m *tlsManager) reload() {
_ = reconfigureDNSServer()
m.confLock.Lock()
tlsConf = m.conf
m.confLock.Unlock()
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.TLSConfigChanged(context.Background(), tlsConf)
Context.web.TLSConfigChanged(context.Background(), m.conf)
}
// loadTLSConf loads and validates the TLS configuration. The returned error is
// also set in status.WarningValidation.
func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) {
func loadTLSConf(tlsConf *tlsConfiguration, status *tlsConfigStatus) (err error) {
defer func() {
if err != nil {
status.WarningValidation = err.Error()
@@ -172,13 +166,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
if tlsConf.CertificatePath != "" {
if tlsConf.CertificateChain != "" {
return errors.Error("certificate data and file can't be set together")
}
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
err = loadCert(tlsConf)
if err != nil {
return fmt.Errorf("reading cert file: %w", err)
// Don't wrap the error, since it's informative enough as is.
return err
}
// Set status.ValidCert to true to signal the frontend that the
@@ -187,13 +178,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
}
if tlsConf.PrivateKeyPath != "" {
if tlsConf.PrivateKey != "" {
return errors.Error("private key data and file can't be set together")
}
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
err = loadPKey(tlsConf)
if err != nil {
return fmt.Errorf("reading key file: %w", err)
// Don't wrap the error, since it's informative enough as is.
return err
}
status.ValidKey = true
@@ -212,278 +200,29 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
return nil
}
// tlsConfigStatus contains the status of a certificate chain and key pair.
type tlsConfigStatus struct {
// Subject is the subject of the first certificate in the chain.
Subject string `json:"subject,omitempty"`
// Issuer is the issuer of the first certificate in the chain.
Issuer string `json:"issuer,omitempty"`
// KeyType is the type of the private key.
KeyType string `json:"key_type,omitempty"`
// NotBefore is the NotBefore field of the first certificate in the chain.
NotBefore time.Time `json:"not_before,omitempty"`
// NotAfter is the NotAfter field of the first certificate in the chain.
NotAfter time.Time `json:"not_after,omitempty"`
// WarningValidation is a validation warning message with the issue
// description.
WarningValidation string `json:"warning_validation,omitempty"`
// DNSNames is the value of SubjectAltNames field of the first certificate
// in the chain.
DNSNames []string `json:"dns_names"`
// ValidCert is true if the specified certificate chain is a valid chain of
// X509 certificates.
ValidCert bool `json:"valid_cert"`
// ValidChain is true if the specified certificate chain is verified and
// issued by a known CA.
ValidChain bool `json:"valid_chain"`
// ValidKey is true if the key is a valid private key.
ValidKey bool `json:"valid_key"`
// ValidPair is true if both certificate and private key are correct for
// each other.
ValidPair bool `json:"valid_pair"`
}
// tlsConfig is the TLS configuration and status response.
type tlsConfig struct {
*tlsConfigStatus `json:",inline"`
tlsConfigSettingsExt `json:",inline"`
}
// tlsConfigSettingsExt is used to (un)marshal the PrivateKeySaved field to
// ensure that clients don't send and receive previously saved private keys.
type tlsConfigSettingsExt struct {
tlsConfigSettings `json:",inline"`
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved,inline"`
}
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
m.confLock.Lock()
data := tlsConfig{
tlsConfigSettingsExt: tlsConfigSettingsExt{
tlsConfigSettings: m.conf,
},
tlsConfigStatus: m.status,
// loadCert loads the certificate from file, if necessary.
func loadCert(tlsConf *tlsConfiguration) (err error) {
if tlsConf.CertificateChain != "" {
return errors.Error("certificate data and file can't be set together")
}
m.confLock.Unlock()
marshalTLS(w, r, data)
}
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
setts, err := unmarshalTLS(r)
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
return fmt.Errorf("reading cert file: %w", err)
}
if setts.PrivateKeySaved {
setts.PrivateKey = m.conf.PrivateKey
}
if setts.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(setts.PortHTTPS),
tcpPort(setts.PortDNSOverTLS),
tcpPort(setts.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(setts.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
if !webCheckPortAvailable(setts.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
setts.PortHTTPS,
)
return
}
// Skip the error check, since we are only interested in the value of
// status.WarningValidation.
status := &tlsConfigStatus{}
_ = loadTLSConf(&setts.tlsConfigSettings, status)
resp := tlsConfig{
tlsConfigSettingsExt: setts,
tlsConfigStatus: status,
}
marshalTLS(w, r, resp)
return nil
}
func (m *tlsManager) setConfig(newConf tlsConfigSettings, status *tlsConfigStatus) (restartHTTPS bool) {
m.confLock.Lock()
defer m.confLock.Unlock()
// Reset the DNSCrypt data before comparing, since we currently do not
// accept these from the frontend.
//
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
log.Info("tls config has changed, restarting https server")
restartHTTPS = true
} else {
log.Info("tls: config has not changed")
// loadPKey loads the private key from file, if necessary.
func loadPKey(tlsConf *tlsConfiguration) (err error) {
if tlsConf.PrivateKey != "" {
return errors.Error("private key data and file cannot be set together")
}
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
m.conf.Enabled = newConf.Enabled
m.conf.ServerName = newConf.ServerName
m.conf.ForceHTTPS = newConf.ForceHTTPS
m.conf.PortHTTPS = newConf.PortHTTPS
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
m.conf.CertificateChain = newConf.CertificateChain
m.conf.CertificatePath = newConf.CertificatePath
m.conf.CertificateChainData = newConf.CertificateChainData
m.conf.PrivateKey = newConf.PrivateKey
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
m.conf.PrivateKeyData = newConf.PrivateKeyData
m.status = status
return restartHTTPS
}
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if req.PrivateKeySaved {
req.PrivateKey = m.conf.PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
// TODO(e.burkov): Investigate and perhaps check other ports.
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable https on it",
req.PortHTTPS,
)
return
}
status := &tlsConfigStatus{}
err = loadTLSConf(&req.tlsConfigSettings, status)
if err != nil {
resp := tlsConfig{
tlsConfigSettingsExt: req,
tlsConfigStatus: status,
}
marshalTLS(w, r, resp)
return
}
restartHTTPS := m.setConfig(req.tlsConfigSettings, status)
m.setCertFileTime()
onConfigModified()
err = reconfigureDNSServer()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
resp := tlsConfig{
tlsConfigSettingsExt: req,
tlsConfigStatus: m.status,
}
marshalTLS(w, r, resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request. It is also should be done in a separate goroutine due to the
// same reason.
if restartHTTPS {
go func() {
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings)
}()
}
}
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
// DNS protocols.
func validatePorts(
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
dnsPort, doqPort udpPort,
) (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(
tcpPorts,
tcpPort(bindPort),
tcpPort(betaBindPort),
tcpPort(dohPort),
tcpPort(dotPort),
tcpPort(dnscryptTCPPort),
)
err = tcpPorts.Validate()
if err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
}
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
err = udpPorts.Validate()
if err != nil {
return fmt.Errorf("validating udp ports: %w", err)
return fmt.Errorf("reading key file: %w", err)
}
return nil
@@ -513,6 +252,11 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
return nil
}
// errNoIPInCert is the error that is returned from [parseCertChain] if the leaf
// certificate doesn't contain IPs.
const errNoIPInCert errors.Error = `certificates has no IP addresses; ` +
`DNS-over-TLS won't be advertised via DDR`
// parseCertChain parses the certificate chain from raw data, and returns it.
// If ok is true, the returned error, if any, is not critical.
func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err error) {
@@ -535,8 +279,7 @@ func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err
log.Info("tls: number of certs: %d", len(parsedCerts))
if !aghtls.CertificateHasIP(parsedCerts[0]) {
err = errors.Error(`certificate has no IP addresses` +
`, this may cause issues with DNS-over-TLS clients`)
err = errNoIPInCert
}
return parsedCerts, true, err
@@ -696,61 +439,3 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error)
return nil, "", errors.Error("tls: failed to parse private key")
}
// unmarshalTLS handles base64-encoded certificates transparently
func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
data := tlsConfigSettingsExt{}
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
return data, fmt.Errorf("failed to parse new TLS config json: %w", err)
}
if data.CertificateChain != "" {
var cert []byte
cert, err = base64.StdEncoding.DecodeString(data.CertificateChain)
if err != nil {
return data, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
}
data.CertificateChain = string(cert)
if data.CertificatePath != "" {
return data, fmt.Errorf("certificate data and file can't be set together")
}
}
if data.PrivateKey != "" {
var key []byte
key, err = base64.StdEncoding.DecodeString(data.PrivateKey)
if err != nil {
return data, fmt.Errorf("failed to base64-decode private key: %w", err)
}
data.PrivateKey = string(key)
if data.PrivateKeyPath != "" {
return data, fmt.Errorf("private key data and file can't be set together")
}
}
return data, nil
}
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
if data.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
data.CertificateChain = encoded
}
if data.PrivateKey != "" {
data.PrivateKeySaved = true
data.PrivateKey = ""
}
_ = aghhttp.WriteJSONResponse(w, r, data)
}
// registerWebHandlers registers HTTP handlers for TLS configuration.
func (m *tlsManager) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
}

362
internal/home/tlshttp.go Normal file
View File

@@ -0,0 +1,362 @@
package home
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/log"
"github.com/google/go-cmp/cmp"
)
// Encryption Settings HTTP API
// tlsConfigStatus contains the status of a certificate chain and key pair.
type tlsConfigStatus struct {
// Subject is the subject of the first certificate in the chain.
Subject string `json:"subject,omitempty"`
// Issuer is the issuer of the first certificate in the chain.
Issuer string `json:"issuer,omitempty"`
// KeyType is the type of the private key.
KeyType string `json:"key_type,omitempty"`
// NotBefore is the NotBefore field of the first certificate in the chain.
NotBefore time.Time `json:"not_before,omitempty"`
// NotAfter is the NotAfter field of the first certificate in the chain.
NotAfter time.Time `json:"not_after,omitempty"`
// WarningValidation is a validation warning message with the issue
// description.
WarningValidation string `json:"warning_validation,omitempty"`
// DNSNames is the value of SubjectAltNames field of the first certificate
// in the chain.
DNSNames []string `json:"dns_names"`
// ValidCert is true if the specified certificate chain is a valid chain of
// X509 certificates.
ValidCert bool `json:"valid_cert"`
// ValidChain is true if the specified certificate chain is verified and
// issued by a known CA.
ValidChain bool `json:"valid_chain"`
// ValidKey is true if the key is a valid private key.
ValidKey bool `json:"valid_key"`
// ValidPair is true if both certificate and private key are correct for
// each other.
ValidPair bool `json:"valid_pair"`
}
// tlsConfigResp is the TLS configuration and status response.
type tlsConfigResp struct {
*tlsConfigStatus
*tlsConfiguration
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
}
// tlsConfigReq is the TLS configuration request.
type tlsConfigReq struct {
tlsConfiguration
// PrivateKeySaved is true if the private key is saved as a string and omit
// key from answer.
PrivateKeySaved bool `yaml:"-" json:"private_key_saved"`
}
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
var resp *tlsConfigResp
func() {
m.mu.RLock()
defer m.mu.RUnlock()
resp = &tlsConfigResp{
tlsConfigStatus: m.status,
tlsConfiguration: m.conf.cloneForEncoding(),
}
}()
marshalTLS(w, r, resp)
}
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if req.PrivateKeySaved {
req.PrivateKey = m.confForEncoding().PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable HTTPS on it",
req.PortHTTPS,
)
return
}
resp := &tlsConfigResp{
tlsConfigStatus: &tlsConfigStatus{},
tlsConfiguration: &req.tlsConfiguration,
}
// Skip the error check, since we are only interested in the value of
// resl.tlsConfigStatus.WarningValidation.
_ = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
marshalTLS(w, r, resp)
}
// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home
// DNS protocols.
func validatePorts(
bindPort, betaBindPort, dohPort, dotPort, dnscryptTCPPort tcpPort,
dnsPort, doqPort udpPort,
) (err error) {
tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(
tcpPorts,
tcpPort(bindPort),
tcpPort(betaBindPort),
tcpPort(dohPort),
tcpPort(dotPort),
tcpPort(dnscryptTCPPort),
)
err = tcpPorts.Validate()
if err != nil {
return fmt.Errorf("validating tcp ports: %w", err)
}
udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(dnsPort), udpPort(doqPort))
err = udpPorts.Validate()
if err != nil {
return fmt.Errorf("validating udp ports: %w", err)
}
return nil
}
// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP
// API.
func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
req, err := unmarshalTLS(r)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err)
return
}
if req.PrivateKeySaved {
req.PrivateKey = m.confForEncoding().PrivateKey
}
if req.Enabled {
err = validatePorts(
tcpPort(config.BindPort),
tcpPort(config.BetaBindPort),
tcpPort(req.PortHTTPS),
tcpPort(req.PortDNSOverTLS),
tcpPort(req.PortDNSCrypt),
udpPort(config.DNS.Port),
udpPort(req.PortDNSOverQUIC),
)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
}
// TODO(e.burkov): Investigate and perhaps check other ports.
if !webCheckPortAvailable(req.PortHTTPS) {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"port %d is not available, cannot enable https on it",
req.PortHTTPS,
)
return
}
resp := &tlsConfigResp{
tlsConfigStatus: &tlsConfigStatus{},
tlsConfiguration: &req.tlsConfiguration,
}
err = loadTLSConf(resp.tlsConfiguration, resp.tlsConfigStatus)
if err != nil {
marshalTLS(w, r, resp)
return
}
restartRequired := m.setConf(resp)
onConfigModified()
err = reconfigureDNSServer()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
resp.tlsConfiguration = m.confForEncoding()
marshalTLS(w, r, resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request. It is also should be done in a separate goroutine due to the
// same reason.
if restartRequired {
go func() {
Context.web.TLSConfigChanged(context.Background(), resp.tlsConfiguration)
}()
}
}
// setConf sets the necessary values from the new configuration.
func (m *tlsManager) setConf(newConf *tlsConfigResp) (restartRequired bool) {
m.mu.Lock()
defer m.mu.Unlock()
// Reset the DNSCrypt data before comparing, since we currently do not
// accept these from the frontend.
//
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
log.Info("tls: config has changed, restarting https server")
restartRequired = true
} else {
log.Info("tls: config has not changed")
}
// Do not just write "m.conf = *newConf.tlsConfiguration", because all other
// members of m.conf must be preserved.
m.conf.Enabled = newConf.Enabled
m.conf.ServerName = newConf.ServerName
m.conf.ForceHTTPS = newConf.ForceHTTPS
m.conf.PortHTTPS = newConf.PortHTTPS
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
m.conf.CertificateChain = newConf.CertificateChain
m.conf.CertificatePath = newConf.CertificatePath
m.conf.CertificateChainData = newConf.CertificateChainData
m.conf.PrivateKey = newConf.PrivateKey
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
m.conf.PrivateKeyData = newConf.PrivateKeyData
m.setCertFileTime()
m.status = newConf.tlsConfigStatus
return restartRequired
}
// marshalTLS handles Base64-encoded certificates transparently.
func marshalTLS(w http.ResponseWriter, r *http.Request, conf *tlsConfigResp) {
if conf.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(conf.CertificateChain))
conf.CertificateChain = encoded
}
if conf.PrivateKey != "" {
conf.PrivateKeySaved = true
conf.PrivateKey = ""
}
_ = aghhttp.WriteJSONResponse(w, r, conf)
}
// unmarshalTLS handles Base64-encoded certificates transparently.
func unmarshalTLS(r *http.Request) (req *tlsConfigReq, err error) {
req = &tlsConfigReq{}
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
return nil, fmt.Errorf("parsing tls config: %w", err)
}
if req.CertificateChain != "" {
var cert []byte
cert, err = base64.StdEncoding.DecodeString(req.CertificateChain)
if err != nil {
return nil, fmt.Errorf("failed to base64-decode certificate chain: %w", err)
}
req.CertificateChain = string(cert)
if req.CertificatePath != "" {
return nil, fmt.Errorf("certificate data and file can't be set together")
}
}
if req.PrivateKey != "" {
var key []byte
key, err = base64.StdEncoding.DecodeString(req.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to base64-decode private key: %w", err)
}
req.PrivateKey = string(key)
if req.PrivateKeyPath != "" {
return nil, fmt.Errorf("private key data and file can't be set together")
}
}
return req, nil
}
// registerWebHandlers registers HTTP handlers for TLS configuration.
func (m *tlsManager) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
}

View File

@@ -143,7 +143,7 @@ func webCheckPortAvailable(port int) (ok bool) {
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf *tlsConfiguration) {
log.Debug("web: applying new tls configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"strings"
"time"
@@ -26,7 +27,7 @@ const (
// WHOIS - module context
type WHOIS struct {
clients *clientsContainer
ipChan chan net.IP
ipChan chan netip.Addr
// dialContext specifies the dial function for creating unencrypted TCP
// connections.
@@ -51,7 +52,7 @@ func initWHOIS(clients *clientsContainer) *WHOIS {
MaxCount: 10000,
}),
dialContext: customDialContext,
ipChan: make(chan net.IP, 255),
ipChan: make(chan netip.Addr, 255),
}
go w.workerLoop()
@@ -192,7 +193,7 @@ func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) {
}
// Request WHOIS information
func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISInfo) {
func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) {
resp, err := w.queryAll(ctx, ip.String())
if err != nil {
log.Debug("whois: error: %s IP:%s", err, ip)
@@ -220,24 +221,25 @@ func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISI
}
// Begin - begin requesting WHOIS info
func (w *WHOIS) Begin(ip net.IP) {
func (w *WHOIS) Begin(ip netip.Addr) {
ipBytes := ip.AsSlice()
now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get([]byte(ip))
expire := w.ipAddrs.Get(ipBytes)
if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire)
if exp > now {
return
}
// TTL expired
}
expire = make([]byte, 8)
binary.BigEndian.PutUint64(expire, now+whoisTTL)
_ = w.ipAddrs.Set([]byte(ip), expire)
_ = w.ipAddrs.Set(ipBytes, expire)
log.Debug("whois: adding %s", ip)
select {
case w.ipChan <- ip:
//
default:
log.Debug("whois: queue is full")
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os"
"sync"
"sync/atomic"
@@ -64,7 +65,7 @@ type Interface interface {
// GetTopClientIP returns at most limit IP addresses corresponding to the
// clients with the most number of requests.
TopClientsIP(limit uint) []net.IP
TopClientsIP(limit uint) []netip.Addr
// WriteDiskConfig puts the Interface's configuration to the dc.
WriteDiskConfig(dc *DiskConfig)
@@ -107,8 +108,6 @@ type StatsCtx struct {
filename string
}
var _ Interface = &StatsCtx{}
// New creates s from conf and properly initializes it. Don't use s before
// calling it's Start method.
func New(conf Config) (s *StatsCtx, err error) {
@@ -178,6 +177,9 @@ func withRecovered(orig *error) {
*orig = errors.WithDeferred(*orig, err)
}
// type check
var _ Interface = (*StatsCtx)(nil)
// Start implements the Interface interface for *StatsCtx.
func (s *StatsCtx) Start() {
s.initWeb()
@@ -250,8 +252,8 @@ func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) {
dc.Interval = atomic.LoadUint32(&s.limitHours) / 24
}
// TopClientsIP implements the Interface interface for *StatsCtx.
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
// TopClientsIP implements the [Interface] interface for *StatsCtx.
func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
limit := atomic.LoadUint32(&s.limitHours)
if limit == 0 {
return nil
@@ -271,10 +273,10 @@ func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) {
}
a := convertMapToSlice(m, int(maxCount))
ips = []net.IP{}
ips = []netip.Addr{}
for _, it := range a {
ip := net.ParseIP(it.Name)
if ip != nil {
ip, err := netip.ParseAddr(it.Name)
if err == nil {
ips = append(ips, ip)
}
}

View File

@@ -11,6 +11,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -45,7 +46,7 @@ func assertSuccessAndUnmarshal(t *testing.T, to any, handler http.Handler, req *
}
func TestStats(t *testing.T) {
cliIP := net.IP{127, 0, 0, 1}
cliIP := netutil.IPv4Localhost()
cliIPStr := cliIP.String()
handlers := map[string]http.Handler{}
@@ -123,7 +124,7 @@ func TestStats(t *testing.T) {
topClients := s.TopClientsIP(2)
require.NotEmpty(t, topClients)
assert.True(t, cliIP.Equal(topClients[0]))
assert.Equal(t, cliIP, topClients[0])
})
t.Run("reset", func(t *testing.T) {