MITM proxy

This commit is contained in:
Simon Zolin
2020-08-18 19:23:33 +03:00
parent c3123473cf
commit f85de51452
21 changed files with 2116 additions and 491 deletions

View File

@@ -9,6 +9,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/AdGuardHome/mitmproxy"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/golibs/file"
@@ -17,8 +19,7 @@ import (
)
const (
dataDir = "data" // data storage
filterDir = "filters" // cache location for downloaded filters, it's under DataDir
dataDir = "data" // data storage
)
// logSettings
@@ -54,9 +55,13 @@ type configuration struct {
DNS dnsConfig `yaml:"dns"`
TLS tlsConfigSettings `yaml:"tls"`
Filters []filter `yaml:"filters"`
WhitelistFilters []filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
MITM mitmproxy.Config `yaml:"mitmproxy"`
Filters []filters.Filter `yaml:"filters"`
WhitelistFilters []filters.Filter `yaml:"whitelist_filters"`
UserRules []string `yaml:"user_rules"`
ProxyFilters []filters.Filter `yaml:"proxy_filters"`
DHCP dhcpd.ServerConfig `yaml:"dhcp"`
@@ -155,7 +160,43 @@ func initConfig() {
config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024
config.DNS.DnsfilterConf.CacheTime = 30
config.Filters = defaultFilters()
config.Filters = defaultDNSBlocklistFilters()
config.ProxyFilters = defaultContentFilters()
}
func defaultDNSBlocklistFilters() []filters.Filter {
return []filters.Filter{
{
ID: 1,
Enabled: true,
URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt",
Name: "AdGuard Simplified Domain Names filter",
},
{
ID: 2,
Enabled: false,
URL: "https://adaway.org/hosts.txt",
Name: "AdAway",
},
{
ID: 3,
Enabled: false,
URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt",
Name: "MalwareDomainList.com Hosts List",
},
}
}
func defaultContentFilters() []filters.Filter {
return []filters.Filter{
{
ID: 1,
Enabled: true,
URL: "https://filters.adtidy.org/extension/chromium/filters/2.txt",
Name: "AdGuard Base filter",
},
}
}
// getConfigFilename returns path to the current config file
@@ -203,7 +244,7 @@ func parseConfig() error {
return err
}
if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
if !filters.CheckFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) {
config.DNS.FiltersUpdateIntervalHours = 24
}
@@ -263,6 +304,17 @@ func (c *configuration) write() error {
config.DNS.DnsfilterConf = c
}
if Context.filters != nil {
fconf := filters.ModuleConf{}
Context.filters.WriteDiskConfig(&fconf)
config.DNS.FilteringEnabled = fconf.Enabled
config.DNS.FiltersUpdateIntervalHours = fconf.UpdateIntervalHours
config.Filters = fconf.DNSBlocklist
config.WhitelistFilters = fconf.DNSAllowlist
config.ProxyFilters = fconf.Proxylist
config.UserRules = fconf.UserRules
}
if Context.dnsServer != nil {
c := dnsforward.FilteringConfig{}
Context.dnsServer.WriteDiskConfig(&c)
@@ -275,6 +327,12 @@ func (c *configuration) write() error {
config.DHCP = c
}
if Context.mitmProxy != nil {
c := mitmproxy.Config{}
Context.mitmProxy.WriteDiskConfig(&c)
config.MITM = c
}
configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config)

View File

@@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/golibs/log"
"github.com/NYTimes/gziphandler"
"github.com/miekg/dns"
)
// ----------------
@@ -87,6 +88,48 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(data)
}
type checkHostResp struct {
Reason string `json:"reason"`
FilterID int64 `json:"filter_id"`
Rule string `json:"rule"`
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for ReasonRewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
}
func handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.FilterID = result.FilterID
resp.Rule = result.Rule
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
// ------------------------
// registration of handlers
// ------------------------
@@ -96,6 +139,7 @@ func registerControlHandlers() {
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
httpRegister(http.MethodPost, "/control/update", handleUpdate)
httpRegister("GET", "/control/filtering/check_host", handleCheckHost)
httpRegister("GET", "/control/profile", handleGetProfile)
RegisterAuthHandlers()

View File

@@ -1,391 +0,0 @@
package home
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// isValidURL - return TRUE if URL or file path is valid
func isValidURL(rawurl string) bool {
if filepath.IsAbs(rawurl) {
// this is a file path
return util.FileExists(rawurl)
}
url, err := url.ParseRequestURI(rawurl)
if err != nil {
return false //Couldn't even parse the rawurl
}
if len(url.Scheme) == 0 {
return false //No Scheme found
}
return true
}
type filterAddJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
if !isValidURL(fj.URL) {
http.Error(w, "Invalid URL or file path", http.StatusBadRequest)
return
}
// Check for duplicates
if filterExists(fj.URL) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL)
return
}
// Set necessary properties
filt := filter{
Enabled: true,
URL: fj.URL,
Name: fj.Name,
white: fj.Whitelist,
}
filt.ID = assignUniqueFilterID()
// Download the filter contents
ok, err := f.update(&filt)
if err != nil {
httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err)
return
}
if !ok {
httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL)
return
}
// URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it
if !filterAdd(filt) {
httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL)
return
}
onConfigModified()
enableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
}
}
func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
type request struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
}
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
return
}
// go through each element and delete if url matches
config.Lock()
newFilters := []filter{}
filters := &config.Filters
if req.Whitelist {
filters = &config.WhitelistFilters
}
for _, filter := range *filters {
if filter.URL != req.URL {
newFilters = append(newFilters, filter)
} else {
err := os.Rename(filter.Path(), filter.Path()+".old")
if err != nil {
log.Error("os.Rename: %s: %s", filter.Path(), err)
}
}
}
// Update the configuration after removing filter files
*filters = newFilters
config.Unlock()
onConfigModified()
enableFilters(true)
// Note: the old files "filter.txt.old" aren't deleted - it's not really necessary,
// but will require the additional code to run after enableFilters() is finished: i.e. complicated
}
type filterURLJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type filterURLReq struct {
URL string `json:"url"`
Whitelist bool `json:"whitelist"`
Data filterURLJSON `json:"data"`
}
func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !isValidURL(fj.Data.URL) {
http.Error(w, "invalid URL or file path", http.StatusBadRequest)
return
}
filt := filter{
Enabled: fj.Data.Enabled,
Name: fj.Data.Name,
URL: fj.Data.URL,
}
status := f.filterSetProperties(fj.URL, filt, fj.Whitelist)
if (status & statusFound) == 0 {
http.Error(w, "URL doesn't exist", http.StatusBadRequest)
return
}
if (status & statusURLExists) != 0 {
http.Error(w, "URL already exists", http.StatusBadRequest)
return
}
onConfigModified()
restart := false
if (status & statusEnabledChanged) != 0 {
// we must add or remove filter rules
restart = true
}
if (status&statusUpdateRequired) != 0 && fj.Data.Enabled {
// download new filter and apply its rules
flags := FilterRefreshBlocklists
if fj.Whitelist {
flags = FilterRefreshAllowlists
}
nUpdated, _ := f.refreshFilters(flags, true)
// if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically
// if not - we restart the filtering ourselves
restart = false
if nUpdated == 0 {
restart = true
}
}
if restart {
enableFilters(true)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
config.UserRules = strings.Split(string(body), "\n")
onConfigModified()
enableFilters(true)
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type Req struct {
White bool `json:"whitelist"`
}
type Resp struct {
Updated int `json:"updated"`
}
resp := Resp{}
var err error
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
Context.controlLock.Unlock()
flags := FilterRefreshBlocklists
if req.White {
flags = FilterRefreshAllowlists
}
resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false)
Context.controlLock.Lock()
if err != nil {
httpError(w, http.StatusInternalServerError, "%s", err)
return
}
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
type filterJSON struct {
ID int64 `json:"id"`
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name"`
RulesCount uint32 `json:"rules_count"`
LastUpdated string `json:"last_updated"`
}
type filteringConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"` // in hours
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
}
func filterToJSON(f filter) filterJSON {
fj := filterJSON{
ID: f.ID,
Enabled: f.Enabled,
URL: f.URL,
Name: f.Name,
RulesCount: uint32(f.RulesCount),
}
if !f.LastUpdated.IsZero() {
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
}
return fj
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
resp := filteringConfig{}
config.RLock()
resp.Enabled = config.DNS.FilteringEnabled
resp.Interval = config.DNS.FiltersUpdateIntervalHours
for _, f := range config.Filters {
fj := filterToJSON(f)
resp.Filters = append(resp.Filters, fj)
}
for _, f := range config.WhitelistFilters {
fj := filterToJSON(f)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
resp.UserRules = config.UserRules
config.RUnlock()
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkFiltersUpdateIntervalHours(req.Interval) {
httpError(w, http.StatusBadRequest, "Unsupported interval")
return
}
config.DNS.FilteringEnabled = req.Enabled
config.DNS.FiltersUpdateIntervalHours = req.Interval
onConfigModified()
enableFilters(true)
}
type checkHostResp struct {
Reason string `json:"reason"`
FilterID int64 `json:"filter_id"`
Rule string `json:"rule"`
// for FilteredBlockedService:
SvcName string `json:"service_name"`
// for ReasonRewrite:
CanonName string `json:"cname"` // CNAME value
IPList []net.IP `json:"ip_addrs"` // list of IP addresses
}
func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
host := q.Get("name")
setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil {
httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err)
return
}
resp := checkHostResp{}
resp.Reason = result.Reason.String()
resp.FilterID = result.FilterID
resp.Rule = result.Rule
resp.SvcName = result.ServiceName
resp.CanonName = result.CanonName
resp.IPList = result.IPList
js, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(js)
}
// RegisterFilteringHandlers - register handlers
func (f *Filtering) RegisterFilteringHandlers() {
httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL)
httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL)
httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL)
httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost)
}
func checkFiltersUpdateIntervalHours(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

View File

@@ -4,9 +4,11 @@ import (
"fmt"
"net"
"path/filepath"
"strings"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/AdGuardHome/util"
@@ -77,8 +79,6 @@ func initDNSServer() error {
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients)
Context.filters.Init()
return nil
}
@@ -277,6 +277,8 @@ func startDNSServer() error {
Context.dnsFilter.Start()
Context.filters.Start()
Context.filters.GetList(filters.DNSBlocklist).SetObserver(onFiltersChanged)
Context.filters.GetList(filters.DNSAllowlist).SetObserver(onFiltersChanged)
Context.stats.Start()
Context.queryLog.Start()
@@ -345,3 +347,58 @@ func closeDNSServer() {
log.Debug("Closed all DNS modules")
}
func onFiltersChanged(flags uint) {
switch flags {
case filters.EventBeforeUpdate:
//
case filters.EventAfterUpdate:
enableFilters(true)
}
}
// Activate new DNS filters
// async: do it asynchronously (the function returns immediately)
func enableFilters(async bool) {
var blockFilters []dnsfilter.Filter
var allowFilters []dnsfilter.Filter
if config.DNS.FilteringEnabled {
// convert array of filters
// add user filter
userFilter := dnsfilter.Filter{
ID: 0,
Data: []byte(strings.Join(config.UserRules, "\n")),
}
blockFilters = append(blockFilters, userFilter)
// add blocklist filters
list := Context.filters.GetList(filters.DNSBlocklist).List(0)
for _, f := range list {
if !f.Enabled || f.RuleCount == 0 {
continue
}
f := dnsfilter.Filter{
ID: int64(f.ID),
FilePath: f.Path,
}
blockFilters = append(blockFilters, f)
}
// add allowlist filters
list = Context.filters.GetList(filters.DNSAllowlist).List(0)
for _, f := range list {
if !f.Enabled || f.RuleCount == 0 {
continue
}
f := dnsfilter.Filter{
ID: int64(f.ID),
FilePath: f.Path,
}
allowFilters = append(allowFilters, f)
}
}
_ = Context.dnsFilter.SetFilters(blockFilters, allowFilters, async)
}

View File

@@ -1,65 +0,0 @@
package home
import (
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func testStartFilterListener() net.Listener {
http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
go func() { _ = http.Serve(listener, nil) }()
return listener
}
func TestFilters(t *testing.T) {
l := testStartFilterListener()
defer func() { _ = l.Close() }()
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
Context = homeContext{}
Context.workDir = dir
Context.client = &http.Client{
Timeout: 5 * time.Second,
}
Context.filters.Init()
f := filter{
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
}
// download
ok, err := Context.filters.update(&f)
assert.Equal(t, nil, err)
assert.True(t, ok)
assert.Equal(t, 3, f.RulesCount)
// refresh
ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil)
err = Context.filters.load(&f)
assert.True(t, err == nil)
f.unload()
_ = os.Remove(f.Path())
}

View File

@@ -20,6 +20,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2"
"github.com/AdguardTeam/AdGuardHome/filters"
"github.com/AdguardTeam/AdGuardHome/update"
"github.com/AdguardTeam/AdGuardHome/util"
@@ -30,6 +31,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/dnsforward"
"github.com/AdguardTeam/AdGuardHome/mitmproxy"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/AdguardTeam/AdGuardHome/stats"
"github.com/AdguardTeam/golibs/log"
@@ -62,12 +64,14 @@ type homeContext struct {
dnsFilter *dnsfilter.Dnsfilter // DNS filtering module
dhcpServer *dhcpd.Server // DHCP module
auth *Auth // HTTP authentication module
filters Filtering // DNS filtering module
filters *filters.Filtering // DNS filtering module
web *Web // Web (HTTP, HTTPS) module
tls *TLSMod // TLS module
autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files
updater *update.Updater
mitmProxy *mitmproxy.MITMProxy // MITM proxy module
// Runtime properties
// --
@@ -279,6 +283,28 @@ func run(args options) {
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
}
fconf := filters.ModuleConf{}
fconf.Enabled = config.DNS.FilteringEnabled
fconf.UpdateIntervalHours = config.DNS.FiltersUpdateIntervalHours
fconf.DataDir = Context.getDataDir()
fconf.DNSBlocklist = config.Filters
fconf.DNSAllowlist = config.WhitelistFilters
fconf.UserRules = config.UserRules
fconf.Proxylist = config.ProxyFilters
fconf.HTTPClient = Context.client
fconf.ConfigModified = onConfigModified
fconf.HTTPRegister = httpRegister
Context.filters = filters.NewModule(fconf)
config.MITM.CertDir = Context.getDataDir()
config.MITM.ConfigModified = onConfigModified
config.MITM.HTTPRegister = httpRegister
config.MITM.Filter = Context.filters.GetList(filters.Proxylist)
Context.mitmProxy = mitmproxy.New(config.MITM)
if Context.mitmProxy == nil {
os.Exit(1)
}
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = args.glinetMode
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
@@ -317,6 +343,13 @@ func run(args options) {
}
}()
if Context.mitmProxy != nil {
err = Context.mitmProxy.Start()
if err != nil {
log.Fatal(err)
}
}
err = startDHCPServer()
if err != nil {
log.Fatal(err)
@@ -501,10 +534,16 @@ func cleanup() {
Context.auth = nil
}
if Context.mitmProxy != nil {
Context.mitmProxy.Close()
Context.mitmProxy = nil
}
err := stopDNSServer()
if err != nil {
log.Error("Couldn't stop DNS server: %s", err)
}
err = stopDHCPServer()
if err != nil {
log.Error("Couldn't stop DHCP server: %s", err)

View File

@@ -170,7 +170,7 @@ func TestHome(t *testing.T) {
assert.True(t, haveIP)
for i := 1; ; i++ {
st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt"))
st, err := os.Stat(filepath.Join(dir, "data", "filters_dnsblock", "1.txt"))
if err == nil && st.Size() != 0 {
break
}

View File

@@ -3,6 +3,8 @@ package home
import (
"fmt"
"testing"
"github.com/AdguardTeam/AdGuardHome/filters"
)
func TestUpgrade1to2(t *testing.T) {
@@ -148,13 +150,13 @@ func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{})
if v != (*oldConfig)[k].(bool) {
t.Fatalf("wrong boolean value for %s", k)
}
case []filter:
if len((*oldConfig)[k].([]filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value))
case []filters.Filter:
if len((*oldConfig)[k].([]filters.Filter)) != len(value) {
t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filters.Filter)), len(value))
}
for i, newFilter := range value {
oldFilter := (*oldConfig)[k].([]filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount {
oldFilter := (*oldConfig)[k].([]filters.Filter)[i]
if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RuleCount != newFilter.RuleCount {
t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name)
}
}
@@ -179,16 +181,16 @@ func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVers
func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) {
diskConfig = make(map[string]interface{})
diskConfig["language"] = "en"
diskConfig["filters"] = []filter{
diskConfig["filters"] = []filters.Filter{
{
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RulesCount: 100,
URL: "https://filters.adtidy.org/android/filters/111_optimized.txt",
Name: "Latvian filter",
RuleCount: 100,
},
{
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RulesCount: 200,
URL: "https://easylist.to/easylistgermany/easylistgermany.txt",
Name: "Germany filter",
RuleCount: 200,
},
}
diskConfig["user_rules"] = []string{}