From 1a3853d52af307509c2b9f7081f4b568d9c4c837 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Fri, 14 Mar 2025 13:51:45 +0300 Subject: [PATCH] Pull request 2353: AGDNS-2688-check-host Merge in DNS/adguard-home from AGDNS-2688-check-host to master Squashed commit of the following: commit bd9ed498b0e36fa044e6921fa946062ac40fe616 Merge: 8dffd94a3 c41af2763 Author: Eugene Burkov Date: Fri Mar 14 13:42:34 2025 +0300 Merge branch 'master' into AGDNS-2688-check-host commit 8dffd94a3bc700cf014cbb16aee9c6339bdc7ffa Author: Stanislav Chzhen Date: Wed Mar 12 17:12:56 2025 +0300 filtering: imp code commit d9a01c8fa60c70e3fd19c40c1a58aec00ae64a6a Author: Stanislav Chzhen Date: Tue Mar 11 20:33:18 2025 +0300 all: imp code commit f1aca5f2eb71a1d8bb155a309c618e7a80f8fde5 Author: Ildar Kamalov Date: Tue Mar 11 16:05:32 2025 +0300 ADG-9783 update check form commit a8ebb0401dbaa08fdd04171b1ac66b87d0228c7b Author: Stanislav Chzhen Date: Mon Mar 10 16:41:55 2025 +0300 dnsforward: imp docs commit 36f5db9075cc525c13905e0318dfbc4089355523 Merge: 9a746ee9a 66fba942c Author: Stanislav Chzhen Date: Mon Mar 10 16:09:22 2025 +0300 Merge branch 'master' into AGDNS-2688-check-host commit 9a746ee9a05895676a60980eb4bd1381fe8d8e4b Author: Stanislav Chzhen Date: Mon Mar 10 16:06:48 2025 +0300 all: imp docs commit 0a25e1e8f3536053e30049497bb42a58c6a153d6 Author: Stanislav Chzhen Date: Thu Mar 6 21:48:44 2025 +0300 all: imp code commit ec618bc484190dde52a0dc57d144bade8dfc22e2 Author: Stanislav Chzhen Date: Thu Mar 6 17:38:35 2025 +0300 all: imp code commit 979c5cfd4c34e2aac46ea11b7fcba8d2929966b8 Author: Stanislav Chzhen Date: Wed Mar 5 21:22:54 2025 +0300 all: add tests commit ce0d6117ad7f341edcc018a68acedaa0b718bef1 Author: Stanislav Chzhen Date: Tue Mar 4 15:13:06 2025 +0300 all: check host --- CHANGELOG.md | 5 + client/src/__locales/en.json | 4 + client/src/components/Filters/Check/index.tsx | 70 ++++-- client/src/components/Filters/CustomRules.tsx | 16 +- client/src/helpers/constants.ts | 9 + internal/client/storage.go | 36 +++ internal/dnsforward/config.go | 4 - .../dnsforward/dnsforward_internal_test.go | 49 ++++- internal/dnsforward/filter.go | 4 +- internal/dnsforward/filter_internal_test.go | 6 +- internal/filtering/blocked.go | 3 + internal/filtering/filter.go | 17 ++ internal/filtering/filtering.go | 19 ++ internal/filtering/http.go | 73 ++++++- internal/filtering/http_internal_test.go | 169 ++++++++++++++ internal/filtering/safesearch.go | 3 + internal/home/clients.go | 2 + internal/home/dns.go | 52 ----- internal/home/dns_internal_test.go | 206 ------------------ openapi/CHANGELOG.md | 6 + openapi/openapi.yaml | 14 ++ 21 files changed, 467 insertions(+), 300 deletions(-) delete mode 100644 internal/home/dns_internal_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 3755ae41..162fa516 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,11 @@ NOTE: Add new changes BELOW THIS COMMENT. - Go version has been updated to prevent the possibility of exploiting the Go vulnerabilities fixed in [1.24.1][go-1.24.1]. +### Added + +- The ability to check filtering rules for host names using an optional query type and optional ClientID or client IP address. +- Optional `client` and `qtype` URL query parameters to the `GET /control/check_host` HTTP API. + ### Fixed - Clearing the DNS cache on the *DNS settings* page now includes both global cache and custom client cache. diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 9634cf70..67c9f55f 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -620,6 +620,10 @@ "check_cname": "CNAME: {{cname}}", "check_reason": "Reason: {{reason}}", "check_service": "Service name: {{service}}", + "check_hostname": "Hostname or domain name", + "check_client_id": "Client identifier (ClientID or IP address)", + "check_enter_client_id": "Enter client identifier", + "check_dns_record": "Select DNS record type", "service_name": "Service name", "check_not_found": "Not found in your filter lists", "client_confirm_block": "Are you sure you want to block the client \"{{ip}}\"?", diff --git a/client/src/components/Filters/Check/index.tsx b/client/src/components/Filters/Check/index.tsx index 996d6b75..2f578ade 100644 --- a/client/src/components/Filters/Check/index.tsx +++ b/client/src/components/Filters/Check/index.tsx @@ -9,13 +9,17 @@ import Info from './Info'; import { RootState } from '../../../initialState'; import { validateRequiredValue } from '../../../helpers/validators'; import { Input } from '../../ui/Controls/Input'; +import { DNS_RECORD_TYPES } from '../../../helpers/constants'; +import { Select } from '../../ui/Controls/Select'; -interface FormValues { +export type FilteringCheckFormValues = { name: string; + client?: string; + qtype?: string; } type Props = { - onSubmit?: (data: FormValues) => void; + onSubmit?: (data: FilteringCheckFormValues) => void; }; const Check = ({ onSubmit }: Props) => { @@ -27,11 +31,13 @@ const Check = ({ onSubmit }: Props) => { const { control, handleSubmit, - formState: { isDirty, isValid }, - } = useForm({ + formState: { isValid }, + } = useForm({ mode: 'onBlur', defaultValues: { name: '', + client: '', + qtype: DNS_RECORD_TYPES[0], }, }); @@ -48,24 +54,56 @@ const Check = ({ onSubmit }: Props) => { - - - } /> )} /> + ( + + )} + /> + + ( + + )} + /> + + + {hostname && ( <>
diff --git a/client/src/components/Filters/CustomRules.tsx b/client/src/components/Filters/CustomRules.tsx index db7fa92d..cc378e99 100644 --- a/client/src/components/Filters/CustomRules.tsx +++ b/client/src/components/Filters/CustomRules.tsx @@ -7,7 +7,7 @@ import PageTitle from '../ui/PageTitle'; import Examples from './Examples'; -import Check from './Check'; +import Check, { FilteringCheckFormValues } from './Check'; import { getTextareaCommentsHighlight, syncScroll } from '../../helpers/highlightTextareaComments'; import { COMMENT_LINE_DEFAULT_TOKEN } from '../../helpers/constants'; @@ -48,8 +48,18 @@ class CustomRules extends Component { this.props.setRules(this.props.filtering.userRules); }; - handleCheck = (values: any) => { - this.props.checkHost(values); + handleCheck = (values: FilteringCheckFormValues) => { + const params: FilteringCheckFormValues = { name: values.name }; + + if (values.client) { + params.client = values.client; + } + + if (values.qtype) { + params.qtype = values.qtype; + } + + this.props.checkHost(params); }; onScroll = (e: any) => syncScroll(e, this.ref); diff --git a/client/src/helpers/constants.ts b/client/src/helpers/constants.ts index 7b54ae8d..d4e7c940 100644 --- a/client/src/helpers/constants.ts +++ b/client/src/helpers/constants.ts @@ -523,3 +523,12 @@ export const TIME_UNITS = { HOURS: 'hours', DAYS: 'days', }; + +export const DNS_RECORD_TYPES = [ + "A", "AAAA", "AFSDB", "APL", "CAA", "CDNSKEY", "CDS", "CERT", "CNAME", + "CSYNC", "DHCID", "DLV", "DNAME", "DNSKEY", "DS", "EUI48", "EUI64", + "HINFO", "HIP", "HTTPS", "IPSECKEY", "KEY", "KX", "LOC", "MX", "NAPTR", + "NS", "NSEC", "NSEC3", "NSEC3PARAM", "OPENPGPKEY", "PTR", "RP", "RRSIG", + "SIG", "SMIMEA", "SOA", "SRV", "SSHFP", "SVCB", "TA", "TKEY", + "TLSA", "TSIG", "TXT", "URI", "ZONEMD" +]; diff --git a/internal/client/storage.go b/internal/client/storage.go index e3bf58e0..3bb839b4 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" @@ -671,3 +672,38 @@ func (s *Storage) ClearUpstreamCache() { s.upstreamManager.clearUpstreamCache() } + +// ApplyClientFiltering retrieves persistent client information using the +// ClientID or client IP address, and applies it to the filtering settings. +func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) { + c, ok := s.index.findByClientID(id) + if !ok { + c, ok = s.index.findByIP(addr) + } + + if !ok { + s.logger.Debug("no client filtering settings found", "clientid", id, "addr", addr) + + return + } + + s.logger.Debug("applying custom client filtering settings", "client_name", c.Name) + + setts.ClientIP = addr + + if c.UseOwnBlockedServices { + setts.BlockedServices = c.BlockedServices.Clone() + } + + setts.ClientName = c.Name + setts.ClientTags = slices.Clone(c.Tags) + if !c.UseOwnSettings { + return + } + + setts.FilteringEnabled = c.FilteringEnabled + setts.SafeSearchEnabled = c.SafeSearchConf.Enabled + setts.ClientSafeSearch = c.SafeSearch + setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled + setts.ParentalEnabled = c.ParentalEnabled +} diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index c549a07e..37e8ce98 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -16,7 +16,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/container" @@ -34,9 +33,6 @@ import ( type Config struct { // Callbacks for other modules - // FilterHandler is an optional additional filtering callback. - FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"` - // ClientsContainer stores the information about special handling of some // DNS clients. ClientsContainer ClientsContainer `yaml:"-"` diff --git a/internal/dnsforward/dnsforward_internal_test.go b/internal/dnsforward/dnsforward_internal_test.go index 0ced288d..5f035f84 100644 --- a/internal/dnsforward/dnsforward_internal_test.go +++ b/internal/dnsforward/dnsforward_internal_test.go @@ -27,6 +27,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -106,6 +107,21 @@ func startDeferStop(t *testing.T, s *Server) { testutil.CleanupAndRequireSuccess(t, s.Stop) } +// applyEmptyClientFiltering is a helper function for tests with +// [filtering.Config] that does nothing. +func applyEmptyClientFiltering(_ string, _ netip.Addr, _ *filtering.Settings) {} + +// emptyFilteringBlockedServices is a helper function that returns an empty +// filtering blocked services for tests. +func emptyFilteringBlockedServices() (bsvc *filtering.BlockedServices) { + return &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + } +} + +// createTestServer is a helper function that returns a properly initialized +// *Server for use in tests, given the provided parameters. It also populates +// the filtering configuration with default parameters. func createTestServer( t *testing.T, filterConf *filtering.Config, @@ -123,6 +139,12 @@ func createTestServer( Data: []byte(rules), }} + filterConf.BlockedServices = cmp.Or(filterConf.BlockedServices, emptyFilteringBlockedServices()) + + if filterConf.ApplyClientFiltering == nil { + filterConf.ApplyClientFiltering = applyEmptyClientFiltering + } + f, err := filtering.New(filterConf, filters) require.NoError(t, err) @@ -926,9 +948,6 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, Config: Config{ - FilterHandler: func(_ netip.Addr, _ string, settings *filtering.Settings) { - settings.FilteringEnabled = false - }, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, @@ -1020,10 +1039,12 @@ func TestBlockedCustomIP(t *testing.T) { }} f, err := filtering.New(&filtering.Config{ - ProtectionEnabled: true, - BlockingMode: filtering.BlockingModeCustomIP, - BlockingIPv4: netip.Addr{}, - BlockingIPv6: netip.Addr{}, + ProtectionEnabled: true, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeCustomIP, + BlockingIPv4: netip.Addr{}, + BlockingIPv6: netip.Addr{}, }, filters) require.NoError(t, err) @@ -1176,7 +1197,9 @@ func TestBlockedBySafeBrowsing(t *testing.T) { func TestRewrite(t *testing.T) { c := &filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, Rewrites: []*filtering.LegacyRewrite{{ Domain: "test.com", Answer: "1.2.3.4", @@ -1322,7 +1345,9 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { const localDomain = "lan" flt, err := filtering.New(&filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, }, nil) require.NoError(t, err) @@ -1411,8 +1436,10 @@ func TestPTRResponseFromHosts(t *testing.T) { }) flt, err := filtering.New(&filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, - EtcHosts: hc, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, + EtcHosts: hc, }, nil) require.NoError(t, err) diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index f6cd319d..6cfd7bea 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -17,9 +17,7 @@ import ( func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { setts = s.dnsFilter.Settings() setts.ProtectionEnabled = dctx.protectionEnabled - if s.conf.FilterHandler != nil { - s.conf.FilterHandler(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts) - } + s.dnsFilter.ApplyAdditionalFiltering(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts) return setts } diff --git a/internal/dnsforward/filter_internal_test.go b/internal/dnsforward/filter_internal_test.go index 922213c4..7f1ab293 100644 --- a/internal/dnsforward/filter_internal_test.go +++ b/internal/dnsforward/filter_internal_test.go @@ -45,8 +45,10 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { }} f, err := filtering.New(&filtering.Config{ - ProtectionEnabled: true, - BlockingMode: filtering.BlockingModeDefault, + ProtectionEnabled: true, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, }, filters) require.NoError(t, err) f.SetEnabled(true) diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index aec5855e..ca59a1b8 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -49,6 +49,9 @@ func initBlockedServices() { } // BlockedServices is the configuration of blocked services. +// +// TODO(s.chzhen): Move to a higher-level package to allow importing the client +// package into the filtering package. type BlockedServices struct { // Schedule is blocked services schedule for every day of the week. Schedule *schedule.Weekly `json:"schedule" yaml:"schedule"` diff --git a/internal/filtering/filter.go b/internal/filtering/filter.go index 0dd3471c..5242082f 100644 --- a/internal/filtering/filter.go +++ b/internal/filtering/filter.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "net/netip" "os" "path/filepath" "slices" @@ -629,3 +630,19 @@ func (d *DNSFilter) enableFiltersLocked(async bool) { d.SetEnabled(d.conf.FilteringEnabled) } + +// ApplyAdditionalFiltering enhances the provided filtering settings with +// blocked services and client-specific configurations. +func (d *DNSFilter) ApplyAdditionalFiltering(cliAddr netip.Addr, clientID string, setts *Settings) { + d.ApplyBlockedServices(setts) + + d.applyClientFiltering(clientID, cliAddr, setts) + if setts.BlockedServices != nil { + // TODO(e.burkov): Get rid of this crutch. + setts.ServicesRules = nil + svcs := setts.BlockedServices.IDs + if !setts.BlockedServices.Schedule.Contains(time.Now()) { + d.ApplyBlockedServicesList(setts, svcs) + } + } +} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 8836515c..d5fd49d5 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -40,6 +40,8 @@ type ServiceEntry struct { } // Settings are custom filtering settings for a client. +// +// TODO(s.chzhen): Move to the client package. type Settings struct { ClientName string ClientIP netip.Addr @@ -47,6 +49,10 @@ type Settings struct { ServicesRules []ServiceEntry + // BlockedServices is the configuration of blocked services of a client. It + // is nil if the client does not have any blocked services. + BlockedServices *BlockedServices + ProtectionEnabled bool FilteringEnabled bool SafeSearchEnabled bool @@ -78,6 +84,11 @@ type Config struct { SafeSearch SafeSearch `yaml:"-"` + // ApplyClientFiltering retrieves persistent client information using the + // ClientID or client IP address, and applies it to the filtering settings. + // It must not be nil. + ApplyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) `yaml:"-"` + // BlockedServices is the configuration of blocked services. // Per-client settings can override this configuration. BlockedServices *BlockedServices `yaml:"blocked_services"` @@ -244,6 +255,13 @@ type DNSFilter struct { // parentalControl is the parental control hash-prefix checker. parentalControlChecker Checker + // applyClientFiltering retrieves persistent client information using the + // ClientID or client IP address, and applies it to the filtering settings. + // + // TODO(s.chzhen): Consider finding a better approach while taking an + // import cycle into account. + applyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) + engineLock sync.RWMutex // confMu protects conf. @@ -998,6 +1016,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { refreshLock: &sync.Mutex{}, safeBrowsingChecker: c.SafeBrowsingChecker, parentalControlChecker: c.ParentalControlChecker, + applyClientFiltering: c.ApplyClientFiltering, confMu: &sync.RWMutex{}, } diff --git a/internal/filtering/http.go b/internal/filtering/http.go index b99bc253..266a422e 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -9,6 +9,8 @@ import ( "os" "path/filepath" "slices" + "strconv" + "strings" "sync" "time" @@ -420,15 +422,53 @@ type checkHostResp struct { FilterID rulelist.URLFilterID `json:"filter_id"` } +// handleCheckHost is the handler for the GET /control/filtering/check_host HTTP +// API. func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { - host := r.URL.Query().Get("name") + query := r.URL.Query() + host := query.Get("name") + if host == "" { + aghhttp.Error( + r, + w, + http.StatusBadRequest, + `query parameter "name" is required`, + ) + + return + } + + cli := query.Get("client") + qTypeStr := query.Get("qtype") + qType, err := stringToDNSType(qTypeStr) + if err != nil { + aghhttp.Error( + r, + w, + http.StatusUnprocessableEntity, + "bad qtype query parameter: %q", + qTypeStr, + ) + + return + } setts := d.Settings() setts.FilteringEnabled = true setts.ProtectionEnabled = true - d.ApplyBlockedServices(setts) - result, err := d.CheckHost(host, dns.TypeA, setts) + addr, err := netip.ParseAddr(cli) + if err == nil { + setts.ClientIP = addr + d.ApplyAdditionalFiltering(addr, "", setts) + } else if cli != "" { + // TODO(s.chzhen): Set [Settings.ClientName] once urlfilter supports + // multiple client names. This will handle the case when a rule exists + // but the persistent client does not. + d.ApplyAdditionalFiltering(netip.Addr{}, cli, setts) + } + + result, err := d.CheckHost(host, qType, setts) if err != nil { aghhttp.Error( r, @@ -466,6 +506,33 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { aghhttp.WriteJSONResponseOK(w, r, resp) } +// stringToDNSType is a helper function that converts a string to DNS type. If +// the string is empty, it returns the default value [dns.TypeA]. +func stringToDNSType(str string) (qtype uint16, err error) { + if str == "" { + return dns.TypeA, nil + } + + qtype, ok := dns.StringToType[str] + if ok { + return qtype, nil + } + + // typePref is a prefix for DNS types from experimental RFCs. + const typePref = "TYPE" + + if !strings.HasPrefix(str, typePref) { + return 0, errors.ErrBadEnumValue + } + + val, err := strconv.ParseUint(str[len(typePref):], 10, 16) + if err != nil { + return 0, errors.ErrBadEnumValue + } + + return uint16(val), nil +} + // setProtectedBool sets the value of a boolean pointer under a lock. l must // protect the value under ptr. // diff --git a/internal/filtering/http_internal_test.go b/internal/filtering/http_internal_test.go index 8330dac6..a46d5d7b 100644 --- a/internal/filtering/http_internal_test.go +++ b/internal/filtering/http_internal_test.go @@ -3,11 +3,15 @@ package filtering import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "net/netip" + "strings" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -305,3 +309,168 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) { }) } } + +func TestDNSFilter_HandleCheckHost(t *testing.T) { + const ( + cliName = "client_name" + cliID = "client_id" + + notFilteredHost = "not.filterd.example" + allowedHost = "allowed.example" + blockedHost = "blocked.example" + cliHost = "client.example" + qTypeHost = "qtype.example" + cliQTypeHost = "cli.qtype.example" + + target = "/control/check_host" + hostFmt = target + "?name=%s" + hostCliFmt = hostFmt + "&client=%s" + hostQTypeFmt = hostFmt + "&qtype=%s" + hostCliQTypeFmt = hostCliFmt + "&qtype=%s" + + allowedRuleFmt = "@@||%s^" + blockedRuleFmt = "||%s^" + blockedRuleCliFmt = blockedRuleFmt + "$client=%s" + blockedRuleQTypeFmt = blockedRuleFmt + "$dnstype=%s" + blockedRuleCliQTypeFmt = blockedRuleCliFmt + ",dnstype=%s" + ) + + var ( + allowedRule = fmt.Sprintf(allowedRuleFmt, allowedHost) + blockedRule = fmt.Sprintf(blockedRuleFmt, blockedHost) + blockedClientRule = fmt.Sprintf(blockedRuleCliFmt, cliHost, cliName) + blockedQTypeRule = fmt.Sprintf(blockedRuleQTypeFmt, qTypeHost, "CNAME") + blockedClientQTypeRule = fmt.Sprintf(blockedRuleCliQTypeFmt, cliQTypeHost, cliName, "CNAME") + + notFilteredURL = fmt.Sprintf(hostFmt, notFilteredHost) + allowedURL = fmt.Sprintf(hostFmt, allowedHost) + blockedURL = fmt.Sprintf(hostFmt, blockedHost) + blockedClientURL = fmt.Sprintf(hostCliFmt, cliHost, cliID) + allowedQTypeURL = fmt.Sprintf(hostQTypeFmt, qTypeHost, "AAAA") + blockedQTypeURL = fmt.Sprintf(hostQTypeFmt, qTypeHost, "CNAME") + allowedClientQTypeURL = fmt.Sprintf(hostCliQTypeFmt, cliQTypeHost, cliID, "AAAA") + blockedClientQTypeURL = fmt.Sprintf(hostCliQTypeFmt, cliQTypeHost, cliID, "CNAME") + ) + + rules := []string{ + allowedRule, + blockedRule, + blockedClientRule, + blockedQTypeRule, + blockedClientQTypeRule, + } + rulesData := strings.Join(rules, "\n") + + filters := []Filter{{ + ID: 0, Data: []byte(rulesData), + }} + + clientNames := map[string]string{ + cliID: cliName, + } + + dnsFilter, err := New(&Config{ + BlockedServices: &BlockedServices{ + Schedule: schedule.EmptyWeekly(), + }, + ApplyClientFiltering: func(clientID string, cliAddr netip.Addr, setts *Settings) { + setts.ClientName = clientNames[clientID] + }, + }, filters) + require.NoError(t, err) + + testCases := []struct { + name string + url string + want *checkHostResp + }{{ + name: "not_filtered", + url: notFilteredURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredNotFound], + Rule: "", + Rules: []*checkHostRespRule{}, + }, + }, { + name: "allowed", + url: allowedURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredAllowList], + Rule: allowedRule, + Rules: []*checkHostRespRule{{ + Text: allowedRule, + }}, + }, + }, { + name: "blocked", + url: blockedURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedRule, + Rules: []*checkHostRespRule{{ + Text: blockedRule, + }}, + }, + }, { + name: "blocked_client", + url: blockedClientURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedClientRule, + Rules: []*checkHostRespRule{{ + Text: blockedClientRule, + }}, + }, + }, { + name: "allowed_qtype", + url: allowedQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredNotFound], + Rule: "", + Rules: []*checkHostRespRule{}, + }, + }, { + name: "blocked_qtype", + url: blockedQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedQTypeRule, + Rules: []*checkHostRespRule{{ + Text: blockedQTypeRule, + }}, + }, + }, { + name: "blocked_client_qtype", + url: blockedClientQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedClientQTypeRule, + Rules: []*checkHostRespRule{{ + Text: blockedClientQTypeRule, + }}, + }, + }, { + name: "allowed_client_qtype", + url: allowedClientQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredNotFound], + Rule: "", + Rules: []*checkHostRespRule{}, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tc.url, nil) + w := httptest.NewRecorder() + + dnsFilter.handleCheckHost(w, r) + + res := &checkHostResp{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + assert.Equal(t, tc.want, res) + }) + } +} diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index b389573a..92df7062 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -3,6 +3,9 @@ package filtering import "context" // SafeSearch interface describes a service for search engines hosts rewrites. +// +// TODO(s.chzhen): Move to a higher-level package to allow importing the client +// package into the filtering package. type SafeSearch interface { // CheckHost checks host with safe search filter. CheckHost must be safe // for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA]. diff --git a/internal/home/clients.go b/internal/home/clients.go index 4a5a4d00..781e5e9d 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -121,6 +121,8 @@ func (clients *clientsContainer) Init( sigHdlr.addClientStorage(clients.storage) + filteringConf.ApplyClientFiltering = clients.storage.ApplyClientFiltering + return nil } diff --git a/internal/home/dns.go b/internal/home/dns.go index 4a89131c..deca502f 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -247,7 +247,6 @@ func newServerConfig( hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) fwdConf := dnsConf.Config - fwdConf.FilterHandler = applyAdditionalFiltering fwdConf.ClientsContainer = clientsContainer newConf = &dnsforward.ServerConfig{ @@ -411,57 +410,6 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) { return de } -// applyAdditionalFiltering adds additional client information and settings if -// the client has them. -func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filtering.Settings) { - // pref is a prefix for logging messages around the scope. - const pref = "applying filters" - - globalContext.filters.ApplyBlockedServices(setts) - - log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID) - - if !clientIP.IsValid() { - return - } - - setts.ClientIP = clientIP - - c, ok := globalContext.clients.storage.Find(clientID) - if !ok { - c, ok = globalContext.clients.storage.Find(clientIP.String()) - if !ok { - log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) - - return - } - } - - log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID) - - if c.UseOwnBlockedServices { - // TODO(e.burkov): Get rid of this crutch. - setts.ServicesRules = nil - svcs := c.BlockedServices.IDs - if !c.BlockedServices.Schedule.Contains(time.Now()) { - globalContext.filters.ApplyBlockedServicesList(setts, svcs) - log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) - } - } - - setts.ClientName = c.Name - setts.ClientTags = c.Tags - if !c.UseOwnSettings { - return - } - - setts.FilteringEnabled = c.FilteringEnabled - setts.SafeSearchEnabled = c.SafeSearchConf.Enabled - setts.ClientSafeSearch = c.SafeSearch - setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled - setts.ParentalEnabled = c.ParentalEnabled -} - func startDNSServer() error { config.RLock() defer config.RUnlock() diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go deleted file mode 100644 index d7c627f7..00000000 --- a/internal/home/dns_internal_test.go +++ /dev/null @@ -1,206 +0,0 @@ -package home - -import ( - "net/netip" - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/AdGuardHome/internal/schedule" - "github.com/AdguardTeam/golibs/logutil/slogutil" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) - -// newStorage is a helper function that returns a client storage filled with -// persistent clients. It also generates a UID for each client. -func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { - tb.Helper() - - ctx := testutil.ContextWithTimeout(tb, testTimeout) - s, err := client.NewStorage(ctx, &client.StorageConfig{ - Logger: slogutil.NewDiscardLogger(), - }) - require.NoError(tb, err) - - for _, p := range clients { - p.UID = client.MustNewUID() - require.NoError(tb, s.Add(ctx, p)) - } - - return s -} - -func TestApplyAdditionalFiltering(t *testing.T) { - var err error - - globalContext.filters, err = filtering.New(&filtering.Config{ - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - }, - }, nil) - require.NoError(t, err) - - globalContext.clients.storage = newStorage(t, []*client.Persistent{{ - Name: "default", - ClientIDs: []string{"default"}, - UseOwnSettings: false, - SafeSearchConf: filtering.SafeSearchConfig{Enabled: false}, - FilteringEnabled: false, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }, { - Name: "custom_filtering", - ClientIDs: []string{"custom_filtering"}, - UseOwnSettings: true, - SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: true, - ParentalEnabled: true, - }, { - Name: "partial_custom_filtering", - ClientIDs: []string{"partial_custom_filtering"}, - UseOwnSettings: true, - SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }}) - - testCases := []struct { - name string - id string - FilteringEnabled assert.BoolAssertionFunc - SafeSearchEnabled assert.BoolAssertionFunc - SafeBrowsingEnabled assert.BoolAssertionFunc - ParentalEnabled assert.BoolAssertionFunc - }{{ - name: "global_settings", - id: "default", - FilteringEnabled: assert.False, - SafeSearchEnabled: assert.False, - SafeBrowsingEnabled: assert.False, - ParentalEnabled: assert.False, - }, { - name: "custom_settings", - id: "custom_filtering", - FilteringEnabled: assert.True, - SafeSearchEnabled: assert.True, - SafeBrowsingEnabled: assert.True, - ParentalEnabled: assert.True, - }, { - name: "partial", - id: "partial_custom_filtering", - FilteringEnabled: assert.True, - SafeSearchEnabled: assert.True, - SafeBrowsingEnabled: assert.False, - ParentalEnabled: assert.False, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setts := &filtering.Settings{} - - applyAdditionalFiltering(testIPv4, tc.id, setts) - tc.FilteringEnabled(t, setts.FilteringEnabled) - tc.SafeSearchEnabled(t, setts.SafeSearchEnabled) - tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled) - tc.ParentalEnabled(t, setts.ParentalEnabled) - }) - } -} - -func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { - filtering.InitModule() - - var ( - globalBlockedServices = []string{"ok"} - clientBlockedServices = []string{"ok", "mail_ru", "vk"} - invalidBlockedServices = []string{"invalid"} - - err error - ) - - globalContext.filters, err = filtering.New(&filtering.Config{ - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: globalBlockedServices, - }, - }, nil) - require.NoError(t, err) - - globalContext.clients.storage = newStorage(t, []*client.Persistent{{ - Name: "default", - ClientIDs: []string{"default"}, - UseOwnBlockedServices: false, - }, { - Name: "no_services", - ClientIDs: []string{"no_services"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - }, - UseOwnBlockedServices: true, - }, { - Name: "services", - ClientIDs: []string{"services"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, - }, { - Name: "invalid_services", - ClientIDs: []string{"invalid_services"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: invalidBlockedServices, - }, - UseOwnBlockedServices: true, - }, { - Name: "allow_all", - ClientIDs: []string{"allow_all"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.FullWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, - }}) - - testCases := []struct { - name string - id string - wantLen int - }{{ - name: "global_settings", - id: "default", - wantLen: len(globalBlockedServices), - }, { - name: "custom_settings", - id: "no_services", - wantLen: 0, - }, { - name: "custom_settings_block", - id: "services", - wantLen: len(clientBlockedServices), - }, { - name: "custom_settings_invalid", - id: "invalid_services", - wantLen: 0, - }, { - name: "custom_settings_inactive_schedule", - id: "allow_all", - wantLen: 0, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setts := &filtering.Settings{} - - applyAdditionalFiltering(testIPv4, tc.id, setts) - require.Len(t, setts.ServicesRules, tc.wantLen) - }) - } -} diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 0e6dbed0..20132c42 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -4,6 +4,12 @@ ## v0.108.0: API changes +## v0.107.58: API changes + +### The ability to check rules for query types and/or clients: GET /control/check_host + +- Added optional `client` and `qtype` URL query parameters. + ## v0.107.57: API changes - The new field `"upstream_timeout"` in `GET /control/dns_info` and `POST /control/dns_config` is the number of seconds to wait for a response from the upstream server. diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 633138c4..d6c47ce2 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -739,6 +739,20 @@ - 'name': 'name' 'in': 'query' 'description': 'Filter by host name' + 'required': true + 'example': 'google.com' + 'schema': + 'type': 'string' + - 'name': 'client' + 'in': 'query' + 'description': 'Optional ClientID or client IP address' + 'example': '192.0.2.1' + 'schema': + 'type': 'string' + - 'name': 'qtype' + 'in': 'query' + 'description': 'Optional DNS type' + 'example': 'AAAA' 'schema': 'type': 'string' 'responses':