websvc: add test; imp names, docs
This commit is contained in:
19
internal/aghchan/aghchan.go
Normal file
19
internal/aghchan/aghchan.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Package aghchan contains channel utilities.
|
||||
package aghchan
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MustReceive panics if it cannot receive a value form c before timeout runs
|
||||
// out.
|
||||
func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) {
|
||||
timeoutCh := time.After(timeout)
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
panic(fmt.Errorf("did not receive after %s", timeout))
|
||||
case v, ok = <-c:
|
||||
return v, ok
|
||||
}
|
||||
}
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
@@ -163,15 +163,9 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||
checkRefresh := func(t *testing.T, want *HostsRecord) {
|
||||
t.Helper()
|
||||
|
||||
var ok bool
|
||||
var upd *netutil.IPMap
|
||||
select {
|
||||
case upd, ok = <-hc.Upd():
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("did not receive after 1s")
|
||||
}
|
||||
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, upd)
|
||||
|
||||
assert.Equal(t, 1, upd.Len())
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
|
||||
@@ -47,6 +48,11 @@ type Config struct {
|
||||
// Service is the AdGuard Home DNS service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
// running is an atomic boolean value. Keep it the first value in the
|
||||
// struct to ensure atomic alignment. 0 means that the service is not
|
||||
// running, 1 means that it is running.
|
||||
running uint64
|
||||
|
||||
proxy *proxy.Proxy
|
||||
bootstraps []string
|
||||
upstreams []string
|
||||
@@ -160,6 +166,13 @@ func (svc *Service) Start() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
|
||||
// tell when all servers are actually up, so at best this is merely an
|
||||
// assumption.
|
||||
atomic.StoreUint64(&svc.running, 1)
|
||||
}()
|
||||
|
||||
return svc.proxy.Start()
|
||||
}
|
||||
|
||||
@@ -173,13 +186,27 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
return svc.proxy.Stop()
|
||||
}
|
||||
|
||||
// Config returns the current configuration of the web service.
|
||||
// Config returns the current configuration of the web service. Config must not
|
||||
// be called simultaneously with Start. If svc was initialized with ":0"
|
||||
// addresses, addrs will not return the actual bound ports until Start is
|
||||
// finished.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
// TODO(a.garipov): Do we need to get the TCP addresses separately?
|
||||
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
|
||||
addrs := make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.(*net.UDPAddr).AddrPort()
|
||||
|
||||
var addrs []netip.AddrPort
|
||||
if atomic.LoadUint64(&svc.running) == 1 {
|
||||
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
|
||||
addrs = make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.(*net.UDPAddr).AddrPort()
|
||||
}
|
||||
} else {
|
||||
conf := svc.proxy.Config
|
||||
udpAddrs := conf.UDPListenAddr
|
||||
addrs = make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.AddrPort()
|
||||
}
|
||||
}
|
||||
|
||||
c = &Config{
|
||||
|
||||
@@ -25,8 +25,9 @@ type ReqPatchSettingsDNS struct {
|
||||
UpstreamTimeout timeutil.Duration `json:"upstream_timeout"`
|
||||
}
|
||||
|
||||
// httpAPIDNSSettings are the DNS settings as used by the HTTP API.
|
||||
type httpAPIDNSSettings struct {
|
||||
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
|
||||
// DnsSettings object in the OpenAPI specification.
|
||||
type HTTPAPIDNSSettings struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
@@ -76,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONResponse(w, r, &httpAPIDNSSettings{
|
||||
writeJSONResponse(w, r, &HTTPAPIDNSSettings{
|
||||
Addresses: newConf.Addresses,
|
||||
BootstrapServers: newConf.BootstrapServers,
|
||||
UpstreamServers: newConf.UpstreamServers,
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// HTTP Settings Handlers
|
||||
@@ -22,25 +23,25 @@ type ReqPatchSettingsHTTP struct {
|
||||
//
|
||||
// TODO(a.garipov): Add wait time.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
|
||||
Timeout timeutil.Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
// httpAPIDNSSettings are the HTTP settings as used by the HTTP API.
|
||||
type httpAPIHTTPSettings struct {
|
||||
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
|
||||
// HttpSettings object in the OpenAPI specification.
|
||||
type HTTPAPIHTTPSettings struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
|
||||
Timeout timeutil.Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
|
||||
// HTTP API.
|
||||
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
req := &ReqPatchSettingsHTTP{
|
||||
Addresses: []netip.AddrPort{},
|
||||
SecureAddresses: []netip.AddrPort{},
|
||||
}
|
||||
req := &ReqPatchSettingsHTTP{}
|
||||
|
||||
// TODO(a.garipov): Validate nulls and proper JSON patch.
|
||||
|
||||
@@ -56,13 +57,14 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||
TLS: svc.tls,
|
||||
Addresses: req.Addresses,
|
||||
SecureAddresses: req.SecureAddresses,
|
||||
Timeout: svc.timeout,
|
||||
Timeout: req.Timeout.Duration,
|
||||
ForceHTTPS: svc.forceHTTPS,
|
||||
}
|
||||
|
||||
writeJSONResponse(w, r, &httpAPIHTTPSettings{
|
||||
writeJSONResponse(w, r, &HTTPAPIHTTPSettings{
|
||||
Addresses: newConf.Addresses,
|
||||
SecureAddresses: newConf.SecureAddresses,
|
||||
Timeout: timeutil.Duration{Duration: newConf.Timeout},
|
||||
})
|
||||
|
||||
cancelUpd := func() {}
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
type RespGetV1SettingsAll struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
DNS *httpAPIDNSSettings `json:"dns"`
|
||||
HTTP *httpAPIHTTPSettings `json:"http"`
|
||||
DNS *HTTPAPIDNSSettings `json:"dns"`
|
||||
HTTP *HTTPAPIHTTPSettings `json:"http"`
|
||||
}
|
||||
|
||||
// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP
|
||||
@@ -25,18 +25,21 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request)
|
||||
dnsSvc := svc.confMgr.DNS()
|
||||
dnsConf := dnsSvc.Config()
|
||||
|
||||
httpConf := svc.Config()
|
||||
webSvc := svc.confMgr.Web()
|
||||
httpConf := webSvc.Config()
|
||||
|
||||
// TODO(a.garipov): Add all currently supported parameters.
|
||||
writeJSONResponse(w, r, &RespGetV1SettingsAll{
|
||||
DNS: &httpAPIDNSSettings{
|
||||
DNS: &HTTPAPIDNSSettings{
|
||||
Addresses: dnsConf.Addresses,
|
||||
BootstrapServers: dnsConf.BootstrapServers,
|
||||
UpstreamServers: dnsConf.UpstreamServers,
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsConf.UpstreamTimeout},
|
||||
},
|
||||
HTTP: &httpAPIHTTPSettings{
|
||||
HTTP: &HTTPAPIHTTPSettings{
|
||||
Addresses: httpConf.Addresses,
|
||||
SecureAddresses: httpConf.SecureAddresses,
|
||||
Timeout: timeutil.Duration{Duration: httpConf.Timeout},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
61
internal/v1/websvc/settings_test.go
Normal file
61
internal/v1/websvc/settings_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_HandleGetSettingsAll(t *testing.T) {
|
||||
// TODO(a.garipov): Add all currently supported parameters.
|
||||
|
||||
dnsAddrs := []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")}
|
||||
upsSrvs := []string{"94.140.14.14", "1.1.1.1"}
|
||||
|
||||
webAddrs := []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}
|
||||
const webTimeout = 5 * time.Second
|
||||
|
||||
confMgr := newConfigManager()
|
||||
confMgr.onDNS = func() (c *dnssvc.Service) {
|
||||
c, err := dnssvc.New(&dnssvc.Config{
|
||||
Addresses: dnsAddrs,
|
||||
UpstreamServers: upsSrvs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
confMgr.onWeb = func() (c *websvc.Service) {
|
||||
return websvc.New(&websvc.Config{
|
||||
Addresses: webAddrs,
|
||||
Timeout: webTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsAll,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
resp := &websvc.RespGetV1SettingsAll{}
|
||||
err := json.Unmarshal(body, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dnsAddrs, resp.DNS.Addresses)
|
||||
assert.Equal(t, upsSrvs, resp.DNS.UpstreamServers)
|
||||
|
||||
assert.Equal(t, webAddrs, resp.HTTP.Addresses)
|
||||
assert.Equal(t, webTimeout, resp.HTTP.Timeout.Duration)
|
||||
}
|
||||
@@ -14,7 +14,8 @@ import (
|
||||
)
|
||||
|
||||
func TestService_handleGetV1SystemInfo(t *testing.T) {
|
||||
_, addr := newTestServer(t)
|
||||
confMgr := newConfigManager()
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -33,15 +32,7 @@ func TestWaitListener_Accept(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
done := make(chan struct{})
|
||||
a := time.After(testTimeout)
|
||||
go func() {
|
||||
select {
|
||||
case <-a:
|
||||
panic(fmt.Errorf("did not finish after %s", testTimeout))
|
||||
case <-done:
|
||||
// Success.
|
||||
}
|
||||
}()
|
||||
go aghchan.MustReceive(done, testTimeout)
|
||||
|
||||
go func() {
|
||||
var wrapper net.Listener = &waitListener{
|
||||
|
||||
@@ -170,10 +170,12 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) {
|
||||
}
|
||||
|
||||
// addrs returns all addresses on which this server serves the HTTP API. addrs
|
||||
// must not be called until Start returns.
|
||||
func (svc *Service) addrs() (addrs, secAddrs []netip.AddrPort) {
|
||||
// must not be called simultaneously with Start. If svc was initialized with
|
||||
// ":0" addresses, addrs will not return the actual bound ports until Start is
|
||||
// finished.
|
||||
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
|
||||
for _, srv := range svc.servers {
|
||||
ipp, err := netip.ParseAddrPort(srv.Addr)
|
||||
addrPort, err := netip.ParseAddrPort(srv.Addr)
|
||||
if err != nil {
|
||||
// Technically shouldn't happen, since all servers must have a valid
|
||||
// address.
|
||||
@@ -184,14 +186,14 @@ func (svc *Service) addrs() (addrs, secAddrs []netip.AddrPort) {
|
||||
// relying only on the nilness of TLSConfig, check the length of the
|
||||
// certificates field as well.
|
||||
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
|
||||
addrs = append(addrs, ipp)
|
||||
addrs = append(addrs, addrPort)
|
||||
} else {
|
||||
secAddrs = append(secAddrs, ipp)
|
||||
secureAddrs = append(secureAddrs, addrPort)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return addrs, secAddrs
|
||||
return addrs, secureAddrs
|
||||
}
|
||||
|
||||
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
|
||||
@@ -279,8 +281,10 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config returns the current configuration of the web service. Currently, only
|
||||
// the Addresses and SecureAddresses fields are filled in c.
|
||||
// Config returns the current configuration of the web service. Config must not
|
||||
// be called simultaneously with Start. If svc was initialized with ":0"
|
||||
// addresses, addrs will not return the actual bound ports until Start is
|
||||
// finished.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
c = &Config{
|
||||
ConfigManager: svc.confMgr,
|
||||
|
||||
@@ -9,32 +9,89 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
aghtest.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testStart is the server start value for tests.
|
||||
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// type check
|
||||
var _ websvc.ConfigManager = (*configManager)(nil)
|
||||
|
||||
// configManager is a [websvc.ConfigManager] for tests.
|
||||
type configManager struct {
|
||||
onDNS func() (svc *dnssvc.Service)
|
||||
onWeb func() (svc *websvc.Service)
|
||||
|
||||
onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error)
|
||||
onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error)
|
||||
}
|
||||
|
||||
// DNS implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) DNS() (svc *dnssvc.Service) {
|
||||
return m.onDNS()
|
||||
}
|
||||
|
||||
// Web implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) Web() (svc *websvc.Service) {
|
||||
return m.onWeb()
|
||||
}
|
||||
|
||||
// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||
return m.onUpdateDNS(ctx, c)
|
||||
}
|
||||
|
||||
// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
|
||||
return m.onUpdateWeb(ctx, c)
|
||||
}
|
||||
|
||||
// newConfigManager returns a *configManager all methods of which panic.
|
||||
func newConfigManager() (m *configManager) {
|
||||
return &configManager{
|
||||
onDNS: func() (svc *dnssvc.Service) { panic("not implemented") },
|
||||
onWeb: func() (svc *websvc.Service) { panic("not implemented") },
|
||||
onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newTestServer creates and starts a new web service instance as well as its
|
||||
// sole address. It also registers a cleanup procedure, which shuts the
|
||||
// instance down.
|
||||
//
|
||||
// TODO(a.garipov): Use svc or remove it.
|
||||
func newTestServer(t testing.TB) (svc *websvc.Service, addr netip.AddrPort) {
|
||||
func newTestServer(
|
||||
t testing.TB,
|
||||
confMgr websvc.ConfigManager,
|
||||
) (svc *websvc.Service, addr netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
c := &websvc.Config{
|
||||
ConfigManager: confMgr,
|
||||
TLS: nil,
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
SecureAddresses: nil,
|
||||
Timeout: testTimeout,
|
||||
Start: testStart,
|
||||
ForceHTTPS: false,
|
||||
}
|
||||
|
||||
svc = websvc.New(c)
|
||||
@@ -82,7 +139,8 @@ func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
|
||||
}
|
||||
|
||||
func TestService_Start_getHealthCheck(t *testing.T) {
|
||||
_, addr := newTestServer(t)
|
||||
confMgr := newConfigManager()
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
|
||||
Reference in New Issue
Block a user