1. Added --workdir command-line argument that lets configure the working dir.
2. Made "dnsforward" use this workdir parameter when saving/reading querylog.
3. Reworked "dnsforward" -- moved http handlers out of there to control.go
This commit is contained in:
Andrey Meshkov
2019-02-10 20:47:43 +03:00
parent 6b6eacaa2b
commit 9a03190a62
15 changed files with 630 additions and 418 deletions

View File

@@ -1,6 +1,7 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -8,6 +9,7 @@ import (
"net"
"net/http"
"os"
"sort"
"strconv"
"strings"
"time"
@@ -32,9 +34,28 @@ var client = &http.Client{
Timeout: time.Second * 30,
}
// -------------------
// ----------------
// helper functions
// ----------------
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Println(text)
http.Error(w, text, code)
}
// ---------------
// dns run control
// -------------------
// ---------------
func writeAllConfigsAndReloadDNS() error {
err := writeAllConfigs()
if err != nil {
@@ -55,15 +76,6 @@ func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
returnOK(w)
}
func returnOK(w http.ResponseWriter) {
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func handleStatus(w http.ResponseWriter, r *http.Request) {
data := map[string]interface{}{
"dns_address": config.BindHost,
@@ -117,12 +129,190 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
httpUpdateConfigReloadDNSReturnOK(w, r)
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Println(text)
http.Error(w, text, code)
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
data := dnsServer.GetQueryLog()
jsonVal, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func handleStatsTop(w http.ResponseWriter, r *http.Request) {
s := dnsServer.GetStatsTop()
// use manual json marshalling because we want maps to be sorted by value
statsJSON := bytes.Buffer{}
statsJSON.WriteString("{\n")
gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) {
json.WriteString(" ")
json.WriteString(fmt.Sprintf("%q", name))
json.WriteString(": {\n")
sorted := sortByValue(top)
// no more than 50 entries
if len(sorted) > 50 {
sorted = sorted[:50]
}
for i, key := range sorted {
json.WriteString(" ")
json.WriteString(fmt.Sprintf("%q", key))
json.WriteString(": ")
json.WriteString(strconv.Itoa(top[key]))
if i+1 != len(sorted) {
json.WriteByte(',')
}
json.WriteByte('\n')
}
json.WriteString(" }")
if addComma {
json.WriteByte(',')
}
json.WriteByte('\n')
}
gen(&statsJSON, "top_queried_domains", s.Domains, true)
gen(&statsJSON, "top_blocked_domains", s.Blocked, true)
gen(&statsJSON, "top_clients", s.Clients, true)
statsJSON.WriteString(" \"stats_period\": \"24 hours\"\n")
statsJSON.WriteString("}\n")
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(statsJSON.Bytes())
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
// handleStatsReset resets the stats caches
func handleStatsReset(w http.ResponseWriter, r *http.Request) {
dnsServer.ResetStats()
_, err := fmt.Fprintf(w, "OK\n")
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
// handleStats returns aggregated stats data for the 24 hours
func handleStats(w http.ResponseWriter, r *http.Request) {
summed := dnsServer.GetAggregatedStats()
statsJSON, err := json.Marshal(summed)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(statsJSON)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, 500)
return
}
}
// HandleStatsHistory returns historical stats data for the 24 hours
func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
// handle time unit and prepare our time window size
timeUnitString := r.URL.Query().Get("time_unit")
var timeUnit time.Duration
switch timeUnitString {
case "seconds":
timeUnit = time.Second
case "minutes":
timeUnit = time.Minute
case "hours":
timeUnit = time.Hour
case "days":
timeUnit = time.Hour * 24
default:
http.Error(w, "Must specify valid time_unit parameter", http.StatusBadRequest)
return
}
// parse start and end time
startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time"))
if err != nil {
errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time"))
if err != nil {
errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
if err != nil {
errorText := fmt.Sprintf("Cannot get stats history: %s", err)
http.Error(w, errorText, http.StatusBadRequest)
return
}
statsJSON, err := json.Marshal(data)
if err != nil {
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(statsJSON)
if err != nil {
errorText := fmt.Sprintf("Unable to write response json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
}
// sortByValue is a helper function for querylog API
func sortByValue(m map[string]int) []string {
type kv struct {
k string
v int
}
var ss []kv
for k, v := range m {
ss = append(ss, kv{k, v})
}
sort.Slice(ss, func(l, r int) bool {
return ss[l].v > ss[r].v
})
sorted := []string{}
for _, v := range ss {
sorted = append(sorted, v.k)
}
return sorted
}
// -----------------------
// upstreams configuration
// -----------------------
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
@@ -737,8 +927,8 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data.Interfaces = make(map[string]interface{})
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
addrs, e := iface.Addrs()
if e != nil {
httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
return
}
@@ -844,17 +1034,17 @@ func registerControlHandlers() {
http.HandleFunc("/control/status", postInstall(optionalAuth(ensureGET(handleStatus))))
http.HandleFunc("/control/enable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionEnable))))
http.HandleFunc("/control/disable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionDisable))))
http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(dnsforward.HandleQueryLog))))
http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(handleQueryLog))))
http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable))))
http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable))))
http.HandleFunc("/control/set_upstream_dns", postInstall(optionalAuth(ensurePOST(handleSetUpstreamDNS))))
http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS))))
http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage))))
http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage))))
http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsTop))))
http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(dnsforward.HandleStats))))
http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsHistory))))
http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(dnsforward.HandleStatsReset))))
http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(handleStatsTop))))
http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(handleStats))))
http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory))))
http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset))))
http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON)))
http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable))))
http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable))))