Merge branch 'master' into 6902-doc-fix

This commit is contained in:
Ainar Garipov
2024-04-26 18:44:52 +03:00
45 changed files with 1413 additions and 1771 deletions

View File

@@ -27,7 +27,40 @@ NOTE: Add new changes BELOW THIS COMMENT.
- Support for comments in the ipset file ([#5345]). - Support for comments in the ipset file ([#5345]).
### Changed
- Private rDNS resolution now also affects `SOA` and `NS` requests ([#6882]).
- Rewrite rules mechanics was changed due to improve resolving in safe search.
### Fixed
- Domain specifications for top-level domains not considered for requests to
unqualified domains ([#6744]).
- Support for link-local subnets, i.e. `fe80::/16`, as client identifiers
([#6312]).
- Issues with QUIC and HTTP/3 upstreams on older Linux kernel versions
([#6422]).
- YouTube restricted mode is not enforced by HTTPS queries on Firefox.
- Support for link-local subnets, i.e. `fe80::/16`, in the access settings
([#6192]).
- The ability to apply an invalid configuration for private RDNS, which led to
server inoperability.
- Ignoring query log for clients with ClientID set ([#5812]).
- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
incorrectly considered invalid when specified for private RDNS upstream
servers ([#6854]).
- Unspecified IP addresses aren't checked when using "Fastest IP address" mode
([#6875]).
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345 [#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
[#5812]: https://github.com/AdguardTeam/AdGuardHome/issues/5812
[#6192]: https://github.com/AdguardTeam/AdGuardHome/issues/6192
[#6312]: https://github.com/AdguardTeam/AdGuardHome/issues/6312
[#6422]: https://github.com/AdguardTeam/AdGuardHome/issues/6422
[#6744]: https://github.com/AdguardTeam/AdGuardHome/issues/6744
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854
[#6875]: https://github.com/AdguardTeam/AdGuardHome/issues/6875
[#6882]: https://github.com/AdguardTeam/AdGuardHome/issues/6882
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.
@@ -41,7 +74,7 @@ See also the [v0.107.48 GitHub milestone][ms-v0.107.48].
### Fixed ### Fixed
- Access settings not being applied to encrypted protocols ([#6890]) - Access settings not being applied to encrypted protocols ([#6890]).
[#6890]: https://github.com/AdguardTeam/AdGuardHome/issues/6890 [#6890]: https://github.com/AdguardTeam/AdGuardHome/issues/6890

View File

@@ -13,14 +13,14 @@
"fallback_dns_desc": "List of fallback DNS servers used when upstream DNS servers are not responding. The syntax is the same as in the main upstreams field above.", "fallback_dns_desc": "List of fallback DNS servers used when upstream DNS servers are not responding. The syntax is the same as in the main upstreams field above.",
"fallback_dns_placeholder": "Enter one fallback DNS server per line", "fallback_dns_placeholder": "Enter one fallback DNS server per line",
"local_ptr_title": "Private reverse DNS servers", "local_ptr_title": "Private reverse DNS servers",
"local_ptr_desc": "The DNS servers that AdGuard Home uses for local PTR queries. These servers are used to resolve PTR requests for addresses in private IP ranges, for example \"192.168.12.34\", using reverse DNS. If not set, AdGuard Home uses the addresses of the default DNS resolvers of your OS except for the addresses of AdGuard Home itself.", "local_ptr_desc": "DNS servers used by AdGuard Home for private PTR, SOA, and NS requests. A request is considered private if it asks for an ARPA domain containing a subnet within private IP ranges (such as \"192.168.12.34\") and comes from a client with a private IP address. If not set, the default DNS resolvers of your OS will be used, except for the AdGuard Home IP addresses.",
"local_ptr_default_resolver": "By default, AdGuard Home uses the following reverse DNS resolvers: {{ip}}.", "local_ptr_default_resolver": "By default, AdGuard Home uses the following reverse DNS resolvers: {{ip}}.",
"local_ptr_no_default_resolver": "AdGuard Home could not determine suitable private reverse DNS resolvers for this system.", "local_ptr_no_default_resolver": "AdGuard Home could not determine suitable private reverse DNS resolvers for this system.",
"local_ptr_placeholder": "Enter one IP address per line", "local_ptr_placeholder": "Enter one IP address per line",
"resolve_clients_title": "Enable reverse resolving of clients' IP addresses", "resolve_clients_title": "Enable reverse resolving of clients' IP addresses",
"resolve_clients_desc": "Reversely resolve clients' IP addresses into their hostnames by sending PTR queries to corresponding resolvers (private DNS servers for local clients, upstream servers for clients with public IP addresses).", "resolve_clients_desc": "Reversely resolve clients' IP addresses into their hostnames by sending PTR queries to corresponding resolvers (private DNS servers for local clients, upstream servers for clients with public IP addresses).",
"use_private_ptr_resolvers_title": "Use private reverse DNS resolvers", "use_private_ptr_resolvers_title": "Use private reverse DNS resolvers",
"use_private_ptr_resolvers_desc": "Perform reverse DNS lookups for locally served addresses using these upstream servers. If disabled, AdGuard Home responds with NXDOMAIN to all such PTR requests except for clients known from DHCP, /etc/hosts, and so on.", "use_private_ptr_resolvers_desc": "Resolve PTR, SOA, and NS requests for ARPA domains containing private IP addresses through private upstream servers, DHCP, /etc/hosts, etc. If disabled, AdGuard Home will respond to all such requests with NXDOMAIN.",
"check_dhcp_servers": "Check for DHCP servers", "check_dhcp_servers": "Check for DHCP servers",
"save_config": "Save configuration", "save_config": "Save configuration",
"enabled_dhcp": "DHCP server enabled", "enabled_dhcp": "DHCP server enabled",

23
go.mod
View File

@@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
go 1.22.2 go 1.22.2
require ( require (
github.com/AdguardTeam/dnsproxy v0.67.0 github.com/AdguardTeam/dnsproxy v0.71.0
github.com/AdguardTeam/golibs v0.21.0 github.com/AdguardTeam/golibs v0.23.2
github.com/AdguardTeam/urlfilter v0.18.0 github.com/AdguardTeam/urlfilter v0.18.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnscrypt/v2 v2.2.7
@@ -28,14 +28,15 @@ require (
// own code for that. Perhaps, use gopacket. // own code for that. Perhaps, use gopacket.
github.com/mdlayher/raw v0.1.0 github.com/mdlayher/raw v0.1.0
github.com/miekg/dns v1.1.58 github.com/miekg/dns v1.1.58
github.com/quic-go/quic-go v0.41.0 // TODO(a.garipov): Use release version.
github.com/stretchr/testify v1.8.4 github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f
github.com/stretchr/testify v1.9.0
github.com/ti-mo/netfilter v0.5.1 github.com/ti-mo/netfilter v0.5.1
go.etcd.io/bbolt v1.3.9 go.etcd.io/bbolt v1.3.9
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.22.0
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8
golang.org/x/net v0.23.0 golang.org/x/net v0.24.0
golang.org/x/sys v0.18.0 golang.org/x/sys v0.19.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
howett.net/plist v1.0.1 howett.net/plist v1.0.1
@@ -58,9 +59,9 @@ require (
github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.4.0 // indirect
golang.org/x/mod v0.16.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.6.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.19.0 // indirect golang.org/x/tools v0.20.0 // indirect
gonum.org/v1/gonum v0.14.0 // indirect gonum.org/v1/gonum v0.14.0 // indirect
) )

50
go.sum
View File

@@ -1,7 +1,7 @@
github.com/AdguardTeam/dnsproxy v0.67.0 h1:7oKfcA8sm9d1N4qvhsNmQWBX4+fs3sX4cAnERmBXEbw= github.com/AdguardTeam/dnsproxy v0.71.0 h1:7epvSeTtjTdwbeifRJK6jTXldoJwqd3GWqAYDlSo7zQ=
github.com/AdguardTeam/dnsproxy v0.67.0/go.mod h1:XLfD6IpSplUZZ+f5vhWSJW1mp4wm+KkHWiMo9w7U1Ls= github.com/AdguardTeam/dnsproxy v0.71.0/go.mod h1:rCaCL4m4n63sgwTOyUVdc7MC42PlUYBt11Fz/UjD+kM=
github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE= github.com/AdguardTeam/golibs v0.23.2 h1:rMjYantwtQ39e8G4zBQ6ZLlm4s3XH30Bc9VxhoOHwao=
github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= github.com/AdguardTeam/golibs v0.23.2/go.mod h1:o9i55Sx6v7qogRQeqaBfmLbC/pZqeMBWi015U5PTDY0=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s= github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s=
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
@@ -101,19 +101,19 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f h1:L7x60Z6AW2giF/SvbDpMglGHJxtmFJV03khPwXLDScU=
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4= github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4= github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU= github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8= github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8=
github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw= github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw=
@@ -131,26 +131,26 @@ go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -161,17 +161,19 @@ golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=

View File

@@ -159,21 +159,11 @@ func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c) notifyReconfigureSignal(c)
} }
// NotifyShutdownSignal notifies c on receiving shutdown signals.
func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c)
}
// IsReconfigureSignal returns true if sig is a reconfigure signal. // IsReconfigureSignal returns true if sig is a reconfigure signal.
func IsReconfigureSignal(sig os.Signal) (ok bool) { func IsReconfigureSignal(sig os.Signal) (ok bool) {
return isReconfigureSignal(sig) return isReconfigureSignal(sig)
} }
// IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig)
}
// SendShutdownSignal sends the shutdown signal to the channel. // SendShutdownSignal sends the shutdown signal to the channel.
func SendShutdownSignal(c chan<- os.Signal) { func SendShutdownSignal(c chan<- os.Signal) {
sendShutdownSignal(c) sendShutdownSignal(c)

View File

@@ -13,26 +13,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP) signal.Notify(c, unix.SIGHUP)
} }
func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
}
func isReconfigureSignal(sig os.Signal) (ok bool) { func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == unix.SIGHUP return sig == unix.SIGHUP
} }
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
unix.SIGINT,
unix.SIGQUIT,
unix.SIGTERM:
return true
default:
return false
}
}
func sendShutdownSignal(_ chan<- os.Signal) { func sendShutdownSignal(_ chan<- os.Signal) {
// On Unix we are already notified by the system. // On Unix we are already notified by the system.
} }

View File

@@ -5,7 +5,6 @@ package aghos
import ( import (
"os" "os"
"os/signal" "os/signal"
"syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@@ -43,25 +42,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, windows.SIGHUP) signal.Notify(c, windows.SIGHUP)
} }
func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows.
signal.Notify(c, os.Interrupt)
}
func isReconfigureSignal(sig os.Signal) (ok bool) { func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == windows.SIGHUP return sig == windows.SIGHUP
} }
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case os.Interrupt, syscall.SIGTERM:
return true
default:
return false
}
}
func sendShutdownSignal(c chan<- os.Signal) { func sendShutdownSignal(c chan<- os.Signal) {
c <- os.Interrupt c <- os.Interrupt
} }

View File

@@ -197,8 +197,10 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return ci.uidToClient[uid], true return ci.uidToClient[uid], true
} }
ipWithoutZone := ip.WithZone("")
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) { ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) { // Remove zone before checking because prefixes strip zones.
if pref.Contains(ipWithoutZone) {
uid, found = id, true uid, found = id, true
return false return false
@@ -214,6 +216,26 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false return nil, false
} }
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
for addr, uid := range ci.ipToUID {
if addr.WithZone("") == ip {
return ci.uidToClient[uid]
}
}
return nil
}
// find finds persistent client by MAC. // find finds persistent client by MAC.
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac) k := macToKey(mac)

View File

@@ -35,27 +35,49 @@ func TestClientIndex(t *testing.T) {
cliID = "client-id" cliID = "client-id"
cliMAC = "11:11:11:11:11:11" cliMAC = "11:11:11:11:11:11"
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
linkLocalSubnet = "fe80::/16"
) )
clients := []*Persistent{{ var (
Name: "client1", clientWithBothFams = &Persistent{
IPs: []netip.Addr{ Name: "client1",
netip.MustParseAddr(cliIP1), IPs: []netip.Addr{
netip.MustParseAddr(cliIPv6), netip.MustParseAddr(cliIP1),
}, netip.MustParseAddr(cliIPv6),
}, { },
Name: "client2", }
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients) clientWithSubnet = &Persistent{
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}
clientWithMAC = &Persistent{
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}
clientWithID = &Persistent{
Name: "client_with_id",
ClientIDs: []string{cliID},
}
clientLinkLocal = &Persistent{
Name: "client_link_local",
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
}
)
ci := newIDIndex([]*Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
})
testCases := []struct { testCases := []struct {
want *Persistent want *Persistent
@@ -64,19 +86,23 @@ func TestClientIndex(t *testing.T) {
}{{ }{{
name: "ipv4_ipv6", name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6}, ids: []string{cliIP1, cliIPv6},
want: clients[0], want: clientWithBothFams,
}, { }, {
name: "ipv4_subnet", name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP}, ids: []string{cliIP2, cliSubnetIP},
want: clients[1], want: clientWithSubnet,
}, { }, {
name: "mac", name: "mac",
ids: []string{cliMAC}, ids: []string{cliMAC},
want: clients[2], want: clientWithMAC,
}, { }, {
name: "client_id", name: "client_id",
ids: []string{cliID}, ids: []string{cliID},
want: clients[3], want: clientWithID,
}, {
name: "client_link_local_subnet",
ids: []string{linkLocalIP},
want: clientLinkLocal,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@@ -221,3 +247,52 @@ func TestMACToKey(t *testing.T) {
_ = macToKey(mac) _ = macToKey(mac)
}) })
} }
func TestIndex_FindByIPWithoutZone(t *testing.T) {
var (
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
)
var (
clientNoZone = &Persistent{
Name: "client",
IPs: []netip.Addr{ip},
}
clientWithZone = &Persistent{
Name: "client_with_zone",
IPs: []netip.Addr{ipWithZone},
}
)
ci := newIDIndex([]*Persistent{
clientNoZone,
clientWithZone,
})
testCases := []struct {
ip netip.Addr
want *Persistent
name string
}{{
name: "without_zone",
ip: ip,
want: clientNoZone,
}, {
name: "with_zone",
ip: ipWithZone,
want: clientWithZone,
}, {
name: "zero_address",
ip: netip.Addr{},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
}

View File

@@ -64,9 +64,7 @@ type Persistent struct {
// upstream must be used. // upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig UpstreamConfig *proxy.CustomUpstreamConfig
// TODO(d.kolyshev): Make SafeSearchConf a pointer. SafeSearch filtering.SafeSearch
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client. // BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices BlockedServices *filtering.BlockedServices
@@ -95,6 +93,9 @@ type Persistent struct {
UseOwnBlockedServices bool UseOwnBlockedServices bool
IgnoreQueryLog bool IgnoreQueryLog bool
IgnoreStatistics bool IgnoreStatistics bool
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
} }
// SetTags sets the tags if they are known, otherwise logs an unknown tag. // SetTags sets the tags if they are known, otherwise logs an unknown tag.

View File

@@ -156,7 +156,10 @@ func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
} }
for _, ipnet := range ipnets { for _, ipnet := range ipnets {
if ipnet.Contains(ip) { // Remove zone before checking because prefixes stip zones.
//
// TODO(d.kolyshev): Cover with tests.
if ipnet.Contains(ip.WithZone("")) {
return blocked, ipnet.String() return blocked, ipnet.String()
} }
} }

View File

@@ -0,0 +1,116 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(d.kolyshev): Extract to separate package.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return &proxy.BeforeRequestError{
Err: fmt.Errorf("getting clientid: %w", err),
Response: s.NewMsgSERVFAIL(pctx.Req),
}
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return nil
}
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// errAccessBlocked is a sentinel error returned when a request is blocked by
// access settings.
var errAccessBlocked errors.Error = "blocked by access settings"
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return errAccessBlocked
}
return &proxy.BeforeRequestError{
Err: errAccessBlocked,
Response: s.makeResponseREFUSED(pctx.Req),
}
}

View File

@@ -0,0 +1,299 @@
package dnsforward
import (
"crypto/tls"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
blockedHost = "blockedhost.org"
testFQDN = "example.org."
dnsClientTimeout = 200 * time.Millisecond
)
func TestServer_HandleBefore_tls(t *testing.T) {
t.Parallel()
const clientID = "client-1"
testCases := []struct {
clientSrvName string
name string
host string
allowedClients []string
disallowedClients []string
blockedHosts []string
wantRCode int
}{{
clientSrvName: tlsServerName,
name: "allow_all",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: "%" + "." + tlsServerName,
name: "invalid_client_id",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeServerFailure,
}, {
clientSrvName: clientID + "." + tlsServerName,
name: "allowed_client_allowed",
host: testFQDN,
allowedClients: []string{clientID},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: "client-2." + tlsServerName,
name: "allowed_client_rejected",
host: testFQDN,
allowedClients: []string{clientID},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
clientSrvName: tlsServerName,
name: "disallowed_client_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{clientID},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: clientID + "." + tlsServerName,
name: "disallowed_client_rejected",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{clientID},
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
clientSrvName: tlsServerName,
name: "blocked_hosts_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: tlsServerName,
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeRefused,
}}
localAns := []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
Rdlength: 4,
},
A: net.IP{1, 2, 3, 4},
}}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = localAns
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s, _ := createTestTLS(t, TLSConfig{
TLSListenAddrs: []*net.TCPAddr{{}},
ServerName: tlsServerName,
})
s.conf.UpstreamDNS = []string{localUpsAddr}
s.conf.AllowedClients = tc.allowedClients
s.conf.DisallowedClients = tc.disallowedClients
s.conf.BlockedHosts = tc.blockedHosts
err := s.Prepare(&s.conf)
require.NoError(t, err)
startDeferStop(t, s)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: tc.clientSrvName,
}
client := &dns.Client{
Net: "tcp-tls",
TLSConfig: tlsConfig,
Timeout: dnsClientTimeout,
}
req := createTestMessage(tc.host)
addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()
reply, _, err := client.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, tc.wantRCode, reply.Rcode)
if tc.wantRCode == dns.RcodeSuccess {
assert.Equal(t, localAns, reply.Answer)
} else {
assert.Empty(t, reply.Answer)
}
})
}
}
func TestServer_HandleBefore_udp(t *testing.T) {
t.Parallel()
const (
clientIPv4 = "127.0.0.1"
clientIPv6 = "::1"
)
clientIPs := []string{clientIPv4, clientIPv6}
testCases := []struct {
name string
host string
allowedClients []string
disallowedClients []string
blockedHosts []string
wantTimeout bool
}{{
name: "allow_all",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "allowed_client_allowed",
host: testFQDN,
allowedClients: clientIPs,
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "allowed_client_rejected",
host: testFQDN,
allowedClients: []string{"1:2:3::4"},
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: true,
}, {
name: "disallowed_client_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{"1:2:3::4"},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "disallowed_client_rejected",
host: testFQDN,
allowedClients: []string{},
disallowedClients: clientIPs,
blockedHosts: []string{},
wantTimeout: true,
}, {
name: "blocked_hosts_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantTimeout: false,
}, {
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantTimeout: true,
}}
localAns := []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
Rdlength: 4,
},
A: net.IP{1, 2, 3, 4},
}}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = localAns
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
AllowedClients: tc.allowedClients,
DisallowedClients: tc.disallowedClients,
BlockedHosts: tc.blockedHosts,
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
})
startDeferStop(t, s)
client := &dns.Client{
Net: "udp",
Timeout: dnsClientTimeout,
}
req := createTestMessage(tc.host)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
reply, _, err := client.Exchange(req, addr)
if tc.wantTimeout {
wantErr := &net.OpError{}
require.ErrorAs(t, err, &wantErr)
assert.True(t, wantErr.Timeout())
assert.Nil(t, reply)
} else {
require.NoError(t, err)
require.NotNil(t, reply)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.Equal(t, localAns, reply.Answer)
}
})
}
}

View File

@@ -110,46 +110,6 @@ type quicConnection interface {
ConnectionState() (cs quic.ConnectionState) ConnectionState() (cs quic.ConnectionState)
} }
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// clientServerName returns the TLS server name based on the protocol. For // clientServerName returns the TLS server name based on the protocol. For
// DNS-over-HTTPS requests, it will return the hostname part of the Host header // DNS-over-HTTPS requests, it will return the hostname part of the Host header
// if there is one. // if there is one.

View File

@@ -235,9 +235,18 @@ type DNSCryptConfig struct {
// ServerConfig represents server configuration. // ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use. // The zero ServerConfig is empty and ready for use.
type ServerConfig struct { type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address // UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
TCPListenAddrs []*net.TCPAddr // TCP listen address UDPListenAddrs []*net.UDPAddr
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
TCPListenAddrs []*net.TCPAddr
// UpstreamConfig is the general configuration of upstream DNS servers.
UpstreamConfig *proxy.UpstreamConfig
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
// for private reverse DNS.
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
// AddrProcConf defines the configuration for the client IP processor. // AddrProcConf defines the configuration for the client IP processor.
// If nil, [client.EmptyAddrProc] is used. // If nil, [client.EmptyAddrProc] is used.
@@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
conf = &proxy.Config{ conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3, HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit), Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6, RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist, RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny, RefuseAny: srvConf.RefuseAny,
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes), TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
CacheMinTTL: srvConf.CacheMinTTL, CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL, CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic, CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig, UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
RequestHandler: s.handleDNSRequest, BeforeRequestHandler: s,
HTTPSServerName: aghhttp.UserAgent(), RequestHandler: s.handleDNSRequest,
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled, HTTPSServerName: aghhttp.UserAgent(),
MaxGoroutines: srvConf.MaxGoroutines, EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
UseDNS64: srvConf.UseDNS64, MaxGoroutines: srvConf.MaxGoroutines,
DNS64Prefs: srvConf.DNS64Prefixes, UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
} }
if srvConf.EDNSClientSubnet.UseCustom { if srvConf.EDNSClientSubnet.UseCustom {
@@ -459,6 +472,26 @@ func (s *Server) prepareIpsetListSettings() (err error) {
return s.ipset.init(ipsets) return s.ipset.init(ipsets)
} }
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
if conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
}
var data []byte
data, err = os.ReadFile(conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}
// collectListenAddr adds addrPort to addrs. It also adds its port to // collectListenAddr adds addrPort to addrs. It also adds its port to
// unspecPorts if its address is unspecified. // unspecPorts if its address is unspecified.
func collectListenAddr( func collectListenAddr(
@@ -530,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr()) return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
} }
// filterOut filters out all the upstreams that match um. It returns all the // filterOutAddrs filters out all the upstreams that match um. It returns all
// closing errors joined. // the closing errors joined.
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) { func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
var errs []error var errs []error
delFunc := func(u upstream.Upstream) (ok bool) { delFunc := func(u upstream.Upstream) (ok bool) {

View File

@@ -3,7 +3,6 @@ package dnsforward
import ( import (
"net" "net"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@@ -11,6 +10,7 @@ import (
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
} }
func TestServer_HandleDNSRequest_dns64(t *testing.T) { func TestServer_HandleDNSRequest_dns64(t *testing.T) {
t.Parallel()
const ( const (
ipv4Domain = "ipv4.only." ipv4Domain = "ipv4.only."
ipv6Domain = "ipv6.only." ipv6Domain = "ipv6.only."
@@ -252,32 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Len(pt, m.Question, 1) require.Len(pt, m.Question, 1)
require.Equal(pt, m.Question[0].Name, ptr64Domain) require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp := (&dns.Msg{
Answer: []dns.RR{localRR}, resp := (&dns.Msg{}).SetReply(m)
}).SetReply(m) resp.Answer = []dns.RR{localRR}
require.NoError(t, w.WriteMsg(resp)) require.NoError(t, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{ client := &dns.Client{
Net: "tcp", Net: string(proxy.ProtoTCP),
Timeout: 1 * time.Second, Timeout: testTimeout,
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel()
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0] q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype] answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{ resp := (&dns.Msg{}).SetReply(req)
Answer: answer[sectionAnswer], resp.Answer = answer[sectionAnswer]
Ns: answer[sectionAuthority], resp.Ns = answer[sectionAuthority]
Extra: answer[sectionAdditional], resp.Extra = answer[sectionAdditional]
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp)) require.NoError(pt, w.WriteMsg(resp))
}) })
@@ -307,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String()) resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, excErr) require.NoError(t, excErr)
require.Equal(t, tc.wantAns, resp.Answer) require.Equal(t, tc.wantAns, resp.Answer)
}) })
} }
} }
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
t.Parallel()
// Shouldn't go to upstream at all.
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
panic("not implemented")
})
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
require.NoError(t, err)
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
cli := &dns.Client{
Net: string(proxy.ProtoTCP),
Timeout: testTimeout,
}
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
}

View File

@@ -135,12 +135,6 @@ type Server struct {
// WHOIS, etc. // WHOIS, etc.
addrProc client.AddressProcessor addrProc client.AddressProcessor
// localResolvers is a DNS proxy instance used to resolve PTR records for
// addresses considered private as per the [privateNets].
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
// sysResolvers used to fetch system resolvers to use by default for private // sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving. // PTR resolving.
sysResolvers SystemResolvers sysResolvers SystemResolvers
@@ -158,12 +152,6 @@ type Server struct {
// [upstream.Resolver] interface. // [upstream.Resolver] interface.
bootResolvers []*upstream.UpstreamResolver bootResolvers []*upstream.UpstreamResolver
// recDetector is a cache for recursive requests. It is used to detect and
// prevent recursive requests only for private upstreams.
//
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
recDetector *recursionDetector
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major // dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
// part of DNS64 happens inside the [proxy] package, but there still are // part of DNS64 happens inside the [proxy] package, but there still are
// some places where response mapping is needed (e.g. DHCP). // some places where response mapping is needed (e.g. DHCP).
@@ -212,14 +200,6 @@ type DNSCreateParams struct {
LocalDomain string LocalDomain string
} }
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 1 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent
// requests.
cachedRecurrentReqNum = 1000
)
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once // Note: this function must be called only once
// //
@@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Use some case-insensitive string comparison. // TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix), localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts, etcHosts: etcHosts,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{ clientIDCache: cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxCount: defaultClientIDCacheCount, MaxCount: defaultClientIDCacheCount,
@@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
// TODO(e.burkov): Migrate to [netip.Addr] already.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil { if err != nil {
return "", 0, fmt.Errorf("reversing ip: %w", err) return "", 0, fmt.Errorf("reversing ip: %w", err)
@@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
} }
dctx := &proxy.DNSContext{ dctx := &proxy.DNSContext{
Proto: "udp", Proto: proxy.ProtoUDP,
Req: req, Req: req,
IsPrivateClient: true,
} }
var resolver *proxy.Proxy
var errMsg string var errMsg string
if s.privateNets.Contains(ip) { if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", 0, nil return "", 0, nil
} }
resolver = s.localResolvers
errMsg = "resolving a private address: %w" errMsg = "resolving a private address: %w"
s.recDetector.add(*req) dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
} else { } else {
resolver = s.internalProxy
errMsg = "resolving an address: %w" errMsg = "resolving an address: %w"
} }
if err = resolver.Resolve(dctx); err != nil { if err = s.internalProxy.Resolve(dctx); err != nil {
return "", 0, fmt.Errorf(errMsg, err) return "", 0, fmt.Errorf(errMsg, err)
} }
@@ -473,103 +451,6 @@ func (s *Server) startLocked() error {
return err return err
} }
// prepareLocalResolvers initializes the local upstreams configuration using
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
func (s *Server) prepareLocalResolvers(
boot upstream.Resolver,
) (uc *proxy.UpstreamConfig, err error) {
set, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
resolvers := s.conf.LocalPTRResolvers
confNeedsFiltering := len(resolvers) > 0
if confNeedsFiltering {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
resolvers = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
resolvers = append(resolvers, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: boot,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return nil, fmt.Errorf("preparing private upstreams: %w", err)
}
if confNeedsFiltering {
err = filterOutAddrs(uc, set)
if err != nil {
return nil, fmt.Errorf("filtering private upstreams: %w", err)
}
}
return uc, nil
}
// LocalResolversError is an error type for errors during local resolvers setup.
// This is only needed to distinguish these errors from errors returned by
// creating the proxy.
type LocalResolversError struct {
Err error
}
// type check
var _ error = (*LocalResolversError)(nil)
// Error implements the error interface for *LocalResolversError.
func (err *LocalResolversError) Error() (s string) {
return fmt.Sprintf("creating local resolvers: %s", err.Err)
}
// type check
var _ errors.Wrapper = (*LocalResolversError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
func (err *LocalResolversError) Unwrap() error {
return err.Err
}
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running. It returns the upstream
// configuration used for private PTR resolving, or nil if it's disabled. Note,
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
localResolvers, err := proxy.New(&proxy.Config{
UpstreamConfig: uc,
})
if err != nil {
return nil, &LocalResolversError{Err: err}
}
s.localResolvers = localResolvers
// TODO(e.burkov): Should we also consider the DNS64 usage?
return uc, nil
}
// Prepare initializes parameters of s using data from conf. conf must not be // Prepare initializes parameters of s using data from conf. conf must not be
// nil. // nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) { func (s *Server) Prepare(conf *ServerConfig) (err error) {
@@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings() s.initDefaultSettings()
boot, err := s.prepareInternalDNS() err = s.prepareInternalDNS()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err
@@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err) return fmt.Errorf("preparing access: %w", err)
} }
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
proxyConfig.Fallbacks, err = s.setupFallbackDNS() proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil { if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err) return fmt.Errorf("setting up fallback dns servers: %w", err)
@@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.dnsProxy = dnsProxy s.dnsProxy = dnsProxy
s.recDetector.clear()
s.setupAddrProc() s.setupAddrProc()
s.registerHandlers() s.registerHandlers()
@@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return nil return nil
} }
// prepareInternalDNS initializes the internal state of s before initializing // prepareUpstreamSettings sets upstream DNS server settings.
// the primary DNS proxy instance. It assumes s.serverLock is locked or the func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Server not running. // Load upstreams either from the file, or from the settings
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { var upstreams []string
err = s.prepareIpsetListSettings() upstreams, err = s.conf.loadUpstreams()
if err != nil { if err != nil {
return nil, fmt.Errorf("preparing ipset settings: %w", err) return fmt.Errorf("loading upstreams: %w", err)
} }
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{ uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Timeout: DefaultTimeout, Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
}) })
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
s.conf.UpstreamConfig = uc
return nil
}
// PrivateRDNSError is returned when the private rDNS upstreams are
// invalid but enabled.
//
// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams
// configuration in proxy when the private rDNS function is enabled. In theory,
// proxy supports the case when no upstreams provided to resolve the private
// request, since it already supports this for DNS64-prefixed PTR requests.
type PrivateRDNSError struct {
err error
}
// Error implements the [errors.Error] interface.
func (e *PrivateRDNSError) Error() (s string) {
return e.err.Error()
}
func (e *PrivateRDNSError) Unwrap() (err error) {
return e.err
}
// prepareLocalResolvers initializes the private RDNS upstream configuration
// according to the server's settings. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
return nil, nil
}
var ownAddrs addrPortSet
ownAddrs, err = s.conf.ourAddrsSet()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
opts := &upstream.Options{
Bootstrap: s.bootstrap,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
}
addrs := s.conf.LocalPTRResolvers
uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts)
if err != nil {
return nil, fmt.Errorf("preparing resolvers: %w", err)
}
return uc, nil
}
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (err error) {
err = s.prepareIpsetListSettings()
if err != nil {
return fmt.Errorf("preparing ipset settings: %w", err)
}
bootOpts := &upstream.Options{
Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
}
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = s.prepareUpstreamSettings(s.bootstrap) err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err return err
}
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers()
if err != nil {
return err
} }
err = s.prepareInternalProxy() err = s.prepareInternalProxy()
if err != nil { if err != nil {
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err) return fmt.Errorf("preparing internal proxy: %w", err)
} }
return s.bootstrap, nil return nil
} }
// setupFallbackDNS initializes the fallback DNS servers. // setupFallbackDNS initializes the fallback DNS servers.
@@ -743,10 +707,16 @@ func validateBlockingMode(
func (s *Server) prepareInternalProxy() (err error) { func (s *Server) prepareInternalProxy() (err error) {
srvConf := s.conf srvConf := s.conf
conf := &proxy.Config{ conf := &proxy.Config{
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: 4096, CacheSizeBytes: 4096,
UpstreamConfig: srvConf.UpstreamConfig, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
MaxGoroutines: s.conf.MaxGoroutines, UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
} }
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
@@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) {
} }
} }
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers { for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
} }

View File

@@ -2,7 +2,6 @@ package dnsforward
import ( import (
"cmp" "cmp"
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@@ -491,19 +490,10 @@ func TestServerRace(t *testing.T) {
} }
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
safeSearchConf := filtering.SafeSearchConfig{ safeSearchConf := filtering.SafeSearchConfig{
Enabled: true, Enabled: true,
Google: true, Google: true,
Yandex: true, Yandex: true,
CustomResolver: resolver,
} }
filterConf := &filtering.Config{ filterConf := &filtering.Config{
@@ -540,7 +530,6 @@ func TestSafeSearch(t *testing.T) {
client := &dns.Client{} client := &dns.Client{}
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56}) yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct { testCases := []struct {
host string host string
@@ -564,19 +553,19 @@ func TestSafeSearch(t *testing.T) {
wantCNAME: "", wantCNAME: "",
}, { }, {
host: "www.google.com.", host: "www.google.com.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}, { }, {
host: "www.google.com.af.", host: "www.google.com.af.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}, { }, {
host: "www.google.be.", host: "www.google.be.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}, { }, {
host: "www.google.by.", host: "www.google.by.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}} }}
@@ -593,12 +582,15 @@ func TestSafeSearch(t *testing.T) {
cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0]) cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
assert.Equal(t, tc.wantCNAME, cname.Target) assert.Equal(t, tc.wantCNAME, cname.Target)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[1])
assert.NotEmpty(t, a.A)
} else { } else {
require.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
}
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1]) a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A) assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
}
}) })
} }
} }

View File

@@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
res *filtering.Result, res *filtering.Result,
pctx *proxy.DNSContext, pctx *proxy.DNSContext,
) (err error) { ) (err error) {
resp := s.makeResponse(req) resp := s.replyCompressed(req)
dnsrr := res.DNSRewriteResult dnsrr := res.DNSRewriteResult
if dnsrr == nil { if dnsrr == nil {
return errors.Error("no dns rewrite rule content") return errors.Error("no dns rewrite rule content")

View File

@@ -1,57 +1,17 @@
package dnsforward package dnsforward
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"slices" "slices"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// beforeRequestHandler is the handler that is called before any other
// processing, including logs. It performs access checks and puts the client
// ID, if there is one, into the server's cache.
func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil
}
// clientRequestFilteringSettings looks up client filtering settings using the // clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx. // client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
@@ -71,6 +31,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
req := pctx.Req req := pctx.Req
q := req.Question[0] q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".") host := strings.TrimSuffix(q.Name, ".")
resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts) resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts)
if err != nil { if err != nil {
return nil, fmt.Errorf("checking host %q: %w", host, err) return nil, fmt.Errorf("checking host %q: %w", host, err)
@@ -79,22 +40,15 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
// TODO(a.garipov): Make CheckHost return a pointer. // TODO(a.garipov): Make CheckHost return a pointer.
res = &resVal res = &resVal
switch { switch {
case res.IsFiltered: case isRewrittenCNAME(res):
log.Debug(
"dnsforward: host %q is filtered, reason: %q; rule: %q",
host,
res.Reason,
res.Rules[0].Text,
)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0:
// Resolve the new canonical name, not the original host name. The // Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse. // original question is readded in processFilteringAfterResponse.
dctx.origQuestion = q dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName) req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.Reason == filtering.Rewritten: case res.IsFiltered:
log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName) pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts): case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, pctx); err != nil { if err = s.filterDNSRewrite(req, res, pctx); err != nil {
@@ -105,6 +59,17 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
return res, err return res, err
} }
// isRewrittenCNAME returns true if the request considered to be rewritten with
// CNAME and has no resolved IPs.
func isRewrittenCNAME(res *filtering.Result) (ok bool) {
return res.Reason.In(
filtering.Rewritten,
filtering.RewrittenRule,
filtering.FilteredSafeSearch) &&
res.CanonName != "" &&
len(res.IPList) == 0
}
// checkHostRules checks the host against filters. It is safe for concurrent // checkHostRules checks the host against filters. It is safe for concurrent
// use. // use.
func (s *Server) checkHostRules( func (s *Server) checkHostRules(

View File

@@ -1,6 +1,7 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -261,55 +262,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
} }
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid. // validate returns an error if any field of req is invalid.
// //
// TODO(s.chzhen): Parse, don't validate. // TODO(s.chzhen): Parse, don't validate.
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validate(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
defer func() { err = errors.Annotate(err, "validating dns config: %w") }() defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
err = req.validateUpstreamDNSServers(privateNets) err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return err return err
@@ -342,20 +305,77 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return nil return nil
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkPrivateRDNS returns an error if the configuration of the private RDNS is
// not valid.
func (req *jsonDNSConfig) checkPrivateRDNS(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
if (req.UsePrivateRDNS == nil || !*req.UsePrivateRDNS) && req.LocalPTRUpstreams == nil {
return nil
}
addrs := cmp.Or(req.LocalPTRUpstreams, &[]string{})
uc, err := newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, &upstream.Options{})
err = errors.WithDeferred(err, uc.Close())
if err != nil {
return fmt.Errorf("private upstream servers: %w", err)
}
return nil
}
// validateUpstreamDNSServers returns an error if any field of req is invalid. // validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validateUpstreamDNSServers(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
var uc *proxy.UpstreamConfig
opts := &upstream.Options{}
if req.Upstreams != nil { if req.Upstreams != nil {
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{}) uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil { if err != nil {
return fmt.Errorf("upstream servers: %w", err) return fmt.Errorf("upstream servers: %w", err)
} }
} }
if req.LocalPTRUpstreams != nil { err = req.checkPrivateRDNS(ownAddrs, sysResolvers, privateNets)
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) if err != nil {
if err != nil { // Don't wrap the error since it's informative enough as is.
return fmt.Errorf("private upstream servers: %w", err) return err
}
} }
err = req.checkBootstrap() err = req.checkBootstrap()
@@ -364,10 +384,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
return err return err
} }
err = req.checkFallbacks() if req.Fallbacks != nil {
if err != nil { uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts)
// Don't wrap the error since it's informative enough as is. err = errors.WithDeferred(err, uc.Close())
return err if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
} }
return nil return nil
@@ -436,7 +458,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
err = req.validate(s.privateNets) // TODO(e.burkov): Consider prebuilding this set on startup.
ourAddrs, err := s.conf.ourAddrsSet()
if err != nil {
// TODO(e.burkov): Put into openapi.
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
return
}
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@@ -587,7 +618,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
} }
var boots []*upstream.UpstreamResolver var boots []*upstream.UpstreamResolver
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts) opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)

View File

@@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "", wantSet: "",
}, { }, {
name: "local_ptr_upstreams_bad", name: "local_ptr_upstreams_bad",
wantSet: `validating dns config: ` + wantSet: `validating dns config: private upstream servers: ` +
`private upstream servers: checking domain-specific upstreams: ` + `bad arpa domain name "non.arpa": not a reversed ip network`,
`bad arpa domain name "non.arpa.": not a reversed ip network`,
}, { }, {
name: "local_ptr_upstreams_null", name: "local_ptr_upstreams_null",
wantSet: "", wantSet: "",
@@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
} }
} }
func TestValidateUpstreamsPrivate(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world.": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper() t.Helper()

View File

@@ -11,17 +11,21 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// makeResponse creates a DNS response by req and sets necessary flags. It also // TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
// guarantees that req.Question will be not empty. // template. Also extract all the methods to a separate entity.
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}
resp.SetReply(req) // reply creates a DNS response for req.
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, code)
resp.RecursionAvailable = true
return resp
}
// replyCompressed creates a DNS response for req and sets the compress flag.
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeSuccess)
resp.Compress = true
return resp return resp
} }
@@ -48,10 +52,10 @@ func (s *Server) genDNSFilterMessage(
) (resp *dns.Msg) { ) (resp *dns.Msg) {
req := dctx.Req req := dctx.Req
qt := req.Question[0].Qtype qt := req.Question[0].Qtype
if qt != dns.TypeA && qt != dns.TypeAAAA { if qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeHTTPS {
m, _, _ := s.dnsFilter.BlockingMode() m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP { if m == filtering.BlockingModeNullIP {
return s.makeResponse(req) return s.replyCompressed(req)
} }
return s.newMsgNODATA(req) return s.newMsgNODATA(req)
@@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
// getCNAMEWithIPs generates a filtered response to req for with CNAME record // getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips. // and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) { func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
originalName := req.Question[0].Name originalName := req.Question[0].Name
@@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeNullIP: case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req) return s.makeResponseNullIP(req)
case filtering.BlockingModeNXDOMAIN: case filtering.BlockingModeNXDOMAIN:
return s.genNXDomain(req) return s.NewMsgNXDOMAIN(req)
case filtering.BlockingModeREFUSED: case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req) return s.makeResponseREFUSED(req)
default: default:
log.Error("dnsforward: invalid blocking mode %q", mode) log.Error("dnsforward: invalid blocking mode %q", mode)
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
@@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
// genDNSFilterMessage. // genDNSFilterMessage.
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp return resp
} }
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp return resp
} }
@@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// Go on and return an empty response. // Go on and return an empty response.
} }
resp = s.makeResponse(req) resp = s.replyCompressed(req)
resp.Answer = ans resp.Answer = ans
return resp return resp
@@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
case dns.TypeAAAA: case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()}) resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
default: default:
resp = s.makeResponse(req) resp = s.replyCompressed(req)
} }
return resp return resp
@@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if newAddr == "" { if newAddr == "" {
log.Info("dnsforward: block host is not specified") log.Info("dnsforward: block host is not specified")
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
ip, err := netip.ParseAddr(newAddr) ip, err := netip.ParseAddr(newAddr)
@@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if prx == nil { if prx == nil {
log.Debug("dnsforward: %s", srvClosedErr) log.Debug("dnsforward: %s", srvClosedErr)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
err = prx.Resolve(newContext) err = prx.Resolve(newContext)
if err != nil { if err != nil {
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err) log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
resp := s.makeResponse(request) resp := s.replyCompressed(request)
if newContext.Res != nil { if newContext.Res != nil {
for _, answer := range newContext.Res.Answer { for _, answer := range newContext.Res.Answer {
answer.Header().Name = request.Question[0].Name answer.Header().Name = request.Question[0].Name
@@ -342,47 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return resp return resp
} }
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return false, nil
}
pctx.Res = s.makeResponseREFUSED(pctx.Req)
return true, nil
}
// Create REFUSED DNS response // Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg { func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
resp := dns.Msg{} return s.reply(req, dns.RcodeRefused)
resp.SetRcode(request, dns.RcodeRefused)
resp.RecursionAvailable = true
return &resp
} }
// newMsgNODATA returns a properly initialized NODATA response. // newMsgNODATA returns a properly initialized NODATA response.
// //
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) { func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess) resp = s.reply(req, dns.RcodeSuccess)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(req) resp.Ns = s.genSOA(req)
return resp return resp
} }
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request)
return &resp
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR { func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := "" zone := ""
if len(request.Question) > 0 { if len(request.Question) > 0 {
@@ -414,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
if len(zone) > 0 && zone[0] != '.' { if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone soa.Mbox += zone
} }
return []dns.RR{&soa} return []dns.RR{&soa}
} }
// type check
var _ proxy.MessageConstructor = (*Server)(nil)
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNameError)
resp.Ns = s.genSOA(req)
return resp
}
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return s.reply(req, dns.RcodeServerFailure)
}
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNotImplemented)
// Most of the Internet and especially the inner core has an MTU of at least
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
//
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
const maxUDPPayload = 1452
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
// explicitly set it.
resp.SetEdns0(maxUDPPayload, false)
return resp
}

View File

@@ -1,20 +1,17 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"encoding/binary" "encoding/binary"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -34,11 +31,6 @@ type dnsContext struct {
// response is modified by filters. // response is modified by filters.
origResp *dns.Msg origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from a PTR request if it
// was parsed successfully and belongs to one of the locally served IP
// ranges.
unreversedReqIP netip.Addr
// err is the error returned from a processing function. // err is the error returned from a processing function.
err error err error
@@ -63,10 +55,6 @@ type dnsContext struct {
// responseAD shows if the response had the AD bit set. // responseAD shows if the response had the AD bit set.
responseAD bool responseAD bool
// isLocalClient shows if client's IP address is from locally served
// network.
isLocalClient bool
// isDHCPHost is true if the request for a local domain name and the DHCP is // isDHCPHost is true if the request for a local domain name and the DHCP is
// available for this request. // available for this request.
isDHCPHost bool isDHCPHost bool
@@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
// (*proxy.Proxy).handleDNSRequest method performs it before calling the // (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler. // appropriate handler.
mods := []modProcessFunc{ mods := []modProcessFunc{
s.processRecursion,
s.processInitial, s.processInitial,
s.processDDRQuery, s.processDDRQuery,
s.processDetermineLocal,
s.processDHCPHosts, s.processDHCPHosts,
s.processRestrictLocal,
s.processDHCPAddrs, s.processDHCPAddrs,
s.processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream, s.processUpstream,
s.processFilteringAfterResponse, s.processFilteringAfterResponse,
s.ipset.process, s.ipset.process,
@@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
return nil return nil
} }
// processRecursion checks the incoming request and halts its handling by
// answering NXDOMAIN if s has tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing recursion")
defer log.Debug("dnsforward: finished processing recursion")
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// mozillaFQDN is the domain used to signal the Firefox browser to not use its // mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server. // own DoH server.
// //
@@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
} }
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN { if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
pctx.Res = s.genNXDomain(pctx.Req) pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
if q.Name == healthcheckFQDN { if q.Name == healthcheckFQDN {
// Generate a NODATA negative response to make nslookup exit with 0. // Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req) pctx.Res = s.replyCompressed(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
@@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
// //
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html. // [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
if req.Question[0].Qtype != dns.TypeSVCB { if req.Question[0].Qtype != dns.TypeSVCB {
return resp return resp
} }
@@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
return resp return resp
} }
// processDetermineLocal determines if the client's IP address is from locally
// served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local detection")
defer log.Debug("dnsforward: finished processing local detection")
rc = resultCodeSuccess
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
return rc
}
// processDHCPHosts respond to A requests if the target hostname is known to // processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and // the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA. // the request is for AAAA.
@@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
if !dctx.isLocalClient { if !pctx.IsPrivateClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
// Do not even put into query log. // Do not even put into query log.
return resultCodeFinish return resultCodeFinish
@@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
resp := s.makeResponse(req) resp := s.replyCompressed(req)
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA:
a := &dns.A{ a := &dns.A{
@@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed.
func indexFirstV4Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
if parseErr != nil {
return idx
}
idx = curIdx
}
return idx
}
// indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed.
func indexFirstV6Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
curIdx := idx - len("a.")
if curIdx > 1 && domain[curIdx-1] != '.' {
return idx
}
nibble := domain[curIdx]
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
return idx
}
idx = curIdx
}
return idx
}
// extractARPASubnet tries to convert a reversed ARPA address being a part of
// domain to an IP network. domain must be an FQDN.
//
// TODO(e.burkov): Move to golibs.
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
const (
v4Suffix = "in-addr.arpa."
v6Suffix = "ip6.arpa."
)
domain = strings.ToLower(domain)
var idx int
switch {
case strings.HasSuffix(domain, v4Suffix):
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
case strings.HasSuffix(domain, v6Suffix):
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
default:
return netip.Prefix{}, &netutil.AddrError{
Err: netutil.ErrNotAReversedSubnet,
Kind: netutil.AddrKindARPA,
Addr: domain,
}
}
return netutil.PrefixFromReversedAddr(domain[idx:])
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local restriction")
defer log.Debug("dnsforward: finished processing local restriction")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
if q.Qtype != dns.TypePTR {
// No need for restriction.
return resultCodeSuccess
}
subnet, err := extractARPASubnet(q.Name)
if err != nil {
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
log.Debug("dnsforward: request is not for arpa domain")
return resultCodeSuccess
}
log.Debug("dnsforward: parsing reversed addr: %s", err)
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
subnetAddr := subnet.Addr()
if !s.privateNets.Contains(subnetAddr) {
return resultCodeSuccess
}
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
pctx.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = subnetAddr
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA
// hostname so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
// Nothing to restrict.
return resultCodeSuccess
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the // processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server. // DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
@@ -562,23 +380,27 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
ipAddr := dctx.unreversedReqIP req := pctx.Req
if ipAddr == (netip.Addr{}) { q := req.Question[0]
pref := pctx.RequestedPrivateRDNS
// TODO(e.burkov): Consider answering authoritatively for SOA and NS
// queries.
if pref == (netip.Prefix{}) || q.Qtype != dns.TypePTR {
return resultCodeSuccess return resultCodeSuccess
} }
host := s.dhcpServer.HostByIP(ipAddr) addr := pref.Addr()
host := s.dhcpServer.HostByIP(addr)
if host == "" { if host == "" {
return resultCodeSuccess return resultCodeSuccess
} }
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host) log.Debug("dnsforward: dhcp client %s is %q", addr, host)
req := pctx.Req resp := s.replyCompressed(req)
resp := s.makeResponse(req)
ptr := &dns.PTR{ ptr := &dns.PTR{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: req.Question[0].Name, Name: q.Name,
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See // TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
// https://github.com/AdguardTeam/AdGuardHome/issues/3932. // https://github.com/AdguardTeam/AdGuardHome/issues/3932.
@@ -593,62 +415,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// processLocalPTR responds to PTR requests if the target IP is detected to be
// inside the local network and the query was not answered from DHCP.
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local ptr")
defer log.Debug("dnsforward: finished processing local ptr")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
}
ip := dctx.unreversedReqIP
if ip == (netip.Addr{}) {
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.conf.UsePrivateRDNS {
s.recDetector.add(*pctx.Req)
if err := s.localResolvers.Resolve(pctx); err != nil {
log.Debug("dnsforward: resolving private address: %s", err)
// Generate the server failure if the private upstream configuration
// is empty.
//
// This is a crutch, see TODO at [Server.localResolvers].
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
dctx.err = err
return resultCodeError
}
}
if pctx.Res == nil {
pctx.Res = s.genNXDomain(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
return resultCodeSuccess
}
// Apply filtering logic // Apply filtering logic
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req") log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req") defer log.Debug("dnsforward: finished processing filtering before req")
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
// There is no need to filter request for locally served ARPA hostname
// so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
}
if dctx.proxyCtx.Res != nil { if dctx.proxyCtx.Res != nil {
// Go on since the response is already set. // Go on since the response is already set.
return resultCodeSuccess return resultCodeSuccess
@@ -695,7 +475,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
// local domain name if there is one. // local domain name if there is one.
name := req.Question[0].Name name := req.Question[0].Name
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
return resultCodeFinish return resultCodeFinish
} }
@@ -712,21 +492,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
return resultCodeError return resultCodeError
} }
if err := prx.Resolve(pctx); err != nil { if dctx.err = prx.Resolve(pctx); dctx.err != nil {
if errors.Is(err, upstream.ErrNoUpstreams) {
// Do not even put into querylog. Currently this happens either
// when the private resolvers enabled and the request is DNS64 PTR,
// or when the client isn't considered local by prx.
//
// TODO(e.burkov): Make proxy detect local client the same way as
// AGH does.
pctx.Res = s.genNXDomain(req)
return resultCodeFinish
}
dctx.err = err
return resultCodeError return resultCodeError
} }
@@ -810,7 +576,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
} }
// Use the ClientID first, since it has a higher priority. // Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String()) id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap) upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil { if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
@@ -835,7 +601,8 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess return resultCodeSuccess
case case
filtering.Rewritten, filtering.Rewritten,
filtering.RewrittenRule: filtering.RewrittenRule,
filtering.FilteredSafeSearch:
if dctx.origQuestion.Name == "" { if dctx.origQuestion.Name == "" {
// origQuestion is set in case we get only CNAME without IP from // origQuestion is set in case we get only CNAME without IP from
@@ -845,11 +612,10 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
pctx := dctx.proxyCtx pctx := dctx.proxyCtx
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
if len(pctx.Res.Answer) > 0 {
rr := s.genAnswerCNAME(pctx.Req, res.CanonName) rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
answer := append([]dns.RR{rr}, pctx.Res.Answer...) answer := append([]dns.RR{rr}, pctx.Res.Answer...)
pctx.Res.Answer = answer pctx.Res.Answer = answer
}
return resultCodeSuccess return resultCodeSuccess
default: default:

View File

@@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@@ -375,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f return f
} }
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
}
s.processDetermineLocal(dctx)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
const ( const (
localDomainSuffix = "lan" localDomainSuffix = "lan"
@@ -482,9 +445,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: tc.isLocalCli,
}, },
isLocalClient: tc.isLocalCli,
} }
res := s.processDHCPHosts(dctx) res := s.processDHCPHosts(dctx)
@@ -617,9 +580,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: true,
}, },
isLocalClient: true,
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -654,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
} }
} }
func TestServer_ProcessRestrictLocal(t *testing.T) { // TODO(e.burkov): Rewrite this test to use the whole server instead of just
// testing the [handleDNSRequest] method. See comment on
// "from_external_for_local" test case.
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
intAddr := netip.MustParseAddr("192.168.1.1")
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
require.NoError(t, err)
extAddr := netip.MustParseAddr("254.253.252.1")
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
require.NoError(t, err)
const ( const (
extPTRQuestion = "251.252.253.254.in-addr.arpa." extPTRAnswer = "host1.example.net."
extPTRAnswer = "host1.example.net." intPTRAnswer = "some.local-client."
intPTRQuestion = "1.1.168.192.in-addr.arpa."
intPTRAnswer = "some.local-client."
) )
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
@@ -692,123 +664,165 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
startDeferStop(t, s) startDeferStop(t, s)
testCases := []struct { testCases := []struct {
name string name string
want string question string
question net.IP wantErr error
cliAddr netip.AddrPort wantAns []dns.RR
wantLen int isPrivate bool
}{{ }{{
name: "from_local_to_external", name: "from_local_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_local", // In theory this case is not reproducible because [proxy.Proxy] should
want: "", // respond to such queries with NXDOMAIN before they reach
question: net.IP{192, 168, 1, 1}, // [Server.handleDNSRequest].
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"), name: "from_external_for_local",
wantLen: 0, question: intPTRQuestion,
wantErr: upstream.ErrNoUpstreams,
wantAns: nil,
isPrivate: false,
}, { }, {
name: "from_local_for_local", name: "from_local_for_local",
want: "some.local-client.", question: intPTRQuestion,
question: net.IP{192, 168, 1, 1}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(intPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(intPTRAnswer) + 1),
},
Ptr: dns.Fqdn(intPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_external", name: "from_external_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
reqAddr, err := dns.ReverseAddr(tc.question.String()) pref, extErr := netutil.ExtractReversedAddr(tc.question)
require.NoError(t, err) require.NoError(t, extErr)
req := createTestMessageWithType(reqAddr, dns.TypePTR)
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
pctx := &proxy.DNSContext{ pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP, Req: req,
Req: req, IsPrivateClient: tc.isPrivate,
Addr: tc.cliAddr, }
// TODO(e.burkov): Configure the subnet set properly.
if netutil.IsLocallyServed(pref.Addr()) {
pctx.RequestedPrivateRDNS = pref
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, pctx) err = s.handleDNSRequest(s.dnsProxy, pctx)
require.NoError(t, err) require.ErrorIs(t, err, tc.wantErr)
require.NotNil(t, pctx.Res)
require.Len(t, pctx.Res.Answer, tc.wantLen)
if tc.wantLen > 0 { require.NotNil(t, pctx.Res)
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, tc.wantAns, pctx.Res.Answer)
}
}) })
} }
} }
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { func TestServer_ProcessUpstream_localPTR(t *testing.T) {
const locDomain = "some.local." const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa." const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer( newPrxCtx := func() (prxCtx *proxy.DNSContext) {
t, return &proxy.DNSContext{
&filtering.Config{ Addr: testClientAddrPort,
BlockingMode: filtering.BlockingModeDefault, Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}, IsPrivateClient: true,
ServerConfig{ RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
var proxyCtx *proxy.DNSContext
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
} }
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
}
s.conf.UsePrivateRDNS = use
} }
t.Run("enabled", func(t *testing.T) { t.Run("enabled", func(t *testing.T) {
setup(true) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeSuccess, rc) require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, proxyCtx.Res.Answer) require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, locDomain, ptr.Ptr)
}) })
t.Run("disabled", func(t *testing.T) { t.Run("disabled", func(t *testing.T) {
setup(false) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeFinish, rc) require.Equal(t, resultCodeError, rc)
require.Empty(t, proxyCtx.Res.Answer) require.Empty(t, pctx.Res.Answer)
}) })
} }
@@ -826,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil)) assert.Empty(t, ipStringFromAddr(nil))
}) })
} }
// TODO(e.burkov): Add fuzzing when moving to golibs.
func TestExtractARPASubnet(t *testing.T) {
const (
v4Suf = `in-addr.arpa.`
v4Part = `2.1.` + v4Suf
v4Whole = `4.3.` + v4Part
v6Suf = `ip6.arpa.`
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
)
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
testCases := []struct {
want netip.Prefix
name string
domain string
wantErr string
}{{
want: netip.Prefix{},
name: "not_an_arpa",
domain: "some.domain.name.",
wantErr: `bad arpa domain name "some.domain.name.": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "bad_domain_name",
domain: "abc.123.",
wantErr: `bad domain name "abc.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
}, {
want: v4Pref,
name: "whole_v4",
domain: v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4",
domain: v4Part,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_within_domain",
domain: "a." + v4Whole,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_additional_label",
domain: "5." + v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4_within_domain",
domain: "a." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4",
domain: "256." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4_within_domain",
domain: "a.256." + v4Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v4",
domain: v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v4_within_domain",
domain: "a." + v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: v6Pref,
name: "whole_v6",
domain: v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6",
domain: v6Part,
}, {
want: v6Pref,
name: "whole_v6_within_domain",
domain: "g." + v6Whole,
wantErr: "",
}, {
want: v6Pref,
name: "whole_v6_additional_label",
domain: "1." + v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6_within_domain",
domain: "label." + v6Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v6",
domain: v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v6_within_domain",
domain: "g." + v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
subnet, err := extractARPASubnet(tc.domain)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, subnet)
})
}
}

View File

@@ -1,115 +0,0 @@
package dnsforward
import (
"bytes"
"encoding/binary"
"time"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those. See https://github.com/golang/go/issues/29982.
const (
uint16sz = 2
uint64sz = 8
)
// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
recentRequests cache.Cache
ttl time.Duration
}
// check checks if the passed req was already sent by the server.
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
if len(msg.Question) == 0 {
return false
}
key := msgToSignature(msg)
expireData := rd.recentRequests.Get(key)
if expireData == nil {
return false
}
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
return time.Now().Before(expire)
}
// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg dns.Msg) {
now := time.Now()
if len(msg.Question) == 0 {
return
}
key := msgToSignature(msg)
expire64 := uint64(now.Add(rd.ttl).UnixNano())
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, expire64)
rd.recentRequests.Set(key, expire)
}
// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
rd.recentRequests.Clear()
}
// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
return &recursionDetector{
recentRequests: cache.New(cache.Config{
EnableLRU: true,
MaxCount: suspectsNum,
}),
ttl: ttl,
}
}
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the real
// machine's endianness is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
copy(sig[2*uint16sz:], []byte(q.Name))
return sig
}
// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
type msgSignature struct {
name [netutil.MaxDomainNameLen]byte
id uint16
qtype uint16
}
b := bytes.NewBuffer(sig)
q := msg.Question[0]
signature := msgSignature{
id: msg.Id,
qtype: q.Qtype,
}
copy(signature.name[:], q.Name)
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
log.Debug("writing message signature: %s", err)
}
return b.Bytes()
}

View File

@@ -1,148 +0,0 @@
package dnsforward
import (
"encoding/binary"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestRecursionDetector_Check(t *testing.T) {
rd := newRecursionDetector(0, 2)
const (
recID = 1234
recTTL = time.Hour * 100
)
const nonRecID = recID * 2
sampleQuestion := dns.Question{
Name: "some.domain",
Qtype: dns.TypeAAAA,
}
sampleMsg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: recID,
},
Question: []dns.Question{sampleQuestion},
}
// Manually add the message with big ttl.
key := msgToSignature(sampleMsg)
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
rd.recentRequests.Set(key, expire)
// Add an expired message.
sampleMsg.Id = nonRecID
rd.add(sampleMsg)
testCases := []struct {
name string
questions []dns.Question
id uint16
want bool
}{{
name: "recurrent",
questions: []dns.Question{sampleQuestion},
id: recID,
want: true,
}, {
name: "not_suspected",
questions: []dns.Question{sampleQuestion},
id: recID + 1,
want: false,
}, {
name: "expired",
questions: []dns.Question{sampleQuestion},
id: nonRecID,
want: false,
}, {
name: "empty",
questions: []dns.Question{},
id: nonRecID,
want: false,
}}
for _, tc := range testCases {
sampleMsg.Id = tc.id
sampleMsg.Question = tc.questions
t.Run(tc.name, func(t *testing.T) {
detected := rd.check(sampleMsg)
assert.Equal(t, tc.want, detected)
})
}
}
func TestRecursionDetector_Suspect(t *testing.T) {
rd := newRecursionDetector(0, 1)
testCases := []struct {
name string
msg dns.Msg
want int
}{{
name: "simple",
msg: dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: "some.domain",
Qtype: dns.TypeA,
}},
},
want: 1,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
})
}
}
var sink []byte
func BenchmarkMsgToSignature(b *testing.B) {
const name = "some.not.very.long.host.name"
msg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypeAAAA,
}},
}
b.Run("efficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignature(msg)
}
assert.NotEmpty(b, sink)
})
b.Run("inefficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignatureSlow(msg)
}
assert.NotEmpty(b, sink)
})
}

View File

@@ -29,7 +29,13 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr) log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr)
ids := []string{ipStr, dctx.clientID} ids := []string{ipStr}
if dctx.clientID != "" {
// Use the ClientID first because it has a higher priority. Filters
// have the same priority, see applyAdditionalFiltering.
ids = []string{dctx.clientID, ipStr}
}
qt, cl := q.Qtype, q.Qclass qt, cl := q.Qtype, q.Qclass
// Synchronize access to s.queryLog and s.stats so they won't be suddenly // Synchronize access to s.queryLog and s.stats so they won't be suddenly

View File

@@ -2,90 +2,77 @@ package dnsforward
import ( import (
"fmt" "fmt"
"net/netip"
"os"
"slices" "slices"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
) )
// loadUpstreams parses upstream DNS servers from the configured file or from // newBootstrap returns a bootstrap resolver based on the configuration of s.
// the configuration itself. // boots are the upstream resolvers that should be closed after use. r is the
func (s *Server) loadUpstreams() (upstreams []string, err error) { // actual bootstrap resolver, which may include the system hosts.
if s.conf.UpstreamDNSFileName == "" { //
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil // TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func newBootstrap(
addrs []string,
etcHosts upstream.Resolver,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
} }
var data []byte boots, err = aghnet.ParseBootstraps(addrs, opts)
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err) // Don't wrap the error, since it's informative enough as is.
return nil, nil, err
} }
upstreams = stringutil.SplitTrimmed(string(data), "\n") var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName) if etcHosts != nil {
r = upstream.ConsequentResolver{etcHosts, parallel}
} else {
r = parallel
}
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil return r, boots, nil
} }
// prepareUpstreamSettings sets upstream DNS server settings. // newUpstreamConfig returns the upstream configuration based on upstreams. If
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { // upstreams slice specifies no default upstreams, defaultUpstreams are used to
// Load upstreams either from the file, or from the settings // create upstreams with no domain specifications. opts are used when creating
var upstreams []string // upstream configuration.
upstreams, err = s.loadUpstreams() func newUpstreamConfig(
if err != nil {
return fmt.Errorf("loading upstreams: %w", err)
}
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
})
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
return nil
}
// prepareUpstreamConfig returns the upstream configuration based on upstreams
// and configuration of s.
func (s *Server) prepareUpstreamConfig(
upstreams []string, upstreams []string,
defaultUpstreams []string, defaultUpstreams []string,
opts *upstream.Options, opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) { ) (uc *proxy.UpstreamConfig, err error) {
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts) uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing upstream config: %w", err) return uc, fmt.Errorf("parsing upstreams: %w", err)
} }
if len(uc.Upstreams) == 0 && defaultUpstreams != nil { if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 {
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams) log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
var defaultUpstreamConfig *proxy.UpstreamConfig var defaultUpstreamConfig *proxy.UpstreamConfig
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts) defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing default upstreams: %w", err) return uc, fmt.Errorf("parsing default upstreams: %w", err)
} }
uc.Upstreams = defaultUpstreamConfig.Upstreams uc.Upstreams = defaultUpstreamConfig.Upstreams
@@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig(
return uc, nil return uc, nil
} }
// newPrivateConfig creates an upstream configuration for resolving PTR records
// for local addresses. The configuration is built either from the provided
// addresses or from the system resolvers. unwanted filters the resulting
// upstream configuration.
func newPrivateConfig(
addrs []string,
unwanted addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) {
confNeedsFiltering := len(addrs) > 0
if confNeedsFiltering {
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
addrs = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
addrs = append(addrs, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs)
uc, err = proxy.ParseUpstreamsConfig(addrs, opts)
if err != nil {
return uc, fmt.Errorf("preparing private upstreams: %w", err)
}
if !confNeedsFiltering {
return uc, nil
}
err = filterOutAddrs(uc, unwanted)
if err != nil {
return uc, fmt.Errorf("filtering private upstreams: %w", err)
}
// Prevalidate the config to catch the exact error before creating proxy.
// See TODO on [PrivateRDNSError].
err = proxy.ValidatePrivateConfig(uc, privateNets)
if err != nil {
return uc, &PrivateRDNSError{err: err}
}
return uc, nil
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration // UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration. // depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
@@ -130,85 +165,9 @@ func setProxyUpstreamMode(
return nil return nil
} }
// createBootstrap returns a bootstrap resolver based on the configuration of s.
// boots are the upstream resolvers that should be closed after use. r is the
// actual bootstrap resolver, which may include the system hosts.
//
// TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func (s *Server) createBootstrap(
addrs []string,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
if s.etcHosts != nil {
r = upstream.ConsequentResolver{s.etcHosts, parallel}
} else {
r = parallel
}
return r, boots, nil
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. // IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream // This function is useful for filtering out non-upstream lines from upstream
// configs. // configs.
func IsCommentOrEmpty(s string) (ok bool) { func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' return len(s) == 0 || s[0] == '#'
} }
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}

View File

@@ -559,6 +559,8 @@ type Result struct {
Reason Reason `json:",omitempty"` Reason Reason `json:",omitempty"`
// IsFiltered is true if the request is filtered. // IsFiltered is true if the request is filtered.
//
// TODO(d.kolyshev): Get rid of this flag.
IsFiltered bool `json:",omitempty"` IsFiltered bool `json:",omitempty"`
} }

View File

@@ -1,7 +1,6 @@
package rulelist_test package rulelist_test
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
@@ -28,14 +27,12 @@ func TestEngine_Refresh(t *testing.T) {
require.NotNil(t, eng) require.NotNil(t, eng)
testutil.CleanupAndRequireSuccess(t, eng.Close) testutil.CleanupAndRequireSuccess(t, eng.Close)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
buf := make([]byte, rulelist.DefaultRuleBufSize) buf := make([]byte, rulelist.DefaultRuleBufSize)
cli := &http.Client{ cli := &http.Client{
Timeout: testTimeout, Timeout: testTimeout,
} }
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -1,7 +1,6 @@
package rulelist_test package rulelist_test
import ( import (
"context"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -67,14 +66,12 @@ func TestFilter_Refresh(t *testing.T) {
require.NotNil(t, f) require.NotNil(t, f)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
buf := make([]byte, rulelist.DefaultRuleBufSize) buf := make([]byte, rulelist.DefaultRuleBufSize)
cli := &http.Client{ cli := &http.Client{
Timeout: testTimeout, Timeout: testTimeout,
} }
ctx := testutil.ContextWithTimeout(t, testTimeout)
res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -1,7 +1,5 @@
package filtering package filtering
import "github.com/miekg/dns"
// SafeSearch interface describes a service for search engines hosts rewrites. // SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface { type SafeSearch interface {
// CheckHost checks host with safe search filter. CheckHost must be safe // CheckHost checks host with safe search filter. CheckHost must be safe
@@ -16,9 +14,6 @@ type SafeSearch interface {
// SafeSearchConfig is a struct with safe search related settings. // SafeSearchConfig is a struct with safe search related settings.
type SafeSearchConfig struct { type SafeSearchConfig struct {
// CustomResolver is the resolver used by safe search.
CustomResolver Resolver `yaml:"-" json:"-"`
// Enabled indicates if safe search is enabled entirely. // Enabled indicates if safe search is enabled entirely.
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
@@ -40,13 +35,7 @@ func (d *DNSFilter) checkSafeSearch(
qtype uint16, qtype uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ProtectionEnabled || if d.safeSearch == nil || !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
!setts.SafeSearchEnabled ||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
return Result{}, nil
}
if d.safeSearch == nil {
return Result{}, nil return Result{}, nil
} }

View File

@@ -3,11 +3,9 @@ package safesearch
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"encoding/gob" "encoding/gob"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
@@ -67,7 +65,6 @@ type Default struct {
engine *urlfilter.DNSEngine engine *urlfilter.DNSEngine
cache cache.Cache cache cache.Cache
resolver filtering.Resolver
logPrefix string logPrefix string
cacheTTL time.Duration cacheTTL time.Duration
} }
@@ -80,11 +77,6 @@ func NewDefault(
cacheSize uint, cacheSize uint,
cacheTTL time.Duration, cacheTTL time.Duration,
) (ss *Default, err error) { ) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}
ss = &Default{ ss = &Default{
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
@@ -92,7 +84,6 @@ func NewDefault(
EnableLRU: true, EnableLRU: true,
MaxSize: cacheSize, MaxSize: cacheSize,
}), }),
resolver: resolver,
// Use %s, because the client safe-search names already contain double // Use %s, because the client safe-search names already contain double
// quotes. // quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name), logPrefix: fmt.Sprintf("safesearch %s: ", name),
@@ -170,8 +161,11 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start)) ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}() }()
if qtype != dns.TypeA && qtype != dns.TypeAAAA { switch qtype {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype)) case dns.TypeA, dns.TypeAAAA, dns.TypeHTTPS:
// Go on.
default:
return filtering.Result{}, nil
} }
// Check cache. Return cached result if it was found // Check cache. Return cached result if it was found
@@ -195,6 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
} }
res = *fltRes res = *fltRes
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
// saving results to cache.
ss.setCacheResult(host, qtype, res) ss.setCacheResult(host, qtype, res)
return res, nil return res, nil
@@ -223,20 +220,13 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
} }
// newResult creates Result object from rewrite rule. qtype must be either // newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the // [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is
// empty result is converted into a NODATA response. // never nil, so that the empty result is converted into a NODATA response.
//
// TODO(a.garipov): Use the main rewrite result mechanism used in
// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save
// them in the safe search cache.
func (ss *Default) newResult( func (ss *Default) newResult(
rewrite *rules.DNSRewrite, rewrite *rules.DNSRewrite,
qtype rules.RRType, qtype rules.RRType,
) (res *filtering.Result, err error) { ) (res *filtering.Result, err error) {
res = &filtering.Result{ res = &filtering.Result{
Rules: []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
}},
Reason: filtering.FilteredSafeSearch, Reason: filtering.FilteredSafeSearch,
IsFiltered: true, IsFiltered: true,
} }
@@ -247,69 +237,19 @@ func (ss *Default) newResult(
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value) return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value)
} }
res.Rules[0].IP = ip res.Rules = []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
IP: ip,
}}
return res, nil return res, nil
} }
host := rewrite.NewCNAME res.CanonName = rewrite.NewCNAME
if host == "" {
return res, nil
}
res.CanonName = host
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, fmt.Errorf("resolving cname: %w", err)
}
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
addr := fitToProto(ip, qtype)
if addr == (netip.Addr{}) {
continue
}
// TODO(e.burkov): Rules[0]?
res.Rules[0].IP = addr
}
return res, nil return res, nil
} }
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) {
if ip4 := ip.To4(); qtype == dns.TypeA {
if ip4 != nil {
return netip.AddrFrom4([4]byte(ip4))
}
} else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA {
return netip.AddrFrom16([16]byte(ip))
}
return netip.Addr{}
}
// setCacheResult stores data in cache for host. qtype is expected to be either // setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA]. // [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) { func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {

View File

@@ -1,13 +1,10 @@
package safesearch package safesearch
import ( import (
"context"
"net"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
} }
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.Equal(t, wantIP, res.Rules[0].IP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, wantIP, cachedValue.Rules[0].IP)
}
const googleHost = "www.google.com" const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite var dnsRewriteSink *rules.DNSRewrite

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
@@ -31,8 +30,6 @@ const (
// testConf is the default safe search configuration for tests. // testConf is the default safe search configuration for tests.
var testConf = filtering.SafeSearchConfig{ var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true, Enabled: true,
Bing: true, Bing: true,
@@ -52,61 +49,60 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
// Check host for each domain. hosts := []string{
for _, host := range []string{
"yandex.ru", "yandex.ru",
"yAndeX.ru", "yAndeX.ru",
"YANdex.COM", "YANdex.COM",
"yandex.by", "yandex.by",
"yandex.kz", "yandex.kz",
"www.yandex.com", "www.yandex.com",
} {
var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
} }
}
func TestDefault_CheckHost_yandexAAAA(t *testing.T) { testCases := []struct {
conf := testConf want netip.Addr
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) name string
require.NoError(t, err) qt uint16
}{{
want: yandexIP,
name: "a",
qt: dns.TypeA,
}, {
want: netip.Addr{},
name: "aaaa",
qt: dns.TypeAAAA,
}, {
want: netip.Addr{},
name: "https",
qt: dns.TypeHTTPS,
}}
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA) for _, tc := range testCases {
require.NoError(t, err) t.Run(tc.name, func(t *testing.T) {
for _, host := range hosts {
// Check host for each domain.
var res filtering.Result
res, err = ss.CheckHost(host, tc.qt)
require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule if tc.want == (netip.Addr{}) {
// with a nil IP address. This isn't really necessary and should be changed assert.Empty(t, res.Rules)
// once the TODO in [safesearch.Default.newResult] is resolved. } else {
require.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP) rule := res.Rules[0]
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) assert.Equal(t, tc.want, rule.IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, rule.FilterListID)
}
}
})
}
} }
func TestDefault_CheckHost_google(t *testing.T) { func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.Resolver{ ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
conf := testConf
conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
// Check host for each domain. // Check host for each domain.
@@ -125,11 +121,9 @@ func TestDefault_CheckHost_google(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
require.Len(t, res.Rules, 1) assert.Equal(t, "forcesafesearch.google.com", res.CanonName)
assert.Empty(t, res.Rules)
assert.Equal(t, wantIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
}) })
} }
} }
@@ -154,17 +148,7 @@ func (r *testResolver) LookupIP(
} }
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
conf := testConf ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
conf.CustomResolver = &testResolver{
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
assert.Equal(t, "ip6", network)
assert.Equal(t, "safe.duckduckgo.com", host)
return nil, nil
},
}
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but // The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
@@ -174,14 +158,9 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule assert.Equal(t, "safe.duckduckgo.com", res.CanonName)
// with a nil IP address. This isn't really necessary and should be changed assert.Empty(t, res.Rules)
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
} }
func TestDefault_Update(t *testing.T) { func TestDefault_Update(t *testing.T) {

View File

@@ -249,8 +249,6 @@ func (o *clientObject) toPersistent(
} }
if o.SafeSearchConf.Enabled { if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err = cli.SetSafeSearch( err = cli.SetSafeSearch(
o.SafeSearchConf, o.SafeSearchConf,
filteringConf.SafeSearchCacheSize, filteringConf.SafeSearchCacheSize,
@@ -414,7 +412,11 @@ func (clients *clientsContainer) clientOrArtificial(
}() }()
cli, ok := clients.find(id) cli, ok := clients.find(id)
if ok { if !ok {
cli = clients.clientIndex.FindByIPWithoutZone(ip)
}
if cli != nil {
return &querylog.Client{ return &querylog.Client{
Name: cli.Name, Name: cli.Name,
IgnoreQueryLog: cli.IgnoreQueryLog, IgnoreQueryLog: cli.IgnoreQueryLog,

View File

@@ -203,15 +203,24 @@ type dnsConfig struct {
// resolver should be used. // resolver should be used.
PrivateNets []netutil.Prefix `yaml:"private_networks"` PrivateNets []netutil.Prefix `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from // UsePrivateRDNS enables resolving requests containing a private IP address
// locally-served networks should be resolved via private PTR resolvers. // using private reverse DNS resolvers. See PrivateRDNSResolvers.
//
// TODO(e.burkov): Rename in YAML.
UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"` UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"`
// LocalPTRResolvers is the slice of addresses to be used as upstreams // PrivateRDNSResolvers is the slice of addresses to be used as upstreams
// for PTR queries for locally-served networks. // for private requests. It's only used for PTR, SOA, and NS queries,
LocalPTRResolvers []string `yaml:"local_ptr_upstreams"` // containing an ARPA subdomain, came from the the client with private
// address. The address considered private according to PrivateNets.
//
// If empty, the OS-provided resolvers are used for private requests.
PrivateRDNSResolvers []string `yaml:"local_ptr_upstreams"`
// UseDNS64 defines if DNS64 should be used for incoming requests. // UseDNS64 defines if DNS64 should be used for incoming requests. Requests
// of type PTR for addresses within the configured prefixes will be resolved
// via [PrivateRDNSResolvers], so those should be valid and UsePrivateRDNS
// be set to true.
UseDNS64 bool `yaml:"use_dns64"` UseDNS64 bool `yaml:"use_dns64"`
// DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64. // DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64.
@@ -658,7 +667,7 @@ func (c *configuration) write() (err error) {
dns := &config.DNS dns := &config.DNS
dns.Config = c dns.Config = c
dns.LocalPTRResolvers = s.LocalPTRResolvers() dns.PrivateRDNSResolvers = s.LocalPTRResolvers()
addrProcConf := s.AddrProcConfig() addrProcConf := s.AddrProcConfig()
config.Clients.Sources.RDNS = addrProcConf.UseRDNS config.Clients.Sources.RDNS = addrProcConf.UseRDNS

View File

@@ -1,7 +1,6 @@
package home package home
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -18,7 +17,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@@ -157,14 +155,12 @@ func initDNSServer(
return fmt.Errorf("newServerConfig: %w", err) return fmt.Errorf("newServerConfig: %w", err)
} }
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
// should go away once the private RDNS resolution is moved to the proxy.
var locResErr *dnsforward.LocalResolversError
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
log.Info("WARNING: no local resolvers configured while private RDNS " +
"resolution enabled, trying to disable")
dnsConf.UsePrivateRDNS = false dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
} }
@@ -245,7 +241,7 @@ func newServerConfig(
TLSv12Roots: Context.tlsRoots, TLSv12Roots: Context.tlsRoots,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpReg, HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.LocalPTRResolvers, LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
UseDNS64: dnsConf.UseDNS64, UseDNS64: dnsConf.UseDNS64,
DNS64Prefixes: dnsConf.DNS64Prefixes, DNS64Prefixes: dnsConf.DNS64Prefixes,
UsePrivateRDNS: dnsConf.UsePrivateRDNS, UsePrivateRDNS: dnsConf.UsePrivateRDNS,
@@ -531,36 +527,6 @@ func closeDNSServer() {
log.Debug("all dns modules are closed") log.Debug("all dns modules are closed")
} }
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
// search.
type safeSearchResolver struct{}
// type check
var _ filtering.Resolver = safeSearchResolver{}
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.Addr with IPv4 and IPv6 instances.
func (r safeSearchResolver) LookupIP(
ctx context.Context,
network string,
host string,
) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(ctx, network, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
for _, a := range addrs {
ips = append(ips, a.AsSlice())
}
return ips, nil
}
// checkStatsAndQuerylogDirs checks and returns directory paths to store // checkStatsAndQuerylogDirs checks and returns directory paths to store
// statistics and query log. // statistics and query log.
func checkStatsAndQuerylogDirs( func checkStatsAndQuerylogDirs(

View File

@@ -439,7 +439,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.ParentalBlockHost = host conf.ParentalBlockHost = host
} }
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
conf.SafeSearch, err = safesearch.NewDefault( conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf, conf.SafeSearchConf,
"default", "default",

View File

@@ -649,11 +649,6 @@ status() {
// freeBSDScript is the source of the daemon script for FreeBSD. Keep as close // freeBSDScript is the source of the daemon script for FreeBSD. Keep as close
// as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204. // as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204.
//
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
// guarantees that it will actually be the required directory.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
const freeBSDScript = `#!/bin/sh const freeBSDScript = `#!/bin/sh
# PROVIDE: {{.Name}} # PROVIDE: {{.Name}}
# REQUIRE: networking # REQUIRE: networking
@@ -667,7 +662,9 @@ name="{{.Name}}"
pidfile_child="/var/run/${name}.pid" pidfile_child="/var/run/${name}.pid"
pidfile="/var/run/${name}_daemon.pid" pidfile="/var/run/${name}_daemon.pid"
command="/usr/sbin/daemon" command="/usr/sbin/daemon"
command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirectory}}/{{.Name}}" daemon_args="-P ${pidfile} -p ${pidfile_child} -r -t ${name}"
command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}"
run_rc_command "$1" run_rc_command "$1"
` `

View File

@@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/google/renameio/v2/maybe" "github.com/google/renameio/v2/maybe"
) )
@@ -38,7 +39,7 @@ func (h *signalHandler) handle() {
if aghos.IsReconfigureSignal(sig) { if aghos.IsReconfigureSignal(sig) {
h.reconfigure() h.reconfigure()
} else if aghos.IsShutdownSignal(sig) { } else if osutil.IsShutdownSignal(sig) {
status := h.shutdown() status := h.shutdown()
h.removePID() h.removePID()
@@ -122,7 +123,8 @@ func newSignalHandler(
services: svcs, services: svcs,
} }
aghos.NotifyShutdownSignal(h.signal) notifier := osutil.DefaultSignalNotifier{}
osutil.NotifyShutdownSignal(notifier, h.signal)
aghos.NotifyReconfigureSignal(h.signal) aghos.NotifyReconfigureSignal(h.signal)
return h return h

View File

@@ -1,7 +1,6 @@
package dnssvc_test package dnssvc_test
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@@ -94,10 +93,8 @@ func TestService(t *testing.T) {
}}, }},
} }
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{} cli := &dns.Client{}
ctx := testutil.ContextWithTimeout(t, testTimeout)
var resp *dns.Msg var resp *dns.Msg
require.Eventually(t, func() (ok bool) { require.Eventually(t, func() (ok bool) {
@@ -110,10 +107,8 @@ func TestService(t *testing.T) {
assert.NotNil(t, resp) assert.NotNil(t, resp)
}) })
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) err = svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
defer cancel()
err = svc.Shutdown(ctx)
require.NoError(t, err) require.NoError(t, err)
err = upstreamSrv.Shutdown() err = upstreamSrv.Shutdown()

View File

@@ -109,12 +109,8 @@ func newTestServer(
err = svc.Start() err = svc.Start()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, func() (err error) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
}) })
c = svc.Config() c = svc.Config()

View File

@@ -31,6 +31,7 @@ type logEntry struct {
Answer []byte `json:",omitempty"` Answer []byte `json:",omitempty"`
OrigAnswer []byte `json:",omitempty"` OrigAnswer []byte `json:",omitempty"`
// TODO(s.chzhen): Use netip.Addr.
IP net.IP `json:"IP"` IP net.IP `json:"IP"`
Result filtering.Result Result filtering.Result