diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 5fe2be3d..3f5ff71e 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -648,17 +648,17 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er // UpdatedProtectionStatus updates protection state, if the protection was // disabled temporarily. Returns the updated state of protection. -func (s *Server) UpdatedProtectionStatus() (enabled bool) { +func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) { s.serverLock.RLock() defer s.serverLock.RUnlock() - disabledUntil := s.conf.ProtectionDisabledUntil + disabledUntil = s.conf.ProtectionDisabledUntil if disabledUntil == nil { - return s.conf.ProtectionEnabled + return s.conf.ProtectionEnabled, nil } if time.Now().Before(*disabledUntil) { - return false + return false, disabledUntil } // Update the values in a separate goroutine, unless an update is already in @@ -671,7 +671,7 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool) { go s.enableProtectionAfterPause() } - return true + return true, nil } // enableProtectionAfterPause sets the protection configuration to enabled diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index f8ce7d6f..05ed0d1e 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -206,7 +206,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { dctx.clientID = string(s.clientIDCache.Get(key[:])) // Get the client-specific filtering settings. - dctx.protectionEnabled = s.UpdatedProtectionStatus() + dctx.protectionEnabled, _ = s.UpdatedProtectionStatus() dctx.setts = s.getClientRequestFilteringSettings(dctx) return resultCodeSuccess @@ -460,7 +460,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { } // indexFirstV4Label returns the index at which the reversed IPv4 address -// starts, assuiming the domain is pre-validated ARPA domain having in-addr and +// starts, assuming the domain is pre-validated ARPA domain having in-addr and // arpa labels removed. func indexFirstV4Label(domain string) (idx int) { idx = len(domain) @@ -478,7 +478,7 @@ func indexFirstV4Label(domain string) (idx int) { } // indexFirstV6Label returns the index at which the reversed IPv6 address -// starts, assuiming the domain is pre-validated ARPA domain having ip6 and arpa +// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa // labels removed. func indexFirstV6Label(domain string) (idx int) { idx = len(domain) diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index d7e238b4..3cf28e4b 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -101,7 +101,7 @@ type jsonDNSConfig struct { } func (s *Server) getDNSConfig() (c *jsonDNSConfig) { - protectionEnabled := s.UpdatedProtectionStatus() + protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus() s.serverLock.RLock() defer s.serverLock.RUnlock() @@ -128,12 +128,6 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) { usePrivateRDNS := s.conf.UsePrivateRDNS localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers) - var disabledUntil *time.Time - if s.conf.ProtectionDisabledUntil != nil { - t := *s.conf.ProtectionDisabledUntil - disabledUntil = &t - } - var upstreamMode string if s.conf.FastestAddr { upstreamMode = "fastest_addr" @@ -169,7 +163,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) { UsePrivateRDNS: &usePrivateRDNS, LocalPTRUpstreams: &localPTRUpstreams, DefaultLocalPTRUpstreams: defLocalPTRUps, - DisabledUntil: disabledUntil, + DisabledUntil: protectionDisabledUntil, } } diff --git a/internal/home/control.go b/internal/home/control.go index 3f654b3c..d8a9f7d8 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -15,6 +15,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/mathutil" "github.com/AdguardTeam/golibs/netutil" "github.com/NYTimes/gziphandler" ) @@ -98,14 +99,17 @@ func collectDNSAddresses() (addrs []string, err error) { // statusResponse is a response for /control/status endpoint. type statusResponse struct { - Version string `json:"version"` - Language string `json:"language"` - DNSAddrs []string `json:"dns_addresses"` - DNSPort int `json:"dns_port"` - HTTPPort int `json:"http_port"` - IsProtectionEnabled bool `json:"protection_enabled"` - // ProtectionDisabledDuration is a pause duration in milliseconds. + Version string `json:"version"` + Language string `json:"language"` + DNSAddrs []string `json:"dns_addresses"` + DNSPort int `json:"dns_port"` + HTTPPort int `json:"http_port"` + + // ProtectionDisabledDuration is the duration of the protection pause in + // milliseconds. ProtectionDisabledDuration int64 `json:"protection_disabled_duration"` + + ProtectionEnabled bool `json:"protection_enabled"` // TODO(e.burkov): Inspect if front-end doesn't requires this field as // openapi.yaml declares. IsDHCPAvailable bool `json:"dhcp_available"` @@ -122,12 +126,15 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { return } - isProtectionEnabled := false - var c *dnsforward.FilteringConfig + var ( + fltConf *dnsforward.FilteringConfig + protectionDisabledUntil *time.Time + protectionEnabled bool + ) if Context.dnsServer != nil { - c = &dnsforward.FilteringConfig{} - Context.dnsServer.WriteDiskConfig(c) - isProtectionEnabled = Context.dnsServer.UpdatedProtectionStatus() + fltConf = &dnsforward.FilteringConfig{} + Context.dnsServer.WriteDiskConfig(fltConf) + protectionEnabled, protectionDisabledUntil = Context.dnsServer.UpdatedProtectionStatus() } var resp statusResponse @@ -135,20 +142,26 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { config.RLock() defer config.RUnlock() - var pauseDuration int64 - if until := config.DNS.ProtectionDisabledUntil; until != nil { - pauseDuration = time.Until(*until).Milliseconds() + var protectionDisabledDuration int64 + if protectionDisabledUntil != nil { + // Make sure that we don't send negative numbers to the frontend, + // since enough time might have passed to make the difference less + // than zero. + protectionDisabledDuration = mathutil.Max( + 0, + time.Until(*protectionDisabledUntil).Milliseconds(), + ) } resp = statusResponse{ Version: version.Version(), + Language: config.Language, DNSAddrs: dnsAddrs, DNSPort: config.DNS.Port, HTTPPort: config.BindPort, - Language: config.Language, + ProtectionDisabledDuration: protectionDisabledDuration, + ProtectionEnabled: protectionEnabled, IsRunning: isRunning(), - ProtectionDisabledDuration: pauseDuration, - IsProtectionEnabled: isProtectionEnabled, } }()