diff --git a/internal/aghchan/aghchan.go b/internal/aghchan/aghchan.go new file mode 100644 index 00000000..5e504e45 --- /dev/null +++ b/internal/aghchan/aghchan.go @@ -0,0 +1,19 @@ +// Package aghchan contains channel utilities. +package aghchan + +import ( + "fmt" + "time" +) + +// MustReceive panics if it cannot receive a value form c before timeout runs +// out. +func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) { + timeoutCh := time.After(timeout) + select { + case <-timeoutCh: + panic(fmt.Errorf("did not receive after %s", timeout)) + case v, ok = <-c: + return v, ok + } +} diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 1f75a3c9..d2637d85 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -10,9 +10,9 @@ import ( "testing/fstest" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter" @@ -163,15 +163,9 @@ func TestHostsContainer_refresh(t *testing.T) { checkRefresh := func(t *testing.T, want *HostsRecord) { t.Helper() - var ok bool - var upd *netutil.IPMap - select { - case upd, ok = <-hc.Upd(): - require.True(t, ok) - require.NotNil(t, upd) - case <-time.After(1 * time.Second): - t.Fatal("did not receive after 1s") - } + upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second) + require.True(t, ok) + require.NotNil(t, upd) assert.Equal(t, 1, upd.Len()) diff --git a/internal/v1/dnssvc/dnssvc.go b/internal/v1/dnssvc/dnssvc.go index ffe5b080..31860fa0 100644 --- a/internal/v1/dnssvc/dnssvc.go +++ b/internal/v1/dnssvc/dnssvc.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/netip" + "sync/atomic" "time" "github.com/AdguardTeam/AdGuardHome/internal/v1/agh" @@ -47,6 +48,11 @@ type Config struct { // Service is the AdGuard Home DNS service. A nil *Service is a valid // [agh.Service] that does nothing. type Service struct { + // running is an atomic boolean value. Keep it the first value in the + // struct to ensure atomic alignment. 0 means that the service is not + // running, 1 means that it is running. + running uint64 + proxy *proxy.Proxy bootstraps []string upstreams []string @@ -160,6 +166,13 @@ func (svc *Service) Start() (err error) { return nil } + defer func() { + // TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to + // tell when all servers are actually up, so at best this is merely an + // assumption. + atomic.StoreUint64(&svc.running, 1) + }() + return svc.proxy.Start() } @@ -173,13 +186,27 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return svc.proxy.Stop() } -// Config returns the current configuration of the web service. +// Config returns the current configuration of the web service. Config must not +// be called simultaneously with Start. If svc was initialized with ":0" +// addresses, addrs will not return the actual bound ports until Start is +// finished. func (svc *Service) Config() (c *Config) { // TODO(a.garipov): Do we need to get the TCP addresses separately? - udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP) - addrs := make([]netip.AddrPort, len(udpAddrs)) - for i, a := range udpAddrs { - addrs[i] = a.(*net.UDPAddr).AddrPort() + + var addrs []netip.AddrPort + if atomic.LoadUint64(&svc.running) == 1 { + udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP) + addrs = make([]netip.AddrPort, len(udpAddrs)) + for i, a := range udpAddrs { + addrs[i] = a.(*net.UDPAddr).AddrPort() + } + } else { + conf := svc.proxy.Config + udpAddrs := conf.UDPListenAddr + addrs = make([]netip.AddrPort, len(udpAddrs)) + for i, a := range udpAddrs { + addrs[i] = a.AddrPort() + } } c = &Config{ diff --git a/internal/v1/websvc/dns.go b/internal/v1/websvc/dns.go index 670a5b31..b536aedd 100644 --- a/internal/v1/websvc/dns.go +++ b/internal/v1/websvc/dns.go @@ -25,8 +25,9 @@ type ReqPatchSettingsDNS struct { UpstreamTimeout timeutil.Duration `json:"upstream_timeout"` } -// httpAPIDNSSettings are the DNS settings as used by the HTTP API. -type httpAPIDNSSettings struct { +// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the +// DnsSettings object in the OpenAPI specification. +type HTTPAPIDNSSettings struct { // TODO(a.garipov): Add more as we go. Addresses []netip.AddrPort `json:"addresses"` @@ -76,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques return } - writeJSONResponse(w, r, &httpAPIDNSSettings{ + writeJSONResponse(w, r, &HTTPAPIDNSSettings{ Addresses: newConf.Addresses, BootstrapServers: newConf.BootstrapServers, UpstreamServers: newConf.UpstreamServers, diff --git a/internal/v1/websvc/http.go b/internal/v1/websvc/http.go index 279043e4..6a2b206d 100644 --- a/internal/v1/websvc/http.go +++ b/internal/v1/websvc/http.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/timeutil" ) // HTTP Settings Handlers @@ -22,25 +23,25 @@ type ReqPatchSettingsHTTP struct { // // TODO(a.garipov): Add wait time. - Addresses []netip.AddrPort `json:"addresses"` - SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout timeutil.Duration `json:"timeout"` } -// httpAPIDNSSettings are the HTTP settings as used by the HTTP API. -type httpAPIHTTPSettings struct { +// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the +// HttpSettings object in the OpenAPI specification. +type HTTPAPIHTTPSettings struct { // TODO(a.garipov): Add more as we go. - Addresses []netip.AddrPort `json:"addresses"` - SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout timeutil.Duration `json:"timeout"` } // handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http // HTTP API. func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) { - req := &ReqPatchSettingsHTTP{ - Addresses: []netip.AddrPort{}, - SecureAddresses: []netip.AddrPort{}, - } + req := &ReqPatchSettingsHTTP{} // TODO(a.garipov): Validate nulls and proper JSON patch. @@ -56,13 +57,14 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque TLS: svc.tls, Addresses: req.Addresses, SecureAddresses: req.SecureAddresses, - Timeout: svc.timeout, + Timeout: req.Timeout.Duration, ForceHTTPS: svc.forceHTTPS, } - writeJSONResponse(w, r, &httpAPIHTTPSettings{ + writeJSONResponse(w, r, &HTTPAPIHTTPSettings{ Addresses: newConf.Addresses, SecureAddresses: newConf.SecureAddresses, + Timeout: timeutil.Duration{Duration: newConf.Timeout}, }) cancelUpd := func() {} diff --git a/internal/v1/websvc/settings.go b/internal/v1/websvc/settings.go index c9bd922c..1d076263 100644 --- a/internal/v1/websvc/settings.go +++ b/internal/v1/websvc/settings.go @@ -15,8 +15,8 @@ import ( type RespGetV1SettingsAll struct { // TODO(a.garipov): Add more as we go. - DNS *httpAPIDNSSettings `json:"dns"` - HTTP *httpAPIHTTPSettings `json:"http"` + DNS *HTTPAPIDNSSettings `json:"dns"` + HTTP *HTTPAPIHTTPSettings `json:"http"` } // handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP @@ -25,18 +25,21 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) dnsSvc := svc.confMgr.DNS() dnsConf := dnsSvc.Config() - httpConf := svc.Config() + webSvc := svc.confMgr.Web() + httpConf := webSvc.Config() + // TODO(a.garipov): Add all currently supported parameters. writeJSONResponse(w, r, &RespGetV1SettingsAll{ - DNS: &httpAPIDNSSettings{ + DNS: &HTTPAPIDNSSettings{ Addresses: dnsConf.Addresses, BootstrapServers: dnsConf.BootstrapServers, UpstreamServers: dnsConf.UpstreamServers, UpstreamTimeout: timeutil.Duration{Duration: dnsConf.UpstreamTimeout}, }, - HTTP: &httpAPIHTTPSettings{ + HTTP: &HTTPAPIHTTPSettings{ Addresses: httpConf.Addresses, SecureAddresses: httpConf.SecureAddresses, + Timeout: timeutil.Duration{Duration: httpConf.Timeout}, }, }) } diff --git a/internal/v1/websvc/settings_test.go b/internal/v1/websvc/settings_test.go new file mode 100644 index 00000000..a92b5b32 --- /dev/null +++ b/internal/v1/websvc/settings_test.go @@ -0,0 +1,61 @@ +package websvc_test + +import ( + "encoding/json" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandleGetSettingsAll(t *testing.T) { + // TODO(a.garipov): Add all currently supported parameters. + + dnsAddrs := []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")} + upsSrvs := []string{"94.140.14.14", "1.1.1.1"} + + webAddrs := []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")} + const webTimeout = 5 * time.Second + + confMgr := newConfigManager() + confMgr.onDNS = func() (c *dnssvc.Service) { + c, err := dnssvc.New(&dnssvc.Config{ + Addresses: dnsAddrs, + UpstreamServers: upsSrvs, + }) + require.NoError(t, err) + + return c + } + + confMgr.onWeb = func() (c *websvc.Service) { + return websvc.New(&websvc.Config{ + Addresses: webAddrs, + Timeout: webTimeout, + }) + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsAll, + } + + body := httpGet(t, u, http.StatusOK) + resp := &websvc.RespGetV1SettingsAll{} + err := json.Unmarshal(body, resp) + require.NoError(t, err) + + assert.Equal(t, dnsAddrs, resp.DNS.Addresses) + assert.Equal(t, upsSrvs, resp.DNS.UpstreamServers) + + assert.Equal(t, webAddrs, resp.HTTP.Addresses) + assert.Equal(t, webTimeout, resp.HTTP.Timeout.Duration) +} diff --git a/internal/v1/websvc/system_test.go b/internal/v1/websvc/system_test.go index c267e5ac..ad81637f 100644 --- a/internal/v1/websvc/system_test.go +++ b/internal/v1/websvc/system_test.go @@ -14,7 +14,8 @@ import ( ) func TestService_handleGetV1SystemInfo(t *testing.T) { - _, addr := newTestServer(t) + confMgr := newConfigManager() + _, addr := newTestServer(t, confMgr) u := &url.URL{ Scheme: "http", Host: addr.String(), diff --git a/internal/v1/websvc/waitlistener_internal_test.go b/internal/v1/websvc/waitlistener_internal_test.go index 74c0bf80..3d51baa6 100644 --- a/internal/v1/websvc/waitlistener_internal_test.go +++ b/internal/v1/websvc/waitlistener_internal_test.go @@ -1,13 +1,12 @@ package websvc import ( - "fmt" "net" "sync" "sync/atomic" "testing" - "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" ) @@ -33,15 +32,7 @@ func TestWaitListener_Accept(t *testing.T) { wg.Add(1) done := make(chan struct{}) - a := time.After(testTimeout) - go func() { - select { - case <-a: - panic(fmt.Errorf("did not finish after %s", testTimeout)) - case <-done: - // Success. - } - }() + go aghchan.MustReceive(done, testTimeout) go func() { var wrapper net.Listener = &waitListener{ diff --git a/internal/v1/websvc/websvc.go b/internal/v1/websvc/websvc.go index 8183d92c..6bbdf7ec 100644 --- a/internal/v1/websvc/websvc.go +++ b/internal/v1/websvc/websvc.go @@ -170,10 +170,12 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) { } // addrs returns all addresses on which this server serves the HTTP API. addrs -// must not be called until Start returns. -func (svc *Service) addrs() (addrs, secAddrs []netip.AddrPort) { +// must not be called simultaneously with Start. If svc was initialized with +// ":0" addresses, addrs will not return the actual bound ports until Start is +// finished. +func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) { for _, srv := range svc.servers { - ipp, err := netip.ParseAddrPort(srv.Addr) + addrPort, err := netip.ParseAddrPort(srv.Addr) if err != nil { // Technically shouldn't happen, since all servers must have a valid // address. @@ -184,14 +186,14 @@ func (svc *Service) addrs() (addrs, secAddrs []netip.AddrPort) { // relying only on the nilness of TLSConfig, check the length of the // certificates field as well. if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 { - addrs = append(addrs, ipp) + addrs = append(addrs, addrPort) } else { - secAddrs = append(secAddrs, ipp) + secureAddrs = append(secureAddrs, addrPort) } } - return addrs, secAddrs + return addrs, secureAddrs } // handleGetHealthCheck is the handler for the GET /health-check HTTP API. @@ -279,8 +281,10 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) { return nil } -// Config returns the current configuration of the web service. Currently, only -// the Addresses and SecureAddresses fields are filled in c. +// Config returns the current configuration of the web service. Config must not +// be called simultaneously with Start. If svc was initialized with ":0" +// addresses, addrs will not return the actual bound ports until Start is +// finished. func (svc *Service) Config() (c *Config) { c = &Config{ ConfigManager: svc.confMgr, diff --git a/internal/v1/websvc/websvc_test.go b/internal/v1/websvc/websvc_test.go index 8b73a015..affb9ad5 100644 --- a/internal/v1/websvc/websvc_test.go +++ b/internal/v1/websvc/websvc_test.go @@ -9,32 +9,89 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + aghtest.DiscardLogOutput(m) +} + // testTimeout is the common timeout for tests. const testTimeout = 1 * time.Second // testStart is the server start value for tests. var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) +// type check +var _ websvc.ConfigManager = (*configManager)(nil) + +// configManager is a [websvc.ConfigManager] for tests. +type configManager struct { + onDNS func() (svc *dnssvc.Service) + onWeb func() (svc *websvc.Service) + + onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error) + onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error) +} + +// DNS implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) DNS() (svc *dnssvc.Service) { + return m.onDNS() +} + +// Web implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) Web() (svc *websvc.Service) { + return m.onWeb() +} + +// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) { + return m.onUpdateDNS(ctx, c) +} + +// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager. +func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) { + return m.onUpdateWeb(ctx, c) +} + +// newConfigManager returns a *configManager all methods of which panic. +func newConfigManager() (m *configManager) { + return &configManager{ + onDNS: func() (svc *dnssvc.Service) { panic("not implemented") }, + onWeb: func() (svc *websvc.Service) { panic("not implemented") }, + onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) { + panic("not implemented") + }, + onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) { + panic("not implemented") + }, + } +} + // newTestServer creates and starts a new web service instance as well as its // sole address. It also registers a cleanup procedure, which shuts the // instance down. // // TODO(a.garipov): Use svc or remove it. -func newTestServer(t testing.TB) (svc *websvc.Service, addr netip.AddrPort) { +func newTestServer( + t testing.TB, + confMgr websvc.ConfigManager, +) (svc *websvc.Service, addr netip.AddrPort) { t.Helper() c := &websvc.Config{ + ConfigManager: confMgr, TLS: nil, Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")}, SecureAddresses: nil, Timeout: testTimeout, Start: testStart, + ForceHTTPS: false, } svc = websvc.New(c) @@ -82,7 +139,8 @@ func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) { } func TestService_Start_getHealthCheck(t *testing.T) { - _, addr := newTestServer(t) + confMgr := newConfigManager() + _, addr := newTestServer(t, confMgr) u := &url.URL{ Scheme: "http", Host: addr.String(),