From f85de514520598885a991f03a5f8f27e4ced165d Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Tue, 18 Aug 2020 19:23:33 +0300 Subject: [PATCH] MITM proxy --- AGHTechDoc.md | 97 ++++++++-- filters/filter_file.go | 247 ++++++++++++++++++++++++ filters/filter_http.go | 329 ++++++++++++++++++++++++++++++++ filters/filter_module.go | 118 ++++++++++++ filters/filter_storage.go | 246 ++++++++++++++++++++++++ filters/filter_test.go | 154 +++++++++++++++ filters/filter_update.go | 176 +++++++++++++++++ filters/filters.go | 93 +++++++++ go.mod | 2 + go.sum | 8 + home/config.go | 72 ++++++- home/control.go | 44 +++++ home/control_filtering.go | 391 -------------------------------------- home/dns.go | 61 +++++- home/filter_test.go | 65 ------- home/home.go | 41 +++- home/home_test.go | 2 +- home/upgrade_test.go | 26 +-- mitmproxy/mitm_http.go | 99 ++++++++++ mitmproxy/mitm_test.go | 57 ++++++ mitmproxy/mitmproxy.go | 279 +++++++++++++++++++++++++++ 21 files changed, 2116 insertions(+), 491 deletions(-) create mode 100644 filters/filter_file.go create mode 100644 filters/filter_http.go create mode 100644 filters/filter_module.go create mode 100644 filters/filter_storage.go create mode 100644 filters/filter_test.go create mode 100644 filters/filter_update.go create mode 100644 filters/filters.go delete mode 100644 home/control_filtering.go delete mode 100644 home/filter_test.go create mode 100644 mitmproxy/mitm_http.go create mode 100644 mitmproxy/mitm_test.go create mode 100644 mitmproxy/mitmproxy.go diff --git a/AGHTechDoc.md b/AGHTechDoc.md index e50982ba..77aed5a4 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -52,15 +52,21 @@ Contents: * API: Get query log * API: Set querylog parameters * API: Get querylog parameters -* Filtering +* DNS Filtering * Filters update mechanism * API: Get filtering parameters * API: Set filtering parameters * API: Refresh filters * API: Add Filter - * API: Set URL parameters - * API: Delete URL + * API: Set Filter parameters + * API: Delete Filter * API: Domain Check +* HTTP Proxy + * API: Get Proxy settings + * API: Set Proxy settings + * API: Get Proxy filtering parameters + * API: Add Proxy Filter + * API: Delete Proxy Filter * Log-in page * API: Log in * API: Log out @@ -1477,7 +1483,7 @@ Response: } -## Filtering +## DNS Filtering ![](doc/agh-filtering.png) @@ -1548,7 +1554,19 @@ Response: } ... ], - "user_rules":["...", ...] + "user_rules":["...", ...], + + "proxy_filtering_enabled": true | false + "proxy_filters":[ + { + "enabled":true, + "url":"https://...", + "name":"...", + "rules_count":1234, + "last_updated":"2019-09-04T18:29:30+00:00", + } + ... + ], } For both arrays `filters` and `whitelist_filters` there are unique values: id, url. @@ -1563,6 +1581,7 @@ Request: { "enabled": true | false + "proxy_filtering_enabled": true | false "interval": 0 | 1 | 12 | 1*24 || 3*24 || 7*24 } @@ -1578,7 +1597,7 @@ Request: POST /control/filtering/refresh { - "whitelist": true + "type": blocklist | whitelist | proxylist } Response: @@ -1599,7 +1618,7 @@ Request: { "name": "..." "url": "..." // URL or an absolute file path - "whitelist": true + "type": blocklist | whitelist | proxylist } Response: @@ -1607,7 +1626,7 @@ Response: 200 OK -### API: Set URL parameters +### API: Set Filter parameters Request: @@ -1615,11 +1634,11 @@ Request: { "url": "..." - "whitelist": true + "type": blocklist | whitelist | proxylist "data": { "name": "..." "url": "..." - "enabled": true | false + "enabled": true } } @@ -1628,7 +1647,7 @@ Response: 200 OK -### API: Delete URL +### API: Delete Filter Request: @@ -1636,7 +1655,7 @@ Request: { "url": "..." - "whitelist": true + "type": blocklist | whitelist | proxylist } Response: @@ -1668,6 +1687,60 @@ Response: } +## HTTP Proxy + + Browser <-(HTTP)-> AGH Proxy <-(HTTP)-> Internet Server + +HTTPS MITM: + + . Browser --(CONNECT...)-> AGH Proxy --(handshake)-> Internet Server + . Browser <-(handshake,cert/AGH)-- AGH Proxy <-(cert/issuer)-- Internet Server + . Browser <-(TLS/session2)-> AGH Proxy <-(TLS/session1)-> Internet Server + + +### API: Get Proxy settings + +Request: + + GET /control/proxy_info + +Response: + + 200 OK + + { + "enabled": true|false, + "listen_address": "ip", + "listen_port": 12345, + + "auth_username": "", + "auth_password": "" + } + + +### API: Set Proxy settings + +Request: + + POST /control/proxy_config + + { + "enabled": true|false, + "listen_address": "ip", + "listen_port": 12345, + + "auth_username": "", + "auth_password": "", + + "cert_data":"...", // user-specified certificate. "": generate new + "pkey_data":"...", + } + +Response: + + 200 OK + + ## Log-in page After user completes the steps of installation wizard, he must log in into dashboard using his name and password. After user successfully logs in, he gets the Cookie which allows the server to authenticate him next time without password. After the Cookie is expired, user needs to perform log-in operation again. diff --git a/filters/filter_file.go b/filters/filter_file.go new file mode 100644 index 00000000..c1377ae4 --- /dev/null +++ b/filters/filter_file.go @@ -0,0 +1,247 @@ +package filters + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" +) + +// Allows printable UTF-8 text with CR, LF, TAB characters +func isPrintableText(data []byte) bool { + for _, c := range data { + if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' { + continue + } + return false + } + return true +} + +// Download filter data +// Return nil on success. Set f.Path to a file path, or "" if the file was not modified +func (fs *filterStg) downloadFilter(f *Filter) error { + log.Debug("Filters: Downloading filter from %s", f.URL) + + // create temp file + tmpFile, err := ioutil.TempFile(filepath.Join(fs.conf.FilterDir), "") + if err != nil { + return err + } + defer func() { + if tmpFile != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpFile.Name()) + } + }() + + // create data reader object + var reader io.Reader + if filepath.IsAbs(f.URL) { + f, err := os.Open(f.URL) + if err != nil { + return fmt.Errorf("open file: %s", err) + } + defer f.Close() + reader = f + } else { + req, err := http.NewRequest("GET", f.URL, nil) + if err != nil { + return err + } + + if len(f.LastModified) != 0 { + req.Header.Add("If-Modified-Since", f.LastModified) + } + + resp, err := fs.conf.HTTPClient.Do(req) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + f.networkError = true + return err + } + + if resp.StatusCode == 304 { // "NOT_MODIFIED" + log.Debug("Filters: filter %s isn't modified since %s", + f.URL, f.LastModified) + f.LastUpdated = time.Now() + f.Path = "" + return nil + + } else if resp.StatusCode != 200 { + err := fmt.Errorf("Filters: Couldn't download filter from %s: status code: %d", + f.URL, resp.StatusCode) + return err + } + + f.LastModified = resp.Header.Get("Last-Modified") + + reader = resp.Body + } + + // parse and validate data, write to a file + err = writeFile(f, reader, tmpFile) + if err != nil { + return err + } + + // Closing the file before renaming it is necessary on Windows + _ = tmpFile.Close() + fname := fs.filePath(*f) + err = os.Rename(tmpFile.Name(), fname) + if err != nil { + return err + } + tmpFile = nil // prevent from deleting this file in "defer" handler + + log.Debug("Filters: saved filter %s at %s", f.URL, fname) + f.Path = fname + f.LastUpdated = time.Now() + return nil +} + +func gatherUntil(dst []byte, dstLen int, src []byte, until int) int { + num := util.MinInt(len(src), until-dstLen) + return copy(dst[dstLen:], src[:num]) +} + +func isHTML(buf []byte) bool { + s := strings.ToLower(string(buf)) + return strings.Contains(s, " maxPeriod { + period = maxPeriod + } + } +} + +// Begin update procedure by signal +func (fs *filterStg) updateBySignal() { + for { + select { + case ok := <-fs.updateChan: + if !ok { + return + } + fs.updateAll() + } + } +} + +// Update filters +// Algorithm: +// . Get next filter to update: +// . Download data from Internet and store on disk (in a new file) +// . Add new filter to the special list +// . Repeat for next filter +// (All filters are downloaded) +// . Stop modules that use filters +// . For each updated filter: +// . Rename "new file name" -> "old file name" +// . Update meta data +// . Restart modules that use filters +func (fs *filterStg) updateAll() { + log.Debug("Filters: updating...") + + for { + var uf Filter + fs.confLock.Lock() + f := fs.getNextToUpdate() + if f != nil { + uf = *f + } + fs.confLock.Unlock() + + if f == nil { + fs.applyUpdate() + return + } + + uf.ID = fs.nextFilterID() + err := fs.downloadFilter(&uf) + if err != nil { + if uf.networkError { + fs.confLock.Lock() + f.nextUpdate = time.Now().Add(10 * time.Second) + fs.confLock.Unlock() + } + continue + } + + // add new filter to the list + fs.updated = append(fs.updated, uf) + } +} + +// Get next filter to update +func (fs *filterStg) getNextToUpdate() *Filter { + now := time.Now() + + for i := range fs.conf.List { + f := &fs.conf.List[i] + + if f.Enabled && + f.nextUpdate.Unix() <= now.Unix() { + + f.nextUpdate = now.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour) + return f + } + } + + return nil +} + +// Replace filter files +func (fs *filterStg) applyUpdate() { + if len(fs.updated) == 0 { + log.Debug("Filters: no filters were updated") + return + } + + fs.NotifyObserver(EventBeforeUpdate) + + nUpdated := 0 + + fs.confLock.Lock() + for _, uf := range fs.updated { + found := false + + for i := range fs.conf.List { + f := &fs.conf.List[i] + + if uf.URL == f.URL { + found = true + fpath := fs.filePath(*f) + f.LastUpdated = uf.LastUpdated + + if len(uf.Path) == 0 { + // the data hasn't changed - just update file mod time + err := os.Chtimes(fpath, f.LastUpdated, f.LastUpdated) + if err != nil { + log.Error("Filters: os.Chtimes: %s", err) + } + continue + } + + err := os.Rename(uf.Path, fpath) + if err != nil { + log.Error("Filters: os.Rename:%s", err) + } + + f.RuleCount = uf.RuleCount + nUpdated++ + break + } + } + + if !found { + // the updated filter was downloaded, + // but it's already removed from the main list + _ = os.Remove(fs.filePath(uf)) + } + } + fs.confLock.Unlock() + + log.Debug("Filters: %d filters were updated", nUpdated) + + fs.updated = nil + fs.NotifyObserver(EventAfterUpdate) +} diff --git a/filters/filters.go b/filters/filters.go new file mode 100644 index 00000000..370313af --- /dev/null +++ b/filters/filters.go @@ -0,0 +1,93 @@ +package filters + +import ( + "net/http" + "time" +) + +// Filters - main interface +type Filters interface { + // Start - start module + Start() + + // Close - close the module + Close() + + // WriteDiskConfig - write configuration on disk + WriteDiskConfig(c *Conf) + + // SetConfig - set new configuration settings + // Currently only UpdateIntervalHours is supported + SetConfig(c Conf) + + // SetObserver - set user handler for notifications + SetObserver(handler EventHandler) + + // NotifyObserver - notify users about the event + NotifyObserver(flags uint) + + // List (thread safe) + List(flags uint) []Filter + + // Add - add filter (thread safe) + Add(nf Filter) error + + // Delete - remove filter (thread safe) + Delete(url string) *Filter + + // Modify - set filter properties (thread safe) + // Return Status* bitarray, old filter properties and error + Modify(url string, enabled bool, name string, newURL string) (int, Filter, error) + + // Refresh - begin filters update procedure + Refresh(flags uint) +} + +// Filter - filter object +type Filter struct { + ID uint64 `yaml:"id"` + Enabled bool `yaml:"enabled"` + Name string `yaml:"name"` + URL string `yaml:"url"` + LastModified string `yaml:"last_modified"` // value of Last-Modified HTTP header field + + Path string `yaml:"-"` + + // number of rules + // 0 means the file isn't loaded - user shouldn't use this filter + RuleCount uint64 `yaml:"-"` + + LastUpdated time.Time `yaml:"-"` // time of the last update (= file modification time) + nextUpdate time.Time // time of the next update + networkError bool // network error during download +} + +const ( + // EventBeforeUpdate - this event is signalled before the update procedure renames/removes old filter files + EventBeforeUpdate = iota + // EventAfterUpdate - this event is signalled after the update procedure is finished + EventAfterUpdate +) + +// EventHandler - event handler function +type EventHandler func(flags uint) + +const ( + // StatusChangedEnabled - changed 'Enabled' + StatusChangedEnabled = 2 + // StatusChangedURL - changed 'URL' + StatusChangedURL = 4 +) + +// Conf - configuration +type Conf struct { + FilterDir string + UpdateIntervalHours uint32 // 0: disabled + HTTPClient *http.Client + List []Filter +} + +// New - create object +func New(conf Conf) Filters { + return newFiltersObj(conf) +} diff --git a/go.mod b/go.mod index a9f5a605..33628418 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.14 require ( github.com/AdguardTeam/dnsproxy v0.30.1 github.com/AdguardTeam/golibs v0.4.2 + github.com/AdguardTeam/gomitmproxy v0.2.0 github.com/AdguardTeam/urlfilter v0.11.2 github.com/NYTimes/gziphandler v1.1.1 github.com/fsnotify/fsnotify v1.4.7 @@ -17,6 +18,7 @@ require ( github.com/sparrc/go-ping v0.0.0-20190613174326-4e5b6552494c github.com/stretchr/testify v1.5.1 go.etcd.io/bbolt v1.3.4 + go.uber.org/atomic v1.6.0 golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e golang.org/x/sys v0.0.0-20200331124033-c3d80250170d diff --git a/go.sum b/go.sum index 4271f61e..519126ee 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,7 @@ github.com/AdguardTeam/dnsproxy v0.30.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPq github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= +github.com/AdguardTeam/gomitmproxy v0.2.0 h1:rvCOf17pd1/CnMyMQW891zrEiIQBpQ8cIGjKN9pinUU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.11.2 h1:gCrWGh63Yqw3z4yi9pgikfsbshIEyvAu/KYV3MvTBlc= github.com/AdguardTeam/urlfilter v0.11.2/go.mod h1:aMuejlNxpWppOVjiEV87X6z0eMf7wsXHTAIWQuylfZY= @@ -113,6 +114,8 @@ github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljT github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= go.etcd.io/bbolt v1.3.4 h1:hi1bXHMVrlQh6WwxAy+qZCV/SYIlqo+Ushwdpa4tAKg= go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -120,6 +123,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8 h1:fpnn/HnJONpIu6hkXi1u/7rR0NzilgWr4T0JmWkEitk= golang.org/x/crypto v0.0.0-20200403201458-baeed622b8d8/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -145,9 +150,12 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190624180213-70d37148ca0c/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/home/config.go b/home/config.go index 42e08cc2..553ee151 100644 --- a/home/config.go +++ b/home/config.go @@ -9,6 +9,8 @@ import ( "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/filters" + "github.com/AdguardTeam/AdGuardHome/mitmproxy" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/golibs/file" @@ -17,8 +19,7 @@ import ( ) const ( - dataDir = "data" // data storage - filterDir = "filters" // cache location for downloaded filters, it's under DataDir + dataDir = "data" // data storage ) // logSettings @@ -54,9 +55,13 @@ type configuration struct { DNS dnsConfig `yaml:"dns"` TLS tlsConfigSettings `yaml:"tls"` - Filters []filter `yaml:"filters"` - WhitelistFilters []filter `yaml:"whitelist_filters"` - UserRules []string `yaml:"user_rules"` + MITM mitmproxy.Config `yaml:"mitmproxy"` + + Filters []filters.Filter `yaml:"filters"` + WhitelistFilters []filters.Filter `yaml:"whitelist_filters"` + UserRules []string `yaml:"user_rules"` + + ProxyFilters []filters.Filter `yaml:"proxy_filters"` DHCP dhcpd.ServerConfig `yaml:"dhcp"` @@ -155,7 +160,43 @@ func initConfig() { config.DNS.DnsfilterConf.SafeSearchCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.ParentalCacheSize = 1 * 1024 * 1024 config.DNS.DnsfilterConf.CacheTime = 30 - config.Filters = defaultFilters() + config.Filters = defaultDNSBlocklistFilters() + + config.ProxyFilters = defaultContentFilters() +} + +func defaultDNSBlocklistFilters() []filters.Filter { + return []filters.Filter{ + { + ID: 1, + Enabled: true, + URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", + Name: "AdGuard Simplified Domain Names filter", + }, + { + ID: 2, + Enabled: false, + URL: "https://adaway.org/hosts.txt", + Name: "AdAway", + }, + { + ID: 3, + Enabled: false, + URL: "https://www.malwaredomainlist.com/hostslist/hosts.txt", + Name: "MalwareDomainList.com Hosts List", + }, + } +} + +func defaultContentFilters() []filters.Filter { + return []filters.Filter{ + { + ID: 1, + Enabled: true, + URL: "https://filters.adtidy.org/extension/chromium/filters/2.txt", + Name: "AdGuard Base filter", + }, + } } // getConfigFilename returns path to the current config file @@ -203,7 +244,7 @@ func parseConfig() error { return err } - if !checkFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { + if !filters.CheckFiltersUpdateIntervalHours(config.DNS.FiltersUpdateIntervalHours) { config.DNS.FiltersUpdateIntervalHours = 24 } @@ -263,6 +304,17 @@ func (c *configuration) write() error { config.DNS.DnsfilterConf = c } + if Context.filters != nil { + fconf := filters.ModuleConf{} + Context.filters.WriteDiskConfig(&fconf) + config.DNS.FilteringEnabled = fconf.Enabled + config.DNS.FiltersUpdateIntervalHours = fconf.UpdateIntervalHours + config.Filters = fconf.DNSBlocklist + config.WhitelistFilters = fconf.DNSAllowlist + config.ProxyFilters = fconf.Proxylist + config.UserRules = fconf.UserRules + } + if Context.dnsServer != nil { c := dnsforward.FilteringConfig{} Context.dnsServer.WriteDiskConfig(&c) @@ -275,6 +327,12 @@ func (c *configuration) write() error { config.DHCP = c } + if Context.mitmProxy != nil { + c := mitmproxy.Config{} + Context.mitmProxy.WriteDiskConfig(&c) + config.MITM = c + } + configFile := config.getConfigFilename() log.Debug("Writing YAML file: %s", configFile) yamlText, err := yaml.Marshal(&config) diff --git a/home/control.go b/home/control.go index 744fc327..61fab498 100644 --- a/home/control.go +++ b/home/control.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/golibs/log" "github.com/NYTimes/gziphandler" + "github.com/miekg/dns" ) // ---------------- @@ -87,6 +88,48 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(data) } +type checkHostResp struct { + Reason string `json:"reason"` + FilterID int64 `json:"filter_id"` + Rule string `json:"rule"` + + // for FilteredBlockedService: + SvcName string `json:"service_name"` + + // for ReasonRewrite: + CanonName string `json:"cname"` // CNAME value + IPList []net.IP `json:"ip_addrs"` // list of IP addresses +} + +func handleCheckHost(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + host := q.Get("name") + + setts := Context.dnsFilter.GetConfig() + setts.FilteringEnabled = true + Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) + result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) + if err != nil { + httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) + return + } + + resp := checkHostResp{} + resp.Reason = result.Reason.String() + resp.FilterID = result.FilterID + resp.Rule = result.Rule + resp.SvcName = result.ServiceName + resp.CanonName = result.CanonName + resp.IPList = result.IPList + js, err := json.Marshal(resp) + if err != nil { + httpError(w, http.StatusInternalServerError, "json encode: %s", err) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(js) +} + // ------------------------ // registration of handlers // ------------------------ @@ -96,6 +139,7 @@ func registerControlHandlers() { httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) httpRegister(http.MethodPost, "/control/update", handleUpdate) + httpRegister("GET", "/control/filtering/check_host", handleCheckHost) httpRegister("GET", "/control/profile", handleGetProfile) RegisterAuthHandlers() diff --git a/home/control_filtering.go b/home/control_filtering.go deleted file mode 100644 index e3dd9d46..00000000 --- a/home/control_filtering.go +++ /dev/null @@ -1,391 +0,0 @@ -package home - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "time" - - "github.com/AdguardTeam/AdGuardHome/util" - "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" -) - -// isValidURL - return TRUE if URL or file path is valid -func isValidURL(rawurl string) bool { - if filepath.IsAbs(rawurl) { - // this is a file path - return util.FileExists(rawurl) - } - - url, err := url.ParseRequestURI(rawurl) - if err != nil { - return false //Couldn't even parse the rawurl - } - if len(url.Scheme) == 0 { - return false //No Scheme found - } - return true -} - -type filterAddJSON struct { - Name string `json:"name"` - URL string `json:"url"` - Whitelist bool `json:"whitelist"` -} - -func (f *Filtering) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { - fj := filterAddJSON{} - err := json.NewDecoder(r.Body).Decode(&fj) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) - return - } - - if !isValidURL(fj.URL) { - http.Error(w, "Invalid URL or file path", http.StatusBadRequest) - return - } - - // Check for duplicates - if filterExists(fj.URL) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", fj.URL) - return - } - - // Set necessary properties - filt := filter{ - Enabled: true, - URL: fj.URL, - Name: fj.Name, - white: fj.Whitelist, - } - filt.ID = assignUniqueFilterID() - - // Download the filter contents - ok, err := f.update(&filt) - if err != nil { - httpError(w, http.StatusBadRequest, "Couldn't fetch filter from url %s: %s", filt.URL, err) - return - } - if !ok { - httpError(w, http.StatusBadRequest, "Filter at the url %s is invalid (maybe it points to blank page?)", filt.URL) - return - } - - // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it - if !filterAdd(filt) { - httpError(w, http.StatusBadRequest, "Filter URL already added -- %s", filt.URL) - return - } - - onConfigModified() - enableFilters(true) - - _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } -} - -func (f *Filtering) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { - - type request struct { - URL string `json:"url"` - Whitelist bool `json:"whitelist"` - } - req := request{} - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to parse request body json: %s", err) - return - } - - // go through each element and delete if url matches - config.Lock() - newFilters := []filter{} - filters := &config.Filters - if req.Whitelist { - filters = &config.WhitelistFilters - } - for _, filter := range *filters { - if filter.URL != req.URL { - newFilters = append(newFilters, filter) - } else { - err := os.Rename(filter.Path(), filter.Path()+".old") - if err != nil { - log.Error("os.Rename: %s: %s", filter.Path(), err) - } - } - } - // Update the configuration after removing filter files - *filters = newFilters - config.Unlock() - - onConfigModified() - enableFilters(true) - - // Note: the old files "filter.txt.old" aren't deleted - it's not really necessary, - // but will require the additional code to run after enableFilters() is finished: i.e. complicated -} - -type filterURLJSON struct { - Name string `json:"name"` - URL string `json:"url"` - Enabled bool `json:"enabled"` -} - -type filterURLReq struct { - URL string `json:"url"` - Whitelist bool `json:"whitelist"` - Data filterURLJSON `json:"data"` -} - -func (f *Filtering) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) { - fj := filterURLReq{} - err := json.NewDecoder(r.Body).Decode(&fj) - if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) - return - } - - if !isValidURL(fj.Data.URL) { - http.Error(w, "invalid URL or file path", http.StatusBadRequest) - return - } - - filt := filter{ - Enabled: fj.Data.Enabled, - Name: fj.Data.Name, - URL: fj.Data.URL, - } - status := f.filterSetProperties(fj.URL, filt, fj.Whitelist) - if (status & statusFound) == 0 { - http.Error(w, "URL doesn't exist", http.StatusBadRequest) - return - } - if (status & statusURLExists) != 0 { - http.Error(w, "URL already exists", http.StatusBadRequest) - return - } - - onConfigModified() - restart := false - if (status & statusEnabledChanged) != 0 { - // we must add or remove filter rules - restart = true - } - if (status&statusUpdateRequired) != 0 && fj.Data.Enabled { - // download new filter and apply its rules - flags := FilterRefreshBlocklists - if fj.Whitelist { - flags = FilterRefreshAllowlists - } - nUpdated, _ := f.refreshFilters(flags, true) - // if at least 1 filter has been updated, refreshFilters() restarts the filtering automatically - // if not - we restart the filtering ourselves - restart = false - if nUpdated == 0 { - restart = true - } - } - if restart { - enableFilters(true) - } -} - -func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) - if err != nil { - httpError(w, http.StatusBadRequest, "Failed to read request body: %s", err) - return - } - - config.UserRules = strings.Split(string(body), "\n") - onConfigModified() - enableFilters(true) -} - -func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { - type Req struct { - White bool `json:"whitelist"` - } - type Resp struct { - Updated int `json:"updated"` - } - resp := Resp{} - var err error - - req := Req{} - err = json.NewDecoder(r.Body).Decode(&req) - if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) - return - } - - Context.controlLock.Unlock() - flags := FilterRefreshBlocklists - if req.White { - flags = FilterRefreshAllowlists - } - resp.Updated, err = f.refreshFilters(flags|FilterRefreshForce, false) - Context.controlLock.Lock() - if err != nil { - httpError(w, http.StatusInternalServerError, "%s", err) - return - } - - js, err := json.Marshal(resp) - if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) - return - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(js) -} - -type filterJSON struct { - ID int64 `json:"id"` - Enabled bool `json:"enabled"` - URL string `json:"url"` - Name string `json:"name"` - RulesCount uint32 `json:"rules_count"` - LastUpdated string `json:"last_updated"` -} - -type filteringConfig struct { - Enabled bool `json:"enabled"` - Interval uint32 `json:"interval"` // in hours - Filters []filterJSON `json:"filters"` - WhitelistFilters []filterJSON `json:"whitelist_filters"` - UserRules []string `json:"user_rules"` -} - -func filterToJSON(f filter) filterJSON { - fj := filterJSON{ - ID: f.ID, - Enabled: f.Enabled, - URL: f.URL, - Name: f.Name, - RulesCount: uint32(f.RulesCount), - } - - if !f.LastUpdated.IsZero() { - fj.LastUpdated = f.LastUpdated.Format(time.RFC3339) - } - - return fj -} - -// Get filtering configuration -func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) { - resp := filteringConfig{} - config.RLock() - resp.Enabled = config.DNS.FilteringEnabled - resp.Interval = config.DNS.FiltersUpdateIntervalHours - for _, f := range config.Filters { - fj := filterToJSON(f) - resp.Filters = append(resp.Filters, fj) - } - for _, f := range config.WhitelistFilters { - fj := filterToJSON(f) - resp.WhitelistFilters = append(resp.WhitelistFilters, fj) - } - resp.UserRules = config.UserRules - config.RUnlock() - - jsonVal, err := json.Marshal(resp) - if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(w, http.StatusInternalServerError, "http write: %s", err) - } -} - -// Set filtering configuration -func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) { - req := filteringConfig{} - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - httpError(w, http.StatusBadRequest, "json decode: %s", err) - return - } - - if !checkFiltersUpdateIntervalHours(req.Interval) { - httpError(w, http.StatusBadRequest, "Unsupported interval") - return - } - - config.DNS.FilteringEnabled = req.Enabled - config.DNS.FiltersUpdateIntervalHours = req.Interval - onConfigModified() - enableFilters(true) -} - -type checkHostResp struct { - Reason string `json:"reason"` - FilterID int64 `json:"filter_id"` - Rule string `json:"rule"` - - // for FilteredBlockedService: - SvcName string `json:"service_name"` - - // for ReasonRewrite: - CanonName string `json:"cname"` // CNAME value - IPList []net.IP `json:"ip_addrs"` // list of IP addresses -} - -func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { - q := r.URL.Query() - host := q.Get("name") - - setts := Context.dnsFilter.GetConfig() - setts.FilteringEnabled = true - Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) - result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) - if err != nil { - httpError(w, http.StatusInternalServerError, "couldn't apply filtering: %s: %s", host, err) - return - } - - resp := checkHostResp{} - resp.Reason = result.Reason.String() - resp.FilterID = result.FilterID - resp.Rule = result.Rule - resp.SvcName = result.ServiceName - resp.CanonName = result.CanonName - resp.IPList = result.IPList - js, err := json.Marshal(resp) - if err != nil { - httpError(w, http.StatusInternalServerError, "json encode: %s", err) - return - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(js) -} - -// RegisterFilteringHandlers - register handlers -func (f *Filtering) RegisterFilteringHandlers() { - httpRegister("GET", "/control/filtering/status", f.handleFilteringStatus) - httpRegister("POST", "/control/filtering/config", f.handleFilteringConfig) - httpRegister("POST", "/control/filtering/add_url", f.handleFilteringAddURL) - httpRegister("POST", "/control/filtering/remove_url", f.handleFilteringRemoveURL) - httpRegister("POST", "/control/filtering/set_url", f.handleFilteringSetURL) - httpRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh) - httpRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules) - httpRegister("GET", "/control/filtering/check_host", f.handleCheckHost) -} - -func checkFiltersUpdateIntervalHours(i uint32) bool { - return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24 -} diff --git a/home/dns.go b/home/dns.go index 1ff0da32..3e5b1b07 100644 --- a/home/dns.go +++ b/home/dns.go @@ -4,9 +4,11 @@ import ( "fmt" "net" "path/filepath" + "strings" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/filters" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/util" @@ -77,8 +79,6 @@ func initDNSServer() error { Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) Context.whois = initWhois(&Context.clients) - - Context.filters.Init() return nil } @@ -277,6 +277,8 @@ func startDNSServer() error { Context.dnsFilter.Start() Context.filters.Start() + Context.filters.GetList(filters.DNSBlocklist).SetObserver(onFiltersChanged) + Context.filters.GetList(filters.DNSAllowlist).SetObserver(onFiltersChanged) Context.stats.Start() Context.queryLog.Start() @@ -345,3 +347,58 @@ func closeDNSServer() { log.Debug("Closed all DNS modules") } + +func onFiltersChanged(flags uint) { + switch flags { + case filters.EventBeforeUpdate: + // + + case filters.EventAfterUpdate: + enableFilters(true) + } +} + +// Activate new DNS filters +// async: do it asynchronously (the function returns immediately) +func enableFilters(async bool) { + var blockFilters []dnsfilter.Filter + var allowFilters []dnsfilter.Filter + if config.DNS.FilteringEnabled { + // convert array of filters + + // add user filter + userFilter := dnsfilter.Filter{ + ID: 0, + Data: []byte(strings.Join(config.UserRules, "\n")), + } + blockFilters = append(blockFilters, userFilter) + + // add blocklist filters + list := Context.filters.GetList(filters.DNSBlocklist).List(0) + for _, f := range list { + if !f.Enabled || f.RuleCount == 0 { + continue + } + f := dnsfilter.Filter{ + ID: int64(f.ID), + FilePath: f.Path, + } + blockFilters = append(blockFilters, f) + } + + // add allowlist filters + list = Context.filters.GetList(filters.DNSAllowlist).List(0) + for _, f := range list { + if !f.Enabled || f.RuleCount == 0 { + continue + } + f := dnsfilter.Filter{ + ID: int64(f.ID), + FilePath: f.Path, + } + allowFilters = append(allowFilters, f) + } + } + + _ = Context.dnsFilter.SetFilters(blockFilters, allowFilters, async) +} diff --git a/home/filter_test.go b/home/filter_test.go deleted file mode 100644 index 317741d8..00000000 --- a/home/filter_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package home - -import ( - "fmt" - "net" - "net/http" - "os" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func testStartFilterListener() net.Listener { - http.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) { - content := `||example.org^$third-party -# Inline comment example -||example.com^$third-party -0.0.0.0 example.com -` - _, _ = w.Write([]byte(content)) - }) - - listener, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - - go func() { _ = http.Serve(listener, nil) }() - return listener -} - -func TestFilters(t *testing.T) { - l := testStartFilterListener() - defer func() { _ = l.Close() }() - - dir := prepareTestDir() - defer func() { _ = os.RemoveAll(dir) }() - Context = homeContext{} - Context.workDir = dir - Context.client = &http.Client{ - Timeout: 5 * time.Second, - } - Context.filters.Init() - - f := filter{ - URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port), - } - - // download - ok, err := Context.filters.update(&f) - assert.Equal(t, nil, err) - assert.True(t, ok) - assert.Equal(t, 3, f.RulesCount) - - // refresh - ok, err = Context.filters.update(&f) - assert.True(t, !ok && err == nil) - - err = Context.filters.load(&f) - assert.True(t, err == nil) - - f.unload() - _ = os.Remove(f.Path()) -} diff --git a/home/home.go b/home/home.go index e297173c..afbc0798 100644 --- a/home/home.go +++ b/home/home.go @@ -20,6 +20,7 @@ import ( "gopkg.in/natefinch/lumberjack.v2" + "github.com/AdguardTeam/AdGuardHome/filters" "github.com/AdguardTeam/AdGuardHome/update" "github.com/AdguardTeam/AdGuardHome/util" @@ -30,6 +31,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/mitmproxy" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/golibs/log" @@ -62,12 +64,14 @@ type homeContext struct { dnsFilter *dnsfilter.Dnsfilter // DNS filtering module dhcpServer *dhcpd.Server // DHCP module auth *Auth // HTTP authentication module - filters Filtering // DNS filtering module + filters *filters.Filtering // DNS filtering module web *Web // Web (HTTP, HTTPS) module tls *TLSMod // TLS module autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files updater *update.Updater + mitmProxy *mitmproxy.MITMProxy // MITM proxy module + // Runtime properties // -- @@ -279,6 +283,28 @@ func run(args options) { log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err) } + fconf := filters.ModuleConf{} + fconf.Enabled = config.DNS.FilteringEnabled + fconf.UpdateIntervalHours = config.DNS.FiltersUpdateIntervalHours + fconf.DataDir = Context.getDataDir() + fconf.DNSBlocklist = config.Filters + fconf.DNSAllowlist = config.WhitelistFilters + fconf.UserRules = config.UserRules + fconf.Proxylist = config.ProxyFilters + fconf.HTTPClient = Context.client + fconf.ConfigModified = onConfigModified + fconf.HTTPRegister = httpRegister + Context.filters = filters.NewModule(fconf) + + config.MITM.CertDir = Context.getDataDir() + config.MITM.ConfigModified = onConfigModified + config.MITM.HTTPRegister = httpRegister + config.MITM.Filter = Context.filters.GetList(filters.Proxylist) + Context.mitmProxy = mitmproxy.New(config.MITM) + if Context.mitmProxy == nil { + os.Exit(1) + } + sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") GLMode = args.glinetMode Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60) @@ -317,6 +343,13 @@ func run(args options) { } }() + if Context.mitmProxy != nil { + err = Context.mitmProxy.Start() + if err != nil { + log.Fatal(err) + } + } + err = startDHCPServer() if err != nil { log.Fatal(err) @@ -501,10 +534,16 @@ func cleanup() { Context.auth = nil } + if Context.mitmProxy != nil { + Context.mitmProxy.Close() + Context.mitmProxy = nil + } + err := stopDNSServer() if err != nil { log.Error("Couldn't stop DNS server: %s", err) } + err = stopDHCPServer() if err != nil { log.Error("Couldn't stop DHCP server: %s", err) diff --git a/home/home_test.go b/home/home_test.go index 89901387..50d00799 100644 --- a/home/home_test.go +++ b/home/home_test.go @@ -170,7 +170,7 @@ func TestHome(t *testing.T) { assert.True(t, haveIP) for i := 1; ; i++ { - st, err := os.Stat(filepath.Join(dir, "data", "filters", "1.txt")) + st, err := os.Stat(filepath.Join(dir, "data", "filters_dnsblock", "1.txt")) if err == nil && st.Size() != 0 { break } diff --git a/home/upgrade_test.go b/home/upgrade_test.go index f884de04..22a3967b 100644 --- a/home/upgrade_test.go +++ b/home/upgrade_test.go @@ -3,6 +3,8 @@ package home import ( "fmt" "testing" + + "github.com/AdguardTeam/AdGuardHome/filters" ) func TestUpgrade1to2(t *testing.T) { @@ -148,13 +150,13 @@ func compareConfigs(t *testing.T, oldConfig, newConfig *map[string]interface{}) if v != (*oldConfig)[k].(bool) { t.Fatalf("wrong boolean value for %s", k) } - case []filter: - if len((*oldConfig)[k].([]filter)) != len(value) { - t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filter)), len(value)) + case []filters.Filter: + if len((*oldConfig)[k].([]filters.Filter)) != len(value) { + t.Fatalf("wrong filters count. Before update: %d; After update: %d", len((*oldConfig)[k].([]filters.Filter)), len(value)) } for i, newFilter := range value { - oldFilter := (*oldConfig)[k].([]filter)[i] - if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RulesCount != newFilter.RulesCount { + oldFilter := (*oldConfig)[k].([]filters.Filter)[i] + if oldFilter.Enabled != newFilter.Enabled || oldFilter.Name != newFilter.Name || oldFilter.RuleCount != newFilter.RuleCount { t.Fatalf("old filter %s not equals new filter %s", oldFilter.Name, newFilter.Name) } } @@ -179,16 +181,16 @@ func compareSchemaVersion(t *testing.T, newSchemaVersion interface{}, schemaVers func createTestDiskConfig(schemaVersion int) (diskConfig map[string]interface{}) { diskConfig = make(map[string]interface{}) diskConfig["language"] = "en" - diskConfig["filters"] = []filter{ + diskConfig["filters"] = []filters.Filter{ { - URL: "https://filters.adtidy.org/android/filters/111_optimized.txt", - Name: "Latvian filter", - RulesCount: 100, + URL: "https://filters.adtidy.org/android/filters/111_optimized.txt", + Name: "Latvian filter", + RuleCount: 100, }, { - URL: "https://easylist.to/easylistgermany/easylistgermany.txt", - Name: "Germany filter", - RulesCount: 200, + URL: "https://easylist.to/easylistgermany/easylistgermany.txt", + Name: "Germany filter", + RuleCount: 200, }, } diskConfig["user_rules"] = []string{} diff --git a/mitmproxy/mitm_http.go b/mitmproxy/mitm_http.go new file mode 100644 index 00000000..ffce88e4 --- /dev/null +++ b/mitmproxy/mitm_http.go @@ -0,0 +1,99 @@ +package mitmproxy + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "strconv" + + "github.com/AdguardTeam/golibs/jsonutil" + "github.com/AdguardTeam/golibs/log" +) + +// Print to log and set HTTP error message +func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + log.Info("MITM: %s %s: %s", r.Method, r.URL, text) + http.Error(w, text, code) +} + +type mitmConfigJSON struct { + Enabled bool `json:"enabled"` + ListenAddr string `json:"listen_address"` + ListenPort int `json:"listen_port"` + + UserName string `json:"auth_username"` + Password string `json:"auth_password"` + + CertData string `json:"cert_data"` + PKeyData string `json:"pkey_data"` +} + +func (p *MITMProxy) handleGetConfig(w http.ResponseWriter, r *http.Request) { + resp := mitmConfigJSON{} + p.confLock.Lock() + resp.Enabled = p.conf.Enabled + host, port, _ := net.SplitHostPort(p.conf.ListenAddr) + resp.ListenAddr = host + resp.ListenPort, _ = strconv.Atoi(port) + resp.UserName = p.conf.UserName + resp.Password = p.conf.Password + p.confLock.Unlock() + + js, err := json.Marshal(resp) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "json.Marshal: %s", err) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(js) +} + +func (p *MITMProxy) handleSetConfig(w http.ResponseWriter, r *http.Request) { + req := mitmConfigJSON{} + _, err := jsonutil.DecodeObject(&req, r.Body) + if err != nil { + httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) + return + } + + if !((len(req.CertData) != 0 && len(req.PKeyData) != 0) || + (len(req.CertData) == 0 && len(req.PKeyData) == 0)) { + httpError(r, w, http.StatusBadRequest, "certificate & private key must be both empty or specified") + return + } + + p.confLock.Lock() + if len(req.CertData) != 0 { + err = p.storeCert([]byte(req.CertData), []byte(req.PKeyData)) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "%s", err) + p.confLock.Unlock() + return + } + p.conf.RegenCert = false + } else { + p.conf.RegenCert = true + } + p.conf.Enabled = req.Enabled + p.conf.ListenAddr = net.JoinHostPort(req.ListenAddr, strconv.Itoa(req.ListenPort)) + p.conf.UserName = req.UserName + p.conf.Password = req.Password + p.confLock.Unlock() + + p.conf.ConfigModified() + + p.Close() + err = p.Restart() + if err != nil { + httpError(r, w, http.StatusInternalServerError, "%s", err) + return + } +} + +// Initialize web handlers +func (p *MITMProxy) initWeb() { + p.conf.HTTPRegister("GET", "/control/proxy_info", p.handleGetConfig) + p.conf.HTTPRegister("POST", "/control/proxy_config", p.handleSetConfig) +} diff --git a/mitmproxy/mitm_test.go b/mitmproxy/mitm_test.go new file mode 100644 index 00000000..86e6926e --- /dev/null +++ b/mitmproxy/mitm_test.go @@ -0,0 +1,57 @@ +package mitmproxy + +import ( + "net/http" + "net/url" + "os" + "testing" + + "github.com/AdguardTeam/AdGuardHome/filters" + "github.com/stretchr/testify/assert" +) + +func prepareTestDir() string { + const dir = "./agh-test" + _ = os.RemoveAll(dir) + _ = os.MkdirAll(dir, 0755) + return dir +} + +func TestMITM(t *testing.T) { + dir := prepareTestDir() + defer func() { _ = os.RemoveAll(dir) }() + + fconf := filters.Conf{} + fconf.FilterDir = dir + fconf.HTTPClient = http.DefaultClient + filters := filters.New(fconf) + + conf := Config{} + conf.Enabled = true + conf.CertDir = dir + conf.RegenCert = true + conf.ListenAddr = "127.0.0.1:8081" + conf.Filter = filters + s := New(conf) + assert.NotNil(t, s) + + err := s.Start() + assert.Nil(t, err) + + proxyURL, _ := url.Parse("http://127.0.0.1:8081") + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + c := http.Client{ + Transport: transport, + } + resp, err := c.Get("http://example.com/") + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = c.Get("http://adguardhome.api/cert.crt") + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + s.Close() +} diff --git a/mitmproxy/mitmproxy.go b/mitmproxy/mitmproxy.go new file mode 100644 index 00000000..be2ae34d --- /dev/null +++ b/mitmproxy/mitmproxy.go @@ -0,0 +1,279 @@ +package mitmproxy + +import ( + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "net" + "net/http" + "path/filepath" + "strconv" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/filters" + "github.com/AdguardTeam/golibs/file" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/gomitmproxy/mitm" + "github.com/AdguardTeam/urlfilter/proxy" +) + +// MITMProxy - MITM proxy structure +type MITMProxy struct { + proxy *proxy.Server + conf Config + confLock sync.Mutex +} + +// Config - module configuration +type Config struct { + Enabled bool `yaml:"enabled"` + ListenAddr string `yaml:"listen_address"` + + UserName string `yaml:"auth_username"` + Password string `yaml:"auth_password"` + + // TLS: + RegenCert bool `yaml:"regenerate_cert"` // Regenerate certificate on cert loading failure + CertDir string `yaml:"-"` // Directory where Root certificate & pkey is stored + certFileName string + pkeyFileName string + certData []byte + pkeyData []byte + + Filter filters.Filters `yaml:"-"` + + // Called when the configuration is changed by HTTP request + ConfigModified func() `yaml:"-"` + + // Register an HTTP handler + HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` +} + +// New - create a new instance of the query log +func New(conf Config) *MITMProxy { + p := MITMProxy{} + + p.conf = conf + p.conf.certFileName = filepath.Join(p.conf.CertDir, "/http_proxy.crt") + p.conf.pkeyFileName = filepath.Join(p.conf.CertDir, "/http_proxy.key") + + err := p.create() + if err != nil { + log.Error("MITM: %s", err) + return nil + } + + if p.conf.HTTPRegister != nil { + p.initWeb() + } + + p.conf.Filter.SetObserver(p.onFiltersChanged) + + return &p +} + +// Close - close the object +func (p *MITMProxy) Close() { + if p.proxy != nil { + p.proxy.Close() + p.proxy = nil + log.Debug("MITM: Closed proxy") + } +} + +// WriteDiskConfig - write configuration on disk +func (p *MITMProxy) WriteDiskConfig(c *Config) { + p.confLock.Lock() + *c = p.conf + p.confLock.Unlock() +} + +// Start - start proxy server +func (p *MITMProxy) Start() error { + if !p.conf.Enabled { + return nil + } + + err := p.proxy.Start() + if err != nil { + return err + } + log.Debug("MITM: Running...") + return nil +} + +// Restart - restart proxy server after Close() +func (p *MITMProxy) Restart() error { + err := p.create() + if err != nil { + return err + } + return p.Start() +} + +// Create a gomitmproxy object +func (p *MITMProxy) create() error { + if !p.conf.Enabled { + return nil + } + + c := proxy.Config{} + c.ProxyConfig.APIHost = "adguardhome.api" + addr, port, err := net.SplitHostPort(p.conf.ListenAddr) + if err != nil { + return fmt.Errorf("net.SplitHostPort: %s", err) + } + + c.CompressContentScript = true + c.ProxyConfig.ListenAddr = &net.TCPAddr{} + c.ProxyConfig.ListenAddr.IP = net.ParseIP(addr) + if c.ProxyConfig.ListenAddr.IP == nil { + return fmt.Errorf("invalid IP: %s", addr) + } + c.ProxyConfig.ListenAddr.Port, err = strconv.Atoi(port) + if c.ProxyConfig.ListenAddr.Port < 0 || c.ProxyConfig.ListenAddr.Port > 0xffff || err != nil { + return fmt.Errorf("invalid port number: %s", port) + } + + c.ProxyConfig.Username = p.conf.UserName + c.ProxyConfig.Password = p.conf.Password + + err = p.loadCert() + if err != nil { + if !p.conf.RegenCert { + return err + } + log.Debug("%s", err) + + // certificate or private key file doesn't exist - generate new + err = p.createRootCert() + if err != nil { + return err + } + } + + c.ProxyConfig.MITMConfig, err = p.prepareMITMConfig() + if err != nil { + if !p.conf.RegenCert { + return err + } + + // certificate or private key is invalid - generate new + err = p.createRootCert() + if err != nil { + return err + } + + c.ProxyConfig.MITMConfig, err = p.prepareMITMConfig() + if err != nil { + return err + } + } + + c.FiltersPaths = make(map[int]string) + filtrs := p.conf.Filter.List(0) + i := 0 + for _, f := range filtrs { + if !f.Enabled || + f.RuleCount == 0 { // not loaded + continue + } + + c.FiltersPaths[i] = f.Path + i++ + } + + p.proxy, err = proxy.NewServer(c) + if err != nil { + return fmt.Errorf("proxy.NewServer: %s", err) + } + return nil +} + +// Load cert and pkey from file +func (p *MITMProxy) loadCert() error { + var err error + p.conf.certData, err = ioutil.ReadFile(p.conf.certFileName) + if err != nil { + return err + } + p.conf.pkeyData, err = ioutil.ReadFile(p.conf.pkeyFileName) + if err != nil { + return err + } + return nil +} + +// Create Root certificate and pkey and store it on disk +func (p *MITMProxy) createRootCert() error { + cert, key, err := mitm.NewAuthority("AdGuardHome Root", "AdGuard", 365*24*time.Hour) + if err != nil { + return err + } + + p.conf.certData = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + p.conf.pkeyData = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + log.Debug("MITM: Created root certificate and key") + + err = p.storeCert(p.conf.certData, p.conf.pkeyData) + if err != nil { + return err + } + return nil +} + +// Store cert & pkey on disk +func (p *MITMProxy) storeCert(certData []byte, pkeyData []byte) error { + err := file.SafeWrite(p.conf.certFileName, certData) + if err != nil { + return err + } + + err = file.SafeWrite(p.conf.pkeyFileName, pkeyData) + if err != nil { + return err + } + + log.Debug("MITM: stored root certificate and key: %s, %s", p.conf.certFileName, p.conf.pkeyFileName) + return nil +} + +// Fill TLSConfig & MITMConfig objects +func (p *MITMProxy) prepareMITMConfig() (*mitm.Config, error) { + tlsCert, err := tls.X509KeyPair(p.conf.certData, p.conf.pkeyData) + if err != nil { + return nil, fmt.Errorf("failed to load root CA: %v", err) + } + privateKey := tlsCert.PrivateKey.(*rsa.PrivateKey) + + x509c, err := x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("invalid certificate: %v", err) + } + + mitmConfig, err := mitm.NewConfig(x509c, privateKey, nil) + if err != nil { + return nil, fmt.Errorf("failed to create MITM config: %v", err) + } + + mitmConfig.SetValidity(time.Hour * 24 * 7) // generate certs valid for 7 days + mitmConfig.SetOrganization("AdGuard") // cert organization + return mitmConfig, nil +} + +func (p *MITMProxy) onFiltersChanged(flags uint) { + switch flags { + case filters.EventBeforeUpdate: + p.Close() + + case filters.EventAfterUpdate: + err := p.Restart() + if err != nil { + log.Error("MITM: %s", err) + } + } +}