From 1989c91c078c81836636f18d2b78b3f66369abca Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Fri, 9 Sep 2022 15:05:33 +0300 Subject: [PATCH] websvc: imp tests --- internal/next/dnssvc/dnssvc.go | 6 +- internal/next/websvc/dns.go | 22 +++---- internal/next/websvc/http.go | 21 +++---- internal/next/websvc/http_test.go | 62 +++++++++++++++++++ internal/next/websvc/json.go | 55 +++++++++++++--- .../{json_internal_test.go => json_test.go} | 23 +++---- internal/next/websvc/settings.go | 6 +- internal/next/websvc/settings_test.go | 9 ++- internal/next/websvc/system.go | 4 +- .../next/websvc/waitlistener_internal_test.go | 8 +-- internal/next/websvc/websvc_test.go | 34 ++++++++++ openapi/v1.yaml | 18 ++++-- 12 files changed, 201 insertions(+), 67 deletions(-) create mode 100644 internal/next/websvc/http_test.go rename internal/next/websvc/{json_internal_test.go => json_test.go} (82%) diff --git a/internal/next/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go index b62d3e51..f25fa294 100644 --- a/internal/next/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -173,7 +173,11 @@ func (svc *Service) Start() (err error) { // TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to // tell when all servers are actually up, so at best this is merely an // assumption. - atomic.StoreUint64(&svc.running, 1) + if err != nil { + atomic.StoreUint64(&svc.running, 0) + } else { + atomic.StoreUint64(&svc.running, 1) + } }() return svc.proxy.Start() diff --git a/internal/next/websvc/dns.go b/internal/next/websvc/dns.go index 4b86dfc1..3ec98692 100644 --- a/internal/next/websvc/dns.go +++ b/internal/next/websvc/dns.go @@ -5,9 +5,9 @@ import ( "fmt" "net/http" "net/netip" + "time" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" - "github.com/AdguardTeam/golibs/timeutil" ) // DNS Settings Handlers @@ -19,10 +19,10 @@ import ( type ReqPatchSettingsDNS struct { // TODO(a.garipov): Add more as we go. - Addresses []netip.AddrPort `json:"addresses"` - BootstrapServers []string `json:"bootstrap_servers"` - UpstreamServers []string `json:"upstream_servers"` - UpstreamTimeout timeutil.Duration `json:"upstream_timeout"` + Addresses []netip.AddrPort `json:"addresses"` + BootstrapServers []string `json:"bootstrap_servers"` + UpstreamServers []string `json:"upstream_servers"` + UpstreamTimeout JSONDuration `json:"upstream_timeout"` } // HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the @@ -30,10 +30,10 @@ type ReqPatchSettingsDNS struct { type HTTPAPIDNSSettings struct { // TODO(a.garipov): Add more as we go. - Addresses []netip.AddrPort `json:"addresses"` - BootstrapServers []string `json:"bootstrap_servers"` - UpstreamServers []string `json:"upstream_servers"` - UpstreamTimeout timeutil.Duration `json:"upstream_timeout"` + Addresses []netip.AddrPort `json:"addresses"` + BootstrapServers []string `json:"bootstrap_servers"` + UpstreamServers []string `json:"upstream_servers"` + UpstreamTimeout JSONDuration `json:"upstream_timeout"` } // handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP @@ -58,7 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques Addresses: req.Addresses, BootstrapServers: req.BootstrapServers, UpstreamServers: req.UpstreamServers, - UpstreamTimeout: req.UpstreamTimeout.Duration, + UpstreamTimeout: time.Duration(req.UpstreamTimeout), } ctx := r.Context() @@ -81,6 +81,6 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques Addresses: newConf.Addresses, BootstrapServers: newConf.BootstrapServers, UpstreamServers: newConf.UpstreamServers, - UpstreamTimeout: timeutil.Duration{Duration: newConf.UpstreamTimeout}, + UpstreamTimeout: JSONDuration(newConf.UpstreamTimeout), }) } diff --git a/internal/next/websvc/http.go b/internal/next/websvc/http.go index f7145bdd..928b37e0 100644 --- a/internal/next/websvc/http.go +++ b/internal/next/websvc/http.go @@ -9,13 +9,10 @@ import ( "time" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/timeutil" ) // HTTP Settings Handlers -// TODO(a.garipov): !! Write tests! - // ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http // HTTP API. type ReqPatchSettingsHTTP struct { @@ -23,9 +20,9 @@ type ReqPatchSettingsHTTP struct { // // TODO(a.garipov): Add wait time. - Addresses []netip.AddrPort `json:"addresses"` - SecureAddresses []netip.AddrPort `json:"secure_addresses"` - Timeout timeutil.Duration `json:"timeout"` + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout JSONDuration `json:"timeout"` } // HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the @@ -33,10 +30,10 @@ type ReqPatchSettingsHTTP struct { type HTTPAPIHTTPSettings struct { // TODO(a.garipov): Add more as we go. - Addresses []netip.AddrPort `json:"addresses"` - SecureAddresses []netip.AddrPort `json:"secure_addresses"` - Timeout timeutil.Duration `json:"timeout"` - ForceHTTPS bool `json:"force_https"` + Addresses []netip.AddrPort `json:"addresses"` + SecureAddresses []netip.AddrPort `json:"secure_addresses"` + Timeout JSONDuration `json:"timeout"` + ForceHTTPS bool `json:"force_https"` } // handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http @@ -58,14 +55,14 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque TLS: svc.tls, Addresses: req.Addresses, SecureAddresses: req.SecureAddresses, - Timeout: req.Timeout.Duration, + Timeout: time.Duration(req.Timeout), ForceHTTPS: svc.forceHTTPS, } writeJSONResponse(w, r, &HTTPAPIHTTPSettings{ Addresses: newConf.Addresses, SecureAddresses: newConf.SecureAddresses, - Timeout: timeutil.Duration{Duration: newConf.Timeout}, + Timeout: JSONDuration(newConf.Timeout), ForceHTTPS: newConf.ForceHTTPS, }) diff --git a/internal/next/websvc/http_test.go b/internal/next/websvc/http_test.go new file mode 100644 index 00000000..6831e8e4 --- /dev/null +++ b/internal/next/websvc/http_test.go @@ -0,0 +1,62 @@ +package websvc_test + +import ( + "context" + "crypto/tls" + "encoding/json" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_HandlePatchSettingsHTTP(t *testing.T) { + wantWeb := &websvc.HTTPAPIHTTPSettings{ + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")}, + Timeout: websvc.JSONDuration(10 * time.Second), + ForceHTTPS: false, + } + + confMgr := newConfigManager() + confMgr.onWeb = func() (c *websvc.Service) { + return websvc.New(&websvc.Config{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{{}}, + }, + Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, + SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, + Timeout: 5 * time.Second, + ForceHTTPS: true, + }) + } + confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) { + return nil + } + + _, addr := newTestServer(t, confMgr) + u := &url.URL{ + Scheme: "http", + Host: addr.String(), + Path: websvc.PathV1SettingsHTTP, + } + + req := jobj{ + "addresses": wantWeb.Addresses, + "secure_addresses": wantWeb.SecureAddresses, + "timeout": wantWeb.Timeout, + "ForceHTTPS": wantWeb.ForceHTTPS, + } + + respBody := httpPatch(t, u, req, http.StatusOK) + resp := &websvc.HTTPAPIHTTPSettings{} + err := json.Unmarshal(respBody, resp) + require.NoError(t, err) + + assert.Equal(t, wantWeb, resp) +} diff --git a/internal/next/websvc/json.go b/internal/next/websvc/json.go index e08d7e2e..15b2f7e2 100644 --- a/internal/next/websvc/json.go +++ b/internal/next/websvc/json.go @@ -13,19 +13,54 @@ import ( // JSON Utilities -// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON -// according to our API conventions. -type jsonTime time.Time +// JSONDuration is a time.Duration that can be decoded from JSON and encoded +// into JSON according to our API conventions. +type JSONDuration time.Duration // type check -var _ json.Marshaler = jsonTime{} +var _ json.Marshaler = JSONDuration(0) + +// MarshalJSON implements the json.Marshaler interface for JSONDuration. err is +// always nil. +func (d JSONDuration) MarshalJSON() (b []byte, err error) { + msec := float64(time.Duration(d)) / nsecPerMsec + b = strconv.AppendFloat(nil, msec, 'f', -1, 64) + + return b, nil +} + +// type check +var _ json.Unmarshaler = (*JSONDuration)(nil) + +// UnmarshalJSON implements the json.Marshaler interface for *JSONDuration. +func (d *JSONDuration) UnmarshalJSON(b []byte) (err error) { + if d == nil { + return fmt.Errorf("json duration is nil") + } + + msec, err := strconv.ParseFloat(string(b), 64) + if err != nil { + return fmt.Errorf("parsing json time: %w", err) + } + + *d = JSONDuration(int64(msec * nsecPerMsec)) + + return nil +} + +// JSONTime is a time.Time that can be decoded from JSON and encoded into JSON +// according to our API conventions. +type JSONTime time.Time + +// type check +var _ json.Marshaler = JSONTime{} // nsecPerMsec is the number of nanoseconds in a millisecond. const nsecPerMsec = float64(time.Millisecond / time.Nanosecond) -// MarshalJSON implements the json.Marshaler interface for jsonTime. err is +// MarshalJSON implements the json.Marshaler interface for JSONTime. err is // always nil. -func (t jsonTime) MarshalJSON() (b []byte, err error) { +func (t JSONTime) MarshalJSON() (b []byte, err error) { msec := float64(time.Time(t).UnixNano()) / nsecPerMsec b = strconv.AppendFloat(nil, msec, 'f', -1, 64) @@ -33,10 +68,10 @@ func (t jsonTime) MarshalJSON() (b []byte, err error) { } // type check -var _ json.Unmarshaler = (*jsonTime)(nil) +var _ json.Unmarshaler = (*JSONTime)(nil) -// UnmarshalJSON implements the json.Marshaler interface for *jsonTime. -func (t *jsonTime) UnmarshalJSON(b []byte) (err error) { +// UnmarshalJSON implements the json.Marshaler interface for *JSONTime. +func (t *JSONTime) UnmarshalJSON(b []byte) (err error) { if t == nil { return fmt.Errorf("json time is nil") } @@ -46,7 +81,7 @@ func (t *jsonTime) UnmarshalJSON(b []byte) (err error) { return fmt.Errorf("parsing json time: %w", err) } - *t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC()) + *t = JSONTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC()) return nil } diff --git a/internal/next/websvc/json_internal_test.go b/internal/next/websvc/json_test.go similarity index 82% rename from internal/next/websvc/json_internal_test.go rename to internal/next/websvc/json_test.go index 69810736..90874958 100644 --- a/internal/next/websvc/json_internal_test.go +++ b/internal/next/websvc/json_test.go @@ -1,17 +1,18 @@ -package websvc +package websvc_test import ( "encoding/json" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // testJSONTime is the JSON time for tests. -var testJSONTime = jsonTime(time.Unix(1_234_567_890, 123_456_000).UTC()) +var testJSONTime = websvc.JSONTime(time.Unix(1_234_567_890, 123_456_000).UTC()) // testJSONTimeStr is the string with the JSON encoding of testJSONTime. const testJSONTimeStr = "1234567890123.456" @@ -20,17 +21,17 @@ func TestJSONTime_MarshalJSON(t *testing.T) { testCases := []struct { name string wantErrMsg string - in jsonTime + in websvc.JSONTime want []byte }{{ name: "unix_zero", wantErrMsg: "", - in: jsonTime(time.Unix(0, 0)), + in: websvc.JSONTime(time.Unix(0, 0)), want: []byte("0"), }, { name: "empty", wantErrMsg: "", - in: jsonTime{}, + in: websvc.JSONTime{}, want: []byte("-6795364578871.345"), }, { name: "time", @@ -50,7 +51,7 @@ func TestJSONTime_MarshalJSON(t *testing.T) { t.Run("json", func(t *testing.T) { in := &struct { - A jsonTime + A websvc.JSONTime }{ A: testJSONTime, } @@ -66,7 +67,7 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) { testCases := []struct { name string wantErrMsg string - want jsonTime + want websvc.JSONTime data []byte }{{ name: "time", @@ -77,13 +78,13 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) { name: "bad", wantErrMsg: `parsing json time: strconv.ParseFloat: parsing "{}": ` + `invalid syntax`, - want: jsonTime{}, + want: websvc.JSONTime{}, data: []byte(`{}`), }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var got jsonTime + var got websvc.JSONTime err := got.UnmarshalJSON(tc.data) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) @@ -92,7 +93,7 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) { } t.Run("nil", func(t *testing.T) { - err := (*jsonTime)(nil).UnmarshalJSON([]byte("0")) + err := (*websvc.JSONTime)(nil).UnmarshalJSON([]byte("0")) require.Error(t, err) msg := err.Error() @@ -102,7 +103,7 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) { t.Run("json", func(t *testing.T) { want := testJSONTime var got struct { - A jsonTime + A websvc.JSONTime } err := json.Unmarshal([]byte(`{"A":`+testJSONTimeStr+`}`), &got) diff --git a/internal/next/websvc/settings.go b/internal/next/websvc/settings.go index 9052e471..1c4267ef 100644 --- a/internal/next/websvc/settings.go +++ b/internal/next/websvc/settings.go @@ -2,8 +2,6 @@ package websvc import ( "net/http" - - "github.com/AdguardTeam/golibs/timeutil" ) // All Settings Handlers @@ -32,12 +30,12 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) Addresses: dnsConf.Addresses, BootstrapServers: dnsConf.BootstrapServers, UpstreamServers: dnsConf.UpstreamServers, - UpstreamTimeout: timeutil.Duration{Duration: dnsConf.UpstreamTimeout}, + UpstreamTimeout: JSONDuration(dnsConf.UpstreamTimeout), }, HTTP: &HTTPAPIHTTPSettings{ Addresses: httpConf.Addresses, SecureAddresses: httpConf.SecureAddresses, - Timeout: timeutil.Duration{Duration: httpConf.Timeout}, + Timeout: JSONDuration(httpConf.Timeout), ForceHTTPS: httpConf.ForceHTTPS, }, }) diff --git a/internal/next/websvc/settings_test.go b/internal/next/websvc/settings_test.go index a1652230..519edb6f 100644 --- a/internal/next/websvc/settings_test.go +++ b/internal/next/websvc/settings_test.go @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" - "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,13 +22,13 @@ func TestService_HandleGetSettingsAll(t *testing.T) { Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")}, BootstrapServers: []string{"94.140.14.140", "94.140.14.141"}, UpstreamServers: []string{"94.140.14.14", "1.1.1.1"}, - UpstreamTimeout: timeutil.Duration{Duration: 1 * time.Second}, + UpstreamTimeout: websvc.JSONDuration(1 * time.Second), } wantWeb := &websvc.HTTPAPIHTTPSettings{ Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")}, SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")}, - Timeout: timeutil.Duration{Duration: 5 * time.Second}, + Timeout: websvc.JSONDuration(5 * time.Second), ForceHTTPS: true, } @@ -39,7 +38,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) { Addresses: wantDNS.Addresses, UpstreamServers: wantDNS.UpstreamServers, BootstrapServers: wantDNS.BootstrapServers, - UpstreamTimeout: wantDNS.UpstreamTimeout.Duration, + UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout), }) require.NoError(t, err) @@ -53,7 +52,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) { }, Addresses: wantWeb.Addresses, SecureAddresses: wantWeb.SecureAddresses, - Timeout: wantWeb.Timeout.Duration, + Timeout: time.Duration(wantWeb.Timeout), ForceHTTPS: true, }) } diff --git a/internal/next/websvc/system.go b/internal/next/websvc/system.go index 47d0c63c..1d329db8 100644 --- a/internal/next/websvc/system.go +++ b/internal/next/websvc/system.go @@ -16,7 +16,7 @@ type RespGetV1SystemInfo struct { Channel string `json:"channel"` OS string `json:"os"` NewVersion string `json:"new_version,omitempty"` - Start jsonTime `json:"start"` + Start JSONTime `json:"start"` Version string `json:"version"` } @@ -29,7 +29,7 @@ func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request OS: runtime.GOOS, // TODO(a.garipov): Fill this when we have an updater. NewVersion: "", - Start: jsonTime(svc.start), + Start: JSONTime(svc.start), Version: version.Version(), }) } diff --git a/internal/next/websvc/waitlistener_internal_test.go b/internal/next/websvc/waitlistener_internal_test.go index 3d51baa6..e151341b 100644 --- a/internal/next/websvc/waitlistener_internal_test.go +++ b/internal/next/websvc/waitlistener_internal_test.go @@ -20,12 +20,8 @@ func TestWaitListener_Accept(t *testing.T) { return nil, nil }, - OnAddr: func() (addr net.Addr) { - panic("not implemented") - }, - OnClose: func() (err error) { - panic("not implemented") - }, + OnAddr: func() (addr net.Addr) { panic("not implemented") }, + OnClose: func() (err error) { panic("not implemented") }, } wg := &sync.WaitGroup{} diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index 476fbc01..c58010c4 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -1,7 +1,9 @@ package websvc_test import ( + "bytes" "context" + "encoding/json" "io" "net/http" "net/netip" @@ -113,6 +115,9 @@ func newTestServer( return svc, c.Addresses[0] } +// jobj is a utility alias for JSON objects. +type jobj map[string]any + // httpGet is a helper that performs an HTTP GET request and returns the body of // the response as well as checks that the status code is correct. // @@ -138,6 +143,35 @@ func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) { return body } +// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded +// reqBody as the request body and returns the body of the response as well as +// checks that the status code is correct. +// +// TODO(a.garipov): Add helpers for other methods. +func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) { + t.Helper() + + b, err := json.Marshal(reqBody) + require.NoErrorf(t, err, "marshaling reqBody") + + req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b)) + require.NoErrorf(t, err, "creating req") + + httpCli := &http.Client{ + Timeout: testTimeout, + } + resp, err := httpCli.Do(req) + require.NoErrorf(t, err, "performing req") + require.Equal(t, wantCode, resp.StatusCode) + + testutil.CleanupAndRequireSuccess(t, resp.Body.Close) + + body, err = io.ReadAll(resp.Body) + require.NoErrorf(t, err, "reading body") + + return body +} + func TestService_Start_getHealthCheck(t *testing.T) { confMgr := newConfigManager() _, addr := newTestServer(t, confMgr) diff --git a/openapi/v1.yaml b/openapi/v1.yaml index 77eb1a09..adab6d4d 100644 --- a/openapi/v1.yaml +++ b/openapi/v1.yaml @@ -2289,7 +2289,7 @@ 'upstream_servers': - '1.1.1.1' - '8.8.8.8' - 'upstream_timeout': '1s' + 'upstream_timeout': 1000 'required': - 'addresses' - 'blocking_mode' @@ -2397,8 +2397,9 @@ 'type': 'array' 'upstream_timeout': 'description': > - Upstream request timeout, as a human readable duration. - 'type': 'string' + Upstream request timeout, in milliseconds. + 'format': 'double' + 'type': 'number' 'type': 'object' 'DnsType': @@ -3505,14 +3506,16 @@ 'addresses': - '127.0.0.1:80' - '192.168.1.1:80' + 'force_https': true 'secure_addresses': - '127.0.0.1:443' - '192.168.1.1:443' - 'force_https': true + 'timeout': 10000 'required': - 'addresses' - - 'secure_addresses' - 'force_https' + - 'secure_addresses' + - 'timeout' 'HttpSettingsPatch': 'description': > @@ -3539,6 +3542,11 @@ 'items': 'type': 'string' 'type': 'array' + 'timeout': + 'description': > + HTTP request timeout, in milliseconds. + 'format': 'double' + 'type': 'number' 'type': 'object' 'InternalServerErrorResp':