From a829adad10e4e746d9c4e100b00ad68e9c950e3c Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Mon, 17 Mar 2025 20:56:05 +0300 Subject: [PATCH] all: resync with master --- CHANGELOG.md | 10 +- client/src/__locales/en.json | 4 + client/src/components/Filters/Check/index.tsx | 70 ++- client/src/components/Filters/CustomRules.tsx | 16 +- client/src/helpers/constants.ts | 9 + internal/aghalg/sortedmap_test.go | 9 +- internal/aghnet/hostgen_test.go | 7 +- internal/aghnet/interfaces_test.go | 68 +-- internal/aghnet/ipmut_test.go | 7 +- ...in_test.go => net_darwin_internal_test.go} | 2 + ...d_test.go => net_freebsd_internal_test.go} | 0 ...nux_test.go => net_linux_internal_test.go} | 0 ...d_test.go => net_openbsd_internal_test.go} | 0 .../aghos/{os_test.go => os_internal_test.go} | 0 internal/client/storage.go | 36 ++ internal/client/upstreammanager.go | 4 +- ...bitset_test.go => bitset_internal_test.go} | 0 ...test.go => broadcast_bsd_internal_test.go} | 0 ...t.go => broadcast_others_internal_test.go} | 0 ...ix_test.go => dhcpd_unix_internal_test.go} | 0 ..._test.go => http_windows_internal_test.go} | 0 ...range_test.go => iprange_internal_test.go} | 0 ..._test.go => options_unix_internal_test.go} | 0 ..._unix_test.go => v4_unix_internal_test.go} | 0 ..._unix_test.go => v6_unix_internal_test.go} | 0 ...access_test.go => access_internal_test.go} | 0 ...ntid_test.go => clientid_internal_test.go} | 0 internal/dnsforward/config.go | 4 - ...config_test.go => config_internal_test.go} | 0 .../{dns64_test.go => dns64_internal_test.go} | 0 ...rd_test.go => dnsforward_internal_test.go} | 49 +- ...te_test.go => dnsrewrite_internal_test.go} | 0 internal/dnsforward/filter.go | 4 +- ...filter_test.go => filter_internal_test.go} | 6 +- .../{http_test.go => http_internal_test.go} | 0 .../{stats_test.go => stats_internal_test.go} | 0 ...cbmsg_test.go => svcbmsg_internal_test.go} | 0 internal/filtering/blocked.go | 3 + internal/filtering/dnsrewrite_test.go | 34 +- internal/filtering/filter.go | 17 + ...filter_test.go => filter_internal_test.go} | 0 internal/filtering/filtering.go | 19 + ...ing_test.go => filtering_internal_test.go} | 0 internal/filtering/hosts_test.go | 29 +- internal/filtering/http.go | 73 ++- .../{http_test.go => http_internal_test.go} | 169 +++++++ ...orage_test.go => storage_internal_test.go} | 0 ...ites_test.go => rewrites_internal_test.go} | 0 internal/filtering/safesearch.go | 3 + ...et_test.go => authglinet_internal_test.go} | 0 ...st.go => authratelimiter_internal_test.go} | 0 internal/home/clients.go | 2 + internal/home/dns.go | 77 --- internal/home/dns_internal_test.go | 206 -------- internal/home/home.go | 3 +- .../{home_test.go => home_internal_test.go} | 0 ...s_test.go => middlewares_internal_test.go} | 0 ...tions_test.go => options_internal_test.go} | 0 internal/home/signal.go | 2 +- internal/home/tls.go | 145 ++++-- internal/home/tls_internal_test.go | 453 +++++++++++++++++- ...decode_test.go => decode_internal_test.go} | 0 .../{qlog_test.go => qlog_internal_test.go} | 0 ...file_test.go => qlogfile_internal_test.go} | 0 ...er_test.go => qlogreader_internal_test.go} | 0 ...search_test.go => search_internal_test.go} | 0 .../{http_test.go => http_internal_test.go} | 0 openapi/CHANGELOG.md | 6 + openapi/openapi.yaml | 14 + 69 files changed, 1126 insertions(+), 434 deletions(-) rename internal/aghnet/{net_darwin_test.go => net_darwin_internal_test.go} (99%) rename internal/aghnet/{net_freebsd_test.go => net_freebsd_internal_test.go} (100%) rename internal/aghnet/{net_linux_test.go => net_linux_internal_test.go} (100%) rename internal/aghnet/{net_openbsd_test.go => net_openbsd_internal_test.go} (100%) rename internal/aghos/{os_test.go => os_internal_test.go} (100%) rename internal/dhcpd/{bitset_test.go => bitset_internal_test.go} (100%) rename internal/dhcpd/{broadcast_bsd_test.go => broadcast_bsd_internal_test.go} (100%) rename internal/dhcpd/{broadcast_others_test.go => broadcast_others_internal_test.go} (100%) rename internal/dhcpd/{dhcpd_unix_test.go => dhcpd_unix_internal_test.go} (100%) rename internal/dhcpd/{http_windows_test.go => http_windows_internal_test.go} (100%) rename internal/dhcpd/{iprange_test.go => iprange_internal_test.go} (100%) rename internal/dhcpd/{options_unix_test.go => options_unix_internal_test.go} (100%) rename internal/dhcpd/{v4_unix_test.go => v4_unix_internal_test.go} (100%) rename internal/dhcpd/{v6_unix_test.go => v6_unix_internal_test.go} (100%) rename internal/dnsforward/{access_test.go => access_internal_test.go} (100%) rename internal/dnsforward/{clientid_test.go => clientid_internal_test.go} (100%) rename internal/dnsforward/{config_test.go => config_internal_test.go} (100%) rename internal/dnsforward/{dns64_test.go => dns64_internal_test.go} (100%) rename internal/dnsforward/{dnsforward_test.go => dnsforward_internal_test.go} (96%) rename internal/dnsforward/{dnsrewrite_test.go => dnsrewrite_internal_test.go} (100%) rename internal/dnsforward/{filter_test.go => filter_internal_test.go} (97%) rename internal/dnsforward/{http_test.go => http_internal_test.go} (100%) rename internal/dnsforward/{stats_test.go => stats_internal_test.go} (100%) rename internal/dnsforward/{svcbmsg_test.go => svcbmsg_internal_test.go} (100%) rename internal/filtering/{filter_test.go => filter_internal_test.go} (100%) rename internal/filtering/{filtering_test.go => filtering_internal_test.go} (100%) rename internal/filtering/{http_test.go => http_internal_test.go} (61%) rename internal/filtering/rewrite/{storage_test.go => storage_internal_test.go} (100%) rename internal/filtering/{rewrites_test.go => rewrites_internal_test.go} (100%) rename internal/home/{authglinet_test.go => authglinet_internal_test.go} (100%) rename internal/home/{authratelimiter_test.go => authratelimiter_internal_test.go} (100%) delete mode 100644 internal/home/dns_internal_test.go rename internal/home/{home_test.go => home_internal_test.go} (100%) rename internal/home/{middlewares_test.go => middlewares_internal_test.go} (100%) rename internal/home/{options_test.go => options_internal_test.go} (100%) rename internal/querylog/{decode_test.go => decode_internal_test.go} (100%) rename internal/querylog/{qlog_test.go => qlog_internal_test.go} (100%) rename internal/querylog/{qlogfile_test.go => qlogfile_internal_test.go} (100%) rename internal/querylog/{qlogreader_test.go => qlogreader_internal_test.go} (100%) rename internal/querylog/{search_test.go => search_internal_test.go} (100%) rename internal/stats/{http_test.go => http_internal_test.go} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09564f71..e510677b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ NOTE: Add new changes BELOW THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT. --> -## [v0.107.58] - 2025-03-13 +## [v0.107.58] - 2025-03-18 See also the [v0.107.58 GitHub milestone][ms-v0.107.58]. @@ -30,8 +30,16 @@ See also the [v0.107.58 GitHub milestone][ms-v0.107.58]. - Go version has been updated to prevent the possibility of exploiting the Go vulnerabilities fixed in [1.24.1][go-1.24.1]. +### Added + +- The ability to check filtering rules for host names using an optional query type and optional ClientID or client IP address. + +- Optional `client` and `qtype` URL query parameters to the `GET /control/check_host` HTTP API. + ### Fixed +- Clearing the DNS cache on the *DNS settings* page now includes both global cache and custom client cache. + - Invalid ICMPv6 Router Advertisement messages ([#7547]). - Disabled button for autofilled login form. diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 9634cf70..67c9f55f 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -620,6 +620,10 @@ "check_cname": "CNAME: {{cname}}", "check_reason": "Reason: {{reason}}", "check_service": "Service name: {{service}}", + "check_hostname": "Hostname or domain name", + "check_client_id": "Client identifier (ClientID or IP address)", + "check_enter_client_id": "Enter client identifier", + "check_dns_record": "Select DNS record type", "service_name": "Service name", "check_not_found": "Not found in your filter lists", "client_confirm_block": "Are you sure you want to block the client \"{{ip}}\"?", diff --git a/client/src/components/Filters/Check/index.tsx b/client/src/components/Filters/Check/index.tsx index 996d6b75..2f578ade 100644 --- a/client/src/components/Filters/Check/index.tsx +++ b/client/src/components/Filters/Check/index.tsx @@ -9,13 +9,17 @@ import Info from './Info'; import { RootState } from '../../../initialState'; import { validateRequiredValue } from '../../../helpers/validators'; import { Input } from '../../ui/Controls/Input'; +import { DNS_RECORD_TYPES } from '../../../helpers/constants'; +import { Select } from '../../ui/Controls/Select'; -interface FormValues { +export type FilteringCheckFormValues = { name: string; + client?: string; + qtype?: string; } type Props = { - onSubmit?: (data: FormValues) => void; + onSubmit?: (data: FilteringCheckFormValues) => void; }; const Check = ({ onSubmit }: Props) => { @@ -27,11 +31,13 @@ const Check = ({ onSubmit }: Props) => { const { control, handleSubmit, - formState: { isDirty, isValid }, - } = useForm({ + formState: { isValid }, + } = useForm({ mode: 'onBlur', defaultValues: { name: '', + client: '', + qtype: DNS_RECORD_TYPES[0], }, }); @@ -48,24 +54,56 @@ const Check = ({ onSubmit }: Props) => { - - - } /> )} /> + ( + + )} + /> + + ( + + )} + /> + + + {hostname && ( <>
diff --git a/client/src/components/Filters/CustomRules.tsx b/client/src/components/Filters/CustomRules.tsx index db7fa92d..cc378e99 100644 --- a/client/src/components/Filters/CustomRules.tsx +++ b/client/src/components/Filters/CustomRules.tsx @@ -7,7 +7,7 @@ import PageTitle from '../ui/PageTitle'; import Examples from './Examples'; -import Check from './Check'; +import Check, { FilteringCheckFormValues } from './Check'; import { getTextareaCommentsHighlight, syncScroll } from '../../helpers/highlightTextareaComments'; import { COMMENT_LINE_DEFAULT_TOKEN } from '../../helpers/constants'; @@ -48,8 +48,18 @@ class CustomRules extends Component { this.props.setRules(this.props.filtering.userRules); }; - handleCheck = (values: any) => { - this.props.checkHost(values); + handleCheck = (values: FilteringCheckFormValues) => { + const params: FilteringCheckFormValues = { name: values.name }; + + if (values.client) { + params.client = values.client; + } + + if (values.qtype) { + params.qtype = values.qtype; + } + + this.props.checkHost(params); }; onScroll = (e: any) => syncScroll(e, this.ref); diff --git a/client/src/helpers/constants.ts b/client/src/helpers/constants.ts index 7b54ae8d..d4e7c940 100644 --- a/client/src/helpers/constants.ts +++ b/client/src/helpers/constants.ts @@ -523,3 +523,12 @@ export const TIME_UNITS = { HOURS: 'hours', DAYS: 'days', }; + +export const DNS_RECORD_TYPES = [ + "A", "AAAA", "AFSDB", "APL", "CAA", "CDNSKEY", "CDS", "CERT", "CNAME", + "CSYNC", "DHCID", "DLV", "DNAME", "DNSKEY", "DS", "EUI48", "EUI64", + "HINFO", "HIP", "HTTPS", "IPSECKEY", "KEY", "KX", "LOC", "MX", "NAPTR", + "NS", "NSEC", "NSEC3", "NSEC3PARAM", "OPENPGPKEY", "PTR", "RP", "RRSIG", + "SIG", "SMIMEA", "SOA", "SRV", "SSHFP", "SVCB", "TA", "TKEY", + "TLSA", "TSIG", "TXT", "URI", "ZONEMD" +]; diff --git a/internal/aghalg/sortedmap_test.go b/internal/aghalg/sortedmap_test.go index 6e563802..a3806639 100644 --- a/internal/aghalg/sortedmap_test.go +++ b/internal/aghalg/sortedmap_test.go @@ -1,14 +1,15 @@ -package aghalg +package aghalg_test import ( "strings" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/stretchr/testify/assert" ) func TestNewSortedMap(t *testing.T) { - var m SortedMap[string, int] + var m aghalg.SortedMap[string, int] letters := []string{} for i := range 10 { @@ -17,7 +18,7 @@ func TestNewSortedMap(t *testing.T) { } t.Run("create_and_fill", func(t *testing.T) { - m = NewSortedMap[string, int](strings.Compare) + m = aghalg.NewSortedMap[string, int](strings.Compare) nums := []int{} for i, r := range letters { @@ -68,7 +69,7 @@ func TestNewSortedMap_nil(t *testing.T) { val = "val" ) - var m SortedMap[string, string] + var m aghalg.SortedMap[string, string] assert.Panics(t, func() { m.Set(key, val) diff --git a/internal/aghnet/hostgen_test.go b/internal/aghnet/hostgen_test.go index 5896720d..8ac43872 100644 --- a/internal/aghnet/hostgen_test.go +++ b/internal/aghnet/hostgen_test.go @@ -1,9 +1,10 @@ -package aghnet +package aghnet_test import ( "net/netip" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/stretchr/testify/assert" ) @@ -29,13 +30,13 @@ func TestGenerateHostName(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - hostname := GenerateHostname(tc.ip) + hostname := aghnet.GenerateHostname(tc.ip) assert.Equal(t, tc.want, hostname) }) } }) t.Run("invalid", func(t *testing.T) { - assert.Panics(t, func() { GenerateHostname(netip.Addr{}) }) + assert.Panics(t, func() { aghnet.GenerateHostname(netip.Addr{}) }) }) } diff --git a/internal/aghnet/interfaces_test.go b/internal/aghnet/interfaces_test.go index ca829fb1..83bb81d5 100644 --- a/internal/aghnet/interfaces_test.go +++ b/internal/aghnet/interfaces_test.go @@ -1,22 +1,24 @@ -package aghnet +package aghnet_test import ( "net" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// fakeIface is a stub implementation of aghnet.NetIface to simplify testing. +// fakeIface is a stub implementation of [aghnet.NetIface] interface to simplify +// testing. type fakeIface struct { err error addrs []net.Addr } -// Addrs implements the NetIface interface for *fakeIface. +// Addrs implements the [aghnet.NetIface] interface for *fakeIface. func (iface *fakeIface) Addrs() (addrs []net.Addr, err error) { if iface.err != nil { return nil, iface.err @@ -25,6 +27,9 @@ func (iface *fakeIface) Addrs() (addrs []net.Addr, err error) { return iface.addrs, nil } +// type check +var _ aghnet.NetIface = (*fakeIface)(nil) + func TestIfaceIPAddrs(t *testing.T) { const errTest errors.Error = "test error" @@ -35,76 +40,76 @@ func TestIfaceIPAddrs(t *testing.T) { addr6 := &net.IPNet{IP: ip6} testCases := []struct { - iface NetIface + iface aghnet.NetIface name string wantErrMsg string want []net.IP - ipv IPVersion + ipv aghnet.IPVersion }{{ iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, name: "ipv4_success", wantErrMsg: "", want: []net.IP{ip4}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, }, { iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, name: "ipv4_success_with_ipv6", wantErrMsg: "", want: []net.IP{ip4}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, }, { iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest}, name: "ipv4_error", wantErrMsg: errTest.Error(), want: nil, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, }, { iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil}, name: "ipv6_success", wantErrMsg: "", want: []net.IP{ip6}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, }, { iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, name: "ipv6_success_with_ipv4", wantErrMsg: "", want: []net.IP{ip6}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, }, { iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest}, name: "ipv6_error", wantErrMsg: errTest.Error(), want: nil, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, }, { iface: &fakeIface{addrs: nil, err: nil}, name: "bad_proto", wantErrMsg: "invalid ip version 10", want: nil, - ipv: IPVersion6 + IPVersion4, + ipv: aghnet.IPVersion6 + aghnet.IPVersion4, }, { iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip4}}, err: nil}, name: "ipaddr_v4", wantErrMsg: "", want: []net.IP{ip4}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, }, { iface: &fakeIface{addrs: []net.Addr{&net.IPAddr{IP: ip6, Zone: ""}}, err: nil}, name: "ipaddr_v6", wantErrMsg: "", want: []net.IP{ip6}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, }, { iface: &fakeIface{addrs: []net.Addr{&net.UnixAddr{}}, err: nil}, name: "non-ipv4", wantErrMsg: "", want: nil, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := IfaceIPAddrs(tc.iface, tc.ipv) + got, err := aghnet.IfaceIPAddrs(tc.iface, tc.ipv) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) assert.Equal(t, tc.want, got) @@ -118,7 +123,10 @@ type waitingFakeIface struct { n int } -// Addrs implements the NetIface interface for *waitingFakeIface. +// type check +var _ aghnet.NetIface = (*waitingFakeIface)(nil) + +// Addrs implements the [aghnet.NetIface] interface for *waitingFakeIface. func (iface *waitingFakeIface) Addrs() (addrs []net.Addr, err error) { if iface.err != nil { return nil, iface.err @@ -143,76 +151,76 @@ func TestIfaceDNSIPAddrs(t *testing.T) { addr6 := &net.IPNet{IP: ip6} testCases := []struct { - iface NetIface + iface aghnet.NetIface wantErr error name string want []net.IP - ipv IPVersion + ipv aghnet.IPVersion }{{ name: "ipv4_success", iface: &fakeIface{addrs: []net.Addr{addr4}, err: nil}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, want: []net.IP{ip4, ip4}, wantErr: nil, }, { name: "ipv4_success_with_ipv6", iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, want: []net.IP{ip4, ip4}, wantErr: nil, }, { name: "ipv4_error", iface: &fakeIface{addrs: []net.Addr{addr4}, err: errTest}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, want: nil, wantErr: errTest, }, { name: "ipv4_wait", iface: &waitingFakeIface{addrs: []net.Addr{addr4}, err: nil, n: 1}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, want: []net.IP{ip4, ip4}, wantErr: nil, }, { name: "ipv6_success", iface: &fakeIface{addrs: []net.Addr{addr6}, err: nil}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, want: []net.IP{ip6, ip6}, wantErr: nil, }, { name: "ipv6_success_with_ipv4", iface: &fakeIface{addrs: []net.Addr{addr6, addr4}, err: nil}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, want: []net.IP{ip6, ip6}, wantErr: nil, }, { name: "ipv6_error", iface: &fakeIface{addrs: []net.Addr{addr6}, err: errTest}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, want: nil, wantErr: errTest, }, { name: "ipv6_wait", iface: &waitingFakeIface{addrs: []net.Addr{addr6}, err: nil, n: 1}, - ipv: IPVersion6, + ipv: aghnet.IPVersion6, want: []net.IP{ip6, ip6}, wantErr: nil, }, { name: "empty", iface: &fakeIface{addrs: nil, err: nil}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, want: nil, wantErr: nil, }, { name: "many", iface: &fakeIface{addrs: []net.Addr{addr4, addr4}}, - ipv: IPVersion4, + ipv: aghnet.IPVersion4, want: []net.IP{ip4, ip4}, wantErr: nil, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) + got, err := aghnet.IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) require.ErrorIs(t, err, tc.wantErr) assert.Equal(t, tc.want, got) diff --git a/internal/aghnet/ipmut_test.go b/internal/aghnet/ipmut_test.go index 51fc16ba..10beee3f 100644 --- a/internal/aghnet/ipmut_test.go +++ b/internal/aghnet/ipmut_test.go @@ -1,9 +1,10 @@ -package aghnet +package aghnet_test import ( "net" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/netutil" "github.com/stretchr/testify/assert" ) @@ -18,7 +19,7 @@ func TestIPMut(t *testing.T) { }} t.Run("nil_no_mut", func(t *testing.T) { - ipmut := NewIPMut(nil) + ipmut := aghnet.NewIPMut(nil) ips := netutil.CloneIPs(testIPs) for i := range ips { @@ -28,7 +29,7 @@ func TestIPMut(t *testing.T) { }) t.Run("not_nil_mut", func(t *testing.T) { - ipmut := NewIPMut(func(ip net.IP) { + ipmut := aghnet.NewIPMut(func(ip net.IP) { for i := range ip { ip[i] = 0 } diff --git a/internal/aghnet/net_darwin_test.go b/internal/aghnet/net_darwin_internal_test.go similarity index 99% rename from internal/aghnet/net_darwin_test.go rename to internal/aghnet/net_darwin_internal_test.go index 06e7eeaf..7593a969 100644 --- a/internal/aghnet/net_darwin_test.go +++ b/internal/aghnet/net_darwin_internal_test.go @@ -1,3 +1,5 @@ +//go:build darwin + package aghnet import ( diff --git a/internal/aghnet/net_freebsd_test.go b/internal/aghnet/net_freebsd_internal_test.go similarity index 100% rename from internal/aghnet/net_freebsd_test.go rename to internal/aghnet/net_freebsd_internal_test.go diff --git a/internal/aghnet/net_linux_test.go b/internal/aghnet/net_linux_internal_test.go similarity index 100% rename from internal/aghnet/net_linux_test.go rename to internal/aghnet/net_linux_internal_test.go diff --git a/internal/aghnet/net_openbsd_test.go b/internal/aghnet/net_openbsd_internal_test.go similarity index 100% rename from internal/aghnet/net_openbsd_test.go rename to internal/aghnet/net_openbsd_internal_test.go diff --git a/internal/aghos/os_test.go b/internal/aghos/os_internal_test.go similarity index 100% rename from internal/aghos/os_test.go rename to internal/aghos/os_internal_test.go diff --git a/internal/client/storage.go b/internal/client/storage.go index e3bf58e0..3bb839b4 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" @@ -671,3 +672,38 @@ func (s *Storage) ClearUpstreamCache() { s.upstreamManager.clearUpstreamCache() } + +// ApplyClientFiltering retrieves persistent client information using the +// ClientID or client IP address, and applies it to the filtering settings. +func (s *Storage) ApplyClientFiltering(id string, addr netip.Addr, setts *filtering.Settings) { + c, ok := s.index.findByClientID(id) + if !ok { + c, ok = s.index.findByIP(addr) + } + + if !ok { + s.logger.Debug("no client filtering settings found", "clientid", id, "addr", addr) + + return + } + + s.logger.Debug("applying custom client filtering settings", "client_name", c.Name) + + setts.ClientIP = addr + + if c.UseOwnBlockedServices { + setts.BlockedServices = c.BlockedServices.Clone() + } + + setts.ClientName = c.Name + setts.ClientTags = slices.Clone(c.Tags) + if !c.UseOwnSettings { + return + } + + setts.FilteringEnabled = c.FilteringEnabled + setts.SafeSearchEnabled = c.SafeSearchConf.Enabled + setts.ClientSafeSearch = c.SafeSearch + setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled + setts.ParentalEnabled = c.ParentalEnabled +} diff --git a/internal/client/upstreammanager.go b/internal/client/upstreammanager.go index f6cad878..b804dbdc 100644 --- a/internal/client/upstreammanager.go +++ b/internal/client/upstreammanager.go @@ -153,7 +153,9 @@ func (m *upstreamManager) isConfigChanged(cliConf *customUpstreamConfig) (ok boo // upstream configuration. func (m *upstreamManager) clearUpstreamCache() { for _, c := range m.uidToCustomConf { - c.proxyConf.ClearCache() + if c.proxyConf != nil { + c.proxyConf.ClearCache() + } } } diff --git a/internal/dhcpd/bitset_test.go b/internal/dhcpd/bitset_internal_test.go similarity index 100% rename from internal/dhcpd/bitset_test.go rename to internal/dhcpd/bitset_internal_test.go diff --git a/internal/dhcpd/broadcast_bsd_test.go b/internal/dhcpd/broadcast_bsd_internal_test.go similarity index 100% rename from internal/dhcpd/broadcast_bsd_test.go rename to internal/dhcpd/broadcast_bsd_internal_test.go diff --git a/internal/dhcpd/broadcast_others_test.go b/internal/dhcpd/broadcast_others_internal_test.go similarity index 100% rename from internal/dhcpd/broadcast_others_test.go rename to internal/dhcpd/broadcast_others_internal_test.go diff --git a/internal/dhcpd/dhcpd_unix_test.go b/internal/dhcpd/dhcpd_unix_internal_test.go similarity index 100% rename from internal/dhcpd/dhcpd_unix_test.go rename to internal/dhcpd/dhcpd_unix_internal_test.go diff --git a/internal/dhcpd/http_windows_test.go b/internal/dhcpd/http_windows_internal_test.go similarity index 100% rename from internal/dhcpd/http_windows_test.go rename to internal/dhcpd/http_windows_internal_test.go diff --git a/internal/dhcpd/iprange_test.go b/internal/dhcpd/iprange_internal_test.go similarity index 100% rename from internal/dhcpd/iprange_test.go rename to internal/dhcpd/iprange_internal_test.go diff --git a/internal/dhcpd/options_unix_test.go b/internal/dhcpd/options_unix_internal_test.go similarity index 100% rename from internal/dhcpd/options_unix_test.go rename to internal/dhcpd/options_unix_internal_test.go diff --git a/internal/dhcpd/v4_unix_test.go b/internal/dhcpd/v4_unix_internal_test.go similarity index 100% rename from internal/dhcpd/v4_unix_test.go rename to internal/dhcpd/v4_unix_internal_test.go diff --git a/internal/dhcpd/v6_unix_test.go b/internal/dhcpd/v6_unix_internal_test.go similarity index 100% rename from internal/dhcpd/v6_unix_test.go rename to internal/dhcpd/v6_unix_internal_test.go diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_internal_test.go similarity index 100% rename from internal/dnsforward/access_test.go rename to internal/dnsforward/access_internal_test.go diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_internal_test.go similarity index 100% rename from internal/dnsforward/clientid_test.go rename to internal/dnsforward/clientid_internal_test.go diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index c549a07e..37e8ce98 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -16,7 +16,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/container" @@ -34,9 +33,6 @@ import ( type Config struct { // Callbacks for other modules - // FilterHandler is an optional additional filtering callback. - FilterHandler func(cliAddr netip.Addr, clientID string, settings *filtering.Settings) `yaml:"-"` - // ClientsContainer stores the information about special handling of some // DNS clients. ClientsContainer ClientsContainer `yaml:"-"` diff --git a/internal/dnsforward/config_test.go b/internal/dnsforward/config_internal_test.go similarity index 100% rename from internal/dnsforward/config_test.go rename to internal/dnsforward/config_internal_test.go diff --git a/internal/dnsforward/dns64_test.go b/internal/dnsforward/dns64_internal_test.go similarity index 100% rename from internal/dnsforward/dns64_test.go rename to internal/dnsforward/dns64_internal_test.go diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_internal_test.go similarity index 96% rename from internal/dnsforward/dnsforward_test.go rename to internal/dnsforward/dnsforward_internal_test.go index 0ced288d..5f035f84 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_internal_test.go @@ -27,6 +27,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -106,6 +107,21 @@ func startDeferStop(t *testing.T, s *Server) { testutil.CleanupAndRequireSuccess(t, s.Stop) } +// applyEmptyClientFiltering is a helper function for tests with +// [filtering.Config] that does nothing. +func applyEmptyClientFiltering(_ string, _ netip.Addr, _ *filtering.Settings) {} + +// emptyFilteringBlockedServices is a helper function that returns an empty +// filtering blocked services for tests. +func emptyFilteringBlockedServices() (bsvc *filtering.BlockedServices) { + return &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + } +} + +// createTestServer is a helper function that returns a properly initialized +// *Server for use in tests, given the provided parameters. It also populates +// the filtering configuration with default parameters. func createTestServer( t *testing.T, filterConf *filtering.Config, @@ -123,6 +139,12 @@ func createTestServer( Data: []byte(rules), }} + filterConf.BlockedServices = cmp.Or(filterConf.BlockedServices, emptyFilteringBlockedServices()) + + if filterConf.ApplyClientFiltering == nil { + filterConf.ApplyClientFiltering = applyEmptyClientFiltering + } + f, err := filtering.New(filterConf, filters) require.NoError(t, err) @@ -926,9 +948,6 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, Config: Config{ - FilterHandler: func(_ netip.Addr, _ string, settings *filtering.Settings) { - settings.FilteringEnabled = false - }, UpstreamMode: UpstreamModeLoadBalance, EDNSClientSubnet: &EDNSClientSubnet{ Enabled: false, @@ -1020,10 +1039,12 @@ func TestBlockedCustomIP(t *testing.T) { }} f, err := filtering.New(&filtering.Config{ - ProtectionEnabled: true, - BlockingMode: filtering.BlockingModeCustomIP, - BlockingIPv4: netip.Addr{}, - BlockingIPv6: netip.Addr{}, + ProtectionEnabled: true, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeCustomIP, + BlockingIPv4: netip.Addr{}, + BlockingIPv6: netip.Addr{}, }, filters) require.NoError(t, err) @@ -1176,7 +1197,9 @@ func TestBlockedBySafeBrowsing(t *testing.T) { func TestRewrite(t *testing.T) { c := &filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, Rewrites: []*filtering.LegacyRewrite{{ Domain: "test.com", Answer: "1.2.3.4", @@ -1322,7 +1345,9 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { const localDomain = "lan" flt, err := filtering.New(&filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, }, nil) require.NoError(t, err) @@ -1411,8 +1436,10 @@ func TestPTRResponseFromHosts(t *testing.T) { }) flt, err := filtering.New(&filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, - EtcHosts: hc, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, + EtcHosts: hc, }, nil) require.NoError(t, err) diff --git a/internal/dnsforward/dnsrewrite_test.go b/internal/dnsforward/dnsrewrite_internal_test.go similarity index 100% rename from internal/dnsforward/dnsrewrite_test.go rename to internal/dnsforward/dnsrewrite_internal_test.go diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index f6cd319d..6cfd7bea 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -17,9 +17,7 @@ import ( func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { setts = s.dnsFilter.Settings() setts.ProtectionEnabled = dctx.protectionEnabled - if s.conf.FilterHandler != nil { - s.conf.FilterHandler(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts) - } + s.dnsFilter.ApplyAdditionalFiltering(dctx.proxyCtx.Addr.Addr(), dctx.clientID, setts) return setts } diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_internal_test.go similarity index 97% rename from internal/dnsforward/filter_test.go rename to internal/dnsforward/filter_internal_test.go index 922213c4..7f1ab293 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_internal_test.go @@ -45,8 +45,10 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { }} f, err := filtering.New(&filtering.Config{ - ProtectionEnabled: true, - BlockingMode: filtering.BlockingModeDefault, + ProtectionEnabled: true, + ApplyClientFiltering: applyEmptyClientFiltering, + BlockedServices: emptyFilteringBlockedServices(), + BlockingMode: filtering.BlockingModeDefault, }, filters) require.NoError(t, err) f.SetEnabled(true) diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_internal_test.go similarity index 100% rename from internal/dnsforward/http_test.go rename to internal/dnsforward/http_internal_test.go diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_internal_test.go similarity index 100% rename from internal/dnsforward/stats_test.go rename to internal/dnsforward/stats_internal_test.go diff --git a/internal/dnsforward/svcbmsg_test.go b/internal/dnsforward/svcbmsg_internal_test.go similarity index 100% rename from internal/dnsforward/svcbmsg_test.go rename to internal/dnsforward/svcbmsg_internal_test.go diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index aec5855e..ca59a1b8 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -49,6 +49,9 @@ func initBlockedServices() { } // BlockedServices is the configuration of blocked services. +// +// TODO(s.chzhen): Move to a higher-level package to allow importing the client +// package into the filtering package. type BlockedServices struct { // Schedule is blocked services schedule for every day of the week. Schedule *schedule.Weekly `json:"schedule" yaml:"schedule"` diff --git a/internal/filtering/dnsrewrite_test.go b/internal/filtering/dnsrewrite_test.go index 06cd921b..89b6b30d 100644 --- a/internal/filtering/dnsrewrite_test.go +++ b/internal/filtering/dnsrewrite_test.go @@ -1,10 +1,11 @@ -package filtering +package filtering_test import ( "net/netip" "path" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -50,8 +51,17 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { |1.2.3.5.in-addr.arpa^$dnsrewrite=NOERROR;PTR;new-ptr-with-dot. ` - f, _ := newForTest(t, nil, []Filter{{ID: 0, Data: []byte(text)}}) - setts := &Settings{ + conf := &filtering.Config{ + SafeBrowsingCacheSize: 10000, + ParentalCacheSize: 10000, + SafeSearchCacheSize: 1000, + CacheTime: 30, + } + + f, err := filtering.New(conf, []filtering.Filter{{ID: 0, Data: []byte(text)}}) + require.NoError(t, err) + + setts := &filtering.Settings{ FilteringEnabled: true, } @@ -117,7 +127,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { t.Run(tc.name, func(t *testing.T) { host := path.Base(tc.name) - res, err := f.CheckHostRules(host, tc.dtyp, setts) + var res filtering.Result + res, err = f.CheckHostRules(host, tc.dtyp, setts) require.NoError(t, err) dnsrr := res.DNSRewriteResult @@ -141,7 +152,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { dtyp := dns.TypeA host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dtyp, setts) + var res filtering.Result + res, err = f.CheckHostRules(host, dtyp, setts) require.NoError(t, err) assert.Equal(t, "new-cname", res.CanonName) @@ -151,7 +163,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { dtyp := dns.TypeA host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dtyp, setts) + var res filtering.Result + res, err = f.CheckHostRules(host, dtyp, setts) require.NoError(t, err) assert.Equal(t, "new-cname-2", res.CanonName) @@ -162,7 +175,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { dtyp := dns.TypeA host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dtyp, setts) + var res filtering.Result + res, err = f.CheckHostRules(host, dtyp, setts) require.NoError(t, err) assert.Empty(t, res.CanonName) @@ -173,7 +187,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { dtyp := dns.TypePTR host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dtyp, setts) + var res filtering.Result + res, err = f.CheckHostRules(host, dtyp, setts) require.NoError(t, err) require.NotNil(t, res.DNSRewriteResult) @@ -193,7 +208,8 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { dtyp := dns.TypePTR host := path.Base(t.Name()) - res, err := f.CheckHostRules(host, dtyp, setts) + var res filtering.Result + res, err = f.CheckHostRules(host, dtyp, setts) require.NoError(t, err) require.NotNil(t, res.DNSRewriteResult) diff --git a/internal/filtering/filter.go b/internal/filtering/filter.go index 0dd3471c..5242082f 100644 --- a/internal/filtering/filter.go +++ b/internal/filtering/filter.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "net/netip" "os" "path/filepath" "slices" @@ -629,3 +630,19 @@ func (d *DNSFilter) enableFiltersLocked(async bool) { d.SetEnabled(d.conf.FilteringEnabled) } + +// ApplyAdditionalFiltering enhances the provided filtering settings with +// blocked services and client-specific configurations. +func (d *DNSFilter) ApplyAdditionalFiltering(cliAddr netip.Addr, clientID string, setts *Settings) { + d.ApplyBlockedServices(setts) + + d.applyClientFiltering(clientID, cliAddr, setts) + if setts.BlockedServices != nil { + // TODO(e.burkov): Get rid of this crutch. + setts.ServicesRules = nil + svcs := setts.BlockedServices.IDs + if !setts.BlockedServices.Schedule.Contains(time.Now()) { + d.ApplyBlockedServicesList(setts, svcs) + } + } +} diff --git a/internal/filtering/filter_test.go b/internal/filtering/filter_internal_test.go similarity index 100% rename from internal/filtering/filter_test.go rename to internal/filtering/filter_internal_test.go diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 8836515c..d5fd49d5 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -40,6 +40,8 @@ type ServiceEntry struct { } // Settings are custom filtering settings for a client. +// +// TODO(s.chzhen): Move to the client package. type Settings struct { ClientName string ClientIP netip.Addr @@ -47,6 +49,10 @@ type Settings struct { ServicesRules []ServiceEntry + // BlockedServices is the configuration of blocked services of a client. It + // is nil if the client does not have any blocked services. + BlockedServices *BlockedServices + ProtectionEnabled bool FilteringEnabled bool SafeSearchEnabled bool @@ -78,6 +84,11 @@ type Config struct { SafeSearch SafeSearch `yaml:"-"` + // ApplyClientFiltering retrieves persistent client information using the + // ClientID or client IP address, and applies it to the filtering settings. + // It must not be nil. + ApplyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) `yaml:"-"` + // BlockedServices is the configuration of blocked services. // Per-client settings can override this configuration. BlockedServices *BlockedServices `yaml:"blocked_services"` @@ -244,6 +255,13 @@ type DNSFilter struct { // parentalControl is the parental control hash-prefix checker. parentalControlChecker Checker + // applyClientFiltering retrieves persistent client information using the + // ClientID or client IP address, and applies it to the filtering settings. + // + // TODO(s.chzhen): Consider finding a better approach while taking an + // import cycle into account. + applyClientFiltering func(clientID string, cliAddr netip.Addr, setts *Settings) + engineLock sync.RWMutex // confMu protects conf. @@ -998,6 +1016,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { refreshLock: &sync.Mutex{}, safeBrowsingChecker: c.SafeBrowsingChecker, parentalControlChecker: c.ParentalControlChecker, + applyClientFiltering: c.ApplyClientFiltering, confMu: &sync.RWMutex{}, } diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_internal_test.go similarity index 100% rename from internal/filtering/filtering_test.go rename to internal/filtering/filtering_internal_test.go diff --git a/internal/filtering/hosts_test.go b/internal/filtering/hosts_test.go index e94603a0..14e20adc 100644 --- a/internal/filtering/hosts_test.go +++ b/internal/filtering/hosts_test.go @@ -1,4 +1,4 @@ -package filtering +package filtering_test import ( "fmt" @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" @@ -50,27 +51,27 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) - conf := &Config{ + conf := &filtering.Config{ EtcHosts: hc, } - f, err := New(conf, nil) + f, err := filtering.New(conf, nil) require.NoError(t, err) - setts := &Settings{ + setts := &filtering.Settings{ FilteringEnabled: true, } testCases := []struct { name string host string - wantRules []*ResultRule + wantRules []*filtering.ResultRule wantResps []rules.RRValue dtyp uint16 }{{ name: "v4", host: "v4.host.example", dtyp: dns.TypeA, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: "1.2.3.4 v4.host.example", FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -79,7 +80,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "v6", host: "v6.host.example", dtyp: dns.TypeAAAA, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: "::1 v6.host.example", FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -88,7 +89,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "mapped", host: "mapped.host.example", dtyp: dns.TypeAAAA, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: "::ffff:1.2.3.4 mapped.host.example", FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -97,7 +98,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "ptr", host: "4.3.2.1.in-addr.arpa", dtyp: dns.TypePTR, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: "1.2.3.4 v4.host.example", FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -106,7 +107,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "ptr-mapped", host: "4.0.3.0.2.0.1.0.f.f.f.f.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", dtyp: dns.TypePTR, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: "::ffff:1.2.3.4 mapped.host.example", FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -133,7 +134,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "v4_mismatch", host: "v4.host.example", dtyp: dns.TypeAAAA, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: fmt.Sprintf("%s v4.host.example", addrv4), FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -142,7 +143,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "v6_mismatch", host: "v6.host.example", dtyp: dns.TypeA, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: fmt.Sprintf("%s v6.host.example", addrv6), FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -163,7 +164,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { name: "v4_dup", host: "v4.host.with-dup", dtyp: dns.TypeA, - wantRules: []*ResultRule{{ + wantRules: []*filtering.ResultRule{{ Text: "4.3.2.1 v4.host.with-dup", FilterListID: rulelist.URLFilterIDEtcHosts, }}, @@ -172,7 +173,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var res Result + var res filtering.Result res, err = f.CheckHost(tc.host, tc.dtyp, setts) require.NoError(t, err) diff --git a/internal/filtering/http.go b/internal/filtering/http.go index b99bc253..266a422e 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -9,6 +9,8 @@ import ( "os" "path/filepath" "slices" + "strconv" + "strings" "sync" "time" @@ -420,15 +422,53 @@ type checkHostResp struct { FilterID rulelist.URLFilterID `json:"filter_id"` } +// handleCheckHost is the handler for the GET /control/filtering/check_host HTTP +// API. func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { - host := r.URL.Query().Get("name") + query := r.URL.Query() + host := query.Get("name") + if host == "" { + aghhttp.Error( + r, + w, + http.StatusBadRequest, + `query parameter "name" is required`, + ) + + return + } + + cli := query.Get("client") + qTypeStr := query.Get("qtype") + qType, err := stringToDNSType(qTypeStr) + if err != nil { + aghhttp.Error( + r, + w, + http.StatusUnprocessableEntity, + "bad qtype query parameter: %q", + qTypeStr, + ) + + return + } setts := d.Settings() setts.FilteringEnabled = true setts.ProtectionEnabled = true - d.ApplyBlockedServices(setts) - result, err := d.CheckHost(host, dns.TypeA, setts) + addr, err := netip.ParseAddr(cli) + if err == nil { + setts.ClientIP = addr + d.ApplyAdditionalFiltering(addr, "", setts) + } else if cli != "" { + // TODO(s.chzhen): Set [Settings.ClientName] once urlfilter supports + // multiple client names. This will handle the case when a rule exists + // but the persistent client does not. + d.ApplyAdditionalFiltering(netip.Addr{}, cli, setts) + } + + result, err := d.CheckHost(host, qType, setts) if err != nil { aghhttp.Error( r, @@ -466,6 +506,33 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) { aghhttp.WriteJSONResponseOK(w, r, resp) } +// stringToDNSType is a helper function that converts a string to DNS type. If +// the string is empty, it returns the default value [dns.TypeA]. +func stringToDNSType(str string) (qtype uint16, err error) { + if str == "" { + return dns.TypeA, nil + } + + qtype, ok := dns.StringToType[str] + if ok { + return qtype, nil + } + + // typePref is a prefix for DNS types from experimental RFCs. + const typePref = "TYPE" + + if !strings.HasPrefix(str, typePref) { + return 0, errors.ErrBadEnumValue + } + + val, err := strconv.ParseUint(str[len(typePref):], 10, 16) + if err != nil { + return 0, errors.ErrBadEnumValue + } + + return uint16(val), nil +} + // setProtectedBool sets the value of a boolean pointer under a lock. l must // protect the value under ptr. // diff --git a/internal/filtering/http_test.go b/internal/filtering/http_internal_test.go similarity index 61% rename from internal/filtering/http_test.go rename to internal/filtering/http_internal_test.go index 8330dac6..a46d5d7b 100644 --- a/internal/filtering/http_test.go +++ b/internal/filtering/http_internal_test.go @@ -3,11 +3,15 @@ package filtering import ( "bytes" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "net/netip" + "strings" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -305,3 +309,168 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) { }) } } + +func TestDNSFilter_HandleCheckHost(t *testing.T) { + const ( + cliName = "client_name" + cliID = "client_id" + + notFilteredHost = "not.filterd.example" + allowedHost = "allowed.example" + blockedHost = "blocked.example" + cliHost = "client.example" + qTypeHost = "qtype.example" + cliQTypeHost = "cli.qtype.example" + + target = "/control/check_host" + hostFmt = target + "?name=%s" + hostCliFmt = hostFmt + "&client=%s" + hostQTypeFmt = hostFmt + "&qtype=%s" + hostCliQTypeFmt = hostCliFmt + "&qtype=%s" + + allowedRuleFmt = "@@||%s^" + blockedRuleFmt = "||%s^" + blockedRuleCliFmt = blockedRuleFmt + "$client=%s" + blockedRuleQTypeFmt = blockedRuleFmt + "$dnstype=%s" + blockedRuleCliQTypeFmt = blockedRuleCliFmt + ",dnstype=%s" + ) + + var ( + allowedRule = fmt.Sprintf(allowedRuleFmt, allowedHost) + blockedRule = fmt.Sprintf(blockedRuleFmt, blockedHost) + blockedClientRule = fmt.Sprintf(blockedRuleCliFmt, cliHost, cliName) + blockedQTypeRule = fmt.Sprintf(blockedRuleQTypeFmt, qTypeHost, "CNAME") + blockedClientQTypeRule = fmt.Sprintf(blockedRuleCliQTypeFmt, cliQTypeHost, cliName, "CNAME") + + notFilteredURL = fmt.Sprintf(hostFmt, notFilteredHost) + allowedURL = fmt.Sprintf(hostFmt, allowedHost) + blockedURL = fmt.Sprintf(hostFmt, blockedHost) + blockedClientURL = fmt.Sprintf(hostCliFmt, cliHost, cliID) + allowedQTypeURL = fmt.Sprintf(hostQTypeFmt, qTypeHost, "AAAA") + blockedQTypeURL = fmt.Sprintf(hostQTypeFmt, qTypeHost, "CNAME") + allowedClientQTypeURL = fmt.Sprintf(hostCliQTypeFmt, cliQTypeHost, cliID, "AAAA") + blockedClientQTypeURL = fmt.Sprintf(hostCliQTypeFmt, cliQTypeHost, cliID, "CNAME") + ) + + rules := []string{ + allowedRule, + blockedRule, + blockedClientRule, + blockedQTypeRule, + blockedClientQTypeRule, + } + rulesData := strings.Join(rules, "\n") + + filters := []Filter{{ + ID: 0, Data: []byte(rulesData), + }} + + clientNames := map[string]string{ + cliID: cliName, + } + + dnsFilter, err := New(&Config{ + BlockedServices: &BlockedServices{ + Schedule: schedule.EmptyWeekly(), + }, + ApplyClientFiltering: func(clientID string, cliAddr netip.Addr, setts *Settings) { + setts.ClientName = clientNames[clientID] + }, + }, filters) + require.NoError(t, err) + + testCases := []struct { + name string + url string + want *checkHostResp + }{{ + name: "not_filtered", + url: notFilteredURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredNotFound], + Rule: "", + Rules: []*checkHostRespRule{}, + }, + }, { + name: "allowed", + url: allowedURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredAllowList], + Rule: allowedRule, + Rules: []*checkHostRespRule{{ + Text: allowedRule, + }}, + }, + }, { + name: "blocked", + url: blockedURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedRule, + Rules: []*checkHostRespRule{{ + Text: blockedRule, + }}, + }, + }, { + name: "blocked_client", + url: blockedClientURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedClientRule, + Rules: []*checkHostRespRule{{ + Text: blockedClientRule, + }}, + }, + }, { + name: "allowed_qtype", + url: allowedQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredNotFound], + Rule: "", + Rules: []*checkHostRespRule{}, + }, + }, { + name: "blocked_qtype", + url: blockedQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedQTypeRule, + Rules: []*checkHostRespRule{{ + Text: blockedQTypeRule, + }}, + }, + }, { + name: "blocked_client_qtype", + url: blockedClientQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[FilteredBlockList], + Rule: blockedClientQTypeRule, + Rules: []*checkHostRespRule{{ + Text: blockedClientQTypeRule, + }}, + }, + }, { + name: "allowed_client_qtype", + url: allowedClientQTypeURL, + want: &checkHostResp{ + Reason: reasonNames[NotFilteredNotFound], + Rule: "", + Rules: []*checkHostRespRule{}, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tc.url, nil) + w := httptest.NewRecorder() + + dnsFilter.handleCheckHost(w, r) + + res := &checkHostResp{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + assert.Equal(t, tc.want, res) + }) + } +} diff --git a/internal/filtering/rewrite/storage_test.go b/internal/filtering/rewrite/storage_internal_test.go similarity index 100% rename from internal/filtering/rewrite/storage_test.go rename to internal/filtering/rewrite/storage_internal_test.go diff --git a/internal/filtering/rewrites_test.go b/internal/filtering/rewrites_internal_test.go similarity index 100% rename from internal/filtering/rewrites_test.go rename to internal/filtering/rewrites_internal_test.go diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index b389573a..92df7062 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -3,6 +3,9 @@ package filtering import "context" // SafeSearch interface describes a service for search engines hosts rewrites. +// +// TODO(s.chzhen): Move to a higher-level package to allow importing the client +// package into the filtering package. type SafeSearch interface { // CheckHost checks host with safe search filter. CheckHost must be safe // for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA]. diff --git a/internal/home/authglinet_test.go b/internal/home/authglinet_internal_test.go similarity index 100% rename from internal/home/authglinet_test.go rename to internal/home/authglinet_internal_test.go diff --git a/internal/home/authratelimiter_test.go b/internal/home/authratelimiter_internal_test.go similarity index 100% rename from internal/home/authratelimiter_test.go rename to internal/home/authratelimiter_internal_test.go diff --git a/internal/home/clients.go b/internal/home/clients.go index 4a5a4d00..781e5e9d 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -121,6 +121,8 @@ func (clients *clientsContainer) Init( sigHdlr.addClientStorage(clients.storage) + filteringConf.ApplyClientFiltering = clients.storage.ApplyClientFiltering + return nil } diff --git a/internal/home/dns.go b/internal/home/dns.go index 4a89131c..4cfc63f8 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -247,7 +247,6 @@ func newServerConfig( hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) fwdConf := dnsConf.Config - fwdConf.FilterHandler = applyAdditionalFiltering fwdConf.ClientsContainer = clientsContainer newConf = &dnsforward.ServerConfig{ @@ -411,57 +410,6 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) { return de } -// applyAdditionalFiltering adds additional client information and settings if -// the client has them. -func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filtering.Settings) { - // pref is a prefix for logging messages around the scope. - const pref = "applying filters" - - globalContext.filters.ApplyBlockedServices(setts) - - log.Debug("%s: looking for client with ip %s and clientid %q", pref, clientIP, clientID) - - if !clientIP.IsValid() { - return - } - - setts.ClientIP = clientIP - - c, ok := globalContext.clients.storage.Find(clientID) - if !ok { - c, ok = globalContext.clients.storage.Find(clientIP.String()) - if !ok { - log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID) - - return - } - } - - log.Debug("%s: using settings for client %q (%s; %q)", pref, c.Name, clientIP, clientID) - - if c.UseOwnBlockedServices { - // TODO(e.burkov): Get rid of this crutch. - setts.ServicesRules = nil - svcs := c.BlockedServices.IDs - if !c.BlockedServices.Schedule.Contains(time.Now()) { - globalContext.filters.ApplyBlockedServicesList(setts, svcs) - log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) - } - } - - setts.ClientName = c.Name - setts.ClientTags = c.Tags - if !c.UseOwnSettings { - return - } - - setts.FilteringEnabled = c.FilteringEnabled - setts.SafeSearchEnabled = c.SafeSearchConf.Enabled - setts.ClientSafeSearch = c.SafeSearch - setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled - setts.ParentalEnabled = c.ParentalEnabled -} - func startDNSServer() error { config.RLock() defer config.RUnlock() @@ -495,31 +443,6 @@ func startDNSServer() error { return nil } -// reconfigureDNSServer updates the DNS server configuration using the provided -// TLS settings. tlsMgr must not be nil. -func reconfigureDNSServer(tlsMgr *tlsManager) (err error) { - tlsConf := &tlsConfigSettings{} - tlsMgr.WriteDiskConfig(tlsConf) - - newConf, err := newServerConfig( - &config.DNS, - config.Clients.Sources, - tlsConf, - httpRegister, - globalContext.clients.storage, - ) - if err != nil { - return fmt.Errorf("generating forwarding dns server config: %w", err) - } - - err = globalContext.dnsServer.Reconfigure(newConf) - if err != nil { - return fmt.Errorf("starting forwarding dns server: %w", err) - } - - return nil -} - func stopDNSServer() (err error) { if !isRunning() { return nil diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go deleted file mode 100644 index d7c627f7..00000000 --- a/internal/home/dns_internal_test.go +++ /dev/null @@ -1,206 +0,0 @@ -package home - -import ( - "net/netip" - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/client" - "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/AdGuardHome/internal/schedule" - "github.com/AdguardTeam/golibs/logutil/slogutil" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) - -// newStorage is a helper function that returns a client storage filled with -// persistent clients. It also generates a UID for each client. -func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { - tb.Helper() - - ctx := testutil.ContextWithTimeout(tb, testTimeout) - s, err := client.NewStorage(ctx, &client.StorageConfig{ - Logger: slogutil.NewDiscardLogger(), - }) - require.NoError(tb, err) - - for _, p := range clients { - p.UID = client.MustNewUID() - require.NoError(tb, s.Add(ctx, p)) - } - - return s -} - -func TestApplyAdditionalFiltering(t *testing.T) { - var err error - - globalContext.filters, err = filtering.New(&filtering.Config{ - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - }, - }, nil) - require.NoError(t, err) - - globalContext.clients.storage = newStorage(t, []*client.Persistent{{ - Name: "default", - ClientIDs: []string{"default"}, - UseOwnSettings: false, - SafeSearchConf: filtering.SafeSearchConfig{Enabled: false}, - FilteringEnabled: false, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }, { - Name: "custom_filtering", - ClientIDs: []string{"custom_filtering"}, - UseOwnSettings: true, - SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: true, - ParentalEnabled: true, - }, { - Name: "partial_custom_filtering", - ClientIDs: []string{"partial_custom_filtering"}, - UseOwnSettings: true, - SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }}) - - testCases := []struct { - name string - id string - FilteringEnabled assert.BoolAssertionFunc - SafeSearchEnabled assert.BoolAssertionFunc - SafeBrowsingEnabled assert.BoolAssertionFunc - ParentalEnabled assert.BoolAssertionFunc - }{{ - name: "global_settings", - id: "default", - FilteringEnabled: assert.False, - SafeSearchEnabled: assert.False, - SafeBrowsingEnabled: assert.False, - ParentalEnabled: assert.False, - }, { - name: "custom_settings", - id: "custom_filtering", - FilteringEnabled: assert.True, - SafeSearchEnabled: assert.True, - SafeBrowsingEnabled: assert.True, - ParentalEnabled: assert.True, - }, { - name: "partial", - id: "partial_custom_filtering", - FilteringEnabled: assert.True, - SafeSearchEnabled: assert.True, - SafeBrowsingEnabled: assert.False, - ParentalEnabled: assert.False, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setts := &filtering.Settings{} - - applyAdditionalFiltering(testIPv4, tc.id, setts) - tc.FilteringEnabled(t, setts.FilteringEnabled) - tc.SafeSearchEnabled(t, setts.SafeSearchEnabled) - tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled) - tc.ParentalEnabled(t, setts.ParentalEnabled) - }) - } -} - -func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { - filtering.InitModule() - - var ( - globalBlockedServices = []string{"ok"} - clientBlockedServices = []string{"ok", "mail_ru", "vk"} - invalidBlockedServices = []string{"invalid"} - - err error - ) - - globalContext.filters, err = filtering.New(&filtering.Config{ - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: globalBlockedServices, - }, - }, nil) - require.NoError(t, err) - - globalContext.clients.storage = newStorage(t, []*client.Persistent{{ - Name: "default", - ClientIDs: []string{"default"}, - UseOwnBlockedServices: false, - }, { - Name: "no_services", - ClientIDs: []string{"no_services"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - }, - UseOwnBlockedServices: true, - }, { - Name: "services", - ClientIDs: []string{"services"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, - }, { - Name: "invalid_services", - ClientIDs: []string{"invalid_services"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: invalidBlockedServices, - }, - UseOwnBlockedServices: true, - }, { - Name: "allow_all", - ClientIDs: []string{"allow_all"}, - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.FullWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, - }}) - - testCases := []struct { - name string - id string - wantLen int - }{{ - name: "global_settings", - id: "default", - wantLen: len(globalBlockedServices), - }, { - name: "custom_settings", - id: "no_services", - wantLen: 0, - }, { - name: "custom_settings_block", - id: "services", - wantLen: len(clientBlockedServices), - }, { - name: "custom_settings_invalid", - id: "invalid_services", - wantLen: 0, - }, { - name: "custom_settings_inactive_schedule", - id: "allow_all", - wantLen: 0, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setts := &filtering.Settings{} - - applyAdditionalFiltering(testIPv4, tc.id, setts) - require.Len(t, setts.ServicesRules, tc.wantLen) - }) - } -} diff --git a/internal/home/home.go b/internal/home/home.go index 7777e6dd..f774a06a 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -664,7 +664,8 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH globalContext.auth, err = initUsers() fatalOnError(err) - tlsMgr, err := newTLSManager(config.TLS, config.DNS.ServePlainDNS) + tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager") + tlsMgr, err := newTLSManager(ctx, tlsMgrLogger, config.TLS, config.DNS.ServePlainDNS) if err != nil { log.Error("initializing tls: %s", err) onConfigModified() diff --git a/internal/home/home_test.go b/internal/home/home_internal_test.go similarity index 100% rename from internal/home/home_test.go rename to internal/home/home_internal_test.go diff --git a/internal/home/middlewares_test.go b/internal/home/middlewares_internal_test.go similarity index 100% rename from internal/home/middlewares_test.go rename to internal/home/middlewares_internal_test.go diff --git a/internal/home/options_test.go b/internal/home/options_internal_test.go similarity index 100% rename from internal/home/options_test.go rename to internal/home/options_internal_test.go diff --git a/internal/home/signal.go b/internal/home/signal.go index 824e62dd..638d3632 100644 --- a/internal/home/signal.go +++ b/internal/home/signal.go @@ -116,6 +116,6 @@ func (h *signalHandler) reloadConfig(ctx context.Context) { } if h.tlsManager != nil { - h.tlsManager.reload() + h.tlsManager.reload(ctx) } } diff --git a/internal/home/tls.go b/internal/home/tls.go index 0e8d5c62..778c674b 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -12,6 +12,7 @@ import ( "encoding/json" "encoding/pem" "fmt" + "log/slog" "net/http" "os" "strings" @@ -23,13 +24,17 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/c2h5oh/datasize" "github.com/google/go-cmp/cmp" ) // tlsManager contains the current configuration and state of AdGuard Home TLS // encryption. type tlsManager struct { + // logger is used for logging the operation of the TLS Manager. + logger *slog.Logger + // status is the current status of the configuration. It is never nil. status *tlsConfigStatus @@ -45,31 +50,38 @@ type tlsManager struct { // newTLSManager initializes the manager of TLS configuration. m is always // non-nil while any returned error indicates that the TLS configuration isn't -// valid. Thus TLS may be initialized later, e.g. via the web UI. -func newTLSManager(conf tlsConfigSettings, servePlainDNS bool) (m *tlsManager, err error) { +// valid. Thus TLS may be initialized later, e.g. via the web UI. logger must +// not be nil. +func newTLSManager( + ctx context.Context, + logger *slog.Logger, + conf tlsConfigSettings, + servePlainDNS bool, +) (m *tlsManager, err error) { m = &tlsManager{ + logger: logger, status: &tlsConfigStatus{}, conf: conf, servePlainDNS: servePlainDNS, } if m.conf.Enabled { - err = m.load() + err = m.load(ctx) if err != nil { m.conf.Enabled = false return m, err } - m.setCertFileTime() + m.setCertFileTime(ctx) } return m, nil } // load reloads the TLS configuration from files or data from the config file. -func (m *tlsManager) load() (err error) { - err = loadTLSConf(&m.conf, m.status) +func (m *tlsManager) load(ctx context.Context) (err error) { + err = m.loadTLSConf(ctx, &m.conf, m.status) if err != nil { return fmt.Errorf("loading config: %w", err) } @@ -84,16 +96,16 @@ func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) { m.confLock.Unlock() } -// setCertFileTime sets t.certLastMod from the certificate. If there are -// errors, setCertFileTime logs them. -func (m *tlsManager) setCertFileTime() { +// setCertFileTime sets [tlsManager.certLastMod] from the certificate. If there +// are errors, setCertFileTime logs them. +func (m *tlsManager) setCertFileTime(ctx context.Context) { if len(m.conf.CertificatePath) == 0 { return } fi, err := os.Stat(m.conf.CertificatePath) if err != nil { - log.Error("tls: looking up certificate path: %s", err) + m.logger.ErrorContext(ctx, "looking up certificate path", slogutil.KeyError, err) return } @@ -117,8 +129,8 @@ func (m *tlsManager) start(_ context.Context) { globalContext.web.tlsConfigChanged(context.Background(), tlsConf) } -// reload updates the configuration and restarts t. -func (m *tlsManager) reload() { +// reload updates the configuration and restarts the TLS manager. +func (m *tlsManager) reload(ctx context.Context) { m.confLock.Lock() tlsConf := m.conf m.confLock.Unlock() @@ -127,33 +139,37 @@ func (m *tlsManager) reload() { return } - fi, err := os.Stat(tlsConf.CertificatePath) + certPath := tlsConf.CertificatePath + fi, err := os.Stat(certPath) if err != nil { - log.Error("tls: %s", err) + m.logger.ErrorContext(ctx, "checking certificate file", slogutil.KeyError, err) return } if fi.ModTime().UTC().Equal(m.certLastMod) { - log.Debug("tls: certificate file isn't modified") + m.logger.InfoContext(ctx, "certificate file is not modified") return } - log.Debug("tls: certificate file is modified") + m.logger.InfoContext(ctx, "certificate file is modified") m.confLock.Lock() - err = m.load() + err = m.load(ctx) m.confLock.Unlock() if err != nil { - log.Error("tls: reloading: %s", err) + m.logger.ErrorContext(ctx, "reloading", slogutil.KeyError, err) return } m.certLastMod = fi.ModTime().UTC() - _ = reconfigureDNSServer(m) + err = m.reconfigureDNSServer() + if err != nil { + m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) + } m.confLock.Lock() tlsConf = m.conf @@ -165,9 +181,38 @@ func (m *tlsManager) reload() { globalContext.web.tlsConfigChanged(context.Background(), tlsConf) } +// reconfigureDNSServer updates the DNS server configuration using the stored +// TLS settings. +func (m *tlsManager) reconfigureDNSServer() (err error) { + tlsConf := &tlsConfigSettings{} + m.WriteDiskConfig(tlsConf) + + newConf, err := newServerConfig( + &config.DNS, + config.Clients.Sources, + tlsConf, + httpRegister, + globalContext.clients.storage, + ) + if err != nil { + return fmt.Errorf("generating forwarding dns server config: %w", err) + } + + err = globalContext.dnsServer.Reconfigure(newConf) + if err != nil { + return fmt.Errorf("starting forwarding dns server: %w", err) + } + + return nil +} + // loadTLSConf loads and validates the TLS configuration. The returned error is // also set in status.WarningValidation. -func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error) { +func (m *tlsManager) loadTLSConf( + ctx context.Context, + tlsConf *tlsConfigSettings, + status *tlsConfigStatus, +) (err error) { defer func() { if err != nil { status.WarningValidation = err.Error() @@ -190,7 +235,8 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error return err } - err = validateCertificates( + err = m.validateCertificates( + ctx, status, tlsConf.CertificateChainData, tlsConf.PrivateKeyData, @@ -342,7 +388,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { // Skip the error check, since we are only interested in the value of // status.WarningValidation. status := &tlsConfigStatus{} - _ = loadTLSConf(&setts.tlsConfigSettings, status) + _ = m.loadTLSConf(r.Context(), &setts.tlsConfigSettings, status) resp := tlsConfig{ tlsConfigSettingsExt: setts, tlsConfigStatus: status, @@ -353,6 +399,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { // setConfig updates manager conf with the given one. func (m *tlsManager) setConfig( + ctx context.Context, newConf tlsConfigSettings, status *tlsConfigStatus, servePlain aghalg.NullBool, @@ -367,10 +414,10 @@ func (m *tlsManager) setConfig( newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile newConf.PortDNSCrypt = m.conf.PortDNSCrypt if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) { - log.Info("tls config has changed, restarting https server") + m.logger.InfoContext(ctx, "config has changed, restarting https server") restartHTTPS = true } else { - log.Info("tls: config has not changed") + m.logger.InfoContext(ctx, "config has not changed") } // Note: don't do just `t.conf = data` because we must preserve all other members of t.conf @@ -398,6 +445,8 @@ func (m *tlsManager) setConfig( // handleTLSConfigure is the handler for the POST /control/tls/configure HTTP // API. func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + req, err := unmarshalTLS(r) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "Failed to unmarshal TLS config: %s", err) @@ -416,7 +465,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) } status := &tlsConfigStatus{} - err = loadTLSConf(&req.tlsConfigSettings, status) + err = m.loadTLSConf(ctx, &req.tlsConfigSettings, status) if err != nil { resp := tlsConfig{ tlsConfigSettingsExt: req, @@ -428,8 +477,8 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) return } - restartHTTPS := m.setConfig(req.tlsConfigSettings, status, req.ServePlainDNS) - m.setCertFileTime() + restartHTTPS := m.setConfig(ctx, req.tlsConfigSettings, status, req.ServePlainDNS) + m.setCertFileTime(ctx) if req.ServePlainDNS != aghalg.NBNull { func() { @@ -442,8 +491,10 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) onConfigModified() - err = reconfigureDNSServer(m) + err = m.reconfigureDNSServer() if err != nil { + m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) + aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) return @@ -530,15 +581,27 @@ func validatePorts( // validateCertChain verifies certs using the first as the main one and others // as intermediate. srvName stands for the expected DNS name. -func validateCertChain(certs []*x509.Certificate, srvName string) (err error) { +func (m *tlsManager) validateCertChain( + ctx context.Context, + certs []*x509.Certificate, + srvName string, +) (err error) { main, others := certs[0], certs[1:] pool := x509.NewCertPool() for _, cert := range others { - log.Info("tls: got an intermediate cert") pool.AddCert(cert) } + othersLen := len(others) + if othersLen > 0 { + m.logger.InfoContext( + ctx, + "verifying certificate chain: got an intermediate cert", + "num", othersLen, + ) + } + opts := x509.VerifyOptions{ DNSName: srvName, Roots: globalContext.tlsRoots, @@ -552,15 +615,18 @@ func validateCertChain(certs []*x509.Certificate, srvName string) (err error) { return nil } -// errNoIPInCert is the error that is returned from [parseCertChain] if the leaf -// certificate doesn't contain IPs. +// errNoIPInCert is the error that is returned from [tlsManager.parseCertChain] +// if the leaf certificate doesn't contain IPs. const errNoIPInCert errors.Error = `certificates has no IP addresses; ` + `DNS-over-TLS won't be advertised via DDR` // parseCertChain parses the certificate chain from raw data, and returns it. // If ok is true, the returned error, if any, is not critical. -func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err error) { - log.Debug("tls: got certificate chain: %d bytes", len(chain)) +func (m *tlsManager) parseCertChain( + ctx context.Context, + chain []byte, +) (parsedCerts []*x509.Certificate, ok bool, err error) { + m.logger.DebugContext(ctx, "parsing certificate chain", "size", datasize.ByteSize(len(chain))) var certs []*pem.Block for decoded, pemblock := pem.Decode(chain); decoded != nil; { @@ -576,7 +642,7 @@ func parseCertChain(chain []byte) (parsedCerts []*x509.Certificate, ok bool, err return nil, false, err } - log.Info("tls: number of certs: %d", len(parsedCerts)) + m.logger.InfoContext(ctx, "parsing multiple pem certificates", "num", len(parsedCerts)) if !aghtls.CertificateHasIP(parsedCerts[0]) { err = errNoIPInCert @@ -643,7 +709,8 @@ func validatePKey(pkey []byte) (keyType string, err error) { // validateCertificates processes certificate data and its private key. status // must not be nil, since it's used to accumulate the validation results. Other // parameters are optional. -func validateCertificates( +func (m *tlsManager) validateCertificates( + ctx context.Context, status *tlsConfigStatus, certChain []byte, pkey []byte, @@ -652,7 +719,7 @@ func validateCertificates( // Check only the public certificate separately from the key. if len(certChain) > 0 { var certs []*x509.Certificate - certs, status.ValidCert, err = parseCertChain(certChain) + certs, status.ValidCert, err = m.parseCertChain(ctx, certChain) if !status.ValidCert { // Don't wrap the error, since it's informative enough as is. return err @@ -665,7 +732,7 @@ func validateCertificates( status.NotBefore = mainCert.NotBefore status.DNSNames = mainCert.DNSNames - if chainErr := validateCertChain(certs, serverName); chainErr != nil { + if chainErr := m.validateCertChain(ctx, certs, serverName); chainErr != nil { // Let self-signed certs through and don't return this error to set // its message into the status.WarningValidation afterwards. err = chainErr diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index 201cf551..e67393b4 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -1,11 +1,33 @@ package home import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var testCertChainData = []byte(`-----BEGIN CERTIFICATE----- @@ -41,9 +63,15 @@ kXS9jgARhhiWXJrk -----END PRIVATE KEY-----`) func TestValidateCertificates(t *testing.T) { + ctx := testutil.ContextWithTimeout(t, testTimeout) + logger := slogutil.NewDiscardLogger() + + m, err := newTLSManager(ctx, logger, tlsConfigSettings{}, false) + require.NoError(t, err) + t.Run("bad_certificate", func(t *testing.T) { status := &tlsConfigStatus{} - err := validateCertificates(status, []byte("bad cert"), nil, "") + err = m.validateCertificates(ctx, status, []byte("bad cert"), nil, "") testutil.AssertErrorMsg(t, "empty certificate", err) assert.False(t, status.ValidCert) assert.False(t, status.ValidChain) @@ -51,14 +79,14 @@ func TestValidateCertificates(t *testing.T) { t.Run("bad_private_key", func(t *testing.T) { status := &tlsConfigStatus{} - err := validateCertificates(status, nil, []byte("bad priv key"), "") + err = m.validateCertificates(ctx, status, nil, []byte("bad priv key"), "") testutil.AssertErrorMsg(t, "no valid keys were found", err) assert.False(t, status.ValidKey) }) t.Run("valid", func(t *testing.T) { status := &tlsConfigStatus{} - err := validateCertificates(status, testCertChainData, testPrivateKeyData, "") + err = m.validateCertificates(ctx, status, testCertChainData, testPrivateKeyData, "") assert.Error(t, err) notBefore := time.Date(2019, 2, 27, 9, 24, 23, 0, time.UTC) @@ -75,3 +103,422 @@ func TestValidateCertificates(t *testing.T) { assert.True(t, status.ValidPair) }) } + +// storeGlobals is a test helper function that saves global variables and +// restores them once the test is complete. +// +// The global variables are: +// - [configuration.dns] +// - [homeContext.clients.storage] +// - [homeContext.dnsServer] +// - [homeContext.mux] +// - [homeContext.web] +// +// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global +// variables. Make tests that use this helper concurrent. +func storeGlobals(tb testing.TB) { + tb.Helper() + + prevConfig := config + storage := globalContext.clients.storage + dnsServer := globalContext.dnsServer + mux := globalContext.mux + web := globalContext.web + + tb.Cleanup(func() { + config = prevConfig + globalContext.clients.storage = storage + globalContext.dnsServer = dnsServer + globalContext.mux = mux + globalContext.web = web + }) +} + +// newCertAndKey is a helper function that generates certificate and key. +func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey) { + tb.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(tb, err) + + certTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(n), + } + + certDER, err = x509.CreateCertificate(rand.Reader, certTmpl, certTmpl, &key.PublicKey, key) + require.NoError(tb, err) + + return certDER, key +} + +// writeCertAndKey is a helper function that writes certificate and key to +// specified paths. +func writeCertAndKey( + tb testing.TB, + certDER []byte, + certPath string, + key *rsa.PrivateKey, + keyPath string, +) { + tb.Helper() + + certFile, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE, 0o600) + require.NoError(tb, err) + + defer func() { + err = certFile.Close() + require.NoError(tb, err) + }() + + err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + require.NoError(tb, err) + + keyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE, 0o600) + require.NoError(tb, err) + + defer func() { + err = keyFile.Close() + require.NoError(tb, err) + }() + + err = pem.Encode(keyFile, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + require.NoError(tb, err) +} + +// assertCertSerialNumber is a helper function that checks serial number of the +// TLS certificate. +func assertCertSerialNumber(tb testing.TB, conf *tlsConfigSettings, wantSN int64) { + tb.Helper() + + cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData) + require.NoError(tb, err) + + assert.Equal(tb, wantSN, cert.Leaf.SerialNumber.Int64()) +} + +func TestTLSManager_Reload(t *testing.T) { + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ + Logger: logger, + }) + require.NoError(t, err) + + globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ + Logger: logger, + Clock: timeutil.SystemClock{}, + }) + require.NoError(t, err) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + const ( + snBefore int64 = 1 + snAfter int64 = 2 + ) + + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + certDER, key := newCertAndKey(t, snBefore) + writeCertAndKey(t, certDER, certPath, key, keyPath) + + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificatePath: certPath, + PrivateKeyPath: keyPath, + }, + }, false) + require.NoError(t, err) + + conf := &tlsConfigSettings{} + m.WriteDiskConfig(conf) + assertCertSerialNumber(t, conf, snBefore) + + certDER, key = newCertAndKey(t, snAfter) + writeCertAndKey(t, certDER, certPath, key, keyPath) + + m.reload(ctx) + + m.WriteDiskConfig(conf) + assertCertSerialNumber(t, conf, snAfter) +} + +func TestTLSManager_HandleTLSStatus(t *testing.T) { + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: string(testCertChainData), + PrivateKey: string(testPrivateKeyData), + }, + }, false) + require.NoError(t, err) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/control/tls/status", nil) + m.handleTLSStatus(w, r) + + res := &tlsConfigSettingsExt{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData) + assert.True(t, res.Enabled) + assert.Equal(t, wantCertificateChain, res.CertificateChain) + assert.True(t, res.PrivateKeySaved) +} + +func TestValidateTLSSettings(t *testing.T) { + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, ln.Close) + + addr := testutil.RequireTypeAssert[*net.TCPAddr](t, ln.Addr()) + + busyPort := addr.Port + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + testCases := []struct { + setts tlsConfigSettingsExt + name string + wantErr string + }{{ + name: "basic", + setts: tlsConfigSettingsExt{}, + wantErr: "", + }, { + setts: tlsConfigSettingsExt{ + ServePlainDNS: aghalg.NBFalse, + }, + name: "disabled_all", + wantErr: "plain DNS is required in case encryption protocols are disabled", + }, { + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: uint16(busyPort), + }, + }, + name: "busy_port", + wantErr: fmt.Sprintf("port %d is not available, cannot enable HTTPS on it", busyPort), + }, { + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 4433, + PortDNSOverTLS: 4433, + }, + }, + name: "duplicate_port", + wantErr: "validating tcp ports: duplicated values: [4433]", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = validateTLSSettings(tc.setts) + testutil.AssertErrorMsg(t, tc.wantErr, err) + }) + } +} + +func TestTLSManager_HandleTLSValidate(t *testing.T) { + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: string(testCertChainData), + PrivateKey: string(testPrivateKeyData), + }, + }, false) + require.NoError(t, err) + + setts := &tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), + PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), + }, + }, + } + + req, err := json.Marshal(setts) + require.NoError(t, err) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/control/tls/validate", bytes.NewReader(req)) + m.handleTLSValidate(w, r) + + res := &tlsConfigStatus{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) + require.NoError(t, err) + + wantIssuer := cert.Leaf.Issuer.String() + assert.Equal(t, wantIssuer, res.Issuer) +} + +func TestTLSManager_HandleTLSConfigure(t *testing.T) { + // Store the global state before making any changes. + storeGlobals(t) + + var ( + logger = slogutil.NewDiscardLogger() + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ + Logger: logger, + }) + require.NoError(t, err) + + err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{ + Config: dnsforward.Config{ + UpstreamMode: dnsforward.UpstreamModeLoadBalance, + EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, + ClientsContainer: dnsforward.EmptyClientsContainer{}, + }, + ServePlainDNS: true, + }) + require.NoError(t, err) + + globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ + Logger: logger, + Clock: timeutil.SystemClock{}, + }) + require.NoError(t, err) + + globalContext.mux = http.NewServeMux() + + globalContext.web, err = initWeb(ctx, options{}, nil, nil, logger, nil, false) + require.NoError(t, err) + + config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")} + config.DNS.Port = 0 + + const wantSerialNumber int64 = 1 + + // Prepare the TLS manager configuration. + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + certDER, key := newCertAndKey(t, wantSerialNumber) + writeCertAndKey(t, certDER, certPath, key, keyPath) + + // Initialize the TLS manager and assert its configuration. + m, err := newTLSManager(ctx, logger, tlsConfigSettings{ + Enabled: true, + TLSConfig: dnsforward.TLSConfig{ + CertificatePath: certPath, + PrivateKeyPath: keyPath, + }, + }, true) + require.NoError(t, err) + + conf := &tlsConfigSettings{} + m.WriteDiskConfig(conf) + assertCertSerialNumber(t, conf, wantSerialNumber) + + // Prepare a request with the new TLS configuration. + setts := &tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 4433, + TLSConfig: dnsforward.TLSConfig{ + CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData), + PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData), + }, + }, + } + + req, err := json.Marshal(setts) + require.NoError(t, err) + + r := httptest.NewRequest(http.MethodPost, "/control/tls/configure", bytes.NewReader(req)) + w := httptest.NewRecorder() + + // Reconfigure the TLS manager. + m.handleTLSConfigure(w, r) + + // The [tlsManager.handleTLSConfigure] method will start the DNS server and + // it should be stopped after the test ends. + testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop) + + res := &tlsConfig{ + tlsConfigStatus: &tlsConfigStatus{}, + } + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) + require.NoError(t, err) + + wantIssuer := cert.Leaf.Issuer.String() + assert.Equal(t, wantIssuer, res.tlsConfigStatus.Issuer) + + // Assert that the Web API's TLS configuration has been updated. + // + // TODO(s.chzhen): Remove when [httpsServer.cond] is removed. + assert.Eventually(t, func() bool { + globalContext.web.httpsServer.condLock.Lock() + defer globalContext.web.httpsServer.condLock.Unlock() + + cert = globalContext.web.httpsServer.cert + if cert.Leaf == nil { + return false + } + + assert.Equal(t, wantIssuer, cert.Leaf.Issuer.String()) + + return true + }, testTimeout, testTimeout/10) +} diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_internal_test.go similarity index 100% rename from internal/querylog/decode_test.go rename to internal/querylog/decode_internal_test.go diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_internal_test.go similarity index 100% rename from internal/querylog/qlog_test.go rename to internal/querylog/qlog_internal_test.go diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_internal_test.go similarity index 100% rename from internal/querylog/qlogfile_test.go rename to internal/querylog/qlogfile_internal_test.go diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_internal_test.go similarity index 100% rename from internal/querylog/qlogreader_test.go rename to internal/querylog/qlogreader_internal_test.go diff --git a/internal/querylog/search_test.go b/internal/querylog/search_internal_test.go similarity index 100% rename from internal/querylog/search_test.go rename to internal/querylog/search_internal_test.go diff --git a/internal/stats/http_test.go b/internal/stats/http_internal_test.go similarity index 100% rename from internal/stats/http_test.go rename to internal/stats/http_internal_test.go diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 0e6dbed0..20132c42 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -4,6 +4,12 @@ ## v0.108.0: API changes +## v0.107.58: API changes + +### The ability to check rules for query types and/or clients: GET /control/check_host + +- Added optional `client` and `qtype` URL query parameters. + ## v0.107.57: API changes - The new field `"upstream_timeout"` in `GET /control/dns_info` and `POST /control/dns_config` is the number of seconds to wait for a response from the upstream server. diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 633138c4..d6c47ce2 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -739,6 +739,20 @@ - 'name': 'name' 'in': 'query' 'description': 'Filter by host name' + 'required': true + 'example': 'google.com' + 'schema': + 'type': 'string' + - 'name': 'client' + 'in': 'query' + 'description': 'Optional ClientID or client IP address' + 'example': '192.0.2.1' + 'schema': + 'type': 'string' + - 'name': 'qtype' + 'in': 'query' + 'description': 'Optional DNS type' + 'example': 'AAAA' 'schema': 'type': 'string' 'responses':