all: sync with master
This commit is contained in:
@@ -400,7 +400,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||
upstreams,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: config.DNS.UpstreamTimeout.Duration,
|
||||
Timeout: time.Duration(config.DNS.UpstreamTimeout),
|
||||
HTTPVersions: dnsforward.UpstreamHTTPVersions(config.DNS.UseHTTP3Upstreams),
|
||||
PreferIPv6: config.DNS.BootstrapPreferIPv6,
|
||||
},
|
||||
|
||||
@@ -424,6 +424,8 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
}
|
||||
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]*clientJSON{}
|
||||
@@ -433,19 +435,58 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
break
|
||||
}
|
||||
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.storage.Find(idStr)
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
data = append(data, map[string]*clientJSON{
|
||||
idStr: cj,
|
||||
idStr: clients.findClient(idStr),
|
||||
})
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
}
|
||||
|
||||
// findClient returns available information about a client by idStr from the
|
||||
// client's storage or access settings. cj is guaranteed to be non-nil.
|
||||
func (clients *clientsContainer) findClient(idStr string) (cj *clientJSON) {
|
||||
ip, _ := netip.ParseAddr(idStr)
|
||||
c, ok := clients.storage.Find(idStr)
|
||||
if !ok {
|
||||
return clients.findRuntime(ip, idStr)
|
||||
}
|
||||
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
// searchQueryJSON is a request to the POST /control/clients/search HTTP API.
|
||||
//
|
||||
// TODO(s.chzhen): Add UIDs.
|
||||
type searchQueryJSON struct {
|
||||
Clients []searchClientJSON `json:"clients"`
|
||||
}
|
||||
|
||||
// searchClientJSON is a part of [searchQueryJSON] that contains a string
|
||||
// representation of the client's IP address, CIDR, MAC address, or ClientID.
|
||||
type searchClientJSON struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// handleSearchClient is the handler for the POST /control/clients/search HTTP API.
|
||||
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := searchQueryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&q)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
data := []map[string]*clientJSON{}
|
||||
for _, c := range q.Clients {
|
||||
idStr := c.ID
|
||||
data = append(data, map[string]*clientJSON{
|
||||
idStr: clients.findClient(idStr),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -493,5 +534,8 @@ func (clients *clientsContainer) registerWebHandlers() {
|
||||
httpRegister(http.MethodPost, "/control/clients/add", clients.handleAddClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/search", clients.handleSearchClient)
|
||||
|
||||
// Deprecated handler.
|
||||
httpRegister(http.MethodGet, "/control/clients/find", clients.handleFindClient)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -408,3 +409,145 @@ func TestClientsContainer_HandleFindClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsContainer_HandleSearchClient(t *testing.T) {
|
||||
var (
|
||||
runtimeCli = "runtime_client1"
|
||||
|
||||
runtimeCliIP = "3.3.3.3"
|
||||
blockedCliIP = "4.4.4.4"
|
||||
nonExistentCliIP = "5.5.5.5"
|
||||
|
||||
allowed = false
|
||||
dissallowed = true
|
||||
|
||||
emptyRule = ""
|
||||
disallowedRule = "disallowed_rule"
|
||||
)
|
||||
|
||||
clients := newClientsContainer(t)
|
||||
clients.clientChecker = &testBlockedClientChecker{
|
||||
onIsBlockedClient: func(ip netip.Addr, _ string) (ok bool, rule string) {
|
||||
if ip == netip.MustParseAddr(blockedCliIP) {
|
||||
return true, disallowedRule
|
||||
}
|
||||
|
||||
return false, emptyRule
|
||||
},
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
|
||||
err := clients.storage.Add(ctx, clientOne)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
|
||||
err = clients.storage.Add(ctx, clientTwo)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
|
||||
|
||||
clients.UpdateAddress(ctx, netip.MustParseAddr(runtimeCliIP), runtimeCli, nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query *searchQueryJSON
|
||||
wantPersistent []*client.Persistent
|
||||
wantRuntime *clientJSON
|
||||
}{{
|
||||
name: "single",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: testClientIP1,
|
||||
}},
|
||||
},
|
||||
wantPersistent: []*client.Persistent{clientOne},
|
||||
}, {
|
||||
name: "multiple",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: testClientIP1,
|
||||
}, {
|
||||
ID: testClientIP2,
|
||||
}},
|
||||
},
|
||||
wantPersistent: []*client.Persistent{clientOne, clientTwo},
|
||||
}, {
|
||||
name: "runtime",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: runtimeCliIP,
|
||||
}},
|
||||
},
|
||||
wantRuntime: &clientJSON{
|
||||
Name: runtimeCli,
|
||||
IDs: []string{runtimeCliIP},
|
||||
Disallowed: &allowed,
|
||||
DisallowedRule: &emptyRule,
|
||||
WHOIS: &whois.Info{},
|
||||
},
|
||||
}, {
|
||||
name: "blocked_access",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: blockedCliIP,
|
||||
}},
|
||||
},
|
||||
wantRuntime: &clientJSON{
|
||||
IDs: []string{blockedCliIP},
|
||||
Disallowed: &dissallowed,
|
||||
DisallowedRule: &disallowedRule,
|
||||
WHOIS: &whois.Info{},
|
||||
},
|
||||
}, {
|
||||
name: "non_existing_client",
|
||||
query: &searchQueryJSON{
|
||||
Clients: []searchClientJSON{{
|
||||
ID: nonExistentCliIP,
|
||||
}},
|
||||
},
|
||||
wantRuntime: &clientJSON{
|
||||
IDs: []string{nonExistentCliIP},
|
||||
Disallowed: &allowed,
|
||||
DisallowedRule: &emptyRule,
|
||||
WHOIS: &whois.Info{},
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var body []byte
|
||||
body, err = json.Marshal(tc.query)
|
||||
require.NoError(t, err)
|
||||
|
||||
var r *http.Request
|
||||
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
clients.handleSearchClient(rw, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
body, err = io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientData := []map[string]*clientJSON{}
|
||||
err = json.Unmarshal(body, &clientData)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.wantPersistent != nil {
|
||||
assertPersistentClientsData(t, clients, clientData, tc.wantPersistent)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, clientData, 1)
|
||||
require.Len(t, clientData[0], 1)
|
||||
|
||||
rc := clientData[0][tc.wantRuntime.IDs[0]]
|
||||
assert.Equal(t, tc.wantRuntime, rc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,7 +339,7 @@ var config = &configuration{
|
||||
AuthBlockMin: 15,
|
||||
HTTPConfig: httpConfig{
|
||||
Address: netip.AddrPortFrom(netip.IPv4Unspecified(), 3000),
|
||||
SessionTTL: timeutil.Duration{Duration: 30 * timeutil.Day},
|
||||
SessionTTL: timeutil.Duration(30 * timeutil.Day),
|
||||
Pprof: &httpPprofConfig{
|
||||
Enabled: false,
|
||||
Port: 6060,
|
||||
@@ -355,9 +355,7 @@ var config = &configuration{
|
||||
RefuseAny: true,
|
||||
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
|
||||
HandleDDR: true,
|
||||
FastestTimeout: timeutil.Duration{
|
||||
Duration: fastip.DefaultPingWaitTimeout,
|
||||
},
|
||||
FastestTimeout: timeutil.Duration(fastip.DefaultPingWaitTimeout),
|
||||
|
||||
TrustedProxies: []netutil.Prefix{{
|
||||
Prefix: netip.MustParsePrefix("127.0.0.0/8"),
|
||||
@@ -378,7 +376,7 @@ var config = &configuration{
|
||||
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
|
||||
MaxGoroutines: 300,
|
||||
},
|
||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||
UpstreamTimeout: timeutil.Duration(dnsforward.DefaultTimeout),
|
||||
UsePrivateRDNS: true,
|
||||
ServePlainDNS: true,
|
||||
HostsFileEnabled: true,
|
||||
@@ -391,13 +389,13 @@ var config = &configuration{
|
||||
QueryLog: queryLogConfig{
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
Interval: timeutil.Duration{Duration: 90 * timeutil.Day},
|
||||
Interval: timeutil.Duration(90 * timeutil.Day),
|
||||
MemSize: 1000,
|
||||
Ignored: []string{},
|
||||
},
|
||||
Stats: statsConfig{
|
||||
Enabled: true,
|
||||
Interval: timeutil.Duration{Duration: 1 * timeutil.Day},
|
||||
Interval: timeutil.Duration(1 * timeutil.Day),
|
||||
Ignored: []string{},
|
||||
},
|
||||
// NOTE: Keep these parameters in sync with the one put into
|
||||
@@ -565,8 +563,8 @@ func parseConfig() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.DNS.UpstreamTimeout.Duration == 0 {
|
||||
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
|
||||
if config.DNS.UpstreamTimeout == 0 {
|
||||
config.DNS.UpstreamTimeout = timeutil.Duration(dnsforward.DefaultTimeout)
|
||||
}
|
||||
|
||||
// Do not wrap the error because it's informative enough as is.
|
||||
@@ -659,7 +657,7 @@ func (c *configuration) write() (err error) {
|
||||
if Context.stats != nil {
|
||||
statsConf := stats.Config{}
|
||||
Context.stats.WriteDiskConfig(&statsConf)
|
||||
config.Stats.Interval = timeutil.Duration{Duration: statsConf.Limit}
|
||||
config.Stats.Interval = timeutil.Duration(statsConf.Limit)
|
||||
config.Stats.Enabled = statsConf.Enabled
|
||||
config.Stats.Ignored = statsConf.Ignored.Values()
|
||||
}
|
||||
@@ -670,7 +668,7 @@ func (c *configuration) write() (err error) {
|
||||
config.DNS.AnonymizeClientIP = dc.AnonymizeClientIP
|
||||
config.QueryLog.Enabled = dc.Enabled
|
||||
config.QueryLog.FileEnabled = dc.FileEnabled
|
||||
config.QueryLog.Interval = timeutil.Duration{Duration: dc.RotationIvl}
|
||||
config.QueryLog.Interval = timeutil.Duration(dc.RotationIvl)
|
||||
config.QueryLog.MemSize = dc.MemSize
|
||||
config.QueryLog.Ignored = dc.Ignored.Values()
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
@@ -124,6 +125,8 @@ func (req *checkConfReq) validateWeb(tcpPorts aghalg.UniqChecker[tcpPort]) (err
|
||||
// be set. canAutofix is true if the port can be unbound by AdGuard Home
|
||||
// automatically.
|
||||
func (req *checkConfReq) validateDNS(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
tcpPorts aghalg.UniqChecker[tcpPort],
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
@@ -154,10 +157,10 @@ func (req *checkConfReq) validateDNS(
|
||||
}
|
||||
|
||||
// Try to fix automatically.
|
||||
canAutofix = checkDNSStubListener()
|
||||
canAutofix = checkDNSStubListener(ctx, l)
|
||||
if canAutofix && req.DNS.Autofix {
|
||||
if derr := disableDNSStubListener(); derr != nil {
|
||||
log.Error("disabling DNSStubListener: %s", err)
|
||||
if derr := disableDNSStubListener(ctx, l); derr != nil {
|
||||
l.ErrorContext(ctx, "disabling DNSStubListener", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, port))
|
||||
@@ -184,7 +187,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(tcpPorts); err != nil {
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts); err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
@@ -233,27 +236,39 @@ func handleStaticIP(ip netip.Addr, set bool) staticIPJSON {
|
||||
return resp
|
||||
}
|
||||
|
||||
// Check if DNSStubListener is active
|
||||
func checkDNSStubListener() bool {
|
||||
// checkDNSStubListener returns true if DNSStubListener is active.
|
||||
func checkDNSStubListener(ctx context.Context, l *slog.Logger) (ok bool) {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
|
||||
_, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Info("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
l.InfoContext(
|
||||
ctx,
|
||||
"execution failed",
|
||||
"cmd", cmd.Path,
|
||||
"code", cmd.ProcessState.ExitCode(),
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
log.Info("command %s has failed: %v code:%d",
|
||||
cmd.Path, err, cmd.ProcessState.ExitCode())
|
||||
l.InfoContext(
|
||||
ctx,
|
||||
"execution failed",
|
||||
"cmd", cmd.Path,
|
||||
"code", cmd.ProcessState.ExitCode(),
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -269,8 +284,9 @@ DNSStubListener=no
|
||||
)
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// Deactivate DNSStubListener
|
||||
func disableDNSStubListener() (err error) {
|
||||
// disableDNSStubListener deactivates DNSStubListerner and returns an error, if
|
||||
// any.
|
||||
func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
|
||||
dir := filepath.Dir(resolvedConfPath)
|
||||
err = os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
@@ -290,7 +306,7 @@ func disableDNSStubListener() (err error) {
|
||||
}
|
||||
|
||||
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -327,9 +343,9 @@ func copyInstallSettings(dst, src *configuration) {
|
||||
// shutdownTimeout is the timeout for shutting HTTP server down operation.
|
||||
const shutdownTimeout = 5 * time.Second
|
||||
|
||||
// shutdownSrv shuts srv down and prints error messages to the log.
|
||||
func shutdownSrv(ctx context.Context, srv *http.Server) {
|
||||
defer log.OnPanic("")
|
||||
// shutdownSrv shuts down srv and logs the error, if any. l must not be nil.
|
||||
func shutdownSrv(ctx context.Context, l *slog.Logger, srv *http.Server) {
|
||||
defer slogutil.RecoverAndLog(ctx, l)
|
||||
|
||||
if srv == nil {
|
||||
return
|
||||
@@ -340,19 +356,19 @@ func shutdownSrv(ctx context.Context, srv *http.Server) {
|
||||
return
|
||||
}
|
||||
|
||||
const msgFmt = "shutting down http server %q: %s"
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debug(msgFmt, srv.Addr, err)
|
||||
} else {
|
||||
log.Error(msgFmt, srv.Addr, err)
|
||||
lvl := slog.LevelDebug
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
lvl = slog.LevelError
|
||||
}
|
||||
|
||||
l.Log(ctx, lvl, "shutting down http server", "addr", srv.Addr, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
// shutdownSrv3 shuts srv down and prints error messages to the log.
|
||||
// shutdownSrv3 shuts down srv and logs the error, if any. l must not be nil.
|
||||
//
|
||||
// TODO(a.garipov): Think of a good way to merge with [shutdownSrv].
|
||||
func shutdownSrv3(srv *http3.Server) {
|
||||
defer log.OnPanic("")
|
||||
func shutdownSrv3(ctx context.Context, l *slog.Logger, srv *http3.Server) {
|
||||
defer slogutil.RecoverAndLog(ctx, l)
|
||||
|
||||
if srv == nil {
|
||||
return
|
||||
@@ -363,12 +379,12 @@ func shutdownSrv3(srv *http3.Server) {
|
||||
return
|
||||
}
|
||||
|
||||
const msgFmt = "shutting down http/3 server %q: %s"
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debug(msgFmt, srv.Addr, err)
|
||||
} else {
|
||||
log.Error(msgFmt, srv.Addr, err)
|
||||
lvl := slog.LevelDebug
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
lvl = slog.LevelError
|
||||
}
|
||||
|
||||
l.Log(ctx, lvl, "shutting down http/3 server", "addr", srv.Addr, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
// PasswordMinRunes is the minimum length of user's password in runes.
|
||||
@@ -436,7 +452,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods(web.logger)
|
||||
err = startMods(web.baseLogger)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
@@ -472,12 +488,11 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// and with its own context, because it waits until all requests are handled
|
||||
// and will be blocked by it's own caller.
|
||||
go func(timeout time.Duration) {
|
||||
defer log.OnPanic("web")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer slogutil.RecoverAndLog(ctx, web.logger)
|
||||
defer cancel()
|
||||
|
||||
shutdownSrv(ctx, web.httpServer)
|
||||
shutdownSrv(ctx, web.logger, web.httpServer)
|
||||
}(shutdownTimeout)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -16,7 +17,8 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
|
||||
// temporaryError is the interface for temporary errors from the Go standard
|
||||
@@ -52,7 +54,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
err = web.requestVersionInfo(resp, req.Recheck)
|
||||
err = web.requestVersionInfo(r.Context(), resp, req.Recheck)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
|
||||
@@ -73,7 +75,11 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
// update server.
|
||||
func (web *webAPI) requestVersionInfo(resp *versionResponse, recheck bool) (err error) {
|
||||
func (web *webAPI) requestVersionInfo(
|
||||
ctx context.Context,
|
||||
resp *versionResponse,
|
||||
recheck bool,
|
||||
) (err error) {
|
||||
updater := web.conf.updater
|
||||
for range 3 {
|
||||
resp.VersionInfo, err = updater.VersionInfo(recheck)
|
||||
@@ -89,7 +95,9 @@ func (web *webAPI) requestVersionInfo(resp *versionResponse, recheck bool) (err
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/934.
|
||||
const sleepTime = 2 * time.Second
|
||||
|
||||
log.Info("update: temp net error: %v; sleeping for %s and retrying", err, sleepTime)
|
||||
err = fmt.Errorf("temp net error: %w; sleeping for %s and retrying", err, sleepTime)
|
||||
web.logger.InfoContext(ctx, "updating version info", slogutil.KeyError, err)
|
||||
|
||||
time.Sleep(sleepTime)
|
||||
|
||||
continue
|
||||
@@ -140,7 +148,7 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
// The background context is used because the underlying functions wrap it
|
||||
// with timeout and shut down the server, which handles current request. It
|
||||
// also should be done in a separate goroutine for the same reason.
|
||||
go finishUpdate(context.Background(), execPath, web.conf.runningAsService)
|
||||
go finishUpdate(context.Background(), web.logger, execPath, web.conf.runningAsService)
|
||||
}
|
||||
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
@@ -180,15 +188,17 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
||||
}
|
||||
|
||||
// finishUpdate completes an update procedure.
|
||||
func finishUpdate(ctx context.Context, execPath string, runningAsService bool) {
|
||||
var err error
|
||||
// finishUpdate completes an update procedure. It is intended to be used as a
|
||||
// goroutine.
|
||||
func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningAsService bool) {
|
||||
defer slogutil.RecoverAndExit(ctx, l, osutil.ExitCodeFailure)
|
||||
|
||||
log.Info("stopping all tasks")
|
||||
l.InfoContext(ctx, "stopping all tasks")
|
||||
|
||||
cleanup(ctx)
|
||||
cleanupAlways()
|
||||
|
||||
var err error
|
||||
if runtime.GOOS == "windows" {
|
||||
if runningAsService {
|
||||
// NOTE: We can't restart the service via "kardianos/service"
|
||||
@@ -199,28 +209,28 @@ func finishUpdate(ctx context.Context, execPath string, runningAsService bool) {
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("restarting: stopping: %s", err)
|
||||
panic(fmt.Errorf("restarting service: %w", err))
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
cmd := exec.Command(execPath, os.Args[1:]...)
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
log.Fatalf("restarting:: %s", err)
|
||||
panic(fmt.Errorf("restarting: %w", err))
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
log.Info("restarting: %q %q", execPath, os.Args[1:])
|
||||
l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
|
||||
err = syscall.Exec(execPath, os.Args, os.Environ())
|
||||
if err != nil {
|
||||
log.Fatalf("restarting: %s", err)
|
||||
panic(fmt.Errorf("restarting: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,14 +47,14 @@ func onConfigModified() {
|
||||
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized. l must not be nil.
|
||||
// [config] and [Context] are initialized. baseLogger must not be nil.
|
||||
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "stats"),
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
Limit: time.Duration(config.Stats.Interval),
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
Enabled: config.Stats.Enabled,
|
||||
@@ -80,7 +80,7 @@ func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error)
|
||||
FindClient: Context.clients.findMultiple,
|
||||
BaseDir: querylogDir,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
RotationIvl: config.QueryLog.Interval.Duration,
|
||||
RotationIvl: time.Duration(config.QueryLog.Interval),
|
||||
MemSize: config.QueryLog.MemSize,
|
||||
Enabled: config.QueryLog.Enabled,
|
||||
FileEnabled: config.QueryLog.FileEnabled,
|
||||
@@ -243,7 +243,7 @@ func newServerConfig(
|
||||
Config: fwdConf,
|
||||
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
|
||||
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
||||
UpstreamTimeout: dnsConf.UpstreamTimeout.Duration,
|
||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||
TLSv12Roots: Context.tlsRoots,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpReg,
|
||||
|
||||
@@ -167,13 +167,13 @@ func setupContext(opts options) (err error) {
|
||||
if err != nil {
|
||||
log.Error("parsing configuration file: %s", err)
|
||||
|
||||
os.Exit(1)
|
||||
os.Exit(osutil.ExitCodeFailure)
|
||||
}
|
||||
|
||||
if opts.checkConfig {
|
||||
log.Info("configuration file is ok")
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -522,18 +522,20 @@ func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customU
|
||||
}
|
||||
}
|
||||
|
||||
// initWeb initializes the web module.
|
||||
// initWeb initializes the web module. upd and baseLogger must not be nil.
|
||||
func initWeb(
|
||||
ctx context.Context,
|
||||
opts options,
|
||||
clientBuildFS fs.FS,
|
||||
upd *updater.Updater,
|
||||
l *slog.Logger,
|
||||
baseLogger *slog.Logger,
|
||||
customURL bool,
|
||||
) (web *webAPI, err error) {
|
||||
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
|
||||
|
||||
var clientFS fs.FS
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
logger.WarnContext(ctx, "using local frontend files")
|
||||
|
||||
clientFS = os.DirFS("build/static")
|
||||
} else {
|
||||
@@ -543,10 +545,12 @@ func initWeb(
|
||||
}
|
||||
}
|
||||
|
||||
disableUpdate := !isUpdateEnabled(ctx, l, &opts, customURL)
|
||||
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, customURL)
|
||||
|
||||
webConf := &webConfig{
|
||||
updater: upd,
|
||||
updater: upd,
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
@@ -562,7 +566,7 @@ func initWeb(
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
}
|
||||
|
||||
web = newWebAPI(webConf, l)
|
||||
web = newWebAPI(ctx, webConf)
|
||||
if web == nil {
|
||||
return nil, errors.Error("can not initialize web")
|
||||
}
|
||||
@@ -640,7 +644,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
fatalOnError(err)
|
||||
|
||||
if config.HTTPConfig.Pprof.Enabled {
|
||||
startPprof(config.HTTPConfig.Pprof.Port)
|
||||
startPprof(slogLogger, config.HTTPConfig.Pprof.Port)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -692,7 +696,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
checkPermissions(ctx, slogLogger, Context.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||
}
|
||||
|
||||
Context.web.start()
|
||||
Context.web.start(ctx)
|
||||
|
||||
// Wait for other goroutines to complete their job.
|
||||
<-done
|
||||
@@ -783,7 +787,7 @@ func initUsers() (auth *Auth, err error) {
|
||||
|
||||
trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies))
|
||||
|
||||
sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
|
||||
sessionTTL := time.Duration(config.HTTPConfig.SessionTTL).Seconds()
|
||||
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter, trustedProxies)
|
||||
if auth == nil {
|
||||
return nil, errors.Error("initializing auth module failed")
|
||||
@@ -803,15 +807,15 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
||||
return aghnet.NewIPMut(anonFunc)
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation. l must
|
||||
// not be nil.
|
||||
func startMods(l *slog.Logger) (err error) {
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
// baseLogger must not be nil.
|
||||
func startMods(baseLogger *slog.Logger) (err error) {
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = initDNS(l, statsDir, querylogDir)
|
||||
err = initDNS(baseLogger, statsDir, querylogDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -984,7 +988,7 @@ func loadCmdLineOpts() (opts options) {
|
||||
exitWithError()
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
return opts
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
@@ -329,7 +330,7 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||
fmt.Println(version.Full())
|
||||
}
|
||||
|
||||
os.Exit(0)
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
|
||||
return nil
|
||||
}, nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -15,9 +16,11 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -39,6 +42,13 @@ const (
|
||||
type webConfig struct {
|
||||
updater *updater.Updater
|
||||
|
||||
// logger is a slog logger used in webAPI. It must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// baseLogger is used to create loggers for other entities. It must not be
|
||||
// nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
clientFS fs.FS
|
||||
|
||||
// BindAddr is the binding address with port for plain HTTP web interface.
|
||||
@@ -94,21 +104,26 @@ type webAPI struct {
|
||||
// logger is a slog logger used in webAPI. It must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// baseLogger is used to create loggers for other entities. It must not be
|
||||
// nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
|
||||
// [Web.http3Server] must also not be nil.
|
||||
httpsServer httpsServer
|
||||
}
|
||||
|
||||
// newWebAPI creates a new instance of the web UI and API server. l must not be
|
||||
// nil.
|
||||
// newWebAPI creates a new instance of the web UI and API server. conf must be
|
||||
// valid.
|
||||
//
|
||||
// TODO(a.garipov): Return a proper error.
|
||||
func newWebAPI(conf *webConfig, l *slog.Logger) (w *webAPI) {
|
||||
log.Info("web: initializing")
|
||||
func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
conf.logger.InfoContext(ctx, "initializing")
|
||||
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
logger: l,
|
||||
conf: conf,
|
||||
logger: conf.logger,
|
||||
baseLogger: conf.baseLogger,
|
||||
}
|
||||
|
||||
clientFS := http.FileServer(http.FS(conf.clientFS))
|
||||
@@ -118,7 +133,11 @@ func newWebAPI(conf *webConfig, l *slog.Logger) (w *webAPI) {
|
||||
|
||||
// add handlers for /install paths, we only need them when we're not configured yet
|
||||
if conf.firstRun {
|
||||
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
|
||||
conf.logger.InfoContext(
|
||||
ctx,
|
||||
"This is the first launch of AdGuard Home, redirecting everything to /install.html",
|
||||
)
|
||||
|
||||
Context.mux.Handle("/install.html", preInstallHandler(clientFS))
|
||||
w.registerInstallHandlers()
|
||||
} else {
|
||||
@@ -154,7 +173,9 @@ func webCheckPortAvailable(port uint16) (ok bool) {
|
||||
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// if necessary.
|
||||
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
log.Debug("web: applying new tls configuration")
|
||||
defer slogutil.RecoverAndExit(ctx, web.logger, osutil.ExitCodeFailure)
|
||||
|
||||
web.logger.DebugContext(ctx, "applying new tls configuration")
|
||||
|
||||
enabled := tlsConf.Enabled &&
|
||||
tlsConf.PortHTTPS != 0 &&
|
||||
@@ -165,7 +186,7 @@ func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettin
|
||||
if enabled {
|
||||
cert, err = tls.X509KeyPair(tlsConf.CertificateChainData, tlsConf.PrivateKeyData)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,8 +194,8 @@ func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettin
|
||||
if web.httpsServer.server != nil {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, shutdownTimeout)
|
||||
shutdownSrv(ctx, web.httpsServer.server)
|
||||
shutdownSrv3(web.httpsServer.server3)
|
||||
shutdownSrv(ctx, web.logger, web.httpsServer.server)
|
||||
shutdownSrv3(ctx, web.logger, web.httpsServer.server3)
|
||||
|
||||
cancel()
|
||||
}
|
||||
@@ -185,12 +206,17 @@ func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettin
|
||||
web.httpsServer.cond.L.Unlock()
|
||||
}
|
||||
|
||||
// loggerKeyServer is the key used by [webAPI] to identify servers.
|
||||
const loggerKeyServer = "server"
|
||||
|
||||
// start - start serving HTTP requests
|
||||
func (web *webAPI) start() {
|
||||
log.Println("AdGuard Home is available at the following addresses:")
|
||||
func (web *webAPI) start(ctx context.Context) {
|
||||
defer slogutil.RecoverAndExit(ctx, web.logger, osutil.ExitCodeFailure)
|
||||
|
||||
web.logger.InfoContext(ctx, "AdGuard Home is available at the following addresses:")
|
||||
|
||||
// for https, we have a separate goroutine loop
|
||||
go web.tlsServerLoop()
|
||||
go web.tlsServerLoop(ctx)
|
||||
|
||||
// this loop is used as an ability to change listening host and/or port
|
||||
for !web.httpsServer.inShutdown {
|
||||
@@ -200,17 +226,19 @@ func (web *webAPI) start() {
|
||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||
hdlr := h2c.NewHandler(withMiddlewares(Context.mux, limitRequestBody), &http2.Server{})
|
||||
|
||||
logger := web.baseLogger.With(loggerKeyServer, "plain")
|
||||
|
||||
// Create a new instance, because the Web is not usable after Shutdown.
|
||||
web.httpServer = &http.Server{
|
||||
ErrorLog: log.StdLog("web: plain", log.DEBUG),
|
||||
Addr: web.conf.BindAddr.String(),
|
||||
Handler: hdlr,
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
|
||||
}
|
||||
go func() {
|
||||
defer log.OnPanic("web: plain")
|
||||
defer slogutil.RecoverAndLog(ctx, web.logger)
|
||||
|
||||
errs <- web.httpServer.ListenAndServe()
|
||||
}()
|
||||
@@ -218,7 +246,7 @@ func (web *webAPI) start() {
|
||||
err := <-errs
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
cleanupAlways()
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// We use ErrServerClosed as a sign that we need to rebind on a new
|
||||
@@ -228,7 +256,7 @@ func (web *webAPI) start() {
|
||||
|
||||
// close gracefully shuts down the HTTP servers.
|
||||
func (web *webAPI) close(ctx context.Context) {
|
||||
log.Info("stopping http server...")
|
||||
web.logger.InfoContext(ctx, "stopping http server")
|
||||
|
||||
web.httpsServer.cond.L.Lock()
|
||||
web.httpsServer.inShutdown = true
|
||||
@@ -238,14 +266,16 @@ func (web *webAPI) close(ctx context.Context) {
|
||||
ctx, cancel = context.WithTimeout(ctx, shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
shutdownSrv(ctx, web.httpsServer.server)
|
||||
shutdownSrv3(web.httpsServer.server3)
|
||||
shutdownSrv(ctx, web.httpServer)
|
||||
shutdownSrv(ctx, web.logger, web.httpsServer.server)
|
||||
shutdownSrv3(ctx, web.logger, web.httpsServer.server3)
|
||||
shutdownSrv(ctx, web.logger, web.httpServer)
|
||||
|
||||
log.Info("stopped http server")
|
||||
web.logger.InfoContext(ctx, "stopped http server")
|
||||
}
|
||||
|
||||
func (web *webAPI) tlsServerLoop() {
|
||||
func (web *webAPI) tlsServerLoop(ctx context.Context) {
|
||||
defer slogutil.RecoverAndExit(ctx, web.logger, osutil.ExitCodeFailure)
|
||||
|
||||
for {
|
||||
web.httpsServer.cond.L.Lock()
|
||||
if web.httpsServer.inShutdown {
|
||||
@@ -273,38 +303,40 @@ func (web *webAPI) tlsServerLoop() {
|
||||
}()
|
||||
|
||||
addr := netip.AddrPortFrom(web.conf.BindAddr.Addr(), portHTTPS).String()
|
||||
logger := web.baseLogger.With(loggerKeyServer, "https")
|
||||
|
||||
web.httpsServer.server = &http.Server{
|
||||
ErrorLog: log.StdLog("web: https", log.DEBUG),
|
||||
Addr: addr,
|
||||
Addr: addr,
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
RootCAs: Context.tlsRoots,
|
||||
CipherSuites: Context.tlsCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
|
||||
}
|
||||
|
||||
printHTTPAddresses(urlutil.SchemeHTTPS)
|
||||
|
||||
if web.conf.serveHTTP3 {
|
||||
go web.mustStartHTTP3(addr)
|
||||
go web.mustStartHTTP3(ctx, addr)
|
||||
}
|
||||
|
||||
log.Debug("web: starting https server")
|
||||
web.logger.DebugContext(ctx, "starting https server")
|
||||
err := web.httpsServer.server.ListenAndServeTLS("", "")
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
cleanupAlways()
|
||||
log.Fatalf("web: https: %s", err)
|
||||
panic(fmt.Errorf("https: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (web *webAPI) mustStartHTTP3(address string) {
|
||||
defer log.OnPanic("web: http3")
|
||||
func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
|
||||
defer slogutil.RecoverAndExit(ctx, web.logger, osutil.ExitCodeFailure)
|
||||
|
||||
web.httpsServer.server3 = &http3.Server{
|
||||
// TODO(a.garipov): See if there is a way to use the error log as
|
||||
@@ -319,16 +351,16 @@ func (web *webAPI) mustStartHTTP3(address string) {
|
||||
Handler: withMiddlewares(Context.mux, limitRequestBody),
|
||||
}
|
||||
|
||||
log.Debug("web: starting http/3 server")
|
||||
web.logger.DebugContext(ctx, "starting http/3 server")
|
||||
err := web.httpsServer.server3.ListenAndServe()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
cleanupAlways()
|
||||
log.Fatalf("web: http3: %s", err)
|
||||
panic(fmt.Errorf("http3: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
// startPprof launches the debug and profiling server on the provided port.
|
||||
func startPprof(port uint16) {
|
||||
func startPprof(baseLogger *slog.Logger, port uint16) {
|
||||
addr := netip.AddrPortFrom(netutil.IPv4Localhost(), port)
|
||||
|
||||
runtime.SetBlockProfileRate(1)
|
||||
@@ -337,13 +369,16 @@ func startPprof(port uint16) {
|
||||
mux := http.NewServeMux()
|
||||
httputil.RoutePprof(mux)
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("pprof server")
|
||||
ctx := context.Background()
|
||||
logger := baseLogger.With(slogutil.KeyPrefix, "pprof")
|
||||
|
||||
log.Info("pprof: listening on %q", addr)
|
||||
go func() {
|
||||
defer slogutil.RecoverAndLog(ctx, logger)
|
||||
|
||||
logger.InfoContext(ctx, "listening", "addr", addr)
|
||||
err := http.ListenAndServe(addr.String(), mux)
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Error("pprof: shutting down: %s", err)
|
||||
logger.ErrorContext(ctx, "shutting down", slogutil.KeyError, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user