Merge branch 'master' into websvc-confin-manager

This commit is contained in:
Ainar Garipov
2022-09-30 15:27:54 +03:00
162 changed files with 2095 additions and 1755 deletions

View File

@@ -2,6 +2,7 @@
package aghhttp
import (
"encoding/json"
"fmt"
"io"
"net/http"
@@ -10,6 +11,12 @@ import (
"github.com/AdguardTeam/golibs/log"
)
// HTTP scheme constants.
const (
SchemeHTTP = "http"
SchemeHTTPS = "https"
)
// RegisterFunc is the function that sets the handler to handle the URL for the
// method.
//
@@ -26,7 +33,7 @@ func OK(w http.ResponseWriter) {
// Error writes formatted message to w and also logs it.
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s: %s", r.Method, r.URL, text)
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
http.Error(w, text, code)
}
@@ -35,3 +42,34 @@ func Error(r *http.Request, w http.ResponseWriter, code int, format string, args
func UserAgent() (ua string) {
return fmt.Sprintf("AdGuardHome/%s", version.Version())
}
// textPlainDeprMsg is the message returned to API users when they try to use
// an API that used to accept "text/plain" but doesn't anymore.
const textPlainDeprMsg = `using this api with the text/plain content-type is deprecated; ` +
`use application/json`
// WriteTextPlainDeprecated responds to the request with a message about
// deprecation and removal of a plain-text API if the request is made with the
// "text/plain" content-type.
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
if r.Header.Get(HdrNameContentType) != HdrValTextPlain {
return false
}
Error(r, w, http.StatusUnsupportedMediaType, textPlainDeprMsg)
return true
}
// WriteJSONResponse sets the content-type header in w.Header() to
// "application/json", encodes resp to w, calls Error on any returned error, and
// returns it as well.
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) {
w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
Error(r, w, http.StatusInternalServerError, "encoding resp: %s", err)
}
return err
}

View File

@@ -8,8 +8,8 @@ package aghhttp
const (
HdrNameAcceptEncoding = "Accept-Encoding"
HdrNameAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HdrNameContentType = "Content-Type"
HdrNameContentEncoding = "Content-Encoding"
HdrNameContentType = "Content-Type"
HdrNameServer = "Server"
HdrNameTrailer = "Trailer"
HdrNameUserAgent = "User-Agent"
@@ -18,4 +18,5 @@ const (
// HTTP header value constants.
const (
HdrValApplicationJSON = "application/json"
HdrValTextPlain = "text/plain"
)

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd
// +build darwin freebsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd
// +build darwin freebsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd
// +build openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd
// +build openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !(windows || linux)
// +build !windows,!linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd netbsd openbsd solaris
//go:build darwin || freebsd || openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet
@@ -19,27 +18,18 @@ import (
// How to test on a real Linux machine:
//
// 1. Run:
// 1. Run "sudo ipset create example_set hash:ip family ipv4".
//
// sudo ipset create example_set hash:ip family ipv4
// 2. Run "sudo ipset list example_set". The Members field should be empty.
//
// 2. Run:
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
//
// sudo ipset list example_set
// 4. Start AdGuardHome.
//
// The Members field should be empty.
// 5. Make requests to example.com and its subdomains.
//
// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml.
//
// 4. Start AdGuardHome.
//
// 5. Make requests to example.com and its subdomains.
//
// 6. Run:
//
// sudo ipset list example_set
//
// The Members field should contain the resolved IP addresses.
// 6. Run "sudo ipset list example_set". The Members field should contain the
// resolved IP addresses.
// newIpsetMgr returns a new Linux ipset manager.
func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) {

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !linux
// +build !linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd || openbsd
// +build darwin freebsd openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build darwin
// +build darwin
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build freebsd
// +build freebsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build freebsd
// +build freebsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd
// +build openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd
// +build openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build openbsd || freebsd || linux || darwin
// +build openbsd freebsd linux darwin
//go:build darwin || freebsd || linux || openbsd
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghnet

View File

@@ -1,5 +1,4 @@
//go:build mips || mips64
// +build mips mips64
// This file is an adapted version of github.com/josharian/native.

View File

@@ -1,5 +1,4 @@
//go:build amd64 || 386 || arm || arm64 || mipsle || mips64le || ppc64le
// +build amd64 386 arm arm64 mipsle mips64le ppc64le
// This file is an adapted version of github.com/josharian/native.

View File

@@ -1,5 +1,4 @@
//go:build darwin || netbsd || openbsd
// +build darwin netbsd openbsd
//go:build darwin || openbsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build freebsd
// +build freebsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build linux
// +build linux
package aghos

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd || linux || openbsd
// +build darwin freebsd linux openbsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghos

View File

@@ -1,5 +1,4 @@
//go:build !(windows || plan9)
// +build !windows,!plan9
//go:build !windows
package aghos

View File

@@ -1,5 +1,4 @@
//go:build windows || plan9
// +build windows plan9
//go:build windows
package aghos

View File

@@ -1,5 +1,4 @@
//go:build darwin || freebsd || linux || netbsd || openbsd
// +build darwin freebsd linux netbsd openbsd
//go:build darwin || freebsd || linux || openbsd
package aghos

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package aghos

View File

@@ -1,5 +1,4 @@
//go:build freebsd || openbsd
// +build freebsd openbsd
package dhcpd
@@ -10,11 +9,10 @@ import (
// broadcast sends resp to the broadcast address specific for network interface.
func (c *dhcpConn) broadcast(respData []byte, peer *net.UDPAddr) (n int, err error) {
// Despite the fact that server4.NewIPv4UDPConn explicitly sets socket
// options to allow broadcasting, it also binds the connection to a
// specific interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo
// causes errors while writing to the addresses that belong to another
// interface. So, use the broadcast address specific for the interface
// bound.
// options to allow broadcasting, it also binds the connection to a specific
// interface. On FreeBSD and OpenBSD net.UDPConn.WriteTo causes errors
// while writing to the addresses that belong to another interface. So, use
// the broadcast address specific for the interface bound.
peer.IP = c.bcastIP
return c.udpConn.WriteTo(respData, peer)

View File

@@ -1,5 +1,4 @@
//go:build freebsd || openbsd
// +build freebsd openbsd
package dhcpd

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
// +build aix darwin dragonfly linux netbsd solaris
//go:build darwin || linux
package dhcpd

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || linux || netbsd || solaris
// +build aix darwin dragonfly linux netbsd solaris
//go:build darwin || linux
package dhcpd

View File

@@ -1,10 +1,40 @@
package dhcpd
import (
"fmt"
"net"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
)
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
// LocalDomainName is the domain name used for DHCP hosts. For example,
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
// when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"`
}
// DHCPServer - DHCP server interface
type DHCPServer interface {
// ResetLeases resets leases.
@@ -80,6 +110,86 @@ type V4ServerConf struct {
notify func(uint32)
}
// errNilConfig is an error returned by validation method if the config is nil.
const errNilConfig errors.Error = "nil config"
// ensureV4 returns a 4-byte version of ip. An error is returned if the passed
// ip is not an IPv4.
func ensureV4(ip net.IP) (ip4 net.IP, err error) {
if ip == nil {
return nil, fmt.Errorf("%v is not an IP address", ip)
}
ip4 = ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
}
return ip4, nil
}
// Validate returns an error if c is not a valid configuration.
//
// TODO(e.burkov): Don't set the config fields when the server itself will stop
// containing the config.
func (c *V4ServerConf) Validate() (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if c == nil {
return errNilConfig
}
var gatewayIP net.IP
gatewayIP, err = ensureV4(c.GatewayIP)
if err != nil {
// Don't wrap an errors since it's inforative enough as is and there is
// an annotation deferred already.
return err
}
if c.SubnetMask == nil {
return fmt.Errorf("invalid subnet mask: %v", c.SubnetMask)
}
subnetMask := net.IPMask(netutil.CloneIP(c.SubnetMask.To4()))
c.subnet = &net.IPNet{
IP: gatewayIP,
Mask: subnetMask,
}
c.broadcastIP = aghnet.BroadcastFromIPNet(c.subnet)
c.ipRange, err = newIPRange(c.RangeStart, c.RangeEnd)
if err != nil {
// Don't wrap an errors since it's inforative enough as is and there is
// an annotation deferred already.
return err
}
if c.ipRange.contains(gatewayIP) {
return fmt.Errorf("gateway ip %v in the ip range: %v-%v",
gatewayIP,
c.RangeStart,
c.RangeEnd,
)
}
if !c.subnet.Contains(c.RangeStart) {
return fmt.Errorf("range start %v is outside network %v",
c.RangeStart,
c.subnet,
)
}
if !c.subnet.Contains(c.RangeEnd) {
return fmt.Errorf("range end %v is outside network %v",
c.RangeEnd,
c.subnet,
)
}
return nil
}
// V6ServerConf - server configuration
type V6ServerConf struct {
Enabled bool `yaml:"-" json:"-"`

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd

View File

@@ -32,7 +32,7 @@ func normalizeIP(ip net.IP) net.IP {
}
// Load lease table from DB
func (s *Server) dbLoad() (err error) {
func (s *server) dbLoad() (err error) {
dynLeases := []*Lease{}
staticLeases := []*Lease{}
v6StaticLeases := []*Lease{}
@@ -132,7 +132,7 @@ func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
}
// Store lease table in DB
func (s *Server) dbStore() (err error) {
func (s *server) dbStore() (err error) {
// Use an empty slice here as opposed to nil so that it doesn't write
// "null" into the database file if leases are empty.
leases := []leaseJSON{}

View File

@@ -6,12 +6,11 @@ import (
"fmt"
"net"
"path/filepath"
"runtime"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
)
const (
@@ -21,9 +20,19 @@ const (
// TODO(e.burkov): Remove it when static leases determining mechanism
// will be improved.
leaseExpireStatic = 1
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
DefaultDHCPTimeoutICMP = 1000
)
var webHandlersRegistered = false
// Currently used defaults for ifaceDNSAddrs.
const (
defaultMaxAttempts int = 10
defaultBackoff time.Duration = 500 * time.Millisecond
)
// Lease contains the necessary information about a DHCP lease
type Lease struct {
@@ -119,30 +128,6 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) {
return nil
}
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
// LocalDomainName is the domain name used for DHCP hosts. For example,
// a DHCP client with the hostname "myhost" can be addressed as "myhost.lan"
// when LocalDomainName is "lan".
LocalDomainName string `yaml:"local_domain_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"`
}
// OnLeaseChangedT is a callback for lease changes.
type OnLeaseChangedT func(flags int)
@@ -156,8 +141,68 @@ const (
LeaseChangedDBStore
)
// Server - the current state of the DHCP server
type Server struct {
// GetLeasesFlags are the flags for GetLeases.
type GetLeasesFlags uint8
// GetLeasesFlags values
const (
LeasesDynamic GetLeasesFlags = 0b01
LeasesStatic GetLeasesFlags = 0b10
LeasesAll = LeasesDynamic | LeasesStatic
)
// Interface is the DHCP server that deals with both IP address families.
type Interface interface {
Start() (err error)
Stop() (err error)
Enabled() (ok bool)
Leases(flags GetLeasesFlags) (leases []*Lease)
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
FindMACbyIP(ip net.IP) (mac net.HardwareAddr)
WriteDiskConfig(c *ServerConfig)
}
// MockInterface is a mock Interface implementation.
//
// TODO(e.burkov): Move to aghtest when the API stabilized.
type MockInterface struct {
OnStart func() (err error)
OnStop func() (err error)
OnEnabled func() (ok bool)
OnLeases func(flags GetLeasesFlags) (leases []*Lease)
OnSetOnLeaseChanged func(f OnLeaseChangedT)
OnFindMACbyIP func(ip net.IP) (mac net.HardwareAddr)
OnWriteDiskConfig func(c *ServerConfig)
}
var _ Interface = (*MockInterface)(nil)
// Start implements the Interface for *MockInterface.
func (s *MockInterface) Start() (err error) { return s.OnStart() }
// Stop implements the Interface for *MockInterface.
func (s *MockInterface) Stop() (err error) { return s.OnStop() }
// Enabled implements the Interface for *MockInterface.
func (s *MockInterface) Enabled() (ok bool) { return s.OnEnabled() }
// Leases implements the Interface for *MockInterface.
func (s *MockInterface) Leases(flags GetLeasesFlags) (ls []*Lease) { return s.OnLeases(flags) }
// SetOnLeaseChanged implements the Interface for *MockInterface.
func (s *MockInterface) SetOnLeaseChanged(f OnLeaseChangedT) { s.OnSetOnLeaseChanged(f) }
// FindMACbyIP implements the Interface for *MockInterface.
func (s *MockInterface) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return s.OnFindMACbyIP(ip) }
// WriteDiskConfig implements the Interface for *MockInterface.
func (s *MockInterface) WriteDiskConfig(c *ServerConfig) { s.OnWriteDiskConfig(c) }
// server is the DHCP service that handles DHCPv4, DHCPv6, and HTTP API.
type server struct {
srv4 DHCPServer
srv6 DHCPServer
@@ -169,27 +214,15 @@ type Server struct {
onLeaseChanged []OnLeaseChangedT
}
// GetLeasesFlags are the flags for GetLeases.
type GetLeasesFlags uint8
// type check
var _ Interface = (*server)(nil)
// GetLeasesFlags values
const (
LeasesDynamic GetLeasesFlags = 0b0001
LeasesStatic GetLeasesFlags = 0b0010
LeasesAll = LeasesDynamic | LeasesStatic
)
// ServerInterface is an interface for servers.
type ServerInterface interface {
Enabled() (ok bool)
Leases(flags GetLeasesFlags) (leases []*Lease)
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
}
// Create - create object
func Create(conf *ServerConfig) (s *Server, err error) {
s = &Server{
// Create initializes and returns the DHCP server handling both address
// families. It also registers the corresponding HTTP API endpoints.
//
// TODO(e.burkov): Don't register handlers, see TODO on [aghhttp.RegisterFunc].
func Create(conf *ServerConfig) (s *server, err error) {
s = &server{
conf: &ServerConfig{
ConfigModified: conf.ConfigModified,
@@ -204,35 +237,20 @@ func Create(conf *ServerConfig) (s *Server, err error) {
},
}
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
if runtime.GOOS == "windows" {
// Our DHCP server doesn't work on Windows yet, so
// signal that to the front with an HTTP 501.
//
// TODO(a.garipov): This needs refactoring. We
// shouldn't even try and initialize a DHCP server on
// Windows, but there are currently too many
// interconnected parts--such as HTTP handlers and
// frontend--to make that work properly.
s.registerNotImplementedHandlers()
} else {
s.registerHandlers()
}
webHandlersRegistered = true
}
s.registerHandlers()
v4conf := conf.Conf4
v4conf.Enabled = s.conf.Enabled
if len(v4conf.RangeStart) == 0 {
v4conf.Enabled = false
}
v4conf.InterfaceName = s.conf.InterfaceName
v4conf.notify = s.onNotify
s.srv4, err = v4Create(v4conf)
v4conf.Enabled = s.conf.Enabled && len(v4conf.RangeStart) != 0
s.srv4, err = v4Create(&v4conf)
if err != nil {
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
if v4conf.Enabled {
return nil, fmt.Errorf("creating dhcpv4 srv: %w", err)
}
log.Error("creating dhcpv4 srv: %s", err)
}
v6conf := conf.Conf6
@@ -265,12 +283,12 @@ func Create(conf *ServerConfig) (s *Server, err error) {
}
// Enabled returns true when the server is enabled.
func (s *Server) Enabled() (ok bool) {
func (s *server) Enabled() (ok bool) {
return s.conf.Enabled
}
// resetLeases resets all leases in the lease database.
func (s *Server) resetLeases() (err error) {
func (s *server) resetLeases() (err error) {
err = s.srv4.ResetLeases(nil)
if err != nil {
return err
@@ -287,7 +305,7 @@ func (s *Server) resetLeases() (err error) {
}
// server calls this function after DB is updated
func (s *Server) onNotify(flags uint32) {
func (s *server) onNotify(flags uint32) {
if flags == LeaseChangedDBStore {
err := s.dbStore()
if err != nil {
@@ -301,31 +319,28 @@ func (s *Server) onNotify(flags uint32) {
}
// SetOnLeaseChanged - set callback
func (s *Server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
func (s *server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
}
func (s *Server) notify(flags int) {
if len(s.onLeaseChanged) == 0 {
return
}
func (s *server) notify(flags int) {
for _, f := range s.onLeaseChanged {
f(flags)
}
}
// WriteDiskConfig - write configuration
func (s *Server) WriteDiskConfig(c *ServerConfig) {
func (s *server) WriteDiskConfig(c *ServerConfig) {
c.Enabled = s.conf.Enabled
c.InterfaceName = s.conf.InterfaceName
c.LocalDomainName = s.conf.LocalDomainName
s.srv4.WriteDiskConfig4(&c.Conf4)
s.srv6.WriteDiskConfig6(&c.Conf6)
}
// Start will listen on port 67 and serve DHCP requests.
func (s *Server) Start() (err error) {
func (s *server) Start() (err error) {
err = s.srv4.Start()
if err != nil {
return err
@@ -340,7 +355,7 @@ func (s *Server) Start() (err error) {
}
// Stop closes the listening UDP socket
func (s *Server) Stop() (err error) {
func (s *server) Stop() (err error) {
err = s.srv4.Stop()
if err != nil {
return err
@@ -356,12 +371,12 @@ func (s *Server) Stop() (err error) {
// Leases returns the list of active IPv4 and IPv6 DHCP leases. It's safe for
// concurrent use.
func (s *Server) Leases(flags GetLeasesFlags) (leases []*Lease) {
func (s *server) Leases(flags GetLeasesFlags) (leases []*Lease) {
return append(s.srv4.GetLeases(flags), s.srv6.GetLeases(flags)...)
}
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
func (s *server) FindMACbyIP(ip net.IP) net.HardwareAddr {
if ip.To4() != nil {
return s.srv4.FindMACbyIP(ip)
}
@@ -369,6 +384,6 @@ func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
}
// AddStaticLease - add static v4 lease
func (s *Server) AddStaticLease(l *Lease) error {
func (s *server) AddStaticLease(l *Lease) error {
return s.srv4.AddStaticLease(l)
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -26,13 +25,13 @@ func testNotify(flags uint32) {
// Leases database store/load.
func TestDB(t *testing.T) {
var err error
s := Server{
s := server{
conf: &ServerConfig{
DBFilePath: dbFilename,
},
}
s.srv4, err = v4Create(V4ServerConf{
s.srv4, err = v4Create(&V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
@@ -88,32 +87,6 @@ func TestDB(t *testing.T) {
assert.Equal(t, leases[0].Expiry.Unix(), ll[1].Expiry.Unix())
}
func TestIsValidSubnetMask(t *testing.T) {
testCases := []struct {
mask net.IP
want bool
}{{
mask: net.IP{255, 255, 255, 0},
want: true,
}, {
mask: net.IP{255, 255, 254, 0},
want: true,
}, {
mask: net.IP{255, 255, 252, 0},
want: true,
}, {
mask: net.IP{255, 255, 253, 0},
}, {
mask: net.IP{255, 255, 255, 1},
}}
for _, tc := range testCases {
t.Run(tc.mask.String(), func(t *testing.T) {
assert.Equal(t, tc.want, isValidSubnetMask(tc.mask))
})
}
}
func TestNormalizeLeases(t *testing.T) {
dynLeases := []*Lease{{
HWAddr: net.HardwareAddr{1, 2, 3, 4},
@@ -174,7 +147,7 @@ func TestV4Server_badRange(t *testing.T) {
notify: testNotify,
}
_, err := v4Create(conf)
_, err := v4Create(&conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}

View File

@@ -1,36 +0,0 @@
package dhcpd
import (
"encoding/binary"
"fmt"
"net"
)
func tryTo4(ip net.IP) (ip4 net.IP, err error) {
if ip == nil {
return nil, fmt.Errorf("%v is not an IP address", ip)
}
ip4 = ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
}
return ip4, nil
}
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
func isValidSubnetMask(mask net.IP) bool {
var n uint32
n = binary.BigEndian.Uint32(mask)
for i := 0; i != 32; i++ {
if n == 0 {
break
}
if (n & 0x80000000) == 0 {
return false
}
n <<= 1
}
return true
}

View File

@@ -1,21 +1,19 @@
//go:build darwin || freebsd || linux || openbsd
package dhcpd
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
type v4ServerConfJSON struct {
@@ -26,12 +24,12 @@ type v4ServerConfJSON struct {
LeaseDuration uint32 `json:"lease_duration"`
}
func v4JSONToServerConf(j *v4ServerConfJSON) V4ServerConf {
func (j *v4ServerConfJSON) toServerConf() *V4ServerConf {
if j == nil {
return V4ServerConf{}
return &V4ServerConf{}
}
return V4ServerConf{
return &V4ServerConf{
GatewayIP: j.GatewayIP,
SubnetMask: j.SubnetMask,
RangeStart: j.RangeStart,
@@ -66,7 +64,7 @@ type dhcpStatusResponse struct {
Enabled bool `json:"enabled"`
}
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
status := &dhcpStatusResponse{
Enabled: s.conf.Enabled,
IfaceName: s.conf.InterfaceName,
@@ -81,6 +79,7 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
status.StaticLeases = s.Leases(LeasesStatic)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status)
if err != nil {
aghhttp.Error(
@@ -93,28 +92,26 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) enableDHCP(ifaceName string) (code int, err error) {
func (s *server) enableDHCP(ifaceName string) (code int, err error) {
var hasStaticIP bool
hasStaticIP, err = aghnet.IfaceHasStaticIP(ifaceName)
if err != nil {
if errors.Is(err, os.ErrPermission) {
// ErrPermission may happen here on Linux systems where
// AdGuard Home is installed using Snap. That doesn't
// necessarily mean that the machine doesn't have
// a static IP, so we can assume that it has and go on.
// If the machine doesn't, we'll get an error later.
// ErrPermission may happen here on Linux systems where AdGuard Home
// is installed using Snap. That doesn't necessarily mean that the
// machine doesn't have a static IP, so we can assume that it has
// and go on. If the machine doesn't, we'll get an error later.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2667.
//
// TODO(a.garipov): I was thinking about moving this
// into IfaceHasStaticIP, but then we wouldn't be able
// to log it. Think about it more.
// TODO(a.garipov): I was thinking about moving this into
// IfaceHasStaticIP, but then we wouldn't be able to log it. Think
// about it more.
log.Info("error while checking static ip: %s; "+
"assuming machine has static ip and going on", err)
hasStaticIP = true
} else if errors.Is(err, aghnet.ErrNoStaticIPInfo) {
// Couldn't obtain a definitive answer. Assume static
// IP an go on.
// Couldn't obtain a definitive answer. Assume static IP an go on.
log.Info("can't check for static ip; " +
"assuming machine has static ip and going on")
hasStaticIP = true
@@ -149,34 +146,39 @@ type dhcpServerConfigJSON struct {
Enabled aghalg.NullBool `json:"enabled"`
}
func (s *Server) handleDHCPSetConfigV4(
func (s *server) handleDHCPSetConfigV4(
conf *dhcpServerConfigJSON,
) (srv4 DHCPServer, enabled bool, err error) {
) (srv DHCPServer, enabled bool, err error) {
if conf.V4 == nil {
return nil, false, nil
}
v4Conf := v4JSONToServerConf(conf.V4)
v4Conf := conf.V4.toServerConf()
v4Conf.Enabled = conf.Enabled == aghalg.NBTrue
if len(v4Conf.RangeStart) == 0 {
v4Conf.Enabled = false
}
enabled = v4Conf.Enabled
v4Conf.InterfaceName = conf.InterfaceName
c4 := V4ServerConf{}
s.srv4.WriteDiskConfig4(&c4)
// Set the default values for the fields not configurable via web API.
c4 := &V4ServerConf{
notify: s.onNotify,
ICMPTimeout: s.conf.Conf4.ICMPTimeout,
Options: s.conf.Conf4.Options,
}
s.srv4.WriteDiskConfig4(c4)
v4Conf.notify = c4.notify
v4Conf.ICMPTimeout = c4.ICMPTimeout
v4Conf.Options = c4.Options
srv4, err = v4Create(v4Conf)
srv4, err := v4Create(v4Conf)
return srv4, enabled, err
return srv4, srv4.enabled(), err
}
func (s *Server) handleDHCPSetConfigV6(
func (s *server) handleDHCPSetConfigV6(
conf *dhcpServerConfigJSON,
) (srv6 DHCPServer, enabled bool, err error) {
if conf.V6 == nil {
@@ -205,7 +207,7 @@ func (s *Server) handleDHCPSetConfigV6(
return srv6, enabled, err
}
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
conf := &dhcpServerConfigJSON{}
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
conf.InterfaceName = s.conf.InterfaceName
@@ -287,7 +289,7 @@ type netInterfaceJSON struct {
Addrs6 []net.IP `json:"ipv6_addresses"`
}
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
response := map[string]netInterfaceJSON{}
ifaces, err := net.Interfaces()
@@ -406,31 +408,37 @@ type dhcpSearchResult struct {
V6 dhcpSearchV6Result `json:"v6"`
}
// Perform the following tasks:
// . Search for another DHCP server running
// . Check if a static IP is configured for the network interface
// Respond with results
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body)
// findActiveServerReq is the JSON structure for the request to find active DHCP
// servers.
type findActiveServerReq struct {
Interface string `json:"interface"`
}
// handleDHCPFindActiveServer performs the following tasks:
// 1. searches for another DHCP server in the network;
// 2. check if a static IP is configured for the network interface;
// 3. responds with the results.
func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
if aghhttp.WriteTextPlainDeprecated(w, r) {
return
}
req := &findActiveServerReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err)
log.Error(msg)
http.Error(w, msg, http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
ifaceName := strings.TrimSpace(string(body))
ifaceName := req.Interface
if ifaceName == "" {
msg := "empty interface name specified"
log.Error(msg)
http.Error(w, msg, http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusBadRequest, "empty interface name")
return
}
result := dhcpSearchResult{
result := &dhcpSearchResult{
V4: dhcpSearchV4Result{
OtherServer: dhcpSearchOtherResult{
Found: "no",
@@ -455,6 +463,14 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String()
}
setOtherDHCPResult(ifaceName, result)
_ = aghhttp.WriteJSONResponse(w, r, result)
}
// setOtherDHCPResult sets the results of the check for another DHCP server in
// result.
func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName)
if err4 != nil {
result.V4.OtherServer.Found = "error"
@@ -462,27 +478,16 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
} else if found4 {
result.V4.OtherServer.Found = "yes"
}
if err6 != nil {
result.V6.OtherServer.Found = "error"
result.V6.OtherServer.Error = err6.Error()
} else if found6 {
result.V6.OtherServer.Found = "yes"
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Failed to marshal DHCP found json: %s",
err,
)
}
}
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
@@ -497,20 +502,16 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
return
}
ip4 := l.IP.To4()
if ip4 == nil {
var srv DHCPServer
if ip4 := l.IP.To4(); ip4 != nil {
l.IP = ip4
srv = s.srv4
} else {
l.IP = l.IP.To16()
err = s.srv6.AddStaticLease(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
}
return
srv = s.srv6
}
l.IP = ip4
err = s.srv4.AddStaticLease(l)
err = srv.AddStaticLease(l)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -518,7 +519,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
}
}
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
l := &Lease{}
err := json.NewDecoder(r.Body).Decode(l)
if err != nil {
@@ -555,14 +556,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
}
}
const (
// DefaultDHCPLeaseTTL is the default time-to-live for leases.
DefaultDHCPLeaseTTL = uint32(timeutil.Day / time.Second)
// DefaultDHCPTimeoutICMP is the default timeout for waiting ICMP responses.
DefaultDHCPTimeoutICMP = 1000
)
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
err := s.Stop()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
@@ -586,7 +580,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
DBFilePath: s.conf.DBFilePath,
}
v4conf := V4ServerConf{
v4conf := &V4ServerConf{
LeaseDuration: DefaultDHCPLeaseTTL,
ICMPTimeout: DefaultDHCPTimeoutICMP,
notify: s.onNotify,
@@ -602,7 +596,7 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
s.conf.ConfigModified()
}
func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
err := s.resetLeases()
if err != nil {
msg := "resetting leases: %s"
@@ -612,7 +606,11 @@ func (s *Server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) registerHandlers() {
func (s *server) registerHandlers() {
if s.conf.HTTPRegister == nil {
return
}
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
@@ -622,44 +620,3 @@ func (s *Server) registerHandlers() {
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
}
// jsonError is a generic JSON error response.
//
// TODO(a.garipov): Merge together with the implementations in .../home and
// other packages after refactoring the web handler registering.
type jsonError struct {
// Message is the error message, an opaque string.
Message string `json:"message"`
}
// notImplemented returns a handler that replies to any request with an HTTP 501
// Not Implemented status and a JSON error with the provided message msg.
//
// TODO(a.garipov): Either take the logger from the server after we've
// refactored logging or make this not a method of *Server.
func (s *Server) notImplemented(msg string) (f func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotImplemented)
err := json.NewEncoder(w).Encode(&jsonError{
Message: msg,
})
if err != nil {
log.Debug("writing 501 json response: %s", err)
}
}
}
func (s *Server) registerNotImplementedHandlers() {
h := s.notImplemented("dhcp is not supported on windows")
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", h)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", h)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", h)
}

View File

@@ -0,0 +1,55 @@
//go:build windows
package dhcpd
import (
"encoding/json"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
)
// jsonError is a generic JSON error response.
//
// TODO(a.garipov): Merge together with the implementations in .../home and
// other packages after refactoring the web handler registering.
type jsonError struct {
// Message is the error message, an opaque string.
Message string `json:"message"`
}
// notImplemented is a handler that replies to any request with an HTTP 501 Not
// Implemented status and a JSON error with the provided message msg.
//
// TODO(a.garipov): Either take the logger from the server after we've
// refactored logging or make this not a method of *Server.
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotImplemented)
err := json.NewEncoder(w).Encode(&jsonError{
Message: aghos.Unsupported("dhcp").Error(),
})
if err != nil {
log.Debug("writing 501 json response: %s", err)
}
}
// registerHandlers sets the handlers for DHCP HTTP API that always respond with
// an HTTP 501, since DHCP server doesn't work on Windows yet.
//
// TODO(a.garipov): This needs refactoring. We shouldn't even try and
// initialize a DHCP server on Windows, but there are currently too many
// interconnected parts--such as HTTP handlers and frontend--to make that work
// properly.
func (s *server) registerHandlers() {
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.notImplemented)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
}

View File

@@ -1,23 +1,28 @@
//go:build windows
package dhcpd
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer_notImplemented(t *testing.T) {
s := &Server{}
h := s.notImplemented("never!")
s := &server{}
w := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, "/unsupported", nil)
require.NoError(t, err)
h(w, r)
s.notImplemented(w, r)
assert.Equal(t, http.StatusNotImplemented, w.Code)
assert.Equal(t, `{"message":"never!"}`+"\n", w.Body.String())
wantStr := fmt.Sprintf("{%q:%q}", "message", aghos.Unsupported("dhcp"))
assert.JSONEq(t, wantStr, w.Body.String())
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -197,10 +196,10 @@ func parseDHCPOption(s string) (code dhcpv4.OptionCode, val dhcpv4.OptionValue,
// prepareOptions builds the set of DHCP options according to host requirements
// document and values from conf.
func prepareOptions(conf V4ServerConf) (implicit, explicit dhcpv4.Options) {
func (s *v4Server) prepareOptions() {
// Set default values of host configuration parameters listed in Appendix A
// of RFC-2131.
implicit = dhcpv4.OptionsFromList(
s.implicitOpts = dhcpv4.OptionsFromList(
// IP-Layer Per Host
// An Internet host that includes embedded gateway code MUST have a
@@ -376,14 +375,14 @@ func prepareOptions(conf V4ServerConf) (implicit, explicit dhcpv4.Options) {
// Set the Router Option to working subnet's IP since it's initialized
// with the address of the gateway.
dhcpv4.OptRouter(conf.subnet.IP),
dhcpv4.OptRouter(s.conf.subnet.IP),
dhcpv4.OptSubnetMask(conf.subnet.Mask),
dhcpv4.OptSubnetMask(s.conf.subnet.Mask),
)
// Set values for explicitly configured options.
explicit = dhcpv4.Options{}
for i, o := range conf.Options {
s.explicitOpts = dhcpv4.Options{}
for i, o := range s.conf.Options {
code, val, err := parseDHCPOption(o)
if err != nil {
log.Error("dhcpv4: bad option string at index %d: %s", i, err)
@@ -391,17 +390,15 @@ func prepareOptions(conf V4ServerConf) (implicit, explicit dhcpv4.Options) {
continue
}
explicit.Update(dhcpv4.Option{Code: code, Value: val})
s.explicitOpts.Update(dhcpv4.Option{Code: code, Value: val})
// Remove those from the implicit options.
delete(implicit, code.Code())
delete(s.implicitOpts, code.Code())
}
log.Debug("dhcpv4: implicit options:\n%s", implicit.Summary(nil))
log.Debug("dhcpv4: explicit options:\n%s", explicit.Summary(nil))
log.Debug("dhcpv4: implicit options:\n%s", s.implicitOpts.Summary(nil))
log.Debug("dhcpv4: explicit options:\n%s", s.explicitOpts.Summary(nil))
if len(explicit) == 0 {
explicit = nil
if len(s.explicitOpts) == 0 {
s.explicitOpts = nil
}
return implicit, explicit
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -250,17 +249,21 @@ func TestPrepareOptions(t *testing.T) {
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
implicit, explicit := prepareOptions(V4ServerConf{
s := &v4Server{
conf: &V4ServerConf{
// Just to avoid nil pointer dereference.
subnet: &net.IPNet{},
Options: tc.opts,
})
},
}
assert.Equal(t, tc.wantExplicit, explicit)
t.Run(tc.name, func(t *testing.T) {
s.prepareOptions()
for c := range explicit {
assert.NotContains(t, implicit, c)
assert.Equal(t, tc.wantExplicit, s.explicitOpts)
for c := range s.explicitOpts {
assert.NotContains(t, s.implicitOpts, c)
}
})
}

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package dhcpd

View File

@@ -1,12 +0,0 @@
package dhcpd
import (
"time"
)
// Currently used defaults for ifaceDNSAddrs.
const (
defaultMaxAttempts int = 10
defaultBackoff time.Duration = 500 * time.Millisecond
)

View File

@@ -1,5 +1,4 @@
//go:build windows
// +build windows
package dhcpd
@@ -9,15 +8,19 @@ import "net"
type winServer struct{}
func (s *winServer) ResetLeases(_ []*Lease) (err error) { return nil }
func (s *winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
func (s *winServer) getLeasesRef() []*Lease { return nil }
func (s *winServer) AddStaticLease(_ *Lease) (err error) { return nil }
func (s *winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
func (s *winServer) FindMACbyIP(ip net.IP) (mac net.HardwareAddr) { return nil }
func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {}
func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {}
func (s *winServer) Start() (err error) { return nil }
func (s *winServer) Stop() (err error) { return nil }
func v4Create(conf V4ServerConf) (DHCPServer, error) { return &winServer{}, nil }
func v6Create(conf V6ServerConf) (DHCPServer, error) { return &winServer{}, nil }
// type check
var _ DHCPServer = winServer{}
func (winServer) ResetLeases(_ []*Lease) (err error) { return nil }
func (winServer) GetLeases(_ GetLeasesFlags) (leases []*Lease) { return nil }
func (winServer) getLeasesRef() []*Lease { return nil }
func (winServer) AddStaticLease(_ *Lease) (err error) { return nil }
func (winServer) RemoveStaticLease(_ *Lease) (err error) { return nil }
func (winServer) FindMACbyIP(_ net.IP) (mac net.HardwareAddr) { return nil }
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
func (winServer) Start() (err error) { return nil }
func (winServer) Stop() (err error) { return nil }
func v4Create(_ *V4ServerConf) (s DHCPServer, err error) { return winServer{}, nil }
func v6Create(_ V6ServerConf) (s DHCPServer, err error) { return winServer{}, nil }

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -30,8 +29,9 @@ import (
//
// TODO(a.garipov): Think about unifying this and v6Server.
type v4Server struct {
conf V4ServerConf
srv *server4.Server
conf *V4ServerConf
srv *server4.Server
// implicitOpts are the options listed in Appendix A of RFC 2131 initialized
// with default values. It must not have intersections with [explicitOpts].
@@ -55,9 +55,15 @@ type v4Server struct {
leases []*Lease
}
func (s *v4Server) enabled() (ok bool) {
return s.conf != nil && s.conf.Enabled
}
// WriteDiskConfig4 - write configuration
func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) {
*c = s.conf
if s.conf != nil {
*c = *s.conf
}
}
// WriteDiskConfig6 - write configuration
@@ -114,8 +120,8 @@ func (s *v4Server) validHostnameForClient(cliHostname string, ip net.IP) (hostna
func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if !s.conf.Enabled {
return
if s.conf == nil {
return nil
}
s.leasedOffsets = newBitSet()
@@ -129,12 +135,7 @@ func (s *v4Server) ResetLeases(leases []*Lease) (err error) {
err = s.addLease(l)
if err != nil {
// TODO(a.garipov): Wrap and bubble up the error.
log.Error(
"dhcpv4: reset: re-adding a lease for %s (%s): %s",
l.IP,
l.HWAddr,
err,
)
log.Error("dhcpv4: reset: re-adding a lease for %s (%s): %s", l.IP, l.HWAddr, err)
continue
}
@@ -336,11 +337,19 @@ func (s *v4Server) rmLease(lease *Lease) (err error) {
return errors.Error("lease not found")
}
// ErrUnconfigured is returned from the server's method when it requires the
// server to be configured and it's not.
const ErrUnconfigured errors.Error = "server is unconfigured"
// AddStaticLease implements the DHCPServer interface for *v4Server. It is safe
// for concurrent use.
func (s *v4Server) AddStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
if s.conf == nil {
return ErrUnconfigured
}
ip := l.IP.To4()
if ip == nil {
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
@@ -414,6 +423,10 @@ func (s *v4Server) AddStaticLease(l *Lease) (err error) {
func (s *v4Server) RemoveStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if s.conf == nil {
return ErrUnconfigured
}
if len(l.IP) != 4 {
return fmt.Errorf("invalid IP")
}
@@ -1086,12 +1099,6 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
s.send(peer, conn, req, resp)
}
// minDHCPMsgSize is the minimum length of the encoded DHCP message in bytes
// according to RFC-2131.
//
// See https://datatracker.ietf.org/doc/html/rfc2131#section-2.
const minDHCPMsgSize = 576
// send writes resp for peer to conn considering the req's parameters according
// to RFC-2131.
//
@@ -1133,16 +1140,6 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
}
pktData := resp.ToBytes()
pktLen := len(pktData)
if pktLen < minDHCPMsgSize {
// Expand the packet to match the minimum DHCP message length. Although
// the dhpcv4 package deals with the BOOTP's lower packet length
// constraint, it seems some clients expecting the length being at least
// 576 bytes as per RFC 2131 (and an obsolete RFC 1533).
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/4337.
pktData = append(pktData, make([]byte, minDHCPMsgSize-pktLen)...)
}
log.Debug("dhcpv4: sending %d bytes to %s: %s", len(pktData), peer, resp.Summary())
@@ -1156,7 +1153,7 @@ func (s *v4Server) send(peer net.Addr, conn net.PacketConn, req, resp *dhcpv4.DH
func (s *v4Server) Start() (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if !s.conf.Enabled {
if !s.enabled() {
return nil
}
@@ -1248,62 +1245,20 @@ func (s *v4Server) Stop() (err error) {
}
// Create DHCPv4 server
func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
s := &v4Server{}
s.conf = conf
s.leaseHosts = stringutil.NewSet()
// TODO(a.garipov): Don't use a disabled server in other places or just
// use an interface.
if !conf.Enabled {
return s, nil
func v4Create(conf *V4ServerConf) (srv *v4Server, err error) {
s := &v4Server{
leaseHosts: stringutil.NewSet(),
}
var routerIP net.IP
routerIP, err = tryTo4(s.conf.GatewayIP)
err = conf.Validate()
if err != nil {
return s, fmt.Errorf("dhcpv4: %w", err)
// TODO(a.garipov): Don't use a disabled server in other places or just
// use an interface.
return s, err
}
if s.conf.SubnetMask == nil {
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask)
}
subnetMask := make([]byte, 4)
copy(subnetMask, s.conf.SubnetMask.To4())
s.conf.subnet = &net.IPNet{
IP: routerIP,
Mask: subnetMask,
}
s.conf.broadcastIP = aghnet.BroadcastFromIPNet(s.conf.subnet)
s.conf.ipRange, err = newIPRange(conf.RangeStart, conf.RangeEnd)
if err != nil {
return s, fmt.Errorf("dhcpv4: %w", err)
}
if s.conf.ipRange.contains(routerIP) {
return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v",
routerIP,
conf.RangeStart,
conf.RangeEnd,
)
}
if !s.conf.subnet.Contains(conf.RangeStart) {
return s, fmt.Errorf("dhcpv4: range start %v is outside network %v",
conf.RangeStart,
s.conf.subnet,
)
}
if !s.conf.subnet.Contains(conf.RangeEnd) {
return s, fmt.Errorf("dhcpv4: range end %v is outside network %v",
conf.RangeEnd,
s.conf.subnet,
)
}
s.conf = &V4ServerConf{}
*s.conf = *conf
// TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange.
s.leasedOffsets = newBitSet()
@@ -1315,7 +1270,7 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration)
}
s.implicitOpts, s.explicitOpts = prepareOptions(s.conf)
s.prepareOptions()
return s, nil
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -30,19 +29,16 @@ var (
DefaultSubnetMask = net.IP{255, 255, 255, 0}
)
func notify4(flags uint32) {
}
// defaultV4ServerConf returns the default configuration for *v4Server to use in
// tests.
func defaultV4ServerConf() (conf V4ServerConf) {
return V4ServerConf{
func defaultV4ServerConf() (conf *V4ServerConf) {
return &V4ServerConf{
Enabled: true,
RangeStart: DefaultRangeStart,
RangeEnd: DefaultRangeEnd,
GatewayIP: DefaultGatewayIP,
SubnetMask: DefaultSubnetMask,
notify: notify4,
notify: testNotify,
dnsIPAddrs: []net.IP{DefaultSelfIP},
}
}
@@ -350,13 +346,10 @@ func TestV4Server_handle_optionsPriority(t *testing.T) {
defer func() { s.implicitOpts.Update(dhcpv4.OptDNS(defaultIP)) }()
}
ss, err := v4Create(conf)
var err error
s, err = v4Create(conf)
require.NoError(t, err)
var ok bool
s, ok = ss.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{defaultIP}
return s
@@ -490,10 +483,9 @@ func TestV4Server_updateOptions(t *testing.T) {
require.NoError(t, err)
require.IsType(t, (*v4Server)(nil), s)
s4, _ := s.(*v4Server)
t.Run(tc.name, func(t *testing.T) {
s4.updateOptions(req, resp)
s.updateOptions(req, resp)
for c, v := range tc.wantOpts {
if v == nil {
@@ -596,13 +588,9 @@ func TestV4DynamicLease_Get(t *testing.T) {
"82 ip 1.2.3.4",
}
var err error
sIface, err := v4Create(conf)
s, err := v4Create(conf)
require.NoError(t, err)
s, ok := sIface.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
s.implicitOpts.Update(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd
@@ -173,7 +172,7 @@ func (s *v6Server) rmDynamicLease(lease *Lease) (err error) {
func (s *v6Server) AddStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != 16 {
if len(l.IP) != net.IPv6len {
return fmt.Errorf("invalid IP")
}

View File

@@ -1,5 +1,4 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build darwin || freebsd || linux || openbsd
package dhcpd

View File

@@ -128,9 +128,14 @@ type FilteringConfig struct {
// IpsetList is the ipset configuration that allows AdGuard Home to add
// IP addresses of the specified domain names to an ipset list. Syntax:
//
// DOMAIN[,DOMAIN].../IPSET_NAME
// DOMAIN[,DOMAIN].../IPSET_NAME
//
// This field is ignored if [IpsetListFileName] is set.
IpsetList []string `yaml:"ipset"`
// IpsetListFileName, if set, points to the file with ipset configuration.
// The format is the same as in [IpsetList].
IpsetListFileName string `yaml:"ipset_file"`
}
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
@@ -400,6 +405,26 @@ func setProxyUpstreamMode(
}
}
// prepareIpsetListSettings reads and prepares the ipset configuration either
// from a file or from the data in the configuration file.
func (s *Server) prepareIpsetListSettings() (err error) {
fn := s.conf.IpsetListFileName
if fn == "" {
return s.ipset.init(s.conf.IpsetList)
}
data, err := os.ReadFile(fn)
if err != nil {
return err
}
ipsets := stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
return s.ipset.init(ipsets)
}
// prepareTLS - prepares TLS configuration for the DNS proxy
func (s *Server) prepareTLS(proxyConfig *proxy.Config) error {
if len(s.conf.CertificateChainData) == 0 || len(s.conf.PrivateKeyData) == 0 {

View File

@@ -296,7 +296,7 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
&dns.SVCBDoHPath{Template: "/dns-query{?dns}"},
}
ans := &dns.SVCB{

View File

@@ -26,7 +26,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: 8044},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
&dns.SVCBDoHPath{Template: "/dns-query{?dns}"},
},
}
@@ -267,7 +267,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := &Server{
dhcpServer: &testDHCP{},
dhcpServer: testDHCP,
localDomainSuffix: defaultLocalDomainSuffix,
tableHostToIP: hostToIPTable{
"example." + defaultLocalDomainSuffix: knownIP,
@@ -378,7 +378,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
for _, tc := range testCases {
s := &Server{
dhcpServer: &testDHCP{},
dhcpServer: testDHCP,
localDomainSuffix: tc.suffix,
tableHostToIP: hostToIPTable{
"example." + tc.suffix: knownIP,

View File

@@ -58,10 +58,10 @@ type hostToIPTable = map[string]net.IP
//
// The zero Server is empty and ready for use.
type Server struct {
dnsProxy *proxy.Proxy // DNS proxy instance
dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.ServerInterface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance
dnsProxy *proxy.Proxy // DNS proxy instance
dnsFilter *filtering.DNSFilter // DNS filter instance
dhcpServer dhcpd.Interface // DHCP server instance (optional)
queryLog querylog.QueryLog // Query log instance
stats stats.Interface
access *accessCtx
@@ -110,7 +110,7 @@ type DNSCreateParams struct {
DNSFilter *filtering.DNSFilter
Stats stats.Interface
QueryLog querylog.QueryLog
DHCPServer dhcpd.ServerInterface
DHCPServer dhcpd.Interface
PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut
LocalDomain string
@@ -446,10 +446,10 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings()
err = s.ipset.init(s.conf.IpsetList)
err = s.prepareIpsetListSettings()
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
return fmt.Errorf("preparing ipset settings: %w", err)
}
err = s.prepareUpstreamSettings()

View File

@@ -67,12 +67,13 @@ func createTestServer(
ID: 0, Data: []byte(rules),
}}
f := filtering.New(filterConf, filters)
f, err := filtering.New(filterConf, filters)
require.NoError(t, err)
f.SetEnabled(true)
var err error
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DHCPServer: testDHCP,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -774,9 +775,11 @@ func TestBlockedCustomIP(t *testing.T) {
Data: []byte(rules),
}}
f := filtering.New(&filtering.Config{}, filters)
f, err := filtering.New(&filtering.Config{}, filters)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DHCPServer: testDHCP,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -906,11 +909,13 @@ func TestRewrite(t *testing.T) {
Type: dns.TypeCNAME,
}},
}
f := filtering.New(c, nil)
f, err := filtering.New(c, nil)
require.NoError(t, err)
f.SetEnabled(true)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DHCPServer: testDHCP,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})
@@ -1005,26 +1010,31 @@ func publicKey(priv any) any {
}
}
type testDHCP struct{}
func (d *testDHCP) Enabled() (ok bool) { return true }
func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
return []*dhcpd.Lease{{
IP: net.IP{192, 168, 12, 34},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Hostname: "myhost",
}}
var testDHCP = &dhcpd.MockInterface{
OnStart: func() (err error) { panic("not implemented") },
OnStop: func() (err error) { panic("not implemented") },
OnEnabled: func() (ok bool) { return true },
OnLeases: func(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) {
return []*dhcpd.Lease{{
IP: net.IP{192, 168, 12, 34},
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
Hostname: "myhost",
}}
},
OnSetOnLeaseChanged: func(olct dhcpd.OnLeaseChangedT) {},
OnFindMACbyIP: func(ip net.IP) (mac net.HardwareAddr) { panic("not implemented") },
OnWriteDiskConfig: func(c *dhcpd.ServerConfig) { panic("not implemented") },
}
func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {}
func TestPTRResponseFromDHCPLeases(t *testing.T) {
const localDomain = "lan"
flt, err := filtering.New(&filtering.Config{}, nil)
require.NoError(t, err)
s, err := NewServer(DNSCreateParams{
DNSFilter: filtering.New(&filtering.Config{}, nil),
DHCPServer: &testDHCP{},
DNSFilter: flt,
DHCPServer: testDHCP,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
LocalDomain: localDomain,
})
@@ -1090,14 +1100,16 @@ func TestPTRResponseFromHosts(t *testing.T) {
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
})
flt := filtering.New(&filtering.Config{
flt, err := filtering.New(&filtering.Config{
EtcHosts: hc,
}, nil)
require.NoError(t, err)
flt.SetEnabled(true)
var s *Server
s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DHCPServer: testDHCP,
DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})

View File

@@ -35,11 +35,12 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
ID: 0, Data: []byte(rules),
}}
f := filtering.New(&filtering.Config{}, filters)
f, err := filtering.New(&filtering.Config{}, filters)
require.NoError(t, err)
f.SetEnabled(true)
s, err := NewServer(DNSCreateParams{
DHCPServer: &testDHCP{},
DHCPServer: testDHCP,
DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
})

View File

@@ -421,31 +421,34 @@ func initBlockedServices() {
}
// BlockedSvcKnown - return TRUE if a blocked service name is known
func BlockedSvcKnown(s string) bool {
_, ok := serviceRules[s]
func BlockedSvcKnown(s string) (ok bool) {
_, ok = serviceRules[s]
return ok
}
// ApplyBlockedServices - set blocked services settings for this DNS request
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string, global bool) {
func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
setts.ServicesRules = []ServiceEntry{}
if global {
if list == nil {
d.confLock.RLock()
defer d.confLock.RUnlock()
list = d.Config.BlockedServices
}
for _, name := range list {
rules, ok := serviceRules[name]
if !ok {
log.Error("unknown service name: %s", name)
continue
}
s := ServiceEntry{}
s.Name = name
s.Rules = rules
setts.ServicesRules = append(setts.ServicesRules, s)
setts.ServicesRules = append(setts.ServicesRules, ServiceEntry{
Name: name,
Rules: rules,
})
}
}
@@ -490,10 +493,3 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
d.ConfigModified()
}
// registerBlockedServicesHandlers - register HTTP handlers
func (d *DNSFilter) registerBlockedServicesHandlers() {
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
d.Config.HTTPRegister(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
d.Config.HTTPRegister(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
}

View File

@@ -49,7 +49,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot.
`
f := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &Settings{
FilteringEnabled: true,
}

View File

@@ -1,4 +1,4 @@
package home
package filtering
import (
"bufio"
@@ -8,63 +8,29 @@ import (
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
)
var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
// filterDir is the subdirectory of a data directory to store downloaded
// filters.
const filterDir = "filters"
// Filtering - module object
type Filtering struct {
// conf FilteringConf
refreshStatus uint32 // 0:none; 1:in progress
refreshLock sync.Mutex
filterTitleRegexp *regexp.Regexp
}
// nextFilterID is a way to seed a unique ID generation.
//
// TODO(e.burkov): Use more deterministic approach.
var nextFilterID = time.Now().Unix()
// Init - initialize the module
func (f *Filtering) Init() {
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755)
f.loadFilters(config.Filters)
f.loadFilters(config.WhitelistFilters)
deduplicateFilters()
updateUniqueFilterID(config.Filters)
updateUniqueFilterID(config.WhitelistFilters)
}
// Start - start the module
func (f *Filtering) Start() {
f.RegisterFilteringHandlers()
// Here we should start updating filters,
// but currently we can't wake up the periodic task to do so.
// So for now we just start this periodic task from here.
go f.periodicallyRefreshFilters()
}
// Close - close the module
func (f *Filtering) Close() {
}
func defaultFilters() []filter {
return []filter{
{Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
{Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
}
}
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
// FilterYAML respresents a filter list in the configuration file.
//
// TODO(e.burkov): Investigate if the field oredering is important.
type FilterYAML struct {
Enabled bool
URL string // URL or a file path
Name string `yaml:"name"`
@@ -73,91 +39,108 @@ type filter struct {
checksum uint32 // checksum of the file data
white bool
filtering.Filter `yaml:",inline"`
Filter `yaml:",inline"`
}
// Clear filter rules
func (filter *FilterYAML) unload() {
filter.RulesCount = 0
filter.checksum = 0
}
// Path to the filter contents
func (filter *FilterYAML) Path(dataDir string) string {
return filepath.Join(dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
const (
statusFound = 1
statusEnabledChanged = 2
statusURLChanged = 4
statusURLExists = 8
statusUpdateRequired = 0x10
statusFound = 1 << iota
statusEnabledChanged
statusURLChanged
statusURLExists
statusUpdateRequired
)
// Update properties for a filter specified by its URL
// Return status* flags.
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
func (d *DNSFilter) filterSetProperties(url string, newf FilterYAML, whitelist bool) int {
r := 0
config.Lock()
defer config.Unlock()
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
filters := &config.Filters
filters := d.Filters
if whitelist {
filters = &config.WhitelistFilters
filters = d.WhitelistFilters
}
for i := range *filters {
filt := &(*filters)[i]
if filt.URL != url {
continue
i := slices.IndexFunc(filters, func(filt FilterYAML) bool {
return filt.URL == url
})
if i == -1 {
return 0
}
filt := &filters[i]
log.Debug("filter: set properties: %s: {%s %s %v}", filt.URL, newf.Name, newf.URL, newf.Enabled)
filt.Name = newf.Name
if filt.URL != newf.URL {
r |= statusURLChanged | statusUpdateRequired
if d.filterExistsNoLock(newf.URL) {
return statusURLExists
}
log.Debug("filter: set properties: %s: {%s %s %v}",
filt.URL, newf.Name, newf.URL, newf.Enabled)
filt.Name = newf.Name
filt.URL = newf.URL
filt.unload()
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
}
if filt.URL != newf.URL {
r |= statusURLChanged | statusUpdateRequired
if filterExistsNoLock(newf.URL) {
return statusURLExists
}
filt.URL = newf.URL
filt.unload()
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
}
if filt.Enabled != newf.Enabled {
r |= statusEnabledChanged
filt.Enabled = newf.Enabled
if filt.Enabled {
if (r & statusURLChanged) == 0 {
err := d.load(filt)
if err != nil {
// TODO(e.burkov): It seems the error is only returned when
// the file exists and couldn't be open. Investigate and
// improve.
log.Error("loading filter %d: %s", filt.ID, err)
if filt.Enabled != newf.Enabled {
r |= statusEnabledChanged
filt.Enabled = newf.Enabled
if filt.Enabled {
if (r & statusURLChanged) == 0 {
e := f.load(filt)
if e != nil {
// This isn't a fatal error,
// because it may occur when someone removes the file from disk.
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
r |= statusUpdateRequired
}
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
r |= statusUpdateRequired
}
} else {
filt.unload()
}
} else {
filt.unload()
}
return r | statusFound
}
return 0
return r | statusFound
}
// Return TRUE if a filter with this URL exists
func filterExists(url string) bool {
config.RLock()
r := filterExistsNoLock(url)
config.RUnlock()
func (d *DNSFilter) filterExists(url string) bool {
d.filtersMu.RLock()
defer d.filtersMu.RUnlock()
r := d.filterExistsNoLock(url)
return r
}
func filterExistsNoLock(url string) bool {
for _, f := range config.Filters {
func (d *DNSFilter) filterExistsNoLock(url string) bool {
for _, f := range d.Filters {
if f.URL == url {
return true
}
}
for _, f := range config.WhitelistFilters {
for _, f := range d.WhitelistFilters {
if f.URL == url {
return true
}
@@ -167,26 +150,26 @@ func filterExistsNoLock(url string) bool {
// Add a filter
// Return FALSE if a filter with this URL exists
func filterAdd(f filter) bool {
config.Lock()
defer config.Unlock()
func (d *DNSFilter) filterAdd(flt FilterYAML) bool {
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
// Check for duplicates
if filterExistsNoLock(f.URL) {
if d.filterExistsNoLock(flt.URL) {
return false
}
if f.white {
config.WhitelistFilters = append(config.WhitelistFilters, f)
if flt.white {
d.WhitelistFilters = append(d.WhitelistFilters, flt)
} else {
config.Filters = append(config.Filters, f)
d.Filters = append(d.Filters, flt)
}
return true
}
// Load filters from the disk
// And if any filter has zero ID, assign a new one
func (f *Filtering) loadFilters(array []filter) {
func (d *DNSFilter) loadFilters(array []FilterYAML) {
for i := range array {
filter := &array[i] // otherwise we're operating on a copy
if filter.ID == 0 {
@@ -198,32 +181,30 @@ func (f *Filtering) loadFilters(array []filter) {
continue
}
err := f.load(filter)
err := d.load(filter)
if err != nil {
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
}
}
}
func deduplicateFilters() {
// Deduplicate filters
i := 0 // output index, used for deletion later
urls := map[string]bool{}
for _, filter := range config.Filters {
if _, ok := urls[filter.URL]; !ok {
// we didn't see it before, keep it
urls[filter.URL] = true // remember the URL
config.Filters[i] = filter
i++
func deduplicateFilters(filters []FilterYAML) (deduplicated []FilterYAML) {
urls := stringutil.NewSet()
lastIdx := 0
for _, filter := range filters {
if !urls.Has(filter.URL) {
urls.Add(filter.URL)
filters[lastIdx] = filter
lastIdx++
}
}
// all entries we want to keep are at front, delete the rest
config.Filters = config.Filters[:i]
return filters[:lastIdx]
}
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
func updateUniqueFilterID(filters []FilterYAML) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
@@ -238,22 +219,19 @@ func assignUniqueFilterID() int64 {
}
// Sets up a timer that will be checking for filters updates periodically
func (f *Filtering) periodicallyRefreshFilters() {
func (d *DNSFilter) periodicallyRefreshFilters() {
const maxInterval = 1 * 60 * 60
intval := 5 // use a dynamically increasing time interval
for {
isNetworkErr := false
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
f.refreshLock.Lock()
_, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists)
f.refreshLock.Unlock()
f.refreshStatus = 0
if !isNetworkErr {
isNetErr, ok := false, false
if d.FiltersUpdateIntervalHours != 0 {
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
if ok && !isNetErr {
intval = maxInterval
}
}
if isNetworkErr {
if isNetErr {
intval *= 2
if intval > maxInterval {
intval = maxInterval
@@ -264,51 +242,73 @@ func (f *Filtering) periodicallyRefreshFilters() {
}
}
// Refresh filters
// flags: filterRefresh*
// important:
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
// already going on.
//
// TRUE: ignore the fact that we're currently updating the filters
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
if !important && !set {
return 0, fmt.Errorf("filters update procedure is already running")
// TODO(e.burkov): Get rid of the concurrency pattern which requires the
// sync.Mutex.TryLock.
func (d *DNSFilter) tryRefreshFilters(block, allow, force bool) (updated int, isNetworkErr, ok bool) {
if ok = d.refreshLock.TryLock(); !ok {
return 0, false, ok
}
defer d.refreshLock.Unlock()
f.refreshLock.Lock()
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
f.refreshLock.Unlock()
f.refreshStatus = 0
return nUpdated, nil
updated, isNetworkErr = d.refreshFiltersIntl(block, allow, force)
return updated, isNetworkErr, ok
}
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
var updateFilters []filter
// refreshFilters updates the lists and returns the number of updated ones.
// It's safe for concurrent use, but blocks at least until the previous
// refreshing is finished.
func (d *DNSFilter) refreshFilters(block, allow, force bool) (updated int) {
d.refreshLock.Lock()
defer d.refreshLock.Unlock()
updated, _ = d.refreshFiltersIntl(block, allow, force)
return updated
}
// listsToUpdate returns the slice of filter lists that could be updated.
func (d *DNSFilter) listsToUpdate(filters *[]FilterYAML, force bool) (toUpd []FilterYAML) {
now := time.Now()
d.filtersMu.RLock()
defer d.filtersMu.RUnlock()
for i := range *filters {
flt := &(*filters)[i] // otherwise we will be operating on a copy
log.Debug("checking list at index %d: %v", i, flt)
if !flt.Enabled {
continue
}
if !force {
exp := flt.LastUpdated.Add(time.Duration(d.FiltersUpdateIntervalHours) * time.Hour)
if now.Before(exp) {
continue
}
}
toUpd = append(toUpd, FilterYAML{
Filter: Filter{
ID: flt.ID,
},
URL: flt.URL,
Name: flt.Name,
checksum: flt.checksum,
})
}
return toUpd
}
func (d *DNSFilter) refreshFiltersArray(filters *[]FilterYAML, force bool) (int, []FilterYAML, []bool, bool) {
var updateFlags []bool // 'true' if filter data has changed
now := time.Now()
config.RLock()
for i := range *filters {
f := &(*filters)[i] // otherwise we will be operating on a copy
if !f.Enabled {
continue
}
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
if !force && expireTime > now.Unix() {
continue
}
var uf filter
uf.ID = f.ID
uf.URL = f.URL
uf.Name = f.Name
uf.checksum = f.checksum
updateFilters = append(updateFilters, uf)
}
config.RUnlock()
updateFilters := d.listsToUpdate(filters, force)
if len(updateFilters) == 0 {
return 0, nil, nil, false
}
@@ -316,7 +316,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
nfail := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated, err := f.update(uf)
updated, err := d.update(uf)
updateFlags = append(updateFlags, updated)
if err != nil {
nfail++
@@ -334,7 +334,7 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
uf := &updateFilters[i]
updated := updateFlags[i]
config.Lock()
d.filtersMu.Lock()
for k := range *filters {
f := &(*filters)[k]
if f.ID != uf.ID || f.URL != uf.URL {
@@ -352,20 +352,14 @@ func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []f
f.checksum = uf.checksum
updateCount++
}
config.Unlock()
d.filtersMu.Unlock()
}
return updateCount, updateFilters, updateFlags, false
}
const (
filterRefreshForce = 1 // ignore last file modification date
filterRefreshAllowlists = 2 // update allow-lists
filterRefreshBlocklists = 4 // update block-lists
)
// refreshFiltersIfNecessary checks filters and updates them if necessary. If
// force is true, it ignores the filter.LastUpdated field value.
// refreshFiltersIntl checks filters and updates them if necessary. If force is
// true, it ignores the filter.LastUpdated field value.
//
// Algorithm:
//
@@ -378,53 +372,49 @@ const (
// that this method works only on Unix systems. On Windows, don't pass
// files to filtering, pass the whole data.
//
// refreshFiltersIfNecessary returns the number of updated filters. It also
// returns true if there was a network error and nothing could be updated.
// refreshFiltersIntl returns the number of updated filters. It also returns
// true if there was a network error and nothing could be updated.
//
// TODO(a.garipov, e.burkov): What the hell?
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
log.Debug("Filters: updating...")
func (d *DNSFilter) refreshFiltersIntl(block, allow, force bool) (int, bool) {
log.Debug("filtering: updating...")
updateCount := 0
var updateFilters []filter
var updateFlags []bool
netError := false
netErrorW := false
force := false
if (flags & filterRefreshForce) != 0 {
force = true
updNum := 0
var lists []FilterYAML
var toUpd []bool
isNetErr := false
if block {
updNum, lists, toUpd, isNetErr = d.refreshFiltersArray(&d.Filters, force)
}
if (flags & filterRefreshBlocklists) != 0 {
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
if allow {
updNumAl, listsAl, toUpdAl, isNetErrAl := d.refreshFiltersArray(&d.WhitelistFilters, force)
updNum += updNumAl
lists = append(lists, listsAl...)
toUpd = append(toUpd, toUpdAl...)
isNetErr = isNetErr || isNetErrAl
}
if (flags & filterRefreshAllowlists) != 0 {
updateCountW := 0
var updateFiltersW []filter
var updateFlagsW []bool
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
updateCount += updateCountW
updateFilters = append(updateFilters, updateFiltersW...)
updateFlags = append(updateFlags, updateFlagsW...)
}
if netError && netErrorW {
if isNetErr {
return 0, true
}
if updateCount != 0 {
enableFilters(false)
if updNum != 0 {
d.EnableFilters(false)
for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
for i := range lists {
uf := &lists[i]
updated := toUpd[i]
if !updated {
continue
}
_ = os.Remove(uf.Path() + ".old")
_ = os.Remove(uf.Path(d.DataDir) + ".old")
}
}
log.Debug("Filters: update finished")
return updateCount, false
log.Debug("filtering: update finished")
return updNum, false
}
// Allows printable UTF-8 text with CR, LF, TAB characters
@@ -440,7 +430,7 @@ func isPrintableText(data []byte, len int) bool {
}
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
func (d *DNSFilter) parseFilterContents(file io.Reader) (int, uint32, string) {
rulesCount := 0
name := ""
seenTitle := false
@@ -455,7 +445,7 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
if len(line) == 0 {
//
} else if line[0] == '!' {
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
m := d.filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
@@ -476,11 +466,11 @@ func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
}
// Perform upgrade on a filter and update LastUpdated value
func (f *Filtering) update(filter *filter) (bool, error) {
b, err := f.updateIntl(filter)
func (d *DNSFilter) update(filter *FilterYAML) (bool, error) {
b, err := d.updateIntl(filter)
filter.LastUpdated = time.Now()
if !b {
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
e := os.Chtimes(filter.Path(d.DataDir), filter.LastUpdated, filter.LastUpdated)
if e != nil {
log.Error("os.Chtimes(): %v", e)
}
@@ -488,7 +478,7 @@ func (f *Filtering) update(filter *filter) (bool, error) {
return b, err
}
func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) {
func (d *DNSFilter) read(reader io.Reader, tmpFile *os.File, filter *FilterYAML) (int, error) {
htmlTest := true
firstChunk := make([]byte, 4*1024)
firstChunkLen := 0
@@ -539,20 +529,20 @@ func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (in
// finalizeUpdate closes and gets rid of temporary file f with filter's content
// according to updated. It also saves new values of flt's name, rules number
// and checksum if sucсeeded.
func finalizeUpdate(
f *os.File,
flt *filter,
func (d *DNSFilter) finalizeUpdate(
file *os.File,
flt *FilterYAML,
updated bool,
name string,
rnum int,
cs uint32,
) (err error) {
tmpFileName := f.Name()
tmpFileName := file.Name()
// Close the file before renaming it because it's required on Windows.
//
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
if err = f.Close(); err != nil {
if err = file.Close(); err != nil {
return fmt.Errorf("closing temporary file: %w", err)
}
@@ -562,9 +552,9 @@ func finalizeUpdate(
return os.Remove(tmpFileName)
}
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path())
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path(d.DataDir))
if err = os.Rename(tmpFileName, flt.Path()); err != nil {
if err = os.Rename(tmpFileName, flt.Path(d.DataDir)); err != nil {
return errors.WithDeferred(err, os.Remove(tmpFileName))
}
@@ -578,12 +568,12 @@ func finalizeUpdate(
// processUpdate copies filter's content from src to dst and returns the name,
// rules number, and checksum for it. It also returns the number of bytes read
// from src.
func (f *Filtering) processUpdate(
func (d *DNSFilter) processUpdate(
src io.Reader,
dst *os.File,
flt *filter,
flt *FilterYAML,
) (name string, rnum int, cs uint32, n int, err error) {
if n, err = f.read(src, dst, flt); err != nil {
if n, err = d.read(src, dst, flt); err != nil {
return "", 0, 0, 0, err
}
@@ -591,14 +581,14 @@ func (f *Filtering) processUpdate(
return "", 0, 0, 0, err
}
rnum, cs, name = f.parseFilterContents(dst)
rnum, cs, name = d.parseFilterContents(dst)
return name, rnum, cs, n, nil
}
// updateIntl updates the flt rewriting it's actual file. It returns true if
// the actual update has been performed.
func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
func (d *DNSFilter) updateIntl(flt *FilterYAML) (ok bool, err error) {
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
var name string
@@ -606,12 +596,12 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
var cs uint32
var tmpFile *os.File
tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "")
tmpFile, err = os.CreateTemp(filepath.Join(d.DataDir, filterDir), "")
if err != nil {
return false, err
}
defer func() {
err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
err = errors.WithDeferred(err, d.finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
ok = ok && err == nil
if ok {
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
@@ -638,7 +628,7 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
r = file
} else {
var resp *http.Response
resp, err = Context.client.Get(flt.URL)
resp, err = d.HTTPClient.Get(flt.URL)
if err != nil {
log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
@@ -655,16 +645,16 @@ func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
r = resp.Body
}
name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt)
name, rnum, cs, n, err = d.processUpdate(r, tmpFile, flt)
return cs != flt.checksum, err
}
// loads filter contents from the file in dataDir
func (f *Filtering) load(filter *filter) (err error) {
filterFilePath := filter.Path()
func (d *DNSFilter) load(filter *FilterYAML) (err error) {
filterFilePath := filter.Path(d.DataDir)
log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath)
log.Tracef("filtering: loading filter %d from %s", filter.ID, filterFilePath)
file, err := os.Open(filterFilePath)
if errors.Is(err, os.ErrNotExist) {
@@ -682,7 +672,7 @@ func (f *Filtering) load(filter *filter) (err error) {
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
rulesCount, checksum, _ := f.parseFilterContents(file)
rulesCount, checksum, _ := d.parseFilterContents(file)
filter.RulesCount = rulesCount
filter.checksum = checksum
@@ -691,56 +681,45 @@ func (f *Filtering) load(filter *filter) (err error) {
return nil
}
// Clear filter rules
func (filter *filter) unload() {
filter.RulesCount = 0
filter.checksum = 0
func (d *DNSFilter) EnableFilters(async bool) {
d.filtersMu.RLock()
defer d.filtersMu.RUnlock()
d.enableFiltersLocked(async)
}
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
func enableFilters(async bool) {
config.RLock()
defer config.RUnlock()
enableFiltersLocked(async)
}
func enableFiltersLocked(async bool) {
filters := []filtering.Filter{{
ID: filtering.CustomListID,
Data: []byte(strings.Join(config.UserRules, "\n")),
func (d *DNSFilter) enableFiltersLocked(async bool) {
filters := []Filter{{
ID: CustomListID,
Data: []byte(strings.Join(d.UserRules, "\n")),
}}
for _, filter := range config.Filters {
for _, filter := range d.Filters {
if !filter.Enabled {
continue
}
filters = append(filters, filtering.Filter{
filters = append(filters, Filter{
ID: filter.ID,
FilePath: filter.Path(),
FilePath: filter.Path(d.DataDir),
})
}
var allowFilters []filtering.Filter
for _, filter := range config.WhitelistFilters {
var allowFilters []Filter
for _, filter := range d.WhitelistFilters {
if !filter.Enabled {
continue
}
allowFilters = append(allowFilters, filtering.Filter{
allowFilters = append(allowFilters, Filter{
ID: filter.ID,
FilePath: filter.Path(),
FilePath: filter.Path(d.DataDir),
})
}
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
if err := d.SetFilters(filters, allowFilters, async); err != nil {
log.Debug("enabling filters: %s", err)
}
Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled)
d.SetEnabled(d.FilteringEnabled)
}

View File

@@ -1,4 +1,4 @@
package home
package filtering
import (
"io/fs"
@@ -51,15 +51,17 @@ func TestFilters(t *testing.T) {
l := testStartFilterListener(t, &fltContent)
Context = homeContext{
workDir: t.TempDir(),
client: &http.Client{
tempDir := t.TempDir()
filters, err := New(&Config{
DataDir: tempDir,
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
}
Context.filters.Init()
}, nil)
require.NoError(t, err)
f := &filter{
f := &FilterYAML{
URL: (&url.URL{
Scheme: "http",
Host: (&netutil.IPPort{
@@ -71,21 +73,22 @@ func TestFilters(t *testing.T) {
}
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
ok, err := Context.filters.update(f)
var ok bool
ok, err = filters.update(f)
require.NoError(t, err)
want(t, ok)
assert.Equal(t, wantRulesCount, f.RulesCount)
var dir []fs.DirEntry
dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir))
dir, err = os.ReadDir(filepath.Join(tempDir, filterDir))
require.NoError(t, err)
assert.Len(t, dir, 1)
require.FileExists(t, f.Path())
require.FileExists(t, f.Path(tempDir))
err = Context.filters.load(f)
err = filters.load(f)
require.NoError(t, err)
}
@@ -105,11 +108,9 @@ func TestFilters(t *testing.T) {
})
t.Run("load_unload", func(t *testing.T) {
err := Context.filters.load(f)
err = filters.load(f)
require.NoError(t, err)
f.unload()
})
require.NoError(t, os.Remove(f.Path()))
}

View File

@@ -6,7 +6,10 @@ import (
"fmt"
"io/fs"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"runtime/debug"
"strings"
@@ -24,6 +27,7 @@ import (
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// The IDs of built-in filter lists.
@@ -69,8 +73,13 @@ type Config struct {
// enabled is used to be returned within Settings.
//
// It is of type uint32 to be accessed by atomic.
//
// TODO(e.burkov): Use atomic.Bool in Go 1.19.
enabled uint32
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
ParentalEnabled bool `yaml:"parental_enabled"`
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
@@ -98,6 +107,24 @@ type Config struct {
// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver `yaml:"-"`
// HTTPClient is the client to use for updating the remote filters.
HTTPClient *http.Client `yaml:"-"`
// DataDir is used to store filters' contents.
DataDir string `yaml:"-"`
// filtersMu protects filter lists.
filtersMu *sync.RWMutex
// Filters are the blocking filter lists.
Filters []FilterYAML `yaml:"-"`
// WhitelistFilters are the allowing filter lists.
WhitelistFilters []FilterYAML `yaml:"-"`
// UserRules is the global list of custom rules.
UserRules []string `yaml:"-"`
}
// LookupStats store stats collected during safebrowsing or parental checks
@@ -128,11 +155,13 @@ type hostChecker struct {
// DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct {
rulesStorage *filterlist.RuleStorage
filteringEngine *urlfilter.DNSEngine
rulesStorage *filterlist.RuleStorage
filteringEngine *urlfilter.DNSEngine
rulesStorageAllow *filterlist.RuleStorage
filteringEngineAllow *urlfilter.DNSEngine
engineLock sync.RWMutex
engineLock sync.RWMutex
parentalServer string // access via methods
safeBrowsingServer string // access via methods
@@ -156,6 +185,12 @@ type DNSFilter struct {
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
resolver Resolver
refreshLock *sync.Mutex
// filterTitleRegexp is the regular expression to retrieve a name of a
// filter list.
filterTitleRegexp *regexp.Regexp
hostCheckers []hostChecker
}
@@ -168,7 +203,7 @@ type Filter struct {
Data []byte `yaml:"-"`
// ID is automatically assigned when filter is added using nextFilterID.
ID int64
ID int64 `yaml:"id"`
}
// Reason holds an enum detailing why it was filtered or not filtered
@@ -245,15 +280,7 @@ func (r Reason) String() string {
}
// In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) (ok bool) {
for _, reason := range reasons {
if r == reason {
return true
}
}
return false
}
func (r Reason) In(reasons ...Reason) (ok bool) { return slices.Contains(reasons, r) }
// SetEnabled sets the status of the *DNSFilter.
func (d *DNSFilter) SetEnabled(enabled bool) {
@@ -261,6 +288,7 @@ func (d *DNSFilter) SetEnabled(enabled bool) {
if enabled {
i = 1
}
atomic.StoreUint32(&d.enabled, uint32(i))
}
@@ -279,11 +307,20 @@ func (d *DNSFilter) GetConfig() (s Settings) {
// WriteDiskConfig - write configuration
func (d *DNSFilter) WriteDiskConfig(c *Config) {
d.confLock.Lock()
defer d.confLock.Unlock()
func() {
d.confLock.Lock()
defer d.confLock.Unlock()
*c = d.Config
c.Rewrites = cloneRewrites(c.Rewrites)
*c = d.Config
c.Rewrites = cloneRewrites(c.Rewrites)
}()
d.filtersMu.RLock()
defer d.filtersMu.RUnlock()
c.Filters = slices.Clone(d.Filters)
c.WhitelistFilters = slices.Clone(d.WhitelistFilters)
c.UserRules = slices.Clone(d.UserRules)
}
// cloneRewrites returns a deep copy of entries.
@@ -309,6 +346,8 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
}
d.filtersInitializerLock.Lock() // prevent multiple writers from adding more than 1 task
defer d.filtersInitializerLock.Unlock()
// remove all pending tasks
stop := false
for !stop {
@@ -321,7 +360,6 @@ func (d *DNSFilter) SetFilters(blockFilters, allowFilters []Filter, async bool)
}
d.filtersInitializerChan <- params
d.filtersInitializerLock.Unlock()
return nil
}
@@ -350,22 +388,19 @@ func (d *DNSFilter) filtersInitializer() {
func (d *DNSFilter) Close() {
d.engineLock.Lock()
defer d.engineLock.Unlock()
d.reset()
}
func (d *DNSFilter) reset() {
var err error
if d.rulesStorage != nil {
err = d.rulesStorage.Close()
if err != nil {
if err := d.rulesStorage.Close(); err != nil {
log.Error("filtering: rulesStorage.Close: %s", err)
}
}
if d.rulesStorageAllow != nil {
err = d.rulesStorageAllow.Close()
if err != nil {
if err := d.rulesStorageAllow.Close(); err != nil {
log.Error("filtering: rulesStorageAllow.Close: %s", err)
}
}
@@ -885,29 +920,30 @@ func InitModule() {
initBlockedServices()
}
// New creates properly initialized DNS Filter that is ready to be used.
func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
// New creates properly initialized DNS Filter that is ready to be used. c must
// be non-nil.
func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
d = &DNSFilter{
resolver: net.DefaultResolver,
resolver: net.DefaultResolver,
refreshLock: &sync.Mutex{},
filterTitleRegexp: regexp.MustCompile(`^! Title: +(.*)$`),
}
if c != nil {
d.safebrowsingCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeBrowsingCacheSize,
})
d.safeSearchCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeSearchCacheSize,
})
d.parentalCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.ParentalCacheSize,
})
d.safebrowsingCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeBrowsingCacheSize,
})
d.safeSearchCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.SafeSearchCacheSize,
})
d.parentalCache = cache.New(cache.Config{
EnableLRU: true,
MaxSize: c.ParentalCacheSize,
})
if c.CustomResolver != nil {
d.resolver = c.CustomResolver
}
if r := c.CustomResolver; r != nil {
d.resolver = r
}
d.hostCheckers = []hostChecker{{
@@ -930,27 +966,26 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
name: "safe search",
}}
err := d.initSecurityServices()
if err != nil {
log.Error("filtering: initialize services: %s", err)
defer func() { err = errors.Annotate(err, "filtering: %w") }()
return nil
err = d.initSecurityServices()
if err != nil {
return nil, fmt.Errorf("initializing services: %s", err)
}
if c != nil {
d.Config = *c
err = d.prepareRewrites()
if err != nil {
log.Error("rewrites: preparing: %s", err)
d.Config = *c
d.filtersMu = &sync.RWMutex{}
return nil
}
err = d.prepareRewrites()
if err != nil {
return nil, fmt.Errorf("rewrites: preparing: %s", err)
}
bsvcs := []string{}
for _, s := range d.BlockedServices {
if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service %q", s)
continue
}
bsvcs = append(bsvcs, s)
@@ -960,13 +995,24 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter) {
if blockFilters != nil {
err = d.initFiltering(nil, blockFilters)
if err != nil {
log.Error("Can't initialize filtering subsystem: %s", err)
d.Close()
return nil
return nil, fmt.Errorf("initializing filtering subsystem: %s", err)
}
}
return d
_ = os.MkdirAll(filepath.Join(d.DataDir, filterDir), 0o755)
d.loadFilters(d.Filters)
d.loadFilters(d.WhitelistFilters)
d.Filters = deduplicateFilters(d.Filters)
d.WhitelistFilters = deduplicateFilters(d.WhitelistFilters)
updateUniqueFilterID(d.Filters)
updateUniqueFilterID(d.WhitelistFilters)
return d, nil
}
// Start - start the module:
@@ -976,9 +1022,10 @@ func (d *DNSFilter) Start() {
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
go d.filtersInitializer()
if d.Config.HTTPRegister != nil { // for tests
d.registerSecurityHandlers()
d.registerRewritesHandlers()
d.registerBlockedServicesHandlers()
}
d.RegisterFilteringHandlers()
// Here we should start updating filters,
// but currently we can't wake up the periodic task to do so.
// So for now we just start this periodic task from here.
go d.periodicallyRefreshFilters()
}

View File

@@ -26,10 +26,6 @@ const (
pcBlocked = "pornhub.com"
)
var setts = Settings{
ProtectionEnabled: true,
}
// Helpers.
func purgeCaches(d *DNSFilter) {
@@ -44,8 +40,8 @@ func purgeCaches(d *DNSFilter) {
}
}
func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
setts = Settings{
func newForTest(t testing.TB, c *Config, filters []Filter) (f *DNSFilter, setts *Settings) {
setts = &Settings{
ProtectionEnabled: true,
FilteringEnabled: true,
}
@@ -57,26 +53,31 @@ func newForTest(t testing.TB, c *Config, filters []Filter) *DNSFilter {
setts.SafeSearchEnabled = c.SafeSearchEnabled
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
} else {
// It must not be nil.
c = &Config{}
}
d := New(c, filters)
purgeCaches(d)
f, err := New(c, filters)
require.NoError(t, err)
return d
purgeCaches(f)
return f, setts
}
func (d *DNSFilter) checkMatch(t *testing.T, hostname string) {
func (d *DNSFilter) checkMatch(t *testing.T, hostname string, setts *Settings) {
t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
res, err := d.CheckHost(hostname, dns.TypeA, setts)
require.NoErrorf(t, err, "host %q", hostname)
assert.Truef(t, res.IsFiltered, "host %q", hostname)
}
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16) {
func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16, setts *Settings) {
t.Helper()
res, err := d.CheckHost(hostname, qtype, &setts)
res, err := d.CheckHost(hostname, qtype, setts)
require.NoErrorf(t, err, "host %q", hostname, err)
require.NotEmpty(t, res.Rules, "host %q", hostname)
@@ -88,10 +89,10 @@ func (d *DNSFilter) checkMatchIP(t *testing.T, hostname, ip string, qtype uint16
assert.Equalf(t, ip, r.IP.String(), "host %q", hostname)
}
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string) {
func (d *DNSFilter) checkMatchEmpty(t *testing.T, hostname string, setts *Settings) {
t.Helper()
res, err := d.CheckHost(hostname, dns.TypeA, &setts)
res, err := d.CheckHost(hostname, dns.TypeA, setts)
require.NoErrorf(t, err, "host %q", hostname)
assert.Falsef(t, res.IsFiltered, "host %q", hostname)
@@ -111,19 +112,19 @@ func TestEtcHostsMatching(t *testing.T) {
filters := []Filter{{
ID: 0, Data: []byte(text),
}}
d := newForTest(t, nil, filters)
d, setts := newForTest(t, nil, filters)
t.Cleanup(d.Close)
d.checkMatchIP(t, "google.com", addr, dns.TypeA)
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA)
d.checkMatchEmpty(t, "subdomain.google.com")
d.checkMatchEmpty(t, "example.org")
d.checkMatchIP(t, "google.com", addr, dns.TypeA, setts)
d.checkMatchIP(t, "www.google.com", addr, dns.TypeA, setts)
d.checkMatchEmpty(t, "subdomain.google.com", setts)
d.checkMatchEmpty(t, "example.org", setts)
// IPv4 match.
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA)
d.checkMatchIP(t, "block.com", "0.0.0.0", dns.TypeA, setts)
// Empty IPv6.
res, err := d.CheckHost("block.com", dns.TypeAAAA, &setts)
res, err := d.CheckHost("block.com", dns.TypeAAAA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -134,10 +135,10 @@ func TestEtcHostsMatching(t *testing.T) {
assert.Empty(t, res.Rules[0].IP)
// IPv6 match.
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA)
d.checkMatchIP(t, "ipv6.com", addr6, dns.TypeAAAA, setts)
// Empty IPv4.
res, err = d.CheckHost("ipv6.com", dns.TypeA, &setts)
res, err = d.CheckHost("ipv6.com", dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -148,7 +149,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.Empty(t, res.Rules[0].IP)
// Two IPv4, both must be returned.
res, err = d.CheckHost("host2", dns.TypeA, &setts)
res, err = d.CheckHost("host2", dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -159,7 +160,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.Equal(t, res.Rules[1].IP, net.IP{0, 0, 0, 2})
// One IPv6 address.
res, err = d.CheckHost("host2", dns.TypeAAAA, &setts)
res, err = d.CheckHost("host2", dns.TypeAAAA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -176,27 +177,27 @@ func TestSafeBrowsing(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked)
d.checkMatch(t, sbBlocked, setts)
require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, pcBlocked)
d.checkMatch(t, "test."+sbBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru", setts)
d.checkMatchEmpty(t, pcBlocked, setts)
// Cached result.
d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, sbBlocked)
d.checkMatchEmpty(t, pcBlocked)
d.checkMatch(t, sbBlocked, setts)
d.checkMatchEmpty(t, pcBlocked, setts)
d.safeBrowsingServer = defaultSafebrowsingServer
}
func TestParallelSB(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
d, setts := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
@@ -205,10 +206,10 @@ func TestParallelSB(t *testing.T) {
for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel()
d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, pcBlocked)
d.checkMatch(t, sbBlocked, setts)
d.checkMatch(t, "test."+sbBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru", setts)
d.checkMatchEmpty(t, pcBlocked, setts)
})
}
})
@@ -217,7 +218,7 @@ func TestParallelSB(t *testing.T) {
// Safe Search.
func TestSafeSearch(t *testing.T) {
d := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
d, _ := newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
val, ok := d.SafeSearchDomain("www.google.com")
require.True(t, ok)
@@ -226,7 +227,7 @@ func TestSafeSearch(t *testing.T) {
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
d := newForTest(t, &Config{
d, setts := newForTest(t, &Config{
SafeSearchEnabled: true,
}, nil)
t.Cleanup(d.Close)
@@ -243,7 +244,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"www.yandex.com",
} {
t.Run(strings.ToLower(host), func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts)
res, err := d.CheckHost(host, dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -258,7 +259,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
func TestCheckHostSafeSearchGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(t, &Config{
d, setts := newForTest(t, &Config{
SafeSearchEnabled: true,
CustomResolver: resolver,
}, nil)
@@ -277,7 +278,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
"www.google.je",
} {
t.Run(host, func(t *testing.T) {
res, err := d.CheckHost(host, dns.TypeA, &setts)
res, err := d.CheckHost(host, dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -291,12 +292,12 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
}
func TestSafeSearchCacheYandex(t *testing.T) {
d := newForTest(t, nil, nil)
d, setts := newForTest(t, nil, nil)
t.Cleanup(d.Close)
const domain = "yandex.ru"
// Check host with disabled safesearch.
res, err := d.CheckHost(domain, dns.TypeA, &setts)
res, err := d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
@@ -305,10 +306,10 @@ func TestSafeSearchCacheYandex(t *testing.T) {
yandexIP := net.IPv4(213, 180, 193, 56)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
res, err = d.CheckHost(domain, dns.TypeA, &setts)
res, err = d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
// For yandex we already know valid IP.
@@ -325,20 +326,20 @@ func TestSafeSearchCacheYandex(t *testing.T) {
func TestSafeSearchCacheGoogle(t *testing.T) {
resolver := &aghtest.TestResolver{}
d := newForTest(t, &Config{
d, setts := newForTest(t, &Config{
CustomResolver: resolver,
}, nil)
t.Cleanup(d.Close)
const domain = "www.google.ru"
res, err := d.CheckHost(domain, dns.TypeA, &setts)
res, err := d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
require.Empty(t, res.Rules)
d = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
d, setts = newForTest(t, &Config{SafeSearchEnabled: true}, nil)
t.Cleanup(d.Close)
d.resolver = resolver
@@ -358,7 +359,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
}
}
res, err = d.CheckHost(domain, dns.TypeA, &setts)
res, err = d.CheckHost(domain, dns.TypeA, setts)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
@@ -379,22 +380,22 @@ func TestParentalControl(t *testing.T) {
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
d := newForTest(t, &Config{ParentalEnabled: true}, nil)
d, setts := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close)
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.checkMatch(t, pcBlocked)
d.checkMatch(t, pcBlocked, setts)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
d.checkMatch(t, "www."+pcBlocked)
d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com")
d.checkMatch(t, "www."+pcBlocked, setts)
d.checkMatchEmpty(t, "www.yandex.ru", setts)
d.checkMatchEmpty(t, "yandex.ru", setts)
d.checkMatchEmpty(t, "api.jquery.com", setts)
// Test cached result.
d.parentalServer = "127.0.0.1"
d.checkMatch(t, pcBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatch(t, pcBlocked, setts)
d.checkMatchEmpty(t, "yandex.ru", setts)
}
// Filtering.
@@ -679,10 +680,10 @@ func TestMatching(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s-%s", tc.name, tc.host), func(t *testing.T) {
filters := []Filter{{ID: 0, Data: []byte(tc.rules)}}
d := newForTest(t, nil, filters)
d, setts := newForTest(t, nil, filters)
t.Cleanup(d.Close)
res, err := d.CheckHost(tc.host, tc.wantDNSType, &setts)
res, err := d.CheckHost(tc.host, tc.wantDNSType, setts)
require.NoError(t, err)
assert.Equalf(t, tc.wantIsFiltered, res.IsFiltered, "Hostname %s has wrong result (%v must be %v)", tc.host, res.IsFiltered, tc.wantIsFiltered)
@@ -705,7 +706,7 @@ func TestWhitelist(t *testing.T) {
whiteFilters := []Filter{{
ID: 0, Data: []byte(whiteRules),
}}
d := newForTest(t, nil, filters)
d, setts := newForTest(t, nil, filters)
err := d.SetFilters(filters, whiteFilters, false)
require.NoError(t, err)
@@ -713,7 +714,7 @@ func TestWhitelist(t *testing.T) {
t.Cleanup(d.Close)
// Matched by white filter.
res, err := d.CheckHost("host1", dns.TypeA, &setts)
res, err := d.CheckHost("host1", dns.TypeA, setts)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
@@ -724,7 +725,7 @@ func TestWhitelist(t *testing.T) {
assert.Equal(t, "||host1^", res.Rules[0].Text)
// Not matched by white filter, but matched by block filter.
res, err = d.CheckHost("host2", dns.TypeA, &setts)
res, err = d.CheckHost("host2", dns.TypeA, setts)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
@@ -750,7 +751,7 @@ func applyClientSettings(setts *Settings) {
}
func TestClientSettings(t *testing.T) {
d := newForTest(t,
d, setts := newForTest(t,
&Config{
ParentalEnabled: true,
SafeBrowsingEnabled: false,
@@ -796,7 +797,7 @@ func TestClientSettings(t *testing.T) {
return func(t *testing.T) {
t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
r, err := d.CheckHost(tc.host, dns.TypeA, setts)
require.NoError(t, err)
if before {
@@ -814,7 +815,7 @@ func TestClientSettings(t *testing.T) {
t.Run(tc.name, makeTester(tc, tc.before))
}
applyClientSettings(&setts)
applyClientSettings(setts)
for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, !tc.before))
@@ -824,13 +825,13 @@ func TestClientSettings(t *testing.T) {
// Benchmarks.
func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
for n := 0; n < b.N; n++ {
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
@@ -838,14 +839,14 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}
func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
d, setts := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
@@ -854,7 +855,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) {
}
func BenchmarkSafeSearch(b *testing.B) {
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
for n := 0; n < b.N; n++ {
val, ok := d.SafeSearchDomain("www.google.com")
@@ -865,7 +866,7 @@ func BenchmarkSafeSearch(b *testing.B) {
}
func BenchmarkSafeSearchParallel(b *testing.B) {
d := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
d, _ := newForTest(b, &Config{SafeSearchEnabled: true}, nil)
b.Cleanup(d.Close)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {

View File

@@ -1,15 +1,13 @@
package home
package filtering
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
@@ -34,7 +32,7 @@ func validateFilterURL(urlStr string) (err error) {
return fmt.Errorf("checking filter url: %w", err)
}
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
if s := url.Scheme; s != aghhttp.SchemeHTTP && s != aghhttp.SchemeHTTPS {
return fmt.Errorf("checking filter url: invalid scheme %q", s)
}
@@ -47,7 +45,7 @@ type filterAddJSON struct {
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
@@ -65,14 +63,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
}
// Check for duplicates
if filterExists(fj.URL) {
if d.filterExists(fj.URL) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return
}
// Set necessary properties
filt := filter{
filt := FilterYAML{
Enabled: true,
URL: fj.URL,
Name: fj.Name,
@@ -81,7 +79,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
filt.ID = assignUniqueFilterID()
// Download the filter contents
ok, err := f.update(&filt)
ok, err := d.update(&filt)
if err != nil {
aghhttp.Error(
r,
@@ -109,14 +107,14 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines.
if !filterAdd(filt) {
if !d.filterAdd(filt) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return
}
onConfigModified()
enableFilters(true)
d.ConfigModified()
d.EnableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
@@ -124,7 +122,7 @@ func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
}
}
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
@@ -138,23 +136,23 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
return
}
config.Lock()
filters := &config.Filters
d.filtersMu.Lock()
filters := &d.Filters
if req.Whitelist {
filters = &config.WhitelistFilters
filters = &d.WhitelistFilters
}
var deleted filter
var newFilters []filter
for _, f := range *filters {
if f.URL != req.URL {
newFilters = append(newFilters, f)
var deleted FilterYAML
var newFilters []FilterYAML
for _, flt := range *filters {
if flt.URL != req.URL {
newFilters = append(newFilters, flt)
continue
}
deleted = f
path := f.Path()
deleted = flt
path := flt.Path(d.DataDir)
err = os.Rename(path, path+".old")
if err != nil {
log.Error("deleting filter %q: %s", path, err)
@@ -162,10 +160,10 @@ func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
}
*filters = newFilters
config.Unlock()
d.filtersMu.Unlock()
onConfigModified()
enableFilters(true)
d.ConfigModified()
d.EnableFilters(true)
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
// necessary, but will require the additional complicated code to run
@@ -191,55 +189,51 @@ type filterURLReq struct {
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
if fj.Data == nil {
err = errors.Error("data cannot be null")
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent"))
return
}
err = validateFilterURL(fj.Data.URL)
if err != nil {
err = fmt.Errorf("invalid url: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err)
return
}
filt := filter{
filt := FilterYAML{
Enabled: fj.Data.Enabled,
Name: fj.Data.Name,
URL: fj.Data.URL,
}
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
status := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusBadRequest, "URL doesn't exist")
return
}
if (status & statusURLExists) != 0 {
http.Error(w, "URL already exists", http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusBadRequest, "URL already exists")
return
}
onConfigModified()
d.ConfigModified()
restart := (status & statusEnabledChanged) != 0
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules
flags := filterRefreshBlocklists
if fj.Whitelist {
flags = filterRefreshAllowlists
}
nUpdated, _ := f.refreshFilters(flags, true)
// download new filter and apply its rules.
nUpdated := d.refreshFilters(!fj.Whitelist, fj.Whitelist, false)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves
restart = false
@@ -249,25 +243,34 @@ func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
}
if restart {
enableFilters(true)
d.EnableFilters(true)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body)
// filteringRulesReq is the JSON structure for settings custom filtering rules.
type filteringRulesReq struct {
Rules []string `json:"rules"`
}
func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
if aghhttp.WriteTextPlainDeprecated(w, r) {
return
}
req := &filteringRulesReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
config.UserRules = strings.Split(string(body), "\n")
onConfigModified()
enableFilters(true)
d.UserRules = req.Rules
d.ConfigModified()
d.EnableFilters(true)
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct {
White bool `json:"whitelist"`
}
@@ -285,35 +288,27 @@ func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
return
}
flags := filterRefreshBlocklists
if req.White {
flags = filterRefreshAllowlists
}
func() {
// Temporarily unlock the Context.controlLock because the
// f.refreshFilters waits for it to be unlocked but it's
// actually locked in ensure wrapper.
//
// TODO(e.burkov): Reconsider this messy syncing process.
Context.controlLock.Unlock()
defer Context.controlLock.Lock()
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
}()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
var ok bool
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
if !ok {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"filters update procedure is already running",
)
return
}
js, err := json.Marshal(resp)
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
type filterJSON struct {
@@ -333,7 +328,7 @@ type filteringConfig struct {
Enabled bool `json:"enabled"`
}
func filterToJSON(f filter) filterJSON {
func filterToJSON(f FilterYAML) filterJSON {
fj := filterJSON{
ID: f.ID,
Enabled: f.Enabled,
@@ -350,21 +345,21 @@ func filterToJSON(f filter) filterJSON {
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{}
config.RLock()
resp.Enabled = config.DNS.FilteringEnabled
resp.Interval = config.DNS.FiltersUpdateIntervalHours
for _, f := range config.Filters {
d.filtersMu.RLock()
resp.Enabled = d.FilteringEnabled
resp.Interval = d.FiltersUpdateIntervalHours
for _, f := range d.Filters {
fj := filterToJSON(f)
resp.Filters = append(resp.Filters, fj)
}
for _, f := range config.WhitelistFilters {
for _, f := range d.WhitelistFilters {
fj := filterToJSON(f)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
resp.UserRules = config.UserRules
config.RUnlock()
resp.UserRules = d.UserRules
d.filtersMu.RUnlock()
jsonVal, err := json.Marshal(resp)
if err != nil {
@@ -380,7 +375,7 @@ func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -389,22 +384,22 @@ func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request
return
}
if !checkFiltersUpdateIntervalHours(req.Interval) {
if !ValidateUpdateIvl(req.Interval) {
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
func() {
config.Lock()
defer config.Unlock()
d.filtersMu.Lock()
defer d.filtersMu.Unlock()
config.DNS.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval
d.FilteringEnabled = req.Enabled
d.FiltersUpdateIntervalHours = req.Interval
}()
onConfigModified()
enableFilters(true)
d.ConfigModified()
d.EnableFilters(true)
}
type checkHostRespRule struct {
@@ -435,15 +430,15 @@ type checkHostResp struct {
FilterID int64 `json:"filter_id"`
}
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
host := r.URL.Query().Get("name")
setts := Context.dnsFilter.GetConfig()
setts := d.GetConfig()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
d.ApplyBlockedServices(&setts, nil)
result, err := d.CheckHost(host, dns.TypeA, &setts)
if err != nil {
aghhttp.Error(
r,
@@ -457,18 +452,20 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
rulesLen := len(result.Rules)
resp := checkHostResp{
Reason: result.Reason.String(),
SvcName: result.ServiceName,
CanonName: result.CanonName,
IPList: result.IPList,
Rules: make([]*checkHostRespRule, len(result.Rules)),
}
if len(result.Rules) > 0 {
if rulesLen > 0 {
resp.FilterID = result.Rules[0].FilterListID
resp.Rule = result.Rules[0].Text
}
resp.Rules = make([]*checkHostRespRule, len(result.Rules))
for i, r := range result.Rules {
resp.Rules[i] = &checkHostRespRule{
FilterListID: r.FilterListID,
@@ -476,28 +473,51 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
}
}
js, err := json.Marshal(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
}
}
// RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() {
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
func (d *DNSFilter) RegisterFilteringHandlers() {
registerHTTP := d.HTTPRegister
if registerHTTP == nil {
return
}
registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
registerHTTP(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
registerHTTP(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
registerHTTP(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
registerHTTP(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
registerHTTP(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
registerHTTP(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
registerHTTP(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
registerHTTP(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
registerHTTP(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
registerHTTP(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
registerHTTP(http.MethodGet, "/control/blocked_services/services", d.handleBlockedServicesAvailableServices)
registerHTTP(http.MethodGet, "/control/blocked_services/list", d.handleBlockedServicesList)
registerHTTP(http.MethodPost, "/control/blocked_services/set", d.handleBlockedServicesSet)
registerHTTP(http.MethodGet, "/control/filtering/status", d.handleFilteringStatus)
registerHTTP(http.MethodPost, "/control/filtering/config", d.handleFilteringConfig)
registerHTTP(http.MethodPost, "/control/filtering/add_url", d.handleFilteringAddURL)
registerHTTP(http.MethodPost, "/control/filtering/remove_url", d.handleFilteringRemoveURL)
registerHTTP(http.MethodPost, "/control/filtering/set_url", d.handleFilteringSetURL)
registerHTTP(http.MethodPost, "/control/filtering/refresh", d.handleFilteringRefresh)
registerHTTP(http.MethodPost, "/control/filtering/set_rules", d.handleFilteringSetRules)
registerHTTP(http.MethodGet, "/control/filtering/check_host", d.handleCheckHost)
}
func checkFiltersUpdateIntervalHours(i uint32) bool {
// ValidateUpdateIvl returns false if i is not a valid filters update interval.
func ValidateUpdateIvl(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

View File

@@ -133,34 +133,31 @@ func matchDomainWildcard(host, wildcard string) (ok bool) {
// 1. A and AAAA > CNAME
// 2. wildcard > exact
// 3. lower level wildcard > higher level wildcard
//
// TODO(a.garipov): Replace with slices.Sort.
type rewritesSorted []*LegacyRewrite
// Len implements the sort.Interface interface for legacyRewritesSorted.
// Len implements the sort.Interface interface for rewritesSorted.
func (a rewritesSorted) Len() (l int) { return len(a) }
// Swap implements the sort.Interface interface for legacyRewritesSorted.
// Swap implements the sort.Interface interface for rewritesSorted.
func (a rewritesSorted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// Less implements the sort.Interface interface for legacyRewritesSorted.
// Less implements the sort.Interface interface for rewritesSorted.
func (a rewritesSorted) Less(i, j int) (less bool) {
if a[i].Type == dns.TypeCNAME && a[j].Type != dns.TypeCNAME {
ith, jth := a[i], a[j]
if ith.Type == dns.TypeCNAME && jth.Type != dns.TypeCNAME {
return true
} else if a[i].Type != dns.TypeCNAME && a[j].Type == dns.TypeCNAME {
} else if ith.Type != dns.TypeCNAME && jth.Type == dns.TypeCNAME {
return false
}
if isWildcard(a[i].Domain) {
if !isWildcard(a[j].Domain) {
return false
}
} else {
if isWildcard(a[j].Domain) {
return true
}
if iw, jw := isWildcard(ith.Domain), isWildcard(jth.Domain); iw != jw {
return jw
}
// Both are wildcards.
return len(a[i].Domain) > len(a[j].Domain)
// Both are either wildcards or not.
return len(ith.Domain) > len(jth.Domain)
}
// prepareRewrites normalizes and validates all legacy DNS rewrites.
@@ -313,9 +310,3 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
d.Config.ConfigModified()
}
func (d *DNSFilter) registerRewritesHandlers() {
d.Config.HTTPRegister(http.MethodGet, "/control/rewrite/list", d.handleRewriteList)
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/add", d.handleRewriteAdd)
d.Config.HTTPRegister(http.MethodPost, "/control/rewrite/delete", d.handleRewriteDelete)
}

View File

@@ -12,7 +12,7 @@ import (
// TODO(e.burkov): All the tests in this file may and should me merged together.
func TestRewrites(t *testing.T) {
d := newForTest(t, nil, nil)
d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close)
d.Rewrites = []*LegacyRewrite{{
@@ -188,7 +188,7 @@ func TestRewrites(t *testing.T) {
}
func TestRewritesLevels(t *testing.T) {
d := newForTest(t, nil, nil)
d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Exact host, wildcard L2, wildcard L3.
d.Rewrites = []*LegacyRewrite{{
@@ -235,7 +235,7 @@ func TestRewritesLevels(t *testing.T) {
}
func TestRewritesExceptionCNAME(t *testing.T) {
d := newForTest(t, nil, nil)
d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Wildcard and exception for a sub-domain.
d.Rewrites = []*LegacyRewrite{{
@@ -286,7 +286,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
}
func TestRewritesExceptionIP(t *testing.T) {
d := newForTest(t, nil, nil)
d, _ := newForTest(t, nil, nil)
t.Cleanup(d.Close)
// Exception for AAAA record.
d.Rewrites = []*LegacyRewrite{{

View File

@@ -415,17 +415,3 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}
func (d *DNSFilter) registerSecurityHandlers() {
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/parental/enable", d.handleParentalEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/parental/disable", d.handleParentalDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/parental/status", d.handleParentalStatus)
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/enable", d.handleSafeSearchEnable)
d.Config.HTTPRegister(http.MethodPost, "/control/safesearch/disable", d.handleSafeSearchDisable)
d.Config.HTTPRegister(http.MethodGet, "/control/safesearch/status", d.handleSafeSearchStatus)
}

View File

@@ -107,7 +107,7 @@ func TestSafeBrowsingCache(t *testing.T) {
}
func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
ups := aghtest.NewErrorUpstream()
@@ -128,7 +128,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
}
func TestSBPC(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
d, _ := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const hostname = "example.org"

View File

@@ -8,12 +8,14 @@ import (
"fmt"
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/timeutil"
@@ -32,7 +34,8 @@ const sessionTokenSize = 16
type session struct {
userName string
expire uint32 // expiration time (in seconds)
// expire is the expiration time, in seconds.
expire uint32
}
func (s *session) serialize() []byte {
@@ -64,29 +67,29 @@ func (s *session) deserialize(data []byte) bool {
// Auth - global object
type Auth struct {
db *bbolt.DB
blocker *authRateLimiter
sessions map[string]*session
users []User
lock sync.Mutex
sessionTTL uint32
db *bbolt.DB
raleLimiter *authRateLimiter
sessions map[string]*session
users []webUser
lock sync.Mutex
sessionTTL uint32
}
// User object
type User struct {
// webUser represents a user of the Web UI.
type webUser struct {
Name string `yaml:"name"`
PasswordHash string `yaml:"password"` // bcrypt hash
PasswordHash string `yaml:"password"`
}
// InitAuth - create a global object
func InitAuth(dbFilename string, users []User, sessionTTL uint32, blocker *authRateLimiter) *Auth {
func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter *authRateLimiter) *Auth {
log.Info("Initializing auth module: %s", dbFilename)
a := &Auth{
sessionTTL: sessionTTL,
blocker: blocker,
sessions: make(map[string]*session),
users: users,
sessionTTL: sessionTTL,
raleLimiter: rateLimiter,
sessions: make(map[string]*session),
users: users,
}
var err error
a.db, err = bbolt.Open(dbFilename, 0o644, nil)
@@ -326,35 +329,25 @@ func newSessionToken() (data []byte, err error) {
return randData, nil
}
// cookieTimeFormat is the format to be used in (time.Time).Format for cookie's
// expiry field.
const cookieTimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// cookieExpiryFormat returns the formatted exp to be used in cookie string.
// It's quite simple for now, but probably will be expanded in the future.
func cookieExpiryFormat(exp time.Time) (formatted string) {
return exp.Format(cookieTimeFormat)
}
func (a *Auth) httpCookie(req loginJSON, addr string) (cookie string, err error) {
blocker := a.blocker
u := a.UserFind(req.Name, req.Password)
if len(u.Name) == 0 {
if blocker != nil {
blocker.inc(addr)
// newCookie creates a new authentication cookie.
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
rateLimiter := a.raleLimiter
u, ok := a.findUser(req.Name, req.Password)
if !ok {
if rateLimiter != nil {
rateLimiter.inc(addr)
}
return "", err
return nil, errors.Error("invalid username or password")
}
if blocker != nil {
blocker.remove(addr)
if rateLimiter != nil {
rateLimiter.remove(addr)
}
var sess []byte
sess, err = newSessionToken()
sess, err := newSessionToken()
if err != nil {
return "", err
return nil, fmt.Errorf("generating token: %w", err)
}
now := time.Now().UTC()
@@ -364,11 +357,15 @@ func (a *Auth) httpCookie(req loginJSON, addr string) (cookie string, err error)
expire: uint32(now.Unix()) + a.sessionTTL,
})
return fmt.Sprintf(
"%s=%s; Path=/; HttpOnly; Expires=%s",
sessionCookieName, hex.EncodeToString(sess),
cookieExpiryFormat(now.Add(cookieTTL)),
), nil
return &http.Cookie{
Name: sessionCookieName,
Value: hex.EncodeToString(sess),
Path: "/",
Expires: now.Add(cookieTTL),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
// realIP extracts the real IP address of the client from an HTTP request using
@@ -436,8 +433,8 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
if blocker := Context.auth.blocker; blocker != nil {
if left := blocker.check(remoteAddr); left > 0 {
if rateLimiter := Context.auth.raleLimiter; rateLimiter != nil {
if left := rateLimiter.check(remoteAddr); left > 0 {
w.Header().Set("Retry-After", strconv.Itoa(int(left.Seconds())))
aghhttp.Error(r, w, http.StatusTooManyRequests, "auth: blocked for %s", left)
@@ -445,10 +442,9 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
}
}
var cookie string
cookie, err = Context.auth.httpCookie(req, remoteAddr)
cookie, err := Context.auth.newCookie(req, remoteAddr)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "crypto rand reader: %s", err)
aghhttp.Error(r, w, http.StatusForbidden, "%s", err)
return
}
@@ -462,20 +458,11 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
log.Error("auth: unknown ip")
}
if len(cookie) == 0 {
log.Info("auth: failed to login user %q from ip %v", req.Name, ip)
time.Sleep(1 * time.Second)
http.Error(w, "invalid username or password", http.StatusBadRequest)
return
}
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
http.SetCookie(w, cookie)
h := w.Header()
h.Set("Set-Cookie", cookie)
h.Set("Cache-Control", "no-store, no-cache, must-revalidate, proxy-revalidate")
h.Set("Pragma", "no-cache")
h.Set("Expires", "0")
@@ -484,17 +471,31 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
}
func handleLogout(w http.ResponseWriter, r *http.Request) {
cookie := r.Header.Get("Cookie")
sess := parseCookie(cookie)
respHdr := w.Header()
c, err := r.Cookie(sessionCookieName)
if err != nil {
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
// The user is already logged out.
respHdr.Set("Location", "/login.html")
w.WriteHeader(http.StatusFound)
Context.auth.RemoveSession(sess)
return
}
w.Header().Set("Location", "/login.html")
Context.auth.RemoveSession(c.Value)
s := fmt.Sprintf("%s=; Path=/; HttpOnly; Expires=Thu, 01 Jan 1970 00:00:00 GMT",
sessionCookieName)
w.Header().Set("Set-Cookie", s)
c = &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
respHdr.Set("Location", "/login.html")
respHdr.Set("Set-Cookie", c.String())
w.WriteHeader(http.StatusFound)
}
@@ -504,101 +505,108 @@ func RegisterAuthHandlers() {
httpRegister(http.MethodGet, "/control/logout", handleLogout)
}
func parseCookie(cookie string) string {
pairs := strings.Split(cookie, ";")
for _, pair := range pairs {
pair = strings.TrimSpace(pair)
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
continue
}
if kv[0] == sessionCookieName {
return kv[1]
}
}
return ""
}
// optionalAuthThird return true if user should authenticate first.
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool) {
authFirst = false
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
if glProcessCookie(r) {
log.Debug("auth: authentication is handled by GL-Inet submodule")
return false
}
// redirect to login page if not authenticated
ok := false
isAuthenticated := false
cookie, err := r.Cookie(sessionCookieName)
if glProcessCookie(r) {
log.Debug("auth: authentication was handled by GL-Inet submodule")
ok = true
} else if err == nil {
r := Context.auth.checkSession(cookie.Value)
if r == checkSessionOK {
ok = true
} else if r < 0 {
log.Debug("auth: invalid cookie value: %s", cookie)
}
} else {
// there's no Cookie, check Basic authentication
user, pass, ok2 := r.BasicAuth()
if ok2 {
u := Context.auth.UserFind(user, pass)
if len(u.Name) != 0 {
ok = true
} else {
if err != nil {
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
// Check Basic authentication.
user, pass, hasBasic := r.BasicAuth()
if hasBasic {
_, isAuthenticated = Context.auth.findUser(user, pass)
if !isAuthenticated {
log.Info("auth: invalid Basic Authorization value")
}
}
}
if !ok {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
if glProcessRedirect(w, r) {
log.Debug("auth: redirected to login page by GL-Inet submodule")
} else {
w.Header().Set("Location", "/login.html")
w.WriteHeader(http.StatusFound)
}
} else {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
} else {
res := Context.auth.checkSession(cookie.Value)
isAuthenticated = res == checkSessionOK
if !isAuthenticated {
log.Debug("auth: invalid cookie value: %s", cookie)
}
authFirst = true
}
return authFirst
if isAuthenticated {
return false
}
if p := r.URL.Path; p == "/" || p == "/index.html" {
if glProcessRedirect(w, r) {
log.Debug("auth: redirected to login page by GL-Inet submodule")
} else {
log.Debug("auth: redirected to login page")
w.Header().Set("Location", "/login.html")
w.WriteHeader(http.StatusFound)
}
} else {
log.Debug("auth: responded with forbidden to %s %s", r.Method, p)
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
}
return true
}
func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
// TODO(a.garipov): Use [http.Handler] consistently everywhere throughout the
// project.
func optionalAuth(
h func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/login.html" {
// redirect to dashboard if already authenticated
authRequired := Context.auth != nil && Context.auth.AuthRequired()
p := r.URL.Path
authRequired := Context.auth != nil && Context.auth.AuthRequired()
if p == "/login.html" {
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
r := Context.auth.checkSession(cookie.Value)
if r == checkSessionOK {
// Redirect to the dashboard if already authenticated.
res := Context.auth.checkSession(cookie.Value)
if res == checkSessionOK {
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)
return
} else if r == checkSessionNotFound {
log.Debug("auth: invalid cookie value: %s", cookie)
}
}
} else if strings.HasPrefix(r.URL.Path, "/assets/") ||
strings.HasPrefix(r.URL.Path, "/login.") {
// process as usual
// no additional auth requirements
} else if Context.auth != nil && Context.auth.AuthRequired() {
log.Debug("auth: invalid cookie value: %s", cookie)
}
} else if isPublicResource(p) {
// Process as usual, no additional auth requirements.
} else if authRequired {
if optionalAuthThird(w, r) {
return
}
}
handler(w, r)
h(w, r)
}
}
// isPublicResource returns true if p is a path to a public resource.
func isPublicResource(p string) (ok bool) {
isAsset, err := path.Match("/assets/*", p)
if err != nil {
// The only error that is returned from path.Match is
// [path.ErrBadPattern]. This is a programmer error.
panic(fmt.Errorf("bad asset pattern: %w", err))
}
isLogin, err := path.Match("/login.*", p)
if err != nil {
// Same as above.
panic(fmt.Errorf("bad login pattern: %w", err))
}
return isAsset || isLogin
}
type authHandler struct {
handler http.Handler
}
@@ -612,7 +620,7 @@ func optionalAuthHandler(handler http.Handler) http.Handler {
}
// UserAdd - add new user
func (a *Auth) UserAdd(u *User, password string) {
func (a *Auth) UserAdd(u *webUser, password string) {
if len(password) == 0 {
return
}
@@ -631,31 +639,35 @@ func (a *Auth) UserAdd(u *User, password string) {
log.Debug("auth: added user: %s", u.Name)
}
// UserFind - find a user
func (a *Auth) UserFind(login, password string) User {
// findUser returns a user if there is one.
func (a *Auth) findUser(login, password string) (u webUser, ok bool) {
a.lock.Lock()
defer a.lock.Unlock()
for _, u := range a.users {
for _, u = range a.users {
if u.Name == login &&
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
return u
return u, true
}
}
return User{}
return webUser{}, false
}
// getCurrentUser returns the current user. It returns an empty User if the
// user is not found.
func (a *Auth) getCurrentUser(r *http.Request) User {
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
return Context.auth.UserFind(user, pass)
u, _ = Context.auth.findUser(user, pass)
return u
}
return User{}
return webUser{}
}
a.lock.Lock()
@@ -663,20 +675,20 @@ func (a *Auth) getCurrentUser(r *http.Request) User {
s, ok := a.sessions[cookie.Value]
if !ok {
return User{}
return webUser{}
}
for _, u := range a.users {
for _, u = range a.users {
if u.Name == s.userName {
return u
}
}
return User{}
return webUser{}
}
// GetUsers - get users
func (a *Auth) GetUsers() []User {
func (a *Auth) GetUsers() []webUser {
a.lock.Lock()
users := a.users
a.lock.Unlock()

View File

@@ -43,14 +43,14 @@ func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []User{{
users := []webUser{{
Name: "name",
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
}}
a := InitAuth(fn, nil, 60, nil)
s := session{}
user := User{Name: "name"}
user := webUser{Name: "name"}
a.UserAdd(&user, "password")
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
@@ -84,7 +84,8 @@ func TestAuth(t *testing.T) {
a.storeSession(sess, &s)
a.Close()
u := a.UserFind("name", "password")
u, ok := a.findUser("name", "password")
assert.True(t, ok)
assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second)
@@ -118,7 +119,7 @@ func TestAuthHTTP(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
users := []User{
users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
Context.auth = InitAuth(fn, users, 60, nil)
@@ -150,18 +151,19 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
// perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "")
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
require.NoError(t, err)
assert.NotEmpty(t, cookie)
require.NotNil(t, cookie)
// get /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie)
r.Header.Set("Cookie", cookie.String())
r.URL = &url.URL{Path: "/"}
handlerCalled = false
handler2(&w, &r)
assert.True(t, handlerCalled)
r.Header.Del("Cookie")
// get / with basic auth
@@ -177,7 +179,7 @@ func TestAuthHTTP(t *testing.T) {
// get login page with a valid cookie - we're redirected to /
handler2 = optionalAuth(handler)
w.hdr = make(http.Header)
r.Header.Set("Cookie", cookie)
r.Header.Set("Cookie", cookie.String())
r.URL = &url.URL{Path: loginURL}
handlerCalled = false
handler2(&w, &r)

View File

@@ -126,7 +126,7 @@ type clientsContainer struct {
allTags *stringutil.Set
// dhcpServer is used for looking up clients IP addresses by MAC addresses
dhcpServer *dhcpd.Server
dhcpServer dhcpd.Interface
// dnsServer is used for checking clients IP status access list status
dnsServer *dnsforward.Server
@@ -146,7 +146,7 @@ type clientsContainer struct {
// Note: this function must be called only once
func (clients *clientsContainer) Init(
objects []*clientObject,
dhcpServer *dhcpd.Server,
dhcpServer dhcpd.Interface,
etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB,
) {

View File

@@ -279,8 +279,6 @@ func TestClientsAddExisting(t *testing.T) {
t.Skip("skipping dhcp test on windows")
}
var err error
ip := net.IP{1, 2, 3, 4}
// First, init a DHCP server with a single static lease.
@@ -296,13 +294,15 @@ func TestClientsAddExisting(t *testing.T) {
},
}
clients.dhcpServer, err = dhcpd.Create(config)
dhcpServer, err := dhcpd.Create(config)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return os.Remove("leases.db")
})
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
clients.dhcpServer = dhcpServer
err = dhcpServer.AddStaticLease(&dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: ip,
Hostname: "testhost",

View File

@@ -93,13 +93,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
data.Tags = clientTags
w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w).Encode(data)
if e != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "failed to encode to json: %v", e)
return
}
_ = aghhttp.WriteJSONResponse(w, r, data)
}
// Convert JSON object to Client object
@@ -249,11 +243,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
})
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(data)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write response: %s", err)
}
_ = aghhttp.WriteJSONResponse(w, r, data)
}
// findRuntime looks up the IP in runtime and temporary storages, like

View File

@@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -23,10 +22,9 @@ import (
yaml "gopkg.in/yaml.v3"
)
const (
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
)
// dataDir is the name of a directory under the working one to store some
// persistent data.
const dataDir = "data"
// logSettings are the logging settings part of the configuration file.
//
@@ -87,10 +85,10 @@ type configuration struct {
// It's reset after config is parsed
fileData []byte
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
Users []User `yaml:"users"` // Users that can access HTTP server
BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
Users []webUser `yaml:"users"` // Users that can access HTTP server
// AuthAttempts is the maximum number of failed login attempts a user
// can do before being blocked.
AuthAttempts uint `yaml:"auth_attempts"`
@@ -108,9 +106,16 @@ type configuration struct {
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"`
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
// Filters reflects the filters from [filtering.Config]. It's cloned to the
// config used in the filtering module at the startup. Afterwards it's
// cloned from the filtering module back here.
//
// TODO(e.burkov): Move all the filtering configuration fields into the
// only configuration subsection covering the changes with a single
// migration.
Filters []filtering.FilterYAML `yaml:"filters"`
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
@@ -145,9 +150,7 @@ type dnsConfig struct {
dnsforward.FilteringConfig `yaml:",inline"`
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
DnsfilterConf filtering.Config `yaml:",inline"`
DnsfilterConf *filtering.Config `yaml:",inline"`
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
@@ -193,15 +196,20 @@ type tlsConfigSettings struct {
//
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
var config = &configuration{
BindPort: 3000,
BetaBindPort: 0,
BindHost: net.IP{0, 0, 0, 0},
AuthAttempts: 5,
AuthBlockMin: 15,
BindPort: 3000,
BetaBindPort: 0,
BindHost: net.IP{0, 0, 0, 0},
AuthAttempts: 5,
AuthBlockMin: 15,
WebSessionTTLHours: 30 * 24,
DNS: dnsConfig{
BindHosts: []net.IP{{0, 0, 0, 0}},
Port: defaultPortDNS,
StatsInterval: 1,
BindHosts: []net.IP{{0, 0, 0, 0}},
Port: defaultPortDNS,
StatsInterval: 1,
QueryLogEnabled: true,
QueryLogFileEnabled: true,
QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day},
QueryLogMemSize: 1000,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of filtering features
BlockingMode: dnsforward.BlockingModeDefault,
@@ -222,18 +230,42 @@ var config = &configuration{
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
MaxGoroutines: 300,
},
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
DnsfilterConf: &filtering.Config{
SafeBrowsingCacheSize: 1 * 1024 * 1024,
SafeSearchCacheSize: 1 * 1024 * 1024,
ParentalCacheSize: 1 * 1024 * 1024,
CacheTime: 30,
FilteringEnabled: true,
FiltersUpdateIntervalHours: 24,
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
},
TLS: tlsConfigSettings{
PortHTTPS: defaultPortHTTPS,
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC,
},
Filters: []filtering.FilterYAML{{
Filter: filtering.Filter{ID: 1},
Enabled: true,
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
Name: "AdGuard DNS filter",
}, {
Filter: filtering.Filter{ID: 2},
Enabled: false,
URL: "https://adaway.org/hosts.txt",
Name: "AdAway Default Blocklist",
}},
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
Conf4: dhcpd.V4ServerConf{
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP,
},
Conf6: dhcpd.V6ServerConf{
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
},
},
Clients: &clientsConfig{
Sources: &clientSourcesConf{
@@ -255,31 +287,6 @@ var config = &configuration{
SchemaVersion: currentSchemaVersion,
}
// initConfig initializes default configuration for the current OS&ARCH
func initConfig() {
config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true
config.DNS.QueryLogFileEnabled = true
config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day}
config.DNS.QueryLogMemSize = 1000
config.DNS.CacheSize = 4 * 1024 * 1024
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
config.BetaBindPort = 3001
}
}
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename)
@@ -348,8 +355,8 @@ func parseConfig() (err error) {
return fmt.Errorf("validating udp ports: %w", err)
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
}
if config.DNS.UpstreamTimeout.Duration == 0 {
@@ -418,10 +425,11 @@ func (c *configuration) write() (err error) {
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
}
if Context.dnsFilter != nil {
c := filtering.Config{}
Context.dnsFilter.WriteDiskConfig(&c)
config.DNS.DnsfilterConf = c
if Context.filters != nil {
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
config.Filters = config.DNS.DnsfilterConf.Filters
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
config.UserRules = config.DNS.DnsfilterConf.UserRules
}
if s := Context.dnsServer; s != nil {
@@ -433,9 +441,7 @@ func (c *configuration) write() (err error) {
}
if Context.dhcpServer != nil {
c := &dhcpd.ServerConfig{}
Context.dhcpServer.WriteDiskConfig(c)
config.DHCP = c
Context.dhcpServer.WriteDiskConfig(config.DHCP)
}
config.Clients.Persistent = Context.clients.forConfig()

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"runtime"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@@ -97,16 +98,16 @@ func collectDNSAddresses() (addrs []string, err error) {
// statusResponse is a response for /control/status endpoint.
type statusResponse struct {
Version string `json:"version"`
Language string `json:"language"`
DNSAddrs []string `json:"dns_addresses"`
DNSPort int `json:"dns_port"`
HTTPPort int `json:"http_port"`
IsProtectionEnabled bool `json:"protection_enabled"`
// TODO(e.burkov): Inspect if front-end doesn't requires this field as
// openapi.yaml declares.
IsDHCPAvailable bool `json:"dhcp_available"`
IsRunning bool `json:"running"`
Version string `json:"version"`
Language string `json:"language"`
IsDHCPAvailable bool `json:"dhcp_available"`
IsRunning bool `json:"running"`
}
func handleStatus(w http.ResponseWriter, r *http.Request) {
@@ -125,12 +126,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
defer config.RUnlock()
resp = statusResponse{
Version: version.Version(),
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
IsRunning: isRunning(),
Version: version.Version(),
Language: config.Language,
IsRunning: isRunning(),
}
}()
@@ -146,13 +147,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
resp.IsDHCPAvailable = Context.dhcpServer != nil
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
type profileJSON struct {
@@ -162,13 +157,16 @@ type profileJSON struct {
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{}
u := Context.auth.getCurrentUser(r)
pj.Name = u.Name
data, err := json.Marshal(pj)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Marshal: %s", err)
return
}
_, _ = w.Write(data)
}
@@ -199,19 +197,29 @@ func httpRegister(method, url string, handler http.HandlerFunc) {
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
}
// ----------------------------------
// helper functions for HTTP handlers
// ----------------------------------
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
// ensure returns a wrapped handler that makes sure that the request has the
// correct method as well as additional method and header checks.
func ensure(
method string,
handler func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
log.Debug("%s %v", r.Method, r.URL)
start := time.Now()
m, u := r.Method, r.URL
log.Debug("started %s %s %s", m, r.Host, u)
defer func() { log.Debug("finished %s %s %s in %s", m, r.Host, u, time.Since(start)) }()
if m != method {
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
if r.Method != method {
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
return
}
if method == http.MethodPost || method == http.MethodPut || method == http.MethodDelete {
if modifiesData(m) {
if !ensureContentType(w, r) {
return
}
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
}
@@ -220,6 +228,42 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
}
}
// modifiesData returns true if m is an HTTP method that can modify data.
func modifiesData(m string) (ok bool) {
return m == http.MethodPost || m == http.MethodPut || m == http.MethodDelete
}
// ensureContentType makes sure that the content type of a data-modifying
// request is set correctly. If it is not, ensureContentType writes a response
// to w, and ok is false.
func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
const statusUnsup = http.StatusUnsupportedMediaType
cType := r.Header.Get(aghhttp.HdrNameContentType)
if r.ContentLength == 0 {
if cType == "" {
return true
}
// Assume that browsers always send a content type when submitting HTML
// forms and require no content type for requests with no body to make
// sure that the request comes from JavaScript.
aghhttp.Error(r, w, statusUnsup, "empty body with content-type %q not allowed", cType)
return false
}
const wantCType = aghhttp.HdrValApplicationJSON
if cType == wantCType {
return true
}
aghhttp.Error(r, w, statusUnsup, "only content-type %s is allowed", wantCType)
return false
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure(http.MethodPost, handler)
}
@@ -291,7 +335,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
}
httpsURL := &url.URL{
Scheme: schemeHTTPS,
Scheme: aghhttp.SchemeHTTPS,
Host: hostPort,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
@@ -307,7 +351,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
//
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
originURL := &url.URL{
Scheme: schemeHTTP,
Scheme: aghhttp.SchemeHTTP,
Host: r.Host,
}
w.Header().Set("Access-Control-Allow-Origin", originURL.String())

View File

@@ -59,19 +59,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
data.Interfaces[iface.Name] = iface
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal default addresses to json: %s",
err,
)
return
}
_ = aghhttp.WriteJSONResponse(w, r, data)
}
type checkConfReqEnt struct {
@@ -201,13 +189,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding the response: %s", err)
return
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// handleStaticIP - handles static IP request
@@ -424,7 +406,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
return
}
u := &User{
u := &webUser{
Name: req.Username,
}
Context.auth.UserAdd(u, req.Password)
@@ -688,19 +670,7 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
data.Interfaces = ifaces
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal default addresses to json: %s",
err,
)
return
}
_ = aghhttp.WriteJSONResponse(w, r, data)
}
// registerBetaInstallHandlers registers the install handlers for new client

View File

@@ -28,8 +28,6 @@ type temporaryError interface {
// Get the latest available version from the Internet
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
resp := &versionResponse{}
if Context.disableUpdate {
resp.Disabled = true
@@ -71,10 +69,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "writing body: %s", err)
}
_ = aghhttp.WriteJSONResponse(w, r, resp)
}
// requestVersionInfo sets the VersionInfo field of resp if it can reach the

View File

@@ -31,7 +31,10 @@ const (
// Called by other modules when configuration is changed
func onConfigModified() {
_ = config.write()
err := config.write()
if err != nil {
log.Error("writing config: %s", err)
}
}
// initDNSServer creates an instance of the dnsforward.Server
@@ -71,11 +74,11 @@ func initDNSServer() (err error) {
}
Context.queryLog = querylog.New(conf)
filterConf := config.DNS.DnsfilterConf
filterConf.EtcHosts = Context.etcHosts
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
Context.dnsFilter = filtering.New(&filterConf, nil)
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
@@ -83,13 +86,10 @@ func initDNSServer() (err error) {
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
var n *net.IPNet
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = n
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
@@ -101,15 +101,13 @@ func initDNSServer() (err error) {
}
p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter,
DNSFilter: Context.filters,
Stats: Context.stats,
QueryLog: Context.queryLog,
PrivateNets: privateNets,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
DHCPServer: Context.dhcpServer,
}
Context.dnsServer, err = dnsforward.NewServer(p)
@@ -143,7 +141,6 @@ func initDNSServer() (err error) {
Context.whois = initWHOIS(&Context.clients)
}
Context.filters.Init()
return nil
}
@@ -335,9 +332,12 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
Context.filters.ApplyBlockedServices(setts, nil)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
if clientIP == nil {
return
@@ -349,16 +349,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
if !ok {
c, ok = Context.clients.Find(clientIP.String())
if !ok {
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
return
}
}
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
Context.filters.ApplyBlockedServices(setts, c.BlockedServices)
}
setts.ClientName = c.Name
@@ -381,7 +381,7 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
enableFiltersLocked(false)
Context.filters.EnableFilters(false)
Context.clients.Start()
@@ -390,7 +390,6 @@ func startDNSServer() error {
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
}
Context.dnsFilter.Start()
Context.filters.Start()
Context.stats.Start()
Context.queryLog.Start()
@@ -449,10 +448,7 @@ func closeDNSServer() {
Context.dnsServer = nil
}
if Context.dnsFilter != nil {
Context.dnsFilter.Close()
Context.dnsFilter = nil
}
Context.filters.Close()
if Context.stats != nil {
err := Context.stats.Close()
@@ -469,7 +465,5 @@ func closeDNSServer() {
Context.queryLog = nil
}
Context.filters.Close()
log.Debug("Closed all DNS modules")
log.Debug("all dns modules are closed")
}

View File

@@ -20,6 +20,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
@@ -33,6 +34,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
"gopkg.in/natefinch/lumberjack.v2"
)
@@ -52,10 +54,9 @@ type homeContext struct {
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module
dnsFilter *filtering.DNSFilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
// etcHosts is an IP-hostname pairs set taken from system configuration
@@ -140,7 +141,12 @@ func setupContext(args options) {
checkPermissions()
}
initConfig()
switch version.Channel() {
case version.ChannelEdge, version.ChannelDevelopment:
config.BetaBindPort = 3001
default:
// Go on.
}
Context.tlsRoots = LoadSystemRootCAs()
Context.transport = &http.Transport{
@@ -265,6 +271,15 @@ func setupHostsContainer() (err error) {
}
func setupConfig(args options) (err error) {
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
config.DNS.DnsfilterConf.UserRules = slices.Clone(config.UserRules)
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
@@ -384,8 +399,6 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home.
func run(args options, clientBuildFS fs.FS) {
var err error
// configure config filename
initConfigFilename(args)
@@ -396,7 +409,7 @@ func run(args options, clientBuildFS fs.FS) {
configureLogger(args)
// Print the first message after logger is configured.
log.Println(version.Full())
log.Info(version.Full())
log.Debug("current working directory is %s", Context.workDir)
if args.runningAsService {
log.Info("AdGuard Home is running as a service")
@@ -404,7 +417,7 @@ func run(args options, clientBuildFS fs.FS) {
setupContext(args)
err = configureOS(config)
err := configureOS(config)
fatalOnError(err)
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
@@ -442,9 +455,9 @@ func run(args options, clientBuildFS fs.FS) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = args.glinetMode
var arl *authRateLimiter
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
arl = newAuthRateLimiter(
rateLimiter = newAuthRateLimiter(
time.Duration(config.AuthBlockMin)*time.Minute,
config.AuthAttempts,
)
@@ -456,7 +469,7 @@ func run(args options, clientBuildFS fs.FS) {
sessFilename,
config.Users,
config.WebSessionTTLHours*60*60,
arl,
rateLimiter,
)
if Context.auth == nil {
log.Fatalf("Couldn't initialize Auth module")
@@ -641,14 +654,9 @@ func configureLogger(args options) {
log.Fatalf("cannot initialize syslog: %s", err)
}
} else {
logFilePath := filepath.Join(Context.workDir, ls.File)
if filepath.IsAbs(ls.File) {
logFilePath = ls.File
}
_, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
if err != nil {
log.Fatalf("cannot create a log file: %s", err)
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{
@@ -768,12 +776,12 @@ func printHTTPAddresses(proto string) {
}
port := config.BindPort
if proto == schemeHTTPS {
if proto == aghhttp.SchemeHTTPS {
port = tlsConf.PortHTTPS
}
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
if proto == schemeHTTPS && tlsConf.ServerName != "" {
if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
return

View File

@@ -1,10 +1,8 @@
package home
import (
"fmt"
"io"
"encoding/json"
"net/http"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
@@ -51,43 +49,35 @@ var allowedLanguages = stringutil.NewSet(
"zh-tw",
)
func handleI18nCurrentLanguage(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
log.Printf("config.Language is %s", config.Language)
_, err := fmt.Fprintf(w, "%s\n", config.Language)
if err != nil {
msg := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusInternalServerError)
// languageJSON is the JSON structure for language requests and responses.
type languageJSON struct {
Language string `json:"language"`
}
return
}
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("home: language is %s", config.Language)
_ = aghhttp.WriteJSONResponse(w, r, &languageJSON{
Language: config.Language,
})
}
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body)
if aghhttp.WriteTextPlainDeprecated(w, r) {
return
}
langReq := &languageJSON{}
err := json.NewDecoder(r.Body).Decode(langReq)
if err != nil {
msg := fmt.Sprintf("failed to read request body: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusBadRequest)
aghhttp.Error(r, w, http.StatusInternalServerError, "reading req: %s", err)
return
}
language := strings.TrimSpace(string(body))
if language == "" {
msg := "empty language specified"
log.Println(msg)
http.Error(w, msg, http.StatusBadRequest)
return
}
if !allowedLanguages.Has(language) {
msg := fmt.Sprintf("unknown language specified: %s", language)
log.Println(msg)
http.Error(w, msg, http.StatusBadRequest)
lang := langReq.Language
if !allowedLanguages.Has(lang) {
aghhttp.Error(r, w, http.StatusBadRequest, "unknown language: %q", lang)
return
}
@@ -96,7 +86,8 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
config.Lock()
defer config.Unlock()
config.Language = language
config.Language = lang
log.Printf("home: language is set to %s", lang)
}()
onConfigModified()

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"path"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
case dnsProtoHTTPS:
dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{
Scheme: schemeHTTPS,
Scheme: aghhttp.SchemeHTTPS,
Host: d.ServerName,
Path: path.Join("/dns-query", clientID),
}

View File

@@ -11,6 +11,7 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
@@ -175,7 +176,8 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
chooseSystem()
action := opts.serviceControlAction
log.Printf("service: control action: %s", action)
log.Info(version.Full())
log.Info("service: control action: %s", action)
if action == "reload" {
sendSigReload()
@@ -277,7 +279,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.
AdGuard Home is now available at the following addresses:`)
printHTTPAddresses(schemeHTTP)
printHTTPAddresses(aghhttp.SchemeHTTP)
}
}

Some files were not shown because too many files have changed in this diff Show More