Compare commits

...

7 Commits

Author SHA1 Message Date
Eugene Burkov
de837e4eec all: use netip for web 2022-09-29 14:56:37 +03:00
Eugene Burkov
e528d2f23b add basic lla 2022-09-28 20:30:40 +03:00
Eugene Burkov
47c9c946a3 Pull request: 4871 imp filtering
Merge in DNS/adguard-home from 4871-imp-filtering to master

Closes #4871.

Squashed commit of the following:

commit 618e7c558447703c114332708c94ef1b34362cf9
Merge: 41ff8ab7 11e4f091
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Sep 22 19:27:08 2022 +0300

    Merge branch 'master' into 4871-imp-filtering

commit 41ff8ab755a87170e7334dedcae00f01dcca238a
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Sep 22 19:26:11 2022 +0300

    filtering: imp code, log

commit e4ae1d1788406ffd7ef0fcc6df896a22b0c2db37
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Sep 22 14:11:07 2022 +0300

    filtering: move handlers into single func

commit f7a340b4c10980f512ae935a156f02b0133a1627
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Sep 21 19:21:09 2022 +0300

    all: imp code

commit e064bf4d3de0283e4bda2aaf5b9822bb8a08f4a6
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Sep 20 20:12:16 2022 +0300

    all: imp name

commit e7eda3905762f0821e1be1ac3cf77e0ecbedeff4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Sep 20 17:51:23 2022 +0300

    all: finally get rid of filtering

commit 188550d873e625cc2951583bb3a2eaad036745f5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Sep 20 17:36:03 2022 +0300

    filtering: merge refresh

commit e54ed9c7952b17e66b790c835269b28fbc26f9ca
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Sep 20 17:16:23 2022 +0300

    filtering: merge filters

commit 32da31b754a319487d5f9d5e81e607d349b90180
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Sep 20 14:48:13 2022 +0300

    filtering: imp docs

commit 43b0cafa7a27bb9b620c2ba50ccdddcf32cfcecc
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Sep 20 14:38:04 2022 +0300

    all: imp code

commit 253a2ea6c92815d364546e34d631e406dd604644
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Sep 19 20:43:15 2022 +0300

    filtering: rm important flag

commit 1b87f08f946389d410f13412c7e486290d5e752d
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Sep 19 17:05:40 2022 +0300

    all: move filtering to the package

commit daa13499f1dd4fe475c4b75769e34f1eb0915bdf
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Sep 19 15:13:55 2022 +0300

    all: finish merging

commit d6db75eb2e1f23528e9200ea51507eb793eefa3c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Sep 16 18:18:14 2022 +0300

    all: continue merging

commit 45b4c484deb7198a469aa18d719bb9dbe81e5d22
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Sep 14 15:44:22 2022 +0300

    all: merge filtering types
2022-09-23 13:23:35 +03:00
Ainar Garipov
11e4f09165 Pull request: imp-scripts
Merge in DNS/adguard-home from imp-scripts to master

Squashed commit of the following:

commit ab63a8a2dd1b64287e00a2a6f747fd48b530709e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Sep 21 19:15:06 2022 +0300

    all: imp scripts; upd tools; doc
2022-09-21 19:21:13 +03:00
Ainar Garipov
c45c02de29 Pull request: imp-stalebot
Merge in DNS/adguard-home from imp-stalebot to master

Squashed commit of the following:

commit d1fb5c6da25eeb168c53abfc7af714827a5242cd
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Sep 21 14:31:50 2022 +0300

    all: imp stalebot
2022-09-21 15:02:35 +03:00
Ainar Garipov
4fc045de11 Pull request: 4927-ddr-template
Updates #4927.

Squashed commit of the following:

commit 8cf080d5355261ced7e8b10de607cbf37e1d663d
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Sep 20 15:18:48 2022 +0300

    dnsforward: fix doh template
2022-09-20 15:26:10 +03:00
Ildar Kamalov
cc2388e0c8 Pull request: 4815 fix query log modal on tablet
Updates #4815

Squashed commit of the following:

commit 148c39ac40963a593885b86a0c851b4010b68ab0
Merge: 3447611d ab6da05b
Author: Ildar Kamalov <ik@adguard.com>
Date:   Tue Sep 20 13:21:06 2022 +0300

    Merge branch 'master' into 4815-tablet-view

commit 3447611dc0b1c7d2cc1f8235d1c469dd92736166
Author: Ildar Kamalov <ik@adguard.com>
Date:   Fri Sep 16 17:01:05 2022 +0300

    client: fix query log modal on tablet
2022-09-20 13:48:57 +03:00
45 changed files with 1051 additions and 959 deletions

22
.github/stale.yml vendored
View File

@@ -4,15 +4,17 @@
'daysUntilClose': 15 'daysUntilClose': 15
# Issues with these labels will never be considered stale. # Issues with these labels will never be considered stale.
'exemptLabels': 'exemptLabels':
- 'bug' - 'bug'
- 'documentation' - 'documentation'
- 'enhancement' - 'enhancement'
- 'feature request' - 'feature request'
- 'help wanted' - 'help wanted'
- 'localization' - 'localization'
- 'needs investigation' - 'needs investigation'
- 'recurrent' - 'recurrent'
- 'research' - 'research'
# Set to true to ignore issues in a milestone.
'exemptMilestones': true
# Label to use when marking an issue as stale. # Label to use when marking an issue as stale.
'staleLabel': 'wontfix' 'staleLabel': 'wontfix'
# Comment to post when marking an issue as stale. Set to `false` to disable. # Comment to post when marking an issue as stale. Set to `false` to disable.
@@ -22,3 +24,5 @@
for your contributions. for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable. # Comment to post when closing a stale issue. Set to `false` to disable.
'closeComment': false 'closeComment': false
# Limit the number of actions per hour.
'limitPerRun': 1

View File

@@ -25,7 +25,12 @@ and this project adheres to
- Support for plain (unencrypted) HTTP/2 ([#4930]). This is useful for AdGuard - Support for plain (unencrypted) HTTP/2 ([#4930]). This is useful for AdGuard
Home installations behind a reverse proxy. Home installations behind a reverse proxy.
### Fixed
- Incorrect path template in DDR responses ([#4927]).
[#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993 [#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993
[#4927]: https://github.com/AdguardTeam/AdGuardHome/issues/4927
[#4930]: https://github.com/AdguardTeam/AdGuardHome/issues/4930 [#4930]: https://github.com/AdguardTeam/AdGuardHome/issues/4930

View File

@@ -62,7 +62,7 @@ const ClientCell = ({
'white-space--nowrap': isDetailed, 'white-space--nowrap': isDetailed,
}); });
const hintClass = classNames('icons mr-4 icon--24 icon--lightgray', { const hintClass = classNames('icons mr-4 icon--24 logs__question icon--lightgray', {
'my-3': isDetailed, 'my-3': isDetailed,
}); });

View File

@@ -34,7 +34,7 @@ const DomainCell = ({
'my-3': isDetailed, 'my-3': isDetailed,
}); });
const privacyIconClass = classNames('icons mx-2 icon--24 d-none d-sm-block', { const privacyIconClass = classNames('icons mx-2 icon--24 d-none d-sm-block logs__question', {
'icon--green': hasTracker, 'icon--green': hasTracker,
'icon--disabled': !hasTracker, 'icon--disabled': !hasTracker,
'my-3': isDetailed, 'my-3': isDetailed,

View File

@@ -49,6 +49,12 @@
padding-top: 1rem; padding-top: 1rem;
} }
@media (max-width: 1024px) {
.grid .key-colon, .grid .title--border {
font-weight: 600;
}
}
@media (max-width: 767.98px) { @media (max-width: 767.98px) {
.grid { .grid {
grid-template-columns: 35% 55%; grid-template-columns: 35% 55%;
@@ -70,10 +76,6 @@
grid-column: 2 / span 1; grid-column: 2 / span 1;
margin: 0 !important; margin: 0 !important;
} }
.grid .key-colon, .grid .title--border {
font-weight: 600;
}
} }
.grid .key-colon:nth-child(odd)::after { .grid .key-colon:nth-child(odd)::after {

View File

@@ -97,7 +97,7 @@ const ResponseCell = ({
return ( return (
<div className="logs__cell logs__cell--response" role="gridcell"> <div className="logs__cell logs__cell--response" role="gridcell">
<IconTooltip <IconTooltip
className={classNames('icons mr-4 icon--24 icon--lightgray', { 'my-3': isDetailed })} className={classNames('icons mr-4 icon--24 icon--lightgray logs__question', { 'my-3': isDetailed })}
columnClass='grid grid--limited' columnClass='grid grid--limited'
tooltipClass='px-5 pb-5 pt-4 mw-75 custom-tooltip__response-details' tooltipClass='px-5 pb-5 pt-4 mw-75 custom-tooltip__response-details'
contentItemClass='text-truncate key-colon o-hidden' contentItemClass='text-truncate key-colon o-hidden'

View File

@@ -485,3 +485,13 @@
.bg--green { .bg--green {
color: var(--green79); color: var(--green79);
} }
@media (max-width: 1024px) {
.logs__question {
display: none;
}
}
.logs__modal {
max-width: 720px;
}

View File

@@ -184,27 +184,34 @@ const Logs = () => {
setButtonType={setButtonType} setButtonType={setButtonType}
setModalOpened={setModalOpened} setModalOpened={setModalOpened}
/> />
<Modal portalClassName='grid' isOpen={isSmallScreen && isModalOpened} <Modal
onRequestClose={closeModal} portalClassName='grid'
style={{ isOpen={isSmallScreen && isModalOpened}
content: { onRequestClose={closeModal}
width: '100%', style={{
height: 'fit-content', content: {
left: 0, width: '100%',
top: 47, height: 'fit-content',
padding: '1rem 1.5rem 1rem', left: '50%',
}, top: 47,
overlay: { padding: '1rem 1.5rem 1rem',
backgroundColor: 'rgba(0,0,0,0.5)', maxWidth: '720px',
}, transform: 'translateX(-50%)',
}} },
overlay: {
backgroundColor: 'rgba(0,0,0,0.5)',
},
}}
> >
<svg <div className="logs__modal-wrap">
className="icon icon--24 icon-cross d-block d-md-none cursor--pointer" <svg
onClick={closeModal}> className="icon icon--24 icon-cross d-block cursor--pointer"
<use xlinkHref="#cross" /> onClick={closeModal}
</svg> >
{processContent(detailedDataCurrent, buttonType)} <use xlinkHref="#cross" />
</svg>
{processContent(detailedDataCurrent, buttonType)}
</div>
</Modal> </Modal>
</>; </>;

View File

@@ -9,6 +9,12 @@ import (
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// HTTP scheme constants.
const (
SchemeHTTP = "http"
SchemeHTTPS = "https"
)
// RegisterFunc is the function that sets the handler to handle the URL for the // RegisterFunc is the function that sets the handler to handle the URL for the
// method. // method.
// //

View File

@@ -18,27 +18,18 @@ import (
// How to test on a real Linux machine: // How to test on a real Linux machine:
// //
// 1. Run: // 1. Run "sudo ipset create example_set hash:ip family ipv4".
// //
// sudo ipset create example_set hash:ip family ipv4 // 2. Run "sudo ipset list example_set". The Members field should be empty.
// //
// 2. Run: // 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
// //
// sudo ipset list example_set // 4. Start AdGuardHome.
// //
// The Members field should be empty. // 5. Make requests to example.com and its subdomains.
// //
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml. // 6. Run "sudo ipset list example_set". The Members field should contain the
// // resolved IP addresses.
// 4. Start AdGuardHome.
//
// 5. Make requests to example.com and its subdomains.
//
// 6. Run:
//
// sudo ipset list example_set
//
// The Members field should contain the resolved IP addresses.
// newIpsetMgr returns a new Linux ipset manager. // newIpsetMgr returns a new Linux ipset manager.
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) { func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"syscall" "syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
@@ -31,6 +32,12 @@ var (
// the IP being static is available. // the IP being static is available.
const ErrNoStaticIPInfo errors.Error = "no information about static ip" const ErrNoStaticIPInfo errors.Error = "no information about static ip"
// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4].
func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{0: 127, 3: 1}) }
// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6].
func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) }
// IfaceHasStaticIP checks if interface is configured to have static IP address. // IfaceHasStaticIP checks if interface is configured to have static IP address.
// If it can't give a definitive answer, it returns false and an error for which // If it can't give a definitive answer, it returns false and an error for which
// errors.Is(err, ErrNoStaticIPInfo) is true. // errors.Is(err, ErrNoStaticIPInfo) is true.
@@ -47,26 +54,31 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
// //
// TODO(e.burkov): Investigate if the gateway address may be fetched in another // TODO(e.burkov): Investigate if the gateway address may be fetched in another
// way since not every machine has the software installed. // way since not every machine has the software installed.
func GatewayIP(ifaceName string) (ip net.IP) { func GatewayIP(ifaceName string) (ip netip.Addr) {
code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName) code, out, err := aghosRunCommand("ip", "route", "show", "dev", ifaceName)
if err != nil { if err != nil {
log.Debug("%s", err) log.Debug("%s", err)
return nil return ip
} else if code != 0 { } else if code != 0 {
log.Debug("fetching gateway ip: unexpected exit code: %d", code) log.Debug("fetching gateway ip: unexpected exit code: %d", code)
return nil return ip
} }
fields := bytes.Fields(out) fields := bytes.Fields(out)
// The meaningful "ip route" command output should contain the word // The meaningful "ip route" command output should contain the word
// "default" at first field and default gateway IP address at third field. // "default" at first field and default gateway IP address at third field.
if len(fields) < 3 || string(fields[0]) != "default" { if len(fields) < 3 || string(fields[0]) != "default" {
return nil return ip
} }
return net.ParseIP(string(fields[2])) ip, err = netip.ParseAddr(string(fields[2]))
if err != nil {
return netip.Addr{}
}
return ip
} }
// CanBindPrivilegedPorts checks if current process can bind to privileged // CanBindPrivilegedPorts checks if current process can bind to privileged
@@ -78,9 +90,9 @@ func CanBindPrivilegedPorts() (can bool, err error) {
// NetInterface represents an entry of network interfaces map. // NetInterface represents an entry of network interfaces map.
type NetInterface struct { type NetInterface struct {
// Addresses are the network interface addresses. // Addresses are the network interface addresses.
Addresses []net.IP `json:"ip_addresses,omitempty"` Addresses []netip.Addr `json:"ip_addresses,omitempty"`
// Subnets are the IP networks for this network interface. // Subnets are the IP networks for this network interface.
Subnets []*net.IPNet `json:"-"` Subnets []netip.Prefix `json:"-"`
Name string `json:"name"` Name string `json:"name"`
HardwareAddr net.HardwareAddr `json:"hardware_address"` HardwareAddr net.HardwareAddr `json:"hardware_address"`
Flags net.Flags `json:"flags"` Flags net.Flags `json:"flags"`
@@ -101,57 +113,78 @@ func (iface NetInterface) MarshalJSON() ([]byte, error) {
}) })
} }
func NetInterfaceFrom(iface *net.Interface) (niface *NetInterface, err error) {
niface = &NetInterface{
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
MTU: iface.MTU,
}
addrs, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err)
}
// Collect network interface addresses.
for _, addr := range addrs {
n, ok := addr.(*net.IPNet)
if !ok {
// Should be *net.IPNet, this is weird.
return nil, fmt.Errorf("expected %[2]s to be %[1]T, got %[2]T", n, addr)
} else if ip4 := n.IP.To4(); ip4 != nil {
n.IP = ip4
}
ip, ok := netip.AddrFromSlice(n.IP)
if !ok {
return nil, fmt.Errorf("bad address %s", n.IP)
}
if ip.IsLinkLocalUnicast() {
// Ignore link-local IPv4.
if ip.Is4() {
continue
}
ip = ip.WithZone(iface.Name)
}
ones, _ := n.Mask.Size()
p := netip.PrefixFrom(ip, ones)
niface.Addresses = append(niface.Addresses, ip)
niface.Subnets = append(niface.Subnets, p)
}
return niface, nil
}
// GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and // GetValidNetInterfacesForWeb returns interfaces that are eligible for DNS and
// WEB only we do not return link-local addresses here. // WEB only we do not return link-local addresses here.
// //
// TODO(e.burkov): Can't properly test the function since it's nontrivial to // TODO(e.burkov): Can't properly test the function since it's nontrivial to
// substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used. // substitute net.Interface.Addrs and the net.InterfaceAddrs can't be used.
func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) { func GetValidNetInterfacesForWeb() (nifaces []*NetInterface, err error) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't get interfaces: %w", err) return nil, fmt.Errorf("getting interfaces: %w", err)
} else if len(ifaces) == 0 { } else if len(ifaces) == 0 {
return nil, errors.Error("couldn't find any legible interface") return nil, errors.Error("no legible interfaces")
} }
for _, iface := range ifaces { for i := range ifaces {
var addrs []net.Addr var niface *NetInterface
addrs, err = iface.Addrs() niface, err = NetInterfaceFrom(&ifaces[i])
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get addresses for interface %s: %w", iface.Name, err) return nil, err
} } else if len(niface.Addresses) != 0 {
// Discard interfaces with no addresses.
netIface := &NetInterface{ nifaces = append(nifaces, niface)
MTU: iface.MTU,
Name: iface.Name,
HardwareAddr: iface.HardwareAddr,
Flags: iface.Flags,
}
// Collect network interface addresses.
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
// Should be net.IPNet, this is weird.
return nil, fmt.Errorf("got %s that is not net.IPNet, it is %T", addr, addr)
}
// Ignore link-local.
if ipNet.IP.IsLinkLocalUnicast() {
continue
}
netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet)
}
// Discard interfaces with no addresses.
if len(netIface.Addresses) != 0 {
netIfaces = append(netIfaces, netIface)
} }
} }
return netIfaces, nil return nifaces, nil
} }
// InterfaceByIP returns the name of the interface bound to ip. // InterfaceByIP returns the name of the interface bound to ip.
@@ -160,7 +193,7 @@ func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
// IP address can be shared by multiple interfaces in some configurations. // IP address can be shared by multiple interfaces in some configurations.
// //
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb. // TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func InterfaceByIP(ip net.IP) (ifaceName string) { func InterfaceByIP(ip netip.Addr) (ifaceName string) {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
return "" return ""
@@ -168,7 +201,7 @@ func InterfaceByIP(ip net.IP) (ifaceName string) {
for _, iface := range ifaces { for _, iface := range ifaces {
for _, addr := range iface.Addresses { for _, addr := range iface.Addresses {
if ip.Equal(addr) { if ip == addr {
return iface.Name return iface.Name
} }
} }
@@ -177,15 +210,16 @@ func InterfaceByIP(ip net.IP) (ifaceName string) {
return "" return ""
} }
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if // GetSubnet returns the subnet corresponding to the interface of zero prefix if
// the search fails. // the search fails.
// //
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb. // TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet { func GetSubnet(ifaceName string) (p netip.Prefix) {
netIfaces, err := GetValidNetInterfacesForWeb() netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
log.Error("Could not get network interfaces info: %v", err) log.Error("Could not get network interfaces info: %v", err)
return nil
return p
} }
for _, netIface := range netIfaces { for _, netIface := range netIfaces {
@@ -194,14 +228,14 @@ func GetSubnet(ifaceName string) *net.IPNet {
} }
} }
return nil return p
} }
// CheckPort checks if the port is available for binding. network is expected // CheckPort checks if the port is available for binding. network is expected
// to be one of "udp" and "tcp". // to be one of "udp" and "tcp".
func CheckPort(network string, ip net.IP, port int) (err error) { func CheckPort(network string, ipp netip.AddrPort) (err error) {
var c io.Closer var c io.Closer
addr := netutil.IPPort{IP: ip, Port: port}.String() addr := ipp.String()
switch network { switch network {
case "tcp": case "tcp":
c, err = net.Listen(network, addr) c, err = net.Listen(network, addr)

View File

@@ -6,7 +6,7 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os" "os"
"strings" "strings"
@@ -151,7 +151,7 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
// interface through dhcpcd.conf. // interface through dhcpcd.conf.
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName) ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil { if !ipNet.Addr().IsValid() {
return errors.Error("can't get IP address") return errors.Error("can't get IP address")
} }
@@ -174,7 +174,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that // dhcpcdConfIface returns configuration lines for the dhcpdc.conf files that
// configure the interface to have a static IP. // configure the interface to have a static IP.
func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf string) { func dhcpcdConfIface(ifaceName string, subnet netip.Prefix, gateway netip.Addr) (conf string) {
b := &strings.Builder{} b := &strings.Builder{}
stringutil.WriteToBuilder( stringutil.WriteToBuilder(
b, b,
@@ -183,15 +183,15 @@ func dhcpcdConfIface(ifaceName string, ipNet *net.IPNet, gwIP net.IP) (conf stri
" added by AdGuard Home.\ninterface ", " added by AdGuard Home.\ninterface ",
ifaceName, ifaceName,
"\nstatic ip_address=", "\nstatic ip_address=",
ipNet.String(), subnet.String(),
"\n", "\n",
) )
if gwIP != nil { if gateway.IsValid() {
stringutil.WriteToBuilder(b, "static routers=", gwIP.String(), "\n") stringutil.WriteToBuilder(b, "static routers=", gateway.String(), "\n")
} }
stringutil.WriteToBuilder(b, "static domain_name_servers=", ipNet.IP.String(), "\n\n") stringutil.WriteToBuilder(b, "static domain_name_servers=", subnet.Addr().String(), "\n\n")
return b.String() return b.String()
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net" "net"
"net/netip"
"os" "os"
"strings" "strings"
"testing" "testing"
@@ -93,34 +94,34 @@ func TestGatewayIP(t *testing.T) {
const cmd = "ip route show dev " + ifaceName const cmd = "ip route show dev " + ifaceName
testCases := []struct { testCases := []struct {
name string
shell mapShell shell mapShell
want net.IP want netip.Addr
name string
}{{ }{{
name: "success_v4",
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil), shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: net.IP{1, 2, 3, 4}.To16(), want: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
name: "success_v4",
}, { }, {
name: "success_v6",
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil), shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: net.IP{ want: netip.AddrFrom16([16]byte{
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0xFF, 0xFF, 0x0, 0x0, 0xFF, 0xFF,
}, }),
name: "success_v6",
}, { }, {
name: "bad_output",
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil), shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: nil, want: netip.Addr{},
name: "bad_output",
}, { }, {
name: "err_runcmd",
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")), shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: nil, want: netip.Addr{},
name: "err_runcmd",
}, { }, {
name: "bad_code",
shell: theOnlyCmd(cmd, 1, "", nil), shell: theOnlyCmd(cmd, 1, "", nil),
want: nil, want: netip.Addr{},
name: "bad_code",
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@@ -198,17 +199,21 @@ func TestBroadcastFromIPNet(t *testing.T) {
} }
func TestCheckPort(t *testing.T) { func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) { t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:") l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close) testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := netutil.IPPortFromAddr(l.Addr()) addr := l.Addr()
require.NotNil(t, ipp) require.IsType(t, new(net.TCPAddr), addr)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
err = CheckPort("tcp", ipp.IP, ipp.Port) ipp := addr.(*net.TCPAddr).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{} target := &net.OpError{}
require.ErrorAs(t, err, &target) require.ErrorAs(t, err, &target)
@@ -216,16 +221,18 @@ func TestCheckPort(t *testing.T) {
}) })
t.Run("udp_bound", func(t *testing.T) { t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", "127.0.0.1:") conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close) testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := netutil.IPPortFromAddr(conn.LocalAddr()) addr := conn.LocalAddr()
require.NotNil(t, ipp) require.IsType(t, new(net.UDPAddr), addr)
require.NotNil(t, ipp.IP)
require.NotZero(t, ipp.Port)
err = CheckPort("udp", ipp.IP, ipp.Port) ipp := addr.(*net.UDPAddr).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{} target := &net.OpError{}
require.ErrorAs(t, err, &target) require.ErrorAs(t, err, &target)
@@ -233,12 +240,12 @@ func TestCheckPort(t *testing.T) {
}) })
t.Run("bad_network", func(t *testing.T) { t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", nil, 0) err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("can_bind", func(t *testing.T) { t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", net.IP{0, 0, 0, 0}, 0) err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@@ -322,18 +329,18 @@ func TestNetInterface_MarshalJSON(t *testing.T) {
`"mtu":1500` + `"mtu":1500` +
`}` + "\n" `}` + "\n"
ip4, ip6 := net.IP{1, 2, 3, 4}, net.IP{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
mask4, mask6 := net.CIDRMask(24, netutil.IPv4BitLen), net.CIDRMask(8, netutil.IPv6BitLen) require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{ iface := &NetInterface{
Addresses: []net.IP{ip4, ip6}, Addresses: []netip.Addr{ip4, ip6},
Subnets: []*net.IPNet{{ Subnets: []netip.Prefix{net4, net6},
IP: ip4.Mask(mask4),
Mask: mask4,
}, {
IP: ip6.Mask(mask6),
Mask: mask6,
}},
Name: "iface0", Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast, Flags: net.FlagUp | net.FlagMulticast,

View File

@@ -349,6 +349,8 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
return return
} }
// ignore link-local // ignore link-local
//
// TODO(e.burkov): Try to listen DHCP on LLA as well.
if ipnet.IP.IsLinkLocalUnicast() { if ipnet.IP.IsLinkLocalUnicast() {
continue continue
} }
@@ -359,7 +361,7 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
} }
} }
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 { if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
jsonIface.GatewayIP = aghnet.GatewayIP(iface.Name) jsonIface.GatewayIP = aghnet.GatewayIP(iface.Name).AsSlice()
response[iface.Name] = jsonIface response[iface.Name] = jsonIface
} }
} }

View File

@@ -296,7 +296,7 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
values := []dns.SVCBKeyValue{ values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}}, &dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: uint16(addr.Port)}, &dns.SVCBPort{Port: uint16(addr.Port)},
&dns.SVCBDoHPath{Template: "/dns-query?dns"}, &dns.SVCBDoHPath{Template: "/dns-query{?dns}"},
} }
ans := &dns.SVCB{ ans := &dns.SVCB{

View File

@@ -26,7 +26,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
Value: []dns.SVCBKeyValue{ Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}}, &dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: 8044}, &dns.SVCBPort{Port: 8044},
&dns.SVCBDoHPath{Template: "/dns-query?dns"}, &dns.SVCBDoHPath{Template: "/dns-query{?dns}"},
}, },
} }

View File

@@ -67,10 +67,11 @@ func createTestServer(
ID: 0, Data: []byte(rules), ID: 0, Data: []byte(rules),
}} }}
f := filtering.New(filterConf, filters) f, err := filtering.New(filterConf, filters)
require.NoError(t, err)
f.SetEnabled(true) f.SetEnabled(true)
var err error
s, err = NewServer(DNSCreateParams{ s, err = NewServer(DNSCreateParams{
DHCPServer: testDHCP, DHCPServer: testDHCP,
DNSFilter: f, DNSFilter: f,
@@ -774,7 +775,9 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules), Data: []byte(rules),
}} }}
f := filtering.New(&filtering.Config{}, filters) f, err := filtering.New(&filtering.Config{}, filters)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{ s, err := NewServer(DNSCreateParams{
DHCPServer: testDHCP, DHCPServer: testDHCP,
DNSFilter: f, DNSFilter: f,
@@ -906,7 +909,9 @@ func TestRewrite(t *testing.T) {
Type: dns.TypeCNAME, Type: dns.TypeCNAME,
}}, }},
} }
f := filtering.New(c, nil) f, err := filtering.New(c, nil)
require.NoError(t, err)
f.SetEnabled(true) f.SetEnabled(true)
s, err := NewServer(DNSCreateParams{ s, err := NewServer(DNSCreateParams{
@@ -1021,19 +1026,14 @@ var testDHCP = &dhcpd.MockInterface{
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") }, OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
} }
// func (*testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
// return []*dhcpd.Lease{{
// IP: net.IP{192, 168, 12, 34},
// HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
// Hostname: "myhost",
// }}
// }
func TestPTRResponseFromDHCPLeases(t *testing.T) { func TestPTRResponseFromDHCPLeases(t *testing.T) {
const localDomain = "lan" const localDomain = "lan"
flt, err := filtering.New(&filtering.Config{}, nil)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{ s, err := NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil), DNSFilter: flt,
DHCPServer: testDHCP, DHCPServer: testDHCP,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
LocalDomain: localDomain, LocalDomain: localDomain,
@@ -1100,9 +1100,11 @@ func TestPTRResponseFromHosts(t *testing.T) {
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter)) assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
}) })
flt := filtering.New(&filtering.Config{ flt, err := filtering.New(&filtering.Config{
EtcHosts: hc, EtcHosts: hc,
}, nil) }, nil)
require.NoError(t, err)
flt.SetEnabled(true) flt.SetEnabled(true)
var s *Server var s *Server

View File

@@ -35,7 +35,8 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
ID: 0, Data: []byte(rules), ID: 0, Data: []byte(rules),
}} }}
f := filtering.New(&filtering.Config{}, filters) f, err := filtering.New(&filtering.Config{}, filters)
require.NoError(t, err)
f.SetEnabled(true) f.SetEnabled(true)
s, err := NewServer(DNSCreateParams{ s, err := NewServer(DNSCreateParams{

View File

@@ -421,31 +421,34 @@ func initBlockedServices() {
} }
// BlockedSvcKnown - return TRUE if a blocked service name is known // BlockedSvcKnown - return TRUE if a blocked service name is known
func BlockedSvcKnown(s string) bool { func BlockedSvcKnown(s string) (ok bool) {
_, ok := serviceRules[s] _, ok = serviceRules[s]
return ok return ok
} }
// ApplyBlockedServices - set blocked services settings for this DNS request // ApplyBlockedServices - set blocked services settings for this DNS request
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global bool) { func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
setts.ServicesRules = []ServiceEntry{} setts.ServicesRules = []ServiceEntry{}
if global { if list == nil {
d.confLock.RLock() d.confLock.RLock()
defer d.confLock.RUnlock() defer d.confLock.RUnlock()
list = d.Config.BlockedServices list = d.Config.BlockedServices
} }
for _, name := range list { for _, name := range list {
rules, ok := serviceRules[name] rules, ok := serviceRules[name]
if !ok { if !ok {
log.Error("unknown service name: %s", name) log.Error("unknown service name: %s", name)
continue continue
} }
s := ServiceEntry{} setts.ServicesRules = append(setts.ServicesRules, ServiceEntry{
s.Name = name Name: name,
s.Rules = rules Rules: rules,
setts.ServicesRules = append(setts.ServicesRules, s) })
} }
} }
@@ -490,10 +493,3 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
d.ConfigModified() d.ConfigModified()
} }
// registerBlockedServicesHandlers - register HTTP handlers
func (d *DNSFilter) registerBlockedServicesHandlers() {
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
}

View File

@@ -1,4 +1,4 @@
package home package filtering
import ( import (
"encoding/json" "encoding/json"
@@ -34,7 +34,7 @@ func validateFilterURL(urlStr string) (err error) {
return fmt.Errorf("checking filter url: %w", err) return fmt.Errorf("checking filter url: %w", err)
} }
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS { if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
return fmt.Errorf("checking filter url: invalid scheme %q", s) return fmt.Errorf("checking filter url: invalid scheme %q", s)
} }
@@ -47,7 +47,7 @@ type filterAddJSON struct {
Whitelist bool `json:"whitelist"` Whitelist bool `json:"whitelist"`
} }
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{} fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj) err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil { if err != nil {
@@ -65,14 +65,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
} }
// Check for duplicates // Check for duplicates
if filterExists(fj.URL) { if d.filterExists(fj.URL) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return return
} }
// Set necessary properties // Set necessary properties
filt := filter{ filt := FilterYAML{
Enabled: true, Enabled: true,
URL: fj.URL, URL: fj.URL,
Name: fj.Name, Name: fj.Name,
@@ -81,7 +81,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
filt.ID = assignUniqueFilterID() filt.ID = assignUniqueFilterID()
// Download the filter contents // Download the filter contents
ok, err := f.update(&filt) ok, err := d.update(&filt)
if err != nil { if err != nil {
aghhttp.Error( aghhttp.Error(
r, r,
@@ -109,14 +109,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// URL is assumed valid so append it to filters, update config, write new // URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines. // file and reload it to engines.
if !filterAdd(filt) { if !d.filterAdd(filt) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return return
} }
onConfigModified() d.ConfigModified()
enableFilters(true) d.EnableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil { if err != nil {
@@ -124,7 +124,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
} }
} }
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct { type request struct {
URL string `json:"url"` URL string `json:"url"`
Whitelist bool `json:"whitelist"` Whitelist bool `json:"whitelist"`
@@ -138,23 +138,23 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
return return
} }
config.Lock() d.filtersMu.Lock()
filters := &config.Filters filters := &d.Filters
if req.Whitelist { if req.Whitelist {
filters = &config.WhitelistFilters filters = &d.WhitelistFilters
} }
var deleted filter var deleted FilterYAML
var newFilters []filter var newFilters []FilterYAML
for _, f := range *filters { for _, flt := range *filters {
if f.URL != req.URL { if flt.URL != req.URL {
newFilters = append(newFilters, f) newFilters = append(newFilters, flt)
continue continue
} }
deleted = f deleted = flt
path := f.Path() path := flt.Path(d.DataDir)
err = os.Rename(path, path+".old") err = os.Rename(path, path+".old")
if err != nil { if err != nil {
log.Error("deleting filter %q: %s", path, err) log.Error("deleting filter %q: %s", path, err)
@@ -162,10 +162,10 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
} }
*filters = newFilters *filters = newFilters
config.Unlock() d.filtersMu.Unlock()
onConfigModified() d.ConfigModified()
enableFilters(true) d.EnableFilters(true)
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really // NOTE: The old files "filter.txt.old" aren't deleted. It's not really
// necessary, but will require the additional complicated code to run // necessary, but will require the additional complicated code to run
@@ -191,55 +191,51 @@ type filterURLReq struct {
Whitelist bool `json:"whitelist"` Whitelist bool `json:"whitelist"`
} }
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{} fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj) err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
return return
} }
if fj.Data == nil { if fj.Data == nil {
err = errors.Error("data cannot be null") aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent"))
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
err = validateFilterURL(fj.Data.URL) err = validateFilterURL(fj.Data.URL)
if err != nil { if err != nil {
err = fmt.Errorf("invalid url: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
filt := filter{ filt := FilterYAML{
Enabled: fj.Data.Enabled, Enabled: fj.Data.Enabled,
Name: fj.Data.Name, Name: fj.Data.Name,
URL: fj.Data.URL, URL: fj.Data.URL,
} }
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist) status := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 { if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest) aghhttp.Error(r, w, http.StatusBadRequest, "URL doesn't exist")
return return
} }
if (status & statusURLExists) != 0 { if (status & statusURLExists) != 0 {
http.Error(w, "URL already exists", http.StatusBadRequest) aghhttp.Error(r, w, http.StatusBadRequest, "URL already exists")
return return
} }
onConfigModified() d.ConfigModified()
restart := (status & statusEnabledChanged) != 0 restart := (status & statusEnabledChanged) != 0
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled { if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules // download new filter and apply its rules.
flags := filterRefreshBlocklists nUpdated := d.refreshFilters(!fj.Whitelist, fj.Whitelist, false)
if fj.Whitelist {
flags = filterRefreshAllowlists
}
nUpdated, _ := f.refreshFilters(flags, true)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically // if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves // if not - we restart the filtering ourselves
restart = false restart = false
@@ -249,11 +245,11 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
} }
if restart { if restart {
enableFilters(true) d.EnableFilters(true)
} }
} }
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited. // This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@@ -262,12 +258,12 @@ func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
return return
} }
config.UserRules = strings.Split(string(body), "\n") d.UserRules = strings.Split(string(body), "\n")
onConfigModified() d.ConfigModified()
enableFilters(true) d.EnableFilters(true)
} }
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct { type Req struct {
White bool `json:"whitelist"` White bool `json:"whitelist"`
} }
@@ -285,35 +281,27 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
return return
} }
flags := filterRefreshBlocklists var ok bool
if req.White { resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
flags = filterRefreshAllowlists if !ok {
} aghhttp.Error(
func() { r,
// Temporarily unlock the Context.controlLock because the w,
// f.refreshFilters waits for it to be unlocked but it's http.StatusInternalServerError,
// actually locked in ensure wrapper. "filters update procedure is already running",
// )
// TODO(e.burkov): Reconsider this messy syncing process.
Context.controlLock.Unlock()
defer Context.controlLock.Lock()
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
}()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return return
} }
js, err := json.Marshal(resp) w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return return
} }
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
} }
type filterJSON struct { type filterJSON struct {
@@ -333,7 +321,7 @@ type filteringConfig struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
} }
func filterToJSON(f filter) filterJSON { func filterToJSON(f FilterYAML) filterJSON {
fj := filterJSON{ fj := filterJSON{
ID: f.ID, ID: f.ID,
Enabled: f.Enabled, Enabled: f.Enabled,
@@ -350,21 +338,21 @@ func filterToJSON(f filter) filterJSON {
} }
// Get filtering configuration // Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{} resp := filteringConfig{}
config.RLock() d.filtersMu.RLock()
resp.Enabled = config.DNS.FilteringEnabled resp.Enabled = d.FilteringEnabled
resp.Interval = config.DNS.FiltersUpdateIntervalHours resp.Interval = d.FiltersUpdateIntervalHours
for _, f := range config.Filters { for _, f := range d.Filters {
fj := filterToJSON(f) fj := filterToJSON(f)
resp.Filters = append(resp.Filters, fj) resp.Filters = append(resp.Filters, fj)
} }
for _, f := range config.WhitelistFilters { for _, f := range d.WhitelistFilters {
fj := filterToJSON(f) fj := filterToJSON(f)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj) resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
} }
resp.UserRules = config.UserRules resp.UserRules = d.UserRules
config.RUnlock() d.filtersMu.RUnlock()
jsonVal, err := json.Marshal(resp) jsonVal, err := json.Marshal(resp)
if err != nil { if err != nil {
@@ -380,7 +368,7 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
} }
// Set filtering configuration // Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{} req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@@ -389,22 +377,22 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
return return
} }
if !checkFiltersUpdateIntervalHours(req.Interval) { if !ValidateUpdateIvl(req.Interval) {
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval") aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return return
} }
func() { func() {
config.Lock() d.filtersMu.Lock()
defer config.Unlock() defer d.filtersMu.Unlock()
config.DNS.FilteringEnabled = req.Enabled d.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval d.FiltersUpdateIntervalHours = req.Interval
}() }()
onConfigModified() d.ConfigModified()
enableFilters(true) d.EnableFilters(true)
} }
type checkHostRespRule struct { type checkHostRespRule struct {
@@ -435,15 +423,15 @@ type checkHostResp struct {
FilterID int64 `json:"filter_id"` FilterID int64 `json:"filter_id"`
} }
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() host := r.URL.Query().Get("name")
host := q.Get("name")
setts := Context.dnsFilter.GetConfig() setts := d.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
setts.ProtectionEnabled = true setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) d.ApplyBlockedServices(&setts, nil)
result, err := d.CheckHost(host, dns.TypeA, &setts)
if err != nil { if err != nil {
aghhttp.Error( aghhttp.Error(
r, r,
@@ -457,18 +445,20 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
resp := checkHostResp{} rulesLen := len(result.Rules)
resp.Reason = result.Reason.String() resp := checkHostResp{
resp.SvcName = result.ServiceName Reason: result.Reason.String(),
resp.CanonName = result.CanonName SvcName: result.ServiceName,
resp.IPList = result.IPList CanonName: result.CanonName,
IPList: result.IPList,
Rules: make([]*checkHostRespRule, len(result.Rules)),
}
if len(result.Rules) > 0 { if rulesLen > 0 {
resp.FilterID = result.Rules[0].FilterListID resp.FilterID = result.Rules[0].FilterListID
resp.Rule = result.Rules[0].Text resp.Rule = result.Rules[0].Text
} }
resp.Rules = make([]*checkHostRespRule, len(result.Rules))
for i, r := range result.Rules { for i, r := range result.Rules {
resp.Rules[i] = &checkHostRespRule{ resp.Rules[i] = &checkHostRespRule{
FilterListID: r.FilterListID, FilterListID: r.FilterListID,
@@ -476,28 +466,51 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
} }
} }
js, err := json.Marshal(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js) err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
}
} }
// RegisterFilteringHandlers - register handlers // RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() { func (d *DNSFilter) RegisterFilteringHandlers() {
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus) registerHTTP := d.HTTPRegister
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig) if registerHTTP == nil {
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL) return
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL) }
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh) registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules) registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost) registerHTTP(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
registerHTTP(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
registerHTTP(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
registerHTTP(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
registerHTTP(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
registerHTTP(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
registerHTTP(http.MethodGet, "/control/filtering/status", d.handleFilteringStatus)
registerHTTP(http.MethodPost, "/control/filtering/config", d.handleFilteringConfig)
registerHTTP(http.MethodPost, "/control/filtering/add_url", d.handleFilteringAddURL)
registerHTTP(http.MethodPost, "/control/filtering/remove_url", d.handleFilteringRemoveURL)
registerHTTP(http.MethodPost, "/control/filtering/set_url", d.handleFilteringSetURL)
registerHTTP(http.MethodPost, "/control/filtering/refresh", d.handleFilteringRefresh)
registerHTTP(http.MethodPost, "/control/filtering/set_rules", d.handleFilteringSetRules)
registerHTTP(http.MethodGet, "/control/filtering/check_host", d.handleCheckHost)
} }
func checkFiltersUpdateIntervalHours(i uint32) bool { // ValidateUpdateIvl returns false if i is not a valid filters update interval.
func ValidateUpdateIvl(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24 return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
} }

View File

@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot. |1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
` `
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}}) f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{ setts := &Settings{
FilteringEnabled: true, FilteringEnabled: true,
} }

View File

@@ -1,4 +1,4 @@
package home package filtering
import ( import (
"bufio" "bufio"
@@ -8,63 +8,29 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
) )
var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID // filterDir is the subdirectory of a data directory to store downloaded
// filters.
const filterDir = "filters"
// Filtering - module object // nextFilterID is a way to seed a unique ID generation.
type Filtering struct { //
// conf FilteringConf // TODO(e.burkov): Use more deterministic approach.
refreshStatus uint32 // 0:none; 1:in progress var nextFilterID = time.Now().Unix()
refreshLock sync.Mutex
filterTitleRegexp *regexp.Regexp
}
// Init - initialize the module // FilterYAML respresents a filter list in the configuration file.
func (f *Filtering) Init() { //
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) // TODO(e.burkov): Investigate if the field oredering is important.
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755) type FilterYAML struct {
f.loadFilters(config.Filters)
f.loadFilters(config.WhitelistFilters)
deduplicateFilters()
updateUniqueFilterID(config.Filters)
updateUniqueFilterID(config.WhitelistFilters)
}
// Start - start the module
func (f *Filtering) Start() {
f.RegisterFilteringHandlers()
// Here we should start updating filters,
// but currently we can't wake up the periodic task to do so.
// So for now we just start this periodic task from here.
go f.periodicallyRefreshFilters()
}
// Close - close the module
func (f *Filtering) Close() {
}
func defaultFilters() []filter {
return []filter{
{Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
{Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
}
}
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
Enabled bool Enabled bool
URL string // URL or a file path URL string // URL or a file path
Name string `yaml:"name"` Name string `yaml:"name"`
@@ -73,91 +39,108 @@ type filter struct {
checksum uint32 // checksum of the file data checksum uint32 // checksum of the file data
white bool white bool
filtering.Filter `yaml:",inline"` Filter `yaml:",inline"`
}
// Clear filter rules
func (filter *FilterYAML) unload() {
filter.RulesCount = 0
filter.checksum = 0
}
// Path to the filter contents
func (filter *FilterYAML) Path(dataDir string) string {
return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
} }
const ( const (
statusFound = 1 statusFound = 1 << iota
statusEnabledChanged = 2 statusEnabledChanged
statusURLChanged = 4 statusURLChanged
statusURLExists = 8 statusURLExists
statusUpdateRequired = 0x10 statusUpdateRequired
) )
// Update properties for a filter specified by its URL // Update properties for a filter specified by its URL
// Return status* flags. // Return status* flags.
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int { func (d *DNSFilter) filterSetProperties(url string, newf FilterYAML, whitelist bool) int {
r := 0 r := 0
config.Lock() d.filtersMu.Lock()
defer config.Unlock() defer d.filtersMu.Unlock()
filters := &config.Filters filters := d.Filters
if whitelist { if whitelist {
filters = &config.WhitelistFilters filters = d.WhitelistFilters
} }
for i := range *filters { i := slices.IndexFunc(filters, func(filt FilterYAML) bool {
filt := &(*filters)[i] return filt.URL == url
if filt.URL != url { })
continue if i == -1 {
return 0
}
filt := &filters[i]
log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL, newf.Name, newf.URL, newf.Enabled)
filt.Name = newf.Name
if filt.URL != newf.URL {
r |= statusURLChanged | statusUpdateRequired
if d.filterExistsNoLock(newf.URL) {
return statusURLExists
} }
log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL = newf.URL
filt.URL, newf.Name, newf.URL, newf.Enabled) filt.unload()
filt.Name = newf.Name filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
}
if filt.URL != newf.URL { if filt.Enabled != newf.Enabled {
r |= statusURLChanged | statusUpdateRequired r |= statusEnabledChanged
if filterExistsNoLock(newf.URL) { filt.Enabled = newf.Enabled
return statusURLExists if filt.Enabled {
} if (r & statusURLChanged) == 0 {
filt.URL = newf.URL err := d.load(filt)
filt.unload() if err != nil {
filt.LastUpdated = time.Time{} // TODO(e.burkov): It seems the error is only returned when
filt.checksum = 0 // the file exists and couldn't be open. Investigate and
filt.RulesCount = 0 // improve.
} log.Error("loading filter %d: %s", filt.ID, err)
if filt.Enabled != newf.Enabled { filt.LastUpdated = time.Time{}
r |= statusEnabledChanged filt.checksum = 0
filt.Enabled = newf.Enabled filt.RulesCount = 0
if filt.Enabled { r |= statusUpdateRequired
if (r & statusURLChanged) == 0 {
e := f.load(filt)
if e != nil {
// This isn't a fatal error,
// because it may occur when someone removes the file from disk.
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
r |= statusUpdateRequired
}
} }
} else {
filt.unload()
} }
} else {
filt.unload()
} }
return r | statusFound
} }
return 0
return r | statusFound
} }
// Return TRUE if a filter with this URL exists // Return TRUE if a filter with this URL exists
func filterExists(url string) bool { func (d *DNSFilter) filterExists(url string) bool {
config.RLock() d.filtersMu.RLock()
r := filterExistsNoLock(url) defer d.filtersMu.RUnlock()
config.RUnlock()
r := d.filterExistsNoLock(url)
return r return r
} }
func filterExistsNoLock(url string) bool { func (d *DNSFilter) filterExistsNoLock(url string) bool {
for _, f := range config.Filters { for _, f := range d.Filters {
if f.URL == url { if f.URL == url {
return true return true
} }
} }
for _, f := range config.WhitelistFilters { for _, f := range d.WhitelistFilters {
if f.URL == url { if f.URL == url {
return true return true
} }
@@ -167,26 +150,26 @@ func filterExistsNoLock(url string) bool {
// Add a filter // Add a filter
// Return FALSE if a filter with this URL exists // Return FALSE if a filter with this URL exists
func filterAdd(f filter) bool { func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
config.Lock() d.filtersMu.Lock()
defer config.Unlock() defer d.filtersMu.Unlock()
// Check for duplicates // Check for duplicates
if filterExistsNoLock(f.URL) { if d.filterExistsNoLock(flt.URL) {
return false return false
} }
if f.white { if flt.white {
config.WhitelistFilters = append(config.WhitelistFilters, f) d.WhitelistFilters = append(d.WhitelistFilters, flt)
} else { } else {
config.Filters = append(config.Filters, f) d.Filters = append(d.Filters, flt)
} }
return true return true
} }
// Load filters from the disk // Load filters from the disk
// And if any filter has zero ID, assign a new one // And if any filter has zero ID, assign a new one
func (f *Filtering) loadFilters(array []filter) { func (d *DNSFilter) loadFilters(array []FilterYAML) {
for i := range array { for i := range array {
filter := &array[i] // otherwise we're operating on a copy filter := &array[i] // otherwise we're operating on a copy
if filter.ID == 0 { if filter.ID == 0 {
@@ -198,32 +181,30 @@ func (f *Filtering) loadFilters(array []filter) {
continue continue
} }
err := f.load(filter) err := d.load(filter)
if err != nil { if err != nil {
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err) log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
} }
} }
} }
func deduplicateFilters() { func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
// Deduplicate filters urls := stringutil.NewSet()
i := 0 // output index, used for deletion later lastIdx := 0
urls := map[string]bool{}
for _, filter := range config.Filters { for _, filter := range filters {
if _, ok := urls[filter.URL]; !ok { if !urls.Has(filter.URL) {
// we didn't see it before, keep it urls.Add(filter.URL)
urls[filter.URL] = true // remember the URL filters[lastIdx] = filter
config.Filters[i] = filter lastIdx++
i++
} }
} }
// all entries we want to keep are at front, delete the rest return filters[:lastIdx]
config.Filters = config.Filters[:i]
} }
// Set the next filter ID to max(filter.ID) + 1 // Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) { func updateUniqueFilterID(filters []FilterYAML) {
for _, filter := range filters { for _, filter := range filters {
if nextFilterID < filter.ID { if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1 nextFilterID = filter.ID + 1
@@ -238,22 +219,19 @@ func assignUniqueFilterID() int64 {
} }
// Sets up a timer that will be checking for filters updates periodically // Sets up a timer that will be checking for filters updates periodically
func (f *Filtering) periodicallyRefreshFilters() { func (d *DNSFilter) periodicallyRefreshFilters() {
const maxInterval = 1 * 60 * 60 const maxInterval = 1 * 60 * 60
intval := 5 // use a dynamically increasing time interval intval := 5 // use a dynamically increasing time interval
for { for {
isNetworkErr := false isNetErr, ok := false, false
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) { if d.FiltersUpdateIntervalHours != 0 {
f.refreshLock.Lock() _, isNetErr, ok = d.tryRefreshFilters(true, true, false)
_, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists) if ok && !isNetErr {
f.refreshLock.Unlock()
f.refreshStatus = 0
if !isNetworkErr {
intval = maxInterval intval = maxInterval
} }
} }
if isNetworkErr { if isNetErr {
intval *= 2 intval *= 2
if intval > maxInterval { if intval > maxInterval {
intval = maxInterval intval = maxInterval
@@ -264,51 +242,73 @@ func (f *Filtering) periodicallyRefreshFilters() {
} }
} }
// Refresh filters // tryRefreshFilters is like [refreshFilters], but backs down if the update is
// flags: filterRefresh* // already going on.
// important:
// //
// TRUE: ignore the fact that we're currently updating the filters // TODO(e.burkov): Get rid of the concurrency pattern which requires the
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) { // sync.Mutex.TryLock.
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
if !important && !set { if ok = d.refreshLock.TryLock(); !ok {
return 0, fmt.Errorf("filters update procedure is already running") return 0, false, ok
} }
defer d.refreshLock.Unlock()
f.refreshLock.Lock() updated, isNetworkErr = d.refreshFiltersIntl(block, allow, force)
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
f.refreshLock.Unlock() return updated, isNetworkErr, ok
f.refreshStatus = 0
return nUpdated, nil
} }
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) { // refreshFilters updates the lists and returns the number of updated ones.
var updateFilters []filter // It's safe for concurrent use, but blocks at least until the previous
// refreshing is finished.
func (d *DNSFilter) refreshFilters(block, allow, force bool) (updated int) {
d.refreshLock.Lock()
defer d.refreshLock.Unlock()
updated, _ = d.refreshFiltersIntl(block, allow, force)
return updated
}
// listsToUpdate returns the slice of filter lists that could be updated.
func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) {
now := time.Now()
d.filtersMu.RLock()
defer d.filtersMu.RUnlock()
for i := range *filters {
flt := &(*filters)[i] // otherwise we will be operating on a copy
log.Debug("checking list at index %d: %v", i, flt)
if !flt.Enabled {
continue
}
if !force {
exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour)
if now.Before(exp) {
continue
}
}
toUpd = append(toUpd, FilterYAML{
Filter: Filter{
ID: flt.ID,
},
URL: flt.URL,
Name: flt.Name,
checksum: flt.checksum,
})
}
return toUpd
}
func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, []FilterYAML, []bool, bool) {
var updateFlags []bool // 'true' if filter data has changed var updateFlags []bool // 'true' if filter data has changed
now := time.Now() updateFilters := d.listsToUpdate(filters, force)
config.RLock()
for i := range *filters {
f := &(*filters)[i] // otherwise we will be operating on a copy
if !f.Enabled {
continue
}
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
if !force && expireTime > now.Unix() {
continue
}
var uf filter
uf.ID = f.ID
uf.URL = f.URL
uf.Name = f.Name
uf.checksum = f.checksum
updateFilters = append(updateFilters, uf)
}
config.RUnlock()
if len(updateFilters) == 0 { if len(updateFilters) == 0 {
return 0, nil, nil, false return 0, nil, nil, false
} }
@@ -316,7 +316,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
nfail := 0 nfail := 0
for i := range updateFilters { for i := range updateFilters {
uf := &updateFilters[i] uf := &updateFilters[i]
updated, err := f.update(uf) updated, err := d.update(uf)
updateFlags = append(updateFlags, updated) updateFlags = append(updateFlags, updated)
if err != nil { if err != nil {
nfail++ nfail++
@@ -334,7 +334,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
uf := &updateFilters[i] uf := &updateFilters[i]
updated := updateFlags[i] updated := updateFlags[i]
config.Lock() d.filtersMu.Lock()
for k := range *filters { for k := range *filters {
f := &(*filters)[k] f := &(*filters)[k]
if f.ID != uf.ID || f.URL != uf.URL { if f.ID != uf.ID || f.URL != uf.URL {
@@ -352,20 +352,14 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
f.checksum = uf.checksum f.checksum = uf.checksum
updateCount++ updateCount++
} }
config.Unlock() d.filtersMu.Unlock()
} }
return updateCount, updateFilters, updateFlags, false return updateCount, updateFilters, updateFlags, false
} }
const ( // refreshFiltersIntl checks filters and updates them if necessary. If force is
filterRefreshForce = 1 // ignore last file modification date // true, it ignores the filter.LastUpdated field value.
filterRefreshAllowlists = 2 // update allow-lists
filterRefreshBlocklists = 4 // update block-lists
)
// refreshFiltersIfNecessary checks filters and updates them if necessary. If
// force is true, it ignores the filter.LastUpdated field value.
// //
// Algorithm: // Algorithm:
// //
@@ -378,53 +372,49 @@ const (
// that this method works only on Unix systems. On Windows, don't pass // that this method works only on Unix systems. On Windows, don't pass
// files to filtering, pass the whole data. // files to filtering, pass the whole data.
// //
// refreshFiltersIfNecessary returns the number of updated filters. It also // refreshFiltersIntl returns the number of updated filters. It also returns
// returns true if there was a network error and nothing could be updated. // true if there was a network error and nothing could be updated.
// //
// TODO(a.garipov, e.burkov): What the hell? // TODO(a.garipov, e.burkov): What the hell?
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) { func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
log.Debug("Filters: updating...") log.Debug("filtering: updating...")
updateCount := 0 updNum := 0
var updateFilters []filter var lists []FilterYAML
var updateFlags []bool var toUpd []bool
netError := false isNetErr := false
netErrorW := false
force := false if block {
if (flags & filterRefreshForce) != 0 { updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force)
force = true
} }
if (flags & filterRefreshBlocklists) != 0 { if allow {
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force) updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force)
updNum += updNumAl
lists = append(lists, listsAl...)
toUpd = append(toUpd, toUpdAl...)
isNetErr = isNetErr || isNetErrAl
} }
if (flags & filterRefreshAllowlists) != 0 { if isNetErr {
updateCountW := 0
var updateFiltersW []filter
var updateFlagsW []bool
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
updateCount += updateCountW
updateFilters = append(updateFilters, updateFiltersW...)
updateFlags = append(updateFlags, updateFlagsW...)
}
if netError && netErrorW {
return 0, true return 0, true
} }
if updateCount != 0 { if updNum != 0 {
enableFilters(false) d.EnableFilters(false)
for i := range updateFilters { for i := range lists {
uf := &updateFilters[i] uf := &lists[i]
updated := updateFlags[i] updated := toUpd[i]
if !updated { if !updated {
continue continue
} }
_ = os.Remove(uf.Path() + ".old") _ = os.Remove(uf.Path(d.DataDir) + ".old")
} }
} }
log.Debug("Filters: update finished") log.Debug("filtering: update finished")
return updateCount, false
return updNum, false
} }
// Allows printable UTF-8 text with CR, LF, TAB characters // Allows printable UTF-8 text with CR, LF, TAB characters
@@ -440,7 +430,7 @@ func isPrintableText(data []byte, len int) bool {
} }
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) // A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) { func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
rulesCount := 0 rulesCount := 0
name := "" name := ""
seenTitle := false seenTitle := false
@@ -455,7 +445,7 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
if len(line) == 0 { if len(line) == 0 {
// //
} else if line[0] == '!' { } else if line[0] == '!' {
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1) m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle { if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1] name = m[0][1]
seenTitle = true seenTitle = true
@@ -476,11 +466,11 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
} }
// Perform upgrade on a filter and update LastUpdated value // Perform upgrade on a filter and update LastUpdated value
func (f *Filtering) update(filter *filter) (bool, error) { func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
b, err := f.updateIntl(filter) b, err := d.updateIntl(filter)
filter.LastUpdated = time.Now() filter.LastUpdated = time.Now()
if !b { if !b {
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated) e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
if e != nil { if e != nil {
log.Error("os.Chtimes(): %v", e) log.Error("os.Chtimes(): %v", e)
} }
@@ -488,7 +478,7 @@ func (f *Filtering) update(filter *filter) (bool, error) {
return b, err return b, err
} }
func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) { func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
htmlTest := true htmlTest := true
firstChunk := make([]byte, 4*1024) firstChunk := make([]byte, 4*1024)
firstChunkLen := 0 firstChunkLen := 0
@@ -539,20 +529,20 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in
// finalizeUpdate closes and gets rid of temporary file f with filter's content // finalizeUpdate closes and gets rid of temporary file f with filter's content
// according to updated. It also saves new values of flt's name, rules number // according to updated. It also saves new values of flt's name, rules number
// and checksum if sucсeeded. // and checksum if sucсeeded.
func finalizeUpdate( func (d *DNSFilter) finalizeUpdate(
f *os.File, file *os.File,
flt *filter, flt *FilterYAML,
updated bool, updated bool,
name string, name string,
rnum int, rnum int,
cs uint32, cs uint32,
) (err error) { ) (err error) {
tmpFileName := f.Name() tmpFileName := file.Name()
// Close the file before renaming it because it's required on Windows. // Close the file before renaming it because it's required on Windows.
// //
// See https://github.com/adguardTeam/adGuardHome/issues/1553. // See https://github.com/adguardTeam/adGuardHome/issues/1553.
if err = f.Close(); err != nil { if err = file.Close(); err != nil {
return fmt.Errorf("closing temporary file: %w", err) return fmt.Errorf("closing temporary file: %w", err)
} }
@@ -562,9 +552,9 @@ func finalizeUpdate(
return os.Remove(tmpFileName) return os.Remove(tmpFileName)
} }
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path()) log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
if err = os.Rename(tmpFileName, flt.Path()); err != nil { if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
return errors.WithDeferred(err, os.Remove(tmpFileName)) return errors.WithDeferred(err, os.Remove(tmpFileName))
} }
@@ -578,12 +568,12 @@ func finalizeUpdate(
// processUpdate copies filter's content from src to dst and returns the name, // processUpdate copies filter's content from src to dst and returns the name,
// rules number, and checksum for it. It also returns the number of bytes read // rules number, and checksum for it. It also returns the number of bytes read
// from src. // from src.
func (f *Filtering) processUpdate( func (d *DNSFilter) processUpdate(
src io.Reader, src io.Reader,
dst *os.File, dst *os.File,
flt *filter, flt *FilterYAML,
) (name string, rnum int, cs uint32, n int, err error) { ) (name string, rnum int, cs uint32, n int, err error) {
if n, err = f.read(src, dst, flt); err != nil { if n, err = d.read(src, dst, flt); err != nil {
return "", 0, 0, 0, err return "", 0, 0, 0, err
} }
@@ -591,14 +581,14 @@ func (f *Filtering) processUpdate(
return "", 0, 0, 0, err return "", 0, 0, 0, err
} }
rnum, cs, name = f.parseFilterContents(dst) rnum, cs, name = d.parseFilterContents(dst)
return name, rnum, cs, n, nil return name, rnum, cs, n, nil
} }
// updateIntl updates the flt rewriting it's actual file. It returns true if // updateIntl updates the flt rewriting it's actual file. It returns true if
// the actual update has been performed. // the actual update has been performed.
func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) { func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL) log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
var name string var name string
@@ -606,12 +596,12 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
var cs uint32 var cs uint32
var tmpFile *os.File var tmpFile *os.File
tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "") tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
if err != nil { if err != nil {
return false, err return false, err
} }
defer func() { defer func() {
err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs)) err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
ok = ok && err == nil ok = ok && err == nil
if ok { if ok {
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum) log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
@@ -638,7 +628,7 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
r = file r = file
} else { } else {
var resp *http.Response var resp *http.Response
resp, err = Context.client.Get(flt.URL) resp, err = d.HTTPClient.Get(flt.URL)
if err != nil { if err != nil {
log.Printf("requesting filter from %s, skip: %s", flt.URL, err) log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
@@ -655,16 +645,16 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
r = resp.Body r = resp.Body
} }
name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt) name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
return cs != flt.checksum, err return cs != flt.checksum, err
} }
// loads filter contents from the file in dataDir // loads filter contents from the file in dataDir
func (f *Filtering) load(filter *filter) (err error) { func (d *DNSFilter) load(filter *FilterYAML) (err error) {
filterFilePath := filter.Path() filterFilePath := filter.Path(d.DataDir)
log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath) log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
file, err := os.Open(filterFilePath) file, err := os.Open(filterFilePath)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
@@ -682,7 +672,7 @@ func (f *Filtering) load(filter *filter) (err error) {
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size()) log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
rulesCount, checksum, _ := f.parseFilterContents(file) rulesCount, checksum, _ := d.parseFilterContents(file)
filter.RulesCount = rulesCount filter.RulesCount = rulesCount
filter.checksum = checksum filter.checksum = checksum
@@ -691,56 +681,45 @@ func (f *Filtering) load(filter *filter) (err error) {
return nil return nil
} }
// Clear filter rules func (d *DNSFilter) EnableFilters(async bool) {
func (filter *filter) unload() { d.filtersMu.RLock()
filter.RulesCount = 0 defer d.filtersMu.RUnlock()
filter.checksum = 0
d.enableFiltersLocked(async)
} }
// Path to the filter contents func (d *DNSFilter) enableFiltersLocked(async bool) {
func (filter *filter) Path() string { filters := []Filter{{
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt") ID: CustomListID,
} Data: []byte(strings.Join(d.UserRules, "\n")),
func enableFilters(async bool) {
config.RLock()
defer config.RUnlock()
enableFiltersLocked(async)
}
func enableFiltersLocked(async bool) {
filters := []filtering.Filter{{
ID: filtering.CustomListID,
Data: []byte(strings.Join(config.UserRules, "\n")),
}} }}
for _, filter := range config.Filters { for _, filter := range d.Filters {
if !filter.Enabled { if !filter.Enabled {
continue continue
} }
filters = append(filters, filtering.Filter{ filters = append(filters, Filter{
ID: filter.ID, ID: filter.ID,
FilePath: filter.Path(), FilePath: filter.Path(d.DataDir),
}) })
} }
var allowFilters []filtering.Filter var allowFilters []Filter
for _, filter := range config.WhitelistFilters { for _, filter := range d.WhitelistFilters {
if !filter.Enabled { if !filter.Enabled {
continue continue
} }
allowFilters = append(allowFilters, filtering.Filter{ allowFilters = append(allowFilters, Filter{
ID: filter.ID, ID: filter.ID,
FilePath: filter.Path(), FilePath: filter.Path(d.DataDir),
}) })
} }
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil { if err := d.SetFilters(filters, allowFilters, async); err != nil {
log.Debug("enabling filters: %s", err) log.Debug("enabling filters: %s", err)
} }
Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled) d.SetEnabled(d.FilteringEnabled)
} }

View File

@@ -1,4 +1,4 @@
package home package filtering
import ( import (
"io/fs" "io/fs"
@@ -51,15 +51,17 @@ func TestFilters(t *testing.T) {
l := testStartFilterListener(t, &fltContent) l := testStartFilterListener(t, &fltContent)
Context = homeContext{ tempDir := t.TempDir()
workDir: t.TempDir(),
client: &http.Client{ filters, err := New(&Config{
DataDir: tempDir,
HTTPClient: &http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
}, },
} }, nil)
Context.filters.Init() require.NoError(t, err)
f := &filter{ f := &FilterYAML{
URL: (&url.URL{ URL: (&url.URL{
Scheme: "http", Scheme: "http",
Host: (&netutil.IPPort{ Host: (&netutil.IPPort{
@@ -71,21 +73,22 @@ func TestFilters(t *testing.T) {
} }
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) { updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
ok, err := Context.filters.update(f) var ok bool
ok, err = filters.update(f)
require.NoError(t, err) require.NoError(t, err)
want(t, ok) want(t, ok)
assert.Equal(t, wantRulesCount, f.RulesCount) assert.Equal(t, wantRulesCount, f.RulesCount)
var dir []fs.DirEntry var dir []fs.DirEntry
dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir)) dir, err = os.ReadDir(filepath.Join(tempDir, filterDir))
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, dir, 1) assert.Len(t, dir, 1)
require.FileExists(t, f.Path()) require.FileExists(t, f.Path(tempDir))
err = Context.filters.load(f) err = filters.load(f)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -105,11 +108,9 @@ func TestFilters(t *testing.T) {
}) })
t.Run("load_unload", func(t *testing.T) { t.Run("load_unload", func(t *testing.T) {
err := Context.filters.load(f) err = filters.load(f)
require.NoError(t, err) require.NoError(t, err)
f.unload() f.unload()
}) })
require.NoError(t, os.Remove(f.Path()))
} }

View File

@@ -6,7 +6,10 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net" "net"
"net/http"
"os" "os"
"path/filepath"
"regexp"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@@ -24,6 +27,7 @@ import (
"github.com/AdguardTeam/urlfilter/filterlist" "github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/exp/slices"
) )
// The IDs of built-in filter lists. // The IDs of built-in filter lists.
@@ -69,8 +73,13 @@ type Config struct {
// enabled is used to be returned within Settings. // enabled is used to be returned within Settings.
// //
// It is of type uint32 to be accessed by atomic. // It is of type uint32 to be accessed by atomic.
//
// TODO(e.burkov): Use atomic.Bool in Go 1.19.
enabled uint32 enabled uint32
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
ParentalEnabled bool `yaml:"parental_enabled"` ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
@@ -98,6 +107,24 @@ type Config struct {
// CustomResolver is the resolver used by DNSFilter. // CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver `yaml:"-"` CustomResolver Resolver `yaml:"-"`
// HTTPClient is the client to use for updating the remote filters.
HTTPClient *http.Client `yaml:"-"`
// DataDir is used to store filters' contents.
DataDir string `yaml:"-"`
// filtersMu protects filter lists.
filtersMu *sync.RWMutex
// Filters are the blocking filter lists.
Filters []FilterYAML `yaml:"-"`
// WhitelistFilters are the allowing filter lists.
WhitelistFilters []FilterYAML `yaml:"-"`
// UserRules is the global list of custom rules.
UserRules []string `yaml:"-"`
} }
// LookupStats store stats collected during safebrowsing or parental checks // LookupStats store stats collected during safebrowsing or parental checks
@@ -128,11 +155,13 @@ type hostChecker struct {
// DNSFilter matches hostnames and DNS requests against filtering rules. // DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct { type DNSFilter struct {
rulesStorage *filterlist.RuleStorage rulesStorage *filterlist.RuleStorage
filteringEngine *urlfilter.DNSEngine filteringEngine *urlfilter.DNSEngine
rulesStorageAllow *filterlist.RuleStorage rulesStorageAllow *filterlist.RuleStorage
filteringEngineAllow *urlfilter.DNSEngine filteringEngineAllow *urlfilter.DNSEngine
engineLock sync.RWMutex
engineLock sync.RWMutex
parentalServer string // access via methods parentalServer string // access via methods
safeBrowsingServer string // access via methods safeBrowsingServer string // access via methods
@@ -156,6 +185,12 @@ type DNSFilter struct {
// TODO(e.burkov): Use upstream that configured in dnsforward instead. // TODO(e.burkov): Use upstream that configured in dnsforward instead.
resolver Resolver resolver Resolver
refreshLock *sync.Mutex
// filterTitleRegexp is the regular expression to retrieve a name of a
// filter list.
filterTitleRegexp *regexp.Regexp
hostCheckers []hostChecker hostCheckers []hostChecker
} }
@@ -168,7 +203,7 @@ type Filter struct {
Data []byte `yaml:"-"` Data []byte `yaml:"-"`
// ID is automatically assigned when filter is added using nextFilterID. // ID is automatically assigned when filter is added using nextFilterID.
ID int64 ID int64 `yaml:"id"`
} }
// Reason holds an enum detailing why it was filtered or not filtered // Reason holds an enum detailing why it was filtered or not filtered
@@ -245,15 +280,7 @@ func (r Reason) String() string {
} }
// In returns true if reasons include r. // In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) (ok bool) { func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons, r) }
for _, reason := range reasons {
if r == reason {
return true
}
}
return false
}
// SetEnabled sets the status of the *DNSFilter. // SetEnabled sets the status of the *DNSFilter.
func (d *DNSFilter) SetEnabled(enabled bool) { func (d *DNSFilter) SetEnabled(enabled bool) {
@@ -261,6 +288,7 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
if enabled { if enabled {
i = 1 i = 1
} }
atomic.StoreUint32(&d.enabled, uint32(i)) atomic.StoreUint32(&d.enabled, uint32(i))
} }
@@ -279,11 +307,20 @@ func (d *DNSFilter) GetConfig() (s Settings) {
// WriteDiskConfig - write configuration // WriteDiskConfig - write configuration
func (d *DNSFilter) WriteDiskConfig(c *Config) { func (d *DNSFilter) WriteDiskConfig(c *Config) {
d.confLock.Lock() func() {
defer d.confLock.Unlock() d.confLock.Lock()
defer d.confLock.Unlock()
*c = d.Config *c = d.Config
c.Rewrites = cloneRewrites(c.Rewrites) c.Rewrites = cloneRewrites(c.Rewrites)
}()
d.filtersMu.RLock()
defer d.filtersMu.RUnlock()
c.Filters = slices.Clone(d.Filters)
c.WhitelistFilters = slices.Clone(d.WhitelistFilters)
c.UserRules = slices.Clone(d.UserRules)
} }
// cloneRewrites returns a deep copy of entries. // cloneRewrites returns a deep copy of entries.
@@ -309,6 +346,8 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
} }
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
defer d.filtersInitializerLock.Unlock()
// remove all pending tasks // remove all pending tasks
stop := false stop := false
for !stop { for !stop {
@@ -321,7 +360,6 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
} }
d.filtersInitializerChan <- params d.filtersInitializerChan <- params
d.filtersInitializerLock.Unlock()
return nil return nil
} }
@@ -350,22 +388,19 @@ func (d *DNSFilter) filtersInitializer() {
func (d *DNSFilter) Close() { func (d *DNSFilter) Close() {
d.engineLock.Lock() d.engineLock.Lock()
defer d.engineLock.Unlock() defer d.engineLock.Unlock()
d.reset() d.reset()
} }
func (d *DNSFilter) reset() { func (d *DNSFilter) reset() {
var err error
if d.rulesStorage != nil { if d.rulesStorage != nil {
err = d.rulesStorage.Close() if err := d.rulesStorage.Close(); err != nil {
if err != nil {
log.Error("filtering: rulesStorage.Close: %s", err) log.Error("filtering: rulesStorage.Close: %s", err)
} }
} }
if d.rulesStorageAllow != nil { if d.rulesStorageAllow != nil {
err = d.rulesStorageAllow.Close() if err := d.rulesStorageAllow.Close(); err != nil {
if err != nil {
log.Error("filtering: rulesStorageAllow.Close: %s", err) log.Error("filtering: rulesStorageAllow.Close: %s", err)
} }
} }
@@ -885,29 +920,30 @@ func InitModule() {
initBlockedServices() initBlockedServices()
} }
// New creates properly initialized DNS Filter that is ready to be used. // New creates properly initialized DNS Filter that is ready to be used. c must
func New(c *Config, blockFilters []Filter) (d *DNSFilter) { // be non-nil.
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d = &DNSFilter{ d = &DNSFilter{
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
refreshLock: &sync.Mutex{},
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
} }
if c != nil {
d.safebrowsingCache = cache.New(cache.Config{ d.safebrowsingCache = cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxSize: c.SafeBrowsingCacheSize, MaxSize: c.SafeBrowsingCacheSize,
}) })
d.safeSearchCache = cache.New(cache.Config{ d.safeSearchCache = cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxSize: c.SafeSearchCacheSize, MaxSize: c.SafeSearchCacheSize,
}) })
d.parentalCache = cache.New(cache.Config{ d.parentalCache = cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxSize: c.ParentalCacheSize, MaxSize: c.ParentalCacheSize,
}) })
if c.CustomResolver != nil { if r := c.CustomResolver; r != nil {
d.resolver = c.CustomResolver d.resolver = r
}
} }
d.hostCheckers = []hostChecker{{ d.hostCheckers = []hostChecker{{
@@ -930,27 +966,26 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
name: "safe search", name: "safe search",
}} }}
err := d.initSecurityServices() defer func() { err = errors.Annotate(err, "filtering: %w") }()
if err != nil {
log.Error("filtering: initialize services: %s", err)
return nil err = d.initSecurityServices()
if err != nil {
return nil, fmt.Errorf("initializing services: %s", err)
} }
if c != nil { d.Config = *c
d.Config = *c d.filtersMu = &sync.RWMutex{}
err = d.prepareRewrites()
if err != nil {
log.Error("rewrites: preparing: %s", err)
return nil err = d.prepareRewrites()
} if err != nil {
return nil, fmt.Errorf("rewrites: preparing: %s", err)
} }
bsvcs := []string{} bsvcs := []string{}
for _, s := range d.BlockedServices { for _, s := range d.BlockedServices {
if !BlockedSvcKnown(s) { if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service %q", s) log.Debug("skipping unknown blocked-service %q", s)
continue continue
} }
bsvcs = append(bsvcs, s) bsvcs = append(bsvcs, s)
@@ -960,13 +995,24 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
if blockFilters != nil { if blockFilters != nil {
err = d.initFiltering(nil, blockFilters) err = d.initFiltering(nil, blockFilters)
if err != nil { if err != nil {
log.Error("Can't initialize filtering subsystem: %s", err)
d.Close() d.Close()
return nil
return nil, fmt.Errorf("initializing filtering subsystem: %s", err)
} }
} }
return d _ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755)
d.loadFilters(d.Filters)
d.loadFilters(d.WhitelistFilters)
d.Filters = deduplicateFilters(d.Filters)
d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters)
updateUniqueFilterID(d.Filters)
updateUniqueFilterID(d.WhitelistFilters)
return d, nil
} }
// Start - start the module: // Start - start the module:
@@ -976,9 +1022,10 @@ func (d *DNSFilter) Start() {
d.filtersInitializerChan = make(chan filtersInitializerParams, 1) d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
go d.filtersInitializer() go d.filtersInitializer()
if d.Config.HTTPRegister != nil { // for tests d.RegisterFilteringHandlers()
d.registerSecurityHandlers()
d.registerRewritesHandlers() // Here we should start updating filters,
d.registerBlockedServicesHandlers() // but currently we can't wake up the periodic task to do so.
} // So for now we just start this periodic task from here.
go d.periodicallyRefreshFilters()
} }

View File

@@ -26,10 +26,6 @@ const (
pcBlocked = "pornhub.com" pcBlocked = "pornhub.com"
) )
var setts = Settings{
ProtectionEnabled: true,
}
// Helpers. // Helpers.
func purgeCaches(d *DNSFilter) { func purgeCaches(d *DNSFilter) {
@@ -44,8 +40,8 @@ func purgeCaches(d *DNSFilter) {
} }
} }
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter { func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) {
setts = Settings{ setts = &Settings{
ProtectionEnabled: true, ProtectionEnabled: true,
FilteringEnabled: true, FilteringEnabled: true,
} }
@@ -57,26 +53,31 @@ func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
setts.SafeSearchEnabled = c.SafeSearchEnabled setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled setts.ParentalEnabled = c.ParentalEnabled
} else {
// It must not be nil.
c = &Config{}
} }
d := New(c, filters) f, err := New(c, filters)
purgeCaches(d) require.NoError(t, err)
return d purgeCaches(f)
return f, setts
} }
func (d *DNSFilter) checkMatch(t *testing.T, hostname string) { func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) {
t.Helper() t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts) res, err := d.CheckHost(hostname, dns.TypeA, setts)
require.NoErrorf(t, err, "host %q", hostname) require.NoErrorf(t, err, "host %q", hostname)
assert.Truef(t, res.IsFiltered, "host %q", hostname) assert.Truef(t, res.IsFiltered, "host %q", hostname)
} }
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) { func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16, setts *Settings) {
t.Helper() t.Helper()
res, err := d.CheckHost(hostname, qtype, &setts) res, err := d.CheckHost(hostname, qtype, setts)
require.NoErrorf(t, err, "host %q", hostname, err) require.NoErrorf(t, err, "host %q", hostname, err)
require.NotEmpty(t, res.Rules, "host %q", hostname) require.NotEmpty(t, res.Rules, "host %q", hostname)
@@ -88,10 +89,10 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname) assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
} }
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) { func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string, setts *Settings) {
t.Helper() t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts) res, err := d.CheckHost(hostname, dns.TypeA, setts)
require.NoErrorf(t, err, "host %q", hostname) require.NoErrorf(t, err, "host %q", hostname)
assert.Falsef(t, res.IsFiltered, "host %q", hostname) assert.Falsef(t, res.IsFiltered, "host %q", hostname)
@@ -111,19 +112,19 @@ func TestEtcHostsMatching(t *testing.T) {
filters := []Filter{{ filters := []Filter{{
ID: 0, Data: []byte(text), ID: 0, Data: []byte(text),
}} }}
d := newForTest(t, nil, filters) d, setts := newForTest(t, nil, filters)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.checkMatchIP(t, "google.com", addr, dns.TypeA) d.checkMatchIP(t, "google.com", addr, dns.TypeA, setts)
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA) d.checkMatchIP(t, "www.google.com", addr, dns.TypeA, setts)
d.checkMatchEmpty(t, "subdomain.google.com") d.checkMatchEmpty(t, "subdomain.google.com", setts)
d.checkMatchEmpty(t, "example.org") d.checkMatchEmpty(t, "example.org", setts)
// IPv4 match. // IPv4 match.
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA) d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA, setts)
// Empty IPv6. // Empty IPv6.
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts) res, err := d.CheckHost("block.com", dns.TypeAAAA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -134,10 +135,10 @@ func TestEtcHostsMatching(t *testing.T) {
assert.Empty(t, res.Rules[0].IP) assert.Empty(t, res.Rules[0].IP)
// IPv6 match. // IPv6 match.
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA) d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA, setts)
// Empty IPv4. // Empty IPv4.
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts) res, err = d.CheckHost("ipv6.com", dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -148,7 +149,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.Empty(t, res.Rules[0].IP) assert.Empty(t, res.Rules[0].IP)
// Two IPv4, both must be returned. // Two IPv4, both must be returned.
res, err = d.CheckHost("host2", dns.TypeA, &setts) res, err = d.CheckHost("host2", dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -159,7 +160,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2}) assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
// One IPv6 address. // One IPv6 address.
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts) res, err = d.CheckHost("host2", dns.TypeAAAA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -176,27 +177,27 @@ func TestSafeBrowsing(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked) d.checkMatch(t, sbBlocked, setts)
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked)) require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
d.checkMatch(t, "test."+sbBlocked) d.checkMatch(t, "test."+sbBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru", setts)
d.checkMatchEmpty(t, pcBlocked) d.checkMatchEmpty(t, pcBlocked, setts)
// Cached result. // Cached result.
d.safeBrowsingServer = "127.0.0.1" d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, sbBlocked) d.checkMatch(t, sbBlocked, setts)
d.checkMatchEmpty(t, pcBlocked) d.checkMatchEmpty(t, pcBlocked, setts)
d.safeBrowsingServer = defaultSafebrowsingServer d.safeBrowsingServer = defaultSafebrowsingServer
} }
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
@@ -205,10 +206,10 @@ func TestParallelSB(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel() t.Parallel()
d.checkMatch(t, sbBlocked) d.checkMatch(t, sbBlocked, setts)
d.checkMatch(t, "test."+sbBlocked) d.checkMatch(t, "test."+sbBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru", setts)
d.checkMatchEmpty(t, pcBlocked) d.checkMatchEmpty(t, pcBlocked, setts)
}) })
} }
}) })
@@ -217,7 +218,7 @@ func TestParallelSB(t *testing.T) {
// Safe Search. // Safe Search.
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil) d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
require.True(t, ok) require.True(t, ok)
@@ -226,7 +227,7 @@ func TestSafeSearch(t *testing.T) {
} }
func TestCheckHostSafeSearchYandex(t *testing.T) { func TestCheckHostSafeSearchYandex(t *testing.T) {
d := newForTest(t, &Config{ d, setts := newForTest(t, &Config{
SafeSearchEnabled: true, SafeSearchEnabled: true,
}, nil) }, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
@@ -243,7 +244,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"www.yandex.com", "www.yandex.com",
} { } {
t.Run(strings.ToLower(host), func(t *testing.T) { t.Run(strings.ToLower(host), func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts) res, err := d.CheckHost(host, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -258,7 +259,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{} resolver := &aghtest.TestResolver{}
d := newForTest(t, &Config{ d, setts := newForTest(t, &Config{
SafeSearchEnabled: true, SafeSearchEnabled: true,
CustomResolver: resolver, CustomResolver: resolver,
}, nil) }, nil)
@@ -277,7 +278,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
"www.google.je", "www.google.je",
} { } {
t.Run(host, func(t *testing.T) { t.Run(host, func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts) res, err := d.CheckHost(host, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -291,12 +292,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
} }
func TestSafeSearchCacheYandex(t *testing.T) { func TestSafeSearchCacheYandex(t *testing.T) {
d := newForTest(t, nil, nil) d, setts := newForTest(t, nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const domain = "yandex.ru" const domain = "yandex.ru"
// Check host with disabled safesearch. // Check host with disabled safesearch.
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
@@ -305,10 +306,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
yandexIP := net.IPv4(213, 180, 193, 56) yandexIP := net.IPv4(213, 180, 193, 56)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil) d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts) res, err = d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
// For yandex we already know valid IP. // For yandex we already know valid IP.
@@ -325,20 +326,20 @@ func TestSafeSearchCacheYandex(t *testing.T) {
func TestSafeSearchCacheGoogle(t *testing.T) { func TestSafeSearchCacheGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{} resolver := &aghtest.TestResolver{}
d := newForTest(t, &Config{ d, setts := newForTest(t, &Config{
CustomResolver: resolver, CustomResolver: resolver,
}, nil) }, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const domain = "www.google.ru" const domain = "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
require.Empty(t, res.Rules) require.Empty(t, res.Rules)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil) d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.resolver = resolver d.resolver = resolver
@@ -358,7 +359,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
} }
} }
res, err = d.CheckHost(domain, dns.TypeA, &setts) res, err = d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
@@ -379,22 +380,22 @@ func TestParentalControl(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput) aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG) aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(t, &Config{ParentalEnabled: true}, nil) d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true)) d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.checkMatch(t, pcBlocked) d.checkMatch(t, pcBlocked, setts)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked)) require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
d.checkMatch(t, "www."+pcBlocked) d.checkMatch(t, "www."+pcBlocked, setts)
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru", setts)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru", setts)
d.checkMatchEmpty(t, "api.jquery.com") d.checkMatchEmpty(t, "api.jquery.com", setts)
// Test cached result. // Test cached result.
d.parentalServer = "127.0.0.1" d.parentalServer = "127.0.0.1"
d.checkMatch(t, pcBlocked) d.checkMatch(t, pcBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru", setts)
} }
// Filtering. // Filtering.
@@ -679,10 +680,10 @@ func TestMatching(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) { t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}} filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
d := newForTest(t, nil, filters) d, setts := newForTest(t, nil, filters)
t.Cleanup(d.Close) t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts) res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
require.NoError(t, err) require.NoError(t, err)
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered) assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
@@ -705,7 +706,7 @@ func TestWhitelist(t *testing.T) {
whiteFilters := []Filter{{ whiteFilters := []Filter{{
ID: 0, Data: []byte(whiteRules), ID: 0, Data: []byte(whiteRules),
}} }}
d := newForTest(t, nil, filters) d, setts := newForTest(t, nil, filters)
err := d.SetFilters(filters, whiteFilters, false) err := d.SetFilters(filters, whiteFilters, false)
require.NoError(t, err) require.NoError(t, err)
@@ -713,7 +714,7 @@ func TestWhitelist(t *testing.T) {
t.Cleanup(d.Close) t.Cleanup(d.Close)
// Matched by white filter. // Matched by white filter.
res, err := d.CheckHost("host1", dns.TypeA, &setts) res, err := d.CheckHost("host1", dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
@@ -724,7 +725,7 @@ func TestWhitelist(t *testing.T) {
assert.Equal(t, "||host1^", res.Rules[0].Text) assert.Equal(t, "||host1^", res.Rules[0].Text)
// Not matched by white filter, but matched by block filter. // Not matched by white filter, but matched by block filter.
res, err = d.CheckHost("host2", dns.TypeA, &setts) res, err = d.CheckHost("host2", dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@@ -750,7 +751,7 @@ func applyClientSettings(setts *Settings) {
} }
func TestClientSettings(t *testing.T) { func TestClientSettings(t *testing.T) {
d := newForTest(t, d, setts := newForTest(t,
&Config{ &Config{
ParentalEnabled: true, ParentalEnabled: true,
SafeBrowsingEnabled: false, SafeBrowsingEnabled: false,
@@ -796,7 +797,7 @@ func TestClientSettings(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
t.Helper() t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts) r, err := d.CheckHost(tc.host, dns.TypeA, setts)
require.NoError(t, err) require.NoError(t, err)
if before { if before {
@@ -814,7 +815,7 @@ func TestClientSettings(t *testing.T) {
t.Run(tc.name, makeTester(tc, tc.before)) t.Run(tc.name, makeTester(tc, tc.before))
} }
applyClientSettings(&setts) applyClientSettings(setts)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, !tc.before)) t.Run(tc.name, makeTester(tc, !tc.before))
@@ -824,13 +825,13 @@ func TestClientSettings(t *testing.T) {
// Benchmarks. // Benchmarks.
func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err) require.NoError(b, err)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
@@ -838,14 +839,14 @@ func BenchmarkSafeBrowsing(b *testing.B) {
} }
func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err) require.NoError(b, err)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
@@ -854,7 +855,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
} }
func BenchmarkSafeSearch(b *testing.B) { func BenchmarkSafeSearch(b *testing.B) {
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil) d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
@@ -865,7 +866,7 @@ func BenchmarkSafeSearch(b *testing.B) {
} }
func BenchmarkSafeSearchParallel(b *testing.B) { func BenchmarkSafeSearchParallel(b *testing.B) {
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil) d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {

View File

@@ -133,34 +133,31 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
// 1. A and AAAA > CNAME // 1. A and AAAA > CNAME
// 2. wildcard > exact // 2. wildcard > exact
// 3. lower level wildcard > higher level wildcard // 3. lower level wildcard > higher level wildcard
//
// TODO(a.garipov): Replace with slices.Sort.
type rewritesSorted []*LegacyRewrite type rewritesSorted []*LegacyRewrite
// Len implements the sort.Interface interface for legacyRewritesSorted. // Len implements the sort.Interface interface for rewritesSorted.
func (a rewritesSorted) Len() (l int) { return len(a) } func (a rewritesSorted) Len() (l int) { return len(a) }
// Swap implements the sort.Interface interface for legacyRewritesSorted. // Swap implements the sort.Interface interface for rewritesSorted.
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// Less implements the sort.Interface interface for legacyRewritesSorted. // Less implements the sort.Interface interface for rewritesSorted.
func (a rewritesSorted) Less(i, j int) (less bool) { func (a rewritesSorted) Less(i, j int) (less bool) {
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME { ith, jth := a[i], a[j]
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
return true return true
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME { } else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
return false return false
} }
if isWildcard(a[i].Domain) { if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
if !isWildcard(a[j].Domain) { return jw
return false
}
} else {
if isWildcard(a[j].Domain) {
return true
}
} }
// Both are wildcards. // Both are either wildcards or not.
return len(a[i].Domain) > len(a[j].Domain) return len(ith.Domain) > len(jth.Domain)
} }
// prepareRewrites normalizes and validates all legacy DNS rewrites. // prepareRewrites normalizes and validates all legacy DNS rewrites.
@@ -313,9 +310,3 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
d.Config.ConfigModified() d.Config.ConfigModified()
} }
func (d *DNSFilter) registerRewritesHandlers() {
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
}

View File

@@ -12,7 +12,7 @@ import (
// TODO(e.burkov): All the tests in this file may and should me merged together. // TODO(e.burkov): All the tests in this file may and should me merged together.
func TestRewrites(t *testing.T) { func TestRewrites(t *testing.T) {
d := newForTest(t, nil, nil) d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.Rewrites = []*LegacyRewrite{{ d.Rewrites = []*LegacyRewrite{{
@@ -188,7 +188,7 @@ func TestRewrites(t *testing.T) {
} }
func TestRewritesLevels(t *testing.T) { func TestRewritesLevels(t *testing.T) {
d := newForTest(t, nil, nil) d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// Exact host, wildcard L2, wildcard L3. // Exact host, wildcard L2, wildcard L3.
d.Rewrites = []*LegacyRewrite{{ d.Rewrites = []*LegacyRewrite{{
@@ -235,7 +235,7 @@ func TestRewritesLevels(t *testing.T) {
} }
func TestRewritesExceptionCNAME(t *testing.T) { func TestRewritesExceptionCNAME(t *testing.T) {
d := newForTest(t, nil, nil) d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// Wildcard and exception for a sub-domain. // Wildcard and exception for a sub-domain.
d.Rewrites = []*LegacyRewrite{{ d.Rewrites = []*LegacyRewrite{{
@@ -286,7 +286,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
} }
func TestRewritesExceptionIP(t *testing.T) { func TestRewritesExceptionIP(t *testing.T) {
d := newForTest(t, nil, nil) d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
// Exception for AAAA record. // Exception for AAAA record.
d.Rewrites = []*LegacyRewrite{{ d.Rewrites = []*LegacyRewrite{{

View File

@@ -415,17 +415,3 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
} }
} }
func (d *DNSFilter) registerSecurityHandlers() {
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
}

View File

@@ -107,7 +107,7 @@ func TestSafeBrowsingCache(t *testing.T) {
} }
func TestSBPC_checkErrorUpstream(t *testing.T) { func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
ups := aghtest.NewErrorUpstream() ups := aghtest.NewErrorUpstream()
@@ -128,7 +128,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
} }
func TestSBPC(t *testing.T) { func TestSBPC(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const hostname = "example.org" const hostname = "example.org"

View File

@@ -3,7 +3,7 @@ package home
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/dnsproxy/fastip" "github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@@ -23,10 +22,9 @@ import (
yaml "gopkg.in/yaml.v3" yaml "gopkg.in/yaml.v3"
) )
const ( // dataDir is the name of a directory under the working one to store some
dataDir = "data" // data storage // persistent data.
filterDir = "filters" // cache location for downloaded filters, it's under DataDir const dataDir = "data"
)
// logSettings are the logging settings part of the configuration file. // logSettings are the logging settings part of the configuration file.
// //
@@ -87,10 +85,10 @@ type configuration struct {
// It's reset after config is parsed // It's reset after config is parsed
fileData []byte fileData []byte
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindHost netip.Addr `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
Users []User `yaml:"users"` // Users that can access HTTP server Users []User `yaml:"users"` // Users that can access HTTP server
// AuthAttempts is the maximum number of failed login attempts a user // AuthAttempts is the maximum number of failed login attempts a user
// can do before being blocked. // can do before being blocked.
AuthAttempts uint `yaml:"auth_attempts"` AuthAttempts uint `yaml:"auth_attempts"`
@@ -108,9 +106,16 @@ type configuration struct {
DNS dnsConfig `yaml:"dns"` DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"` TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"` // Filters reflects the filters from [filtering.Config]. It's cloned to the
WhitelistFilters []filter `yaml:"whitelist_filters"` // config used in the filtering module at the startup. Afterwards it's
UserRules []string `yaml:"user_rules"` // cloned from the filtering module back here.
//
// TODO(e.burkov): Move all the filtering configuration fields into the
// only configuration subsection covering the changes with a single
// migration.
Filters []filtering.FilterYAML `yaml:"filters"`
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"` DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
@@ -130,8 +135,8 @@ type configuration struct {
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct { type dnsConfig struct {
BindHosts []net.IP `yaml:"bind_hosts"` BindHosts []netip.Addr `yaml:"bind_hosts"`
Port int `yaml:"port"` Port int `yaml:"port"`
// time interval for statistics (in days) // time interval for statistics (in days)
StatsInterval uint32 `yaml:"statistics_interval"` StatsInterval uint32 `yaml:"statistics_interval"`
@@ -145,9 +150,7 @@ type dnsConfig struct {
dnsforward.FilteringConfig `yaml:",inline"` dnsforward.FilteringConfig `yaml:",inline"`
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists DnsfilterConf *filtering.Config `yaml:",inline"`
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
DnsfilterConf filtering.Config `yaml:",inline"`
// UpstreamTimeout is the timeout for querying upstream servers. // UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"` UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
@@ -193,15 +196,20 @@ type tlsConfigSettings struct {
// //
// TODO(a.garipov, e.burkov): This global is awful and must be removed. // TODO(a.garipov, e.burkov): This global is awful and must be removed.
var config = &configuration{ var config = &configuration{
BindPort: 3000, BindPort: 3000,
BetaBindPort: 0, BetaBindPort: 0,
BindHost: net.IP{0, 0, 0, 0}, BindHost: netip.IPv4Unspecified(),
AuthAttempts: 5, AuthAttempts: 5,
AuthBlockMin: 15, AuthBlockMin: 15,
WebSessionTTLHours: 30 * 24,
DNS: dnsConfig{ DNS: dnsConfig{
BindHosts: []net.IP{{0, 0, 0, 0}}, BindHosts: []netip.Addr{netip.IPv4Unspecified()},
Port: defaultPortDNS, Port: defaultPortDNS,
StatsInterval: 1, StatsInterval: 1,
QueryLogEnabled: true,
QueryLogFileEnabled: true,
QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day},
QueryLogMemSize: 1000,
FilteringConfig: dnsforward.FilteringConfig{ FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of filtering features ProtectionEnabled: true, // whether or not use any of filtering features
BlockingMode: dnsforward.BlockingModeDefault, BlockingMode: dnsforward.BlockingModeDefault,
@@ -222,18 +230,42 @@ var config = &configuration{
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257 // was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
MaxGoroutines: 300, MaxGoroutines: 300,
}, },
FilteringEnabled: true, // whether or not use filter lists DnsfilterConf: &filtering.Config{
FiltersUpdateIntervalHours: 24, SafeBrowsingCacheSize: 1 * 1024 * 1024,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, SafeSearchCacheSize: 1 * 1024 * 1024,
UsePrivateRDNS: true, ParentalCacheSize: 1 * 1024 * 1024,
CacheTime: 30,
FilteringEnabled: true,
FiltersUpdateIntervalHours: 24,
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
}, },
TLS: tlsConfigSettings{ TLS: tlsConfigSettings{
PortHTTPS: defaultPortHTTPS, PortHTTPS: defaultPortHTTPS,
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC, PortDNSOverQUIC: defaultPortQUIC,
}, },
Filters: []filtering.FilterYAML{{
Filter: filtering.Filter{ID: 1},
Enabled: true,
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
Name: "AdGuard DNS filter",
}, {
Filter: filtering.Filter{ID: 2},
Enabled: false,
URL: "https://adaway.org/hosts.txt",
Name: "AdAway Default Blocklist",
}},
DHCP: &dhcpd.ServerConfig{ DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan", LocalDomainName: "lan",
Conf4: dhcpd.V4ServerConf{
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP,
},
Conf6: dhcpd.V6ServerConf{
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
},
}, },
Clients: &clientsConfig{ Clients: &clientsConfig{
Sources: &clientSourcesConf{ Sources: &clientSourcesConf{
@@ -255,31 +287,6 @@ var config = &configuration{
SchemaVersion: currentSchemaVersion, SchemaVersion: currentSchemaVersion,
} }
// initConfig initializes default configuration for the current OS&ARCH
func initConfig() {
config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true
config.DNS.QueryLogFileEnabled = true
config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day}
config.DNS.QueryLogMemSize = 1000
config.DNS.CacheSize = 4 * 1024 * 1024
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
config.BetaBindPort = 3001
}
}
// getConfigFilename returns path to the current config file // getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string { func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename) configFile, err := filepath.EvalSymlinks(Context.configFilename)
@@ -348,8 +355,8 @@ func parseConfig() (err error) {
return fmt.Errorf("validating udp ports: %w", err) return fmt.Errorf("validating udp ports: %w", err)
} }
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24 config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
} }
if config.DNS.UpstreamTimeout.Duration == 0 { if config.DNS.UpstreamTimeout.Duration == 0 {
@@ -418,10 +425,11 @@ func (c *configuration) write() (err error) {
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
} }
if Context.dnsFilter != nil { if Context.filters != nil {
c := filtering.Config{} Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
Context.dnsFilter.WriteDiskConfig(&c) config.Filters = config.DNS.DnsfilterConf.Filters
config.DNS.DnsfilterConf = c config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
config.UserRules = config.DNS.DnsfilterConf.UserRules
} }
if s := Context.dnsServer; s != nil { if s := Context.dnsServer; s != nil {

View File

@@ -3,8 +3,8 @@ package home
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"runtime" "runtime"
"strings" "strings"
@@ -20,11 +20,11 @@ import (
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS // appendDNSAddrs is a convenient helper for appending a formatted form of DNS
// addresses to a slice of strings. // addresses to a slice of strings.
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) { func appendDNSAddrs(dst []string, addrs ...netip.Addr) (res []string) {
for _, addr := range addrs { for _, addr := range addrs {
var hostport string var hostport string
if config.DNS.Port != defaultPortDNS { if config.DNS.Port != defaultPortDNS {
hostport = netutil.JoinHostPort(addr.String(), config.DNS.Port) hostport = netip.AddrPortFrom(addr, uint16(config.DNS.Port)).String()
} else { } else {
hostport = addr.String() hostport = addr.String()
} }
@@ -38,7 +38,7 @@ func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to // appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to
// dst. It also adds the IP addresses of all network interfaces if src contains // dst. It also adds the IP addresses of all network interfaces if src contains
// an unspecified IP address. // an unspecified IP address.
func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) { func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err error) {
ifacesAdded := false ifacesAdded := false
for _, h := range src { for _, h := range src {
if !h.IsUnspecified() { if !h.IsUnspecified() {
@@ -71,7 +71,9 @@ func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err err
// on, including the addresses on all interfaces in cases of unspecified IPs. // on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) { func collectDNSAddresses() (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 { if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addrs = appendDNSAddrs(addrs, net.IP{127, 0, 0, 1}) addr := aghnet.IPv4Localhost()
addrs = appendDNSAddrs(addrs, addr)
} else { } else {
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts) addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
if err != nil { if err != nil {
@@ -291,7 +293,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
} }
httpsURL := &url.URL{ httpsURL := &url.URL{
Scheme: schemeHTTPS, Scheme: aghhttp.SchemeHTTPS,
Host: hostPort, Host: hostPort,
Path: r.URL.Path, Path: r.URL.Path,
RawQuery: r.URL.RawQuery, RawQuery: r.URL.RawQuery,
@@ -307,7 +309,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
// //
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
originURL := &url.URL{ originURL := &url.URL{
Scheme: schemeHTTP, Scheme: aghhttp.SchemeHTTP,
Host: r.Host, Host: r.Host,
} }
w.Header().Set("Access-Control-Allow-Origin", originURL.String()) w.Header().Set("Access-Control-Allow-Origin", originURL.String())

View File

@@ -5,8 +5,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/netip"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@@ -75,9 +75,9 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
} }
type checkConfReqEnt struct { type checkConfReqEnt struct {
IP net.IP `json:"ip"` IP netip.Addr `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
Autofix bool `json:"autofix"` Autofix bool `json:"autofix"`
} }
type checkConfReq struct { type checkConfReq struct {
@@ -128,7 +128,7 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
// unbound after install. // unbound after install.
} }
return aghnet.CheckPort("tcp", req.Web.IP, portInt) return aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(portInt)))
} }
// validateDNS returns error if the DNS part of the initial configuration can't // validateDNS returns error if the DNS part of the initial configuration can't
@@ -153,13 +153,13 @@ func (req *checkConfReq) validateDNS(
return false, err return false, err
} }
err = aghnet.CheckPort("tcp", req.DNS.IP, port) err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
if err != nil { if err != nil {
return false, err return false, err
} }
} }
err = aghnet.CheckPort("udp", req.DNS.IP, port) err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
if !aghnet.IsAddrInUse(err) { if !aghnet.IsAddrInUse(err) {
return false, err return false, err
} }
@@ -171,7 +171,7 @@ func (req *checkConfReq) validateDNS(
log.Error("disabling DNSStubListener: %s", err) log.Error("disabling DNSStubListener: %s", err)
} }
err = aghnet.CheckPort("udp", req.DNS.IP, port) err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(port)))
canAutofix = false canAutofix = false
} }
@@ -213,7 +213,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
// handleStaticIP - handles static IP request // handleStaticIP - handles static IP request
// It either checks if we have a static IP // It either checks if we have a static IP
// Or if set=true, it tries to set it // Or if set=true, it tries to set it
func handleStaticIP(ip net.IP, set bool) staticIPJSON { func handleStaticIP(ip netip.Addr, set bool) staticIPJSON {
resp := staticIPJSON{} resp := staticIPJSON{}
interfaceName := aghnet.InterfaceByIP(ip) interfaceName := aghnet.InterfaceByIP(ip)
@@ -321,8 +321,8 @@ func disableDNSStubListener() error {
} }
type applyConfigReqEnt struct { type applyConfigReqEnt struct {
IP net.IP `json:"ip"` IP netip.Addr `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
} }
type applyConfigReq struct { type applyConfigReq struct {
@@ -388,14 +388,14 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
return return
} }
err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port) err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
err = aghnet.CheckPort("tcp", req.DNS.IP, req.DNS.Port) err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, uint16(req.DNS.Port)))
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -408,7 +408,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
Context.firstRun = false Context.firstRun = false
config.BindHost = req.Web.IP config.BindHost = req.Web.IP
config.BindPort = req.Web.Port config.BindPort = req.Web.Port
config.DNS.BindHosts = []net.IP{req.DNS.IP} config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port config.DNS.Port = req.DNS.Port
// TODO(e.burkov): StartMods() should be put in a separate goroutine at the // TODO(e.burkov): StartMods() should be put in a separate goroutine at the
@@ -481,9 +481,9 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return nil, false, errors.Error("ports cannot be 0") return nil, false, errors.Error("ports cannot be 0")
} }
restartHTTP = !config.BindHost.Equal(req.Web.IP) || config.BindPort != req.Web.Port restartHTTP = config.BindHost != req.Web.IP || config.BindPort != req.Web.Port
if restartHTTP { if restartHTTP {
err = aghnet.CheckPort("tcp", req.Web.IP, req.Web.Port) err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.Web.IP, uint16(req.Web.Port)))
if err != nil { if err != nil {
return nil, false, fmt.Errorf( return nil, false, fmt.Errorf(
"checking address %s:%d: %w", "checking address %s:%d: %w",
@@ -509,9 +509,9 @@ func (web *Web) registerInstallHandlers() {
// TODO(e.burkov): This should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReqEnt. // functionality will appear in default checkConfigReqEnt.
type checkConfigReqEntBeta struct { type checkConfigReqEntBeta struct {
IP []net.IP `json:"ip"` IP []netip.Addr `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
Autofix bool `json:"autofix"` Autofix bool `json:"autofix"`
} }
// checkConfigReqBeta is a struct representing new client's config check request // checkConfigReqBeta is a struct representing new client's config check request
@@ -586,8 +586,8 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
// TODO(e.burkov): This should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReqEnt. // functionality will appear in default applyConfigReqEnt.
type applyConfigReqEntBeta struct { type applyConfigReqEntBeta struct {
IP []net.IP `json:"ip"` IP []netip.Addr `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
} }
// applyConfigReqBeta is a struct representing new client's config setting // applyConfigReqBeta is a struct representing new client's config setting

View File

@@ -3,6 +3,7 @@ package home
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
@@ -31,7 +32,10 @@ const (
// Called by other modules when configuration is changed // Called by other modules when configuration is changed
func onConfigModified() { func onConfigModified() {
_ = config.write() err := config.write()
if err != nil {
log.Error("writing config: %s", err)
}
} }
// initDNSServer creates an instance of the dnsforward.Server // initDNSServer creates an instance of the dnsforward.Server
@@ -71,11 +75,11 @@ func initDNSServer() (err error) {
} }
Context.queryLog = querylog.New(conf) Context.queryLog = querylog.New(conf)
filterConf := config.DNS.DnsfilterConf Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
filterConf.EtcHosts = Context.etcHosts if err != nil {
filterConf.ConfigModified = onConfigModified // Don't wrap the error, since it's informative enough as is.
filterConf.HTTPRegister = httpRegister return err
Context.dnsFilter = filtering.New(&filterConf, nil) }
var privateNets netutil.SubnetSet var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) { switch len(config.DNS.PrivateNets) {
@@ -83,13 +87,10 @@ func initDNSServer() (err error) {
// Use an optimized locally-served matcher. // Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1: case 1:
var n *net.IPNet privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil { if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err) return fmt.Errorf("preparing the set of private subnets: %w", err)
} }
privateNets = n
default: default:
var nets []*net.IPNet var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
@@ -101,15 +102,13 @@ func initDNSServer() (err error) {
} }
p := dnsforward.DNSCreateParams{ p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter, DNSFilter: Context.filters,
Stats: Context.stats, Stats: Context.stats,
QueryLog: Context.queryLog, QueryLog: Context.queryLog,
PrivateNets: privateNets, PrivateNets: privateNets,
Anonymizer: anonymizer, Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName, LocalDomain: config.DHCP.LocalDomainName,
} DHCPServer: Context.dhcpServer,
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
} }
Context.dnsServer, err = dnsforward.NewServer(p) Context.dnsServer, err = dnsforward.NewServer(p)
@@ -143,7 +142,6 @@ func initDNSServer() (err error) {
Context.whois = initWHOIS(&Context.clients) Context.whois = initWHOIS(&Context.clients)
} }
Context.filters.Init()
return nil return nil
} }
@@ -167,33 +165,27 @@ func onDNSRequest(pctx *proxy.DNSContext) {
} }
} }
func ipsToTCPAddrs(ips []net.IP, port int) (tcpAddrs []*net.TCPAddr) { func ipsToTCPAddrs(ips []netip.Addr, port int) (tcpAddrs []*net.TCPAddr) {
if ips == nil { if ips == nil {
return nil return nil
} }
tcpAddrs = make([]*net.TCPAddr, len(ips)) tcpAddrs = make([]*net.TCPAddr, 0, len(ips))
for i, ip := range ips { for _, ip := range ips {
tcpAddrs[i] = &net.TCPAddr{ tcpAddrs = append(tcpAddrs, net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
IP: ip,
Port: port,
}
} }
return tcpAddrs return tcpAddrs
} }
func ipsToUDPAddrs(ips []net.IP, port int) (udpAddrs []*net.UDPAddr) { func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) {
if ips == nil { if ips == nil {
return nil return nil
} }
udpAddrs = make([]*net.UDPAddr, len(ips)) udpAddrs = make([]*net.UDPAddr, 0, len(ips))
for i, ip := range ips { for _, ip := range ips {
udpAddrs[i] = &net.UDPAddr{ udpAddrs = append(udpAddrs, net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))))
IP: ip,
Port: port,
}
} }
return udpAddrs return udpAddrs
@@ -203,7 +195,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
dnsConf := config.DNS dnsConf := config.DNS
hosts := dnsConf.BindHosts hosts := dnsConf.BindHosts
if len(hosts) == 0 { if len(hosts) == 0 {
hosts = []net.IP{{127, 0, 0, 1}} hosts = []netip.Addr{aghnet.IPv4Localhost()}
} }
newConf = dnsforward.ServerConfig{ newConf = dnsforward.ServerConfig{
@@ -257,7 +249,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
return newConf, nil return newConf, nil
} }
func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) { func newDNSCrypt(hosts []netip.Addr, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
if tlsConf.DNSCryptConfigFile == "" { if tlsConf.DNSCryptConfigFile == "" {
return dnscc, errors.Error("no dnscrypt_config_file") return dnscc, errors.Error("no dnscrypt_config_file")
} }
@@ -335,9 +327,12 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if // applyAdditionalFiltering adds additional client information and settings if
// the client has them. // the client has them.
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) { func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true) // pref is a prefix for logging messages around the scope.
const pref = "applying filters"
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID) Context.filters.ApplyBlockedServices(setts, nil)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
if clientIP == nil { if clientIP == nil {
return return
@@ -349,16 +344,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
if !ok { if !ok {
c, ok = Context.clients.Find(clientIP.String()) c, ok = Context.clients.Find(clientIP.String())
if !ok { if !ok {
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID) log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
return return
} }
} }
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID) log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID)
if c.UseOwnBlockedServices { if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false) Context.filters.ApplyBlockedServices(setts, c.BlockedServices)
} }
setts.ClientName = c.Name setts.ClientName = c.Name
@@ -381,7 +376,7 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running") return fmt.Errorf("unable to start forwarding DNS server: Already running")
} }
enableFiltersLocked(false) Context.filters.EnableFilters(false)
Context.clients.Start() Context.clients.Start()
@@ -390,7 +385,6 @@ func startDNSServer() error {
return fmt.Errorf("couldn't start forwarding DNS server: %w", err) return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
} }
Context.dnsFilter.Start()
Context.filters.Start() Context.filters.Start()
Context.stats.Start() Context.stats.Start()
Context.queryLog.Start() Context.queryLog.Start()
@@ -449,10 +443,7 @@ func closeDNSServer() {
Context.dnsServer = nil Context.dnsServer = nil
} }
if Context.dnsFilter != nil { Context.filters.Close()
Context.dnsFilter.Close()
Context.dnsFilter = nil
}
if Context.stats != nil { if Context.stats != nil {
err := Context.stats.Close() err := Context.stats.Close()
@@ -469,7 +460,5 @@ func closeDNSServer() {
Context.queryLog = nil Context.queryLog = nil
} }
Context.filters.Close() log.Debug("all dns modules are closed")
log.Debug("Closed all DNS modules")
} }

View File

@@ -10,6 +10,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"net/netip"
"net/url" "net/url"
"os" "os"
"os/signal" "os/signal"
@@ -20,6 +21,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/aghtls"
@@ -33,6 +35,7 @@ import (
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
@@ -52,10 +55,9 @@ type homeContext struct {
dnsServer *dnsforward.Server // DNS module dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module whois *WHOIS // WHOIS module
dnsFilter *filtering.DNSFilter // DNS filtering module
dhcpServer dhcpd.Interface // DHCP module dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module tls *TLSMod // TLS module
// etcHosts is an IP-hostname pairs set taken from system configuration // etcHosts is an IP-hostname pairs set taken from system configuration
@@ -140,7 +142,12 @@ func setupContext(args options) {
checkPermissions() checkPermissions()
} }
initConfig() switch version.Channel() {
case version.ChannelEdge, version.ChannelDevelopment:
config.BetaBindPort = 3001
default:
// Go on.
}
Context.tlsRoots = LoadSystemRootCAs() Context.tlsRoots = LoadSystemRootCAs()
Context.transport = &http.Transport{ Context.transport = &http.Transport{
@@ -265,6 +272,14 @@ func setupHostsContainer() (err error) {
} }
func setupConfig(args options) (err error) { func setupConfig(args options) (err error) {
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DHCP.WorkDir = Context.workDir config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified config.DHCP.ConfigModified = onConfigModified
@@ -325,7 +340,7 @@ func setupConfig(args options) (err error) {
} }
// override bind host/port from the console // override bind host/port from the console
if args.bindHost != nil { if args.bindHost.IsValid() {
config.BindHost = args.bindHost config.BindHost = args.bindHost
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
@@ -384,8 +399,6 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home. // run configures and starts AdGuard Home.
func run(args options, clientBuildFS fs.FS) { func run(args options, clientBuildFS fs.FS) {
var err error
// configure config filename // configure config filename
initConfigFilename(args) initConfigFilename(args)
@@ -404,7 +417,7 @@ func run(args options, clientBuildFS fs.FS) {
setupContext(args) setupContext(args)
err = configureOS(config) err := configureOS(config)
fatalOnError(err) fatalOnError(err)
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()), // clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
@@ -525,7 +538,7 @@ func checkPermissions() {
} }
// We should check if AdGuard Home is able to bind to port 53 // We should check if AdGuard Home is able to bind to port 53
err := aghnet.CheckPort("tcp", net.IP{127, 0, 0, 1}, defaultPortDNS) err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS))
if err != nil { if err != nil {
if errors.Is(err, os.ErrPermission) { if errors.Is(err, os.ErrPermission) {
log.Fatal(`Permission check failed. log.Fatal(`Permission check failed.
@@ -763,12 +776,12 @@ func printHTTPAddresses(proto string) {
} }
port := config.BindPort port := config.BindPort
if proto == schemeHTTPS { if proto == aghhttp.SchemeHTTPS {
port = tlsConf.PortHTTPS port = tlsConf.PortHTTPS
} }
// TODO(e.burkov): Inspect and perhaps merge with the previous condition. // TODO(e.burkov): Inspect and perhaps merge with the previous condition.
if proto == schemeHTTPS && tlsConf.ServerName != "" { if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0) printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
return return

View File

@@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"path" "path"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
case dnsProtoHTTPS: case dnsProtoHTTPS:
dspName = fmt.Sprintf("%s DoH", d.ServerName) dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{ u := &url.URL{
Scheme: schemeHTTPS, Scheme: aghhttp.SchemeHTTPS,
Host: d.ServerName, Host: d.ServerName,
Path: path.Join("/dns-query", clientID), Path: path.Join("/dns-query", clientID),
} }

View File

@@ -3,12 +3,11 @@ package home
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"testing" "testing"
"github.com/AdguardTeam/golibs/netutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"howett.net/plist" "howett.net/plist"
@@ -28,7 +27,7 @@ func setupDNSIPs(t testing.TB) {
config = &configuration{ config = &configuration{
DNS: dnsConfig{ DNS: dnsConfig{
BindHosts: []net.IP{netutil.IPv4Zero()}, BindHosts: []netip.Addr{netip.IPv4Unspecified()},
Port: defaultPortDNS, Port: defaultPortDNS,
}, },
} }

View File

@@ -2,7 +2,7 @@ package home
import ( import (
"fmt" "fmt"
"net" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -12,15 +12,15 @@ import (
// options passed from command-line arguments // options passed from command-line arguments
type options struct { type options struct {
verbose bool // is verbose logging enabled verbose bool // is verbose logging enabled
configFilename string // path to the config file configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog workDir string // path to the working directory where we will store the filters data and the querylog
bindHost net.IP // host address to bind HTTP server on bindHost netip.Addr // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to pidFile string // File name to save PID to
checkConfig bool // Check configuration and exit checkConfig bool // Check configuration and exit
disableUpdate bool // If set, don't check for updates disableUpdate bool // If set, don't check for updates
// service control action (see service.ControlAction array + "status" command) // service control action (see service.ControlAction array + "status" command)
serviceControlAction string serviceControlAction string
@@ -60,8 +60,8 @@ type arg struct {
// against its zero value and return nil if the parameter value is // against its zero value and return nil if the parameter value is
// zero otherwise they return a string slice of the parameter // zero otherwise they return a string slice of the parameter
func ipSliceOrNil(ip net.IP) []string { func ipSliceOrNil(ip netip.Addr) []string {
if ip == nil { if !ip.IsValid() {
return nil return nil
} }
@@ -113,7 +113,7 @@ var workDirArg = arg{
var hostArg = arg{ var hostArg = arg{
"Host address to bind HTTP server on.", "Host address to bind HTTP server on.",
"host", "h", "host", "h",
func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil, func(o options, v string) (options, error) { o.bindHost, _ = netip.ParseAddr(v); return o, nil }, nil, nil,
func(o options) []string { return ipSliceOrNil(o.bindHost) }, func(o options) []string { return ipSliceOrNil(o.bindHost) },
} }

View File

@@ -2,7 +2,7 @@ package home
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -56,11 +56,13 @@ func TestParseWorkDir(t *testing.T) {
} }
func TestParseBindHost(t *testing.T) { func TestParseBindHost(t *testing.T) {
assert.Nil(t, testParseOK(t).bindHost, "empty is not host") wantAddr := netip.AddrFrom4([4]byte{1, 2, 3, 4})
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
assert.Zero(t, testParseOK(t).bindHost, "empty is not host")
assert.Equal(t, wantAddr, testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
testParseParamMissing(t, "-h") testParseParamMissing(t, "-h")
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host") assert.Equal(t, wantAddr, testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host")
testParseParamMissing(t, "--host") testParseParamMissing(t, "--host")
} }
@@ -149,7 +151,7 @@ func TestSerialize(t *testing.T) {
ss: []string{"-w", "path"}, ss: []string{"-w", "path"},
}, { }, {
name: "bind_host", name: "bind_host",
opts: options{bindHost: net.IP{1, 2, 3, 4}}, opts: options{bindHost: netip.AddrFrom4([4]byte{1, 2, 3, 4})},
ss: []string{"-h", "1.2.3.4"}, ss: []string{"-h", "1.2.3.4"},
}, { }, {
name: "bind_port", name: "bind_port",

View File

@@ -11,6 +11,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@@ -277,7 +278,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it. There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup. Click on the link below and follow the Installation Wizard steps to finish setup.
AdGuard Home is now available at the following addresses:`) AdGuard Home is now available at the following addresses:`)
printHTTPAddresses(schemeHTTP) printHTTPAddresses(aghhttp.SchemeHTTP)
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -160,7 +161,7 @@ func assertEqualExcept(t *testing.T, oldConf, newConf yobj, oldKeys, newKeys []s
} }
func testDiskConf(schemaVersion int) (diskConf yobj) { func testDiskConf(schemaVersion int) (diskConf yobj) {
filters := []filter{{ filters := []filtering.FilterYAML{{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt", URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter", Name: "Latvian filter",
RulesCount: 100, RulesCount: 100,

View File

@@ -4,11 +4,12 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"io/fs" "io/fs"
"net"
"net/http" "net/http"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@@ -19,12 +20,6 @@ import (
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
) )
// HTTP scheme constants.
const (
schemeHTTP = "http"
schemeHTTPS = "https"
)
const ( const (
// readTimeout is the maximum duration for reading the entire request, // readTimeout is the maximum duration for reading the entire request,
// including the body. // including the body.
@@ -40,7 +35,7 @@ type webConfig struct {
clientFS fs.FS clientFS fs.FS
clientBetaFS fs.FS clientBetaFS fs.FS
BindHost net.IP BindHost netip.Addr
BindPort int BindPort int
BetaBindPort int BetaBindPort int
PortHTTPS int PortHTTPS int
@@ -119,8 +114,11 @@ func CreateWeb(conf *webConfig) *Web {
// WebCheckPortAvailable - check if port is available // WebCheckPortAvailable - check if port is available
// BUT: if we are already using this port, no need // BUT: if we are already using this port, no need
func WebCheckPortAvailable(port int) bool { func WebCheckPortAvailable(port int) bool {
return Context.web.httpsServer.server != nil || if Context.web.httpsServer.server != nil {
aghnet.CheckPort("tcp", config.BindHost, port) == nil return true
}
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
} }
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server // TLSConfigChanged updates the TLS configuration and restarts the HTTPS server
@@ -166,7 +164,7 @@ func (web *Web) Start() {
// this loop is used as an ability to change listening host and/or port // this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown { for !web.httpsServer.shutdown {
printHTTPAddresses(schemeHTTP) printHTTPAddresses(aghhttp.SchemeHTTP)
errs := make(chan error, 2) errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
@@ -286,7 +284,7 @@ func (web *Web) tlsServerLoop() {
WriteTimeout: web.conf.WriteTimeout, WriteTimeout: web.conf.WriteTimeout,
} }
printHTTPAddresses(schemeHTTPS) printHTTPAddresses(aghhttp.SchemeHTTPS)
err := web.httpsServer.server.ListenAndServeTLS("", "") err := web.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
cleanupAlways() cleanupAlways()

View File

@@ -9,8 +9,8 @@ require (
github.com/kisielk/errcheck v1.6.2 github.com/kisielk/errcheck v1.6.2
github.com/kyoh86/looppointer v0.1.7 github.com/kyoh86/looppointer v0.1.7
github.com/securego/gosec/v2 v2.13.1 github.com/securego/gosec/v2 v2.13.1
golang.org/x/tools v0.1.13-0.20220803210227-8b9a1fbdf5c3 golang.org/x/tools v0.1.13-0.20220921142454-16b974289fe5
golang.org/x/vuln v0.0.0-20220912202342-0ed43f12cb05 golang.org/x/vuln v0.0.0-20220921153644-d9be10b6cc84
honnef.co/go/tools v0.3.3 honnef.co/go/tools v0.3.3
mvdan.cc/gofumpt v0.3.1 mvdan.cc/gofumpt v0.3.1
mvdan.cc/unparam v0.0.0-20220831102321-2fc90a84c7ec mvdan.cc/unparam v0.0.0-20220831102321-2fc90a84c7ec
@@ -25,10 +25,10 @@ require (
github.com/kyoh86/nolint v0.0.1 // indirect github.com/kyoh86/nolint v0.0.1 // indirect
github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 // indirect golang.org/x/exp v0.0.0-20220921023135-46d9e7742f1e // indirect
golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35 // indirect
golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde // indirect golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde // indirect
golang.org/x/sys v0.0.0-20220909162455-aba9fc2a8ff2 // indirect golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
) )

View File

@@ -55,15 +55,15 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= golang.org/x/exp v0.0.0-20220921023135-46d9e7742f1e h1:Ctm9yurWsg7aWwIpH9Bnap/IdSVxixymIb3MhiMEQQA=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/exp v0.0.0-20220921023135-46d9e7742f1e/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91 h1:Ic/qN6TEifvObMGQy72k0n1LlJr7DjWWEi+MOsDOiSk= golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91 h1:Ic/qN6TEifvObMGQy72k0n1LlJr7DjWWEi+MOsDOiSk=
golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35 h1:CZP0Rbk/s1EIiUMx5DS2MhK2ct52xpQxqddVD0FmF+o=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
@@ -86,8 +86,8 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220909162455-aba9fc2a8ff2 h1:wM1k/lXfpc5HdkJJyW9GELpd8ERGdnh8sMGL6Gzq3Ho= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 h1:h+EGohizhe9XlX18rfpa8k8RAc5XyaeamM+0VHRd4lc=
golang.org/x/sys v0.0.0-20220909162455-aba9fc2a8ff2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -100,10 +100,10 @@ golang.org/x/tools v0.0.0-20200710042808-f1c4188a97a1/go.mod h1:njjCfa9FT2d7l9Bc
golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/tools v0.0.0-20201007032633-0806396f153e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/tools v0.1.13-0.20220803210227-8b9a1fbdf5c3 h1:aE4T3aJwdCNz+s35ScSQYUzeGu7BOLDHZ1bBHVurqqY= golang.org/x/tools v0.1.13-0.20220921142454-16b974289fe5 h1:o1LhIiY5L+hLK9DWqfFlilCrpZnw/s7WU4iCUkb/bao=
golang.org/x/tools v0.1.13-0.20220803210227-8b9a1fbdf5c3/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.13-0.20220921142454-16b974289fe5/go.mod h1:VsjNM1dMo+Ofkp5d7y7fOdQZD8MTXSQ4w3EPk65AvKU=
golang.org/x/vuln v0.0.0-20220912202342-0ed43f12cb05 h1:NWQHMTdThZhCArzUbnu1Bh+l3LdwUfjZws+ivBR2sxM= golang.org/x/vuln v0.0.0-20220921153644-d9be10b6cc84 h1:L0qUjdplndgX880fozFRGC242wAtfsViyRXWGlpZQ54=
golang.org/x/vuln v0.0.0-20220912202342-0ed43f12cb05/go.mod h1:7tDfEDtOLlzHQRi4Yzfg5seVBSvouUIjyPzBx4q5CxQ= golang.org/x/vuln v0.0.0-20220921153644-d9be10b6cc84/go.mod h1:7tDfEDtOLlzHQRi4Yzfg5seVBSvouUIjyPzBx4q5CxQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -219,10 +219,7 @@ exit_on_output gofumpt --extra -e -l .
"$GO" vet ./... "$GO" vet ./...
# TODO(a.garipov): Reenable this once https://github.com/golang/go/issues/55035 govulncheck ./...
# is fixed.
#
# govulncheck ./...
# Apply more lax standards to the code we haven't properly refactored yet. # Apply more lax standards to the code we haven't properly refactored yet.
gocyclo --over 17 ./internal/querylog/ gocyclo --over 17 ./internal/querylog/

View File

@@ -85,11 +85,7 @@ in
# num_commits_since_minor is the number of commits since the last new # num_commits_since_minor is the number of commits since the last new
# minor release. If the current commit is the new minor release, # minor release. If the current commit is the new minor release,
# num_commits_since_minor is zero. # num_commits_since_minor is zero.
num_commits_since_minor="$( git rev-list "${last_minor_zero}..HEAD" | wc -l )" num_commits_since_minor="$( git rev-list --count "${last_minor_zero}..HEAD" )"
# The output of darwin's implementation of wc needs to be trimmed from
# redundant spaces.
num_commits_since_minor="$( echo "$num_commits_since_minor" | tr -d '[:space:]' )"
readonly num_commits_since_minor readonly num_commits_since_minor
# next_minor is the next minor release version. # next_minor is the next minor release version.