Pull request: * all: move internal Go packages to internal/

Merge in DNS/adguard-home from 2234-move-to-internal to master

Squashed commit of the following:

commit d26a288cabeac86f9483fab307677b1027c78524
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Oct 30 12:44:18 2020 +0300

    * all: move internal Go packages to internal/

    Closes #2234.
This commit is contained in:
Ainar Garipov
2020-10-30 13:32:02 +03:00
parent df3fa595a2
commit ae8de95d89
125 changed files with 85 additions and 85 deletions

180
internal/querylog/decode.go Normal file
View File

@@ -0,0 +1,180 @@
package querylog
import (
"encoding/base64"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// decodeLogEntry - decodes query log entry from a line
// nolint (gocyclo)
func decodeLogEntry(ent *logEntry, str string) {
var b bool
var i int
var err error
for {
k, v, t := readJSON(&str)
if t == jsonTErr {
break
}
switch k {
case "IP":
if len(ent.IP) == 0 {
ent.IP = v
}
case "T":
ent.Time, err = time.Parse(time.RFC3339, v)
case "QH":
ent.QHost = v
case "QT":
ent.QType = v
case "QC":
ent.QClass = v
case "CP":
ent.ClientProto = v
case "Answer":
ent.Answer, err = base64.StdEncoding.DecodeString(v)
case "OrigAnswer":
ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v)
case "IsFiltered":
b, err = strconv.ParseBool(v)
ent.Result.IsFiltered = b
case "Rule":
ent.Result.Rule = v
case "FilterID":
i, err = strconv.Atoi(v)
ent.Result.FilterID = int64(i)
case "Reason":
i, err = strconv.Atoi(v)
ent.Result.Reason = dnsfilter.Reason(i)
case "ServiceName":
ent.Result.ServiceName = v
case "Upstream":
ent.Upstream = v
case "Elapsed":
i, err = strconv.Atoi(v)
ent.Elapsed = time.Duration(i)
// pre-v0.99.3 compatibility:
case "Question":
var qstr []byte
qstr, err = base64.StdEncoding.DecodeString(v)
if err != nil {
break
}
q := new(dns.Msg)
err = q.Unpack(qstr)
if err != nil {
break
}
ent.QHost = q.Question[0].Name
if len(ent.QHost) == 0 {
break
}
ent.QHost = ent.QHost[:len(ent.QHost)-1]
ent.QType = dns.TypeToString[q.Question[0].Qtype]
ent.QClass = dns.ClassToString[q.Question[0].Qclass]
case "Time":
ent.Time, err = time.Parse(time.RFC3339, v)
}
if err != nil {
log.Debug("decodeLogEntry err: %s", err)
break
}
}
}
// Get value from "key":"value"
func readJSONValue(s, name string) string {
i := strings.Index(s, "\""+name+"\":\"")
if i == -1 {
return ""
}
start := i + 1 + len(name) + 3
i = strings.IndexByte(s[start:], '"')
if i == -1 {
return ""
}
end := start + i
return s[start:end]
}
const (
jsonTErr = iota
jsonTObj
jsonTStr
jsonTNum
jsonTBool
)
// Parse JSON key-value pair
// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number)
// Note the limitations:
// . doesn't support whitespace
// . doesn't support "null"
// . doesn't validate boolean or number
// . no proper handling of {} braces
// . no handling of [] brackets
// Return (key, value, type)
func readJSON(ps *string) (string, string, int32) {
s := *ps
k := ""
v := ""
t := int32(jsonTErr)
q1 := strings.IndexByte(s, '"')
if q1 == -1 {
return k, v, t
}
q2 := strings.IndexByte(s[q1+1:], '"')
if q2 == -1 {
return k, v, t
}
k = s[q1+1 : q1+1+q2]
s = s[q1+1+q2+1:]
if len(s) < 2 || s[0] != ':' {
return k, v, t
}
if s[1] == '"' {
q2 = strings.IndexByte(s[2:], '"')
if q2 == -1 {
return k, v, t
}
v = s[2 : 2+q2]
t = jsonTStr
s = s[2+q2+1:]
} else if s[1] == '{' {
t = jsonTObj
s = s[1+1:]
} else {
sep := strings.IndexAny(s[1:], ",}")
if sep == -1 {
return k, v, t
}
v = s[1 : 1+sep]
if s[1] == 't' || s[1] == 'f' {
t = jsonTBool
} else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') {
t = jsonTNum
}
s = s[1+sep+1:]
}
*ps = s
return k, v, t
}

View File

@@ -0,0 +1,34 @@
package querylog
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestJSON(t *testing.T) {
s := `
{"keystr":"val","obj":{"keybool":true,"keyint":123456}}
`
k, v, jtype := readJSON(&s)
assert.Equal(t, jtype, int32(jsonTStr))
assert.Equal(t, "keystr", k)
assert.Equal(t, "val", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTObj))
assert.Equal(t, "obj", k)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTBool))
assert.Equal(t, "keybool", k)
assert.Equal(t, "true", v)
k, v, jtype = readJSON(&s)
assert.Equal(t, jtype, int32(jsonTNum))
assert.Equal(t, "keyint", k)
assert.Equal(t, "123456", v)
k, v, jtype = readJSON(&s)
assert.True(t, jtype == jsonTErr)
}

167
internal/querylog/json.go Normal file
View File

@@ -0,0 +1,167 @@
package querylog
import (
"fmt"
"net"
"strconv"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string {
if l.conf.AnonymizeClientIP {
ip := net.ParseIP(clientIP)
if ip != nil {
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 16
const AnonymizeClientIP6Mask = 112
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
}
}
return clientIP
}
// entriesToJSON - converts log entries to JSON
func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) map[string]interface{} {
// init the response object
var data = []map[string]interface{}{}
// the elements order is already reversed (from newer to older)
for i := 0; i < len(entries); i++ {
entry := entries[i]
jsonEntry := l.logEntryToJSONEntry(entry)
data = append(data, jsonEntry)
}
var result = map[string]interface{}{}
result["oldest"] = ""
if !oldest.IsZero() {
result["oldest"] = oldest.Format(time.RFC3339Nano)
}
result["data"] = data
return result
}
func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} {
var msg *dns.Msg
if len(entry.Answer) > 0 {
msg = new(dns.Msg)
if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer))
msg = nil
}
}
jsonEntry := map[string]interface{}{
"reason": entry.Result.Reason.String(),
"elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64),
"time": entry.Time.Format(time.RFC3339Nano),
"client": l.getClientIP(entry.IP),
"client_proto": entry.ClientProto,
}
jsonEntry["question"] = map[string]interface{}{
"host": entry.QHost,
"type": entry.QType,
"class": entry.QClass,
}
if msg != nil {
jsonEntry["status"] = dns.RcodeToString[msg.Rcode]
opt := msg.IsEdns0()
dnssecOk := false
if opt != nil {
dnssecOk = opt.Do()
}
jsonEntry["answer_dnssec"] = dnssecOk
}
if len(entry.Result.Rule) > 0 {
jsonEntry["rule"] = entry.Result.Rule
jsonEntry["filterId"] = entry.Result.FilterID
}
if len(entry.Result.ServiceName) != 0 {
jsonEntry["service_name"] = entry.Result.ServiceName
}
answers := answerToMap(msg)
if answers != nil {
jsonEntry["answer"] = answers
}
if len(entry.OrigAnswer) != 0 {
a := new(dns.Msg)
err := a.Unpack(entry.OrigAnswer)
if err == nil {
answers = answerToMap(a)
if answers != nil {
jsonEntry["original_answer"] = answers
}
} else {
log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer))
}
}
jsonEntry["upstream"] = entry.Upstream
return jsonEntry
}
func answerToMap(a *dns.Msg) []map[string]interface{} {
if a == nil || len(a.Answer) == 0 {
return nil
}
var answers = []map[string]interface{}{}
for _, k := range a.Answer {
header := k.Header()
answer := map[string]interface{}{
"type": dns.TypeToString[header.Rrtype],
"ttl": header.Ttl,
}
// try most common record types
switch v := k.(type) {
case *dns.A:
answer["value"] = v.A.String()
case *dns.AAAA:
answer["value"] = v.AAAA.String()
case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME:
answer["value"] = v.Target
case *dns.NS:
answer["value"] = v.Ns
case *dns.SPF:
answer["value"] = v.Txt
case *dns.TXT:
answer["value"] = v.Txt
case *dns.PTR:
answer["value"] = v.Ptr
case *dns.SOA:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl)
case *dns.CAA:
answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value)
case *dns.HINFO:
answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os)
case *dns.RRSIG:
answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature)
default:
// type unknown, marshall it as-is
answer["value"] = v
}
answers = append(answers, answer)
}
return answers
}

176
internal/querylog/qlog.go Normal file
View File

@@ -0,0 +1,176 @@
package querylog
import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
const (
queryLogFileName = "querylog.json" // .gz added during compression
)
// queryLog is a structure that writes and reads the DNS query log
type queryLog struct {
conf *Config
lock sync.Mutex
logFile string // path to the log file
bufferLock sync.RWMutex
buffer []*logEntry
fileFlushLock sync.Mutex // synchronize a file-flushing goroutine and main thread
flushPending bool // don't start another goroutine while the previous one is still running
fileWriteLock sync.Mutex
}
// logEntry - represents a single log entry
type logEntry struct {
IP string `json:"IP"` // Client IP
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
ClientProto string `json:"CP"` // "" or "doh"
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result dnsfilter.Result
Elapsed time.Duration
Upstream string `json:",omitempty"` // if empty, means it was cached
}
// create a new instance of the query log
func newQueryLog(conf Config) *queryLog {
l := queryLog{}
l.logFile = filepath.Join(conf.BaseDir, queryLogFileName)
l.conf = &Config{}
*l.conf = conf
if !checkInterval(l.conf.Interval) {
l.conf.Interval = 1
}
return &l
}
func (l *queryLog) Start() {
if l.conf.HTTPRegister != nil {
l.initWeb()
}
go l.periodicRotate()
}
func (l *queryLog) Close() {
_ = l.flushLogBuffer(true)
}
func checkInterval(days uint32) bool {
return days == 1 || days == 7 || days == 30 || days == 90
}
func (l *queryLog) WriteDiskConfig(c *Config) {
*c = *l.conf
}
// Clear memory buffer and remove log files
func (l *queryLog) clear() {
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
l.bufferLock.Lock()
l.buffer = nil
l.flushPending = false
l.bufferLock.Unlock()
err := os.Remove(l.logFile + ".1")
if err != nil && !os.IsNotExist(err) {
log.Error("file remove: %s: %s", l.logFile+".1", err)
}
err = os.Remove(l.logFile)
if err != nil && !os.IsNotExist(err) {
log.Error("file remove: %s: %s", l.logFile, err)
}
log.Debug("Query log: cleared")
}
func (l *queryLog) Add(params AddParams) {
if !l.conf.Enabled {
return
}
if params.Question == nil || len(params.Question.Question) != 1 || len(params.Question.Question[0].Name) == 0 ||
params.ClientIP == nil {
return
}
if params.Result == nil {
params.Result = &dnsfilter.Result{}
}
now := time.Now()
entry := logEntry{
IP: l.getClientIP(params.ClientIP.String()),
Time: now,
Result: *params.Result,
Elapsed: params.Elapsed,
Upstream: params.Upstream,
ClientProto: params.ClientProto,
}
q := params.Question.Question[0]
entry.QHost = strings.ToLower(q.Name[:len(q.Name)-1]) // remove the last dot
entry.QType = dns.Type(q.Qtype).String()
entry.QClass = dns.Class(q.Qclass).String()
if params.Answer != nil {
a, err := params.Answer.Pack()
if err != nil {
log.Info("Querylog: Answer.Pack(): %s", err)
return
}
entry.Answer = a
}
if params.OrigAnswer != nil {
a, err := params.OrigAnswer.Pack()
if err != nil {
log.Info("Querylog: OrigAnswer.Pack(): %s", err)
return
}
entry.OrigAnswer = a
}
l.bufferLock.Lock()
l.buffer = append(l.buffer, &entry)
needFlush := false
if !l.conf.FileEnabled {
if len(l.buffer) > int(l.conf.MemSize) {
// writing to file is disabled - just remove the oldest entry from array
l.buffer = l.buffer[1:]
}
} else if !l.flushPending {
needFlush = len(l.buffer) >= int(l.conf.MemSize)
if needFlush {
l.flushPending = true
}
}
l.bufferLock.Unlock()
// if buffer needs to be flushed to disk, do it now
if needFlush {
go func() {
_ = l.flushLogBuffer(false)
}()
}
}

View File

@@ -0,0 +1,336 @@
package querylog
import (
"errors"
"io"
"os"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
)
// ErrSeekNotFound is returned from the Seek method
// if we failed to find the desired record
var ErrSeekNotFound = errors.New("Seek not found the record")
// TODO: Find a way to grow buffer instead of relying on this value when reading strings
const maxEntrySize = 16 * 1024
// buffer should be enough for at least this number of entries
const bufferSize = 100 * maxEntrySize
// QLogFile represents a single query log file
// It allows reading from the file in the reverse order
//
// Please note that this is a stateful object.
// Internally, it contains a pointer to a specific position in the file,
// and it reads lines in reverse order starting from that position.
type QLogFile struct {
file *os.File // the query log file
position int64 // current position in the file
buffer []byte // buffer that we've read from the file
bufferStart int64 // start of the buffer (in the file)
bufferLen int // buffer len
lock sync.Mutex // We use mutex to make it thread-safe
}
// NewQLogFile initializes a new instance of the QLogFile
func NewQLogFile(path string) (*QLogFile, error) {
f, err := os.OpenFile(path, os.O_RDONLY, 0644)
if err != nil {
return nil, err
}
return &QLogFile{
file: f,
}, nil
}
// Seek performs binary search in the query log file looking for a record
// with the specified timestamp. Once the record is found, it sets
// "position" so that the next ReadNext call returned that record.
//
// The algorithm is rather simple:
// 1. It starts with the position in the middle of a file
// 2. Shifts back to the beginning of the line
// 3. Checks the log record timestamp
// 4. If it is lower than the timestamp we are looking for,
// it shifts seek position to 3/4 of the file. Otherwise, to 1/4 of the file.
// 5. It performs the search again, every time the search scope is narrowed twice.
//
// Returns:
// * It returns the position of the the line with the timestamp we were looking for
// so that when we call "ReadNext" this line was returned.
// * Depth of the search (how many times we compared timestamps).
// * If we could not find it, it returns ErrSeekNotFound
func (q *QLogFile) Seek(timestamp int64) (int64, int, error) {
q.lock.Lock()
defer q.lock.Unlock()
// Empty the buffer
q.buffer = nil
// First of all, check the file size
fileInfo, err := q.file.Stat()
if err != nil {
return 0, 0, err
}
// Define the search scope
start := int64(0) // start of the search interval (position in the file)
end := fileInfo.Size() // end of the search interval (position in the file)
probe := (end - start) / 2 // probe -- approximate index of the line we'll try to check
var line string
var lineIdx int64 // index of the probe line in the file
var lineEndIdx int64
var lastProbeLineIdx int64 // index of the last probe line
lastProbeLineIdx = -1
// Count seek depth in order to detect mistakes
// If depth is too large, we should stop the search
depth := 0
for {
// Get the line at the specified position
line, lineIdx, lineEndIdx, err = q.readProbeLine(probe)
if err != nil {
return 0, depth, err
}
if lineIdx < start || lineEndIdx > end || lineIdx == lastProbeLineIdx {
// If we're testing the same line twice then most likely
// the scope is too narrow and we won't find anything anymore
log.Error("querylog: didn't find timestamp:%v", timestamp)
return 0, depth, ErrSeekNotFound
}
// Save the last found idx
lastProbeLineIdx = lineIdx
// Get the timestamp from the query log record
ts := readQLogTimestamp(line)
if ts == 0 {
return 0, depth, ErrSeekNotFound
}
if ts == timestamp {
// Hurray, returning the result
break
}
// Narrow the scope and repeat the search
if ts > timestamp {
// If the timestamp we're looking for is OLDER than what we found
// Then the line is somewhere on the LEFT side from the current probe position
end = lineIdx
} else {
// If the timestamp we're looking for is NEWER than what we found
// Then the line is somewhere on the RIGHT side from the current probe position
start = lineEndIdx
}
probe = start + (end-start)/2
depth++
if depth >= 100 {
log.Error("Seek depth is too high, aborting. File %s, ts %v", q.file.Name(), timestamp)
return 0, depth, ErrSeekNotFound
}
}
q.position = lineIdx + int64(len(line))
return q.position, depth, nil
}
// SeekStart changes the current position to the end of the file
// Please note that we're reading query log in the reverse order
// and that's why log start is actually the end of file
//
// Returns nil if we were able to change the current position.
// Returns error in any other case.
func (q *QLogFile) SeekStart() (int64, error) {
q.lock.Lock()
defer q.lock.Unlock()
// Empty the buffer
q.buffer = nil
// First of all, check the file size
fileInfo, err := q.file.Stat()
if err != nil {
return 0, err
}
// Place the position to the very end of file
q.position = fileInfo.Size() - 1
if q.position < 0 {
q.position = 0
}
return q.position, nil
}
// ReadNext reads the next line (in the reverse order) from the file
// and shifts the current position left to the next (actually prev) line.
// returns io.EOF if there's nothing to read more
func (q *QLogFile) ReadNext() (string, error) {
q.lock.Lock()
defer q.lock.Unlock()
if q.position == 0 {
return "", io.EOF
}
line, lineIdx, err := q.readNextLine(q.position)
if err != nil {
return "", err
}
// Shift position
if lineIdx == 0 {
q.position = 0
} else {
// there's usually a line break before the line
// so we should shift one more char left from the line
// line\nline
q.position = lineIdx - 1
}
return line, err
}
// Close frees the underlying resources
func (q *QLogFile) Close() error {
return q.file.Close()
}
// readNextLine reads the next line from the specified position
// this line actually have to END on that position.
//
// the algorithm is:
// 1. check if we have the buffer initialized
// 2. if it is, scan it and look for the line there
// 3. if we cannot find the line there, read the prev chunk into the buffer
// 4. read the line from the buffer
func (q *QLogFile) readNextLine(position int64) (string, int64, error) {
relativePos := position - q.bufferStart
if q.buffer == nil || (relativePos < maxEntrySize && q.bufferStart != 0) {
// Time to re-init the buffer
err := q.initBuffer(position)
if err != nil {
return "", 0, err
}
relativePos = position - q.bufferStart
}
// Look for the end of the prev line
// This is where we'll read from
var startLine = int64(0)
for i := relativePos - 1; i >= 0; i-- {
if q.buffer[i] == '\n' {
startLine = i + 1
break
}
}
line := string(q.buffer[startLine:relativePos])
lineIdx := q.bufferStart + startLine
return line, lineIdx, nil
}
// initBuffer initializes the QLogFile buffer.
// the goal is to read a chunk of file that includes the line with the specified position.
func (q *QLogFile) initBuffer(position int64) error {
q.bufferStart = int64(0)
if (position - bufferSize) > 0 {
q.bufferStart = position - bufferSize
}
// Seek to this position
_, err := q.file.Seek(q.bufferStart, io.SeekStart)
if err != nil {
return err
}
if q.buffer == nil {
q.buffer = make([]byte, bufferSize)
}
q.bufferLen, err = q.file.Read(q.buffer)
if err != nil {
return err
}
return nil
}
// readProbeLine reads a line that includes the specified position
// this method is supposed to be used when we use binary search in the Seek method
// in the case of consecutive reads, use readNext (it uses a better buffer)
func (q *QLogFile) readProbeLine(position int64) (string, int64, int64, error) {
// First of all, we should read a buffer that will include the query log line
// In order to do this, we'll define the boundaries
seekPosition := int64(0)
relativePos := position // position relative to the buffer we're going to read
if (position - maxEntrySize) > 0 {
seekPosition = position - maxEntrySize
relativePos = maxEntrySize
}
// Seek to this position
_, err := q.file.Seek(seekPosition, io.SeekStart)
if err != nil {
return "", 0, 0, err
}
// The buffer size is 2*maxEntrySize
buffer := make([]byte, maxEntrySize*2)
bufferLen, err := q.file.Read(buffer)
if err != nil {
return "", 0, 0, err
}
// Now start looking for the new line character starting
// from the relativePos and going left
var startLine = int64(0)
for i := relativePos - 1; i >= 0; i-- {
if buffer[i] == '\n' {
startLine = i + 1
break
}
}
// Looking for the end of line now
var endLine = int64(bufferLen)
lineEndIdx := endLine + seekPosition
for i := relativePos; i < int64(bufferLen); i++ {
if buffer[i] == '\n' {
endLine = i
lineEndIdx = endLine + seekPosition + 1
break
}
}
// Finally we can return the string we were looking for
lineIdx := startLine + seekPosition
return string(buffer[startLine:endLine]), lineIdx, lineEndIdx, nil
}
// readQLogTimestamp reads the timestamp field from the query log line
func readQLogTimestamp(str string) int64 {
val := readJSONValue(str, "T")
if len(val) == 0 {
val = readJSONValue(str, "Time")
}
if len(val) == 0 {
log.Error("Couldn't find timestamp: %s", str)
return 0
}
tm, err := time.Parse(time.RFC3339Nano, val)
if err != nil {
log.Error("Couldn't parse timestamp: %s", val)
return 0
}
return tm.UnixNano()
}

View File

@@ -0,0 +1,290 @@
package querylog
import (
"encoding/binary"
"io"
"io/ioutil"
"math"
"net"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestQLogFileEmpty(t *testing.T) {
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, 0)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.Equal(t, int64(0), pos)
// try reading anyway
line, err := q.ReadNext()
assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line)
}
func TestQLogFileLarge(t *testing.T) {
// should be large enough
count := 50000
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, count)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos)
read := 0
var line string
for err == nil {
line, err = q.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
read += 1
}
}
assert.Equal(t, count, read)
assert.Equal(t, io.EOF, err)
}
func TestQLogFileSeekLargeFile(t *testing.T) {
// more or less big file
count := 10000
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, count)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// CASE 1: NOT TOO OLD LINE
testSeekLineQLogFile(t, q, 300)
// CASE 2: OLD LINE
testSeekLineQLogFile(t, q, count-300)
// CASE 3: FIRST LINE
testSeekLineQLogFile(t, q, 0)
// CASE 4: LAST LINE
testSeekLineQLogFile(t, q, count)
// CASE 5: Seek non-existent (too low)
_, _, err = q.Seek(123)
assert.NotNil(t, err)
// CASE 6: Seek non-existent (too high)
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
_, _, err = q.Seek(ts.UnixNano())
assert.NotNil(t, err)
// CASE 7: "Almost" found
line, err := getQLogFileLine(q, count/2)
assert.Nil(t, err)
// ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp)
_, depth, err := q.Seek(timestamp)
assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3))
}
func TestQLogFileSeekSmallFile(t *testing.T) {
// more or less big file
count := 10
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, count)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// CASE 1: NOT TOO OLD LINE
testSeekLineQLogFile(t, q, 2)
// CASE 2: OLD LINE
testSeekLineQLogFile(t, q, count-2)
// CASE 3: FIRST LINE
testSeekLineQLogFile(t, q, 0)
// CASE 4: LAST LINE
testSeekLineQLogFile(t, q, count)
// CASE 5: Seek non-existent (too low)
_, _, err = q.Seek(123)
assert.NotNil(t, err)
// CASE 6: Seek non-existent (too high)
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
_, _, err = q.Seek(ts.UnixNano())
assert.NotNil(t, err)
// CASE 7: "Almost" found
line, err := getQLogFileLine(q, count/2)
assert.Nil(t, err)
// ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp)
_, depth, err := q.Seek(timestamp)
assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3))
}
func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) {
line, err := getQLogFileLine(q, lineNumber)
assert.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts)
// try seeking to that line now
pos, _, err := q.Seek(ts)
assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos)
testLine, err := q.ReadNext()
assert.Nil(t, err)
assert.Equal(t, line, testLine)
}
func getQLogFileLine(q *QLogFile, lineNumber int) (string, error) {
_, err := q.SeekStart()
if err != nil {
return "", err
}
for i := 1; i < lineNumber; i++ {
_, err := q.ReadNext()
if err != nil {
return "", err
}
}
return q.ReadNext()
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQLogFile(t *testing.T) {
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFile := prepareTestFile(testDir, 2)
// create the new QLogFile instance
q, err := NewQLogFile(testFile)
assert.Nil(t, err)
assert.NotNil(t, q)
defer q.Close()
// seek to the start
pos, err := q.SeekStart()
assert.Nil(t, err)
assert.True(t, pos > 0)
// read first line
line, err := q.ReadNext()
assert.Nil(t, err)
assert.True(t, strings.Contains(line, "0.0.0.2"), line)
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// read second line
line, err = q.ReadNext()
assert.Nil(t, err)
assert.Equal(t, int64(0), q.position)
assert.True(t, strings.Contains(line, "0.0.0.1"), line)
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
// try reading again (there's nothing to read anymore)
line, err = q.ReadNext()
assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line)
}
// prepareTestFile - prepares a test query log file with the specified number of lines
func prepareTestFile(dir string, linesCount int) string {
return prepareTestFiles(dir, 1, linesCount)[0]
}
// prepareTestFiles - prepares several test query log files
// each of them -- with the specified linesCount
func prepareTestFiles(dir string, filesCount, linesCount int) []string {
format := `{"IP":"${IP}","T":"${TIMESTAMP}","QH":"example.org","QT":"A","QC":"IN","Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=","Result":{},"Elapsed":0,"Upstream":"upstream"}`
lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00")
lineIP := uint32(0)
files := make([]string, 0)
for j := 0; j < filesCount; j++ {
f, _ := ioutil.TempFile(dir, "*.txt")
files = append(files, f.Name())
for i := 0; i < linesCount; i++ {
lineIP += 1
lineTime = lineTime.Add(time.Second)
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, lineIP)
line := format
line = strings.ReplaceAll(line, "${IP}", ip.String())
line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano))
_, _ = f.WriteString(line)
_, _ = f.WriteString("\n")
}
}
return files
}
func TestQLogSeek(t *testing.T) {
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
d := `{"T":"2020-08-31T18:44:23.911246629+03:00","QH":"wfqvjymurpwegyv","QT":"A","QC":"IN","CP":"","Answer":"","Result":{},"Elapsed":66286385,"Upstream":"tls://dns-unfiltered.adguard.com:853"}
{"T":"2020-08-31T18:44:25.376690873+03:00"}
{"T":"2020-08-31T18:44:25.382540454+03:00"}`
f, _ := ioutil.TempFile(testDir, "*.txt")
_, _ = f.WriteString(d)
defer f.Close()
q, err := NewQLogFile(f.Name())
assert.Nil(t, err)
defer q.Close()
target, _ := time.Parse(time.RFC3339, "2020-08-31T18:44:25.376690873+03:00")
_, depth, err := q.Seek(target.UnixNano())
assert.Nil(t, err)
assert.Equal(t, 1, depth)
}

View File

@@ -0,0 +1,194 @@
package querylog
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
AnonymizeClientIP bool `json:"anonymize_client_ip"`
}
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("QueryLog: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
params, err := l.parseSearchParams(r)
if err != nil {
httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err)
return
}
// search for the log entries
entries, oldest := l.search(params)
// convert log entries to JSON
var data = l.entriesToJSON(entries, oldest)
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't marshal data into json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
l.clear()
}
// Get configuration
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
resp := qlogConfig{}
resp.Enabled = l.conf.Enabled
resp.Interval = l.conf.Interval
resp.AnonymizeClientIP = l.conf.AnonymizeClientIP
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set configuration
func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) {
d := qlogConfig{}
req, err := jsonutil.DecodeObject(&d, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
if req.Exists("interval") && !checkInterval(d.Interval) {
httpError(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
l.lock.Lock()
// copy data, modify it, then activate. Other threads (readers) don't need to use this lock.
conf := *l.conf
if req.Exists("enabled") {
conf.Enabled = d.Enabled
}
if req.Exists("interval") {
conf.Interval = d.Interval
}
if req.Exists("anonymize_client_ip") {
conf.AnonymizeClientIP = d.AnonymizeClientIP
}
l.conf = &conf
l.lock.Unlock()
l.conf.ConfigModified()
}
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
*s = t[1 : len(t)-1]
return true
}
return false
}
// parseSearchCriteria - parses "searchCriteria" from the specified query parameter
func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) {
val := q.Get(name)
if len(val) == 0 {
return false, searchCriteria{}, nil
}
c := searchCriteria{
criteriaType: ct,
value: val,
}
if getDoubleQuotesEnclosedValue(&c.value) {
c.strict = true
}
if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) {
return false, c, fmt.Errorf("invalid value %s", c.value)
}
return true, c, nil
}
// parseSearchParams - parses "searchParams" from the HTTP request's query string
func (l *queryLog) parseSearchParams(r *http.Request) (*searchParams, error) {
p := newSearchParams()
var err error
q := r.URL.Query()
olderThan := q.Get("older_than")
if len(olderThan) != 0 {
p.olderThan, err = time.Parse(time.RFC3339Nano, olderThan)
if err != nil {
return nil, err
}
}
if limit, err := strconv.ParseInt(q.Get("limit"), 10, 64); err == nil {
p.limit = int(limit)
}
if offset, err := strconv.ParseInt(q.Get("offset"), 10, 64); err == nil {
p.offset = int(offset)
// If we don't use "olderThan" and use offset/limit instead, we should change the default behavior
// and scan all log records until we found enough log entries
p.maxFileScanEntries = 0
}
paramNames := map[string]criteriaType{
"search": ctDomainOrClient,
"response_status": ctFilteringStatus,
}
for k, v := range paramNames {
ok, c, err := l.parseSearchCriteria(q, k, v)
if err != nil {
return nil, err
}
if ok {
p.searchCriteria = append(p.searchCriteria, c)
}
}
return p, nil
}

View File

@@ -0,0 +1,139 @@
package querylog
import (
"io"
"github.com/joomcode/errorx"
)
// QLogReader allows reading from multiple query log files in the reverse order.
//
// Please note that this is a stateful object.
// Internally, it contains a pointer to a particular query log file, and
// to a specific position in this file, and it reads lines in reverse order
// starting from that position.
type QLogReader struct {
// qFiles - array with the query log files
// The order is - from oldest to newest
qFiles []*QLogFile
currentFile int // Index of the current file
}
// NewQLogReader initializes a QLogReader instance
// with the specified files
func NewQLogReader(files []string) (*QLogReader, error) {
qFiles := make([]*QLogFile, 0)
for _, f := range files {
q, err := NewQLogFile(f)
if err != nil {
// Close what we've already opened
_ = closeQFiles(qFiles)
return nil, err
}
qFiles = append(qFiles, q)
}
return &QLogReader{
qFiles: qFiles,
currentFile: (len(qFiles) - 1),
}, nil
}
// Seek performs binary search of a query log record with the specified timestamp.
// If the record is found, it sets QLogReader's position to point to that line,
// so that the next ReadNext call returned this line.
//
// Returns nil if the record is successfully found.
// Returns an error if for some reason we could not find a record with the specified timestamp.
func (r *QLogReader) Seek(timestamp int64) error {
for i := len(r.qFiles) - 1; i >= 0; i-- {
q := r.qFiles[i]
_, _, err := q.Seek(timestamp)
if err == nil {
// Our search is finished, we found the element we were looking for
// Update currentFile only, position is already set properly in the QLogFile
r.currentFile = i
return nil
}
}
return ErrSeekNotFound
}
// SeekStart changes the current position to the end of the newest file
// Please note that we're reading query log in the reverse order
// and that's why log start is actually the end of file
//
// Returns nil if we were able to change the current position.
// Returns error in any other case.
func (r *QLogReader) SeekStart() error {
if len(r.qFiles) == 0 {
return nil
}
r.currentFile = len(r.qFiles) - 1
_, err := r.qFiles[r.currentFile].SeekStart()
return err
}
// ReadNext reads the next line (in the reverse order) from the query log files.
// and shifts the current position left to the next (actually prev) line (or the next file).
// returns io.EOF if there's nothing to read more.
func (r *QLogReader) ReadNext() (string, error) {
if len(r.qFiles) == 0 {
return "", io.EOF
}
for r.currentFile >= 0 {
q := r.qFiles[r.currentFile]
line, err := q.ReadNext()
if err != nil {
// Shift to the older file
r.currentFile--
if r.currentFile < 0 {
break
}
q = r.qFiles[r.currentFile]
// Set it's position to the start right away
_, err = q.SeekStart()
// This is unexpected, return an error right away
if err != nil {
return "", err
}
} else {
return line, nil
}
}
// Nothing to read anymore
return "", io.EOF
}
// Close closes the QLogReader
func (r *QLogReader) Close() error {
return closeQFiles(r.qFiles)
}
// closeQFiles - helper method to close multiple QLogFile instances
func closeQFiles(qFiles []*QLogFile) error {
var errs []error
for _, q := range qFiles {
err := q.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errorx.DecorateMany("Error while closing QLogReader", errs...)
}
return nil
}

View File

@@ -0,0 +1,157 @@
package querylog
import (
"io"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestQLogReaderEmpty(t *testing.T) {
r, err := NewQLogReader([]string{})
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// seek to the start
err = r.SeekStart()
assert.Nil(t, err)
line, err := r.ReadNext()
assert.Equal(t, "", line)
assert.Equal(t, io.EOF, err)
}
func TestQLogReaderOneFile(t *testing.T) {
// let's do one small file
count := 10
filesCount := 1
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// seek to the start
err = r.SeekStart()
assert.Nil(t, err)
// read everything
read := 0
var line string
for err == nil {
line, err = r.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
read += 1
}
}
assert.Equal(t, count*filesCount, read)
assert.Equal(t, io.EOF, err)
}
func TestQLogReaderMultipleFiles(t *testing.T) {
// should be large enough
count := 10000
filesCount := 5
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// seek to the start
err = r.SeekStart()
assert.Nil(t, err)
// read everything
read := 0
var line string
for err == nil {
line, err = r.ReadNext()
if err == nil {
assert.True(t, len(line) > 0)
read += 1
}
}
assert.Equal(t, count*filesCount, read)
assert.Equal(t, io.EOF, err)
}
func TestQLogReaderSeek(t *testing.T) {
// more or less big file
count := 10000
filesCount := 2
testDir := prepareTestDir()
defer func() { _ = os.RemoveAll(testDir) }()
testFiles := prepareTestFiles(testDir, filesCount, count)
r, err := NewQLogReader(testFiles)
assert.Nil(t, err)
assert.NotNil(t, r)
defer r.Close()
// CASE 1: NOT TOO OLD LINE
testSeekLineQLogReader(t, r, 300)
// CASE 2: OLD LINE
testSeekLineQLogReader(t, r, count-300)
// CASE 3: FIRST LINE
testSeekLineQLogReader(t, r, 0)
// CASE 4: LAST LINE
testSeekLineQLogReader(t, r, count)
// CASE 5: Seek non-existent (too low)
err = r.Seek(123)
assert.NotNil(t, err)
// CASE 6: Seek non-existent (too high)
ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00")
err = r.Seek(ts.UnixNano())
assert.NotNil(t, err)
}
func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) {
line, err := getQLogReaderLine(r, lineNumber)
assert.Nil(t, err)
ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts)
// try seeking to that line now
err = r.Seek(ts)
assert.Nil(t, err)
testLine, err := r.ReadNext()
assert.Nil(t, err)
assert.Equal(t, line, testLine)
}
func getQLogReaderLine(r *QLogReader, lineNumber int) (string, error) {
err := r.SeekStart()
if err != nil {
return "", err
}
for i := 1; i < lineNumber; i++ {
_, err := r.ReadNext()
if err != nil {
return "", err
}
}
return r.ReadNext()
}

View File

@@ -0,0 +1,253 @@
package querylog
import (
"net"
"os"
"testing"
"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func prepareTestDir() string {
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0755)
return dir
}
// Check adding and loading (with filtering) entries from disk and memory
func TestQueryLog(t *testing.T) {
conf := Config{
Enabled: true,
FileEnabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
// write to disk (first file)
_ = l.flushLogBuffer(true)
// start writing to the second file
_ = l.rotate()
// add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2")
// write to disk
_ = l.flushLogBuffer(true)
// add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3")
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4")
// get all entries
params := newSearchParams()
entries, _ := l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
// search by domain (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: true,
value: "TEST.example.org",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
// search by domain (not strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: false,
value: "example.ORG",
})
entries, _ = l.search(params)
assert.Equal(t, 3, len(entries))
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
// search by client IP (strict)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: true,
value: "2.2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 1, len(entries))
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
// search by client IP (part of)
params = newSearchParams()
params.searchCriteria = append(params.searchCriteria, searchCriteria{
criteriaType: ctDomainOrClient,
strict: false,
value: "2.2.2",
})
entries, _ = l.search(params)
assert.Equal(t, 4, len(entries))
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1")
}
func TestQueryLogOffsetLimit(t *testing.T) {
conf := Config{
Enabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1")
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
// add 10 more entries to the log (memory)
for i := 0; i < 10; i++ {
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1")
}
// First page
params := newSearchParams()
params.offset = 0
params.limit = 10
entries, _ := l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "first.example.org")
assert.Equal(t, entries[9].QHost, "first.example.org")
// Second page
params.offset = 10
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[9].QHost, "second.example.org")
// Second and a half page
params.offset = 15
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 5, len(entries))
assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[4].QHost, "second.example.org")
// Third page
params.offset = 20
params.limit = 10
entries, _ = l.search(params)
assert.Equal(t, 0, len(entries))
}
func TestQueryLogMaxFileScanEntries(t *testing.T) {
conf := Config{
Enabled: true,
FileEnabled: true,
Interval: 1,
MemSize: 100,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
// add 10 entries to the log
for i := 0; i < 10; i++ {
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1")
}
// write them to disk (first file)
_ = l.flushLogBuffer(true)
params := newSearchParams()
params.maxFileScanEntries = 5 // do not scan more than 5 records
entries, _ := l.search(params)
assert.Equal(t, 5, len(entries))
params.maxFileScanEntries = 0 // disable the limit
entries, _ = l.search(params)
assert.Equal(t, 10, len(entries))
}
func TestQueryLogFileDisabled(t *testing.T) {
conf := Config{
Enabled: true,
FileEnabled: false,
Interval: 1,
MemSize: 2,
}
conf.BaseDir = prepareTestDir()
defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf)
addEntry(l, "example1.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example2.org", "1.1.1.1", "2.2.2.1")
addEntry(l, "example3.org", "1.1.1.1", "2.2.2.1")
// the oldest entry is now removed from mem buffer
params := newSearchParams()
ll, _ := l.search(params)
assert.Equal(t, 2, len(ll))
assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost)
}
func addEntry(l *queryLog, host, answerStr, client string) {
q := dns.Msg{}
q.Question = append(q.Question, dns.Question{
Name: host + ".",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
})
a := dns.Msg{}
a.Question = append(a.Question, q.Question[0])
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
}
answer.A = net.ParseIP(answerStr)
a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{}
params := AddParams{
Question: &q,
Answer: &a,
Result: &res,
ClientIP: net.ParseIP(client),
Upstream: "upstream",
}
l.Add(params)
}
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool {
assert.Equal(t, host, entry.QHost)
assert.Equal(t, client, entry.IP)
assert.Equal(t, "A", entry.QType)
assert.Equal(t, "IN", entry.QClass)
msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer))
assert.Equal(t, 1, len(msg.Answer))
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String())
return true
}

View File

@@ -0,0 +1,57 @@
package querylog
import (
"net"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/miekg/dns"
)
// QueryLog - main interface
type QueryLog interface {
Start()
// Close query log object
Close()
// Add a log entry
Add(params AddParams)
// WriteDiskConfig - write configuration
WriteDiskConfig(c *Config)
}
// Config - configuration object
type Config struct {
Enabled bool // enable the module
FileEnabled bool // write logs to file
BaseDir string // directory where log file is stored
Interval uint32 // interval to rotate logs (in days)
MemSize uint32 // number of entries kept in memory before they are flushed to disk
AnonymizeClientIP bool // anonymize clients' IP addresses
// Called when the configuration is changed by HTTP request
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
}
// AddParams - parameters for Add()
type AddParams struct {
Question *dns.Msg
Answer *dns.Msg // The response we sent to the client (optional)
OrigAnswer *dns.Msg // The response from an upstream server (optional)
Result *dnsfilter.Result // Filtering result (optional)
Elapsed time.Duration // Time spent for processing the request
ClientIP net.IP
Upstream string // Upstream server URL
ClientProto string // Protocol for the client connection: "" (plain), "doh", "dot"
}
// New - create a new instance of the query log
func New(conf Config) QueryLog {
return newQueryLog(conf)
}

View File

@@ -0,0 +1,139 @@
package querylog
import (
"bytes"
"encoding/json"
"os"
"time"
"github.com/AdguardTeam/golibs/log"
)
// flushLogBuffer flushes the current buffer to file and resets the current buffer
func (l *queryLog) flushLogBuffer(fullFlush bool) error {
if !l.conf.FileEnabled {
return nil
}
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
// flush remainder to file
l.bufferLock.Lock()
needFlush := len(l.buffer) >= int(l.conf.MemSize)
if !needFlush && !fullFlush {
l.bufferLock.Unlock()
return nil
}
flushBuffer := l.buffer
l.buffer = nil
l.flushPending = false
l.bufferLock.Unlock()
err := l.flushToFile(flushBuffer)
if err != nil {
log.Error("Saving querylog to file failed: %s", err)
return err
}
return nil
}
// flushToFile saves the specified log entries to the query log file
func (l *queryLog) flushToFile(buffer []*logEntry) error {
if len(buffer) == 0 {
log.Debug("querylog: there's nothing to write to a file")
return nil
}
start := time.Now()
var b bytes.Buffer
e := json.NewEncoder(&b)
for _, entry := range buffer {
err := e.Encode(entry)
if err != nil {
log.Error("Failed to marshal entry: %s", err)
return err
}
}
elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
var err error
var zb bytes.Buffer
filename := l.logFile
zb = b
l.fileWriteLock.Lock()
defer l.fileWriteLock.Unlock()
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
log.Error("failed to create file \"%s\": %s", filename, err)
return err
}
defer f.Close()
n, err := f.Write(zb.Bytes())
if err != nil {
log.Error("Couldn't write to file: %s", err)
return err
}
log.Debug("querylog: ok \"%s\": %v bytes written", filename, n)
return nil
}
func (l *queryLog) rotate() error {
from := l.logFile
to := l.logFile + ".1"
if _, err := os.Stat(from); os.IsNotExist(err) {
// do nothing, file doesn't exist
return nil
}
err := os.Rename(from, to)
if err != nil {
log.Error("querylog: failed to rename file: %s", err)
return err
}
log.Debug("querylog: renamed %s -> %s", from, to)
return nil
}
func (l *queryLog) readFileFirstTimeValue() int64 {
f, err := os.Open(l.logFile)
if err != nil {
return -1
}
defer f.Close()
buf := make([]byte, 500)
r, err := f.Read(buf)
if err != nil {
return -1
}
buf = buf[:r]
val := readJSONValue(string(buf), "T")
t, err := time.Parse(time.RFC3339Nano, val)
if err != nil {
return -1
}
log.Debug("querylog: the oldest log entry: %s", val)
return t.Unix()
}
func (l *queryLog) periodicRotate() {
intervalSeconds := uint64(l.conf.Interval) * 24 * 60 * 60
for {
oldest := l.readFileFirstTimeValue()
if uint64(oldest)+intervalSeconds <= uint64(time.Now().Unix()) {
_ = l.rotate()
}
time.Sleep(24 * time.Hour)
}
}

View File

@@ -0,0 +1,181 @@
package querylog
import (
"io"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
)
// search - searches log entries in the query log using specified parameters
// returns the list of entries found + time of the oldest entry
func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) {
now := time.Now()
if params.limit == 0 {
return []*logEntry{}, time.Time{}
}
// add from file
fileEntries, oldest, total := l.searchFiles(params)
// add from memory buffer
l.bufferLock.Lock()
total += len(l.buffer)
memoryEntries := make([]*logEntry, 0)
// go through the buffer in the reverse order
// from NEWER to OLDER
for i := len(l.buffer) - 1; i >= 0; i-- {
entry := l.buffer[i]
if !params.match(entry) {
continue
}
memoryEntries = append(memoryEntries, entry)
}
l.bufferLock.Unlock()
// limits
totalLimit := params.offset + params.limit
// now let's get a unified collection
entries := append(memoryEntries, fileEntries...)
if len(entries) > totalLimit {
// remove extra records
entries = entries[:totalLimit]
}
if params.offset > 0 {
if len(entries) > params.offset {
entries = entries[params.offset:]
} else {
entries = make([]*logEntry, 0)
oldest = time.Time{}
}
}
if len(entries) == totalLimit {
// change the "oldest" value here.
// we cannot use the "oldest" we got from "searchFiles" anymore
// because after adding in-memory records and removing extra records
// the situation has changed
oldest = entries[len(entries)-1].Time
}
log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s",
len(entries), total, params.olderThan, time.Since(now))
return entries, oldest
}
// searchFiles reads log entries from all log files and applies the specified search criteria.
// IMPORTANT: this method does not scan more than "maxSearchEntries" so you
// may need to call it many times.
//
// it returns:
// * an array of log entries that we have read
// * time of the oldest processed entry (even if it was discarded)
// * total number of processed entries (including discarded).
func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, int) {
entries := make([]*logEntry, 0)
oldest := time.Time{}
r, err := l.openReader()
if err != nil {
log.Error("Failed to open qlog reader: %v", err)
return entries, oldest, 0
}
defer r.Close()
if params.olderThan.IsZero() {
err = r.SeekStart()
} else {
err = r.Seek(params.olderThan.UnixNano())
if err == nil {
// Read to the next record right away
// The one that was specified in the "oldest" param is not needed,
// we need only the one next to it
_, err = r.ReadNext()
}
}
if err != nil {
log.Debug("Cannot Seek() to %v: %v", params.olderThan, err)
return entries, oldest, 0
}
totalLimit := params.offset + params.limit
total := 0
oldestNano := int64(0)
// By default, we do not scan more than "maxFileScanEntries" at once
// The idea is to make search calls faster so that the UI could handle it and show something
// This behavior can be overridden if "maxFileScanEntries" is set to 0
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
entry, ts, err := l.readNextEntry(r, params)
if err == io.EOF {
// there's nothing to read anymore
break
}
oldestNano = ts
total++
if entry != nil {
entries = append(entries, entry)
if len(entries) == totalLimit {
// Do not read more than "totalLimit" records at once
break
}
}
}
if oldestNano != 0 {
oldest = time.Unix(0, oldestNano)
}
return entries, oldest, total
}
// readNextEntry - reads the next log entry and checks if it matches the search criteria (getDataParams)
//
// returns:
// * log entry that matches search criteria or null if it was discarded (or if there's nothing to read)
// * timestamp of the processed log entry
// * error if we can't read anymore
func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry, int64, error) {
line, err := r.ReadNext()
if err != nil {
return nil, 0, err
}
// Read the log record timestamp right away
timestamp := readQLogTimestamp(line)
// Quick check without deserializing log entry
if !params.quickMatch(line) {
return nil, timestamp, nil
}
entry := logEntry{}
decodeLogEntry(&entry, line)
// Full check of the deserialized log entry
if !params.match(&entry) {
return nil, timestamp, nil
}
return &entry, timestamp, nil
}
// openReader - opens QLogReader instance
func (l *queryLog) openReader() (*QLogReader, error) {
files := make([]string, 0)
if util.FileExists(l.logFile + ".1") {
files = append(files, l.logFile+".1")
}
if util.FileExists(l.logFile) {
files = append(files, l.logFile)
}
return NewQLogReader(files)
}

View File

@@ -0,0 +1,142 @@
package querylog
import (
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
)
type criteriaType int
const (
ctDomainOrClient criteriaType = iota // domain name or client IP address
ctFilteringStatus // filtering status
)
const (
filteringStatusAll = "all"
filteringStatusFiltered = "filtered" // all kinds of filtering
filteringStatusBlocked = "blocked" // blocked or blocked services
filteringStatusBlockedService = "blocked_services" // blocked
filteringStatusBlockedSafebrowsing = "blocked_safebrowsing" // blocked by safebrowsing
filteringStatusBlockedParental = "blocked_parental" // blocked by parental control
filteringStatusWhitelisted = "whitelisted" // whitelisted
filteringStatusRewritten = "rewritten" // all kinds of rewrites
filteringStatusSafeSearch = "safe_search" // enforced safe search
filteringStatusProcessed = "processed" // not blocked, not white-listed entries
)
// filteringStatusValues -- array with all possible filteringStatus values
var filteringStatusValues = []string{
filteringStatusAll, filteringStatusFiltered, filteringStatusBlocked,
filteringStatusBlockedService, filteringStatusBlockedSafebrowsing, filteringStatusBlockedParental,
filteringStatusWhitelisted, filteringStatusRewritten, filteringStatusSafeSearch,
filteringStatusProcessed,
}
// searchCriteria - every search request may contain a list of different search criteria
// we use each of them to match the query
type searchCriteria struct {
criteriaType criteriaType // type of the criteria
strict bool // should we strictly match (equality) or not (indexOf)
value string // search criteria value
}
// quickMatch - quickly checks if the log entry matches this search criteria
// the reason is to do it as quickly as possible without de-serializing the entry
func (c *searchCriteria) quickMatch(line string) bool {
// note that we do this only for a limited set of criteria
switch c.criteriaType {
case ctDomainOrClient:
return c.quickMatchJSONValue(line, "QH") ||
c.quickMatchJSONValue(line, "IP")
default:
return true
}
}
// quickMatchJSONValue - helper used by quickMatch
func (c *searchCriteria) quickMatchJSONValue(line string, propertyName string) bool {
val := readJSONValue(line, propertyName)
if len(val) == 0 {
return false
}
val = strings.ToLower(val)
searchVal := strings.ToLower(c.value)
if c.strict && searchVal == val {
return true
}
if !c.strict && strings.Contains(val, searchVal) {
return true
}
return false
}
// match - checks if the log entry matches this search criteria
// nolint (gocyclo)
func (c *searchCriteria) match(entry *logEntry) bool {
switch c.criteriaType {
case ctDomainOrClient:
qhost := strings.ToLower(entry.QHost)
searchVal := strings.ToLower(c.value)
if c.strict && qhost == searchVal {
return true
}
if !c.strict && strings.Contains(qhost, searchVal) {
return true
}
if c.strict && entry.IP == c.value {
return true
}
if !c.strict && strings.Contains(entry.IP, c.value) {
return true
}
return false
case ctFilteringStatus:
res := entry.Result
switch c.value {
case filteringStatusAll:
return true
case filteringStatusFiltered:
return res.IsFiltered ||
res.Reason == dnsfilter.NotFilteredWhiteList ||
res.Reason == dnsfilter.ReasonRewrite ||
res.Reason == dnsfilter.RewriteEtcHosts
case filteringStatusBlocked:
return res.IsFiltered &&
(res.Reason == dnsfilter.FilteredBlackList ||
res.Reason == dnsfilter.FilteredBlockedService)
case filteringStatusBlockedService:
return res.IsFiltered && res.Reason == dnsfilter.FilteredBlockedService
case filteringStatusBlockedParental:
return res.IsFiltered && res.Reason == dnsfilter.FilteredParental
case filteringStatusBlockedSafebrowsing:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeBrowsing
case filteringStatusWhitelisted:
return res.Reason == dnsfilter.NotFilteredWhiteList
case filteringStatusRewritten:
return res.Reason == dnsfilter.ReasonRewrite ||
res.Reason == dnsfilter.RewriteEtcHosts
case filteringStatusSafeSearch:
return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch
case filteringStatusProcessed:
return !(res.Reason == dnsfilter.FilteredBlackList ||
res.Reason == dnsfilter.FilteredBlockedService ||
res.Reason == dnsfilter.NotFilteredWhiteList)
default:
return false
}
}
return false
}

View File

@@ -0,0 +1,57 @@
package querylog
import "time"
// searchParams represent the search query sent by the client
type searchParams struct {
// searchCriteria - list of search criteria that we use to get filter results
searchCriteria []searchCriteria
// olderThen - return entries that are older than this value
// if not set - disregard it and return any value
olderThan time.Time
offset int // offset for the search
limit int // limit the number of records returned
maxFileScanEntries int // maximum log entries to scan in query log files. if 0 - no limit
}
// newSearchParams - creates an empty instance of searchParams
func newSearchParams() *searchParams {
return &searchParams{
// default max log entries to return
limit: 500,
// by default, we scan up to 50k entries at once
maxFileScanEntries: 50000,
}
}
// quickMatchesGetDataParams - quickly checks if the line matches the searchParams
// this method does not guarantee anything and the reason is to do a quick check
// without deserializing anything
func (s *searchParams) quickMatch(line string) bool {
for _, c := range s.searchCriteria {
if !c.quickMatch(line) {
return false
}
}
return true
}
// match - checks if the logEntry matches the searchParams
func (s *searchParams) match(entry *logEntry) bool {
if !s.olderThan.IsZero() && entry.Time.UnixNano() >= s.olderThan.UnixNano() {
// Ignore entries newer than what was requested
return false
}
for _, c := range s.searchCriteria {
if !c.match(entry) {
return false
}
}
return true
}