Compare commits
247 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7d8b9ede1 | ||
|
|
5c6bb33e3a | ||
|
|
158d4f0249 | ||
|
|
f73717ec08 | ||
|
|
1807198a9b | ||
|
|
1ccf8fe116 | ||
|
|
d22f0eefe2 | ||
|
|
344c66f7ab | ||
|
|
83be002b41 | ||
|
|
9945cd3991 | ||
|
|
667263a3a8 | ||
|
|
6318fc424b | ||
|
|
e32a37a747 | ||
|
|
7805a71332 | ||
|
|
6fb2aee210 | ||
|
|
ce9bb588ed | ||
|
|
55fb914537 | ||
|
|
6f7bfd6c9c | ||
|
|
fbc0d981ba | ||
|
|
48d1c673a9 | ||
|
|
889a0eb8b3 | ||
|
|
b01c10b73e | ||
|
|
f6ad64bf69 | ||
|
|
a5e8443735 | ||
|
|
2860929a47 | ||
|
|
ecdac56616 | ||
|
|
25918e56fa | ||
|
|
df91f016f2 | ||
|
|
f7d259f653 | ||
|
|
82ab4328d4 | ||
|
|
b21e19a223 | ||
|
|
c6aed4eb57 | ||
|
|
760d466b38 | ||
|
|
258eecc55b | ||
|
|
7b93f5d7cf | ||
|
|
3be7676970 | ||
|
|
48ee2f8a42 | ||
|
|
ec83d0eb86 | ||
|
|
19347d263a | ||
|
|
b22b16d98c | ||
|
|
cadb765b7d | ||
|
|
1116da8b83 | ||
|
|
c65700923a | ||
|
|
7030c7c24c | ||
|
|
09718a2170 | ||
|
|
77cda2c2c5 | ||
|
|
d9c57cdd9a | ||
|
|
0dad53b5f7 | ||
|
|
9a7315dbea | ||
|
|
a21558f418 | ||
|
|
4f928be393 | ||
|
|
f543b47261 | ||
|
|
66b831072c | ||
|
|
80eb339896 | ||
|
|
c69639c013 | ||
|
|
5f6fbe8e08 | ||
|
|
b40bbf0260 | ||
|
|
a11c8e91ab | ||
|
|
618d0e596c | ||
|
|
fde9ea5cb1 | ||
|
|
03d9803238 | ||
|
|
bd64b8b014 | ||
|
|
67fe064fcf | ||
|
|
471668d19a | ||
|
|
42762dfe54 | ||
|
|
c9314610d4 | ||
|
|
16755c37d8 | ||
|
|
73fcbd6ea2 | ||
|
|
30244f361f | ||
|
|
083991fb21 | ||
|
|
e3200d5046 | ||
|
|
21f6ed36fe | ||
|
|
77d04d44eb | ||
|
|
b34d119255 | ||
|
|
63bd71a10c | ||
|
|
faf2b32389 | ||
|
|
d23da1b757 | ||
|
|
beb8e36eee | ||
|
|
fe70161c01 | ||
|
|
39fa4b1f8e | ||
|
|
c7a8883201 | ||
|
|
3fd467413c | ||
|
|
9728dd856f | ||
|
|
ecadf78d60 | ||
|
|
eba4612d72 | ||
|
|
9200163f85 | ||
|
|
3c17853344 | ||
|
|
993a3fc42c | ||
|
|
7bb9b2416b | ||
|
|
2de321ce24 | ||
|
|
30b2b85ff1 | ||
|
|
6ea4788f56 | ||
|
|
3c52a021b9 | ||
|
|
0ceea9af5f | ||
|
|
39b404be19 | ||
|
|
56dc3eab02 | ||
|
|
554a38eeb1 | ||
|
|
c8d3afe869 | ||
|
|
44222c604c | ||
|
|
cbf221585e | ||
|
|
48322f6d0d | ||
|
|
d5a213c639 | ||
|
|
8166c4bc33 | ||
|
|
133cd9ef6b | ||
|
|
11146f73ed | ||
|
|
1beb18db47 | ||
|
|
f7bc2273a7 | ||
|
|
d1e735a003 | ||
|
|
af4ff5c748 | ||
|
|
fc951c1226 | ||
|
|
f81fd42472 | ||
|
|
1029ea5966 | ||
|
|
c0abdb4bc7 | ||
|
|
6681178ad3 | ||
|
|
e73605c4c5 | ||
|
|
c7017d49aa | ||
|
|
191d3bde49 | ||
|
|
18876a8e5c | ||
|
|
aa4a0d9880 | ||
|
|
d03d731d65 | ||
|
|
33b58a42fe | ||
|
|
2e9e708647 | ||
|
|
8ad22841ab | ||
|
|
32cf02264c | ||
|
|
0e8445b38f | ||
|
|
cb27ecd6c0 | ||
|
|
535220b3df | ||
|
|
7b9cfa94f8 | ||
|
|
b3f2e88e9c | ||
|
|
aa7a8d45e4 | ||
|
|
49cdef3d6a | ||
|
|
fecd146552 | ||
|
|
b01efd8c98 | ||
|
|
bd4dfb261c | ||
|
|
e754e4d2f6 | ||
|
|
b220e35c99 | ||
|
|
4f5131f423 | ||
|
|
dcb043df5f | ||
|
|
86e5756262 | ||
|
|
ba0cf5739b | ||
|
|
c4a13b92d2 | ||
|
|
723279121a | ||
|
|
3ad7649f7d | ||
|
|
2898a49d86 | ||
|
|
1547f9d35e | ||
|
|
adadd55c42 | ||
|
|
33b0225aa4 | ||
|
|
97d4058d80 | ||
|
|
86207e719d | ||
|
|
113f94ff46 | ||
|
|
5673deb391 | ||
|
|
3548a393ed | ||
|
|
254515f274 | ||
|
|
bccbecc6ea | ||
|
|
66f53803af | ||
|
|
faef005ce7 | ||
|
|
941cd2a562 | ||
|
|
6a4a9a0239 | ||
|
|
b9dbe6f1b6 | ||
|
|
7fec111ef8 | ||
|
|
5e1bd99718 | ||
|
|
9d75f72ceb | ||
|
|
d98d96db1a | ||
|
|
6a0ef2df15 | ||
|
|
75c2eb4c8a | ||
|
|
d021a67d66 | ||
|
|
4ed97cab12 | ||
|
|
a38742eed7 | ||
|
|
5efa95ed26 | ||
|
|
04db7db607 | ||
|
|
d17c6c6bb3 | ||
|
|
b2052f2ef1 | ||
|
|
cddcf852c2 | ||
|
|
1def426b45 | ||
|
|
b114fd5279 | ||
|
|
d27c3284f6 | ||
|
|
ba24a26b53 | ||
|
|
3e6678b6b4 | ||
|
|
83fd6f9782 | ||
|
|
52bc1b3f10 | ||
|
|
dd2153b7ac | ||
|
|
dd96a34861 | ||
|
|
daf26ee25a | ||
|
|
7e140eaaac | ||
|
|
d07a712988 | ||
|
|
95863288bf | ||
|
|
ea12be658b | ||
|
|
faa7c9aae5 | ||
|
|
e3653e8c25 | ||
|
|
b40cb24822 | ||
|
|
74004c1aa0 | ||
|
|
3e240741f1 | ||
|
|
6cfdbef1a5 | ||
|
|
d9bde6425b | ||
|
|
e2ae9e1591 | ||
|
|
5ebcbfa9ad | ||
|
|
e276bd7a31 | ||
|
|
659b2529bf | ||
|
|
97b3ed43ab | ||
|
|
767d6d3f28 | ||
|
|
31fc9bfc52 | ||
|
|
3f06b02409 | ||
|
|
5bf958ec6b | ||
|
|
959d9ff9a0 | ||
|
|
4813b4de25 | ||
|
|
119100924c | ||
|
|
bd584de4ee | ||
|
|
ede85ab2f2 | ||
|
|
12c20288e4 | ||
|
|
5bbbf89c10 | ||
|
|
d55393ecd5 | ||
|
|
2b5927306f | ||
|
|
4f016b6ed7 | ||
|
|
3a2a6d10ec | ||
|
|
2491426b09 | ||
|
|
5ebdd1390e | ||
|
|
b7f0247575 | ||
|
|
e28186a28a | ||
|
|
de1a7ce48f | ||
|
|
48480fb33b | ||
|
|
f41332fe6b | ||
|
|
1f8b340b8f | ||
|
|
fdaf1d09d3 | ||
|
|
b9682c4f10 | ||
|
|
69dcb4effd | ||
|
|
d50fd0ba91 | ||
|
|
c2c7b4c731 | ||
|
|
952d5f3a3d | ||
|
|
3f126c9ec9 | ||
|
|
0be58ef918 | ||
|
|
8f9053e2fc | ||
|
|
68452e5330 | ||
|
|
2eacc46eaa | ||
|
|
74dcc91ea7 | ||
|
|
dd7bf61323 | ||
|
|
2819d6cace | ||
|
|
75355a6883 | ||
|
|
e9c007d56b | ||
|
|
84c9085516 | ||
|
|
9f36e57c1e | ||
|
|
7528699fc2 | ||
|
|
d280151c18 | ||
|
|
b44c755d25 | ||
|
|
e4078e87a1 | ||
|
|
be36204756 | ||
|
|
b5409d6d00 | ||
|
|
f3d6bce03e |
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -1,7 +1,7 @@
|
||||
'name': 'build'
|
||||
|
||||
'env':
|
||||
'GO_VERSION': '1.23.1'
|
||||
'GO_VERSION': '1.22.5'
|
||||
'NODE_VERSION': '16'
|
||||
|
||||
'on':
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -1,7 +1,7 @@
|
||||
'name': 'lint'
|
||||
|
||||
'env':
|
||||
'GO_VERSION': '1.23.1'
|
||||
'GO_VERSION': '1.22.5'
|
||||
|
||||
'on':
|
||||
'push':
|
||||
|
||||
28
CHANGELOG.md
28
CHANGELOG.md
@@ -27,34 +27,6 @@ See also the [v0.107.53 GitHub milestone][ms-v0.107.53].
|
||||
NOTE: Add new changes BELOW THIS COMMENT.
|
||||
-->
|
||||
|
||||
### Security
|
||||
|
||||
- Go version has been updated to prevent the possibility of exploiting the Go
|
||||
vulnerabilities fixed in [1.23.1][go-1.23.1].
|
||||
|
||||
### Added
|
||||
|
||||
- Support for 64-bit RISC-V architecture ([#5704]).
|
||||
- Ecosia search engine is now supported in safe search ([#5009]).
|
||||
|
||||
### Changed
|
||||
|
||||
- Upstream server URL domain names requirements has been relaxed and now follow
|
||||
the same rules as their domain specifications.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Update Google safe search domains list ([#7155]).
|
||||
- Enforce Bing safe search from Edge sidebar ([#7154]).
|
||||
- Text overflow on the query log page ([#7119]).
|
||||
|
||||
[#5009]: https://github.com/AdguardTeam/AdGuardHome/issues/5009
|
||||
[#7119]: https://github.com/AdguardTeam/AdGuardHome/issues/7119
|
||||
[#7154]: https://github.com/AdguardTeam/AdGuardHome/pull/7154
|
||||
[#7155]: https://github.com/AdguardTeam/AdGuardHome/pull/7155
|
||||
|
||||
[go-1.23.1]: https://groups.google.com/g/golang-announce/c/K-cEzDeCtpc
|
||||
|
||||
<!--
|
||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||
-->
|
||||
|
||||
48
Makefile
48
Makefile
@@ -8,7 +8,7 @@
|
||||
# Makefile. Bump this number every time a significant change is made to
|
||||
# this Makefile.
|
||||
#
|
||||
# AdGuard-Project-Version: 6
|
||||
# AdGuard-Project-Version: 4
|
||||
|
||||
# Don't name these macros "GO" etc., because GNU Make apparently makes
|
||||
# them exported environment variables with the literal value of
|
||||
@@ -23,13 +23,11 @@ VERBOSE.MACRO = $${VERBOSE:-0}
|
||||
CHANNEL = development
|
||||
CLIENT_DIR = client
|
||||
COMMIT = $$( git rev-parse --short HEAD )
|
||||
DEPLOY_SCRIPT_PATH = not/a/real/path
|
||||
DIST_DIR = dist
|
||||
GOAMD64 = v1
|
||||
GOPROXY = https://proxy.golang.org|direct
|
||||
GOPROXY = https://goproxy.cn|https://proxy.golang.org|direct
|
||||
GOSUMDB = sum.golang.google.cn
|
||||
GOTOOLCHAIN = go1.23.1
|
||||
GOTELEMETRY = off
|
||||
GOTOOLCHAIN = go1.22.5
|
||||
GPG_KEY = devteam@adguard.com
|
||||
GPG_KEY_PASSPHRASE = not-a-real-password
|
||||
NPM = npm
|
||||
@@ -38,7 +36,6 @@ NPM_INSTALL_FLAGS = $(NPM_FLAGS) --quiet --no-progress --ignore-engines\
|
||||
--ignore-optional --ignore-platform --ignore-scripts
|
||||
RACE = 0
|
||||
SIGN = 1
|
||||
SIGNER_API_KEY = not-a-real-key
|
||||
VERSION = v0.0.0
|
||||
YARN = yarn
|
||||
|
||||
@@ -62,28 +59,20 @@ BUILD_RELEASE_DEPS_1 = go-deps
|
||||
ENV = env\
|
||||
CHANNEL='$(CHANNEL)'\
|
||||
COMMIT='$(COMMIT)'\
|
||||
DEPLOY_SCRIPT_PATH='$(DEPLOY_SCRIPT_PATH)' \
|
||||
DIST_DIR='$(DIST_DIR)'\
|
||||
GO="$(GO.MACRO)"\
|
||||
GOAMD64='$(GOAMD64)'\
|
||||
GOAMD64="$(GOAMD64)"\
|
||||
GOPROXY='$(GOPROXY)'\
|
||||
GOSUMDB='$(GOSUMDB)'\
|
||||
GOTELEMETRY='$(GOTELEMETRY)'\
|
||||
GOTOOLCHAIN='$(GOTOOLCHAIN)'\
|
||||
GPG_KEY='$(GPG_KEY)'\
|
||||
GPG_KEY_PASSPHRASE='$(GPG_KEY_PASSPHRASE)'\
|
||||
PATH="$${PWD}/bin:$$( "$(GO.MACRO)" env GOPATH )/bin:$${PATH}"\
|
||||
RACE='$(RACE)'\
|
||||
SIGN='$(SIGN)'\
|
||||
SIGNER_API_KEY='$(SIGNER_API_KEY)' \
|
||||
NEXTAPI='$(NEXTAPI)'\
|
||||
VERBOSE="$(VERBOSE.MACRO)"\
|
||||
VERSION="$(VERSION)"\
|
||||
|
||||
# Keep the line above blank.
|
||||
|
||||
ENV_MISC = env\
|
||||
VERBOSE="$(VERBOSE.MACRO)"\
|
||||
VERSION='$(VERSION)'\
|
||||
|
||||
# Keep the line above blank.
|
||||
|
||||
@@ -112,22 +101,23 @@ js-deps: ; $(NPM) $(NPM_INSTALL_FLAGS) ci
|
||||
js-lint: ; $(NPM) $(NPM_FLAGS) run lint
|
||||
js-test: ; $(NPM) $(NPM_FLAGS) run test
|
||||
|
||||
go-bench: ; $(ENV) "$(SHELL)" ./scripts/make/go-bench.sh
|
||||
go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh
|
||||
go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh
|
||||
go-env: ; $(ENV) "$(GO.MACRO)" env
|
||||
go-fuzz: ; $(ENV) "$(SHELL)" ./scripts/make/go-fuzz.sh
|
||||
go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh
|
||||
go-bench: ; $(ENV) "$(SHELL)" ./scripts/make/go-bench.sh
|
||||
go-build: ; $(ENV) "$(SHELL)" ./scripts/make/go-build.sh
|
||||
go-deps: ; $(ENV) "$(SHELL)" ./scripts/make/go-deps.sh
|
||||
go-fuzz: ; $(ENV) "$(SHELL)" ./scripts/make/go-fuzz.sh
|
||||
go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh
|
||||
go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh
|
||||
|
||||
# TODO(a.garipov): Think about making RACE='1' the default for all
|
||||
# targets.
|
||||
go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
|
||||
go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh
|
||||
go-upd-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-upd-tools.sh
|
||||
go-test: ; $(ENV) RACE='1' "$(SHELL)" ./scripts/make/go-test.sh
|
||||
|
||||
go-upd-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-upd-tools.sh
|
||||
|
||||
go-check: go-tools go-lint go-test
|
||||
|
||||
# A quick check to make sure that all operating systems relevant to the
|
||||
# development of the project can be typechecked and built successfully.
|
||||
# A quick check to make sure that all supported operating systems can be
|
||||
# typechecked and built successfully.
|
||||
go-os-check:
|
||||
env GOOS='darwin' "$(GO.MACRO)" vet ./internal/...
|
||||
env GOOS='freebsd' "$(GO.MACRO)" vet ./internal/...
|
||||
@@ -135,11 +125,7 @@ go-os-check:
|
||||
env GOOS='linux' "$(GO.MACRO)" vet ./internal/...
|
||||
env GOOS='windows' "$(GO.MACRO)" vet ./internal/...
|
||||
|
||||
|
||||
openapi-lint: ; cd ./openapi/ && $(YARN) test
|
||||
openapi-show: ; cd ./openapi/ && $(YARN) start
|
||||
|
||||
txt-lint: ; $(ENV) "$(SHELL)" ./scripts/make/txt-lint.sh
|
||||
|
||||
md-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/md-lint.sh
|
||||
sh-lint: ; $(ENV_MISC) "$(SHELL)" ./scripts/make/sh-lint.sh
|
||||
|
||||
@@ -205,7 +205,7 @@ Run `make init` to prepare the development environment.
|
||||
|
||||
You will need this to build AdGuard Home:
|
||||
|
||||
- [Go](https://golang.org/dl/) v1.23 or later;
|
||||
- [Go](https://golang.org/dl/) v1.22 or later;
|
||||
- [Node.js](https://nodejs.org/en/download/) v18.18 or later;
|
||||
- [npm](https://www.npmjs.com/) v8 or later;
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
'variables':
|
||||
'channel': 'edge'
|
||||
'dockerFrontend': 'adguard/home-js-builder:2.0'
|
||||
'dockerGo': 'adguard/go-builder:1.23.1--1'
|
||||
'dockerGo': 'adguard/go-builder:1.22.5--1'
|
||||
|
||||
'stages':
|
||||
- 'Build frontend':
|
||||
@@ -91,11 +91,6 @@
|
||||
'tasks':
|
||||
- 'checkout':
|
||||
'force-clean-build': true
|
||||
- 'checkout':
|
||||
'repository': 'bamboo-deploy-publisher'
|
||||
# The paths are always relative to the working directory.
|
||||
'path': 'bamboo-deploy-publisher'
|
||||
'force-clean-build': true
|
||||
- 'script':
|
||||
'interpreter': 'SHELL'
|
||||
'scripts':
|
||||
@@ -104,9 +99,6 @@
|
||||
|
||||
set -e -f -u -x
|
||||
|
||||
# Explicitly checkout the revision that we need.
|
||||
git checkout "${bamboo.repository.revision.number}"
|
||||
|
||||
# Run the build with the specified channel.
|
||||
echo "${bamboo.gpgSecretKeyPart1}${bamboo.gpgSecretKeyPart2}"\
|
||||
| awk '{ gsub(/\\n/, "\n"); print; }'\
|
||||
@@ -115,8 +107,6 @@
|
||||
make\
|
||||
CHANNEL=${bamboo.channel}\
|
||||
GPG_KEY_PASSPHRASE=${bamboo.gpgPassword}\
|
||||
DEPLOY_SCRIPT_PATH="./bamboo-deploy-publisher/deploy.sh"\
|
||||
SIGNER_API_KEY="${bamboo.adguardHomeWinSignerSecretApiKey}"\
|
||||
FRONTEND_PREBUILT=1\
|
||||
PARALLELISM=1\
|
||||
VERBOSE=2\
|
||||
@@ -276,7 +266,7 @@
|
||||
'variables':
|
||||
'channel': 'beta'
|
||||
'dockerFrontend': 'adguard/home-js-builder:2.0'
|
||||
'dockerGo': 'adguard/go-builder:1.23.1--1'
|
||||
'dockerGo': 'adguard/go-builder:1.22.5--1'
|
||||
# release-vX.Y.Z branches are the branches from which the actual final
|
||||
# release is built.
|
||||
- '^release-v[0-9]+\.[0-9]+\.[0-9]+':
|
||||
@@ -292,4 +282,4 @@
|
||||
'variables':
|
||||
'channel': 'release'
|
||||
'dockerFrontend': 'adguard/home-js-builder:2.0'
|
||||
'dockerGo': 'adguard/go-builder:1.23.1--1'
|
||||
'dockerGo': 'adguard/go-builder:1.22.5--1'
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
'name': 'AdGuard Home - Build and run tests'
|
||||
'variables':
|
||||
'dockerFrontend': 'adguard/home-js-builder:2.0'
|
||||
'dockerGo': 'adguard/go-builder:1.23.1--1'
|
||||
'dockerGo': 'adguard/go-builder:1.22.5--1'
|
||||
'channel': 'development'
|
||||
|
||||
'stages':
|
||||
@@ -54,7 +54,6 @@
|
||||
'requirements':
|
||||
- 'adg-docker': 'true'
|
||||
|
||||
# TODO(e.burkov): Add the linting stage for markdown docs and shell scripts.
|
||||
'Test backend':
|
||||
'docker':
|
||||
'image': '${bamboo.dockerGo}'
|
||||
@@ -196,5 +195,5 @@
|
||||
# may need to build a few of these.
|
||||
'variables':
|
||||
'dockerFrontend': 'adguard/home-js-builder:2.0'
|
||||
'dockerGo': 'adguard/go-builder:1.23.1--1'
|
||||
'dockerGo': 'adguard/go-builder:1.22.5--1'
|
||||
'channel': 'candidate'
|
||||
|
||||
@@ -154,7 +154,7 @@
|
||||
"use_adguard_parental": "Use AdGuard parental control web service",
|
||||
"use_adguard_parental_hint": "AdGuard Home will check if domain contains adult materials. It uses the same privacy-friendly API as the browsing security web service.",
|
||||
"enforce_safe_search": "Use Safe Search",
|
||||
"enforce_save_search_hint": "AdGuard Home will enforce safe search in the following search engines: Google, YouTube, Bing, DuckDuckGo, Ecosia, Yandex, Pixabay.",
|
||||
"enforce_save_search_hint": "AdGuard Home will enforce safe search in the following search engines: Google, YouTube, Bing, DuckDuckGo, Yandex, Pixabay.",
|
||||
"no_servers_specified": "No servers specified",
|
||||
"general_settings": "General settings",
|
||||
"dns_settings": "DNS settings",
|
||||
|
||||
@@ -66,7 +66,7 @@ export const renderFormattedClientCell = (value: any, info: any, isDetailed = fa
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="logs__text logs__text--client mw-100" title={value}>
|
||||
<div className="logs__text mw-100" title={value}>
|
||||
<Link to={`logs?search="${encodeURIComponent(value)}"`}>{nameContainer}</Link>
|
||||
{whoisContainer}
|
||||
</div>
|
||||
|
||||
22
go.mod
22
go.mod
@@ -1,10 +1,10 @@
|
||||
module github.com/AdguardTeam/AdGuardHome
|
||||
|
||||
go 1.23.1
|
||||
go 1.22.5
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.73.0
|
||||
github.com/AdguardTeam/golibs v0.26.0
|
||||
github.com/AdguardTeam/dnsproxy v0.71.2
|
||||
github.com/AdguardTeam/golibs v0.24.0
|
||||
github.com/AdguardTeam/urlfilter v0.19.0
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
github.com/ameshkov/dnscrypt/v2 v2.3.0
|
||||
@@ -32,10 +32,10 @@ require (
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/ti-mo/netfilter v0.5.2
|
||||
go.etcd.io/bbolt v1.3.10
|
||||
golang.org/x/crypto v0.26.0
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa
|
||||
golang.org/x/net v0.28.0
|
||||
golang.org/x/sys v0.24.0
|
||||
golang.org/x/crypto v0.24.0
|
||||
golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/sys v0.21.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
howett.net/plist v1.0.1
|
||||
@@ -58,9 +58,9 @@ require (
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/mod v0.20.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
golang.org/x/tools v0.24.0 // indirect
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/tools v0.22.0 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
)
|
||||
|
||||
40
go.sum
40
go.sum
@@ -1,7 +1,7 @@
|
||||
github.com/AdguardTeam/dnsproxy v0.73.0 h1:E1fxzosMqExZH8h7OJnKXLxyktcAFRJapLF4+nKULms=
|
||||
github.com/AdguardTeam/dnsproxy v0.73.0/go.mod h1:ZcvmyQY2EiX5B0yCTkiYTgtm+1lBWA0lajbEI9dOhW4=
|
||||
github.com/AdguardTeam/golibs v0.26.0 h1:uLL0XggEjB+87lL1tPpEAQNoKAlHDq5AyBUVWEgf63E=
|
||||
github.com/AdguardTeam/golibs v0.26.0/go.mod h1:iWdjXPCwmK2g2FKIb/OwEPnovSXeMqRhI8FWLxF5oxE=
|
||||
github.com/AdguardTeam/dnsproxy v0.71.2 h1:dFG2wga4GDdj1eI3rU2wqjQ6QGQm9MjLRb5ZzyH3Vgg=
|
||||
github.com/AdguardTeam/dnsproxy v0.71.2/go.mod h1:huI5zyWhlimHBhg0jt2CMinXzsEHymI+WlvxIfmfEGA=
|
||||
github.com/AdguardTeam/golibs v0.24.0 h1:qAnOq7BQtwSVo7Co9q703/n+nZ2Ap6smkugU9G9MomY=
|
||||
github.com/AdguardTeam/golibs v0.24.0/go.mod h1:9/vJcYznW7RlmCT/Qzi8XNZGj+ZbWfHZJmEXKnRpCAU=
|
||||
github.com/AdguardTeam/urlfilter v0.19.0 h1:q7eH13+yNETlpD/VD3u5rLQOripcUdEktqZFy+KiQLk=
|
||||
github.com/AdguardTeam/urlfilter v0.19.0/go.mod h1:+N54ZvxqXYLnXuvpaUhK2exDQW+djZBRSb6F6j0rkBY=
|
||||
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
||||
@@ -128,26 +128,26 @@ go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI=
|
||||
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8 h1:LoYXNGAShUG3m/ehNk4iFctuhGX/+R1ZpfJ4/ia80JM=
|
||||
golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
|
||||
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
|
||||
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -158,19 +158,19 @@ golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
|
||||
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
|
||||
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.15.0 h1:2lYxjRbTYyxkJxlhC+LvJIx3SsANPdRybu1tGj9/OrQ=
|
||||
|
||||
94
internal/aghalg/ringbuffer.go
Normal file
94
internal/aghalg/ringbuffer.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package aghalg
|
||||
|
||||
// RingBuffer is the implementation of ring buffer data structure.
|
||||
type RingBuffer[T any] struct {
|
||||
buf []T
|
||||
cur uint
|
||||
full bool
|
||||
}
|
||||
|
||||
// NewRingBuffer initializes the new instance of ring buffer. size must be
|
||||
// greater or equal to zero.
|
||||
func NewRingBuffer[T any](size uint) (rb *RingBuffer[T]) {
|
||||
return &RingBuffer[T]{
|
||||
buf: make([]T, size),
|
||||
}
|
||||
}
|
||||
|
||||
// Append appends an element to the buffer.
|
||||
func (rb *RingBuffer[T]) Append(e T) {
|
||||
if len(rb.buf) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rb.buf[rb.cur] = e
|
||||
rb.cur = (rb.cur + 1) % uint(cap(rb.buf))
|
||||
if rb.cur == 0 {
|
||||
rb.full = true
|
||||
}
|
||||
}
|
||||
|
||||
// Range calls cb for each element of the buffer. If cb returns false it stops.
|
||||
func (rb *RingBuffer[T]) Range(cb func(T) (cont bool)) {
|
||||
before, after := rb.splitCur()
|
||||
|
||||
for _, e := range before {
|
||||
if !cb(e) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range after {
|
||||
if !cb(e) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseRange calls cb for each element of the buffer in reverse order. If
|
||||
// cb returns false it stops.
|
||||
func (rb *RingBuffer[T]) ReverseRange(cb func(T) (cont bool)) {
|
||||
before, after := rb.splitCur()
|
||||
|
||||
for i := len(after) - 1; i >= 0; i-- {
|
||||
if !cb(after[i]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i := len(before) - 1; i >= 0; i-- {
|
||||
if !cb(before[i]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitCur splits the buffer in two, before and after current position in
|
||||
// chronological order. If buffer is not full, after is nil.
|
||||
func (rb *RingBuffer[T]) splitCur() (before, after []T) {
|
||||
if len(rb.buf) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cur := rb.cur
|
||||
if !rb.full {
|
||||
return rb.buf[:cur], nil
|
||||
}
|
||||
|
||||
return rb.buf[cur:], rb.buf[:cur]
|
||||
}
|
||||
|
||||
// Len returns a length of the buffer.
|
||||
func (rb *RingBuffer[T]) Len() (l uint) {
|
||||
if !rb.full {
|
||||
return rb.cur
|
||||
}
|
||||
|
||||
return uint(cap(rb.buf))
|
||||
}
|
||||
|
||||
// Clear clears the buffer.
|
||||
func (rb *RingBuffer[T]) Clear() {
|
||||
rb.full = false
|
||||
rb.cur = 0
|
||||
}
|
||||
169
internal/aghalg/ringbuffer_test.go
Normal file
169
internal/aghalg/ringbuffer_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package aghalg_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// elements is a helper function that returns n elements of the buffer.
|
||||
func elements(b *aghalg.RingBuffer[int], n uint, reverse bool) (es []int) {
|
||||
fn := b.Range
|
||||
if reverse {
|
||||
fn = b.ReverseRange
|
||||
}
|
||||
|
||||
var i uint
|
||||
fn(func(e int) (cont bool) {
|
||||
if i >= n {
|
||||
return false
|
||||
}
|
||||
|
||||
es = append(es, e)
|
||||
i++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return es
|
||||
}
|
||||
|
||||
func TestNewRingBuffer(t *testing.T) {
|
||||
t.Run("success_and_clear", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](5)
|
||||
for i := range 10 {
|
||||
b.Append(i)
|
||||
}
|
||||
assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false))
|
||||
|
||||
b.Clear()
|
||||
assert.Zero(t, b.Len())
|
||||
})
|
||||
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](0)
|
||||
for i := range 10 {
|
||||
b.Append(i)
|
||||
bufLen := b.Len()
|
||||
assert.EqualValues(t, 0, bufLen)
|
||||
assert.Empty(t, elements(b, bufLen, false))
|
||||
assert.Empty(t, elements(b, bufLen, true))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single", func(t *testing.T) {
|
||||
b := aghalg.NewRingBuffer[int](1)
|
||||
for i := range 10 {
|
||||
b.Append(i)
|
||||
bufLen := b.Len()
|
||||
assert.EqualValues(t, 1, bufLen)
|
||||
assert.Equal(t, []int{i}, elements(b, bufLen, false))
|
||||
assert.Equal(t, []int{i}, elements(b, bufLen, true))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRingBuffer_Range(t *testing.T) {
|
||||
const size = 5
|
||||
|
||||
b := aghalg.NewRingBuffer[int](size)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want []int
|
||||
count int
|
||||
length uint
|
||||
}{{
|
||||
name: "three",
|
||||
count: 3,
|
||||
length: 3,
|
||||
want: []int{0, 1, 2},
|
||||
}, {
|
||||
name: "ten",
|
||||
count: 10,
|
||||
length: size,
|
||||
want: []int{5, 6, 7, 8, 9},
|
||||
}, {
|
||||
name: "hundred",
|
||||
count: 100,
|
||||
length: size,
|
||||
want: []int{95, 96, 97, 98, 99},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for i := range tc.count {
|
||||
b.Append(i)
|
||||
}
|
||||
|
||||
bufLen := b.Len()
|
||||
assert.Equal(t, tc.length, bufLen)
|
||||
|
||||
want := tc.want
|
||||
assert.Equal(t, want, elements(b, bufLen, false))
|
||||
assert.Equal(t, want[:len(want)-1], elements(b, bufLen-1, false))
|
||||
assert.Equal(t, want[:len(want)/2], elements(b, bufLen/2, false))
|
||||
|
||||
want = want[:cap(want)]
|
||||
slices.Reverse(want)
|
||||
|
||||
assert.Equal(t, want, elements(b, bufLen, true))
|
||||
assert.Equal(t, want[:len(want)-1], elements(b, bufLen-1, true))
|
||||
assert.Equal(t, want[:len(want)/2], elements(b, bufLen/2, true))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBuffer_Range_increment(t *testing.T) {
|
||||
const size = 5
|
||||
|
||||
b := aghalg.NewRingBuffer[int](size)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want []int
|
||||
}{{
|
||||
name: "one",
|
||||
want: []int{0},
|
||||
}, {
|
||||
name: "two",
|
||||
want: []int{0, 1},
|
||||
}, {
|
||||
name: "three",
|
||||
want: []int{0, 1, 2},
|
||||
}, {
|
||||
name: "four",
|
||||
want: []int{0, 1, 2, 3},
|
||||
}, {
|
||||
name: "five",
|
||||
want: []int{0, 1, 2, 3, 4},
|
||||
}, {
|
||||
name: "six",
|
||||
want: []int{1, 2, 3, 4, 5},
|
||||
}, {
|
||||
name: "seven",
|
||||
want: []int{2, 3, 4, 5, 6},
|
||||
}, {
|
||||
name: "eight",
|
||||
want: []int{3, 4, 5, 6, 7},
|
||||
}, {
|
||||
name: "nine",
|
||||
want: []int{4, 5, 6, 7, 8},
|
||||
}, {
|
||||
name: "ten",
|
||||
want: []int{5, 6, 7, 8, 9},
|
||||
}}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
b.Append(i)
|
||||
bufLen := b.Len()
|
||||
assert.Equal(t, tc.want, elements(b, bufLen, false))
|
||||
|
||||
slices.Reverse(tc.want)
|
||||
assert.Equal(t, tc.want, elements(b, bufLen, true))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,16 +2,13 @@
|
||||
package aghhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// HTTP scheme constants.
|
||||
@@ -34,39 +31,12 @@ func OK(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
// Error writes formatted message to w and also logs it.
|
||||
//
|
||||
// TODO(s.chzhen): Remove it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// ErrorAndLog writes formatted message to w and also logs it with the specified
|
||||
// logging level.
|
||||
func ErrorAndLog(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
r *http.Request,
|
||||
w http.ResponseWriter,
|
||||
code int,
|
||||
format string,
|
||||
args ...any,
|
||||
) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
l.ErrorContext(
|
||||
ctx,
|
||||
"http error",
|
||||
"host", r.Host,
|
||||
"method", r.Method,
|
||||
"raddr", r.RemoteAddr,
|
||||
"request_uri", r.RequestURI,
|
||||
slogutil.KeyError, text,
|
||||
)
|
||||
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// UserAgent returns the ID of the service as a User-Agent string. It can also
|
||||
// be used as the value of the Server HTTP header.
|
||||
func UserAgent() (ua string) {
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
@@ -38,13 +38,9 @@ func (n interfaceName) rcConfStaticConfig(r io.Reader) (_ []string, cont bool, e
|
||||
// TODO(e.burkov): Expand the check to cover possible
|
||||
// configurations from man rc.conf(5).
|
||||
fields := strings.Fields(line[cfgLeft:cfgRight])
|
||||
switch {
|
||||
case
|
||||
len(fields) < 2,
|
||||
!strings.EqualFold(fields[0], "inet"),
|
||||
!netutil.IsValidIPString(fields[1]):
|
||||
continue
|
||||
default:
|
||||
if len(fields) >= 2 &&
|
||||
strings.EqualFold(fields[0], "inet") &&
|
||||
net.ParseIP(fields[1]) != nil {
|
||||
return nil, false, s.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
func ifaceHasStaticIP(ifaceName string) (ok bool, err error) {
|
||||
@@ -25,13 +25,7 @@ func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) {
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
fields := strings.Fields(line)
|
||||
switch {
|
||||
case
|
||||
len(fields) < 2,
|
||||
fields[0] != "inet",
|
||||
!netutil.IsValidIPString(fields[1]):
|
||||
continue
|
||||
default:
|
||||
if len(fields) >= 2 && fields[0] == "inet" && net.ParseIP(fields[1]) != nil {
|
||||
return nil, false, s.Err()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package aghos
|
||||
|
||||
// ConfigureSyslog reroutes standard logger output to syslog.
|
||||
func ConfigureSyslog(serviceName string) (err error) {
|
||||
func ConfigureSyslog(serviceName string) error {
|
||||
return configureSyslog(serviceName)
|
||||
}
|
||||
|
||||
@@ -8,15 +8,11 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// configureSyslog sets standard log output to syslog.
|
||||
func configureSyslog(serviceName string) (err error) {
|
||||
func configureSyslog(serviceName string) error {
|
||||
w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
log.SetOutput(w)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -19,30 +19,23 @@ func (w *eventLogWriter) Write(b []byte) (int, error) {
|
||||
return len(b), w.el.Info(1, string(b))
|
||||
}
|
||||
|
||||
// configureSyslog sets standard log output to event log.
|
||||
func configureSyslog(serviceName string) (err error) {
|
||||
// Note that the eventlog src is the same as the service name, otherwise we
|
||||
// will get "the description for event id cannot be found" warning in every
|
||||
// log record.
|
||||
func configureSyslog(serviceName string) error {
|
||||
// Note that the eventlog src is the same as the service name
|
||||
// Otherwise, we will get "the description for event id cannot be found" warning in every log record
|
||||
|
||||
// Continue if we receive "registry key already exists" or if we get
|
||||
// ERROR_ACCESS_DENIED so that we can log without administrative permissions
|
||||
// for pre-existing eventlog sources.
|
||||
err = eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error)
|
||||
if err != nil &&
|
||||
!strings.Contains(err.Error(), "registry key already exists") &&
|
||||
err != windows.ERROR_ACCESS_DENIED {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
if err := eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error); err != nil {
|
||||
if !strings.Contains(err.Error(), "registry key already exists") && err != windows.ERROR_ACCESS_DENIED {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
el, err := eventlog.Open(serviceName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
log.SetOutput(&eventLogWriter{el: el})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,10 +5,9 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"slices"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// init makes sure that the cipher name map is filled.
|
||||
@@ -76,5 +75,15 @@ func SaferCipherSuites() (safe []uint16) {
|
||||
// CertificateHasIP returns true if cert has at least a single IP address among
|
||||
// its subjectAltNames.
|
||||
func CertificateHasIP(cert *x509.Certificate) (ok bool) {
|
||||
return len(cert.IPAddresses) > 0 || slices.ContainsFunc(cert.DNSNames, netutil.IsValidIPString)
|
||||
if len(cert.IPAddresses) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, name := range cert.DNSNames {
|
||||
if _, err := netip.ParseAddr(name); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
@@ -39,8 +38,8 @@ type Interface interface {
|
||||
}
|
||||
|
||||
// New returns the [Interface] properly initialized for the OS.
|
||||
func New(logger *slog.Logger) (arp Interface) {
|
||||
return newARPDB(logger)
|
||||
func New() (arp Interface) {
|
||||
return newARPDB()
|
||||
}
|
||||
|
||||
// Empty is the [Interface] implementation that does nothing.
|
||||
@@ -70,30 +69,6 @@ type Neighbor struct {
|
||||
MAC net.HardwareAddr
|
||||
}
|
||||
|
||||
// newNeighbor returns the new initialized [Neighbor] by parsing string
|
||||
// representations of IP and MAC addresses.
|
||||
func newNeighbor(host, ipStr, macStr string) (n *Neighbor, err error) {
|
||||
defer func() { err = errors.Annotate(err, "getting arp neighbor: %w") }()
|
||||
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
// Don't wrap the error, as it will get annotated.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mac, err := net.ParseMAC(macStr)
|
||||
if err != nil {
|
||||
// Don't wrap the error, as it will get annotated.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Neighbor{
|
||||
Name: host,
|
||||
IP: ip,
|
||||
MAC: mac,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Clone returns the deep copy of n.
|
||||
func (n Neighbor) Clone() (clone Neighbor) {
|
||||
return Neighbor{
|
||||
@@ -105,10 +80,10 @@ func (n Neighbor) Clone() (clone Neighbor) {
|
||||
|
||||
// validatedHostname returns h if it's a valid hostname, or an empty string
|
||||
// otherwise, logging the validation error.
|
||||
func validatedHostname(logger *slog.Logger, h string) (host string) {
|
||||
func validatedHostname(h string) (host string) {
|
||||
err := netutil.ValidateHostname(h)
|
||||
if err != nil {
|
||||
logger.Debug("parsing host of arp output", slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: host: %s", err)
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -157,18 +132,15 @@ func (ns *neighs) reset(with []Neighbor) {
|
||||
// parseNeighsFunc parses the text from sc as if it'd be an output of some
|
||||
// ARP-related command. lenHint is a hint for the size of the allocated slice
|
||||
// of Neighbors.
|
||||
//
|
||||
// TODO(s.chzhen): Return []*Neighbor instead.
|
||||
type parseNeighsFunc func(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor)
|
||||
type parseNeighsFunc func(sc *bufio.Scanner, lenHint int) (ns []Neighbor)
|
||||
|
||||
// cmdARPDB is the implementation of the [Interface] that uses command line to
|
||||
// retrieve data.
|
||||
type cmdARPDB struct {
|
||||
logger *slog.Logger
|
||||
parse parseNeighsFunc
|
||||
ns *neighs
|
||||
cmd string
|
||||
args []string
|
||||
parse parseNeighsFunc
|
||||
ns *neighs
|
||||
cmd string
|
||||
args []string
|
||||
}
|
||||
|
||||
// type check
|
||||
@@ -186,7 +158,7 @@ func (arp *cmdARPDB) Refresh() (err error) {
|
||||
}
|
||||
|
||||
sc := bufio.NewScanner(bytes.NewReader(out))
|
||||
ns := arp.parse(arp.logger, sc, arp.ns.len())
|
||||
ns := arp.parse(sc, arp.ns.len())
|
||||
if err = sc.Err(); err != nil {
|
||||
// TODO(e.burkov): This error seems unreachable. Investigate.
|
||||
return fmt.Errorf("scanning the output: %w", err)
|
||||
|
||||
@@ -4,17 +4,17 @@ package arpdb
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
func newARPDB() (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseArpA,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
@@ -33,7 +33,7 @@ func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
// The expected input format:
|
||||
//
|
||||
// host.name (192.168.0.1) at ff:ff:ff:ff:ff:ff on en0 ifscope [ethernet]
|
||||
func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
ln := sc.Text()
|
||||
@@ -48,15 +48,26 @@ func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighb
|
||||
continue
|
||||
}
|
||||
|
||||
host := validatedHostname(logger, fields[0])
|
||||
n, err := newNeighbor(host, ipStr[1:len(ipStr)-1], fields[3])
|
||||
ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1])
|
||||
if err != nil {
|
||||
logger.Debug("parsing arp output", "line", ln, slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, *n)
|
||||
hwStr := fields[3]
|
||||
mac, err := net.ParseMAC(hwStr)
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, Neighbor{
|
||||
IP: ip,
|
||||
MAC: mac,
|
||||
Name: validatedHostname(fields[0]),
|
||||
})
|
||||
}
|
||||
|
||||
return ns
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -62,7 +61,7 @@ func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
var a Interface
|
||||
require.NotPanics(t, func() { a = New(slogutil.NewDiscardLogger()) })
|
||||
require.NotPanics(t, func() { a = New() })
|
||||
|
||||
assert.NotNil(t, a)
|
||||
}
|
||||
@@ -202,9 +201,8 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
|
||||
func TestCmdARPDB_arpa(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
cmd: "cmd",
|
||||
parse: parseArpA,
|
||||
cmd: "cmd",
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
|
||||
@@ -6,18 +6,17 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *arpdbs) {
|
||||
func newARPDB() (arp *arpdbs) {
|
||||
// Use the common storage among the implementations.
|
||||
ns := &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
@@ -40,10 +39,9 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
|
||||
},
|
||||
// Then, try "arp -a -n".
|
||||
&cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseF,
|
||||
ns: ns,
|
||||
cmd: "arp",
|
||||
parse: parseF,
|
||||
ns: ns,
|
||||
cmd: "arp",
|
||||
// Use -n flag to avoid resolving the hostnames of the neighbors.
|
||||
// By default ARP attempts to resolve the hostnames via DNS. See
|
||||
// man 8 arp.
|
||||
@@ -53,11 +51,10 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
|
||||
},
|
||||
// Finally, try "ip neigh".
|
||||
&cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseIPNeigh,
|
||||
ns: ns,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
parse: parseIPNeigh,
|
||||
ns: ns,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -134,7 +131,7 @@ func (arp *fsysARPDB) Neighbors() (ns []Neighbor) {
|
||||
//
|
||||
// IP address HW type Flags HW address Mask Device
|
||||
// 192.168.11.98 0x1 0x2 5a:92:df:a9:7e:28 * wan
|
||||
func parseArpAWrt(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
func parseArpAWrt(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
if !sc.Scan() {
|
||||
// Skip the header.
|
||||
return
|
||||
@@ -149,14 +146,25 @@ func parseArpAWrt(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Nei
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := newNeighbor("", fields[0], fields[3])
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
logger.Debug("parsing arp output", "line", ln, slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, *n)
|
||||
hwStr := fields[3]
|
||||
mac, err := net.ParseMAC(hwStr)
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, Neighbor{
|
||||
IP: ip,
|
||||
MAC: mac,
|
||||
})
|
||||
}
|
||||
|
||||
return ns
|
||||
@@ -166,7 +174,7 @@ func parseArpAWrt(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Nei
|
||||
// expected input format:
|
||||
//
|
||||
// hostname (192.168.1.1) at ab:cd:ef:ab:cd:ef [ether] on enp0s3
|
||||
func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
ln := sc.Text()
|
||||
@@ -181,15 +189,26 @@ func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighb
|
||||
continue
|
||||
}
|
||||
|
||||
host := validatedHostname(logger, fields[0])
|
||||
n, err := newNeighbor(host, ipStr[1:len(ipStr)-1], fields[3])
|
||||
ip, err := netip.ParseAddr(ipStr[1 : len(ipStr)-1])
|
||||
if err != nil {
|
||||
logger.Debug("parsing arp output", "line", ln, slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, *n)
|
||||
hwStr := fields[3]
|
||||
mac, err := net.ParseMAC(hwStr)
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, Neighbor{
|
||||
IP: ip,
|
||||
MAC: mac,
|
||||
Name: validatedHostname(fields[0]),
|
||||
})
|
||||
}
|
||||
|
||||
return ns
|
||||
@@ -199,7 +218,7 @@ func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighb
|
||||
// expected input format:
|
||||
//
|
||||
// 192.168.1.1 dev enp0s3 lladdr ab:cd:ef:ab:cd:ef REACHABLE
|
||||
func parseIPNeigh(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
func parseIPNeigh(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
ln := sc.Text()
|
||||
@@ -209,14 +228,27 @@ func parseIPNeigh(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Nei
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := newNeighbor("", fields[0], fields[4])
|
||||
n := Neighbor{}
|
||||
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
logger.Debug("parsing arp output", "line", ln, slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
ns = append(ns, *n)
|
||||
mac, err := net.ParseMAC(fields[4])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.MAC = mac
|
||||
}
|
||||
|
||||
ns = append(ns, n)
|
||||
}
|
||||
|
||||
return ns
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -70,10 +69,9 @@ func TestCmdARPDB_linux(t *testing.T) {
|
||||
|
||||
t.Run("wrt", func(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
parse: parseArpAWrt,
|
||||
cmd: "arp",
|
||||
args: []string{"-a"},
|
||||
parse: parseArpAWrt,
|
||||
cmd: "arp",
|
||||
args: []string{"-a"},
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
@@ -88,10 +86,9 @@ func TestCmdARPDB_linux(t *testing.T) {
|
||||
|
||||
t.Run("ip_neigh", func(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
parse: parseIPNeigh,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
parse: parseIPNeigh,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
|
||||
@@ -4,17 +4,17 @@ package arpdb
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
func newARPDB() (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseArpA,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
@@ -34,7 +34,7 @@ func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
//
|
||||
// Host Ethernet Address Netif Expire Flags
|
||||
// 192.168.1.1 ab:cd:ef:ab:cd:ef em0 19m59s
|
||||
func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
// Skip the header.
|
||||
if !sc.Scan() {
|
||||
return nil
|
||||
@@ -49,14 +49,27 @@ func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighb
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := newNeighbor("", fields[0], fields[1])
|
||||
n := Neighbor{}
|
||||
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
logger.Debug("parsing arp output", "line", ln, slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.IP = ip
|
||||
}
|
||||
|
||||
ns = append(ns, *n)
|
||||
mac, err := net.ParseMAC(fields[1])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
} else {
|
||||
n.MAC = mac
|
||||
}
|
||||
|
||||
ns = append(ns, n)
|
||||
}
|
||||
|
||||
return ns
|
||||
|
||||
@@ -4,17 +4,17 @@ package arpdb
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
func newARPDB() (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseArpA,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
@@ -31,7 +31,7 @@ func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
// Internet Address Physical Address Type
|
||||
// 192.168.56.1 0a-00-27-00-00-00 dynamic
|
||||
// 192.168.56.255 ff-ff-ff-ff-ff-ff static
|
||||
func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
func parseArpA(sc *bufio.Scanner, lenHint int) (ns []Neighbor) {
|
||||
ns = make([]Neighbor, 0, lenHint)
|
||||
for sc.Scan() {
|
||||
ln := sc.Text()
|
||||
@@ -44,14 +44,24 @@ func parseArpA(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (ns []Neighb
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := newNeighbor("", fields[0], fields[1])
|
||||
ip, err := netip.ParseAddr(fields[0])
|
||||
if err != nil {
|
||||
logger.Debug("parsing arp output", "line", ln, slogutil.KeyError, err)
|
||||
log.Debug("arpdb: parsing arp output: ip: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, *n)
|
||||
mac, err := net.ParseMAC(fields[1])
|
||||
if err != nil {
|
||||
log.Debug("arpdb: parsing arp output: mac: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ns = append(ns, Neighbor{
|
||||
IP: ip,
|
||||
MAC: mac,
|
||||
})
|
||||
}
|
||||
|
||||
return ns
|
||||
|
||||
@@ -2,7 +2,6 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
@@ -40,10 +38,6 @@ func (EmptyAddrProc) Close() (_ error) { return nil }
|
||||
|
||||
// DefaultAddrProcConfig is the configuration structure for address processors.
|
||||
type DefaultAddrProcConfig struct {
|
||||
// BaseLogger is used to create loggers with custom prefixes for sources of
|
||||
// information about runtime clients. It must not be nil.
|
||||
BaseLogger *slog.Logger
|
||||
|
||||
// DialContext is used to create TCP connections to WHOIS servers.
|
||||
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
|
||||
DialContext aghnet.DialContextFunc
|
||||
@@ -153,7 +147,6 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
|
||||
if c.UseRDNS {
|
||||
p.rdns = rdns.New(&rdns.Config{
|
||||
Logger: c.BaseLogger.With(slogutil.KeyPrefix, "rdns"),
|
||||
Exchanger: c.Exchanger,
|
||||
CacheSize: defaultCacheSize,
|
||||
CacheTTL: defaultIPTTL,
|
||||
@@ -161,7 +154,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
}
|
||||
|
||||
if c.UseWHOIS {
|
||||
p.whois = newWHOIS(c.BaseLogger.With(slogutil.KeyPrefix, "whois"), c.DialContext)
|
||||
p.whois = newWHOIS(c.DialContext)
|
||||
}
|
||||
|
||||
go p.process(c.CatchPanics)
|
||||
@@ -175,7 +168,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||
|
||||
// newWHOIS returns a whois.Interface instance using the given function for
|
||||
// dialing.
|
||||
func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||
func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||
// TODO(s.chzhen): Consider making configurable.
|
||||
const (
|
||||
// defaultTimeout is the timeout for WHOIS requests.
|
||||
@@ -193,7 +186,6 @@ func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Int
|
||||
)
|
||||
|
||||
return whois.New(&whois.Config{
|
||||
Logger: logger,
|
||||
DialContext: dialFunc,
|
||||
ServerAddr: whois.DefaultServer,
|
||||
Port: whois.DefaultPort,
|
||||
@@ -235,11 +227,9 @@ func (p *DefaultAddrProc) process(catchPanics bool) {
|
||||
|
||||
log.Info("clients: processing addresses")
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
for ip := range p.clientIPs {
|
||||
host := p.processRDNS(ctx, ip)
|
||||
info := p.processWHOIS(ctx, ip)
|
||||
host := p.processRDNS(ip)
|
||||
info := p.processWHOIS(ip)
|
||||
|
||||
p.addrUpdater.UpdateAddress(ip, host, info)
|
||||
}
|
||||
@@ -249,7 +239,7 @@ func (p *DefaultAddrProc) process(catchPanics bool) {
|
||||
|
||||
// processRDNS resolves the clients' IP addresses using reverse DNS. host is
|
||||
// empty if there were errors or if the information hasn't changed.
|
||||
func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host string) {
|
||||
func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) {
|
||||
start := time.Now()
|
||||
log.Debug("clients: processing %s with rdns", ip)
|
||||
defer func() {
|
||||
@@ -261,7 +251,7 @@ func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host
|
||||
return
|
||||
}
|
||||
|
||||
host, changed := p.rdns.Process(ctx, ip)
|
||||
host, changed := p.rdns.Process(ip)
|
||||
if !changed {
|
||||
host = ""
|
||||
}
|
||||
@@ -278,7 +268,7 @@ func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) {
|
||||
// processWHOIS looks up the information about clients' IP addresses in the
|
||||
// WHOIS databases. info is nil if there were errors or if the information
|
||||
// hasn't changed.
|
||||
func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info *whois.Info) {
|
||||
func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) {
|
||||
start := time.Now()
|
||||
log.Debug("clients: processing %s with whois", ip)
|
||||
defer func() {
|
||||
@@ -287,7 +277,7 @@ func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info
|
||||
|
||||
// TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the
|
||||
// context.
|
||||
info, changed := p.whois.Process(ctx, ip)
|
||||
info, changed := p.whois.Process(context.Background(), ip)
|
||||
if !changed {
|
||||
info = nil
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakenet"
|
||||
@@ -100,7 +99,6 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
|
||||
updInfoCh := make(chan *whois.Info, 1)
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
@@ -210,7 +208,6 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
|
||||
updInfoCh := make(chan *whois.Info, 1)
|
||||
|
||||
p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{
|
||||
BaseLogger: slogutil.NewDiscardLogger(),
|
||||
DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) {
|
||||
return whoisConn, nil
|
||||
},
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
@@ -121,7 +120,6 @@ func (r *Runtime) Info() (cs Source, host string) {
|
||||
|
||||
// SetInfo sets a host as a client information from the cs.
|
||||
func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
||||
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
|
||||
if len(hosts) == 1 && hosts[0] == "" {
|
||||
hosts = []string{}
|
||||
}
|
||||
@@ -177,15 +175,3 @@ func (r *Runtime) isEmpty() (ok bool) {
|
||||
func (r *Runtime) Addr() (ip netip.Addr) {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the runtime client.
|
||||
func (r *Runtime) Clone() (c *Runtime) {
|
||||
return &Runtime{
|
||||
ip: r.ip,
|
||||
whois: r.whois.Clone(),
|
||||
arp: slices.Clone(r.arp),
|
||||
rdns: slices.Clone(r.rdns),
|
||||
dhcp: slices.Clone(r.dhcp),
|
||||
hostsFile: slices.Clone(r.hostsFile),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
@@ -32,6 +31,8 @@ type Storage struct {
|
||||
index *index
|
||||
|
||||
// runtimeIndex contains information about runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
runtimeIndex *RuntimeIndex
|
||||
}
|
||||
|
||||
@@ -235,75 +236,20 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||
return s.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// UpdateRuntime updates the stored runtime client with information from rc. If
|
||||
// no such client exists, saves the copy of rc in storage. rc must not be nil.
|
||||
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
|
||||
// AddRuntime saves the runtime client information in the storage. IP address
|
||||
// of a client must be unique. rc must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) AddRuntime(rc *Runtime) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.updateRuntimeLocked(rc)
|
||||
}
|
||||
|
||||
// updateRuntimeLocked updates the stored runtime client with information from
|
||||
// rc. rc must not be nil. Storage.mu is expected to be locked.
|
||||
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
|
||||
stored := s.runtimeIndex.Client(rc.ip)
|
||||
if stored == nil {
|
||||
s.runtimeIndex.Add(rc.Clone())
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if rc.whois != nil {
|
||||
stored.whois = rc.whois.Clone()
|
||||
}
|
||||
|
||||
if rc.arp != nil {
|
||||
stored.arp = slices.Clone(rc.arp)
|
||||
}
|
||||
|
||||
if rc.rdns != nil {
|
||||
stored.rdns = slices.Clone(rc.rdns)
|
||||
}
|
||||
|
||||
if rc.dhcp != nil {
|
||||
stored.dhcp = slices.Clone(rc.dhcp)
|
||||
}
|
||||
|
||||
if rc.hostsFile != nil {
|
||||
stored.hostsFile = slices.Clone(rc.hostsFile)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// BatchUpdateBySource updates the stored runtime clients information from the
|
||||
// specified source and returns the number of added and removed clients.
|
||||
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, rc := range s.runtimeIndex.index {
|
||||
rc.unset(src)
|
||||
}
|
||||
|
||||
for _, rc := range rcs {
|
||||
if s.updateRuntimeLocked(rc) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, rc := range s.runtimeIndex.index {
|
||||
if rc.isEmpty() {
|
||||
delete(s.runtimeIndex.index, ip)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
|
||||
return added, removed
|
||||
s.runtimeIndex.Add(rc)
|
||||
}
|
||||
|
||||
// SizeRuntime returns the number of the runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) SizeRuntime() (n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -312,6 +258,8 @@ func (s *Storage) SizeRuntime() (n int) {
|
||||
}
|
||||
|
||||
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -319,6 +267,16 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
s.runtimeIndex.Range(f)
|
||||
}
|
||||
|
||||
// DeleteRuntime removes the runtime client by ip.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) DeleteRuntime(ip netip.Addr) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Delete(ip)
|
||||
}
|
||||
|
||||
// DeleteBySource removes all runtime clients that have information only from
|
||||
// the specified source and returns the number of removed clients.
|
||||
//
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -26,19 +25,9 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||
require.NoError(tb, s.Add(c))
|
||||
}
|
||||
|
||||
require.Equal(tb, len(m), s.Size())
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// newRuntimeClient is a helper function that returns a new runtime client.
|
||||
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
|
||||
rc = client.NewRuntime(ip)
|
||||
rc.SetInfo(source, []string{host})
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||
// error.
|
||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||
@@ -54,9 +43,6 @@ func TestStorage_Add(t *testing.T) {
|
||||
const (
|
||||
existingName = "existing_name"
|
||||
existingClientID = "existing_client_id"
|
||||
|
||||
allowedTag = "tag"
|
||||
notAllowedTag = "not_allowed_tag"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -74,7 +60,7 @@ func TestStorage_Add(t *testing.T) {
|
||||
}
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: []string{allowedTag},
|
||||
AllowedTags: nil,
|
||||
})
|
||||
err := s.Add(existingClient)
|
||||
require.NoError(t, err)
|
||||
@@ -133,15 +119,6 @@ func TestStorage_Add(t *testing.T) {
|
||||
},
|
||||
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||
`uses the same ClientID "existing_client_id"`,
|
||||
}, {
|
||||
name: "not_allowed_tag",
|
||||
cli: &client.Persistent{
|
||||
Name: "nont_allowed_tag",
|
||||
Tags: []string{notAllowedTag},
|
||||
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
||||
UID: client.MustNewUID(),
|
||||
},
|
||||
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -364,127 +341,6 @@ func TestStorage_FindLoose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByName(t *testing.T) {
|
||||
const (
|
||||
cliIP1 = "1.1.1.1"
|
||||
cliIP2 = "2.2.2.2"
|
||||
)
|
||||
|
||||
const (
|
||||
clientExistingName = "client_existing"
|
||||
clientAnotherExistingName = "client_another_existing"
|
||||
nonExistingClientName = "client_non_existing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: clientExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: clientAnotherExistingName,
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientName string
|
||||
}{{
|
||||
name: "existing",
|
||||
clientName: clientExistingName,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientName: clientAnotherExistingName,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientName: nonExistingClientName,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByName(tc.clientName)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_FindByMAC(t *testing.T) {
|
||||
var (
|
||||
cliMAC = mustParseMAC("11:11:11:11:11:11")
|
||||
cliAnotherMAC = mustParseMAC("22:22:22:22:22:22")
|
||||
nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33")
|
||||
)
|
||||
|
||||
var (
|
||||
clientExisting = &client.Persistent{
|
||||
Name: "client",
|
||||
MACs: []net.HardwareAddr{cliMAC},
|
||||
}
|
||||
|
||||
clientAnotherExisting = &client.Persistent{
|
||||
Name: "another_client",
|
||||
MACs: []net.HardwareAddr{cliAnotherMAC},
|
||||
}
|
||||
)
|
||||
|
||||
clients := []*client.Persistent{
|
||||
clientExisting,
|
||||
clientAnotherExisting,
|
||||
}
|
||||
s := newStorage(t, clients)
|
||||
|
||||
testCases := []struct {
|
||||
want *client.Persistent
|
||||
name string
|
||||
clientMAC net.HardwareAddr
|
||||
}{{
|
||||
name: "existing",
|
||||
clientMAC: cliMAC,
|
||||
want: clientExisting,
|
||||
}, {
|
||||
name: "another_existing",
|
||||
clientMAC: cliAnotherMAC,
|
||||
want: clientAnotherExisting,
|
||||
}, {
|
||||
name: "non_existing",
|
||||
clientMAC: nonExistingClientMAC,
|
||||
want: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c, ok := s.FindByMAC(tc.clientMAC)
|
||||
if tc.want == nil {
|
||||
assert.False(t, ok)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tc.want, c)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_Update(t *testing.T) {
|
||||
const (
|
||||
clientName = "client_name"
|
||||
@@ -623,157 +479,3 @@ func TestStorage_RangeByName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_UpdateRuntime(t *testing.T) {
|
||||
const (
|
||||
addedARP = "added_arp"
|
||||
addedSecondARP = "added_arp"
|
||||
|
||||
updatedARP = "updated_arp"
|
||||
|
||||
cliCity = "City"
|
||||
cliCountry = "Country"
|
||||
cliOrgname = "Orgname"
|
||||
)
|
||||
|
||||
var (
|
||||
ip = netip.MustParseAddr("1.1.1.1")
|
||||
ip2 = netip.MustParseAddr("2.2.2.2")
|
||||
)
|
||||
|
||||
updated := client.NewRuntime(ip)
|
||||
updated.SetInfo(client.SourceARP, []string{updatedARP})
|
||||
|
||||
info := &whois.Info{
|
||||
City: cliCity,
|
||||
Country: cliCountry,
|
||||
Orgname: cliOrgname,
|
||||
}
|
||||
updated.SetWHOIS(info)
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
|
||||
t.Run("add_arp_client", func(t *testing.T) {
|
||||
added := client.NewRuntime(ip)
|
||||
added.SetInfo(client.SourceARP, []string{addedARP})
|
||||
|
||||
require.True(t, s.UpdateRuntime(added))
|
||||
require.Equal(t, 1, s.SizeRuntime())
|
||||
|
||||
got := s.ClientRuntime(ip)
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, addedARP, host)
|
||||
})
|
||||
|
||||
t.Run("add_second_arp_client", func(t *testing.T) {
|
||||
added := client.NewRuntime(ip2)
|
||||
added.SetInfo(client.SourceARP, []string{addedSecondARP})
|
||||
|
||||
require.True(t, s.UpdateRuntime(added))
|
||||
require.Equal(t, 2, s.SizeRuntime())
|
||||
|
||||
got := s.ClientRuntime(ip2)
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, addedSecondARP, host)
|
||||
})
|
||||
|
||||
t.Run("update_first_client", func(t *testing.T) {
|
||||
require.False(t, s.UpdateRuntime(updated))
|
||||
got := s.ClientRuntime(ip)
|
||||
require.Equal(t, 2, s.SizeRuntime())
|
||||
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, updatedARP, host)
|
||||
})
|
||||
|
||||
t.Run("remove_arp_info", func(t *testing.T) {
|
||||
n := s.DeleteBySource(client.SourceARP)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, 1, s.SizeRuntime())
|
||||
|
||||
got := s.ClientRuntime(ip)
|
||||
source, _ := got.Info()
|
||||
assert.Equal(t, client.SourceWHOIS, source)
|
||||
assert.Equal(t, info, got.WHOIS())
|
||||
})
|
||||
|
||||
t.Run("remove_whois_info", func(t *testing.T) {
|
||||
n := s.DeleteBySource(client.SourceWHOIS)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, 0, s.SizeRuntime())
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_BatchUpdateBySource(t *testing.T) {
|
||||
const (
|
||||
defSrc = client.SourceARP
|
||||
|
||||
cliFirstHost1 = "host1"
|
||||
cliFirstHost2 = "host2"
|
||||
cliUpdatedHost3 = "host3"
|
||||
cliUpdatedHost4 = "host4"
|
||||
cliUpdatedHost5 = "host5"
|
||||
)
|
||||
|
||||
var (
|
||||
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
|
||||
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
|
||||
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
|
||||
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
|
||||
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
|
||||
)
|
||||
|
||||
firstClients := []*client.Runtime{
|
||||
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
|
||||
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
|
||||
}
|
||||
|
||||
updatedClients := []*client.Runtime{
|
||||
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
|
||||
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
|
||||
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
|
||||
}
|
||||
|
||||
s := client.NewStorage(&client.Config{
|
||||
AllowedTags: nil,
|
||||
})
|
||||
|
||||
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
|
||||
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
|
||||
require.Equal(t, len(firstClients), added)
|
||||
require.Equal(t, 0, removed)
|
||||
require.Equal(t, len(firstClients), s.SizeRuntime())
|
||||
|
||||
rc := s.ClientRuntime(cliFirstIP1)
|
||||
src, host := rc.Info()
|
||||
assert.Equal(t, defSrc, src)
|
||||
assert.Equal(t, cliFirstHost1, host)
|
||||
})
|
||||
|
||||
t.Run("update_storage", func(t *testing.T) {
|
||||
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
|
||||
require.Equal(t, len(updatedClients), added)
|
||||
require.Equal(t, len(firstClients), removed)
|
||||
require.Equal(t, len(updatedClients), s.SizeRuntime())
|
||||
|
||||
rc := s.ClientRuntime(cliUpdatedIP3)
|
||||
src, host := rc.Info()
|
||||
assert.Equal(t, defSrc, src)
|
||||
assert.Equal(t, cliUpdatedHost3, host)
|
||||
|
||||
rc = s.ClientRuntime(cliFirstIP1)
|
||||
assert.Nil(t, rc)
|
||||
})
|
||||
|
||||
t.Run("remove_all", func(t *testing.T) {
|
||||
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
|
||||
require.Equal(t, 0, added)
|
||||
require.Equal(t, len(updatedClients), removed)
|
||||
require.Equal(t, 0, s.SizeRuntime())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package dhcpsvc
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
@@ -24,8 +23,7 @@ type Config struct {
|
||||
// clients' hostnames.
|
||||
LocalDomainName string
|
||||
|
||||
// DBFilePath is the path to the database file containing the DHCP leases.
|
||||
DBFilePath string
|
||||
// TODO(e.burkov): Add DB path.
|
||||
|
||||
// ICMPTimeout is the timeout for checking another DHCP server's presence.
|
||||
ICMPTimeout time.Duration
|
||||
@@ -66,12 +64,6 @@ func (conf *Config) Validate() (err error) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
// This is a best-effort check for the file accessibility. The file will be
|
||||
// checked again when it is opened later.
|
||||
if _, err = os.Stat(conf.DBFilePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
errs = append(errs, fmt.Errorf("db file path %q: %w", conf.DBFilePath, err))
|
||||
}
|
||||
|
||||
if len(conf.Interfaces) == 0 {
|
||||
errs = append(errs, errNoInterfaces)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dhcpsvc_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
@@ -9,8 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
leasesPath := filepath.Join(t.TempDir(), "leases.json")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
conf *dhcpsvc.Config
|
||||
@@ -28,7 +25,6 @@ func TestConfig_Validate(t *testing.T) {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
wantErrMsg: `bad domain name "": domain name is empty`,
|
||||
}, {
|
||||
@@ -36,7 +32,6 @@ func TestConfig_Validate(t *testing.T) {
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: nil,
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "no_interfaces",
|
||||
wantErrMsg: "no interfaces specified",
|
||||
@@ -45,7 +40,6 @@ func TestConfig_Validate(t *testing.T) {
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: nil,
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "no_interfaces",
|
||||
wantErrMsg: "no interfaces specified",
|
||||
@@ -56,7 +50,6 @@ func TestConfig_Validate(t *testing.T) {
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": nil,
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "nil_interface",
|
||||
wantErrMsg: `interface "eth0": config is nil`,
|
||||
@@ -70,7 +63,6 @@ func TestConfig_Validate(t *testing.T) {
|
||||
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
|
||||
},
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "nil_ipv4",
|
||||
wantErrMsg: `interface "eth0": ipv4: config is nil`,
|
||||
@@ -84,7 +76,6 @@ func TestConfig_Validate(t *testing.T) {
|
||||
IPv6: nil,
|
||||
},
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "nil_ipv6",
|
||||
wantErrMsg: `interface "eth0": ipv6: config is nil`,
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
)
|
||||
|
||||
// dataVersion is the current version of the stored DHCP leases structure.
|
||||
const dataVersion = 1
|
||||
|
||||
// databasePerm is the permissions for the database file.
|
||||
const databasePerm fs.FileMode = 0o640
|
||||
|
||||
// dataLeases is the structure of the stored DHCP leases.
|
||||
type dataLeases struct {
|
||||
// Leases is the list containing stored DHCP leases.
|
||||
Leases []*dbLease `json:"leases"`
|
||||
|
||||
// Version is the current version of the structure.
|
||||
Version int `json:"version"`
|
||||
}
|
||||
|
||||
// dbLease is the structure of stored lease.
|
||||
type dbLease struct {
|
||||
Expiry string `json:"expires"`
|
||||
IP netip.Addr `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
HWAddr string `json:"mac"`
|
||||
IsStatic bool `json:"static"`
|
||||
}
|
||||
|
||||
// compareNames returns the result of comparing the hostnames of dl and other
|
||||
// lexicographically.
|
||||
func (dl *dbLease) compareNames(other *dbLease) (res int) {
|
||||
return strings.Compare(dl.Hostname, other.Hostname)
|
||||
}
|
||||
|
||||
// toDBLease converts *Lease to *dbLease.
|
||||
func toDBLease(l *Lease) (dl *dbLease) {
|
||||
var expiryStr string
|
||||
if !l.IsStatic {
|
||||
// The front-end is waiting for RFC 3999 format of the time value. It
|
||||
// also shouldn't got an Expiry field for static leases.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
|
||||
expiryStr = l.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
return &dbLease{
|
||||
Expiry: expiryStr,
|
||||
Hostname: l.Hostname,
|
||||
HWAddr: l.HWAddr.String(),
|
||||
IP: l.IP,
|
||||
IsStatic: l.IsStatic,
|
||||
}
|
||||
}
|
||||
|
||||
// toInternal converts dl to *Lease.
|
||||
func (dl *dbLease) toInternal() (l *Lease, err error) {
|
||||
mac, err := net.ParseMAC(dl.HWAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing hardware address: %w", err)
|
||||
}
|
||||
|
||||
expiry := time.Time{}
|
||||
if !dl.IsStatic {
|
||||
expiry, err = time.Parse(time.RFC3339, dl.Expiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing expiry time: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Lease{
|
||||
Expiry: expiry,
|
||||
IP: dl.IP,
|
||||
Hostname: dl.Hostname,
|
||||
HWAddr: mac,
|
||||
IsStatic: dl.IsStatic,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dbLoad loads stored leases. It must only be called before the service has
|
||||
// been started.
|
||||
func (srv *DHCPServer) dbLoad(ctx context.Context) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "loading db: %w") }()
|
||||
|
||||
file, err := os.Open(srv.dbFilePath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("reading db: %w", err)
|
||||
}
|
||||
|
||||
srv.logger.DebugContext(ctx, "no db file found")
|
||||
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
err = errors.WithDeferred(err, file.Close())
|
||||
}()
|
||||
|
||||
dl := &dataLeases{}
|
||||
err = json.NewDecoder(file).Decode(dl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decoding db: %w", err)
|
||||
}
|
||||
|
||||
srv.resetLeases()
|
||||
srv.addDBLeases(ctx, dl.Leases)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addDBLeases adds leases to the server.
|
||||
func (srv *DHCPServer) addDBLeases(ctx context.Context, leases []*dbLease) {
|
||||
var v4, v6 uint
|
||||
for i, l := range leases {
|
||||
lease, err := l.toInternal()
|
||||
if err != nil {
|
||||
srv.logger.WarnContext(ctx, "converting lease", "idx", i, slogutil.KeyError, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
iface, err := srv.ifaceForAddr(l.IP)
|
||||
if err != nil {
|
||||
srv.logger.WarnContext(ctx, "searching lease iface", "idx", i, slogutil.KeyError, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = srv.leases.add(lease, iface)
|
||||
if err != nil {
|
||||
srv.logger.WarnContext(ctx, "adding lease", "idx", i, slogutil.KeyError, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if l.IP.Is4() {
|
||||
v4++
|
||||
} else {
|
||||
v6++
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Group by interface.
|
||||
srv.logger.InfoContext(ctx, "loaded leases", "v4", v4, "v6", v6, "total", len(leases))
|
||||
}
|
||||
|
||||
// writeDB writes leases to the database file. It expects the
|
||||
// [DHCPServer.leasesMu] to be locked.
|
||||
func (srv *DHCPServer) dbStore(ctx context.Context) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "writing db: %w") }()
|
||||
|
||||
dl := &dataLeases{
|
||||
// Avoid writing null into the database file if there are no leases.
|
||||
Leases: make([]*dbLease, 0, srv.leases.len()),
|
||||
Version: dataVersion,
|
||||
}
|
||||
|
||||
srv.leases.rangeLeases(func(l *Lease) (cont bool) {
|
||||
lease := toDBLease(l)
|
||||
i, _ := slices.BinarySearchFunc(dl.Leases, lease, (*dbLease).compareNames)
|
||||
dl.Leases = slices.Insert(dl.Leases, i, lease)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
buf, err := json.Marshal(dl)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = maybe.WriteFile(srv.dbFilePath, buf, databasePerm)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
srv.logger.InfoContext(ctx, "stored leases", "num", len(dl.Leases), "file", srv.dbFilePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
// DatabasePerm is the permissions for the test database file.
|
||||
const DatabasePerm = databasePerm
|
||||
@@ -50,7 +50,7 @@ type Interface interface {
|
||||
IPByHost(host string) (ip netip.Addr)
|
||||
|
||||
// Leases returns all the active DHCP leases. The returned slice should be
|
||||
// a clone. The order of leases is undefined.
|
||||
// a clone.
|
||||
//
|
||||
// TODO(e.burkov): Consider implementing iterating methods with appropriate
|
||||
// signatures instead of cloning the whole list.
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package dhcpsvc_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testLocalTLD is a common local TLD for tests.
|
||||
const testLocalTLD = "local"
|
||||
|
||||
// testTimeout is a common timeout for tests and contexts.
|
||||
const testTimeout time.Duration = 10 * time.Second
|
||||
|
||||
// discardLog is a logger to discard test output.
|
||||
var discardLog = slogutil.NewDiscardLogger()
|
||||
|
||||
// testInterfaceConf is a common set of interface configurations for tests.
|
||||
var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("192.168.0.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("192.168.0.2"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.0.254"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
},
|
||||
IPv6: &dhcpsvc.IPv6Config{
|
||||
Enabled: true,
|
||||
RangeStart: netip.MustParseAddr("2001:db8::1"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
RAAllowSLAAC: true,
|
||||
RASLAACOnly: true,
|
||||
},
|
||||
},
|
||||
"eth1": {
|
||||
IPv4: &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("172.16.0.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("172.16.0.2"),
|
||||
RangeEnd: netip.MustParseAddr("172.16.0.255"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
},
|
||||
IPv6: &dhcpsvc.IPv6Config{
|
||||
Enabled: true,
|
||||
RangeStart: netip.MustParseAddr("2001:db9::1"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
RAAllowSLAAC: true,
|
||||
RASLAACOnly: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// mustParseMAC parses a hardware address from s and requires no errors.
|
||||
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
|
||||
mac, err := net.ParseMAC(s)
|
||||
require.NoError(t, err)
|
||||
|
||||
return mac
|
||||
}
|
||||
@@ -3,74 +3,42 @@ package dhcpsvc
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
// macKey contains hardware address as byte array of 6, 8, or 20 bytes.
|
||||
//
|
||||
// TODO(e.burkov): Move to aghnet or even to netutil.
|
||||
type macKey any
|
||||
|
||||
// macToKey converts mac into macKey, which is used as the key for the lease
|
||||
// maps. mac must be a valid hardware address of length 6, 8, or 20 bytes, see
|
||||
// [netutil.ValidateMAC].
|
||||
func macToKey(mac net.HardwareAddr) (key macKey) {
|
||||
switch len(mac) {
|
||||
case 6:
|
||||
return [6]byte(mac)
|
||||
case 8:
|
||||
return [8]byte(mac)
|
||||
case 20:
|
||||
return [20]byte(mac)
|
||||
default:
|
||||
panic(fmt.Errorf("invalid mac address %#v", mac))
|
||||
}
|
||||
}
|
||||
|
||||
// netInterface is a common part of any interface within the DHCP server.
|
||||
// netInterface is a common part of any network interface within the DHCP
|
||||
// server.
|
||||
//
|
||||
// TODO(e.burkov): Add other methods as [DHCPServer] evolves.
|
||||
type netInterface struct {
|
||||
// logger logs the events related to the network interface.
|
||||
logger *slog.Logger
|
||||
|
||||
// leases is the set of DHCP leases assigned to this interface.
|
||||
leases map[macKey]*Lease
|
||||
|
||||
// name is the name of the network interface.
|
||||
name string
|
||||
|
||||
// leases is a set of leases sorted by hardware address.
|
||||
leases []*Lease
|
||||
|
||||
// leaseTTL is the default Time-To-Live value for leases.
|
||||
leaseTTL time.Duration
|
||||
}
|
||||
|
||||
// newNetInterface creates a new netInterface with the given name, leaseTTL, and
|
||||
// logger.
|
||||
func newNetInterface(name string, l *slog.Logger, leaseTTL time.Duration) (iface *netInterface) {
|
||||
return &netInterface{
|
||||
logger: l,
|
||||
leases: map[macKey]*Lease{},
|
||||
name: name,
|
||||
leaseTTL: leaseTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// reset clears all the slices in iface for reuse.
|
||||
func (iface *netInterface) reset() {
|
||||
clear(iface.leases)
|
||||
iface.leases = iface.leases[:0]
|
||||
}
|
||||
|
||||
// addLease inserts the given lease into iface. It returns an error if the
|
||||
// insertLease inserts the given lease into iface. It returns an error if the
|
||||
// lease can't be inserted.
|
||||
func (iface *netInterface) addLease(l *Lease) (err error) {
|
||||
mk := macToKey(l.HWAddr)
|
||||
_, found := iface.leases[mk]
|
||||
func (iface *netInterface) insertLease(l *Lease) (err error) {
|
||||
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
|
||||
if found {
|
||||
return fmt.Errorf("lease for mac %s already exists", l.HWAddr)
|
||||
}
|
||||
|
||||
iface.leases[mk] = l
|
||||
iface.leases = slices.Insert(iface.leases, i, l)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -78,13 +46,12 @@ func (iface *netInterface) addLease(l *Lease) (err error) {
|
||||
// updateLease replaces an existing lease within iface with the given one. It
|
||||
// returns an error if there is no lease with such hardware address.
|
||||
func (iface *netInterface) updateLease(l *Lease) (prev *Lease, err error) {
|
||||
mk := macToKey(l.HWAddr)
|
||||
prev, found := iface.leases[mk]
|
||||
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("no lease for mac %s", l.HWAddr)
|
||||
}
|
||||
|
||||
iface.leases[mk] = l
|
||||
prev, iface.leases[i] = iface.leases[i], l
|
||||
|
||||
return prev, nil
|
||||
}
|
||||
@@ -92,13 +59,12 @@ func (iface *netInterface) updateLease(l *Lease) (prev *Lease, err error) {
|
||||
// removeLease removes an existing lease from iface. It returns an error if
|
||||
// there is no lease equal to l.
|
||||
func (iface *netInterface) removeLease(l *Lease) (err error) {
|
||||
mk := macToKey(l.HWAddr)
|
||||
_, found := iface.leases[mk]
|
||||
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
|
||||
if !found {
|
||||
return fmt.Errorf("no lease for mac %s", l.HWAddr)
|
||||
}
|
||||
|
||||
delete(iface.leases, mk)
|
||||
iface.leases = slices.Delete(iface.leases, i, i+1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -44,3 +45,8 @@ func (l *Lease) Clone() (clone *Lease) {
|
||||
IsStatic: l.IsStatic,
|
||||
}
|
||||
}
|
||||
|
||||
// compareLeaseMAC compares two [Lease]s by hardware address.
|
||||
func compareLeaseMAC(a, b *Lease) (res int) {
|
||||
return bytes.Compare(a.HWAddr, b.HWAddr)
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (idx *leaseIndex) add(l *Lease, iface *netInterface) (err error) {
|
||||
return fmt.Errorf("lease for hostname %s already exists", l.Hostname)
|
||||
}
|
||||
|
||||
err = iface.addLease(l)
|
||||
err = iface.insertLease(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,18 +124,3 @@ func (idx *leaseIndex) update(l *Lease, iface *netInterface) (err error) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rangeLeases calls f for each lease in idx in an unspecified order until f
|
||||
// returns false.
|
||||
func (idx *leaseIndex) rangeLeases(f func(l *Lease) (cont bool)) {
|
||||
for _, l := range idx.byName {
|
||||
if !f(l) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// len returns the number of leases in idx.
|
||||
func (idx *leaseIndex) len() (l uint) {
|
||||
return uint(len(idx.byAddr))
|
||||
}
|
||||
|
||||
@@ -27,13 +27,6 @@ type DHCPServer struct {
|
||||
// hostnames.
|
||||
localTLD string
|
||||
|
||||
// dbFilePath is the path to the database file containing the DHCP leases.
|
||||
//
|
||||
// TODO(e.burkov): Consider extracting the database logic into a separate
|
||||
// interface to prevent packages that only need lease data from depending on
|
||||
// the entire server and to simplify testing.
|
||||
dbFilePath string
|
||||
|
||||
// leasesMu protects the leases index as well as leases in the interfaces.
|
||||
leasesMu *sync.RWMutex
|
||||
|
||||
@@ -41,10 +34,10 @@ type DHCPServer struct {
|
||||
leases *leaseIndex
|
||||
|
||||
// interfaces4 is the set of IPv4 interfaces sorted by interface name.
|
||||
interfaces4 dhcpInterfacesV4
|
||||
interfaces4 netInterfacesV4
|
||||
|
||||
// interfaces6 is the set of IPv6 interfaces sorted by interface name.
|
||||
interfaces6 dhcpInterfacesV6
|
||||
interfaces6 netInterfacesV6
|
||||
|
||||
// icmpTimeout is the timeout for checking another DHCP server's presence.
|
||||
icmpTimeout time.Duration
|
||||
@@ -63,9 +56,28 @@ func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ifaces4, ifaces6, err := newInterfaces(ctx, l, conf.Interfaces)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
// TODO(e.burkov): Add validations scoped to the network interfaces set.
|
||||
ifaces4 := make(netInterfacesV4, 0, len(conf.Interfaces))
|
||||
ifaces6 := make(netInterfacesV6, 0, len(conf.Interfaces))
|
||||
var errs []error
|
||||
|
||||
mapsutil.SortedRange(conf.Interfaces, func(name string, iface *InterfaceConfig) (cont bool) {
|
||||
var i4 *netInterfaceV4
|
||||
i4, err = newNetInterfaceV4(ctx, l, name, iface.IPv4)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err))
|
||||
} else if i4 != nil {
|
||||
ifaces4 = append(ifaces4, i4)
|
||||
}
|
||||
|
||||
i6 := newNetInterfaceV6(ctx, l, name, iface.IPv6)
|
||||
if i6 != nil {
|
||||
ifaces6 = append(ifaces6, i6)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
if err = errors.Join(errs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -81,55 +93,13 @@ func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) {
|
||||
interfaces4: ifaces4,
|
||||
interfaces6: ifaces6,
|
||||
icmpTimeout: conf.ICMPTimeout,
|
||||
dbFilePath: conf.DBFilePath,
|
||||
}
|
||||
|
||||
err = srv.dbLoad(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
// TODO(e.burkov): Load leases.
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// newInterfaces creates interfaces for the given map of interface names to
|
||||
// their configurations.
|
||||
func newInterfaces(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
ifaces map[string]*InterfaceConfig,
|
||||
) (v4 dhcpInterfacesV4, v6 dhcpInterfacesV6, err error) {
|
||||
defer func() { err = errors.Annotate(err, "creating interfaces: %w") }()
|
||||
|
||||
// TODO(e.burkov): Add validations scoped to the network interfaces set.
|
||||
v4 = make(dhcpInterfacesV4, 0, len(ifaces))
|
||||
v6 = make(dhcpInterfacesV6, 0, len(ifaces))
|
||||
|
||||
var errs []error
|
||||
mapsutil.SortedRange(ifaces, func(name string, iface *InterfaceConfig) (cont bool) {
|
||||
var i4 *dhcpInterfaceV4
|
||||
i4, err = newDHCPInterfaceV4(ctx, l, name, iface.IPv4)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err))
|
||||
} else if i4 != nil {
|
||||
v4 = append(v4, i4)
|
||||
}
|
||||
|
||||
i6 := newDHCPInterfaceV6(ctx, l, name, iface.IPv6)
|
||||
if i6 != nil {
|
||||
v6 = append(v6, i6)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
if err = errors.Join(errs...); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return v4, v6, nil
|
||||
}
|
||||
|
||||
// type check
|
||||
//
|
||||
// TODO(e.burkov): Uncomment when the [Interface] interface is implemented.
|
||||
@@ -145,11 +115,16 @@ func (srv *DHCPServer) Leases() (leases []*Lease) {
|
||||
srv.leasesMu.RLock()
|
||||
defer srv.leasesMu.RUnlock()
|
||||
|
||||
srv.leases.rangeLeases(func(l *Lease) (cont bool) {
|
||||
leases = append(leases, l.Clone())
|
||||
|
||||
return true
|
||||
})
|
||||
for _, iface := range srv.interfaces4 {
|
||||
for _, lease := range iface.leases {
|
||||
leases = append(leases, lease.Clone())
|
||||
}
|
||||
}
|
||||
for _, iface := range srv.interfaces6 {
|
||||
for _, lease := range iface.leases {
|
||||
leases = append(leases, lease.Clone())
|
||||
}
|
||||
}
|
||||
|
||||
return leases
|
||||
}
|
||||
@@ -192,35 +167,22 @@ func (srv *DHCPServer) IPByHost(host string) (ip netip.Addr) {
|
||||
|
||||
// Reset implements the [Interface] interface for *DHCPServer.
|
||||
func (srv *DHCPServer) Reset(ctx context.Context) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "resetting leases: %w") }()
|
||||
|
||||
srv.leasesMu.Lock()
|
||||
defer srv.leasesMu.Unlock()
|
||||
|
||||
srv.resetLeases()
|
||||
err = srv.dbStore(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since there is already an annotation deferred.
|
||||
return err
|
||||
for _, iface := range srv.interfaces4 {
|
||||
iface.reset()
|
||||
}
|
||||
for _, iface := range srv.interfaces6 {
|
||||
iface.reset()
|
||||
}
|
||||
srv.leases.clear()
|
||||
|
||||
srv.logger.DebugContext(ctx, "reset leases")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetLeases resets the leases for all network interfaces of the server. It
|
||||
// expects the DHCPServer.leasesMu to be locked.
|
||||
func (srv *DHCPServer) resetLeases() {
|
||||
for _, iface := range srv.interfaces4 {
|
||||
iface.common.reset()
|
||||
}
|
||||
for _, iface := range srv.interfaces6 {
|
||||
iface.common.reset()
|
||||
}
|
||||
srv.leases.clear()
|
||||
}
|
||||
|
||||
// AddLease implements the [Interface] interface for *DHCPServer.
|
||||
func (srv *DHCPServer) AddLease(ctx context.Context, l *Lease) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "adding lease: %w") }()
|
||||
@@ -228,7 +190,7 @@ func (srv *DHCPServer) AddLease(ctx context.Context, l *Lease) (err error) {
|
||||
addr := l.IP
|
||||
iface, err := srv.ifaceForAddr(addr)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's already informative enough as is.
|
||||
// Don't wrap the error since there is already an annotation deferred.
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -241,12 +203,6 @@ func (srv *DHCPServer) AddLease(ctx context.Context, l *Lease) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = srv.dbStore(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's already informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
iface.logger.DebugContext(
|
||||
ctx, "added lease",
|
||||
"hostname", l.Hostname,
|
||||
@@ -267,7 +223,7 @@ func (srv *DHCPServer) UpdateStaticLease(ctx context.Context, l *Lease) (err err
|
||||
addr := l.IP
|
||||
iface, err := srv.ifaceForAddr(addr)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's already informative enough as is.
|
||||
// Don't wrap the error since there is already an annotation deferred.
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -280,12 +236,6 @@ func (srv *DHCPServer) UpdateStaticLease(ctx context.Context, l *Lease) (err err
|
||||
return err
|
||||
}
|
||||
|
||||
err = srv.dbStore(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's already informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
iface.logger.DebugContext(
|
||||
ctx, "updated lease",
|
||||
"hostname", l.Hostname,
|
||||
@@ -304,7 +254,7 @@ func (srv *DHCPServer) RemoveLease(ctx context.Context, l *Lease) (err error) {
|
||||
addr := l.IP
|
||||
iface, err := srv.ifaceForAddr(addr)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's already informative enough as is.
|
||||
// Don't wrap the error since there is already an annotation deferred.
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -317,12 +267,6 @@ func (srv *DHCPServer) RemoveLease(ctx context.Context, l *Lease) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = srv.dbStore(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's already informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
iface.logger.DebugContext(
|
||||
ctx, "removed lease",
|
||||
"hostname", l.Hostname,
|
||||
|
||||
@@ -1,41 +1,72 @@
|
||||
package dhcpsvc_test
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testdata is a filesystem containing data for tests.
|
||||
var testdata = os.DirFS("testdata")
|
||||
// testLocalTLD is a common local TLD for tests.
|
||||
const testLocalTLD = "local"
|
||||
|
||||
// newTempDB copies the leases database file located in the testdata FS, under
|
||||
// tb.Name()/leases.json, to a temporary directory and returns the path to the
|
||||
// copied file.
|
||||
func newTempDB(tb testing.TB) (dst string) {
|
||||
tb.Helper()
|
||||
// testTimeout is a common timeout for tests and contexts.
|
||||
const testTimeout time.Duration = 10 * time.Second
|
||||
|
||||
const filename = "leases.json"
|
||||
// discardLog is a logger to discard test output.
|
||||
var discardLog = slogutil.NewDiscardLogger()
|
||||
|
||||
data, err := fs.ReadFile(testdata, path.Join(tb.Name(), filename))
|
||||
require.NoError(tb, err)
|
||||
// testInterfaceConf is a common set of interface configurations for tests.
|
||||
var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("192.168.0.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("192.168.0.2"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.0.254"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
},
|
||||
IPv6: &dhcpsvc.IPv6Config{
|
||||
Enabled: true,
|
||||
RangeStart: netip.MustParseAddr("2001:db8::1"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
RAAllowSLAAC: true,
|
||||
RASLAACOnly: true,
|
||||
},
|
||||
},
|
||||
"eth1": {
|
||||
IPv4: &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("172.16.0.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("172.16.0.2"),
|
||||
RangeEnd: netip.MustParseAddr("172.16.0.255"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
},
|
||||
IPv6: &dhcpsvc.IPv6Config{
|
||||
Enabled: true,
|
||||
RangeStart: netip.MustParseAddr("2001:db9::1"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
RAAllowSLAAC: true,
|
||||
RASLAACOnly: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dst = filepath.Join(tb.TempDir(), filename)
|
||||
// mustParseMAC parses a hardware address from s and requires no errors.
|
||||
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
|
||||
mac, err := net.ParseMAC(s)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(dst, data, dhcpsvc.DatabasePerm)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return dst
|
||||
return mac
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
@@ -72,8 +103,6 @@ func TestNew(t *testing.T) {
|
||||
RASLAACOnly: true,
|
||||
}
|
||||
|
||||
leasesPath := filepath.Join(t.TempDir(), "leases.json")
|
||||
|
||||
testCases := []struct {
|
||||
conf *dhcpsvc.Config
|
||||
name string
|
||||
@@ -89,7 +118,6 @@ func TestNew(t *testing.T) {
|
||||
IPv6: validIPv6Conf,
|
||||
},
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "valid",
|
||||
wantErrMsg: "",
|
||||
@@ -104,7 +132,6 @@ func TestNew(t *testing.T) {
|
||||
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
|
||||
},
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "disabled_interfaces",
|
||||
wantErrMsg: "",
|
||||
@@ -119,10 +146,9 @@ func TestNew(t *testing.T) {
|
||||
IPv6: validIPv6Conf,
|
||||
},
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "gateway_within_range",
|
||||
wantErrMsg: `creating interfaces: interface "eth0": ipv4: ` +
|
||||
wantErrMsg: `interface "eth0": ipv4: ` +
|
||||
`gateway ip 192.168.0.100 in the ip range 192.168.0.1-192.168.0.254`,
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
@@ -135,10 +161,9 @@ func TestNew(t *testing.T) {
|
||||
IPv6: validIPv6Conf,
|
||||
},
|
||||
},
|
||||
DBFilePath: leasesPath,
|
||||
},
|
||||
name: "bad_start",
|
||||
wantErrMsg: `creating interfaces: interface "eth0": ipv4: ` +
|
||||
wantErrMsg: `interface "eth0": ipv4: ` +
|
||||
`range start 127.0.0.1 is not within 192.168.0.1/24`,
|
||||
}}
|
||||
|
||||
@@ -155,36 +180,32 @@ func TestNew(t *testing.T) {
|
||||
func TestDHCPServer_AddLease(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
leasesPath := filepath.Join(t.TempDir(), "leases.json")
|
||||
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Logger: discardLog,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
existHost = "host1"
|
||||
newHost = "host2"
|
||||
ipv6Host = "host3"
|
||||
host1 = "host1"
|
||||
host2 = "host2"
|
||||
host3 = "host3"
|
||||
)
|
||||
|
||||
var (
|
||||
existIP = netip.MustParseAddr("192.168.0.2")
|
||||
newIP = netip.MustParseAddr("192.168.0.3")
|
||||
newIPv6 = netip.MustParseAddr("2001:db8::2")
|
||||
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||
ip3 := netip.MustParseAddr("2001:db8::2")
|
||||
|
||||
existMAC = mustParseMAC(t, "01:02:03:04:05:06")
|
||||
newMAC = mustParseMAC(t, "06:05:04:03:02:01")
|
||||
ipv6MAC = mustParseMAC(t, "02:03:04:05:06:07")
|
||||
)
|
||||
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
|
||||
|
||||
require.NoError(t, srv.AddLease(ctx, &dhcpsvc.Lease{
|
||||
Hostname: existHost,
|
||||
IP: existIP,
|
||||
HWAddr: existMAC,
|
||||
Hostname: host1,
|
||||
IP: ip1,
|
||||
HWAddr: mac1,
|
||||
IsStatic: true,
|
||||
}))
|
||||
|
||||
@@ -195,61 +216,61 @@ func TestDHCPServer_AddLease(t *testing.T) {
|
||||
}{{
|
||||
name: "outside_range",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: newHost,
|
||||
Hostname: host2,
|
||||
IP: netip.MustParseAddr("1.2.3.4"),
|
||||
HWAddr: newMAC,
|
||||
HWAddr: mac2,
|
||||
},
|
||||
wantErrMsg: "adding lease: no interface for ip 1.2.3.4",
|
||||
}, {
|
||||
name: "duplicate_ip",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: newHost,
|
||||
IP: existIP,
|
||||
HWAddr: newMAC,
|
||||
Hostname: host2,
|
||||
IP: ip1,
|
||||
HWAddr: mac2,
|
||||
},
|
||||
wantErrMsg: "adding lease: lease for ip " + existIP.String() +
|
||||
wantErrMsg: "adding lease: lease for ip " + ip1.String() +
|
||||
" already exists",
|
||||
}, {
|
||||
name: "duplicate_hostname",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: existHost,
|
||||
IP: newIP,
|
||||
HWAddr: newMAC,
|
||||
Hostname: host1,
|
||||
IP: ip2,
|
||||
HWAddr: mac2,
|
||||
},
|
||||
wantErrMsg: "adding lease: lease for hostname " + existHost +
|
||||
wantErrMsg: "adding lease: lease for hostname " + host1 +
|
||||
" already exists",
|
||||
}, {
|
||||
name: "duplicate_hostname_case",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: strings.ToUpper(existHost),
|
||||
IP: newIP,
|
||||
HWAddr: newMAC,
|
||||
Hostname: strings.ToUpper(host1),
|
||||
IP: ip2,
|
||||
HWAddr: mac2,
|
||||
},
|
||||
wantErrMsg: "adding lease: lease for hostname " +
|
||||
strings.ToUpper(existHost) + " already exists",
|
||||
strings.ToUpper(host1) + " already exists",
|
||||
}, {
|
||||
name: "duplicate_mac",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: newHost,
|
||||
IP: newIP,
|
||||
HWAddr: existMAC,
|
||||
Hostname: host2,
|
||||
IP: ip2,
|
||||
HWAddr: mac1,
|
||||
},
|
||||
wantErrMsg: "adding lease: lease for mac " + existMAC.String() +
|
||||
wantErrMsg: "adding lease: lease for mac " + mac1.String() +
|
||||
" already exists",
|
||||
}, {
|
||||
name: "valid",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: newHost,
|
||||
IP: newIP,
|
||||
HWAddr: newMAC,
|
||||
Hostname: host2,
|
||||
IP: ip2,
|
||||
HWAddr: mac2,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "valid_v6",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: ipv6Host,
|
||||
IP: newIPv6,
|
||||
HWAddr: ipv6MAC,
|
||||
Hostname: host3,
|
||||
IP: ip3,
|
||||
HWAddr: mac3,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}}
|
||||
@@ -259,21 +280,16 @@ func TestDHCPServer_AddLease(t *testing.T) {
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.AddLease(ctx, tc.lease))
|
||||
})
|
||||
}
|
||||
|
||||
assert.NotEmpty(t, srv.Leases())
|
||||
assert.FileExists(t, leasesPath)
|
||||
}
|
||||
|
||||
func TestDHCPServer_index(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
leasesPath := newTempDB(t)
|
||||
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Logger: discardLog,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -285,23 +301,46 @@ func TestDHCPServer_index(t *testing.T) {
|
||||
host5 = "host5"
|
||||
)
|
||||
|
||||
var (
|
||||
ip1 = netip.MustParseAddr("192.168.0.2")
|
||||
ip2 = netip.MustParseAddr("192.168.0.3")
|
||||
ip3 = netip.MustParseAddr("172.16.0.3")
|
||||
ip4 = netip.MustParseAddr("172.16.0.4")
|
||||
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||
ip3 := netip.MustParseAddr("172.16.0.3")
|
||||
ip4 := netip.MustParseAddr("172.16.0.4")
|
||||
|
||||
mac1 = mustParseMAC(t, "01:02:03:04:05:06")
|
||||
mac2 = mustParseMAC(t, "06:05:04:03:02:01")
|
||||
mac3 = mustParseMAC(t, "02:03:04:05:06:07")
|
||||
)
|
||||
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
|
||||
|
||||
leases := []*dhcpsvc.Lease{{
|
||||
Hostname: host1,
|
||||
IP: ip1,
|
||||
HWAddr: mac1,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: host2,
|
||||
IP: ip2,
|
||||
HWAddr: mac2,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: host3,
|
||||
IP: ip3,
|
||||
HWAddr: mac3,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: host4,
|
||||
IP: ip4,
|
||||
HWAddr: mac1,
|
||||
IsStatic: true,
|
||||
}}
|
||||
for _, l := range leases {
|
||||
require.NoError(t, srv.AddLease(ctx, l))
|
||||
}
|
||||
|
||||
t.Run("ip_idx", func(t *testing.T) {
|
||||
assert.Equal(t, ip1, srv.IPByHost(host1))
|
||||
assert.Equal(t, ip2, srv.IPByHost(host2))
|
||||
assert.Equal(t, ip3, srv.IPByHost(host3))
|
||||
assert.Equal(t, ip4, srv.IPByHost(host4))
|
||||
assert.Zero(t, srv.IPByHost(host5))
|
||||
assert.Equal(t, netip.Addr{}, srv.IPByHost(host5))
|
||||
})
|
||||
|
||||
t.Run("name_idx", func(t *testing.T) {
|
||||
@@ -309,7 +348,7 @@ func TestDHCPServer_index(t *testing.T) {
|
||||
assert.Equal(t, host2, srv.HostByIP(ip2))
|
||||
assert.Equal(t, host3, srv.HostByIP(ip3))
|
||||
assert.Equal(t, host4, srv.HostByIP(ip4))
|
||||
assert.Zero(t, srv.HostByIP(netip.Addr{}))
|
||||
assert.Equal(t, "", srv.HostByIP(netip.Addr{}))
|
||||
})
|
||||
|
||||
t.Run("mac_idx", func(t *testing.T) {
|
||||
@@ -317,20 +356,18 @@ func TestDHCPServer_index(t *testing.T) {
|
||||
assert.Equal(t, mac2, srv.MACByIP(ip2))
|
||||
assert.Equal(t, mac3, srv.MACByIP(ip3))
|
||||
assert.Equal(t, mac1, srv.MACByIP(ip4))
|
||||
assert.Zero(t, srv.MACByIP(netip.Addr{}))
|
||||
assert.Nil(t, srv.MACByIP(netip.Addr{}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
leasesPath := newTempDB(t)
|
||||
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Logger: discardLog,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -343,16 +380,36 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
||||
host6 = "host6"
|
||||
)
|
||||
|
||||
var (
|
||||
ip1 = netip.MustParseAddr("192.168.0.2")
|
||||
ip2 = netip.MustParseAddr("192.168.0.3")
|
||||
ip3 = netip.MustParseAddr("192.168.0.4")
|
||||
ip4 = netip.MustParseAddr("2001:db8::3")
|
||||
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||
ip3 := netip.MustParseAddr("192.168.0.4")
|
||||
ip4 := netip.MustParseAddr("2001:db8::2")
|
||||
ip5 := netip.MustParseAddr("2001:db8::3")
|
||||
|
||||
mac1 = mustParseMAC(t, "01:02:03:04:05:06")
|
||||
mac2 = mustParseMAC(t, "06:05:04:03:02:01")
|
||||
mac3 = mustParseMAC(t, "06:05:04:03:02:02")
|
||||
)
|
||||
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||
mac2 := mustParseMAC(t, "01:02:03:04:05:07")
|
||||
mac3 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||
mac4 := mustParseMAC(t, "06:05:04:03:02:02")
|
||||
|
||||
leases := []*dhcpsvc.Lease{{
|
||||
Hostname: host1,
|
||||
IP: ip1,
|
||||
HWAddr: mac1,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: host2,
|
||||
IP: ip2,
|
||||
HWAddr: mac2,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: host4,
|
||||
IP: ip4,
|
||||
HWAddr: mac4,
|
||||
IsStatic: true,
|
||||
}}
|
||||
for _, l := range leases {
|
||||
require.NoError(t, srv.AddLease(ctx, l))
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -371,9 +428,9 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host3,
|
||||
IP: ip3,
|
||||
HWAddr: mac2,
|
||||
HWAddr: mac3,
|
||||
},
|
||||
wantErrMsg: "updating static lease: no lease for mac " + mac2.String(),
|
||||
wantErrMsg: "updating static lease: no lease for mac " + mac3.String(),
|
||||
}, {
|
||||
name: "duplicate_ip",
|
||||
lease: &dhcpsvc.Lease{
|
||||
@@ -413,8 +470,8 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
||||
name: "valid_v6",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host6,
|
||||
IP: ip4,
|
||||
HWAddr: mac3,
|
||||
IP: ip5,
|
||||
HWAddr: mac4,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}}
|
||||
@@ -424,20 +481,16 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.UpdateStaticLease(ctx, tc.lease))
|
||||
})
|
||||
}
|
||||
|
||||
assert.FileExists(t, leasesPath)
|
||||
}
|
||||
|
||||
func TestDHCPServer_RemoveLease(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
leasesPath := newTempDB(t)
|
||||
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Logger: discardLog,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -447,15 +500,28 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
|
||||
host3 = "host3"
|
||||
)
|
||||
|
||||
var (
|
||||
existIP = netip.MustParseAddr("192.168.0.2")
|
||||
newIP = netip.MustParseAddr("192.168.0.3")
|
||||
newIPv6 = netip.MustParseAddr("2001:db8::2")
|
||||
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||
ip3 := netip.MustParseAddr("2001:db8::2")
|
||||
|
||||
existMAC = mustParseMAC(t, "01:02:03:04:05:06")
|
||||
newMAC = mustParseMAC(t, "02:03:04:05:06:07")
|
||||
ipv6MAC = mustParseMAC(t, "06:05:04:03:02:01")
|
||||
)
|
||||
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||
mac2 := mustParseMAC(t, "02:03:04:05:06:07")
|
||||
mac3 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||
|
||||
leases := []*dhcpsvc.Lease{{
|
||||
Hostname: host1,
|
||||
IP: ip1,
|
||||
HWAddr: mac1,
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: host3,
|
||||
IP: ip3,
|
||||
HWAddr: mac3,
|
||||
IsStatic: true,
|
||||
}}
|
||||
for _, l := range leases {
|
||||
require.NoError(t, srv.AddLease(ctx, l))
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -465,40 +531,40 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
|
||||
name: "not_found_mac",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host1,
|
||||
IP: existIP,
|
||||
HWAddr: newMAC,
|
||||
IP: ip1,
|
||||
HWAddr: mac2,
|
||||
},
|
||||
wantErrMsg: "removing lease: no lease for mac " + newMAC.String(),
|
||||
wantErrMsg: "removing lease: no lease for mac " + mac2.String(),
|
||||
}, {
|
||||
name: "not_found_ip",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host1,
|
||||
IP: newIP,
|
||||
HWAddr: existMAC,
|
||||
IP: ip2,
|
||||
HWAddr: mac1,
|
||||
},
|
||||
wantErrMsg: "removing lease: no lease for ip " + newIP.String(),
|
||||
wantErrMsg: "removing lease: no lease for ip " + ip2.String(),
|
||||
}, {
|
||||
name: "not_found_host",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host2,
|
||||
IP: existIP,
|
||||
HWAddr: existMAC,
|
||||
IP: ip1,
|
||||
HWAddr: mac1,
|
||||
},
|
||||
wantErrMsg: "removing lease: no lease for hostname " + host2,
|
||||
}, {
|
||||
name: "valid",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host1,
|
||||
IP: existIP,
|
||||
HWAddr: existMAC,
|
||||
IP: ip1,
|
||||
HWAddr: mac1,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "valid_v6",
|
||||
lease: &dhcpsvc.Lease{
|
||||
Hostname: host3,
|
||||
IP: newIPv6,
|
||||
HWAddr: ipv6MAC,
|
||||
IP: ip3,
|
||||
HWAddr: mac3,
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}}
|
||||
@@ -509,64 +575,49 @@ func TestDHCPServer_RemoveLease(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
assert.FileExists(t, leasesPath)
|
||||
assert.Empty(t, srv.Leases())
|
||||
}
|
||||
|
||||
func TestDHCPServer_Reset(t *testing.T) {
|
||||
leasesPath := newTempDB(t)
|
||||
conf := &dhcpsvc.Config{
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Logger: discardLog,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
srv, err := dhcpsvc.New(ctx, conf)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const leasesNum = 4
|
||||
leases := []*dhcpsvc.Lease{{
|
||||
Hostname: "host1",
|
||||
IP: netip.MustParseAddr("192.168.0.2"),
|
||||
HWAddr: mustParseMAC(t, "01:02:03:04:05:06"),
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: "host2",
|
||||
IP: netip.MustParseAddr("192.168.0.3"),
|
||||
HWAddr: mustParseMAC(t, "06:05:04:03:02:01"),
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: "host3",
|
||||
IP: netip.MustParseAddr("2001:db8::2"),
|
||||
HWAddr: mustParseMAC(t, "02:03:04:05:06:07"),
|
||||
IsStatic: true,
|
||||
}, {
|
||||
Hostname: "host4",
|
||||
IP: netip.MustParseAddr("2001:db8::3"),
|
||||
HWAddr: mustParseMAC(t, "06:05:04:03:02:02"),
|
||||
IsStatic: true,
|
||||
}}
|
||||
|
||||
require.Len(t, srv.Leases(), leasesNum)
|
||||
for _, l := range leases {
|
||||
require.NoError(t, srv.AddLease(ctx, l))
|
||||
}
|
||||
|
||||
require.Len(t, srv.Leases(), len(leases))
|
||||
|
||||
require.NoError(t, srv.Reset(ctx))
|
||||
|
||||
assert.FileExists(t, leasesPath)
|
||||
assert.Empty(t, srv.Leases())
|
||||
}
|
||||
|
||||
func TestServer_Leases(t *testing.T) {
|
||||
leasesPath := newTempDB(t)
|
||||
conf := &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
Logger: discardLog,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: testInterfaceConf,
|
||||
DBFilePath: leasesPath,
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
srv, err := dhcpsvc.New(ctx, conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
expiry, err := time.Parse(time.RFC3339, "2042-01-02T03:04:05Z")
|
||||
require.NoError(t, err)
|
||||
|
||||
wantLeases := []*dhcpsvc.Lease{{
|
||||
Expiry: expiry,
|
||||
IP: netip.MustParseAddr("192.168.0.3"),
|
||||
Hostname: "example.host",
|
||||
HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"),
|
||||
IsStatic: false,
|
||||
}, {
|
||||
Expiry: time.Time{},
|
||||
IP: netip.MustParseAddr("192.168.0.4"),
|
||||
Hostname: "example.static.host",
|
||||
HWAddr: mustParseMAC(t, "BB:BB:BB:BB:BB:BB"),
|
||||
IsStatic: true,
|
||||
}}
|
||||
assert.ElementsMatch(t, wantLeases, srv.Leases())
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"leases": [
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.2",
|
||||
"hostname": "host1",
|
||||
"mac": "01:02:03:04:05:06",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "2001:db8::2",
|
||||
"hostname": "host3",
|
||||
"mac": "06:05:04:03:02:01",
|
||||
"static": true
|
||||
}
|
||||
],
|
||||
"version": 1
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"leases": [
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.2",
|
||||
"hostname": "host1",
|
||||
"mac": "01:02:03:04:05:06",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.3",
|
||||
"hostname": "host2",
|
||||
"mac": "06:05:04:03:02:01",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "2001:db8::2",
|
||||
"hostname": "host3",
|
||||
"mac": "02:03:04:05:06:07",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "2001:db8::3",
|
||||
"hostname": "host4",
|
||||
"mac": "06:05:04:03:02:02",
|
||||
"static": true
|
||||
}
|
||||
],
|
||||
"version": 1
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
{
|
||||
"leases": [
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.2",
|
||||
"hostname": "host1",
|
||||
"mac": "01:02:03:04:05:06",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.3",
|
||||
"hostname": "host2",
|
||||
"mac": "01:02:03:04:05:07",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "2001:db8::2",
|
||||
"hostname": "host4",
|
||||
"mac": "06:05:04:03:02:02",
|
||||
"static": true
|
||||
}
|
||||
],
|
||||
"version": 1
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"leases": [
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.2",
|
||||
"hostname": "host1",
|
||||
"mac": "01:02:03:04:05:06",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "192.168.0.3",
|
||||
"hostname": "host2",
|
||||
"mac": "06:05:04:03:02:01",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "172.16.0.3",
|
||||
"hostname": "host3",
|
||||
"mac": "02:03:04:05:06:07",
|
||||
"static": true
|
||||
},
|
||||
{
|
||||
"expires": "",
|
||||
"ip": "172.16.0.4",
|
||||
"hostname": "host4",
|
||||
"mac": "01:02:03:04:05:06",
|
||||
"static": true
|
||||
}
|
||||
],
|
||||
"version": 1
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"leases": [{
|
||||
"expires": "2042-01-02T03:04:05Z",
|
||||
"ip": "192.168.0.3",
|
||||
"hostname": "example.host",
|
||||
"mac": "AA:AA:AA:AA:AA:AA",
|
||||
"static": false
|
||||
}, {
|
||||
"ip": "192.168.0.4",
|
||||
"hostname": "example.static.host",
|
||||
"mac": "BB:BB:BB:BB:BB:BB",
|
||||
"static": true
|
||||
}],
|
||||
"version": 1
|
||||
}
|
||||
@@ -82,12 +82,8 @@ func (c *IPv4Config) validate() (err error) {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// dhcpInterfaceV4 is a DHCP interface for IPv4 address family.
|
||||
type dhcpInterfaceV4 struct {
|
||||
// common is the common part of any network interface within the DHCP
|
||||
// server.
|
||||
common *netInterface
|
||||
|
||||
// netInterfaceV4 is a DHCP interface for IPv4 address family.
|
||||
type netInterfaceV4 struct {
|
||||
// gateway is the IP address of the network gateway.
|
||||
gateway netip.Addr
|
||||
|
||||
@@ -105,22 +101,25 @@ type dhcpInterfaceV4 struct {
|
||||
// explicitOpts are the user-configured options. It must not have
|
||||
// intersections with implicitOpts.
|
||||
explicitOpts layers.DHCPOptions
|
||||
|
||||
// netInterface is embedded here to provide some common network interface
|
||||
// logic.
|
||||
netInterface
|
||||
}
|
||||
|
||||
// newDHCPInterfaceV4 creates a new DHCP interface for IPv4 address family with
|
||||
// newNetInterfaceV4 creates a new DHCP interface for IPv4 address family with
|
||||
// the given configuration. It returns an error if the given configuration
|
||||
// can't be used.
|
||||
func newDHCPInterfaceV4(
|
||||
func newNetInterfaceV4(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
name string,
|
||||
conf *IPv4Config,
|
||||
) (i *dhcpInterfaceV4, err error) {
|
||||
) (i *netInterfaceV4, err error) {
|
||||
l = l.With(
|
||||
keyInterface, name,
|
||||
keyFamily, netutil.AddrFamilyIPv4,
|
||||
)
|
||||
|
||||
if !conf.Enabled {
|
||||
l.DebugContext(ctx, "disabled")
|
||||
|
||||
@@ -144,31 +143,35 @@ func newDHCPInterfaceV4(
|
||||
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
|
||||
}
|
||||
|
||||
i = &dhcpInterfaceV4{
|
||||
i = &netInterfaceV4{
|
||||
gateway: conf.GatewayIP,
|
||||
subnet: subnet,
|
||||
addrSpace: addrSpace,
|
||||
common: newNetInterface(name, l, conf.LeaseDuration),
|
||||
netInterface: netInterface{
|
||||
name: name,
|
||||
leaseTTL: conf.LeaseDuration,
|
||||
logger: l,
|
||||
},
|
||||
}
|
||||
i.implicitOpts, i.explicitOpts = conf.options(ctx, l)
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// dhcpInterfacesV4 is a slice of network interfaces of IPv4 address family.
|
||||
type dhcpInterfacesV4 []*dhcpInterfaceV4
|
||||
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
|
||||
type netInterfacesV4 []*netInterfaceV4
|
||||
|
||||
// find returns the first network interface within ifaces containing ip. It
|
||||
// returns false if there is no such interface.
|
||||
func (ifaces dhcpInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
|
||||
i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV4) (contains bool) {
|
||||
func (ifaces netInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
|
||||
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV4) (contains bool) {
|
||||
return iface.subnet.Contains(ip)
|
||||
})
|
||||
if i < 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return ifaces[i].common, true
|
||||
return &ifaces[i].netInterface, true
|
||||
}
|
||||
|
||||
// options returns the implicit and explicit options for the interface. The two
|
||||
|
||||
@@ -62,12 +62,10 @@ func (c *IPv6Config) validate() (err error) {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// dhcpInterfaceV6 is a DHCP interface for IPv6 address family.
|
||||
type dhcpInterfaceV6 struct {
|
||||
// common is the common part of any network interface within the DHCP
|
||||
// server.
|
||||
common *netInterface
|
||||
|
||||
// netInterfaceV6 is a DHCP interface for IPv6 address family.
|
||||
//
|
||||
// TODO(e.burkov): Add options.
|
||||
type netInterfaceV6 struct {
|
||||
// rangeStart is the first IP address in the range.
|
||||
rangeStart netip.Addr
|
||||
|
||||
@@ -80,6 +78,10 @@ type dhcpInterfaceV6 struct {
|
||||
// intersections with implicitOpts.
|
||||
explicitOpts layers.DHCPv6Options
|
||||
|
||||
// netInterface is embedded here to provide some common network interface
|
||||
// logic.
|
||||
netInterface
|
||||
|
||||
// raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO
|
||||
// flags.
|
||||
raSLAACOnly bool
|
||||
@@ -88,16 +90,16 @@ type dhcpInterfaceV6 struct {
|
||||
raAllowSLAAC bool
|
||||
}
|
||||
|
||||
// newDHCPInterfaceV6 creates a new DHCP interface for IPv6 address family with
|
||||
// newNetInterfaceV6 creates a new DHCP interface for IPv6 address family with
|
||||
// the given configuration.
|
||||
//
|
||||
// TODO(e.burkov): Validate properly.
|
||||
func newDHCPInterfaceV6(
|
||||
func newNetInterfaceV6(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
name string,
|
||||
conf *IPv6Config,
|
||||
) (i *dhcpInterfaceV6) {
|
||||
) (i *netInterfaceV6) {
|
||||
l = l.With(keyInterface, name, keyFamily, netutil.AddrFamilyIPv6)
|
||||
if !conf.Enabled {
|
||||
l.DebugContext(ctx, "disabled")
|
||||
@@ -105,9 +107,13 @@ func newDHCPInterfaceV6(
|
||||
return nil
|
||||
}
|
||||
|
||||
i = &dhcpInterfaceV6{
|
||||
rangeStart: conf.RangeStart,
|
||||
common: newNetInterface(name, l, conf.LeaseDuration),
|
||||
i = &netInterfaceV6{
|
||||
rangeStart: conf.RangeStart,
|
||||
netInterface: netInterface{
|
||||
name: name,
|
||||
leaseTTL: conf.LeaseDuration,
|
||||
logger: l,
|
||||
},
|
||||
raSLAACOnly: conf.RASLAACOnly,
|
||||
raAllowSLAAC: conf.RAAllowSLAAC,
|
||||
}
|
||||
@@ -116,12 +122,12 @@ func newDHCPInterfaceV6(
|
||||
return i
|
||||
}
|
||||
|
||||
// dhcpInterfacesV6 is a slice of network interfaces of IPv6 address family.
|
||||
type dhcpInterfacesV6 []*dhcpInterfaceV6
|
||||
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
|
||||
type netInterfacesV6 []*netInterfaceV6
|
||||
|
||||
// find returns the first network interface within ifaces containing ip. It
|
||||
// returns false if there is no such interface.
|
||||
func (ifaces dhcpInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool) {
|
||||
func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool) {
|
||||
// prefLen is the length of prefix to match ip against.
|
||||
//
|
||||
// TODO(e.burkov): DHCPv6 inherits the weird behavior of legacy
|
||||
@@ -130,7 +136,7 @@ func (ifaces dhcpInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok boo
|
||||
// be used instead.
|
||||
const prefLen = netutil.IPv6BitLen - 8
|
||||
|
||||
i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV6) (contains bool) {
|
||||
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV6) (contains bool) {
|
||||
return !ip.Less(iface.rangeStart) &&
|
||||
netip.PrefixFrom(iface.rangeStart, prefLen).Contains(ip)
|
||||
})
|
||||
@@ -138,7 +144,7 @@ func (ifaces dhcpInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok boo
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return ifaces[i].common, true
|
||||
return &ifaces[i].netInterface, true
|
||||
}
|
||||
|
||||
// options returns the implicit and explicit options for the interface. The two
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -218,8 +217,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
baseLogger: slogutil.NewDiscardLogger(),
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -159,7 +158,7 @@ type Config struct {
|
||||
// IpsetList is the ipset configuration that allows AdGuard Home to add IP
|
||||
// addresses of the specified domain names to an ipset list. Syntax:
|
||||
//
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME
|
||||
//
|
||||
// This field is ignored if [IpsetListFileName] is set.
|
||||
IpsetList []string `yaml:"ipset"`
|
||||
@@ -302,8 +301,6 @@ type ServerConfig struct {
|
||||
|
||||
// UpstreamMode is a enumeration of upstream mode representations. See
|
||||
// [proxy.UpstreamModeType].
|
||||
//
|
||||
// TODO(d.kolyshev): Consider using [proxy.UpstreamMode].
|
||||
type UpstreamMode string
|
||||
|
||||
const (
|
||||
@@ -318,7 +315,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
||||
|
||||
conf = &proxy.Config{
|
||||
Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"),
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
@@ -424,6 +420,8 @@ func parseBogusNXDOMAIN(confBogusNXDOMAIN []string) (subnets []netip.Prefix, err
|
||||
return subnets, nil
|
||||
}
|
||||
|
||||
const defaultBlockedResponseTTL = 3600
|
||||
|
||||
// initDefaultSettings initializes default settings if nothing
|
||||
// is configured
|
||||
func (s *Server) initDefaultSettings() {
|
||||
@@ -454,24 +452,24 @@ func (s *Server) initDefaultSettings() {
|
||||
|
||||
// prepareIpsetListSettings reads and prepares the ipset configuration either
|
||||
// from a file or from the data in the configuration file.
|
||||
func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
||||
func (s *Server) prepareIpsetListSettings() (err error) {
|
||||
fn := s.conf.IpsetListFileName
|
||||
if fn == "" {
|
||||
return s.conf.IpsetList, nil
|
||||
return s.ipset.init(s.conf.IpsetList)
|
||||
}
|
||||
|
||||
// #nosec G304 -- Trust the path explicitly given by the user.
|
||||
data, err := os.ReadFile(fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
ipsets = stringutil.SplitTrimmed(string(data), "\n")
|
||||
ipsets = slices.DeleteFunc(ipsets, IsCommentOrEmpty)
|
||||
ipsets := stringutil.SplitTrimmed(string(data), "\n")
|
||||
ipsets = stringutil.FilterOut(ipsets, IsCommentOrEmpty)
|
||||
|
||||
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
||||
|
||||
return ipsets, nil
|
||||
return s.ipset.init(ipsets)
|
||||
}
|
||||
|
||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||
@@ -692,7 +690,7 @@ func matchesDomainWildcard(host, pat string) (ok bool) {
|
||||
// the DNS names and patterns from certificate. dnsNames must be sorted.
|
||||
func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||
// Check sni is either a valid hostname or a valid IP address.
|
||||
if !netutil.IsValidHostname(sni) && !netutil.IsValidIPString(sni) {
|
||||
if netutil.ValidateHostname(sni) != nil && net.ParseIP(sni) == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
|
||||
@@ -29,7 +28,7 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
|
||||
Timeout: time.Minute * 5,
|
||||
}
|
||||
|
||||
if netutil.IsValidIPString(host) {
|
||||
if net.ParseIP(host) != nil {
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -28,7 +27,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/sysresolv"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@@ -123,17 +121,12 @@ type Server struct {
|
||||
// access drops disallowed clients.
|
||||
access *accessManager
|
||||
|
||||
// baseLogger is used to create loggers for other entities. It should not
|
||||
// have a prefix and must not be nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
// must be a valid domain name plus dots on each side.
|
||||
localDomainSuffix string
|
||||
|
||||
// ipset processes DNS requests using ipset data. It must not be nil after
|
||||
// initialization. See [newIpsetHandler].
|
||||
ipset *ipsetHandler
|
||||
// ipset processes DNS requests using ipset data.
|
||||
ipset ipsetCtx
|
||||
|
||||
// privateNets is the configured set of IP networks considered private.
|
||||
privateNets netutil.SubnetSet
|
||||
@@ -204,10 +197,6 @@ type DNSCreateParams struct {
|
||||
PrivateNets netutil.SubnetSet
|
||||
Anonymizer *aghnet.IPMut
|
||||
EtcHosts *aghnet.HostsContainer
|
||||
|
||||
// Logger is used as a base logger. It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
LocalDomain string
|
||||
}
|
||||
|
||||
@@ -244,7 +233,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
stats: p.Stats,
|
||||
queryLog: p.QueryLog,
|
||||
privateNets: p.PrivateNets,
|
||||
baseLogger: p.Logger,
|
||||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||
etcHosts: etcHosts,
|
||||
@@ -608,18 +596,11 @@ func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (err error) {
|
||||
ipsetList, err := s.prepareIpsetListSettings()
|
||||
err = s.prepareIpsetListSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
|
||||
ipsetLogger := s.baseLogger.With(slogutil.KeyPrefix, "ipset")
|
||||
s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
bootOpts := &upstream.Options{
|
||||
Timeout: DefaultTimeout,
|
||||
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
|
||||
@@ -683,7 +664,6 @@ func (s *Server) setupAddrProc() {
|
||||
s.addrProc = client.EmptyAddrProc{}
|
||||
} else {
|
||||
c := s.conf.AddrProcConf
|
||||
c.BaseLogger = s.baseLogger
|
||||
c.DialContext = s.DialContext
|
||||
c.PrivateSubnets = s.privateNets
|
||||
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
|
||||
@@ -727,7 +707,6 @@ func validateBlockingMode(
|
||||
func (s *Server) prepareInternalProxy() (err error) {
|
||||
srvConf := s.conf
|
||||
conf := &proxy.Config{
|
||||
Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"),
|
||||
CacheEnabled: true,
|
||||
CacheSizeBytes: 4096,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -100,7 +99,6 @@ func createTestServer(
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -341,10 +339,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: createTestDNSFilter(t),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(srvConf)
|
||||
@@ -354,10 +349,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DNSFilter: createTestDNSFilter(t),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
@@ -384,9 +376,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
s, err := NewServer(DNSCreateParams{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(srvConf)
|
||||
@@ -972,7 +962,6 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1138,7 +1127,6 @@ func TestRewrite(t *testing.T) {
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1268,7 +1256,6 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
},
|
||||
},
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
LocalDomain: localDomain,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1354,7 +1341,6 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
DHCPServer: dhcp,
|
||||
DNSFilter: flt,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1406,29 +1392,24 @@ func TestNewServer(t *testing.T) {
|
||||
in DNSCreateParams
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "success",
|
||||
in: DNSCreateParams{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
},
|
||||
name: "success",
|
||||
in: DNSCreateParams{},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "success_local_tld",
|
||||
in: DNSCreateParams{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
LocalDomain: "mynet",
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "success_local_domain",
|
||||
in: DNSCreateParams{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
LocalDomain: "my.local.net",
|
||||
},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "bad_local_domain",
|
||||
in: DNSCreateParams{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
LocalDomain: "!!!",
|
||||
},
|
||||
wantErrMsg: `local domain: bad domain name "!!!": ` +
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -58,7 +57,6 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -231,7 +229,6 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
DHCPServer: &testDHCP{},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "upstream_dns_bad",
|
||||
wantSet: `validating dns config: upstream servers: parsing error at index 0: ` +
|
||||
`cannot prepare the upstream: invalid address !!!: bad domain name "!!!": ` +
|
||||
`cannot prepare the upstream: invalid address !!!: bad hostname "!!!": ` +
|
||||
`bad top-level domain name label "!!!": bad top-level domain name label rune '!'`,
|
||||
}, {
|
||||
name: "bootstraps_bad",
|
||||
|
||||
@@ -1,43 +1,28 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/ipset"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ipsetHandler is the ipset context. ipsetMgr can be nil.
|
||||
type ipsetHandler struct {
|
||||
// ipsetCtx is the ipset context. ipsetMgr can be nil.
|
||||
type ipsetCtx struct {
|
||||
ipsetMgr ipset.Manager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// newIpsetHandler returns a new initialized [ipsetHandler]. It is not safe for
|
||||
// concurrent use.
|
||||
func newIpsetHandler(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
ipsetList []string,
|
||||
) (h *ipsetHandler, err error) {
|
||||
h = &ipsetHandler{
|
||||
logger: logger,
|
||||
}
|
||||
conf := &ipset.Config{
|
||||
Logger: logger,
|
||||
Lines: ipsetList,
|
||||
}
|
||||
h.ipsetMgr, err = ipset.NewManager(ctx, conf)
|
||||
if errors.Is(err, os.ErrInvalid) ||
|
||||
errors.Is(err, os.ErrPermission) ||
|
||||
errors.Is(err, errors.ErrUnsupported) {
|
||||
// init initializes the ipset context. It is not safe for concurrent use.
|
||||
//
|
||||
// TODO(a.garipov): Rewrite into a simple constructor?
|
||||
func (c *ipsetCtx) init(ipsetConf []string) (err error) {
|
||||
c.ipsetMgr, err = ipset.NewManager(ipsetConf)
|
||||
if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) {
|
||||
// ipset cannot currently be initialized if the server was installed
|
||||
// from Snap or when the user or the binary doesn't have the required
|
||||
// permissions, or when the kernel doesn't support netfilter.
|
||||
@@ -46,28 +31,30 @@ func newIpsetHandler(
|
||||
//
|
||||
// TODO(a.garipov): The Snap problem can probably be solved if we add
|
||||
// the netlink-connector interface plug.
|
||||
logger.WarnContext(ctx, "cannot initialize", slogutil.KeyError, err)
|
||||
log.Info("ipset: warning: cannot initialize: %s", err)
|
||||
|
||||
return h, nil
|
||||
return nil
|
||||
} else if errors.Is(err, errors.ErrUnsupported) {
|
||||
log.Info("ipset: warning: %s", err)
|
||||
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("initializing ipset: %w", err)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// close closes the Linux Netfilter connections. close can be called on a nil
|
||||
// handler.
|
||||
func (h *ipsetHandler) close() (err error) {
|
||||
if h != nil && h.ipsetMgr != nil {
|
||||
return h.ipsetMgr.Close()
|
||||
return fmt.Errorf("initializing ipset: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dctxIsFilled returns true if dctx has enough information to process.
|
||||
func dctxIsFilled(dctx *dnsContext) (ok bool) {
|
||||
// close closes the Linux Netfilter connections.
|
||||
func (c *ipsetCtx) close() (err error) {
|
||||
if c.ipsetMgr != nil {
|
||||
return c.ipsetMgr.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ipsetCtx) dctxIsfilled(dctx *dnsContext) (ok bool) {
|
||||
return dctx != nil &&
|
||||
dctx.responseFromUpstream &&
|
||||
dctx.proxyCtx != nil &&
|
||||
@@ -78,8 +65,8 @@ func dctxIsFilled(dctx *dnsContext) (ok bool) {
|
||||
|
||||
// skipIpsetProcessing returns true when the ipset processing can be skipped for
|
||||
// this request.
|
||||
func (h *ipsetHandler) skipIpsetProcessing(dctx *dnsContext) (ok bool) {
|
||||
if h == nil || h.ipsetMgr == nil || !dctxIsFilled(dctx) {
|
||||
func (c *ipsetCtx) skipIpsetProcessing(dctx *dnsContext) (ok bool) {
|
||||
if c == nil || c.ipsetMgr == nil || !c.dctxIsfilled(dctx) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -121,31 +108,31 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
|
||||
}
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (h *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
|
||||
// TODO(s.chzhen): Use passed context.
|
||||
ctx := context.TODO()
|
||||
h.logger.DebugContext(ctx, "started processing")
|
||||
defer h.logger.DebugContext(ctx, "finished processing")
|
||||
func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: ipset: started processing")
|
||||
defer log.Debug("dnsforward: ipset: finished processing")
|
||||
|
||||
if h.skipIpsetProcessing(dctx) {
|
||||
if c.skipIpsetProcessing(dctx) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("ipset: starting processing")
|
||||
|
||||
req := dctx.proxyCtx.Req
|
||||
host := req.Question[0].Name
|
||||
host = strings.TrimSuffix(host, ".")
|
||||
host = strings.ToLower(host)
|
||||
|
||||
ip4s, ip6s := ipsFromAnswer(dctx.proxyCtx.Res.Answer)
|
||||
n, err := h.ipsetMgr.Add(ctx, host, ip4s, ip6s)
|
||||
n, err := c.ipsetMgr.Add(host, ip4s, ip6s)
|
||||
if err != nil {
|
||||
// Consider ipset errors non-critical to the request.
|
||||
h.logger.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
|
||||
log.Error("dnsforward: ipset: adding host ips: %s", err)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
h.logger.DebugContext(ctx, "added new ipset entries", "num", n)
|
||||
log.Debug("dnsforward: ipset: added %d new ipset entries", n)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -18,7 +16,7 @@ type fakeIpsetMgr struct {
|
||||
}
|
||||
|
||||
// Add implements the aghnet.IpsetManager interface for *fakeIpsetMgr.
|
||||
func (m *fakeIpsetMgr) Add(_ context.Context, host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
func (m *fakeIpsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
m.ip4s = append(m.ip4s, ip4s...)
|
||||
m.ip6s = append(m.ip6s, ip6s...)
|
||||
|
||||
@@ -60,9 +58,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
responseFromUpstream: true,
|
||||
}
|
||||
|
||||
ictx := &ipsetHandler{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
ictx := &ipsetCtx{}
|
||||
rc := ictx.process(dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
|
||||
@@ -81,9 +77,8 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
}
|
||||
|
||||
m := &fakeIpsetMgr{}
|
||||
ictx := &ipsetHandler{
|
||||
ictx := &ipsetCtx{
|
||||
ipsetMgr: m,
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
rc := ictx.process(dctx)
|
||||
@@ -106,9 +101,8 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
}
|
||||
|
||||
m := &fakeIpsetMgr{}
|
||||
ictx := &ipsetHandler{
|
||||
ictx := &ipsetCtx{
|
||||
ipsetMgr: m,
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
rc := ictx.process(dctx)
|
||||
@@ -130,9 +124,8 @@ func TestIpsetCtx_SkipIpsetProcessing(t *testing.T) {
|
||||
}
|
||||
|
||||
m := &fakeIpsetMgr{}
|
||||
ictx := &ipsetHandler{
|
||||
ictx := &ipsetCtx{
|
||||
ipsetMgr: m,
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (s *Server) genDNSFilterMessage(
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
|
||||
return s.NewMsgNODATA(req)
|
||||
return s.newMsgNODATA(req)
|
||||
}
|
||||
|
||||
switch res.Reason {
|
||||
@@ -344,6 +344,51 @@ func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
|
||||
return s.reply(req, dns.RcodeRefused)
|
||||
}
|
||||
|
||||
// newMsgNODATA returns a properly initialized NODATA response.
|
||||
//
|
||||
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
|
||||
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeSuccess)
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genSOA(request *dns.Msg) []dns.RR {
|
||||
zone := ""
|
||||
if len(request.Question) > 0 {
|
||||
zone = request.Question[0].Name
|
||||
}
|
||||
|
||||
soa := dns.SOA{
|
||||
// values copied from verisign's nonexistent .com domain
|
||||
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
|
||||
Refresh: 1800,
|
||||
Retry: 900,
|
||||
Expire: 604800,
|
||||
Minttl: 86400,
|
||||
// copied from AdGuard DNS
|
||||
Ns: "fake-for-negative-caching.adguard.com.",
|
||||
Serial: 100500,
|
||||
// rest is request-specific
|
||||
Hdr: dns.RR_Header{
|
||||
Name: zone,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Ttl: s.dnsFilter.BlockedResponseTTL(),
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
|
||||
}
|
||||
if soa.Hdr.Ttl == 0 {
|
||||
soa.Hdr.Ttl = defaultBlockedResponseTTL
|
||||
}
|
||||
if len(zone) > 0 && zone[0] != '.' {
|
||||
soa.Mbox += zone
|
||||
}
|
||||
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ proxy.MessageConstructor = (*Server)(nil)
|
||||
|
||||
@@ -380,52 +425,3 @@ func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewMsgNODATA implements the [proxy.MessageConstructor] interface for *Server.
|
||||
func (s *Server) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp = s.reply(req, dns.RcodeSuccess)
|
||||
resp.Ns = s.genSOA(req)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genSOA(req *dns.Msg) []dns.RR {
|
||||
zone := ""
|
||||
if len(req.Question) > 0 {
|
||||
zone = req.Question[0].Name
|
||||
}
|
||||
|
||||
const defaultBlockedResponseTTL = 3600
|
||||
|
||||
soa := dns.SOA{
|
||||
// Values copied from verisign's nonexistent.com domain.
|
||||
//
|
||||
// Their exact values are not important in our use case because they are
|
||||
// used for domain transfers between primary/secondary DNS servers.
|
||||
Refresh: 1800,
|
||||
Retry: 900,
|
||||
Expire: 604800,
|
||||
Minttl: 86400,
|
||||
// copied from AdGuard DNS
|
||||
Ns: "fake-for-negative-caching.adguard.com.",
|
||||
Serial: 100500,
|
||||
// rest is request-specific
|
||||
Hdr: dns.RR_Header{
|
||||
Name: zone,
|
||||
Rrtype: dns.TypeSOA,
|
||||
Ttl: s.dnsFilter.BlockedResponseTTL(),
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
// zone will be appended later if it's not ".".
|
||||
Mbox: "hostmaster.",
|
||||
}
|
||||
if soa.Hdr.Ttl == 0 {
|
||||
soa.Hdr.Ttl = defaultBlockedResponseTTL
|
||||
}
|
||||
|
||||
if zone != "." {
|
||||
soa.Mbox += zone
|
||||
}
|
||||
|
||||
return []dns.RR{&soa}
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
if s.conf.AAAADisabled && qt == dns.TypeAAAA {
|
||||
pctx.Res = s.NewMsgNODATA(pctx.Req)
|
||||
_ = proxy.CheckDisabledAAAARequest(pctx, true)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
@@ -431,7 +430,6 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
dnsFilter: createTestDNSFilter(t),
|
||||
dhcpServer: dhcp,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
baseLogger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
@@ -567,7 +565,6 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
dnsFilter: createTestDNSFilter(t),
|
||||
dhcpServer: testDHCP,
|
||||
localDomainSuffix: tc.suffix,
|
||||
baseLogger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -203,7 +202,6 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
ql := &testQueryLog{}
|
||||
st := &testStats{}
|
||||
srv := &Server{
|
||||
baseLogger: slogutil.NewDiscardLogger(),
|
||||
queryLog: ql,
|
||||
stats: st,
|
||||
anonymizer: aghnet.NewIPMut(nil),
|
||||
|
||||
@@ -150,12 +150,12 @@ func setProxyUpstreamMode(
|
||||
) (err error) {
|
||||
switch upstreamMode {
|
||||
case UpstreamModeParallel:
|
||||
conf.UpstreamMode = proxy.UpstreamModeParallel
|
||||
conf.UpstreamMode = proxy.UModeParallel
|
||||
case UpstreamModeFastestAddr:
|
||||
conf.UpstreamMode = proxy.UpstreamModeFastestAddr
|
||||
conf.UpstreamMode = proxy.UModeFastestAddr
|
||||
conf.FastestPingTimeout = fastestTimeout
|
||||
case UpstreamModeLoadBalance:
|
||||
conf.UpstreamMode = proxy.UpstreamModeLoadBalance
|
||||
conf.UpstreamMode = proxy.UModeLoadBalance
|
||||
default:
|
||||
return fmt.Errorf("unexpected value %q", upstreamMode)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ func fromCacheItem(item *cacheItem) (data []byte) {
|
||||
data = binary.BigEndian.AppendUint64(data, uint64(expiry))
|
||||
|
||||
for _, v := range item.hashes {
|
||||
// nolint:looppointer // The subslice of v is used for a copy.
|
||||
data = append(data, v[:]...)
|
||||
}
|
||||
|
||||
@@ -62,6 +63,7 @@ func (c *Checker) findInCache(
|
||||
|
||||
i := 0
|
||||
for _, hash := range hashes {
|
||||
// nolint:looppointer // The has subslice is used for a cache lookup.
|
||||
data := c.cache.Get(hash[:prefixLen])
|
||||
if data == nil {
|
||||
hashes[i] = hash
|
||||
@@ -96,6 +98,7 @@ func (c *Checker) storeInCache(hashesToRequest, respHashes []hostnameHash) {
|
||||
|
||||
for _, hash := range respHashes {
|
||||
var pref prefix
|
||||
// nolint:looppointer // The hash subslice is used for a copy.
|
||||
copy(pref[:], hash[:])
|
||||
|
||||
hashToStore[pref] = append(hashToStore[pref], hash)
|
||||
@@ -106,9 +109,11 @@ func (c *Checker) storeInCache(hashesToRequest, respHashes []hostnameHash) {
|
||||
}
|
||||
|
||||
for _, hash := range hashesToRequest {
|
||||
// nolint:looppointer // The hash subslice is used for a cache lookup.
|
||||
val := c.cache.Get(hash[:prefixLen])
|
||||
if val == nil {
|
||||
var pref prefix
|
||||
// nolint:looppointer // The hash subslice is used for a copy.
|
||||
copy(pref[:], hash[:])
|
||||
|
||||
c.setCache(pref, nil)
|
||||
|
||||
@@ -173,6 +173,7 @@ func (c *Checker) getQuestion(hashes []hostnameHash) (q string) {
|
||||
b := &strings.Builder{}
|
||||
|
||||
for _, hash := range hashes {
|
||||
// nolint:looppointer // The hash subslice is used for hex encoding.
|
||||
stringutil.WriteToBuilder(b, hex.EncodeToString(hash[:prefixLen]), ".")
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ type SafeSearchConfig struct {
|
||||
|
||||
Bing bool `yaml:"bing" json:"bing"`
|
||||
DuckDuckGo bool `yaml:"duckduckgo" json:"duckduckgo"`
|
||||
Ecosia bool `yaml:"ecosia" json:"ecosia"`
|
||||
Google bool `yaml:"google" json:"google"`
|
||||
Pixabay bool `yaml:"pixabay" json:"pixabay"`
|
||||
Yandex bool `yaml:"yandex" json:"yandex"`
|
||||
|
||||
@@ -14,9 +14,6 @@ var pixabay string
|
||||
//go:embed rules/duckduckgo.txt
|
||||
var duckduckgo string
|
||||
|
||||
//go:embed rules/ecosia.txt
|
||||
var ecosia string
|
||||
|
||||
//go:embed rules/yandex.txt
|
||||
var yandex string
|
||||
|
||||
@@ -30,7 +27,6 @@ var youtube string
|
||||
var safeSearchRules = map[Service]string{
|
||||
Bing: bing,
|
||||
DuckDuckGo: duckduckgo,
|
||||
Ecosia: ecosia,
|
||||
Google: google,
|
||||
Pixabay: pixabay,
|
||||
Yandex: yandex,
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com
|
||||
|edgeservices.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|www.ecosia.org^$dnsrewrite=NOERROR;CNAME;strict-safe-search.ecosia.org
|
||||
@@ -46,9 +46,6 @@
|
||||
|www.google.co.uz^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.co.ve^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.co.vi^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.co.za^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.co.zm^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.co.zw^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.com.af^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.com.ag^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|www.google.com.ai^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|
||||
|
||||
@@ -28,7 +28,6 @@ type Service string
|
||||
const (
|
||||
Bing Service = "bing"
|
||||
DuckDuckGo Service = "duckduckgo"
|
||||
Ecosia Service = "ecosia"
|
||||
Google Service = "google"
|
||||
Pixabay Service = "pixabay"
|
||||
Yandex Service = "yandex"
|
||||
@@ -42,8 +41,6 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
|
||||
return s.Bing
|
||||
case DuckDuckGo:
|
||||
return s.DuckDuckGo
|
||||
case Ecosia:
|
||||
return s.Ecosia
|
||||
case Google:
|
||||
return s.Google
|
||||
case Pixabay:
|
||||
|
||||
@@ -25,7 +25,6 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
|
||||
Enabled: true,
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Ecosia: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
|
||||
@@ -34,7 +34,6 @@ var testConf = filtering.SafeSearchConfig{
|
||||
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Ecosia: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
|
||||
@@ -97,7 +97,6 @@ func glGetTokenDate(file string) uint32 {
|
||||
|
||||
buf := bytes.NewBuffer(bs)
|
||||
|
||||
// TODO(a.garipov): Get rid of github.com/josharian/native dependency.
|
||||
err = binary.Read(buf, native.Endian, &dateToken)
|
||||
if err != nil {
|
||||
log.Error("decoding token: %s", err)
|
||||
|
||||
@@ -47,6 +47,9 @@ type clientsContainer struct {
|
||||
// storage stores information about persistent clients.
|
||||
storage *client.Storage
|
||||
|
||||
// runtimeIndex stores information about runtime clients.
|
||||
runtimeIndex *client.RuntimeIndex
|
||||
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
@@ -102,6 +105,8 @@ func (clients *clientsContainer) Init(
|
||||
return errors.Error("clients container already initialized")
|
||||
}
|
||||
|
||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
||||
|
||||
clients.storage = client.NewStorage(&client.Config{
|
||||
AllowedTags: clientTags,
|
||||
})
|
||||
@@ -353,7 +358,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
||||
return client.SourcePersistent
|
||||
}
|
||||
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc != nil {
|
||||
src, _ = rc.Info()
|
||||
}
|
||||
@@ -534,9 +539,22 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
|
||||
return clients.storage.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
return clients.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
rc = clients.storage.ClientRuntime(ip)
|
||||
rc = clients.runtimeClient(ip)
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
|
||||
if host != "" {
|
||||
@@ -562,11 +580,20 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||
return
|
||||
}
|
||||
|
||||
rc := client.NewRuntime(ip)
|
||||
rc.SetWHOIS(wi)
|
||||
clients.storage.UpdateRuntime(rc)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
host, _ := rc.Info()
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
|
||||
}
|
||||
|
||||
rc.SetWHOIS(wi)
|
||||
}
|
||||
|
||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
@@ -617,20 +644,26 @@ func (clients *clientsContainer) addHostLocked(
|
||||
host string,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
rc := client.NewRuntime(ip)
|
||||
rc.SetInfo(src, []string{host})
|
||||
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
|
||||
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
if src < client.SourceDHCP {
|
||||
if clients.dhcp.HostByIP(ip) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
}
|
||||
|
||||
clients.storage.UpdateRuntime(rc)
|
||||
rc.SetInfo(src, []string{host})
|
||||
|
||||
log.Debug(
|
||||
"clients: adding client info %s -> %q %q [%d]",
|
||||
ip,
|
||||
src,
|
||||
host,
|
||||
clients.storage.SizeRuntime(),
|
||||
clients.runtimeIndex.Size(),
|
||||
)
|
||||
|
||||
return true
|
||||
@@ -642,22 +675,23 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
var rcs []*client.Runtime
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
|
||||
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
|
||||
|
||||
added := 0
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
rc := client.NewRuntime(addr)
|
||||
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
|
||||
|
||||
rcs = append(rcs, rc)
|
||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
||||
added++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
|
||||
log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed)
|
||||
log.Debug("clients: added %d client aliases from system hosts file", added)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
@@ -681,16 +715,17 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
var rcs []*client.Runtime
|
||||
for _, n := range ns {
|
||||
rc := client.NewRuntime(n.IP)
|
||||
rc.SetInfo(client.SourceARP, []string{n.Name})
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
|
||||
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
|
||||
|
||||
rcs = append(rcs, rc)
|
||||
added := 0
|
||||
for _, n := range ns {
|
||||
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
|
||||
added++
|
||||
}
|
||||
}
|
||||
|
||||
added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
|
||||
log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed)
|
||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
||||
}
|
||||
|
||||
// close gracefully closes all the client-specific upstream configurations of
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.255")
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
@@ -252,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
@@ -269,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
require.Nil(t, rc)
|
||||
|
||||
assert.True(t, clients.storage.RemoveByName("client1"))
|
||||
|
||||
@@ -103,7 +103,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
return true
|
||||
})
|
||||
|
||||
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
cj := runtimeClientJSON{
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
@@ -248,7 +248,6 @@ func copySafeSearch(
|
||||
if conf.Enabled {
|
||||
conf.Bing = true
|
||||
conf.DuckDuckGo = true
|
||||
conf.Ecosia = true
|
||||
conf.Google = true
|
||||
conf.Pixabay = true
|
||||
conf.Yandex = true
|
||||
|
||||
@@ -423,7 +423,6 @@ var config = &configuration{
|
||||
Enabled: false,
|
||||
Bing: true,
|
||||
DuckDuckGo: true,
|
||||
Ecosia: true,
|
||||
Google: true,
|
||||
Pixabay: true,
|
||||
Yandex: true,
|
||||
|
||||
@@ -433,7 +433,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods(web.logger)
|
||||
err = startMods()
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
|
||||
@@ -2,7 +2,6 @@ package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -20,7 +19,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
@@ -45,8 +43,8 @@ func onConfigModified() {
|
||||
|
||||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized. l must not be nil.
|
||||
func initDNS(l *slog.Logger) (err error) {
|
||||
// [config] and [Context] are initialized.
|
||||
func initDNS() (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||
@@ -55,7 +53,6 @@ func initDNS(l *slog.Logger) (err error) {
|
||||
}
|
||||
|
||||
statsConf := stats.Config{
|
||||
Logger: l.With(slogutil.KeyPrefix, "stats"),
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
ConfigModified: onConfigModified,
|
||||
@@ -116,16 +113,13 @@ func initDNS(l *slog.Logger) (err error) {
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
l,
|
||||
)
|
||||
}
|
||||
|
||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||
// proxy, none of the arguments are required, but tlsConf and l still must not
|
||||
// be nil, in other cases all the arguments also must not be nil. It also must
|
||||
// not be called unless [config] and [Context] are initialized.
|
||||
//
|
||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||
// proxy, none of the arguments are required, but tlsConf still must not be nil,
|
||||
// in other cases all the arguments also must not be nil. It also must not be
|
||||
// called unless [config] and [Context] are initialized.
|
||||
func initDNSServer(
|
||||
filters *filtering.DNSFilter,
|
||||
sts stats.Interface,
|
||||
@@ -134,10 +128,8 @@ func initDNSServer(
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsConf *tlsConfigSettings,
|
||||
l *slog.Logger,
|
||||
) (err error) {
|
||||
Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
Logger: l,
|
||||
DNSFilter: filters,
|
||||
Stats: sts,
|
||||
QueryLog: qlog,
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -39,7 +38,6 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
)
|
||||
@@ -92,8 +90,6 @@ func (c *homeContext) getDataDir() string {
|
||||
}
|
||||
|
||||
// Context - a global context object
|
||||
//
|
||||
// TODO(a.garipov): Refactor.
|
||||
var Context homeContext
|
||||
|
||||
// Main is the entry point
|
||||
@@ -277,7 +273,7 @@ func setupOpts(opts options) (err error) {
|
||||
}
|
||||
|
||||
// initContextClients initializes Context clients and related fields.
|
||||
func initContextClients(logger *slog.Logger) (err error) {
|
||||
func initContextClients() (err error) {
|
||||
err = setupDNSFilteringConf(config.Filtering)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
@@ -301,7 +297,7 @@ func initContextClients(logger *slog.Logger) (err error) {
|
||||
|
||||
var arpDB arpdb.Interface
|
||||
if config.Clients.Sources.ARP {
|
||||
arpDB = arpdb.New(logger.With(slogutil.KeyError, "arpdb"))
|
||||
arpDB = arpdb.New()
|
||||
}
|
||||
|
||||
return Context.clients.Init(
|
||||
@@ -486,12 +482,7 @@ func checkPorts() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func initWeb(
|
||||
opts options,
|
||||
clientBuildFS fs.FS,
|
||||
upd *updater.Updater,
|
||||
l *slog.Logger,
|
||||
) (web *webAPI, err error) {
|
||||
func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webAPI, err error) {
|
||||
var clientFS fs.FS
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
@@ -533,7 +524,7 @@ func initWeb(
|
||||
serveHTTP3: config.DNS.ServeHTTP3,
|
||||
}
|
||||
|
||||
web = newWebAPI(webConf, l)
|
||||
web = newWebAPI(webConf)
|
||||
if web == nil {
|
||||
return nil, fmt.Errorf("initializing web: %w", err)
|
||||
}
|
||||
@@ -556,15 +547,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// Configure config filename.
|
||||
initConfigFilename(opts)
|
||||
|
||||
ls := getLogSettings(opts)
|
||||
|
||||
// Configure log level and output.
|
||||
err = configureLogger(ls)
|
||||
err = configureLogger(opts)
|
||||
fatalOnError(err)
|
||||
|
||||
// TODO(a.garipov): Use slog everywhere.
|
||||
slogLogger := newSlogLogger(ls)
|
||||
|
||||
// Print the first message after logger is configured.
|
||||
log.Info(version.Full())
|
||||
log.Debug("current working directory is %s", Context.workDir)
|
||||
@@ -583,7 +569,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
// data first, but also to avoid relying on automatic Go init() function.
|
||||
filtering.InitModule()
|
||||
|
||||
err = initContextClients(slogLogger)
|
||||
err = initContextClients()
|
||||
fatalOnError(err)
|
||||
|
||||
err = setupOpts(opts)
|
||||
@@ -618,7 +604,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
|
||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||
// effect.
|
||||
cmdlineUpdate(opts, upd, slogLogger)
|
||||
cmdlineUpdate(opts, upd)
|
||||
|
||||
if !Context.firstRun {
|
||||
// Save the updated config.
|
||||
@@ -646,11 +632,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||
onConfigModified()
|
||||
}
|
||||
|
||||
Context.web, err = initWeb(opts, clientBuildFS, upd, slogLogger)
|
||||
Context.web, err = initWeb(opts, clientBuildFS, upd)
|
||||
fatalOnError(err)
|
||||
|
||||
if !Context.firstRun {
|
||||
err = initDNS(slogLogger)
|
||||
err = initDNS()
|
||||
fatalOnError(err)
|
||||
|
||||
Context.tls.start()
|
||||
@@ -711,10 +697,9 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
|
||||
return aghnet.NewIPMut(anonFunc)
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation. l must
|
||||
// not be nil.
|
||||
func startMods(l *slog.Logger) (err error) {
|
||||
err = initDNS(l)
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
func startMods() (err error) {
|
||||
err = initDNS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -974,8 +959,8 @@ type jsonError struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// cmdlineUpdate updates current application and exits. l must not be nil.
|
||||
func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) {
|
||||
// cmdlineUpdate updates current application and exits.
|
||||
func cmdlineUpdate(opts options, upd *updater.Updater) {
|
||||
if !opts.performUpdate {
|
||||
return
|
||||
}
|
||||
@@ -985,7 +970,7 @@ func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) {
|
||||
//
|
||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||
// separately.
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l)
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{})
|
||||
fatalOnError(err)
|
||||
|
||||
log.Info("cmdline update: performing update")
|
||||
|
||||
@@ -3,13 +3,11 @@ package home
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -18,21 +16,10 @@ import (
|
||||
// for logger output.
|
||||
const configSyslog = "syslog"
|
||||
|
||||
// newSlogLogger returns new [*slog.Logger] configured with the given settings.
|
||||
func newSlogLogger(ls *logSettings) (l *slog.Logger) {
|
||||
if !ls.Enabled {
|
||||
return slogutil.NewDiscardLogger()
|
||||
}
|
||||
|
||||
return slogutil.New(&slogutil.Config{
|
||||
Format: slogutil.FormatAdGuardLegacy,
|
||||
AddTimestamp: true,
|
||||
Verbose: ls.Verbose,
|
||||
})
|
||||
}
|
||||
|
||||
// configureLogger configures logger level and output.
|
||||
func configureLogger(ls *logSettings) (err error) {
|
||||
func configureLogger(opts options) (err error) {
|
||||
ls := getLogSettings(opts)
|
||||
|
||||
// Configure logger level.
|
||||
if !ls.Enabled {
|
||||
log.SetLevel(log.OFF)
|
||||
@@ -73,7 +60,7 @@ func configureLogger(ls *logSettings) (err error) {
|
||||
MaxAge: ls.MaxAge,
|
||||
})
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLogSettings returns a log settings object properly initialized from opts.
|
||||
|
||||
@@ -5,15 +5,12 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// middlerware is a wrapper function signature.
|
||||
type middleware func(http.Handler) http.Handler
|
||||
|
||||
// withMiddlewares consequently wraps h with all the middlewares.
|
||||
//
|
||||
// TODO(e.burkov): Use [httputil.Wrap].
|
||||
func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Handler) {
|
||||
wrapped = h
|
||||
|
||||
@@ -26,11 +23,11 @@ func withMiddlewares(h http.Handler, middlewares ...middleware) (wrapped http.Ha
|
||||
|
||||
const (
|
||||
// defaultReqBodySzLim is the default maximum request body size.
|
||||
defaultReqBodySzLim datasize.ByteSize = 64 * datasize.KB
|
||||
defaultReqBodySzLim = 64 * 1024
|
||||
|
||||
// largerReqBodySzLim is the maximum request body size for APIs expecting
|
||||
// larger requests.
|
||||
largerReqBodySzLim datasize.ByteSize = 4 * datasize.MB
|
||||
largerReqBodySzLim = 4 * 1024 * 1024
|
||||
)
|
||||
|
||||
// expectsLargerRequests shows if this request should use a larger body size
|
||||
@@ -41,28 +38,26 @@ const (
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2666 and
|
||||
// https://github.com/AdguardTeam/AdGuardHome/issues/2675.
|
||||
func expectsLargerRequests(r *http.Request) (ok bool) {
|
||||
if r.Method != http.MethodPost {
|
||||
m := r.Method
|
||||
if m != http.MethodPost {
|
||||
return false
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/control/access/set", "/control/filtering/set_rules":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
p := r.URL.Path
|
||||
return p == "/control/access/set" ||
|
||||
p == "/control/filtering/set_rules"
|
||||
}
|
||||
|
||||
// limitRequestBody wraps underlying handler h, making it's request's body Read
|
||||
// method limited.
|
||||
func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
szLim := defaultReqBodySzLim
|
||||
var szLim uint64 = defaultReqBodySzLim
|
||||
if expectsLargerRequests(r) {
|
||||
szLim = largerReqBodySzLim
|
||||
}
|
||||
|
||||
reader := ioutil.LimitReader(r.Body, szLim.Bytes())
|
||||
reader := ioutil.LimitReader(r.Body, szLim)
|
||||
|
||||
// HTTP handlers aren't supposed to call r.Body.Close(), so just
|
||||
// replace the body in a clone.
|
||||
|
||||
@@ -14,29 +14,29 @@ import (
|
||||
|
||||
func TestLimitRequestBody(t *testing.T) {
|
||||
errReqLimitReached := &ioutil.LimitError{
|
||||
Limit: defaultReqBodySzLim.Bytes(),
|
||||
Limit: defaultReqBodySzLim,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
wantErr error
|
||||
name string
|
||||
body string
|
||||
want []byte
|
||||
wantErr error
|
||||
}{{
|
||||
wantErr: nil,
|
||||
name: "not_so_big",
|
||||
body: "somestr",
|
||||
want: []byte("somestr"),
|
||||
wantErr: nil,
|
||||
}, {
|
||||
wantErr: errReqLimitReached,
|
||||
name: "so_big",
|
||||
body: string(make([]byte, defaultReqBodySzLim+1)),
|
||||
want: make([]byte, defaultReqBodySzLim),
|
||||
wantErr: errReqLimitReached,
|
||||
}, {
|
||||
wantErr: nil,
|
||||
name: "empty",
|
||||
body: "",
|
||||
want: []byte(nil),
|
||||
wantErr: nil,
|
||||
}}
|
||||
|
||||
makeHandler := func(t *testing.T, err *error) http.HandlerFunc {
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -17,7 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/pprofutil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"golang.org/x/net/http2"
|
||||
@@ -91,22 +90,17 @@ type webAPI struct {
|
||||
// TODO(a.garipov): Refactor all these servers.
|
||||
httpServer *http.Server
|
||||
|
||||
// logger is a slog logger used in webAPI. It must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// httpsServer is the server that handles HTTPS traffic. If it is not nil,
|
||||
// [Web.http3Server] must also not be nil.
|
||||
httpsServer httpsServer
|
||||
}
|
||||
|
||||
// newWebAPI creates a new instance of the web UI and API server. l must not be
|
||||
// nil.
|
||||
func newWebAPI(conf *webConfig, l *slog.Logger) (w *webAPI) {
|
||||
// newWebAPI creates a new instance of the web UI and API server.
|
||||
func newWebAPI(conf *webConfig) (w *webAPI) {
|
||||
log.Info("web: initializing")
|
||||
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
logger: l,
|
||||
conf: conf,
|
||||
}
|
||||
|
||||
clientFS := http.FileServer(http.FS(conf.clientFS))
|
||||
@@ -333,7 +327,7 @@ func startPprof(port uint16) {
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
httputil.RoutePprof(mux)
|
||||
pprofutil.RoutePprof(mux)
|
||||
|
||||
go func() {
|
||||
defer log.OnPanic("pprof server")
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
)
|
||||
|
||||
@@ -12,33 +10,24 @@ import (
|
||||
// TODO(a.garipov): Perhaps generalize this into some kind of a NetFilter type,
|
||||
// since ipset is exclusive to Linux?
|
||||
type Manager interface {
|
||||
Add(ctx context.Context, host string, ip4s, ip6s []net.IP) (n int, err error)
|
||||
Add(host string, ip4s, ip6s []net.IP) (n int, err error)
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
// Config is the configuration structure for the ipset manager.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the ipset manager. It must
|
||||
// not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Lines is the ipset configuration with the following syntax:
|
||||
//
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
|
||||
//
|
||||
// Lines must not contain any blank lines or comments.
|
||||
Lines []string
|
||||
}
|
||||
|
||||
// NewManager returns a new ipset manager. IPv4 addresses are added to an ipset
|
||||
// with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must exist.
|
||||
// NewManager returns a new ipset manager. IPv4 addresses are added to an
|
||||
// ipset with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must
|
||||
// exist.
|
||||
//
|
||||
// If conf.Lines is empty, mgr and err are nil. The error's chain contains
|
||||
// The syntax of the ipsetConf is:
|
||||
//
|
||||
// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]...
|
||||
//
|
||||
// If ipsetConf is empty, msg and err are nil. The error's chain contains
|
||||
// [errors.ErrUnsupported] if current OS is not supported.
|
||||
func NewManager(ctx context.Context, conf *Config) (mgr Manager, err error) {
|
||||
if len(conf.Lines) == 0 {
|
||||
func NewManager(ipsetConf []string) (mgr Manager, err error) {
|
||||
if len(ipsetConf) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return newManager(ctx, conf)
|
||||
return newManager(ipsetConf)
|
||||
}
|
||||
|
||||
@@ -4,16 +4,14 @@ package ipset
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/digineo/go-ipset/v2"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/ti-mo/netfilter"
|
||||
@@ -36,8 +34,8 @@ import (
|
||||
// resolved IP addresses.
|
||||
|
||||
// newManager returns a new Linux ipset manager.
|
||||
func newManager(ctx context.Context, conf *Config) (set Manager, err error) {
|
||||
return newManagerWithDialer(ctx, conf, defaultDial)
|
||||
func newManager(ipsetConf []string) (set Manager, err error) {
|
||||
return newManagerWithDialer(ipsetConf, defaultDial)
|
||||
}
|
||||
|
||||
// defaultDial is the default netfilter dialing function.
|
||||
@@ -182,8 +180,6 @@ type manager struct {
|
||||
nameToIpset map[string]props
|
||||
domainToIpsets map[string][]props
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
dial dialer
|
||||
|
||||
// mu protects all properties below.
|
||||
@@ -258,7 +254,7 @@ func parseIpsetConfigLine(confStr string) (hosts, ipsetNames []string, err error
|
||||
|
||||
// parseIpsetConfig parses the ipset configuration and stores ipsets. It
|
||||
// returns an error if the configuration can't be used.
|
||||
func (m *manager) parseIpsetConfig(ctx context.Context, ipsetConf []string) (err error) {
|
||||
func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) {
|
||||
// The family doesn't seem to matter when we use a header query, so query
|
||||
// only the IPv4 one.
|
||||
//
|
||||
@@ -282,7 +278,7 @@ func (m *manager) parseIpsetConfig(ctx context.Context, ipsetConf []string) (err
|
||||
}
|
||||
|
||||
var ipsets []props
|
||||
ipsets, err = m.ipsets(ctx, ipsetNames, currentlyKnown)
|
||||
ipsets, err = m.ipsets(ipsetNames, currentlyKnown)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting ipsets from config line at idx %d: %w", i, err)
|
||||
}
|
||||
@@ -332,11 +328,7 @@ func (m *manager) ipsetProps(name string) (p props, err error) {
|
||||
|
||||
// ipsets returns ipset properties of currently known ipsets. It also makes an
|
||||
// additional ipset header data query if needed.
|
||||
func (m *manager) ipsets(
|
||||
ctx context.Context,
|
||||
names []string,
|
||||
currentlyKnown map[string]props,
|
||||
) (sets []props, err error) {
|
||||
func (m *manager) ipsets(names []string, currentlyKnown map[string]props) (sets []props, err error) {
|
||||
for _, n := range names {
|
||||
p, ok := currentlyKnown[n]
|
||||
if !ok {
|
||||
@@ -344,12 +336,10 @@ func (m *manager) ipsets(
|
||||
}
|
||||
|
||||
if p.family != netfilter.ProtoIPv4 && p.family != netfilter.ProtoIPv6 {
|
||||
m.logger.DebugContext(
|
||||
ctx,
|
||||
"got unexpected ipset family while getting set properties",
|
||||
"set_name", p.name,
|
||||
"set_type", p.typeName,
|
||||
"set_family", p.family,
|
||||
log.Debug("ipset: getting properties: %q %q unexpected ipset family %q",
|
||||
p.name,
|
||||
p.typeName,
|
||||
p.family,
|
||||
)
|
||||
|
||||
p, err = m.ipsetProps(n)
|
||||
@@ -367,7 +357,7 @@ func (m *manager) ipsets(
|
||||
|
||||
// newManagerWithDialer returns a new Linux ipset manager using the provided
|
||||
// dialer.
|
||||
func newManagerWithDialer(ctx context.Context, conf *Config, dial dialer) (mgr Manager, err error) {
|
||||
func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err error) {
|
||||
defer func() { err = errors.Annotate(err, "ipset: %w") }()
|
||||
|
||||
m := &manager{
|
||||
@@ -376,8 +366,6 @@ func newManagerWithDialer(ctx context.Context, conf *Config, dial dialer) (mgr M
|
||||
nameToIpset: make(map[string]props),
|
||||
domainToIpsets: make(map[string][]props),
|
||||
|
||||
logger: conf.Logger,
|
||||
|
||||
dial: dial,
|
||||
|
||||
addedIPs: container.NewMapSet[ipInIpsetEntry](),
|
||||
@@ -388,7 +376,7 @@ func newManagerWithDialer(ctx context.Context, conf *Config, dial dialer) (mgr M
|
||||
if errors.Is(err, unix.EPROTONOSUPPORT) {
|
||||
// The implementation doesn't support this protocol version. Just
|
||||
// issue a warning.
|
||||
m.logger.WarnContext(ctx, "dialing netfilter", slogutil.KeyError, err)
|
||||
log.Info("ipset: dialing netfilter: warning: %s", err)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -396,12 +384,12 @@ func newManagerWithDialer(ctx context.Context, conf *Config, dial dialer) (mgr M
|
||||
return nil, fmt.Errorf("dialing netfilter: %w", err)
|
||||
}
|
||||
|
||||
err = m.parseIpsetConfig(ctx, conf.Lines)
|
||||
err = m.parseIpsetConfig(ipsetConf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ipsets: %w", err)
|
||||
}
|
||||
|
||||
m.logger.DebugContext(ctx, "initialized")
|
||||
log.Debug("ipset: initialized")
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -488,7 +476,6 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
||||
|
||||
// addToSets adds the IP addresses to the corresponding ipset.
|
||||
func (m *manager) addToSets(
|
||||
ctx context.Context,
|
||||
host string,
|
||||
ip4s []net.IP,
|
||||
ip6s []net.IP,
|
||||
@@ -511,13 +498,7 @@ func (m *manager) addToSets(
|
||||
return n, fmt.Errorf("%q %q unexpected family %q", set.name, set.typeName, set.family)
|
||||
}
|
||||
|
||||
m.logger.DebugContext(
|
||||
ctx,
|
||||
"added ips to set",
|
||||
"ips_num", nn,
|
||||
"set_name", set.name,
|
||||
"set_type", set.typeName,
|
||||
)
|
||||
log.Debug("ipset: added %d ips to set %q %q", nn, set.name, set.typeName)
|
||||
|
||||
n += nn
|
||||
}
|
||||
@@ -526,7 +507,7 @@ func (m *manager) addToSets(
|
||||
}
|
||||
|
||||
// Add implements the [Manager] interface for *manager.
|
||||
func (m *manager) Add(ctx context.Context, host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
func (m *manager) Add(host string, ip4s, ip6s []net.IP) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -535,9 +516,9 @@ func (m *manager) Add(ctx context.Context, host string, ip4s, ip6s []net.IP) (n
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
m.logger.DebugContext(ctx, "found sets", "set_num", len(sets))
|
||||
log.Debug("ipset: found %d sets", len(sets))
|
||||
|
||||
return m.addToSets(ctx, host, ip4s, ip6s, sets)
|
||||
return m.addToSets(host, ip4s, ip6s, sets)
|
||||
}
|
||||
|
||||
// Close implements the [Manager] interface for *manager.
|
||||
|
||||
@@ -6,11 +6,8 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/digineo/go-ipset/v2"
|
||||
"github.com/mdlayher/netlink"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,9 +15,6 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
)
|
||||
|
||||
// testTimeout is a common timeout for tests and contexts.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// fakeConn is a fake ipsetConn for tests.
|
||||
type fakeConn struct {
|
||||
ipv4Header *ipset.HeaderPolicy
|
||||
@@ -64,7 +58,7 @@ func (c *fakeConn) listAll() (sets []props, err error) {
|
||||
}
|
||||
|
||||
func TestManager_Add(t *testing.T) {
|
||||
ipsetList := []string{
|
||||
ipsetConf := []string{
|
||||
"example.com,example.net/ipv4set",
|
||||
"example.org,example.biz/ipv6set",
|
||||
}
|
||||
@@ -95,11 +89,7 @@ func TestManager_Add(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
conf := &Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Lines: ipsetList,
|
||||
}
|
||||
m, err := newManagerWithDialer(testutil.ContextWithTimeout(t, testTimeout), conf, fakeDial)
|
||||
m, err := newManagerWithDialer(ipsetConf, fakeDial)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
@@ -110,7 +100,7 @@ func TestManager_Add(t *testing.T) {
|
||||
0x00, 0x00, 0x56, 0x78,
|
||||
}
|
||||
|
||||
n, err := m.Add(testutil.ContextWithTimeout(t, testTimeout), "example.net", []net.IP{ip4}, nil)
|
||||
n, err := m.Add("example.net", []net.IP{ip4}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, n)
|
||||
@@ -120,7 +110,7 @@ func TestManager_Add(t *testing.T) {
|
||||
gotIP4 := ipv4Entries[0].IP.Value
|
||||
assert.Equal(t, ip4, gotIP4)
|
||||
|
||||
n, err = m.Add(testutil.ContextWithTimeout(t, testTimeout), "example.biz", nil, []net.IP{ip6})
|
||||
n, err = m.Add("example.biz", nil, []net.IP{ip6})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
@@ -3,11 +3,9 @@
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
)
|
||||
|
||||
func newManager(_ context.Context, _ *Config) (mgr Manager, err error) {
|
||||
func newManager(_ []string) (mgr Manager, err error) {
|
||||
return nil, aghos.Unsupported("ipset")
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/pprofutil"
|
||||
httptreemux "github.com/dimfeld/httptreemux/v5"
|
||||
)
|
||||
|
||||
@@ -107,7 +107,7 @@ func (svc *Service) setupPprof(c *PprofConfig) {
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
pprofMux := http.NewServeMux()
|
||||
httputil.RoutePprof(pprofMux)
|
||||
pprofutil.RoutePprof(pprofMux)
|
||||
|
||||
svc.pprofPort = c.Port
|
||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -32,7 +32,7 @@ type queryLog struct {
|
||||
|
||||
// buffer contains recent log entries. The entries in this buffer must not
|
||||
// be modified.
|
||||
buffer *container.RingBuffer[*logEntry]
|
||||
buffer *aghalg.RingBuffer[*logEntry]
|
||||
|
||||
// logFile is the path to the log file.
|
||||
logFile string
|
||||
@@ -225,7 +225,7 @@ func (l *queryLog) Add(params *AddParams) {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
l.buffer.Push(entry)
|
||||
l.buffer.Append(entry)
|
||||
|
||||
if !l.flushPending && fileIsEnabled && l.buffer.Len() >= memSize {
|
||||
l.flushPending = true
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -153,7 +153,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
|
||||
l = &queryLog{
|
||||
findClient: findClient,
|
||||
|
||||
buffer: container.NewRingBuffer[*logEntry](memSize),
|
||||
buffer: aghalg.NewRingBuffer[*logEntry](memSize),
|
||||
|
||||
conf: &Config{},
|
||||
confMu: &sync.RWMutex{},
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
package rdns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/bluele/gcache"
|
||||
)
|
||||
|
||||
@@ -16,7 +14,7 @@ import (
|
||||
type Interface interface {
|
||||
// Process makes rDNS request and returns domain name. changed indicates
|
||||
// that domain name was updated since last request.
|
||||
Process(ctx context.Context, ip netip.Addr) (host string, changed bool)
|
||||
Process(ip netip.Addr) (host string, changed bool)
|
||||
}
|
||||
|
||||
// Empty is an empty [Interface] implementation which does nothing.
|
||||
@@ -26,7 +24,7 @@ type Empty struct{}
|
||||
var _ Interface = (*Empty)(nil)
|
||||
|
||||
// Process implements the [Interface] interface for Empty.
|
||||
func (Empty) Process(_ context.Context, _ netip.Addr) (host string, changed bool) {
|
||||
func (Empty) Process(_ netip.Addr) (host string, changed bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -39,10 +37,6 @@ type Exchanger interface {
|
||||
|
||||
// Config is the configuration structure for Default.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the reverse DNS lookup
|
||||
// queries. It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Exchanger resolves IP addresses to domain names.
|
||||
Exchanger Exchanger
|
||||
|
||||
@@ -56,10 +50,6 @@ type Config struct {
|
||||
|
||||
// Default is the default rDNS query processor.
|
||||
type Default struct {
|
||||
// logger is used for logging the operation of the reverse DNS lookup
|
||||
// queries. It must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// cache is the cache containing IP addresses of clients. An active IP
|
||||
// address is resolved once again after it expires. If IP address couldn't
|
||||
// be resolved, it stays here for some time to prevent further attempts to
|
||||
@@ -76,7 +66,6 @@ type Default struct {
|
||||
// New returns a new default rDNS query processor. conf must not be nil.
|
||||
func New(conf *Config) (r *Default) {
|
||||
return &Default{
|
||||
logger: conf.Logger,
|
||||
cache: gcache.New(conf.CacheSize).LRU().Build(),
|
||||
exchanger: conf.Exchanger,
|
||||
cacheTTL: conf.CacheTTL,
|
||||
@@ -87,15 +76,15 @@ func New(conf *Config) (r *Default) {
|
||||
var _ Interface = (*Default)(nil)
|
||||
|
||||
// Process implements the [Interface] interface for Default.
|
||||
func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, changed bool) {
|
||||
fromCache, expired := r.findInCache(ctx, ip)
|
||||
func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
|
||||
fromCache, expired := r.findInCache(ip)
|
||||
if !expired {
|
||||
return fromCache, false
|
||||
}
|
||||
|
||||
host, ttl, err := r.exchanger.Exchange(ip)
|
||||
if err != nil {
|
||||
r.logger.DebugContext(ctx, "resolving", "ip", ip, slogutil.KeyError, err)
|
||||
log.Debug("rdns: resolving %q: %s", ip, err)
|
||||
}
|
||||
|
||||
ttl = max(ttl, r.cacheTTL)
|
||||
@@ -107,7 +96,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan
|
||||
|
||||
err = r.cache.Set(ip, item)
|
||||
if err != nil {
|
||||
r.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err)
|
||||
log.Debug("rdns: cache: adding item %q: %s", ip, err)
|
||||
}
|
||||
|
||||
// TODO(e.burkov): The name doesn't change if it's neither stored in cache
|
||||
@@ -117,22 +106,22 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan
|
||||
|
||||
// findInCache finds domain name in the cache. expired is true if host is not
|
||||
// valid anymore.
|
||||
func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string, expired bool) {
|
||||
func (r *Default) findInCache(ip netip.Addr) (host string, expired bool) {
|
||||
val, err := r.cache.Get(ip)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gcache.KeyNotFoundError) {
|
||||
r.logger.DebugContext(
|
||||
ctx,
|
||||
"retrieving item from cache",
|
||||
"key", ip,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
log.Debug("rdns: cache: retrieving %q: %s", ip, err)
|
||||
}
|
||||
|
||||
return "", true
|
||||
}
|
||||
|
||||
item := val.(*cacheItem)
|
||||
item, ok := val.(*cacheItem)
|
||||
if !ok {
|
||||
log.Debug("rdns: cache: %q bad type %T", ip, val)
|
||||
|
||||
return "", true
|
||||
}
|
||||
|
||||
return item.host, time.Now().After(item.expiry)
|
||||
}
|
||||
|
||||
@@ -8,14 +8,10 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testTimeout is a common timeout for tests and contexts.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
func TestDefault_Process(t *testing.T) {
|
||||
ip1 := netip.MustParseAddr("1.2.3.4")
|
||||
revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice())
|
||||
@@ -75,14 +71,14 @@ func TestDefault_Process(t *testing.T) {
|
||||
Exchanger: &aghtest.Exchanger{OnExchange: onExchange},
|
||||
})
|
||||
|
||||
got, changed := r.Process(testutil.ContextWithTimeout(t, testTimeout), tc.addr)
|
||||
got, changed := r.Process(tc.addr)
|
||||
require.True(t, changed)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
assert.Equal(t, 1, hit)
|
||||
|
||||
// From cache.
|
||||
got, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), tc.addr)
|
||||
got, changed = r.Process(tc.addr)
|
||||
require.False(t, changed)
|
||||
|
||||
assert.Equal(t, tc.want, got)
|
||||
@@ -105,7 +101,7 @@ func TestDefault_Process(t *testing.T) {
|
||||
Exchanger: zeroTTLExchanger,
|
||||
})
|
||||
|
||||
got, changed := r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1)
|
||||
got, changed := r.Process(ip1)
|
||||
require.True(t, changed)
|
||||
assert.Equal(t, revAddr1, got)
|
||||
|
||||
@@ -113,15 +109,14 @@ func TestDefault_Process(t *testing.T) {
|
||||
return revAddr2, time.Hour, nil
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
require.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
got, changed = r.Process(ctx, ip1)
|
||||
got, changed = r.Process(ip1)
|
||||
assert.True(t, changed)
|
||||
assert.Equal(t, revAddr2, got)
|
||||
}, 2*cacheTTL, time.Millisecond*100)
|
||||
|
||||
assert.Never(t, func() (changed bool) {
|
||||
_, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1)
|
||||
_, changed = r.Process(ip1)
|
||||
|
||||
return changed
|
||||
}, 2*cacheTTL, time.Millisecond*100)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
@@ -50,8 +51,6 @@ type StatsResp struct {
|
||||
func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var (
|
||||
resp *StatsResp
|
||||
ok bool
|
||||
@@ -63,17 +62,12 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
resp, ok = s.getData(uint32(s.limit.Hours()))
|
||||
}()
|
||||
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"prepared data",
|
||||
"elapsed", timeutil.Duration{Duration: time.Since(start)},
|
||||
)
|
||||
log.Debug("stats: prepared data in %v", time.Since(start))
|
||||
|
||||
if !ok {
|
||||
// Don't bring the message to the lower case since it's a part of UI
|
||||
// text for the moment.
|
||||
const msg = "Couldn't get statistics data"
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusInternalServerError, msg)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get statistics data")
|
||||
|
||||
return
|
||||
}
|
||||
@@ -152,18 +146,16 @@ func (s *StatsCtx) handleGetStatsConfig(w http.ResponseWriter, r *http.Request)
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
reqData := configResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !checkInterval(reqData.IntervalDays) {
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusBadRequest, "Unsupported interval")
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
|
||||
|
||||
return
|
||||
}
|
||||
@@ -181,19 +173,17 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
// handlePutStatsConfig is the handler for the PUT /control/stats/config/update
|
||||
// HTTP API.
|
||||
func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
reqData := getConfigResp{}
|
||||
err := json.NewDecoder(r.Body).Decode(&reqData)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
engine, err := aghnet.NewIgnoreEngine(reqData.Ignored)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "ignored: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -201,21 +191,13 @@ func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request)
|
||||
ivl := time.Duration(reqData.Interval) * time.Millisecond
|
||||
err = validateIvl(ivl)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
s.logger,
|
||||
r,
|
||||
w,
|
||||
http.StatusUnprocessableEntity,
|
||||
"unsupported interval: %s",
|
||||
err,
|
||||
)
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "unsupported interval: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if reqData.Enabled == aghalg.NBNull {
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusUnprocessableEntity, "enabled is null")
|
||||
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "enabled is null")
|
||||
|
||||
return
|
||||
}
|
||||
@@ -234,15 +216,7 @@ func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request)
|
||||
func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||
err := s.clear()
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
r.Context(),
|
||||
s.logger,
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
"stats: %s",
|
||||
err,
|
||||
)
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -25,7 +24,6 @@ func TestHandleStatsConfig(t *testing.T) {
|
||||
)
|
||||
|
||||
conf := Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
UnitID: func() (id uint32) { return 0 },
|
||||
ConfigModified: func() {},
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -16,7 +14,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
@@ -45,10 +43,6 @@ func validateIvl(ivl time.Duration) (err error) {
|
||||
//
|
||||
// Do not alter any fields of this structure after using it.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the statistics management.
|
||||
// It must not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// UnitID is the function to generate the identifier for current unit. If
|
||||
// nil, the default function is used, see newUnitID.
|
||||
UnitID UnitIDGenFunc
|
||||
@@ -102,10 +96,6 @@ type Interface interface {
|
||||
// StatsCtx collects the statistics and flushes it to the database. Its default
|
||||
// flushing interval is one hour.
|
||||
type StatsCtx struct {
|
||||
// logger is used for logging the operation of the statistics management.
|
||||
// It must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// currMu protects curr.
|
||||
currMu *sync.RWMutex
|
||||
// curr is the actual statistics collection result.
|
||||
@@ -160,7 +150,6 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
}
|
||||
|
||||
s = &StatsCtx{
|
||||
logger: conf.Logger,
|
||||
currMu: &sync.RWMutex{},
|
||||
httpRegister: conf.HTTPRegister,
|
||||
configModified: conf.ConfigModified,
|
||||
@@ -189,21 +178,21 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
|
||||
tx, err := s.db.Load().Begin(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening a transaction: %w", err)
|
||||
return nil, fmt.Errorf("stats: opening a transaction: %w", err)
|
||||
}
|
||||
|
||||
deleted := s.deleteOldUnits(tx, id-uint32(s.limit.Hours())-1)
|
||||
udb = s.loadUnitFromDB(tx, id)
|
||||
deleted := deleteOldUnits(tx, id-uint32(s.limit.Hours())-1)
|
||||
udb = loadUnitFromDB(tx, id)
|
||||
|
||||
err = finishTxn(tx, deleted > 0)
|
||||
if err != nil {
|
||||
s.logger.Error("finishing transacation", slogutil.KeyError, err)
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
|
||||
s.curr = newUnit(id)
|
||||
s.curr.deserialize(udb)
|
||||
|
||||
s.logger.Debug("initialized")
|
||||
log.Debug("stats: initialized")
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -239,6 +228,8 @@ func (s *StatsCtx) Start() {
|
||||
|
||||
// Close implements the [io.Closer] interface for *StatsCtx.
|
||||
func (s *StatsCtx) Close() (err error) {
|
||||
defer func() { err = errors.Annotate(err, "stats: closing: %w") }()
|
||||
|
||||
db := s.db.Swap(nil)
|
||||
if db == nil {
|
||||
return nil
|
||||
@@ -246,7 +237,7 @@ func (s *StatsCtx) Close() (err error) {
|
||||
defer func() {
|
||||
cerr := db.Close()
|
||||
if cerr == nil {
|
||||
s.logger.Debug("database closed")
|
||||
log.Debug("stats: database closed")
|
||||
}
|
||||
|
||||
err = errors.WithDeferred(err, cerr)
|
||||
@@ -263,7 +254,7 @@ func (s *StatsCtx) Close() (err error) {
|
||||
|
||||
udb := s.curr.serialize()
|
||||
|
||||
return s.flushUnitToDB(udb, tx, s.curr.id)
|
||||
return udb.flushUnitToDB(tx, s.curr.id)
|
||||
}
|
||||
|
||||
// Update implements the [Interface] interface for *StatsCtx. e must not be
|
||||
@@ -278,7 +269,7 @@ func (s *StatsCtx) Update(e *Entry) {
|
||||
|
||||
err := e.validate()
|
||||
if err != nil {
|
||||
s.logger.Debug("validating entry", slogutil.KeyError, err)
|
||||
log.Debug("stats: updating: validating entry: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@@ -287,7 +278,7 @@ func (s *StatsCtx) Update(e *Entry) {
|
||||
defer s.currMu.Unlock()
|
||||
|
||||
if s.curr == nil {
|
||||
s.logger.Error("current unit is nil")
|
||||
log.Error("stats: current unit is nil")
|
||||
|
||||
return
|
||||
}
|
||||
@@ -342,8 +333,8 @@ func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) {
|
||||
|
||||
// deleteOldUnits walks the buckets available to tx and deletes old units. It
|
||||
// returns the number of deletions performed.
|
||||
func (s *StatsCtx) deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) {
|
||||
s.logger.Debug("deleting old units up to", "unit", firstID)
|
||||
func deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) {
|
||||
log.Debug("stats: deleting old units until id %d", firstID)
|
||||
|
||||
// TODO(a.garipov): See if this is actually necessary. Looks like a rather
|
||||
// bizarre solution.
|
||||
@@ -357,12 +348,12 @@ func (s *StatsCtx) deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) {
|
||||
|
||||
err = tx.DeleteBucket(name)
|
||||
if err != nil {
|
||||
s.logger.Debug("deleting bucket", slogutil.KeyError, err)
|
||||
log.Debug("stats: deleting bucket: %s", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Debug("deleted unit", "name_id", nameID, "name", fmt.Sprintf("%x", name))
|
||||
log.Debug("stats: deleted unit %d (name %x)", nameID, name)
|
||||
|
||||
deleted++
|
||||
|
||||
@@ -371,7 +362,7 @@ func (s *StatsCtx) deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) {
|
||||
|
||||
err := tx.ForEach(walk)
|
||||
if err != nil && !errors.Is(err, errStop) {
|
||||
s.logger.Debug("deleting units", slogutil.KeyError, err)
|
||||
log.Debug("stats: deleting units: %s", err)
|
||||
}
|
||||
|
||||
return deleted
|
||||
@@ -380,29 +371,20 @@ func (s *StatsCtx) deleteOldUnits(tx *bbolt.Tx, firstID uint32) (deleted int) {
|
||||
// openDB returns an error if the database can't be opened from the specified
|
||||
// file. It's safe for concurrent use.
|
||||
func (s *StatsCtx) openDB() (err error) {
|
||||
s.logger.Debug("opening database")
|
||||
log.Debug("stats: opening database")
|
||||
|
||||
var db *bbolt.DB
|
||||
db, err = bbolt.Open(s.filename, 0o644, nil)
|
||||
if err != nil {
|
||||
if err.Error() == "invalid argument" {
|
||||
const lines = `AdGuard Home cannot be initialized due to an incompatible file system.
|
||||
Please read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations`
|
||||
|
||||
// TODO(s.chzhen): Use passed context.
|
||||
slogutil.PrintLines(
|
||||
context.TODO(),
|
||||
s.logger,
|
||||
slog.LevelError,
|
||||
"opening database",
|
||||
lines,
|
||||
)
|
||||
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
defer s.logger.Debug("database opened")
|
||||
// Use defer to unlock the mutex as soon as possible.
|
||||
defer log.Debug("stats: database opened")
|
||||
|
||||
s.db.Store(db)
|
||||
|
||||
@@ -442,37 +424,34 @@ func (s *StatsCtx) flushDB(id, limit uint32, ptr *unit) (cont bool, sleepFor tim
|
||||
isCommitable := true
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
s.logger.Error("opening transaction", slogutil.KeyError, err)
|
||||
log.Error("stats: opening transaction: %s", err)
|
||||
|
||||
return true, 0
|
||||
}
|
||||
defer func() {
|
||||
if err = finishTxn(tx, isCommitable); err != nil {
|
||||
s.logger.Error("finishing transaction", slogutil.KeyError, err)
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
s.curr = newUnit(id)
|
||||
|
||||
udb := ptr.serialize()
|
||||
flushErr := s.flushUnitToDB(udb, tx, ptr.id)
|
||||
flushErr := ptr.serialize().flushUnitToDB(tx, ptr.id)
|
||||
if flushErr != nil {
|
||||
s.logger.Error("flushing unit", slogutil.KeyError, flushErr)
|
||||
log.Error("stats: flushing unit: %s", flushErr)
|
||||
isCommitable = false
|
||||
}
|
||||
|
||||
delErr := tx.DeleteBucket(idToUnitName(id - limit))
|
||||
|
||||
if delErr != nil {
|
||||
// TODO(e.burkov): Improve the algorithm of deleting the oldest bucket
|
||||
// to avoid the error.
|
||||
lvl := slog.LevelWarn
|
||||
if !errors.Is(delErr, bbolt.ErrBucketNotFound) {
|
||||
if errors.Is(delErr, bbolt.ErrBucketNotFound) {
|
||||
log.Debug("stats: warning: deleting unit: %s", delErr)
|
||||
} else {
|
||||
isCommitable = false
|
||||
lvl = slog.LevelError
|
||||
log.Error("stats: deleting unit: %s", delErr)
|
||||
}
|
||||
|
||||
s.logger.Log(context.TODO(), lvl, "deleting bucket", slogutil.KeyError, delErr)
|
||||
}
|
||||
|
||||
return true, 0
|
||||
@@ -488,7 +467,7 @@ func (s *StatsCtx) periodicFlush() {
|
||||
cont, sleepFor = s.flush()
|
||||
}
|
||||
|
||||
s.logger.Debug("periodic flushing finished")
|
||||
log.Debug("periodic flushing finished")
|
||||
}
|
||||
|
||||
// setLimit sets the limit. s.lock is expected to be locked.
|
||||
@@ -498,16 +477,16 @@ func (s *StatsCtx) setLimit(limit time.Duration) {
|
||||
if limit != 0 {
|
||||
s.enabled = true
|
||||
s.limit = limit
|
||||
s.logger.Debug("setting limit in days", "num", limit/timeutil.Day)
|
||||
log.Debug("stats: set limit: %d days", limit/timeutil.Day)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.enabled = false
|
||||
s.logger.Debug("disabled")
|
||||
log.Debug("stats: disabled")
|
||||
|
||||
if err := s.clear(); err != nil {
|
||||
s.logger.Error("clearing", slogutil.KeyError, err)
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -520,7 +499,7 @@ func (s *StatsCtx) clear() (err error) {
|
||||
var tx *bbolt.Tx
|
||||
tx, err = db.Begin(true)
|
||||
if err != nil {
|
||||
s.logger.Error("opening transaction", slogutil.KeyError, err)
|
||||
log.Error("stats: opening a transaction: %s", err)
|
||||
} else if err = finishTxn(tx, false); err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
@@ -534,21 +513,21 @@ func (s *StatsCtx) clear() (err error) {
|
||||
}
|
||||
|
||||
// All active transactions are now closed.
|
||||
s.logger.Debug("database closed")
|
||||
log.Debug("stats: database closed")
|
||||
}
|
||||
|
||||
err = os.Remove(s.filename)
|
||||
if err != nil {
|
||||
s.logger.Error("removing", slogutil.KeyError, err)
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
|
||||
err = s.openDB()
|
||||
if err != nil {
|
||||
s.logger.Error("opening database", slogutil.KeyError, err)
|
||||
log.Error("stats: opening database: %s", err)
|
||||
}
|
||||
|
||||
// Use defer to unlock the mutex as soon as possible.
|
||||
defer s.logger.Debug("cleared")
|
||||
defer log.Debug("stats: cleared")
|
||||
|
||||
s.currMu.Lock()
|
||||
defer s.currMu.Unlock()
|
||||
@@ -569,7 +548,7 @@ func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, curID uint32) {
|
||||
// taken into account.
|
||||
tx, err := db.Begin(true)
|
||||
if err != nil {
|
||||
s.logger.Error("opening transaction", slogutil.KeyError, err)
|
||||
log.Error("stats: opening transaction: %s", err)
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
@@ -589,7 +568,7 @@ func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, curID uint32) {
|
||||
units = make([]*unitDB, 0, limit)
|
||||
firstID := curID - limit + 1
|
||||
for i := firstID; i != curID; i++ {
|
||||
u := s.loadUnitFromDB(tx, i)
|
||||
u := loadUnitFromDB(tx, i)
|
||||
if u == nil {
|
||||
u = &unitDB{NResult: make([]uint64, resultLast)}
|
||||
}
|
||||
@@ -598,7 +577,7 @@ func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, curID uint32) {
|
||||
|
||||
err = finishTxn(tx, false)
|
||||
if err != nil {
|
||||
s.logger.Error("finishing transaction", slogutil.KeyError, err)
|
||||
log.Error("stats: %s", err)
|
||||
}
|
||||
|
||||
if cur != nil {
|
||||
@@ -606,8 +585,7 @@ func (s *StatsCtx) loadUnits(limit uint32) (units []*unitDB, curID uint32) {
|
||||
}
|
||||
|
||||
if unitsLen := len(units); unitsLen != int(limit) {
|
||||
// Should not happen.
|
||||
panic(fmt.Errorf("loaded %d units when the desired number is %d", unitsLen, limit))
|
||||
log.Fatalf("loaded %d units whilst the desired number is %d", unitsLen, limit)
|
||||
}
|
||||
|
||||
return units, curID
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -19,7 +18,6 @@ func TestStats_races(t *testing.T) {
|
||||
var r uint32
|
||||
idGen := func() (id uint32) { return atomic.LoadUint32(&r) }
|
||||
conf := Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
UnitID: idGen,
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
@@ -96,7 +94,6 @@ func TestStatsCtx_FillCollectedStats_daily(t *testing.T) {
|
||||
)
|
||||
|
||||
s, err := New(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
Limit: time.Hour,
|
||||
@@ -154,7 +151,6 @@ func TestStatsCtx_DataFromUnits_month(t *testing.T) {
|
||||
const hoursInMonth = 720
|
||||
|
||||
s, err := New(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
Limit: time.Hour,
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@@ -22,6 +21,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// constUnitID is the UnitIDGenFunc which always return 0.
|
||||
func constUnitID() (id uint32) { return 0 }
|
||||
|
||||
@@ -52,7 +55,6 @@ func TestStats(t *testing.T) {
|
||||
|
||||
handlers := map[string]http.Handler{}
|
||||
conf := stats.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
@@ -169,7 +171,6 @@ func TestLargeNumbers(t *testing.T) {
|
||||
handlers := map[string]http.Handler{}
|
||||
|
||||
conf := stats.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
@@ -221,7 +222,6 @@ func TestShouldCount(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := stats.New(stats.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user