Pull request 2303: AGDNS-2505-upd-next

Squashed commit of the following:

commit 586b0eb180afc22d06d673756dd916c17a173361
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Nov 12 19:58:56 2024 +0300

    next: upd more

commit d729aa150f7ac367255830cceca40b8880c51015
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Nov 12 16:53:15 2024 +0300

    next/websvc: upd more

commit 0c64e6cfc66b9212f077b2de7450586fd4d02802
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Nov 11 21:08:51 2024 +0300

    next: upd more

commit 05eec75222360708621c99d3eeac7c8d9f2a5080
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Nov 8 19:20:02 2024 +0300

    next: upd code
This commit is contained in:
Ainar Garipov
2024-11-13 15:44:21 +03:00
parent ac5a96fada
commit 1d6d85cff4
34 changed files with 637 additions and 601 deletions

View File

@@ -3,12 +3,17 @@ package websvc
import (
"crypto/tls"
"io/fs"
"log/slog"
"net/netip"
"time"
)
// Config is the AdGuard Home web service configuration structure.
type Config struct {
// Logger is used for logging the operation of the web API service. It must
// not be nil.
Logger *slog.Logger
// Pprof is the configuration for the pprof debug API. It must not be nil.
Pprof *PprofConfig
@@ -60,17 +65,20 @@ type PprofConfig struct {
// finished.
func (svc *Service) Config() (c *Config) {
c = &Config{
Logger: svc.logger,
Pprof: &PprofConfig{
Port: svc.pprofPort,
Enabled: svc.pprof != nil,
},
ConfigManager: svc.confMgr,
Frontend: svc.frontend,
TLS: svc.tls,
// Leave Addresses and SecureAddresses empty and get the actual
// addresses that include the :0 ones later.
Start: svc.start,
Timeout: svc.timeout,
ForceHTTPS: svc.forceHTTPS,
Start: svc.start,
OverrideAddress: svc.overrideAddr,
Timeout: svc.timeout,
ForceHTTPS: svc.forceHTTPS,
}
c.Addresses, c.SecureAddresses = svc.addrs()

View File

@@ -11,8 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
)
// DNS Settings Handlers
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
// HTTP API.
type ReqPatchSettingsDNS struct {
@@ -60,6 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
}
newConf := &dnssvc.Config{
Logger: svc.logger,
Addresses: req.Addresses,
BootstrapServers: req.BootstrapServers,
UpstreamServers: req.UpstreamServers,
@@ -78,7 +77,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
}
newSvc := svc.confMgr.DNS()
err = newSvc.Start()
err = newSvc.Start(ctx)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))

View File

@@ -35,7 +35,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
OnStart: func() (err error) {
OnStart: func(_ context.Context) (err error) {
started.Store(true)
return nil
@@ -52,7 +52,7 @@ func TestService_HandlePatchSettingsDNS(t *testing.T) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathV1SettingsDNS,
Path: websvc.PathPatternV1SettingsDNS,
}
req := jobj{

View File

@@ -10,11 +10,9 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// HTTP Settings Handlers
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
// HTTP API.
type ReqPatchSettingsHTTP struct {
@@ -53,6 +51,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
}
newConf := &Config{
Logger: svc.logger,
Pprof: &PprofConfig{
Port: svc.pprofPort,
Enabled: svc.pprof != nil,
@@ -89,13 +88,13 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
// relaunch updates the web service in the configuration manager and starts it.
// It is intended to be used as a goroutine.
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
defer log.OnPanic("websvc: relaunching")
defer slogutil.RecoverAndLog(ctx, svc.logger)
defer cancel()
err := svc.confMgr.UpdateWeb(ctx, newConf)
if err != nil {
log.Error("websvc: updating web: %s", err)
svc.logger.ErrorContext(ctx, "updating web", slogutil.KeyError, err)
return
}
@@ -106,18 +105,18 @@ func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, new
var newSvc agh.ServiceWithConfig[*Config]
for newSvc = svc.confMgr.Web(); newSvc == svc; {
if time.Since(updStart) >= maxUpdDur {
log.Error("websvc: failed to update svc after %s", maxUpdDur)
svc.logger.ErrorContext(ctx, "failed to update service on time", "duration", maxUpdDur)
return
}
log.Debug("websvc: waiting for new websvc to be configured")
svc.logger.DebugContext(ctx, "waiting for new service")
time.Sleep(100 * time.Millisecond)
}
err = newSvc.Start()
err = newSvc.Start(ctx)
if err != nil {
log.Error("websvc: new svc failed to start with error: %s", err)
svc.logger.ErrorContext(ctx, "new service failed", slogutil.KeyError, err)
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -27,14 +28,15 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
}
svc, err := websvc.New(&websvc.Config{
Logger: slogutil.NewDiscardLogger(),
Pprof: &websvc.PprofConfig{
Enabled: false,
},
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
Timeout: 5 * time.Second,
ForceHTTPS: true,
})
@@ -48,7 +50,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathV1SettingsHTTP,
Path: websvc.PathPatternV1SettingsHTTP,
}
req := jobj{

View File

@@ -2,15 +2,11 @@ package websvc
import (
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
)
// Middlewares
// jsonMw sets the content type of the response to application/json.
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) {
@@ -21,18 +17,3 @@ func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
return http.HandlerFunc(f)
}
// logMw logs the queries with level debug.
func logMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
m, u := r.Method, r.RequestURI
log.Debug("websvc: %s %s started", m, u)
defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }()
h.ServeHTTP(w, r)
}
return http.HandlerFunc(f)
}

View File

@@ -1,14 +0,0 @@
package websvc
// Path constants
const (
PathRoot = "/"
PathFrontend = "/*filepath"
PathHealthCheck = "/health-check"
PathV1SettingsAll = "/api/v1/settings/all"
PathV1SettingsDNS = "/api/v1/settings/dns"
PathV1SettingsHTTP = "/api/v1/settings/http"
PathV1SystemInfo = "/api/v1/system/info"
)

View File

@@ -0,0 +1,73 @@
package websvc
import (
"log/slog"
"net/http"
"github.com/AdguardTeam/golibs/netutil/httputil"
)
// Path pattern constants.
const (
PathPatternFrontend = "/"
PathPatternHealthCheck = "/health-check"
PathPatternV1SettingsAll = "/api/v1/settings/all"
PathPatternV1SettingsDNS = "/api/v1/settings/dns"
PathPatternV1SettingsHTTP = "/api/v1/settings/http"
PathPatternV1SystemInfo = "/api/v1/system/info"
)
// Route pattern constants.
const (
routePatternFrontend = http.MethodGet + " " + PathPatternFrontend
routePatternGetV1SettingsAll = http.MethodGet + " " + PathPatternV1SettingsAll
routePatternGetV1SystemInfo = http.MethodGet + " " + PathPatternV1SystemInfo
routePatternHealthCheck = http.MethodGet + " " + PathPatternHealthCheck
routePatternPatchV1SettingsDNS = http.MethodPatch + " " + PathPatternV1SettingsDNS
routePatternPatchV1SettingsHTTP = http.MethodPatch + " " + PathPatternV1SettingsHTTP
)
// route registers all necessary handlers in mux.
func (svc *Service) route(mux *http.ServeMux) {
routes := []struct {
handler http.Handler
pattern string
isJSON bool
}{{
handler: httputil.HealthCheckHandler,
pattern: routePatternHealthCheck,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)),
pattern: routePatternFrontend,
isJSON: false,
}, {
handler: http.HandlerFunc(svc.handleGetSettingsAll),
pattern: routePatternGetV1SettingsAll,
isJSON: true,
}, {
handler: http.HandlerFunc(svc.handlePatchSettingsDNS),
pattern: routePatternPatchV1SettingsDNS,
isJSON: true,
}, {
handler: http.HandlerFunc(svc.handlePatchSettingsHTTP),
pattern: routePatternPatchV1SettingsHTTP,
isJSON: true,
}, {
handler: http.HandlerFunc(svc.handleGetV1SystemInfo),
pattern: routePatternGetV1SystemInfo,
isJSON: true,
}}
logMw := httputil.NewLogMiddleware(svc.logger, slog.LevelDebug)
for _, r := range routes {
var hdlr http.Handler
if r.isJSON {
hdlr = jsonMw(r.handler)
} else {
hdlr = r.handler
}
mux.Handle(r.pattern, logMw.Wrap(hdlr))
}
}

View File

@@ -0,0 +1,156 @@
package websvc
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/http"
"net/netip"
"net/url"
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
)
// server contains an *http.Server as well as entities and data associated with
// it.
//
// TODO(a.garipov): Join with similar structs in other projects and move to
// golibs/netutil/httputil.
//
// TODO(a.garipov): Once the above standardization is complete, consider
// merging debugsvc and websvc into a single httpsvc.
type server struct {
// mu protects http, logger, tcpListener, and url.
mu *sync.Mutex
http *http.Server
logger *slog.Logger
tcpListener *net.TCPListener
url *url.URL
tlsConf *tls.Config
initialAddr netip.AddrPort
}
// loggerKeyServer is the key used by [server] to identify itself.
const loggerKeyServer = "server"
// newServer returns a *server that is ready to serve HTTP queries. The TCP
// listener is not started. handler must not be nil.
func newServer(
baseLogger *slog.Logger,
initialAddr netip.AddrPort,
tlsConf *tls.Config,
handler http.Handler,
timeout time.Duration,
) (s *server) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: initialAddr.String(),
}
if tlsConf != nil {
u.Scheme = urlutil.SchemeHTTPS
}
logger := baseLogger.With(loggerKeyServer, u)
return &server{
mu: &sync.Mutex{},
http: &http.Server{
Handler: handler,
ReadTimeout: timeout,
ReadHeaderTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
},
logger: logger,
url: u,
tlsConf: tlsConf,
initialAddr: initialAddr,
}
}
// localAddr returns the local address of the server if the server has started
// listening; otherwise, it returns nil.
func (s *server) localAddr() (addr net.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
if l := s.tcpListener; l != nil {
return l.Addr()
}
return nil
}
// serve starts s. baseLogger is used as a base logger for s. If s fails to
// serve with anything other than [http.ErrServerClosed], it causes an unhandled
// panic. It is intended to be used as a goroutine.
//
// TODO(a.garipov): Improve error handling.
func (s *server) serve(ctx context.Context, baseLogger *slog.Logger) {
l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(s.initialAddr))
if err != nil {
s.logger.ErrorContext(ctx, "listening tcp", slogutil.KeyError, err)
panic(fmt.Errorf("websvc: listening tcp: %w", err))
}
func() {
s.mu.Lock()
defer s.mu.Unlock()
s.tcpListener = l
// Reassign the address in case the port was zero.
s.url.Host = l.Addr().String()
s.logger = baseLogger.With(loggerKeyServer, s.url)
s.http.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError)
}()
s.logger.InfoContext(ctx, "starting")
defer s.logger.InfoContext(ctx, "started")
err = s.http.Serve(l)
if err == nil || errors.Is(err, http.ErrServerClosed) {
return
}
s.logger.ErrorContext(ctx, "serving", slogutil.KeyError, err)
panic(fmt.Errorf("websvc: serving: %w", err))
}
// shutdown shuts s down.
func (s *server) shutdown(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
var errs []error
err = s.http.Shutdown(ctx)
if err != nil {
errs = append(errs, fmt.Errorf("shutting down server %s: %w", s.url, err))
}
// Close the listener separately, as it might not have been closed if the
// context has been canceled.
//
// NOTE: The listener could remain uninitialized if [net.ListenTCP] failed
// in [s.serve].
if l := s.tcpListener; l != nil {
err = l.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("closing listener for server %s: %w", s.url, err))
}
}
return errors.Join(errs...)
}

View File

@@ -1,7 +1,6 @@
package websvc_test
import (
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
@@ -13,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -29,16 +29,10 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
BootstrapPreferIPv6: true,
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: aghhttp.JSONDuration(5 * time.Second),
ForceHTTPS: true,
}
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
c, err := dnssvc.New(&dnssvc.Config{
Logger: slogutil.NewDiscardLogger(),
Addresses: wantDNS.Addresses,
UpstreamServers: wantDNS.UpstreamServers,
BootstrapServers: wantDNS.BootstrapServers,
@@ -50,34 +44,27 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
return c
}
svc, err := websvc.New(&websvc.Config{
Pprof: &websvc.PprofConfig{
Enabled: false,
},
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: time.Duration(wantWeb.Timeout),
ForceHTTPS: true,
})
require.NoError(t, err)
svc, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathPatternV1SettingsAll,
}
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return svc
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathV1SettingsAll,
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{addr},
SecureAddresses: nil,
Timeout: aghhttp.JSONDuration(testTimeout),
ForceHTTPS: false,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{}
err = json.Unmarshal(body, resp)
err := json.Unmarshal(body, resp)
require.NoError(t, err)
assert.Equal(t, wantDNS, resp.DNS)

View File

@@ -20,7 +20,7 @@ func TestService_handleGetV1SystemInfo(t *testing.T) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathV1SystemInfo,
Path: websvc.PathPatternV1SystemInfo,
}
body := httpGet(t, u, http.StatusOK)

View File

@@ -1,31 +0,0 @@
package websvc
import (
"net"
"sync"
)
// Wait Listener
// waitListener is a wrapper around a listener that also calls wg.Done() on the
// first call to Accept. It is useful in situations where it is important to
// catch the precise moment of the first call to Accept, for example when
// starting an HTTP server.
//
// TODO(a.garipov): Move to aghnet?
type waitListener struct {
net.Listener
firstAcceptWG *sync.WaitGroup
firstAcceptOnce sync.Once
}
// type check
var _ net.Listener = (*waitListener)(nil)
// Accept implements the [net.Listener] interface for *waitListener.
func (l *waitListener) Accept() (conn net.Conn, err error) {
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
return l.Listener.Accept()
}

View File

@@ -1,40 +0,0 @@
package websvc
import (
"net"
"sync"
"sync/atomic"
"testing"
"github.com/AdguardTeam/golibs/testutil/fakenet"
"github.com/stretchr/testify/assert"
)
func TestWaitListener_Accept(t *testing.T) {
var accepted atomic.Bool
var l net.Listener = &fakenet.Listener{
OnAccept: func() (conn net.Conn, err error) {
accepted.Store(true)
return nil, nil
},
OnAddr: func() (addr net.Addr) { panic("not implemented") },
OnClose: func() (err error) { panic("not implemented") },
}
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
var wrapper net.Listener = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
_, _ = wrapper.Accept()
}()
wg.Wait()
assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10)
}

View File

@@ -10,22 +10,18 @@ import (
"context"
"crypto/tls"
"fmt"
"io"
"io/fs"
"net"
"log/slog"
"net/http"
"net/netip"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/httputil"
httptreemux "github.com/dimfeld/httptreemux/v5"
)
// ConfigManager is the configuration manager interface.
@@ -40,13 +36,14 @@ type ConfigManager interface {
// Service is the AdGuard Home web service. A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
logger *slog.Logger
confMgr ConfigManager
frontend fs.FS
tls *tls.Config
pprof *http.Server
pprof *server
start time.Time
overrideAddr netip.AddrPort
servers []*http.Server
servers []*server
timeout time.Duration
pprofPort uint16
forceHTTPS bool
@@ -64,6 +61,7 @@ func New(c *Config) (svc *Service, err error) {
}
svc = &Service{
logger: c.Logger,
confMgr: c.ConfigManager,
frontend: c.Frontend,
tls: c.TLS,
@@ -73,17 +71,18 @@ func New(c *Config) (svc *Service, err error) {
forceHTTPS: c.ForceHTTPS,
}
mux := newMux(svc)
mux := http.NewServeMux()
svc.route(mux)
if svc.overrideAddr != (netip.AddrPort{}) {
svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)}
svc.servers = []*server{newServer(svc.logger, svc.overrideAddr, nil, mux, c.Timeout)}
} else {
for _, a := range c.Addresses {
svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout))
svc.servers = append(svc.servers, newServer(svc.logger, a, nil, mux, c.Timeout))
}
for _, a := range c.SecureAddresses {
svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout))
svc.servers = append(svc.servers, newServer(svc.logger, a, c.TLS, mux, c.Timeout))
}
}
@@ -112,96 +111,7 @@ func (svc *Service) setupPprof(c *PprofConfig) {
svc.pprofPort = c.Port
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
// TODO(a.garipov): Consider making pprof timeout configurable.
svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute)
}
// newSrv returns a new *http.Server with the given parameters.
func newSrv(
addr netip.AddrPort,
tlsConf *tls.Config,
h http.Handler,
timeout time.Duration,
) (srv *http.Server) {
addrStr := addr.String()
srv = &http.Server{
Addr: addrStr,
Handler: h,
TLSConfig: tlsConf,
ReadTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ReadHeaderTimeout: timeout,
}
if tlsConf == nil {
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
} else {
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
}
return srv
}
// newMux returns a new HTTP request multiplexer for the AdGuard Home web
// service.
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
mux = httptreemux.NewContextMux()
routes := []struct {
handler http.HandlerFunc
method string
pattern string
isJSON bool
}{{
handler: svc.handleGetHealthCheck,
method: http.MethodGet,
pattern: PathHealthCheck,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
method: http.MethodGet,
pattern: PathFrontend,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
method: http.MethodGet,
pattern: PathRoot,
isJSON: false,
}, {
handler: svc.handleGetSettingsAll,
method: http.MethodGet,
pattern: PathV1SettingsAll,
isJSON: true,
}, {
handler: svc.handlePatchSettingsDNS,
method: http.MethodPatch,
pattern: PathV1SettingsDNS,
isJSON: true,
}, {
handler: svc.handlePatchSettingsHTTP,
method: http.MethodPatch,
pattern: PathV1SettingsHTTP,
isJSON: true,
}, {
handler: svc.handleGetV1SystemInfo,
method: http.MethodGet,
pattern: PathV1SystemInfo,
isJSON: true,
}}
for _, r := range routes {
var hdlr http.Handler
if r.isJSON {
hdlr = jsonMw(r.handler)
} else {
hdlr = r.handler
}
mux.Handle(r.method, r.pattern, logMw(hdlr))
}
return mux
svc.pprof = newServer(svc.logger, addr, nil, pprofMux, 10*time.Minute)
}
// addrs returns all addresses on which this server serves the HTTP API. addrs
@@ -214,14 +124,12 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
}
for _, srv := range svc.servers {
// Use MustParseAddrPort, since no errors should technically happen
// here, because all servers must have a valid address.
addrPort := netip.MustParseAddrPort(srv.Addr)
addrPort := netutil.NetAddrToAddrPort(srv.localAddr())
if addrPort == (netip.AddrPort{}) {
continue
}
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
// of relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
if srv.tlsConf == nil {
addrs = append(addrs, addrPort)
} else {
secureAddrs = append(secureAddrs, addrPort)
@@ -231,74 +139,60 @@ func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
return addrs, secureAddrs
}
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "OK")
}
// type check
var _ agh.Service = (*Service)(nil)
var _ agh.ServiceWithConfig[*Config] = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all HTTP servers have tried to start, possibly failing and
// writing error messages to the log.
func (svc *Service) Start() (err error) {
//
// TODO(a.garipov): Use the context for cancelation as well.
func (svc *Service) Start(ctx context.Context) (err error) {
if svc == nil {
return nil
}
pprofEnabled := svc.pprof != nil
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled)
svc.logger.InfoContext(ctx, "starting")
defer svc.logger.InfoContext(ctx, "started")
wg := &sync.WaitGroup{}
wg.Add(srvNum)
for _, srv := range svc.servers {
go serve(srv, wg)
go srv.serve(ctx, svc.logger)
}
if pprofEnabled {
go serve(svc.pprof, wg)
if svc.pprof != nil {
go svc.pprof.serve(ctx, svc.logger)
}
wg.Wait()
return svc.wait(ctx)
}
// wait waits until either the context is canceled or all servers have started.
func (svc *Service) wait(ctx context.Context) (err error) {
for !svc.serversHaveStarted() {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Wait and let the other goroutines do their job.
runtime.Gosched()
}
}
return nil
}
// serve starts and runs srv and writes all errors into its log.
func serve(srv *http.Server, wg *sync.WaitGroup) {
addr := srv.Addr
defer log.OnPanic(addr)
var proto string
var l net.Listener
var err error
if srv.TLSConfig == nil {
proto = "http"
l, err = net.Listen("tcp", addr)
} else {
proto = "https"
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
}
if err != nil {
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
// serversHaveStarted returns true if all servers have started serving.
func (svc *Service) serversHaveStarted() (started bool) {
started = len(svc.servers) != 0
for _, srv := range svc.servers {
started = started && srv.localAddr() != nil
}
// Update the server's address in case the address had the port zero, which
// would mean that a random available port was automatically chosen.
srv.Addr = l.Addr().String()
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
l = &waitListener{
Listener: l,
firstAcceptWG: wg,
if svc.pprof != nil {
started = started && svc.pprof.localAddr() != nil
}
err = srv.Serve(l)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
}
return started
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
@@ -308,20 +202,24 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return nil
}
svc.logger.InfoContext(ctx, "shutting down")
defer svc.logger.InfoContext(ctx, "shut down")
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
var errs []error
for _, srv := range svc.servers {
shutdownErr := srv.Shutdown(ctx)
shutdownErr := srv.shutdown(ctx)
if shutdownErr != nil {
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr))
// Don't wrap the error, because it's informative enough as is.
errs = append(errs, err)
}
}
if svc.pprof != nil {
shutdownErr := svc.pprof.Shutdown(ctx)
shutdownErr := svc.pprof.shutdown(ctx)
if shutdownErr != nil {
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr))
errs = append(errs, fmt.Errorf("pprof: %w", shutdownErr))
}
}

View File

@@ -1,6 +0,0 @@
package websvc
import "time"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second

View File

@@ -15,6 +15,8 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
@@ -22,10 +24,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
@@ -81,8 +79,6 @@ func newConfigManager() (m *configManager) {
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(
t testing.TB,
confMgr websvc.ConfigManager,
@@ -90,6 +86,7 @@ func newTestServer(
t.Helper()
c := &websvc.Config{
Logger: slogutil.NewDiscardLogger(),
Pprof: &websvc.PprofConfig{
Enabled: false,
},
@@ -108,7 +105,7 @@ func newTestServer(
svc, err := websvc.New(c)
require.NoError(t, err)
err = svc.Start()
err = svc.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
@@ -184,10 +181,10 @@ func TestService_Start_getHealthCheck(t *testing.T) {
u := &url.URL{
Scheme: urlutil.SchemeHTTP,
Host: addr.String(),
Path: websvc.PathHealthCheck,
Path: websvc.PathPatternHealthCheck,
}
body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body)
assert.Equal(t, []byte(httputil.HealthCheckHandler), body)
}