Pull request 2100: v0.107.42-rc

Squashed commit of the following:

commit 284190f748345c7556e60b67f051ec5f6f080948
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Dec 6 19:36:00 2023 +0300

    all: sync with master; upd chlog
This commit is contained in:
Ainar Garipov
2023-12-07 17:23:00 +03:00
parent df91f016f2
commit 25918e56fa
148 changed files with 7204 additions and 1705 deletions

View File

@@ -0,0 +1,30 @@
# This is a file showing example configuration for AdGuard Home.
#
# TODO(a.garipov): Move to the top level once the rewrite is over.
dns:
addresses:
- '0.0.0.0:53'
bootstrap_dns:
- '9.9.9.10'
- '149.112.112.10'
- '2620:fe::10'
- '2620:fe::fe:10'
upstream_dns:
- '8.8.8.8'
dns64_prefixes:
- '1234::/64'
upstream_timeout: 1s
bootstrap_prefer_ipv6: true
use_dns64: true
http:
pprof:
enabled: true
port: 6060
addresses:
- '0.0.0.0:3000'
secure_addresses: []
timeout: 5s
force_https: true
log:
verbose: true

View File

@@ -0,0 +1,42 @@
# AdGuard Home v0.108.0 Changelog DRAFT
This changelog should be merged into the main one once the next API matures
enough.
## [v0.108.0] - TODO
### Added
- The ability to change the port of the pprof debug API.
- The ability to log to stderr using `--logFile=stderr`.
- The new `--web-addr` flag to set the Web UI address in a `host:port` form.
- `SIGHUP` now reloads all configuration from the configuration file ([#5676]).
### Changed
#### New HTTP API
**TODO(a.garipov):** Describe the new API and add a link to the new OpenAPI doc.
#### Other changes
- `-h` is now an alias for `--help` instead of the removed `--host`, see below.
Use `--web-addr=host:port` to set an address on which to serve the Web UI.
### Fixed
- `--check-config` breaking the configuration file ([#4067]).
- Inconsistent application of `--work-dir/-w` ([#2598], [#2902]).
- The order of `-v/--verbose` and `--version` being significant ([#2893]).
### Removed
- The deprecated `--no-mem-optimization` and `--no-etc-hosts` flags.
- `--host` and `-p/--port` flags. Use `--web-addr=host:port` to set an address
on which to serve the Web UI. `-h` is now an alias for `--help`, see above.
[#2598]: https://github.com/AdguardTeam/AdGuardHome/issues/2598
[#2893]: https://github.com/AdguardTeam/AdGuardHome/issues/2893
[#2902]: https://github.com/AdguardTeam/AdGuardHome/issues/2902
[#4067]: https://github.com/AdguardTeam/AdGuardHome/issues/4067
[#5676]: https://github.com/AdguardTeam/AdGuardHome/issues/5676

95
internal/next/cmd/cmd.go Normal file
View File

@@ -0,0 +1,95 @@
// Package cmd is the AdGuard Home entry point. It assembles the configuration
// file manager, sets up signal processing logic, and so on.
//
// TODO(a.garipov): Move to the upper-level internal/.
package cmd
import (
"context"
"io/fs"
"os"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
)
// Main is the entry point of AdGuard Home.
func Main(embeddedFrontend fs.FS) {
start := time.Now()
cmdName := os.Args[0]
opts, err := parseOptions(cmdName, os.Args[1:])
exitCode, needExit := processOptions(opts, cmdName, err)
if needExit {
os.Exit(exitCode)
}
err = setLog(opts)
check(err)
log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid())
if opts.workDir != "" {
log.Info("changing working directory to %q", opts.workDir)
err = os.Chdir(opts.workDir)
check(err)
}
frontend, err := frontendFromOpts(opts, embeddedFrontend)
check(err)
confMgrConf := &configmgr.Config{
Frontend: frontend,
WebAddr: opts.webAddr,
Start: start,
FileName: opts.confFile,
}
confMgr, err := newConfigMgr(confMgrConf)
check(err)
web := confMgr.Web()
err = web.Start()
check(err)
dns := confMgr.DNS()
err = dns.Start()
check(err)
sigHdlr := newSignalHandler(
confMgrConf,
opts.pidFile,
web,
dns,
)
sigHdlr.handle()
}
// defaultTimeout is the timeout used for some operations where another timeout
// hasn't been defined yet.
const defaultTimeout = 5 * time.Second
// ctxWithDefaultTimeout is a helper function that returns a context with
// timeout set to defaultTimeout.
func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) {
return context.WithTimeout(context.Background(), defaultTimeout)
}
// newConfigMgr returns a new configuration manager using defaultTimeout as the
// context timeout.
func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) {
ctx, cancel := ctxWithDefaultTimeout()
defer cancel()
return configmgr.New(ctx, c)
}
// check is a simple error-checking helper. It must only be used within Main.
func check(err error) {
if err != nil {
panic(err)
}
}

39
internal/next/cmd/log.go Normal file
View File

@@ -0,0 +1,39 @@
package cmd
import (
"fmt"
"os"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
)
// syslogServiceName is the name of the AdGuard Home service used for writing
// logs to the system log.
const syslogServiceName = "AdGuardHome"
// setLog sets up the text logging.
//
// TODO(a.garipov): Add parameters from configuration file.
func setLog(opts *options) (err error) {
switch opts.confFile {
case "stdout":
log.SetOutput(os.Stdout)
case "stderr":
log.SetOutput(os.Stderr)
case "syslog":
err = aghos.ConfigureSyslog(syslogServiceName)
if err != nil {
return fmt.Errorf("initializing syslog: %w", err)
}
default:
// TODO(a.garipov): Use the path.
}
if opts.verbose {
log.SetLevel(log.DEBUG)
log.Debug("verbose logging enabled")
}
return nil
}

418
internal/next/cmd/opt.go Normal file
View File

@@ -0,0 +1,418 @@
package cmd
import (
"encoding"
"flag"
"fmt"
"io"
"io/fs"
"net/netip"
"os"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices"
)
// options contains all command-line options for the AdGuardHome(.exe) binary.
type options struct {
// confFile is the path to the configuration file.
confFile string
// logFile is the path to the log file. Special values:
//
// - "stdout": Write to stdout (the default).
// - "stderr": Write to stderr.
// - "syslog": Write to the system log.
logFile string
// pidFile is the path to the file where to store the PID.
pidFile string
// serviceAction is the service control action to perform:
//
// - "install": Installs AdGuard Home as a system service.
// - "uninstall": Uninstalls it.
// - "status": Prints the service status.
// - "start": Starts the previously installed service.
// - "stop": Stops the previously installed service.
// - "restart": Restarts the previously installed service.
// - "reload": Reloads the configuration.
// - "run": This is a special command that is not supposed to be used
// directly it is specified when we register a service, and it indicates
// to the app that it is being run as a service.
//
// TODO(a.garipov): Use.
serviceAction string
// workDir is the path to the working directory. It is applied before all
// other configuration is read, so all relative paths are relative to it.
workDir string
// webAddr contains the address on which to serve the web UI.
webAddr netip.AddrPort
// checkConfig, if true, instructs AdGuard Home to check the configuration
// file, optionally print an error message to stdout, and exit with a
// corresponding exit code.
checkConfig bool
// disableUpdate, if true, prevents AdGuard Home from automatically checking
// for updates.
//
// TODO(a.garipov): Use.
disableUpdate bool
// glinetMode enables the GL-Inet compatibility mode.
//
// TODO(a.garipov): Use.
glinetMode bool
// help, if true, instructs AdGuard Home to print the command-line option
// help message and quit with a successful exit-code.
help bool
// localFrontend, if true, instructs AdGuard Home to use the local frontend
// directory instead of the files compiled into the binary.
//
// TODO(a.garipov): Use.
localFrontend bool
// performUpdate, if true, instructs AdGuard Home to update the current
// binary and restart the service in case it's installed.
//
// TODO(a.garipov): Use.
performUpdate bool
// verbose, if true, instructs AdGuard Home to enable verbose logging.
verbose bool
// version, if true, instructs AdGuard Home to print the version to stdout
// and quit with a successful exit-code. If verbose is also true, print a
// more detailed version description.
version bool
}
// Indexes to help with the [commandLineOptions] initialization.
const (
confFileIdx = iota
logFileIdx
pidFileIdx
serviceActionIdx
workDirIdx
webAddrIdx
checkConfigIdx
disableUpdateIdx
glinetModeIdx
helpIdx
localFrontend
performUpdateIdx
verboseIdx
versionIdx
)
// commandLineOption contains information about a command-line option: its long
// and, if there is one, short forms, the value type, the description, and the
// default value.
type commandLineOption struct {
defaultValue any
description string
long string
short string
valueType string
}
// commandLineOptions are all command-line options currently supported by
// AdGuard Home.
var commandLineOptions = []*commandLineOption{
confFileIdx: {
// TODO(a.garipov): Remove the directory when the new code is ready.
defaultValue: "internal/next/AdGuardHome.yaml",
description: "Path to the config file.",
long: "config",
short: "c",
valueType: "path",
},
logFileIdx: {
defaultValue: "stdout",
description: `Path to log file. Special values include "stdout", "stderr", and "syslog".`,
long: "logfile",
short: "l",
valueType: "path",
},
pidFileIdx: {
defaultValue: "",
description: "Path to the file where to store the PID.",
long: "pidfile",
short: "",
valueType: "path",
},
serviceActionIdx: {
defaultValue: "",
description: `Service control action: "status", "install" (as a service), ` +
`"uninstall" (as a service), "start", "stop", "restart", "reload" (configuration).`,
long: "service",
short: "s",
valueType: "action",
},
workDirIdx: {
defaultValue: "",
description: `Path to the working directory. ` +
`It is applied before all other configuration is read, ` +
`so all relative paths are relative to it.`,
long: "work-dir",
short: "w",
valueType: "path",
},
webAddrIdx: {
defaultValue: netip.AddrPort{},
description: `Address to serve the web UI on, in the host:port format.`,
long: "web-addr",
short: "",
valueType: "host:port",
},
checkConfigIdx: {
defaultValue: false,
description: "Check configuration, print errors to stdout, and quit.",
long: "check-config",
short: "",
valueType: "",
},
disableUpdateIdx: {
defaultValue: false,
description: "Disable automatic update checking.",
long: "no-check-update",
short: "",
valueType: "",
},
glinetModeIdx: {
defaultValue: false,
description: "Run in GL-Inet compatibility mode.",
long: "glinet",
short: "",
valueType: "",
},
helpIdx: {
defaultValue: false,
description: "Print this help message and quit.",
long: "help",
short: "h",
valueType: "",
},
localFrontend: {
defaultValue: false,
description: "Use local frontend directories.",
long: "local-frontend",
short: "",
valueType: "",
},
performUpdateIdx: {
defaultValue: false,
description: "Update the current binary and restart the service in case it's installed.",
long: "update",
short: "",
valueType: "",
},
verboseIdx: {
defaultValue: false,
description: "Enable verbose logging.",
long: "verbose",
short: "v",
valueType: "",
},
versionIdx: {
defaultValue: false,
description: `Print the version to stdout and quit. ` +
`Print a more detailed version description with -v.`,
long: "version",
short: "",
valueType: "",
},
}
// parseOptions parses the command-line options for AdGuardHome.
func parseOptions(cmdName string, args []string) (opts *options, err error) {
flags := flag.NewFlagSet(cmdName, flag.ContinueOnError)
opts = &options{}
for i, fieldPtr := range []any{
confFileIdx: &opts.confFile,
logFileIdx: &opts.logFile,
pidFileIdx: &opts.pidFile,
serviceActionIdx: &opts.serviceAction,
workDirIdx: &opts.workDir,
webAddrIdx: &opts.webAddr,
checkConfigIdx: &opts.checkConfig,
disableUpdateIdx: &opts.disableUpdate,
glinetModeIdx: &opts.glinetMode,
helpIdx: &opts.help,
localFrontend: &opts.localFrontend,
performUpdateIdx: &opts.performUpdate,
verboseIdx: &opts.verbose,
versionIdx: &opts.version,
} {
addOption(flags, fieldPtr, commandLineOptions[i])
}
flags.Usage = func() { usage(cmdName, os.Stderr) }
err = flags.Parse(args)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return opts, nil
}
// addOption adds the command-line option described by o to flags using fieldPtr
// as the pointer to the value.
func addOption(flags *flag.FlagSet, fieldPtr any, o *commandLineOption) {
switch fieldPtr := fieldPtr.(type) {
case *string:
flags.StringVar(fieldPtr, o.long, o.defaultValue.(string), o.description)
if o.short != "" {
flags.StringVar(fieldPtr, o.short, o.defaultValue.(string), o.description)
}
case *bool:
flags.BoolVar(fieldPtr, o.long, o.defaultValue.(bool), o.description)
if o.short != "" {
flags.BoolVar(fieldPtr, o.short, o.defaultValue.(bool), o.description)
}
case encoding.TextUnmarshaler:
flags.TextVar(fieldPtr, o.long, o.defaultValue.(encoding.TextMarshaler), o.description)
if o.short != "" {
flags.TextVar(fieldPtr, o.short, o.defaultValue.(encoding.TextMarshaler), o.description)
}
default:
panic(fmt.Errorf("unexpected field pointer type %T", fieldPtr))
}
}
// usage prints a usage message similar to the one printed by package flag but
// taking long vs. short versions into account as well as using more informative
// value hints.
func usage(cmdName string, output io.Writer) {
options := slices.Clone(commandLineOptions)
slices.SortStableFunc(options, func(a, b *commandLineOption) (res int) {
return strings.Compare(a.long, b.long)
})
b := &strings.Builder{}
_, _ = fmt.Fprintf(b, "Usage of %s:\n", cmdName)
for _, o := range options {
writeUsageLine(b, o)
// Use four spaces before the tab to trigger good alignment for both 4-
// and 8-space tab stops.
if shouldIncludeDefault(o.defaultValue) {
_, _ = fmt.Fprintf(b, " \t%s (Default value: %q)\n", o.description, o.defaultValue)
} else {
_, _ = fmt.Fprintf(b, " \t%s\n", o.description)
}
}
_, _ = io.WriteString(output, b.String())
}
// shouldIncludeDefault returns true if this default value should be printed.
func shouldIncludeDefault(v any) (ok bool) {
switch v := v.(type) {
case bool:
return v
case string:
return v != ""
default:
return v == nil
}
}
// writeUsageLine writes the usage line for the provided command-line option.
func writeUsageLine(b *strings.Builder, o *commandLineOption) {
if o.short == "" {
if o.valueType == "" {
_, _ = fmt.Fprintf(b, " --%s\n", o.long)
} else {
_, _ = fmt.Fprintf(b, " --%s=%s\n", o.long, o.valueType)
}
return
}
if o.valueType == "" {
_, _ = fmt.Fprintf(b, " --%s/-%s\n", o.long, o.short)
} else {
_, _ = fmt.Fprintf(b, " --%[1]s=%[3]s/-%[2]s %[3]s\n", o.long, o.short, o.valueType)
}
}
// processOptions decides if AdGuard Home should exit depending on the results
// of command-line option parsing.
func processOptions(
opts *options,
cmdName string,
parseErr error,
) (exitCode int, needExit bool) {
if parseErr != nil {
// Assume that usage has already been printed.
return statusArgumentError, true
}
if opts.help {
usage(cmdName, os.Stdout)
return statusSuccess, true
}
if opts.version {
if opts.verbose {
fmt.Println(version.Verbose())
} else {
fmt.Printf("AdGuard Home %s\n", version.Version())
}
return statusSuccess, true
}
if opts.checkConfig {
err := configmgr.Validate(opts.confFile)
if err != nil {
_, _ = io.WriteString(os.Stdout, err.Error()+"\n")
return statusError, true
}
return statusSuccess, true
}
return 0, false
}
// frontendFromOpts returns the frontend to use based on the options.
func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) {
const frontendSubdir = "build/static"
if opts.localFrontend {
log.Info("warning: using local frontend files")
return os.DirFS(frontendSubdir), nil
}
return fs.Sub(embeddedFrontend, frontendSubdir)
}

167
internal/next/cmd/signal.go Normal file
View File

@@ -0,0 +1,167 @@
package cmd
import (
"os"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/v2/maybe"
)
// signalHandler processes incoming signals and shuts services down.
type signalHandler struct {
// confMgrConf contains the configuration parameters for the configuration
// manager.
confMgrConf *configmgr.Config
// signal is the channel to which OS signals are sent.
signal chan os.Signal
// pidFile is the path to the file where to store the PID, if any.
pidFile string
// services are the services that are shut down before application exiting.
services []agh.Service
}
// handle processes OS signals.
func (h *signalHandler) handle() {
defer log.OnPanic("signalHandler.handle")
h.writePID()
for sig := range h.signal {
log.Info("sighdlr: received signal %q", sig)
if aghos.IsReconfigureSignal(sig) {
h.reconfigure()
} else if aghos.IsShutdownSignal(sig) {
status := h.shutdown()
h.removePID()
log.Info("sighdlr: exiting with status %d", status)
os.Exit(status)
}
}
}
// reconfigure rereads the configuration file and updates and restarts services.
func (h *signalHandler) reconfigure() {
log.Info("sighdlr: reconfiguring adguard home")
status := h.shutdown()
if status != statusSuccess {
log.Info("sighdlr: reconfiguring: exiting with status %d", status)
os.Exit(status)
}
// TODO(a.garipov): This is a very rough way to do it. Some services can be
// reconfigured without the full shutdown, and the error handling is
// currently not the best.
confMgr, err := newConfigMgr(h.confMgrConf)
check(err)
web := confMgr.Web()
err = web.Start()
check(err)
dns := confMgr.DNS()
err = dns.Start()
check(err)
h.services = []agh.Service{
dns,
web,
}
log.Info("sighdlr: successfully reconfigured adguard home")
}
// Exit status constants.
const (
statusSuccess = 0
statusError = 1
statusArgumentError = 2
)
// shutdown gracefully shuts down all services.
func (h *signalHandler) shutdown() (status int) {
ctx, cancel := ctxWithDefaultTimeout()
defer cancel()
status = statusSuccess
log.Info("sighdlr: shutting down services")
for i, service := range h.services {
err := service.Shutdown(ctx)
if err != nil {
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
status = statusError
}
}
return status
}
// newSignalHandler returns a new signalHandler that shuts down svcs.
func newSignalHandler(
confMgrConf *configmgr.Config,
pidFile string,
svcs ...agh.Service,
) (h *signalHandler) {
h = &signalHandler{
confMgrConf: confMgrConf,
signal: make(chan os.Signal, 1),
pidFile: pidFile,
services: svcs,
}
aghos.NotifyShutdownSignal(h.signal)
aghos.NotifyReconfigureSignal(h.signal)
return h
}
// writePID writes the PID to the file, if needed. Any errors are reported to
// log.
func (h *signalHandler) writePID() {
if h.pidFile == "" {
return
}
// Use 8, since most PIDs will fit.
data := make([]byte, 0, 8)
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
data = append(data, '\n')
err := maybe.WriteFile(h.pidFile, data, 0o644)
if err != nil {
log.Error("sighdlr: writing pidfile: %s", err)
return
}
log.Debug("sighdlr: wrote pid to %q", h.pidFile)
}
// removePID removes the PID file, if any.
func (h *signalHandler) removePID() {
if h.pidFile == "" {
return
}
err := os.Remove(h.pidFile)
if err != nil {
log.Error("sighdlr: removing pidfile: %s", err)
return
}
log.Debug("sighdlr: removed pid at %q", h.pidFile)
}

View File

@@ -0,0 +1,138 @@
package configmgr
import (
"fmt"
"net/netip"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/timeutil"
)
// Configuration Structures
// config is the top-level on-disk configuration structure.
type config struct {
DNS *dnsConfig `yaml:"dns"`
HTTP *httpConfig `yaml:"http"`
Log *logConfig `yaml:"log"`
// TODO(a.garipov): Use.
SchemaVersion int `yaml:"schema_version"`
}
const errNoConf errors.Error = "configuration not found"
// validate returns an error if the configuration structure is invalid.
func (c *config) validate() (err error) {
if c == nil {
return errNoConf
}
// TODO(a.garipov): Add more validations.
// Keep this in the same order as the fields in the config.
validators := []struct {
validate func() (err error)
name string
}{{
validate: c.DNS.validate,
name: "dns",
}, {
validate: c.HTTP.validate,
name: "http",
}, {
validate: c.Log.validate,
name: "log",
}}
for _, v := range validators {
err = v.validate()
if err != nil {
return fmt.Errorf("%s: %w", v.name, err)
}
}
return nil
}
// dnsConfig is the on-disk DNS configuration.
type dnsConfig struct {
Addresses []netip.AddrPort `yaml:"addresses"`
BootstrapDNS []string `yaml:"bootstrap_dns"`
UpstreamDNS []string `yaml:"upstream_dns"`
DNS64Prefixes []netip.Prefix `yaml:"dns64_prefixes"`
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
BootstrapPreferIPv6 bool `yaml:"bootstrap_prefer_ipv6"`
UseDNS64 bool `yaml:"use_dns64"`
}
// validate returns an error if the DNS configuration structure is invalid.
//
// TODO(a.garipov): Add more validations.
func (c *dnsConfig) validate() (err error) {
// TODO(a.garipov): Add more validations.
switch {
case c == nil:
return errNoConf
case c.UpstreamTimeout.Duration <= 0:
return newMustBePositiveError("upstream_timeout", c.UpstreamTimeout)
default:
return nil
}
}
// httpConfig is the on-disk web API configuration.
type httpConfig struct {
Pprof *httpPprofConfig `yaml:"pprof"`
// TODO(a.garipov): Document the configuration change.
Addresses []netip.AddrPort `yaml:"addresses"`
SecureAddresses []netip.AddrPort `yaml:"secure_addresses"`
Timeout timeutil.Duration `yaml:"timeout"`
ForceHTTPS bool `yaml:"force_https"`
}
// validate returns an error if the HTTP configuration structure is invalid.
//
// TODO(a.garipov): Add more validations.
func (c *httpConfig) validate() (err error) {
switch {
case c == nil:
return errNoConf
case c.Timeout.Duration <= 0:
return newMustBePositiveError("timeout", c.Timeout)
default:
return c.Pprof.validate()
}
}
// httpPprofConfig is the on-disk pprof configuration.
type httpPprofConfig struct {
Port uint16 `yaml:"port"`
Enabled bool `yaml:"enabled"`
}
// validate returns an error if the pprof configuration structure is invalid.
func (c *httpPprofConfig) validate() (err error) {
if c == nil {
return errNoConf
}
return nil
}
// logConfig is the on-disk web API configuration.
type logConfig struct {
// TODO(a.garipov): Use.
Verbose bool `yaml:"verbose"`
}
// validate returns an error if the HTTP configuration structure is invalid.
//
// TODO(a.garipov): Add more validations.
func (c *logConfig) validate() (err error) {
if c == nil {
return errNoConf
}
return nil
}

View File

@@ -0,0 +1,301 @@
// Package configmgr defines the AdGuard Home on-disk configuration entities and
// configuration manager.
//
// TODO(a.garipov): Add tests.
package configmgr
import (
"context"
"fmt"
"io/fs"
"net/netip"
"os"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/v2/maybe"
"golang.org/x/exp/slices"
"gopkg.in/yaml.v3"
)
// Configuration Manager
// Manager handles full and partial changes in the configuration, persisting
// them to disk if necessary.
//
// TODO(a.garipov): Support missing configs and default values.
type Manager struct {
// updMu makes sure that at most one reconfiguration is performed at a time.
// updMu protects all fields below.
updMu *sync.RWMutex
// dns is the DNS service.
dns *dnssvc.Service
// Web is the Web API service.
web *websvc.Service
// current is the current configuration.
current *config
// fileName is the name of the configuration file.
fileName string
}
// Validate returns an error if the configuration file with the given name does
// not exist or is invalid.
func Validate(fileName string) (err error) {
conf, err := read(fileName)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
// Don't wrap the error, because it's informative enough as is.
return conf.validate()
}
// Config contains the configuration parameters for the configuration manager.
type Config struct {
// Frontend is the filesystem with the frontend files.
Frontend fs.FS
// WebAddr is the initial or override address for the Web UI. It is not
// written to the configuration file.
WebAddr netip.AddrPort
// Start is the time of start of AdGuard Home.
Start time.Time
// FileName is the path to the configuration file.
FileName string
}
// New creates a new *Manager that persists changes to the file pointed to by
// c.FileName. It reads the configuration file and populates the service
// fields. c must not be nil.
func New(ctx context.Context, c *Config) (m *Manager, err error) {
conf, err := read(c.FileName)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
err = conf.validate()
if err != nil {
return nil, fmt.Errorf("validating config: %w", err)
}
m = &Manager{
updMu: &sync.RWMutex{},
current: conf,
fileName: c.FileName,
}
err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
if err != nil {
return nil, fmt.Errorf("creating config manager: %w", err)
}
return m, nil
}
// read reads and decodes configuration from the provided filename.
func read(fileName string) (conf *config, err error) {
defer func() { err = errors.Annotate(err, "reading config: %w") }()
conf = &config{}
f, err := os.Open(fileName)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
err = yaml.NewDecoder(f).Decode(conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return conf, nil
}
// assemble creates all services and puts them into the corresponding fields.
// The fields of conf must not be modified after calling assemble.
func (m *Manager) assemble(
ctx context.Context,
conf *config,
frontend fs.FS,
webAddr netip.AddrPort,
start time.Time,
) (err error) {
dnsConf := &dnssvc.Config{
Addresses: conf.DNS.Addresses,
BootstrapServers: conf.DNS.BootstrapDNS,
UpstreamServers: conf.DNS.UpstreamDNS,
DNS64Prefixes: conf.DNS.DNS64Prefixes,
UpstreamTimeout: conf.DNS.UpstreamTimeout.Duration,
BootstrapPreferIPv6: conf.DNS.BootstrapPreferIPv6,
UseDNS64: conf.DNS.UseDNS64,
}
err = m.updateDNS(ctx, dnsConf)
if err != nil {
return fmt.Errorf("assembling dnssvc: %w", err)
}
webSvcConf := &websvc.Config{
Pprof: &websvc.PprofConfig{
Port: conf.HTTP.Pprof.Port,
Enabled: conf.HTTP.Pprof.Enabled,
},
ConfigManager: m,
Frontend: frontend,
// TODO(a.garipov): Fill from config file.
TLS: nil,
Start: start,
Addresses: conf.HTTP.Addresses,
SecureAddresses: conf.HTTP.SecureAddresses,
OverrideAddress: webAddr,
Timeout: conf.HTTP.Timeout.Duration,
ForceHTTPS: conf.HTTP.ForceHTTPS,
}
err = m.updateWeb(ctx, webSvcConf)
if err != nil {
return fmt.Errorf("assembling websvc: %w", err)
}
return nil
}
// write writes the current configuration to disk.
func (m *Manager) write() (err error) {
b, err := yaml.Marshal(m.current)
if err != nil {
return fmt.Errorf("encoding: %w", err)
}
err = maybe.WriteFile(m.fileName, b, 0o644)
if err != nil {
return fmt.Errorf("writing: %w", err)
}
log.Info("configmgr: written to %q", m.fileName)
return nil
}
// DNS returns the current DNS service. It is safe for concurrent use.
func (m *Manager) DNS() (dns agh.ServiceWithConfig[*dnssvc.Config]) {
m.updMu.RLock()
defer m.updMu.RUnlock()
return m.dns
}
// UpdateDNS implements the [websvc.ConfigManager] interface for *Manager. The
// fields of c must not be modified after calling UpdateDNS.
func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
m.updMu.Lock()
defer m.updMu.Unlock()
// TODO(a.garipov): Update and write the configuration file. Return an
// error if something went wrong.
err = m.updateDNS(ctx, c)
if err != nil {
return fmt.Errorf("reassembling dnssvc: %w", err)
}
m.updateCurrentDNS(c)
return m.write()
}
// updateDNS recreates the DNS service. m.updMu is expected to be locked.
func (m *Manager) updateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
if prev := m.dns; prev != nil {
err = prev.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down dns svc: %w", err)
}
}
svc, err := dnssvc.New(c)
if err != nil {
return fmt.Errorf("creating dns svc: %w", err)
}
m.dns = svc
return nil
}
// updateCurrentDNS updates the DNS configuration in the current config.
func (m *Manager) updateCurrentDNS(c *dnssvc.Config) {
m.current.DNS.Addresses = slices.Clone(c.Addresses)
m.current.DNS.BootstrapDNS = slices.Clone(c.BootstrapServers)
m.current.DNS.UpstreamDNS = slices.Clone(c.UpstreamServers)
m.current.DNS.DNS64Prefixes = slices.Clone(c.DNS64Prefixes)
m.current.DNS.UpstreamTimeout = timeutil.Duration{Duration: c.UpstreamTimeout}
m.current.DNS.BootstrapPreferIPv6 = c.BootstrapPreferIPv6
m.current.DNS.UseDNS64 = c.UseDNS64
}
// Web returns the current web service. It is safe for concurrent use.
func (m *Manager) Web() (web agh.ServiceWithConfig[*websvc.Config]) {
m.updMu.RLock()
defer m.updMu.RUnlock()
return m.web
}
// UpdateWeb implements the [websvc.ConfigManager] interface for *Manager. The
// fields of c must not be modified after calling UpdateWeb.
func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
m.updMu.Lock()
defer m.updMu.Unlock()
err = m.updateWeb(ctx, c)
if err != nil {
return fmt.Errorf("reassembling websvc: %w", err)
}
m.updateCurrentWeb(c)
return m.write()
}
// updateWeb recreates the web service. m.upd is expected to be locked.
func (m *Manager) updateWeb(ctx context.Context, c *websvc.Config) (err error) {
if prev := m.web; prev != nil {
err = prev.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down web svc: %w", err)
}
}
m.web, err = websvc.New(c)
if err != nil {
return fmt.Errorf("creating web svc: %w", err)
}
return nil
}
// updateCurrentWeb updates the web configuration in the current config.
func (m *Manager) updateCurrentWeb(c *websvc.Config) {
// TODO(a.garipov): Update pprof from API?
m.current.HTTP.Addresses = slices.Clone(c.Addresses)
m.current.HTTP.SecureAddresses = slices.Clone(c.SecureAddresses)
m.current.HTTP.Timeout = timeutil.Duration{Duration: c.Timeout}
m.current.HTTP.ForceHTTPS = c.ForceHTTPS
}

View File

@@ -0,0 +1,27 @@
package configmgr
import (
"fmt"
"github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/exp/constraints"
)
// numberOrDuration is the constraint for integer types along with
// timeutil.Duration.
type numberOrDuration interface {
constraints.Integer | timeutil.Duration
}
// newMustBePositiveError returns an error about the value that must be positive
// but isn't. prop is the name of the property to mention in the error message.
//
// TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS
// as well.
func newMustBePositiveError[T numberOrDuration](prop string, v T) (err error) {
if s, ok := any(v).(fmt.Stringer); ok {
return fmt.Errorf("%s must be positive, got %s", prop, s)
}
return fmt.Errorf("%s must be positive, got %d", prop, v)
}

View File

@@ -0,0 +1,35 @@
package dnssvc
import (
"net/netip"
"time"
)
// Config is the AdGuard Home DNS service configuration structure.
//
// TODO(a.garipov): Add timeout for incoming requests.
type Config struct {
// Addresses are the addresses on which to serve plain DNS queries.
Addresses []netip.AddrPort
// BootstrapServers are the addresses of DNS servers used for bootstrapping
// the upstream DNS server addresses.
BootstrapServers []string
// UpstreamServers are the upstream DNS server addresses to use.
UpstreamServers []string
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64. See
// also [Config.UseDNS64].
DNS64Prefixes []netip.Prefix
// UpstreamTimeout is the timeout for upstream requests.
UpstreamTimeout time.Duration
// BootstrapPreferIPv6, if true, instructs the bootstrapper to prefer IPv6
// addresses to IPv4 ones when bootstrapping.
BootstrapPreferIPv6 bool
// UseDNS64, if true, enables DNS64 protection for incoming requests.
UseDNS64 bool
}

View File

@@ -0,0 +1,232 @@
// Package dnssvc contains the AdGuard Home DNS service.
//
// TODO(a.garipov): Define, if all methods of a *Service should work with a nil
// receiver.
package dnssvc
import (
"context"
"fmt"
"net"
"net/netip"
"sync/atomic"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
// and replacement of module dnsproxy.
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
)
// Service is the AdGuard Home DNS service. A nil *Service is a valid
// [agh.Service] that does nothing.
//
// TODO(a.garipov): Consider saving a [*proxy.Config] instance for those
// fields that are only used in [New] and [Service.Config].
type Service struct {
proxy *proxy.Proxy
bootstraps []string
bootstrapResolvers []*upstream.UpstreamResolver
upstreams []string
dns64Prefixes []netip.Prefix
upsTimeout time.Duration
running atomic.Bool
bootstrapPreferIPv6 bool
useDNS64 bool
}
// New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. The fields of c must not be modified after
// calling New.
func New(c *Config) (svc *Service, err error) {
if c == nil {
return nil, nil
}
svc = &Service{
bootstraps: c.BootstrapServers,
upstreams: c.UpstreamServers,
dns64Prefixes: c.DNS64Prefixes,
upsTimeout: c.UpstreamTimeout,
bootstrapPreferIPv6: c.BootstrapPreferIPv6,
useDNS64: c.UseDNS64,
}
upstreams, resolvers, err := addressesToUpstreams(
c.UpstreamServers,
c.BootstrapServers,
c.UpstreamTimeout,
c.BootstrapPreferIPv6,
)
if err != nil {
return nil, fmt.Errorf("converting upstreams: %w", err)
}
svc.bootstrapResolvers = resolvers
svc.proxy = &proxy.Proxy{
Config: proxy.Config{
UDPListenAddr: udpAddrs(c.Addresses),
TCPListenAddr: tcpAddrs(c.Addresses),
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: upstreams,
},
UseDNS64: c.UseDNS64,
DNS64Prefs: c.DNS64Prefixes,
},
}
err = svc.proxy.Init()
if err != nil {
return nil, fmt.Errorf("proxy: %w", err)
}
return svc, nil
}
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
// accepts a slice of addresses and other upstream parameters, and returns a
// slice of upstreams.
func addressesToUpstreams(
upsStrs []string,
bootstraps []string,
timeout time.Duration,
preferIPv6 bool,
) (upstreams []upstream.Upstream, boots []*upstream.UpstreamResolver, err error) {
opts := &upstream.Options{
Timeout: timeout,
PreferIPv6: preferIPv6,
}
boots, err = aghnet.ParseBootstraps(bootstraps, opts)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
// TODO(e.burkov): Add system hosts resolver here.
var bootstrap upstream.ParallelResolver
for _, r := range boots {
bootstrap = append(bootstrap, r)
}
upstreams = make([]upstream.Upstream, len(upsStrs))
for i, upsStr := range upsStrs {
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
PreferIPv6: preferIPv6,
})
if err != nil {
return nil, boots, fmt.Errorf("upstream at index %d: %w", i, err)
}
}
return upstreams, boots, nil
}
// tcpAddrs converts []netip.AddrPort into []*net.TCPAddr.
func tcpAddrs(addrPorts []netip.AddrPort) (tcpAddrs []*net.TCPAddr) {
if addrPorts == nil {
return nil
}
tcpAddrs = make([]*net.TCPAddr, len(addrPorts))
for i, a := range addrPorts {
tcpAddrs[i] = net.TCPAddrFromAddrPort(a)
}
return tcpAddrs
}
// udpAddrs converts []netip.AddrPort into []*net.UDPAddr.
func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
if addrPorts == nil {
return nil
}
udpAddrs = make([]*net.UDPAddr, len(addrPorts))
for i, a := range addrPorts {
udpAddrs[i] = net.UDPAddrFromAddrPort(a)
}
return udpAddrs
}
// type check
var _ agh.Service = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all DNS servers have tried to start, but there is no
// guarantee that they did. Errors from the servers are written to the log.
func (svc *Service) Start() (err error) {
if svc == nil {
return nil
}
defer func() {
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
// tell when all servers are actually up, so at best this is merely an
// assumption.
svc.running.Store(err == nil)
}()
return svc.proxy.Start()
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
if svc == nil {
return nil
}
errs := []error{
svc.proxy.Stop(),
}
for _, b := range svc.bootstrapResolvers {
errs = append(errs, errors.Annotate(b.Close(), "closing bootstrap %s: %w", b.Address()))
}
return errors.Join(errs...)
}
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
// TODO(a.garipov): Do we need to get the TCP addresses separately?
var addrs []netip.AddrPort
if svc.running.Load() {
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.(*net.UDPAddr).AddrPort()
}
} else {
conf := svc.proxy.Config
udpAddrs := conf.UDPListenAddr
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.AddrPort()
}
}
c = &Config{
Addresses: addrs,
BootstrapServers: svc.bootstraps,
UpstreamServers: svc.upstreams,
DNS64Prefixes: svc.dns64Prefixes,
UpstreamTimeout: svc.upsTimeout,
BootstrapPreferIPv6: svc.bootstrapPreferIPv6,
UseDNS64: svc.useDNS64,
}
return c
}

View File

@@ -0,0 +1,125 @@
package dnssvc_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
func TestService(t *testing.T) {
const (
listenAddr = "127.0.0.1:0"
bootstrapAddr = "127.0.0.1:0"
upstreamAddr = "upstream.example"
)
upstreamErrCh := make(chan error, 1)
upstreamStartedCh := make(chan struct{})
upstreamSrv := &dns.Server{
Addr: bootstrapAddr,
Net: "udp",
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
pt := testutil.PanicT{}
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{},
A: netip.MustParseAddrPort(bootstrapAddr).Addr().AsSlice(),
})
writeErr := w.WriteMsg(resp)
require.NoError(pt, writeErr)
}),
NotifyStartedFunc: func() { close(upstreamStartedCh) },
}
go func() {
listenErr := upstreamSrv.ListenAndServe()
if listenErr != nil {
// Log these immediately to see what happens.
t.Logf("upstream listen error: %s", listenErr)
}
upstreamErrCh <- listenErr
}()
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
c := &dnssvc.Config{
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
UpstreamServers: []string{upstreamAddr},
DNS64Prefixes: nil,
UpstreamTimeout: testTimeout,
BootstrapPreferIPv6: false,
UseDNS64: false,
}
svc, err := dnssvc.New(c)
require.NoError(t, err)
err = svc.Start()
require.NoError(t, err)
gotConf := svc.Config()
require.NotNil(t, gotConf)
require.Len(t, gotConf.Addresses, 1)
addr := gotConf.Addresses[0]
t.Run("dns", func(t *testing.T) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{}
var resp *dns.Msg
require.Eventually(t, func() (ok bool) {
var excErr error
resp, _, excErr = cli.ExchangeContext(ctx, req, addr.String())
return excErr == nil
}, testTimeout, testTimeout/10)
assert.NotNil(t, resp)
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
err = svc.Shutdown(ctx)
require.NoError(t, err)
err = upstreamSrv.Shutdown()
require.NoError(t, err)
err, ok := testutil.RequireReceive(t, upstreamErrCh, testTimeout)
require.True(t, ok)
require.NoError(t, err)
}

View File

@@ -0,0 +1,79 @@
package websvc
import (
"crypto/tls"
"io/fs"
"net/netip"
"time"
)
// Config is the AdGuard Home web service configuration structure.
type Config struct {
// Pprof is the configuration for the pprof debug API. It must not be nil.
Pprof *PprofConfig
// ConfigManager is used to show information about services as well as
// dynamically reconfigure them.
ConfigManager ConfigManager
// Frontend is the filesystem with the frontend and other statically
// compiled files.
Frontend fs.FS
// TLS is the optional TLS configuration. If TLS is not nil,
// SecureAddresses must not be empty.
TLS *tls.Config
// Start is the time of start of AdGuard Home.
Start time.Time
// OverrideAddress is the initial or override address for the HTTP API. If
// set, it is used instead of [Addresses] and [SecureAddresses].
OverrideAddress netip.AddrPort
// Addresses are the addresses on which to serve the plain HTTP API.
Addresses []netip.AddrPort
// SecureAddresses are the addresses on which to serve the HTTPS API. If
// SecureAddresses is not empty, TLS must not be nil.
SecureAddresses []netip.AddrPort
// Timeout is the timeout for all server operations.
Timeout time.Duration
// ForceHTTPS tells if all requests to Addresses should be redirected to a
// secure address instead.
//
// TODO(a.garipov): Use; define rules, which address to redirect to.
ForceHTTPS bool
}
// PprofConfig is the configuration for the pprof debug API.
type PprofConfig struct {
Port uint16 `yaml:"port"`
Enabled bool `yaml:"enabled"`
}
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
c = &Config{
Pprof: &PprofConfig{
Port: svc.pprofPort,
Enabled: svc.pprof != nil,
},
ConfigManager: svc.confMgr,
TLS: svc.tls,
// Leave Addresses and SecureAddresses empty and get the actual
// addresses that include the :0 ones later.
Start: svc.start,
Timeout: svc.timeout,
ForceHTTPS: svc.forceHTTPS,
}
c.Addresses, c.SecureAddresses = svc.addrs()
return c
}

View File

@@ -0,0 +1,97 @@
package websvc
import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
)
// DNS Settings Handlers
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
// HTTP API.
type ReqPatchSettingsDNS struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
DNS64Prefixes []netip.Prefix `json:"dns64_prefixes"`
UpstreamTimeout aghhttp.JSONDuration `json:"upstream_timeout"`
BootstrapPreferIPv6 bool `json:"bootstrap_prefer_ipv6"`
UseDNS64 bool `json:"use_dns64"`
}
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
// DnsSettings object in the OpenAPI specification.
type HTTPAPIDNSSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
DNS64Prefixes []netip.Prefix `json:"dns64_prefixes"`
UpstreamTimeout aghhttp.JSONDuration `json:"upstream_timeout"`
BootstrapPreferIPv6 bool `json:"bootstrap_prefer_ipv6"`
UseDNS64 bool `json:"use_dns64"`
}
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
// API.
func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsDNS{
Addresses: []netip.AddrPort{},
BootstrapServers: []string{},
UpstreamServers: []string{},
}
// TODO(a.garipov): Validate nulls and proper JSON patch.
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
return
}
newConf := &dnssvc.Config{
Addresses: req.Addresses,
BootstrapServers: req.BootstrapServers,
UpstreamServers: req.UpstreamServers,
DNS64Prefixes: req.DNS64Prefixes,
UpstreamTimeout: time.Duration(req.UpstreamTimeout),
BootstrapPreferIPv6: req.BootstrapPreferIPv6,
UseDNS64: req.UseDNS64,
}
ctx := r.Context()
err = svc.confMgr.UpdateDNS(ctx, newConf)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("updating: %w", err))
return
}
newSvc := svc.confMgr.DNS()
err = newSvc.Start()
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
return
}
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIDNSSettings{
Addresses: newConf.Addresses,
BootstrapServers: newConf.BootstrapServers,
UpstreamServers: newConf.UpstreamServers,
DNS64Prefixes: newConf.DNS64Prefixes,
UpstreamTimeout: aghhttp.JSONDuration(newConf.UpstreamTimeout),
BootstrapPreferIPv6: newConf.BootstrapPreferIPv6,
UseDNS64: newConf.UseDNS64,
})
}

View File

@@ -0,0 +1,75 @@
package websvc_test
import (
"context"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandlePatchSettingsDNS(t *testing.T) {
wantDNS := &websvc.HTTPAPIDNSSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:53")},
BootstrapServers: []string{"1.0.0.1"},
UpstreamServers: []string{"1.1.1.1"},
DNS64Prefixes: []netip.Prefix{netip.MustParsePrefix("1234::/64")},
UpstreamTimeout: aghhttp.JSONDuration(2 * time.Second),
BootstrapPreferIPv6: true,
UseDNS64: true,
}
var started atomic.Bool
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
OnStart: func() (err error) {
started.Store(true)
return nil
},
OnShutdown: func(_ context.Context) (err error) { panic("not implemented") },
OnConfig: func() (c *dnssvc.Config) { panic("not implemented") },
}
}
confMgr.onUpdateDNS = func(ctx context.Context, c *dnssvc.Config) (err error) {
return nil
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsDNS,
}
req := jobj{
"addresses": wantDNS.Addresses,
"bootstrap_servers": wantDNS.BootstrapServers,
"upstream_servers": wantDNS.UpstreamServers,
"dns64_prefixes": wantDNS.DNS64Prefixes,
"upstream_timeout": wantDNS.UpstreamTimeout,
"bootstrap_prefer_ipv6": wantDNS.BootstrapPreferIPv6,
"use_dns64": wantDNS.UseDNS64,
}
respBody := httpPatch(t, u, req, http.StatusOK)
resp := &websvc.HTTPAPIDNSSettings{}
err := json.Unmarshal(respBody, resp)
require.NoError(t, err)
assert.True(t, started.Load())
assert.Equal(t, wantDNS, resp)
assert.Equal(t, wantDNS, resp)
}

View File

@@ -0,0 +1,123 @@
package websvc
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/golibs/log"
)
// HTTP Settings Handlers
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
// HTTP API.
type ReqPatchSettingsHTTP struct {
// TODO(a.garipov): Add more as we go.
//
// TODO(a.garipov): Add wait time.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout aghhttp.JSONDuration `json:"timeout"`
}
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
// HttpSettings object in the OpenAPI specification.
type HTTPAPIHTTPSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout aghhttp.JSONDuration `json:"timeout"`
ForceHTTPS bool `json:"force_https"`
}
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
// HTTP API.
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsHTTP{}
// TODO(a.garipov): Validate nulls and proper JSON patch.
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
return
}
newConf := &Config{
Pprof: &PprofConfig{
Port: svc.pprofPort,
Enabled: svc.pprof != nil,
},
ConfigManager: svc.confMgr,
Frontend: svc.frontend,
TLS: svc.tls,
Addresses: req.Addresses,
SecureAddresses: req.SecureAddresses,
Timeout: time.Duration(req.Timeout),
ForceHTTPS: svc.forceHTTPS,
}
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIHTTPSettings{
Addresses: newConf.Addresses,
SecureAddresses: newConf.SecureAddresses,
Timeout: aghhttp.JSONDuration(newConf.Timeout),
ForceHTTPS: newConf.ForceHTTPS,
})
cancelUpd := func() {}
updCtx := context.Background()
ctx := r.Context()
if deadline, ok := ctx.Deadline(); ok {
updCtx, cancelUpd = context.WithDeadline(updCtx, deadline)
}
// Launch the new HTTP service in a separate goroutine to let this handler
// finish and thus, this server to shutdown.
go svc.relaunch(updCtx, cancelUpd, newConf)
}
// relaunch updates the web service in the configuration manager and starts it.
// It is intended to be used as a goroutine.
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
defer log.OnPanic("websvc: relaunching")
defer cancel()
err := svc.confMgr.UpdateWeb(ctx, newConf)
if err != nil {
log.Error("websvc: updating web: %s", err)
return
}
// TODO(a.garipov): Consider better ways to do this.
const maxUpdDur = 5 * time.Second
updStart := time.Now()
var newSvc agh.ServiceWithConfig[*Config]
for newSvc = svc.confMgr.Web(); newSvc == svc; {
if time.Since(updStart) >= maxUpdDur {
log.Error("websvc: failed to update svc after %s", maxUpdDur)
return
}
log.Debug("websvc: waiting for new websvc to be configured")
time.Sleep(100 * time.Millisecond)
}
err = newSvc.Start()
if err != nil {
log.Error("websvc: new svc failed to start with error: %s", err)
}
}

View File

@@ -0,0 +1,66 @@
package websvc_test
import (
"context"
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandlePatchSettingsHTTP(t *testing.T) {
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")},
Timeout: aghhttp.JSONDuration(10 * time.Second),
ForceHTTPS: false,
}
svc, err := websvc.New(&websvc.Config{
Pprof: &websvc.PprofConfig{
Enabled: false,
},
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: 5 * time.Second,
ForceHTTPS: true,
})
require.NoError(t, err)
confMgr := newConfigManager()
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { return svc }
confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) { return nil }
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsHTTP,
}
req := jobj{
"addresses": wantWeb.Addresses,
"secure_addresses": wantWeb.SecureAddresses,
"timeout": wantWeb.Timeout,
"force_https": wantWeb.ForceHTTPS,
}
respBody := httpPatch(t, u, req, http.StatusOK)
resp := &websvc.HTTPAPIHTTPSettings{}
err = json.Unmarshal(respBody, resp)
require.NoError(t, err)
assert.Equal(t, wantWeb, resp)
}

View File

@@ -0,0 +1,38 @@
package websvc
import (
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
)
// Middlewares
// jsonMw sets the content type of the response to application/json.
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
h.ServeHTTP(w, r)
}
return http.HandlerFunc(f)
}
// logMw logs the queries with level debug.
func logMw(h http.Handler) (wrapped http.HandlerFunc) {
f := func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
m, u := r.Method, r.RequestURI
log.Debug("websvc: %s %s started", m, u)
defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }()
h.ServeHTTP(w, r)
}
return http.HandlerFunc(f)
}

View File

@@ -0,0 +1,14 @@
package websvc
// Path constants
const (
PathRoot = "/"
PathFrontend = "/*filepath"
PathHealthCheck = "/health-check"
PathV1SettingsAll = "/api/v1/settings/all"
PathV1SettingsDNS = "/api/v1/settings/dns"
PathV1SettingsHTTP = "/api/v1/settings/http"
PathV1SystemInfo = "/api/v1/system/info"
)

View File

@@ -0,0 +1,47 @@
package websvc
import (
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
)
// All Settings Handlers
// RespGetV1SettingsAll describes the response of the GET /api/v1/settings/all
// HTTP API.
type RespGetV1SettingsAll struct {
// TODO(a.garipov): Add more as we go.
DNS *HTTPAPIDNSSettings `json:"dns"`
HTTP *HTTPAPIHTTPSettings `json:"http"`
}
// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP
// API.
func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) {
dnsSvc := svc.confMgr.DNS()
dnsConf := dnsSvc.Config()
webSvc := svc.confMgr.Web()
httpConf := webSvc.Config()
// TODO(a.garipov): Add all currently supported parameters.
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SettingsAll{
DNS: &HTTPAPIDNSSettings{
Addresses: dnsConf.Addresses,
BootstrapServers: dnsConf.BootstrapServers,
UpstreamServers: dnsConf.UpstreamServers,
DNS64Prefixes: dnsConf.DNS64Prefixes,
UpstreamTimeout: aghhttp.JSONDuration(dnsConf.UpstreamTimeout),
BootstrapPreferIPv6: dnsConf.BootstrapPreferIPv6,
UseDNS64: dnsConf.UseDNS64,
},
HTTP: &HTTPAPIHTTPSettings{
Addresses: httpConf.Addresses,
SecureAddresses: httpConf.SecureAddresses,
Timeout: aghhttp.JSONDuration(httpConf.Timeout),
ForceHTTPS: httpConf.ForceHTTPS,
},
})
}

View File

@@ -0,0 +1,84 @@
package websvc_test
import (
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandleGetSettingsAll(t *testing.T) {
// TODO(a.garipov): Add all currently supported parameters.
wantDNS := &websvc.HTTPAPIDNSSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")},
BootstrapServers: []string{"94.140.14.140", "94.140.14.141"},
UpstreamServers: []string{"94.140.14.14", "1.1.1.1"},
UpstreamTimeout: aghhttp.JSONDuration(1 * time.Second),
BootstrapPreferIPv6: true,
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: aghhttp.JSONDuration(5 * time.Second),
ForceHTTPS: true,
}
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
c, err := dnssvc.New(&dnssvc.Config{
Addresses: wantDNS.Addresses,
UpstreamServers: wantDNS.UpstreamServers,
BootstrapServers: wantDNS.BootstrapServers,
UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout),
BootstrapPreferIPv6: true,
})
require.NoError(t, err)
return c
}
svc, err := websvc.New(&websvc.Config{
Pprof: &websvc.PprofConfig{
Enabled: false,
},
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: time.Duration(wantWeb.Timeout),
ForceHTTPS: true,
})
require.NoError(t, err)
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return svc
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsAll,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{}
err = json.Unmarshal(body, resp)
require.NoError(t, err)
assert.Equal(t, wantDNS, resp.DNS)
assert.Equal(t, wantWeb, resp.HTTP)
}

View File

@@ -0,0 +1,36 @@
package websvc
import (
"net/http"
"runtime"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/version"
)
// System Handlers
// RespGetV1SystemInfo describes the response of the GET /api/v1/system/info
// HTTP API.
type RespGetV1SystemInfo struct {
Arch string `json:"arch"`
Channel string `json:"channel"`
OS string `json:"os"`
NewVersion string `json:"new_version,omitempty"`
Start aghhttp.JSONTime `json:"start"`
Version string `json:"version"`
}
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
// API.
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SystemInfo{
Arch: runtime.GOARCH,
Channel: version.Channel(),
OS: runtime.GOOS,
// TODO(a.garipov): Fill this when we have an updater.
NewVersion: "",
Start: aghhttp.JSONTime(svc.start),
Version: version.Version(),
})
}

View File

@@ -0,0 +1,37 @@
package websvc_test
import (
"encoding/json"
"net/http"
"net/url"
"runtime"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_handleGetV1SystemInfo(t *testing.T) {
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SystemInfo,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SystemInfo{}
err := json.Unmarshal(body, resp)
require.NoError(t, err)
// TODO(a.garipov): Consider making version.Channel and version.Version
// testable and test these better.
assert.NotEmpty(t, resp.Channel)
assert.Equal(t, resp.Arch, runtime.GOARCH)
assert.Equal(t, resp.OS, runtime.GOOS)
assert.Equal(t, testStart, time.Time(resp.Start))
}

View File

@@ -0,0 +1,31 @@
package websvc
import (
"net"
"sync"
)
// Wait Listener
// waitListener is a wrapper around a listener that also calls wg.Done() on the
// first call to Accept. It is useful in situations where it is important to
// catch the precise moment of the first call to Accept, for example when
// starting an HTTP server.
//
// TODO(a.garipov): Move to aghnet?
type waitListener struct {
net.Listener
firstAcceptWG *sync.WaitGroup
firstAcceptOnce sync.Once
}
// type check
var _ net.Listener = (*waitListener)(nil)
// Accept implements the [net.Listener] interface for *waitListener.
func (l *waitListener) Accept() (conn net.Conn, err error) {
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
return l.Listener.Accept()
}

View File

@@ -0,0 +1,40 @@
package websvc
import (
"net"
"sync"
"sync/atomic"
"testing"
"github.com/AdguardTeam/golibs/testutil/fakenet"
"github.com/stretchr/testify/assert"
)
func TestWaitListener_Accept(t *testing.T) {
var accepted atomic.Bool
var l net.Listener = &fakenet.Listener{
OnAccept: func() (conn net.Conn, err error) {
accepted.Store(true)
return nil, nil
},
OnAddr: func() (addr net.Addr) { panic("not implemented") },
OnClose: func() (err error) { panic("not implemented") },
}
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
var wrapper net.Listener = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
_, _ = wrapper.Accept()
}()
wg.Wait()
assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10)
}

View File

@@ -0,0 +1,329 @@
// Package websvc contains the AdGuard Home HTTP API service.
//
// NOTE: Packages other than cmd must not import this package, as it imports
// most other packages.
//
// TODO(a.garipov): Add tests.
package websvc
import (
"context"
"crypto/tls"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/netip"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/pprofutil"
httptreemux "github.com/dimfeld/httptreemux/v5"
)
// ConfigManager is the configuration manager interface.
type ConfigManager interface {
DNS() (svc agh.ServiceWithConfig[*dnssvc.Config])
Web() (svc agh.ServiceWithConfig[*Config])
UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error)
UpdateWeb(ctx context.Context, c *Config) (err error)
}
// Service is the AdGuard Home web service. A nil *Service is a valid
// [agh.Service] that does nothing.
type Service struct {
confMgr ConfigManager
frontend fs.FS
tls *tls.Config
pprof *http.Server
start time.Time
overrideAddr netip.AddrPort
servers []*http.Server
timeout time.Duration
pprofPort uint16
forceHTTPS bool
}
// New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. The fields of c must not be modified after
// calling New.
//
// TODO(a.garipov): Get rid of this special handling of nil or explain it
// better.
func New(c *Config) (svc *Service, err error) {
if c == nil {
return nil, nil
}
svc = &Service{
confMgr: c.ConfigManager,
frontend: c.Frontend,
tls: c.TLS,
start: c.Start,
overrideAddr: c.OverrideAddress,
timeout: c.Timeout,
forceHTTPS: c.ForceHTTPS,
}
mux := newMux(svc)
if svc.overrideAddr != (netip.AddrPort{}) {
svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)}
} else {
for _, a := range c.Addresses {
svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout))
}
for _, a := range c.SecureAddresses {
svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout))
}
}
svc.setupPprof(c.Pprof)
return svc, nil
}
// setupPprof sets the pprof properties of svc.
func (svc *Service) setupPprof(c *PprofConfig) {
if !c.Enabled {
// Set to zero explicitly in case pprof used to be enabled before a
// reconfiguration took place.
runtime.SetBlockProfileRate(0)
runtime.SetMutexProfileFraction(0)
return
}
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
pprofMux := http.NewServeMux()
pprofutil.RoutePprof(pprofMux)
svc.pprofPort = c.Port
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
// TODO(a.garipov): Consider making pprof timeout configurable.
svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute)
}
// newSrv returns a new *http.Server with the given parameters.
func newSrv(
addr netip.AddrPort,
tlsConf *tls.Config,
h http.Handler,
timeout time.Duration,
) (srv *http.Server) {
addrStr := addr.String()
srv = &http.Server{
Addr: addrStr,
Handler: h,
TLSConfig: tlsConf,
ReadTimeout: timeout,
WriteTimeout: timeout,
IdleTimeout: timeout,
ReadHeaderTimeout: timeout,
}
if tlsConf == nil {
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
} else {
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
}
return srv
}
// newMux returns a new HTTP request multiplexer for the AdGuard Home web
// service.
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
mux = httptreemux.NewContextMux()
routes := []struct {
handler http.HandlerFunc
method string
pattern string
isJSON bool
}{{
handler: svc.handleGetHealthCheck,
method: http.MethodGet,
pattern: PathHealthCheck,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
method: http.MethodGet,
pattern: PathFrontend,
isJSON: false,
}, {
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
method: http.MethodGet,
pattern: PathRoot,
isJSON: false,
}, {
handler: svc.handleGetSettingsAll,
method: http.MethodGet,
pattern: PathV1SettingsAll,
isJSON: true,
}, {
handler: svc.handlePatchSettingsDNS,
method: http.MethodPatch,
pattern: PathV1SettingsDNS,
isJSON: true,
}, {
handler: svc.handlePatchSettingsHTTP,
method: http.MethodPatch,
pattern: PathV1SettingsHTTP,
isJSON: true,
}, {
handler: svc.handleGetV1SystemInfo,
method: http.MethodGet,
pattern: PathV1SystemInfo,
isJSON: true,
}}
for _, r := range routes {
var hdlr http.Handler
if r.isJSON {
hdlr = jsonMw(r.handler)
} else {
hdlr = r.handler
}
mux.Handle(r.method, r.pattern, logMw(hdlr))
}
return mux
}
// addrs returns all addresses on which this server serves the HTTP API. addrs
// must not be called simultaneously with Start. If svc was initialized with
// ":0" addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
if svc.overrideAddr != (netip.AddrPort{}) {
return []netip.AddrPort{svc.overrideAddr}, nil
}
for _, srv := range svc.servers {
// Use MustParseAddrPort, since no errors should technically happen
// here, because all servers must have a valid address.
addrPort := netip.MustParseAddrPort(srv.Addr)
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
// of relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
addrs = append(addrs, addrPort)
} else {
secureAddrs = append(secureAddrs, addrPort)
}
}
return addrs, secureAddrs
}
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "OK")
}
// type check
var _ agh.Service = (*Service)(nil)
// Start implements the [agh.Service] interface for *Service. svc may be nil.
// After Start exits, all HTTP servers have tried to start, possibly failing and
// writing error messages to the log.
func (svc *Service) Start() (err error) {
if svc == nil {
return nil
}
pprofEnabled := svc.pprof != nil
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled)
wg := &sync.WaitGroup{}
wg.Add(srvNum)
for _, srv := range svc.servers {
go serve(srv, wg)
}
if pprofEnabled {
go serve(svc.pprof, wg)
}
wg.Wait()
return nil
}
// serve starts and runs srv and writes all errors into its log.
func serve(srv *http.Server, wg *sync.WaitGroup) {
addr := srv.Addr
defer log.OnPanic(addr)
var proto string
var l net.Listener
var err error
if srv.TLSConfig == nil {
proto = "http"
l, err = net.Listen("tcp", addr)
} else {
proto = "https"
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
}
if err != nil {
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
}
// Update the server's address in case the address had the port zero, which
// would mean that a random available port was automatically chosen.
srv.Addr = l.Addr().String()
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
l = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
err = srv.Serve(l)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
}
}
// Shutdown implements the [agh.Service] interface for *Service. svc may be
// nil.
func (svc *Service) Shutdown(ctx context.Context) (err error) {
if svc == nil {
return nil
}
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
var errs []error
for _, srv := range svc.servers {
shutdownErr := srv.Shutdown(ctx)
if shutdownErr != nil {
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr))
}
}
if svc.pprof != nil {
shutdownErr := svc.pprof.Shutdown(ctx)
if shutdownErr != nil {
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr))
}
}
return errors.Join(errs...)
}

View File

@@ -0,0 +1,6 @@
package websvc
import "time"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second

View File

@@ -0,0 +1,196 @@
package websvc_test
import (
"bytes"
"context"
"encoding/json"
"io"
"io/fs"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testStart is the server start value for tests.
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
// type check
var _ websvc.ConfigManager = (*configManager)(nil)
// configManager is a [websvc.ConfigManager] for tests.
type configManager struct {
onDNS func() (svc agh.ServiceWithConfig[*dnssvc.Config])
onWeb func() (svc agh.ServiceWithConfig[*websvc.Config])
onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error)
onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error)
}
// DNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) DNS() (svc agh.ServiceWithConfig[*dnssvc.Config]) {
return m.onDNS()
}
// Web implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) Web() (svc agh.ServiceWithConfig[*websvc.Config]) {
return m.onWeb()
}
// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
return m.onUpdateDNS(ctx, c)
}
// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
return m.onUpdateWeb(ctx, c)
}
// newConfigManager returns a *configManager all methods of which panic.
func newConfigManager() (m *configManager) {
return &configManager{
onDNS: func() (svc agh.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") },
onWeb: func() (svc agh.ServiceWithConfig[*websvc.Config]) { panic("not implemented") },
onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) {
panic("not implemented")
},
onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) {
panic("not implemented")
},
}
}
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(
t testing.TB,
confMgr websvc.ConfigManager,
) (svc *websvc.Service, addr netip.AddrPort) {
t.Helper()
c := &websvc.Config{
Pprof: &websvc.PprofConfig{
Enabled: false,
},
ConfigManager: confMgr,
Frontend: &fakefs.FS{
OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist },
},
TLS: nil,
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: nil,
Timeout: testTimeout,
Start: testStart,
ForceHTTPS: false,
}
svc, err := websvc.New(c)
require.NoError(t, err)
err = svc.Start()
require.NoError(t, err)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
})
c = svc.Config()
require.NotNil(t, c)
require.Len(t, c.Addresses, 1)
return svc, c.Addresses[0]
}
// jobj is a utility alias for JSON objects.
type jobj map[string]any
// httpGet is a helper that performs an HTTP GET request and returns the body of
// the response as well as checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
t.Helper()
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded
// reqBody as the request body and returns the body of the response as well as
// checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) {
t.Helper()
b, err := json.Marshal(reqBody)
require.NoErrorf(t, err, "marshaling reqBody")
req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b))
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
func TestService_Start_getHealthCheck(t *testing.T) {
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathHealthCheck,
}
body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body)
}