websvc: add test; imp names, docs

This commit is contained in:
Ainar Garipov
2022-09-02 18:52:22 +03:00
parent 8a65848da4
commit abcbdbed29
11 changed files with 218 additions and 57 deletions

View File

@@ -0,0 +1,19 @@
// Package aghchan contains channel utilities.
package aghchan
import (
"fmt"
"time"
)
// MustReceive panics if it cannot receive a value form c before timeout runs
// out.
func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) {
timeoutCh := time.After(timeout)
select {
case <-timeoutCh:
panic(fmt.Errorf("did not receive after %s", timeout))
case v, ok = <-c:
return v, ok
}
}

View File

@@ -10,9 +10,9 @@ import (
"testing/fstest"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter"
@@ -163,15 +163,9 @@ func TestHostsContainer_refresh(t *testing.T) {
checkRefresh := func(t *testing.T, want *HostsRecord) {
t.Helper()
var ok bool
var upd *netutil.IPMap
select {
case upd, ok = <-hc.Upd():
require.True(t, ok)
require.NotNil(t, upd)
case <-time.After(1 * time.Second):
t.Fatal("did not receive after 1s")
}
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
require.True(t, ok)
require.NotNil(t, upd)
assert.Equal(t, 1, upd.Len())

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"net"
"net/netip"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
@@ -47,6 +48,11 @@ type Config struct {
// Service is the AdGuard Home DNS service. A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
// running is an atomic boolean value. Keep it the first value in the
// struct to ensure atomic alignment. 0 means that the service is not
// running, 1 means that it is running.
running uint64
proxy *proxy.Proxy
bootstraps []string
upstreams []string
@@ -160,6 +166,13 @@ func (svc *Service) Start() (err error) {
return nil
}
defer func() {
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
// tell when all servers are actually up, so at best this is merely an
// assumption.
atomic.StoreUint64(&svc.running, 1)
}()
return svc.proxy.Start()
}
@@ -173,13 +186,27 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return svc.proxy.Stop()
}
// Config returns the current configuration of the web service.
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
// TODO(a.garipov): Do we need to get the TCP addresses separately?
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs := make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.(*net.UDPAddr).AddrPort()
var addrs []netip.AddrPort
if atomic.LoadUint64(&svc.running) == 1 {
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.(*net.UDPAddr).AddrPort()
}
} else {
conf := svc.proxy.Config
udpAddrs := conf.UDPListenAddr
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.AddrPort()
}
}
c = &Config{

View File

@@ -25,8 +25,9 @@ type ReqPatchSettingsDNS struct {
UpstreamTimeout timeutil.Duration `json:"upstream_timeout"`
}
// httpAPIDNSSettings are the DNS settings as used by the HTTP API.
type httpAPIDNSSettings struct {
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
// DnsSettings object in the OpenAPI specification.
type HTTPAPIDNSSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
@@ -76,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
return
}
writeJSONResponse(w, r, &httpAPIDNSSettings{
writeJSONResponse(w, r, &HTTPAPIDNSSettings{
Addresses: newConf.Addresses,
BootstrapServers: newConf.BootstrapServers,
UpstreamServers: newConf.UpstreamServers,

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
// HTTP Settings Handlers
@@ -22,25 +23,25 @@ type ReqPatchSettingsHTTP struct {
//
// TODO(a.garipov): Add wait time.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout timeutil.Duration `json:"timeout"`
}
// httpAPIDNSSettings are the HTTP settings as used by the HTTP API.
type httpAPIHTTPSettings struct {
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
// HttpSettings object in the OpenAPI specification.
type HTTPAPIHTTPSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout timeutil.Duration `json:"timeout"`
}
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
// HTTP API.
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsHTTP{
Addresses: []netip.AddrPort{},
SecureAddresses: []netip.AddrPort{},
}
req := &ReqPatchSettingsHTTP{}
// TODO(a.garipov): Validate nulls and proper JSON patch.
@@ -56,13 +57,14 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
TLS: svc.tls,
Addresses: req.Addresses,
SecureAddresses: req.SecureAddresses,
Timeout: svc.timeout,
Timeout: req.Timeout.Duration,
ForceHTTPS: svc.forceHTTPS,
}
writeJSONResponse(w, r, &httpAPIHTTPSettings{
writeJSONResponse(w, r, &HTTPAPIHTTPSettings{
Addresses: newConf.Addresses,
SecureAddresses: newConf.SecureAddresses,
Timeout: timeutil.Duration{Duration: newConf.Timeout},
})
cancelUpd := func() {}

View File

@@ -15,8 +15,8 @@ import (
type RespGetV1SettingsAll struct {
// TODO(a.garipov): Add more as we go.
DNS *httpAPIDNSSettings `json:"dns"`
HTTP *httpAPIHTTPSettings `json:"http"`
DNS *HTTPAPIDNSSettings `json:"dns"`
HTTP *HTTPAPIHTTPSettings `json:"http"`
}
// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP
@@ -25,18 +25,21 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request)
dnsSvc := svc.confMgr.DNS()
dnsConf := dnsSvc.Config()
httpConf := svc.Config()
webSvc := svc.confMgr.Web()
httpConf := webSvc.Config()
// TODO(a.garipov): Add all currently supported parameters.
writeJSONResponse(w, r, &RespGetV1SettingsAll{
DNS: &httpAPIDNSSettings{
DNS: &HTTPAPIDNSSettings{
Addresses: dnsConf.Addresses,
BootstrapServers: dnsConf.BootstrapServers,
UpstreamServers: dnsConf.UpstreamServers,
UpstreamTimeout: timeutil.Duration{Duration: dnsConf.UpstreamTimeout},
},
HTTP: &httpAPIHTTPSettings{
HTTP: &HTTPAPIHTTPSettings{
Addresses: httpConf.Addresses,
SecureAddresses: httpConf.SecureAddresses,
Timeout: timeutil.Duration{Duration: httpConf.Timeout},
},
})
}

View File

@@ -0,0 +1,61 @@
package websvc_test
import (
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandleGetSettingsAll(t *testing.T) {
// TODO(a.garipov): Add all currently supported parameters.
dnsAddrs := []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")}
upsSrvs := []string{"94.140.14.14", "1.1.1.1"}
webAddrs := []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}
const webTimeout = 5 * time.Second
confMgr := newConfigManager()
confMgr.onDNS = func() (c *dnssvc.Service) {
c, err := dnssvc.New(&dnssvc.Config{
Addresses: dnsAddrs,
UpstreamServers: upsSrvs,
})
require.NoError(t, err)
return c
}
confMgr.onWeb = func() (c *websvc.Service) {
return websvc.New(&websvc.Config{
Addresses: webAddrs,
Timeout: webTimeout,
})
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsAll,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{}
err := json.Unmarshal(body, resp)
require.NoError(t, err)
assert.Equal(t, dnsAddrs, resp.DNS.Addresses)
assert.Equal(t, upsSrvs, resp.DNS.UpstreamServers)
assert.Equal(t, webAddrs, resp.HTTP.Addresses)
assert.Equal(t, webTimeout, resp.HTTP.Timeout.Duration)
}

View File

@@ -14,7 +14,8 @@ import (
)
func TestService_handleGetV1SystemInfo(t *testing.T) {
_, addr := newTestServer(t)
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),

View File

@@ -1,13 +1,12 @@
package websvc
import (
"fmt"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
)
@@ -33,15 +32,7 @@ func TestWaitListener_Accept(t *testing.T) {
wg.Add(1)
done := make(chan struct{})
a := time.After(testTimeout)
go func() {
select {
case <-a:
panic(fmt.Errorf("did not finish after %s", testTimeout))
case <-done:
// Success.
}
}()
go aghchan.MustReceive(done, testTimeout)
go func() {
var wrapper net.Listener = &waitListener{

View File

@@ -170,10 +170,12 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) {
}
// addrs returns all addresses on which this server serves the HTTP API. addrs
// must not be called until Start returns.
func (svc *Service) addrs() (addrs, secAddrs []netip.AddrPort) {
// must not be called simultaneously with Start. If svc was initialized with
// ":0" addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
for _, srv := range svc.servers {
ipp, err := netip.ParseAddrPort(srv.Addr)
addrPort, err := netip.ParseAddrPort(srv.Addr)
if err != nil {
// Technically shouldn't happen, since all servers must have a valid
// address.
@@ -184,14 +186,14 @@ func (svc *Service) addrs() (addrs, secAddrs []netip.AddrPort) {
// relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
addrs = append(addrs, ipp)
addrs = append(addrs, addrPort)
} else {
secAddrs = append(secAddrs, ipp)
secureAddrs = append(secureAddrs, addrPort)
}
}
return addrs, secAddrs
return addrs, secureAddrs
}
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
@@ -279,8 +281,10 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return nil
}
// Config returns the current configuration of the web service. Currently, only
// the Addresses and SecureAddresses fields are filled in c.
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
c = &Config{
ConfigManager: svc.confMgr,

View File

@@ -9,32 +9,89 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testStart is the server start value for tests.
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
// type check
var _ websvc.ConfigManager = (*configManager)(nil)
// configManager is a [websvc.ConfigManager] for tests.
type configManager struct {
onDNS func() (svc *dnssvc.Service)
onWeb func() (svc *websvc.Service)
onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error)
onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error)
}
// DNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) DNS() (svc *dnssvc.Service) {
return m.onDNS()
}
// Web implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) Web() (svc *websvc.Service) {
return m.onWeb()
}
// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
return m.onUpdateDNS(ctx, c)
}
// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
return m.onUpdateWeb(ctx, c)
}
// newConfigManager returns a *configManager all methods of which panic.
func newConfigManager() (m *configManager) {
return &configManager{
onDNS: func() (svc *dnssvc.Service) { panic("not implemented") },
onWeb: func() (svc *websvc.Service) { panic("not implemented") },
onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) {
panic("not implemented")
},
onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) {
panic("not implemented")
},
}
}
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(t testing.TB) (svc *websvc.Service, addr netip.AddrPort) {
func newTestServer(
t testing.TB,
confMgr websvc.ConfigManager,
) (svc *websvc.Service, addr netip.AddrPort) {
t.Helper()
c := &websvc.Config{
ConfigManager: confMgr,
TLS: nil,
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: nil,
Timeout: testTimeout,
Start: testStart,
ForceHTTPS: false,
}
svc = websvc.New(c)
@@ -82,7 +139,8 @@ func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
}
func TestService_Start_getHealthCheck(t *testing.T) {
_, addr := newTestServer(t)
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),