Filters are now saved to a file
Also, they're loaded from the file on startup
Filter ID is not passed to the CoreDNS plugin config (server-side AG DNS must be changed accordingly)
Some minor refactoring, unused functions removed
This commit is contained in:
Andrey Meshkov
2018-10-30 02:17:24 +03:00
parent 30f3eb446c
commit 32d4e80c93
8 changed files with 339 additions and 190 deletions

View File

@@ -16,7 +16,6 @@ import (
"time"
coredns_plugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4"
)
@@ -423,7 +422,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
}
}
ok, err := filter.update(time.Now())
ok, err := filter.update(true)
if err != nil {
errortext := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err)
log.Println(errortext)
@@ -452,14 +451,9 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
http.Error(w, errortext, http.StatusInternalServerError)
return
}
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
tellCoreDNSToReload()
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount)
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
@@ -468,6 +462,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
}
}
// TODO: Start using filter ID
func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
@@ -493,19 +488,22 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
for _, filter := range config.Filters {
if filter.URL != url {
newFilters = append(newFilters, filter)
} else {
// Remove the filter file
err := os.Remove(filter.getFilterFilePath())
if err != nil {
errortext := fmt.Sprintf("Couldn't remove the filter file: %s", err)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
}
}
// Update the configuration after removing filter files
config.Filters = newFilters
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r)
}
// TODO: Start using filter ID
func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
@@ -542,16 +540,10 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
// kick off refresh of rules from new URLs
refreshFiltersIfNeccessary()
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r)
}
// TODO: Start using filter ID
func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
parameters, err := parseParametersFromBody(r.Body)
if err != nil {
@@ -586,13 +578,6 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
return
}
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r)
}
@@ -606,13 +591,6 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
}
config.UserRules = strings.Split(string(body), "\n")
err = writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
return
}
httpUpdateConfigReloadDNSReturnOK(w, r)
}
@@ -639,7 +617,6 @@ func runFilterRefreshers() {
}
func refreshFiltersIfNeccessary() int {
now := time.Now()
config.Lock()
// deduplicate
@@ -663,7 +640,7 @@ func refreshFiltersIfNeccessary() int {
updateCount := 0
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy
updated, err := filter.update(now)
updated, err := filter.update(false)
if err != nil {
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
continue
@@ -675,27 +652,25 @@ func refreshFiltersIfNeccessary() int {
config.Unlock()
if updateCount > 0 {
err := writeFilterFile()
if err != nil {
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
log.Println(errortext)
}
tellCoreDNSToReload()
}
return updateCount
}
func (filter *filter) update(now time.Time) (bool, error) {
// Checks for filters updates
// If "force" is true -- does not check the filter's LastUpdated field
func (filter *filter) update(force bool) (bool, error) {
if !filter.Enabled {
return false, nil
}
elapsed := time.Since(filter.LastUpdated)
if elapsed <= updatePeriod {
if !force && time.Since(filter.LastUpdated) <= updatePeriod {
return false, nil
}
log.Printf("Downloading update for filter %d", filter.ID)
// use same update period for failed filter downloads to avoid flooding with requests
filter.LastUpdated = now
filter.LastUpdated = time.Now()
resp, err := client.Get(filter.URL)
if resp != nil && resp.Body != nil {
@@ -706,9 +681,15 @@ func (filter *filter) update(now time.Time) (bool, error) {
return false, err
}
if resp.StatusCode >= 400 {
if resp.StatusCode != 200 {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
return false, fmt.Errorf("Got status code >= 400: %d", resp.StatusCode)
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
}
contentType := strings.ToLower(resp.Header.Get("content-type"))
if !strings.HasPrefix(contentType, "text/plain") {
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
return false, fmt.Errorf("non-text response %s", contentType)
}
body, err := ioutil.ReadAll(resp.Body)
@@ -717,11 +698,12 @@ func (filter *filter) update(now time.Time) (bool, error) {
return false, err
}
// extract filter name and count number of rules
// Extract filter name and count number of rules
lines := strings.Split(string(body), "\n")
rulesCount := 0
seenTitle := false
d := dnsfilter.New()
// Count lines in the filter
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) > 0 && line[0] == '!' {
@@ -730,61 +712,73 @@ func (filter *filter) update(now time.Time) (bool, error) {
seenTitle = true
}
} else if len(line) != 0 {
err = d.AddRule(line, 0)
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
continue
}
if err != nil {
log.Printf("Cannot add rule %s from %s: %s", line, filter.URL, err)
// Just ignore invalid rules
continue
}
rulesCount++
}
}
// Check if the filter was really changed
if bytes.Equal(filter.contents, body) {
return false, nil
}
log.Printf("Filter %s updated: %d bytes, %d rules", filter.URL, len(body), rulesCount)
filter.RulesCount = rulesCount
filter.contents = body
// Saving it to the filters dir now
err = filter.save()
if err != nil {
return false, nil
}
return true, nil
}
// write filter file
func writeFilterFile() error {
filterpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile)
log.Printf("Writing filter file: %s", filterpath)
// TODO: check if file contents have modified
data := []byte{}
config.RLock()
filters := config.Filters
for _, filter := range filters {
if !filter.Enabled {
continue
}
data = append(data, filter.contents...)
data = append(data, '\n')
}
for _, rule := range config.UserRules {
data = append(data, []byte(rule)...)
data = append(data, '\n')
}
config.RUnlock()
err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
// saves filter contents to the file in config.ourDataDir
func (filter *filter) save() error {
filterFilePath := filter.getFilterFilePath()
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
err := writeFileSafe(filterFilePath, filter.contents)
if err != nil {
log.Printf("Couldn't write filter file: %s", err)
return err
}
err = os.Rename(filterpath+".tmp", filterpath)
if err != nil {
log.Printf("Couldn't rename filter file: %s", err)
return nil;
}
// loads filter contents from the file in config.ourDataDir
func (filter *filter) load() error {
if !filter.Enabled {
// No need to load a filter that is not enabled
return nil
}
filterFilePath := filter.getFilterFilePath()
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
// do nothing, file doesn't exist
return err
}
filterFile, err := ioutil.ReadFile(filterFilePath)
if err != nil {
return err
}
log.Printf("Filter %d length is %d", filter.ID, len(filterFile))
filter.contents = filterFile
return nil
}
// Path to the filter contents
func (filter *filter) getFilterFilePath() string {
return filepath.Join(config.ourBinaryDir, config.ourDataDir, FiltersDir, strconv.Itoa(filter.ID) + ".txt")
}
// ------------
// safebrowsing
// ------------