Pull request: 2639 use testify require vol.4

Merge in DNS/adguard-home from 2639-testify-require-4 to master

Closes #2639.

Squashed commit of the following:

commit 0bb9125f42ab6d2511c1b8e481112aa5edd581d9
Merge: 0e9e9ed1 2c9992e0
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Mar 11 15:47:21 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit 0e9e9ed16ae13ce648b5e1da6ffd123df911c2d7
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:43:15 2021 +0300

    home: rm deletion error check

commit 6bfbbcd2b7f9197a06856f9e6b959c2e1c4b8353
Merge: c8ebe541 8811c881
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:30:07 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit c8ebe54142bba780226f76ddb72e33664ed28f30
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Mar 10 12:28:43 2021 +0300

    home: imp tests

commit f0e1db456f02df5f5f56ca93e7bd40a48475b38c
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Mar 5 14:06:41 2021 +0300

    dnsforward: imp tests

commit 4528246105ed06471a8778abbe8e5c30fc5483d5
Merge: 54b08d9c 90ebc4d8
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Mar 4 18:17:52 2021 +0300

    Merge branch 'master' into 2639-testify-require-4

commit 54b08d9c980b8d69d019a1a1b3931aa048275691
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Feb 11 13:17:05 2021 +0300

    dnsfilter: imp tests
This commit is contained in:
Eugene Burkov
2021-03-11 17:32:58 +03:00
parent 2c9992e0cc
commit dfdbfee4fd
19 changed files with 1375 additions and 1267 deletions

View File

@@ -20,10 +20,18 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func prepareTestDir() string {
func prepareTestDir(t *testing.T) string {
t.Helper()
const dir = "./agh-test"
_ = os.RemoveAll(dir)
_ = os.MkdirAll(dir, 0o755)
require.Nil(t, os.RemoveAll(dir))
// TODO(e.burkov): Replace with testing.TempDir after updating Go
// version to 1.16.
require.Nil(t, os.MkdirAll(dir, 0o755))
t.Cleanup(func() { require.Nil(t, os.RemoveAll(dir)) })
return dir
}
@@ -47,8 +55,7 @@ func TestNewSessionToken(t *testing.T) {
}
func TestAuth(t *testing.T) {
dir := prepareTestDir()
t.Cleanup(func() { _ = os.RemoveAll(dir) })
dir := prepareTestDir(t)
fn := filepath.Join(dir, "sessions.db")
users := []User{{
@@ -123,8 +130,7 @@ func (w *testResponseWriter) WriteHeader(statusCode int) {
}
func TestAuthHTTP(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
dir := prepareTestDir(t)
fn := filepath.Join(dir, "sessions.db")
users := []User{

View File

@@ -4,40 +4,38 @@ import (
"encoding/binary"
"io/ioutil"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthGL(t *testing.T) {
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
dir := prepareTestDir(t)
GLMode = true
t.Cleanup(func() {
GLMode = false
})
glFilePrefix = dir + "/gl_token_"
tval := uint32(1)
data := make([]byte, 4)
putFunc := binary.BigEndian.PutUint32
if archIsLittleEndian() {
binary.LittleEndian.PutUint32(data, tval)
} else {
binary.BigEndian.PutUint32(data, tval)
putFunc = binary.LittleEndian.PutUint32
}
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
data := make([]byte, 4)
putFunc(data, 1)
require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
assert.False(t, glCheckToken("test"))
tval = uint32(time.Now().UTC().Unix() + 60)
data = make([]byte, 4)
if archIsLittleEndian() {
binary.LittleEndian.PutUint32(data, tval)
} else {
binary.BigEndian.PutUint32(data, tval)
}
assert.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
putFunc(data, uint32(time.Now().UTC().Unix()+60))
require.Nil(t, ioutil.WriteFile(glFilePrefix+"test", data, 0o644))
r, _ := http.NewRequest(http.MethodGet, "http://localhost/", nil)
r.AddCookie(&http.Cookie{Name: glCookieName, Value: "test"})
assert.True(t, glProcessCookie(r))
GLMode = false
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClients(t *testing.T) {
@@ -24,8 +25,8 @@ func TestClients(t *testing.T) {
}
ok, err := clients.Add(c)
require.Nil(t, err)
assert.True(t, ok)
assert.Nil(t, err)
c = &Client{
IDs: []string{"2.2.2.2"},
@@ -33,110 +34,99 @@ func TestClients(t *testing.T) {
}
ok, err = clients.Add(c)
require.Nil(t, err)
assert.True(t, ok)
assert.Nil(t, err)
c, ok = clients.Find("1.1.1.1")
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("1:2:3::4")
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "client1", c.Name)
c, ok = clients.Find("2.2.2.2")
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "client2", c.Name)
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
})
t.Run("add_fail_name", func(t *testing.T) {
c := &Client{
ok, err := clients.Add(&Client{
IDs: []string{"1.2.3.5"},
Name: "client1",
}
ok, err := clients.Add(c)
})
require.Nil(t, err)
assert.False(t, ok)
assert.Nil(t, err)
})
t.Run("add_fail_ip", func(t *testing.T) {
c := &Client{
ok, err := clients.Add(&Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
ok, err := clients.Add(c)
})
require.NotNil(t, err)
assert.False(t, ok)
assert.NotNil(t, err)
})
t.Run("update_fail_name", func(t *testing.T) {
c := &Client{
err := clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"},
Name: "client3",
}
})
require.NotNil(t, err)
err := clients.Update("client3", c)
assert.NotNil(t, err)
c = &Client{
err = clients.Update("client3", &Client{
IDs: []string{"1.2.3.0"},
Name: "client2",
}
err = clients.Update("client3", c)
})
assert.NotNil(t, err)
})
t.Run("update_fail_ip", func(t *testing.T) {
c := &Client{
err := clients.Update("client1", &Client{
IDs: []string{"2.2.2.2"},
Name: "client1",
}
err := clients.Update("client1", c)
})
assert.NotNil(t, err)
})
t.Run("update_success", func(t *testing.T) {
c := &Client{
err := clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
})
require.Nil(t, err)
err := clients.Update("client1", c)
assert.Nil(t, err)
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
c = &Client{
err = clients.Update("client1", &Client{
IDs: []string{"1.1.1.2"},
Name: "client1-renamed",
UseOwnSettings: true,
}
err = clients.Update("client1", c)
assert.Nil(t, err)
})
require.Nil(t, err)
c, ok := clients.Find("1.1.1.2")
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings)
assert.Nil(t, clients.list["client1"])
if assert.Len(t, c.IDs, 1) {
assert.Equal(t, "1.1.1.2", c.IDs[0])
}
nilCli, ok := clients.list["client1"]
require.False(t, ok)
assert.Nil(t, nilCli)
require.Len(t, c.IDs, 1)
assert.Equal(t, "1.1.1.2", c.IDs[0])
})
t.Run("del_success", func(t *testing.T) {
ok := clients.Del("client1-renamed")
assert.True(t, ok)
require.True(t, ok)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
})
@@ -147,146 +137,155 @@ func TestClients(t *testing.T) {
t.Run("addhost_success", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
require.Nil(t, err)
assert.True(t, ok)
assert.Nil(t, err)
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
require.Nil(t, err)
assert.True(t, ok)
assert.Nil(t, err)
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
require.Nil(t, err)
assert.True(t, ok)
assert.Nil(t, err)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
})
t.Run("addhost_fail", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
require.Nil(t, err)
assert.False(t, ok)
assert.Nil(t, err)
})
}
func TestClientsWhois(t *testing.T) {
var c *Client
clients := clientsContainer{}
clients.testing = true
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois)
if assert.NotNil(t, clients.ipHost["1.1.1.255"]) {
t.Run("new_client", func(t *testing.T) {
clients.SetWhoisInfo("1.1.1.255", whois)
require.NotNil(t, clients.ipHost["1.1.1.255"])
h := clients.ipHost["1.1.1.255"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois)
if assert.NotNil(t, clients.ipHost["1.1.1.1"]) {
require.Len(t, h.WhoisInfo, 2)
require.Len(t, h.WhoisInfo[0], 2)
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
})
t.Run("existing_auto-client", func(t *testing.T) {
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
require.Nil(t, err)
assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.1", whois)
require.NotNil(t, clients.ipHost["1.1.1.1"])
h := clients.ipHost["1.1.1.1"]
if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) {
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
}
}
// Check that we cannot set whois info on a manually-added client
c = &Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
}
_, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois)
assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1")
require.Len(t, h.WhoisInfo, 2)
require.Len(t, h.WhoisInfo[0], 2)
assert.Equal(t, "orgname-val", h.WhoisInfo[0][1])
})
t.Run("can't_set_manually-added", func(t *testing.T) {
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.2"},
Name: "client1",
})
require.Nil(t, err)
assert.True(t, ok)
clients.SetWhoisInfo("1.1.1.2", whois)
require.Nil(t, clients.ipHost["1.1.1.2"])
assert.True(t, clients.Del("client1"))
})
}
func TestClientsAddExisting(t *testing.T) {
var c *Client
clients := clientsContainer{}
clients.testing = true
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
// some test variables
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
testIP := "1.2.3.4"
t.Run("simple", func(t *testing.T) {
// Add a client.
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
})
require.Nil(t, err)
assert.True(t, ok)
// add a client
c = &Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
Name: "client1",
}
ok, err := clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// add an auto-client with the same IP - it's allowed
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
assert.True(t, ok)
assert.Nil(t, err)
// now some more complicated stuff
// first, init a DHCP server with a single static lease
config := dhcpd.ServerConfig{
DBFilePath: "leases.db",
}
defer func() { _ = os.Remove("leases.db") }()
clients.dhcpServer = dhcpd.Create(config)
err = clients.dhcpServer.AddStaticLease(dhcpd.Lease{
HWAddr: mac,
IP: net.ParseIP(testIP).To4(),
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
// Now add an auto-client with the same IP.
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
require.Nil(t, err)
assert.True(t, ok)
})
assert.Nil(t, err)
// add a new client with the same IP as for a client with MAC
c = &Client{
IDs: []string{testIP},
Name: "client2",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
t.Run("complicated", func(t *testing.T) {
testIP := net.IP{1, 2, 3, 4}
// add a new client with the IP from the client1's IP range
c = &Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
}
ok, err = clients.Add(c)
assert.True(t, ok)
assert.Nil(t, err)
// First, init a DHCP server with a single static lease.
config := dhcpd.ServerConfig{
DBFilePath: "leases.db",
}
clients.dhcpServer = dhcpd.Create(config)
t.Cleanup(func() { _ = os.Remove("leases.db") })
err := clients.dhcpServer.AddStaticLease(dhcpd.Lease{
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: testIP,
Hostname: "testhost",
Expiry: time.Now().Add(time.Hour),
})
require.Nil(t, err)
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.Add(&Client{
IDs: []string{testIP.String()},
Name: "client2",
})
require.Nil(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP
// range.
ok, err = clients.Add(&Client{
IDs: []string{"2.2.2.2"},
Name: "client3",
})
require.Nil(t, err)
assert.True(t, ok)
})
}
func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{}
clients.testing = true
clients := clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil)
// add client with upstreams
c := &Client{
// Add client with upstreams.
ok, err := clients.Add(&Client{
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"},
Name: "client1",
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
}
ok, err := clients.Add(c)
assert.Nil(t, err)
})
require.Nil(t, err)
assert.True(t, ok)
config := clients.FindUpstreams("1.2.3.4")
assert.Nil(t, config)
config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams))
assert.Equal(t, 1, len(config.DomainReservedUpstreams))
require.NotNil(t, config)
assert.Len(t, config.Upstreams, 1)
assert.Len(t, config.DomainReservedUpstreams, 1)
}

View File

@@ -3,32 +3,12 @@ package home
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
/* Tests performed:
. Bad certificate
. Bad private key
. Valid certificate & private key */
func TestValidateCertificates(t *testing.T) {
var data tlsConfigStatus
// bad cert
data = validateCertificates("bad cert", "", "")
if !(data.WarningValidation != "" &&
!data.ValidCert &&
!data.ValidChain) {
t.Fatalf("bad cert: validateCertificates(): %v", data)
}
// bad priv key
data = validateCertificates("", "bad priv key", "")
if !(data.WarningValidation != "" &&
!data.ValidKey) {
t.Fatalf("bad priv key: validateCertificates(): %v", data)
}
// valid cert & priv key
CertificateChain := `-----BEGIN CERTIFICATE-----
const (
CertificateChain = `-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
@@ -42,7 +22,7 @@ LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----`
PrivateKey := `-----BEGIN PRIVATE KEY-----
PrivateKey = `-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
@@ -58,20 +38,35 @@ O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----`
data = validateCertificates(CertificateChain, PrivateKey, "")
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
if !(data.WarningValidation != "" /* self signed */ &&
data.ValidCert &&
!data.ValidChain &&
data.ValidKey &&
data.KeyType == "RSA" &&
data.Subject == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.Issuer == "CN=AdGuard Home,O=AdGuard Ltd" &&
data.NotBefore.Equal(notBefore) &&
data.NotAfter.Equal(notAfter) &&
// data.DNSNames[0] == &&
data.ValidPair) {
t.Fatalf("valid cert & priv key: validateCertificates(): %v", data)
}
)
func TestValidateCertificates(t *testing.T) {
t.Run("bad_certificate", func(t *testing.T) {
data := validateCertificates("bad cert", "", "")
assert.NotEmpty(t, data.WarningValidation)
assert.False(t, data.ValidCert)
assert.False(t, data.ValidChain)
})
t.Run("bad_private_key", func(t *testing.T) {
data := validateCertificates("", "bad priv key", "")
assert.NotEmpty(t, data.WarningValidation)
assert.False(t, data.ValidKey)
})
t.Run("valid", func(t *testing.T) {
data := validateCertificates(CertificateChain, PrivateKey, "")
notBefore, _ := time.Parse(time.RFC3339, "2019-02-27T09:24:23Z")
notAfter, _ := time.Parse(time.RFC3339, "2046-07-14T09:24:23Z")
assert.NotEmpty(t, data.WarningValidation)
assert.True(t, data.ValidCert)
assert.False(t, data.ValidChain)
assert.True(t, data.ValidKey)
assert.Equal(t, "RSA", data.KeyType)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Subject)
assert.Equal(t, "CN=AdGuard Home,O=AdGuard Ltd", data.Issuer)
assert.Equal(t, notBefore, data.NotBefore)
assert.Equal(t, notAfter, data.NotAfter)
assert.True(t, data.ValidPair)
})
}

View File

@@ -9,38 +9,47 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testStartFilterListener() net.Listener {
func testStartFilterListener(t *testing.T) net.Listener {
t.Helper()
const content = `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
mux := http.NewServeMux()
mux.HandleFunc("/filters/1.txt", func(w http.ResponseWriter, r *http.Request) {
content := `||example.org^$third-party
# Inline comment example
||example.com^$third-party
0.0.0.0 example.com
`
_, _ = w.Write([]byte(content))
_, werr := w.Write([]byte(content))
assert.Nil(t, werr)
})
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
require.Nil(t, err)
go func() {
_ = http.Serve(listener, mux)
}()
t.Cleanup(func() {
assert.Nil(t, listener.Close())
})
go func() { _ = http.Serve(listener, mux) }()
return listener
}
func TestFilters(t *testing.T) {
l := testStartFilterListener()
defer func() { _ = l.Close() }()
l := testStartFilterListener(t)
dir := prepareTestDir(t)
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
Context = homeContext{}
Context.workDir = dir
Context.client = &http.Client{
Timeout: 5 * time.Second,
Context = homeContext{
workDir: dir,
client: &http.Client{
Timeout: 5 * time.Second,
},
}
Context.filters.Init()
@@ -48,20 +57,20 @@ func TestFilters(t *testing.T) {
URL: fmt.Sprintf("http://127.0.0.1:%d/filters/1.txt", l.Addr().(*net.TCPAddr).Port),
}
// download
// Download.
ok, err := Context.filters.update(&f)
assert.Nil(t, err)
assert.True(t, ok)
require.Nil(t, err)
require.True(t, ok)
assert.Equal(t, 3, f.RulesCount)
// refresh
// Refresh.
ok, err = Context.filters.update(&f)
assert.False(t, ok)
assert.Nil(t, err)
require.Nil(t, err)
require.False(t, ok)
err = Context.filters.load(&f)
assert.Nil(t, err)
require.Nil(t, err)
f.unload()
_ = os.Remove(f.Path())
require.Nil(t, os.Remove(f.Path()))
}

View File

@@ -114,8 +114,7 @@ func TestHome(t *testing.T) {
// Init new context
Context = homeContext{}
dir := prepareTestDir()
defer func() { _ = os.RemoveAll(dir) }()
dir := prepareTestDir(t)
fn := filepath.Join(dir, "AdGuardHome.yaml")
// Prepare the test config

View File

@@ -39,21 +39,21 @@ func TestLimitRequestBody(t *testing.T) {
wantErr: nil,
}}
makeHandler := func(err *error) http.HandlerFunc {
makeHandler := func(t *testing.T, err *error) http.HandlerFunc {
t.Helper()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var b []byte
b, *err = ioutil.ReadAll(r.Body)
_, werr := w.Write(b)
if werr != nil {
panic(werr)
}
require.Nil(t, werr)
})
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var err error
handler := makeHandler(&err)
handler := makeHandler(t, &err)
lim := limitRequestBody(handler)
req := httptest.NewRequest(http.MethodPost, "https://www.example.com", strings.NewReader(tc.body))
@@ -61,7 +61,7 @@ func TestLimitRequestBody(t *testing.T) {
lim.ServeHTTP(res, req)
require.Equal(t, tc.wantErr, err)
assert.Equal(t, tc.wantErr, err)
assert.Equal(t, tc.want, res.Body.Bytes())
})
}

View File

@@ -6,29 +6,29 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"howett.net/plist"
)
func TestHandleMobileConfigDOH(t *testing.T) {
t.Run("success", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
}
require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
})
t.Run("success_no_host", func(t *testing.T) {
@@ -40,23 +40,22 @@ func TestHandleMobileConfigDOH(t *testing.T) {
}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
}
require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query", mc.PayloadContent[0].DNSSettings.ServerURL)
})
t.Run("error_no_host", func(t *testing.T) {
@@ -66,7 +65,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
@@ -76,45 +75,43 @@ func TestHandleMobileConfigDOH(t *testing.T) {
t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOH(w, r)
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
}
require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL)
})
}
func TestHandleMobileConfigDOT(t *testing.T) {
t.Run("success", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
})
t.Run("success_no_host", func(t *testing.T) {
@@ -126,22 +123,21 @@ func TestHandleMobileConfigDOT(t *testing.T) {
}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
})
t.Run("error_no_host", func(t *testing.T) {
@@ -151,7 +147,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
Context.tls = &TLSMod{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
@@ -161,21 +157,20 @@ func TestHandleMobileConfigDOT(t *testing.T) {
t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)
require.Nil(t, err)
w := httptest.NewRecorder()
handleMobileConfigDOT(w, r)
assert.Equal(t, http.StatusOK, w.Code)
require.Equal(t, http.StatusOK, w.Code)
var mc mobileConfig
_, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err)
require.Nil(t, err)
if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
}
require.Len(t, mc.PayloadContent, 1)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName)
})
}

View File

@@ -4,96 +4,74 @@ import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testParseOk(t *testing.T, ss ...string) options {
func testParseOK(t *testing.T, ss ...string) options {
t.Helper()
o, _, err := parse("", ss)
if err != nil {
t.Fatal(err.Error())
}
require.Nil(t, err)
return o
}
func testParseErr(t *testing.T, descr string, ss ...string) {
t.Helper()
_, _, err := parse("", ss)
if err == nil {
t.Fatalf("expected an error because %s but no error returned", descr)
}
require.NotNilf(t, err, "expected an error because %s but no error returned", descr)
}
func testParseParamMissing(t *testing.T, param string) {
t.Helper()
testParseErr(t, fmt.Sprintf("%s parameter missing", param), param)
}
func TestParseVerbose(t *testing.T) {
if testParseOk(t).verbose {
t.Fatal("empty is not verbose")
}
if !testParseOk(t, "-v").verbose {
t.Fatal("-v is verbose")
}
if !testParseOk(t, "--verbose").verbose {
t.Fatal("--verbose is verbose")
}
assert.False(t, testParseOK(t).verbose, "empty is not verbose")
assert.True(t, testParseOK(t, "-v").verbose, "-v is verbose")
assert.True(t, testParseOK(t, "--verbose").verbose, "--verbose is verbose")
}
func TestParseConfigFilename(t *testing.T) {
if testParseOk(t).configFilename != "" {
t.Fatal("empty is no config filename")
}
if testParseOk(t, "-c", "path").configFilename != "path" {
t.Fatal("-c is config filename")
}
assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename")
assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename")
testParseParamMissing(t, "-c")
if testParseOk(t, "--config", "path").configFilename != "path" {
t.Fatal("--configFilename is config filename")
}
assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename")
testParseParamMissing(t, "--config")
}
func TestParseWorkDir(t *testing.T) {
if testParseOk(t).workDir != "" {
t.Fatal("empty is no work dir")
}
if testParseOk(t, "-w", "path").workDir != "path" {
t.Fatal("-w is work dir")
}
assert.Equal(t, "", testParseOK(t).workDir, "empty is no work dir")
assert.Equal(t, "path", testParseOK(t, "-w", "path").workDir, "-w is work dir")
testParseParamMissing(t, "-w")
if testParseOk(t, "--work-dir", "path").workDir != "path" {
t.Fatal("--work-dir is work dir")
}
assert.Equal(t, "path", testParseOK(t, "--work-dir", "path").workDir, "--work-dir is work dir")
testParseParamMissing(t, "--work-dir")
}
func TestParseBindHost(t *testing.T) {
if testParseOk(t).bindHost != nil {
t.Fatal("empty is no host")
}
if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("-h is host")
}
assert.Nil(t, testParseOK(t).bindHost, "empty is not host")
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "-h", "1.2.3.4").bindHost, "-h is host")
testParseParamMissing(t, "-h")
if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("--host is host")
}
assert.Equal(t, net.IPv4(1, 2, 3, 4), testParseOK(t, "--host", "1.2.3.4").bindHost, "--host is host")
testParseParamMissing(t, "--host")
}
func TestParseBindPort(t *testing.T) {
if testParseOk(t).bindPort != 0 {
t.Fatal("empty is port 0")
}
if testParseOk(t, "-p", "65535").bindPort != 65535 {
t.Fatal("-p is port")
}
assert.Equal(t, 0, testParseOK(t).bindPort, "empty is port 0")
assert.Equal(t, 65535, testParseOK(t, "-p", "65535").bindPort, "-p is port")
testParseParamMissing(t, "-p")
if testParseOk(t, "--port", "65535").bindPort != 65535 {
t.Fatal("--port is port")
}
testParseParamMissing(t, "--port")
}
func TestParseBindPortBad(t *testing.T) {
assert.Equal(t, 65535, testParseOK(t, "--port", "65535").bindPort, "--port is port")
testParseParamMissing(t, "--port")
testParseErr(t, "not an int", "-p", "x")
testParseErr(t, "hex not supported", "-p", "0x100")
testParseErr(t, "port negative", "-p", "-1")
@@ -103,72 +81,40 @@ func TestParseBindPortBad(t *testing.T) {
}
func TestParseLogfile(t *testing.T) {
if testParseOk(t).logFile != "" {
t.Fatal("empty is no log file")
}
if testParseOk(t, "-l", "path").logFile != "path" {
t.Fatal("-l is log file")
}
if testParseOk(t, "--logfile", "path").logFile != "path" {
t.Fatal("--logfile is log file")
}
assert.Equal(t, "", testParseOK(t).logFile, "empty is no log file")
assert.Equal(t, "path", testParseOK(t, "-l", "path").logFile, "-l is log file")
assert.Equal(t, "path", testParseOK(t, "--logfile", "path").logFile, "--logfile is log file")
}
func TestParsePidfile(t *testing.T) {
if testParseOk(t).pidFile != "" {
t.Fatal("empty is no pid file")
}
if testParseOk(t, "--pidfile", "path").pidFile != "path" {
t.Fatal("--pidfile is pid file")
}
assert.Equal(t, "", testParseOK(t).pidFile, "empty is no pid file")
assert.Equal(t, "path", testParseOK(t, "--pidfile", "path").pidFile, "--pidfile is pid file")
}
func TestParseCheckConfig(t *testing.T) {
if testParseOk(t).checkConfig {
t.Fatal("empty is not check config")
}
if !testParseOk(t, "--check-config").checkConfig {
t.Fatal("--check-config is check config")
}
assert.False(t, testParseOK(t).checkConfig, "empty is not check config")
assert.True(t, testParseOK(t, "--check-config").checkConfig, "--check-config is check config")
}
func TestParseDisableUpdate(t *testing.T) {
if testParseOk(t).disableUpdate {
t.Fatal("empty is not disable update")
}
if !testParseOk(t, "--no-check-update").disableUpdate {
t.Fatal("--no-check-update is disable update")
}
assert.False(t, testParseOK(t).disableUpdate, "empty is not disable update")
assert.True(t, testParseOK(t, "--no-check-update").disableUpdate, "--no-check-update is disable update")
}
func TestParseDisableMemoryOptimization(t *testing.T) {
if testParseOk(t).disableMemoryOptimization {
t.Fatal("empty is not disable update")
}
if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization {
t.Fatal("--no-mem-optimization is disable update")
}
assert.False(t, testParseOK(t).disableMemoryOptimization, "empty is not disable update")
assert.True(t, testParseOK(t, "--no-mem-optimization").disableMemoryOptimization, "--no-mem-optimization is disable update")
}
func TestParseService(t *testing.T) {
if testParseOk(t).serviceControlAction != "" {
t.Fatal("empty is no service command")
}
if testParseOk(t, "-s", "command").serviceControlAction != "command" {
t.Fatal("-s is service command")
}
if testParseOk(t, "--service", "command").serviceControlAction != "command" {
t.Fatal("--service is service command")
}
assert.Equal(t, "", testParseOK(t).serviceControlAction, "empty is not service cmd")
assert.Equal(t, "cmd", testParseOK(t, "-s", "cmd").serviceControlAction, "-s is service cmd")
assert.Equal(t, "cmd", testParseOK(t, "--service", "cmd").serviceControlAction, "--service is service cmd")
}
func TestParseGLInet(t *testing.T) {
if testParseOk(t).glinetMode {
t.Fatal("empty is not GL-Inet mode")
}
if !testParseOk(t, "--glinet").glinetMode {
t.Fatal("--glinet is GL-Inet mode")
}
assert.False(t, testParseOK(t).glinetMode, "empty is not GL-Inet mode")
assert.True(t, testParseOK(t, "--glinet").glinetMode, "--glinet is GL-Inet mode")
}
func TestParseUnknown(t *testing.T) {
@@ -180,73 +126,85 @@ func TestParseUnknown(t *testing.T) {
testParseErr(t, "unknown dash", "-")
}
func testSerialize(t *testing.T, o options, ss ...string) {
result := serialize(o)
if len(result) != len(ss) {
t.Fatalf("expected %s but got %s", ss, result)
}
for i, r := range result {
if r != ss[i] {
t.Fatalf("expected %s but got %s", ss, result)
}
func TestSerialize(t *testing.T) {
const reportFmt = "expected %s but got %s"
testCases := []struct {
name string
opts options
ss []string
}{{
name: "empty",
opts: options{},
ss: []string{},
}, {
name: "config_filename",
opts: options{configFilename: "path"},
ss: []string{"-c", "path"},
}, {
name: "work_dir",
opts: options{workDir: "path"},
ss: []string{"-w", "path"},
}, {
name: "bind_host",
opts: options{bindHost: net.IP{1, 2, 3, 4}},
ss: []string{"-h", "1.2.3.4"},
}, {
name: "bind_port",
opts: options{bindPort: 666},
ss: []string{"-p", "666"},
}, {
name: "log_file",
opts: options{logFile: "path"},
ss: []string{"-l", "path"},
}, {
name: "pid_file",
opts: options{pidFile: "path"},
ss: []string{"--pidfile", "path"},
}, {
name: "disable_update",
opts: options{disableUpdate: true},
ss: []string{"--no-check-update"},
}, {
name: "control_action",
opts: options{serviceControlAction: "run"},
ss: []string{"-s", "run"},
}, {
name: "glinet_mode",
opts: options{glinetMode: true},
ss: []string{"--glinet"},
}, {
name: "disable_mem_opt",
opts: options{disableMemoryOptimization: true},
ss: []string{"--no-mem-optimization"},
}, {
name: "multiple",
opts: options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
},
ss: []string{
"-c", "config",
"-w", "work",
"-s", "run",
"--pidfile", "pid",
"--no-check-update",
"--no-mem-optimization",
},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := serialize(tc.opts)
require.Lenf(t, result, len(tc.ss), reportFmt, tc.ss, result)
for i, r := range result {
assert.Equalf(t, tc.ss[i], r, reportFmt, tc.ss, result)
}
})
}
}
func TestSerializeEmpty(t *testing.T) {
testSerialize(t, options{})
}
func TestSerializeConfigFilename(t *testing.T) {
testSerialize(t, options{configFilename: "path"}, "-c", "path")
}
func TestSerializeWorkDir(t *testing.T) {
testSerialize(t, options{workDir: "path"}, "-w", "path")
}
func TestSerializeBindHost(t *testing.T) {
testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
}
func TestSerializeBindPort(t *testing.T) {
testSerialize(t, options{bindPort: 666}, "-p", "666")
}
func TestSerializeLogfile(t *testing.T) {
testSerialize(t, options{logFile: "path"}, "-l", "path")
}
func TestSerializePidfile(t *testing.T) {
testSerialize(t, options{pidFile: "path"}, "--pidfile", "path")
}
func TestSerializeCheckConfig(t *testing.T) {
testSerialize(t, options{checkConfig: true}, "--check-config")
}
func TestSerializeDisableUpdate(t *testing.T) {
testSerialize(t, options{disableUpdate: true}, "--no-check-update")
}
func TestSerializeService(t *testing.T) {
testSerialize(t, options{serviceControlAction: "run"}, "-s", "run")
}
func TestSerializeGLInet(t *testing.T) {
testSerialize(t, options{glinetMode: true}, "--glinet")
}
func TestSerializeDisableMemoryOptimization(t *testing.T) {
testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization")
}
func TestSerializeMultiple(t *testing.T) {
testSerialize(t, options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization")
}