all: sync with master

This commit is contained in:
Ainar Garipov
2022-09-29 17:36:01 +03:00
parent 77d04d44eb
commit 21f6ed36fe
74 changed files with 1004 additions and 891 deletions

View File

@@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/dnsproxy/fastip"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -23,10 +22,9 @@ import (
yaml "gopkg.in/yaml.v3"
)
const (
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
)
// dataDir is the name of a directory under the working one to store some
// persistent data.
const dataDir = "data"
// logSettings are the logging settings part of the configuration file.
//
@@ -108,9 +106,16 @@ type configuration struct {
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"`
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
// Filters reflects the filters from [filtering.Config]. It's cloned to the
// config used in the filtering module at the startup. Afterwards it's
// cloned from the filtering module back here.
//
// TODO(e.burkov): Move all the filtering configuration fields into the
// only configuration subsection covering the changes with a single
// migration.
Filters []filtering.FilterYAML `yaml:"filters"`
WhitelistFilters []filtering.FilterYAML `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
DHCP *dhcpd.ServerConfig `yaml:"dhcp"`
@@ -145,9 +150,7 @@ type dnsConfig struct {
dnsforward.FilteringConfig `yaml:",inline"`
FilteringEnabled bool `yaml:"filtering_enabled"` // whether or not use filter lists
FiltersUpdateIntervalHours uint32 `yaml:"filters_update_interval"` // time period to update filters (in hours)
DnsfilterConf filtering.Config `yaml:",inline"`
DnsfilterConf *filtering.Config `yaml:",inline"`
// UpstreamTimeout is the timeout for querying upstream servers.
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
@@ -193,15 +196,20 @@ type tlsConfigSettings struct {
//
// TODO(a.garipov, e.burkov): This global is awful and must be removed.
var config = &configuration{
BindPort: 3000,
BetaBindPort: 0,
BindHost: net.IP{0, 0, 0, 0},
AuthAttempts: 5,
AuthBlockMin: 15,
BindPort: 3000,
BetaBindPort: 0,
BindHost: net.IP{0, 0, 0, 0},
AuthAttempts: 5,
AuthBlockMin: 15,
WebSessionTTLHours: 30 * 24,
DNS: dnsConfig{
BindHosts: []net.IP{{0, 0, 0, 0}},
Port: defaultPortDNS,
StatsInterval: 1,
BindHosts: []net.IP{{0, 0, 0, 0}},
Port: defaultPortDNS,
StatsInterval: 1,
QueryLogEnabled: true,
QueryLogFileEnabled: true,
QueryLogInterval: timeutil.Duration{Duration: 90 * timeutil.Day},
QueryLogMemSize: 1000,
FilteringConfig: dnsforward.FilteringConfig{
ProtectionEnabled: true, // whether or not use any of filtering features
BlockingMode: dnsforward.BlockingModeDefault,
@@ -222,18 +230,42 @@ var config = &configuration{
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
MaxGoroutines: 300,
},
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
DnsfilterConf: &filtering.Config{
SafeBrowsingCacheSize: 1 * 1024 * 1024,
SafeSearchCacheSize: 1 * 1024 * 1024,
ParentalCacheSize: 1 * 1024 * 1024,
CacheTime: 30,
FilteringEnabled: true,
FiltersUpdateIntervalHours: 24,
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
},
TLS: tlsConfigSettings{
PortHTTPS: defaultPortHTTPS,
PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy
PortDNSOverQUIC: defaultPortQUIC,
},
Filters: []filtering.FilterYAML{{
Filter: filtering.Filter{ID: 1},
Enabled: true,
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
Name: "AdGuard DNS filter",
}, {
Filter: filtering.Filter{ID: 2},
Enabled: false,
URL: "https://adaway.org/hosts.txt",
Name: "AdAway Default Blocklist",
}},
DHCP: &dhcpd.ServerConfig{
LocalDomainName: "lan",
Conf4: dhcpd.V4ServerConf{
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
ICMPTimeout: dhcpd.DefaultDHCPTimeoutICMP,
},
Conf6: dhcpd.V6ServerConf{
LeaseDuration: dhcpd.DefaultDHCPLeaseTTL,
},
},
Clients: &clientsConfig{
Sources: &clientSourcesConf{
@@ -255,31 +287,6 @@ var config = &configuration{
SchemaVersion: currentSchemaVersion,
}
// initConfig initializes default configuration for the current OS&ARCH
func initConfig() {
config.WebSessionTTLHours = 30 * 24
config.DNS.QueryLogEnabled = true
config.DNS.QueryLogFileEnabled = true
config.DNS.QueryLogInterval = timeutil.Duration{Duration: 90 * timeutil.Day}
config.DNS.QueryLogMemSize = 1000
config.DNS.CacheSize = 4 * 1024 * 1024
config.DNS.DnsfilterConf.SafeBrowsingCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.DHCP.Conf4.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
config.DHCP.Conf4.ICMPTimeout = dhcpd.DefaultDHCPTimeoutICMP
config.DHCP.Conf6.LeaseDuration = dhcpd.DefaultDHCPLeaseTTL
if ch := version.Channel(); ch == version.ChannelEdge || ch == version.ChannelDevelopment {
config.BetaBindPort = 3001
}
}
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename)
@@ -348,8 +355,8 @@ func parseConfig() (err error) {
return fmt.Errorf("validating udp ports: %w", err)
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24
if !filtering.ValidateUpdateIvl(config.DNS.DnsfilterConf.FiltersUpdateIntervalHours) {
config.DNS.DnsfilterConf.FiltersUpdateIntervalHours = 24
}
if config.DNS.UpstreamTimeout.Duration == 0 {
@@ -418,10 +425,11 @@ func (c *configuration) write() (err error) {
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
}
if Context.dnsFilter != nil {
c := filtering.Config{}
Context.dnsFilter.WriteDiskConfig(&c)
config.DNS.DnsfilterConf = c
if Context.filters != nil {
Context.filters.WriteDiskConfig(config.DNS.DnsfilterConf)
config.Filters = config.DNS.DnsfilterConf.Filters
config.WhitelistFilters = config.DNS.DnsfilterConf.WhitelistFilters
config.UserRules = config.DNS.DnsfilterConf.UserRules
}
if s := Context.dnsServer; s != nil {

View File

@@ -291,7 +291,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
}
httpsURL := &url.URL{
Scheme: schemeHTTPS,
Scheme: aghhttp.SchemeHTTPS,
Host: hostPort,
Path: r.URL.Path,
RawQuery: r.URL.RawQuery,
@@ -307,7 +307,7 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
//
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin.
originURL := &url.URL{
Scheme: schemeHTTP,
Scheme: aghhttp.SchemeHTTP,
Host: r.Host,
}
w.Header().Set("Access-Control-Allow-Origin", originURL.String())

View File

@@ -1,503 +0,0 @@
package home
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// validateFilterURL validates the filter list URL or file name.
func validateFilterURL(urlStr string) (err error) {
if filepath.IsAbs(urlStr) {
_, err = os.Stat(urlStr)
if err != nil {
return fmt.Errorf("checking filter file: %w", err)
}
return nil
}
url, err := url.ParseRequestURI(urlStr)
if err != nil {
return fmt.Errorf("checking filter url: %w", err)
}
if s := url.Scheme; s != schemeHTTP && s != schemeHTTPS {
return fmt.Errorf("checking filter url: invalid scheme %q", s)
}
return nil
}
type filterAddJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
err = validateFilterURL(fj.URL)
if err != nil {
err = fmt.Errorf("invalid url: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
// Check for duplicates
if filterExists(fj.URL) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return
}
// Set necessary properties
filt := filter{
Enabled: true,
URL: fj.URL,
Name: fj.Name,
white: fj.Whitelist,
}
filt.ID = assignUniqueFilterID()
// Download the filter contents
ok, err := f.update(&filt)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Couldn't fetch filter from url %s: %s",
filt.URL,
err,
)
return
}
if !ok {
aghhttp.Error(
r,
w,
http.StatusBadRequest,
"Filter at the url %s is invalid (maybe it points to blank page?)",
filt.URL,
)
return
}
// URL is assumed valid so append it to filters, update config, write new
// file and reload it to engines.
if !filterAdd(filt) {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return
}
onConfigModified()
enableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse request body json: %s", err)
return
}
config.Lock()
filters := &config.Filters
if req.Whitelist {
filters = &config.WhitelistFilters
}
var deleted filter
var newFilters []filter
for _, f := range *filters {
if f.URL != req.URL {
newFilters = append(newFilters, f)
continue
}
deleted = f
path := f.Path()
err = os.Rename(path, path+".old")
if err != nil {
log.Error("deleting filter %q: %s", path, err)
}
}
*filters = newFilters
config.Unlock()
onConfigModified()
enableFilters(true)
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
// necessary, but will require the additional complicated code to run
// after enableFilters is done.
//
// TODO(a.garipov): Make sure the above comment is true.
_, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "couldn't write body: %s", err)
}
}
type filterURLReqData struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type filterURLReq struct {
Data *filterURLReqData `json:"data"`
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
if fj.Data == nil {
err = errors.Error("data cannot be null")
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = validateFilterURL(fj.Data.URL)
if err != nil {
err = fmt.Errorf("invalid url: %s", err)
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
filt := filter{
Enabled: fj.Data.Enabled,
Name: fj.Data.Name,
URL: fj.Data.URL,
}
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
return
}
if (status & statusURLExists) != 0 {
http.Error(w, "URL already exists", http.StatusBadRequest)
return
}
onConfigModified()
restart := (status & statusEnabledChanged) != 0
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules
flags := filterRefreshBlocklists
if fj.Whitelist {
flags = filterRefreshAllowlists
}
nUpdated, _ := f.refreshFilters(flags, true)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves
restart = false
if nUpdated == 0 {
restart = true
}
}
if restart {
enableFilters(true)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
// This use of ReadAll is safe, because request's body is now limited.
body, err := io.ReadAll(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
config.UserRules = strings.Split(string(body), "\n")
onConfigModified()
enableFilters(true)
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct {
White bool `json:"whitelist"`
}
type Resp struct {
Updated int `json:"updated"`
}
resp := Resp{}
var err error
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
flags := filterRefreshBlocklists
if req.White {
flags = filterRefreshAllowlists
}
func() {
// Temporarily unlock the Context.controlLock because the
// f.refreshFilters waits for it to be unlocked but it's
// actually locked in ensure wrapper.
//
// TODO(e.burkov): Reconsider this messy syncing process.
Context.controlLock.Unlock()
defer Context.controlLock.Lock()
resp.Updated, err = f.refreshFilters(flags|filterRefreshForce, false)
}()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
js, err := json.Marshal(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
type filterJSON struct {
URL string `json:"url"`
Name string `json:"name"`
LastUpdated string `json:"last_updated,omitempty"`
ID int64 `json:"id"`
RulesCount uint32 `json:"rules_count"`
Enabled bool `json:"enabled"`
}
type filteringConfig struct {
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
Interval uint32 `json:"interval"` // in hours
Enabled bool `json:"enabled"`
}
func filterToJSON(f filter) filterJSON {
fj := filterJSON{
ID: f.ID,
Enabled: f.Enabled,
URL: f.URL,
Name: f.Name,
RulesCount: uint32(f.RulesCount),
}
if !f.LastUpdated.IsZero() {
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
}
return fj
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{}
config.RLock()
resp.Enabled = config.DNS.FilteringEnabled
resp.Interval = config.DNS.FiltersUpdateIntervalHours
for _, f := range config.Filters {
fj := filterToJSON(f)
resp.Filters = append(resp.Filters, fj)
}
for _, f := range config.WhitelistFilters {
fj := filterToJSON(f)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
resp.UserRules = config.UserRules
config.RUnlock()
jsonVal, err := json.Marshal(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkFiltersUpdateIntervalHours(req.Interval) {
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
func() {
config.Lock()
defer config.Unlock()
config.DNS.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval
}()
onConfigModified()
enableFilters(true)
}
type checkHostRespRule struct {
Text string `json:"text"`
FilterListID int64 `json:"filter_list_id"`
}
type checkHostResp struct {
Reason string `json:"reason"`
// Rule is the text of the matched rule.
//
// Deprecated: Use Rules[*].Text.
Rule string `json:"rule"`
Rules []*checkHostRespRule `json:"rules"`
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for Rewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
// FilterID is the ID of the rule's filter list.
//
// Deprecated: Use Rules[*].FilterListID.
FilterID int64 `json:"filter_id"`
}
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"couldn't apply filtering: %s: %s",
host,
err,
)
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
if len(result.Rules) > 0 {
resp.FilterID = result.Rules[0].FilterListID
resp.Rule = result.Rules[0].Text
}
resp.Rules = make([]*checkHostRespRule, len(result.Rules))
for i, r := range result.Rules {
resp.Rules[i] = &checkHostRespRule{
FilterListID: r.FilterListID,
Text: r.Text,
}
}
js, err := json.Marshal(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
// RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() {
httpRegister(http.MethodGet, "/control/filtering/status", f.handleFilteringStatus)
httpRegister(http.MethodPost, "/control/filtering/config", f.handleFilteringConfig)
httpRegister(http.MethodPost, "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister(http.MethodPost, "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister(http.MethodPost, "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister(http.MethodPost, "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister(http.MethodPost, "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister(http.MethodGet, "/control/filtering/check_host", f.handleCheckHost)
}
func checkFiltersUpdateIntervalHours(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

View File

@@ -31,7 +31,10 @@ const (
// Called by other modules when configuration is changed
func onConfigModified() {
_ = config.write()
err := config.write()
if err != nil {
log.Error("writing config: %s", err)
}
}
// initDNSServer creates an instance of the dnsforward.Server
@@ -71,11 +74,11 @@ func initDNSServer() (err error) {
}
Context.queryLog = querylog.New(conf)
filterConf := config.DNS.DnsfilterConf
filterConf.EtcHosts = Context.etcHosts
filterConf.ConfigModified = onConfigModified
filterConf.HTTPRegister = httpRegister
Context.dnsFilter = filtering.New(&filterConf, nil)
Context.filters, err = filtering.New(config.DNS.DnsfilterConf, nil)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
var privateNets netutil.SubnetSet
switch len(config.DNS.PrivateNets) {
@@ -83,13 +86,10 @@ func initDNSServer() (err error) {
// Use an optimized locally-served matcher.
privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
case 1:
var n *net.IPNet
n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
privateNets, err = netutil.ParseSubnet(config.DNS.PrivateNets[0])
if err != nil {
return fmt.Errorf("preparing the set of private subnets: %w", err)
}
privateNets = n
default:
var nets []*net.IPNet
nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...)
@@ -101,15 +101,13 @@ func initDNSServer() (err error) {
}
p := dnsforward.DNSCreateParams{
DNSFilter: Context.dnsFilter,
DNSFilter: Context.filters,
Stats: Context.stats,
QueryLog: Context.queryLog,
PrivateNets: privateNets,
Anonymizer: anonymizer,
LocalDomain: config.DHCP.LocalDomainName,
}
if Context.dhcpServer != nil {
p.DHCPServer = Context.dhcpServer
DHCPServer: Context.dhcpServer,
}
Context.dnsServer, err = dnsforward.NewServer(p)
@@ -143,7 +141,6 @@ func initDNSServer() (err error) {
Context.whois = initWHOIS(&Context.clients)
}
Context.filters.Init()
return nil
}
@@ -244,7 +241,6 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
}
newConf.TLSv12Roots = Context.tlsRoots
newConf.TLSCiphers = Context.tlsCiphers
newConf.TLSAllowUnencryptedDoH = tlsConf.AllowUnencryptedDoH
newConf.FilterHandler = applyAdditionalFiltering
@@ -336,9 +332,12 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering.Settings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
log.Debug("looking up settings for client with ip %s and clientid %q", clientIP, clientID)
Context.filters.ApplyBlockedServices(setts, nil)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
if clientIP == nil {
return
@@ -350,16 +349,16 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
if !ok {
c, ok = Context.clients.Find(clientIP.String())
if !ok {
log.Debug("client with ip %s and clientid %q not found", clientIP, clientID)
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
return
}
}
log.Debug("using settings for client %q with ip %s and clientid %q", c.Name, clientIP, clientID)
log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID)
if c.UseOwnBlockedServices {
Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false)
Context.filters.ApplyBlockedServices(setts, c.BlockedServices)
}
setts.ClientName = c.Name
@@ -382,7 +381,7 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
enableFiltersLocked(false)
Context.filters.EnableFilters(false)
Context.clients.Start()
@@ -391,7 +390,6 @@ func startDNSServer() error {
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
}
Context.dnsFilter.Start()
Context.filters.Start()
Context.stats.Start()
Context.queryLog.Start()
@@ -450,10 +448,7 @@ func closeDNSServer() {
Context.dnsServer = nil
}
if Context.dnsFilter != nil {
Context.dnsFilter.Close()
Context.dnsFilter = nil
}
Context.filters.Close()
if Context.stats != nil {
err := Context.stats.Close()
@@ -470,7 +465,5 @@ func closeDNSServer() {
Context.queryLog = nil
}
Context.filters.Close()
log.Debug("Closed all DNS modules")
log.Debug("all dns modules are closed")
}

View File

@@ -1,746 +0,0 @@
package home
import (
"bufio"
"fmt"
"hash/crc32"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
)
var nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID
// Filtering - module object
type Filtering struct {
// conf FilteringConf
refreshStatus uint32 // 0:none; 1:in progress
refreshLock sync.Mutex
filterTitleRegexp *regexp.Regexp
}
// Init - initialize the module
func (f *Filtering) Init() {
f.filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`)
_ = os.MkdirAll(filepath.Join(Context.getDataDir(), filterDir), 0o755)
f.loadFilters(config.Filters)
f.loadFilters(config.WhitelistFilters)
deduplicateFilters()
updateUniqueFilterID(config.Filters)
updateUniqueFilterID(config.WhitelistFilters)
}
// Start - start the module
func (f *Filtering) Start() {
f.RegisterFilteringHandlers()
// Here we should start updating filters,
// but currently we can't wake up the periodic task to do so.
// So for now we just start this periodic task from here.
go f.periodicallyRefreshFilters()
}
// Close - close the module
func (f *Filtering) Close() {
}
func defaultFilters() []filter {
return []filter{
{Filter: filtering.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard DNS filter"},
{Filter: filtering.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway Default Blocklist"},
}
}
// field ordering is important -- yaml fields will mirror ordering from here
type filter struct {
Enabled bool
URL string // URL or a file path
Name string `yaml:"name"`
RulesCount int `yaml:"-"`
LastUpdated time.Time `yaml:"-"`
checksum uint32 // checksum of the file data
white bool
filtering.Filter `yaml:",inline"`
}
const (
statusFound = 1
statusEnabledChanged = 2
statusURLChanged = 4
statusURLExists = 8
statusUpdateRequired = 0x10
)
// Update properties for a filter specified by its URL
// Return status* flags.
func (f *Filtering) filterSetProperties(url string, newf filter, whitelist bool) int {
r := 0
config.Lock()
defer config.Unlock()
filters := &config.Filters
if whitelist {
filters = &config.WhitelistFilters
}
for i := range *filters {
filt := &(*filters)[i]
if filt.URL != url {
continue
}
log.Debug("filter: set properties: %s: {%s %s %v}",
filt.URL, newf.Name, newf.URL, newf.Enabled)
filt.Name = newf.Name
if filt.URL != newf.URL {
r |= statusURLChanged | statusUpdateRequired
if filterExistsNoLock(newf.URL) {
return statusURLExists
}
filt.URL = newf.URL
filt.unload()
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
}
if filt.Enabled != newf.Enabled {
r |= statusEnabledChanged
filt.Enabled = newf.Enabled
if filt.Enabled {
if (r & statusURLChanged) == 0 {
e := f.load(filt)
if e != nil {
// This isn't a fatal error,
// because it may occur when someone removes the file from disk.
filt.LastUpdated = time.Time{}
filt.checksum = 0
filt.RulesCount = 0
r |= statusUpdateRequired
}
}
} else {
filt.unload()
}
}
return r | statusFound
}
return 0
}
// Return TRUE if a filter with this URL exists
func filterExists(url string) bool {
config.RLock()
r := filterExistsNoLock(url)
config.RUnlock()
return r
}
func filterExistsNoLock(url string) bool {
for _, f := range config.Filters {
if f.URL == url {
return true
}
}
for _, f := range config.WhitelistFilters {
if f.URL == url {
return true
}
}
return false
}
// Add a filter
// Return FALSE if a filter with this URL exists
func filterAdd(f filter) bool {
config.Lock()
defer config.Unlock()
// Check for duplicates
if filterExistsNoLock(f.URL) {
return false
}
if f.white {
config.WhitelistFilters = append(config.WhitelistFilters, f)
} else {
config.Filters = append(config.Filters, f)
}
return true
}
// Load filters from the disk
// And if any filter has zero ID, assign a new one
func (f *Filtering) loadFilters(array []filter) {
for i := range array {
filter := &array[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
if !filter.Enabled {
// No need to load a filter that is not enabled
continue
}
err := f.load(filter)
if err != nil {
log.Error("Couldn't load filter %d contents due to %s", filter.ID, err)
}
}
}
func deduplicateFilters() {
// Deduplicate filters
i := 0 // output index, used for deletion later
urls := map[string]bool{}
for _, filter := range config.Filters {
if _, ok := urls[filter.URL]; !ok {
// we didn't see it before, keep it
urls[filter.URL] = true // remember the URL
config.Filters[i] = filter
i++
}
}
// all entries we want to keep are at front, delete the rest
config.Filters = config.Filters[:i]
}
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
}
}
}
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID++
return value
}
// Sets up a timer that will be checking for filters updates periodically
func (f *Filtering) periodicallyRefreshFilters() {
const maxInterval = 1 * 60 * 60
intval := 5 // use a dynamically increasing time interval
for {
isNetworkErr := false
if config.DNS.FiltersUpdateIntervalHours != 0 && atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1) {
f.refreshLock.Lock()
_, isNetworkErr = f.refreshFiltersIfNecessary(filterRefreshBlocklists | filterRefreshAllowlists)
f.refreshLock.Unlock()
f.refreshStatus = 0
if !isNetworkErr {
intval = maxInterval
}
}
if isNetworkErr {
intval *= 2
if intval > maxInterval {
intval = maxInterval
}
}
time.Sleep(time.Duration(intval) * time.Second)
}
}
// Refresh filters
// flags: filterRefresh*
// important:
//
// TRUE: ignore the fact that we're currently updating the filters
func (f *Filtering) refreshFilters(flags int, important bool) (int, error) {
set := atomic.CompareAndSwapUint32(&f.refreshStatus, 0, 1)
if !important && !set {
return 0, fmt.Errorf("filters update procedure is already running")
}
f.refreshLock.Lock()
nUpdated, _ := f.refreshFiltersIfNecessary(flags)
f.refreshLock.Unlock()
f.refreshStatus = 0
return nUpdated, nil
}
func (f *Filtering) refreshFiltersArray(filters *[]filter, force bool) (int, []filter, []bool, bool) {
var updateFilters []filter
var updateFlags []bool // 'true' if filter data has changed
now := time.Now()
config.RLock()
for i := range *filters {
f := &(*filters)[i] // otherwise we will be operating on a copy
if !f.Enabled {
continue
}
expireTime := f.LastUpdated.Unix() + int64(config.DNS.FiltersUpdateIntervalHours)*60*60
if !force && expireTime > now.Unix() {
continue
}
var uf filter
uf.ID = f.ID
uf.URL = f.URL
uf.Name = f.Name
uf.checksum = f.checksum
updateFilters = append(updateFilters, uf)
}
config.RUnlock()
if len(updateFilters) == 0 {
return 0, nil, nil, false
}
nfail := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated, err := f.update(uf)
updateFlags = append(updateFlags, updated)
if err != nil {
nfail++
log.Printf("Failed to update filter %s: %s\n", uf.URL, err)
continue
}
}
if nfail == len(updateFilters) {
return 0, nil, nil, true
}
updateCount := 0
for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
config.Lock()
for k := range *filters {
f := &(*filters)[k]
if f.ID != uf.ID || f.URL != uf.URL {
continue
}
f.LastUpdated = uf.LastUpdated
if !updated {
continue
}
log.Info("Updated filter #%d. Rules: %d -> %d",
f.ID, f.RulesCount, uf.RulesCount)
f.Name = uf.Name
f.RulesCount = uf.RulesCount
f.checksum = uf.checksum
updateCount++
}
config.Unlock()
}
return updateCount, updateFilters, updateFlags, false
}
const (
filterRefreshForce = 1 // ignore last file modification date
filterRefreshAllowlists = 2 // update allow-lists
filterRefreshBlocklists = 4 // update block-lists
)
// refreshFiltersIfNecessary checks filters and updates them if necessary. If
// force is true, it ignores the filter.LastUpdated field value.
//
// Algorithm:
//
// 1. Get the list of filters to be updated. For each filter, run the download
// and checksum check operation. Store downloaded data in a temporary file
// inside data/filters directory
//
// 2. For each filter, if filter data hasn't changed, just set new update time
// on file. Otherwise, rename the temporary file (<temp> -> 1.txt). Note
// that this method works only on Unix systems. On Windows, don't pass
// files to filtering, pass the whole data.
//
// refreshFiltersIfNecessary returns the number of updated filters. It also
// returns true if there was a network error and nothing could be updated.
//
// TODO(a.garipov, e.burkov): What the hell?
func (f *Filtering) refreshFiltersIfNecessary(flags int) (int, bool) {
log.Debug("Filters: updating...")
updateCount := 0
var updateFilters []filter
var updateFlags []bool
netError := false
netErrorW := false
force := false
if (flags & filterRefreshForce) != 0 {
force = true
}
if (flags & filterRefreshBlocklists) != 0 {
updateCount, updateFilters, updateFlags, netError = f.refreshFiltersArray(&config.Filters, force)
}
if (flags & filterRefreshAllowlists) != 0 {
updateCountW := 0
var updateFiltersW []filter
var updateFlagsW []bool
updateCountW, updateFiltersW, updateFlagsW, netErrorW = f.refreshFiltersArray(&config.WhitelistFilters, force)
updateCount += updateCountW
updateFilters = append(updateFilters, updateFiltersW...)
updateFlags = append(updateFlags, updateFlagsW...)
}
if netError && netErrorW {
return 0, true
}
if updateCount != 0 {
enableFilters(false)
for i := range updateFilters {
uf := &updateFilters[i]
updated := updateFlags[i]
if !updated {
continue
}
_ = os.Remove(uf.Path() + ".old")
}
}
log.Debug("Filters: update finished")
return updateCount, false
}
// Allows printable UTF-8 text with CR, LF, TAB characters
func isPrintableText(data []byte, len int) bool {
for i := 0; i < len; i++ {
c := data[i]
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
continue
}
return false
}
return true
}
// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any)
func (f *Filtering) parseFilterContents(file io.Reader) (int, uint32, string) {
rulesCount := 0
name := ""
seenTitle := false
r := bufio.NewReader(file)
checksum := uint32(0)
for {
line, err := r.ReadString('\n')
checksum = crc32.Update(checksum, crc32.IEEETable, []byte(line))
line = strings.TrimSpace(line)
if len(line) == 0 {
//
} else if line[0] == '!' {
m := f.filterTitleRegexp.FindAllStringSubmatch(line, -1)
if len(m) > 0 && len(m[0]) >= 2 && !seenTitle {
name = m[0][1]
seenTitle = true
}
} else if line[0] == '#' {
//
} else {
rulesCount++
}
if err != nil {
break
}
}
return rulesCount, checksum, name
}
// Perform upgrade on a filter and update LastUpdated value
func (f *Filtering) update(filter *filter) (bool, error) {
b, err := f.updateIntl(filter)
filter.LastUpdated = time.Now()
if !b {
e := os.Chtimes(filter.Path(), filter.LastUpdated, filter.LastUpdated)
if e != nil {
log.Error("os.Chtimes(): %v", e)
}
}
return b, err
}
func (f *Filtering) read(reader io.Reader, tmpFile *os.File, filter *filter) (int, error) {
htmlTest := true
firstChunk := make([]byte, 4*1024)
firstChunkLen := 0
buf := make([]byte, 64*1024)
total := 0
for {
n, err := reader.Read(buf)
total += n
if htmlTest {
num := len(firstChunk) - firstChunkLen
if n < num {
num = n
}
copied := copy(firstChunk[firstChunkLen:], buf[:num])
firstChunkLen += copied
if firstChunkLen == len(firstChunk) || err == io.EOF {
if !isPrintableText(firstChunk, firstChunkLen) {
return total, fmt.Errorf("data contains non-printable characters")
}
s := strings.ToLower(string(firstChunk))
if strings.Contains(s, "<html") || strings.Contains(s, "<!doctype") {
return total, fmt.Errorf("data is HTML, not plain text")
}
htmlTest = false
firstChunk = nil
}
}
_, err2 := tmpFile.Write(buf[:n])
if err2 != nil {
return total, err2
}
if err == io.EOF {
return total, nil
}
if err != nil {
log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err)
return total, err
}
}
}
// finalizeUpdate closes and gets rid of temporary file f with filter's content
// according to updated. It also saves new values of flt's name, rules number
// and checksum if sucсeeded.
func finalizeUpdate(
f *os.File,
flt *filter,
updated bool,
name string,
rnum int,
cs uint32,
) (err error) {
tmpFileName := f.Name()
// Close the file before renaming it because it's required on Windows.
//
// See https://github.com/adguardTeam/adGuardHome/issues/1553.
if err = f.Close(); err != nil {
return fmt.Errorf("closing temporary file: %w", err)
}
if !updated {
log.Tracef("filter #%d from %s has no changes, skip", flt.ID, flt.URL)
return os.Remove(tmpFileName)
}
log.Printf("saving filter %d contents to: %s", flt.ID, flt.Path())
if err = os.Rename(tmpFileName, flt.Path()); err != nil {
return errors.WithDeferred(err, os.Remove(tmpFileName))
}
flt.Name = stringutil.Coalesce(flt.Name, name)
flt.checksum = cs
flt.RulesCount = rnum
return nil
}
// processUpdate copies filter's content from src to dst and returns the name,
// rules number, and checksum for it. It also returns the number of bytes read
// from src.
func (f *Filtering) processUpdate(
src io.Reader,
dst *os.File,
flt *filter,
) (name string, rnum int, cs uint32, n int, err error) {
if n, err = f.read(src, dst, flt); err != nil {
return "", 0, 0, 0, err
}
if _, err = dst.Seek(0, io.SeekStart); err != nil {
return "", 0, 0, 0, err
}
rnum, cs, name = f.parseFilterContents(dst)
return name, rnum, cs, n, nil
}
// updateIntl updates the flt rewriting it's actual file. It returns true if
// the actual update has been performed.
func (f *Filtering) updateIntl(flt *filter) (ok bool, err error) {
log.Tracef("downloading update for filter %d from %s", flt.ID, flt.URL)
var name string
var rnum, n int
var cs uint32
var tmpFile *os.File
tmpFile, err = os.CreateTemp(filepath.Join(Context.getDataDir(), filterDir), "")
if err != nil {
return false, err
}
defer func() {
err = errors.WithDeferred(err, finalizeUpdate(tmpFile, flt, ok, name, rnum, cs))
ok = ok && err == nil
if ok {
log.Printf("updated filter %d: %d bytes, %d rules", flt.ID, n, rnum)
}
}()
// Change the default 0o600 permission to something more acceptable by
// end users.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3198.
if err = tmpFile.Chmod(0o644); err != nil {
return false, fmt.Errorf("changing file mode: %w", err)
}
var r io.Reader
if filepath.IsAbs(flt.URL) {
var file io.ReadCloser
file, err = os.Open(flt.URL)
if err != nil {
return false, fmt.Errorf("open file: %w", err)
}
defer func() { err = errors.WithDeferred(err, file.Close()) }()
r = file
} else {
var resp *http.Response
resp, err = Context.client.Get(flt.URL)
if err != nil {
log.Printf("requesting filter from %s, skip: %s", flt.URL, err)
return false, err
}
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
if resp.StatusCode != http.StatusOK {
log.Printf("got status code %d from %s, skip", resp.StatusCode, flt.URL)
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
r = resp.Body
}
name, rnum, cs, n, err = f.processUpdate(r, tmpFile, flt)
return cs != flt.checksum, err
}
// loads filter contents from the file in dataDir
func (f *Filtering) load(filter *filter) (err error) {
filterFilePath := filter.Path()
log.Tracef("filtering: loading filter %d contents to: %s", filter.ID, filterFilePath)
file, err := os.Open(filterFilePath)
if errors.Is(err, os.ErrNotExist) {
// Do nothing, file doesn't exist.
return nil
} else if err != nil {
return fmt.Errorf("opening filter file: %w", err)
}
defer func() { err = errors.WithDeferred(err, file.Close()) }()
st, err := file.Stat()
if err != nil {
return fmt.Errorf("getting filter file stat: %w", err)
}
log.Tracef("filtering: File %s, id %d, length %d", filterFilePath, filter.ID, st.Size())
rulesCount, checksum, _ := f.parseFilterContents(file)
filter.RulesCount = rulesCount
filter.checksum = checksum
filter.LastUpdated = st.ModTime()
return nil
}
// Clear filter rules
func (filter *filter) unload() {
filter.RulesCount = 0
filter.checksum = 0
}
// Path to the filter contents
func (filter *filter) Path() string {
return filepath.Join(Context.getDataDir(), filterDir, strconv.FormatInt(filter.ID, 10)+".txt")
}
func enableFilters(async bool) {
config.RLock()
defer config.RUnlock()
enableFiltersLocked(async)
}
func enableFiltersLocked(async bool) {
filters := []filtering.Filter{{
ID: filtering.CustomListID,
Data: []byte(strings.Join(config.UserRules, "\n")),
}}
for _, filter := range config.Filters {
if !filter.Enabled {
continue
}
filters = append(filters, filtering.Filter{
ID: filter.ID,
FilePath: filter.Path(),
})
}
var allowFilters []filtering.Filter
for _, filter := range config.WhitelistFilters {
if !filter.Enabled {
continue
}
allowFilters = append(allowFilters, filtering.Filter{
ID: filter.ID,
FilePath: filter.Path(),
})
}
if err := Context.dnsFilter.SetFilters(filters, allowFilters, async); err != nil {
log.Debug("enabling filters: %s", err)
}
Context.dnsFilter.SetEnabled(config.DNS.FilteringEnabled)
}

View File

@@ -1,115 +0,0 @@
package home
import (
"io/fs"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testFltsFileName = "1.txt"
func testStartFilterListener(t *testing.T, fltContent *[]byte) (l net.Listener) {
t.Helper()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
n, werr := w.Write(*fltContent)
require.NoError(t, werr)
require.Equal(t, len(*fltContent), n)
})
var err error
l, err = net.Listen("tcp", ":0")
require.NoError(t, err)
go func() {
_ = http.Serve(l, h)
}()
testutil.CleanupAndRequireSuccess(t, l.Close)
return l
}
func TestFilters(t *testing.T) {
const content = `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
fltContent := []byte(content)
l := testStartFilterListener(t, &fltContent)
Context = homeContext{
workDir: t.TempDir(),
client: &http.Client{
Timeout: 5 * time.Second,
},
}
Context.filters.Init()
f := &filter{
URL: (&url.URL{
Scheme: "http",
Host: (&netutil.IPPort{
IP: net.IP{127, 0, 0, 1},
Port: l.Addr().(*net.TCPAddr).Port,
}).String(),
Path: path.Join(filterDir, testFltsFileName),
}).String(),
}
updateAndAssert := func(t *testing.T, want require.BoolAssertionFunc, wantRulesCount int) {
ok, err := Context.filters.update(f)
require.NoError(t, err)
want(t, ok)
assert.Equal(t, wantRulesCount, f.RulesCount)
var dir []fs.DirEntry
dir, err = os.ReadDir(filepath.Join(Context.getDataDir(), filterDir))
require.NoError(t, err)
assert.Len(t, dir, 1)
require.FileExists(t, f.Path())
err = Context.filters.load(f)
require.NoError(t, err)
}
t.Run("download", func(t *testing.T) {
updateAndAssert(t, require.True, 3)
})
t.Run("refresh_idle", func(t *testing.T) {
updateAndAssert(t, require.False, 3)
})
t.Run("refresh_actually", func(t *testing.T) {
fltContent = []byte(`||example.com^`)
t.Cleanup(func() { fltContent = []byte(content) })
updateAndAssert(t, require.True, 1)
})
t.Run("load_unload", func(t *testing.T) {
err := Context.filters.load(f)
require.NoError(t, err)
f.unload()
})
require.NoError(t, os.Remove(f.Path()))
}

View File

@@ -20,8 +20,10 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -32,6 +34,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
"gopkg.in/natefinch/lumberjack.v2"
)
@@ -51,10 +54,9 @@ type homeContext struct {
dnsServer *dnsforward.Server // DNS module
rdns *RDNS // rDNS module
whois *WHOIS // WHOIS module
dnsFilter *filtering.DNSFilter // DNS filtering module
dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
// etcHosts is an IP-hostname pairs set taken from system configuration
@@ -78,7 +80,6 @@ type homeContext struct {
disableUpdate bool // If set, don't check for updates
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
tlsCiphers []uint16 // list of TLS ciphers to use
transport *http.Transport
client *http.Client
appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app
@@ -140,16 +141,21 @@ func setupContext(args options) {
checkPermissions()
}
initConfig()
switch version.Channel() {
case version.ChannelEdge, version.ChannelDevelopment:
config.BetaBindPort = 3001
default:
// Go on.
}
Context.tlsRoots = LoadSystemRootCAs()
Context.tlsCiphers = InitTLSCiphers()
Context.transport = &http.Transport{
DialContext: customDialContext,
Proxy: getHTTPProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
MinVersion: tls.VersionTLS12,
RootCAs: Context.tlsRoots,
CipherSuites: aghtls.SaferCipherSuites(),
MinVersion: tls.VersionTLS12,
},
}
Context.client = &http.Client{
@@ -265,6 +271,14 @@ func setupHostsContainer() (err error) {
}
func setupConfig(args options) (err error) {
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
config.DNS.DnsfilterConf.ConfigModified = onConfigModified
config.DNS.DnsfilterConf.HTTPRegister = httpRegister
config.DNS.DnsfilterConf.DataDir = Context.getDataDir()
config.DNS.DnsfilterConf.Filters = slices.Clone(config.Filters)
config.DNS.DnsfilterConf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DHCP.WorkDir = Context.workDir
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
@@ -384,8 +398,6 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home.
func run(args options, clientBuildFS fs.FS) {
var err error
// configure config filename
initConfigFilename(args)
@@ -404,7 +416,7 @@ func run(args options, clientBuildFS fs.FS) {
setupContext(args)
err = configureOS(config)
err := configureOS(config)
fatalOnError(err)
// clients package uses filtering package's static data (filtering.BlockedSvcKnown()),
@@ -763,12 +775,12 @@ func printHTTPAddresses(proto string) {
}
port := config.BindPort
if proto == schemeHTTPS {
if proto == aghhttp.SchemeHTTPS {
port = tlsConf.PortHTTPS
}
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
if proto == schemeHTTPS && tlsConf.ServerName != "" {
if proto == aghhttp.SchemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS, 0)
return

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"path"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
@@ -82,7 +83,7 @@ func encodeMobileConfig(d *dnsSettings, clientID string) ([]byte, error) {
case dnsProtoHTTPS:
dspName = fmt.Sprintf("%s DoH", d.ServerName)
u := &url.URL{
Scheme: schemeHTTPS,
Scheme: aghhttp.SchemeHTTPS,
Host: d.ServerName,
Path: path.Join("/dns-query", clientID),
}

View File

@@ -11,6 +11,7 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
@@ -277,7 +278,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.
AdGuard Home is now available at the following addresses:`)
printHTTPAddresses(schemeHTTP)
printHTTPAddresses(aghhttp.SchemeHTTP)
}
}

View File

@@ -26,7 +26,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/go-cmp/cmp"
"golang.org/x/sys/cpu"
)
var tlsWebHandlersRegistered = false
@@ -754,52 +753,3 @@ func LoadSystemRootCAs() (roots *x509.CertPool) {
return nil
}
// InitTLSCiphers performs the same work as initDefaultCipherSuites() from
// crypto/tls/common.go but don't uses lots of other default ciphers.
func InitTLSCiphers() (ciphers []uint16) {
// Check the cpu flags for each platform that has optimized GCM
// implementations. The worst case is when all these variables are
// false.
var (
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// Keep in sync with crypto/aes/cipher_s390x.go.
hasGCMAsmS390X = cpu.S390X.HasAES &&
cpu.S390X.HasAESCBC &&
cpu.S390X.HasAESCTR &&
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
hasGCMAsm = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X
)
if hasGCMAsm {
// If AES-GCM hardware is provided then prioritize AES-GCM
// cipher suites.
ciphers = []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
}
} else {
// Without AES-GCM hardware, we put the ChaCha20-Poly1305 cipher
// suites first.
ciphers = []uint16{
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
}
}
return append(
ciphers,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
)
}

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
@@ -160,7 +161,7 @@ func assertEqualExcept(t *testing.T, oldConf, newConf yobj, oldKeys, newKeys []s
}
func testDiskConf(schemaVersion int) (diskConf yobj) {
filters := []filter{{
filters := []filtering.FilterYAML{{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RulesCount: 100,

View File

@@ -9,17 +9,15 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/NYTimes/gziphandler"
)
// HTTP scheme constants.
const (
schemeHTTP = "http"
schemeHTTPS = "https"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
const (
@@ -163,15 +161,18 @@ func (web *Web) Start() {
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.shutdown {
printHTTPAddresses(schemeHTTP)
printHTTPAddresses(aghhttp.SchemeHTTP)
errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
// Create a new instance, because the Web is not usable after Shutdown.
hostStr := web.conf.BindHost.String()
// we need to have new instance, because after Shutdown() the Server is not usable
web.httpServer = &http.Server{
ErrorLog: log.StdLog("web: plain", log.DEBUG),
Addr: netutil.JoinHostPort(hostStr, web.conf.BindPort),
Handler: withMiddlewares(Context.mux, limitRequestBody),
Handler: hdlr,
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout,
@@ -201,10 +202,16 @@ func (web *Web) startBetaServer(hostStr string) {
return
}
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(
withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta),
&http2.Server{},
)
web.httpServerBeta = &http.Server{
ErrorLog: log.StdLog("web: plain: beta", log.DEBUG),
Addr: netutil.JoinHostPort(hostStr, web.conf.BetaBindPort),
Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta),
Handler: hdlr,
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout,
@@ -264,9 +271,9 @@ func (web *Web) tlsServerLoop() {
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
MinVersion: tls.VersionTLS12,
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCiphers,
CipherSuites: aghtls.SaferCipherSuites(),
MinVersion: tls.VersionTLS12,
},
Handler: withMiddlewares(Context.mux, limitRequestBody),
ReadTimeout: web.conf.ReadTimeout,
@@ -274,7 +281,7 @@ func (web *Web) tlsServerLoop() {
WriteTimeout: web.conf.WriteTimeout,
}
printHTTPAddresses(schemeHTTPS)
printHTTPAddresses(aghhttp.SchemeHTTPS)
err := web.httpsServer.server.ListenAndServeTLS("", "")
if err != http.ErrServerClosed {
cleanupAlways()