Compare commits
217 Commits
6399-fix-u
...
release-v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b21e19a223 | ||
|
|
c6aed4eb57 | ||
|
|
760d466b38 | ||
|
|
258eecc55b | ||
|
|
7b93f5d7cf | ||
|
|
3be7676970 | ||
|
|
48ee2f8a42 | ||
|
|
ec83d0eb86 | ||
|
|
19347d263a | ||
|
|
b22b16d98c | ||
|
|
cadb765b7d | ||
|
|
1116da8b83 | ||
|
|
c65700923a | ||
|
|
7030c7c24c | ||
|
|
09718a2170 | ||
|
|
77cda2c2c5 | ||
|
|
d9c57cdd9a | ||
|
|
0dad53b5f7 | ||
|
|
9a7315dbea | ||
|
|
a21558f418 | ||
|
|
4f928be393 | ||
|
|
f543b47261 | ||
|
|
66b831072c | ||
|
|
80eb339896 | ||
|
|
c69639c013 | ||
|
|
5f6fbe8e08 | ||
|
|
b40bbf0260 | ||
|
|
a11c8e91ab | ||
|
|
618d0e596c | ||
|
|
fde9ea5cb1 | ||
|
|
03d9803238 | ||
|
|
bd64b8b014 | ||
|
|
67fe064fcf | ||
|
|
471668d19a | ||
|
|
42762dfe54 | ||
|
|
c9314610d4 | ||
|
|
16755c37d8 | ||
|
|
73fcbd6ea2 | ||
|
|
30244f361f | ||
|
|
083991fb21 | ||
|
|
e3200d5046 | ||
|
|
21f6ed36fe | ||
|
|
77d04d44eb | ||
|
|
b34d119255 | ||
|
|
63bd71a10c | ||
|
|
faf2b32389 | ||
|
|
d23da1b757 | ||
|
|
beb8e36eee | ||
|
|
fe70161c01 | ||
|
|
39fa4b1f8e | ||
|
|
c7a8883201 | ||
|
|
3fd467413c | ||
|
|
9728dd856f | ||
|
|
ecadf78d60 | ||
|
|
eba4612d72 | ||
|
|
9200163f85 | ||
|
|
3c17853344 | ||
|
|
993a3fc42c | ||
|
|
7bb9b2416b | ||
|
|
2de321ce24 | ||
|
|
30b2b85ff1 | ||
|
|
6ea4788f56 | ||
|
|
3c52a021b9 | ||
|
|
0ceea9af5f | ||
|
|
39b404be19 | ||
|
|
56dc3eab02 | ||
|
|
554a38eeb1 | ||
|
|
c8d3afe869 | ||
|
|
44222c604c | ||
|
|
cbf221585e | ||
|
|
48322f6d0d | ||
|
|
d5a213c639 | ||
|
|
8166c4bc33 | ||
|
|
133cd9ef6b | ||
|
|
11146f73ed | ||
|
|
1beb18db47 | ||
|
|
f7bc2273a7 | ||
|
|
d1e735a003 | ||
|
|
af4ff5c748 | ||
|
|
fc951c1226 | ||
|
|
f81fd42472 | ||
|
|
1029ea5966 | ||
|
|
c0abdb4bc7 | ||
|
|
6681178ad3 | ||
|
|
e73605c4c5 | ||
|
|
c7017d49aa | ||
|
|
191d3bde49 | ||
|
|
18876a8e5c | ||
|
|
aa4a0d9880 | ||
|
|
d03d731d65 | ||
|
|
33b58a42fe | ||
|
|
2e9e708647 | ||
|
|
8ad22841ab | ||
|
|
32cf02264c | ||
|
|
0e8445b38f | ||
|
|
cb27ecd6c0 | ||
|
|
535220b3df | ||
|
|
7b9cfa94f8 | ||
|
|
b3f2e88e9c | ||
|
|
aa7a8d45e4 | ||
|
|
49cdef3d6a | ||
|
|
fecd146552 | ||
|
|
b01efd8c98 | ||
|
|
bd4dfb261c | ||
|
|
e754e4d2f6 | ||
|
|
b220e35c99 | ||
|
|
4f5131f423 | ||
|
|
dcb043df5f | ||
|
|
86e5756262 | ||
|
|
ba0cf5739b | ||
|
|
c4a13b92d2 | ||
|
|
723279121a | ||
|
|
3ad7649f7d | ||
|
|
2898a49d86 | ||
|
|
1547f9d35e | ||
|
|
adadd55c42 | ||
|
|
33b0225aa4 | ||
|
|
97d4058d80 | ||
|
|
86207e719d | ||
|
|
113f94ff46 | ||
|
|
5673deb391 | ||
|
|
3548a393ed | ||
|
|
254515f274 | ||
|
|
bccbecc6ea | ||
|
|
66f53803af | ||
|
|
faef005ce7 | ||
|
|
941cd2a562 | ||
|
|
6a4a9a0239 | ||
|
|
b9dbe6f1b6 | ||
|
|
7fec111ef8 | ||
|
|
5e1bd99718 | ||
|
|
9d75f72ceb | ||
|
|
d98d96db1a | ||
|
|
6a0ef2df15 | ||
|
|
75c2eb4c8a | ||
|
|
d021a67d66 | ||
|
|
4ed97cab12 | ||
|
|
a38742eed7 | ||
|
|
5efa95ed26 | ||
|
|
04db7db607 | ||
|
|
d17c6c6bb3 | ||
|
|
b2052f2ef1 | ||
|
|
cddcf852c2 | ||
|
|
1def426b45 | ||
|
|
b114fd5279 | ||
|
|
d27c3284f6 | ||
|
|
ba24a26b53 | ||
|
|
3e6678b6b4 | ||
|
|
83fd6f9782 | ||
|
|
52bc1b3f10 | ||
|
|
dd2153b7ac | ||
|
|
dd96a34861 | ||
|
|
daf26ee25a | ||
|
|
7e140eaaac | ||
|
|
d07a712988 | ||
|
|
95863288bf | ||
|
|
ea12be658b | ||
|
|
faa7c9aae5 | ||
|
|
e3653e8c25 | ||
|
|
b40cb24822 | ||
|
|
74004c1aa0 | ||
|
|
3e240741f1 | ||
|
|
6cfdbef1a5 | ||
|
|
d9bde6425b | ||
|
|
e2ae9e1591 | ||
|
|
5ebcbfa9ad | ||
|
|
e276bd7a31 | ||
|
|
659b2529bf | ||
|
|
97b3ed43ab | ||
|
|
767d6d3f28 | ||
|
|
31fc9bfc52 | ||
|
|
3f06b02409 | ||
|
|
5bf958ec6b | ||
|
|
959d9ff9a0 | ||
|
|
4813b4de25 | ||
|
|
119100924c | ||
|
|
bd584de4ee | ||
|
|
ede85ab2f2 | ||
|
|
12c20288e4 | ||
|
|
5bbbf89c10 | ||
|
|
d55393ecd5 | ||
|
|
2b5927306f | ||
|
|
4f016b6ed7 | ||
|
|
3a2a6d10ec | ||
|
|
2491426b09 | ||
|
|
5ebdd1390e | ||
|
|
b7f0247575 | ||
|
|
e28186a28a | ||
|
|
de1a7ce48f | ||
|
|
48480fb33b | ||
|
|
f41332fe6b | ||
|
|
1f8b340b8f | ||
|
|
fdaf1d09d3 | ||
|
|
b9682c4f10 | ||
|
|
69dcb4effd | ||
|
|
d50fd0ba91 | ||
|
|
c2c7b4c731 | ||
|
|
952d5f3a3d | ||
|
|
3f126c9ec9 | ||
|
|
0be58ef918 | ||
|
|
8f9053e2fc | ||
|
|
68452e5330 | ||
|
|
2eacc46eaa | ||
|
|
74dcc91ea7 | ||
|
|
dd7bf61323 | ||
|
|
2819d6cace | ||
|
|
75355a6883 | ||
|
|
e9c007d56b | ||
|
|
84c9085516 | ||
|
|
9f36e57c1e | ||
|
|
7528699fc2 | ||
|
|
d280151c18 | ||
|
|
b44c755d25 | ||
|
|
e4078e87a1 | ||
|
|
be36204756 | ||
|
|
b5409d6d00 | ||
|
|
f3d6bce03e |
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -1,7 +1,7 @@
|
||||
'name': 'build'
|
||||
|
||||
'env':
|
||||
'GO_VERSION': '1.20.11'
|
||||
'GO_VERSION': '1.20.10'
|
||||
'NODE_VERSION': '16'
|
||||
|
||||
'on':
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -1,7 +1,7 @@
|
||||
'name': 'lint'
|
||||
|
||||
'env':
|
||||
'GO_VERSION': '1.20.11'
|
||||
'GO_VERSION': '1.20.10'
|
||||
|
||||
'on':
|
||||
'push':
|
||||
|
||||
28
CHANGELOG.md
28
CHANGELOG.md
@@ -23,34 +23,6 @@ See also the [v0.107.41 GitHub milestone][ms-v0.107.41].
|
||||
NOTE: Add new changes BELOW THIS COMMENT.
|
||||
-->
|
||||
|
||||
### Added
|
||||
|
||||
- Ability to specify multiple domain specific upstreams per line, e.g.
|
||||
`[/domain1/../domain2/]upstream1 upstream2 .. upstreamN` ([#4977]).
|
||||
|
||||
### Changed
|
||||
|
||||
- The height of ready-to-use filter lists has been increased ([#6358]).
|
||||
- Improved authentication failure logging ([#6357]).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Redundant shortening long client names in the Top Clients table ([#6338]).
|
||||
- Scrolling column headers in the tables ([#6337]).
|
||||
- `$important,dnsrewrite` rules do not take precedence over allowlist rules
|
||||
([#6204]).
|
||||
- Dark mode DNS rewrite background ([#6329]).
|
||||
- Issues with QUIC and HTTP/3 upstreams on Linux ([#6335]).
|
||||
|
||||
[#4977]: https://github.com/AdguardTeam/AdGuardHome/issues/4977
|
||||
[#6204]: https://github.com/AdguardTeam/AdGuardHome/issues/6204
|
||||
[#6329]: https://github.com/AdguardTeam/AdGuardHome/issues/6329
|
||||
[#6335]: https://github.com/AdguardTeam/AdGuardHome/issues/6335
|
||||
[#6337]: https://github.com/AdguardTeam/AdGuardHome/issues/6337
|
||||
[#6338]: https://github.com/AdguardTeam/AdGuardHome/issues/6338
|
||||
[#6357]: https://github.com/AdguardTeam/AdGuardHome/issues/6357
|
||||
[#6358]: https://github.com/AdguardTeam/AdGuardHome/issues/6358
|
||||
|
||||
<!--
|
||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||
-->
|
||||
|
||||
@@ -201,7 +201,7 @@ opinion, this cannot be legitimately counted as a Pi-Hole's feature.
|
||||
| Cross-platform | ✅ | ❌ (not natively, only via Docker) |
|
||||
| Running as a DNS-over-HTTPS or DNS-over-TLS server | ✅ | ❌ (requires additional software) |
|
||||
| Blocking phishing and malware domains | ✅ | ❌ (requires non-default blocklists) |
|
||||
| Parental control (blocking adult domains) | ✅ | ❌ (requires non-default blocklists) |
|
||||
| Parental control (blocking adult domains) | ✅ | ❌ |
|
||||
| Force Safe search on search engines | ✅ | ❌ |
|
||||
| Per-client (device) configuration | ✅ | ✅ |
|
||||
| Access settings (choose who can use AGH DNS) | ✅ | ❌ |
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# Make sure to sync any changes with the branch overrides below.
|
||||
'variables':
|
||||
'channel': 'edge'
|
||||
'dockerGo': 'adguard/golang-ubuntu:7.5'
|
||||
'dockerGo': 'adguard/golang-ubuntu:7.4'
|
||||
|
||||
'stages':
|
||||
- 'Build frontend':
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# Make sure to sync any changes with the branch overrides below.
|
||||
'variables':
|
||||
'channel': 'edge'
|
||||
'dockerGo': 'adguard/golang-ubuntu:7.5'
|
||||
'dockerGo': 'adguard/golang-ubuntu:7.4'
|
||||
'snapcraftChannel': 'edge'
|
||||
|
||||
'stages':
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
'key': 'AHBRTSPECS'
|
||||
'name': 'AdGuard Home - Build and run tests'
|
||||
'variables':
|
||||
'dockerGo': 'adguard/golang-ubuntu:7.5'
|
||||
'dockerGo': 'adguard/golang-ubuntu:7.4'
|
||||
|
||||
'stages':
|
||||
- 'Tests':
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
{
|
||||
"client_settings": "Client settings",
|
||||
"example_upstream_reserved": "an upstream <0>for specific domains</0>;",
|
||||
"example_multiple_upstreams_reserved": "multiple upstreams <0>for specific domains</0>;",
|
||||
"example_upstream_comment": "a comment.",
|
||||
"upstream_parallel": "Use parallel queries to speed up resolving by querying all upstream servers simultaneously.",
|
||||
"parallel_requests": "Parallel requests",
|
||||
|
||||
@@ -118,11 +118,6 @@ body {
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.modal-body--filters {
|
||||
max-height: 600px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.modal-body__item:not(:first-child) {
|
||||
padding-top: 1.5rem;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const renderIcons = (iconsData) => iconsData.map(({
|
||||
}) => <a key={iconName} href={href} target="_blank" rel="noopener noreferrer"
|
||||
className={classNames('d-flex align-items-center', className)}
|
||||
>
|
||||
<svg className="icon icon--15 mr-1 icon--gray">
|
||||
<svg className="nav-icon nav-icon--gray">
|
||||
<use xlinkHref={`#${iconName}`} />
|
||||
</svg>
|
||||
</a>);
|
||||
@@ -110,7 +110,7 @@ const Form = (props) => {
|
||||
const openAddFiltersModal = () => openModal(MODAL_TYPE.ADD_FILTERS);
|
||||
|
||||
return <form onSubmit={handleSubmit}>
|
||||
<div className="modal-body modal-body--filters">
|
||||
<div className="modal-body modal-body--medium">
|
||||
{modalType === MODAL_TYPE.SELECT_MODAL_TYPE
|
||||
&& <div className="d-flex justify-content-around">
|
||||
<button onClick={openFilteringListModal}
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
color: var(--gray-f3);
|
||||
}
|
||||
|
||||
.logs__table .logs__text--client {
|
||||
.logs__text--client {
|
||||
padding-right: 32px;
|
||||
}
|
||||
|
||||
|
||||
@@ -137,22 +137,6 @@ const Examples = (props) => (
|
||||
example_upstream_reserved
|
||||
</Trans>
|
||||
</li>
|
||||
<li>
|
||||
<code>[/example.local/]94.140.14.140 2a10:50c0::1:ff</code>: <Trans
|
||||
components={[
|
||||
<a
|
||||
href="https://github.com/AdguardTeam/AdGuardHome/wiki/Configuration#upstreams-for-domains"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
key="0"
|
||||
>
|
||||
Link
|
||||
</a>,
|
||||
]}
|
||||
>
|
||||
example_multiple_upstreams_reserved
|
||||
</Trans>
|
||||
</li>
|
||||
<li>
|
||||
<code>{COMMENT_LINE_DEFAULT_TOKEN} comment</code>: <Trans>
|
||||
example_upstream_comment
|
||||
|
||||
@@ -149,7 +149,3 @@
|
||||
.card .logs__row--blue {
|
||||
background-color: #ecf7ff;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .card .logs__row--blue {
|
||||
background-color: var(--logs__row--blue-bgcolor);
|
||||
}
|
||||
|
||||
@@ -24,13 +24,6 @@
|
||||
height: var(--size);
|
||||
}
|
||||
|
||||
.icon--15 {
|
||||
--size: 0.95rem;
|
||||
|
||||
width: var(--size);
|
||||
height: var(--size);
|
||||
}
|
||||
|
||||
.icon--gray {
|
||||
color: var(--gray-a5);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
.ReactTable .rt-tbody {
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
.ReactTable .rt-noData {
|
||||
color: var(--rt-nodata-color);
|
||||
background-color: var(--rt-nodata-bgcolor);
|
||||
|
||||
12
go.mod
12
go.mod
@@ -3,9 +3,9 @@ module github.com/AdguardTeam/AdGuardHome
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.56.3
|
||||
github.com/AdguardTeam/golibs v0.17.2
|
||||
github.com/AdguardTeam/urlfilter v0.17.3
|
||||
github.com/AdguardTeam/dnsproxy v0.56.2
|
||||
github.com/AdguardTeam/golibs v0.17.1
|
||||
github.com/AdguardTeam/urlfilter v0.17.0
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
||||
github.com/bluele/gcache v0.0.2
|
||||
@@ -17,7 +17,7 @@ require (
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/renameio/v2 v2.0.0
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/insomniacslk/dhcp v0.0.0-20231016090811-6a2c8fbdcc1c
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||
github.com/kardianos/service v1.2.2
|
||||
github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118
|
||||
@@ -27,9 +27,9 @@ require (
|
||||
// own code for that. Perhaps, use gopacket.
|
||||
github.com/mdlayher/raw v0.1.0
|
||||
github.com/miekg/dns v1.1.56
|
||||
github.com/quic-go/quic-go v0.39.2
|
||||
github.com/quic-go/quic-go v0.39.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/ti-mo/netfilter v0.5.1
|
||||
github.com/ti-mo/netfilter v0.5.0
|
||||
go.etcd.io/bbolt v1.3.7
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
|
||||
|
||||
24
go.sum
24
go.sum
@@ -1,9 +1,9 @@
|
||||
github.com/AdguardTeam/dnsproxy v0.56.3 h1:WP1FooLfZQPHEH2SuwMtJsOurDt32rubGx0OddcsKT0=
|
||||
github.com/AdguardTeam/dnsproxy v0.56.3/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM=
|
||||
github.com/AdguardTeam/golibs v0.17.2 h1:vg6wHMjUKscnyPGRvxS5kAt7Uw4YxcJiITZliZ476W8=
|
||||
github.com/AdguardTeam/golibs v0.17.2/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U=
|
||||
github.com/AdguardTeam/urlfilter v0.17.3 h1:fg/ObbnO0Cv6aw0tW6N/ETDMhhNvmcUUOZ7HlmKC3rw=
|
||||
github.com/AdguardTeam/urlfilter v0.17.3/go.mod h1:Jru7jFfeH2CoDf150uDs+rRYcZBzHHBz05r9REyDKyE=
|
||||
github.com/AdguardTeam/dnsproxy v0.56.2 h1:+k1iUmp05QIqkgXWyPn70fki4FouHe6vHIyHguelKao=
|
||||
github.com/AdguardTeam/dnsproxy v0.56.2/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM=
|
||||
github.com/AdguardTeam/golibs v0.17.1 h1:j3Ehhld5GI/amcHYG+CF0sJ4OOzAQ06BY3N/iBYJZ1M=
|
||||
github.com/AdguardTeam/golibs v0.17.1/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U=
|
||||
github.com/AdguardTeam/urlfilter v0.17.0 h1:tUzhtR9wMx704GIP3cibsDQJrixlMHfwoQbYJfPdFow=
|
||||
github.com/AdguardTeam/urlfilter v0.17.0/go.mod h1:bbuZjPUzm/Ip+nz5qPPbwIP+9rZyQbQad8Lt/0fCulU=
|
||||
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
||||
github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c=
|
||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||
@@ -49,8 +49,8 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20231016090811-6a2c8fbdcc1c h1:PgxFEySCI41sH0mB7/2XswdXbUykQsRUGod8Rn+NubM=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20231016090811-6a2c8fbdcc1c/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a h1:S33o3djA1nPRd+d/bf7jbbXytXuK/EoXow7+aa76grQ=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a/go.mod h1:zmdm3sTSDP3vOOX3CEWRkkRHtKr1DxBx+J1OQFoDQQs=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
@@ -94,8 +94,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/quic-go/quic-go v0.39.2 h1:hmwAf8zAHlvan0Y5PXxeeBFZEW17IW99sXLry8I2kjk=
|
||||
github.com/quic-go/quic-go v0.39.2/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
|
||||
github.com/quic-go/quic-go v0.39.1 h1:d/m3oaN/SD2c+f7/yEjZxe2zEVotXprnrCCJ2y/ZZFE=
|
||||
github.com/quic-go/quic-go v0.39.1/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
|
||||
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -105,8 +105,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
||||
github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8=
|
||||
github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw=
|
||||
github.com/ti-mo/netfilter v0.5.0 h1:MZmsUw5bFRecOb0AeyjOPxTHg4UxYzyEs0Ek/6Lxoy8=
|
||||
github.com/ti-mo/netfilter v0.5.0/go.mod h1:nt+8B9hx/QpqHr7Hazq+2qMCCA8u2OTkyc/7+U9ARz8=
|
||||
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
|
||||
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
|
||||
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 h1:YcojQL98T/OO+rybuzn2+5KrD5dBwXIvYBvQ2cD3Avg=
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// Config is the configuration for the DHCP service.
|
||||
@@ -35,58 +33,54 @@ type InterfaceConfig struct {
|
||||
IPv6 *IPv6Config
|
||||
}
|
||||
|
||||
// Validate returns an error in conf if any.
|
||||
func (conf *Config) Validate() (err error) {
|
||||
switch {
|
||||
case conf == nil:
|
||||
return errNilConfig
|
||||
case !conf.Enabled:
|
||||
return nil
|
||||
case conf.ICMPTimeout < 0:
|
||||
return fmt.Errorf("icmp timeout %s must be non-negative", conf.ICMPTimeout)
|
||||
}
|
||||
// IPv4Config is the interface-specific configuration for DHCPv4.
|
||||
type IPv4Config struct {
|
||||
// GatewayIP is the IPv4 address of the network's gateway. It is used as
|
||||
// the default gateway for DHCP clients and also used in calculating the
|
||||
// network-specific broadcast address.
|
||||
GatewayIP netip.Addr
|
||||
|
||||
err = netutil.ValidateDomainName(conf.LocalDomainName)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
// SubnetMask is the IPv4 subnet mask of the network. It should be a valid
|
||||
// IPv4 subnet mask (i.e. all 1s followed by all 0s).
|
||||
SubnetMask netip.Addr
|
||||
|
||||
if len(conf.Interfaces) == 0 {
|
||||
return errNoInterfaces
|
||||
}
|
||||
// RangeStart is the first address in the range to assign to DHCP clients.
|
||||
RangeStart netip.Addr
|
||||
|
||||
ifaces := maps.Keys(conf.Interfaces)
|
||||
slices.Sort(ifaces)
|
||||
// RangeEnd is the last address in the range to assign to DHCP clients.
|
||||
RangeEnd netip.Addr
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if err = conf.Interfaces[iface].validate(); err != nil {
|
||||
return fmt.Errorf("interface %q: %w", iface, err)
|
||||
}
|
||||
}
|
||||
// Options is the list of DHCP options to send to DHCP clients.
|
||||
Options layers.DHCPOptions
|
||||
|
||||
return nil
|
||||
// LeaseDuration is the TTL of a DHCP lease.
|
||||
LeaseDuration time.Duration
|
||||
|
||||
// Enabled is the state of the DHCPv4 service, whether it is enabled or not
|
||||
// on the specific interface.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// mustBeErr returns an error that indicates that valName must be as must
|
||||
// describes.
|
||||
func mustBeErr(valName, must string, val fmt.Stringer) (err error) {
|
||||
return fmt.Errorf("%s %s must %s", valName, val, must)
|
||||
}
|
||||
|
||||
// validate returns an error in ic, if any.
|
||||
func (ic *InterfaceConfig) validate() (err error) {
|
||||
if ic == nil {
|
||||
return errNilConfig
|
||||
}
|
||||
|
||||
if err = ic.IPv4.validate(); err != nil {
|
||||
return fmt.Errorf("ipv4: %w", err)
|
||||
}
|
||||
|
||||
if err = ic.IPv6.validate(); err != nil {
|
||||
return fmt.Errorf("ipv6: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// IPv6Config is the interface-specific configuration for DHCPv6.
|
||||
type IPv6Config struct {
|
||||
// RangeStart is the first address in the range to assign to DHCP clients.
|
||||
RangeStart netip.Addr
|
||||
|
||||
// Options is the list of DHCP options to send to DHCP clients.
|
||||
Options layers.DHCPOptions
|
||||
|
||||
// LeaseDuration is the TTL of a DHCP lease.
|
||||
LeaseDuration time.Duration
|
||||
|
||||
// RASlaacOnly defines whether the DHCP clients should only use SLAAC for
|
||||
// address assignment.
|
||||
RASLAACOnly bool
|
||||
|
||||
// RAAllowSlaac defines whether the DHCP clients may use SLAAC for address
|
||||
// assignment.
|
||||
RAAllowSLAAC bool
|
||||
|
||||
// Enabled is the state of the DHCPv6 service, whether it is enabled or not
|
||||
// on the specific interface.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package dhcpsvc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
)
|
||||
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
conf *dhcpsvc.Config
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "nil_config",
|
||||
conf: nil,
|
||||
wantErrMsg: "config is nil",
|
||||
}, {
|
||||
name: "disabled",
|
||||
conf: &dhcpsvc.Config{},
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "empty",
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
},
|
||||
wantErrMsg: `bad domain name "": domain name is empty`,
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: nil,
|
||||
},
|
||||
name: "no_interfaces",
|
||||
wantErrMsg: "no interfaces specified",
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: nil,
|
||||
},
|
||||
name: "no_interfaces",
|
||||
wantErrMsg: "no interfaces specified",
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": nil,
|
||||
},
|
||||
},
|
||||
name: "nil_interface",
|
||||
wantErrMsg: `interface "eth0": config is nil`,
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: nil,
|
||||
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "nil_ipv4",
|
||||
wantErrMsg: `interface "eth0": ipv4: config is nil`,
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: &dhcpsvc.IPv4Config{Enabled: false},
|
||||
IPv6: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "nil_ipv6",
|
||||
wantErrMsg: `interface "eth0": ipv6: config is nil`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, tc.conf.Validate())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -56,17 +56,16 @@ type Interface interface {
|
||||
// hostname, either set or generated.
|
||||
IPByHost(host string) (ip netip.Addr)
|
||||
|
||||
// Leases returns all the active DHCP leases.
|
||||
Leases() (ls []*Lease)
|
||||
// Leases returns all the DHCP leases.
|
||||
Leases() (leases []*Lease)
|
||||
|
||||
// AddLease adds a new DHCP lease. It returns an error if the lease is
|
||||
// invalid or already exists.
|
||||
AddLease(l *Lease) (err error)
|
||||
|
||||
// UpdateStaticLease changes an existing DHCP lease. It returns an error if
|
||||
// there is no lease with such hardware addressor if new values are invalid
|
||||
// or already exist.
|
||||
UpdateStaticLease(l *Lease) (err error)
|
||||
// EditLease changes an existing DHCP lease. It returns an error if there
|
||||
// is no lease equal to old or if new is invalid or already exists.
|
||||
EditLease(old, new *Lease) (err error)
|
||||
|
||||
// RemoveLease removes an existing DHCP lease. It returns an error if there
|
||||
// is no lease equal to l.
|
||||
@@ -80,7 +79,7 @@ type Interface interface {
|
||||
type Empty struct{}
|
||||
|
||||
// type check
|
||||
var _ agh.ServiceWithConfig[*Config] = Empty{}
|
||||
var _ Interface = Empty{}
|
||||
|
||||
// Start implements the [Service] interface for Empty.
|
||||
func (Empty) Start() (err error) { return nil }
|
||||
@@ -88,6 +87,8 @@ func (Empty) Start() (err error) { return nil }
|
||||
// Shutdown implements the [Service] interface for Empty.
|
||||
func (Empty) Shutdown(_ context.Context) (err error) { return nil }
|
||||
|
||||
var _ agh.ServiceWithConfig[*Config] = Empty{}
|
||||
|
||||
// Config implements the [ServiceWithConfig] interface for Empty.
|
||||
func (Empty) Config() (conf *Config) { return nil }
|
||||
|
||||
@@ -112,8 +113,8 @@ func (Empty) Leases() (leases []*Lease) { return nil }
|
||||
// AddLease implements the [Interface] interface for Empty.
|
||||
func (Empty) AddLease(_ *Lease) (err error) { return nil }
|
||||
|
||||
// UpdateStaticLease implements the [Interface] interface for Empty.
|
||||
func (Empty) UpdateStaticLease(_ *Lease) (err error) { return nil }
|
||||
// EditLease implements the [Interface] interface for Empty.
|
||||
func (Empty) EditLease(_, _ *Lease) (err error) { return nil }
|
||||
|
||||
// RemoveLease implements the [Interface] interface for Empty.
|
||||
func (Empty) RemoveLease(_ *Lease) (err error) { return nil }
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import "github.com/AdguardTeam/golibs/errors"
|
||||
|
||||
const (
|
||||
// errNilConfig is returned when a nil config met.
|
||||
errNilConfig errors.Error = "config is nil"
|
||||
|
||||
// errNoInterfaces is returned when no interfaces found in configuration.
|
||||
errNoInterfaces errors.Error = "no interfaces specified"
|
||||
)
|
||||
@@ -1,98 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// ipRange is an inclusive range of IP addresses. A zero range doesn't contain
|
||||
// any IP addresses.
|
||||
//
|
||||
// It is safe for concurrent use.
|
||||
type ipRange struct {
|
||||
start netip.Addr
|
||||
end netip.Addr
|
||||
}
|
||||
|
||||
// maxRangeLen is the maximum IP range length. The bitsets used in servers only
|
||||
// accept uints, which can have the size of 32 bit.
|
||||
//
|
||||
// TODO(a.garipov, e.burkov): Reconsider the value for IPv6.
|
||||
const maxRangeLen = math.MaxUint32
|
||||
|
||||
// newIPRange creates a new IP address range. start must be less than end. The
|
||||
// resulting range must not be greater than maxRangeLen.
|
||||
func newIPRange(start, end netip.Addr) (r ipRange, err error) {
|
||||
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
|
||||
|
||||
switch false {
|
||||
case start.Is4() == end.Is4():
|
||||
return ipRange{}, fmt.Errorf("%s and %s must be within the same address family", start, end)
|
||||
case start.Less(end):
|
||||
return ipRange{}, fmt.Errorf("start %s is greater than or equal to end %s", start, end)
|
||||
default:
|
||||
diff := (&big.Int{}).Sub(
|
||||
(&big.Int{}).SetBytes(end.AsSlice()),
|
||||
(&big.Int{}).SetBytes(start.AsSlice()),
|
||||
)
|
||||
|
||||
if !diff.IsUint64() || diff.Uint64() > maxRangeLen {
|
||||
return ipRange{}, fmt.Errorf("range length must be within %d", uint32(maxRangeLen))
|
||||
}
|
||||
}
|
||||
|
||||
return ipRange{
|
||||
start: start,
|
||||
end: end,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// contains returns true if r contains ip.
|
||||
func (r ipRange) contains(ip netip.Addr) (ok bool) {
|
||||
// Assume that the end was checked to be within the same address family as
|
||||
// the start during construction.
|
||||
return r.start.Is4() == ip.Is4() && !ip.Less(r.start) && !r.end.Less(ip)
|
||||
}
|
||||
|
||||
// ipPredicate is a function that is called on every IP address in
|
||||
// [ipRange.find].
|
||||
type ipPredicate func(ip netip.Addr) (ok bool)
|
||||
|
||||
// find finds the first IP address in r for which p returns true. It returns an
|
||||
// empty [netip.Addr] if there are no addresses that satisfy p.
|
||||
//
|
||||
// TODO(e.burkov): Use.
|
||||
func (r ipRange) find(p ipPredicate) (ip netip.Addr) {
|
||||
for ip = r.start; !r.end.Less(ip); ip = ip.Next() {
|
||||
if p(ip) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// offset returns the offset of ip from the beginning of r. It returns 0 and
|
||||
// false if ip is not in r.
|
||||
func (r ipRange) offset(ip netip.Addr) (offset uint64, ok bool) {
|
||||
if !r.contains(ip) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
startData, ipData := r.start.As16(), ip.As16()
|
||||
be := binary.BigEndian
|
||||
|
||||
// Assume that the range length was checked against maxRangeLen during
|
||||
// construction.
|
||||
return be.Uint64(ipData[8:]) - be.Uint64(startData[8:]), true
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface for *ipRange.
|
||||
func (r ipRange) String() (s string) {
|
||||
return fmt.Sprintf("%s-%s", r.start, r.end)
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewIPRange(t *testing.T) {
|
||||
start4 := netip.MustParseAddr("0.0.0.1")
|
||||
end4 := netip.MustParseAddr("0.0.0.3")
|
||||
start6 := netip.MustParseAddr("1::1")
|
||||
end6 := netip.MustParseAddr("1::3")
|
||||
end6Large := netip.MustParseAddr("2::3")
|
||||
|
||||
testCases := []struct {
|
||||
start netip.Addr
|
||||
end netip.Addr
|
||||
name string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
start: start4,
|
||||
end: end4,
|
||||
name: "success_ipv4",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
start: start6,
|
||||
end: end6,
|
||||
name: "success_ipv6",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
start: end4,
|
||||
end: start4,
|
||||
name: "start_gt_end",
|
||||
wantErrMsg: "invalid ip range: start 0.0.0.3 is greater than or equal to end 0.0.0.1",
|
||||
}, {
|
||||
start: start4,
|
||||
end: start4,
|
||||
name: "start_eq_end",
|
||||
wantErrMsg: "invalid ip range: start 0.0.0.1 is greater than or equal to end 0.0.0.1",
|
||||
}, {
|
||||
start: start6,
|
||||
end: end6Large,
|
||||
name: "too_large",
|
||||
wantErrMsg: "invalid ip range: range length must be within " +
|
||||
strconv.FormatUint(maxRangeLen, 10),
|
||||
}, {
|
||||
start: start4,
|
||||
end: end6,
|
||||
name: "different_family",
|
||||
wantErrMsg: "invalid ip range: 0.0.0.1 and 1::3 must be within the same address family",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := newIPRange(tc.start, tc.end)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRange_Contains(t *testing.T) {
|
||||
start, end := netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.3")
|
||||
r, err := newIPRange(start, end)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
in netip.Addr
|
||||
want assert.BoolAssertionFunc
|
||||
name string
|
||||
}{{
|
||||
in: start,
|
||||
want: assert.True,
|
||||
name: "start",
|
||||
}, {
|
||||
in: end,
|
||||
want: assert.True,
|
||||
name: "end",
|
||||
}, {
|
||||
in: start.Next(),
|
||||
want: assert.True,
|
||||
name: "within",
|
||||
}, {
|
||||
in: netip.MustParseAddr("0.0.0.0"),
|
||||
want: assert.False,
|
||||
name: "before",
|
||||
}, {
|
||||
in: netip.MustParseAddr("0.0.0.4"),
|
||||
want: assert.False,
|
||||
name: "after",
|
||||
}, {
|
||||
in: netip.MustParseAddr("::"),
|
||||
want: assert.False,
|
||||
name: "another_family",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.want(t, r.contains(tc.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRange_Find(t *testing.T) {
|
||||
start, end := netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.5")
|
||||
r, err := newIPRange(start, end)
|
||||
require.NoError(t, err)
|
||||
|
||||
num, ok := r.offset(end)
|
||||
require.True(t, ok)
|
||||
|
||||
testCases := []struct {
|
||||
predicate ipPredicate
|
||||
want netip.Addr
|
||||
name string
|
||||
}{{
|
||||
predicate: func(ip netip.Addr) (ok bool) {
|
||||
ipData := ip.AsSlice()
|
||||
|
||||
return ipData[len(ipData)-1]%2 == 0
|
||||
},
|
||||
want: netip.MustParseAddr("0.0.0.2"),
|
||||
name: "even",
|
||||
}, {
|
||||
predicate: func(ip netip.Addr) (ok bool) {
|
||||
ipData := ip.AsSlice()
|
||||
|
||||
return ipData[len(ipData)-1]%10 == 0
|
||||
},
|
||||
want: netip.Addr{},
|
||||
name: "none",
|
||||
}, {
|
||||
predicate: func(ip netip.Addr) (ok bool) {
|
||||
return true
|
||||
},
|
||||
want: start,
|
||||
name: "first",
|
||||
}, {
|
||||
predicate: func(ip netip.Addr) (ok bool) {
|
||||
off, _ := r.offset(ip)
|
||||
|
||||
return off == num
|
||||
},
|
||||
want: end,
|
||||
name: "last",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := r.find(tc.predicate)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRange_Offset(t *testing.T) {
|
||||
start, end := netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.5")
|
||||
r, err := newIPRange(start, end)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
in netip.Addr
|
||||
name string
|
||||
wantOffset uint64
|
||||
wantOK bool
|
||||
}{{
|
||||
in: netip.MustParseAddr("0.0.0.2"),
|
||||
name: "in",
|
||||
wantOffset: 1,
|
||||
wantOK: true,
|
||||
}, {
|
||||
in: start,
|
||||
name: "in_start",
|
||||
wantOffset: 0,
|
||||
wantOK: true,
|
||||
}, {
|
||||
in: end,
|
||||
name: "in_end",
|
||||
wantOffset: 4,
|
||||
wantOK: true,
|
||||
}, {
|
||||
in: netip.MustParseAddr("0.0.0.6"),
|
||||
name: "out_after",
|
||||
wantOffset: 0,
|
||||
wantOK: false,
|
||||
}, {
|
||||
in: netip.MustParseAddr("0.0.0.0"),
|
||||
name: "out_before",
|
||||
wantOffset: 0,
|
||||
wantOK: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
offset, ok := r.offset(tc.in)
|
||||
assert.Equal(t, tc.wantOffset, offset)
|
||||
assert.Equal(t, tc.wantOK, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
|
||||
type DHCPServer struct {
|
||||
// enabled indicates whether the DHCP server is enabled and can provide
|
||||
// information about its clients.
|
||||
enabled *atomic.Bool
|
||||
|
||||
// localTLD is the top-level domain name to use for resolving DHCP
|
||||
// clients' hostnames.
|
||||
localTLD string
|
||||
|
||||
// interfaces4 is the set of IPv4 interfaces sorted by interface name.
|
||||
interfaces4 []*iface4
|
||||
|
||||
// interfaces6 is the set of IPv6 interfaces sorted by interface name.
|
||||
interfaces6 []*iface6
|
||||
|
||||
// icmpTimeout is the timeout for checking another DHCP server's presence.
|
||||
icmpTimeout time.Duration
|
||||
}
|
||||
|
||||
// New creates a new DHCP server with the given configuration. It returns an
|
||||
// error if the given configuration can't be used.
|
||||
//
|
||||
// TODO(e.burkov): Use.
|
||||
func New(conf *Config) (srv *DHCPServer, err error) {
|
||||
if !conf.Enabled {
|
||||
// TODO(e.burkov): Perhaps return [Empty]?
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ifaces4 := make([]*iface4, len(conf.Interfaces))
|
||||
ifaces6 := make([]*iface6, len(conf.Interfaces))
|
||||
|
||||
ifaceNames := maps.Keys(conf.Interfaces)
|
||||
slices.Sort(ifaceNames)
|
||||
|
||||
var i4 *iface4
|
||||
var i6 *iface6
|
||||
|
||||
for _, ifaceName := range ifaceNames {
|
||||
iface := conf.Interfaces[ifaceName]
|
||||
|
||||
i4, err = newIface4(ifaceName, iface.IPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("interface %q: ipv4: %w", ifaceName, err)
|
||||
} else if i4 != nil {
|
||||
ifaces4 = append(ifaces4, i4)
|
||||
}
|
||||
|
||||
i6 = newIface6(ifaceName, iface.IPv6)
|
||||
if i6 != nil {
|
||||
ifaces6 = append(ifaces6, i6)
|
||||
}
|
||||
}
|
||||
|
||||
enabled := &atomic.Bool{}
|
||||
enabled.Store(conf.Enabled)
|
||||
|
||||
return &DHCPServer{
|
||||
enabled: enabled,
|
||||
interfaces4: ifaces4,
|
||||
interfaces6: ifaces6,
|
||||
localTLD: conf.LocalDomainName,
|
||||
icmpTimeout: conf.ICMPTimeout,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
package dhcpsvc_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
)
|
||||
|
||||
// testLocalTLD is a common local TLD for tests.
|
||||
const testLocalTLD = "local"
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
validIPv4Conf := &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("192.168.0.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("192.168.0.2"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.0.254"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
}
|
||||
gwInRangeConf := &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("192.168.0.100"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("192.168.0.1"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.0.254"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
}
|
||||
badStartConf := &dhcpsvc.IPv4Config{
|
||||
Enabled: true,
|
||||
GatewayIP: netip.MustParseAddr("192.168.0.1"),
|
||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||
RangeStart: netip.MustParseAddr("127.0.0.1"),
|
||||
RangeEnd: netip.MustParseAddr("192.168.0.254"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
}
|
||||
|
||||
validIPv6Conf := &dhcpsvc.IPv6Config{
|
||||
Enabled: true,
|
||||
RangeStart: netip.MustParseAddr("2001:db8::1"),
|
||||
LeaseDuration: 1 * time.Hour,
|
||||
RAAllowSLAAC: true,
|
||||
RASLAACOnly: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
conf *dhcpsvc.Config
|
||||
name string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: validIPv4Conf,
|
||||
IPv6: validIPv6Conf,
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "valid",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: &dhcpsvc.IPv4Config{Enabled: false},
|
||||
IPv6: &dhcpsvc.IPv6Config{Enabled: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "disabled_interfaces",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: gwInRangeConf,
|
||||
IPv6: validIPv6Conf,
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "gateway_within_range",
|
||||
wantErrMsg: `interface "eth0": ipv4: ` +
|
||||
`gateway ip 192.168.0.100 in the ip range 192.168.0.1-192.168.0.254`,
|
||||
}, {
|
||||
conf: &dhcpsvc.Config{
|
||||
Enabled: true,
|
||||
LocalDomainName: testLocalTLD,
|
||||
Interfaces: map[string]*dhcpsvc.InterfaceConfig{
|
||||
"eth0": {
|
||||
IPv4: badStartConf,
|
||||
IPv6: validIPv6Conf,
|
||||
},
|
||||
},
|
||||
},
|
||||
name: "bad_start",
|
||||
wantErrMsg: `interface "eth0": ipv4: ` +
|
||||
`range start 127.0.0.1 is not within 192.168.0.1/24`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := dhcpsvc.New(tc.conf)
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// IPv4Config is the interface-specific configuration for DHCPv4.
|
||||
type IPv4Config struct {
|
||||
// GatewayIP is the IPv4 address of the network's gateway. It is used as
|
||||
// the default gateway for DHCP clients and also used in calculating the
|
||||
// network-specific broadcast address.
|
||||
GatewayIP netip.Addr
|
||||
|
||||
// SubnetMask is the IPv4 subnet mask of the network. It should be a valid
|
||||
// IPv4 CIDR (i.e. all 1s followed by all 0s).
|
||||
SubnetMask netip.Addr
|
||||
|
||||
// RangeStart is the first address in the range to assign to DHCP clients.
|
||||
RangeStart netip.Addr
|
||||
|
||||
// RangeEnd is the last address in the range to assign to DHCP clients.
|
||||
RangeEnd netip.Addr
|
||||
|
||||
// Options is the list of DHCP options to send to DHCP clients.
|
||||
Options layers.DHCPOptions
|
||||
|
||||
// LeaseDuration is the TTL of a DHCP lease.
|
||||
LeaseDuration time.Duration
|
||||
|
||||
// Enabled is the state of the DHCPv4 service, whether it is enabled or not
|
||||
// on the specific interface.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// validate returns an error in conf if any.
|
||||
func (conf *IPv4Config) validate() (err error) {
|
||||
switch {
|
||||
case conf == nil:
|
||||
return errNilConfig
|
||||
case !conf.Enabled:
|
||||
return nil
|
||||
case !conf.GatewayIP.Is4():
|
||||
return mustBeErr("gateway ip", "be a valid ipv4", conf.GatewayIP)
|
||||
case !conf.SubnetMask.Is4():
|
||||
return mustBeErr("subnet mask", "be a valid ipv4 cidr mask", conf.SubnetMask)
|
||||
case !conf.RangeStart.Is4():
|
||||
return mustBeErr("range start", "be a valid ipv4", conf.RangeStart)
|
||||
case !conf.RangeEnd.Is4():
|
||||
return mustBeErr("range end", "be a valid ipv4", conf.RangeEnd)
|
||||
case conf.LeaseDuration <= 0:
|
||||
return mustBeErr("lease duration", "be less than %d", conf.LeaseDuration)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// iface4 is a DHCP interface for IPv4 address family.
|
||||
type iface4 struct {
|
||||
// gateway is the IP address of the network gateway.
|
||||
gateway netip.Addr
|
||||
|
||||
// subnet is the network subnet.
|
||||
subnet netip.Prefix
|
||||
|
||||
// addrSpace is the IPv4 address space allocated for leasing.
|
||||
addrSpace ipRange
|
||||
|
||||
// name is the name of the interface.
|
||||
name string
|
||||
|
||||
// TODO(e.burkov): Add options.
|
||||
|
||||
// leaseTTL is the time-to-live of dynamic leases on this interface.
|
||||
leaseTTL time.Duration
|
||||
}
|
||||
|
||||
// newIface4 creates a new DHCP interface for IPv4 address family with the given
|
||||
// configuration. It returns an error if the given configuration can't be used.
|
||||
func newIface4(name string, conf *IPv4Config) (i *iface4, err error) {
|
||||
if !conf.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
|
||||
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
|
||||
|
||||
switch {
|
||||
case !subnet.Contains(conf.RangeStart):
|
||||
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
|
||||
case !subnet.Contains(conf.RangeEnd):
|
||||
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
|
||||
}
|
||||
|
||||
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if addrSpace.contains(conf.GatewayIP) {
|
||||
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
|
||||
}
|
||||
|
||||
return &iface4{
|
||||
name: name,
|
||||
gateway: conf.GatewayIP,
|
||||
subnet: subnet,
|
||||
addrSpace: addrSpace,
|
||||
leaseTTL: conf.LeaseDuration,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package dhcpsvc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
)
|
||||
|
||||
// IPv6Config is the interface-specific configuration for DHCPv6.
|
||||
type IPv6Config struct {
|
||||
// RangeStart is the first address in the range to assign to DHCP clients.
|
||||
RangeStart netip.Addr
|
||||
|
||||
// Options is the list of DHCP options to send to DHCP clients.
|
||||
Options layers.DHCPOptions
|
||||
|
||||
// LeaseDuration is the TTL of a DHCP lease.
|
||||
LeaseDuration time.Duration
|
||||
|
||||
// RASlaacOnly defines whether the DHCP clients should only use SLAAC for
|
||||
// address assignment.
|
||||
RASLAACOnly bool
|
||||
|
||||
// RAAllowSlaac defines whether the DHCP clients may use SLAAC for address
|
||||
// assignment.
|
||||
RAAllowSLAAC bool
|
||||
|
||||
// Enabled is the state of the DHCPv6 service, whether it is enabled or not
|
||||
// on the specific interface.
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// validate returns an error in conf if any.
|
||||
func (conf *IPv6Config) validate() (err error) {
|
||||
switch {
|
||||
case conf == nil:
|
||||
return errNilConfig
|
||||
case !conf.Enabled:
|
||||
return nil
|
||||
case !conf.RangeStart.Is6():
|
||||
return fmt.Errorf("range start %s should be a valid ipv6", conf.RangeStart)
|
||||
case conf.LeaseDuration <= 0:
|
||||
return fmt.Errorf("lease duration %s must be positive", conf.LeaseDuration)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// iface6 is a DHCP interface for IPv6 address family.
|
||||
//
|
||||
// TODO(e.burkov): Add options.
|
||||
type iface6 struct {
|
||||
// rangeStart is the first IP address in the range.
|
||||
rangeStart netip.Addr
|
||||
|
||||
// name is the name of the interface.
|
||||
name string
|
||||
|
||||
// leaseTTL is the time-to-live of dynamic leases on this interface.
|
||||
leaseTTL time.Duration
|
||||
|
||||
// raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO
|
||||
// flags.
|
||||
raSLAACOnly bool
|
||||
|
||||
// raAllowSLAAC defines if DHCP should send ICMPv6.RA packets with MO flags.
|
||||
raAllowSLAAC bool
|
||||
}
|
||||
|
||||
// newIface6 creates a new DHCP interface for IPv6 address family with the given
|
||||
// configuration.
|
||||
//
|
||||
// TODO(e.burkov): Validate properly.
|
||||
func newIface6(name string, conf *IPv6Config) (i *iface6) {
|
||||
if !conf.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &iface6{
|
||||
name: name,
|
||||
rangeStart: conf.RangeStart,
|
||||
leaseTTL: conf.LeaseDuration,
|
||||
raSLAACOnly: conf.RASLAACOnly,
|
||||
raAllowSLAAC: conf.RAAllowSLAAC,
|
||||
}
|
||||
}
|
||||
@@ -182,7 +182,6 @@ func (s *Server) accessListJSON() (j accessListJSON) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleAccessList handles requests to the GET /control/access/list endpoint.
|
||||
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
|
||||
}
|
||||
@@ -225,7 +224,6 @@ func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error
|
||||
return uc, uc.Validate()
|
||||
}
|
||||
|
||||
// handleAccessSet handles requests to the POST /control/access/set endpoint.
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
list := &accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
@@ -152,8 +151,6 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
|
||||
// if there is one.
|
||||
func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string, err error) {
|
||||
from := "tls conn"
|
||||
|
||||
switch proto {
|
||||
case proxy.ProtoHTTPS:
|
||||
r := pctx.HTTPRequest
|
||||
@@ -167,7 +164,6 @@ func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string
|
||||
}
|
||||
|
||||
srvName = host
|
||||
from = "host header"
|
||||
}
|
||||
case proxy.ProtoQUIC:
|
||||
qConn := pctx.QUICConnection
|
||||
@@ -187,7 +183,5 @@ func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string
|
||||
srvName = tc.ConnectionState().ServerName
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: got client server name %q from %s", srvName, from)
|
||||
|
||||
return srvName, nil
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
@@ -446,10 +444,19 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = validateUpstreamConfig(upstreams)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
for _, u := range upstreams {
|
||||
var ups string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = validateUpstream(ups, domains)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
@@ -460,7 +467,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil, err
|
||||
} else if len(conf.Upstreams) == 0 {
|
||||
return nil, errors.Error("no default upstreams specified")
|
||||
@@ -469,31 +475,6 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// validateUpstreamConfig validates each upstream from the upstream
|
||||
// configuration and returns an error if any upstream is invalid.
|
||||
//
|
||||
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||
func validateUpstreamConfig(conf []string) (err error) {
|
||||
for _, u := range conf {
|
||||
var ups []string
|
||||
var domains []string
|
||||
ups, domains, err = separateUpstream(u)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range ups {
|
||||
_, err = validateUpstream(addr, len(domains) > 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any
|
||||
// upstream is invalid or if there are no default upstreams specified.
|
||||
//
|
||||
@@ -553,14 +534,14 @@ var protocols = []string{
|
||||
}
|
||||
|
||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||
// upstream configuration. usesDefault is true if the upstream is
|
||||
// upstream configuration. useDefault is true if the upstream is
|
||||
// domain-specific and is configured to point at the default upstream server
|
||||
// which is validated separately. specific reflects if the upstream is
|
||||
// domain-specific.
|
||||
func validateUpstream(u string, specific bool) (usesDefault bool, err error) {
|
||||
// which is validated separately. The upstream is considered domain-specific
|
||||
// only if domains is at least not nil.
|
||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
||||
// The special server address '#' means that default server must be used.
|
||||
if u == "#" && specific {
|
||||
return true, nil
|
||||
if useDefault = u == "#" && domains != nil; useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix.
|
||||
@@ -586,12 +567,12 @@ func validateUpstream(u string, specific bool) (usesDefault bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// separateUpstream returns the upstreams and the specified domains. domains
|
||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// separateUpstream returns the upstream and the specified domains. domains is
|
||||
// nil when the upstream is not domains-specific. Otherwise it may also be
|
||||
// empty.
|
||||
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
|
||||
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return []string{upstreamStr}, nil, nil
|
||||
return upstreamStr, nil, nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||
@@ -601,9 +582,9 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
|
||||
case 2:
|
||||
// Go on.
|
||||
case 1:
|
||||
return nil, nil, errors.Error("missing separator")
|
||||
return "", nil, errors.Error("missing separator")
|
||||
default:
|
||||
return nil, nil, errors.Error("duplicated separator")
|
||||
return "", []string{}, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
for i, host := range strings.Split(parts[0], "/") {
|
||||
@@ -613,22 +594,21 @@ func separateUpstream(upstreamStr string) (upstreams, domains []string, err erro
|
||||
|
||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
domains = append(domains, host)
|
||||
}
|
||||
|
||||
return strings.Fields(parts[1]), domains, nil
|
||||
return parts[1], domains, nil
|
||||
}
|
||||
|
||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||
// properly.
|
||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||
|
||||
// checkExchange is a [healthCheckFunc] that checks if the DNS upstream
|
||||
// exchanges correctly.
|
||||
func checkExchange(u upstream.Upstream) (err error) {
|
||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// testTLD is the special-use fully-qualified domain name for testing the
|
||||
// DNS server reachability.
|
||||
//
|
||||
@@ -658,11 +638,11 @@ func checkExchange(u upstream.Upstream) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPrivateExchange is a [healthCheckFunc] that checks if the upstream for
|
||||
// resolving private addresses exchanges correctly.
|
||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
||||
// addresses exchanges correctly.
|
||||
//
|
||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||
func checkPrivateExchange(u upstream.Upstream) (err error) {
|
||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||
// address resolution.
|
||||
//
|
||||
@@ -703,153 +683,75 @@ func (err domainSpecificTestError) Error() (msg string) {
|
||||
return fmt.Sprintf("WARNING: %s", err.error)
|
||||
}
|
||||
|
||||
// checkUpstreamAddr creates the upstream using opts and, possibly, information
|
||||
// from system hosts files, then checks if the DNS upstream exchanges correctly.
|
||||
// It returns an error if addr is not valid DNS upstream address or the upstream
|
||||
// is not exchanging correctly.
|
||||
//
|
||||
// TODO(e.burkov): Remove the receiver.
|
||||
func (s *Server) checkUpstreamAddr(
|
||||
addr string,
|
||||
specific bool,
|
||||
basicOpts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
) (err error) {
|
||||
usesDefault, err := validateUpstream(addr, specific)
|
||||
// parseUpstreamLine parses line and creates the [upstream.Upstream] using opts
|
||||
// and information from [s.dnsFilter.EtcHosts]. It returns an error if the line
|
||||
// is not a valid upstream line, see [upstream.AddressToUpstream]. It's a
|
||||
// caller's responsibility to close u.
|
||||
func (s *Server) parseUpstreamLine(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
) (u upstream.Upstream, specific bool, err error) {
|
||||
// Separate upstream from domains list.
|
||||
upstreamAddr, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if usesDefault {
|
||||
return nil, false, fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
specific = len(domains) > 0
|
||||
|
||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
||||
if err != nil {
|
||||
return nil, specific, fmt.Errorf("wrong upstream format: %w", err)
|
||||
} else if useDefault {
|
||||
return nil, specific, nil
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
opts = &upstream.Options{
|
||||
Bootstrap: opts.Bootstrap,
|
||||
Timeout: opts.Timeout,
|
||||
PreferIPv6: opts.PreferIPv6,
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr))
|
||||
for _, rec := range recs {
|
||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
|
||||
if err != nil {
|
||||
return nil, specific, fmt.Errorf("creating upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
|
||||
return u, specific, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkDNS(line string, opts *upstream.Options, check healthCheckFunc) (err error) {
|
||||
if IsCommentOrEmpty(line) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checking if upstream %q works", addr)
|
||||
|
||||
var u upstream.Upstream
|
||||
var specific bool
|
||||
defer func() {
|
||||
if err != nil && specific {
|
||||
err = domainSpecificTestError{error: err}
|
||||
}
|
||||
}()
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: basicOpts.Bootstrap,
|
||||
Timeout: basicOpts.Timeout,
|
||||
PreferIPv6: basicOpts.PreferIPv6,
|
||||
}
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
//
|
||||
// TODO(e.burkov): Remove when update dnsproxy.
|
||||
if s.dnsFilter != nil {
|
||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||
for _, rec := range recs {
|
||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||
}
|
||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||
}
|
||||
|
||||
u, err := upstream.AddressToUpstream(addr, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||
u, specific, err = s.parseUpstreamLine(line, opts)
|
||||
if err != nil || u == nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
return check(u)
|
||||
}
|
||||
|
||||
// checkResult is a result of checking an upstream server.
|
||||
type checkResult = struct {
|
||||
// status is an error message if the upstream server is not working. It's
|
||||
// nil for working upstreams.
|
||||
status error
|
||||
|
||||
// address is the upstream server address as given in the request. It may
|
||||
// appear to be a whole line if it's incorrect itself.
|
||||
address string
|
||||
}
|
||||
|
||||
// checkDNS parses an upstream configuration line using opts and checks if the
|
||||
// specified upstreams are working using check. countWG is decremented when the
|
||||
// expected number of results added to resNum, then results are sent to resCh.
|
||||
//
|
||||
// TODO(e.burkov): Remove the receiver.
|
||||
func (s *Server) checkDNS(
|
||||
line string,
|
||||
opts *upstream.Options,
|
||||
check healthCheckFunc,
|
||||
countWG *sync.WaitGroup,
|
||||
resNum *atomic.Int32,
|
||||
resCh chan<- checkResult,
|
||||
) {
|
||||
defer log.OnPanic("dnsforward: checking upstreams")
|
||||
|
||||
addrs, domains, err := separateUpstream(line)
|
||||
if err != nil {
|
||||
resNum.Add(1)
|
||||
countWG.Done()
|
||||
|
||||
resCh <- checkResult{
|
||||
address: line,
|
||||
status: fmt.Errorf("wrong upstream format: %w", err),
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resNum.Add(int32(len(addrs)))
|
||||
countWG.Done()
|
||||
|
||||
specific := len(domains) > 0
|
||||
for _, addr := range addrs {
|
||||
resCh <- checkResult{
|
||||
address: addr,
|
||||
status: s.checkUpstreamAddr(addr, specific, opts, check),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check returns the mapping of upstream addresses to their check results.
|
||||
func (s *Server) check(req *upstreamJSON, opts *upstream.Options) (result map[string]string) {
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
countWG := &sync.WaitGroup{}
|
||||
countWG.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||
|
||||
resNum := &atomic.Int32{}
|
||||
resCh := make(chan checkResult)
|
||||
|
||||
for _, addr := range req.Upstreams {
|
||||
go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh)
|
||||
}
|
||||
for _, addr := range req.FallbackDNS {
|
||||
go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh)
|
||||
}
|
||||
for _, addr := range req.PrivateUpstreams {
|
||||
go s.checkDNS(addr, opts, checkPrivateExchange, countWG, resNum, resCh)
|
||||
}
|
||||
|
||||
// Wait until all the servers are counted and enqueued.
|
||||
countWG.Wait()
|
||||
n := resNum.Load()
|
||||
|
||||
result = make(map[string]string, n)
|
||||
for i := int32(0); i < n; i++ {
|
||||
// TODO(e.burkov): Upstreams intended for different purposes should
|
||||
// be distinguished in the result, even if specified equally.
|
||||
res := <-resCh
|
||||
if res.status != nil {
|
||||
result[res.address] = res.status.Error()
|
||||
} else {
|
||||
result[res.address] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
|
||||
// endpoint.
|
||||
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req := &upstreamJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
@@ -859,18 +761,65 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
bootstrapAddrs := stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||
if len(bootstrapAddrs) == 0 {
|
||||
bootstrapAddrs = defaultBootstrap
|
||||
}
|
||||
|
||||
opts := &upstream.Options{
|
||||
Bootstrap: bootstrapAddrs,
|
||||
Bootstrap: req.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||
}
|
||||
if len(opts.Bootstrap) == 0 {
|
||||
opts.Bootstrap = defaultBootstrap
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, s.check(req, opts))
|
||||
type upsCheckResult = struct {
|
||||
err error
|
||||
host string
|
||||
}
|
||||
|
||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||
|
||||
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)
|
||||
result := make(map[string]string, upsNum)
|
||||
resCh := make(chan upsCheckResult, upsNum)
|
||||
|
||||
for _, ups := range req.Upstreams {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
for _, ups := range req.FallbackDNS {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
for _, ups := range req.PrivateUpstreams {
|
||||
go func(ups string) {
|
||||
resCh <- upsCheckResult{
|
||||
host: ups,
|
||||
err: s.checkDNS(ups, opts, checkPrivateUpstreamExc),
|
||||
}
|
||||
}(ups)
|
||||
}
|
||||
|
||||
for i := 0; i < upsNum; i++ {
|
||||
// TODO(e.burkov): The upstreams used for both common and private
|
||||
// resolving should be reported separately.
|
||||
pair := <-resCh
|
||||
if pair.err != nil {
|
||||
result[pair.host] = pair.err.Error()
|
||||
} else {
|
||||
result[pair.host] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
}
|
||||
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
|
||||
@@ -49,18 +49,13 @@ func loadTestData(t *testing.T, casesFileName string, cases any) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
const (
|
||||
jsonExt = ".json"
|
||||
|
||||
// testBlockedRespTTL is the TTL for blocked responses to use in tests.
|
||||
testBlockedRespTTL = 10
|
||||
)
|
||||
const jsonExt = ".json"
|
||||
|
||||
func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
BlockedResponseTTL: testBlockedRespTTL,
|
||||
BlockedResponseTTL: 10,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
@@ -138,7 +133,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
filterConf := &filtering.Config{
|
||||
ProtectionEnabled: true,
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
BlockedResponseTTL: testBlockedRespTTL,
|
||||
BlockedResponseTTL: 10,
|
||||
SafeBrowsingEnabled: true,
|
||||
SafeBrowsingCacheSize: 1000,
|
||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
@@ -234,9 +229,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
}, {
|
||||
name: "blocked_response_ttl",
|
||||
wantSet: "",
|
||||
}, {
|
||||
name: "multiple_domain_specific_upstreams",
|
||||
wantSet: "",
|
||||
}}
|
||||
|
||||
var data map[string]struct {
|
||||
@@ -258,7 +250,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
|
||||
s.conf = defaultConf
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
|
||||
s.dnsFilter.SetBlockedResponseTTL(testBlockedRespTTL)
|
||||
})
|
||||
|
||||
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
||||
@@ -479,8 +470,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||
}).String()
|
||||
|
||||
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
|
||||
|
||||
const (
|
||||
upsTimeout = 100 * time.Millisecond
|
||||
|
||||
@@ -558,7 +547,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
"[/domain.example/]" + badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
@@ -596,40 +585,6 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
goodUps: "OK",
|
||||
},
|
||||
name: "fallback_comment_mix",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + goodUps + " " + badUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "multiple_domain_specific_upstreams",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
|
||||
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
|
||||
`duplicated separator`,
|
||||
},
|
||||
name: "bad_specification",
|
||||
}, {
|
||||
body: map[string]any{
|
||||
"upstream_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
||||
"fallback_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
||||
"private_upstream": []string{"[/domain.example/]" + goodAndBadUps},
|
||||
},
|
||||
wantResp: map[string]any{
|
||||
goodUps: "OK",
|
||||
badUps: `WARNING: couldn't communicate ` +
|
||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||
`dns: id mismatch`,
|
||||
},
|
||||
name: "all_different",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -839,47 +839,5 @@
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
},
|
||||
"multiple_domain_specific_upstreams": {
|
||||
"req": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:77",
|
||||
"[/example.com/]8.8.4.4:77 9.9.9.10 https://1.1.1.1"
|
||||
]
|
||||
},
|
||||
"want": {
|
||||
"upstream_dns": [
|
||||
"8.8.8.8:77",
|
||||
"[/example.com/]8.8.4.4:77 9.9.9.10 https://1.1.1.1"
|
||||
],
|
||||
"upstream_dns_file": "",
|
||||
"bootstrap_dns": [
|
||||
"9.9.9.10",
|
||||
"149.112.112.10",
|
||||
"2620:fe::10",
|
||||
"2620:fe::fe:10"
|
||||
],
|
||||
"fallback_dns": [],
|
||||
"protection_enabled": true,
|
||||
"protection_disabled_until": null,
|
||||
"ratelimit": 0,
|
||||
"blocking_mode": "default",
|
||||
"blocking_ipv4": "",
|
||||
"blocking_ipv6": "",
|
||||
"blocked_response_ttl": 10,
|
||||
"edns_cs_enabled": false,
|
||||
"dnssec_enabled": false,
|
||||
"disable_ipv6": false,
|
||||
"upstream_mode": "",
|
||||
"cache_size": 0,
|
||||
"cache_ttl_min": 0,
|
||||
"cache_ttl_max": 0,
|
||||
"cache_optimistic": false,
|
||||
"resolve_clients": false,
|
||||
"use_private_ptr_resolvers": false,
|
||||
"local_ptr_upstreams": [],
|
||||
"edns_cs_use_custom": false,
|
||||
"edns_cs_custom_ip": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,6 +263,30 @@ func assignUniqueFilterID() int64 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Sets up a timer that will be checking for filters updates periodically
|
||||
func (d *DNSFilter) periodicallyRefreshFilters() {
|
||||
const maxInterval = 1 * 60 * 60
|
||||
ivl := 5 // use a dynamically increasing time interval
|
||||
for {
|
||||
isNetErr, ok := false, false
|
||||
if d.conf.FiltersUpdateIntervalHours != 0 {
|
||||
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||
if ok && !isNetErr {
|
||||
ivl = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
if isNetErr {
|
||||
ivl *= 2
|
||||
if ivl > maxInterval {
|
||||
ivl = maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(ivl) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
||||
// already going on.
|
||||
//
|
||||
|
||||
@@ -257,9 +257,6 @@ type DNSFilter struct {
|
||||
// conf contains filtering parameters.
|
||||
conf *Config
|
||||
|
||||
// done is the channel to signal to stop running filters updates loop.
|
||||
done chan struct{}
|
||||
|
||||
// Channel for passing data to filters-initializer goroutine
|
||||
filtersInitializerChan chan filtersInitializerParams
|
||||
filtersInitializerLock sync.Mutex
|
||||
@@ -427,15 +424,24 @@ func (d *DNSFilter) setFilters(blockFilters, allowFilters []Filter, async bool)
|
||||
return d.initFiltering(allowFilters, blockFilters)
|
||||
}
|
||||
|
||||
// Starts initializing new filters by signal from channel
|
||||
func (d *DNSFilter) filtersInitializer() {
|
||||
for {
|
||||
params := <-d.filtersInitializerChan
|
||||
err := d.initFiltering(params.allowFilters, params.blockFilters)
|
||||
if err != nil {
|
||||
log.Error("filtering: initializing: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close - close the object
|
||||
func (d *DNSFilter) Close() {
|
||||
d.engineLock.Lock()
|
||||
defer d.engineLock.Unlock()
|
||||
|
||||
if d.done != nil {
|
||||
d.done <- struct{}{}
|
||||
}
|
||||
|
||||
d.reset()
|
||||
}
|
||||
|
||||
@@ -1125,64 +1131,19 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Start registers web handlers and starts filters updates loop.
|
||||
// Start - start the module:
|
||||
// . start async filtering initializer goroutine
|
||||
// . register web handlers
|
||||
func (d *DNSFilter) Start() {
|
||||
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
||||
d.done = make(chan struct{}, 1)
|
||||
go d.filtersInitializer()
|
||||
|
||||
d.RegisterFilteringHandlers()
|
||||
|
||||
go d.updatesLoop()
|
||||
}
|
||||
|
||||
// updatesLoop initializes new filters and checks for filters updates in a loop.
|
||||
func (d *DNSFilter) updatesLoop() {
|
||||
defer log.OnPanic("filtering: updates loop")
|
||||
|
||||
ivl := time.Second * 5
|
||||
t := time.NewTimer(ivl)
|
||||
|
||||
for {
|
||||
select {
|
||||
case params := <-d.filtersInitializerChan:
|
||||
err := d.initFiltering(params.allowFilters, params.blockFilters)
|
||||
if err != nil {
|
||||
log.Error("filtering: initializing: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
case <-t.C:
|
||||
ivl = d.periodicallyRefreshFilters(ivl)
|
||||
t.Reset(ivl)
|
||||
case <-d.done:
|
||||
t.Stop()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// periodicallyRefreshFilters checks for filters updates and returns time
|
||||
// interval for the next update.
|
||||
func (d *DNSFilter) periodicallyRefreshFilters(ivl time.Duration) (nextIvl time.Duration) {
|
||||
const maxInterval = time.Hour
|
||||
|
||||
if d.conf.FiltersUpdateIntervalHours == 0 {
|
||||
return ivl
|
||||
}
|
||||
|
||||
isNetErr, ok := false, false
|
||||
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||
|
||||
if ok && !isNetErr {
|
||||
ivl = maxInterval
|
||||
} else if isNetErr {
|
||||
ivl *= 2
|
||||
// TODO(s.chzhen): Use built-in function max in Go 1.21.
|
||||
ivl = mathutil.Max(ivl, maxInterval)
|
||||
}
|
||||
|
||||
return ivl
|
||||
// Here we should start updating filters,
|
||||
// but currently we can't wake up the periodic task to do so.
|
||||
// So for now we just start this periodic task from here.
|
||||
go d.periodicallyRefreshFilters()
|
||||
}
|
||||
|
||||
// Safe browsing and parental control methods.
|
||||
|
||||
@@ -4,17 +4,32 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"go.etcd.io/bbolt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// cookieTTL is the time-to-live of the session cookie.
|
||||
const cookieTTL = 365 * timeutil.Day
|
||||
|
||||
// sessionCookieName is the name of the session cookie.
|
||||
const sessionCookieName = "agh_session"
|
||||
|
||||
// sessionTokenSize is the length of session token in bytes.
|
||||
const sessionTokenSize = 16
|
||||
|
||||
@@ -54,7 +69,7 @@ func (s *session) deserialize(data []byte) bool {
|
||||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
rateLimiter *authRateLimiter
|
||||
raleLimiter *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []webUser
|
||||
lock sync.Mutex
|
||||
@@ -62,8 +77,6 @@ type Auth struct {
|
||||
}
|
||||
|
||||
// webUser represents a user of the Web UI.
|
||||
//
|
||||
// TODO(s.chzhen): Improve naming.
|
||||
type webUser struct {
|
||||
Name string `yaml:"name"`
|
||||
PasswordHash string `yaml:"password"`
|
||||
@@ -75,7 +88,7 @@ func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter
|
||||
|
||||
a := &Auth{
|
||||
sessionTTL: sessionTTL,
|
||||
rateLimiter: rateLimiter,
|
||||
raleLimiter: rateLimiter,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
}
|
||||
@@ -203,8 +216,8 @@ func (a *Auth) storeSession(data []byte, s *session) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// removeSessionFromFile removes a stored session from the DB file on disk.
|
||||
func (a *Auth) removeSessionFromFile(sess []byte) {
|
||||
// remove session from file
|
||||
func (a *Auth) removeSession(sess []byte) {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
@@ -266,7 +279,7 @@ func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSessionFromFile(key)
|
||||
a.removeSession(key)
|
||||
|
||||
return checkSessionExpired
|
||||
}
|
||||
@@ -288,17 +301,351 @@ func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
return checkSessionOK
|
||||
}
|
||||
|
||||
// removeSession removes the session from the active sessions and the disk.
|
||||
func (a *Auth) removeSession(sess string) {
|
||||
// RemoveSession - remove session
|
||||
func (a *Auth) RemoveSession(sess string) {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.lock.Lock()
|
||||
delete(a.sessions, sess)
|
||||
a.lock.Unlock()
|
||||
a.removeSessionFromFile(key)
|
||||
a.removeSession(key)
|
||||
}
|
||||
|
||||
// addUser adds a new user with the given password.
|
||||
func (a *Auth) addUser(u *webUser, password string) (err error) {
|
||||
type loginJSON struct {
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// newSessionToken returns cryptographically secure randomly generated slice of
|
||||
// bytes of sessionTokenSize length.
|
||||
//
|
||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||
func newSessionToken() (data []byte, err error) {
|
||||
randData := make([]byte, sessionTokenSize)
|
||||
|
||||
_, err = rand.Read(randData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return randData, nil
|
||||
}
|
||||
|
||||
// newCookie creates a new authentication cookie.
|
||||
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
|
||||
rateLimiter := a.raleLimiter
|
||||
u, ok := a.findUser(req.Name, req.Password)
|
||||
if !ok {
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.inc(addr)
|
||||
}
|
||||
|
||||
return nil, errors.Error("invalid username or password")
|
||||
}
|
||||
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.remove(addr)
|
||||
}
|
||||
|
||||
sess, err := newSessionToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
a.addSession(sess, &session{
|
||||
userName: u.Name,
|
||||
expire: uint32(now.Unix()) + a.sessionTTL,
|
||||
})
|
||||
|
||||
return &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: hex.EncodeToString(sess),
|
||||
Path: "/",
|
||||
Expires: now.Add(cookieTTL),
|
||||
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// realIP extracts the real IP address of the client from an HTTP request using
|
||||
// the known HTTP headers.
|
||||
//
|
||||
// TODO(a.garipov): Currently, this is basically a copy of a similar function in
|
||||
// module dnsproxy. This should really become a part of module golibs and be
|
||||
// replaced both here and there. Or be replaced in both places by
|
||||
// a well-maintained third-party module.
|
||||
//
|
||||
// TODO(a.garipov): Support header Forwarded from RFC 7329.
|
||||
func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
proxyHeaders := []string{
|
||||
httphdr.CFConnectingIP,
|
||||
httphdr.TrueClientIP,
|
||||
httphdr.XRealIP,
|
||||
}
|
||||
|
||||
for _, h := range proxyHeaders {
|
||||
v := r.Header.Get(h)
|
||||
ip = net.ParseIP(v)
|
||||
if ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If none of the above yielded any results, get the leftmost IP address
|
||||
// from the X-Forwarded-For header.
|
||||
s := r.Header.Get(httphdr.XForwardedFor)
|
||||
ipStrs := strings.SplitN(s, ", ", 2)
|
||||
ip = net.ParseIP(ipStrs[0])
|
||||
if ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// When everything else fails, just return the remote address as understood
|
||||
// by the stdlib.
|
||||
ipStr, err := netutil.SplitHost(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ip from client addr: %w", err)
|
||||
}
|
||||
|
||||
return net.ParseIP(ipStr), nil
|
||||
}
|
||||
|
||||
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
|
||||
// when it writes to the log.
|
||||
func writeErrorWithIP(
|
||||
r *http.Request,
|
||||
w http.ResponseWriter,
|
||||
code int,
|
||||
remoteIP string,
|
||||
format string,
|
||||
args ...any,
|
||||
) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s %s: from ip %s: %s", r.Method, r.Host, r.URL, remoteIP, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
req := loginJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var remoteIP string
|
||||
// realIP cannot be used here without taking TrustedProxies into account due
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
//
|
||||
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
r.RemoteAddr,
|
||||
"auth: getting remote address: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rateLimiter := Context.auth.raleLimiter; rateLimiter != nil {
|
||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
w,
|
||||
http.StatusTooManyRequests,
|
||||
remoteIP,
|
||||
"auth: blocked for %s",
|
||||
left,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||
if err != nil {
|
||||
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Use realIP here, since this IP address is only used for logging.
|
||||
ip, err := realIP(r)
|
||||
if err != nil {
|
||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||
}
|
||||
|
||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
h := w.Header()
|
||||
h.Set(httphdr.CacheControl, "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
h.Set(httphdr.Pragma, "no-cache")
|
||||
h.Set(httphdr.Expires, "0")
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
respHdr := w.Header()
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// The user is already logged out.
|
||||
respHdr.Set(httphdr.Location, "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Context.auth.RemoveSession(c.Value)
|
||||
|
||||
c = &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
respHdr.Set(httphdr.Location, "/login.html")
|
||||
respHdr.Set(httphdr.SetCookie, c.String())
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
// RegisterAuthHandlers - register handlers
|
||||
func RegisterAuthHandlers() {
|
||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||
}
|
||||
|
||||
// optionalAuthThird return true if user should authenticate first.
|
||||
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||
if glProcessCookie(r) {
|
||||
log.Debug("auth: authentication is handled by GL-Inet submodule")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// redirect to login page if not authenticated
|
||||
isAuthenticated := false
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// Check Basic authentication.
|
||||
user, pass, hasBasic := r.BasicAuth()
|
||||
if hasBasic {
|
||||
_, isAuthenticated = Context.auth.findUser(user, pass)
|
||||
if !isAuthenticated {
|
||||
log.Info("auth: invalid Basic Authorization value")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
isAuthenticated = res == checkSessionOK
|
||||
if !isAuthenticated {
|
||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
}
|
||||
|
||||
if isAuthenticated {
|
||||
return false
|
||||
}
|
||||
|
||||
if p := r.URL.Path; p == "/" || p == "/index.html" {
|
||||
if glProcessRedirect(w, r) {
|
||||
log.Debug("auth: redirected to login page by GL-Inet submodule")
|
||||
} else {
|
||||
log.Debug("auth: redirected to login page")
|
||||
http.Redirect(w, r, "login.html", http.StatusFound)
|
||||
}
|
||||
} else {
|
||||
log.Debug("auth: responded with forbidden to %s %s", r.Method, p)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("Forbidden"))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use [http.Handler] consistently everywhere throughout the
|
||||
// project.
|
||||
func optionalAuth(
|
||||
h func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
p := r.URL.Path
|
||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
||||
if p == "/login.html" {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
// Redirect to the dashboard if already authenticated.
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
if res == checkSessionOK {
|
||||
http.Redirect(w, r, "", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
||||
}
|
||||
} else if isPublicResource(p) {
|
||||
// Process as usual, no additional auth requirements.
|
||||
} else if authRequired {
|
||||
if optionalAuthThird(w, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// isPublicResource returns true if p is a path to a public resource.
|
||||
func isPublicResource(p string) (ok bool) {
|
||||
isAsset, err := path.Match("/assets/*", p)
|
||||
if err != nil {
|
||||
// The only error that is returned from path.Match is
|
||||
// [path.ErrBadPattern]. This is a programmer error.
|
||||
panic(fmt.Errorf("bad asset pattern: %w", err))
|
||||
}
|
||||
|
||||
isLogin, err := path.Match("/login.*", p)
|
||||
if err != nil {
|
||||
// Same as above.
|
||||
panic(fmt.Errorf("bad login pattern: %w", err))
|
||||
}
|
||||
|
||||
return isAsset || isLogin
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
optionalAuth(a.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||
return &authHandler{handler}
|
||||
}
|
||||
|
||||
// Add adds a new user with the given password.
|
||||
func (a *Auth) Add(u *webUser, password string) (err error) {
|
||||
if len(password) == 0 {
|
||||
return errors.Error("empty password")
|
||||
}
|
||||
@@ -368,40 +715,22 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
||||
return webUser{}
|
||||
}
|
||||
|
||||
// usersList returns a copy of a users list.
|
||||
func (a *Auth) usersList() (users []webUser) {
|
||||
// GetUsers - get users
|
||||
func (a *Auth) GetUsers() []webUser {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
users = make([]webUser, len(a.users))
|
||||
copy(users, a.users)
|
||||
|
||||
users := a.users
|
||||
a.lock.Unlock()
|
||||
return users
|
||||
}
|
||||
|
||||
// authRequired returns true if a authentication is required.
|
||||
func (a *Auth) authRequired() bool {
|
||||
// AuthRequired - if authentication is required
|
||||
func (a *Auth) AuthRequired() bool {
|
||||
if GLMode {
|
||||
return true
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
return len(a.users) != 0
|
||||
}
|
||||
|
||||
// newSessionToken returns cryptographically secure randomly generated slice of
|
||||
// bytes of sessionTokenSize length.
|
||||
//
|
||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||
func newSessionToken() (data []byte, err error) {
|
||||
randData := make([]byte, sessionTokenSize)
|
||||
|
||||
_, err = rand.Read(randData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return randData, nil
|
||||
r := (len(a.users) != 0)
|
||||
a.lock.Unlock()
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewSessionToken(t *testing.T) {
|
||||
// Successful case.
|
||||
token, err := newSessionToken()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, token, sessionTokenSize)
|
||||
|
||||
// Break the rand.Reader.
|
||||
prevReader := rand.Reader
|
||||
t.Cleanup(func() { rand.Reader = prevReader })
|
||||
rand.Reader = &bytes.Buffer{}
|
||||
|
||||
// Unsuccessful case.
|
||||
token, err = newSessionToken()
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []webUser{{
|
||||
Name: "name",
|
||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||
}}
|
||||
a := InitAuth(fn, nil, 60, nil)
|
||||
s := session{}
|
||||
|
||||
user := webUser{Name: "name"}
|
||||
err := a.addUser(&user, "password")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.removeSession("notfound")
|
||||
|
||||
sess, err := newSessionToken()
|
||||
require.NoError(t, err)
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
// check expiration
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
s = session{}
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
|
||||
// load saved session
|
||||
a = InitAuth(fn, users, 60, nil)
|
||||
|
||||
// the session is still alive
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
// reset our expiration time because checkSession() has just updated it
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u, ok := a.findUser("name", "password")
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, u.Name)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60, nil)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@@ -14,6 +18,82 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewSessionToken(t *testing.T) {
|
||||
// Successful case.
|
||||
token, err := newSessionToken()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, token, sessionTokenSize)
|
||||
|
||||
// Break the rand.Reader.
|
||||
prevReader := rand.Reader
|
||||
t.Cleanup(func() { rand.Reader = prevReader })
|
||||
rand.Reader = &bytes.Buffer{}
|
||||
|
||||
// Unsuccessful case.
|
||||
token, err = newSessionToken()
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []webUser{{
|
||||
Name: "name",
|
||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||
}}
|
||||
a := InitAuth(fn, nil, 60, nil)
|
||||
s := session{}
|
||||
|
||||
user := webUser{Name: "name"}
|
||||
err := a.Add(&user, "password")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.RemoveSession("notfound")
|
||||
|
||||
sess, err := newSessionToken()
|
||||
require.NoError(t, err)
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
// check expiration
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
s = session{}
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
|
||||
// load saved session
|
||||
a = InitAuth(fn, users, 60, nil)
|
||||
|
||||
// the session is still alive
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
// reset our expiration time because checkSession() has just updated it
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u, ok := a.findUser("name", "password")
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, u.Name)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60, nil)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
}
|
||||
|
||||
// implements http.ResponseWriter
|
||||
type testResponseWriter struct {
|
||||
hdr http.Header
|
||||
@@ -1,352 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// cookieTTL is the time-to-live of the session cookie.
|
||||
const cookieTTL = 365 * timeutil.Day
|
||||
|
||||
// sessionCookieName is the name of the session cookie.
|
||||
const sessionCookieName = "agh_session"
|
||||
|
||||
// loginJSON is the JSON structure for authentication.
|
||||
type loginJSON struct {
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// newCookie creates a new authentication cookie.
|
||||
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
|
||||
rateLimiter := a.rateLimiter
|
||||
u, ok := a.findUser(req.Name, req.Password)
|
||||
if !ok {
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.inc(addr)
|
||||
}
|
||||
|
||||
return nil, errors.Error("invalid username or password")
|
||||
}
|
||||
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.remove(addr)
|
||||
}
|
||||
|
||||
sess, err := newSessionToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
a.addSession(sess, &session{
|
||||
userName: u.Name,
|
||||
expire: uint32(now.Unix()) + a.sessionTTL,
|
||||
})
|
||||
|
||||
return &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: hex.EncodeToString(sess),
|
||||
Path: "/",
|
||||
Expires: now.Add(cookieTTL),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// realIP extracts the real IP address of the client from an HTTP request using
|
||||
// the known HTTP headers.
|
||||
//
|
||||
// TODO(a.garipov): Currently, this is basically a copy of a similar function in
|
||||
// module dnsproxy. This should really become a part of module golibs and be
|
||||
// replaced both here and there. Or be replaced in both places by
|
||||
// a well-maintained third-party module.
|
||||
//
|
||||
// TODO(a.garipov): Support header Forwarded from RFC 7329.
|
||||
func realIP(r *http.Request) (ip net.IP, err error) {
|
||||
proxyHeaders := []string{
|
||||
httphdr.CFConnectingIP,
|
||||
httphdr.TrueClientIP,
|
||||
httphdr.XRealIP,
|
||||
}
|
||||
|
||||
for _, h := range proxyHeaders {
|
||||
v := r.Header.Get(h)
|
||||
ip = net.ParseIP(v)
|
||||
if ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If none of the above yielded any results, get the leftmost IP address
|
||||
// from the X-Forwarded-For header.
|
||||
s := r.Header.Get(httphdr.XForwardedFor)
|
||||
ipStrs := strings.SplitN(s, ", ", 2)
|
||||
ip = net.ParseIP(ipStrs[0])
|
||||
if ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// When everything else fails, just return the remote address as understood
|
||||
// by the stdlib.
|
||||
ipStr, err := netutil.SplitHost(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ip from client addr: %w", err)
|
||||
}
|
||||
|
||||
return net.ParseIP(ipStr), nil
|
||||
}
|
||||
|
||||
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
|
||||
// when it writes to the log.
|
||||
func writeErrorWithIP(
|
||||
r *http.Request,
|
||||
w http.ResponseWriter,
|
||||
code int,
|
||||
remoteIP string,
|
||||
format string,
|
||||
args ...any,
|
||||
) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s %s: from ip %s: %s", r.Method, r.Host, r.URL, remoteIP, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// handleLogin is the handler for the POST /control/login HTTP API.
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
req := loginJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var remoteIP string
|
||||
// realIP cannot be used here without taking TrustedProxies into account due
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
//
|
||||
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
r.RemoteAddr,
|
||||
"auth: getting remote address: %s",
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil {
|
||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
||||
writeErrorWithIP(
|
||||
r,
|
||||
w,
|
||||
http.StatusTooManyRequests,
|
||||
remoteIP,
|
||||
"auth: blocked for %s",
|
||||
left,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||
if err != nil {
|
||||
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Use realIP here, since this IP address is only used for logging.
|
||||
ip, err := realIP(r)
|
||||
if err != nil {
|
||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||
}
|
||||
|
||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
h := w.Header()
|
||||
h.Set(httphdr.CacheControl, "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||
h.Set(httphdr.Pragma, "no-cache")
|
||||
h.Set(httphdr.Expires, "0")
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
// handleLogout is the handler for the GET /control/logout HTTP API.
|
||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
respHdr := w.Header()
|
||||
c, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// The user is already logged out.
|
||||
respHdr.Set(httphdr.Location, "/login.html")
|
||||
w.WriteHeader(http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Context.auth.removeSession(c.Value)
|
||||
|
||||
c = &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
respHdr.Set(httphdr.Location, "/login.html")
|
||||
respHdr.Set(httphdr.SetCookie, c.String())
|
||||
w.WriteHeader(http.StatusFound)
|
||||
}
|
||||
|
||||
// RegisterAuthHandlers - register handlers
|
||||
func RegisterAuthHandlers() {
|
||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||
}
|
||||
|
||||
// optionalAuthThird returns true if a user should authenticate first.
|
||||
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||
pref := fmt.Sprintf("auth: raddr %s", r.RemoteAddr)
|
||||
|
||||
if glProcessCookie(r) {
|
||||
log.Debug("%s: authentication is handled by gl-inet submodule", pref)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// redirect to login page if not authenticated
|
||||
isAuthenticated := false
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// Check Basic authentication.
|
||||
user, pass, hasBasic := r.BasicAuth()
|
||||
if hasBasic {
|
||||
_, isAuthenticated = Context.auth.findUser(user, pass)
|
||||
if !isAuthenticated {
|
||||
log.Info("%s: invalid basic authorization value", pref)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
isAuthenticated = res == checkSessionOK
|
||||
if !isAuthenticated {
|
||||
log.Debug("%s: invalid cookie value: %q", pref, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
if isAuthenticated {
|
||||
return false
|
||||
}
|
||||
|
||||
if p := r.URL.Path; p == "/" || p == "/index.html" {
|
||||
if glProcessRedirect(w, r) {
|
||||
log.Debug("%s: redirected to login page by gl-inet submodule", pref)
|
||||
} else {
|
||||
log.Debug("%s: redirected to login page", pref)
|
||||
http.Redirect(w, r, "login.html", http.StatusFound)
|
||||
}
|
||||
} else {
|
||||
log.Debug("%s: responded with forbidden to %s %s", pref, r.Method, p)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("Forbidden"))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Use [http.Handler] consistently everywhere throughout the
|
||||
// project.
|
||||
func optionalAuth(
|
||||
h func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
p := r.URL.Path
|
||||
authRequired := Context.auth != nil && Context.auth.authRequired()
|
||||
if p == "/login.html" {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
// Redirect to the dashboard if already authenticated.
|
||||
res := Context.auth.checkSession(cookie.Value)
|
||||
if res == checkSessionOK {
|
||||
http.Redirect(w, r, "", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("auth: raddr %s: invalid cookie value: %q", r.RemoteAddr, cookie)
|
||||
}
|
||||
} else if isPublicResource(p) {
|
||||
// Process as usual, no additional auth requirements.
|
||||
} else if authRequired {
|
||||
if optionalAuthThird(w, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// isPublicResource returns true if p is a path to a public resource.
|
||||
func isPublicResource(p string) (ok bool) {
|
||||
isAsset, err := path.Match("/assets/*", p)
|
||||
if err != nil {
|
||||
// The only error that is returned from path.Match is
|
||||
// [path.ErrBadPattern]. This is a programmer error.
|
||||
panic(fmt.Errorf("bad asset pattern: %w", err))
|
||||
}
|
||||
|
||||
isLogin, err := path.Match("/login.*", p)
|
||||
if err != nil {
|
||||
// Same as above.
|
||||
panic(fmt.Errorf("bad login pattern: %w", err))
|
||||
}
|
||||
|
||||
return isAsset || isLogin
|
||||
}
|
||||
|
||||
// authHandler is a helper structure that implements [http.Handler].
|
||||
type authHandler struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// ServeHTTP implements the [http.Handler] interface for *authHandler.
|
||||
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
optionalAuth(a.handler.ServeHTTP)(w, r)
|
||||
}
|
||||
|
||||
// optionalAuthHandler returns a authentication handler.
|
||||
func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||
return &authHandler{handler}
|
||||
}
|
||||
@@ -587,7 +587,7 @@ func (c *configuration) write() (err error) {
|
||||
defer c.Unlock()
|
||||
|
||||
if Context.auth != nil {
|
||||
config.Users = Context.auth.usersList()
|
||||
config.Users = Context.auth.GetUsers()
|
||||
}
|
||||
|
||||
if Context.tls != nil {
|
||||
|
||||
@@ -420,7 +420,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
u := &webUser{
|
||||
Name: req.Username,
|
||||
}
|
||||
err = Context.auth.addUser(u, req.Password)
|
||||
err = Context.auth.Add(u, req.Password)
|
||||
if err != nil {
|
||||
Context.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package ipset
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@@ -39,69 +38,19 @@ func newManager(ipsetConf []string) (set Manager, err error) {
|
||||
|
||||
// defaultDial is the default netfilter dialing function.
|
||||
func defaultDial(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
|
||||
c, err := ipset.Dial(pf, conf)
|
||||
conn, err = ipset.Dial(pf, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &queryConn{c}, nil
|
||||
}
|
||||
|
||||
// queryConn is the [ipsetConn] implementation with listAll method, which
|
||||
// returns the list of properties of all available ipsets.
|
||||
type queryConn struct {
|
||||
*ipset.Conn
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ ipsetConn = (*queryConn)(nil)
|
||||
|
||||
// listAll returns the list of properties of all available ipsets.
|
||||
//
|
||||
// TODO(s.chzhen): Use https://github.com/vishvananda/netlink.
|
||||
func (qc *queryConn) listAll() (sets []props, err error) {
|
||||
msg, err := netfilter.MarshalNetlink(
|
||||
netfilter.Header{
|
||||
// The family doesn't seem to matter. See TODO on parseIpsetConfig.
|
||||
Family: qc.Conn.Family,
|
||||
SubsystemID: netfilter.NFSubsysIPSet,
|
||||
MessageType: netfilter.MessageType(ipset.CmdList),
|
||||
Flags: netlink.Request | netlink.Dump,
|
||||
},
|
||||
[]netfilter.Attribute{{
|
||||
Type: uint16(ipset.AttrProtocol),
|
||||
Data: []byte{ipset.Protocol},
|
||||
}},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling netlink msg: %w", err)
|
||||
}
|
||||
|
||||
// We assume it's OK to call a method of an unexported type
|
||||
// [ipset.connector], since there is no negative effects.
|
||||
ms, err := qc.Conn.Conn.Query(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying netlink msg: %w", err)
|
||||
}
|
||||
|
||||
for i, s := range ms {
|
||||
p := props{}
|
||||
err = p.unmarshalMessage(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling netlink msg at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
sets = append(sets, p)
|
||||
}
|
||||
|
||||
return sets, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ipsetConn is the ipset conn interface.
|
||||
type ipsetConn interface {
|
||||
Add(name string, entries ...*ipset.Entry) (err error)
|
||||
Close() (err error)
|
||||
listAll() (sets []props, err error)
|
||||
Header(name string) (p *ipset.HeaderPolicy, err error)
|
||||
}
|
||||
|
||||
// dialer creates an ipsetConn.
|
||||
@@ -109,75 +58,8 @@ type dialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn
|
||||
|
||||
// props contains one Linux Netfilter ipset properties.
|
||||
type props struct {
|
||||
// name of the ipset.
|
||||
name string
|
||||
|
||||
// family of the IP addresses in the ipset.
|
||||
name string
|
||||
family netfilter.ProtoFamily
|
||||
|
||||
// isPersistent indicates that ipset has no timeout parameter and all
|
||||
// entries are added permanently.
|
||||
isPersistent bool
|
||||
}
|
||||
|
||||
// unmarshalMessage unmarshals netlink message and sets the properties of the
|
||||
// ipset.
|
||||
func (p *props) unmarshalMessage(msg netlink.Message) (err error) {
|
||||
_, attrs, err := netfilter.UnmarshalNetlink(msg)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// By default ipset has no timeout parameter.
|
||||
p.isPersistent = true
|
||||
|
||||
for _, a := range attrs {
|
||||
p.parseAttribute(a)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAttribute parses netfilter attribute and sets the name and family of
|
||||
// the ipset.
|
||||
func (p *props) parseAttribute(a netfilter.Attribute) {
|
||||
switch ipset.AttributeType(a.Type) {
|
||||
case ipset.AttrData:
|
||||
p.parseAttrData(a)
|
||||
case ipset.AttrSetName:
|
||||
// Trim the null character.
|
||||
p.name = string(bytes.Trim(a.Data, "\x00"))
|
||||
case ipset.AttrFamily:
|
||||
p.family = netfilter.ProtoFamily(a.Data[0])
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
}
|
||||
|
||||
// parseAttrData parses attribute data and sets the timeout of the ipset.
|
||||
func (p *props) parseAttrData(a netfilter.Attribute) {
|
||||
for _, a := range a.Children {
|
||||
switch ipset.AttributeType(a.Type) {
|
||||
case ipset.AttrTimeout:
|
||||
timeout := a.Uint32()
|
||||
p.isPersistent = timeout == 0
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unit is a convenient alias for struct{}.
|
||||
type unit = struct{}
|
||||
|
||||
// ipsInIpset is the type of a set of IP-address-to-ipset mappings.
|
||||
type ipsInIpset map[ipInIpsetEntry]unit
|
||||
|
||||
// ipInIpsetEntry is the type for entries in an ipsInIpset set.
|
||||
type ipInIpsetEntry struct {
|
||||
ipsetName string
|
||||
ipArr [net.IPv6len]byte
|
||||
}
|
||||
|
||||
// manager is the Linux Netfilter ipset manager.
|
||||
@@ -190,13 +72,6 @@ type manager struct {
|
||||
// mu protects all properties below.
|
||||
mu *sync.Mutex
|
||||
|
||||
// TODO(a.garipov): Currently, the ipset list is static, and we don't
|
||||
// read the IPs already in sets, so we can assume that all incoming IPs
|
||||
// are either added to all corresponding ipsets or not. When that stops
|
||||
// being the case, for example if we add dynamic reconfiguration of
|
||||
// ipsets, this map will need to become a per-ipset-name one.
|
||||
addedIPs ipsInIpset
|
||||
|
||||
ipv4Conn ipsetConn
|
||||
ipv6Conn ipsetConn
|
||||
}
|
||||
@@ -221,8 +96,8 @@ func (m *manager) dialNetfilter(conf *netlink.Config) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseIpsetConfigLine parses one ipset configuration line.
|
||||
func parseIpsetConfigLine(confStr string) (hosts, ipsetNames []string, err error) {
|
||||
// parseIpsetConfig parses one ipset configuration string.
|
||||
func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) {
|
||||
confStr = strings.TrimSpace(confStr)
|
||||
hostsAndNames := strings.Split(confStr, "/")
|
||||
if len(hostsAndNames) != 2 {
|
||||
@@ -250,53 +125,50 @@ func parseIpsetConfigLine(confStr string) (hosts, ipsetNames []string, err error
|
||||
return hosts, ipsetNames, nil
|
||||
}
|
||||
|
||||
// parseIpsetConfig parses the ipset configuration and stores ipsets. It
|
||||
// returns an error if the configuration can't be used.
|
||||
func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) {
|
||||
// The family doesn't seem to matter when we use a header query, so query
|
||||
// only the IPv4 one.
|
||||
// ipsetProps returns the properties of an ipset with the given name.
|
||||
func (m *manager) ipsetProps(name string) (set props, err error) {
|
||||
// The family doesn't seem to matter when we use a header query, so
|
||||
// query only the IPv4 one.
|
||||
//
|
||||
// TODO(a.garipov): Find out if this is a bug or a feature.
|
||||
all, err := m.ipv4Conn.listAll()
|
||||
var res *ipset.HeaderPolicy
|
||||
res, err = m.ipv4Conn.Header(name)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
return set, err
|
||||
}
|
||||
|
||||
for _, p := range all {
|
||||
m.nameToIpset[p.name] = p
|
||||
if res == nil || res.Family == nil {
|
||||
return set, errors.Error("empty response or no family data")
|
||||
}
|
||||
|
||||
for i, confStr := range ipsetConf {
|
||||
var hosts, ipsetNames []string
|
||||
hosts, ipsetNames, err = parseIpsetConfigLine(confStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config line at idx %d: %w", i, err)
|
||||
}
|
||||
|
||||
var ipsets []props
|
||||
ipsets, err = m.ipsets(ipsetNames)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting ipsets from config line at idx %d: %w", i, err)
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
m.domainToIpsets[host] = append(m.domainToIpsets[host], ipsets...)
|
||||
}
|
||||
family := netfilter.ProtoFamily(res.Family.Value)
|
||||
if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 {
|
||||
return set, fmt.Errorf("unexpected ipset family %d", family)
|
||||
}
|
||||
|
||||
return nil
|
||||
return props{
|
||||
name: name,
|
||||
family: family,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ipsets returns currently known ipsets.
|
||||
func (m *manager) ipsets(names []string) (sets []props, err error) {
|
||||
for _, n := range names {
|
||||
p, ok := m.nameToIpset[n]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown ipset %q", n)
|
||||
for _, name := range names {
|
||||
set, ok := m.nameToIpset[name]
|
||||
if ok {
|
||||
sets = append(sets, set)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
sets = append(sets, p)
|
||||
set, err = m.ipsetProps(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying ipset %q: %w", name, err)
|
||||
}
|
||||
|
||||
m.nameToIpset[name] = set
|
||||
sets = append(sets, set)
|
||||
}
|
||||
|
||||
return sets, nil
|
||||
@@ -314,8 +186,6 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err
|
||||
domainToIpsets: make(map[string][]props),
|
||||
|
||||
dial: dial,
|
||||
|
||||
addedIPs: make(ipsInIpset),
|
||||
}
|
||||
|
||||
err = m.dialNetfilter(&netlink.Config{})
|
||||
@@ -331,9 +201,26 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err
|
||||
return nil, fmt.Errorf("dialing netfilter: %w", err)
|
||||
}
|
||||
|
||||
err = m.parseIpsetConfig(ipsetConf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ipsets: %w", err)
|
||||
for i, confStr := range ipsetConf {
|
||||
var hosts, ipsetNames []string
|
||||
hosts, ipsetNames, err = parseIpsetConfig(confStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config line at idx %d: %w", i, err)
|
||||
}
|
||||
|
||||
var ipsets []props
|
||||
ipsets, err = m.ipsets(ipsetNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"getting ipsets from config line at idx %d: %w",
|
||||
i,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
m.domainToIpsets[host] = append(m.domainToIpsets[host], ipsets...)
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
@@ -372,19 +259,8 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
||||
}
|
||||
|
||||
var entries []*ipset.Entry
|
||||
var newAddedEntries []ipInIpsetEntry
|
||||
for _, ip := range ips {
|
||||
e := ipInIpsetEntry{
|
||||
ipsetName: set.name,
|
||||
}
|
||||
copy(e.ipArr[:], ip.To16())
|
||||
|
||||
if _, added := m.addedIPs[e]; added {
|
||||
continue
|
||||
}
|
||||
|
||||
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
|
||||
newAddedEntries = append(newAddedEntries, e)
|
||||
}
|
||||
|
||||
n = len(entries)
|
||||
@@ -407,15 +283,6 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
||||
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
|
||||
}
|
||||
|
||||
// Only add these to the cache once we're sure that all of them were
|
||||
// actually sent to the ipset.
|
||||
for _, e := range newAddedEntries {
|
||||
s := m.nameToIpset[e.ipsetName]
|
||||
if s.isPersistent {
|
||||
m.addedIPs[e] = unit{}
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -21,12 +21,8 @@ type fakeConn struct {
|
||||
ipv4Entries *[]*ipset.Entry
|
||||
ipv6Header *ipset.HeaderPolicy
|
||||
ipv6Entries *[]*ipset.Entry
|
||||
sets []props
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ ipsetConn = (*fakeConn)(nil)
|
||||
|
||||
// Add implements the [ipsetConn] interface for *fakeConn.
|
||||
func (c *fakeConn) Add(name string, entries ...*ipset.Entry) (err error) {
|
||||
if strings.Contains(name, "ipv4") {
|
||||
@@ -47,9 +43,15 @@ func (c *fakeConn) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// listAll implements the [ipsetConn] interface for *fakeConn.
|
||||
func (c *fakeConn) listAll() (sets []props, err error) {
|
||||
return c.sets, nil
|
||||
// Header implements the [ipsetConn] interface for *fakeConn.
|
||||
func (c *fakeConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
|
||||
if strings.Contains(name, "ipv4") {
|
||||
return c.ipv4Header, nil
|
||||
} else if strings.Contains(name, "ipv6") {
|
||||
return c.ipv6Header, nil
|
||||
}
|
||||
|
||||
return nil, errors.Error("test: ipset not found")
|
||||
}
|
||||
|
||||
func TestManager_Add(t *testing.T) {
|
||||
@@ -74,13 +76,6 @@ func TestManager_Add(t *testing.T) {
|
||||
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv6)),
|
||||
},
|
||||
ipv6Entries: &ipv6Entries,
|
||||
sets: []props{{
|
||||
name: "ipv4set",
|
||||
family: netfilter.ProtoIPv4,
|
||||
}, {
|
||||
name: "ipv6set",
|
||||
family: netfilter.ProtoIPv6,
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
# This is a file showing example configuration for AdGuard Home.
|
||||
#
|
||||
# TODO(a.garipov): Move to the top level once the rewrite is over.
|
||||
|
||||
dns:
|
||||
addresses:
|
||||
- '0.0.0.0:53'
|
||||
bootstrap_dns:
|
||||
- '9.9.9.10'
|
||||
- '149.112.112.10'
|
||||
- '2620:fe::10'
|
||||
- '2620:fe::fe:10'
|
||||
upstream_dns:
|
||||
- '8.8.8.8'
|
||||
dns64_prefixes:
|
||||
- '1234::/64'
|
||||
upstream_timeout: 1s
|
||||
bootstrap_prefer_ipv6: true
|
||||
use_dns64: true
|
||||
http:
|
||||
pprof:
|
||||
enabled: true
|
||||
port: 6060
|
||||
addresses:
|
||||
- '0.0.0.0:3000'
|
||||
secure_addresses: []
|
||||
timeout: 5s
|
||||
force_https: true
|
||||
log:
|
||||
verbose: true
|
||||
@@ -1,42 +0,0 @@
|
||||
# AdGuard Home v0.108.0 Changelog DRAFT
|
||||
|
||||
This changelog should be merged into the main one once the next API matures
|
||||
enough.
|
||||
|
||||
## [v0.108.0] - TODO
|
||||
|
||||
### Added
|
||||
|
||||
- The ability to change the port of the pprof debug API.
|
||||
- The ability to log to stderr using `--logFile=stderr`.
|
||||
- The new `--web-addr` flag to set the Web UI address in a `host:port` form.
|
||||
- `SIGHUP` now reloads all configuration from the configuration file ([#5676]).
|
||||
|
||||
### Changed
|
||||
|
||||
#### New HTTP API
|
||||
|
||||
**TODO(a.garipov):** Describe the new API and add a link to the new OpenAPI doc.
|
||||
|
||||
#### Other changes
|
||||
|
||||
- `-h` is now an alias for `--help` instead of the removed `--host`, see below.
|
||||
Use `--web-addr=host:port` to set an address on which to serve the Web UI.
|
||||
|
||||
### Fixed
|
||||
|
||||
- `--check-config` breaking the configuration file ([#4067]).
|
||||
- Inconsistent application of `--work-dir/-w` ([#2598], [#2902]).
|
||||
- The order of `-v/--verbose` and `--version` being significant ([#2893]).
|
||||
|
||||
### Removed
|
||||
|
||||
- The deprecated `--no-mem-optimization` and `--no-etc-hosts` flags.
|
||||
- `--host` and `-p/--port` flags. Use `--web-addr=host:port` to set an address
|
||||
on which to serve the Web UI. `-h` is now an alias for `--help`, see above.
|
||||
|
||||
[#2598]: https://github.com/AdguardTeam/AdGuardHome/issues/2598
|
||||
[#2893]: https://github.com/AdguardTeam/AdGuardHome/issues/2893
|
||||
[#2902]: https://github.com/AdguardTeam/AdGuardHome/issues/2902
|
||||
[#4067]: https://github.com/AdguardTeam/AdGuardHome/issues/4067
|
||||
[#5676]: https://github.com/AdguardTeam/AdGuardHome/issues/5676
|
||||
@@ -1,95 +0,0 @@
|
||||
// Package cmd is the AdGuard Home entry point. It assembles the configuration
|
||||
// file manager, sets up signal processing logic, and so on.
|
||||
//
|
||||
// TODO(a.garipov): Move to the upper-level internal/.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Main is the entry point of AdGuard Home.
|
||||
func Main(embeddedFrontend fs.FS) {
|
||||
start := time.Now()
|
||||
|
||||
cmdName := os.Args[0]
|
||||
opts, err := parseOptions(cmdName, os.Args[1:])
|
||||
exitCode, needExit := processOptions(opts, cmdName, err)
|
||||
if needExit {
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
err = setLog(opts)
|
||||
check(err)
|
||||
|
||||
log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid())
|
||||
|
||||
if opts.workDir != "" {
|
||||
log.Info("changing working directory to %q", opts.workDir)
|
||||
err = os.Chdir(opts.workDir)
|
||||
check(err)
|
||||
}
|
||||
|
||||
frontend, err := frontendFromOpts(opts, embeddedFrontend)
|
||||
check(err)
|
||||
|
||||
confMgrConf := &configmgr.Config{
|
||||
Frontend: frontend,
|
||||
WebAddr: opts.webAddr,
|
||||
Start: start,
|
||||
FileName: opts.confFile,
|
||||
}
|
||||
|
||||
confMgr, err := newConfigMgr(confMgrConf)
|
||||
check(err)
|
||||
|
||||
web := confMgr.Web()
|
||||
err = web.Start()
|
||||
check(err)
|
||||
|
||||
dns := confMgr.DNS()
|
||||
err = dns.Start()
|
||||
check(err)
|
||||
|
||||
sigHdlr := newSignalHandler(
|
||||
confMgrConf,
|
||||
opts.pidFile,
|
||||
web,
|
||||
dns,
|
||||
)
|
||||
|
||||
sigHdlr.handle()
|
||||
}
|
||||
|
||||
// defaultTimeout is the timeout used for some operations where another timeout
|
||||
// hasn't been defined yet.
|
||||
const defaultTimeout = 5 * time.Second
|
||||
|
||||
// ctxWithDefaultTimeout is a helper function that returns a context with
|
||||
// timeout set to defaultTimeout.
|
||||
func ctxWithDefaultTimeout() (ctx context.Context, cancel context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), defaultTimeout)
|
||||
}
|
||||
|
||||
// newConfigMgr returns a new configuration manager using defaultTimeout as the
|
||||
// context timeout.
|
||||
func newConfigMgr(c *configmgr.Config) (m *configmgr.Manager, err error) {
|
||||
ctx, cancel := ctxWithDefaultTimeout()
|
||||
defer cancel()
|
||||
|
||||
return configmgr.New(ctx, c)
|
||||
}
|
||||
|
||||
// check is a simple error-checking helper. It must only be used within Main.
|
||||
func check(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// syslogServiceName is the name of the AdGuard Home service used for writing
|
||||
// logs to the system log.
|
||||
const syslogServiceName = "AdGuardHome"
|
||||
|
||||
// setLog sets up the text logging.
|
||||
//
|
||||
// TODO(a.garipov): Add parameters from configuration file.
|
||||
func setLog(opts *options) (err error) {
|
||||
switch opts.confFile {
|
||||
case "stdout":
|
||||
log.SetOutput(os.Stdout)
|
||||
case "stderr":
|
||||
log.SetOutput(os.Stderr)
|
||||
case "syslog":
|
||||
err = aghos.ConfigureSyslog(syslogServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing syslog: %w", err)
|
||||
}
|
||||
default:
|
||||
// TODO(a.garipov): Use the path.
|
||||
}
|
||||
|
||||
if opts.verbose {
|
||||
log.SetLevel(log.DEBUG)
|
||||
log.Debug("verbose logging enabled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,418 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// options contains all command-line options for the AdGuardHome(.exe) binary.
|
||||
type options struct {
|
||||
// confFile is the path to the configuration file.
|
||||
confFile string
|
||||
|
||||
// logFile is the path to the log file. Special values:
|
||||
//
|
||||
// - "stdout": Write to stdout (the default).
|
||||
// - "stderr": Write to stderr.
|
||||
// - "syslog": Write to the system log.
|
||||
logFile string
|
||||
|
||||
// pidFile is the path to the file where to store the PID.
|
||||
pidFile string
|
||||
|
||||
// serviceAction is the service control action to perform:
|
||||
//
|
||||
// - "install": Installs AdGuard Home as a system service.
|
||||
// - "uninstall": Uninstalls it.
|
||||
// - "status": Prints the service status.
|
||||
// - "start": Starts the previously installed service.
|
||||
// - "stop": Stops the previously installed service.
|
||||
// - "restart": Restarts the previously installed service.
|
||||
// - "reload": Reloads the configuration.
|
||||
// - "run": This is a special command that is not supposed to be used
|
||||
// directly it is specified when we register a service, and it indicates
|
||||
// to the app that it is being run as a service.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
serviceAction string
|
||||
|
||||
// workDir is the path to the working directory. It is applied before all
|
||||
// other configuration is read, so all relative paths are relative to it.
|
||||
workDir string
|
||||
|
||||
// webAddr contains the address on which to serve the web UI.
|
||||
webAddr netip.AddrPort
|
||||
|
||||
// checkConfig, if true, instructs AdGuard Home to check the configuration
|
||||
// file, optionally print an error message to stdout, and exit with a
|
||||
// corresponding exit code.
|
||||
checkConfig bool
|
||||
|
||||
// disableUpdate, if true, prevents AdGuard Home from automatically checking
|
||||
// for updates.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
disableUpdate bool
|
||||
|
||||
// glinetMode enables the GL-Inet compatibility mode.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
glinetMode bool
|
||||
|
||||
// help, if true, instructs AdGuard Home to print the command-line option
|
||||
// help message and quit with a successful exit-code.
|
||||
help bool
|
||||
|
||||
// localFrontend, if true, instructs AdGuard Home to use the local frontend
|
||||
// directory instead of the files compiled into the binary.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
localFrontend bool
|
||||
|
||||
// performUpdate, if true, instructs AdGuard Home to update the current
|
||||
// binary and restart the service in case it's installed.
|
||||
//
|
||||
// TODO(a.garipov): Use.
|
||||
performUpdate bool
|
||||
|
||||
// verbose, if true, instructs AdGuard Home to enable verbose logging.
|
||||
verbose bool
|
||||
|
||||
// version, if true, instructs AdGuard Home to print the version to stdout
|
||||
// and quit with a successful exit-code. If verbose is also true, print a
|
||||
// more detailed version description.
|
||||
version bool
|
||||
}
|
||||
|
||||
// Indexes to help with the [commandLineOptions] initialization.
|
||||
const (
|
||||
confFileIdx = iota
|
||||
logFileIdx
|
||||
pidFileIdx
|
||||
serviceActionIdx
|
||||
workDirIdx
|
||||
webAddrIdx
|
||||
checkConfigIdx
|
||||
disableUpdateIdx
|
||||
glinetModeIdx
|
||||
helpIdx
|
||||
localFrontend
|
||||
performUpdateIdx
|
||||
verboseIdx
|
||||
versionIdx
|
||||
)
|
||||
|
||||
// commandLineOption contains information about a command-line option: its long
|
||||
// and, if there is one, short forms, the value type, the description, and the
|
||||
// default value.
|
||||
type commandLineOption struct {
|
||||
defaultValue any
|
||||
description string
|
||||
long string
|
||||
short string
|
||||
valueType string
|
||||
}
|
||||
|
||||
// commandLineOptions are all command-line options currently supported by
|
||||
// AdGuard Home.
|
||||
var commandLineOptions = []*commandLineOption{
|
||||
confFileIdx: {
|
||||
// TODO(a.garipov): Remove the directory when the new code is ready.
|
||||
defaultValue: "internal/next/AdGuardHome.yaml",
|
||||
description: "Path to the config file.",
|
||||
long: "config",
|
||||
short: "c",
|
||||
valueType: "path",
|
||||
},
|
||||
|
||||
logFileIdx: {
|
||||
defaultValue: "stdout",
|
||||
description: `Path to log file. Special values include "stdout", "stderr", and "syslog".`,
|
||||
long: "logfile",
|
||||
short: "l",
|
||||
valueType: "path",
|
||||
},
|
||||
|
||||
pidFileIdx: {
|
||||
defaultValue: "",
|
||||
description: "Path to the file where to store the PID.",
|
||||
long: "pidfile",
|
||||
short: "",
|
||||
valueType: "path",
|
||||
},
|
||||
|
||||
serviceActionIdx: {
|
||||
defaultValue: "",
|
||||
description: `Service control action: "status", "install" (as a service), ` +
|
||||
`"uninstall" (as a service), "start", "stop", "restart", "reload" (configuration).`,
|
||||
long: "service",
|
||||
short: "s",
|
||||
valueType: "action",
|
||||
},
|
||||
|
||||
workDirIdx: {
|
||||
defaultValue: "",
|
||||
description: `Path to the working directory. ` +
|
||||
`It is applied before all other configuration is read, ` +
|
||||
`so all relative paths are relative to it.`,
|
||||
long: "work-dir",
|
||||
short: "w",
|
||||
valueType: "path",
|
||||
},
|
||||
|
||||
webAddrIdx: {
|
||||
defaultValue: netip.AddrPort{},
|
||||
description: `Address to serve the web UI on, in the host:port format.`,
|
||||
long: "web-addr",
|
||||
short: "",
|
||||
valueType: "host:port",
|
||||
},
|
||||
|
||||
checkConfigIdx: {
|
||||
defaultValue: false,
|
||||
description: "Check configuration, print errors to stdout, and quit.",
|
||||
long: "check-config",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
disableUpdateIdx: {
|
||||
defaultValue: false,
|
||||
description: "Disable automatic update checking.",
|
||||
long: "no-check-update",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
glinetModeIdx: {
|
||||
defaultValue: false,
|
||||
description: "Run in GL-Inet compatibility mode.",
|
||||
long: "glinet",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
helpIdx: {
|
||||
defaultValue: false,
|
||||
description: "Print this help message and quit.",
|
||||
long: "help",
|
||||
short: "h",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
localFrontend: {
|
||||
defaultValue: false,
|
||||
description: "Use local frontend directories.",
|
||||
long: "local-frontend",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
performUpdateIdx: {
|
||||
defaultValue: false,
|
||||
description: "Update the current binary and restart the service in case it's installed.",
|
||||
long: "update",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
verboseIdx: {
|
||||
defaultValue: false,
|
||||
description: "Enable verbose logging.",
|
||||
long: "verbose",
|
||||
short: "v",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
versionIdx: {
|
||||
defaultValue: false,
|
||||
description: `Print the version to stdout and quit. ` +
|
||||
`Print a more detailed version description with -v.`,
|
||||
long: "version",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
}
|
||||
|
||||
// parseOptions parses the command-line options for AdGuardHome.
|
||||
func parseOptions(cmdName string, args []string) (opts *options, err error) {
|
||||
flags := flag.NewFlagSet(cmdName, flag.ContinueOnError)
|
||||
|
||||
opts = &options{}
|
||||
for i, fieldPtr := range []any{
|
||||
confFileIdx: &opts.confFile,
|
||||
logFileIdx: &opts.logFile,
|
||||
pidFileIdx: &opts.pidFile,
|
||||
serviceActionIdx: &opts.serviceAction,
|
||||
workDirIdx: &opts.workDir,
|
||||
webAddrIdx: &opts.webAddr,
|
||||
checkConfigIdx: &opts.checkConfig,
|
||||
disableUpdateIdx: &opts.disableUpdate,
|
||||
glinetModeIdx: &opts.glinetMode,
|
||||
helpIdx: &opts.help,
|
||||
localFrontend: &opts.localFrontend,
|
||||
performUpdateIdx: &opts.performUpdate,
|
||||
verboseIdx: &opts.verbose,
|
||||
versionIdx: &opts.version,
|
||||
} {
|
||||
addOption(flags, fieldPtr, commandLineOptions[i])
|
||||
}
|
||||
|
||||
flags.Usage = func() { usage(cmdName, os.Stderr) }
|
||||
|
||||
err = flags.Parse(args)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// addOption adds the command-line option described by o to flags using fieldPtr
|
||||
// as the pointer to the value.
|
||||
func addOption(flags *flag.FlagSet, fieldPtr any, o *commandLineOption) {
|
||||
switch fieldPtr := fieldPtr.(type) {
|
||||
case *string:
|
||||
flags.StringVar(fieldPtr, o.long, o.defaultValue.(string), o.description)
|
||||
if o.short != "" {
|
||||
flags.StringVar(fieldPtr, o.short, o.defaultValue.(string), o.description)
|
||||
}
|
||||
case *bool:
|
||||
flags.BoolVar(fieldPtr, o.long, o.defaultValue.(bool), o.description)
|
||||
if o.short != "" {
|
||||
flags.BoolVar(fieldPtr, o.short, o.defaultValue.(bool), o.description)
|
||||
}
|
||||
case encoding.TextUnmarshaler:
|
||||
flags.TextVar(fieldPtr, o.long, o.defaultValue.(encoding.TextMarshaler), o.description)
|
||||
if o.short != "" {
|
||||
flags.TextVar(fieldPtr, o.short, o.defaultValue.(encoding.TextMarshaler), o.description)
|
||||
}
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected field pointer type %T", fieldPtr))
|
||||
}
|
||||
}
|
||||
|
||||
// usage prints a usage message similar to the one printed by package flag but
|
||||
// taking long vs. short versions into account as well as using more informative
|
||||
// value hints.
|
||||
func usage(cmdName string, output io.Writer) {
|
||||
options := slices.Clone(commandLineOptions)
|
||||
slices.SortStableFunc(options, func(a, b *commandLineOption) (res int) {
|
||||
return strings.Compare(a.long, b.long)
|
||||
})
|
||||
|
||||
b := &strings.Builder{}
|
||||
_, _ = fmt.Fprintf(b, "Usage of %s:\n", cmdName)
|
||||
|
||||
for _, o := range options {
|
||||
writeUsageLine(b, o)
|
||||
|
||||
// Use four spaces before the tab to trigger good alignment for both 4-
|
||||
// and 8-space tab stops.
|
||||
if shouldIncludeDefault(o.defaultValue) {
|
||||
_, _ = fmt.Fprintf(b, " \t%s (Default value: %q)\n", o.description, o.defaultValue)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(b, " \t%s\n", o.description)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = io.WriteString(output, b.String())
|
||||
}
|
||||
|
||||
// shouldIncludeDefault returns true if this default value should be printed.
|
||||
func shouldIncludeDefault(v any) (ok bool) {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v != ""
|
||||
default:
|
||||
return v == nil
|
||||
}
|
||||
}
|
||||
|
||||
// writeUsageLine writes the usage line for the provided command-line option.
|
||||
func writeUsageLine(b *strings.Builder, o *commandLineOption) {
|
||||
if o.short == "" {
|
||||
if o.valueType == "" {
|
||||
_, _ = fmt.Fprintf(b, " --%s\n", o.long)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(b, " --%s=%s\n", o.long, o.valueType)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if o.valueType == "" {
|
||||
_, _ = fmt.Fprintf(b, " --%s/-%s\n", o.long, o.short)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(b, " --%[1]s=%[3]s/-%[2]s %[3]s\n", o.long, o.short, o.valueType)
|
||||
}
|
||||
}
|
||||
|
||||
// processOptions decides if AdGuard Home should exit depending on the results
|
||||
// of command-line option parsing.
|
||||
func processOptions(
|
||||
opts *options,
|
||||
cmdName string,
|
||||
parseErr error,
|
||||
) (exitCode int, needExit bool) {
|
||||
if parseErr != nil {
|
||||
// Assume that usage has already been printed.
|
||||
return statusArgumentError, true
|
||||
}
|
||||
|
||||
if opts.help {
|
||||
usage(cmdName, os.Stdout)
|
||||
|
||||
return statusSuccess, true
|
||||
}
|
||||
|
||||
if opts.version {
|
||||
if opts.verbose {
|
||||
fmt.Println(version.Verbose())
|
||||
} else {
|
||||
fmt.Printf("AdGuard Home %s\n", version.Version())
|
||||
}
|
||||
|
||||
return statusSuccess, true
|
||||
}
|
||||
|
||||
if opts.checkConfig {
|
||||
err := configmgr.Validate(opts.confFile)
|
||||
if err != nil {
|
||||
_, _ = io.WriteString(os.Stdout, err.Error()+"\n")
|
||||
|
||||
return statusError, true
|
||||
}
|
||||
|
||||
return statusSuccess, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// frontendFromOpts returns the frontend to use based on the options.
|
||||
func frontendFromOpts(opts *options, embeddedFrontend fs.FS) (frontend fs.FS, err error) {
|
||||
const frontendSubdir = "build/static"
|
||||
|
||||
if opts.localFrontend {
|
||||
log.Info("warning: using local frontend files")
|
||||
|
||||
return os.DirFS(frontendSubdir), nil
|
||||
}
|
||||
|
||||
return fs.Sub(embeddedFrontend, frontendSubdir)
|
||||
}
|
||||
@@ -1,167 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
)
|
||||
|
||||
// signalHandler processes incoming signals and shuts services down.
|
||||
type signalHandler struct {
|
||||
// confMgrConf contains the configuration parameters for the configuration
|
||||
// manager.
|
||||
confMgrConf *configmgr.Config
|
||||
|
||||
// signal is the channel to which OS signals are sent.
|
||||
signal chan os.Signal
|
||||
|
||||
// pidFile is the path to the file where to store the PID, if any.
|
||||
pidFile string
|
||||
|
||||
// services are the services that are shut down before application exiting.
|
||||
services []agh.Service
|
||||
}
|
||||
|
||||
// handle processes OS signals.
|
||||
func (h *signalHandler) handle() {
|
||||
defer log.OnPanic("signalHandler.handle")
|
||||
|
||||
h.writePID()
|
||||
|
||||
for sig := range h.signal {
|
||||
log.Info("sighdlr: received signal %q", sig)
|
||||
|
||||
if aghos.IsReconfigureSignal(sig) {
|
||||
h.reconfigure()
|
||||
} else if aghos.IsShutdownSignal(sig) {
|
||||
status := h.shutdown()
|
||||
h.removePID()
|
||||
|
||||
log.Info("sighdlr: exiting with status %d", status)
|
||||
|
||||
os.Exit(status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reconfigure rereads the configuration file and updates and restarts services.
|
||||
func (h *signalHandler) reconfigure() {
|
||||
log.Info("sighdlr: reconfiguring adguard home")
|
||||
|
||||
status := h.shutdown()
|
||||
if status != statusSuccess {
|
||||
log.Info("sighdlr: reconfiguring: exiting with status %d", status)
|
||||
|
||||
os.Exit(status)
|
||||
}
|
||||
|
||||
// TODO(a.garipov): This is a very rough way to do it. Some services can be
|
||||
// reconfigured without the full shutdown, and the error handling is
|
||||
// currently not the best.
|
||||
|
||||
confMgr, err := newConfigMgr(h.confMgrConf)
|
||||
check(err)
|
||||
|
||||
web := confMgr.Web()
|
||||
err = web.Start()
|
||||
check(err)
|
||||
|
||||
dns := confMgr.DNS()
|
||||
err = dns.Start()
|
||||
check(err)
|
||||
|
||||
h.services = []agh.Service{
|
||||
dns,
|
||||
web,
|
||||
}
|
||||
|
||||
log.Info("sighdlr: successfully reconfigured adguard home")
|
||||
}
|
||||
|
||||
// Exit status constants.
|
||||
const (
|
||||
statusSuccess = 0
|
||||
statusError = 1
|
||||
statusArgumentError = 2
|
||||
)
|
||||
|
||||
// shutdown gracefully shuts down all services.
|
||||
func (h *signalHandler) shutdown() (status int) {
|
||||
ctx, cancel := ctxWithDefaultTimeout()
|
||||
defer cancel()
|
||||
|
||||
status = statusSuccess
|
||||
|
||||
log.Info("sighdlr: shutting down services")
|
||||
for i, service := range h.services {
|
||||
err := service.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
|
||||
status = statusError
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// newSignalHandler returns a new signalHandler that shuts down svcs.
|
||||
func newSignalHandler(
|
||||
confMgrConf *configmgr.Config,
|
||||
pidFile string,
|
||||
svcs ...agh.Service,
|
||||
) (h *signalHandler) {
|
||||
h = &signalHandler{
|
||||
confMgrConf: confMgrConf,
|
||||
signal: make(chan os.Signal, 1),
|
||||
pidFile: pidFile,
|
||||
services: svcs,
|
||||
}
|
||||
|
||||
aghos.NotifyShutdownSignal(h.signal)
|
||||
aghos.NotifyReconfigureSignal(h.signal)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// writePID writes the PID to the file, if needed. Any errors are reported to
|
||||
// log.
|
||||
func (h *signalHandler) writePID() {
|
||||
if h.pidFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Use 8, since most PIDs will fit.
|
||||
data := make([]byte, 0, 8)
|
||||
data = strconv.AppendInt(data, int64(os.Getpid()), 10)
|
||||
data = append(data, '\n')
|
||||
|
||||
err := maybe.WriteFile(h.pidFile, data, 0o644)
|
||||
if err != nil {
|
||||
log.Error("sighdlr: writing pidfile: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("sighdlr: wrote pid to %q", h.pidFile)
|
||||
}
|
||||
|
||||
// removePID removes the PID file, if any.
|
||||
func (h *signalHandler) removePID() {
|
||||
if h.pidFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := os.Remove(h.pidFile)
|
||||
if err != nil {
|
||||
log.Error("sighdlr: removing pidfile: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("sighdlr: removed pid at %q", h.pidFile)
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package configmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
)
|
||||
|
||||
// Configuration Structures
|
||||
|
||||
// config is the top-level on-disk configuration structure.
|
||||
type config struct {
|
||||
DNS *dnsConfig `yaml:"dns"`
|
||||
HTTP *httpConfig `yaml:"http"`
|
||||
Log *logConfig `yaml:"log"`
|
||||
// TODO(a.garipov): Use.
|
||||
SchemaVersion int `yaml:"schema_version"`
|
||||
}
|
||||
|
||||
const errNoConf errors.Error = "configuration not found"
|
||||
|
||||
// validate returns an error if the configuration structure is invalid.
|
||||
func (c *config) validate() (err error) {
|
||||
if c == nil {
|
||||
return errNoConf
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Add more validations.
|
||||
|
||||
// Keep this in the same order as the fields in the config.
|
||||
validators := []struct {
|
||||
validate func() (err error)
|
||||
name string
|
||||
}{{
|
||||
validate: c.DNS.validate,
|
||||
name: "dns",
|
||||
}, {
|
||||
validate: c.HTTP.validate,
|
||||
name: "http",
|
||||
}, {
|
||||
validate: c.Log.validate,
|
||||
name: "log",
|
||||
}}
|
||||
|
||||
for _, v := range validators {
|
||||
err = v.validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", v.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dnsConfig is the on-disk DNS configuration.
|
||||
type dnsConfig struct {
|
||||
Addresses []netip.AddrPort `yaml:"addresses"`
|
||||
BootstrapDNS []string `yaml:"bootstrap_dns"`
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
DNS64Prefixes []netip.Prefix `yaml:"dns64_prefixes"`
|
||||
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
|
||||
BootstrapPreferIPv6 bool `yaml:"bootstrap_prefer_ipv6"`
|
||||
UseDNS64 bool `yaml:"use_dns64"`
|
||||
}
|
||||
|
||||
// validate returns an error if the DNS configuration structure is invalid.
|
||||
//
|
||||
// TODO(a.garipov): Add more validations.
|
||||
func (c *dnsConfig) validate() (err error) {
|
||||
// TODO(a.garipov): Add more validations.
|
||||
switch {
|
||||
case c == nil:
|
||||
return errNoConf
|
||||
case c.UpstreamTimeout.Duration <= 0:
|
||||
return newMustBePositiveError("upstream_timeout", c.UpstreamTimeout)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// httpConfig is the on-disk web API configuration.
|
||||
type httpConfig struct {
|
||||
Pprof *httpPprofConfig `yaml:"pprof"`
|
||||
|
||||
// TODO(a.garipov): Document the configuration change.
|
||||
Addresses []netip.AddrPort `yaml:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `yaml:"secure_addresses"`
|
||||
Timeout timeutil.Duration `yaml:"timeout"`
|
||||
ForceHTTPS bool `yaml:"force_https"`
|
||||
}
|
||||
|
||||
// validate returns an error if the HTTP configuration structure is invalid.
|
||||
//
|
||||
// TODO(a.garipov): Add more validations.
|
||||
func (c *httpConfig) validate() (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
return errNoConf
|
||||
case c.Timeout.Duration <= 0:
|
||||
return newMustBePositiveError("timeout", c.Timeout)
|
||||
default:
|
||||
return c.Pprof.validate()
|
||||
}
|
||||
}
|
||||
|
||||
// httpPprofConfig is the on-disk pprof configuration.
|
||||
type httpPprofConfig struct {
|
||||
Port uint16 `yaml:"port"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// validate returns an error if the pprof configuration structure is invalid.
|
||||
func (c *httpPprofConfig) validate() (err error) {
|
||||
if c == nil {
|
||||
return errNoConf
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logConfig is the on-disk web API configuration.
|
||||
type logConfig struct {
|
||||
// TODO(a.garipov): Use.
|
||||
Verbose bool `yaml:"verbose"`
|
||||
}
|
||||
|
||||
// validate returns an error if the HTTP configuration structure is invalid.
|
||||
//
|
||||
// TODO(a.garipov): Add more validations.
|
||||
func (c *logConfig) validate() (err error) {
|
||||
if c == nil {
|
||||
return errNoConf
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,301 +0,0 @@
|
||||
// Package configmgr defines the AdGuard Home on-disk configuration entities and
|
||||
// configuration manager.
|
||||
//
|
||||
// TODO(a.garipov): Add tests.
|
||||
package configmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
"golang.org/x/exp/slices"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Configuration Manager
|
||||
|
||||
// Manager handles full and partial changes in the configuration, persisting
|
||||
// them to disk if necessary.
|
||||
//
|
||||
// TODO(a.garipov): Support missing configs and default values.
|
||||
type Manager struct {
|
||||
// updMu makes sure that at most one reconfiguration is performed at a time.
|
||||
// updMu protects all fields below.
|
||||
updMu *sync.RWMutex
|
||||
|
||||
// dns is the DNS service.
|
||||
dns *dnssvc.Service
|
||||
|
||||
// Web is the Web API service.
|
||||
web *websvc.Service
|
||||
|
||||
// current is the current configuration.
|
||||
current *config
|
||||
|
||||
// fileName is the name of the configuration file.
|
||||
fileName string
|
||||
}
|
||||
|
||||
// Validate returns an error if the configuration file with the given name does
|
||||
// not exist or is invalid.
|
||||
func Validate(fileName string) (err error) {
|
||||
conf, err := read(fileName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return conf.validate()
|
||||
}
|
||||
|
||||
// Config contains the configuration parameters for the configuration manager.
|
||||
type Config struct {
|
||||
// Frontend is the filesystem with the frontend files.
|
||||
Frontend fs.FS
|
||||
|
||||
// WebAddr is the initial or override address for the Web UI. It is not
|
||||
// written to the configuration file.
|
||||
WebAddr netip.AddrPort
|
||||
|
||||
// Start is the time of start of AdGuard Home.
|
||||
Start time.Time
|
||||
|
||||
// FileName is the path to the configuration file.
|
||||
FileName string
|
||||
}
|
||||
|
||||
// New creates a new *Manager that persists changes to the file pointed to by
|
||||
// c.FileName. It reads the configuration file and populates the service
|
||||
// fields. c must not be nil.
|
||||
func New(ctx context.Context, c *Config) (m *Manager, err error) {
|
||||
conf, err := read(c.FileName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = conf.validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating config: %w", err)
|
||||
}
|
||||
|
||||
m = &Manager{
|
||||
updMu: &sync.RWMutex{},
|
||||
current: conf,
|
||||
fileName: c.FileName,
|
||||
}
|
||||
|
||||
err = m.assemble(ctx, conf, c.Frontend, c.WebAddr, c.Start)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating config manager: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// read reads and decodes configuration from the provided filename.
|
||||
func read(fileName string) (conf *config, err error) {
|
||||
defer func() { err = errors.Annotate(err, "reading config: %w") }()
|
||||
|
||||
conf = &config{}
|
||||
f, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
err = yaml.NewDecoder(f).Decode(conf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// assemble creates all services and puts them into the corresponding fields.
|
||||
// The fields of conf must not be modified after calling assemble.
|
||||
func (m *Manager) assemble(
|
||||
ctx context.Context,
|
||||
conf *config,
|
||||
frontend fs.FS,
|
||||
webAddr netip.AddrPort,
|
||||
start time.Time,
|
||||
) (err error) {
|
||||
dnsConf := &dnssvc.Config{
|
||||
Addresses: conf.DNS.Addresses,
|
||||
BootstrapServers: conf.DNS.BootstrapDNS,
|
||||
UpstreamServers: conf.DNS.UpstreamDNS,
|
||||
DNS64Prefixes: conf.DNS.DNS64Prefixes,
|
||||
UpstreamTimeout: conf.DNS.UpstreamTimeout.Duration,
|
||||
BootstrapPreferIPv6: conf.DNS.BootstrapPreferIPv6,
|
||||
UseDNS64: conf.DNS.UseDNS64,
|
||||
}
|
||||
err = m.updateDNS(ctx, dnsConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("assembling dnssvc: %w", err)
|
||||
}
|
||||
|
||||
webSvcConf := &websvc.Config{
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Port: conf.HTTP.Pprof.Port,
|
||||
Enabled: conf.HTTP.Pprof.Enabled,
|
||||
},
|
||||
ConfigManager: m,
|
||||
Frontend: frontend,
|
||||
// TODO(a.garipov): Fill from config file.
|
||||
TLS: nil,
|
||||
Start: start,
|
||||
Addresses: conf.HTTP.Addresses,
|
||||
SecureAddresses: conf.HTTP.SecureAddresses,
|
||||
OverrideAddress: webAddr,
|
||||
Timeout: conf.HTTP.Timeout.Duration,
|
||||
ForceHTTPS: conf.HTTP.ForceHTTPS,
|
||||
}
|
||||
|
||||
err = m.updateWeb(ctx, webSvcConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("assembling websvc: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// write writes the current configuration to disk.
|
||||
func (m *Manager) write() (err error) {
|
||||
b, err := yaml.Marshal(m.current)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding: %w", err)
|
||||
}
|
||||
|
||||
err = maybe.WriteFile(m.fileName, b, 0o644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing: %w", err)
|
||||
}
|
||||
|
||||
log.Info("configmgr: written to %q", m.fileName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DNS returns the current DNS service. It is safe for concurrent use.
|
||||
func (m *Manager) DNS() (dns agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||
m.updMu.RLock()
|
||||
defer m.updMu.RUnlock()
|
||||
|
||||
return m.dns
|
||||
}
|
||||
|
||||
// UpdateDNS implements the [websvc.ConfigManager] interface for *Manager. The
|
||||
// fields of c must not be modified after calling UpdateDNS.
|
||||
func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||
m.updMu.Lock()
|
||||
defer m.updMu.Unlock()
|
||||
|
||||
// TODO(a.garipov): Update and write the configuration file. Return an
|
||||
// error if something went wrong.
|
||||
|
||||
err = m.updateDNS(ctx, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reassembling dnssvc: %w", err)
|
||||
}
|
||||
|
||||
m.updateCurrentDNS(c)
|
||||
|
||||
return m.write()
|
||||
}
|
||||
|
||||
// updateDNS recreates the DNS service. m.updMu is expected to be locked.
|
||||
func (m *Manager) updateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||
if prev := m.dns; prev != nil {
|
||||
err = prev.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shutting down dns svc: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating dns svc: %w", err)
|
||||
}
|
||||
|
||||
m.dns = svc
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateCurrentDNS updates the DNS configuration in the current config.
|
||||
func (m *Manager) updateCurrentDNS(c *dnssvc.Config) {
|
||||
m.current.DNS.Addresses = slices.Clone(c.Addresses)
|
||||
m.current.DNS.BootstrapDNS = slices.Clone(c.BootstrapServers)
|
||||
m.current.DNS.UpstreamDNS = slices.Clone(c.UpstreamServers)
|
||||
m.current.DNS.DNS64Prefixes = slices.Clone(c.DNS64Prefixes)
|
||||
m.current.DNS.UpstreamTimeout = timeutil.Duration{Duration: c.UpstreamTimeout}
|
||||
m.current.DNS.BootstrapPreferIPv6 = c.BootstrapPreferIPv6
|
||||
m.current.DNS.UseDNS64 = c.UseDNS64
|
||||
}
|
||||
|
||||
// Web returns the current web service. It is safe for concurrent use.
|
||||
func (m *Manager) Web() (web agh.ServiceWithConfig[*websvc.Config]) {
|
||||
m.updMu.RLock()
|
||||
defer m.updMu.RUnlock()
|
||||
|
||||
return m.web
|
||||
}
|
||||
|
||||
// UpdateWeb implements the [websvc.ConfigManager] interface for *Manager. The
|
||||
// fields of c must not be modified after calling UpdateWeb.
|
||||
func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
|
||||
m.updMu.Lock()
|
||||
defer m.updMu.Unlock()
|
||||
|
||||
err = m.updateWeb(ctx, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reassembling websvc: %w", err)
|
||||
}
|
||||
|
||||
m.updateCurrentWeb(c)
|
||||
|
||||
return m.write()
|
||||
}
|
||||
|
||||
// updateWeb recreates the web service. m.upd is expected to be locked.
|
||||
func (m *Manager) updateWeb(ctx context.Context, c *websvc.Config) (err error) {
|
||||
if prev := m.web; prev != nil {
|
||||
err = prev.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shutting down web svc: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.web, err = websvc.New(c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating web svc: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateCurrentWeb updates the web configuration in the current config.
|
||||
func (m *Manager) updateCurrentWeb(c *websvc.Config) {
|
||||
// TODO(a.garipov): Update pprof from API?
|
||||
|
||||
m.current.HTTP.Addresses = slices.Clone(c.Addresses)
|
||||
m.current.HTTP.SecureAddresses = slices.Clone(c.SecureAddresses)
|
||||
m.current.HTTP.Timeout = timeutil.Duration{Duration: c.Timeout}
|
||||
m.current.HTTP.ForceHTTPS = c.ForceHTTPS
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package configmgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// numberOrDuration is the constraint for integer types along with
|
||||
// timeutil.Duration.
|
||||
type numberOrDuration interface {
|
||||
constraints.Integer | timeutil.Duration
|
||||
}
|
||||
|
||||
// newMustBePositiveError returns an error about the value that must be positive
|
||||
// but isn't. prop is the name of the property to mention in the error message.
|
||||
//
|
||||
// TODO(a.garipov): Consider moving such helpers to golibs and use in AdGuardDNS
|
||||
// as well.
|
||||
func newMustBePositiveError[T numberOrDuration](prop string, v T) (err error) {
|
||||
if s, ok := any(v).(fmt.Stringer); ok {
|
||||
return fmt.Errorf("%s must be positive, got %s", prop, s)
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s must be positive, got %d", prop, v)
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the AdGuard Home DNS service configuration structure.
|
||||
//
|
||||
// TODO(a.garipov): Add timeout for incoming requests.
|
||||
type Config struct {
|
||||
// Addresses are the addresses on which to serve plain DNS queries.
|
||||
Addresses []netip.AddrPort
|
||||
|
||||
// BootstrapServers are the addresses of DNS servers used for bootstrapping
|
||||
// the upstream DNS server addresses.
|
||||
BootstrapServers []string
|
||||
|
||||
// UpstreamServers are the upstream DNS server addresses to use.
|
||||
UpstreamServers []string
|
||||
|
||||
// DNS64Prefixes is a slice of NAT64 prefixes to be used for DNS64. See
|
||||
// also [Config.UseDNS64].
|
||||
DNS64Prefixes []netip.Prefix
|
||||
|
||||
// UpstreamTimeout is the timeout for upstream requests.
|
||||
UpstreamTimeout time.Duration
|
||||
|
||||
// BootstrapPreferIPv6, if true, instructs the bootstrapper to prefer IPv6
|
||||
// addresses to IPv4 ones when bootstrapping.
|
||||
BootstrapPreferIPv6 bool
|
||||
|
||||
// UseDNS64, if true, enables DNS64 protection for incoming requests.
|
||||
UseDNS64 bool
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
// Package dnssvc contains the AdGuard Home DNS service.
|
||||
//
|
||||
// TODO(a.garipov): Define, if all methods of a *Service should work with a nil
|
||||
// receiver.
|
||||
package dnssvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
|
||||
// and replacement of module dnsproxy.
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
)
|
||||
|
||||
// Service is the AdGuard Home DNS service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
//
|
||||
// TODO(a.garipov): Consider saving a [*proxy.Config] instance for those
|
||||
// fields that are only used in [New] and [Service.Config].
|
||||
type Service struct {
|
||||
proxy *proxy.Proxy
|
||||
bootstraps []string
|
||||
upstreams []string
|
||||
dns64Prefixes []netip.Prefix
|
||||
upsTimeout time.Duration
|
||||
running atomic.Bool
|
||||
bootstrapPreferIPv6 bool
|
||||
useDNS64 bool
|
||||
}
|
||||
|
||||
// New returns a new properly initialized *Service. If c is nil, svc is a nil
|
||||
// *Service that does nothing. The fields of c must not be modified after
|
||||
// calling New.
|
||||
func New(c *Config) (svc *Service, err error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
bootstraps: c.BootstrapServers,
|
||||
upstreams: c.UpstreamServers,
|
||||
dns64Prefixes: c.DNS64Prefixes,
|
||||
upsTimeout: c.UpstreamTimeout,
|
||||
bootstrapPreferIPv6: c.BootstrapPreferIPv6,
|
||||
useDNS64: c.UseDNS64,
|
||||
}
|
||||
|
||||
upstreams, err := addressesToUpstreams(
|
||||
c.UpstreamServers,
|
||||
c.BootstrapServers,
|
||||
c.UpstreamTimeout,
|
||||
c.BootstrapPreferIPv6,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting upstreams: %w", err)
|
||||
}
|
||||
|
||||
svc.proxy = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UDPListenAddr: udpAddrs(c.Addresses),
|
||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: upstreams,
|
||||
},
|
||||
UseDNS64: c.UseDNS64,
|
||||
DNS64Prefs: c.DNS64Prefixes,
|
||||
},
|
||||
}
|
||||
|
||||
err = svc.proxy.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: %w", err)
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||
// accepts a slice of addresses and other upstream parameters, and returns a
|
||||
// slice of upstreams.
|
||||
func addressesToUpstreams(
|
||||
upsStrs []string,
|
||||
bootstraps []string,
|
||||
timeout time.Duration,
|
||||
preferIPv6 bool,
|
||||
) (upstreams []upstream.Upstream, err error) {
|
||||
upstreams = make([]upstream.Upstream, len(upsStrs))
|
||||
for i, upsStr := range upsStrs {
|
||||
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: timeout,
|
||||
PreferIPv6: preferIPv6,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
// tcpAddrs converts []netip.AddrPort into []*net.TCPAddr.
|
||||
func tcpAddrs(addrPorts []netip.AddrPort) (tcpAddrs []*net.TCPAddr) {
|
||||
if addrPorts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tcpAddrs = make([]*net.TCPAddr, len(addrPorts))
|
||||
for i, a := range addrPorts {
|
||||
tcpAddrs[i] = net.TCPAddrFromAddrPort(a)
|
||||
}
|
||||
|
||||
return tcpAddrs
|
||||
}
|
||||
|
||||
// udpAddrs converts []netip.AddrPort into []*net.UDPAddr.
|
||||
func udpAddrs(addrPorts []netip.AddrPort) (udpAddrs []*net.UDPAddr) {
|
||||
if addrPorts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
udpAddrs = make([]*net.UDPAddr, len(addrPorts))
|
||||
for i, a := range addrPorts {
|
||||
udpAddrs[i] = net.UDPAddrFromAddrPort(a)
|
||||
}
|
||||
|
||||
return udpAddrs
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all DNS servers have tried to start, but there is no
|
||||
// guarantee that they did. Errors from the servers are written to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
|
||||
// tell when all servers are actually up, so at best this is merely an
|
||||
// assumption.
|
||||
svc.running.Store(err == nil)
|
||||
}()
|
||||
|
||||
return svc.proxy.Start()
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
// nil.
|
||||
func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return svc.proxy.Stop()
|
||||
}
|
||||
|
||||
// Config returns the current configuration of the web service. Config must not
|
||||
// be called simultaneously with Start. If svc was initialized with ":0"
|
||||
// addresses, addrs will not return the actual bound ports until Start is
|
||||
// finished.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
// TODO(a.garipov): Do we need to get the TCP addresses separately?
|
||||
|
||||
var addrs []netip.AddrPort
|
||||
if svc.running.Load() {
|
||||
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
|
||||
addrs = make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.(*net.UDPAddr).AddrPort()
|
||||
}
|
||||
} else {
|
||||
conf := svc.proxy.Config
|
||||
udpAddrs := conf.UDPListenAddr
|
||||
addrs = make([]netip.AddrPort, len(udpAddrs))
|
||||
for i, a := range udpAddrs {
|
||||
addrs[i] = a.AddrPort()
|
||||
}
|
||||
}
|
||||
|
||||
c = &Config{
|
||||
Addresses: addrs,
|
||||
BootstrapServers: svc.bootstraps,
|
||||
UpstreamServers: svc.upstreams,
|
||||
DNS64Prefixes: svc.dns64Prefixes,
|
||||
UpstreamTimeout: svc.upsTimeout,
|
||||
BootstrapPreferIPv6: svc.bootstrapPreferIPv6,
|
||||
UseDNS64: svc.useDNS64,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
package dnssvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
const (
|
||||
listenAddr = "127.0.0.1:0"
|
||||
bootstrapAddr = "127.0.0.1:0"
|
||||
upstreamAddr = "upstream.example"
|
||||
)
|
||||
|
||||
upstreamErrCh := make(chan error, 1)
|
||||
upstreamStartedCh := make(chan struct{})
|
||||
upstreamSrv := &dns.Server{
|
||||
Addr: bootstrapAddr,
|
||||
Net: "udp",
|
||||
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
pt := testutil.PanicT{}
|
||||
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = append(resp.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{},
|
||||
A: netip.MustParseAddrPort(bootstrapAddr).Addr().AsSlice(),
|
||||
})
|
||||
|
||||
writeErr := w.WriteMsg(resp)
|
||||
require.NoError(pt, writeErr)
|
||||
}),
|
||||
NotifyStartedFunc: func() { close(upstreamStartedCh) },
|
||||
}
|
||||
|
||||
go func() {
|
||||
listenErr := upstreamSrv.ListenAndServe()
|
||||
if listenErr != nil {
|
||||
// Log these immediately to see what happens.
|
||||
t.Logf("upstream listen error: %s", listenErr)
|
||||
}
|
||||
|
||||
upstreamErrCh <- listenErr
|
||||
}()
|
||||
|
||||
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
|
||||
|
||||
c := &dnssvc.Config{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
|
||||
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
|
||||
UpstreamServers: []string{upstreamAddr},
|
||||
DNS64Prefixes: nil,
|
||||
UpstreamTimeout: testTimeout,
|
||||
BootstrapPreferIPv6: false,
|
||||
UseDNS64: false,
|
||||
}
|
||||
|
||||
svc, err := dnssvc.New(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
gotConf := svc.Config()
|
||||
require.NotNil(t, gotConf)
|
||||
require.Len(t, gotConf.Addresses, 1)
|
||||
|
||||
addr := gotConf.Addresses[0]
|
||||
|
||||
t.Run("dns", func(t *testing.T) {
|
||||
req := &dns.Msg{
|
||||
MsgHdr: dns.MsgHdr{
|
||||
Id: dns.Id(),
|
||||
RecursionDesired: true,
|
||||
},
|
||||
Question: []dns.Question{{
|
||||
Name: "example.com.",
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
cli := &dns.Client{}
|
||||
|
||||
var resp *dns.Msg
|
||||
require.Eventually(t, func() (ok bool) {
|
||||
var excErr error
|
||||
resp, _, excErr = cli.ExchangeContext(ctx, req, addr.String())
|
||||
|
||||
return excErr == nil
|
||||
}, testTimeout, testTimeout/10)
|
||||
|
||||
assert.NotNil(t, resp)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
err = svc.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = upstreamSrv.Shutdown()
|
||||
require.NoError(t, err)
|
||||
|
||||
err, ok := testutil.RequireReceive(t, upstreamErrCh, testTimeout)
|
||||
require.True(t, ok)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the AdGuard Home web service configuration structure.
|
||||
type Config struct {
|
||||
// Pprof is the configuration for the pprof debug API. It must not be nil.
|
||||
Pprof *PprofConfig
|
||||
|
||||
// ConfigManager is used to show information about services as well as
|
||||
// dynamically reconfigure them.
|
||||
ConfigManager ConfigManager
|
||||
|
||||
// Frontend is the filesystem with the frontend and other statically
|
||||
// compiled files.
|
||||
Frontend fs.FS
|
||||
|
||||
// TLS is the optional TLS configuration. If TLS is not nil,
|
||||
// SecureAddresses must not be empty.
|
||||
TLS *tls.Config
|
||||
|
||||
// Start is the time of start of AdGuard Home.
|
||||
Start time.Time
|
||||
|
||||
// OverrideAddress is the initial or override address for the HTTP API. If
|
||||
// set, it is used instead of [Addresses] and [SecureAddresses].
|
||||
OverrideAddress netip.AddrPort
|
||||
|
||||
// Addresses are the addresses on which to serve the plain HTTP API.
|
||||
Addresses []netip.AddrPort
|
||||
|
||||
// SecureAddresses are the addresses on which to serve the HTTPS API. If
|
||||
// SecureAddresses is not empty, TLS must not be nil.
|
||||
SecureAddresses []netip.AddrPort
|
||||
|
||||
// Timeout is the timeout for all server operations.
|
||||
Timeout time.Duration
|
||||
|
||||
// ForceHTTPS tells if all requests to Addresses should be redirected to a
|
||||
// secure address instead.
|
||||
//
|
||||
// TODO(a.garipov): Use; define rules, which address to redirect to.
|
||||
ForceHTTPS bool
|
||||
}
|
||||
|
||||
// PprofConfig is the configuration for the pprof debug API.
|
||||
type PprofConfig struct {
|
||||
Port uint16 `yaml:"port"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// Config returns the current configuration of the web service. Config must not
|
||||
// be called simultaneously with Start. If svc was initialized with ":0"
|
||||
// addresses, addrs will not return the actual bound ports until Start is
|
||||
// finished.
|
||||
func (svc *Service) Config() (c *Config) {
|
||||
c = &Config{
|
||||
Pprof: &PprofConfig{
|
||||
Port: svc.pprofPort,
|
||||
Enabled: svc.pprof != nil,
|
||||
},
|
||||
ConfigManager: svc.confMgr,
|
||||
TLS: svc.tls,
|
||||
// Leave Addresses and SecureAddresses empty and get the actual
|
||||
// addresses that include the :0 ones later.
|
||||
Start: svc.start,
|
||||
Timeout: svc.timeout,
|
||||
ForceHTTPS: svc.forceHTTPS,
|
||||
}
|
||||
|
||||
c.Addresses, c.SecureAddresses = svc.addrs()
|
||||
|
||||
return c
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
)
|
||||
|
||||
// DNS Settings Handlers
|
||||
|
||||
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
|
||||
// HTTP API.
|
||||
type ReqPatchSettingsDNS struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
BootstrapServers []string `json:"bootstrap_servers"`
|
||||
UpstreamServers []string `json:"upstream_servers"`
|
||||
DNS64Prefixes []netip.Prefix `json:"dns64_prefixes"`
|
||||
UpstreamTimeout aghhttp.JSONDuration `json:"upstream_timeout"`
|
||||
BootstrapPreferIPv6 bool `json:"bootstrap_prefer_ipv6"`
|
||||
UseDNS64 bool `json:"use_dns64"`
|
||||
}
|
||||
|
||||
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
|
||||
// DnsSettings object in the OpenAPI specification.
|
||||
type HTTPAPIDNSSettings struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
BootstrapServers []string `json:"bootstrap_servers"`
|
||||
UpstreamServers []string `json:"upstream_servers"`
|
||||
DNS64Prefixes []netip.Prefix `json:"dns64_prefixes"`
|
||||
UpstreamTimeout aghhttp.JSONDuration `json:"upstream_timeout"`
|
||||
BootstrapPreferIPv6 bool `json:"bootstrap_prefer_ipv6"`
|
||||
UseDNS64 bool `json:"use_dns64"`
|
||||
}
|
||||
|
||||
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
|
||||
// API.
|
||||
func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) {
|
||||
req := &ReqPatchSettingsDNS{
|
||||
Addresses: []netip.AddrPort{},
|
||||
BootstrapServers: []string{},
|
||||
UpstreamServers: []string{},
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Validate nulls and proper JSON patch.
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
newConf := &dnssvc.Config{
|
||||
Addresses: req.Addresses,
|
||||
BootstrapServers: req.BootstrapServers,
|
||||
UpstreamServers: req.UpstreamServers,
|
||||
DNS64Prefixes: req.DNS64Prefixes,
|
||||
UpstreamTimeout: time.Duration(req.UpstreamTimeout),
|
||||
BootstrapPreferIPv6: req.BootstrapPreferIPv6,
|
||||
UseDNS64: req.UseDNS64,
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
err = svc.confMgr.UpdateDNS(ctx, newConf)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("updating: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
newSvc := svc.confMgr.DNS()
|
||||
err = newSvc.Start()
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIDNSSettings{
|
||||
Addresses: newConf.Addresses,
|
||||
BootstrapServers: newConf.BootstrapServers,
|
||||
UpstreamServers: newConf.UpstreamServers,
|
||||
DNS64Prefixes: newConf.DNS64Prefixes,
|
||||
UpstreamTimeout: aghhttp.JSONDuration(newConf.UpstreamTimeout),
|
||||
BootstrapPreferIPv6: newConf.BootstrapPreferIPv6,
|
||||
UseDNS64: newConf.UseDNS64,
|
||||
})
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_HandlePatchSettingsDNS(t *testing.T) {
|
||||
wantDNS := &websvc.HTTPAPIDNSSettings{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:53")},
|
||||
BootstrapServers: []string{"1.0.0.1"},
|
||||
UpstreamServers: []string{"1.1.1.1"},
|
||||
DNS64Prefixes: []netip.Prefix{netip.MustParsePrefix("1234::/64")},
|
||||
UpstreamTimeout: aghhttp.JSONDuration(2 * time.Second),
|
||||
BootstrapPreferIPv6: true,
|
||||
UseDNS64: true,
|
||||
}
|
||||
|
||||
var started atomic.Bool
|
||||
confMgr := newConfigManager()
|
||||
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
|
||||
OnStart: func() (err error) {
|
||||
started.Store(true)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnShutdown: func(_ context.Context) (err error) { panic("not implemented") },
|
||||
OnConfig: func() (c *dnssvc.Config) { panic("not implemented") },
|
||||
}
|
||||
}
|
||||
confMgr.onUpdateDNS = func(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsDNS,
|
||||
}
|
||||
|
||||
req := jobj{
|
||||
"addresses": wantDNS.Addresses,
|
||||
"bootstrap_servers": wantDNS.BootstrapServers,
|
||||
"upstream_servers": wantDNS.UpstreamServers,
|
||||
"dns64_prefixes": wantDNS.DNS64Prefixes,
|
||||
"upstream_timeout": wantDNS.UpstreamTimeout,
|
||||
"bootstrap_prefer_ipv6": wantDNS.BootstrapPreferIPv6,
|
||||
"use_dns64": wantDNS.UseDNS64,
|
||||
}
|
||||
|
||||
respBody := httpPatch(t, u, req, http.StatusOK)
|
||||
resp := &websvc.HTTPAPIDNSSettings{}
|
||||
err := json.Unmarshal(respBody, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, started.Load())
|
||||
assert.Equal(t, wantDNS, resp)
|
||||
assert.Equal(t, wantDNS, resp)
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// HTTP Settings Handlers
|
||||
|
||||
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
|
||||
// HTTP API.
|
||||
type ReqPatchSettingsHTTP struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
//
|
||||
// TODO(a.garipov): Add wait time.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
|
||||
Timeout aghhttp.JSONDuration `json:"timeout"`
|
||||
}
|
||||
|
||||
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
|
||||
// HttpSettings object in the OpenAPI specification.
|
||||
type HTTPAPIHTTPSettings struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
Addresses []netip.AddrPort `json:"addresses"`
|
||||
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
|
||||
Timeout aghhttp.JSONDuration `json:"timeout"`
|
||||
ForceHTTPS bool `json:"force_https"`
|
||||
}
|
||||
|
||||
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
|
||||
// HTTP API.
|
||||
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
req := &ReqPatchSettingsHTTP{}
|
||||
|
||||
// TODO(a.garipov): Validate nulls and proper JSON patch.
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
newConf := &Config{
|
||||
Pprof: &PprofConfig{
|
||||
Port: svc.pprofPort,
|
||||
Enabled: svc.pprof != nil,
|
||||
},
|
||||
ConfigManager: svc.confMgr,
|
||||
Frontend: svc.frontend,
|
||||
TLS: svc.tls,
|
||||
Addresses: req.Addresses,
|
||||
SecureAddresses: req.SecureAddresses,
|
||||
Timeout: time.Duration(req.Timeout),
|
||||
ForceHTTPS: svc.forceHTTPS,
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIHTTPSettings{
|
||||
Addresses: newConf.Addresses,
|
||||
SecureAddresses: newConf.SecureAddresses,
|
||||
Timeout: aghhttp.JSONDuration(newConf.Timeout),
|
||||
ForceHTTPS: newConf.ForceHTTPS,
|
||||
})
|
||||
|
||||
cancelUpd := func() {}
|
||||
updCtx := context.Background()
|
||||
|
||||
ctx := r.Context()
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
updCtx, cancelUpd = context.WithDeadline(updCtx, deadline)
|
||||
}
|
||||
|
||||
// Launch the new HTTP service in a separate goroutine to let this handler
|
||||
// finish and thus, this server to shutdown.
|
||||
go svc.relaunch(updCtx, cancelUpd, newConf)
|
||||
}
|
||||
|
||||
// relaunch updates the web service in the configuration manager and starts it.
|
||||
// It is intended to be used as a goroutine.
|
||||
func (svc *Service) relaunch(ctx context.Context, cancel context.CancelFunc, newConf *Config) {
|
||||
defer log.OnPanic("websvc: relaunching")
|
||||
|
||||
defer cancel()
|
||||
|
||||
err := svc.confMgr.UpdateWeb(ctx, newConf)
|
||||
if err != nil {
|
||||
log.Error("websvc: updating web: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Consider better ways to do this.
|
||||
const maxUpdDur = 5 * time.Second
|
||||
updStart := time.Now()
|
||||
var newSvc agh.ServiceWithConfig[*Config]
|
||||
for newSvc = svc.confMgr.Web(); newSvc == svc; {
|
||||
if time.Since(updStart) >= maxUpdDur {
|
||||
log.Error("websvc: failed to update svc after %s", maxUpdDur)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("websvc: waiting for new websvc to be configured")
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
err = newSvc.Start()
|
||||
if err != nil {
|
||||
log.Error("websvc: new svc failed to start with error: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_HandlePatchSettingsHTTP(t *testing.T) {
|
||||
wantWeb := &websvc.HTTPAPIHTTPSettings{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")},
|
||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")},
|
||||
Timeout: aghhttp.JSONDuration(10 * time.Second),
|
||||
ForceHTTPS: false,
|
||||
}
|
||||
|
||||
svc, err := websvc.New(&websvc.Config{
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
TLS: &tls.Config{
|
||||
Certificates: []tls.Certificate{{}},
|
||||
},
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
|
||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
|
||||
Timeout: 5 * time.Second,
|
||||
ForceHTTPS: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
confMgr := newConfigManager()
|
||||
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) { return svc }
|
||||
confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) { return nil }
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsHTTP,
|
||||
}
|
||||
|
||||
req := jobj{
|
||||
"addresses": wantWeb.Addresses,
|
||||
"secure_addresses": wantWeb.SecureAddresses,
|
||||
"timeout": wantWeb.Timeout,
|
||||
"force_https": wantWeb.ForceHTTPS,
|
||||
}
|
||||
|
||||
respBody := httpPatch(t, u, req, http.StatusOK)
|
||||
resp := &websvc.HTTPAPIHTTPSettings{}
|
||||
err = json.Unmarshal(respBody, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantWeb, resp)
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Middlewares
|
||||
|
||||
// jsonMw sets the content type of the response to application/json.
|
||||
func jsonMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// logMw logs the queries with level debug.
|
||||
func logMw(h http.Handler) (wrapped http.HandlerFunc) {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
m, u := r.Method, r.RequestURI
|
||||
|
||||
log.Debug("websvc: %s %s started", m, u)
|
||||
defer func() { log.Debug("websvc: %s %s finished in %s", m, u, time.Since(start)) }()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(f)
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package websvc
|
||||
|
||||
// Path constants
|
||||
const (
|
||||
PathRoot = "/"
|
||||
PathFrontend = "/*filepath"
|
||||
|
||||
PathHealthCheck = "/health-check"
|
||||
|
||||
PathV1SettingsAll = "/api/v1/settings/all"
|
||||
PathV1SettingsDNS = "/api/v1/settings/dns"
|
||||
PathV1SettingsHTTP = "/api/v1/settings/http"
|
||||
PathV1SystemInfo = "/api/v1/system/info"
|
||||
)
|
||||
@@ -1,47 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
)
|
||||
|
||||
// All Settings Handlers
|
||||
|
||||
// RespGetV1SettingsAll describes the response of the GET /api/v1/settings/all
|
||||
// HTTP API.
|
||||
type RespGetV1SettingsAll struct {
|
||||
// TODO(a.garipov): Add more as we go.
|
||||
|
||||
DNS *HTTPAPIDNSSettings `json:"dns"`
|
||||
HTTP *HTTPAPIHTTPSettings `json:"http"`
|
||||
}
|
||||
|
||||
// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP
|
||||
// API.
|
||||
func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) {
|
||||
dnsSvc := svc.confMgr.DNS()
|
||||
dnsConf := dnsSvc.Config()
|
||||
|
||||
webSvc := svc.confMgr.Web()
|
||||
httpConf := webSvc.Config()
|
||||
|
||||
// TODO(a.garipov): Add all currently supported parameters.
|
||||
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SettingsAll{
|
||||
DNS: &HTTPAPIDNSSettings{
|
||||
Addresses: dnsConf.Addresses,
|
||||
BootstrapServers: dnsConf.BootstrapServers,
|
||||
UpstreamServers: dnsConf.UpstreamServers,
|
||||
DNS64Prefixes: dnsConf.DNS64Prefixes,
|
||||
UpstreamTimeout: aghhttp.JSONDuration(dnsConf.UpstreamTimeout),
|
||||
BootstrapPreferIPv6: dnsConf.BootstrapPreferIPv6,
|
||||
UseDNS64: dnsConf.UseDNS64,
|
||||
},
|
||||
HTTP: &HTTPAPIHTTPSettings{
|
||||
Addresses: httpConf.Addresses,
|
||||
SecureAddresses: httpConf.SecureAddresses,
|
||||
Timeout: aghhttp.JSONDuration(httpConf.Timeout),
|
||||
ForceHTTPS: httpConf.ForceHTTPS,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_HandleGetSettingsAll(t *testing.T) {
|
||||
// TODO(a.garipov): Add all currently supported parameters.
|
||||
|
||||
wantDNS := &websvc.HTTPAPIDNSSettings{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")},
|
||||
BootstrapServers: []string{"94.140.14.140", "94.140.14.141"},
|
||||
UpstreamServers: []string{"94.140.14.14", "1.1.1.1"},
|
||||
UpstreamTimeout: aghhttp.JSONDuration(1 * time.Second),
|
||||
BootstrapPreferIPv6: true,
|
||||
}
|
||||
|
||||
wantWeb := &websvc.HTTPAPIHTTPSettings{
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
|
||||
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
|
||||
Timeout: aghhttp.JSONDuration(5 * time.Second),
|
||||
ForceHTTPS: true,
|
||||
}
|
||||
|
||||
confMgr := newConfigManager()
|
||||
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||
c, err := dnssvc.New(&dnssvc.Config{
|
||||
Addresses: wantDNS.Addresses,
|
||||
UpstreamServers: wantDNS.UpstreamServers,
|
||||
BootstrapServers: wantDNS.BootstrapServers,
|
||||
UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout),
|
||||
BootstrapPreferIPv6: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
svc, err := websvc.New(&websvc.Config{
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
TLS: &tls.Config{
|
||||
Certificates: []tls.Certificate{{}},
|
||||
},
|
||||
Addresses: wantWeb.Addresses,
|
||||
SecureAddresses: wantWeb.SecureAddresses,
|
||||
Timeout: time.Duration(wantWeb.Timeout),
|
||||
ForceHTTPS: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
|
||||
return svc
|
||||
}
|
||||
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SettingsAll,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
resp := &websvc.RespGetV1SettingsAll{}
|
||||
err = json.Unmarshal(body, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantDNS, resp.DNS)
|
||||
assert.Equal(t, wantWeb, resp.HTTP)
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
)
|
||||
|
||||
// System Handlers
|
||||
|
||||
// RespGetV1SystemInfo describes the response of the GET /api/v1/system/info
|
||||
// HTTP API.
|
||||
type RespGetV1SystemInfo struct {
|
||||
Arch string `json:"arch"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
NewVersion string `json:"new_version,omitempty"`
|
||||
Start aghhttp.JSONTime `json:"start"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
|
||||
// API.
|
||||
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SystemInfo{
|
||||
Arch: runtime.GOARCH,
|
||||
Channel: version.Channel(),
|
||||
OS: runtime.GOOS,
|
||||
// TODO(a.garipov): Fill this when we have an updater.
|
||||
NewVersion: "",
|
||||
Start: aghhttp.JSONTime(svc.start),
|
||||
Version: version.Version(),
|
||||
})
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_handleGetV1SystemInfo(t *testing.T) {
|
||||
confMgr := newConfigManager()
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathV1SystemInfo,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
resp := &websvc.RespGetV1SystemInfo{}
|
||||
err := json.Unmarshal(body, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TODO(a.garipov): Consider making version.Channel and version.Version
|
||||
// testable and test these better.
|
||||
assert.NotEmpty(t, resp.Channel)
|
||||
|
||||
assert.Equal(t, resp.Arch, runtime.GOARCH)
|
||||
assert.Equal(t, resp.OS, runtime.GOOS)
|
||||
assert.Equal(t, testStart, time.Time(resp.Start))
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Wait Listener
|
||||
|
||||
// waitListener is a wrapper around a listener that also calls wg.Done() on the
|
||||
// first call to Accept. It is useful in situations where it is important to
|
||||
// catch the precise moment of the first call to Accept, for example when
|
||||
// starting an HTTP server.
|
||||
//
|
||||
// TODO(a.garipov): Move to aghnet?
|
||||
type waitListener struct {
|
||||
net.Listener
|
||||
|
||||
firstAcceptWG *sync.WaitGroup
|
||||
firstAcceptOnce sync.Once
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ net.Listener = (*waitListener)(nil)
|
||||
|
||||
// Accept implements the [net.Listener] interface for *waitListener.
|
||||
func (l *waitListener) Accept() (conn net.Conn, err error) {
|
||||
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
|
||||
|
||||
return l.Listener.Accept()
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/golibs/testutil/fakenet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWaitListener_Accept(t *testing.T) {
|
||||
var accepted atomic.Bool
|
||||
var l net.Listener = &fakenet.Listener{
|
||||
OnAccept: func() (conn net.Conn, err error) {
|
||||
accepted.Store(true)
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
OnAddr: func() (addr net.Addr) { panic("not implemented") },
|
||||
OnClose: func() (err error) { panic("not implemented") },
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
var wrapper net.Listener = &waitListener{
|
||||
Listener: l,
|
||||
firstAcceptWG: wg,
|
||||
}
|
||||
|
||||
_, _ = wrapper.Accept()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Eventually(t, accepted.Load, testTimeout, testTimeout/10)
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
// Package websvc contains the AdGuard Home HTTP API service.
|
||||
//
|
||||
// NOTE: Packages other than cmd must not import this package, as it imports
|
||||
// most other packages.
|
||||
//
|
||||
// TODO(a.garipov): Add tests.
|
||||
package websvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/mathutil"
|
||||
"github.com/AdguardTeam/golibs/pprofutil"
|
||||
httptreemux "github.com/dimfeld/httptreemux/v5"
|
||||
)
|
||||
|
||||
// ConfigManager is the configuration manager interface.
|
||||
type ConfigManager interface {
|
||||
DNS() (svc agh.ServiceWithConfig[*dnssvc.Config])
|
||||
Web() (svc agh.ServiceWithConfig[*Config])
|
||||
|
||||
UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error)
|
||||
UpdateWeb(ctx context.Context, c *Config) (err error)
|
||||
}
|
||||
|
||||
// Service is the AdGuard Home web service. A nil *Service is a valid
|
||||
// [agh.Service] that does nothing.
|
||||
type Service struct {
|
||||
confMgr ConfigManager
|
||||
frontend fs.FS
|
||||
tls *tls.Config
|
||||
pprof *http.Server
|
||||
start time.Time
|
||||
overrideAddr netip.AddrPort
|
||||
servers []*http.Server
|
||||
timeout time.Duration
|
||||
pprofPort uint16
|
||||
forceHTTPS bool
|
||||
}
|
||||
|
||||
// New returns a new properly initialized *Service. If c is nil, svc is a nil
|
||||
// *Service that does nothing. The fields of c must not be modified after
|
||||
// calling New.
|
||||
//
|
||||
// TODO(a.garipov): Get rid of this special handling of nil or explain it
|
||||
// better.
|
||||
func New(c *Config) (svc *Service, err error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
confMgr: c.ConfigManager,
|
||||
frontend: c.Frontend,
|
||||
tls: c.TLS,
|
||||
start: c.Start,
|
||||
overrideAddr: c.OverrideAddress,
|
||||
timeout: c.Timeout,
|
||||
forceHTTPS: c.ForceHTTPS,
|
||||
}
|
||||
|
||||
mux := newMux(svc)
|
||||
|
||||
if svc.overrideAddr != (netip.AddrPort{}) {
|
||||
svc.servers = []*http.Server{newSrv(svc.overrideAddr, nil, mux, c.Timeout)}
|
||||
} else {
|
||||
for _, a := range c.Addresses {
|
||||
svc.servers = append(svc.servers, newSrv(a, nil, mux, c.Timeout))
|
||||
}
|
||||
|
||||
for _, a := range c.SecureAddresses {
|
||||
svc.servers = append(svc.servers, newSrv(a, c.TLS, mux, c.Timeout))
|
||||
}
|
||||
}
|
||||
|
||||
svc.setupPprof(c.Pprof)
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// setupPprof sets the pprof properties of svc.
|
||||
func (svc *Service) setupPprof(c *PprofConfig) {
|
||||
if !c.Enabled {
|
||||
// Set to zero explicitly in case pprof used to be enabled before a
|
||||
// reconfiguration took place.
|
||||
runtime.SetBlockProfileRate(0)
|
||||
runtime.SetMutexProfileFraction(0)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(1)
|
||||
|
||||
pprofMux := http.NewServeMux()
|
||||
pprofutil.RoutePprof(pprofMux)
|
||||
|
||||
svc.pprofPort = c.Port
|
||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), c.Port)
|
||||
|
||||
// TODO(a.garipov): Consider making pprof timeout configurable.
|
||||
svc.pprof = newSrv(addr, nil, pprofMux, 10*time.Minute)
|
||||
}
|
||||
|
||||
// newSrv returns a new *http.Server with the given parameters.
|
||||
func newSrv(
|
||||
addr netip.AddrPort,
|
||||
tlsConf *tls.Config,
|
||||
h http.Handler,
|
||||
timeout time.Duration,
|
||||
) (srv *http.Server) {
|
||||
addrStr := addr.String()
|
||||
srv = &http.Server{
|
||||
Addr: addrStr,
|
||||
Handler: h,
|
||||
TLSConfig: tlsConf,
|
||||
ReadTimeout: timeout,
|
||||
WriteTimeout: timeout,
|
||||
IdleTimeout: timeout,
|
||||
ReadHeaderTimeout: timeout,
|
||||
}
|
||||
|
||||
if tlsConf == nil {
|
||||
srv.ErrorLog = log.StdLog("websvc: plain http: "+addrStr, log.ERROR)
|
||||
} else {
|
||||
srv.ErrorLog = log.StdLog("websvc: https: "+addrStr, log.ERROR)
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
// newMux returns a new HTTP request multiplexer for the AdGuard Home web
|
||||
// service.
|
||||
func newMux(svc *Service) (mux *httptreemux.ContextMux) {
|
||||
mux = httptreemux.NewContextMux()
|
||||
|
||||
routes := []struct {
|
||||
handler http.HandlerFunc
|
||||
method string
|
||||
pattern string
|
||||
isJSON bool
|
||||
}{{
|
||||
handler: svc.handleGetHealthCheck,
|
||||
method: http.MethodGet,
|
||||
pattern: PathHealthCheck,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
|
||||
method: http.MethodGet,
|
||||
pattern: PathFrontend,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: http.FileServer(http.FS(svc.frontend)).ServeHTTP,
|
||||
method: http.MethodGet,
|
||||
pattern: PathRoot,
|
||||
isJSON: false,
|
||||
}, {
|
||||
handler: svc.handleGetSettingsAll,
|
||||
method: http.MethodGet,
|
||||
pattern: PathV1SettingsAll,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: svc.handlePatchSettingsDNS,
|
||||
method: http.MethodPatch,
|
||||
pattern: PathV1SettingsDNS,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: svc.handlePatchSettingsHTTP,
|
||||
method: http.MethodPatch,
|
||||
pattern: PathV1SettingsHTTP,
|
||||
isJSON: true,
|
||||
}, {
|
||||
handler: svc.handleGetV1SystemInfo,
|
||||
method: http.MethodGet,
|
||||
pattern: PathV1SystemInfo,
|
||||
isJSON: true,
|
||||
}}
|
||||
|
||||
for _, r := range routes {
|
||||
var hdlr http.Handler
|
||||
if r.isJSON {
|
||||
hdlr = jsonMw(r.handler)
|
||||
} else {
|
||||
hdlr = r.handler
|
||||
}
|
||||
|
||||
mux.Handle(r.method, r.pattern, logMw(hdlr))
|
||||
}
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// addrs returns all addresses on which this server serves the HTTP API. addrs
|
||||
// must not be called simultaneously with Start. If svc was initialized with
|
||||
// ":0" addresses, addrs will not return the actual bound ports until Start is
|
||||
// finished.
|
||||
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
|
||||
if svc.overrideAddr != (netip.AddrPort{}) {
|
||||
return []netip.AddrPort{svc.overrideAddr}, nil
|
||||
}
|
||||
|
||||
for _, srv := range svc.servers {
|
||||
// Use MustParseAddrPort, since no errors should technically happen
|
||||
// here, because all servers must have a valid address.
|
||||
addrPort := netip.MustParseAddrPort(srv.Addr)
|
||||
|
||||
// [srv.Serve] will set TLSConfig to an almost empty value, so, instead
|
||||
// of relying only on the nilness of TLSConfig, check the length of the
|
||||
// certificates field as well.
|
||||
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
|
||||
addrs = append(addrs, addrPort)
|
||||
} else {
|
||||
secureAddrs = append(secureAddrs, addrPort)
|
||||
}
|
||||
}
|
||||
|
||||
return addrs, secureAddrs
|
||||
}
|
||||
|
||||
// handleGetHealthCheck is the handler for the GET /health-check HTTP API.
|
||||
func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "OK")
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.Service = (*Service)(nil)
|
||||
|
||||
// Start implements the [agh.Service] interface for *Service. svc may be nil.
|
||||
// After Start exits, all HTTP servers have tried to start, possibly failing and
|
||||
// writing error messages to the log.
|
||||
func (svc *Service) Start() (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pprofEnabled := svc.pprof != nil
|
||||
srvNum := len(svc.servers) + mathutil.BoolToNumber[int](pprofEnabled)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(srvNum)
|
||||
for _, srv := range svc.servers {
|
||||
go serve(srv, wg)
|
||||
}
|
||||
|
||||
if pprofEnabled {
|
||||
go serve(svc.pprof, wg)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// serve starts and runs srv and writes all errors into its log.
|
||||
func serve(srv *http.Server, wg *sync.WaitGroup) {
|
||||
addr := srv.Addr
|
||||
defer log.OnPanic(addr)
|
||||
|
||||
var proto string
|
||||
var l net.Listener
|
||||
var err error
|
||||
if srv.TLSConfig == nil {
|
||||
proto = "http"
|
||||
l, err = net.Listen("tcp", addr)
|
||||
} else {
|
||||
proto = "https"
|
||||
l, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
||||
}
|
||||
if err != nil {
|
||||
srv.ErrorLog.Printf("starting srv %s: binding: %s", addr, err)
|
||||
}
|
||||
|
||||
// Update the server's address in case the address had the port zero, which
|
||||
// would mean that a random available port was automatically chosen.
|
||||
srv.Addr = l.Addr().String()
|
||||
|
||||
log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
|
||||
|
||||
l = &waitListener{
|
||||
Listener: l,
|
||||
firstAcceptWG: wg,
|
||||
}
|
||||
|
||||
err = srv.Serve(l)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
srv.ErrorLog.Printf("starting srv %s: %s", addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||
// nil.
|
||||
func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||
if svc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() { err = errors.Annotate(err, "shutting down: %w") }()
|
||||
|
||||
var errs []error
|
||||
for _, srv := range svc.servers {
|
||||
shutdownErr := srv.Shutdown(ctx)
|
||||
if shutdownErr != nil {
|
||||
errs = append(errs, fmt.Errorf("srv %s: %w", srv.Addr, shutdownErr))
|
||||
}
|
||||
}
|
||||
|
||||
if svc.pprof != nil {
|
||||
shutdownErr := svc.pprof.Shutdown(ctx)
|
||||
if shutdownErr != nil {
|
||||
errs = append(errs, fmt.Errorf("pprof srv %s: %w", svc.pprof.Addr, shutdownErr))
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
package websvc
|
||||
|
||||
import "time"
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
@@ -1,196 +0,0 @@
|
||||
package websvc_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// testTimeout is the common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testStart is the server start value for tests.
|
||||
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// type check
|
||||
var _ websvc.ConfigManager = (*configManager)(nil)
|
||||
|
||||
// configManager is a [websvc.ConfigManager] for tests.
|
||||
type configManager struct {
|
||||
onDNS func() (svc agh.ServiceWithConfig[*dnssvc.Config])
|
||||
onWeb func() (svc agh.ServiceWithConfig[*websvc.Config])
|
||||
|
||||
onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error)
|
||||
onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error)
|
||||
}
|
||||
|
||||
// DNS implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) DNS() (svc agh.ServiceWithConfig[*dnssvc.Config]) {
|
||||
return m.onDNS()
|
||||
}
|
||||
|
||||
// Web implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) Web() (svc agh.ServiceWithConfig[*websvc.Config]) {
|
||||
return m.onWeb()
|
||||
}
|
||||
|
||||
// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
|
||||
return m.onUpdateDNS(ctx, c)
|
||||
}
|
||||
|
||||
// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager.
|
||||
func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
|
||||
return m.onUpdateWeb(ctx, c)
|
||||
}
|
||||
|
||||
// newConfigManager returns a *configManager all methods of which panic.
|
||||
func newConfigManager() (m *configManager) {
|
||||
return &configManager{
|
||||
onDNS: func() (svc agh.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") },
|
||||
onWeb: func() (svc agh.ServiceWithConfig[*websvc.Config]) { panic("not implemented") },
|
||||
onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newTestServer creates and starts a new web service instance as well as its
|
||||
// sole address. It also registers a cleanup procedure, which shuts the
|
||||
// instance down.
|
||||
//
|
||||
// TODO(a.garipov): Use svc or remove it.
|
||||
func newTestServer(
|
||||
t testing.TB,
|
||||
confMgr websvc.ConfigManager,
|
||||
) (svc *websvc.Service, addr netip.AddrPort) {
|
||||
t.Helper()
|
||||
|
||||
c := &websvc.Config{
|
||||
Pprof: &websvc.PprofConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
ConfigManager: confMgr,
|
||||
Frontend: &fakefs.FS{
|
||||
OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist },
|
||||
},
|
||||
TLS: nil,
|
||||
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
|
||||
SecureAddresses: nil,
|
||||
Timeout: testTimeout,
|
||||
Start: testStart,
|
||||
ForceHTTPS: false,
|
||||
}
|
||||
|
||||
svc, err := websvc.New(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
err = svc.Shutdown(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
c = svc.Config()
|
||||
require.NotNil(t, c)
|
||||
require.Len(t, c.Addresses, 1)
|
||||
|
||||
return svc, c.Addresses[0]
|
||||
}
|
||||
|
||||
// jobj is a utility alias for JSON objects.
|
||||
type jobj map[string]any
|
||||
|
||||
// httpGet is a helper that performs an HTTP GET request and returns the body of
|
||||
// the response as well as checks that the status code is correct.
|
||||
//
|
||||
// TODO(a.garipov): Add helpers for other methods.
|
||||
func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
|
||||
t.Helper()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
require.NoErrorf(t, err, "creating req")
|
||||
|
||||
httpCli := &http.Client{
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
resp, err := httpCli.Do(req)
|
||||
require.NoErrorf(t, err, "performing req")
|
||||
require.Equal(t, wantCode, resp.StatusCode)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
require.NoErrorf(t, err, "reading body")
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded
|
||||
// reqBody as the request body and returns the body of the response as well as
|
||||
// checks that the status code is correct.
|
||||
//
|
||||
// TODO(a.garipov): Add helpers for other methods.
|
||||
func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) {
|
||||
t.Helper()
|
||||
|
||||
b, err := json.Marshal(reqBody)
|
||||
require.NoErrorf(t, err, "marshaling reqBody")
|
||||
|
||||
req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b))
|
||||
require.NoErrorf(t, err, "creating req")
|
||||
|
||||
httpCli := &http.Client{
|
||||
Timeout: testTimeout,
|
||||
}
|
||||
resp, err := httpCli.Do(req)
|
||||
require.NoErrorf(t, err, "performing req")
|
||||
require.Equal(t, wantCode, resp.StatusCode)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
require.NoErrorf(t, err, "reading body")
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func TestService_Start_getHealthCheck(t *testing.T) {
|
||||
confMgr := newConfigManager()
|
||||
_, addr := newTestServer(t, confMgr)
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: addr.String(),
|
||||
Path: websvc.PathHealthCheck,
|
||||
}
|
||||
|
||||
body := httpGet(t, u, http.StatusOK)
|
||||
|
||||
assert.Equal(t, []byte("OK"), body)
|
||||
}
|
||||
20
main_next.go
20
main_next.go
@@ -1,20 +0,0 @@
|
||||
//go:build next
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/cmd"
|
||||
)
|
||||
|
||||
// Embed the prebuilt client here since we strive to keep .go files inside the
|
||||
// internal directory and the embed package is unable to embed files located
|
||||
// outside of the same or underlying directory.
|
||||
|
||||
//go:embed build
|
||||
var frontend embed.FS
|
||||
|
||||
func main() {
|
||||
cmd.Main(frontend)
|
||||
}
|
||||
@@ -35,7 +35,7 @@ set -f -u
|
||||
go_version="$( "${GO:-go}" version )"
|
||||
readonly go_version
|
||||
|
||||
go_min_version='go1.20.11'
|
||||
go_min_version='go1.20.10'
|
||||
go_version_msg="
|
||||
warning: your go version (${go_version}) is different from the recommended minimal one (${go_min_version}).
|
||||
if you have the version installed, please set the GO environment variable.
|
||||
|
||||
Reference in New Issue
Block a user