MITM proxy

This commit is contained in:
Simon Zolin
2020-08-18 19:23:33 +03:00
parent c3123473cf
commit f85de51452
21 changed files with 2116 additions and 491 deletions

247
filters/filter_file.go Normal file
View File

@@ -0,0 +1,247 @@
package filters
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/log"
)
// Allows printable UTF-8 text with CR, LF, TAB characters
func isPrintableText(data []byte) bool {
for _, c := range data {
if (c >= ' ' && c != 0x7f) || c == '\n' || c == '\r' || c == '\t' {
continue
}
return false
}
return true
}
// Download filter data
// Return nil on success. Set f.Path to a file path, or "" if the file was not modified
func (fs *filterStg) downloadFilter(f *Filter) error {
log.Debug("Filters: Downloading filter from %s", f.URL)
// create temp file
tmpFile, err := ioutil.TempFile(filepath.Join(fs.conf.FilterDir), "")
if err != nil {
return err
}
defer func() {
if tmpFile != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpFile.Name())
}
}()
// create data reader object
var reader io.Reader
if filepath.IsAbs(f.URL) {
f, err := os.Open(f.URL)
if err != nil {
return fmt.Errorf("open file: %s", err)
}
defer f.Close()
reader = f
} else {
req, err := http.NewRequest("GET", f.URL, nil)
if err != nil {
return err
}
if len(f.LastModified) != 0 {
req.Header.Add("If-Modified-Since", f.LastModified)
}
resp, err := fs.conf.HTTPClient.Do(req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
f.networkError = true
return err
}
if resp.StatusCode == 304 { // "NOT_MODIFIED"
log.Debug("Filters: filter %s isn't modified since %s",
f.URL, f.LastModified)
f.LastUpdated = time.Now()
f.Path = ""
return nil
} else if resp.StatusCode != 200 {
err := fmt.Errorf("Filters: Couldn't download filter from %s: status code: %d",
f.URL, resp.StatusCode)
return err
}
f.LastModified = resp.Header.Get("Last-Modified")
reader = resp.Body
}
// parse and validate data, write to a file
err = writeFile(f, reader, tmpFile)
if err != nil {
return err
}
// Closing the file before renaming it is necessary on Windows
_ = tmpFile.Close()
fname := fs.filePath(*f)
err = os.Rename(tmpFile.Name(), fname)
if err != nil {
return err
}
tmpFile = nil // prevent from deleting this file in "defer" handler
log.Debug("Filters: saved filter %s at %s", f.URL, fname)
f.Path = fname
f.LastUpdated = time.Now()
return nil
}
func gatherUntil(dst []byte, dstLen int, src []byte, until int) int {
num := util.MinInt(len(src), until-dstLen)
return copy(dst[dstLen:], src[:num])
}
func isHTML(buf []byte) bool {
s := strings.ToLower(string(buf))
return strings.Contains(s, "<html") ||
strings.Contains(s, "<!doctype")
}
// Read file data and count the number of rules
func parseFilter(f *Filter, reader io.Reader) error {
ruleCount := 0
r := bufio.NewReader(reader)
log.Debug("Filters: parsing %s", f.URL)
var err error
for err == nil {
var line string
line, err = r.ReadString('\n')
if err != nil && err != io.EOF {
return err
}
line = strings.TrimSpace(line)
if len(line) == 0 ||
line[0] == '#' ||
line[0] == '!' {
continue
}
ruleCount++
}
log.Debug("Filters: %s: %d rules", f.URL, ruleCount)
f.RuleCount = uint64(ruleCount)
return nil
}
// Read data, parse, write to a file
func writeFile(f *Filter, reader io.Reader, outFile *os.File) error {
ruleCount := 0
buf := make([]byte, 64*1024)
total := 0
var chunk []byte
firstChunk := make([]byte, 4*1024)
firstChunkLen := 0
for {
n, err := reader.Read(buf)
if err != nil && err != io.EOF {
return err
}
total += n
if !isPrintableText(buf[:n]) {
return fmt.Errorf("data contains non-printable characters")
}
if firstChunk != nil {
// gather full buffer firstChunk and perform its data tests
firstChunkLen += gatherUntil(firstChunk, firstChunkLen, buf[:n], len(firstChunk))
if firstChunkLen == len(firstChunk) ||
err == io.EOF {
if isHTML(firstChunk[:firstChunkLen]) {
return fmt.Errorf("data is HTML, not plain text")
}
firstChunk = nil
}
}
_, err2 := outFile.Write(buf[:n])
if err2 != nil {
return err2
}
chunk = append(chunk, buf[:n]...)
s := string(chunk)
for len(s) != 0 {
i, line := splitNext(&s, '\n')
if i < 0 && err != io.EOF {
// no more lines in the current chunk
break
}
chunk = []byte(s)
if len(line) == 0 ||
line[0] == '#' ||
line[0] == '!' {
continue
}
ruleCount++
}
if err == io.EOF {
break
}
}
log.Debug("Filters: updated filter %s: %d bytes, %d rules",
f.URL, total, ruleCount)
f.RuleCount = uint64(ruleCount)
return nil
}
// SplitNext - split string by a byte
// Whitespace is trimmed
// Return byte position and the first chunk
func splitNext(data *string, by byte) (int, string) {
s := *data
i := strings.IndexByte(s, by)
var chunk string
if i < 0 {
chunk = s
s = ""
} else {
chunk = s[:i]
s = s[i+1:]
}
*data = s
chunk = strings.TrimSpace(chunk)
return i, chunk
}

329
filters/filter_http.go Normal file
View File

@@ -0,0 +1,329 @@
package filters
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
// Print to log and set HTTP error message
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("Filters: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
// IsValidURL - return TRUE if URL or file path is valid
func IsValidURL(rawurl string) bool {
if filepath.IsAbs(rawurl) {
// this is a file path
return util.FileExists(rawurl)
}
url, err := url.ParseRequestURI(rawurl)
if err != nil {
return false //Couldn't even parse the rawurl
}
if len(url.Scheme) == 0 {
return false //No Scheme found
}
return true
}
func (f *Filtering) getFilterModule(t string) Filters {
switch t {
case "blocklist":
return f.dnsBlocklist
case "whitelist":
return f.dnsAllowlist
case "proxylist":
return f.Proxylist
default:
return nil
}
}
func (f *Filtering) restartMods(t string) {
fN := f.getFilterModule(t)
fN.NotifyObserver(EventBeforeUpdate)
fN.NotifyObserver(EventAfterUpdate)
}
func (f *Filtering) handleFilterAdd(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Type string `json:"type"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
filt := Filter{
Enabled: true,
Name: req.Name,
URL: req.URL,
}
err = filterN.Add(filt)
if err != nil {
httpError(r, w, http.StatusBadRequest, "add filter: %s", err)
return
}
f.conf.ConfigModified()
f.restartMods(req.Type)
}
func (f *Filtering) handleFilterRemove(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
URL string `json:"url"`
Type string `json:"type"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
removed := filterN.Delete(req.URL)
if removed == nil {
httpError(r, w, http.StatusInternalServerError, "no filter with such URL")
return
}
f.conf.ConfigModified()
if removed.Enabled {
f.restartMods(req.Type)
}
err = os.Remove(removed.Path)
if err != nil {
log.Error("os.Remove: %s", err)
}
}
func (f *Filtering) handleFilterModify(w http.ResponseWriter, r *http.Request) {
type propsJSON struct {
Name string `json:"name"`
URL string `json:"url"`
Enabled bool `json:"enabled"`
}
type reqJSON struct {
URL string `json:"url"`
Type string `json:"type"`
Data propsJSON `json:"data"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
st, _, err := filterN.Modify(req.URL, req.Data.Enabled, req.Data.Name, req.Data.URL)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
f.conf.ConfigModified()
if st == StatusChangedEnabled ||
st == StatusChangedURL {
// TODO StatusChangedURL: delete old file
f.restartMods(req.Type)
}
}
func (f *Filtering) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
return
}
f.conf.UserRules = strings.Split(string(body), "\n")
f.conf.ConfigModified()
f.restartMods("blocklist")
}
func (f *Filtering) handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
Type string `json:"type"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
filterN := f.getFilterModule(req.Type)
if filterN == nil {
httpError(r, w, http.StatusBadRequest, "invalid type: %s", req.Type)
return
}
filterN.Refresh(0)
}
type filterJSON struct {
ID int64 `json:"id"`
Enabled bool `json:"enabled"`
URL string `json:"url"`
Name string `json:"name"`
RulesCount uint32 `json:"rules_count"`
LastUpdated string `json:"last_updated"`
}
func filterToJSON(f Filter) filterJSON {
fj := filterJSON{
ID: int64(f.ID),
Enabled: f.Enabled,
URL: f.URL,
Name: f.Name,
RulesCount: uint32(f.RuleCount),
}
if !f.LastUpdated.IsZero() {
fj.LastUpdated = f.LastUpdated.Format(time.RFC3339)
}
return fj
}
// Get filtering configuration
func (f *Filtering) handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
type respJSON struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"` // in hours
Filters []filterJSON `json:"filters"`
WhitelistFilters []filterJSON `json:"whitelist_filters"`
UserRules []string `json:"user_rules"`
Proxylist []filterJSON `json:"proxy_filters"`
}
resp := respJSON{}
resp.Enabled = f.conf.Enabled
resp.Interval = f.conf.UpdateIntervalHours
resp.UserRules = f.conf.UserRules
f0 := f.dnsBlocklist.List(0)
f1 := f.dnsAllowlist.List(0)
f2 := f.Proxylist.List(0)
for _, filt := range f0 {
fj := filterToJSON(filt)
resp.Filters = append(resp.Filters, fj)
}
for _, filt := range f1 {
fj := filterToJSON(filt)
resp.WhitelistFilters = append(resp.WhitelistFilters, fj)
}
for _, filt := range f2 {
fj := filterToJSON(filt)
resp.Proxylist = append(resp.Proxylist, fj)
}
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonVal)
}
// Set filtering configuration
func (f *Filtering) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
type reqJSON struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
}
req := reqJSON{}
_, err := jsonutil.DecodeObject(&req, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
if !CheckFiltersUpdateIntervalHours(req.Interval) {
httpError(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
restart := false
if f.conf.Enabled != req.Enabled {
restart = true
}
f.conf.Enabled = req.Enabled
f.conf.UpdateIntervalHours = req.Interval
c := Conf{}
c.UpdateIntervalHours = req.Interval
f.dnsBlocklist.SetConfig(c)
f.dnsAllowlist.SetConfig(c)
f.Proxylist.SetConfig(c)
f.conf.ConfigModified()
if restart {
f.restartMods("blocklist")
}
}
// registerWebHandlers - register handlers
func (f *Filtering) registerWebHandlers() {
f.conf.HTTPRegister("GET", "/control/filtering/status", f.handleFilteringStatus)
f.conf.HTTPRegister("POST", "/control/filtering/config", f.handleFilteringConfig)
f.conf.HTTPRegister("POST", "/control/filtering/add_url", f.handleFilterAdd)
f.conf.HTTPRegister("POST", "/control/filtering/remove_url", f.handleFilterRemove)
f.conf.HTTPRegister("POST", "/control/filtering/set_url", f.handleFilterModify)
f.conf.HTTPRegister("POST", "/control/filtering/refresh", f.handleFilteringRefresh)
f.conf.HTTPRegister("POST", "/control/filtering/set_rules", f.handleFilteringSetRules)
}
// CheckFiltersUpdateIntervalHours - verify update interval
func CheckFiltersUpdateIntervalHours(i uint32) bool {
return i == 0 || i == 1 || i == 12 || i == 1*24 || i == 3*24 || i == 7*24
}

118
filters/filter_module.go Normal file
View File

@@ -0,0 +1,118 @@
package filters
import (
"net/http"
"path/filepath"
)
// Filtering - module object
type Filtering struct {
dnsBlocklist Filters // DNS blocklist filters
dnsAllowlist Filters // DNS allowlist filters
Proxylist Filters // MITM Proxy filtering module
conf ModuleConf
}
// ModuleConf - module config
type ModuleConf struct {
Enabled bool
UpdateIntervalHours uint32 // 0: disabled
HTTPClient *http.Client
DataDir string
DNSBlocklist []Filter
DNSAllowlist []Filter
Proxylist []Filter
UserRules []string
// Called when the configuration is changed by HTTP request
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
}
// NewModule - create module
func NewModule(conf ModuleConf) *Filtering {
f := Filtering{}
f.conf = conf
fconf := Conf{}
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_dnsblock")
fconf.List = conf.DNSBlocklist
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
fconf.HTTPClient = conf.HTTPClient
f.dnsBlocklist = New(fconf)
fconf = Conf{}
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_dnsallow")
fconf.List = conf.DNSAllowlist
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
fconf.HTTPClient = conf.HTTPClient
f.dnsAllowlist = New(fconf)
fconf = Conf{}
fconf.FilterDir = filepath.Join(conf.DataDir, "filters_mitmproxy")
fconf.List = conf.Proxylist
fconf.UpdateIntervalHours = conf.UpdateIntervalHours
fconf.HTTPClient = conf.HTTPClient
f.Proxylist = New(fconf)
return &f
}
const (
DNSBlocklist = iota
DNSAllowlist
Proxylist
)
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// WriteDiskConfig - write configuration data
func (f *Filtering) WriteDiskConfig(mc *ModuleConf) {
mc.Enabled = f.conf.Enabled
mc.UpdateIntervalHours = f.conf.UpdateIntervalHours
mc.UserRules = stringArrayDup(f.conf.UserRules)
c := Conf{}
f.dnsBlocklist.WriteDiskConfig(&c)
mc.DNSBlocklist = c.List
c = Conf{}
f.dnsAllowlist.WriteDiskConfig(&c)
mc.DNSAllowlist = c.List
c = Conf{}
f.Proxylist.WriteDiskConfig(&c)
mc.Proxylist = c.List
}
// GetList - get specific filter list
func (f *Filtering) GetList(t uint32) Filters {
switch t {
case DNSBlocklist:
return f.dnsBlocklist
case DNSAllowlist:
return f.dnsAllowlist
case Proxylist:
return f.Proxylist
}
return nil
}
// Start - start module
func (f *Filtering) Start() {
f.dnsBlocklist.Start()
f.dnsAllowlist.Start()
f.Proxylist.Start()
f.registerWebHandlers()
}
// Close - close the module
func (f *Filtering) Close() {
}

246
filters/filter_storage.go Normal file
View File

@@ -0,0 +1,246 @@
package filters
import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"go.uber.org/atomic"
)
// filter storage object
type filterStg struct {
updateTaskRunning bool
updated []Filter // list of filters that were downloaded during update procedure
updateChan chan bool // signal for the update goroutine
conf *Conf
confLock sync.Mutex
nextID atomic.Uint64 // next filter ID
observer EventHandler // user function that receives notifications
}
// initialize the module
func newFiltersObj(conf Conf) Filters {
fs := filterStg{}
fs.conf = &Conf{}
*fs.conf = conf
fs.nextID.Store(uint64(time.Now().Unix()))
fs.updateChan = make(chan bool, 2)
return &fs
}
// Start - start module
func (fs *filterStg) Start() {
_ = os.MkdirAll(fs.conf.FilterDir, 0755)
// Load all enabled filters
// On error, RuleCount is set to 0 - users won't try to use such filters
// and in the future the update procedure will re-download the file
for i := range fs.conf.List {
f := &fs.conf.List[i]
fname := fs.filePath(*f)
st, err := os.Stat(fname)
if err != nil {
log.Debug("Filters: os.Stat: %s %s", fname, err)
continue
}
f.LastUpdated = st.ModTime()
if !f.Enabled {
continue
}
file, err := os.OpenFile(fname, os.O_RDONLY, 0)
if err != nil {
log.Error("Filters: os.OpenFile: %s %s", fname, err)
continue
}
_ = parseFilter(f, file)
file.Close()
f.nextUpdate = f.LastUpdated.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour)
}
if !fs.updateTaskRunning {
fs.updateTaskRunning = true
go fs.updateBySignal()
go fs.updateByTimer()
}
}
// Close - close the module
func (fs *filterStg) Close() {
fs.updateChan <- false
close(fs.updateChan)
}
// Duplicate filter array
func arrayFilterDup(f []Filter) []Filter {
nf := make([]Filter, len(f))
copy(nf, f)
return nf
}
// WriteDiskConfig - write configuration on disk
func (fs *filterStg) WriteDiskConfig(c *Conf) {
fs.confLock.Lock()
*c = *fs.conf
c.List = arrayFilterDup(fs.conf.List)
fs.confLock.Unlock()
}
// SetConfig - set new configuration settings
func (fs *filterStg) SetConfig(c Conf) {
fs.conf.UpdateIntervalHours = c.UpdateIntervalHours
}
// SetObserver - set user handler for notifications
func (fs *filterStg) SetObserver(handler EventHandler) {
fs.observer = handler
}
// NotifyObserver - notify users about the event
func (fs *filterStg) NotifyObserver(flags uint) {
if fs.observer == nil {
return
}
fs.observer(flags)
}
// List (thread safe)
func (fs *filterStg) List(flags uint) []Filter {
fs.confLock.Lock()
list := make([]Filter, len(fs.conf.List))
for i, f := range fs.conf.List {
nf := f
nf.Path = fs.filePath(f)
list[i] = nf
}
fs.confLock.Unlock()
return list
}
// Add - add filter (thread safe)
func (fs *filterStg) Add(nf Filter) error {
fs.confLock.Lock()
defer fs.confLock.Unlock()
for _, f := range fs.conf.List {
if f.Name == nf.Name || f.URL == nf.URL {
return fmt.Errorf("filter with this Name or URL already exists")
}
}
nf.ID = fs.nextFilterID()
nf.Enabled = true
err := fs.downloadFilter(&nf)
if err != nil {
log.Debug("%s", err)
return err
}
fs.conf.List = append(fs.conf.List, nf)
log.Debug("Filters: added filter %s", nf.URL)
return nil
}
// Delete - remove filter (thread safe)
func (fs *filterStg) Delete(url string) *Filter {
fs.confLock.Lock()
defer fs.confLock.Unlock()
nf := []Filter{}
var found *Filter
for i := range fs.conf.List {
f := &fs.conf.List[i]
if f.URL == url {
found = f
continue
}
nf = append(nf, *f)
}
if found == nil {
return nil
}
fs.conf.List = nf
log.Debug("Filters: removed filter %s", url)
found.Path = fs.filePath(*found) // the caller will delete the file
return found
}
// Modify - set filter properties (thread safe)
// Return Status* bitarray
func (fs *filterStg) Modify(url string, enabled bool, name string, newURL string) (int, Filter, error) {
fs.confLock.Lock()
defer fs.confLock.Unlock()
st := 0
for i := range fs.conf.List {
f := &fs.conf.List[i]
if f.URL == url {
backup := *f
f.Name = name
if f.Enabled != enabled {
f.Enabled = enabled
st |= StatusChangedEnabled
}
if f.URL != newURL {
f.URL = newURL
st |= StatusChangedURL
}
needDownload := false
if (st & StatusChangedURL) != 0 {
f.ID = fs.nextFilterID()
needDownload = true
} else if (st&StatusChangedEnabled) != 0 && enabled {
fname := fs.filePath(*f)
file, err := os.OpenFile(fname, os.O_RDONLY, 0)
if err != nil {
log.Debug("Filters: os.OpenFile: %s %s", fname, err)
needDownload = true
} else {
_ = parseFilter(f, file)
file.Close()
}
}
if needDownload {
f.LastModified = ""
f.RuleCount = 0
err := fs.downloadFilter(f)
if err != nil {
*f = backup
return 0, Filter{}, err
}
}
return st, backup, nil
}
}
return 0, Filter{}, fmt.Errorf("filter %s not found", url)
}
// Get filter file name
func (fs *filterStg) filePath(f Filter) string {
return filepath.Join(fs.conf.FilterDir, fmt.Sprintf("%d.txt", f.ID))
}
// Get next filter ID
func (fs *filterStg) nextFilterID() uint64 {
return fs.nextID.Inc()
}

154
filters/filter_test.go Normal file
View File

@@ -0,0 +1,154 @@
package filters
import (
"fmt"
"net"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
)
func testStartFilterListener(counter *atomic.Uint32) net.Listener {
mux := http.NewServeMux()
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
(*counter).Inc()
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
})
mux.HandleFunc("/filters/2.txt", func(w http.ResponseWriter, r *http.Request) {
(*counter).Inc()
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
1.1.1.1 example1.com
`
_, _ = w.Write([]byte(content))
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
go func() {
_ = http.Serve(listener, mux)
}()
return listener
}
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
var updateStatus atomic.Uint32
func onFiltersUpdate(flags uint) {
switch flags {
case EventBeforeUpdate:
updateStatus.Store(updateStatus.Load() | 1)
case EventAfterUpdate:
updateStatus.Store(updateStatus.Load() | 2)
}
}
func TestFilters(t *testing.T) {
counter := atomic.Uint32{}
lhttp := testStartFilterListener(&counter)
defer func() { _ = lhttp.Close() }()
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
fconf := Conf{}
fconf.UpdateIntervalHours = 1
fconf.FilterDir = dir
fconf.HTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
fs := New(fconf)
fs.SetObserver(onFiltersUpdate)
fs.Start()
port := lhttp.Addr().(*net.TCPAddr).Port
URL := fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", port)
// add and download
f := Filter{
URL: URL,
}
err := fs.Add(f)
assert.Equal(t, nil, err)
// check
l := fs.List(0)
assert.Equal(t, 1, len(l))
assert.Equal(t, URL, l[0].URL)
assert.True(t, l[0].Enabled)
assert.Equal(t, uint64(3), l[0].RuleCount)
assert.True(t, l[0].ID != 0)
// disable
st, _, err := fs.Modify(f.URL, false, "name", f.URL)
assert.Equal(t, StatusChangedEnabled, st)
// check: disabled
l = fs.List(0)
assert.Equal(t, 1, len(l))
assert.True(t, !l[0].Enabled)
// modify URL
newURL := fmt.Sprintf("http://127.0.0.1:%d/filters/2.txt", port)
st, modified, err := fs.Modify(URL, false, "name", newURL)
assert.Equal(t, StatusChangedURL, st)
_ = os.Remove(modified.Path)
// check: new ID, new URL
l = fs.List(0)
assert.Equal(t, 1, len(l))
assert.Equal(t, newURL, l[0].URL)
assert.Equal(t, uint64(4), l[0].RuleCount)
assert.True(t, modified.ID != l[0].ID)
// enable
st, _, err = fs.Modify(newURL, true, "name", newURL)
assert.Equal(t, StatusChangedEnabled, st)
// update
cnt := counter.Load()
fs.Refresh(0)
for i := 0; ; i++ {
if i == 2 {
assert.True(t, false)
break
}
if cnt != counter.Load() {
// filter was updated
break
}
time.Sleep(time.Second)
}
assert.Equal(t, uint32(1|2), updateStatus.Load())
// delete
removed := fs.Delete(newURL)
assert.NotNil(t, removed)
_ = os.Remove(removed.Path)
fs.Close()
}

176
filters/filter_update.go Normal file
View File

@@ -0,0 +1,176 @@
package filters
import (
"os"
"time"
"github.com/AdguardTeam/golibs/log"
)
// Refresh - begin filters update procedure
func (fs *filterStg) Refresh(flags uint) {
fs.confLock.Lock()
defer fs.confLock.Unlock()
for i := range fs.conf.List {
f := &fs.conf.List[i]
f.nextUpdate = time.Time{}
}
fs.updateChan <- true
}
// Start update procedure periodically
func (fs *filterStg) updateByTimer() {
const maxPeriod = 1 * 60 * 60
period := 5 // use a dynamically increasing time interval, while network or DNS is down
for {
if fs.conf.UpdateIntervalHours == 0 {
period = maxPeriod
// update is disabled
time.Sleep(time.Duration(period) * time.Second)
continue
}
fs.updateChan <- true
time.Sleep(time.Duration(period) * time.Second)
period += period
if period > maxPeriod {
period = maxPeriod
}
}
}
// Begin update procedure by signal
func (fs *filterStg) updateBySignal() {
for {
select {
case ok := <-fs.updateChan:
if !ok {
return
}
fs.updateAll()
}
}
}
// Update filters
// Algorithm:
// . Get next filter to update:
// . Download data from Internet and store on disk (in a new file)
// . Add new filter to the special list
// . Repeat for next filter
// (All filters are downloaded)
// . Stop modules that use filters
// . For each updated filter:
// . Rename "new file name" -> "old file name"
// . Update meta data
// . Restart modules that use filters
func (fs *filterStg) updateAll() {
log.Debug("Filters: updating...")
for {
var uf Filter
fs.confLock.Lock()
f := fs.getNextToUpdate()
if f != nil {
uf = *f
}
fs.confLock.Unlock()
if f == nil {
fs.applyUpdate()
return
}
uf.ID = fs.nextFilterID()
err := fs.downloadFilter(&uf)
if err != nil {
if uf.networkError {
fs.confLock.Lock()
f.nextUpdate = time.Now().Add(10 * time.Second)
fs.confLock.Unlock()
}
continue
}
// add new filter to the list
fs.updated = append(fs.updated, uf)
}
}
// Get next filter to update
func (fs *filterStg) getNextToUpdate() *Filter {
now := time.Now()
for i := range fs.conf.List {
f := &fs.conf.List[i]
if f.Enabled &&
f.nextUpdate.Unix() <= now.Unix() {
f.nextUpdate = now.Add(time.Duration(fs.conf.UpdateIntervalHours) * time.Hour)
return f
}
}
return nil
}
// Replace filter files
func (fs *filterStg) applyUpdate() {
if len(fs.updated) == 0 {
log.Debug("Filters: no filters were updated")
return
}
fs.NotifyObserver(EventBeforeUpdate)
nUpdated := 0
fs.confLock.Lock()
for _, uf := range fs.updated {
found := false
for i := range fs.conf.List {
f := &fs.conf.List[i]
if uf.URL == f.URL {
found = true
fpath := fs.filePath(*f)
f.LastUpdated = uf.LastUpdated
if len(uf.Path) == 0 {
// the data hasn't changed - just update file mod time
err := os.Chtimes(fpath, f.LastUpdated, f.LastUpdated)
if err != nil {
log.Error("Filters: os.Chtimes: %s", err)
}
continue
}
err := os.Rename(uf.Path, fpath)
if err != nil {
log.Error("Filters: os.Rename:%s", err)
}
f.RuleCount = uf.RuleCount
nUpdated++
break
}
}
if !found {
// the updated filter was downloaded,
// but it's already removed from the main list
_ = os.Remove(fs.filePath(uf))
}
}
fs.confLock.Unlock()
log.Debug("Filters: %d filters were updated", nUpdated)
fs.updated = nil
fs.NotifyObserver(EventAfterUpdate)
}

93
filters/filters.go Normal file
View File

@@ -0,0 +1,93 @@
package filters
import (
"net/http"
"time"
)
// Filters - main interface
type Filters interface {
// Start - start module
Start()
// Close - close the module
Close()
// WriteDiskConfig - write configuration on disk
WriteDiskConfig(c *Conf)
// SetConfig - set new configuration settings
// Currently only UpdateIntervalHours is supported
SetConfig(c Conf)
// SetObserver - set user handler for notifications
SetObserver(handler EventHandler)
// NotifyObserver - notify users about the event
NotifyObserver(flags uint)
// List (thread safe)
List(flags uint) []Filter
// Add - add filter (thread safe)
Add(nf Filter) error
// Delete - remove filter (thread safe)
Delete(url string) *Filter
// Modify - set filter properties (thread safe)
// Return Status* bitarray, old filter properties and error
Modify(url string, enabled bool, name string, newURL string) (int, Filter, error)
// Refresh - begin filters update procedure
Refresh(flags uint)
}
// Filter - filter object
type Filter struct {
ID uint64 `yaml:"id"`
Enabled bool `yaml:"enabled"`
Name string `yaml:"name"`
URL string `yaml:"url"`
LastModified string `yaml:"last_modified"` // value of Last-Modified HTTP header field
Path string `yaml:"-"`
// number of rules
// 0 means the file isn't loaded - user shouldn't use this filter
RuleCount uint64 `yaml:"-"`
LastUpdated time.Time `yaml:"-"` // time of the last update (= file modification time)
nextUpdate time.Time // time of the next update
networkError bool // network error during download
}
const (
// EventBeforeUpdate - this event is signalled before the update procedure renames/removes old filter files
EventBeforeUpdate = iota
// EventAfterUpdate - this event is signalled after the update procedure is finished
EventAfterUpdate
)
// EventHandler - event handler function
type EventHandler func(flags uint)
const (
// StatusChangedEnabled - changed 'Enabled'
StatusChangedEnabled = 2
// StatusChangedURL - changed 'URL'
StatusChangedURL = 4
)
// Conf - configuration
type Conf struct {
FilterDir string
UpdateIntervalHours uint32 // 0: disabled
HTTPClient *http.Client
List []Filter
}
// New - create object
func New(conf Conf) Filters {
return newFiltersObj(conf)
}