websvc: imp tests

This commit is contained in:
Ainar Garipov
2022-09-09 15:05:33 +03:00
parent dbfc8ae362
commit 1989c91c07
12 changed files with 201 additions and 67 deletions

View File

@@ -173,7 +173,11 @@ func (svc *Service) Start() (err error) {
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
// tell when all servers are actually up, so at best this is merely an
// assumption.
atomic.StoreUint64(&svc.running, 1)
if err != nil {
atomic.StoreUint64(&svc.running, 0)
} else {
atomic.StoreUint64(&svc.running, 1)
}
}()
return svc.proxy.Start()

View File

@@ -5,9 +5,9 @@ import (
"fmt"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/timeutil"
)
// DNS Settings Handlers
@@ -19,10 +19,10 @@ import (
type ReqPatchSettingsDNS struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout timeutil.Duration `json:"upstream_timeout"`
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout JSONDuration `json:"upstream_timeout"`
}
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
@@ -30,10 +30,10 @@ type ReqPatchSettingsDNS struct {
type HTTPAPIDNSSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout timeutil.Duration `json:"upstream_timeout"`
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout JSONDuration `json:"upstream_timeout"`
}
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
@@ -58,7 +58,7 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
Addresses: req.Addresses,
BootstrapServers: req.BootstrapServers,
UpstreamServers: req.UpstreamServers,
UpstreamTimeout: req.UpstreamTimeout.Duration,
UpstreamTimeout: time.Duration(req.UpstreamTimeout),
}
ctx := r.Context()
@@ -81,6 +81,6 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
Addresses: newConf.Addresses,
BootstrapServers: newConf.BootstrapServers,
UpstreamServers: newConf.UpstreamServers,
UpstreamTimeout: timeutil.Duration{Duration: newConf.UpstreamTimeout},
UpstreamTimeout: JSONDuration(newConf.UpstreamTimeout),
})
}

View File

@@ -9,13 +9,10 @@ import (
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
// HTTP Settings Handlers
// TODO(a.garipov): !! Write tests!
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
// HTTP API.
type ReqPatchSettingsHTTP struct {
@@ -23,9 +20,9 @@ type ReqPatchSettingsHTTP struct {
//
// TODO(a.garipov): Add wait time.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout timeutil.Duration `json:"timeout"`
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout JSONDuration `json:"timeout"`
}
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
@@ -33,10 +30,10 @@ type ReqPatchSettingsHTTP struct {
type HTTPAPIHTTPSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout timeutil.Duration `json:"timeout"`
ForceHTTPS bool `json:"force_https"`
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout JSONDuration `json:"timeout"`
ForceHTTPS bool `json:"force_https"`
}
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
@@ -58,14 +55,14 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
TLS: svc.tls,
Addresses: req.Addresses,
SecureAddresses: req.SecureAddresses,
Timeout: req.Timeout.Duration,
Timeout: time.Duration(req.Timeout),
ForceHTTPS: svc.forceHTTPS,
}
writeJSONResponse(w, r, &HTTPAPIHTTPSettings{
Addresses: newConf.Addresses,
SecureAddresses: newConf.SecureAddresses,
Timeout: timeutil.Duration{Duration: newConf.Timeout},
Timeout: JSONDuration(newConf.Timeout),
ForceHTTPS: newConf.ForceHTTPS,
})

View File

@@ -0,0 +1,62 @@
package websvc_test
import (
"context"
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandlePatchSettingsHTTP(t *testing.T) {
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")},
Timeout: websvc.JSONDuration(10 * time.Second),
ForceHTTPS: false,
}
confMgr := newConfigManager()
confMgr.onWeb = func() (c *websvc.Service) {
return websvc.New(&websvc.Config{
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: 5 * time.Second,
ForceHTTPS: true,
})
}
confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) {
return nil
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsHTTP,
}
req := jobj{
"addresses": wantWeb.Addresses,
"secure_addresses": wantWeb.SecureAddresses,
"timeout": wantWeb.Timeout,
"ForceHTTPS": wantWeb.ForceHTTPS,
}
respBody := httpPatch(t, u, req, http.StatusOK)
resp := &websvc.HTTPAPIHTTPSettings{}
err := json.Unmarshal(respBody, resp)
require.NoError(t, err)
assert.Equal(t, wantWeb, resp)
}

View File

@@ -13,19 +13,54 @@ import (
// JSON Utilities
// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON
// according to our API conventions.
type jsonTime time.Time
// JSONDuration is a time.Duration that can be decoded from JSON and encoded
// into JSON according to our API conventions.
type JSONDuration time.Duration
// type check
var _ json.Marshaler = jsonTime{}
var _ json.Marshaler = JSONDuration(0)
// MarshalJSON implements the json.Marshaler interface for JSONDuration. err is
// always nil.
func (d JSONDuration) MarshalJSON() (b []byte, err error) {
msec := float64(time.Duration(d)) / nsecPerMsec
b = strconv.AppendFloat(nil, msec, 'f', -1, 64)
return b, nil
}
// type check
var _ json.Unmarshaler = (*JSONDuration)(nil)
// UnmarshalJSON implements the json.Marshaler interface for *JSONDuration.
func (d *JSONDuration) UnmarshalJSON(b []byte) (err error) {
if d == nil {
return fmt.Errorf("json duration is nil")
}
msec, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return fmt.Errorf("parsing json time: %w", err)
}
*d = JSONDuration(int64(msec * nsecPerMsec))
return nil
}
// JSONTime is a time.Time that can be decoded from JSON and encoded into JSON
// according to our API conventions.
type JSONTime time.Time
// type check
var _ json.Marshaler = JSONTime{}
// nsecPerMsec is the number of nanoseconds in a millisecond.
const nsecPerMsec = float64(time.Millisecond / time.Nanosecond)
// MarshalJSON implements the json.Marshaler interface for jsonTime. err is
// MarshalJSON implements the json.Marshaler interface for JSONTime. err is
// always nil.
func (t jsonTime) MarshalJSON() (b []byte, err error) {
func (t JSONTime) MarshalJSON() (b []byte, err error) {
msec := float64(time.Time(t).UnixNano()) / nsecPerMsec
b = strconv.AppendFloat(nil, msec, 'f', -1, 64)
@@ -33,10 +68,10 @@ func (t jsonTime) MarshalJSON() (b []byte, err error) {
}
// type check
var _ json.Unmarshaler = (*jsonTime)(nil)
var _ json.Unmarshaler = (*JSONTime)(nil)
// UnmarshalJSON implements the json.Marshaler interface for *jsonTime.
func (t *jsonTime) UnmarshalJSON(b []byte) (err error) {
// UnmarshalJSON implements the json.Marshaler interface for *JSONTime.
func (t *JSONTime) UnmarshalJSON(b []byte) (err error) {
if t == nil {
return fmt.Errorf("json time is nil")
}
@@ -46,7 +81,7 @@ func (t *jsonTime) UnmarshalJSON(b []byte) (err error) {
return fmt.Errorf("parsing json time: %w", err)
}
*t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC())
*t = JSONTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC())
return nil
}

View File

@@ -1,17 +1,18 @@
package websvc
package websvc_test
import (
"encoding/json"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testJSONTime is the JSON time for tests.
var testJSONTime = jsonTime(time.Unix(1_234_567_890, 123_456_000).UTC())
var testJSONTime = websvc.JSONTime(time.Unix(1_234_567_890, 123_456_000).UTC())
// testJSONTimeStr is the string with the JSON encoding of testJSONTime.
const testJSONTimeStr = "1234567890123.456"
@@ -20,17 +21,17 @@ func TestJSONTime_MarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
in jsonTime
in websvc.JSONTime
want []byte
}{{
name: "unix_zero",
wantErrMsg: "",
in: jsonTime(time.Unix(0, 0)),
in: websvc.JSONTime(time.Unix(0, 0)),
want: []byte("0"),
}, {
name: "empty",
wantErrMsg: "",
in: jsonTime{},
in: websvc.JSONTime{},
want: []byte("-6795364578871.345"),
}, {
name: "time",
@@ -50,7 +51,7 @@ func TestJSONTime_MarshalJSON(t *testing.T) {
t.Run("json", func(t *testing.T) {
in := &struct {
A jsonTime
A websvc.JSONTime
}{
A: testJSONTime,
}
@@ -66,7 +67,7 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want jsonTime
want websvc.JSONTime
data []byte
}{{
name: "time",
@@ -77,13 +78,13 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) {
name: "bad",
wantErrMsg: `parsing json time: strconv.ParseFloat: parsing "{}": ` +
`invalid syntax`,
want: jsonTime{},
want: websvc.JSONTime{},
data: []byte(`{}`),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got jsonTime
var got websvc.JSONTime
err := got.UnmarshalJSON(tc.data)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
@@ -92,7 +93,7 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) {
}
t.Run("nil", func(t *testing.T) {
err := (*jsonTime)(nil).UnmarshalJSON([]byte("0"))
err := (*websvc.JSONTime)(nil).UnmarshalJSON([]byte("0"))
require.Error(t, err)
msg := err.Error()
@@ -102,7 +103,7 @@ func TestJSONTime_UnmarshalJSON(t *testing.T) {
t.Run("json", func(t *testing.T) {
want := testJSONTime
var got struct {
A jsonTime
A websvc.JSONTime
}
err := json.Unmarshal([]byte(`{"A":`+testJSONTimeStr+`}`), &got)

View File

@@ -2,8 +2,6 @@ package websvc
import (
"net/http"
"github.com/AdguardTeam/golibs/timeutil"
)
// All Settings Handlers
@@ -32,12 +30,12 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request)
Addresses: dnsConf.Addresses,
BootstrapServers: dnsConf.BootstrapServers,
UpstreamServers: dnsConf.UpstreamServers,
UpstreamTimeout: timeutil.Duration{Duration: dnsConf.UpstreamTimeout},
UpstreamTimeout: JSONDuration(dnsConf.UpstreamTimeout),
},
HTTP: &HTTPAPIHTTPSettings{
Addresses: httpConf.Addresses,
SecureAddresses: httpConf.SecureAddresses,
Timeout: timeutil.Duration{Duration: httpConf.Timeout},
Timeout: JSONDuration(httpConf.Timeout),
ForceHTTPS: httpConf.ForceHTTPS,
},
})

View File

@@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -23,13 +22,13 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")},
BootstrapServers: []string{"94.140.14.140", "94.140.14.141"},
UpstreamServers: []string{"94.140.14.14", "1.1.1.1"},
UpstreamTimeout: timeutil.Duration{Duration: 1 * time.Second},
UpstreamTimeout: websvc.JSONDuration(1 * time.Second),
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: timeutil.Duration{Duration: 5 * time.Second},
Timeout: websvc.JSONDuration(5 * time.Second),
ForceHTTPS: true,
}
@@ -39,7 +38,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
Addresses: wantDNS.Addresses,
UpstreamServers: wantDNS.UpstreamServers,
BootstrapServers: wantDNS.BootstrapServers,
UpstreamTimeout: wantDNS.UpstreamTimeout.Duration,
UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout),
})
require.NoError(t, err)
@@ -53,7 +52,7 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: wantWeb.Timeout.Duration,
Timeout: time.Duration(wantWeb.Timeout),
ForceHTTPS: true,
})
}

View File

@@ -16,7 +16,7 @@ type RespGetV1SystemInfo struct {
Channel string `json:"channel"`
OS string `json:"os"`
NewVersion string `json:"new_version,omitempty"`
Start jsonTime `json:"start"`
Start JSONTime `json:"start"`
Version string `json:"version"`
}
@@ -29,7 +29,7 @@ func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request
OS: runtime.GOOS,
// TODO(a.garipov): Fill this when we have an updater.
NewVersion: "",
Start: jsonTime(svc.start),
Start: JSONTime(svc.start),
Version: version.Version(),
})
}

View File

@@ -20,12 +20,8 @@ func TestWaitListener_Accept(t *testing.T) {
return nil, nil
},
OnAddr: func() (addr net.Addr) {
panic("not implemented")
},
OnClose: func() (err error) {
panic("not implemented")
},
OnAddr: func() (addr net.Addr) { panic("not implemented") },
OnClose: func() (err error) { panic("not implemented") },
}
wg := &sync.WaitGroup{}

View File

@@ -1,7 +1,9 @@
package websvc_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/netip"
@@ -113,6 +115,9 @@ func newTestServer(
return svc, c.Addresses[0]
}
// jobj is a utility alias for JSON objects.
type jobj map[string]any
// httpGet is a helper that performs an HTTP GET request and returns the body of
// the response as well as checks that the status code is correct.
//
@@ -138,6 +143,35 @@ func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
return body
}
// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded
// reqBody as the request body and returns the body of the response as well as
// checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) {
t.Helper()
b, err := json.Marshal(reqBody)
require.NoErrorf(t, err, "marshaling reqBody")
req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b))
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
func TestService_Start_getHealthCheck(t *testing.T) {
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)

View File

@@ -2289,7 +2289,7 @@
'upstream_servers':
- '1.1.1.1'
- '8.8.8.8'
'upstream_timeout': '1s'
'upstream_timeout': 1000
'required':
- 'addresses'
- 'blocking_mode'
@@ -2397,8 +2397,9 @@
'type': 'array'
'upstream_timeout':
'description': >
Upstream request timeout, as a human readable duration.
'type': 'string'
Upstream request timeout, in milliseconds.
'format': 'double'
'type': 'number'
'type': 'object'
'DnsType':
@@ -3505,14 +3506,16 @@
'addresses':
- '127.0.0.1:80'
- '192.168.1.1:80'
'force_https': true
'secure_addresses':
- '127.0.0.1:443'
- '192.168.1.1:443'
'force_https': true
'timeout': 10000
'required':
- 'addresses'
- 'secure_addresses'
- 'force_https'
- 'secure_addresses'
- 'timeout'
'HttpSettingsPatch':
'description': >
@@ -3539,6 +3542,11 @@
'items':
'type': 'string'
'type': 'array'
'timeout':
'description': >
HTTP request timeout, in milliseconds.
'format': 'double'
'type': 'number'
'type': 'object'
'InternalServerErrorResp':