Filters are now saved to a file Also, they're loaded from the file on startup Filter ID is not passed to the CoreDNS plugin config (server-side AG DNS must be changed accordingly) Some minor refactoring, unused functions removed
This commit is contained in:
166
control.go
166
control.go
@@ -16,7 +16,6 @@ import (
|
||||
"time"
|
||||
|
||||
coredns_plugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||
"github.com/miekg/dns"
|
||||
"gopkg.in/asaskevich/govalidator.v4"
|
||||
)
|
||||
@@ -423,7 +422,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
ok, err := filter.update(time.Now())
|
||||
ok, err := filter.update(true)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err)
|
||||
log.Println(errortext)
|
||||
@@ -452,14 +451,9 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
err = writeFilterFile()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tellCoreDNSToReload()
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
@@ -468,6 +462,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Start using filter ID
|
||||
func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
parameters, err := parseParametersFromBody(r.Body)
|
||||
if err != nil {
|
||||
@@ -493,19 +488,22 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
||||
for _, filter := range config.Filters {
|
||||
if filter.URL != url {
|
||||
newFilters = append(newFilters, filter)
|
||||
} else {
|
||||
// Remove the filter file
|
||||
err := os.Remove(filter.getFilterFilePath())
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't remove the filter file: %s", err)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update the configuration after removing filter files
|
||||
config.Filters = newFilters
|
||||
err = writeFilterFile()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
// TODO: Start using filter ID
|
||||
func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
|
||||
parameters, err := parseParametersFromBody(r.Body)
|
||||
if err != nil {
|
||||
@@ -542,16 +540,10 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// kick off refresh of rules from new URLs
|
||||
refreshFiltersIfNeccessary()
|
||||
err = writeFilterFile()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
// TODO: Start using filter ID
|
||||
func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
|
||||
parameters, err := parseParametersFromBody(r.Body)
|
||||
if err != nil {
|
||||
@@ -586,13 +578,6 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = writeFilterFile()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
@@ -606,13 +591,6 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
config.UserRules = strings.Split(string(body), "\n")
|
||||
err = writeFilterFile()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
httpUpdateConfigReloadDNSReturnOK(w, r)
|
||||
}
|
||||
|
||||
@@ -639,7 +617,6 @@ func runFilterRefreshers() {
|
||||
}
|
||||
|
||||
func refreshFiltersIfNeccessary() int {
|
||||
now := time.Now()
|
||||
config.Lock()
|
||||
|
||||
// deduplicate
|
||||
@@ -663,7 +640,7 @@ func refreshFiltersIfNeccessary() int {
|
||||
updateCount := 0
|
||||
for i := range config.Filters {
|
||||
filter := &config.Filters[i] // otherwise we will be operating on a copy
|
||||
updated, err := filter.update(now)
|
||||
updated, err := filter.update(false)
|
||||
if err != nil {
|
||||
log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
|
||||
continue
|
||||
@@ -675,27 +652,25 @@ func refreshFiltersIfNeccessary() int {
|
||||
config.Unlock()
|
||||
|
||||
if updateCount > 0 {
|
||||
err := writeFilterFile()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
|
||||
log.Println(errortext)
|
||||
}
|
||||
tellCoreDNSToReload()
|
||||
}
|
||||
return updateCount
|
||||
}
|
||||
|
||||
func (filter *filter) update(now time.Time) (bool, error) {
|
||||
// Checks for filters updates
|
||||
// If "force" is true -- does not check the filter's LastUpdated field
|
||||
func (filter *filter) update(force bool) (bool, error) {
|
||||
if !filter.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
elapsed := time.Since(filter.LastUpdated)
|
||||
if elapsed <= updatePeriod {
|
||||
if !force && time.Since(filter.LastUpdated) <= updatePeriod {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Downloading update for filter %d", filter.ID)
|
||||
|
||||
// use same update period for failed filter downloads to avoid flooding with requests
|
||||
filter.LastUpdated = now
|
||||
filter.LastUpdated = time.Now()
|
||||
|
||||
resp, err := client.Get(filter.URL)
|
||||
if resp != nil && resp.Body != nil {
|
||||
@@ -706,9 +681,15 @@ func (filter *filter) update(now time.Time) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
if resp.StatusCode != 200 {
|
||||
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
|
||||
return false, fmt.Errorf("Got status code >= 400: %d", resp.StatusCode)
|
||||
return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("content-type"))
|
||||
if !strings.HasPrefix(contentType, "text/plain") {
|
||||
log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL)
|
||||
return false, fmt.Errorf("non-text response %s", contentType)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
@@ -717,11 +698,12 @@ func (filter *filter) update(now time.Time) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// extract filter name and count number of rules
|
||||
// Extract filter name and count number of rules
|
||||
lines := strings.Split(string(body), "\n")
|
||||
rulesCount := 0
|
||||
seenTitle := false
|
||||
d := dnsfilter.New()
|
||||
|
||||
// Count lines in the filter
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) > 0 && line[0] == '!' {
|
||||
@@ -730,61 +712,73 @@ func (filter *filter) update(now time.Time) (bool, error) {
|
||||
seenTitle = true
|
||||
}
|
||||
} else if len(line) != 0 {
|
||||
err = d.AddRule(line, 0)
|
||||
if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Cannot add rule %s from %s: %s", line, filter.URL, err)
|
||||
// Just ignore invalid rules
|
||||
continue
|
||||
}
|
||||
rulesCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the filter was really changed
|
||||
if bytes.Equal(filter.contents, body) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Printf("Filter %s updated: %d bytes, %d rules", filter.URL, len(body), rulesCount)
|
||||
filter.RulesCount = rulesCount
|
||||
filter.contents = body
|
||||
|
||||
// Saving it to the filters dir now
|
||||
err = filter.save()
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// write filter file
|
||||
func writeFilterFile() error {
|
||||
filterpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile)
|
||||
log.Printf("Writing filter file: %s", filterpath)
|
||||
// TODO: check if file contents have modified
|
||||
data := []byte{}
|
||||
config.RLock()
|
||||
filters := config.Filters
|
||||
for _, filter := range filters {
|
||||
if !filter.Enabled {
|
||||
continue
|
||||
}
|
||||
data = append(data, filter.contents...)
|
||||
data = append(data, '\n')
|
||||
}
|
||||
for _, rule := range config.UserRules {
|
||||
data = append(data, []byte(rule)...)
|
||||
data = append(data, '\n')
|
||||
}
|
||||
config.RUnlock()
|
||||
err := ioutil.WriteFile(filterpath+".tmp", data, 0644)
|
||||
// saves filter contents to the file in config.ourDataDir
|
||||
func (filter *filter) save() error {
|
||||
|
||||
filterFilePath := filter.getFilterFilePath()
|
||||
log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
err := writeFileSafe(filterFilePath, filter.contents)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't write filter file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Rename(filterpath+".tmp", filterpath)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't rename filter file: %s", err)
|
||||
return nil;
|
||||
}
|
||||
|
||||
// loads filter contents from the file in config.ourDataDir
|
||||
func (filter *filter) load() error {
|
||||
|
||||
if !filter.Enabled {
|
||||
// No need to load a filter that is not enabled
|
||||
return nil
|
||||
}
|
||||
|
||||
filterFilePath := filter.getFilterFilePath()
|
||||
log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath)
|
||||
|
||||
if _, err := os.Stat(filterFilePath); os.IsNotExist(err) {
|
||||
// do nothing, file doesn't exist
|
||||
return err
|
||||
}
|
||||
|
||||
filterFile, err := ioutil.ReadFile(filterFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Filter %d length is %d", filter.ID, len(filterFile))
|
||||
filter.contents = filterFile
|
||||
return nil
|
||||
}
|
||||
|
||||
// Path to the filter contents
|
||||
func (filter *filter) getFilterFilePath() string {
|
||||
return filepath.Join(config.ourBinaryDir, config.ourDataDir, FiltersDir, strconv.Itoa(filter.ID) + ".txt")
|
||||
}
|
||||
|
||||
// ------------
|
||||
// safebrowsing
|
||||
// ------------
|
||||
|
||||
Reference in New Issue
Block a user