Compare commits
21 Commits
v0.108.0-b
...
6399-fix-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24a62d0638 | ||
|
|
f81a94eb94 | ||
|
|
ca898fe74e | ||
|
|
366ec81621 | ||
|
|
f9ee511094 | ||
|
|
deedc490e1 | ||
|
|
f8fe9bfc8b | ||
|
|
6cff5865d2 | ||
|
|
cbcc17a58b | ||
|
|
6a3906aa95 | ||
|
|
ffdebc7b2d | ||
|
|
f3817e4411 | ||
|
|
52713a2600 | ||
|
|
62ec0d5adc | ||
|
|
2a56c78f26 | ||
|
|
c0588146e7 | ||
|
|
f6e34adee7 | ||
|
|
e3cc3b0642 | ||
|
|
cd09ba63b6 | ||
|
|
1d1de1bfb5 | ||
|
|
763bbb5e6b |
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -1,7 +1,7 @@
|
|||||||
'name': 'build'
|
'name': 'build'
|
||||||
|
|
||||||
'env':
|
'env':
|
||||||
'GO_VERSION': '1.20.10'
|
'GO_VERSION': '1.20.11'
|
||||||
'NODE_VERSION': '16'
|
'NODE_VERSION': '16'
|
||||||
|
|
||||||
'on':
|
'on':
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -1,7 +1,7 @@
|
|||||||
'name': 'lint'
|
'name': 'lint'
|
||||||
|
|
||||||
'env':
|
'env':
|
||||||
'GO_VERSION': '1.20.10'
|
'GO_VERSION': '1.20.11'
|
||||||
|
|
||||||
'on':
|
'on':
|
||||||
'push':
|
'push':
|
||||||
|
|||||||
67
CHANGELOG.md
67
CHANGELOG.md
@@ -14,37 +14,73 @@ and this project adheres to
|
|||||||
<!--
|
<!--
|
||||||
## [v0.108.0] - TBA
|
## [v0.108.0] - TBA
|
||||||
|
|
||||||
## [v0.107.40] - 2023-10-25 (APPROX.)
|
## [v0.107.41] - 2023-11-01 (APPROX.)
|
||||||
|
|
||||||
See also the [v0.107.40 GitHub milestone][ms-v0.107.40].
|
See also the [v0.107.41 GitHub milestone][ms-v0.107.41].
|
||||||
|
|
||||||
[ms-v0.107.40]: https://github.com/AdguardTeam/AdGuardHome/milestone/75?closed=1
|
[ms-v0.107.41]: https://github.com/AdguardTeam/AdGuardHome/milestone/76?closed=1
|
||||||
|
|
||||||
NOTE: Add new changes BELOW THIS COMMENT.
|
NOTE: Add new changes BELOW THIS COMMENT.
|
||||||
-->
|
-->
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Ability to specify multiple domain specific upstreams per line, e.g.
|
||||||
|
`[/domain1/../domain2/]upstream1 upstream2 .. upstreamN` ([#4977]).
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- "Block" and "Unblock" buttons of the query log moved to the tooltip menu ([#684]).
|
- The height of ready-to-use filter lists has been increased ([#6358]).
|
||||||
|
- Improved authentication failure logging ([#6357]).
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Redundant shortening long client names in the Top Clients table ([#6338]).
|
||||||
|
- Scrolling column headers in the tables ([#6337]).
|
||||||
|
- `$important,dnsrewrite` rules do not take precedence over allowlist rules
|
||||||
|
([#6204]).
|
||||||
|
- Dark mode DNS rewrite background ([#6329]).
|
||||||
|
- Issues with QUIC and HTTP/3 upstreams on Linux ([#6335]).
|
||||||
|
|
||||||
|
[#4977]: https://github.com/AdguardTeam/AdGuardHome/issues/4977
|
||||||
|
[#6204]: https://github.com/AdguardTeam/AdGuardHome/issues/6204
|
||||||
|
[#6329]: https://github.com/AdguardTeam/AdGuardHome/issues/6329
|
||||||
|
[#6335]: https://github.com/AdguardTeam/AdGuardHome/issues/6335
|
||||||
|
[#6337]: https://github.com/AdguardTeam/AdGuardHome/issues/6337
|
||||||
|
[#6338]: https://github.com/AdguardTeam/AdGuardHome/issues/6338
|
||||||
|
[#6357]: https://github.com/AdguardTeam/AdGuardHome/issues/6357
|
||||||
|
[#6358]: https://github.com/AdguardTeam/AdGuardHome/issues/6358
|
||||||
|
|
||||||
|
<!--
|
||||||
|
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## [v0.107.40] - 2023-10-18
|
||||||
|
|
||||||
|
See also the [v0.107.40 GitHub milestone][ms-v0.107.40].
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- *Block* and *Unblock* buttons of the query log moved to the tooltip menu
|
||||||
|
([#684]).
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- Dashboard tables scroll issue ([#6180]).
|
- Dashboard tables scroll issue ([#6180]).
|
||||||
|
- The time shown in the statistics is one hour less than the current time
|
||||||
|
([#6296]).
|
||||||
- Issues with QUIC and HTTP/3 upstreams on FreeBSD ([#6301]).
|
- Issues with QUIC and HTTP/3 upstreams on FreeBSD ([#6301]).
|
||||||
- Panic on clearing query log ([#6304]).
|
- Panic on clearing the query log ([#6304]).
|
||||||
- The time shown in the statistics is one hour less than the current time ([#6296]).
|
|
||||||
- Issues with QUIC and HTTP/3 upstreams on FreeBSD ([#6301]).
|
|
||||||
- Panic on clearing query log ([#6304]).
|
|
||||||
|
|
||||||
[#684]: https://github.com/AdguardTeam/AdGuardHome/issues/684
|
[#684]: https://github.com/AdguardTeam/AdGuardHome/issues/684
|
||||||
[#6180]: https://github.com/AdguardTeam/AdGuardHome/issues/6180
|
[#6180]: https://github.com/AdguardTeam/AdGuardHome/issues/6180
|
||||||
[#6296]: https://github.com/AdguardTeam/AdGuardHome/issues/6296
|
[#6296]: https://github.com/AdguardTeam/AdGuardHome/issues/6296
|
||||||
[#6301]: https://github.com/AdguardTeam/AdGuardHome/issues/6301
|
[#6301]: https://github.com/AdguardTeam/AdGuardHome/issues/6301
|
||||||
[#6304]: https://github.com/AdguardTeam/AdGuardHome/issues/6304
|
[#6304]: https://github.com/AdguardTeam/AdGuardHome/issues/6304
|
||||||
|
|
||||||
<!--
|
[ms-v0.107.40]: https://github.com/AdguardTeam/AdGuardHome/milestone/75?closed=1
|
||||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -2549,11 +2585,12 @@ See also the [v0.104.2 GitHub milestone][ms-v0.104.2].
|
|||||||
|
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.40...HEAD
|
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.41...HEAD
|
||||||
[v0.107.40]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.39...v0.107.40
|
[v0.107.41]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.40...v0.107.41
|
||||||
-->
|
-->
|
||||||
|
|
||||||
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.39...HEAD
|
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.40...HEAD
|
||||||
|
[v0.107.40]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.39...v0.107.40
|
||||||
[v0.107.39]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.38...v0.107.39
|
[v0.107.39]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.38...v0.107.39
|
||||||
[v0.107.38]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.37...v0.107.38
|
[v0.107.38]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.37...v0.107.38
|
||||||
[v0.107.37]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.36...v0.107.37
|
[v0.107.37]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.36...v0.107.37
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ opinion, this cannot be legitimately counted as a Pi-Hole's feature.
|
|||||||
| Cross-platform | ✅ | ❌ (not natively, only via Docker) |
|
| Cross-platform | ✅ | ❌ (not natively, only via Docker) |
|
||||||
| Running as a DNS-over-HTTPS or DNS-over-TLS server | ✅ | ❌ (requires additional software) |
|
| Running as a DNS-over-HTTPS or DNS-over-TLS server | ✅ | ❌ (requires additional software) |
|
||||||
| Blocking phishing and malware domains | ✅ | ❌ (requires non-default blocklists) |
|
| Blocking phishing and malware domains | ✅ | ❌ (requires non-default blocklists) |
|
||||||
| Parental control (blocking adult domains) | ✅ | ❌ |
|
| Parental control (blocking adult domains) | ✅ | ❌ (requires non-default blocklists) |
|
||||||
| Force Safe search on search engines | ✅ | ❌ |
|
| Force Safe search on search engines | ✅ | ❌ |
|
||||||
| Per-client (device) configuration | ✅ | ✅ |
|
| Per-client (device) configuration | ✅ | ✅ |
|
||||||
| Access settings (choose who can use AGH DNS) | ✅ | ❌ |
|
| Access settings (choose who can use AGH DNS) | ✅ | ❌ |
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
# Make sure to sync any changes with the branch overrides below.
|
# Make sure to sync any changes with the branch overrides below.
|
||||||
'variables':
|
'variables':
|
||||||
'channel': 'edge'
|
'channel': 'edge'
|
||||||
'dockerGo': 'adguard/golang-ubuntu:7.4'
|
'dockerGo': 'adguard/golang-ubuntu:7.5'
|
||||||
|
|
||||||
'stages':
|
'stages':
|
||||||
- 'Build frontend':
|
- 'Build frontend':
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
# Make sure to sync any changes with the branch overrides below.
|
# Make sure to sync any changes with the branch overrides below.
|
||||||
'variables':
|
'variables':
|
||||||
'channel': 'edge'
|
'channel': 'edge'
|
||||||
'dockerGo': 'adguard/golang-ubuntu:7.4'
|
'dockerGo': 'adguard/golang-ubuntu:7.5'
|
||||||
'snapcraftChannel': 'edge'
|
'snapcraftChannel': 'edge'
|
||||||
|
|
||||||
'stages':
|
'stages':
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
'key': 'AHBRTSPECS'
|
'key': 'AHBRTSPECS'
|
||||||
'name': 'AdGuard Home - Build and run tests'
|
'name': 'AdGuard Home - Build and run tests'
|
||||||
'variables':
|
'variables':
|
||||||
'dockerGo': 'adguard/golang-ubuntu:7.4'
|
'dockerGo': 'adguard/golang-ubuntu:7.5'
|
||||||
|
|
||||||
'stages':
|
'stages':
|
||||||
- 'Tests':
|
- 'Tests':
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
"client_settings": "Client settings",
|
"client_settings": "Client settings",
|
||||||
"example_upstream_reserved": "an upstream <0>for specific domains</0>;",
|
"example_upstream_reserved": "an upstream <0>for specific domains</0>;",
|
||||||
|
"example_multiple_upstreams_reserved": "multiple upstreams <0>for specific domains</0>;",
|
||||||
"example_upstream_comment": "a comment.",
|
"example_upstream_comment": "a comment.",
|
||||||
"upstream_parallel": "Use parallel queries to speed up resolving by querying all upstream servers simultaneously.",
|
"upstream_parallel": "Use parallel queries to speed up resolving by querying all upstream servers simultaneously.",
|
||||||
"parallel_requests": "Parallel requests",
|
"parallel_requests": "Parallel requests",
|
||||||
|
|||||||
@@ -118,6 +118,11 @@ body {
|
|||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.modal-body--filters {
|
||||||
|
max-height: 600px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
.modal-body__item:not(:first-child) {
|
.modal-body__item:not(:first-child) {
|
||||||
padding-top: 1.5rem;
|
padding-top: 1.5rem;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ const renderIcons = (iconsData) => iconsData.map(({
|
|||||||
}) => <a key={iconName} href={href} target="_blank" rel="noopener noreferrer"
|
}) => <a key={iconName} href={href} target="_blank" rel="noopener noreferrer"
|
||||||
className={classNames('d-flex align-items-center', className)}
|
className={classNames('d-flex align-items-center', className)}
|
||||||
>
|
>
|
||||||
<svg className="nav-icon nav-icon--gray">
|
<svg className="icon icon--15 mr-1 icon--gray">
|
||||||
<use xlinkHref={`#${iconName}`} />
|
<use xlinkHref={`#${iconName}`} />
|
||||||
</svg>
|
</svg>
|
||||||
</a>);
|
</a>);
|
||||||
@@ -110,7 +110,7 @@ const Form = (props) => {
|
|||||||
const openAddFiltersModal = () => openModal(MODAL_TYPE.ADD_FILTERS);
|
const openAddFiltersModal = () => openModal(MODAL_TYPE.ADD_FILTERS);
|
||||||
|
|
||||||
return <form onSubmit={handleSubmit}>
|
return <form onSubmit={handleSubmit}>
|
||||||
<div className="modal-body modal-body--medium">
|
<div className="modal-body modal-body--filters">
|
||||||
{modalType === MODAL_TYPE.SELECT_MODAL_TYPE
|
{modalType === MODAL_TYPE.SELECT_MODAL_TYPE
|
||||||
&& <div className="d-flex justify-content-around">
|
&& <div className="d-flex justify-content-around">
|
||||||
<button onClick={openFilteringListModal}
|
<button onClick={openFilteringListModal}
|
||||||
|
|||||||
@@ -80,7 +80,7 @@
|
|||||||
color: var(--gray-f3);
|
color: var(--gray-f3);
|
||||||
}
|
}
|
||||||
|
|
||||||
.logs__text--client {
|
.logs__table .logs__text--client {
|
||||||
padding-right: 32px;
|
padding-right: 32px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -137,6 +137,22 @@ const Examples = (props) => (
|
|||||||
example_upstream_reserved
|
example_upstream_reserved
|
||||||
</Trans>
|
</Trans>
|
||||||
</li>
|
</li>
|
||||||
|
<li>
|
||||||
|
<code>[/example.local/]94.140.14.140 2a10:50c0::1:ff</code>: <Trans
|
||||||
|
components={[
|
||||||
|
<a
|
||||||
|
href="https://github.com/AdguardTeam/AdGuardHome/wiki/Configuration#upstreams-for-domains"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
key="0"
|
||||||
|
>
|
||||||
|
Link
|
||||||
|
</a>,
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
example_multiple_upstreams_reserved
|
||||||
|
</Trans>
|
||||||
|
</li>
|
||||||
<li>
|
<li>
|
||||||
<code>{COMMENT_LINE_DEFAULT_TOKEN} comment</code>: <Trans>
|
<code>{COMMENT_LINE_DEFAULT_TOKEN} comment</code>: <Trans>
|
||||||
example_upstream_comment
|
example_upstream_comment
|
||||||
|
|||||||
@@ -149,3 +149,7 @@
|
|||||||
.card .logs__row--blue {
|
.card .logs__row--blue {
|
||||||
background-color: #ecf7ff;
|
background-color: #ecf7ff;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .card .logs__row--blue {
|
||||||
|
background-color: var(--logs__row--blue-bgcolor);
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,6 +24,13 @@
|
|||||||
height: var(--size);
|
height: var(--size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.icon--15 {
|
||||||
|
--size: 0.95rem;
|
||||||
|
|
||||||
|
width: var(--size);
|
||||||
|
height: var(--size);
|
||||||
|
}
|
||||||
|
|
||||||
.icon--gray {
|
.icon--gray {
|
||||||
color: var(--gray-a5);
|
color: var(--gray-a5);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,10 +9,6 @@
|
|||||||
overflow: visible;
|
overflow: visible;
|
||||||
}
|
}
|
||||||
|
|
||||||
.ReactTable .rt-tbody {
|
|
||||||
overflow: visible;
|
|
||||||
}
|
|
||||||
|
|
||||||
.ReactTable .rt-noData {
|
.ReactTable .rt-noData {
|
||||||
color: var(--rt-nodata-color);
|
color: var(--rt-nodata-color);
|
||||||
background-color: var(--rt-nodata-bgcolor);
|
background-color: var(--rt-nodata-bgcolor);
|
||||||
|
|||||||
12
go.mod
12
go.mod
@@ -3,9 +3,9 @@ module github.com/AdguardTeam/AdGuardHome
|
|||||||
go 1.20
|
go 1.20
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.56.2
|
github.com/AdguardTeam/dnsproxy v0.56.3
|
||||||
github.com/AdguardTeam/golibs v0.17.1
|
github.com/AdguardTeam/golibs v0.17.2
|
||||||
github.com/AdguardTeam/urlfilter v0.17.0
|
github.com/AdguardTeam/urlfilter v0.17.3
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
||||||
github.com/bluele/gcache v0.0.2
|
github.com/bluele/gcache v0.0.2
|
||||||
@@ -17,7 +17,7 @@ require (
|
|||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/renameio/v2 v2.0.0
|
github.com/google/renameio/v2 v2.0.0
|
||||||
github.com/google/uuid v1.3.1
|
github.com/google/uuid v1.3.1
|
||||||
github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a
|
github.com/insomniacslk/dhcp v0.0.0-20231016090811-6a2c8fbdcc1c
|
||||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||||
github.com/kardianos/service v1.2.2
|
github.com/kardianos/service v1.2.2
|
||||||
github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118
|
github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118
|
||||||
@@ -27,9 +27,9 @@ require (
|
|||||||
// own code for that. Perhaps, use gopacket.
|
// own code for that. Perhaps, use gopacket.
|
||||||
github.com/mdlayher/raw v0.1.0
|
github.com/mdlayher/raw v0.1.0
|
||||||
github.com/miekg/dns v1.1.56
|
github.com/miekg/dns v1.1.56
|
||||||
github.com/quic-go/quic-go v0.39.1
|
github.com/quic-go/quic-go v0.39.2
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
github.com/ti-mo/netfilter v0.5.0
|
github.com/ti-mo/netfilter v0.5.1
|
||||||
go.etcd.io/bbolt v1.3.7
|
go.etcd.io/bbolt v1.3.7
|
||||||
golang.org/x/crypto v0.14.0
|
golang.org/x/crypto v0.14.0
|
||||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
|
||||||
|
|||||||
24
go.sum
24
go.sum
@@ -1,9 +1,9 @@
|
|||||||
github.com/AdguardTeam/dnsproxy v0.56.2 h1:+k1iUmp05QIqkgXWyPn70fki4FouHe6vHIyHguelKao=
|
github.com/AdguardTeam/dnsproxy v0.56.3 h1:WP1FooLfZQPHEH2SuwMtJsOurDt32rubGx0OddcsKT0=
|
||||||
github.com/AdguardTeam/dnsproxy v0.56.2/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM=
|
github.com/AdguardTeam/dnsproxy v0.56.3/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM=
|
||||||
github.com/AdguardTeam/golibs v0.17.1 h1:j3Ehhld5GI/amcHYG+CF0sJ4OOzAQ06BY3N/iBYJZ1M=
|
github.com/AdguardTeam/golibs v0.17.2 h1:vg6wHMjUKscnyPGRvxS5kAt7Uw4YxcJiITZliZ476W8=
|
||||||
github.com/AdguardTeam/golibs v0.17.1/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U=
|
github.com/AdguardTeam/golibs v0.17.2/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U=
|
||||||
github.com/AdguardTeam/urlfilter v0.17.0 h1:tUzhtR9wMx704GIP3cibsDQJrixlMHfwoQbYJfPdFow=
|
github.com/AdguardTeam/urlfilter v0.17.3 h1:fg/ObbnO0Cv6aw0tW6N/ETDMhhNvmcUUOZ7HlmKC3rw=
|
||||||
github.com/AdguardTeam/urlfilter v0.17.0/go.mod h1:bbuZjPUzm/Ip+nz5qPPbwIP+9rZyQbQad8Lt/0fCulU=
|
github.com/AdguardTeam/urlfilter v0.17.3/go.mod h1:Jru7jFfeH2CoDf150uDs+rRYcZBzHHBz05r9REyDKyE=
|
||||||
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
|
||||||
github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c=
|
github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c=
|
||||||
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
|
||||||
@@ -49,8 +49,8 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
|||||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
|
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
|
||||||
github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a h1:S33o3djA1nPRd+d/bf7jbbXytXuK/EoXow7+aa76grQ=
|
github.com/insomniacslk/dhcp v0.0.0-20231016090811-6a2c8fbdcc1c h1:PgxFEySCI41sH0mB7/2XswdXbUykQsRUGod8Rn+NubM=
|
||||||
github.com/insomniacslk/dhcp v0.0.0-20230908212754-65c27093e38a/go.mod h1:zmdm3sTSDP3vOOX3CEWRkkRHtKr1DxBx+J1OQFoDQQs=
|
github.com/insomniacslk/dhcp v0.0.0-20231016090811-6a2c8fbdcc1c/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
|
||||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||||
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||||
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||||
@@ -94,8 +94,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
|||||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||||
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
|
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
|
||||||
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||||
github.com/quic-go/quic-go v0.39.1 h1:d/m3oaN/SD2c+f7/yEjZxe2zEVotXprnrCCJ2y/ZZFE=
|
github.com/quic-go/quic-go v0.39.2 h1:hmwAf8zAHlvan0Y5PXxeeBFZEW17IW99sXLry8I2kjk=
|
||||||
github.com/quic-go/quic-go v0.39.1/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
|
github.com/quic-go/quic-go v0.39.2/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
|
||||||
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
|
||||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
@@ -105,8 +105,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
|
||||||
github.com/ti-mo/netfilter v0.5.0 h1:MZmsUw5bFRecOb0AeyjOPxTHg4UxYzyEs0Ek/6Lxoy8=
|
github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8=
|
||||||
github.com/ti-mo/netfilter v0.5.0/go.mod h1:nt+8B9hx/QpqHr7Hazq+2qMCCA8u2OTkyc/7+U9ARz8=
|
github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw=
|
||||||
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
|
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
|
||||||
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
|
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
|
||||||
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 h1:YcojQL98T/OO+rybuzn2+5KrD5dBwXIvYBvQ2cD3Avg=
|
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 h1:YcojQL98T/OO+rybuzn2+5KrD5dBwXIvYBvQ2cD3Avg=
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ func (s *Server) accessListJSON() (j accessListJSON) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleAccessList handles requests to the GET /control/access/list endpoint.
|
||||||
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||||
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
|
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
|
||||||
}
|
}
|
||||||
@@ -224,6 +225,7 @@ func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error
|
|||||||
return uc, uc.Validate()
|
return uc, uc.Validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleAccessSet handles requests to the POST /control/access/set endpoint.
|
||||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||||
list := &accessListJSON{}
|
list := &accessListJSON{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&list)
|
err := json.NewDecoder(r.Body).Decode(&list)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
)
|
)
|
||||||
@@ -151,6 +152,8 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
|||||||
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
|
// DNS-over-HTTPS requests, it will return the hostname part of the Host header
|
||||||
// if there is one.
|
// if there is one.
|
||||||
func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string, err error) {
|
func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string, err error) {
|
||||||
|
from := "tls conn"
|
||||||
|
|
||||||
switch proto {
|
switch proto {
|
||||||
case proxy.ProtoHTTPS:
|
case proxy.ProtoHTTPS:
|
||||||
r := pctx.HTTPRequest
|
r := pctx.HTTPRequest
|
||||||
@@ -164,6 +167,7 @@ func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string
|
|||||||
}
|
}
|
||||||
|
|
||||||
srvName = host
|
srvName = host
|
||||||
|
from = "host header"
|
||||||
}
|
}
|
||||||
case proxy.ProtoQUIC:
|
case proxy.ProtoQUIC:
|
||||||
qConn := pctx.QUICConnection
|
qConn := pctx.QUICConnection
|
||||||
@@ -183,5 +187,7 @@ func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string
|
|||||||
srvName = tc.ConnectionState().ServerName
|
srvName = tc.ConnectionState().ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug("dnsforward: got client server name %q from %s", srvName, from)
|
||||||
|
|
||||||
return srvName, nil
|
return srvName, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
@@ -444,19 +446,10 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, u := range upstreams {
|
err = validateUpstreamConfig(upstreams)
|
||||||
var ups string
|
if err != nil {
|
||||||
var domains []string
|
// Don't wrap the error since it's informative enough as is.
|
||||||
ups, domains, err = separateUpstream(u)
|
return nil, err
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = validateUpstream(ups, domains)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("validating upstream %q: %w", u, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conf, err = proxy.ParseUpstreamsConfig(
|
conf, err = proxy.ParseUpstreamsConfig(
|
||||||
@@ -467,6 +460,7 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Don't wrap the error since it's informative enough as is.
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if len(conf.Upstreams) == 0 {
|
} else if len(conf.Upstreams) == 0 {
|
||||||
return nil, errors.Error("no default upstreams specified")
|
return nil, errors.Error("no default upstreams specified")
|
||||||
@@ -475,6 +469,31 @@ func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err erro
|
|||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateUpstreamConfig validates each upstream from the upstream
|
||||||
|
// configuration and returns an error if any upstream is invalid.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Move into aghnet or even into dnsproxy.
|
||||||
|
func validateUpstreamConfig(conf []string) (err error) {
|
||||||
|
for _, u := range conf {
|
||||||
|
var ups []string
|
||||||
|
var domains []string
|
||||||
|
ups, domains, err = separateUpstream(u)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range ups {
|
||||||
|
_, err = validateUpstream(addr, len(domains) > 0)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateUpstreams validates each upstream and returns an error if any
|
// ValidateUpstreams validates each upstream and returns an error if any
|
||||||
// upstream is invalid or if there are no default upstreams specified.
|
// upstream is invalid or if there are no default upstreams specified.
|
||||||
//
|
//
|
||||||
@@ -534,14 +553,14 @@ var protocols = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
// validateUpstream returns an error if u alongside with domains is not a valid
|
||||||
// upstream configuration. useDefault is true if the upstream is
|
// upstream configuration. usesDefault is true if the upstream is
|
||||||
// domain-specific and is configured to point at the default upstream server
|
// domain-specific and is configured to point at the default upstream server
|
||||||
// which is validated separately. The upstream is considered domain-specific
|
// which is validated separately. specific reflects if the upstream is
|
||||||
// only if domains is at least not nil.
|
// domain-specific.
|
||||||
func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
func validateUpstream(u string, specific bool) (usesDefault bool, err error) {
|
||||||
// The special server address '#' means that default server must be used.
|
// The special server address '#' means that default server must be used.
|
||||||
if useDefault = u == "#" && domains != nil; useDefault {
|
if u == "#" && specific {
|
||||||
return useDefault, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the upstream has a valid protocol prefix.
|
// Check if the upstream has a valid protocol prefix.
|
||||||
@@ -567,12 +586,12 @@ func validateUpstream(u string, domains []string) (useDefault bool, err error) {
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// separateUpstream returns the upstream and the specified domains. domains is
|
// separateUpstream returns the upstreams and the specified domains. domains
|
||||||
// nil when the upstream is not domains-specific. Otherwise it may also be
|
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
||||||
// empty.
|
// empty.
|
||||||
func separateUpstream(upstreamStr string) (ups string, domains []string, err error) {
|
func separateUpstream(upstreamStr string) (upstreams, domains []string, err error) {
|
||||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||||
return upstreamStr, nil, nil
|
return []string{upstreamStr}, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()
|
||||||
@@ -582,9 +601,9 @@ func separateUpstream(upstreamStr string) (ups string, domains []string, err err
|
|||||||
case 2:
|
case 2:
|
||||||
// Go on.
|
// Go on.
|
||||||
case 1:
|
case 1:
|
||||||
return "", nil, errors.Error("missing separator")
|
return nil, nil, errors.Error("missing separator")
|
||||||
default:
|
default:
|
||||||
return "", []string{}, errors.Error("duplicated separator")
|
return nil, nil, errors.Error("duplicated separator")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, host := range strings.Split(parts[0], "/") {
|
for i, host := range strings.Split(parts[0], "/") {
|
||||||
@@ -594,21 +613,22 @@ func separateUpstream(upstreamStr string) (ups string, domains []string, err err
|
|||||||
|
|
||||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", domains, fmt.Errorf("domain at index %d: %w", i, err)
|
return nil, nil, fmt.Errorf("domain at index %d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
domains = append(domains, host)
|
domains = append(domains, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
return parts[1], domains, nil
|
return strings.Fields(parts[1]), domains, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// healthCheckFunc is a signature of function to check if upstream exchanges
|
// healthCheckFunc is a signature of function to check if upstream exchanges
|
||||||
// properly.
|
// properly.
|
||||||
type healthCheckFunc func(u upstream.Upstream) (err error)
|
type healthCheckFunc func(u upstream.Upstream) (err error)
|
||||||
|
|
||||||
// checkDNSUpstreamExc checks if the DNS upstream exchanges correctly.
|
// checkExchange is a [healthCheckFunc] that checks if the DNS upstream
|
||||||
func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
// exchanges correctly.
|
||||||
|
func checkExchange(u upstream.Upstream) (err error) {
|
||||||
// testTLD is the special-use fully-qualified domain name for testing the
|
// testTLD is the special-use fully-qualified domain name for testing the
|
||||||
// DNS server reachability.
|
// DNS server reachability.
|
||||||
//
|
//
|
||||||
@@ -638,11 +658,11 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkPrivateUpstreamExc checks if the upstream for resolving private
|
// checkPrivateExchange is a [healthCheckFunc] that checks if the upstream for
|
||||||
// addresses exchanges correctly.
|
// resolving private addresses exchanges correctly.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
// TODO(e.burkov): Think about testing the ip6.arpa. as well.
|
||||||
func checkPrivateUpstreamExc(u upstream.Upstream) (err error) {
|
func checkPrivateExchange(u upstream.Upstream) (err error) {
|
||||||
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
// inAddrArpaTLD is the special-use fully-qualified domain name for PTR IP
|
||||||
// address resolution.
|
// address resolution.
|
||||||
//
|
//
|
||||||
@@ -683,75 +703,153 @@ func (err domainSpecificTestError) Error() (msg string) {
|
|||||||
return fmt.Sprintf("WARNING: %s", err.error)
|
return fmt.Sprintf("WARNING: %s", err.error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseUpstreamLine parses line and creates the [upstream.Upstream] using opts
|
// checkUpstreamAddr creates the upstream using opts and, possibly, information
|
||||||
// and information from [s.dnsFilter.EtcHosts]. It returns an error if the line
|
// from system hosts files, then checks if the DNS upstream exchanges correctly.
|
||||||
// is not a valid upstream line, see [upstream.AddressToUpstream]. It's a
|
// It returns an error if addr is not valid DNS upstream address or the upstream
|
||||||
// caller's responsibility to close u.
|
// is not exchanging correctly.
|
||||||
func (s *Server) parseUpstreamLine(
|
//
|
||||||
line string,
|
// TODO(e.burkov): Remove the receiver.
|
||||||
opts *upstream.Options,
|
func (s *Server) checkUpstreamAddr(
|
||||||
) (u upstream.Upstream, specific bool, err error) {
|
addr string,
|
||||||
// Separate upstream from domains list.
|
specific bool,
|
||||||
upstreamAddr, domains, err := separateUpstream(line)
|
basicOpts *upstream.Options,
|
||||||
|
check healthCheckFunc,
|
||||||
|
) (err error) {
|
||||||
|
usesDefault, err := validateUpstream(addr, specific)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("wrong upstream format: %w", err)
|
return fmt.Errorf("wrong upstream format: %w", err)
|
||||||
}
|
} else if usesDefault {
|
||||||
|
|
||||||
specific = len(domains) > 0
|
|
||||||
|
|
||||||
useDefault, err := validateUpstream(upstreamAddr, domains)
|
|
||||||
if err != nil {
|
|
||||||
return nil, specific, fmt.Errorf("wrong upstream format: %w", err)
|
|
||||||
} else if useDefault {
|
|
||||||
return nil, specific, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
|
||||||
|
|
||||||
opts = &upstream.Options{
|
|
||||||
Bootstrap: opts.Bootstrap,
|
|
||||||
Timeout: opts.Timeout,
|
|
||||||
PreferIPv6: opts.PreferIPv6,
|
|
||||||
}
|
|
||||||
|
|
||||||
// dnsFilter can be nil during application update.
|
|
||||||
if s.dnsFilter != nil {
|
|
||||||
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr))
|
|
||||||
for _, rec := range recs {
|
|
||||||
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
|
||||||
}
|
|
||||||
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
|
||||||
}
|
|
||||||
u, err = upstream.AddressToUpstream(upstreamAddr, opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, specific, fmt.Errorf("creating upstream for %q: %w", upstreamAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return u, specific, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) checkDNS(line string, opts *upstream.Options, check healthCheckFunc) (err error) {
|
|
||||||
if IsCommentOrEmpty(line) {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var u upstream.Upstream
|
log.Debug("dnsforward: checking if upstream %q works", addr)
|
||||||
var specific bool
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil && specific {
|
if err != nil && specific {
|
||||||
err = domainSpecificTestError{error: err}
|
err = domainSpecificTestError{error: err}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
u, specific, err = s.parseUpstreamLine(line, opts)
|
opts := &upstream.Options{
|
||||||
if err != nil || u == nil {
|
Bootstrap: basicOpts.Bootstrap,
|
||||||
return err
|
Timeout: basicOpts.Timeout,
|
||||||
|
PreferIPv6: basicOpts.PreferIPv6,
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsFilter can be nil during application update.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Remove when update dnsproxy.
|
||||||
|
if s.dnsFilter != nil {
|
||||||
|
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(addr))
|
||||||
|
for _, rec := range recs {
|
||||||
|
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
|
||||||
|
}
|
||||||
|
sortNetIPAddrs(opts.ServerIPAddrs, opts.PreferIPv6)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := upstream.AddressToUpstream(addr, opts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating upstream for %q: %w", addr, err)
|
||||||
}
|
}
|
||||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||||
|
|
||||||
return check(u)
|
return check(u)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkResult is a result of checking an upstream server.
|
||||||
|
type checkResult = struct {
|
||||||
|
// status is an error message if the upstream server is not working. It's
|
||||||
|
// nil for working upstreams.
|
||||||
|
status error
|
||||||
|
|
||||||
|
// address is the upstream server address as given in the request. It may
|
||||||
|
// appear to be a whole line if it's incorrect itself.
|
||||||
|
address string
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkDNS parses an upstream configuration line using opts and checks if the
|
||||||
|
// specified upstreams are working using check. countWG is decremented when the
|
||||||
|
// expected number of results added to resNum, then results are sent to resCh.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Remove the receiver.
|
||||||
|
func (s *Server) checkDNS(
|
||||||
|
line string,
|
||||||
|
opts *upstream.Options,
|
||||||
|
check healthCheckFunc,
|
||||||
|
countWG *sync.WaitGroup,
|
||||||
|
resNum *atomic.Int32,
|
||||||
|
resCh chan<- checkResult,
|
||||||
|
) {
|
||||||
|
defer log.OnPanic("dnsforward: checking upstreams")
|
||||||
|
|
||||||
|
addrs, domains, err := separateUpstream(line)
|
||||||
|
if err != nil {
|
||||||
|
resNum.Add(1)
|
||||||
|
countWG.Done()
|
||||||
|
|
||||||
|
resCh <- checkResult{
|
||||||
|
address: line,
|
||||||
|
status: fmt.Errorf("wrong upstream format: %w", err),
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resNum.Add(int32(len(addrs)))
|
||||||
|
countWG.Done()
|
||||||
|
|
||||||
|
specific := len(domains) > 0
|
||||||
|
for _, addr := range addrs {
|
||||||
|
resCh <- checkResult{
|
||||||
|
address: addr,
|
||||||
|
status: s.checkUpstreamAddr(addr, specific, opts, check),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check returns the mapping of upstream addresses to their check results.
|
||||||
|
func (s *Server) check(req *upstreamJSON, opts *upstream.Options) (result map[string]string) {
|
||||||
|
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
||||||
|
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
||||||
|
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
||||||
|
|
||||||
|
countWG := &sync.WaitGroup{}
|
||||||
|
countWG.Add(len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams))
|
||||||
|
|
||||||
|
resNum := &atomic.Int32{}
|
||||||
|
resCh := make(chan checkResult)
|
||||||
|
|
||||||
|
for _, addr := range req.Upstreams {
|
||||||
|
go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh)
|
||||||
|
}
|
||||||
|
for _, addr := range req.FallbackDNS {
|
||||||
|
go s.checkDNS(addr, opts, checkExchange, countWG, resNum, resCh)
|
||||||
|
}
|
||||||
|
for _, addr := range req.PrivateUpstreams {
|
||||||
|
go s.checkDNS(addr, opts, checkPrivateExchange, countWG, resNum, resCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until all the servers are counted and enqueued.
|
||||||
|
countWG.Wait()
|
||||||
|
n := resNum.Load()
|
||||||
|
|
||||||
|
result = make(map[string]string, n)
|
||||||
|
for i := int32(0); i < n; i++ {
|
||||||
|
// TODO(e.burkov): Upstreams intended for different purposes should
|
||||||
|
// be distinguished in the result, even if specified equally.
|
||||||
|
res := <-resCh
|
||||||
|
if res.status != nil {
|
||||||
|
result[res.address] = res.status.Error()
|
||||||
|
} else {
|
||||||
|
result[res.address] = "OK"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
|
||||||
|
// endpoint.
|
||||||
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
req := &upstreamJSON{}
|
req := &upstreamJSON{}
|
||||||
err := json.NewDecoder(r.Body).Decode(req)
|
err := json.NewDecoder(r.Body).Decode(req)
|
||||||
@@ -761,65 +859,18 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bootstrapAddrs := stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||||
|
if len(bootstrapAddrs) == 0 {
|
||||||
|
bootstrapAddrs = defaultBootstrap
|
||||||
|
}
|
||||||
|
|
||||||
opts := &upstream.Options{
|
opts := &upstream.Options{
|
||||||
Bootstrap: req.BootstrapDNS,
|
Bootstrap: bootstrapAddrs,
|
||||||
Timeout: s.conf.UpstreamTimeout,
|
Timeout: s.conf.UpstreamTimeout,
|
||||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||||
}
|
}
|
||||||
if len(opts.Bootstrap) == 0 {
|
|
||||||
opts.Bootstrap = defaultBootstrap
|
|
||||||
}
|
|
||||||
|
|
||||||
type upsCheckResult = struct {
|
aghhttp.WriteJSONResponseOK(w, r, s.check(req, opts))
|
||||||
err error
|
|
||||||
host string
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
|
||||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
|
||||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
|
||||||
|
|
||||||
upsNum := len(req.Upstreams) + len(req.FallbackDNS) + len(req.PrivateUpstreams)
|
|
||||||
result := make(map[string]string, upsNum)
|
|
||||||
resCh := make(chan upsCheckResult, upsNum)
|
|
||||||
|
|
||||||
for _, ups := range req.Upstreams {
|
|
||||||
go func(ups string) {
|
|
||||||
resCh <- upsCheckResult{
|
|
||||||
host: ups,
|
|
||||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
|
||||||
}
|
|
||||||
}(ups)
|
|
||||||
}
|
|
||||||
for _, ups := range req.FallbackDNS {
|
|
||||||
go func(ups string) {
|
|
||||||
resCh <- upsCheckResult{
|
|
||||||
host: ups,
|
|
||||||
err: s.checkDNS(ups, opts, checkDNSUpstreamExc),
|
|
||||||
}
|
|
||||||
}(ups)
|
|
||||||
}
|
|
||||||
for _, ups := range req.PrivateUpstreams {
|
|
||||||
go func(ups string) {
|
|
||||||
resCh <- upsCheckResult{
|
|
||||||
host: ups,
|
|
||||||
err: s.checkDNS(ups, opts, checkPrivateUpstreamExc),
|
|
||||||
}
|
|
||||||
}(ups)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < upsNum; i++ {
|
|
||||||
// TODO(e.burkov): The upstreams used for both common and private
|
|
||||||
// resolving should be reported separately.
|
|
||||||
pair := <-resCh
|
|
||||||
if pair.err != nil {
|
|
||||||
result[pair.host] = pair.err.Error()
|
|
||||||
} else {
|
|
||||||
result[pair.host] = "OK"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||||
|
|||||||
@@ -49,13 +49,18 @@ func loadTestData(t *testing.T, casesFileName string, cases any) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const jsonExt = ".json"
|
const (
|
||||||
|
jsonExt = ".json"
|
||||||
|
|
||||||
|
// testBlockedRespTTL is the TTL for blocked responses to use in tests.
|
||||||
|
testBlockedRespTTL = 10
|
||||||
|
)
|
||||||
|
|
||||||
func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||||
filterConf := &filtering.Config{
|
filterConf := &filtering.Config{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
BlockedResponseTTL: 10,
|
BlockedResponseTTL: testBlockedRespTTL,
|
||||||
SafeBrowsingEnabled: true,
|
SafeBrowsingEnabled: true,
|
||||||
SafeBrowsingCacheSize: 1000,
|
SafeBrowsingCacheSize: 1000,
|
||||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||||
@@ -133,7 +138,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
|||||||
filterConf := &filtering.Config{
|
filterConf := &filtering.Config{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
BlockedResponseTTL: 10,
|
BlockedResponseTTL: testBlockedRespTTL,
|
||||||
SafeBrowsingEnabled: true,
|
SafeBrowsingEnabled: true,
|
||||||
SafeBrowsingCacheSize: 1000,
|
SafeBrowsingCacheSize: 1000,
|
||||||
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||||
@@ -229,6 +234,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
|||||||
}, {
|
}, {
|
||||||
name: "blocked_response_ttl",
|
name: "blocked_response_ttl",
|
||||||
wantSet: "",
|
wantSet: "",
|
||||||
|
}, {
|
||||||
|
name: "multiple_domain_specific_upstreams",
|
||||||
|
wantSet: "",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
var data map[string]struct {
|
var data map[string]struct {
|
||||||
@@ -250,6 +258,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
|||||||
s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
|
s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
|
||||||
s.conf = defaultConf
|
s.conf = defaultConf
|
||||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
|
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
|
||||||
|
s.dnsFilter.SetBlockedResponseTTL(testBlockedRespTTL)
|
||||||
})
|
})
|
||||||
|
|
||||||
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
rBody := io.NopCloser(bytes.NewReader(caseData.Req))
|
||||||
@@ -470,6 +479,8 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||||||
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
Host: newLocalUpstreamListener(t, 0, badHandler).String(),
|
||||||
}).String()
|
}).String()
|
||||||
|
|
||||||
|
goodAndBadUps := strings.Join([]string{goodUps, badUps}, " ")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
upsTimeout = 100 * time.Millisecond
|
upsTimeout = 100 * time.Millisecond
|
||||||
|
|
||||||
@@ -547,7 +558,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||||||
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
"upstream_dns": []string{"[/domain.example/]" + badUps},
|
||||||
},
|
},
|
||||||
wantResp: map[string]any{
|
wantResp: map[string]any{
|
||||||
"[/domain.example/]" + badUps: `WARNING: couldn't communicate ` +
|
badUps: `WARNING: couldn't communicate ` +
|
||||||
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||||
`dns: id mismatch`,
|
`dns: id mismatch`,
|
||||||
},
|
},
|
||||||
@@ -585,6 +596,40 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||||||
goodUps: "OK",
|
goodUps: "OK",
|
||||||
},
|
},
|
||||||
name: "fallback_comment_mix",
|
name: "fallback_comment_mix",
|
||||||
|
}, {
|
||||||
|
body: map[string]any{
|
||||||
|
"upstream_dns": []string{"[/domain.example/]" + goodUps + " " + badUps},
|
||||||
|
},
|
||||||
|
wantResp: map[string]any{
|
||||||
|
goodUps: "OK",
|
||||||
|
badUps: `WARNING: couldn't communicate ` +
|
||||||
|
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||||
|
`dns: id mismatch`,
|
||||||
|
},
|
||||||
|
name: "multiple_domain_specific_upstreams",
|
||||||
|
}, {
|
||||||
|
body: map[string]any{
|
||||||
|
"upstream_dns": []string{"[/domain.example/]/]1.2.3.4"},
|
||||||
|
},
|
||||||
|
wantResp: map[string]any{
|
||||||
|
"[/domain.example/]/]1.2.3.4": `wrong upstream format: ` +
|
||||||
|
`bad upstream for domain "[/domain.example/]/]1.2.3.4": ` +
|
||||||
|
`duplicated separator`,
|
||||||
|
},
|
||||||
|
name: "bad_specification",
|
||||||
|
}, {
|
||||||
|
body: map[string]any{
|
||||||
|
"upstream_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
"fallback_dns": []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
"private_upstream": []string{"[/domain.example/]" + goodAndBadUps},
|
||||||
|
},
|
||||||
|
wantResp: map[string]any{
|
||||||
|
goodUps: "OK",
|
||||||
|
badUps: `WARNING: couldn't communicate ` +
|
||||||
|
`with upstream: exchanging with ` + badUps + ` over tcp: ` +
|
||||||
|
`dns: id mismatch`,
|
||||||
|
},
|
||||||
|
name: "all_different",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
|||||||
@@ -839,5 +839,47 @@
|
|||||||
"edns_cs_use_custom": false,
|
"edns_cs_use_custom": false,
|
||||||
"edns_cs_custom_ip": ""
|
"edns_cs_custom_ip": ""
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"multiple_domain_specific_upstreams": {
|
||||||
|
"req": {
|
||||||
|
"upstream_dns": [
|
||||||
|
"8.8.8.8:77",
|
||||||
|
"[/example.com/]8.8.4.4:77 9.9.9.10 https://1.1.1.1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"want": {
|
||||||
|
"upstream_dns": [
|
||||||
|
"8.8.8.8:77",
|
||||||
|
"[/example.com/]8.8.4.4:77 9.9.9.10 https://1.1.1.1"
|
||||||
|
],
|
||||||
|
"upstream_dns_file": "",
|
||||||
|
"bootstrap_dns": [
|
||||||
|
"9.9.9.10",
|
||||||
|
"149.112.112.10",
|
||||||
|
"2620:fe::10",
|
||||||
|
"2620:fe::fe:10"
|
||||||
|
],
|
||||||
|
"fallback_dns": [],
|
||||||
|
"protection_enabled": true,
|
||||||
|
"protection_disabled_until": null,
|
||||||
|
"ratelimit": 0,
|
||||||
|
"blocking_mode": "default",
|
||||||
|
"blocking_ipv4": "",
|
||||||
|
"blocking_ipv6": "",
|
||||||
|
"blocked_response_ttl": 10,
|
||||||
|
"edns_cs_enabled": false,
|
||||||
|
"dnssec_enabled": false,
|
||||||
|
"disable_ipv6": false,
|
||||||
|
"upstream_mode": "",
|
||||||
|
"cache_size": 0,
|
||||||
|
"cache_ttl_min": 0,
|
||||||
|
"cache_ttl_max": 0,
|
||||||
|
"cache_optimistic": false,
|
||||||
|
"resolve_clients": false,
|
||||||
|
"use_private_ptr_resolvers": false,
|
||||||
|
"local_ptr_upstreams": [],
|
||||||
|
"edns_cs_use_custom": false,
|
||||||
|
"edns_cs_custom_ip": ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -263,30 +263,6 @@ func assignUniqueFilterID() int64 {
|
|||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets up a timer that will be checking for filters updates periodically
|
|
||||||
func (d *DNSFilter) periodicallyRefreshFilters() {
|
|
||||||
const maxInterval = 1 * 60 * 60
|
|
||||||
ivl := 5 // use a dynamically increasing time interval
|
|
||||||
for {
|
|
||||||
isNetErr, ok := false, false
|
|
||||||
if d.conf.FiltersUpdateIntervalHours != 0 {
|
|
||||||
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
|
||||||
if ok && !isNetErr {
|
|
||||||
ivl = maxInterval
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isNetErr {
|
|
||||||
ivl *= 2
|
|
||||||
if ivl > maxInterval {
|
|
||||||
ivl = maxInterval
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Duration(ivl) * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
// tryRefreshFilters is like [refreshFilters], but backs down if the update is
|
||||||
// already going on.
|
// already going on.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -257,6 +257,9 @@ type DNSFilter struct {
|
|||||||
// conf contains filtering parameters.
|
// conf contains filtering parameters.
|
||||||
conf *Config
|
conf *Config
|
||||||
|
|
||||||
|
// done is the channel to signal to stop running filters updates loop.
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
// Channel for passing data to filters-initializer goroutine
|
// Channel for passing data to filters-initializer goroutine
|
||||||
filtersInitializerChan chan filtersInitializerParams
|
filtersInitializerChan chan filtersInitializerParams
|
||||||
filtersInitializerLock sync.Mutex
|
filtersInitializerLock sync.Mutex
|
||||||
@@ -424,24 +427,15 @@ func (d *DNSFilter) setFilters(blockFilters, allowFilters []Filter, async bool)
|
|||||||
return d.initFiltering(allowFilters, blockFilters)
|
return d.initFiltering(allowFilters, blockFilters)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starts initializing new filters by signal from channel
|
|
||||||
func (d *DNSFilter) filtersInitializer() {
|
|
||||||
for {
|
|
||||||
params := <-d.filtersInitializerChan
|
|
||||||
err := d.initFiltering(params.allowFilters, params.blockFilters)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("filtering: initializing: %s", err)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close - close the object
|
// Close - close the object
|
||||||
func (d *DNSFilter) Close() {
|
func (d *DNSFilter) Close() {
|
||||||
d.engineLock.Lock()
|
d.engineLock.Lock()
|
||||||
defer d.engineLock.Unlock()
|
defer d.engineLock.Unlock()
|
||||||
|
|
||||||
|
if d.done != nil {
|
||||||
|
d.done <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
d.reset()
|
d.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1131,19 +1125,64 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
|
|||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start - start the module:
|
// Start registers web handlers and starts filters updates loop.
|
||||||
// . start async filtering initializer goroutine
|
|
||||||
// . register web handlers
|
|
||||||
func (d *DNSFilter) Start() {
|
func (d *DNSFilter) Start() {
|
||||||
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
d.filtersInitializerChan = make(chan filtersInitializerParams, 1)
|
||||||
go d.filtersInitializer()
|
d.done = make(chan struct{}, 1)
|
||||||
|
|
||||||
d.RegisterFilteringHandlers()
|
d.RegisterFilteringHandlers()
|
||||||
|
|
||||||
// Here we should start updating filters,
|
go d.updatesLoop()
|
||||||
// but currently we can't wake up the periodic task to do so.
|
}
|
||||||
// So for now we just start this periodic task from here.
|
|
||||||
go d.periodicallyRefreshFilters()
|
// updatesLoop initializes new filters and checks for filters updates in a loop.
|
||||||
|
func (d *DNSFilter) updatesLoop() {
|
||||||
|
defer log.OnPanic("filtering: updates loop")
|
||||||
|
|
||||||
|
ivl := time.Second * 5
|
||||||
|
t := time.NewTimer(ivl)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case params := <-d.filtersInitializerChan:
|
||||||
|
err := d.initFiltering(params.allowFilters, params.blockFilters)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("filtering: initializing: %s", err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case <-t.C:
|
||||||
|
ivl = d.periodicallyRefreshFilters(ivl)
|
||||||
|
t.Reset(ivl)
|
||||||
|
case <-d.done:
|
||||||
|
t.Stop()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// periodicallyRefreshFilters checks for filters updates and returns time
|
||||||
|
// interval for the next update.
|
||||||
|
func (d *DNSFilter) periodicallyRefreshFilters(ivl time.Duration) (nextIvl time.Duration) {
|
||||||
|
const maxInterval = time.Hour
|
||||||
|
|
||||||
|
if d.conf.FiltersUpdateIntervalHours == 0 {
|
||||||
|
return ivl
|
||||||
|
}
|
||||||
|
|
||||||
|
isNetErr, ok := false, false
|
||||||
|
_, isNetErr, ok = d.tryRefreshFilters(true, true, false)
|
||||||
|
|
||||||
|
if ok && !isNetErr {
|
||||||
|
ivl = maxInterval
|
||||||
|
} else if isNetErr {
|
||||||
|
ivl *= 2
|
||||||
|
// TODO(s.chzhen): Use built-in function max in Go 1.21.
|
||||||
|
ivl = mathutil.Max(ivl, maxInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ivl
|
||||||
}
|
}
|
||||||
|
|
||||||
// Safe browsing and parental control methods.
|
// Safe browsing and parental control methods.
|
||||||
|
|||||||
@@ -4,32 +4,17 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/httphdr"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
|
||||||
"go.etcd.io/bbolt"
|
"go.etcd.io/bbolt"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// cookieTTL is the time-to-live of the session cookie.
|
|
||||||
const cookieTTL = 365 * timeutil.Day
|
|
||||||
|
|
||||||
// sessionCookieName is the name of the session cookie.
|
|
||||||
const sessionCookieName = "agh_session"
|
|
||||||
|
|
||||||
// sessionTokenSize is the length of session token in bytes.
|
// sessionTokenSize is the length of session token in bytes.
|
||||||
const sessionTokenSize = 16
|
const sessionTokenSize = 16
|
||||||
|
|
||||||
@@ -69,7 +54,7 @@ func (s *session) deserialize(data []byte) bool {
|
|||||||
// Auth - global object
|
// Auth - global object
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
db *bbolt.DB
|
db *bbolt.DB
|
||||||
raleLimiter *authRateLimiter
|
rateLimiter *authRateLimiter
|
||||||
sessions map[string]*session
|
sessions map[string]*session
|
||||||
users []webUser
|
users []webUser
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
@@ -77,6 +62,8 @@ type Auth struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// webUser represents a user of the Web UI.
|
// webUser represents a user of the Web UI.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Improve naming.
|
||||||
type webUser struct {
|
type webUser struct {
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
PasswordHash string `yaml:"password"`
|
PasswordHash string `yaml:"password"`
|
||||||
@@ -88,7 +75,7 @@ func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter
|
|||||||
|
|
||||||
a := &Auth{
|
a := &Auth{
|
||||||
sessionTTL: sessionTTL,
|
sessionTTL: sessionTTL,
|
||||||
raleLimiter: rateLimiter,
|
rateLimiter: rateLimiter,
|
||||||
sessions: make(map[string]*session),
|
sessions: make(map[string]*session),
|
||||||
users: users,
|
users: users,
|
||||||
}
|
}
|
||||||
@@ -216,8 +203,8 @@ func (a *Auth) storeSession(data []byte, s *session) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove session from file
|
// removeSessionFromFile removes a stored session from the DB file on disk.
|
||||||
func (a *Auth) removeSession(sess []byte) {
|
func (a *Auth) removeSessionFromFile(sess []byte) {
|
||||||
tx, err := a.db.Begin(true)
|
tx, err := a.db.Begin(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("auth: bbolt.Begin: %s", err)
|
log.Error("auth: bbolt.Begin: %s", err)
|
||||||
@@ -279,7 +266,7 @@ func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
|||||||
if s.expire <= now {
|
if s.expire <= now {
|
||||||
delete(a.sessions, sess)
|
delete(a.sessions, sess)
|
||||||
key, _ := hex.DecodeString(sess)
|
key, _ := hex.DecodeString(sess)
|
||||||
a.removeSession(key)
|
a.removeSessionFromFile(key)
|
||||||
|
|
||||||
return checkSessionExpired
|
return checkSessionExpired
|
||||||
}
|
}
|
||||||
@@ -301,351 +288,17 @@ func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
|||||||
return checkSessionOK
|
return checkSessionOK
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSession - remove session
|
// removeSession removes the session from the active sessions and the disk.
|
||||||
func (a *Auth) RemoveSession(sess string) {
|
func (a *Auth) removeSession(sess string) {
|
||||||
key, _ := hex.DecodeString(sess)
|
key, _ := hex.DecodeString(sess)
|
||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
delete(a.sessions, sess)
|
delete(a.sessions, sess)
|
||||||
a.lock.Unlock()
|
a.lock.Unlock()
|
||||||
a.removeSession(key)
|
a.removeSessionFromFile(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
type loginJSON struct {
|
// addUser adds a new user with the given password.
|
||||||
Name string `json:"name"`
|
func (a *Auth) addUser(u *webUser, password string) (err error) {
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// newSessionToken returns cryptographically secure randomly generated slice of
|
|
||||||
// bytes of sessionTokenSize length.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
|
||||||
func newSessionToken() (data []byte, err error) {
|
|
||||||
randData := make([]byte, sessionTokenSize)
|
|
||||||
|
|
||||||
_, err = rand.Read(randData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return randData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newCookie creates a new authentication cookie.
|
|
||||||
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
|
|
||||||
rateLimiter := a.raleLimiter
|
|
||||||
u, ok := a.findUser(req.Name, req.Password)
|
|
||||||
if !ok {
|
|
||||||
if rateLimiter != nil {
|
|
||||||
rateLimiter.inc(addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.Error("invalid username or password")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rateLimiter != nil {
|
|
||||||
rateLimiter.remove(addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := newSessionToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generating token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
a.addSession(sess, &session{
|
|
||||||
userName: u.Name,
|
|
||||||
expire: uint32(now.Unix()) + a.sessionTTL,
|
|
||||||
})
|
|
||||||
|
|
||||||
return &http.Cookie{
|
|
||||||
Name: sessionCookieName,
|
|
||||||
Value: hex.EncodeToString(sess),
|
|
||||||
Path: "/",
|
|
||||||
Expires: now.Add(cookieTTL),
|
|
||||||
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// realIP extracts the real IP address of the client from an HTTP request using
|
|
||||||
// the known HTTP headers.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Currently, this is basically a copy of a similar function in
|
|
||||||
// module dnsproxy. This should really become a part of module golibs and be
|
|
||||||
// replaced both here and there. Or be replaced in both places by
|
|
||||||
// a well-maintained third-party module.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Support header Forwarded from RFC 7329.
|
|
||||||
func realIP(r *http.Request) (ip net.IP, err error) {
|
|
||||||
proxyHeaders := []string{
|
|
||||||
httphdr.CFConnectingIP,
|
|
||||||
httphdr.TrueClientIP,
|
|
||||||
httphdr.XRealIP,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, h := range proxyHeaders {
|
|
||||||
v := r.Header.Get(h)
|
|
||||||
ip = net.ParseIP(v)
|
|
||||||
if ip != nil {
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If none of the above yielded any results, get the leftmost IP address
|
|
||||||
// from the X-Forwarded-For header.
|
|
||||||
s := r.Header.Get(httphdr.XForwardedFor)
|
|
||||||
ipStrs := strings.SplitN(s, ", ", 2)
|
|
||||||
ip = net.ParseIP(ipStrs[0])
|
|
||||||
if ip != nil {
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// When everything else fails, just return the remote address as understood
|
|
||||||
// by the stdlib.
|
|
||||||
ipStr, err := netutil.SplitHost(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("getting ip from client addr: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.ParseIP(ipStr), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
|
|
||||||
// when it writes to the log.
|
|
||||||
func writeErrorWithIP(
|
|
||||||
r *http.Request,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
code int,
|
|
||||||
remoteIP string,
|
|
||||||
format string,
|
|
||||||
args ...any,
|
|
||||||
) {
|
|
||||||
text := fmt.Sprintf(format, args...)
|
|
||||||
log.Error("%s %s %s: from ip %s: %s", r.Method, r.Host, r.URL, remoteIP, text)
|
|
||||||
http.Error(w, text, code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
|
||||||
req := loginJSON{}
|
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
if err != nil {
|
|
||||||
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var remoteIP string
|
|
||||||
// realIP cannot be used here without taking TrustedProxies into account due
|
|
||||||
// to security issues.
|
|
||||||
//
|
|
||||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
|
||||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
|
||||||
writeErrorWithIP(
|
|
||||||
r,
|
|
||||||
w,
|
|
||||||
http.StatusBadRequest,
|
|
||||||
r.RemoteAddr,
|
|
||||||
"auth: getting remote address: %s",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if rateLimiter := Context.auth.raleLimiter; rateLimiter != nil {
|
|
||||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
|
||||||
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
|
||||||
writeErrorWithIP(
|
|
||||||
r,
|
|
||||||
w,
|
|
||||||
http.StatusTooManyRequests,
|
|
||||||
remoteIP,
|
|
||||||
"auth: blocked for %s",
|
|
||||||
left,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cookie, err := Context.auth.newCookie(req, remoteIP)
|
|
||||||
if err != nil {
|
|
||||||
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use realIP here, since this IP address is only used for logging.
|
|
||||||
ip, err := realIP(r)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
|
||||||
|
|
||||||
http.SetCookie(w, cookie)
|
|
||||||
|
|
||||||
h := w.Header()
|
|
||||||
h.Set(httphdr.CacheControl, "no-store, no-cache, must-revalidate, proxy-revalidate")
|
|
||||||
h.Set(httphdr.Pragma, "no-cache")
|
|
||||||
h.Set(httphdr.Expires, "0")
|
|
||||||
|
|
||||||
aghhttp.OK(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
||||||
respHdr := w.Header()
|
|
||||||
c, err := r.Cookie(sessionCookieName)
|
|
||||||
if err != nil {
|
|
||||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
|
||||||
// The user is already logged out.
|
|
||||||
respHdr.Set(httphdr.Location, "/login.html")
|
|
||||||
w.WriteHeader(http.StatusFound)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Context.auth.RemoveSession(c.Value)
|
|
||||||
|
|
||||||
c = &http.Cookie{
|
|
||||||
Name: sessionCookieName,
|
|
||||||
Value: "",
|
|
||||||
Path: "/",
|
|
||||||
Expires: time.Unix(0, 0),
|
|
||||||
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
}
|
|
||||||
|
|
||||||
respHdr.Set(httphdr.Location, "/login.html")
|
|
||||||
respHdr.Set(httphdr.SetCookie, c.String())
|
|
||||||
w.WriteHeader(http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterAuthHandlers - register handlers
|
|
||||||
func RegisterAuthHandlers() {
|
|
||||||
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
|
||||||
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// optionalAuthThird return true if user should authenticate first.
|
|
||||||
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
|
||||||
if glProcessCookie(r) {
|
|
||||||
log.Debug("auth: authentication is handled by GL-Inet submodule")
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// redirect to login page if not authenticated
|
|
||||||
isAuthenticated := false
|
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
|
||||||
if err != nil {
|
|
||||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
|
||||||
// Check Basic authentication.
|
|
||||||
user, pass, hasBasic := r.BasicAuth()
|
|
||||||
if hasBasic {
|
|
||||||
_, isAuthenticated = Context.auth.findUser(user, pass)
|
|
||||||
if !isAuthenticated {
|
|
||||||
log.Info("auth: invalid Basic Authorization value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
res := Context.auth.checkSession(cookie.Value)
|
|
||||||
isAuthenticated = res == checkSessionOK
|
|
||||||
if !isAuthenticated {
|
|
||||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isAuthenticated {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if p := r.URL.Path; p == "/" || p == "/index.html" {
|
|
||||||
if glProcessRedirect(w, r) {
|
|
||||||
log.Debug("auth: redirected to login page by GL-Inet submodule")
|
|
||||||
} else {
|
|
||||||
log.Debug("auth: redirected to login page")
|
|
||||||
http.Redirect(w, r, "login.html", http.StatusFound)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debug("auth: responded with forbidden to %s %s", r.Method, p)
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
_, _ = w.Write([]byte("Forbidden"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(a.garipov): Use [http.Handler] consistently everywhere throughout the
|
|
||||||
// project.
|
|
||||||
func optionalAuth(
|
|
||||||
h func(http.ResponseWriter, *http.Request),
|
|
||||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
p := r.URL.Path
|
|
||||||
authRequired := Context.auth != nil && Context.auth.AuthRequired()
|
|
||||||
if p == "/login.html" {
|
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
|
||||||
if authRequired && err == nil {
|
|
||||||
// Redirect to the dashboard if already authenticated.
|
|
||||||
res := Context.auth.checkSession(cookie.Value)
|
|
||||||
if res == checkSessionOK {
|
|
||||||
http.Redirect(w, r, "", http.StatusFound)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("auth: invalid cookie value: %s", cookie)
|
|
||||||
}
|
|
||||||
} else if isPublicResource(p) {
|
|
||||||
// Process as usual, no additional auth requirements.
|
|
||||||
} else if authRequired {
|
|
||||||
if optionalAuthThird(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
h(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPublicResource returns true if p is a path to a public resource.
|
|
||||||
func isPublicResource(p string) (ok bool) {
|
|
||||||
isAsset, err := path.Match("/assets/*", p)
|
|
||||||
if err != nil {
|
|
||||||
// The only error that is returned from path.Match is
|
|
||||||
// [path.ErrBadPattern]. This is a programmer error.
|
|
||||||
panic(fmt.Errorf("bad asset pattern: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
isLogin, err := path.Match("/login.*", p)
|
|
||||||
if err != nil {
|
|
||||||
// Same as above.
|
|
||||||
panic(fmt.Errorf("bad login pattern: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return isAsset || isLogin
|
|
||||||
}
|
|
||||||
|
|
||||||
type authHandler struct {
|
|
||||||
handler http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
optionalAuth(a.handler.ServeHTTP)(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func optionalAuthHandler(handler http.Handler) http.Handler {
|
|
||||||
return &authHandler{handler}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a new user with the given password.
|
|
||||||
func (a *Auth) Add(u *webUser, password string) (err error) {
|
|
||||||
if len(password) == 0 {
|
if len(password) == 0 {
|
||||||
return errors.Error("empty password")
|
return errors.Error("empty password")
|
||||||
}
|
}
|
||||||
@@ -715,22 +368,40 @@ func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
|||||||
return webUser{}
|
return webUser{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsers - get users
|
// usersList returns a copy of a users list.
|
||||||
func (a *Auth) GetUsers() []webUser {
|
func (a *Auth) usersList() (users []webUser) {
|
||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
users := a.users
|
defer a.lock.Unlock()
|
||||||
a.lock.Unlock()
|
|
||||||
|
users = make([]webUser, len(a.users))
|
||||||
|
copy(users, a.users)
|
||||||
|
|
||||||
return users
|
return users
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequired - if authentication is required
|
// authRequired returns true if a authentication is required.
|
||||||
func (a *Auth) AuthRequired() bool {
|
func (a *Auth) authRequired() bool {
|
||||||
if GLMode {
|
if GLMode {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
a.lock.Lock()
|
a.lock.Lock()
|
||||||
r := (len(a.users) != 0)
|
defer a.lock.Unlock()
|
||||||
a.lock.Unlock()
|
|
||||||
return r
|
return len(a.users) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSessionToken returns cryptographically secure randomly generated slice of
|
||||||
|
// bytes of sessionTokenSize length.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||||
|
func newSessionToken() (data []byte, err error) {
|
||||||
|
randData := make([]byte, sessionTokenSize)
|
||||||
|
|
||||||
|
_, err = rand.Read(randData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return randData, nil
|
||||||
}
|
}
|
||||||
|
|||||||
89
internal/home/auth_internal_test.go
Normal file
89
internal/home/auth_internal_test.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSessionToken(t *testing.T) {
|
||||||
|
// Successful case.
|
||||||
|
token, err := newSessionToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, token, sessionTokenSize)
|
||||||
|
|
||||||
|
// Break the rand.Reader.
|
||||||
|
prevReader := rand.Reader
|
||||||
|
t.Cleanup(func() { rand.Reader = prevReader })
|
||||||
|
rand.Reader = &bytes.Buffer{}
|
||||||
|
|
||||||
|
// Unsuccessful case.
|
||||||
|
token, err = newSessionToken()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Empty(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
fn := filepath.Join(dir, "sessions.db")
|
||||||
|
|
||||||
|
users := []webUser{{
|
||||||
|
Name: "name",
|
||||||
|
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||||
|
}}
|
||||||
|
a := InitAuth(fn, nil, 60, nil)
|
||||||
|
s := session{}
|
||||||
|
|
||||||
|
user := webUser{Name: "name"}
|
||||||
|
err := a.addUser(&user, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||||
|
a.removeSession("notfound")
|
||||||
|
|
||||||
|
sess, err := newSessionToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
sessStr := hex.EncodeToString(sess)
|
||||||
|
|
||||||
|
now := time.Now().UTC().Unix()
|
||||||
|
// check expiration
|
||||||
|
s.expire = uint32(now)
|
||||||
|
a.addSession(sess, &s)
|
||||||
|
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||||
|
|
||||||
|
// add session with TTL = 2 sec
|
||||||
|
s = session{}
|
||||||
|
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||||
|
a.addSession(sess, &s)
|
||||||
|
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||||
|
|
||||||
|
a.Close()
|
||||||
|
|
||||||
|
// load saved session
|
||||||
|
a = InitAuth(fn, users, 60, nil)
|
||||||
|
|
||||||
|
// the session is still alive
|
||||||
|
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||||
|
// reset our expiration time because checkSession() has just updated it
|
||||||
|
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||||
|
a.storeSession(sess, &s)
|
||||||
|
a.Close()
|
||||||
|
|
||||||
|
u, ok := a.findUser("name", "password")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotEmpty(t, u.Name)
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
// load and remove expired sessions
|
||||||
|
a = InitAuth(fn, users, 60, nil)
|
||||||
|
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||||
|
|
||||||
|
a.Close()
|
||||||
|
}
|
||||||
352
internal/home/authhttp.go
Normal file
352
internal/home/authhttp.go
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
package home
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/httphdr"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cookieTTL is the time-to-live of the session cookie.
|
||||||
|
const cookieTTL = 365 * timeutil.Day
|
||||||
|
|
||||||
|
// sessionCookieName is the name of the session cookie.
|
||||||
|
const sessionCookieName = "agh_session"
|
||||||
|
|
||||||
|
// loginJSON is the JSON structure for authentication.
|
||||||
|
type loginJSON struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// newCookie creates a new authentication cookie.
|
||||||
|
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
|
||||||
|
rateLimiter := a.rateLimiter
|
||||||
|
u, ok := a.findUser(req.Name, req.Password)
|
||||||
|
if !ok {
|
||||||
|
if rateLimiter != nil {
|
||||||
|
rateLimiter.inc(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.Error("invalid username or password")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rateLimiter != nil {
|
||||||
|
rateLimiter.remove(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess, err := newSessionToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
a.addSession(sess, &session{
|
||||||
|
userName: u.Name,
|
||||||
|
expire: uint32(now.Unix()) + a.sessionTTL,
|
||||||
|
})
|
||||||
|
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: sessionCookieName,
|
||||||
|
Value: hex.EncodeToString(sess),
|
||||||
|
Path: "/",
|
||||||
|
Expires: now.Add(cookieTTL),
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// realIP extracts the real IP address of the client from an HTTP request using
|
||||||
|
// the known HTTP headers.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Currently, this is basically a copy of a similar function in
|
||||||
|
// module dnsproxy. This should really become a part of module golibs and be
|
||||||
|
// replaced both here and there. Or be replaced in both places by
|
||||||
|
// a well-maintained third-party module.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Support header Forwarded from RFC 7329.
|
||||||
|
func realIP(r *http.Request) (ip net.IP, err error) {
|
||||||
|
proxyHeaders := []string{
|
||||||
|
httphdr.CFConnectingIP,
|
||||||
|
httphdr.TrueClientIP,
|
||||||
|
httphdr.XRealIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, h := range proxyHeaders {
|
||||||
|
v := r.Header.Get(h)
|
||||||
|
ip = net.ParseIP(v)
|
||||||
|
if ip != nil {
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If none of the above yielded any results, get the leftmost IP address
|
||||||
|
// from the X-Forwarded-For header.
|
||||||
|
s := r.Header.Get(httphdr.XForwardedFor)
|
||||||
|
ipStrs := strings.SplitN(s, ", ", 2)
|
||||||
|
ip = net.ParseIP(ipStrs[0])
|
||||||
|
if ip != nil {
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// When everything else fails, just return the remote address as understood
|
||||||
|
// by the stdlib.
|
||||||
|
ipStr, err := netutil.SplitHost(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting ip from client addr: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ParseIP(ipStr), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeErrorWithIP is like [aghhttp.Error], but includes the remote IP address
|
||||||
|
// when it writes to the log.
|
||||||
|
func writeErrorWithIP(
|
||||||
|
r *http.Request,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
code int,
|
||||||
|
remoteIP string,
|
||||||
|
format string,
|
||||||
|
args ...any,
|
||||||
|
) {
|
||||||
|
text := fmt.Sprintf(format, args...)
|
||||||
|
log.Error("%s %s %s: from ip %s: %s", r.Method, r.Host, r.URL, remoteIP, text)
|
||||||
|
http.Error(w, text, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLogin is the handler for the POST /control/login HTTP API.
|
||||||
|
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
req := loginJSON{}
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var remoteIP string
|
||||||
|
// realIP cannot be used here without taking TrustedProxies into account due
|
||||||
|
// to security issues.
|
||||||
|
//
|
||||||
|
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Use realIP when the issue will be fixed.
|
||||||
|
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||||
|
writeErrorWithIP(
|
||||||
|
r,
|
||||||
|
w,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
r.RemoteAddr,
|
||||||
|
"auth: getting remote address: %s",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rateLimiter := Context.auth.rateLimiter; rateLimiter != nil {
|
||||||
|
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||||
|
w.Header().Set(httphdr.RetryAfter, strconv.Itoa(int(left.Seconds())))
|
||||||
|
writeErrorWithIP(
|
||||||
|
r,
|
||||||
|
w,
|
||||||
|
http.StatusTooManyRequests,
|
||||||
|
remoteIP,
|
||||||
|
"auth: blocked for %s",
|
||||||
|
left,
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie, err := Context.auth.newCookie(req, remoteIP)
|
||||||
|
if err != nil {
|
||||||
|
writeErrorWithIP(r, w, http.StatusForbidden, remoteIP, "%s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use realIP here, since this IP address is only used for logging.
|
||||||
|
ip, err := realIP(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("auth: user %q successfully logged in from ip %v", req.Name, ip)
|
||||||
|
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
|
||||||
|
h := w.Header()
|
||||||
|
h.Set(httphdr.CacheControl, "no-store, no-cache, must-revalidate, proxy-revalidate")
|
||||||
|
h.Set(httphdr.Pragma, "no-cache")
|
||||||
|
h.Set(httphdr.Expires, "0")
|
||||||
|
|
||||||
|
aghhttp.OK(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLogout is the handler for the GET /control/logout HTTP API.
|
||||||
|
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
respHdr := w.Header()
|
||||||
|
c, err := r.Cookie(sessionCookieName)
|
||||||
|
if err != nil {
|
||||||
|
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||||
|
// The user is already logged out.
|
||||||
|
respHdr.Set(httphdr.Location, "/login.html")
|
||||||
|
w.WriteHeader(http.StatusFound)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Context.auth.removeSession(c.Value)
|
||||||
|
|
||||||
|
c = &http.Cookie{
|
||||||
|
Name: sessionCookieName,
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
Expires: time.Unix(0, 0),
|
||||||
|
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}
|
||||||
|
|
||||||
|
respHdr.Set(httphdr.Location, "/login.html")
|
||||||
|
respHdr.Set(httphdr.SetCookie, c.String())
|
||||||
|
w.WriteHeader(http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterAuthHandlers - register handlers
|
||||||
|
func RegisterAuthHandlers() {
|
||||||
|
Context.mux.Handle("/control/login", postInstallHandler(ensureHandler(http.MethodPost, handleLogin)))
|
||||||
|
httpRegister(http.MethodGet, "/control/logout", handleLogout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionalAuthThird returns true if a user should authenticate first.
|
||||||
|
func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||||
|
pref := fmt.Sprintf("auth: raddr %s", r.RemoteAddr)
|
||||||
|
|
||||||
|
if glProcessCookie(r) {
|
||||||
|
log.Debug("%s: authentication is handled by gl-inet submodule", pref)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirect to login page if not authenticated
|
||||||
|
isAuthenticated := false
|
||||||
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
|
if err != nil {
|
||||||
|
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||||
|
// Check Basic authentication.
|
||||||
|
user, pass, hasBasic := r.BasicAuth()
|
||||||
|
if hasBasic {
|
||||||
|
_, isAuthenticated = Context.auth.findUser(user, pass)
|
||||||
|
if !isAuthenticated {
|
||||||
|
log.Info("%s: invalid basic authorization value", pref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res := Context.auth.checkSession(cookie.Value)
|
||||||
|
isAuthenticated = res == checkSessionOK
|
||||||
|
if !isAuthenticated {
|
||||||
|
log.Debug("%s: invalid cookie value: %q", pref, cookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAuthenticated {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if p := r.URL.Path; p == "/" || p == "/index.html" {
|
||||||
|
if glProcessRedirect(w, r) {
|
||||||
|
log.Debug("%s: redirected to login page by gl-inet submodule", pref)
|
||||||
|
} else {
|
||||||
|
log.Debug("%s: redirected to login page", pref)
|
||||||
|
http.Redirect(w, r, "login.html", http.StatusFound)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debug("%s: responded with forbidden to %s %s", pref, r.Method, p)
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
_, _ = w.Write([]byte("Forbidden"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(a.garipov): Use [http.Handler] consistently everywhere throughout the
|
||||||
|
// project.
|
||||||
|
func optionalAuth(
|
||||||
|
h func(http.ResponseWriter, *http.Request),
|
||||||
|
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
p := r.URL.Path
|
||||||
|
authRequired := Context.auth != nil && Context.auth.authRequired()
|
||||||
|
if p == "/login.html" {
|
||||||
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
|
if authRequired && err == nil {
|
||||||
|
// Redirect to the dashboard if already authenticated.
|
||||||
|
res := Context.auth.checkSession(cookie.Value)
|
||||||
|
if res == checkSessionOK {
|
||||||
|
http.Redirect(w, r, "", http.StatusFound)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("auth: raddr %s: invalid cookie value: %q", r.RemoteAddr, cookie)
|
||||||
|
}
|
||||||
|
} else if isPublicResource(p) {
|
||||||
|
// Process as usual, no additional auth requirements.
|
||||||
|
} else if authRequired {
|
||||||
|
if optionalAuthThird(w, r) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPublicResource returns true if p is a path to a public resource.
|
||||||
|
func isPublicResource(p string) (ok bool) {
|
||||||
|
isAsset, err := path.Match("/assets/*", p)
|
||||||
|
if err != nil {
|
||||||
|
// The only error that is returned from path.Match is
|
||||||
|
// [path.ErrBadPattern]. This is a programmer error.
|
||||||
|
panic(fmt.Errorf("bad asset pattern: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
isLogin, err := path.Match("/login.*", p)
|
||||||
|
if err != nil {
|
||||||
|
// Same as above.
|
||||||
|
panic(fmt.Errorf("bad login pattern: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return isAsset || isLogin
|
||||||
|
}
|
||||||
|
|
||||||
|
// authHandler is a helper structure that implements [http.Handler].
|
||||||
|
type authHandler struct {
|
||||||
|
handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements the [http.Handler] interface for *authHandler.
|
||||||
|
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
optionalAuth(a.handler.ServeHTTP)(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// optionalAuthHandler returns a authentication handler.
|
||||||
|
func optionalAuthHandler(handler http.Handler) http.Handler {
|
||||||
|
return &authHandler{handler}
|
||||||
|
}
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/httphdr"
|
"github.com/AdguardTeam/golibs/httphdr"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
@@ -18,82 +14,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewSessionToken(t *testing.T) {
|
|
||||||
// Successful case.
|
|
||||||
token, err := newSessionToken()
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, token, sessionTokenSize)
|
|
||||||
|
|
||||||
// Break the rand.Reader.
|
|
||||||
prevReader := rand.Reader
|
|
||||||
t.Cleanup(func() { rand.Reader = prevReader })
|
|
||||||
rand.Reader = &bytes.Buffer{}
|
|
||||||
|
|
||||||
// Unsuccessful case.
|
|
||||||
token, err = newSessionToken()
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Empty(t, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuth(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
fn := filepath.Join(dir, "sessions.db")
|
|
||||||
|
|
||||||
users := []webUser{{
|
|
||||||
Name: "name",
|
|
||||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
|
||||||
}}
|
|
||||||
a := InitAuth(fn, nil, 60, nil)
|
|
||||||
s := session{}
|
|
||||||
|
|
||||||
user := webUser{Name: "name"}
|
|
||||||
err := a.Add(&user, "password")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
|
||||||
a.RemoveSession("notfound")
|
|
||||||
|
|
||||||
sess, err := newSessionToken()
|
|
||||||
require.NoError(t, err)
|
|
||||||
sessStr := hex.EncodeToString(sess)
|
|
||||||
|
|
||||||
now := time.Now().UTC().Unix()
|
|
||||||
// check expiration
|
|
||||||
s.expire = uint32(now)
|
|
||||||
a.addSession(sess, &s)
|
|
||||||
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
|
||||||
|
|
||||||
// add session with TTL = 2 sec
|
|
||||||
s = session{}
|
|
||||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
|
||||||
a.addSession(sess, &s)
|
|
||||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
|
||||||
|
|
||||||
a.Close()
|
|
||||||
|
|
||||||
// load saved session
|
|
||||||
a = InitAuth(fn, users, 60, nil)
|
|
||||||
|
|
||||||
// the session is still alive
|
|
||||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
|
||||||
// reset our expiration time because checkSession() has just updated it
|
|
||||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
|
||||||
a.storeSession(sess, &s)
|
|
||||||
a.Close()
|
|
||||||
|
|
||||||
u, ok := a.findUser("name", "password")
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.NotEmpty(t, u.Name)
|
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
|
|
||||||
// load and remove expired sessions
|
|
||||||
a = InitAuth(fn, users, 60, nil)
|
|
||||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
|
||||||
|
|
||||||
a.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// implements http.ResponseWriter
|
// implements http.ResponseWriter
|
||||||
type testResponseWriter struct {
|
type testResponseWriter struct {
|
||||||
hdr http.Header
|
hdr http.Header
|
||||||
@@ -587,7 +587,7 @@ func (c *configuration) write() (err error) {
|
|||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
if Context.auth != nil {
|
if Context.auth != nil {
|
||||||
config.Users = Context.auth.GetUsers()
|
config.Users = Context.auth.usersList()
|
||||||
}
|
}
|
||||||
|
|
||||||
if Context.tls != nil {
|
if Context.tls != nil {
|
||||||
|
|||||||
@@ -420,7 +420,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
|||||||
u := &webUser{
|
u := &webUser{
|
||||||
Name: req.Username,
|
Name: req.Username,
|
||||||
}
|
}
|
||||||
err = Context.auth.Add(u, req.Password)
|
err = Context.auth.addUser(u, req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Context.firstRun = true
|
Context.firstRun = true
|
||||||
copyInstallSettings(config, curConfig)
|
copyInstallSettings(config, curConfig)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package ipset
|
package ipset
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -38,19 +39,69 @@ func newManager(ipsetConf []string) (set Manager, err error) {
|
|||||||
|
|
||||||
// defaultDial is the default netfilter dialing function.
|
// defaultDial is the default netfilter dialing function.
|
||||||
func defaultDial(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
|
func defaultDial(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) {
|
||||||
conn, err = ipset.Dial(pf, conf)
|
c, err := ipset.Dial(pf, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return &queryConn{c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryConn is the [ipsetConn] implementation with listAll method, which
|
||||||
|
// returns the list of properties of all available ipsets.
|
||||||
|
type queryConn struct {
|
||||||
|
*ipset.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ ipsetConn = (*queryConn)(nil)
|
||||||
|
|
||||||
|
// listAll returns the list of properties of all available ipsets.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Use https://github.com/vishvananda/netlink.
|
||||||
|
func (qc *queryConn) listAll() (sets []props, err error) {
|
||||||
|
msg, err := netfilter.MarshalNetlink(
|
||||||
|
netfilter.Header{
|
||||||
|
// The family doesn't seem to matter. See TODO on parseIpsetConfig.
|
||||||
|
Family: qc.Conn.Family,
|
||||||
|
SubsystemID: netfilter.NFSubsysIPSet,
|
||||||
|
MessageType: netfilter.MessageType(ipset.CmdList),
|
||||||
|
Flags: netlink.Request | netlink.Dump,
|
||||||
|
},
|
||||||
|
[]netfilter.Attribute{{
|
||||||
|
Type: uint16(ipset.AttrProtocol),
|
||||||
|
Data: []byte{ipset.Protocol},
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshaling netlink msg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We assume it's OK to call a method of an unexported type
|
||||||
|
// [ipset.connector], since there is no negative effects.
|
||||||
|
ms, err := qc.Conn.Conn.Query(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying netlink msg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range ms {
|
||||||
|
p := props{}
|
||||||
|
err = p.unmarshalMessage(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshaling netlink msg at index %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sets = append(sets, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipsetConn is the ipset conn interface.
|
// ipsetConn is the ipset conn interface.
|
||||||
type ipsetConn interface {
|
type ipsetConn interface {
|
||||||
Add(name string, entries ...*ipset.Entry) (err error)
|
Add(name string, entries ...*ipset.Entry) (err error)
|
||||||
Close() (err error)
|
Close() (err error)
|
||||||
Header(name string) (p *ipset.HeaderPolicy, err error)
|
listAll() (sets []props, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialer creates an ipsetConn.
|
// dialer creates an ipsetConn.
|
||||||
@@ -58,8 +109,75 @@ type dialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn
|
|||||||
|
|
||||||
// props contains one Linux Netfilter ipset properties.
|
// props contains one Linux Netfilter ipset properties.
|
||||||
type props struct {
|
type props struct {
|
||||||
name string
|
// name of the ipset.
|
||||||
|
name string
|
||||||
|
|
||||||
|
// family of the IP addresses in the ipset.
|
||||||
family netfilter.ProtoFamily
|
family netfilter.ProtoFamily
|
||||||
|
|
||||||
|
// isPersistent indicates that ipset has no timeout parameter and all
|
||||||
|
// entries are added permanently.
|
||||||
|
isPersistent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalMessage unmarshals netlink message and sets the properties of the
|
||||||
|
// ipset.
|
||||||
|
func (p *props) unmarshalMessage(msg netlink.Message) (err error) {
|
||||||
|
_, attrs, err := netfilter.UnmarshalNetlink(msg)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// By default ipset has no timeout parameter.
|
||||||
|
p.isPersistent = true
|
||||||
|
|
||||||
|
for _, a := range attrs {
|
||||||
|
p.parseAttribute(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAttribute parses netfilter attribute and sets the name and family of
|
||||||
|
// the ipset.
|
||||||
|
func (p *props) parseAttribute(a netfilter.Attribute) {
|
||||||
|
switch ipset.AttributeType(a.Type) {
|
||||||
|
case ipset.AttrData:
|
||||||
|
p.parseAttrData(a)
|
||||||
|
case ipset.AttrSetName:
|
||||||
|
// Trim the null character.
|
||||||
|
p.name = string(bytes.Trim(a.Data, "\x00"))
|
||||||
|
case ipset.AttrFamily:
|
||||||
|
p.family = netfilter.ProtoFamily(a.Data[0])
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAttrData parses attribute data and sets the timeout of the ipset.
|
||||||
|
func (p *props) parseAttrData(a netfilter.Attribute) {
|
||||||
|
for _, a := range a.Children {
|
||||||
|
switch ipset.AttributeType(a.Type) {
|
||||||
|
case ipset.AttrTimeout:
|
||||||
|
timeout := a.Uint32()
|
||||||
|
p.isPersistent = timeout == 0
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unit is a convenient alias for struct{}.
|
||||||
|
type unit = struct{}
|
||||||
|
|
||||||
|
// ipsInIpset is the type of a set of IP-address-to-ipset mappings.
|
||||||
|
type ipsInIpset map[ipInIpsetEntry]unit
|
||||||
|
|
||||||
|
// ipInIpsetEntry is the type for entries in an ipsInIpset set.
|
||||||
|
type ipInIpsetEntry struct {
|
||||||
|
ipsetName string
|
||||||
|
ipArr [net.IPv6len]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// manager is the Linux Netfilter ipset manager.
|
// manager is the Linux Netfilter ipset manager.
|
||||||
@@ -72,6 +190,13 @@ type manager struct {
|
|||||||
// mu protects all properties below.
|
// mu protects all properties below.
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
|
|
||||||
|
// TODO(a.garipov): Currently, the ipset list is static, and we don't
|
||||||
|
// read the IPs already in sets, so we can assume that all incoming IPs
|
||||||
|
// are either added to all corresponding ipsets or not. When that stops
|
||||||
|
// being the case, for example if we add dynamic reconfiguration of
|
||||||
|
// ipsets, this map will need to become a per-ipset-name one.
|
||||||
|
addedIPs ipsInIpset
|
||||||
|
|
||||||
ipv4Conn ipsetConn
|
ipv4Conn ipsetConn
|
||||||
ipv6Conn ipsetConn
|
ipv6Conn ipsetConn
|
||||||
}
|
}
|
||||||
@@ -96,8 +221,8 @@ func (m *manager) dialNetfilter(conf *netlink.Config) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseIpsetConfig parses one ipset configuration string.
|
// parseIpsetConfigLine parses one ipset configuration line.
|
||||||
func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) {
|
func parseIpsetConfigLine(confStr string) (hosts, ipsetNames []string, err error) {
|
||||||
confStr = strings.TrimSpace(confStr)
|
confStr = strings.TrimSpace(confStr)
|
||||||
hostsAndNames := strings.Split(confStr, "/")
|
hostsAndNames := strings.Split(confStr, "/")
|
||||||
if len(hostsAndNames) != 2 {
|
if len(hostsAndNames) != 2 {
|
||||||
@@ -125,50 +250,53 @@ func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) {
|
|||||||
return hosts, ipsetNames, nil
|
return hosts, ipsetNames, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipsetProps returns the properties of an ipset with the given name.
|
// parseIpsetConfig parses the ipset configuration and stores ipsets. It
|
||||||
func (m *manager) ipsetProps(name string) (set props, err error) {
|
// returns an error if the configuration can't be used.
|
||||||
// The family doesn't seem to matter when we use a header query, so
|
func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) {
|
||||||
// query only the IPv4 one.
|
// The family doesn't seem to matter when we use a header query, so query
|
||||||
|
// only the IPv4 one.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Find out if this is a bug or a feature.
|
// TODO(a.garipov): Find out if this is a bug or a feature.
|
||||||
var res *ipset.HeaderPolicy
|
all, err := m.ipv4Conn.listAll()
|
||||||
res, err = m.ipv4Conn.Header(name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return set, err
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if res == nil || res.Family == nil {
|
for _, p := range all {
|
||||||
return set, errors.Error("empty response or no family data")
|
m.nameToIpset[p.name] = p
|
||||||
}
|
}
|
||||||
|
|
||||||
family := netfilter.ProtoFamily(res.Family.Value)
|
for i, confStr := range ipsetConf {
|
||||||
if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 {
|
var hosts, ipsetNames []string
|
||||||
return set, fmt.Errorf("unexpected ipset family %d", family)
|
hosts, ipsetNames, err = parseIpsetConfigLine(confStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("config line at idx %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipsets []props
|
||||||
|
ipsets, err = m.ipsets(ipsetNames)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting ipsets from config line at idx %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, host := range hosts {
|
||||||
|
m.domainToIpsets[host] = append(m.domainToIpsets[host], ipsets...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return props{
|
return nil
|
||||||
name: name,
|
|
||||||
family: family,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipsets returns currently known ipsets.
|
// ipsets returns currently known ipsets.
|
||||||
func (m *manager) ipsets(names []string) (sets []props, err error) {
|
func (m *manager) ipsets(names []string) (sets []props, err error) {
|
||||||
for _, name := range names {
|
for _, n := range names {
|
||||||
set, ok := m.nameToIpset[name]
|
p, ok := m.nameToIpset[n]
|
||||||
if ok {
|
if !ok {
|
||||||
sets = append(sets, set)
|
return nil, fmt.Errorf("unknown ipset %q", n)
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
set, err = m.ipsetProps(name)
|
sets = append(sets, p)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("querying ipset %q: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.nameToIpset[name] = set
|
|
||||||
sets = append(sets, set)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sets, nil
|
return sets, nil
|
||||||
@@ -186,6 +314,8 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err
|
|||||||
domainToIpsets: make(map[string][]props),
|
domainToIpsets: make(map[string][]props),
|
||||||
|
|
||||||
dial: dial,
|
dial: dial,
|
||||||
|
|
||||||
|
addedIPs: make(ipsInIpset),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.dialNetfilter(&netlink.Config{})
|
err = m.dialNetfilter(&netlink.Config{})
|
||||||
@@ -201,26 +331,9 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err
|
|||||||
return nil, fmt.Errorf("dialing netfilter: %w", err)
|
return nil, fmt.Errorf("dialing netfilter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, confStr := range ipsetConf {
|
err = m.parseIpsetConfig(ipsetConf)
|
||||||
var hosts, ipsetNames []string
|
if err != nil {
|
||||||
hosts, ipsetNames, err = parseIpsetConfig(confStr)
|
return nil, fmt.Errorf("getting ipsets: %w", err)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("config line at idx %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ipsets []props
|
|
||||||
ipsets, err = m.ipsets(ipsetNames)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf(
|
|
||||||
"getting ipsets from config line at idx %d: %w",
|
|
||||||
i,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, host := range hosts {
|
|
||||||
m.domainToIpsets[host] = append(m.domainToIpsets[host], ipsets...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
@@ -259,8 +372,19 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
|||||||
}
|
}
|
||||||
|
|
||||||
var entries []*ipset.Entry
|
var entries []*ipset.Entry
|
||||||
|
var newAddedEntries []ipInIpsetEntry
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
|
e := ipInIpsetEntry{
|
||||||
|
ipsetName: set.name,
|
||||||
|
}
|
||||||
|
copy(e.ipArr[:], ip.To16())
|
||||||
|
|
||||||
|
if _, added := m.addedIPs[e]; added {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
|
entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip)))
|
||||||
|
newAddedEntries = append(newAddedEntries, e)
|
||||||
}
|
}
|
||||||
|
|
||||||
n = len(entries)
|
n = len(entries)
|
||||||
@@ -283,6 +407,15 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error
|
|||||||
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
|
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only add these to the cache once we're sure that all of them were
|
||||||
|
// actually sent to the ipset.
|
||||||
|
for _, e := range newAddedEntries {
|
||||||
|
s := m.nameToIpset[e.ipsetName]
|
||||||
|
if s.isPersistent {
|
||||||
|
m.addedIPs[e] = unit{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,12 @@ type fakeConn struct {
|
|||||||
ipv4Entries *[]*ipset.Entry
|
ipv4Entries *[]*ipset.Entry
|
||||||
ipv6Header *ipset.HeaderPolicy
|
ipv6Header *ipset.HeaderPolicy
|
||||||
ipv6Entries *[]*ipset.Entry
|
ipv6Entries *[]*ipset.Entry
|
||||||
|
sets []props
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ ipsetConn = (*fakeConn)(nil)
|
||||||
|
|
||||||
// Add implements the [ipsetConn] interface for *fakeConn.
|
// Add implements the [ipsetConn] interface for *fakeConn.
|
||||||
func (c *fakeConn) Add(name string, entries ...*ipset.Entry) (err error) {
|
func (c *fakeConn) Add(name string, entries ...*ipset.Entry) (err error) {
|
||||||
if strings.Contains(name, "ipv4") {
|
if strings.Contains(name, "ipv4") {
|
||||||
@@ -43,15 +47,9 @@ func (c *fakeConn) Close() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header implements the [ipsetConn] interface for *fakeConn.
|
// listAll implements the [ipsetConn] interface for *fakeConn.
|
||||||
func (c *fakeConn) Header(name string) (p *ipset.HeaderPolicy, err error) {
|
func (c *fakeConn) listAll() (sets []props, err error) {
|
||||||
if strings.Contains(name, "ipv4") {
|
return c.sets, nil
|
||||||
return c.ipv4Header, nil
|
|
||||||
} else if strings.Contains(name, "ipv6") {
|
|
||||||
return c.ipv6Header, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.Error("test: ipset not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Add(t *testing.T) {
|
func TestManager_Add(t *testing.T) {
|
||||||
@@ -76,6 +74,13 @@ func TestManager_Add(t *testing.T) {
|
|||||||
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv6)),
|
Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv6)),
|
||||||
},
|
},
|
||||||
ipv6Entries: &ipv6Entries,
|
ipv6Entries: &ipv6Entries,
|
||||||
|
sets: []props{{
|
||||||
|
name: "ipv4set",
|
||||||
|
family: netfilter.ProtoIPv4,
|
||||||
|
}, {
|
||||||
|
name: "ipv6set",
|
||||||
|
family: netfilter.ProtoIPv6,
|
||||||
|
}},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ set -f -u
|
|||||||
go_version="$( "${GO:-go}" version )"
|
go_version="$( "${GO:-go}" version )"
|
||||||
readonly go_version
|
readonly go_version
|
||||||
|
|
||||||
go_min_version='go1.20.10'
|
go_min_version='go1.20.11'
|
||||||
go_version_msg="
|
go_version_msg="
|
||||||
warning: your go version (${go_version}) is different from the recommended minimal one (${go_min_version}).
|
warning: your go version (${go_version}) is different from the recommended minimal one (${go_min_version}).
|
||||||
if you have the version installed, please set the GO environment variable.
|
if you have the version installed, please set the GO environment variable.
|
||||||
|
|||||||
Reference in New Issue
Block a user