all: mv v1 to next; imp tests, docs

This commit is contained in:
Ainar Garipov
2022-09-09 14:25:48 +03:00
parent d74ba3cb9d
commit dbfc8ae362
26 changed files with 122 additions and 87 deletions

33
internal/next/agh/agh.go Normal file
View File

@@ -0,0 +1,33 @@
// Package agh contains common entities and interfaces of AdGuard Home.
//
// TODO(a.garipov): Move to the upper-level internal/.
package agh
import "context"
// Service is the interface for API servers.
//
// TODO(a.garipov): Consider adding a context to Start.
//
// TODO(a.garipov): Consider adding a Wait method or making an extension
// interface for that.
type Service interface {
// Start starts the service. It does not block.
Start() (err error)
// Shutdown gracefully stops the service. ctx is used to determine
// a timeout before trying to stop the service less gracefully.
Shutdown(ctx context.Context) (err error)
}
// type check
var _ Service = EmptyService{}
// EmptyService is a Service that does nothing.
type EmptyService struct{}
// Start implements the Service interface for EmptyService.
func (EmptyService) Start() (err error) { return nil }
// Shutdown implements the Service interface for EmptyService.
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }

70
internal/next/cmd/cmd.go Normal file
View File

@@ -0,0 +1,70 @@
// Package cmd is the AdGuard Home entry point. It contains the on-disk
// configuration file utilities, signal processing logic, and so on.
//
// TODO(a.garipov): Move to the upper-level internal/.
package cmd
import (
"context"
"io/fs"
"math/rand"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/log"
)
// Main is the entry point of application.
func Main(clientBuildFS fs.FS) {
// Initial Configuration
start := time.Now()
rand.Seed(start.UnixNano())
// TODO(a.garipov): Set up logging.
// Web Service
// TODO(a.garipov): Use in the Web service.
_ = clientBuildFS
// TODO(a.garipov): Make configurable.
web := websvc.New(&websvc.Config{
// TODO(a.garipov): Use an actual implementation.
ConfigManager: nil,
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")},
Start: start,
Timeout: 60 * time.Second,
ForceHTTPS: false,
})
err := web.Start()
fatalOnError(err)
sigHdlr := newSignalHandler(
web,
)
go sigHdlr.handle()
select {}
}
// defaultTimeout is the timeout used for some operations where another timeout
// hasn't been defined yet.
const defaultTimeout = 15 * time.Second
// ctxWithDefaultTimeout is a helper function that returns a context with
// timeout set to defaultTimeout.
func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) {
return context.WithTimeout(context.Background(), defaultTimeout)
}
// fatalOnError is a helper that exits the program with an error code if err is
// not nil. It must only be used within Main.
func fatalOnError(err error) {
if err != nil {
log.Fatal(err)
}
}

View File

@@ -0,0 +1,70 @@
package cmd
import (
"os"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/golibs/log"
)
// signalHandler processes incoming signals and shuts services down.
type signalHandler struct {
signal chan os.Signal
// services are the services that are shut down before application
// exiting.
services []agh.Service
}
// handle processes OS signals.
func (h *signalHandler) handle() {
defer log.OnPanic("signalHandler.handle")
for sig := range h.signal {
log.Info("sighdlr: received signal %q", sig)
if aghos.IsShutdownSignal(sig) {
h.shutdown()
}
}
}
// Exit status constants.
const (
statusSuccess = 0
statusError = 1
)
// shutdown gracefully shuts down all services.
func (h *signalHandler) shutdown() {
ctx, cancel := ctxWithDefaultTimeout()
defer cancel()
status := statusSuccess
log.Info("sighdlr: shutting down services")
for i, service := range h.services {
err := service.Shutdown(ctx)
if err != nil {
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
status = statusError
}
}
log.Info("sighdlr: shutting down adguard home")
os.Exit(status)
}
// newSignalHandler returns a new signalHandler that shuts down svcs.
func newSignalHandler(svcs ...agh.Service) (h *signalHandler) {
h = &signalHandler{
signal: make(chan os.Signal, 1),
services: svcs,
}
aghos.NotifyShutdownSignal(h.signal)
return h
}

View File

@@ -0,0 +1,223 @@
// Package dnssvc contains the AdGuard Home DNS service.
//
// TODO(a.garipov): Define, if all methods of a *Service should work with a nil
// receiver.
package dnssvc
import (
"context"
"fmt"
"net"
"net/netip"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
// and replacement of module dnsproxy.
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
)
// Config is the AdGuard Home DNS service configuration structure.
//
// TODO(a.garipov): Add timeout for incoming requests.
type Config struct {
// Addresses are the addresses on which to serve plain DNS queries.
Addresses []netip.AddrPort
// Upstreams are the DNS upstreams to use. If not set, upstreams are
// created using data from BootstrapServers, UpstreamServers, and
// UpstreamTimeout.
//
// TODO(a.garipov): Think of a better scheme. Those other three parameters
// are here only to make Config work properly.
Upstreams []upstream.Upstream
// BootstrapServers are the addresses for bootstrapping the upstream DNS
// server addresses.
BootstrapServers []string
// UpstreamServers are the upstream DNS server addresses to use.
UpstreamServers []string
// UpstreamTimeout is the timeout for upstream requests.
UpstreamTimeout time.Duration
}
// Service is the AdGuard Home DNS service. A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
// running is an atomic boolean value. Keep it the first value in the
// struct to ensure atomic alignment. 0 means that the service is not
// running, 1 means that it is running.
//
// TODO(a.garipov): Use [atomic.Bool] in Go 1.19 or get rid of it
// completely.
running uint64
proxy *proxy.Proxy
bootstraps []string
upstreams []string
upsTimeout time.Duration
}
// New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. The fields of c must not be modified after
// calling New.
func New(c *Config) (svc *Service, err error) {
if c == nil {
return nil, nil
}
svc = &Service{
bootstraps: c.BootstrapServers,
upstreams: c.UpstreamServers,
upsTimeout: c.UpstreamTimeout,
}
var upstreams []upstream.Upstream
if len(c.Upstreams) > 0 {
upstreams = c.Upstreams
} else {
upstreams, err = addressesToUpstreams(
c.UpstreamServers,
c.BootstrapServers,
c.UpstreamTimeout,
)
if err != nil {
return nil, fmt.Errorf("converting upstreams: %w", err)
}
}
svc.proxy = &proxy.Proxy{
Config: proxy.Config{
UDPListenAddr: udpAddrs(c.Addresses),
TCPListenAddr: tcpAddrs(c.Addresses),
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: upstreams,
},
},
}
err = svc.proxy.Init()
if err != nil {
return nil, fmt.Errorf("proxy: %w", err)
}
return svc, nil
}
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
// accepts a slice of addresses and other upstream parameters, and returns a
// slice of upstreams.
func addressesToUpstreams(
upsStrs []string,
bootstraps []string,
timeout time.Duration,
) (upstreams []upstream.Upstream, err error) {
upstreams = make([]upstream.Upstream, len(upsStrs))
for i, upsStr := range upsStrs {
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
Bootstrap: bootstraps,
Timeout: timeout,
})
if err != nil {
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
}
}
return upstreams, nil
}
// tcpAddrs converts []netip.AddrPort into []*net.TCPAddr.
func tcpAddrs(addrPorts []netip.AddrPort) (tcpAddrs []*net.TCPAddr) {
if addrPorts == nil {
return nil
}
tcpAddrs = make([]*net.TCPAddr, len(addrPorts))
for i, a := range addrPorts {
tcpAddrs[i] = net.TCPAddrFromAddrPort(a)
}
return tcpAddrs
}
// udpAddrs converts []netip.AddrPort into []*net.UDPAddr.
func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
if addrPorts == nil {
return nil
}
udpAddrs = make([]*net.UDPAddr, len(addrPorts))
for i, a := range addrPorts {
udpAddrs[i] = net.UDPAddrFromAddrPort(a)
}
return udpAddrs
}
// type check
var _ agh.Service = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all DNS servers have tried to start, but there is no
// guarantee that they did. Errors from the servers are written to the log.
func (svc *Service) Start() (err error) {
if svc == nil {
return nil
}
defer func() {
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
// tell when all servers are actually up, so at best this is merely an
// assumption.
atomic.StoreUint64(&svc.running, 1)
}()
return svc.proxy.Start()
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
if svc == nil {
return nil
}
return svc.proxy.Stop()
}
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
// TODO(a.garipov): Do we need to get the TCP addresses separately?
var addrs []netip.AddrPort
if atomic.LoadUint64(&svc.running) == 1 {
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.(*net.UDPAddr).AddrPort()
}
} else {
conf := svc.proxy.Config
udpAddrs := conf.UDPListenAddr
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.AddrPort()
}
}
c = &Config{
Addresses: addrs,
BootstrapServers: svc.bootstraps,
UpstreamServers: svc.upstreams,
UpstreamTimeout: svc.upsTimeout,
}
return c
}

View File

@@ -0,0 +1,89 @@
package dnssvc_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 100 * time.Millisecond
func TestService(t *testing.T) {
const (
bootstrapAddr = "bootstrap.example"
upstreamAddr = "upstream.example"
)
ups := &aghtest.UpstreamMock{
OnAddress: func() (addr string) {
return upstreamAddr
},
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = (&dns.Msg{}).SetReply(req)
return resp, nil
},
}
c := &dnssvc.Config{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
Upstreams: []upstream.Upstream{ups},
BootstrapServers: []string{bootstrapAddr},
UpstreamServers: []string{upstreamAddr},
UpstreamTimeout: testTimeout,
}
svc, err := dnssvc.New(c)
require.NoError(t, err)
err = svc.Start()
require.NoError(t, err)
gotConf := svc.Config()
require.NotNil(t, gotConf)
require.Len(t, gotConf.Addresses, 1)
addr := gotConf.Addresses[0]
t.Run("dns", func(t *testing.T) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{}
resp, _, excErr := cli.ExchangeContext(ctx, req, addr.String())
require.NoError(t, excErr)
assert.NotNil(t, resp)
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
err = svc.Shutdown(ctx)
require.NoError(t, err)
}

View File

@@ -0,0 +1,86 @@
package websvc
import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/timeutil"
)
// DNS Settings Handlers
// TODO(a.garipov): !! Write tests!
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
// HTTP API.
type ReqPatchSettingsDNS struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout timeutil.Duration `json:"upstream_timeout"`
}
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
// DnsSettings object in the OpenAPI specification.
type HTTPAPIDNSSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout timeutil.Duration `json:"upstream_timeout"`
}
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
// API.
func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsDNS{
Addresses: []netip.AddrPort{},
BootstrapServers: []string{},
UpstreamServers: []string{},
}
// TODO(a.garipov): Validate nulls and proper JSON patch.
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
writeHTTPError(w, r, fmt.Errorf("decoding: %w", err))
return
}
newConf := &dnssvc.Config{
Addresses: req.Addresses,
BootstrapServers: req.BootstrapServers,
UpstreamServers: req.UpstreamServers,
UpstreamTimeout: req.UpstreamTimeout.Duration,
}
ctx := r.Context()
err = svc.confMgr.UpdateDNS(ctx, newConf)
if err != nil {
writeHTTPError(w, r, fmt.Errorf("updating: %w", err))
return
}
newSvc := svc.confMgr.DNS()
err = newSvc.Start()
if err != nil {
writeHTTPError(w, r, fmt.Errorf("starting new service: %w", err))
return
}
writeJSONResponse(w, r, &HTTPAPIDNSSettings{
Addresses: newConf.Addresses,
BootstrapServers: newConf.BootstrapServers,
UpstreamServers: newConf.UpstreamServers,
UpstreamTimeout: timeutil.Duration{Duration: newConf.UpstreamTimeout},
})
}

View File

@@ -0,0 +1,104 @@
package websvc
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
)
// HTTP Settings Handlers
// TODO(a.garipov): !! Write tests!
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
// HTTP API.
type ReqPatchSettingsHTTP struct {
// TODO(a.garipov): Add more as we go.
//
// TODO(a.garipov): Add wait time.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout timeutil.Duration `json:"timeout"`
}
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
// HttpSettings object in the OpenAPI specification.
type HTTPAPIHTTPSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout timeutil.Duration `json:"timeout"`
ForceHTTPS bool `json:"force_https"`
}
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
// HTTP API.
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsHTTP{}
// TODO(a.garipov): Validate nulls and proper JSON patch.
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
writeHTTPError(w, r, fmt.Errorf("decoding: %w", err))
return
}
newConf := &Config{
ConfigManager: svc.confMgr,
TLS: svc.tls,
Addresses: req.Addresses,
SecureAddresses: req.SecureAddresses,
Timeout: req.Timeout.Duration,
ForceHTTPS: svc.forceHTTPS,
}
writeJSONResponse(w, r, &HTTPAPIHTTPSettings{
Addresses: newConf.Addresses,
SecureAddresses: newConf.SecureAddresses,
Timeout: timeutil.Duration{Duration: newConf.Timeout},
ForceHTTPS: newConf.ForceHTTPS,
})
cancelUpd := func() {}
updCtx := context.Background()
ctx := r.Context()
if deadline, ok := ctx.Deadline(); ok {
updCtx, cancelUpd = context.WithDeadline(updCtx, deadline)
}
// Launch the new HTTP service in a separate goroutine to let this handler
// finish and thus, this server to shutdown.
go func() {
defer cancelUpd()
updErr := svc.confMgr.UpdateWeb(updCtx, newConf)
if updErr != nil {
writeHTTPError(w, r, fmt.Errorf("updating: %w", updErr))
return
}
// TODO(a.garipov): !! Add some kind of timeout? Context?
var newSvc *Service
for newSvc = svc.confMgr.Web(); newSvc == svc; {
log.Debug("websvc: waiting for new websvc to be configured")
time.Sleep(1 * time.Second)
}
updErr = newSvc.Start()
if updErr != nil {
log.Error("websvc: new svc failed to start: %s", updErr)
}
}()
}

View File

@@ -0,0 +1,74 @@
package websvc
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/AdguardTeam/golibs/log"
)
// JSON Utilities
// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON
// according to our API conventions.
type jsonTime time.Time
// type check
var _ json.Marshaler = jsonTime{}
// nsecPerMsec is the number of nanoseconds in a millisecond.
const nsecPerMsec = float64(time.Millisecond / time.Nanosecond)
// MarshalJSON implements the json.Marshaler interface for jsonTime. err is
// always nil.
func (t jsonTime) MarshalJSON() (b []byte, err error) {
msec := float64(time.Time(t).UnixNano()) / nsecPerMsec
b = strconv.AppendFloat(nil, msec, 'f', -1, 64)
return b, nil
}
// type check
var _ json.Unmarshaler = (*jsonTime)(nil)
// UnmarshalJSON implements the json.Marshaler interface for *jsonTime.
func (t *jsonTime) UnmarshalJSON(b []byte) (err error) {
if t == nil {
return fmt.Errorf("json time is nil")
}
msec, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return fmt.Errorf("parsing json time: %w", err)
}
*t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC())
return nil
}
// writeJSONResponse encodes v into w and logs any errors it encounters. r is
// used to get additional information from the request.
func writeJSONResponse(w io.Writer, r *http.Request, v any) {
err := json.NewEncoder(w).Encode(v)
if err != nil {
log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err)
}
}
// writeHTTPError is a helper for logging and writing HTTP errors.
//
// TODO(a.garipov): Improve codes, and add JSON error codes.
func writeHTTPError(w http.ResponseWriter, r *http.Request, err error) {
log.Error("websvc: %s %s: %s", r.Method, r.URL.Path, err)
w.WriteHeader(http.StatusUnprocessableEntity)
_, werr := io.WriteString(w, err.Error())
if werr != nil {
log.Debug("websvc: writing error resp to %s %s: %s", r.Method, r.URL.Path, werr)
}
}

View File

@@ -0,0 +1,113 @@
package websvc
import (
"encoding/json"
"testing"
"time"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testJSONTime is the JSON time for tests.
var testJSONTime = jsonTime(time.Unix(1_234_567_890, 123_456_000).UTC())
// testJSONTimeStr is the string with the JSON encoding of testJSONTime.
const testJSONTimeStr = "1234567890123.456"
func TestJSONTime_MarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
in jsonTime
want []byte
}{{
name: "unix_zero",
wantErrMsg: "",
in: jsonTime(time.Unix(0, 0)),
want: []byte("0"),
}, {
name: "empty",
wantErrMsg: "",
in: jsonTime{},
want: []byte("-6795364578871.345"),
}, {
name: "time",
wantErrMsg: "",
in: testJSONTime,
want: []byte(testJSONTimeStr),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.in.MarshalJSON()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
in := &struct {
A jsonTime
}{
A: testJSONTime,
}
got, err := json.Marshal(in)
require.NoError(t, err)
assert.Equal(t, []byte(`{"A":`+testJSONTimeStr+`}`), got)
})
}
func TestJSONTime_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want jsonTime
data []byte
}{{
name: "time",
wantErrMsg: "",
want: testJSONTime,
data: []byte(testJSONTimeStr),
}, {
name: "bad",
wantErrMsg: `parsing json time: strconv.ParseFloat: parsing "{}": ` +
`invalid syntax`,
want: jsonTime{},
data: []byte(`{}`),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got jsonTime
err := got.UnmarshalJSON(tc.data)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("nil", func(t *testing.T) {
err := (*jsonTime)(nil).UnmarshalJSON([]byte("0"))
require.Error(t, err)
msg := err.Error()
assert.Equal(t, "json time is nil", msg)
})
t.Run("json", func(t *testing.T) {
want := testJSONTime
var got struct {
A jsonTime
}
err := json.Unmarshal([]byte(`{"A":`+testJSONTimeStr+`}`), &got)
require.NoError(t, err)
assert.Equal(t, want, got.A)
})
}

View File

@@ -0,0 +1,16 @@
package websvc
import "net/http"
// Middlewares
// jsonMw sets the content type of the response to application/json.
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
h.ServeHTTP(w, r)
}
return http.HandlerFunc(f)
}

View File

@@ -0,0 +1,11 @@
package websvc
// Path constants
const (
PathHealthCheck = "/health-check"
PathV1SettingsAll = "/api/v1/settings/all"
PathV1SettingsDNS = "/api/v1/settings/dns"
PathV1SettingsHTTP = "/api/v1/settings/http"
PathV1SystemInfo = "/api/v1/system/info"
)

View File

@@ -0,0 +1,44 @@
package websvc
import (
"net/http"
"github.com/AdguardTeam/golibs/timeutil"
)
// All Settings Handlers
// RespGetV1SettingsAll describes the response of the GET /api/v1/settings/all
// HTTP API.
type RespGetV1SettingsAll struct {
// TODO(a.garipov): Add more as we go.
DNS *HTTPAPIDNSSettings `json:"dns"`
HTTP *HTTPAPIHTTPSettings `json:"http"`
}
// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP
// API.
func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) {
dnsSvc := svc.confMgr.DNS()
dnsConf := dnsSvc.Config()
webSvc := svc.confMgr.Web()
httpConf := webSvc.Config()
// TODO(a.garipov): Add all currently supported parameters.
writeJSONResponse(w, r, &RespGetV1SettingsAll{
DNS: &HTTPAPIDNSSettings{
Addresses: dnsConf.Addresses,
BootstrapServers: dnsConf.BootstrapServers,
UpstreamServers: dnsConf.UpstreamServers,
UpstreamTimeout: timeutil.Duration{Duration: dnsConf.UpstreamTimeout},
},
HTTP: &HTTPAPIHTTPSettings{
Addresses: httpConf.Addresses,
SecureAddresses: httpConf.SecureAddresses,
Timeout: timeutil.Duration{Duration: httpConf.Timeout},
ForceHTTPS: httpConf.ForceHTTPS,
},
})
}

View File

@@ -0,0 +1,75 @@
package websvc_test
import (
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandleGetSettingsAll(t *testing.T) {
// TODO(a.garipov): Add all currently supported parameters.
wantDNS := &websvc.HTTPAPIDNSSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")},
BootstrapServers: []string{"94.140.14.140", "94.140.14.141"},
UpstreamServers: []string{"94.140.14.14", "1.1.1.1"},
UpstreamTimeout: timeutil.Duration{Duration: 1 * time.Second},
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: timeutil.Duration{Duration: 5 * time.Second},
ForceHTTPS: true,
}
confMgr := newConfigManager()
confMgr.onDNS = func() (c *dnssvc.Service) {
c, err := dnssvc.New(&dnssvc.Config{
Addresses: wantDNS.Addresses,
UpstreamServers: wantDNS.UpstreamServers,
BootstrapServers: wantDNS.BootstrapServers,
UpstreamTimeout: wantDNS.UpstreamTimeout.Duration,
})
require.NoError(t, err)
return c
}
confMgr.onWeb = func() (c *websvc.Service) {
return websvc.New(&websvc.Config{
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: wantWeb.Timeout.Duration,
ForceHTTPS: true,
})
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsAll,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{}
err := json.Unmarshal(body, resp)
require.NoError(t, err)
assert.Equal(t, wantDNS, resp.DNS)
assert.Equal(t, wantWeb, resp.HTTP)
}

View File

@@ -0,0 +1,35 @@
package websvc
import (
"net/http"
"runtime"
"github.com/AdguardTeam/AdGuardHome/internal/version"
)
// System Handlers
// RespGetV1SystemInfo describes the response of the GET /api/v1/system/info
// HTTP API.
type RespGetV1SystemInfo struct {
Arch string `json:"arch"`
Channel string `json:"channel"`
OS string `json:"os"`
NewVersion string `json:"new_version,omitempty"`
Start jsonTime `json:"start"`
Version string `json:"version"`
}
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
// API.
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
writeJSONResponse(w, r, &RespGetV1SystemInfo{
Arch: runtime.GOARCH,
Channel: version.Channel(),
OS: runtime.GOOS,
// TODO(a.garipov): Fill this when we have an updater.
NewVersion: "",
Start: jsonTime(svc.start),
Version: version.Version(),
})
}

View File

@@ -0,0 +1,37 @@
package websvc_test
import (
"encoding/json"
"net/http"
"net/url"
"runtime"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_handleGetV1SystemInfo(t *testing.T) {
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SystemInfo,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SystemInfo{}
err := json.Unmarshal(body, resp)
require.NoError(t, err)
// TODO(a.garipov): Consider making version.Channel and version.Version
// testable and test these better.
assert.NotEmpty(t, resp.Channel)
assert.Equal(t, resp.Arch, runtime.GOARCH)
assert.Equal(t, resp.OS, runtime.GOOS)
assert.Equal(t, testStart, time.Time(resp.Start))
}

View File

@@ -0,0 +1,31 @@
package websvc
import (
"net"
"sync"
)
// Wait Listener
// waitListener is a wrapper around a listener that also calls wg.Done() on the
// first call to Accept. It is useful in situations where it is important to
// catch the precise moment of the first call to Accept, for example when
// starting an HTTP server.
//
// TODO(a.garipov): Move to aghnet?
type waitListener struct {
net.Listener
firstAcceptWG *sync.WaitGroup
firstAcceptOnce sync.Once
}
// type check
var _ net.Listener = (*waitListener)(nil)
// Accept implements the [net.Listener] interface for *waitListener.
func (l *waitListener) Accept() (conn net.Conn, err error) {
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
return l.Listener.Accept()
}

View File

@@ -0,0 +1,50 @@
package websvc
import (
"net"
"sync"
"sync/atomic"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
)
func TestWaitListener_Accept(t *testing.T) {
// TODO(a.garipov): use atomic.Bool in Go 1.19.
var numAcceptCalls uint32
var l net.Listener = &aghtest.Listener{
OnAccept: func() (conn net.Conn, err error) {
atomic.AddUint32(&numAcceptCalls, 1)
return nil, nil
},
OnAddr: func() (addr net.Addr) {
panic("not implemented")
},
OnClose: func() (err error) {
panic("not implemented")
},
}
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan struct{})
go aghchan.MustReceive(done, testTimeout)
go func() {
var wrapper net.Listener = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
_, _ = wrapper.Accept()
}()
wg.Wait()
close(done)
assert.Equal(t, uint32(1), atomic.LoadUint32(&numAcceptCalls))
}

View File

@@ -0,0 +1,305 @@
// Package websvc contains the AdGuard Home HTTP API service.
//
// NOTE: Packages other than cmd must not import this package, as it imports
// most other packages.
//
// TODO(a.garipov): Add tests.
package websvc
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
httptreemux "github.com/dimfeld/httptreemux/v5"
)
// ConfigManager is the configuration manager interface.
type ConfigManager interface {
DNS() (svc *dnssvc.Service)
Web() (svc *Service)
UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error)
UpdateWeb(ctx context.Context, c *Config) (err error)
}
// Config is the AdGuard Home web service configuration structure.
type Config struct {
// ConfigManager is used to show information about services as well as
// dynamically reconfigure them.
ConfigManager ConfigManager
// TLS is the optional TLS configuration. If TLS is not nil,
// SecureAddresses must not be empty.
TLS *tls.Config
// Start is the time of start of AdGuard Home.
Start time.Time
// Addresses are the addresses on which to serve the plain HTTP API.
Addresses []netip.AddrPort
// SecureAddresses are the addresses on which to serve the HTTPS API. If
// SecureAddresses is not empty, TLS must not be nil.
SecureAddresses []netip.AddrPort
// Timeout is the timeout for all server operations.
Timeout time.Duration
// ForceHTTPS tells if all requests to Addresses should be redirected to a
// secure address instead.
//
// TODO(a.garipov): Use; define rules, which address to redirect to.
ForceHTTPS bool
}
// Service is the AdGuard Home web service. A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
confMgr ConfigManager
tls *tls.Config
start time.Time
servers []*http.Server
timeout time.Duration
forceHTTPS bool
}
// New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. The fields of c must not be modified after
// calling New.
func New(c *Config) (svc *Service) {
if c == nil {
return nil
}
svc = &Service{
confMgr: c.ConfigManager,
tls: c.TLS,
start: c.Start,
timeout: c.Timeout,
forceHTTPS: c.ForceHTTPS,
}
mux := newMux(svc)
for _, a := range c.Addresses {
addr := a.String()
errLog := log.StdLog("websvc: plain http: "+addr, log.ERROR)
svc.servers = append(svc.servers, &http.Server{
Addr: addr,
Handler: mux,
ErrorLog: errLog,
ReadTimeout: c.Timeout,
WriteTimeout: c.Timeout,
IdleTimeout: c.Timeout,
ReadHeaderTimeout: c.Timeout,
})
}
for _, a := range c.SecureAddresses {
addr := a.String()
errLog := log.StdLog("websvc: https: "+addr, log.ERROR)
svc.servers = append(svc.servers, &http.Server{
Addr: addr,
Handler: mux,
TLSConfig: c.TLS,
ErrorLog: errLog,
ReadTimeout: c.Timeout,
WriteTimeout: c.Timeout,
IdleTimeout: c.Timeout,
ReadHeaderTimeout: c.Timeout,
})
}
return svc
}
// newMux returns a new HTTP request multiplexor for the AdGuard Home web
// service.
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
mux = httptreemux.NewContextMux()
routes := []struct {
handler http.HandlerFunc
method string
path string
isJSON bool
}{{
handler: svc.handleGetHealthCheck,
method: http.MethodGet,
path: PathHealthCheck,
isJSON: false,
}, {
handler: svc.handleGetSettingsAll,
method: http.MethodGet,
path: PathV1SettingsAll,
isJSON: true,
}, {
handler: svc.handlePatchSettingsDNS,
method: http.MethodPatch,
path: PathV1SettingsDNS,
isJSON: true,
}, {
handler: svc.handlePatchSettingsHTTP,
method: http.MethodPatch,
path: PathV1SettingsHTTP,
isJSON: true,
}, {
handler: svc.handleGetV1SystemInfo,
method: http.MethodGet,
path: PathV1SystemInfo,
isJSON: true,
}}
for _, r := range routes {
if r.isJSON {
mux.Handle(r.method, r.path, jsonMw(r.handler))
} else {
mux.Handle(r.method, r.path, r.handler)
}
}
return mux
}
// addrs returns all addresses on which this server serves the HTTP API. addrs
// must not be called simultaneously with Start. If svc was initialized with
// ":0" addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
for _, srv := range svc.servers {
addrPort, err := netip.ParseAddrPort(srv.Addr)
if err != nil {
// Technically shouldn't happen, since all servers must have a valid
// address.
panic(fmt.Errorf("websvc: server %q: bad address: %w", srv.Addr, err))
}
// srv.Serve will set TLSConfig to an almost empty value, so, instead of
// relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
addrs = append(addrs, addrPort)
} else {
secureAddrs = append(secureAddrs, addrPort)
}
}
return addrs, secureAddrs
}
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "OK")
}
// type check
var _ agh.Service = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all HTTP servers have tried to start, possibly failing and
// writing error messages to the log.
func (svc *Service) Start() (err error) {
if svc == nil {
return nil
}
wg := &sync.WaitGroup{}
wg.Add(len(svc.servers))
for _, srv := range svc.servers {
go serve(srv, wg)
}
wg.Wait()
return nil
}
// serve starts and runs srv and writes all errors into its log.
func serve(srv *http.Server, wg *sync.WaitGroup) {
addr := srv.Addr
defer log.OnPanic(addr)
var proto string
var l net.Listener
var err error
if srv.TLSConfig == nil {
proto = "http"
l, err = net.Listen("tcp", addr)
} else {
proto = "https"
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
}
if err != nil {
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
}
// Update the server's address in case the address had the port zero, which
// would mean that a random available port was automatically chosen.
srv.Addr = l.Addr().String()
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
l = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
err = srv.Serve(l)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
}
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
if svc == nil {
return nil
}
var errs []error
for _, srv := range svc.servers {
serr := srv.Shutdown(ctx)
if serr != nil {
errs = append(errs, fmt.Errorf("shutting down srv %s: %w", srv.Addr, serr))
}
}
if len(errs) > 0 {
return errors.List("shutting down", errs...)
}
return nil
}
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
c = &Config{
ConfigManager: svc.confMgr,
TLS: svc.tls,
// Leave Addresses and SecureAddresses empty and get the actual
// addresses that include the :0 ones later.
Start: svc.start,
Timeout: svc.timeout,
ForceHTTPS: svc.forceHTTPS,
}
c.Addresses, c.SecureAddresses = svc.addrs()
return c
}

View File

@@ -0,0 +1,6 @@
package websvc
import "time"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second

View File

@@ -0,0 +1,153 @@
package websvc_test
import (
"context"
"io"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testStart is the server start value for tests.
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
// type check
var _ websvc.ConfigManager = (*configManager)(nil)
// configManager is a [websvc.ConfigManager] for tests.
type configManager struct {
onDNS func() (svc *dnssvc.Service)
onWeb func() (svc *websvc.Service)
onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error)
onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error)
}
// DNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) DNS() (svc *dnssvc.Service) {
return m.onDNS()
}
// Web implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) Web() (svc *websvc.Service) {
return m.onWeb()
}
// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
return m.onUpdateDNS(ctx, c)
}
// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
return m.onUpdateWeb(ctx, c)
}
// newConfigManager returns a *configManager all methods of which panic.
func newConfigManager() (m *configManager) {
return &configManager{
onDNS: func() (svc *dnssvc.Service) { panic("not implemented") },
onWeb: func() (svc *websvc.Service) { panic("not implemented") },
onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) {
panic("not implemented")
},
onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) {
panic("not implemented")
},
}
}
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(
t testing.TB,
confMgr websvc.ConfigManager,
) (svc *websvc.Service, addr netip.AddrPort) {
t.Helper()
c := &websvc.Config{
ConfigManager: confMgr,
TLS: nil,
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: nil,
Timeout: testTimeout,
Start: testStart,
ForceHTTPS: false,
}
svc = websvc.New(c)
err := svc.Start()
require.NoError(t, err)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
})
c = svc.Config()
require.NotNil(t, c)
require.Len(t, c.Addresses, 1)
return svc, c.Addresses[0]
}
// httpGet is a helper that performs an HTTP GET request and returns the body of
// the response as well as checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
t.Helper()
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
func TestService_Start_getHealthCheck(t *testing.T) {
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathHealthCheck,
}
body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body)
}