Merge branch 'master' into websvc-confin-manager
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
package aghhttp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +11,12 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// HTTP scheme constants.
|
||||
const (
|
||||
SchemeHTTP = "http"
|
||||
SchemeHTTPS = "https"
|
||||
)
|
||||
|
||||
// RegisterFunc is the function that sets the handler to handle the URL for the
|
||||
// method.
|
||||
//
|
||||
@@ -26,7 +33,7 @@ func OK(w http.ResponseWriter) {
|
||||
// Error writes formatted message to w and also logs it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s: %s", r.Method, r.URL, text)
|
||||
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
@@ -35,3 +42,34 @@ func Error(r *http.Request, w http.ResponseWriter, code int, format string, args
|
||||
func UserAgent() (ua string) {
|
||||
return fmt.Sprintf("AdGuardHome/%s", version.Version())
|
||||
}
|
||||
|
||||
// textPlainDeprMsg is the message returned to API users when they try to use
|
||||
// an API that used to accept "text/plain" but doesn't anymore.
|
||||
const textPlainDeprMsg = `using this api with the text/plain content-type is deprecated; ` +
|
||||
`use application/json`
|
||||
|
||||
// WriteTextPlainDeprecated responds to the request with a message about
|
||||
// deprecation and removal of a plain-text API if the request is made with the
|
||||
// "text/plain" content-type.
|
||||
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
|
||||
if r.Header.Get(HdrNameContentType) != HdrValTextPlain {
|
||||
return false
|
||||
}
|
||||
|
||||
Error(r, w, http.StatusUnsupportedMediaType, textPlainDeprMsg)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// WriteJSONResponse sets the content-type header in w.Header() to
|
||||
// "application/json", encodes resp to w, calls Error on any returned error, and
|
||||
// returns it as well.
|
||||
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) {
|
||||
w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ package aghhttp
|
||||
const (
|
||||
HdrNameAcceptEncoding = "Accept-Encoding"
|
||||
HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
|
||||
HdrNameContentType = "Content-Type"
|
||||
HdrNameContentEncoding = "Content-Encoding"
|
||||
HdrNameContentType = "Content-Type"
|
||||
HdrNameServer = "Server"
|
||||
HdrNameTrailer = "Trailer"
|
||||
HdrNameUserAgent = "User-Agent"
|
||||
@@ -18,4 +18,5 @@ const (
|
||||
// HTTP header value constants.
|
||||
const (
|
||||
HdrValApplicationJSON = "application/json"
|
||||
HdrValTextPlain = "text/plain"
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd
|
||||
// +build darwin freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd
|
||||
// +build darwin freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !(windows || linux)
|
||||
// +build !windows,!linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
@@ -19,27 +18,18 @@ import (
|
||||
|
||||
// How to test on a real Linux machine:
|
||||
//
|
||||
// 1. Run:
|
||||
// 1. Run "sudo ipset create example_set hash:ip family ipv4".
|
||||
//
|
||||
// sudo ipset create example_set hash:ip family ipv4
|
||||
// 2. Run "sudo ipset list example_set". The Members field should be empty.
|
||||
//
|
||||
// 2. Run:
|
||||
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
|
||||
//
|
||||
// sudo ipset list example_set
|
||||
// 4. Start AdGuardHome.
|
||||
//
|
||||
// The Members field should be empty.
|
||||
// 5. Make requests to example.com and its subdomains.
|
||||
//
|
||||
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
|
||||
//
|
||||
// 4. Start AdGuardHome.
|
||||
//
|
||||
// 5. Make requests to example.com and its subdomains.
|
||||
//
|
||||
// 6. Run:
|
||||
//
|
||||
// sudo ipset list example_set
|
||||
//
|
||||
// The Members field should contain the resolved IP addresses.
|
||||
// 6. Run "sudo ipset list example_set". The Members field should contain the
|
||||
// resolved IP addresses.
|
||||
|
||||
// newIpsetMgr returns a new Linux ipset manager.
|
||||
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd || openbsd
|
||||
// +build darwin freebsd openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd
|
||||
// +build openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build openbsd || freebsd || linux || darwin
|
||||
// +build openbsd freebsd linux darwin
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghnet
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build mips || mips64
|
||||
// +build mips mips64
|
||||
|
||||
// This file is an adapted version of github.com/josharian/native.
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build amd64 || 386 || arm || arm64 || mipsle || mips64le || ppc64le
|
||||
// +build amd64 386 arm arm64 mipsle mips64le ppc64le
|
||||
|
||||
// This file is an adapted version of github.com/josharian/native.
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || netbsd || openbsd
|
||||
// +build darwin netbsd openbsd
|
||||
//go:build darwin || openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd
|
||||
// +build freebsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
// +build darwin freebsd linux openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !(windows || plan9)
|
||||
// +build !windows,!plan9
|
||||
//go:build !windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows || plan9
|
||||
// +build windows plan9
|
||||
//go:build windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || freebsd || linux || netbsd || openbsd
|
||||
// +build darwin freebsd linux netbsd openbsd
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package aghos
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd || openbsd
|
||||
// +build freebsd openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -10,11 +9,10 @@ import (
|
||||
// broadcast sends resp to the broadcast address specific for network interface.
|
||||
func (c *dhcpConn) broadcast(respData []byte, peer *net.UDPAddr) (n int, err error) {
|
||||
// Despite the fact that server4.NewIPv4UDPConn explicitly sets socket
|
||||
// options to allow broadcasting, it also binds the connection to a
|
||||
// specific interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo
|
||||
// causes errors while writing to the addresses that belong to another
|
||||
// interface. So, use the broadcast address specific for the interface
|
||||
// bound.
|
||||
// options to allow broadcasting, it also binds the connection to a specific
|
||||
// interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo causes errors
|
||||
// while writing to the addresses that belong to another interface. So, use
|
||||
// the broadcast address specific for the interface bound.
|
||||
peer.IP = c.bcastIP
|
||||
|
||||
return c.udpConn.WriteTo(respData, peer)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build freebsd || openbsd
|
||||
// +build freebsd openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
|
||||
// +build aix darwin dragonfly linux netbsd solaris
|
||||
//go:build darwin || linux
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
|
||||
// +build aix darwin dragonfly linux netbsd solaris
|
||||
//go:build darwin || linux
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,10 +1,40 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// ServerConfig is the configuration for the DHCP server. The order of YAML
|
||||
// fields is important, since the YAML configuration file follows it.
|
||||
type ServerConfig struct {
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
// LocalDomainName is the domain name used for DHCP hosts. For example,
|
||||
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
|
||||
// when LocalDomainName is "lan".
|
||||
LocalDomainName string `yaml:"local_domain_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"`
|
||||
}
|
||||
|
||||
// DHCPServer - DHCP server interface
|
||||
type DHCPServer interface {
|
||||
// ResetLeases resets leases.
|
||||
@@ -80,6 +110,86 @@ type V4ServerConf struct {
|
||||
notify func(uint32)
|
||||
}
|
||||
|
||||
// errNilConfig is an error returned by validation method if the config is nil.
|
||||
const errNilConfig errors.Error = "nil config"
|
||||
|
||||
// ensureV4 returns a 4-byte version of ip. An error is returned if the passed
|
||||
// ip is not an IPv4.
|
||||
func ensureV4(ip net.IP) (ip4 net.IP, err error) {
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("%v is not an IP address", ip)
|
||||
}
|
||||
|
||||
ip4 = ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
|
||||
}
|
||||
|
||||
return ip4, nil
|
||||
}
|
||||
|
||||
// Validate returns an error if c is not a valid configuration.
|
||||
//
|
||||
// TODO(e.burkov): Don't set the config fields when the server itself will stop
|
||||
// containing the config.
|
||||
func (c *V4ServerConf) Validate() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if c == nil {
|
||||
return errNilConfig
|
||||
}
|
||||
|
||||
var gatewayIP net.IP
|
||||
gatewayIP, err = ensureV4(c.GatewayIP)
|
||||
if err != nil {
|
||||
// Don't wrap an errors since it's inforative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
|
||||
if c.SubnetMask == nil {
|
||||
return fmt.Errorf("invalid subnet mask: %v", c.SubnetMask)
|
||||
}
|
||||
|
||||
subnetMask := net.IPMask(netutil.CloneIP(c.SubnetMask.To4()))
|
||||
c.subnet = &net.IPNet{
|
||||
IP: gatewayIP,
|
||||
Mask: subnetMask,
|
||||
}
|
||||
c.broadcastIP = aghnet.BroadcastFromIPNet(c.subnet)
|
||||
|
||||
c.ipRange, err = newIPRange(c.RangeStart, c.RangeEnd)
|
||||
if err != nil {
|
||||
// Don't wrap an errors since it's inforative enough as is and there is
|
||||
// an annotation deferred already.
|
||||
return err
|
||||
}
|
||||
|
||||
if c.ipRange.contains(gatewayIP) {
|
||||
return fmt.Errorf("gateway ip %v in the ip range: %v-%v",
|
||||
gatewayIP,
|
||||
c.RangeStart,
|
||||
c.RangeEnd,
|
||||
)
|
||||
}
|
||||
|
||||
if !c.subnet.Contains(c.RangeStart) {
|
||||
return fmt.Errorf("range start %v is outside network %v",
|
||||
c.RangeStart,
|
||||
c.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
if !c.subnet.Contains(c.RangeEnd) {
|
||||
return fmt.Errorf("range end %v is outside network %v",
|
||||
c.RangeEnd,
|
||||
c.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// V6ServerConf - server configuration
|
||||
type V6ServerConf struct {
|
||||
Enabled bool `yaml:"-" json:"-"`
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func normalizeIP(ip net.IP) net.IP {
|
||||
}
|
||||
|
||||
// Load lease table from DB
|
||||
func (s *Server) dbLoad() (err error) {
|
||||
func (s *server) dbLoad() (err error) {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
v6StaticLeases := []*Lease{}
|
||||
@@ -132,7 +132,7 @@ func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
|
||||
}
|
||||
|
||||
// Store lease table in DB
|
||||
func (s *Server) dbStore() (err error) {
|
||||
func (s *server) dbStore() (err error) {
|
||||
// Use an empty slice here as opposed to nil so that it doesn't write
|
||||
// "null" into the database file if leases are empty.
|
||||
leases := []leaseJSON{}
|
||||
|
||||
@@ -6,12 +6,11 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,9 +20,19 @@ const (
|
||||
// TODO(e.burkov): Remove it when static leases determining mechanism
|
||||
// will be improved.
|
||||
leaseExpireStatic = 1
|
||||
|
||||
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
|
||||
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
|
||||
|
||||
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
|
||||
DefaultDHCPTimeoutICMP = 1000
|
||||
)
|
||||
|
||||
var webHandlersRegistered = false
|
||||
// Currently used defaults for ifaceDNSAddrs.
|
||||
const (
|
||||
defaultMaxAttempts int = 10
|
||||
defaultBackoff time.Duration = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Lease contains the necessary information about a DHCP lease
|
||||
type Lease struct {
|
||||
@@ -119,30 +128,6 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServerConfig is the configuration for the DHCP server. The order of YAML
|
||||
// fields is important, since the YAML configuration file follows it.
|
||||
type ServerConfig struct {
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
// LocalDomainName is the domain name used for DHCP hosts. For example,
|
||||
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
|
||||
// when LocalDomainName is "lan".
|
||||
LocalDomainName string `yaml:"local_domain_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"`
|
||||
}
|
||||
|
||||
// OnLeaseChangedT is a callback for lease changes.
|
||||
type OnLeaseChangedT func(flags int)
|
||||
|
||||
@@ -156,8 +141,68 @@ const (
|
||||
LeaseChangedDBStore
|
||||
)
|
||||
|
||||
// Server - the current state of the DHCP server
|
||||
type Server struct {
|
||||
// GetLeasesFlags are the flags for GetLeases.
|
||||
type GetLeasesFlags uint8
|
||||
|
||||
// GetLeasesFlags values
|
||||
const (
|
||||
LeasesDynamic GetLeasesFlags = 0b01
|
||||
LeasesStatic GetLeasesFlags = 0b10
|
||||
|
||||
LeasesAll = LeasesDynamic | LeasesStatic
|
||||
)
|
||||
|
||||
// Interface is the DHCP server that deals with both IP address families.
|
||||
type Interface interface {
|
||||
Start() (err error)
|
||||
Stop() (err error)
|
||||
Enabled() (ok bool)
|
||||
|
||||
Leases(flags GetLeasesFlags) (leases []*Lease)
|
||||
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
|
||||
FindMACbyIP(ip net.IP) (mac net.HardwareAddr)
|
||||
|
||||
WriteDiskConfig(c *ServerConfig)
|
||||
}
|
||||
|
||||
// MockInterface is a mock Interface implementation.
|
||||
//
|
||||
// TODO(e.burkov): Move to aghtest when the API stabilized.
|
||||
type MockInterface struct {
|
||||
OnStart func() (err error)
|
||||
OnStop func() (err error)
|
||||
OnEnabled func() (ok bool)
|
||||
OnLeases func(flags GetLeasesFlags) (leases []*Lease)
|
||||
OnSetOnLeaseChanged func(f OnLeaseChangedT)
|
||||
OnFindMACbyIP func(ip net.IP) (mac net.HardwareAddr)
|
||||
OnWriteDiskConfig func(c *ServerConfig)
|
||||
}
|
||||
|
||||
var _ Interface = (*MockInterface)(nil)
|
||||
|
||||
// Start implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Start() (err error) { return s.OnStart() }
|
||||
|
||||
// Stop implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Stop() (err error) { return s.OnStop() }
|
||||
|
||||
// Enabled implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Enabled() (ok bool) { return s.OnEnabled() }
|
||||
|
||||
// Leases implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) Leases(flags GetLeasesFlags) (ls []*Lease) { return s.OnLeases(flags) }
|
||||
|
||||
// SetOnLeaseChanged implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) SetOnLeaseChanged(f OnLeaseChangedT) { s.OnSetOnLeaseChanged(f) }
|
||||
|
||||
// FindMACbyIP implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return s.OnFindMACbyIP(ip) }
|
||||
|
||||
// WriteDiskConfig implements the Interface for *MockInterface.
|
||||
func (s *MockInterface) WriteDiskConfig(c *ServerConfig) { s.OnWriteDiskConfig(c) }
|
||||
|
||||
// server is the DHCP service that handles DHCPv4, DHCPv6, and HTTP API.
|
||||
type server struct {
|
||||
srv4 DHCPServer
|
||||
srv6 DHCPServer
|
||||
|
||||
@@ -169,27 +214,15 @@ type Server struct {
|
||||
onLeaseChanged []OnLeaseChangedT
|
||||
}
|
||||
|
||||
// GetLeasesFlags are the flags for GetLeases.
|
||||
type GetLeasesFlags uint8
|
||||
// type check
|
||||
var _ Interface = (*server)(nil)
|
||||
|
||||
// GetLeasesFlags values
|
||||
const (
|
||||
LeasesDynamic GetLeasesFlags = 0b0001
|
||||
LeasesStatic GetLeasesFlags = 0b0010
|
||||
|
||||
LeasesAll = LeasesDynamic | LeasesStatic
|
||||
)
|
||||
|
||||
// ServerInterface is an interface for servers.
|
||||
type ServerInterface interface {
|
||||
Enabled() (ok bool)
|
||||
Leases(flags GetLeasesFlags) (leases []*Lease)
|
||||
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
|
||||
}
|
||||
|
||||
// Create - create object
|
||||
func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
s = &Server{
|
||||
// Create initializes and returns the DHCP server handling both address
|
||||
// families. It also registers the corresponding HTTP API endpoints.
|
||||
//
|
||||
// TODO(e.burkov): Don't register handlers, see TODO on [aghhttp.RegisterFunc].
|
||||
func Create(conf *ServerConfig) (s *server, err error) {
|
||||
s = &server{
|
||||
conf: &ServerConfig{
|
||||
ConfigModified: conf.ConfigModified,
|
||||
|
||||
@@ -204,35 +237,20 @@ func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
},
|
||||
}
|
||||
|
||||
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Our DHCP server doesn't work on Windows yet, so
|
||||
// signal that to the front with an HTTP 501.
|
||||
//
|
||||
// TODO(a.garipov): This needs refactoring. We
|
||||
// shouldn't even try and initialize a DHCP server on
|
||||
// Windows, but there are currently too many
|
||||
// interconnected parts--such as HTTP handlers and
|
||||
// frontend--to make that work properly.
|
||||
s.registerNotImplementedHandlers()
|
||||
} else {
|
||||
s.registerHandlers()
|
||||
}
|
||||
|
||||
webHandlersRegistered = true
|
||||
}
|
||||
s.registerHandlers()
|
||||
|
||||
v4conf := conf.Conf4
|
||||
v4conf.Enabled = s.conf.Enabled
|
||||
if len(v4conf.RangeStart) == 0 {
|
||||
v4conf.Enabled = false
|
||||
}
|
||||
|
||||
v4conf.InterfaceName = s.conf.InterfaceName
|
||||
v4conf.notify = s.onNotify
|
||||
s.srv4, err = v4Create(v4conf)
|
||||
v4conf.Enabled = s.conf.Enabled && len(v4conf.RangeStart) != 0
|
||||
|
||||
s.srv4, err = v4Create(&v4conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
|
||||
if v4conf.Enabled {
|
||||
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
|
||||
}
|
||||
|
||||
log.Error("creating dhcpv4 srv: %s", err)
|
||||
}
|
||||
|
||||
v6conf := conf.Conf6
|
||||
@@ -265,12 +283,12 @@ func Create(conf *ServerConfig) (s *Server, err error) {
|
||||
}
|
||||
|
||||
// Enabled returns true when the server is enabled.
|
||||
func (s *Server) Enabled() (ok bool) {
|
||||
func (s *server) Enabled() (ok bool) {
|
||||
return s.conf.Enabled
|
||||
}
|
||||
|
||||
// resetLeases resets all leases in the lease database.
|
||||
func (s *Server) resetLeases() (err error) {
|
||||
func (s *server) resetLeases() (err error) {
|
||||
err = s.srv4.ResetLeases(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -287,7 +305,7 @@ func (s *Server) resetLeases() (err error) {
|
||||
}
|
||||
|
||||
// server calls this function after DB is updated
|
||||
func (s *Server) onNotify(flags uint32) {
|
||||
func (s *server) onNotify(flags uint32) {
|
||||
if flags == LeaseChangedDBStore {
|
||||
err := s.dbStore()
|
||||
if err != nil {
|
||||
@@ -301,31 +319,28 @@ func (s *Server) onNotify(flags uint32) {
|
||||
}
|
||||
|
||||
// SetOnLeaseChanged - set callback
|
||||
func (s *Server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
|
||||
func (s *server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
|
||||
s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
|
||||
}
|
||||
|
||||
func (s *Server) notify(flags int) {
|
||||
if len(s.onLeaseChanged) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *server) notify(flags int) {
|
||||
for _, f := range s.onLeaseChanged {
|
||||
f(flags)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (s *Server) WriteDiskConfig(c *ServerConfig) {
|
||||
func (s *server) WriteDiskConfig(c *ServerConfig) {
|
||||
c.Enabled = s.conf.Enabled
|
||||
c.InterfaceName = s.conf.InterfaceName
|
||||
c.LocalDomainName = s.conf.LocalDomainName
|
||||
|
||||
s.srv4.WriteDiskConfig4(&c.Conf4)
|
||||
s.srv6.WriteDiskConfig6(&c.Conf6)
|
||||
}
|
||||
|
||||
// Start will listen on port 67 and serve DHCP requests.
|
||||
func (s *Server) Start() (err error) {
|
||||
func (s *server) Start() (err error) {
|
||||
err = s.srv4.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -340,7 +355,7 @@ func (s *Server) Start() (err error) {
|
||||
}
|
||||
|
||||
// Stop closes the listening UDP socket
|
||||
func (s *Server) Stop() (err error) {
|
||||
func (s *server) Stop() (err error) {
|
||||
err = s.srv4.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -356,12 +371,12 @@ func (s *Server) Stop() (err error) {
|
||||
|
||||
// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for
|
||||
// concurrent use.
|
||||
func (s *Server) Leases(flags GetLeasesFlags) (leases []*Lease) {
|
||||
func (s *server) Leases(flags GetLeasesFlags) (leases []*Lease) {
|
||||
return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...)
|
||||
}
|
||||
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
func (s *server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
if ip.To4() != nil {
|
||||
return s.srv4.FindMACbyIP(ip)
|
||||
}
|
||||
@@ -369,6 +384,6 @@ func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
}
|
||||
|
||||
// AddStaticLease - add static v4 lease
|
||||
func (s *Server) AddStaticLease(l *Lease) error {
|
||||
func (s *server) AddStaticLease(l *Lease) error {
|
||||
return s.srv4.AddStaticLease(l)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -26,13 +25,13 @@ func testNotify(flags uint32) {
|
||||
// Leases database store/load.
|
||||
func TestDB(t *testing.T) {
|
||||
var err error
|
||||
s := Server{
|
||||
s := server{
|
||||
conf: &ServerConfig{
|
||||
DBFilePath: dbFilename,
|
||||
},
|
||||
}
|
||||
|
||||
s.srv4, err = v4Create(V4ServerConf{
|
||||
s.srv4, err = v4Create(&V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: net.IP{192, 168, 10, 100},
|
||||
RangeEnd: net.IP{192, 168, 10, 200},
|
||||
@@ -88,32 +87,6 @@ func TestDB(t *testing.T) {
|
||||
assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix())
|
||||
}
|
||||
|
||||
func TestIsValidSubnetMask(t *testing.T) {
|
||||
testCases := []struct {
|
||||
mask net.IP
|
||||
want bool
|
||||
}{{
|
||||
mask: net.IP{255, 255, 255, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 254, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 252, 0},
|
||||
want: true,
|
||||
}, {
|
||||
mask: net.IP{255, 255, 253, 0},
|
||||
}, {
|
||||
mask: net.IP{255, 255, 255, 1},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.mask.String(), func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, isValidSubnetMask(tc.mask))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeLeases(t *testing.T) {
|
||||
dynLeases := []*Lease{{
|
||||
HWAddr: net.HardwareAddr{1, 2, 3, 4},
|
||||
@@ -174,7 +147,7 @@ func TestV4Server_badRange(t *testing.T) {
|
||||
notify: testNotify,
|
||||
}
|
||||
|
||||
_, err := v4Create(conf)
|
||||
_, err := v4Create(&conf)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("%v is not an IP address", ip)
|
||||
}
|
||||
|
||||
ip4 = ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
|
||||
}
|
||||
|
||||
return ip4, nil
|
||||
}
|
||||
|
||||
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
|
||||
func isValidSubnetMask(mask net.IP) bool {
|
||||
var n uint32
|
||||
n = binary.BigEndian.Uint32(mask)
|
||||
for i := 0; i != 32; i++ {
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
if (n & 0x80000000) == 0 {
|
||||
return false
|
||||
}
|
||||
n <<= 1
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,21 +1,19 @@
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
type v4ServerConfJSON struct {
|
||||
@@ -26,12 +24,12 @@ type v4ServerConfJSON struct {
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
|
||||
func (j *v4ServerConfJSON) toServerConf() *V4ServerConf {
|
||||
if j == nil {
|
||||
return V4ServerConf{}
|
||||
return &V4ServerConf{}
|
||||
}
|
||||
|
||||
return V4ServerConf{
|
||||
return &V4ServerConf{
|
||||
GatewayIP: j.GatewayIP,
|
||||
SubnetMask: j.SubnetMask,
|
||||
RangeStart: j.RangeStart,
|
||||
@@ -66,7 +64,7 @@ type dhcpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status := &dhcpStatusResponse{
|
||||
Enabled: s.conf.Enabled,
|
||||
IfaceName: s.conf.InterfaceName,
|
||||
@@ -81,6 +79,7 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status.StaticLeases = s.Leases(LeasesStatic)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
err := json.NewEncoder(w).Encode(status)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
@@ -93,28 +92,26 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) enableDHCP(ifaceName string) (code int, err error) {
|
||||
func (s *server) enableDHCP(ifaceName string) (code int, err error) {
|
||||
var hasStaticIP bool
|
||||
hasStaticIP, err = aghnet.IfaceHasStaticIP(ifaceName)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
// ErrPermission may happen here on Linux systems where
|
||||
// AdGuard Home is installed using Snap. That doesn't
|
||||
// necessarily mean that the machine doesn't have
|
||||
// a static IP, so we can assume that it has and go on.
|
||||
// If the machine doesn't, we'll get an error later.
|
||||
// ErrPermission may happen here on Linux systems where AdGuard Home
|
||||
// is installed using Snap. That doesn't necessarily mean that the
|
||||
// machine doesn't have a static IP, so we can assume that it has
|
||||
// and go on. If the machine doesn't, we'll get an error later.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2667.
|
||||
//
|
||||
// TODO(a.garipov): I was thinking about moving this
|
||||
// into IfaceHasStaticIP, but then we wouldn't be able
|
||||
// to log it. Think about it more.
|
||||
// TODO(a.garipov): I was thinking about moving this into
|
||||
// IfaceHasStaticIP, but then we wouldn't be able to log it. Think
|
||||
// about it more.
|
||||
log.Info("error while checking static ip: %s; "+
|
||||
"assuming machine has static ip and going on", err)
|
||||
hasStaticIP = true
|
||||
} else if errors.Is(err, aghnet.ErrNoStaticIPInfo) {
|
||||
// Couldn't obtain a definitive answer. Assume static
|
||||
// IP an go on.
|
||||
// Couldn't obtain a definitive answer. Assume static IP an go on.
|
||||
log.Info("can't check for static ip; " +
|
||||
"assuming machine has static ip and going on")
|
||||
hasStaticIP = true
|
||||
@@ -149,34 +146,39 @@ type dhcpServerConfigJSON struct {
|
||||
Enabled aghalg.NullBool `json:"enabled"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfigV4(
|
||||
func (s *server) handleDHCPSetConfigV4(
|
||||
conf *dhcpServerConfigJSON,
|
||||
) (srv4 DHCPServer, enabled bool, err error) {
|
||||
) (srv DHCPServer, enabled bool, err error) {
|
||||
if conf.V4 == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
v4Conf := v4JSONToServerConf(conf.V4)
|
||||
v4Conf := conf.V4.toServerConf()
|
||||
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
|
||||
if len(v4Conf.RangeStart) == 0 {
|
||||
v4Conf.Enabled = false
|
||||
}
|
||||
|
||||
enabled = v4Conf.Enabled
|
||||
v4Conf.InterfaceName = conf.InterfaceName
|
||||
|
||||
c4 := V4ServerConf{}
|
||||
s.srv4.WriteDiskConfig4(&c4)
|
||||
// Set the default values for the fields not configurable via web API.
|
||||
c4 := &V4ServerConf{
|
||||
notify: s.onNotify,
|
||||
ICMPTimeout: s.conf.Conf4.ICMPTimeout,
|
||||
Options: s.conf.Conf4.Options,
|
||||
}
|
||||
|
||||
s.srv4.WriteDiskConfig4(c4)
|
||||
v4Conf.notify = c4.notify
|
||||
v4Conf.ICMPTimeout = c4.ICMPTimeout
|
||||
v4Conf.Options = c4.Options
|
||||
|
||||
srv4, err = v4Create(v4Conf)
|
||||
srv4, err := v4Create(v4Conf)
|
||||
|
||||
return srv4, enabled, err
|
||||
return srv4, srv4.enabled(), err
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfigV6(
|
||||
func (s *server) handleDHCPSetConfigV6(
|
||||
conf *dhcpServerConfigJSON,
|
||||
) (srv6 DHCPServer, enabled bool, err error) {
|
||||
if conf.V6 == nil {
|
||||
@@ -205,7 +207,7 @@ func (s *Server) handleDHCPSetConfigV6(
|
||||
return srv6, enabled, err
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
conf := &dhcpServerConfigJSON{}
|
||||
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
|
||||
conf.InterfaceName = s.conf.InterfaceName
|
||||
@@ -287,7 +289,7 @@ type netInterfaceJSON struct {
|
||||
Addrs6 []net.IP `json:"ipv6_addresses"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]netInterfaceJSON{}
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
@@ -406,31 +408,37 @@ type dhcpSearchResult struct {
|
||||
V6 dhcpSearchV6Result `json:"v6"`
|
||||
}
|
||||
|
||||
// Perform the following tasks:
|
||||
// . Search for another DHCP server running
|
||||
// . Check if a static IP is configured for the network interface
|
||||
// Respond with results
|
||||
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// findActiveServerReq is the JSON structure for the request to find active DHCP
|
||||
// servers.
|
||||
type findActiveServerReq struct {
|
||||
Interface string `json:"interface"`
|
||||
}
|
||||
|
||||
// handleDHCPFindActiveServer performs the following tasks:
|
||||
// 1. searches for another DHCP server in the network;
|
||||
// 2. check if a static IP is configured for the network interface;
|
||||
// 3. responds with the results.
|
||||
func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
req := &findActiveServerReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Error(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ifaceName := strings.TrimSpace(string(body))
|
||||
ifaceName := req.Interface
|
||||
if ifaceName == "" {
|
||||
msg := "empty interface name specified"
|
||||
log.Error(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "empty interface name")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result := dhcpSearchResult{
|
||||
result := &dhcpSearchResult{
|
||||
V4: dhcpSearchV4Result{
|
||||
OtherServer: dhcpSearchOtherResult{
|
||||
Found: "no",
|
||||
@@ -455,6 +463,14 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String()
|
||||
}
|
||||
|
||||
setOtherDHCPResult(ifaceName, result)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, result)
|
||||
}
|
||||
|
||||
// setOtherDHCPResult sets the results of the check for another DHCP server in
|
||||
// result.
|
||||
func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
|
||||
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName)
|
||||
if err4 != nil {
|
||||
result.V4.OtherServer.Found = "error"
|
||||
@@ -462,27 +478,16 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
} else if found4 {
|
||||
result.V4.OtherServer.Found = "yes"
|
||||
}
|
||||
|
||||
if err6 != nil {
|
||||
result.V6.OtherServer.Found = "error"
|
||||
result.V6.OtherServer.Error = err6.Error()
|
||||
} else if found6 {
|
||||
result.V6.OtherServer.Found = "yes"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(result)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Failed to marshal DHCP found json: %s",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
@@ -497,20 +502,16 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
ip4 := l.IP.To4()
|
||||
if ip4 == nil {
|
||||
var srv DHCPServer
|
||||
if ip4 := l.IP.To4(); ip4 != nil {
|
||||
l.IP = ip4
|
||||
srv = s.srv4
|
||||
} else {
|
||||
l.IP = l.IP.To16()
|
||||
|
||||
err = s.srv6.AddStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
}
|
||||
|
||||
return
|
||||
srv = s.srv6
|
||||
}
|
||||
|
||||
l.IP = ip4
|
||||
err = s.srv4.AddStaticLease(l)
|
||||
err = srv.AddStaticLease(l)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
@@ -518,7 +519,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
l := &Lease{}
|
||||
err := json.NewDecoder(r.Body).Decode(l)
|
||||
if err != nil {
|
||||
@@ -555,14 +556,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
|
||||
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
|
||||
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
|
||||
DefaultDHCPTimeoutICMP = 1000
|
||||
)
|
||||
|
||||
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
|
||||
@@ -586,7 +580,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
DBFilePath: s.conf.DBFilePath,
|
||||
}
|
||||
|
||||
v4conf := V4ServerConf{
|
||||
v4conf := &V4ServerConf{
|
||||
LeaseDuration: DefaultDHCPLeaseTTL,
|
||||
ICMPTimeout: DefaultDHCPTimeoutICMP,
|
||||
notify: s.onNotify,
|
||||
@@ -602,7 +596,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
s.conf.ConfigModified()
|
||||
}
|
||||
|
||||
func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.resetLeases()
|
||||
if err != nil {
|
||||
msg := "resetting leases: %s"
|
||||
@@ -612,7 +606,11 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
func (s *server) registerHandlers() {
|
||||
if s.conf.HTTPRegister == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
|
||||
@@ -622,44 +620,3 @@ func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
|
||||
}
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in .../home and
|
||||
// other packages after refactoring the web handler registering.
|
||||
type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// notImplemented returns a handler that replies to any request with an HTTP 501
|
||||
// Not Implemented status and a JSON error with the provided message msg.
|
||||
//
|
||||
// TODO(a.garipov): Either take the logger from the server after we've
|
||||
// refactored logging or make this not a method of *Server.
|
||||
func (s *Server) notImplemented(msg string) (f func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
|
||||
err := json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: msg,
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug("writing 501 json response: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) registerNotImplementedHandlers() {
|
||||
h := s.notImplemented("dhcp is not supported on windows")
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", h)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", h)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", h)
|
||||
}
|
||||
55
internal/dhcpd/http_windows.go
Normal file
55
internal/dhcpd/http_windows.go
Normal file
@@ -0,0 +1,55 @@
|
||||
//go:build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// jsonError is a generic JSON error response.
|
||||
//
|
||||
// TODO(a.garipov): Merge together with the implementations in .../home and
|
||||
// other packages after refactoring the web handler registering.
|
||||
type jsonError struct {
|
||||
// Message is the error message, an opaque string.
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// notImplemented is a handler that replies to any request with an HTTP 501 Not
|
||||
// Implemented status and a JSON error with the provided message msg.
|
||||
//
|
||||
// TODO(a.garipov): Either take the logger from the server after we've
|
||||
// refactored logging or make this not a method of *Server.
|
||||
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
|
||||
err := json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: aghos.Unsupported("dhcp").Error(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug("writing 501 json response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// registerHandlers sets the handlers for DHCP HTTP API that always respond with
|
||||
// an HTTP 501, since DHCP server doesn't work on Windows yet.
|
||||
//
|
||||
// TODO(a.garipov): This needs refactoring. We shouldn't even try and
|
||||
// initialize a DHCP server on Windows, but there are currently too many
|
||||
// interconnected parts--such as HTTP handlers and frontend--to make that work
|
||||
// properly.
|
||||
func (s *server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
|
||||
}
|
||||
@@ -1,23 +1,28 @@
|
||||
//go:build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_notImplemented(t *testing.T) {
|
||||
s := &Server{}
|
||||
h := s.notImplemented("never!")
|
||||
s := &server{}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
h(w, r)
|
||||
s.notImplemented(w, r)
|
||||
assert.Equal(t, http.StatusNotImplemented, w.Code)
|
||||
assert.Equal(t, `{"message":"never!"}`+"\n", w.Body.String())
|
||||
|
||||
wantStr := fmt.Sprintf("{%q:%q}", "message", aghos.Unsupported("dhcp"))
|
||||
assert.JSONEq(t, wantStr, w.Body.String())
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -197,10 +196,10 @@ func parseDHCPOption(s string) (code dhcpv4.OptionCode, val dhcpv4.OptionValue,
|
||||
|
||||
// prepareOptions builds the set of DHCP options according to host requirements
|
||||
// document and values from conf.
|
||||
func prepareOptions(conf V4ServerConf) (implicit, explicit dhcpv4.Options) {
|
||||
func (s *v4Server) prepareOptions() {
|
||||
// Set default values of host configuration parameters listed in Appendix A
|
||||
// of RFC-2131.
|
||||
implicit = dhcpv4.OptionsFromList(
|
||||
s.implicitOpts = dhcpv4.OptionsFromList(
|
||||
// IP-Layer Per Host
|
||||
|
||||
// An Internet host that includes embedded gateway code MUST have a
|
||||
@@ -376,14 +375,14 @@ func prepareOptions(conf V4ServerConf) (implicit, explicit dhcpv4.Options) {
|
||||
|
||||
// Set the Router Option to working subnet's IP since it's initialized
|
||||
// with the address of the gateway.
|
||||
dhcpv4.OptRouter(conf.subnet.IP),
|
||||
dhcpv4.OptRouter(s.conf.subnet.IP),
|
||||
|
||||
dhcpv4.OptSubnetMask(conf.subnet.Mask),
|
||||
dhcpv4.OptSubnetMask(s.conf.subnet.Mask),
|
||||
)
|
||||
|
||||
// Set values for explicitly configured options.
|
||||
explicit = dhcpv4.Options{}
|
||||
for i, o := range conf.Options {
|
||||
s.explicitOpts = dhcpv4.Options{}
|
||||
for i, o := range s.conf.Options {
|
||||
code, val, err := parseDHCPOption(o)
|
||||
if err != nil {
|
||||
log.Error("dhcpv4: bad option string at index %d: %s", i, err)
|
||||
@@ -391,17 +390,15 @@ func prepareOptions(conf V4ServerConf) (implicit, explicit dhcpv4.Options) {
|
||||
continue
|
||||
}
|
||||
|
||||
explicit.Update(dhcpv4.Option{Code: code, Value: val})
|
||||
s.explicitOpts.Update(dhcpv4.Option{Code: code, Value: val})
|
||||
// Remove those from the implicit options.
|
||||
delete(implicit, code.Code())
|
||||
delete(s.implicitOpts, code.Code())
|
||||
}
|
||||
|
||||
log.Debug("dhcpv4: implicit options:\n%s", implicit.Summary(nil))
|
||||
log.Debug("dhcpv4: explicit options:\n%s", explicit.Summary(nil))
|
||||
log.Debug("dhcpv4: implicit options:\n%s", s.implicitOpts.Summary(nil))
|
||||
log.Debug("dhcpv4: explicit options:\n%s", s.explicitOpts.Summary(nil))
|
||||
|
||||
if len(explicit) == 0 {
|
||||
explicit = nil
|
||||
if len(s.explicitOpts) == 0 {
|
||||
s.explicitOpts = nil
|
||||
}
|
||||
|
||||
return implicit, explicit
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -250,17 +249,21 @@ func TestPrepareOptions(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
implicit, explicit := prepareOptions(V4ServerConf{
|
||||
s := &v4Server{
|
||||
conf: &V4ServerConf{
|
||||
// Just to avoid nil pointer dereference.
|
||||
subnet: &net.IPNet{},
|
||||
Options: tc.opts,
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.wantExplicit, explicit)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s.prepareOptions()
|
||||
|
||||
for c := range explicit {
|
||||
assert.NotContains(t, implicit, c)
|
||||
assert.Equal(t, tc.wantExplicit, s.explicitOpts)
|
||||
|
||||
for c := range s.explicitOpts {
|
||||
assert.NotContains(t, s.implicitOpts, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Currently used defaults for ifaceDNSAddrs.
|
||||
const (
|
||||
defaultMaxAttempts int = 10
|
||||
|
||||
defaultBackoff time.Duration = 500 * time.Millisecond
|
||||
)
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -9,15 +8,19 @@ import "net"
|
||||
|
||||
type winServer struct{}
|
||||
|
||||
func (s *winServer) ResetLeases(_ []*Lease) (err error) { return nil }
|
||||
func (s *winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
|
||||
func (s *winServer) getLeasesRef() []*Lease { return nil }
|
||||
func (s *winServer) AddStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (s *winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (s *winServer) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return nil }
|
||||
func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {}
|
||||
func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {}
|
||||
func (s *winServer) Start() (err error) { return nil }
|
||||
func (s *winServer) Stop() (err error) { return nil }
|
||||
func v4Create(conf V4ServerConf) (DHCPServer, error) { return &winServer{}, nil }
|
||||
func v6Create(conf V6ServerConf) (DHCPServer, error) { return &winServer{}, nil }
|
||||
// type check
|
||||
var _ DHCPServer = winServer{}
|
||||
|
||||
func (winServer) ResetLeases(_ []*Lease) (err error) { return nil }
|
||||
func (winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
|
||||
func (winServer) getLeasesRef() []*Lease { return nil }
|
||||
func (winServer) AddStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
|
||||
func (winServer) FindMACbyIP(_ net.IP) (mac net.HardwareAddr) { return nil }
|
||||
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
|
||||
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
|
||||
func (winServer) Start() (err error) { return nil }
|
||||
func (winServer) Stop() (err error) { return nil }
|
||||
|
||||
func v4Create(_ *V4ServerConf) (s DHCPServer, err error) { return winServer{}, nil }
|
||||
func v6Create(_ V6ServerConf) (s DHCPServer, err error) { return winServer{}, nil }
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -30,8 +29,9 @@ import (
|
||||
//
|
||||
// TODO(a.garipov): Think about unifying this and v6Server.
|
||||
type v4Server struct {
|
||||
conf V4ServerConf
|
||||
srv *server4.Server
|
||||
conf *V4ServerConf
|
||||
|
||||
srv *server4.Server
|
||||
|
||||
// implicitOpts are the options listed in Appendix A of RFC 2131 initialized
|
||||
// with default values. It must not have intersections with [explicitOpts].
|
||||
@@ -55,9 +55,15 @@ type v4Server struct {
|
||||
leases []*Lease
|
||||
}
|
||||
|
||||
func (s *v4Server) enabled() (ok bool) {
|
||||
return s.conf != nil && s.conf.Enabled
|
||||
}
|
||||
|
||||
// WriteDiskConfig4 - write configuration
|
||||
func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) {
|
||||
*c = s.conf
|
||||
if s.conf != nil {
|
||||
*c = *s.conf
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig6 - write configuration
|
||||
@@ -114,8 +120,8 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna
|
||||
func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if !s.conf.Enabled {
|
||||
return
|
||||
if s.conf == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.leasedOffsets = newBitSet()
|
||||
@@ -129,12 +135,7 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
|
||||
err = s.addLease(l)
|
||||
if err != nil {
|
||||
// TODO(a.garipov): Wrap and bubble up the error.
|
||||
log.Error(
|
||||
"dhcpv4: reset: re-adding a lease for %s (%s): %s",
|
||||
l.IP,
|
||||
l.HWAddr,
|
||||
err,
|
||||
)
|
||||
log.Error("dhcpv4: reset: re-adding a lease for %s (%s): %s", l.IP, l.HWAddr, err)
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -336,11 +337,19 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
|
||||
return errors.Error("lease not found")
|
||||
}
|
||||
|
||||
// ErrUnconfigured is returned from the server's method when it requires the
|
||||
// server to be configured and it's not.
|
||||
const ErrUnconfigured errors.Error = "server is unconfigured"
|
||||
|
||||
// AddStaticLease implements the DHCPServer interface for *v4Server. It is safe
|
||||
// for concurrent use.
|
||||
func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
|
||||
|
||||
if s.conf == nil {
|
||||
return ErrUnconfigured
|
||||
}
|
||||
|
||||
ip := l.IP.To4()
|
||||
if ip == nil {
|
||||
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
|
||||
@@ -414,6 +423,10 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
|
||||
func (s *v4Server) RemoveStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if s.conf == nil {
|
||||
return ErrUnconfigured
|
||||
}
|
||||
|
||||
if len(l.IP) != 4 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
@@ -1086,12 +1099,6 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
|
||||
s.send(peer, conn, req, resp)
|
||||
}
|
||||
|
||||
// minDHCPMsgSize is the minimum length of the encoded DHCP message in bytes
|
||||
// according to RFC-2131.
|
||||
//
|
||||
// See https://datatracker.ietf.org/doc/html/rfc2131#section-2.
|
||||
const minDHCPMsgSize = 576
|
||||
|
||||
// send writes resp for peer to conn considering the req's parameters according
|
||||
// to RFC-2131.
|
||||
//
|
||||
@@ -1133,16 +1140,6 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
|
||||
}
|
||||
|
||||
pktData := resp.ToBytes()
|
||||
pktLen := len(pktData)
|
||||
if pktLen < minDHCPMsgSize {
|
||||
// Expand the packet to match the minimum DHCP message length. Although
|
||||
// the dhpcv4 package deals with the BOOTP's lower packet length
|
||||
// constraint, it seems some clients expecting the length being at least
|
||||
// 576 bytes as per RFC 2131 (and an obsolete RFC 1533).
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/4337.
|
||||
pktData = append(pktData, make([]byte, minDHCPMsgSize-pktLen)...)
|
||||
}
|
||||
|
||||
log.Debug("dhcpv4: sending %d bytes to %s: %s", len(pktData), peer, resp.Summary())
|
||||
|
||||
@@ -1156,7 +1153,7 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
|
||||
func (s *v4Server) Start() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if !s.conf.Enabled {
|
||||
if !s.enabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1248,62 +1245,20 @@ func (s *v4Server) Stop() (err error) {
|
||||
}
|
||||
|
||||
// Create DHCPv4 server
|
||||
func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
|
||||
s := &v4Server{}
|
||||
s.conf = conf
|
||||
s.leaseHosts = stringutil.NewSet()
|
||||
|
||||
// TODO(a.garipov): Don't use a disabled server in other places or just
|
||||
// use an interface.
|
||||
if !conf.Enabled {
|
||||
return s, nil
|
||||
func v4Create(conf *V4ServerConf) (srv *v4Server, err error) {
|
||||
s := &v4Server{
|
||||
leaseHosts: stringutil.NewSet(),
|
||||
}
|
||||
|
||||
var routerIP net.IP
|
||||
routerIP, err = tryTo4(s.conf.GatewayIP)
|
||||
err = conf.Validate()
|
||||
if err != nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
// TODO(a.garipov): Don't use a disabled server in other places or just
|
||||
// use an interface.
|
||||
return s, err
|
||||
}
|
||||
|
||||
if s.conf.SubnetMask == nil {
|
||||
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask)
|
||||
}
|
||||
|
||||
subnetMask := make([]byte, 4)
|
||||
copy(subnetMask, s.conf.SubnetMask.To4())
|
||||
|
||||
s.conf.subnet = &net.IPNet{
|
||||
IP: routerIP,
|
||||
Mask: subnetMask,
|
||||
}
|
||||
s.conf.broadcastIP = aghnet.BroadcastFromIPNet(s.conf.subnet)
|
||||
|
||||
s.conf.ipRange, err = newIPRange(conf.RangeStart, conf.RangeEnd)
|
||||
if err != nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
|
||||
if s.conf.ipRange.contains(routerIP) {
|
||||
return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v",
|
||||
routerIP,
|
||||
conf.RangeStart,
|
||||
conf.RangeEnd,
|
||||
)
|
||||
}
|
||||
|
||||
if !s.conf.subnet.Contains(conf.RangeStart) {
|
||||
return s, fmt.Errorf("dhcpv4: range start %v is outside network %v",
|
||||
conf.RangeStart,
|
||||
s.conf.subnet,
|
||||
)
|
||||
}
|
||||
|
||||
if !s.conf.subnet.Contains(conf.RangeEnd) {
|
||||
return s, fmt.Errorf("dhcpv4: range end %v is outside network %v",
|
||||
conf.RangeEnd,
|
||||
s.conf.subnet,
|
||||
)
|
||||
}
|
||||
s.conf = &V4ServerConf{}
|
||||
*s.conf = *conf
|
||||
|
||||
// TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange.
|
||||
s.leasedOffsets = newBitSet()
|
||||
@@ -1315,7 +1270,7 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
|
||||
s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration)
|
||||
}
|
||||
|
||||
s.implicitOpts, s.explicitOpts = prepareOptions(s.conf)
|
||||
s.prepareOptions()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -30,19 +29,16 @@ var (
|
||||
DefaultSubnetMask = net.IP{255, 255, 255, 0}
|
||||
)
|
||||
|
||||
func notify4(flags uint32) {
|
||||
}
|
||||
|
||||
// defaultV4ServerConf returns the default configuration for *v4Server to use in
|
||||
// tests.
|
||||
func defaultV4ServerConf() (conf V4ServerConf) {
|
||||
return V4ServerConf{
|
||||
func defaultV4ServerConf() (conf *V4ServerConf) {
|
||||
return &V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: DefaultRangeStart,
|
||||
RangeEnd: DefaultRangeEnd,
|
||||
GatewayIP: DefaultGatewayIP,
|
||||
SubnetMask: DefaultSubnetMask,
|
||||
notify: notify4,
|
||||
notify: testNotify,
|
||||
dnsIPAddrs: []net.IP{DefaultSelfIP},
|
||||
}
|
||||
}
|
||||
@@ -350,13 +346,10 @@ func TestV4Server_handle_optionsPriority(t *testing.T) {
|
||||
defer func() { s.implicitOpts.Update(dhcpv4.OptDNS(defaultIP)) }()
|
||||
}
|
||||
|
||||
ss, err := v4Create(conf)
|
||||
var err error
|
||||
s, err = v4Create(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
var ok bool
|
||||
s, ok = ss.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{defaultIP}
|
||||
|
||||
return s
|
||||
@@ -490,10 +483,9 @@ func TestV4Server_updateOptions(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.IsType(t, (*v4Server)(nil), s)
|
||||
s4, _ := s.(*v4Server)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s4.updateOptions(req, resp)
|
||||
s.updateOptions(req, resp)
|
||||
|
||||
for c, v := range tc.wantOpts {
|
||||
if v == nil {
|
||||
@@ -596,13 +588,9 @@ func TestV4DynamicLease_Get(t *testing.T) {
|
||||
"82 ip 1.2.3.4",
|
||||
}
|
||||
|
||||
var err error
|
||||
sIface, err := v4Create(conf)
|
||||
s, err := v4Create(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, ok := sIface.(*v4Server)
|
||||
require.True(t, ok)
|
||||
|
||||
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
|
||||
s.implicitOpts.Update(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -173,7 +172,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
func (s *v6Server) AddStaticLease(l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != 16 {
|
||||
if len(l.IP) != net.IPv6len {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build darwin || freebsd || linux || openbsd
|
||||
|
||||
package dhcpd
|
||||
|
||||
@@ -128,9 +128,14 @@ type FilteringConfig struct {
|
||||
// IpsetList is the ipset configuration that allows AdGuard Home to add
|
||||
// IP addresses of the specified domain names to an ipset list. Syntax:
|
||||
//
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME
|
||||
//
|
||||
// This field is ignored if [IpsetListFileName] is set.
|
||||
IpsetList []string `yaml:"ipset"`
|
||||
|
||||
// IpsetListFileName, if set, points to the file with ipset configuration.
|
||||
// The format is the same as in [IpsetList].
|
||||
IpsetListFileName string `yaml:"ipset_file"`
|
||||
}
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
@@ -400,6 +405,26 @@ func setProxyUpstreamMode(
|
||||
}
|
||||
}
|
||||
|
||||
// prepareIpsetListSettings reads and prepares the ipset configuration either
|
||||
// from a file or from the data in the configuration file.
|
||||
func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
fn := s.conf.IpsetListFileName
|
||||
if fn == "" {
|
||||
return s.ipset.init(s.conf.IpsetList)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(fn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipsets := stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
||||
|
||||
return s.ipset.init(ipsets)
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
|
||||
if len(s.conf.CertificateChainData) == 0 || len(s.conf.PrivateKeyData) == 0 {
|
||||
|
||||
@@ -296,7 +296,7 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query{?dns}"},
|
||||
}
|
||||
|
||||
ans := &dns.SVCB{
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
Value: []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: 8044},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
|
||||
&dns.SVCBDoHPath{Template: "/dns-query{?dns}"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -267,7 +267,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
dhcpServer: testDHCP,
|
||||
localDomainSuffix: defaultLocalDomainSuffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + defaultLocalDomainSuffix: knownIP,
|
||||
@@ -378,7 +378,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
s := &Server{
|
||||
dhcpServer: &testDHCP{},
|
||||
dhcpServer: testDHCP,
|
||||
localDomainSuffix: tc.suffix,
|
||||
tableHostToIP: hostToIPTable{
|
||||
"example." + tc.suffix: knownIP,
|
||||
|
||||
@@ -58,10 +58,10 @@ type hostToIPTable = map[string]net.IP
|
||||
//
|
||||
// The zero Server is empty and ready for use.
|
||||
type Server struct {
|
||||
dnsProxy *proxy.Proxy // DNS proxy instance
|
||||
dnsFilter *filtering.DNSFilter // DNS filter instance
|
||||
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
dnsProxy *proxy.Proxy // DNS proxy instance
|
||||
dnsFilter *filtering.DNSFilter // DNS filter instance
|
||||
dhcpServer dhcpd.Interface // DHCP server instance (optional)
|
||||
queryLog querylog.QueryLog // Query log instance
|
||||
stats stats.Interface
|
||||
access *accessCtx
|
||||
|
||||
@@ -110,7 +110,7 @@ type DNSCreateParams struct {
|
||||
DNSFilter *filtering.DNSFilter
|
||||
Stats stats.Interface
|
||||
QueryLog querylog.QueryLog
|
||||
DHCPServer dhcpd.ServerInterface
|
||||
DHCPServer dhcpd.Interface
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
LocalDomain string
|
||||
@@ -446,10 +446,10 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.initDefaultSettings()
|
||||
|
||||
err = s.ipset.init(s.conf.IpsetList)
|
||||
err = s.prepareIpsetListSettings()
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
|
||||
err = s.prepareUpstreamSettings()
|
||||
|
||||
@@ -67,12 +67,13 @@ func createTestServer(
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(filterConf, filters)
|
||||
f, err := filtering.New(filterConf, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
var err error
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -774,9 +775,11 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -906,11 +909,13 @@ func TestRewrite(t *testing.T) {
|
||||
Type: dns.TypeCNAME,
|
||||
}},
|
||||
}
|
||||
f := filtering.New(c, nil)
|
||||
f, err := filtering.New(c, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
@@ -1005,26 +1010,31 @@ func publicKey(priv any) any {
|
||||
}
|
||||
}
|
||||
|
||||
type testDHCP struct{}
|
||||
|
||||
func (d *testDHCP) Enabled() (ok bool) { return true }
|
||||
|
||||
func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
return []*dhcpd.Lease{{
|
||||
IP: net.IP{192, 168, 12, 34},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
Hostname: "myhost",
|
||||
}}
|
||||
var testDHCP = &dhcpd.MockInterface{
|
||||
OnStart: func() (err error) { panic("not implemented") },
|
||||
OnStop: func() (err error) { panic("not implemented") },
|
||||
OnEnabled: func() (ok bool) { return true },
|
||||
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
|
||||
return []*dhcpd.Lease{{
|
||||
IP: net.IP{192, 168, 12, 34},
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
Hostname: "myhost",
|
||||
}}
|
||||
},
|
||||
OnSetOnLeaseChanged: func(olct dhcpd.OnLeaseChangedT) {},
|
||||
OnFindMACbyIP: func(ip net.IP) (mac net.HardwareAddr) { panic("not implemented") },
|
||||
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
|
||||
}
|
||||
|
||||
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
|
||||
|
||||
func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
const localDomain = "lan"
|
||||
|
||||
flt, err := filtering.New(&filtering.Config{}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: filtering.New(&filtering.Config{}, nil),
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: flt,
|
||||
DHCPServer: testDHCP,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
LocalDomain: localDomain,
|
||||
})
|
||||
@@ -1090,14 +1100,16 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
})
|
||||
|
||||
flt := filtering.New(&filtering.Config{
|
||||
flt, err := filtering.New(&filtering.Config{
|
||||
EtcHosts: hc,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
flt.SetEnabled(true)
|
||||
|
||||
var s *Server
|
||||
s, err = NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
|
||||
@@ -35,11 +35,12 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f := filtering.New(&filtering.Config{}, filters)
|
||||
f, err := filtering.New(&filtering.Config{}, filters)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{},
|
||||
DHCPServer: testDHCP,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
})
|
||||
|
||||
@@ -421,31 +421,34 @@ func initBlockedServices() {
|
||||
}
|
||||
|
||||
// BlockedSvcKnown - return TRUE if a blocked service name is known
|
||||
func BlockedSvcKnown(s string) bool {
|
||||
_, ok := serviceRules[s]
|
||||
func BlockedSvcKnown(s string) (ok bool) {
|
||||
_, ok = serviceRules[s]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global bool) {
|
||||
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
|
||||
setts.ServicesRules = []ServiceEntry{}
|
||||
if global {
|
||||
if list == nil {
|
||||
d.confLock.RLock()
|
||||
defer d.confLock.RUnlock()
|
||||
|
||||
list = d.Config.BlockedServices
|
||||
}
|
||||
|
||||
for _, name := range list {
|
||||
rules, ok := serviceRules[name]
|
||||
|
||||
if !ok {
|
||||
log.Error("unknown service name: %s", name)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
s := ServiceEntry{}
|
||||
s.Name = name
|
||||
s.Rules = rules
|
||||
setts.ServicesRules = append(setts.ServicesRules, s)
|
||||
setts.ServicesRules = append(setts.ServicesRules, ServiceEntry{
|
||||
Name: name,
|
||||
Rules: rules,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,10 +493,3 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
d.ConfigModified()
|
||||
}
|
||||
|
||||
// registerBlockedServicesHandlers - register HTTP handlers
|
||||
func (d *DNSFilter) registerBlockedServicesHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
|
||||
`
|
||||
|
||||
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||
setts := &Settings{
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package home
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -8,63 +8,29 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
|
||||
// filterDir is the subdirectory of a data directory to store downloaded
|
||||
// filters.
|
||||
const filterDir = "filters"
|
||||
|
||||
// Filtering - module object
|
||||
type Filtering struct {
|
||||
// conf FilteringConf
|
||||
refreshStatus uint32 // 0:none; 1:in progress
|
||||
refreshLock sync.Mutex
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
}
|
||||
// nextFilterID is a way to seed a unique ID generation.
|
||||
//
|
||||
// TODO(e.burkov): Use more deterministic approach.
|
||||
var nextFilterID = time.Now().Unix()
|
||||
|
||||
// Init - initialize the module
|
||||
func (f *Filtering) Init() {
|
||||
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
|
||||
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755)
|
||||
f.loadFilters(config.Filters)
|
||||
f.loadFilters(config.WhitelistFilters)
|
||||
deduplicateFilters()
|
||||
updateUniqueFilterID(config.Filters)
|
||||
updateUniqueFilterID(config.WhitelistFilters)
|
||||
}
|
||||
|
||||
// Start - start the module
|
||||
func (f *Filtering) Start() {
|
||||
f.RegisterFilteringHandlers()
|
||||
|
||||
// Here we should start updating filters,
|
||||
// but currently we can't wake up the periodic task to do so.
|
||||
// So for now we just start this periodic task from here.
|
||||
go f.periodicallyRefreshFilters()
|
||||
}
|
||||
|
||||
// Close - close the module
|
||||
func (f *Filtering) Close() {
|
||||
}
|
||||
|
||||
func defaultFilters() []filter {
|
||||
return []filter{
|
||||
{Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
|
||||
{Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
|
||||
}
|
||||
}
|
||||
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type filter struct {
|
||||
// FilterYAML respresents a filter list in the configuration file.
|
||||
//
|
||||
// TODO(e.burkov): Investigate if the field oredering is important.
|
||||
type FilterYAML struct {
|
||||
Enabled bool
|
||||
URL string // URL or a file path
|
||||
Name string `yaml:"name"`
|
||||
@@ -73,91 +39,108 @@ type filter struct {
|
||||
checksum uint32 // checksum of the file data
|
||||
white bool
|
||||
|
||||
filtering.Filter `yaml:",inline"`
|
||||
Filter `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Clear filter rules
|
||||
func (filter *FilterYAML) unload() {
|
||||
filter.RulesCount = 0
|
||||
filter.checksum = 0
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *FilterYAML) Path(dataDir string) string {
|
||||
return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
const (
|
||||
statusFound = 1
|
||||
statusEnabledChanged = 2
|
||||
statusURLChanged = 4
|
||||
statusURLExists = 8
|
||||
statusUpdateRequired = 0x10
|
||||
statusFound = 1 << iota
|
||||
statusEnabledChanged
|
||||
statusURLChanged
|
||||
statusURLExists
|
||||
statusUpdateRequired
|
||||
)
|
||||
|
||||
// Update properties for a filter specified by its URL
|
||||
// Return status* flags.
|
||||
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
|
||||
func (d *DNSFilter) filterSetProperties(url string, newf FilterYAML, whitelist bool) int {
|
||||
r := 0
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
filters := &config.Filters
|
||||
filters := d.Filters
|
||||
if whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
filters = d.WhitelistFilters
|
||||
}
|
||||
|
||||
for i := range *filters {
|
||||
filt := &(*filters)[i]
|
||||
if filt.URL != url {
|
||||
continue
|
||||
i := slices.IndexFunc(filters, func(filt FilterYAML) bool {
|
||||
return filt.URL == url
|
||||
})
|
||||
if i == -1 {
|
||||
return 0
|
||||
}
|
||||
|
||||
filt := &filters[i]
|
||||
|
||||
log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL, newf.Name, newf.URL, newf.Enabled)
|
||||
filt.Name = newf.Name
|
||||
|
||||
if filt.URL != newf.URL {
|
||||
r |= statusURLChanged | statusUpdateRequired
|
||||
if d.filterExistsNoLock(newf.URL) {
|
||||
return statusURLExists
|
||||
}
|
||||
|
||||
log.Debug("filter: set properties: %s: {%s %s %v}",
|
||||
filt.URL, newf.Name, newf.URL, newf.Enabled)
|
||||
filt.Name = newf.Name
|
||||
filt.URL = newf.URL
|
||||
filt.unload()
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
}
|
||||
|
||||
if filt.URL != newf.URL {
|
||||
r |= statusURLChanged | statusUpdateRequired
|
||||
if filterExistsNoLock(newf.URL) {
|
||||
return statusURLExists
|
||||
}
|
||||
filt.URL = newf.URL
|
||||
filt.unload()
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
}
|
||||
if filt.Enabled != newf.Enabled {
|
||||
r |= statusEnabledChanged
|
||||
filt.Enabled = newf.Enabled
|
||||
if filt.Enabled {
|
||||
if (r & statusURLChanged) == 0 {
|
||||
err := d.load(filt)
|
||||
if err != nil {
|
||||
// TODO(e.burkov): It seems the error is only returned when
|
||||
// the file exists and couldn't be open. Investigate and
|
||||
// improve.
|
||||
log.Error("loading filter %d: %s", filt.ID, err)
|
||||
|
||||
if filt.Enabled != newf.Enabled {
|
||||
r |= statusEnabledChanged
|
||||
filt.Enabled = newf.Enabled
|
||||
if filt.Enabled {
|
||||
if (r & statusURLChanged) == 0 {
|
||||
e := f.load(filt)
|
||||
if e != nil {
|
||||
// This isn't a fatal error,
|
||||
// because it may occur when someone removes the file from disk.
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
r |= statusUpdateRequired
|
||||
}
|
||||
filt.LastUpdated = time.Time{}
|
||||
filt.checksum = 0
|
||||
filt.RulesCount = 0
|
||||
r |= statusUpdateRequired
|
||||
}
|
||||
} else {
|
||||
filt.unload()
|
||||
}
|
||||
} else {
|
||||
filt.unload()
|
||||
}
|
||||
|
||||
return r | statusFound
|
||||
}
|
||||
return 0
|
||||
|
||||
return r | statusFound
|
||||
}
|
||||
|
||||
// Return TRUE if a filter with this URL exists
|
||||
func filterExists(url string) bool {
|
||||
config.RLock()
|
||||
r := filterExistsNoLock(url)
|
||||
config.RUnlock()
|
||||
func (d *DNSFilter) filterExists(url string) bool {
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
r := d.filterExistsNoLock(url)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func filterExistsNoLock(url string) bool {
|
||||
for _, f := range config.Filters {
|
||||
func (d *DNSFilter) filterExistsNoLock(url string) bool {
|
||||
for _, f := range d.Filters {
|
||||
if f.URL == url {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
for _, f := range d.WhitelistFilters {
|
||||
if f.URL == url {
|
||||
return true
|
||||
}
|
||||
@@ -167,26 +150,26 @@ func filterExistsNoLock(url string) bool {
|
||||
|
||||
// Add a filter
|
||||
// Return FALSE if a filter with this URL exists
|
||||
func filterAdd(f filter) bool {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
// Check for duplicates
|
||||
if filterExistsNoLock(f.URL) {
|
||||
if d.filterExistsNoLock(flt.URL) {
|
||||
return false
|
||||
}
|
||||
|
||||
if f.white {
|
||||
config.WhitelistFilters = append(config.WhitelistFilters, f)
|
||||
if flt.white {
|
||||
d.WhitelistFilters = append(d.WhitelistFilters, flt)
|
||||
} else {
|
||||
config.Filters = append(config.Filters, f)
|
||||
d.Filters = append(d.Filters, flt)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Load filters from the disk
|
||||
// And if any filter has zero ID, assign a new one
|
||||
func (f *Filtering) loadFilters(array []filter) {
|
||||
func (d *DNSFilter) loadFilters(array []FilterYAML) {
|
||||
for i := range array {
|
||||
filter := &array[i] // otherwise we're operating on a copy
|
||||
if filter.ID == 0 {
|
||||
@@ -198,32 +181,30 @@ func (f *Filtering) loadFilters(array []filter) {
|
||||
continue
|
||||
}
|
||||
|
||||
err := f.load(filter)
|
||||
err := d.load(filter)
|
||||
if err != nil {
|
||||
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func deduplicateFilters() {
|
||||
// Deduplicate filters
|
||||
i := 0 // output index, used for deletion later
|
||||
urls := map[string]bool{}
|
||||
for _, filter := range config.Filters {
|
||||
if _, ok := urls[filter.URL]; !ok {
|
||||
// we didn't see it before, keep it
|
||||
urls[filter.URL] = true // remember the URL
|
||||
config.Filters[i] = filter
|
||||
i++
|
||||
func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
|
||||
urls := stringutil.NewSet()
|
||||
lastIdx := 0
|
||||
|
||||
for _, filter := range filters {
|
||||
if !urls.Has(filter.URL) {
|
||||
urls.Add(filter.URL)
|
||||
filters[lastIdx] = filter
|
||||
lastIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// all entries we want to keep are at front, delete the rest
|
||||
config.Filters = config.Filters[:i]
|
||||
return filters[:lastIdx]
|
||||
}
|
||||
|
||||
// Set the next filter ID to max(filter.ID) + 1
|
||||
func updateUniqueFilterID(filters []filter) {
|
||||
func updateUniqueFilterID(filters []FilterYAML) {
|
||||
for _, filter := range filters {
|
||||
if nextFilterID < filter.ID {
|
||||
nextFilterID = filter.ID + 1
|
||||
@@ -238,22 +219,19 @@ func assignUniqueFilterID() int64 {
|
||||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func (f *Filtering) periodicallyRefreshFilters() {
|
||||
func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||
const maxInterval = 1 * 60 * 60
|
||||
intval := 5 // use a dynamically increasing time interval
|
||||
for {
|
||||
isNetworkErr := false
|
||||
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
|
||||
f.refreshLock.Lock()
|
||||
_, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists)
|
||||
f.refreshLock.Unlock()
|
||||
f.refreshStatus = 0
|
||||
if !isNetworkErr {
|
||||
isNetErr, ok := false, false
|
||||
if d.FiltersUpdateIntervalHours != 0 {
|
||||
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||
if ok && !isNetErr {
|
||||
intval = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
if isNetworkErr {
|
||||
if isNetErr {
|
||||
intval *= 2
|
||||
if intval > maxInterval {
|
||||
intval = maxInterval
|
||||
@@ -264,51 +242,73 @@ func (f *Filtering) periodicallyRefreshFilters() {
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh filters
|
||||
// flags: filterRefresh*
|
||||
// important:
|
||||
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
||||
// already going on.
|
||||
//
|
||||
// TRUE: ignore the fact that we're currently updating the filters
|
||||
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
|
||||
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
|
||||
if !important && !set {
|
||||
return 0, fmt.Errorf("filters update procedure is already running")
|
||||
// TODO(e.burkov): Get rid of the concurrency pattern which requires the
|
||||
// sync.Mutex.TryLock.
|
||||
func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
|
||||
if ok = d.refreshLock.TryLock(); !ok {
|
||||
return 0, false, ok
|
||||
}
|
||||
defer d.refreshLock.Unlock()
|
||||
|
||||
f.refreshLock.Lock()
|
||||
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
|
||||
f.refreshLock.Unlock()
|
||||
f.refreshStatus = 0
|
||||
return nUpdated, nil
|
||||
updated, isNetworkErr = d.refreshFiltersIntl(block, allow, force)
|
||||
|
||||
return updated, isNetworkErr, ok
|
||||
}
|
||||
|
||||
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
|
||||
var updateFilters []filter
|
||||
// refreshFilters updates the lists and returns the number of updated ones.
|
||||
// It's safe for concurrent use, but blocks at least until the previous
|
||||
// refreshing is finished.
|
||||
func (d *DNSFilter) refreshFilters(block, allow, force bool) (updated int) {
|
||||
d.refreshLock.Lock()
|
||||
defer d.refreshLock.Unlock()
|
||||
|
||||
updated, _ = d.refreshFiltersIntl(block, allow, force)
|
||||
|
||||
return updated
|
||||
}
|
||||
|
||||
// listsToUpdate returns the slice of filter lists that could be updated.
|
||||
func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) {
|
||||
now := time.Now()
|
||||
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
for i := range *filters {
|
||||
flt := &(*filters)[i] // otherwise we will be operating on a copy
|
||||
log.Debug("checking list at index %d: %v", i, flt)
|
||||
|
||||
if !flt.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if !force {
|
||||
exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour)
|
||||
if now.Before(exp) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
toUpd = append(toUpd, FilterYAML{
|
||||
Filter: Filter{
|
||||
ID: flt.ID,
|
||||
},
|
||||
URL: flt.URL,
|
||||
Name: flt.Name,
|
||||
checksum: flt.checksum,
|
||||
})
|
||||
}
|
||||
|
||||
return toUpd
|
||||
}
|
||||
|
||||
func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, []FilterYAML, []bool, bool) {
|
||||
var updateFlags []bool // 'true' if filter data has changed
|
||||
|
||||
now := time.Now()
|
||||
config.RLock()
|
||||
for i := range *filters {
|
||||
f := &(*filters)[i] // otherwise we will be operating on a copy
|
||||
|
||||
if !f.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
|
||||
if !force && expireTime > now.Unix() {
|
||||
continue
|
||||
}
|
||||
|
||||
var uf filter
|
||||
uf.ID = f.ID
|
||||
uf.URL = f.URL
|
||||
uf.Name = f.Name
|
||||
uf.checksum = f.checksum
|
||||
updateFilters = append(updateFilters, uf)
|
||||
}
|
||||
config.RUnlock()
|
||||
|
||||
updateFilters := d.listsToUpdate(filters, force)
|
||||
if len(updateFilters) == 0 {
|
||||
return 0, nil, nil, false
|
||||
}
|
||||
@@ -316,7 +316,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
||||
nfail := 0
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated, err := f.update(uf)
|
||||
updated, err := d.update(uf)
|
||||
updateFlags = append(updateFlags, updated)
|
||||
if err != nil {
|
||||
nfail++
|
||||
@@ -334,7 +334,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
||||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
|
||||
config.Lock()
|
||||
d.filtersMu.Lock()
|
||||
for k := range *filters {
|
||||
f := &(*filters)[k]
|
||||
if f.ID != uf.ID || f.URL != uf.URL {
|
||||
@@ -352,20 +352,14 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
|
||||
f.checksum = uf.checksum
|
||||
updateCount++
|
||||
}
|
||||
config.Unlock()
|
||||
d.filtersMu.Unlock()
|
||||
}
|
||||
|
||||
return updateCount, updateFilters, updateFlags, false
|
||||
}
|
||||
|
||||
const (
|
||||
filterRefreshForce = 1 // ignore last file modification date
|
||||
filterRefreshAllowlists = 2 // update allow-lists
|
||||
filterRefreshBlocklists = 4 // update block-lists
|
||||
)
|
||||
|
||||
// refreshFiltersIfNecessary checks filters and updates them if necessary. If
|
||||
// force is true, it ignores the filter.LastUpdated field value.
|
||||
// refreshFiltersIntl checks filters and updates them if necessary. If force is
|
||||
// true, it ignores the filter.LastUpdated field value.
|
||||
//
|
||||
// Algorithm:
|
||||
//
|
||||
@@ -378,53 +372,49 @@ const (
|
||||
// that this method works only on Unix systems. On Windows, don't pass
|
||||
// files to filtering, pass the whole data.
|
||||
//
|
||||
// refreshFiltersIfNecessary returns the number of updated filters. It also
|
||||
// returns true if there was a network error and nothing could be updated.
|
||||
// refreshFiltersIntl returns the number of updated filters. It also returns
|
||||
// true if there was a network error and nothing could be updated.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): What the hell?
|
||||
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
|
||||
log.Debug("Filters: updating...")
|
||||
func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
|
||||
log.Debug("filtering: updating...")
|
||||
|
||||
updateCount := 0
|
||||
var updateFilters []filter
|
||||
var updateFlags []bool
|
||||
netError := false
|
||||
netErrorW := false
|
||||
force := false
|
||||
if (flags & filterRefreshForce) != 0 {
|
||||
force = true
|
||||
updNum := 0
|
||||
var lists []FilterYAML
|
||||
var toUpd []bool
|
||||
isNetErr := false
|
||||
|
||||
if block {
|
||||
updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force)
|
||||
}
|
||||
if (flags & filterRefreshBlocklists) != 0 {
|
||||
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
|
||||
if allow {
|
||||
updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force)
|
||||
|
||||
updNum += updNumAl
|
||||
lists = append(lists, listsAl...)
|
||||
toUpd = append(toUpd, toUpdAl...)
|
||||
isNetErr = isNetErr || isNetErrAl
|
||||
}
|
||||
if (flags & filterRefreshAllowlists) != 0 {
|
||||
updateCountW := 0
|
||||
var updateFiltersW []filter
|
||||
var updateFlagsW []bool
|
||||
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
|
||||
updateCount += updateCountW
|
||||
updateFilters = append(updateFilters, updateFiltersW...)
|
||||
updateFlags = append(updateFlags, updateFlagsW...)
|
||||
}
|
||||
if netError && netErrorW {
|
||||
if isNetErr {
|
||||
return 0, true
|
||||
}
|
||||
|
||||
if updateCount != 0 {
|
||||
enableFilters(false)
|
||||
if updNum != 0 {
|
||||
d.EnableFilters(false)
|
||||
|
||||
for i := range updateFilters {
|
||||
uf := &updateFilters[i]
|
||||
updated := updateFlags[i]
|
||||
for i := range lists {
|
||||
uf := &lists[i]
|
||||
updated := toUpd[i]
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
_ = os.Remove(uf.Path() + ".old")
|
||||
_ = os.Remove(uf.Path(d.DataDir) + ".old")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Filters: update finished")
|
||||
return updateCount, false
|
||||
log.Debug("filtering: update finished")
|
||||
|
||||
return updNum, false
|
||||
}
|
||||
|
||||
// Allows printable UTF-8 text with CR, LF, TAB characters
|
||||
@@ -440,7 +430,7 @@ func isPrintableText(data []byte, len int) bool {
|
||||
}
|
||||
|
||||
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
|
||||
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
rulesCount := 0
|
||||
name := ""
|
||||
seenTitle := false
|
||||
@@ -455,7 +445,7 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
if len(line) == 0 {
|
||||
//
|
||||
} else if line[0] == '!' {
|
||||
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
|
||||
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
|
||||
name = m[0][1]
|
||||
seenTitle = true
|
||||
@@ -476,11 +466,11 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
|
||||
}
|
||||
|
||||
// Perform upgrade on a filter and update LastUpdated value
|
||||
func (f *Filtering) update(filter *filter) (bool, error) {
|
||||
b, err := f.updateIntl(filter)
|
||||
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
|
||||
b, err := d.updateIntl(filter)
|
||||
filter.LastUpdated = time.Now()
|
||||
if !b {
|
||||
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
|
||||
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
|
||||
if e != nil {
|
||||
log.Error("os.Chtimes(): %v", e)
|
||||
}
|
||||
@@ -488,7 +478,7 @@ func (f *Filtering) update(filter *filter) (bool, error) {
|
||||
return b, err
|
||||
}
|
||||
|
||||
func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) {
|
||||
func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
|
||||
htmlTest := true
|
||||
firstChunk := make([]byte, 4*1024)
|
||||
firstChunkLen := 0
|
||||
@@ -539,20 +529,20 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in
|
||||
// finalizeUpdate closes and gets rid of temporary file f with filter's content
|
||||
// according to updated. It also saves new values of flt's name, rules number
|
||||
// and checksum if sucсeeded.
|
||||
func finalizeUpdate(
|
||||
f *os.File,
|
||||
flt *filter,
|
||||
func (d *DNSFilter) finalizeUpdate(
|
||||
file *os.File,
|
||||
flt *FilterYAML,
|
||||
updated bool,
|
||||
name string,
|
||||
rnum int,
|
||||
cs uint32,
|
||||
) (err error) {
|
||||
tmpFileName := f.Name()
|
||||
tmpFileName := file.Name()
|
||||
|
||||
// Close the file before renaming it because it's required on Windows.
|
||||
//
|
||||
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
|
||||
if err = f.Close(); err != nil {
|
||||
if err = file.Close(); err != nil {
|
||||
return fmt.Errorf("closing temporary file: %w", err)
|
||||
}
|
||||
|
||||
@@ -562,9 +552,9 @@ func finalizeUpdate(
|
||||
return os.Remove(tmpFileName)
|
||||
}
|
||||
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path())
|
||||
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
|
||||
|
||||
if err = os.Rename(tmpFileName, flt.Path()); err != nil {
|
||||
if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
|
||||
return errors.WithDeferred(err, os.Remove(tmpFileName))
|
||||
}
|
||||
|
||||
@@ -578,12 +568,12 @@ func finalizeUpdate(
|
||||
// processUpdate copies filter's content from src to dst and returns the name,
|
||||
// rules number, and checksum for it. It also returns the number of bytes read
|
||||
// from src.
|
||||
func (f *Filtering) processUpdate(
|
||||
func (d *DNSFilter) processUpdate(
|
||||
src io.Reader,
|
||||
dst *os.File,
|
||||
flt *filter,
|
||||
flt *FilterYAML,
|
||||
) (name string, rnum int, cs uint32, n int, err error) {
|
||||
if n, err = f.read(src, dst, flt); err != nil {
|
||||
if n, err = d.read(src, dst, flt); err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
@@ -591,14 +581,14 @@ func (f *Filtering) processUpdate(
|
||||
return "", 0, 0, 0, err
|
||||
}
|
||||
|
||||
rnum, cs, name = f.parseFilterContents(dst)
|
||||
rnum, cs, name = d.parseFilterContents(dst)
|
||||
|
||||
return name, rnum, cs, n, nil
|
||||
}
|
||||
|
||||
// updateIntl updates the flt rewriting it's actual file. It returns true if
|
||||
// the actual update has been performed.
|
||||
func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
|
||||
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
|
||||
|
||||
var name string
|
||||
@@ -606,12 +596,12 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||
var cs uint32
|
||||
|
||||
var tmpFile *os.File
|
||||
tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "")
|
||||
tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
|
||||
ok = ok && err == nil
|
||||
if ok {
|
||||
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
|
||||
@@ -638,7 +628,7 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||
r = file
|
||||
} else {
|
||||
var resp *http.Response
|
||||
resp, err = Context.client.Get(flt.URL)
|
||||
resp, err = d.HTTPClient.Get(flt.URL)
|
||||
if err != nil {
|
||||
log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
|
||||
|
||||
@@ -655,16 +645,16 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
|
||||
r = resp.Body
|
||||
}
|
||||
|
||||
name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt)
|
||||
name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
|
||||
|
||||
return cs != flt.checksum, err
|
||||
}
|
||||
|
||||
// loads filter contents from the file in dataDir
|
||||
func (f *Filtering) load(filter *filter) (err error) {
|
||||
filterFilePath := filter.Path()
|
||||
func (d *DNSFilter) load(filter *FilterYAML) (err error) {
|
||||
filterFilePath := filter.Path(d.DataDir)
|
||||
|
||||
log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
|
||||
|
||||
file, err := os.Open(filterFilePath)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
@@ -682,7 +672,7 @@ func (f *Filtering) load(filter *filter) (err error) {
|
||||
|
||||
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
|
||||
|
||||
rulesCount, checksum, _ := f.parseFilterContents(file)
|
||||
rulesCount, checksum, _ := d.parseFilterContents(file)
|
||||
|
||||
filter.RulesCount = rulesCount
|
||||
filter.checksum = checksum
|
||||
@@ -691,56 +681,45 @@ func (f *Filtering) load(filter *filter) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear filter rules
|
||||
func (filter *filter) unload() {
|
||||
filter.RulesCount = 0
|
||||
filter.checksum = 0
|
||||
func (d *DNSFilter) EnableFilters(async bool) {
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
d.enableFiltersLocked(async)
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) Path() string {
|
||||
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
|
||||
}
|
||||
|
||||
func enableFilters(async bool) {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
enableFiltersLocked(async)
|
||||
}
|
||||
|
||||
func enableFiltersLocked(async bool) {
|
||||
filters := []filtering.Filter{{
|
||||
ID: filtering.CustomListID,
|
||||
Data: []byte(strings.Join(config.UserRules, "\n")),
|
||||
func (d *DNSFilter) enableFiltersLocked(async bool) {
|
||||
filters := []Filter{{
|
||||
ID: CustomListID,
|
||||
Data: []byte(strings.Join(d.UserRules, "\n")),
|
||||
}}
|
||||
|
||||
for _, filter := range config.Filters {
|
||||
for _, filter := range d.Filters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
filters = append(filters, filtering.Filter{
|
||||
filters = append(filters, Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
FilePath: filter.Path(d.DataDir),
|
||||
})
|
||||
}
|
||||
|
||||
var allowFilters []filtering.Filter
|
||||
for _, filter := range config.WhitelistFilters {
|
||||
var allowFilters []Filter
|
||||
for _, filter := range d.WhitelistFilters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
allowFilters = append(allowFilters, filtering.Filter{
|
||||
allowFilters = append(allowFilters, Filter{
|
||||
ID: filter.ID,
|
||||
FilePath: filter.Path(),
|
||||
FilePath: filter.Path(d.DataDir),
|
||||
})
|
||||
}
|
||||
|
||||
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
|
||||
if err := d.SetFilters(filters, allowFilters, async); err != nil {
|
||||
log.Debug("enabling filters: %s", err)
|
||||
}
|
||||
|
||||
Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled)
|
||||
d.SetEnabled(d.FilteringEnabled)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package home
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
@@ -51,15 +51,17 @@ func TestFilters(t *testing.T) {
|
||||
|
||||
l := testStartFilterListener(t, &fltContent)
|
||||
|
||||
Context = homeContext{
|
||||
workDir: t.TempDir(),
|
||||
client: &http.Client{
|
||||
tempDir := t.TempDir()
|
||||
|
||||
filters, err := New(&Config{
|
||||
DataDir: tempDir,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
Context.filters.Init()
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
f := &filter{
|
||||
f := &FilterYAML{
|
||||
URL: (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: (&netutil.IPPort{
|
||||
@@ -71,21 +73,22 @@ func TestFilters(t *testing.T) {
|
||||
}
|
||||
|
||||
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
|
||||
ok, err := Context.filters.update(f)
|
||||
var ok bool
|
||||
ok, err = filters.update(f)
|
||||
require.NoError(t, err)
|
||||
want(t, ok)
|
||||
|
||||
assert.Equal(t, wantRulesCount, f.RulesCount)
|
||||
|
||||
var dir []fs.DirEntry
|
||||
dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir))
|
||||
dir, err = os.ReadDir(filepath.Join(tempDir, filterDir))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, dir, 1)
|
||||
|
||||
require.FileExists(t, f.Path())
|
||||
require.FileExists(t, f.Path(tempDir))
|
||||
|
||||
err = Context.filters.load(f)
|
||||
err = filters.load(f)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -105,11 +108,9 @@ func TestFilters(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("load_unload", func(t *testing.T) {
|
||||
err := Context.filters.load(f)
|
||||
err = filters.load(f)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.unload()
|
||||
})
|
||||
|
||||
require.NoError(t, os.Remove(f.Path()))
|
||||
}
|
||||
@@ -6,7 +6,10 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
@@ -24,6 +27,7 @@ import (
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// The IDs of built-in filter lists.
|
||||
@@ -69,8 +73,13 @@ type Config struct {
|
||||
// enabled is used to be returned within Settings.
|
||||
//
|
||||
// It is of type uint32 to be accessed by atomic.
|
||||
//
|
||||
// TODO(e.burkov): Use atomic.Bool in Go 1.19.
|
||||
enabled uint32
|
||||
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
|
||||
ParentalEnabled bool `yaml:"parental_enabled"`
|
||||
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
||||
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
||||
@@ -98,6 +107,24 @@ type Config struct {
|
||||
|
||||
// CustomResolver is the resolver used by DNSFilter.
|
||||
CustomResolver Resolver `yaml:"-"`
|
||||
|
||||
// HTTPClient is the client to use for updating the remote filters.
|
||||
HTTPClient *http.Client `yaml:"-"`
|
||||
|
||||
// DataDir is used to store filters' contents.
|
||||
DataDir string `yaml:"-"`
|
||||
|
||||
// filtersMu protects filter lists.
|
||||
filtersMu *sync.RWMutex
|
||||
|
||||
// Filters are the blocking filter lists.
|
||||
Filters []FilterYAML `yaml:"-"`
|
||||
|
||||
// WhitelistFilters are the allowing filter lists.
|
||||
WhitelistFilters []FilterYAML `yaml:"-"`
|
||||
|
||||
// UserRules is the global list of custom rules.
|
||||
UserRules []string `yaml:"-"`
|
||||
}
|
||||
|
||||
// LookupStats store stats collected during safebrowsing or parental checks
|
||||
@@ -128,11 +155,13 @@ type hostChecker struct {
|
||||
|
||||
// DNSFilter matches hostnames and DNS requests against filtering rules.
|
||||
type DNSFilter struct {
|
||||
rulesStorage *filterlist.RuleStorage
|
||||
filteringEngine *urlfilter.DNSEngine
|
||||
rulesStorage *filterlist.RuleStorage
|
||||
filteringEngine *urlfilter.DNSEngine
|
||||
|
||||
rulesStorageAllow *filterlist.RuleStorage
|
||||
filteringEngineAllow *urlfilter.DNSEngine
|
||||
engineLock sync.RWMutex
|
||||
|
||||
engineLock sync.RWMutex
|
||||
|
||||
parentalServer string // access via methods
|
||||
safeBrowsingServer string // access via methods
|
||||
@@ -156,6 +185,12 @@ type DNSFilter struct {
|
||||
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
||||
resolver Resolver
|
||||
|
||||
refreshLock *sync.Mutex
|
||||
|
||||
// filterTitleRegexp is the regular expression to retrieve a name of a
|
||||
// filter list.
|
||||
filterTitleRegexp *regexp.Regexp
|
||||
|
||||
hostCheckers []hostChecker
|
||||
}
|
||||
|
||||
@@ -168,7 +203,7 @@ type Filter struct {
|
||||
Data []byte `yaml:"-"`
|
||||
|
||||
// ID is automatically assigned when filter is added using nextFilterID.
|
||||
ID int64
|
||||
ID int64 `yaml:"id"`
|
||||
}
|
||||
|
||||
// Reason holds an enum detailing why it was filtered or not filtered
|
||||
@@ -245,15 +280,7 @@ func (r Reason) String() string {
|
||||
}
|
||||
|
||||
// In returns true if reasons include r.
|
||||
func (r Reason) In(reasons ...Reason) (ok bool) {
|
||||
for _, reason := range reasons {
|
||||
if r == reason {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons, r) }
|
||||
|
||||
// SetEnabled sets the status of the *DNSFilter.
|
||||
func (d *DNSFilter) SetEnabled(enabled bool) {
|
||||
@@ -261,6 +288,7 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
|
||||
if enabled {
|
||||
i = 1
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&d.enabled, uint32(i))
|
||||
}
|
||||
|
||||
@@ -279,11 +307,20 @@ func (d *DNSFilter) GetConfig() (s Settings) {
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (d *DNSFilter) WriteDiskConfig(c *Config) {
|
||||
d.confLock.Lock()
|
||||
defer d.confLock.Unlock()
|
||||
func() {
|
||||
d.confLock.Lock()
|
||||
defer d.confLock.Unlock()
|
||||
|
||||
*c = d.Config
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
*c = d.Config
|
||||
c.Rewrites = cloneRewrites(c.Rewrites)
|
||||
}()
|
||||
|
||||
d.filtersMu.RLock()
|
||||
defer d.filtersMu.RUnlock()
|
||||
|
||||
c.Filters = slices.Clone(d.Filters)
|
||||
c.WhitelistFilters = slices.Clone(d.WhitelistFilters)
|
||||
c.UserRules = slices.Clone(d.UserRules)
|
||||
}
|
||||
|
||||
// cloneRewrites returns a deep copy of entries.
|
||||
@@ -309,6 +346,8 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
|
||||
}
|
||||
|
||||
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
|
||||
defer d.filtersInitializerLock.Unlock()
|
||||
|
||||
// remove all pending tasks
|
||||
stop := false
|
||||
for !stop {
|
||||
@@ -321,7 +360,6 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
|
||||
}
|
||||
|
||||
d.filtersInitializerChan <- params
|
||||
d.filtersInitializerLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -350,22 +388,19 @@ func (d *DNSFilter) filtersInitializer() {
|
||||
func (d *DNSFilter) Close() {
|
||||
d.engineLock.Lock()
|
||||
defer d.engineLock.Unlock()
|
||||
|
||||
d.reset()
|
||||
}
|
||||
|
||||
func (d *DNSFilter) reset() {
|
||||
var err error
|
||||
|
||||
if d.rulesStorage != nil {
|
||||
err = d.rulesStorage.Close()
|
||||
if err != nil {
|
||||
if err := d.rulesStorage.Close(); err != nil {
|
||||
log.Error("filtering: rulesStorage.Close: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if d.rulesStorageAllow != nil {
|
||||
err = d.rulesStorageAllow.Close()
|
||||
if err != nil {
|
||||
if err := d.rulesStorageAllow.Close(); err != nil {
|
||||
log.Error("filtering: rulesStorageAllow.Close: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -885,29 +920,30 @@ func InitModule() {
|
||||
initBlockedServices()
|
||||
}
|
||||
|
||||
// New creates properly initialized DNS Filter that is ready to be used.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||
// New creates properly initialized DNS Filter that is ready to be used. c must
|
||||
// be non-nil.
|
||||
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
d = &DNSFilter{
|
||||
resolver: net.DefaultResolver,
|
||||
resolver: net.DefaultResolver,
|
||||
refreshLock: &sync.Mutex{},
|
||||
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
|
||||
}
|
||||
if c != nil {
|
||||
|
||||
d.safebrowsingCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeBrowsingCacheSize,
|
||||
})
|
||||
d.safeSearchCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeSearchCacheSize,
|
||||
})
|
||||
d.parentalCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.ParentalCacheSize,
|
||||
})
|
||||
d.safebrowsingCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeBrowsingCacheSize,
|
||||
})
|
||||
d.safeSearchCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.SafeSearchCacheSize,
|
||||
})
|
||||
d.parentalCache = cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxSize: c.ParentalCacheSize,
|
||||
})
|
||||
|
||||
if c.CustomResolver != nil {
|
||||
d.resolver = c.CustomResolver
|
||||
}
|
||||
if r := c.CustomResolver; r != nil {
|
||||
d.resolver = r
|
||||
}
|
||||
|
||||
d.hostCheckers = []hostChecker{{
|
||||
@@ -930,27 +966,26 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||
name: "safe search",
|
||||
}}
|
||||
|
||||
err := d.initSecurityServices()
|
||||
if err != nil {
|
||||
log.Error("filtering: initialize services: %s", err)
|
||||
defer func() { err = errors.Annotate(err, "filtering: %w") }()
|
||||
|
||||
return nil
|
||||
err = d.initSecurityServices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing services: %s", err)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
d.Config = *c
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
log.Error("rewrites: preparing: %s", err)
|
||||
d.Config = *c
|
||||
d.filtersMu = &sync.RWMutex{}
|
||||
|
||||
return nil
|
||||
}
|
||||
err = d.prepareRewrites()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rewrites: preparing: %s", err)
|
||||
}
|
||||
|
||||
bsvcs := []string{}
|
||||
for _, s := range d.BlockedServices {
|
||||
if !BlockedSvcKnown(s) {
|
||||
log.Debug("skipping unknown blocked-service %q", s)
|
||||
|
||||
continue
|
||||
}
|
||||
bsvcs = append(bsvcs, s)
|
||||
@@ -960,13 +995,24 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
|
||||
if blockFilters != nil {
|
||||
err = d.initFiltering(nil, blockFilters)
|
||||
if err != nil {
|
||||
log.Error("Can't initialize filtering subsystem: %s", err)
|
||||
d.Close()
|
||||
return nil
|
||||
|
||||
return nil, fmt.Errorf("initializing filtering subsystem: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return d
|
||||
_ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755)
|
||||
|
||||
d.loadFilters(d.Filters)
|
||||
d.loadFilters(d.WhitelistFilters)
|
||||
|
||||
d.Filters = deduplicateFilters(d.Filters)
|
||||
d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters)
|
||||
|
||||
updateUniqueFilterID(d.Filters)
|
||||
updateUniqueFilterID(d.WhitelistFilters)
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Start - start the module:
|
||||
@@ -976,9 +1022,10 @@ func (d *DNSFilter) Start() {
|
||||
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
||||
go d.filtersInitializer()
|
||||
|
||||
if d.Config.HTTPRegister != nil { // for tests
|
||||
d.registerSecurityHandlers()
|
||||
d.registerRewritesHandlers()
|
||||
d.registerBlockedServicesHandlers()
|
||||
}
|
||||
d.RegisterFilteringHandlers()
|
||||
|
||||
// Here we should start updating filters,
|
||||
// but currently we can't wake up the periodic task to do so.
|
||||
// So for now we just start this periodic task from here.
|
||||
go d.periodicallyRefreshFilters()
|
||||
}
|
||||
|
||||
@@ -26,10 +26,6 @@ const (
|
||||
pcBlocked = "pornhub.com"
|
||||
)
|
||||
|
||||
var setts = Settings{
|
||||
ProtectionEnabled: true,
|
||||
}
|
||||
|
||||
// Helpers.
|
||||
|
||||
func purgeCaches(d *DNSFilter) {
|
||||
@@ -44,8 +40,8 @@ func purgeCaches(d *DNSFilter) {
|
||||
}
|
||||
}
|
||||
|
||||
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
||||
setts = Settings{
|
||||
func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) {
|
||||
setts = &Settings{
|
||||
ProtectionEnabled: true,
|
||||
FilteringEnabled: true,
|
||||
}
|
||||
@@ -57,26 +53,31 @@ func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
|
||||
setts.SafeSearchEnabled = c.SafeSearchEnabled
|
||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||
setts.ParentalEnabled = c.ParentalEnabled
|
||||
} else {
|
||||
// It must not be nil.
|
||||
c = &Config{}
|
||||
}
|
||||
d := New(c, filters)
|
||||
purgeCaches(d)
|
||||
f, err := New(c, filters)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
purgeCaches(f)
|
||||
|
||||
return f, setts
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
|
||||
func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, setts)
|
||||
require.NoErrorf(t, err, "host %q", hostname)
|
||||
|
||||
assert.Truef(t, res.IsFiltered, "host %q", hostname)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
|
||||
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16, setts *Settings) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, qtype, &setts)
|
||||
res, err := d.CheckHost(hostname, qtype, setts)
|
||||
require.NoErrorf(t, err, "host %q", hostname, err)
|
||||
require.NotEmpty(t, res.Rules, "host %q", hostname)
|
||||
|
||||
@@ -88,10 +89,10 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
|
||||
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
|
||||
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string, setts *Settings) {
|
||||
t.Helper()
|
||||
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(hostname, dns.TypeA, setts)
|
||||
require.NoErrorf(t, err, "host %q", hostname)
|
||||
|
||||
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
|
||||
@@ -111,19 +112,19 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
filters := []Filter{{
|
||||
ID: 0, Data: []byte(text),
|
||||
}}
|
||||
d := newForTest(t, nil, filters)
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
|
||||
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA)
|
||||
d.checkMatchEmpty(t, "subdomain.google.com")
|
||||
d.checkMatchEmpty(t, "example.org")
|
||||
d.checkMatchIP(t, "google.com", addr, dns.TypeA, setts)
|
||||
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA, setts)
|
||||
d.checkMatchEmpty(t, "subdomain.google.com", setts)
|
||||
d.checkMatchEmpty(t, "example.org", setts)
|
||||
|
||||
// IPv4 match.
|
||||
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA)
|
||||
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA, setts)
|
||||
|
||||
// Empty IPv6.
|
||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
|
||||
res, err := d.CheckHost("block.com", dns.TypeAAAA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -134,10 +135,10 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
|
||||
// IPv6 match.
|
||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
|
||||
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA, setts)
|
||||
|
||||
// Empty IPv4.
|
||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
|
||||
res, err = d.CheckHost("ipv6.com", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -148,7 +149,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
assert.Empty(t, res.Rules[0].IP)
|
||||
|
||||
// Two IPv4, both must be returned.
|
||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||
res, err = d.CheckHost("host2", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -159,7 +160,7 @@ func TestEtcHostsMatching(t *testing.T) {
|
||||
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
|
||||
|
||||
// One IPv6 address.
|
||||
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
|
||||
res, err = d.CheckHost("host2", dns.TypeAAAA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -176,27 +177,27 @@ func TestSafeBrowsing(t *testing.T) {
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
|
||||
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked, setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||
|
||||
// Cached result.
|
||||
d.safeBrowsingServer = "127.0.0.1"
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
}
|
||||
|
||||
func TestParallelSB(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
@@ -205,10 +206,10 @@ func TestParallelSB(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
d.checkMatch(t, sbBlocked)
|
||||
d.checkMatch(t, "test."+sbBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, pcBlocked)
|
||||
d.checkMatch(t, sbBlocked, setts)
|
||||
d.checkMatch(t, "test."+sbBlocked, setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, pcBlocked, setts)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -217,7 +218,7 @@ func TestParallelSB(t *testing.T) {
|
||||
// Safe Search.
|
||||
|
||||
func TestSafeSearch(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
require.True(t, ok)
|
||||
@@ -226,7 +227,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
d := newForTest(t, &Config{
|
||||
d, setts := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
@@ -243,7 +244,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
"www.yandex.com",
|
||||
} {
|
||||
t.Run(strings.ToLower(host), func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -258,7 +259,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
|
||||
|
||||
func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(t, &Config{
|
||||
d, setts := newForTest(t, &Config{
|
||||
SafeSearchEnabled: true,
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
@@ -277,7 +278,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
"www.google.je",
|
||||
} {
|
||||
t.Run(host, func(t *testing.T) {
|
||||
res, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -291,12 +292,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, setts := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
const domain = "yandex.ru"
|
||||
|
||||
// Check host with disabled safesearch.
|
||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
@@ -305,10 +306,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
|
||||
yandexIP := net.IPv4(213, 180, 193, 56)
|
||||
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For yandex we already know valid IP.
|
||||
@@ -325,20 +326,20 @@ func TestSafeSearchCacheYandex(t *testing.T) {
|
||||
|
||||
func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
resolver := &aghtest.TestResolver{}
|
||||
d := newForTest(t, &Config{
|
||||
d, setts := newForTest(t, &Config{
|
||||
CustomResolver: resolver,
|
||||
}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const domain = "www.google.ru"
|
||||
res, err := d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
|
||||
require.Empty(t, res.Rules)
|
||||
|
||||
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
d.resolver = resolver
|
||||
|
||||
@@ -358,7 +359,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err = d.CheckHost(domain, dns.TypeA, &setts)
|
||||
res, err = d.CheckHost(domain, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Rules, 1)
|
||||
|
||||
@@ -379,22 +380,22 @@ func TestParentalControl(t *testing.T) {
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
|
||||
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatch(t, pcBlocked, setts)
|
||||
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
|
||||
|
||||
d.checkMatch(t, "www."+pcBlocked)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru")
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatchEmpty(t, "api.jquery.com")
|
||||
d.checkMatch(t, "www."+pcBlocked, setts)
|
||||
d.checkMatchEmpty(t, "www.yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
d.checkMatchEmpty(t, "api.jquery.com", setts)
|
||||
|
||||
// Test cached result.
|
||||
d.parentalServer = "127.0.0.1"
|
||||
d.checkMatch(t, pcBlocked)
|
||||
d.checkMatchEmpty(t, "yandex.ru")
|
||||
d.checkMatch(t, pcBlocked, setts)
|
||||
d.checkMatchEmpty(t, "yandex.ru", setts)
|
||||
}
|
||||
|
||||
// Filtering.
|
||||
@@ -679,10 +680,10 @@ func TestMatching(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
|
||||
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
|
||||
d := newForTest(t, nil, filters)
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
|
||||
res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
|
||||
@@ -705,7 +706,7 @@ func TestWhitelist(t *testing.T) {
|
||||
whiteFilters := []Filter{{
|
||||
ID: 0, Data: []byte(whiteRules),
|
||||
}}
|
||||
d := newForTest(t, nil, filters)
|
||||
d, setts := newForTest(t, nil, filters)
|
||||
|
||||
err := d.SetFilters(filters, whiteFilters, false)
|
||||
require.NoError(t, err)
|
||||
@@ -713,7 +714,7 @@ func TestWhitelist(t *testing.T) {
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
// Matched by white filter.
|
||||
res, err := d.CheckHost("host1", dns.TypeA, &setts)
|
||||
res, err := d.CheckHost("host1", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, res.IsFiltered)
|
||||
@@ -724,7 +725,7 @@ func TestWhitelist(t *testing.T) {
|
||||
assert.Equal(t, "||host1^", res.Rules[0].Text)
|
||||
|
||||
// Not matched by white filter, but matched by block filter.
|
||||
res, err = d.CheckHost("host2", dns.TypeA, &setts)
|
||||
res, err = d.CheckHost("host2", dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, res.IsFiltered)
|
||||
@@ -750,7 +751,7 @@ func applyClientSettings(setts *Settings) {
|
||||
}
|
||||
|
||||
func TestClientSettings(t *testing.T) {
|
||||
d := newForTest(t,
|
||||
d, setts := newForTest(t,
|
||||
&Config{
|
||||
ParentalEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
@@ -796,7 +797,7 @@ func TestClientSettings(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
|
||||
r, err := d.CheckHost(tc.host, dns.TypeA, setts)
|
||||
require.NoError(t, err)
|
||||
|
||||
if before {
|
||||
@@ -814,7 +815,7 @@ func TestClientSettings(t *testing.T) {
|
||||
t.Run(tc.name, makeTester(tc, tc.before))
|
||||
}
|
||||
|
||||
applyClientSettings(&setts)
|
||||
applyClientSettings(setts)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, makeTester(tc, !tc.before))
|
||||
@@ -824,13 +825,13 @@ func TestClientSettings(t *testing.T) {
|
||||
// Benchmarks.
|
||||
|
||||
func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
@@ -838,14 +839,14 @@ func BenchmarkSafeBrowsing(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
|
||||
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
|
||||
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
|
||||
require.NoError(b, err)
|
||||
|
||||
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
|
||||
@@ -854,7 +855,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearch(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
for n := 0; n < b.N; n++ {
|
||||
val, ok := d.SafeSearchDomain("www.google.com")
|
||||
@@ -865,7 +866,7 @@ func BenchmarkSafeSearch(b *testing.B) {
|
||||
}
|
||||
|
||||
func BenchmarkSafeSearchParallel(b *testing.B) {
|
||||
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
|
||||
b.Cleanup(d.Close)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package home
|
||||
package filtering
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -34,7 +32,7 @@ func validateFilterURL(urlStr string) (err error) {
|
||||
return fmt.Errorf("checking filter url: %w", err)
|
||||
}
|
||||
|
||||
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
|
||||
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
|
||||
return fmt.Errorf("checking filter url: invalid scheme %q", s)
|
||||
}
|
||||
|
||||
@@ -47,7 +45,7 @@ type filterAddJSON struct {
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterAddJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
@@ -65,14 +63,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// Check for duplicates
|
||||
if filterExists(fj.URL) {
|
||||
if d.filterExists(fj.URL) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Set necessary properties
|
||||
filt := filter{
|
||||
filt := FilterYAML{
|
||||
Enabled: true,
|
||||
URL: fj.URL,
|
||||
Name: fj.Name,
|
||||
@@ -81,7 +79,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
filt.ID = assignUniqueFilterID()
|
||||
|
||||
// Download the filter contents
|
||||
ok, err := f.update(&filt)
|
||||
ok, err := d.update(&filt)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
@@ -109,14 +107,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// URL is assumed valid so append it to filters, update config, write new
|
||||
// file and reload it to engines.
|
||||
if !filterAdd(filt) {
|
||||
if !d.filterAdd(filt) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||
if err != nil {
|
||||
@@ -124,7 +122,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
type request struct {
|
||||
URL string `json:"url"`
|
||||
Whitelist bool `json:"whitelist"`
|
||||
@@ -138,23 +136,23 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
config.Lock()
|
||||
filters := &config.Filters
|
||||
d.filtersMu.Lock()
|
||||
filters := &d.Filters
|
||||
if req.Whitelist {
|
||||
filters = &config.WhitelistFilters
|
||||
filters = &d.WhitelistFilters
|
||||
}
|
||||
|
||||
var deleted filter
|
||||
var newFilters []filter
|
||||
for _, f := range *filters {
|
||||
if f.URL != req.URL {
|
||||
newFilters = append(newFilters, f)
|
||||
var deleted FilterYAML
|
||||
var newFilters []FilterYAML
|
||||
for _, flt := range *filters {
|
||||
if flt.URL != req.URL {
|
||||
newFilters = append(newFilters, flt)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
deleted = f
|
||||
path := f.Path()
|
||||
deleted = flt
|
||||
path := flt.Path(d.DataDir)
|
||||
err = os.Rename(path, path+".old")
|
||||
if err != nil {
|
||||
log.Error("deleting filter %q: %s", path, err)
|
||||
@@ -162,10 +160,10 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
*filters = newFilters
|
||||
config.Unlock()
|
||||
d.filtersMu.Unlock()
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
|
||||
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
|
||||
// necessary, but will require the additional complicated code to run
|
||||
@@ -191,55 +189,51 @@ type filterURLReq struct {
|
||||
Whitelist bool `json:"whitelist"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
|
||||
fj := filterURLReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(&fj)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if fj.Data == nil {
|
||||
err = errors.Error("data cannot be null")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = validateFilterURL(fj.Data.URL)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid url: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
filt := filter{
|
||||
filt := FilterYAML{
|
||||
Enabled: fj.Data.Enabled,
|
||||
Name: fj.Data.Name,
|
||||
URL: fj.Data.URL,
|
||||
}
|
||||
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
status := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
|
||||
if (status & statusFound) == 0 {
|
||||
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "URL doesn't exist")
|
||||
|
||||
return
|
||||
}
|
||||
if (status & statusURLExists) != 0 {
|
||||
http.Error(w, "URL already exists", http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "URL already exists")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
onConfigModified()
|
||||
d.ConfigModified()
|
||||
|
||||
restart := (status & statusEnabledChanged) != 0
|
||||
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
|
||||
// download new filter and apply its rules
|
||||
flags := filterRefreshBlocklists
|
||||
if fj.Whitelist {
|
||||
flags = filterRefreshAllowlists
|
||||
}
|
||||
nUpdated, _ := f.refreshFilters(flags, true)
|
||||
// download new filter and apply its rules.
|
||||
nUpdated := d.refreshFilters(!fj.Whitelist, fj.Whitelist, false)
|
||||
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
|
||||
// if not - we restart the filtering ourselves
|
||||
restart = false
|
||||
@@ -249,25 +243,34 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
if restart {
|
||||
enableFilters(true)
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
// filteringRulesReq is the JSON structure for settings custom filtering rules.
|
||||
type filteringRulesReq struct {
|
||||
Rules []string `json:"rules"`
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
req := &filteringRulesReq{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
config.UserRules = strings.Split(string(body), "\n")
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.UserRules = req.Rules
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
|
||||
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
type Req struct {
|
||||
White bool `json:"whitelist"`
|
||||
}
|
||||
@@ -285,35 +288,27 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
flags := filterRefreshBlocklists
|
||||
if req.White {
|
||||
flags = filterRefreshAllowlists
|
||||
}
|
||||
func() {
|
||||
// Temporarily unlock the Context.controlLock because the
|
||||
// f.refreshFilters waits for it to be unlocked but it's
|
||||
// actually locked in ensure wrapper.
|
||||
//
|
||||
// TODO(e.burkov): Reconsider this messy syncing process.
|
||||
Context.controlLock.Unlock()
|
||||
defer Context.controlLock.Lock()
|
||||
|
||||
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
|
||||
}()
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
var ok bool
|
||||
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
|
||||
if !ok {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"filters update procedure is already running",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
@@ -333,7 +328,7 @@ type filteringConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func filterToJSON(f filter) filterJSON {
|
||||
func filterToJSON(f FilterYAML) filterJSON {
|
||||
fj := filterJSON{
|
||||
ID: f.ID,
|
||||
Enabled: f.Enabled,
|
||||
@@ -350,21 +345,21 @@ func filterToJSON(f filter) filterJSON {
|
||||
}
|
||||
|
||||
// Get filtering configuration
|
||||
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp := filteringConfig{}
|
||||
config.RLock()
|
||||
resp.Enabled = config.DNS.FilteringEnabled
|
||||
resp.Interval = config.DNS.FiltersUpdateIntervalHours
|
||||
for _, f := range config.Filters {
|
||||
d.filtersMu.RLock()
|
||||
resp.Enabled = d.FilteringEnabled
|
||||
resp.Interval = d.FiltersUpdateIntervalHours
|
||||
for _, f := range d.Filters {
|
||||
fj := filterToJSON(f)
|
||||
resp.Filters = append(resp.Filters, fj)
|
||||
}
|
||||
for _, f := range config.WhitelistFilters {
|
||||
for _, f := range d.WhitelistFilters {
|
||||
fj := filterToJSON(f)
|
||||
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
|
||||
}
|
||||
resp.UserRules = config.UserRules
|
||||
config.RUnlock()
|
||||
resp.UserRules = d.UserRules
|
||||
d.filtersMu.RUnlock()
|
||||
|
||||
jsonVal, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
@@ -380,7 +375,7 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// Set filtering configuration
|
||||
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
|
||||
req := filteringConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -389,22 +384,22 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(req.Interval) {
|
||||
if !ValidateUpdateIvl(req.Interval) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
d.filtersMu.Lock()
|
||||
defer d.filtersMu.Unlock()
|
||||
|
||||
config.DNS.FilteringEnabled = req.Enabled
|
||||
config.DNS.FiltersUpdateIntervalHours = req.Interval
|
||||
d.FilteringEnabled = req.Enabled
|
||||
d.FiltersUpdateIntervalHours = req.Interval
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
enableFilters(true)
|
||||
d.ConfigModified()
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
|
||||
type checkHostRespRule struct {
|
||||
@@ -435,15 +430,15 @@ type checkHostResp struct {
|
||||
FilterID int64 `json:"filter_id"`
|
||||
}
|
||||
|
||||
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
host := q.Get("name")
|
||||
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.URL.Query().Get("name")
|
||||
|
||||
setts := Context.dnsFilter.GetConfig()
|
||||
setts := d.GetConfig()
|
||||
setts.FilteringEnabled = true
|
||||
setts.ProtectionEnabled = true
|
||||
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
|
||||
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
|
||||
|
||||
d.ApplyBlockedServices(&setts, nil)
|
||||
result, err := d.CheckHost(host, dns.TypeA, &setts)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
@@ -457,18 +452,20 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := checkHostResp{}
|
||||
resp.Reason = result.Reason.String()
|
||||
resp.SvcName = result.ServiceName
|
||||
resp.CanonName = result.CanonName
|
||||
resp.IPList = result.IPList
|
||||
rulesLen := len(result.Rules)
|
||||
resp := checkHostResp{
|
||||
Reason: result.Reason.String(),
|
||||
SvcName: result.ServiceName,
|
||||
CanonName: result.CanonName,
|
||||
IPList: result.IPList,
|
||||
Rules: make([]*checkHostRespRule, len(result.Rules)),
|
||||
}
|
||||
|
||||
if len(result.Rules) > 0 {
|
||||
if rulesLen > 0 {
|
||||
resp.FilterID = result.Rules[0].FilterListID
|
||||
resp.Rule = result.Rules[0].Text
|
||||
}
|
||||
|
||||
resp.Rules = make([]*checkHostRespRule, len(result.Rules))
|
||||
for i, r := range result.Rules {
|
||||
resp.Rules[i] = &checkHostRespRule{
|
||||
FilterListID: r.FilterListID,
|
||||
@@ -476,28 +473,51 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
js, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(js)
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFilteringHandlers - register handlers
|
||||
func (f *Filtering) RegisterFilteringHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
|
||||
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
|
||||
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
|
||||
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
|
||||
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
|
||||
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
|
||||
func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||
registerHTTP := d.HTTPRegister
|
||||
if registerHTTP == nil {
|
||||
return
|
||||
}
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
registerHTTP(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
||||
registerHTTP(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
||||
registerHTTP(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
|
||||
registerHTTP(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
|
||||
registerHTTP(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
|
||||
|
||||
registerHTTP(http.MethodGet, "/control/filtering/status", d.handleFilteringStatus)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/config", d.handleFilteringConfig)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/add_url", d.handleFilteringAddURL)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/remove_url", d.handleFilteringRemoveURL)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/set_url", d.handleFilteringSetURL)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/refresh", d.handleFilteringRefresh)
|
||||
registerHTTP(http.MethodPost, "/control/filtering/set_rules", d.handleFilteringSetRules)
|
||||
registerHTTP(http.MethodGet, "/control/filtering/check_host", d.handleCheckHost)
|
||||
}
|
||||
|
||||
func checkFiltersUpdateIntervalHours(i uint32) bool {
|
||||
// ValidateUpdateIvl returns false if i is not a valid filters update interval.
|
||||
func ValidateUpdateIvl(i uint32) bool {
|
||||
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
|
||||
}
|
||||
@@ -133,34 +133,31 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
|
||||
// 1. A and AAAA > CNAME
|
||||
// 2. wildcard > exact
|
||||
// 3. lower level wildcard > higher level wildcard
|
||||
//
|
||||
// TODO(a.garipov): Replace with slices.Sort.
|
||||
type rewritesSorted []*LegacyRewrite
|
||||
|
||||
// Len implements the sort.Interface interface for legacyRewritesSorted.
|
||||
// Len implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Len() (l int) { return len(a) }
|
||||
|
||||
// Swap implements the sort.Interface interface for legacyRewritesSorted.
|
||||
// Swap implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
// Less implements the sort.Interface interface for legacyRewritesSorted.
|
||||
// Less implements the sort.Interface interface for rewritesSorted.
|
||||
func (a rewritesSorted) Less(i, j int) (less bool) {
|
||||
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
|
||||
ith, jth := a[i], a[j]
|
||||
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
|
||||
return true
|
||||
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
|
||||
} else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
|
||||
return false
|
||||
}
|
||||
|
||||
if isWildcard(a[i].Domain) {
|
||||
if !isWildcard(a[j].Domain) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if isWildcard(a[j].Domain) {
|
||||
return true
|
||||
}
|
||||
if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
|
||||
return jw
|
||||
}
|
||||
|
||||
// Both are wildcards.
|
||||
return len(a[i].Domain) > len(a[j].Domain)
|
||||
// Both are either wildcards or not.
|
||||
return len(ith.Domain) > len(jth.Domain)
|
||||
}
|
||||
|
||||
// prepareRewrites normalizes and validates all legacy DNS rewrites.
|
||||
@@ -313,9 +310,3 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
d.Config.ConfigModified()
|
||||
}
|
||||
|
||||
func (d *DNSFilter) registerRewritesHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// TODO(e.burkov): All the tests in this file may and should me merged together.
|
||||
|
||||
func TestRewrites(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
@@ -188,7 +188,7 @@ func TestRewrites(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesLevels(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exact host, wildcard L2, wildcard L3.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
@@ -235,7 +235,7 @@ func TestRewritesLevels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Wildcard and exception for a sub-domain.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
@@ -286,7 +286,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRewritesExceptionIP(t *testing.T) {
|
||||
d := newForTest(t, nil, nil)
|
||||
d, _ := newForTest(t, nil, nil)
|
||||
t.Cleanup(d.Close)
|
||||
// Exception for AAAA record.
|
||||
d.Rewrites = []*LegacyRewrite{{
|
||||
|
||||
@@ -415,17 +415,3 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSFilter) registerSecurityHandlers() {
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
|
||||
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
|
||||
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
|
||||
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
|
||||
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func TestSafeBrowsingCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
ups := aghtest.NewErrorUpstream()
|
||||
@@ -128,7 +128,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSBPC(t *testing.T) {
|
||||
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
|
||||
t.Cleanup(d.Close)
|
||||
|
||||
const hostname = "example.org"
|
||||
|
||||
@@ -8,12 +8,14 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -32,7 +34,8 @@ const sessionTokenSize = 16
|
||||
|
||||
type session struct {
|
||||
userName string
|
||||
expire uint32 // expiration time (in seconds)
|
||||
// expire is the expiration time, in seconds.
|
||||
expire uint32
|
||||
}
|
||||
|
||||
func (s *session) serialize() []byte {
|
||||
@@ -64,29 +67,29 @@ func (s *session) deserialize(data []byte) bool {
|
||||
|
||||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
blocker *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []User
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
db *bbolt.DB
|
||||
raleLimiter *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []webUser
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
}
|
||||
|
||||
// User object
|
||||
type User struct {
|
||||
// webUser represents a user of the Web UI.
|
||||
type webUser struct {
|
||||
Name string `yaml:"name"`
|
||||
PasswordHash string `yaml:"password"` // bcrypt hash
|
||||
PasswordHash string `yaml:"password"`
|
||||
}
|
||||
|
||||
// InitAuth - create a global object
|
||||
func InitAuth(dbFilename string, users []User, sessionTTL uint32, blocker *authRateLimiter) *Auth {
|
||||
func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter *authRateLimiter) *Auth {
|
||||
log.Info("Initializing auth module: %s", dbFilename)
|
||||
|
||||
a := &Auth{
|
||||
sessionTTL: sessionTTL,
|
||||
blocker: blocker,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
sessionTTL: sessionTTL,
|
||||
raleLimiter: rateLimiter,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
}
|
||||
var err error
|
||||
a.db, err = bbolt.Open(dbFilename, 0o644, nil)
|
||||
@@ -326,35 +329,25 @@ func newSessionToken() (data []byte, err error) {
|
||||
return randData, nil
|
||||
}
|
||||
|
||||
// cookieTimeFormat is the format to be used in (time.Time).Format for cookie's
|
||||
// expiry field.
|
||||
const cookieTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
|
||||
|
||||
// cookieExpiryFormat returns the formatted exp to be used in cookie string.
|
||||
// It's quite simple for now, but probably will be expanded in the future.
|
||||
func cookieExpiryFormat(exp time.Time) (formatted string) {
|
||||
return exp.Format(cookieTimeFormat)
|
||||
}
|
||||
|
||||
func (a *Auth) httpCookie(req loginJSON, addr string) (cookie string, err error) {
|
||||
blocker := a.blocker
|
||||
u := a.UserFind(req.Name, req.Password)
|
||||
if len(u.Name) == 0 {
|
||||
if blocker != nil {
|
||||
blocker.inc(addr)
|
||||
// newCookie creates a new authentication cookie.
|
||||
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
|
||||
rateLimiter := a.raleLimiter
|
||||
u, ok := a.findUser(req.Name, req.Password)
|
||||
if !ok {
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.inc(addr)
|
||||
}
|
||||
|
||||
return "", err
|
||||
return nil, errors.Error("invalid username or password")
|
||||
}
|
||||
|
||||
if blocker != nil {
|
||||
blocker.remove(addr)
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.remove(addr)
|
||||
}
|
||||
|
||||
var sess []byte
|
||||
sess, err = newSessionToken()
|
||||
sess, err := newSessionToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, fmt.Errorf("generating token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
@@ -364,11 +357,15 @@ func (a *Auth) httpCookie(req loginJSON, addr string) (cookie string, err error)
|
||||
expire: uint32(now.Unix()) + a.sessionTTL,
|
||||
})
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s=%s; Path=/; HttpOnly; Expires=%s",
|
||||
sessionCookieName, hex.EncodeToString(sess),
|
||||
cookieExpiryFormat(now.Add(cookieTTL)),
|
||||
), nil
|
||||
return &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: hex.EncodeToString(sess),
|
||||
Path: "/",
|
||||
Expires: now.Add(cookieTTL),
|
||||
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// realIP extracts the real IP address of the client from an HTTP request using
|
||||
@@ -436,8 +433,8 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if blocker := Context.auth.blocker; blocker != nil {
|
||||
if left := blocker.check(remoteAddr); left > 0 {
|
||||
if rateLimiter := Context.auth.raleLimiter; rateLimiter != nil {
|
||||
if left := rateLimiter.check(remoteAddr); left > 0 {
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
|
||||
aghhttp.Error(r, w, http.StatusTooManyRequests, "auth: blocked for %s", left)
|
||||
|
||||
@@ -445,10 +442,9 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
var cookie string
|
||||
cookie, err = Context.auth.httpCookie(req, remoteAddr)
|
||||
cookie, err := Context.auth.newCookie(req, remoteAddr)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "crypto rand reader: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusForbidden, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -462,20 +458,11 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
log.Error("auth: unknown ip")
|
||||
}
|
||||
|
||||
if len(cookie) == 0 {
|
||||
log.Info("auth: failed to login user %q from ip %v", req.Name, ip)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
http.Error(w, "invalid username or password", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
h := w.Header()
|
||||
h.Set("Set-Cookie", cookie)
|
||||
h.Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
h.Set("Pragma", "no-cache")
|
||||
h.Set("Expires", "0")
|
||||
@@ -484,17 +471,31 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie := r.Header.Get("Cookie")
|
||||
sess := parseCookie(cookie)
|
||||
respHdr := w.Header()
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// The user is already logged out.
|
||||
respHdr.Set("Location", "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
Context.auth.RemoveSession(sess)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", "/login.html")
|
||||
Context.auth.RemoveSession(c.Value)
|
||||
|
||||
s := fmt.Sprintf("%s=; Path=/; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT",
|
||||
sessionCookieName)
|
||||
w.Header().Set("Set-Cookie", s)
|
||||
c = &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
respHdr.Set("Location", "/login.html")
|
||||
respHdr.Set("Set-Cookie", c.String())
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -504,101 +505,108 @@ func RegisterAuthHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||
}
|
||||
|
||||
func parseCookie(cookie string) string {
|
||||
pairs := strings.Split(cookie, ";")
|
||||
for _, pair := range pairs {
|
||||
pair = strings.TrimSpace(pair)
|
||||
kv := strings.SplitN(pair, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
if kv[0] == sessionCookieName {
|
||||
return kv[1]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// optionalAuthThird return true if user should authenticate first.
|
||||
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool) {
|
||||
authFirst = false
|
||||
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||
if glProcessCookie(r) {
|
||||
log.Debug("auth: authentication is handled by GL-Inet submodule")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// redirect to login page if not authenticated
|
||||
ok := false
|
||||
isAuthenticated := false
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
|
||||
if glProcessCookie(r) {
|
||||
log.Debug("auth: authentication was handled by GL-Inet submodule")
|
||||
ok = true
|
||||
} else if err == nil {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
ok = true
|
||||
} else if r < 0 {
|
||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
} else {
|
||||
// there's no Cookie, check Basic authentication
|
||||
user, pass, ok2 := r.BasicAuth()
|
||||
if ok2 {
|
||||
u := Context.auth.UserFind(user, pass)
|
||||
if len(u.Name) != 0 {
|
||||
ok = true
|
||||
} else {
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// Check Basic authentication.
|
||||
user, pass, hasBasic := r.BasicAuth()
|
||||
if hasBasic {
|
||||
_, isAuthenticated = Context.auth.findUser(user, pass)
|
||||
if !isAuthenticated {
|
||||
log.Info("auth: invalid Basic Authorization value")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
|
||||
if glProcessRedirect(w, r) {
|
||||
log.Debug("auth: redirected to login page by GL-Inet submodule")
|
||||
} else {
|
||||
w.Header().Set("Location", "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
} else {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("Forbidden"))
|
||||
} else {
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
isAuthenticated = res == checkSessionOK
|
||||
if !isAuthenticated {
|
||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
authFirst = true
|
||||
}
|
||||
|
||||
return authFirst
|
||||
if isAuthenticated {
|
||||
return false
|
||||
}
|
||||
|
||||
if p := r.URL.Path; p == "/" || p == "/index.html" {
|
||||
if glProcessRedirect(w, r) {
|
||||
log.Debug("auth: redirected to login page by GL-Inet submodule")
|
||||
} else {
|
||||
log.Debug("auth: redirected to login page")
|
||||
w.Header().Set("Location", "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
} else {
|
||||
log.Debug("auth: responded with forbidden to %s %s", r.Method, p)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("Forbidden"))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
// TODO(a.garipov): Use [http.Handler] consistently everywhere throughout the
|
||||
// project.
|
||||
func optionalAuth(
|
||||
h func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/login.html" {
|
||||
// redirect to dashboard if already authenticated
|
||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
p := r.URL.Path
|
||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
if p == "/login.html" {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
r := Context.auth.checkSession(cookie.Value)
|
||||
if r == checkSessionOK {
|
||||
// Redirect to the dashboard if already authenticated.
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
if res == checkSessionOK {
|
||||
w.Header().Set("Location", "/")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
} else if r == checkSessionNotFound {
|
||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
}
|
||||
|
||||
} else if strings.HasPrefix(r.URL.Path, "/assets/") ||
|
||||
strings.HasPrefix(r.URL.Path, "/login.") {
|
||||
// process as usual
|
||||
// no additional auth requirements
|
||||
} else if Context.auth != nil && Context.auth.AuthRequired() {
|
||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
} else if isPublicResource(p) {
|
||||
// Process as usual, no additional auth requirements.
|
||||
} else if authRequired {
|
||||
if optionalAuthThird(w, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// isPublicResource returns true if p is a path to a public resource.
|
||||
func isPublicResource(p string) (ok bool) {
|
||||
isAsset, err := path.Match("/assets/*", p)
|
||||
if err != nil {
|
||||
// The only error that is returned from path.Match is
|
||||
// [path.ErrBadPattern]. This is a programmer error.
|
||||
panic(fmt.Errorf("bad asset pattern: %w", err))
|
||||
}
|
||||
|
||||
isLogin, err := path.Match("/login.*", p)
|
||||
if err != nil {
|
||||
// Same as above.
|
||||
panic(fmt.Errorf("bad login pattern: %w", err))
|
||||
}
|
||||
|
||||
return isAsset || isLogin
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
@@ -612,7 +620,7 @@ func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// UserAdd - add new user
|
||||
func (a *Auth) UserAdd(u *User, password string) {
|
||||
func (a *Auth) UserAdd(u *webUser, password string) {
|
||||
if len(password) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -631,31 +639,35 @@ func (a *Auth) UserAdd(u *User, password string) {
|
||||
log.Debug("auth: added user: %s", u.Name)
|
||||
}
|
||||
|
||||
// UserFind - find a user
|
||||
func (a *Auth) UserFind(login, password string) User {
|
||||
// findUser returns a user if there is one.
|
||||
func (a *Auth) findUser(login, password string) (u webUser, ok bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
for _, u := range a.users {
|
||||
|
||||
for _, u = range a.users {
|
||||
if u.Name == login &&
|
||||
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
|
||||
return u
|
||||
return u, true
|
||||
}
|
||||
}
|
||||
return User{}
|
||||
|
||||
return webUser{}, false
|
||||
}
|
||||
|
||||
// getCurrentUser returns the current user. It returns an empty User if the
|
||||
// user is not found.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) User {
|
||||
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// There's no Cookie, check Basic authentication.
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
return Context.auth.UserFind(user, pass)
|
||||
u, _ = Context.auth.findUser(user, pass)
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
return User{}
|
||||
return webUser{}
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
@@ -663,20 +675,20 @@ func (a *Auth) getCurrentUser(r *http.Request) User {
|
||||
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
return User{}
|
||||
return webUser{}
|
||||
}
|
||||
|
||||
for _, u := range a.users {
|
||||
for _, u = range a.users {
|
||||
if u.Name == s.userName {
|
||||
return u
|
||||
}
|
||||
}
|
||||
|
||||
return User{}
|
||||
return webUser{}
|
||||
}
|
||||
|
||||
// GetUsers - get users
|
||||
func (a *Auth) GetUsers() []User {
|
||||
func (a *Auth) GetUsers() []webUser {
|
||||
a.lock.Lock()
|
||||
users := a.users
|
||||
a.lock.Unlock()
|
||||
|
||||
@@ -43,14 +43,14 @@ func TestAuth(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []User{{
|
||||
users := []webUser{{
|
||||
Name: "name",
|
||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||
}}
|
||||
a := InitAuth(fn, nil, 60, nil)
|
||||
s := session{}
|
||||
|
||||
user := User{Name: "name"}
|
||||
user := webUser{Name: "name"}
|
||||
a.UserAdd(&user, "password")
|
||||
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
@@ -84,7 +84,8 @@ func TestAuth(t *testing.T) {
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u := a.UserFind("name", "password")
|
||||
u, ok := a.findUser("name", "password")
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, u.Name)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
@@ -118,7 +119,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []User{
|
||||
users := []webUser{
|
||||
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
Context.auth = InitAuth(fn, users, 60, nil)
|
||||
@@ -150,18 +151,19 @@ func TestAuthHTTP(t *testing.T) {
|
||||
assert.True(t, handlerCalled)
|
||||
|
||||
// perform login
|
||||
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, cookie)
|
||||
require.NotNil(t, cookie)
|
||||
|
||||
// get /
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", cookie)
|
||||
r.Header.Set("Cookie", cookie.String())
|
||||
r.URL = &url.URL{Path: "/"}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
assert.True(t, handlerCalled)
|
||||
|
||||
r.Header.Del("Cookie")
|
||||
|
||||
// get / with basic auth
|
||||
@@ -177,7 +179,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
// get login page with a valid cookie - we're redirected to /
|
||||
handler2 = optionalAuth(handler)
|
||||
w.hdr = make(http.Header)
|
||||
r.Header.Set("Cookie", cookie)
|
||||
r.Header.Set("Cookie", cookie.String())
|
||||
r.URL = &url.URL{Path: loginURL}
|
||||
handlerCalled = false
|
||||
handler2(&w, &r)
|
||||
|
||||
@@ -126,7 +126,7 @@ type clientsContainer struct {
|
||||
allTags *stringutil.Set
|
||||
|
||||
// dhcpServer is used for looking up clients IP addresses by MAC addresses
|
||||
dhcpServer *dhcpd.Server
|
||||
dhcpServer dhcpd.Interface
|
||||
|
||||
// dnsServer is used for checking clients IP status access list status
|
||||
dnsServer *dnsforward.Server
|
||||
@@ -146,7 +146,7 @@ type clientsContainer struct {
|
||||
// Note: this function must be called only once
|
||||
func (clients *clientsContainer) Init(
|
||||
objects []*clientObject,
|
||||
dhcpServer *dhcpd.Server,
|
||||
dhcpServer dhcpd.Interface,
|
||||
etcHosts *aghnet.HostsContainer,
|
||||
arpdb aghnet.ARPDB,
|
||||
) {
|
||||
|
||||
@@ -279,8 +279,6 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
t.Skip("skipping dhcp test on windows")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
@@ -296,13 +294,15 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
clients.dhcpServer, err = dhcpd.Create(config)
|
||||
dhcpServer, err := dhcpd.Create(config)
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return os.Remove("leases.db")
|
||||
})
|
||||
|
||||
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||
clients.dhcpServer = dhcpServer
|
||||
|
||||
err = dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: ip,
|
||||
Hostname: "testhost",
|
||||
|
||||
@@ -93,13 +93,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
|
||||
data.Tags = clientTags
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
e := json.NewEncoder(w).Encode(data)
|
||||
if e != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
|
||||
|
||||
return
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// Convert JSON object to Client object
|
||||
@@ -249,11 +243,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write response: %s", err)
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -23,10 +22,9 @@ import (
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
dataDir = "data" // data storage
|
||||
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
|
||||
)
|
||||
// dataDir is the name of a directory under the working one to store some
|
||||
// persistent data.
|
||||
const dataDir = "data"
|
||||
|
||||
// logSettings are the logging settings part of the configuration file.
|
||||
//
|
||||
@@ -87,10 +85,10 @@ type configuration struct {
|
||||
// It's reset after config is parsed
|
||||
fileData []byte
|
||||
|
||||
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
|
||||
Users []User `yaml:"users"` // Users that can access HTTP server
|
||||
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
|
||||
Users []webUser `yaml:"users"` // Users that can access HTTP server
|
||||
// AuthAttempts is the maximum number of failed login attempts a user
|
||||
// can do before being blocked.
|
||||
AuthAttempts uint `yaml:"auth_attempts"`
|
||||
@@ -108,9 +106,16 @@ type configuration struct {
|
||||
DNS dnsConfig `yaml:"dns"`
|
||||
TLS tlsConfigSettings `yaml:"tls"`
|
||||
|
||||
Filters []filter `yaml:"filters"`
|
||||
WhitelistFilters []filter `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
// Filters reflects the filters from [filtering.Config]. It's cloned to the
|
||||
// config used in the filtering module at the startup. Afterwards it's
|
||||
// cloned from the filtering module back here.
|
||||
//
|
||||
// TODO(e.burkov): Move all the filtering configuration fields into the
|
||||
// only configuration subsection covering the changes with a single
|
||||
// migration.
|
||||
Filters []filtering.FilterYAML `yaml:"filters"`
|
||||
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
|
||||
UserRules []string `yaml:"user_rules"`
|
||||
|
||||
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
|
||||
|
||||
@@ -145,9 +150,7 @@ type dnsConfig struct {
|
||||
|
||||
dnsforward.FilteringConfig `yaml:",inline"`
|
||||
|
||||
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
|
||||
DnsfilterConf filtering.Config `yaml:",inline"`
|
||||
DnsfilterConf *filtering.Config `yaml:",inline"`
|
||||
|
||||
// UpstreamTimeout is the timeout for querying upstream servers.
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
@@ -193,15 +196,20 @@ type tlsConfigSettings struct {
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
|
||||
var config = &configuration{
|
||||
BindPort: 3000,
|
||||
BetaBindPort: 0,
|
||||
BindHost: net.IP{0, 0, 0, 0},
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
BindPort: 3000,
|
||||
BetaBindPort: 0,
|
||||
BindHost: net.IP{0, 0, 0, 0},
|
||||
AuthAttempts: 5,
|
||||
AuthBlockMin: 15,
|
||||
WebSessionTTLHours: 30 * 24,
|
||||
DNS: dnsConfig{
|
||||
BindHosts: []net.IP{{0, 0, 0, 0}},
|
||||
Port: defaultPortDNS,
|
||||
StatsInterval: 1,
|
||||
BindHosts: []net.IP{{0, 0, 0, 0}},
|
||||
Port: defaultPortDNS,
|
||||
StatsInterval: 1,
|
||||
QueryLogEnabled: true,
|
||||
QueryLogFileEnabled: true,
|
||||
QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day},
|
||||
QueryLogMemSize: 1000,
|
||||
FilteringConfig: dnsforward.FilteringConfig{
|
||||
ProtectionEnabled: true, // whether or not use any of filtering features
|
||||
BlockingMode: dnsforward.BlockingModeDefault,
|
||||
@@ -222,18 +230,42 @@ var config = &configuration{
|
||||
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
|
||||
MaxGoroutines: 300,
|
||||
},
|
||||
FilteringEnabled: true, // whether or not use filter lists
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
DnsfilterConf: &filtering.Config{
|
||||
SafeBrowsingCacheSize: 1 * 1024 * 1024,
|
||||
SafeSearchCacheSize: 1 * 1024 * 1024,
|
||||
ParentalCacheSize: 1 * 1024 * 1024,
|
||||
CacheTime: 30,
|
||||
FilteringEnabled: true,
|
||||
FiltersUpdateIntervalHours: 24,
|
||||
},
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UsePrivateRDNS: true,
|
||||
},
|
||||
TLS: tlsConfigSettings{
|
||||
PortHTTPS: defaultPortHTTPS,
|
||||
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
|
||||
PortDNSOverQUIC: defaultPortQUIC,
|
||||
},
|
||||
Filters: []filtering.FilterYAML{{
|
||||
Filter: filtering.Filter{ID: 1},
|
||||
Enabled: true,
|
||||
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
|
||||
Name: "AdGuard DNS filter",
|
||||
}, {
|
||||
Filter: filtering.Filter{ID: 2},
|
||||
Enabled: false,
|
||||
URL: "https://adaway.org/hosts.txt",
|
||||
Name: "AdAway Default Blocklist",
|
||||
}},
|
||||
DHCP: &dhcpd.ServerConfig{
|
||||
LocalDomainName: "lan",
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
|
||||
ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP,
|
||||
},
|
||||
Conf6: dhcpd.V6ServerConf{
|
||||
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
|
||||
},
|
||||
},
|
||||
Clients: &clientsConfig{
|
||||
Sources: &clientSourcesConf{
|
||||
@@ -255,31 +287,6 @@ var config = &configuration{
|
||||
SchemaVersion: currentSchemaVersion,
|
||||
}
|
||||
|
||||
// initConfig initializes default configuration for the current OS&ARCH
|
||||
func initConfig() {
|
||||
config.WebSessionTTLHours = 30 * 24
|
||||
|
||||
config.DNS.QueryLogEnabled = true
|
||||
config.DNS.QueryLogFileEnabled = true
|
||||
config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day}
|
||||
config.DNS.QueryLogMemSize = 1000
|
||||
|
||||
config.DNS.CacheSize = 4 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
|
||||
config.DNS.DnsfilterConf.CacheTime = 30
|
||||
config.Filters = defaultFilters()
|
||||
|
||||
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
||||
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
|
||||
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
|
||||
|
||||
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
|
||||
config.BetaBindPort = 3001
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigFilename returns path to the current config file
|
||||
func (c *configuration) getConfigFilename() string {
|
||||
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
||||
@@ -348,8 +355,8 @@ func parseConfig() (err error) {
|
||||
return fmt.Errorf("validating udp ports: %w", err)
|
||||
}
|
||||
|
||||
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
|
||||
config.DNS.FiltersUpdateIntervalHours = 24
|
||||
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
|
||||
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
|
||||
}
|
||||
|
||||
if config.DNS.UpstreamTimeout.Duration == 0 {
|
||||
@@ -418,10 +425,11 @@ func (c *configuration) write() (err error) {
|
||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||
}
|
||||
|
||||
if Context.dnsFilter != nil {
|
||||
c := filtering.Config{}
|
||||
Context.dnsFilter.WriteDiskConfig(&c)
|
||||
config.DNS.DnsfilterConf = c
|
||||
if Context.filters != nil {
|
||||
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
|
||||
config.Filters = config.DNS.DnsfilterConf.Filters
|
||||
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
|
||||
config.UserRules = config.DNS.DnsfilterConf.UserRules
|
||||
}
|
||||
|
||||
if s := Context.dnsServer; s != nil {
|
||||
@@ -433,9 +441,7 @@ func (c *configuration) write() (err error) {
|
||||
}
|
||||
|
||||
if Context.dhcpServer != nil {
|
||||
c := &dhcpd.ServerConfig{}
|
||||
Context.dhcpServer.WriteDiskConfig(c)
|
||||
config.DHCP = c
|
||||
Context.dhcpServer.WriteDiskConfig(config.DHCP)
|
||||
}
|
||||
|
||||
config.Clients.Persistent = Context.clients.forConfig()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@@ -97,16 +98,16 @@ func collectDNSAddresses() (addrs []string, err error) {
|
||||
|
||||
// statusResponse is a response for /control/status endpoint.
|
||||
type statusResponse struct {
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
DNSAddrs []string `json:"dns_addresses"`
|
||||
DNSPort int `json:"dns_port"`
|
||||
HTTPPort int `json:"http_port"`
|
||||
IsProtectionEnabled bool `json:"protection_enabled"`
|
||||
// TODO(e.burkov): Inspect if front-end doesn't requires this field as
|
||||
// openapi.yaml declares.
|
||||
IsDHCPAvailable bool `json:"dhcp_available"`
|
||||
IsRunning bool `json:"running"`
|
||||
Version string `json:"version"`
|
||||
Language string `json:"language"`
|
||||
IsDHCPAvailable bool `json:"dhcp_available"`
|
||||
IsRunning bool `json:"running"`
|
||||
}
|
||||
|
||||
func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -125,12 +126,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
defer config.RUnlock()
|
||||
|
||||
resp = statusResponse{
|
||||
Version: version.Version(),
|
||||
DNSAddrs: dnsAddrs,
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.BindPort,
|
||||
IsRunning: isRunning(),
|
||||
Version: version.Version(),
|
||||
Language: config.Language,
|
||||
IsRunning: isRunning(),
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -146,13 +147,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp.IsDHCPAvailable = Context.dhcpServer != nil
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
type profileJSON struct {
|
||||
@@ -162,13 +157,16 @@ type profileJSON struct {
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
pj := profileJSON{}
|
||||
u := Context.auth.getCurrentUser(r)
|
||||
|
||||
pj.Name = u.Name
|
||||
|
||||
data, err := json.Marshal(pj)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
@@ -199,19 +197,29 @@ func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||
}
|
||||
|
||||
// ----------------------------------
|
||||
// helper functions for HTTP handlers
|
||||
// ----------------------------------
|
||||
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
// ensure returns a wrapped handler that makes sure that the request has the
|
||||
// correct method as well as additional method and header checks.
|
||||
func ensure(
|
||||
method string,
|
||||
handler func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Debug("%s %v", r.Method, r.URL)
|
||||
start := time.Now()
|
||||
m, u := r.Method, r.URL
|
||||
log.Debug("started %s %s %s", m, r.Host, u)
|
||||
defer func() { log.Debug("finished %s %s %s in %s", m, r.Host, u, time.Since(start)) }()
|
||||
|
||||
if m != method {
|
||||
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
|
||||
|
||||
if r.Method != method {
|
||||
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if method == http.MethodPost || method == http.MethodPut || method == http.MethodDelete {
|
||||
if modifiesData(m) {
|
||||
if !ensureContentType(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
Context.controlLock.Lock()
|
||||
defer Context.controlLock.Unlock()
|
||||
}
|
||||
@@ -220,6 +228,42 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
|
||||
}
|
||||
}
|
||||
|
||||
// modifiesData returns true if m is an HTTP method that can modify data.
|
||||
func modifiesData(m string) (ok bool) {
|
||||
return m == http.MethodPost || m == http.MethodPut || m == http.MethodDelete
|
||||
}
|
||||
|
||||
// ensureContentType makes sure that the content type of a data-modifying
|
||||
// request is set correctly. If it is not, ensureContentType writes a response
|
||||
// to w, and ok is false.
|
||||
func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
const statusUnsup = http.StatusUnsupportedMediaType
|
||||
|
||||
cType := r.Header.Get(aghhttp.HdrNameContentType)
|
||||
if r.ContentLength == 0 {
|
||||
if cType == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Assume that browsers always send a content type when submitting HTML
|
||||
// forms and require no content type for requests with no body to make
|
||||
// sure that the request comes from JavaScript.
|
||||
aghhttp.Error(r, w, statusUnsup, "empty body with content-type %q not allowed", cType)
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
const wantCType = aghhttp.HdrValApplicationJSON
|
||||
if cType == wantCType {
|
||||
return true
|
||||
}
|
||||
|
||||
aghhttp.Error(r, w, statusUnsup, "only content-type %s is allowed", wantCType)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return ensure(http.MethodPost, handler)
|
||||
}
|
||||
@@ -291,7 +335,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
}
|
||||
|
||||
httpsURL := &url.URL{
|
||||
Scheme: schemeHTTPS,
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Host: hostPort,
|
||||
Path: r.URL.Path,
|
||||
RawQuery: r.URL.RawQuery,
|
||||
@@ -307,7 +351,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
//
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
|
||||
originURL := &url.URL{
|
||||
Scheme: schemeHTTP,
|
||||
Scheme: aghhttp.SchemeHTTP,
|
||||
Host: r.Host,
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", originURL.String())
|
||||
|
||||
@@ -59,19 +59,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
|
||||
data.Interfaces[iface.Name] = iface
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal default addresses to json: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
type checkConfReqEnt struct {
|
||||
@@ -201,13 +189,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding the response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// handleStaticIP - handles static IP request
|
||||
@@ -424,7 +406,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
u := &User{
|
||||
u := &webUser{
|
||||
Name: req.Username,
|
||||
}
|
||||
Context.auth.UserAdd(u, req.Password)
|
||||
@@ -688,19 +670,7 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
|
||||
|
||||
data.Interfaces = ifaces
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
aghhttp.Error(
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"Unable to marshal default addresses to json: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, data)
|
||||
}
|
||||
|
||||
// registerBetaInstallHandlers registers the install handlers for new client
|
||||
|
||||
@@ -28,8 +28,6 @@ type temporaryError interface {
|
||||
|
||||
// Get the latest available version from the Internet
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
resp := &versionResponse{}
|
||||
if Context.disableUpdate {
|
||||
resp.Disabled = true
|
||||
@@ -71,10 +69,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
|
||||
}
|
||||
_ = aghhttp.WriteJSONResponse(w, r, resp)
|
||||
}
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
|
||||
@@ -31,7 +31,10 @@ const (
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
func onConfigModified() {
|
||||
_ = config.write()
|
||||
err := config.write()
|
||||
if err != nil {
|
||||
log.Error("writing config: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// initDNSServer creates an instance of the dnsforward.Server
|
||||
@@ -71,11 +74,11 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
Context.queryLog = querylog.New(conf)
|
||||
|
||||
filterConf := config.DNS.DnsfilterConf
|
||||
filterConf.EtcHosts = Context.etcHosts
|
||||
filterConf.ConfigModified = onConfigModified
|
||||
filterConf.HTTPRegister = httpRegister
|
||||
Context.dnsFilter = filtering.New(&filterConf, nil)
|
||||
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
|
||||
if err != nil {
|
||||
// Don't wrap the error, since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
var privateNets netutil.SubnetSet
|
||||
switch len(config.DNS.PrivateNets) {
|
||||
@@ -83,13 +86,10 @@ func initDNSServer() (err error) {
|
||||
// Use an optimized locally-served matcher.
|
||||
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||
case 1:
|
||||
var n *net.IPNet
|
||||
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing the set of private subnets: %w", err)
|
||||
}
|
||||
|
||||
privateNets = n
|
||||
default:
|
||||
var nets []*net.IPNet
|
||||
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
|
||||
@@ -101,15 +101,13 @@ func initDNSServer() (err error) {
|
||||
}
|
||||
|
||||
p := dnsforward.DNSCreateParams{
|
||||
DNSFilter: Context.dnsFilter,
|
||||
DNSFilter: Context.filters,
|
||||
Stats: Context.stats,
|
||||
QueryLog: Context.queryLog,
|
||||
PrivateNets: privateNets,
|
||||
Anonymizer: anonymizer,
|
||||
LocalDomain: config.DHCP.LocalDomainName,
|
||||
}
|
||||
if Context.dhcpServer != nil {
|
||||
p.DHCPServer = Context.dhcpServer
|
||||
DHCPServer: Context.dhcpServer,
|
||||
}
|
||||
|
||||
Context.dnsServer, err = dnsforward.NewServer(p)
|
||||
@@ -143,7 +141,6 @@ func initDNSServer() (err error) {
|
||||
Context.whois = initWHOIS(&Context.clients)
|
||||
}
|
||||
|
||||
Context.filters.Init()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -335,9 +332,12 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||
// applyAdditionalFiltering adds additional client information and settings if
|
||||
// the client has them.
|
||||
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||
// pref is a prefix for logging messages around the scope.
|
||||
const pref = "applying filters"
|
||||
|
||||
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
|
||||
Context.filters.ApplyBlockedServices(setts, nil)
|
||||
|
||||
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
if clientIP == nil {
|
||||
return
|
||||
@@ -349,16 +349,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
|
||||
if !ok {
|
||||
c, ok = Context.clients.Find(clientIP.String())
|
||||
if !ok {
|
||||
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
|
||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
|
||||
log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID)
|
||||
|
||||
if c.UseOwnBlockedServices {
|
||||
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
|
||||
Context.filters.ApplyBlockedServices(setts, c.BlockedServices)
|
||||
}
|
||||
|
||||
setts.ClientName = c.Name
|
||||
@@ -381,7 +381,7 @@ func startDNSServer() error {
|
||||
return fmt.Errorf("unable to start forwarding DNS server: Already running")
|
||||
}
|
||||
|
||||
enableFiltersLocked(false)
|
||||
Context.filters.EnableFilters(false)
|
||||
|
||||
Context.clients.Start()
|
||||
|
||||
@@ -390,7 +390,6 @@ func startDNSServer() error {
|
||||
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
|
||||
}
|
||||
|
||||
Context.dnsFilter.Start()
|
||||
Context.filters.Start()
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
@@ -449,10 +448,7 @@ func closeDNSServer() {
|
||||
Context.dnsServer = nil
|
||||
}
|
||||
|
||||
if Context.dnsFilter != nil {
|
||||
Context.dnsFilter.Close()
|
||||
Context.dnsFilter = nil
|
||||
}
|
||||
Context.filters.Close()
|
||||
|
||||
if Context.stats != nil {
|
||||
err := Context.stats.Close()
|
||||
@@ -469,7 +465,5 @@ func closeDNSServer() {
|
||||
Context.queryLog = nil
|
||||
}
|
||||
|
||||
Context.filters.Close()
|
||||
|
||||
log.Debug("Closed all DNS modules")
|
||||
log.Debug("all dns modules are closed")
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"golang.org/x/exp/slices"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
@@ -52,10 +54,9 @@ type homeContext struct {
|
||||
dnsServer *dnsforward.Server // DNS module
|
||||
rdns *RDNS // rDNS module
|
||||
whois *WHOIS // WHOIS module
|
||||
dnsFilter *filtering.DNSFilter // DNS filtering module
|
||||
dhcpServer *dhcpd.Server // DHCP module
|
||||
dhcpServer dhcpd.Interface // DHCP module
|
||||
auth *Auth // HTTP authentication module
|
||||
filters Filtering // DNS filtering module
|
||||
filters *filtering.DNSFilter // DNS filtering module
|
||||
web *Web // Web (HTTP, HTTPS) module
|
||||
tls *TLSMod // TLS module
|
||||
// etcHosts is an IP-hostname pairs set taken from system configuration
|
||||
@@ -140,7 +141,12 @@ func setupContext(args options) {
|
||||
checkPermissions()
|
||||
}
|
||||
|
||||
initConfig()
|
||||
switch version.Channel() {
|
||||
case version.ChannelEdge, version.ChannelDevelopment:
|
||||
config.BetaBindPort = 3001
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
|
||||
Context.tlsRoots = LoadSystemRootCAs()
|
||||
Context.transport = &http.Transport{
|
||||
@@ -265,6 +271,15 @@ func setupHostsContainer() (err error) {
|
||||
}
|
||||
|
||||
func setupConfig(args options) (err error) {
|
||||
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
|
||||
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
|
||||
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
|
||||
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
|
||||
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
|
||||
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||
config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules)
|
||||
config.DNS.DnsfilterConf.HTTPClient = Context.client
|
||||
|
||||
config.DHCP.WorkDir = Context.workDir
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
@@ -384,8 +399,6 @@ func fatalOnError(err error) {
|
||||
|
||||
// run configures and starts AdGuard Home.
|
||||
func run(args options, clientBuildFS fs.FS) {
|
||||
var err error
|
||||
|
||||
// configure config filename
|
||||
initConfigFilename(args)
|
||||
|
||||
@@ -396,7 +409,7 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
configureLogger(args)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Println(version.Full())
|
||||
log.Info(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
if args.runningAsService {
|
||||
log.Info("AdGuard Home is running as a service")
|
||||
@@ -404,7 +417,7 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
|
||||
setupContext(args)
|
||||
|
||||
err = configureOS(config)
|
||||
err := configureOS(config)
|
||||
fatalOnError(err)
|
||||
|
||||
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
|
||||
@@ -442,9 +455,9 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
GLMode = args.glinetMode
|
||||
var arl *authRateLimiter
|
||||
var rateLimiter *authRateLimiter
|
||||
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
|
||||
arl = newAuthRateLimiter(
|
||||
rateLimiter = newAuthRateLimiter(
|
||||
time.Duration(config.AuthBlockMin)*time.Minute,
|
||||
config.AuthAttempts,
|
||||
)
|
||||
@@ -456,7 +469,7 @@ func run(args options, clientBuildFS fs.FS) {
|
||||
sessFilename,
|
||||
config.Users,
|
||||
config.WebSessionTTLHours*60*60,
|
||||
arl,
|
||||
rateLimiter,
|
||||
)
|
||||
if Context.auth == nil {
|
||||
log.Fatalf("Couldn't initialize Auth module")
|
||||
@@ -641,14 +654,9 @@ func configureLogger(args options) {
|
||||
log.Fatalf("cannot initialize syslog: %s", err)
|
||||
}
|
||||
} else {
|
||||
logFilePath := filepath.Join(Context.workDir, ls.File)
|
||||
if filepath.IsAbs(ls.File) {
|
||||
logFilePath = ls.File
|
||||
}
|
||||
|
||||
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot create a log file: %s", err)
|
||||
logFilePath := ls.File
|
||||
if !filepath.IsAbs(logFilePath) {
|
||||
logFilePath = filepath.Join(Context.workDir, logFilePath)
|
||||
}
|
||||
|
||||
log.SetOutput(&lumberjack.Logger{
|
||||
@@ -768,12 +776,12 @@ func printHTTPAddresses(proto string) {
|
||||
}
|
||||
|
||||
port := config.BindPort
|
||||
if proto == schemeHTTPS {
|
||||
if proto == aghhttp.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
|
||||
if proto == schemeHTTPS && tlsConf.ServerName != "" {
|
||||
if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
|
||||
|
||||
return
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -51,43 +49,35 @@ var allowedLanguages = stringutil.NewSet(
|
||||
"zh-tw",
|
||||
)
|
||||
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
log.Printf("config.Language is %s", config.Language)
|
||||
_, err := fmt.Fprintf(w, "%s\n", config.Language)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Unable to write response json: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
// languageJSON is the JSON structure for language requests and responses.
|
||||
type languageJSON struct {
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("home: language is %s", config.Language)
|
||||
|
||||
_ = aghhttp.WriteJSONResponse(w, r, &languageJSON{
|
||||
Language: config.Language,
|
||||
})
|
||||
}
|
||||
|
||||
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
// This use of ReadAll is safe, because request's body is now limited.
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
langReq := &languageJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(langReq)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "reading req: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
language := strings.TrimSpace(string(body))
|
||||
if language == "" {
|
||||
msg := "empty language specified"
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !allowedLanguages.Has(language) {
|
||||
msg := fmt.Sprintf("unknown language specified: %s", language)
|
||||
log.Println(msg)
|
||||
http.Error(w, msg, http.StatusBadRequest)
|
||||
lang := langReq.Language
|
||||
if !allowedLanguages.Has(lang) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "unknown language: %q", lang)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -96,7 +86,8 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
|
||||
config.Language = language
|
||||
config.Language = lang
|
||||
log.Printf("home: language is set to %s", lang)
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
|
||||
case dnsProtoHTTPS:
|
||||
dspName = fmt.Sprintf("%s DoH", d.ServerName)
|
||||
u := &url.URL{
|
||||
Scheme: schemeHTTPS,
|
||||
Scheme: aghhttp.SchemeHTTPS,
|
||||
Host: d.ServerName,
|
||||
Path: path.Join("/dns-query", clientID),
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -175,7 +176,8 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
|
||||
chooseSystem()
|
||||
|
||||
action := opts.serviceControlAction
|
||||
log.Printf("service: control action: %s", action)
|
||||
log.Info(version.Full())
|
||||
log.Info("service: control action: %s", action)
|
||||
|
||||
if action == "reload" {
|
||||
sendSigReload()
|
||||
@@ -277,7 +279,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
|
||||
There are a few more things that must be configured before you can use it.
|
||||
Click on the link below and follow the Installation Wizard steps to finish setup.
|
||||
AdGuard Home is now available at the following addresses:`)
|
||||
printHTTPAddresses(schemeHTTP)
|
||||
printHTTPAddresses(aghhttp.SchemeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user