diff --git a/internal/aghhttp/aghhttp.go b/internal/aghhttp/aghhttp.go index 64e31075..41a59d26 100644 --- a/internal/aghhttp/aghhttp.go +++ b/internal/aghhttp/aghhttp.go @@ -33,5 +33,5 @@ func Error(r *http.Request, w http.ResponseWriter, code int, format string, args // UserAgent returns the ID of the service as a User-Agent string. It can also // be used as the value of the Server HTTP header. func UserAgent() (ua string) { - return fmt.Sprintf("AdGuardDNS/%s", version.Version()) + return fmt.Sprintf("AdGuardHome/%s", version.Version()) } diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 2de9d372..7aae35ee 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -1,6 +1,7 @@ package aghtest import ( + "context" "io/fs" "net" @@ -15,6 +16,8 @@ import ( // Standard Library +// Package fs + // type check var _ fs.FS = &FS{} @@ -58,6 +61,8 @@ func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { return fsys.OnStat(name) } +// Package net + // type check var _ net.Listener = (*Listener)(nil) @@ -83,32 +88,10 @@ func (l *Listener) Close() (err error) { return l.OnClose() } -// Module dnsproxy - -// type check -var _ upstream.Upstream = (*UpstreamMock)(nil) - -// UpstreamMock is a mock [upstream.Upstream] implementation for tests. -// -// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and -// rename it to just Upstream. -type UpstreamMock struct { - OnAddress func() (addr string) - OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) -} - -// Address implements the [upstream.Upstream] interface for *UpstreamMock. -func (u *UpstreamMock) Address() (addr string) { - return u.OnAddress() -} - -// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. -func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { - return u.OnExchange(req) -} - // Module AdGuardHome +// Package aghos + // type check var _ aghos.FSWatcher = (*FSWatcher)(nil) @@ -133,3 +116,57 @@ func (w *FSWatcher) Add(name string) (err error) { func (w *FSWatcher) Close() (err error) { return w.OnClose() } + +// Package websvc + +// ServiceWithConfig is a mock [websvc.ServiceWithConfig] implementation for +// tests. +type ServiceWithConfig[ConfigType any] struct { + OnStart func() (err error) + OnShutdown func(ctx context.Context) (err error) + OnConfig func() (c ConfigType) +} + +// Start implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[_]) Start() (err error) { + return s.OnStart() +} + +// Shutdown implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) { + return s.OnShutdown(ctx) +} + +// Config implements the [websvc.ServiceWithConfig] interface for +// *ServiceWithConfig. +func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) { + return s.OnConfig() +} + +// Module dnsproxy + +// Package upstream + +// type check +var _ upstream.Upstream = (*UpstreamMock)(nil) + +// UpstreamMock is a mock [upstream.Upstream] implementation for tests. +// +// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and +// rename it to just Upstream. +type UpstreamMock struct { + OnAddress func() (addr string) + OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) +} + +// Address implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Address() (addr string) { + return u.OnAddress() +} + +// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + return u.OnExchange(req) +} diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index 5a465c2c..bd2c0823 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -1,9 +1,9 @@ package aghtest_test import ( - "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" ) // type check -var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil) +var _ websvc.ServiceWithConfig[struct{}] = (*aghtest.ServiceWithConfig[struct{}])(nil) diff --git a/internal/next/websvc/dns.go b/internal/next/websvc/dns.go index 3ec98692..8846813d 100644 --- a/internal/next/websvc/dns.go +++ b/internal/next/websvc/dns.go @@ -12,8 +12,6 @@ import ( // DNS Settings Handlers -// TODO(a.garipov): !! Write tests! - // ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns // HTTP API. type ReqPatchSettingsDNS struct { @@ -49,7 +47,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - writeHTTPError(w, r, fmt.Errorf("decoding: %w", err)) + writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err)) return } @@ -64,7 +62,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques ctx := r.Context() err = svc.confMgr.UpdateDNS(ctx, newConf) if err != nil { - writeHTTPError(w, r, fmt.Errorf("updating: %w", err)) + writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", err)) return } @@ -72,12 +70,12 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques newSvc := svc.confMgr.DNS() err = newSvc.Start() if err != nil { - writeHTTPError(w, r, fmt.Errorf("starting new service: %w", err)) + writeJSONErrorResponse(w, r, fmt.Errorf("starting new service: %w", err)) return } - writeJSONResponse(w, r, &HTTPAPIDNSSettings{ + writeJSONOKResponse(w, r, &HTTPAPIDNSSettings{ Addresses: newConf.Addresses, BootstrapServers: newConf.BootstrapServers, UpstreamServers: newConf.UpstreamServers, diff --git a/internal/next/websvc/dns_test.go b/internal/next/websvc/dns_test.go new file mode 100644 index 00000000..f774c3d8 --- /dev/null +++ b/internal/next/websvc/dns_test.go @@ -0,0 +1,68 @@ +package websvc_test + +import ( + "context" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "sync/atomic" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandlePatchSettingsDNS(t *testing.T) { + wantDNS := &websvc.HTTPAPIDNSSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:53")}, + BootstrapServers: []string{"1.0.0.1"}, + UpstreamServers: []string{"1.1.1.1"}, + UpstreamTimeout: websvc.JSONDuration(2 * time.Second), + } + + // TODO(a.garipov): Use [atomic.Bool] in Go 1.19. + var numStarted uint64 + confMgr := newConfigManager() + confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { + return &aghtest.ServiceWithConfig[*dnssvc.Config]{ + OnStart: func() (err error) { + atomic.AddUint64(&numStarted, 1) + + return nil + }, + OnShutdown: func(_ context.Context) (err error) { panic("not implemented") }, + OnConfig: func() (c *dnssvc.Config) { panic("not implemented") }, + } + } + confMgr.onUpdateDNS = func(ctx context.Context, c *dnssvc.Config) (err error) { + return nil + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsDNS, + } + + req := jobj{ + "addresses": wantDNS.Addresses, + "bootstrap_servers": wantDNS.BootstrapServers, + "upstream_servers": wantDNS.UpstreamServers, + "upstream_timeout": wantDNS.UpstreamTimeout, + } + + respBody := httpPatch(t, u, req, http.StatusOK) + resp := &websvc.HTTPAPIDNSSettings{} + err := json.Unmarshal(respBody, resp) + require.NoError(t, err) + + assert.Equal(t, uint64(1), numStarted) + assert.Equal(t, wantDNS, resp) + assert.Equal(t, wantDNS, resp) +} diff --git a/internal/next/websvc/http.go b/internal/next/websvc/http.go index 928b37e0..b00a4e70 100644 --- a/internal/next/websvc/http.go +++ b/internal/next/websvc/http.go @@ -45,7 +45,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - writeHTTPError(w, r, fmt.Errorf("decoding: %w", err)) + writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err)) return } @@ -59,7 +59,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque ForceHTTPS: svc.forceHTTPS, } - writeJSONResponse(w, r, &HTTPAPIHTTPSettings{ + writeJSONOKResponse(w, r, &HTTPAPIHTTPSettings{ Addresses: newConf.Addresses, SecureAddresses: newConf.SecureAddresses, Timeout: JSONDuration(newConf.Timeout), @@ -81,13 +81,13 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque updErr := svc.confMgr.UpdateWeb(updCtx, newConf) if updErr != nil { - writeHTTPError(w, r, fmt.Errorf("updating: %w", updErr)) + writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", updErr)) return } // TODO(a.garipov): !! Add some kind of timeout? Context? - var newSvc *Service + var newSvc ServiceWithConfig[*Config] for newSvc = svc.confMgr.Web(); newSvc == svc; { log.Debug("websvc: waiting for new websvc to be configured") time.Sleep(1 * time.Second) diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go index 6831e8e4..baf384da 100644 --- a/internal/next/websvc/http_test.go +++ b/internal/next/websvc/http_test.go @@ -24,7 +24,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) { } confMgr := newConfigManager() - confMgr.onWeb = func() (c *websvc.Service) { + confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { return websvc.New(&websvc.Config{ TLS: &tls.Config{ Certificates: []tls.Certificate{{}}, @@ -50,7 +50,7 @@ func TestService_HandlePatchSettingsHTTP(t *testing.T) { "addresses": wantWeb.Addresses, "secure_addresses": wantWeb.SecureAddresses, "timeout": wantWeb.Timeout, - "ForceHTTPS": wantWeb.ForceHTTPS, + "force_https": wantWeb.ForceHTTPS, } respBody := httpPatch(t, u, req, http.StatusOK) diff --git a/internal/next/websvc/json.go b/internal/next/websvc/json.go index 4b35c694..fa2010a8 100644 --- a/internal/next/websvc/json.go +++ b/internal/next/websvc/json.go @@ -3,7 +3,6 @@ package websvc import ( "encoding/json" "fmt" - "io" "net/http" "strconv" "time" @@ -87,29 +86,58 @@ func (t *JSONTime) UnmarshalJSON(b []byte) (err error) { return nil } -// writeJSONResponse encodes v into w and logs any errors it encounters. r is -// used to get additional information from the request. -func writeJSONResponse(w http.ResponseWriter, r *http.Request, v any) { +// writeJSONOKResponse writes headers with the code 200 OK, encodes v into w, +// and logs any errors it encounters. r is used to get additional information +// from the request. +func writeJSONOKResponse(w http.ResponseWriter, r *http.Request, v any) { + writeJSONResponse(w, r, v, http.StatusOK) +} + +// writeJSONResponse writes headers with code, encodes v into w, and logs any +// errors it encounters. r is used to get additional information from the +// request. +func writeJSONResponse(w http.ResponseWriter, r *http.Request, v any, code int) { // TODO(a.garipov): Put some of these to a middleware. h := w.Header() h.Set(aghhttp.HdrNameContentType, aghhttp.HdrValApplicationJSON) h.Set(aghhttp.HdrNameServer, aghhttp.UserAgent()) + w.WriteHeader(code) + err := json.NewEncoder(w).Encode(v) if err != nil { log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err) } } -// writeHTTPError is a helper for logging and writing HTTP errors. +// ErrorCode is the error code as used by the HTTP API. See the ErrorCode +// definition in the OpenAPI specification. +type ErrorCode string + +// ErrorCode constants. // -// TODO(a.garipov): Improve codes, and add JSON error codes. -func writeHTTPError(w http.ResponseWriter, r *http.Request, err error) { +// TODO(a.garipov): Expand and document codes. +const ( + // ErrorCodeTMP000 is the temporary error code used for all errors. + ErrorCodeTMP000 = "" +) + +// HTTPAPIErrorResp is the error response as used by the HTTP API. See the +// BadRequestResp, InternalServerErrorResp, and similar objects in the OpenAPI +// specification. +type HTTPAPIErrorResp struct { + Code ErrorCode `json:"code"` + Msg string `json:"msg"` +} + +// writeJSONErrorResponse encodes err as a JSON error into w, and logs any +// errors it encounters. r is used to get additional information from the +// request. +func writeJSONErrorResponse(w http.ResponseWriter, r *http.Request, err error) { log.Error("websvc: %s %s: %s", r.Method, r.URL.Path, err) - w.WriteHeader(http.StatusUnprocessableEntity) - _, werr := io.WriteString(w, err.Error()) - if werr != nil { - log.Debug("websvc: writing error resp to %s %s: %s", r.Method, r.URL.Path, werr) - } + writeJSONResponse(w, r, &HTTPAPIErrorResp{ + Code: ErrorCodeTMP000, + Msg: err.Error(), + }, http.StatusUnprocessableEntity) } diff --git a/internal/next/websvc/settings.go b/internal/next/websvc/settings.go index 1c4267ef..b6c5a80a 100644 --- a/internal/next/websvc/settings.go +++ b/internal/next/websvc/settings.go @@ -25,7 +25,7 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) httpConf := webSvc.Config() // TODO(a.garipov): Add all currently supported parameters. - writeJSONResponse(w, r, &RespGetV1SettingsAll{ + writeJSONOKResponse(w, r, &RespGetV1SettingsAll{ DNS: &HTTPAPIDNSSettings{ Addresses: dnsConf.Addresses, BootstrapServers: dnsConf.BootstrapServers, diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go index 519edb6f..dadb4b55 100644 --- a/internal/next/websvc/settings_test.go +++ b/internal/next/websvc/settings_test.go @@ -33,7 +33,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) { } confMgr := newConfigManager() - confMgr.onDNS = func() (c *dnssvc.Service) { + confMgr.onDNS = func() (s websvc.ServiceWithConfig[*dnssvc.Config]) { c, err := dnssvc.New(&dnssvc.Config{ Addresses: wantDNS.Addresses, UpstreamServers: wantDNS.UpstreamServers, @@ -45,7 +45,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) { return c } - confMgr.onWeb = func() (c *websvc.Service) { + confMgr.onWeb = func() (s websvc.ServiceWithConfig[*websvc.Config]) { return websvc.New(&websvc.Config{ TLS: &tls.Config{ Certificates: []tls.Certificate{{}}, diff --git a/internal/next/websvc/system.go b/internal/next/websvc/system.go index 1d329db8..fbf60fe4 100644 --- a/internal/next/websvc/system.go +++ b/internal/next/websvc/system.go @@ -23,7 +23,7 @@ type RespGetV1SystemInfo struct { // handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP // API. func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) { - writeJSONResponse(w, r, &RespGetV1SystemInfo{ + writeJSONOKResponse(w, r, &RespGetV1SystemInfo{ Arch: runtime.GOARCH, Channel: version.Channel(), OS: runtime.GOOS, diff --git a/internal/next/websvc/websvc.go b/internal/next/websvc/websvc.go index 5247dbf1..75f7d001 100644 --- a/internal/next/websvc/websvc.go +++ b/internal/next/websvc/websvc.go @@ -24,10 +24,21 @@ import ( httptreemux "github.com/dimfeld/httptreemux/v5" ) +// ServiceWithConfig is an extension of the [agh.Service] interface for services +// that can return their configuration. +// +// TODO(a.garipov): Consider removing this generic interface if we figure out +// how to make it testable in a better way. +type ServiceWithConfig[ConfigType any] interface { + agh.Service + + Config() (c ConfigType) +} + // ConfigManager is the configuration manager interface. type ConfigManager interface { - DNS() (svc *dnssvc.Service) - Web() (svc *Service) + DNS() (svc ServiceWithConfig[*dnssvc.Config]) + Web() (svc ServiceWithConfig[*Config]) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) UpdateWeb(ctx context.Context, c *Config) (err error) diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index c58010c4..dbce77d5 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -34,20 +34,20 @@ var _ websvc.ConfigManager = (*configManager)(nil) // configManager is a [websvc.ConfigManager] for tests. type configManager struct { - onDNS func() (svc *dnssvc.Service) - onWeb func() (svc *websvc.Service) + onDNS func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) + onWeb func() (svc websvc.ServiceWithConfig[*websvc.Config]) onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error) onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error) } // DNS implements the [websvc.ConfigManager] interface for *configManager. -func (m *configManager) DNS() (svc *dnssvc.Service) { +func (m *configManager) DNS() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { return m.onDNS() } // Web implements the [websvc.ConfigManager] interface for *configManager. -func (m *configManager) Web() (svc *websvc.Service) { +func (m *configManager) Web() (svc websvc.ServiceWithConfig[*websvc.Config]) { return m.onWeb() } @@ -64,8 +64,8 @@ func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err er // newConfigManager returns a *configManager all methods of which panic. func newConfigManager() (m *configManager) { return &configManager{ - onDNS: func() (svc *dnssvc.Service) { panic("not implemented") }, - onWeb: func() (svc *websvc.Service) { panic("not implemented") }, + onDNS: func() (svc websvc.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") }, + onWeb: func() (svc websvc.ServiceWithConfig[*websvc.Config]) { panic("not implemented") }, onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) { panic("not implemented") },