all: sync with master, upd chlog

This commit is contained in:
Eugene Burkov
2025-03-11 13:36:04 +03:00
parent 805de59805
commit 474cba52f0
166 changed files with 8809 additions and 10440 deletions

View File

@@ -356,7 +356,7 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
u, _ = Context.auth.findUser(user, pass)
u, _ = globalContext.auth.findUser(user, pass)
return u
}
@@ -408,13 +408,12 @@ func (a *Auth) authRequired() bool {
// bytes of sessionTokenSize length.
//
// TODO(e.burkov): Think about using byte array instead of byte slice.
func newSessionToken() (data []byte, err error) {
func newSessionToken() (data []byte) {
randData := make([]byte, sessionTokenSize)
_, err = rand.Read(randData)
if err != nil {
return nil, err
}
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
// unrecoverably instead.
_, _ = rand.Read(randData)
return randData, nil
return randData
}

View File

@@ -1,8 +1,6 @@
package home
import (
"bytes"
"crypto/rand"
"encoding/hex"
"path/filepath"
"testing"
@@ -12,23 +10,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestNewSessionToken(t *testing.T) {
// Successful case.
token, err := newSessionToken()
require.NoError(t, err)
assert.Len(t, token, sessionTokenSize)
// Break the rand.Reader.
prevReader := rand.Reader
t.Cleanup(func() { rand.Reader = prevReader })
rand.Reader = &bytes.Buffer{}
// Unsuccessful case.
token, err = newSessionToken()
require.Error(t, err)
assert.Empty(t, token)
}
func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")
@@ -47,8 +28,7 @@ func TestAuth(t *testing.T) {
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.removeSession("notfound")
sess, err := newSessionToken()
require.NoError(t, err)
sess := newSessionToken()
sessStr := hex.EncodeToString(sess)
now := time.Now().UTC().Unix()

View File

@@ -47,11 +47,7 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
rateLimiter.remove(addr)
}
sess, err := newSessionToken()
if err != nil {
return nil, fmt.Errorf("generating token: %w", err)
}
sess := newSessionToken()
now := time.Now().UTC()
a.addSession(sess, &session{
@@ -155,7 +151,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil {
if rateLimiter := globalContext.auth.rateLimiter; rateLimiter != nil {
if left := rateLimiter.check(remoteIP); left > 0 {
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
writeErrorWithIP(
@@ -176,10 +172,10 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
}
cookie, err := Context.auth.newCookie(req, remoteIP)
cookie, err := globalContext.auth.newCookie(req, remoteIP)
if err != nil {
logIP := remoteIP
if Context.auth.trustedProxies.Contains(ip.Unmap()) {
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
logIP = ip.String()
}
@@ -213,7 +209,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
return
}
Context.auth.removeSession(c.Value)
globalContext.auth.removeSession(c.Value)
c = &http.Cookie{
Name: sessionCookieName,
@@ -232,7 +228,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
// RegisterAuthHandlers - register handlers
func RegisterAuthHandlers() {
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
globalContext.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
httpRegister(http.MethodGet, "/control/logout", handleLogout)
}
@@ -254,13 +250,13 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
// Check Basic authentication.
user, pass, hasBasic := r.BasicAuth()
if hasBasic {
_, isAuthenticated = Context.auth.findUser(user, pass)
_, isAuthenticated = globalContext.auth.findUser(user, pass)
if !isAuthenticated {
log.Info("%s: invalid basic authorization value", pref)
}
}
} else {
res := Context.auth.checkSession(cookie.Value)
res := globalContext.auth.checkSession(cookie.Value)
isAuthenticated = res == checkSessionOK
if !isAuthenticated {
log.Debug("%s: invalid cookie value: %q", pref, cookie)
@@ -294,12 +290,12 @@ func optionalAuth(
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
p := r.URL.Path
authRequired := Context.auth != nil && Context.auth.authRequired()
authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
if p == "/login.html" {
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
// Redirect to the dashboard if already authenticated.
res := Context.auth.checkSession(cookie.Value)
res := globalContext.auth.checkSession(cookie.Value)
if res == checkSessionOK {
http.Redirect(w, r, "", http.StatusFound)

View File

@@ -39,7 +39,7 @@ func TestAuthHTTP(t *testing.T) {
users := []webUser{
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
}
Context.auth = InitAuth(fn, users, 60, nil, nil)
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
handlerCalled := false
handler := func(_ http.ResponseWriter, _ *http.Request) {
@@ -68,7 +68,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
// perform login
cookie, err := Context.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
require.NoError(t, err)
require.NotNil(t, cookie)
@@ -114,7 +114,7 @@ func TestAuthHTTP(t *testing.T) {
assert.True(t, handlerCalled)
r.Header.Del(httphdr.Cookie)
Context.auth.Close()
globalContext.auth.Close()
}
func TestRealIP(t *testing.T) {

View File

@@ -12,17 +12,14 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
)
// clientsContainer is the storage of all runtime and persistent clients.
@@ -75,6 +72,7 @@ func (clients *clientsContainer) Init(
etcHosts *aghnet.HostsContainer,
arpDB arpdb.Interface,
filteringConf *filtering.Config,
sigHdlr *signalHandler,
) (err error) {
// TODO(s.chzhen): Refactor it.
if clients.storage != nil {
@@ -109,6 +107,7 @@ func (clients *clientsContainer) Init(
clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
Logger: baseLogger.With(slogutil.KeyPrefix, "client_storage"),
Clock: timeutil.SystemClock{},
InitialClients: confClients,
DHCP: dhcpServer,
EtcHosts: hosts,
@@ -120,6 +119,8 @@ func (clients *clientsContainer) Init(
return fmt.Errorf("init client storage: %w", err)
}
sigHdlr.addClientStorage(clients.storage)
return nil
}
@@ -370,63 +371,6 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
return true
}
// type check
var _ dnsforward.ClientsContainer = (*clientsContainer)(nil)
// UpstreamConfigByID implements the [dnsforward.ClientsContainer] interface for
// *clientsContainer. upsConf is nil if the client isn't found or if the client
// has no custom upstreams.
func (clients *clientsContainer) UpstreamConfigByID(
id string,
bootstrap upstream.Resolver,
) (conf *proxy.CustomUpstreamConfig, err error) {
clients.lock.Lock()
defer clients.lock.Unlock()
c, ok := clients.storage.Find(id)
if !ok {
return nil, nil
} else if c.UpstreamConfig != nil {
return c.UpstreamConfig, nil
}
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
if len(upstreams) == 0 {
return nil, nil
}
var upsConf *proxy.UpstreamConfig
upsConf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: bootstrap,
Timeout: time.Duration(config.DNS.UpstreamTimeout),
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
PreferIPv6: config.DNS.BootstrapPreferIPv6,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf = proxy.NewCustomUpstreamConfig(
upsConf,
c.UpstreamsCacheEnabled,
int(c.UpstreamsCacheSize),
config.DNS.EDNSClientSubnet.Enabled,
)
c.UpstreamConfig = conf
// TODO(s.chzhen): Pass context.
err = clients.storage.Update(context.TODO(), c.Name, c)
if err != nil {
return nil, fmt.Errorf("setting upstream config: %w", err)
}
return conf, nil
}
// type check
var _ client.AddressUpdater = (*clientsContainer)(nil)

View File

@@ -1,15 +1,12 @@
package home
import (
"net"
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -31,34 +28,10 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
nil,
nil,
&filtering.Config{},
newSignalHandler(nil, nil),
)
require.NoError(t, err)
return c
}
func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add client with upstreams.
err := clients.storage.Add(ctx, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{
"1.1.1.1",
"[/example.org/]8.8.8.8",
},
})
require.NoError(t, err)
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
assert.Nil(t, upsConf)
assert.NoError(t, err)
upsConf, err = clients.UpstreamConfigByID("1.1.1.1", net.DefaultResolver)
require.NotNil(t, upsConf)
assert.NoError(t, err)
}

View File

@@ -486,9 +486,9 @@ var config = &configuration{
// configFilePath returns the absolute path to the symlink-evaluated path to the
// current config file.
func configFilePath() (confPath string) {
confPath, err := filepath.EvalSymlinks(Context.confFilePath)
confPath, err := filepath.EvalSymlinks(globalContext.confFilePath)
if err != nil {
confPath = Context.confFilePath
confPath = globalContext.confFilePath
logFunc := log.Error
if errors.Is(err, os.ErrNotExist) {
logFunc = log.Debug
@@ -498,7 +498,7 @@ func configFilePath() (confPath string) {
}
if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, confPath)
confPath = filepath.Join(globalContext.workDir, confPath)
}
return confPath
@@ -530,8 +530,8 @@ func parseConfig() (err error) {
}
migrator := configmigrate.New(&configmigrate.Config{
WorkingDir: Context.workDir,
DataDir: Context.getDataDir(),
WorkingDir: globalContext.workDir,
DataDir: globalContext.getDataDir(),
})
var upgraded bool
@@ -640,31 +640,31 @@ func readConfigFile() (fileData []byte, err error) {
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
func (c *configuration) write() (err error) {
func (c *configuration) write(tlsMgr *tlsManager) (err error) {
c.Lock()
defer c.Unlock()
if Context.auth != nil {
config.Users = Context.auth.usersList()
if globalContext.auth != nil {
config.Users = globalContext.auth.usersList()
}
if Context.tls != nil {
if tlsMgr != nil {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
tlsMgr.WriteDiskConfig(&tlsConf)
config.TLS = tlsConf
}
if Context.stats != nil {
if globalContext.stats != nil {
statsConf := stats.Config{}
Context.stats.WriteDiskConfig(&statsConf)
globalContext.stats.WriteDiskConfig(&statsConf)
config.Stats.Interval = timeutil.Duration(statsConf.Limit)
config.Stats.Enabled = statsConf.Enabled
config.Stats.Ignored = statsConf.Ignored.Values()
}
if Context.queryLog != nil {
if globalContext.queryLog != nil {
dc := querylog.Config{}
Context.queryLog.WriteDiskConfig(&dc)
globalContext.queryLog.WriteDiskConfig(&dc)
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
config.QueryLog.Enabled = dc.Enabled
config.QueryLog.FileEnabled = dc.FileEnabled
@@ -673,14 +673,14 @@ func (c *configuration) write() (err error) {
config.QueryLog.Ignored = dc.Ignored.Values()
}
if Context.filters != nil {
Context.filters.WriteDiskConfig(config.Filtering)
if globalContext.filters != nil {
globalContext.filters.WriteDiskConfig(config.Filtering)
config.Filters = config.Filtering.Filters
config.WhitelistFilters = config.Filtering.WhitelistFilters
config.UserRules = config.Filtering.UserRules
}
if s := Context.dnsServer; s != nil {
if s := globalContext.dnsServer; s != nil {
c := dnsforward.Config{}
s.WriteDiskConfig(&c)
dns := &config.DNS
@@ -695,11 +695,11 @@ func (c *configuration) write() (err error) {
dns.UpstreamTimeout = timeutil.Duration(s.UpstreamTimeout())
}
if Context.dhcpServer != nil {
Context.dhcpServer.WriteDiskConfig(config.DHCP)
if globalContext.dhcpServer != nil {
globalContext.dhcpServer.WriteDiskConfig(config.DHCP)
}
config.Clients.Persistent = Context.clients.forConfig()
config.Clients.Persistent = globalContext.clients.forConfig()
confPath := configFilePath()
log.Debug("writing config file %q", confPath)
@@ -726,14 +726,14 @@ func setContextTLSCipherIDs() (err error) {
if len(config.TLS.OverrideTLSCiphers) == 0 {
log.Info("tls: using default ciphers")
Context.tlsCipherIDs = aghtls.SaferCipherSuites()
globalContext.tlsCipherIDs = aghtls.SaferCipherSuites()
return nil
}
log.Info("tls: overriding ciphers: %s", config.TLS.OverrideTLSCiphers)
Context.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
globalContext.tlsCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers)
if err != nil {
return fmt.Errorf("parsing override ciphers: %w", err)
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/NYTimes/gziphandler"
@@ -69,7 +68,8 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err
// collectDNSAddresses returns the list of DNS addresses the server is listening
// on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) {
// tlsMgr must not be nil.
func collectDNSAddresses(tlsMgr *tlsManager) (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost())
} else {
@@ -79,7 +79,7 @@ func collectDNSAddresses() (addrs []string, err error) {
}
}
de := getDNSEncryption()
de := getDNSEncryption(tlsMgr)
if de.https != "" {
addrs = append(addrs, de.https)
}
@@ -114,8 +114,8 @@ type statusResponse struct {
IsRunning bool `json:"running"`
}
func handleStatus(w http.ResponseWriter, r *http.Request) {
dnsAddrs, err := collectDNSAddresses()
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
@@ -129,10 +129,10 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
protectionDisabledUntil *time.Time
protectionEnabled bool
)
if Context.dnsServer != nil {
if globalContext.dnsServer != nil {
fltConf = &dnsforward.Config{}
Context.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus()
globalContext.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
}
var resp statusResponse
@@ -162,42 +162,42 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
// IsDHCPAvailable field is now false by default for Windows.
if runtime.GOOS != "windows" {
resp.IsDHCPAvailable = Context.dhcpServer != nil
resp.IsDHCPAvailable = globalContext.dhcpServer != nil
}
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// ------------------------
// registration of handlers
// ------------------------
// registerControlHandlers sets up HTTP handlers for various control endpoints.
// web must not be nil.
func registerControlHandlers(web *webAPI) {
Context.mux.HandleFunc(
globalContext.mux.HandleFunc(
"/control/version.json",
postInstall(optionalAuth(web.handleVersionJSON)),
)
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
httpRegister(http.MethodGet, "/control/status", handleStatus)
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
// No auth is necessary for DoH/DoT configurations
Context.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
Context.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
globalContext.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
RegisterAuthHandlers()
}
// httpRegister registers an HTTP handler.
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
Context.mux.HandleFunc(url, postInstall(handler))
globalContext.mux.HandleFunc(url, postInstall(handler))
return
}
Context.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
globalContext.mux.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
}
// ensure returns a wrapped handler that makes sure that the request has the
@@ -207,11 +207,7 @@ func ensure(
handler func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) {
return func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
m, u := r.Method, r.URL
log.Debug("started %s %s %s", m, r.Host, u)
defer func() { log.Debug("finished %s %s %s in %s", m, r.Host, u, time.Since(start)) }()
m := r.Method
if m != method {
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
@@ -223,8 +219,8 @@ func ensure(
return
}
Context.controlLock.Lock()
defer Context.controlLock.Unlock()
globalContext.controlLock.Lock()
defer globalContext.controlLock.Unlock()
}
handler(w, r)
@@ -293,7 +289,7 @@ func ensureHandler(method string, handler func(http.ResponseWriter, *http.Reques
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !Context.firstRun {
if !globalContext.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
@@ -320,7 +316,7 @@ func preInstallHandler(handler http.Handler) http.Handler {
// HTTPS-related headers. If proceed is true, the middleware must continue
// handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
web := Context.web
web := globalContext.web
if web.httpsServer.server == nil {
return true
}
@@ -409,7 +405,7 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if Context.firstRun && !strings.HasPrefix(path, "/install.") &&
if globalContext.firstRun && !strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "install.html", http.StatusFound)

View File

@@ -428,20 +428,20 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
curConfig := &configuration{}
copyInstallSettings(curConfig, config)
Context.firstRun = false
globalContext.firstRun = false
config.DNS.BindHosts = []netip.Addr{req.DNS.IP}
config.DNS.Port = req.DNS.Port
config.Filtering.SafeFSPatterns = []string{
filepath.Join(Context.workDir, userFilterDataDir, "*"),
filepath.Join(globalContext.workDir, userFilterDataDir, "*"),
}
config.HTTPConfig.Address = netip.AddrPortFrom(req.Web.IP, req.Web.Port)
u := &webUser{
Name: req.Username,
}
err = Context.auth.addUser(u, req.Password)
err = globalContext.auth.addUser(u, req.Password)
if err != nil {
Context.firstRun = true
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "%s", err)
@@ -452,18 +452,18 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server.
err = startMods(web.baseLogger)
err = startMods(r.Context(), web.baseLogger, web.tlsManager)
if err != nil {
Context.firstRun = true
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
return
}
err = config.write()
err = config.write(web.tlsManager)
if err != nil {
Context.firstRun = true
globalContext.firstRun = true
copyInstallSettings(config, curConfig)
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write config: %s", err)
@@ -527,8 +527,33 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return req, restartHTTP, err
}
func (web *webAPI) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
// startMods initializes and starts the DNS server after installation.
// baseLogger and tlsMgr must not be nil.
func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager) (err error) {
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
if err != nil {
return err
}
err = initDNS(baseLogger, tlsMgr, statsDir, querylogDir)
if err != nil {
return err
}
tlsMgr.start(ctx)
err = startDNSServer()
if err != nil {
closeDNSServer()
return err
}
return nil
}
func (web *webAPI) registerInstallHandlers() {
globalContext.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
globalContext.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
globalContext.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))
}

View File

@@ -62,7 +62,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
err = resp.setAllowedToAutoUpdate()
err = resp.setAllowedToAutoUpdate(web.tlsManager)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -158,14 +158,14 @@ type versionResponse struct {
}
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS.
func (vr *versionResponse) setAllowedToAutoUpdate() (err error) {
// allowed to perform an automatic update by the OS. tlsMgr must not be nil.
func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error) {
if vr.CanAutoUpdate != aghalg.NBTrue {
return nil
}
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
tlsMgr.WriteDiskConfig(tlsConf)
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsConf) ||

View File

@@ -39,16 +39,22 @@ const (
// Called by other modules when configuration is changed
func onConfigModified() {
err := config.write()
err := config.write(globalContext.tls)
if err != nil {
log.Error("writing config: %s", err)
}
}
// initDNS updates all the fields of the [Context] needed to initialize the DNS
// server and initializes it at last. It also must not be called unless
// [config] and [Context] are initialized. baseLogger must not be nil.
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
// initDNS updates all the fields of the [globalContext] needed to initialize
// the DNS server and initializes it at last. It also must not be called unless
// [config] and [globalContext] are initialized. baseLogger and tlsMgr must not
// be nil.
func initDNS(
baseLogger *slog.Logger,
tlsMgr *tlsManager,
statsDir string,
querylogDir string,
) (err error) {
anonymizer := config.anonymizer()
statsConf := stats.Config{
@@ -58,7 +64,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
Enabled: config.Stats.Enabled,
ShouldCountClient: Context.clients.shouldCountClient,
ShouldCountClient: globalContext.clients.shouldCountClient,
}
engine, err := aghnet.NewIgnoreEngine(config.Stats.Ignored)
@@ -67,7 +73,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
}
statsConf.Ignored = engine
Context.stats, err = stats.New(statsConf)
globalContext.stats, err = stats.New(statsConf)
if err != nil {
return fmt.Errorf("init stats: %w", err)
}
@@ -77,7 +83,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
Anonymizer: anonymizer,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
FindClient: Context.clients.findMultiple,
FindClient: globalContext.clients.findMultiple,
BaseDir: querylogDir,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
RotationIvl: time.Duration(config.QueryLog.Interval),
@@ -92,25 +98,25 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
}
conf.Ignored = engine
Context.queryLog, err = querylog.New(conf)
globalContext.queryLog, err = querylog.New(conf)
if err != nil {
return fmt.Errorf("init querylog: %w", err)
}
Context.filters, err = filtering.New(config.Filtering, nil)
globalContext.filters, err = filtering.New(config.Filtering, nil)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
tlsMgr.WriteDiskConfig(tlsConf)
return initDNSServer(
Context.filters,
Context.stats,
Context.queryLog,
Context.dhcpServer,
globalContext.filters,
globalContext.stats,
globalContext.queryLog,
globalContext.dhcpServer,
anonymizer,
httpRegister,
tlsConf,
@@ -121,7 +127,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
// initDNSServer initializes the [context.dnsServer]. To only use the internal
// proxy, none of the arguments are required, but tlsConf and l still must not
// be nil, in other cases all the arguments also must not be nil. It also must
// not be called unless [config] and [Context] are initialized.
// not be called unless [config] and [globalContext] are initialized.
//
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
func initDNSServer(
@@ -134,7 +140,7 @@ func initDNSServer(
tlsConf *tlsConfigSettings,
l *slog.Logger,
) (err error) {
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
Logger: l,
DNSFilter: filters,
Stats: sts,
@@ -142,7 +148,7 @@ func initDNSServer(
PrivateNets: parseSubnetSet(config.DNS.PrivateNets),
Anonymizer: anonymizer,
DHCPServer: dhcpSrv,
EtcHosts: Context.etcHosts,
EtcHosts: globalContext.etcHosts,
LocalDomain: config.DHCP.LocalDomainName,
})
defer func() {
@@ -154,21 +160,27 @@ func initDNSServer(
return fmt.Errorf("dnsforward.NewServer: %w", err)
}
Context.clients.clientChecker = Context.dnsServer
globalContext.clients.clientChecker = globalContext.dnsServer
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
dnsConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsConf,
httpReg,
globalContext.clients.storage,
)
if err != nil {
return fmt.Errorf("newServerConfig: %w", err)
}
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
err = Context.dnsServer.Prepare(dnsConf)
err = globalContext.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf)
err = globalContext.dnsServer.Prepare(dnsConf)
}
if err != nil {
@@ -194,7 +206,7 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
}
func isRunning() bool {
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
}
func ipsToTCPAddrs(ips []netip.Addr, port uint16) (tcpAddrs []*net.TCPAddr) {
@@ -230,12 +242,13 @@ func newServerConfig(
clientSrcConf *clientSourcesConfig,
tlsConf *tlsConfigSettings,
httpReg aghhttp.RegisterFunc,
clientsContainer dnsforward.ClientsContainer,
) (newConf *dnsforward.ServerConfig, err error) {
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
fwdConf := dnsConf.Config
fwdConf.FilterHandler = applyAdditionalFiltering
fwdConf.ClientsContainer = &Context.clients
fwdConf.ClientsContainer = clientsContainer
newConf = &dnsforward.ServerConfig{
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
@@ -244,7 +257,7 @@ func newServerConfig(
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
TLSv12Roots: Context.tlsRoots,
TLSv12Roots: globalContext.tlsRoots,
ConfigModified: onConfigModified,
HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
@@ -259,16 +272,16 @@ func newServerConfig(
var initialAddresses []netip.Addr
// Context.stats may be nil here if initDNSServer is called from
// [cmdlineUpdate].
if sts := Context.stats; sts != nil {
if sts := globalContext.stats; sts != nil {
const initialClientsNum = 100
initialAddresses = Context.stats.TopClientsIP(initialClientsNum)
initialAddresses = globalContext.stats.TopClientsIP(initialClientsNum)
}
// Do not set DialContext, PrivateSubnets, and UsePrivateRDNS, because they
// are set by [dnsforward.Server.Prepare].
newConf.AddrProcConf = &client.DefaultAddrProcConfig{
Exchanger: Context.dnsServer,
AddressUpdater: &Context.clients,
Exchanger: globalContext.dnsServer,
AddressUpdater: &globalContext.clients,
InitialAddresses: initialAddresses,
CatchPanics: true,
UseRDNS: clientSrcConf.RDNS,
@@ -350,16 +363,18 @@ func newDNSCryptConfig(
}, nil
}
// dnsEncryption contains different types of TLS encryption addresses.
type dnsEncryption struct {
https string
tls string
quic string
}
func getDNSEncryption() (de dnsEncryption) {
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
// listens on. tlsMgr must not be nil.
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
tlsConf := tlsConfigSettings{}
Context.tls.WriteDiskConfig(&tlsConf)
tlsMgr.WriteDiskConfig(&tlsConf)
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
return dnsEncryption{}
@@ -402,7 +417,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
// pref is a prefix for logging messages around the scope.
const pref = "applying filters"
Context.filters.ApplyBlockedServices(setts)
globalContext.filters.ApplyBlockedServices(setts)
log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID)
@@ -412,9 +427,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ClientIP = clientIP
c, ok := Context.clients.storage.Find(clientID)
c, ok := globalContext.clients.storage.Find(clientID)
if !ok {
c, ok = Context.clients.storage.Find(clientIP.String())
c, ok = globalContext.clients.storage.Find(clientIP.String())
if !ok {
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
@@ -429,7 +444,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
setts.ServicesRules = nil
svcs := c.BlockedServices.IDs
if !c.BlockedServices.Schedule.Contains(time.Now()) {
Context.filters.ApplyBlockedServicesList(setts, svcs)
globalContext.filters.ApplyBlockedServicesList(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
}
@@ -455,24 +470,24 @@ func startDNSServer() error {
return fmt.Errorf("unable to start forwarding DNS server: Already running")
}
Context.filters.EnableFilters(false)
globalContext.filters.EnableFilters(false)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := Context.clients.Start(ctx)
err := globalContext.clients.Start(ctx)
if err != nil {
return fmt.Errorf("starting clients container: %w", err)
}
err = Context.dnsServer.Start()
err = globalContext.dnsServer.Start()
if err != nil {
return fmt.Errorf("starting dns server: %w", err)
}
Context.filters.Start()
Context.stats.Start()
globalContext.filters.Start()
globalContext.stats.Start()
err = Context.queryLog.Start(ctx)
err = globalContext.queryLog.Start(ctx)
if err != nil {
return fmt.Errorf("starting query log: %w", err)
}
@@ -480,16 +495,24 @@ func startDNSServer() error {
return nil
}
func reconfigureDNSServer() (err error) {
// reconfigureDNSServer updates the DNS server configuration using the provided
// TLS settings. tlsMgr must not be nil.
func reconfigureDNSServer(tlsMgr *tlsManager) (err error) {
tlsConf := &tlsConfigSettings{}
Context.tls.WriteDiskConfig(tlsConf)
tlsMgr.WriteDiskConfig(tlsConf)
newConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpRegister)
newConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsConf,
httpRegister,
globalContext.clients.storage,
)
if err != nil {
return fmt.Errorf("generating forwarding dns server config: %w", err)
}
err = Context.dnsServer.Reconfigure(newConf)
err = globalContext.dnsServer.Reconfigure(newConf)
if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err)
}
@@ -502,12 +525,12 @@ func stopDNSServer() (err error) {
return nil
}
err = Context.dnsServer.Stop()
err = globalContext.dnsServer.Stop()
if err != nil {
return fmt.Errorf("stopping forwarding dns server: %w", err)
}
err = Context.clients.close(context.TODO())
err = globalContext.clients.close(context.TODO())
if err != nil {
return fmt.Errorf("closing clients container: %w", err)
}
@@ -519,25 +542,25 @@ func stopDNSServer() (err error) {
func closeDNSServer() {
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
if Context.dnsServer != nil {
Context.dnsServer.Close()
Context.dnsServer = nil
if globalContext.dnsServer != nil {
globalContext.dnsServer.Close()
globalContext.dnsServer = nil
}
if Context.filters != nil {
Context.filters.Close()
if globalContext.filters != nil {
globalContext.filters.Close()
}
if Context.stats != nil {
err := Context.stats.Close()
if globalContext.stats != nil {
err := globalContext.stats.Close()
if err != nil {
log.Error("closing stats: %s", err)
}
}
if Context.queryLog != nil {
if globalContext.queryLog != nil {
// TODO(s.chzhen): Pass context.
err := Context.queryLog.Shutdown(context.TODO())
err := globalContext.queryLog.Shutdown(context.TODO())
if err != nil {
log.Error("closing query log: %s", err)
}

View File

@@ -37,14 +37,14 @@ func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage)
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
Context.filters, err = filtering.New(&filtering.Config{
globalContext.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
}, nil)
require.NoError(t, err)
Context.clients.storage = newStorage(t, []*client.Persistent{{
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default",
ClientIDs: []string{"default"},
UseOwnSettings: false,
@@ -124,7 +124,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
err error
)
Context.filters, err = filtering.New(&filtering.Config{
globalContext.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: globalBlockedServices,
@@ -132,7 +132,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.storage = newStorage(t, []*client.Persistent{{
globalContext.clients.storage = newStorage(t, []*client.Persistent{{
Name: "default",
ClientIDs: []string{"default"},
UseOwnBlockedServices: false,

View File

@@ -57,7 +57,12 @@ type homeContext struct {
auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module
web *webAPI // Web (HTTP, HTTPS) module
tls *tlsManager // TLS module
// tls contains the current configuration and state of TLS encryption.
//
// TODO(s.chzhen): Remove once it is no longer called from different
// modules. See [onConfigModified].
tls *tlsManager
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts
// configuration files, for example /etc/hosts.
@@ -91,10 +96,10 @@ func (c *homeContext) getDataDir() string {
return filepath.Join(c.workDir, dataDir)
}
// Context - a global context object
// globalContext is a global context object.
//
// TODO(a.garipov): Refactor.
var Context homeContext
var globalContext homeContext
// Main is the entry point
func Main(clientBuildFS fs.FS) {
@@ -113,40 +118,32 @@ func Main(clientBuildFS fs.FS) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
ctx := context.Background()
for {
sig := <-signals
log.Info("Received signal %q", sig)
switch sig {
case syscall.SIGHUP:
Context.clients.storage.ReloadARP(ctx)
Context.tls.reload()
default:
cleanup(ctx)
cleanupAlways()
close(done)
}
}
}()
ctx := context.Background()
sigHdlr := newSignalHandler(signals, func(ctx context.Context) {
cleanup(ctx)
cleanupAlways()
close(done)
})
go sigHdlr.handle(ctx)
if opts.serviceControlAction != "" {
handleServiceControlAction(opts, clientBuildFS, signals, done)
handleServiceControlAction(opts, clientBuildFS, signals, done, sigHdlr)
return
}
// run the protection
run(opts, clientBuildFS, done)
run(opts, clientBuildFS, done, sigHdlr)
}
// setupContext initializes [Context] fields. It also reads and upgrades
// setupContext initializes [globalContext] fields. It also reads and upgrades
// config file if necessary.
func setupContext(opts options) (err error) {
Context.firstRun = detectFirstRun()
globalContext.firstRun = detectFirstRun()
Context.tlsRoots = aghtls.SystemRootCAs()
Context.mux = http.NewServeMux()
globalContext.tlsRoots = aghtls.SystemRootCAs()
globalContext.mux = http.NewServeMux()
if !opts.noEtcHosts {
err = setupHostsContainer()
@@ -156,7 +153,7 @@ func setupContext(opts options) (err error) {
}
}
if Context.firstRun {
if globalContext.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkNetworkPermissions()
@@ -247,7 +244,7 @@ func setupHostsContainer() (err error) {
return fmt.Errorf("getting default system hosts paths: %w", err)
}
Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
if err != nil {
closeErr := hostsWatcher.Close()
if errors.Is(err, aghnet.ErrNoHostsPaths) {
@@ -271,14 +268,18 @@ func setupOpts(opts options) (err error) {
}
if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = opts.pidFile
globalContext.pidFileName = opts.pidFile
}
return nil
}
// initContextClients initializes Context clients and related fields.
func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
func initContextClients(
ctx context.Context,
logger *slog.Logger,
sigHdlr *signalHandler,
) (err error) {
err = setupDNSFilteringConf(ctx, logger, config.Filtering)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
@@ -286,13 +287,13 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
}
//lint:ignore SA1019 Migration is not over.
config.DHCP.WorkDir = Context.workDir
config.DHCP.DataDir = Context.getDataDir()
config.DHCP.WorkDir = globalContext.workDir
config.DHCP.DataDir = globalContext.getDataDir()
config.DHCP.HTTPRegister = httpRegister
config.DHCP.ConfigModified = onConfigModified
Context.dhcpServer, err = dhcpd.Create(config.DHCP)
if Context.dhcpServer == nil || err != nil {
globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
if globalContext.dhcpServer == nil || err != nil {
// TODO(a.garipov): There are a lot of places in the code right
// now which assume that the DHCP server can be nil despite this
// condition. Inspect them and perhaps rewrite them to use
@@ -305,14 +306,15 @@ func initContextClients(ctx context.Context, logger *slog.Logger) (err error) {
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
}
return Context.clients.Init(
return globalContext.clients.Init(
ctx,
logger,
config.Clients.Persistent,
Context.dhcpServer,
Context.etcHosts,
globalContext.dhcpServer,
globalContext.etcHosts,
arpDB,
config.Filtering,
sigHdlr,
)
}
@@ -374,15 +376,15 @@ func setupDNSFilteringConf(
pcTXTSuffix = `pc.dns.adguard.com.`
)
conf.EtcHosts = Context.etcHosts
conf.EtcHosts = globalContext.etcHosts
// TODO(s.chzhen): Use empty interface.
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled {
if globalContext.etcHosts == nil || !config.DNS.HostsFileEnabled {
conf.EtcHosts = nil
}
conf.ConfigModified = onConfigModified
conf.HTTPRegister = httpRegister
conf.DataDir = Context.getDataDir()
conf.DataDir = globalContext.getDataDir()
conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
conf.UserRules = slices.Clone(config.UserRules)
@@ -522,13 +524,15 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU
}
}
// initWeb initializes the web module. upd and baseLogger must not be nil.
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
// nil.
func initWeb(
ctx context.Context,
opts options,
clientBuildFS fs.FS,
upd *updater.Updater,
baseLogger *slog.Logger,
tlsMgr *tlsManager,
customURL bool,
) (web *webAPI, err error) {
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
@@ -551,6 +555,7 @@ func initWeb(
updater: upd,
logger: logger,
baseLogger: baseLogger,
tlsManager: tlsMgr,
clientFS: clientFS,
@@ -560,7 +565,7 @@ func initWeb(
ReadHeaderTimeout: readHdrTimeout,
WriteTimeout: writeTimeout,
firstRun: Context.firstRun,
firstRun: globalContext.firstRun,
disableUpdate: disableUpdate,
runningAsService: opts.runningAsService,
serveHTTP3: config.DNS.ServeHTTP3,
@@ -583,7 +588,7 @@ func fatalOnError(err error) {
// run configures and starts AdGuard Home.
//
// TODO(e.burkov): Make opts a pointer.
func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalHandler) {
// Configure working dir.
err := initWorkingDir(opts)
fatalOnError(err)
@@ -599,10 +604,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// TODO(a.garipov): Use slog everywhere.
slogLogger := newSlogLogger(ls)
sigHdlr.swapLogger(slogLogger)
// Print the first message after logger is configured.
log.Info(version.Full())
log.Debug("current working directory is %s", Context.workDir)
log.Info("%s", version.Full())
log.Debug("current working directory is %s", globalContext.workDir)
if opts.runningAsService {
log.Info("AdGuard Home is running as a service")
}
@@ -621,7 +627,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
// TODO(s.chzhen): Use it for the entire initialization process.
ctx := context.Background()
err = initContextClients(ctx, slogLogger)
err = initContextClients(ctx, slogLogger, sigHdlr)
fatalOnError(err)
err = setupOpts(opts)
@@ -632,15 +638,15 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
confPath := configFilePath()
upd, customURL := newUpdater(ctx, slogLogger, Context.workDir, confPath, execPath, config)
upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
// TODO(e.burkov): This could be made earlier, probably as the option's
// effect.
cmdlineUpdate(ctx, slogLogger, opts, upd)
if !Context.firstRun {
if !globalContext.firstRun {
// Save the updated config.
err = config.write()
err = config.write(nil)
fatalOnError(err)
if config.HTTPConfig.Pprof.Enabled {
@@ -648,33 +654,36 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
}
dataDir := Context.getDataDir()
dataDir := globalContext.getDataDir()
err = os.MkdirAll(dataDir, aghos.DefaultPermDir)
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
GLMode = opts.glinetMode
// Init auth module.
Context.auth, err = initUsers()
globalContext.auth, err = initUsers()
fatalOnError(err)
Context.tls, err = newTLSManager(config.TLS, config.DNS.ServePlainDNS)
tlsMgr, err := newTLSManager(config.TLS, config.DNS.ServePlainDNS)
if err != nil {
log.Error("initializing tls: %s", err)
onConfigModified()
}
Context.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, customURL)
globalContext.tls = tlsMgr
sigHdlr.addTLSManager(tlsMgr)
globalContext.web, err = initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
fatalOnError(err)
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
fatalOnError(err)
if !Context.firstRun {
err = initDNS(slogLogger, statsDir, querylogDir)
if !globalContext.firstRun {
err = initDNS(slogLogger, tlsMgr, statsDir, querylogDir)
fatalOnError(err)
Context.tls.start()
tlsMgr.start(ctx)
go func() {
startErr := startDNSServer()
@@ -684,8 +693,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
}()
if Context.dhcpServer != nil {
err = Context.dhcpServer.Start()
if globalContext.dhcpServer != nil {
err = globalContext.dhcpServer.Start()
if err != nil {
log.Error("starting dhcp server: %s", err)
}
@@ -693,10 +702,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
if !opts.noPermCheck {
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir)
checkPermissions(ctx, slogLogger, globalContext.workDir, confPath, dataDir, statsDir, querylogDir)
}
Context.web.start(ctx)
globalContext.web.start(ctx)
// Wait for other goroutines to complete their job.
<-done
@@ -775,7 +784,7 @@ func checkPermissions(
// initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
@@ -807,31 +816,6 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
return aghnet.NewIPMut(anonFunc)
}
// startMods initializes and starts the DNS server after installation.
// baseLogger must not be nil.
func startMods(baseLogger *slog.Logger) (err error) {
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
if err != nil {
return err
}
err = initDNS(baseLogger, statsDir, querylogDir)
if err != nil {
return err
}
Context.tls.start()
err = startDNSServer()
if err != nil {
closeDNSServer()
return err
}
return nil
}
// checkNetworkPermissions checks if the current user permissions are enough to
// use the required networking functionality.
func checkNetworkPermissions() {
@@ -883,14 +867,14 @@ func writePIDFile(fn string) bool {
func initConfigFilename(opts options) {
confPath := opts.confFilename
if confPath == "" {
Context.confFilePath = filepath.Join(Context.workDir, "AdGuardHome.yaml")
globalContext.confFilePath = filepath.Join(globalContext.workDir, "AdGuardHome.yaml")
return
}
log.Debug("config path overridden to %q from cmdline", confPath)
Context.confFilePath = confPath
globalContext.confFilePath = confPath
}
// initWorkingDir initializes the workDir. If no command-line arguments are
@@ -904,18 +888,18 @@ func initWorkingDir(opts options) (err error) {
if opts.workDir != "" {
// If there is a custom config file, use it's directory as our working dir
Context.workDir = opts.workDir
globalContext.workDir = opts.workDir
} else {
Context.workDir = filepath.Dir(execPath)
globalContext.workDir = filepath.Dir(execPath)
}
workDir, err := filepath.EvalSymlinks(Context.workDir)
workDir, err := filepath.EvalSymlinks(globalContext.workDir)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
Context.workDir = workDir
globalContext.workDir = workDir
return nil
}
@@ -924,13 +908,13 @@ func initWorkingDir(opts options) (err error) {
func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home")
if Context.web != nil {
Context.web.close(ctx)
Context.web = nil
if globalContext.web != nil {
globalContext.web.close(ctx)
globalContext.web = nil
}
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
if globalContext.auth != nil {
globalContext.auth.Close()
globalContext.auth = nil
}
err := stopDNSServer()
@@ -938,28 +922,24 @@ func cleanup(ctx context.Context) {
log.Error("stopping dns server: %s", err)
}
if Context.dhcpServer != nil {
err = Context.dhcpServer.Stop()
if globalContext.dhcpServer != nil {
err = globalContext.dhcpServer.Stop()
if err != nil {
log.Error("stopping dhcp server: %s", err)
}
}
if Context.etcHosts != nil {
if err = Context.etcHosts.Close(); err != nil {
if globalContext.etcHosts != nil {
if err = globalContext.etcHosts.Close(); err != nil {
log.Error("closing hosts container: %s", err)
}
}
if Context.tls != nil {
Context.tls = nil
}
}
// This function is called before application exits
func cleanupAlways() {
if len(Context.pidFileName) != 0 {
_ = os.Remove(Context.pidFileName)
if len(globalContext.pidFileName) != 0 {
_ = os.Remove(globalContext.pidFileName)
}
log.Info("stopped")
@@ -975,7 +955,7 @@ func exitWithError() {
func loadCmdLineOpts() (opts options) {
opts, eff, err := parseCmdOpts(os.Args[0], os.Args[1:])
if err != nil {
log.Error(err.Error())
log.Error("%s", err)
printHelp(os.Args[0])
exitWithError()
@@ -984,7 +964,7 @@ func loadCmdLineOpts() (opts options) {
if eff != nil {
err = eff()
if err != nil {
log.Error(err.Error())
log.Error("%s", err)
exitWithError()
}
@@ -1005,10 +985,12 @@ func printWebAddrs(proto, addr string, port uint16) {
// printHTTPAddresses prints the IP addresses which user can use to access the
// admin interface. proto is either schemeHTTP or schemeHTTPS.
func printHTTPAddresses(proto string) {
//
// TODO(s.chzhen): Implement separate functions for HTTP and HTTPS.
func printHTTPAddresses(proto string, tlsMgr *tlsManager) {
tlsConf := tlsConfigSettings{}
if Context.tls != nil {
Context.tls.WriteDiskConfig(&tlsConf)
if tlsMgr != nil {
tlsMgr.WriteDiskConfig(&tlsConf)
}
port := config.HTTPConfig.Address.Port()
@@ -1016,7 +998,6 @@ func printHTTPAddresses(proto string) {
port = tlsConf.PortHTTPS
}
// TODO(e.burkov): Inspect and perhaps merge with the previous condition.
if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(proto, tlsConf.ServerName, tlsConf.PortHTTPS)
@@ -1050,9 +1031,9 @@ func printHTTPAddresses(proto string) {
// detectFirstRun returns true if this is the first run of AdGuard Home.
func detectFirstRun() (ok bool) {
confPath := Context.confFilePath
confPath := globalContext.confFilePath
if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, Context.confFilePath)
confPath = filepath.Join(globalContext.workDir, globalContext.confFilePath)
}
_, err := os.Stat(confPath)
@@ -1105,7 +1086,7 @@ func cmdlineUpdate(ctx context.Context, l *slog.Logger, opts options, upd *updat
os.Exit(osutil.ExitCodeSuccess)
}
err = upd.Update(Context.firstRun)
err = upd.Update(globalContext.firstRun)
fatalOnError(err)
err = restartService()

View File

@@ -17,7 +17,7 @@ func httpClient() (c *http.Client) {
// Do not use Context.dnsServer.DialContext directly in the struct literal
// below, since Context.dnsServer may be nil when this function is called.
dialContext := func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
return Context.dnsServer.DialContext(ctx, network, addr)
return globalContext.dnsServer.DialContext(ctx, network, addr)
}
return &http.Client{
@@ -27,8 +27,8 @@ func httpClient() (c *http.Client) {
DialContext: dialContext,
Proxy: httpProxy,
TLSClientConfig: &tls.Config{
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
RootCAs: globalContext.tlsRoots,
CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
},

View File

@@ -66,7 +66,7 @@ func configureLogger(ls *logSettings) (err error) {
logFilePath := ls.File
if !filepath.IsAbs(logFilePath) {
logFilePath = filepath.Join(Context.workDir, logFilePath)
logFilePath = filepath.Join(globalContext.workDir, logFilePath)
}
log.SetOutput(&lumberjack.Logger{

View File

@@ -19,10 +19,8 @@ func setupDNSIPs(t testing.TB) {
t.Helper()
prevConfig := config
prevTLS := Context.tls
t.Cleanup(func() {
config = prevConfig
Context.tls = prevTLS
})
config = &configuration{
@@ -31,8 +29,6 @@ func setupDNSIPs(t testing.TB) {
Port: defaultPortDNS,
},
}
Context.tls = &tlsManager{}
}
func TestHandleMobileConfigDoH(t *testing.T) {
@@ -62,11 +58,6 @@ func TestHandleMobileConfigDoH(t *testing.T) {
})
t.Run("error_no_host", func(t *testing.T) {
oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig", nil)
require.NoError(t, err)
@@ -134,11 +125,6 @@ func TestHandleMobileConfigDoT(t *testing.T) {
})
t.Run("error_no_host", func(t *testing.T) {
oldTLSConf := Context.tls
t.Cleanup(func() { Context.tls = oldTLSConf })
Context.tls = &tlsManager{conf: tlsConfigSettings{}}
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig", nil)
require.NoError(t, err)

View File

@@ -47,7 +47,7 @@ type profileJSON struct {
// handleGetProfile is the handler for GET /control/profile endpoint.
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
u := Context.auth.getCurrentUser(r)
u := globalContext.auth.getCurrentUser(r)
var resp profileJSON
func() {

View File

@@ -36,6 +36,7 @@ type program struct {
signals chan os.Signal
done chan struct{}
opts options
sigHdlr *signalHandler
}
// type check
@@ -47,7 +48,7 @@ func (p *program) Start(_ service.Service) (err error) {
args := p.opts
args.runningAsService = true
go run(args, p.clientBuildFS, p.done)
go run(args, p.clientBuildFS, p.done, p.sigHdlr)
return nil
}
@@ -204,13 +205,14 @@ func handleServiceControlAction(
clientBuildFS fs.FS,
signals chan os.Signal,
done chan struct{},
sigHdlr *signalHandler,
) {
// Call chooseSystem explicitly to introduce OpenBSD support for service
// package. It's a noop for other GOOS values.
chooseSystem()
action := opts.serviceControlAction
log.Info(version.Full())
log.Info("%s", version.Full())
log.Info("service: control action: %s", action)
if action == "reload" {
@@ -244,6 +246,7 @@ func handleServiceControlAction(
signals: signals,
done: done,
opts: runOpts,
sigHdlr: sigHdlr,
}, svcConfig)
if err != nil {
log.Fatalf("service: initializing service: %s", err)
@@ -336,7 +339,7 @@ AdGuard Home is successfully installed and will automatically start on boot.
There are a few more things that must be configured before you can use it.
Click on the link below and follow the Installation Wizard steps to finish setup.
AdGuard Home is now available at the following addresses:`)
printHTTPAddresses(urlutil.SchemeHTTP)
printHTTPAddresses(urlutil.SchemeHTTP, nil)
}
}

View File

@@ -392,7 +392,7 @@ type sysLogger struct{}
// Error implements service.Logger interface for sysLogger.
func (sysLogger) Error(v ...any) error {
log.Error(fmt.Sprint(v...))
log.Error("%s", fmt.Sprint(v...))
return nil
}
@@ -406,7 +406,7 @@ func (sysLogger) Warning(v ...any) error {
// Info implements service.Logger interface for sysLogger.
func (sysLogger) Info(v ...any) error {
log.Info(fmt.Sprint(v...))
log.Info("%s", fmt.Sprint(v...))
return nil
}

121
internal/home/signal.go Normal file
View File

@@ -0,0 +1,121 @@
package home
import (
"context"
"log/slog"
"os"
"sync"
"sync/atomic"
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
)
// signalHandler processes incoming signals. It reloads configurations of
// stored entities on SIGHUP and performs cleanup on all other signals.
type signalHandler struct {
// logger is used to log the operation of the signal handler. Initially,
// [slog.Default] is used, but it should be swapped later using
// [signalHandler.swapLogger].
logger *atomic.Pointer[slog.Logger]
// mu protects clientStorage and tlsManager.
mu *sync.Mutex
// clientStorage is used to reload information about runtime clients with an
// ARP source.
clientStorage *client.Storage
// tlsManager is used to reload the TLS configuration.
tlsManager *tlsManager
// signals receives incoming signals.
signals <-chan os.Signal
// cleanup is called to perform cleanup on all incoming signals, except
// SIGHUP.
cleanup func(ctx context.Context)
}
// newSignalHandler returns a new properly initialized *signalHandler.
func newSignalHandler(
signals <-chan os.Signal,
cleanup func(ctx context.Context),
) (h *signalHandler) {
h = &signalHandler{
logger: &atomic.Pointer[slog.Logger]{},
mu: &sync.Mutex{},
signals: signals,
cleanup: cleanup,
}
h.logger.Store(slog.Default())
return h
}
// swapLogger replaces the stored logger with the given logger.
func (h *signalHandler) swapLogger(logger *slog.Logger) {
h.logger.Swap(logger)
}
// addClientStorage stores the client storage.
func (h *signalHandler) addClientStorage(s *client.Storage) {
h.mu.Lock()
defer h.mu.Unlock()
h.clientStorage = s
}
// addTLSManager stores the TLS manager.
func (h *signalHandler) addTLSManager(m *tlsManager) {
h.mu.Lock()
defer h.mu.Unlock()
h.tlsManager = m
}
// handle processes incoming signals. It blocks until a signal is received. It
// reloads configurations of stored entities on SIGHUP, or performs cleanup on
// all other signals. It is intended to be used as a goroutine.
func (h *signalHandler) handle(ctx context.Context) {
// NOTE: Avoid using [slogutil.RecoverAndExit] to prevent immediate
// evaluation of the logger.
defer func() {
v := recover()
if v == nil {
return
}
slogutil.PrintRecovered(ctx, h.logger.Load(), v)
os.Exit(osutil.ExitCodeFailure)
}()
for {
sig := <-h.signals
h.logger.Load().InfoContext(ctx, "received signal", "signal", sig)
switch sig {
case syscall.SIGHUP:
h.reloadConfig(ctx)
default:
h.cleanup(ctx)
}
}
}
// reloadConfig refreshes configurations of stored entities.
func (h *signalHandler) reloadConfig(ctx context.Context) {
h.mu.Lock()
defer h.mu.Unlock()
if h.clientStorage != nil {
h.clientStorage.ReloadARP(ctx)
}
if h.tlsManager != nil {
h.tlsManager.reload()
}
}

View File

@@ -102,7 +102,9 @@ func (m *tlsManager) setCertFileTime() {
}
// start updates the configuration of t and starts it.
func (m *tlsManager) start() {
//
// TODO(s.chzhen): Use context.
func (m *tlsManager) start(_ context.Context) {
m.registerWebHandlers()
m.confLock.Lock()
@@ -112,7 +114,7 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.tlsConfigChanged(context.Background(), tlsConf)
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
}
// reload updates the configuration and restarts t.
@@ -151,7 +153,7 @@ func (m *tlsManager) reload() {
m.certLastMod = fi.ModTime().UTC()
_ = reconfigureDNSServer()
_ = reconfigureDNSServer(m)
m.confLock.Lock()
tlsConf = m.conf
@@ -160,7 +162,7 @@ func (m *tlsManager) reload() {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
Context.web.tlsConfigChanged(context.Background(), tlsConf)
globalContext.web.tlsConfigChanged(context.Background(), tlsConf)
}
// loadTLSConf loads and validates the TLS configuration. The returned error is
@@ -440,7 +442,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
onConfigModified()
err = reconfigureDNSServer()
err = reconfigureDNSServer(m)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@@ -463,7 +465,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason.
if restartHTTPS {
go func() {
Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
globalContext.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}()
}
}
@@ -539,7 +541,7 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) {
opts := x509.VerifyOptions{
DNSName: srvName,
Roots: Context.tlsRoots,
Roots: globalContext.tlsRoots,
Intermediates: pool,
}
_, err = main.Verify(opts)

View File

@@ -49,6 +49,10 @@ type webConfig struct {
// nil.
baseLogger *slog.Logger
// tlsManager contains the current configuration and state of TLS
// encryption. It must not be nil.
tlsManager *tlsManager
clientFS fs.FS
// BindAddr is the binding address with port for plain HTTP web interface.
@@ -108,6 +112,10 @@ type webAPI struct {
// nil.
baseLogger *slog.Logger
// tlsManager contains the current configuration and state of TLS
// encryption.
tlsManager *tlsManager
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
// [Web.http3Server] must also not be nil.
httpsServer httpsServer
@@ -124,12 +132,13 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
conf: conf,
logger: conf.logger,
baseLogger: conf.baseLogger,
tlsManager: conf.tlsManager,
}
clientFS := http.FileServer(http.FS(conf.clientFS))
// if not configured, redirect / to /install.html, otherwise redirect /install.html to /
Context.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
globalContext.mux.Handle("/", withMiddlewares(clientFS, gziphandler.GzipHandler, optionalAuthHandler, postInstallHandler))
// add handlers for /install paths, we only need them when we're not configured yet
if conf.firstRun {
@@ -138,7 +147,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
"This is the first launch of AdGuard Home, redirecting everything to /install.html",
)
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
globalContext.mux.Handle("/install.html", preInstallHandler(clientFS))
w.registerInstallHandlers()
} else {
registerControlHandlers(w)
@@ -154,7 +163,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
//
// TODO(a.garipov): Adapt for HTTP/3.
func webCheckPortAvailable(port uint16) (ok bool) {
if Context.web.httpsServer.server != nil {
if globalContext.web.httpsServer.server != nil {
return true
}
@@ -220,14 +229,18 @@ func (web *webAPI) start(ctx context.Context) {
// this loop is used as an ability to change listening host and/or port
for !web.httpsServer.inShutdown {
printHTTPAddresses(urlutil.SchemeHTTP)
printHTTPAddresses(urlutil.SchemeHTTP, web.tlsManager)
errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
logger := web.baseLogger.With(loggerKeyServer, "plain")
// TODO(a.garipov): Remove other logs like this in other code.
logMw := httputil.NewLogMiddleware(logger, slog.LevelDebug)
hdlr = logMw.Wrap(hdlr)
// Create a new instance, because the Web is not usable after Shutdown.
web.httpServer = &http.Server{
Addr: web.conf.BindAddr.String(),
@@ -238,7 +251,9 @@ func (web *webAPI) start(ctx context.Context) {
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
}
go func() {
defer slogutil.RecoverAndLog(ctx, web.logger)
defer slogutil.RecoverAndLog(ctx, logger)
logger.InfoContext(ctx, "starting plain server", "addr", web.httpServer.Addr)
errs <- web.httpServer.ListenAndServe()
}()
@@ -305,13 +320,17 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
logger := web.baseLogger.With(loggerKeyServer, "https")
// TODO(a.garipov): Remove other logs like this in other code.
logMw := httputil.NewLogMiddleware(logger, slog.LevelDebug)
hdlr := logMw.Wrap(withMiddlewares(globalContext.mux, limitRequestBody))
web.httpsServer.server = &http.Server{
Addr: addr,
Handler: withMiddlewares(Context.mux, limitRequestBody),
Handler: hdlr,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
RootCAs: globalContext.tlsRoots,
CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
ReadTimeout: web.conf.ReadTimeout,
@@ -320,13 +339,13 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
}
printHTTPAddresses(urlutil.SchemeHTTPS)
printHTTPAddresses(urlutil.SchemeHTTPS, web.tlsManager)
if web.conf.serveHTTP3 {
go web.mustStartHTTP3(ctx, addr)
}
web.logger.DebugContext(ctx, "starting https server")
logger.InfoContext(ctx, "starting https server")
err := web.httpsServer.server.ListenAndServeTLS("", "")
if !errors.Is(err, http.ErrServerClosed) {
cleanupAlways()
@@ -344,11 +363,11 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
Addr: address,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: Context.tlsRoots,
CipherSuites: Context.tlsCipherIDs,
RootCAs: globalContext.tlsRoots,
CipherSuites: globalContext.tlsCipherIDs,
MinVersion: tls.VersionTLS12,
},
Handler: withMiddlewares(Context.mux, limitRequestBody),
Handler: withMiddlewares(globalContext.mux, limitRequestBody),
}
web.logger.DebugContext(ctx, "starting http/3 server")