Compare commits
11 Commits
v0.108.0-b
...
AGDNS-2743
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53cb84efc0 | ||
|
|
c7c62ad3b6 | ||
|
|
003e7ce0d5 | ||
|
|
a8fdf1c553 | ||
|
|
7d479baba6 | ||
|
|
feb9c886d8 | ||
|
|
3521e8ed9f | ||
|
|
4d258972d1 | ||
|
|
9726171f0f | ||
|
|
6d282ae716 | ||
|
|
6a99c39d11 |
46
CHANGELOG.md
46
CHANGELOG.md
@@ -9,19 +9,39 @@ The format is based on [*Keep a Changelog*](https://keepachangelog.com/en/1.0.0/
|
||||
<!--
|
||||
## [v0.108.0] – TBA
|
||||
|
||||
## [v0.107.60] - 2025-04-01 (APPROX.)
|
||||
## [v0.107.61] - 2025-04-22 (APPROX.)
|
||||
|
||||
See also the [v0.107.60 GitHub milestone][ms-v0.107.60].
|
||||
See also the [v0.107.61 GitHub milestone][ms-v0.107.61].
|
||||
|
||||
[ms-v0.107.60]: https://github.com/AdguardTeam/AdGuardHome/milestone/95?closed=1
|
||||
[ms-v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/milestone/96?closed=1
|
||||
|
||||
NOTE: Add new changes BELOW THIS COMMENT.
|
||||
-->
|
||||
|
||||
### Security
|
||||
|
||||
- Any simultaneous requests that are considered duplicates will now only result in a single request to upstreams, reducing the chance of a cache poisoning attack succeeding. This is controlled by the new configuration object `pending_requests`, which has a single `enabled` property, set to `true` by default.
|
||||
|
||||
**NOTE:** We thank [Xiang Li][mr-xiang-li] for reporting this security issue. It's strongly recommended to leave it enabled, otherwise AdGuard Home will be vulnerable to untrusted clients.
|
||||
|
||||
[mr-xiang-li]: https://lixiang521.com/
|
||||
|
||||
<!--
|
||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||
-->
|
||||
|
||||
## [v0.107.60] - 2025-04-14
|
||||
|
||||
See also the [v0.107.60 GitHub milestone][ms-v0.107.60].
|
||||
|
||||
### Security
|
||||
|
||||
- Go version has been updated to prevent the possibility of exploiting the Go vulnerabilities fixed in [1.24.2][go-1.24.2].
|
||||
|
||||
### Changed
|
||||
|
||||
- Alpine Linux version in `Dockerfile` has been updated to 3.21 ([#7588]).
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Node 20 support, Node 22 will be required in future releases.
|
||||
@@ -32,19 +52,20 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
||||
|
||||
- Filtering for DHCP clients ([#7734]).
|
||||
|
||||
- Incorrect label on login page ([#7729]).
|
||||
|
||||
- Validation process for the HTTPS port on the *Encryption Settings* page.
|
||||
|
||||
### Removed
|
||||
|
||||
- Node 18 support.
|
||||
|
||||
[#7588]: https://github.com/AdguardTeam/AdGuardHome/issues/7588
|
||||
[#7729]: https://github.com/AdguardTeam/AdGuardHome/issues/7729
|
||||
[#7734]: https://github.com/AdguardTeam/AdGuardHome/issues/7734
|
||||
|
||||
[go-1.24.2]: https://groups.google.com/g/golang-announce/c/Y2uBTVKjBQk
|
||||
|
||||
<!--
|
||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||
-->
|
||||
[go-1.24.2]: https://groups.google.com/g/golang-announce/c/Y2uBTVKjBQk
|
||||
[ms-v0.107.60]: https://github.com/AdguardTeam/AdGuardHome/milestone/95?closed=1
|
||||
|
||||
## [v0.107.59] - 2025-03-21
|
||||
|
||||
@@ -52,6 +73,8 @@ See also the [v0.107.59 GitHub milestone][ms-v0.107.59].
|
||||
|
||||
### Fixed
|
||||
|
||||
- Validation process for the DNS-over-TLS, DNS-over-QUIC, and HTTPS ports on the *Encryption Settings* page.
|
||||
|
||||
- Rules with the `client` modifier not working ([#7708]).
|
||||
|
||||
- The search form not working in the query log ([#7704]).
|
||||
@@ -3092,11 +3115,12 @@ See also the [v0.104.2 GitHub milestone][ms-v0.104.2].
|
||||
[ms-v0.104.2]: https://github.com/AdguardTeam/AdGuardHome/milestone/28?closed=1
|
||||
|
||||
<!--
|
||||
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.60...HEAD
|
||||
[v0.107.60]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.59...v0.107.60
|
||||
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.61...HEAD
|
||||
[v0.107.61]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.60...v0.107.61
|
||||
-->
|
||||
|
||||
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.59...HEAD
|
||||
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.60...HEAD
|
||||
[v0.107.60]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.59...v0.107.60
|
||||
[v0.107.59]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.58...v0.107.59
|
||||
[v0.107.58]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.57...v0.107.58
|
||||
[v0.107.57]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.56...v0.107.57
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# Make sure to sync any changes with the branch overrides below.
|
||||
'variables':
|
||||
'channel': 'edge'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.0'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.1'
|
||||
'dockerGo': 'adguard/go-builder:1.24.2--1'
|
||||
|
||||
'stages':
|
||||
@@ -157,6 +157,7 @@
|
||||
|
||||
# Print Docker info.
|
||||
docker info
|
||||
docker buildx version
|
||||
|
||||
# Prepare and push the build.
|
||||
env \
|
||||
@@ -277,7 +278,7 @@
|
||||
# need to build a few of these.
|
||||
'variables':
|
||||
'channel': 'beta'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.0'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.1'
|
||||
'dockerGo': 'adguard/go-builder:1.24.2--1'
|
||||
# release-vX.Y.Z branches are the branches from which the actual final
|
||||
# release is built.
|
||||
@@ -293,5 +294,5 @@
|
||||
# are the ones that actually get released.
|
||||
'variables':
|
||||
'channel': 'release'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.0'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.1'
|
||||
'dockerGo': 'adguard/go-builder:1.24.2--1'
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
'key': 'AHBRTSPECS'
|
||||
'name': 'AdGuard Home - Build and run tests'
|
||||
'variables':
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.0'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.1'
|
||||
'dockerGo': 'adguard/go-builder:1.24.2--1'
|
||||
'channel': 'development'
|
||||
|
||||
@@ -233,6 +233,6 @@
|
||||
# Set the default release channel on the release branch to beta, as we
|
||||
# may need to build a few of these.
|
||||
'variables':
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.0'
|
||||
'dockerFrontend': 'adguard/home-js-builder:3.1'
|
||||
'dockerGo': 'adguard/go-builder:1.24.2--1'
|
||||
'channel': 'candidate'
|
||||
|
||||
1207
client/package-lock.json
generated
vendored
1207
client/package-lock.json
generated
vendored
File diff suppressed because it is too large
Load Diff
4
client/package.json
vendored
4
client/package.json
vendored
@@ -66,7 +66,7 @@
|
||||
"@babel/preset-react": "^7.24.1",
|
||||
"@playwright/test": "1.50.1",
|
||||
"@types/lodash": "^4.17.4",
|
||||
"@types/node": "^22.10.2",
|
||||
"@types/node": "^22.13.10",
|
||||
"@types/react": "^17.0.80",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/react-redux": "^7.1.33",
|
||||
@@ -99,7 +99,7 @@
|
||||
"stylelint": "^16.5.0",
|
||||
"ts-loader": "^9.5.1",
|
||||
"url-loader": "^4.1.1",
|
||||
"vitest": "^3.0.4",
|
||||
"vitest": "^3.1.1",
|
||||
"webpack": "^5.91.0",
|
||||
"webpack-cli": "^5.1.4",
|
||||
"webpack-dev-server": "^5.0.4",
|
||||
|
||||
@@ -45,6 +45,7 @@
|
||||
"filter": "Филтър",
|
||||
"query_log": "История на заявките",
|
||||
"compact": "Compact",
|
||||
"nothing_found": "Нищо не е намерено",
|
||||
"faq": "ЧЗВ",
|
||||
"version": "версия",
|
||||
"address": "Адрес",
|
||||
@@ -65,14 +66,12 @@
|
||||
"stats_malware_phishing": "вируси/атаки",
|
||||
"stats_adult": "сайтове за възрастни",
|
||||
"stats_query_domain": "Най-отваряни страници",
|
||||
"for_last_24_hours": "за последните 24 часа",
|
||||
"no_domains_found": "Няма намерени резултати",
|
||||
"requests_count": "Сума на заявките",
|
||||
"top_blocked_domains": "Най-блокирани страници",
|
||||
"top_clients": "Най-активни IP адреси",
|
||||
"no_clients_found": "Нямa намерени адреси",
|
||||
"general_statistics": "Обща статисика",
|
||||
"number_of_dns_query_24_hours": "Сума на DNS заявки за последните 24 часа",
|
||||
"number_of_dns_query_blocked_24_hours": "Сума на блокирани DNS заявки от филтрите за реклама и местни",
|
||||
"number_of_dns_query_blocked_24_hours_by_sec": "Сума на блокирани DNS заявки от AdGuard свързани със сигурността",
|
||||
"number_of_dns_query_blocked_24_hours_adult": "Сума на блокирани сайтове за възрастни",
|
||||
@@ -156,6 +155,7 @@
|
||||
"rule_added_to_custom_filtering_toast": "Добавено до местни правила за филтриране: {{rule}}",
|
||||
"default": "По подразбиране",
|
||||
"custom_ip": "Персонализиран IP",
|
||||
"dnscrypt": "DNSCrypt",
|
||||
"dns_over_https": "DNS-пред-HTTPS",
|
||||
"dns_over_quic": "DNS-over-QUIC",
|
||||
"plain_dns": "Обикновен DNS",
|
||||
|
||||
@@ -110,9 +110,9 @@
|
||||
"homepage": "Startpagina",
|
||||
"report_an_issue": "Rapporteer een probleem",
|
||||
"privacy_policy": "Privacybeleid",
|
||||
"enable_protection": "Schakel bescherming in",
|
||||
"enable_protection": "Bescherming inschakelen",
|
||||
"enabled_protection": "Bescherming ingeschakeld",
|
||||
"disable_protection": "Schakel bescherming uit",
|
||||
"disable_protection": "Bescherming uitschakelen",
|
||||
"disabled_protection": "Bescherming uitgeschakeld",
|
||||
"refresh_statics": "Ververs statistieken",
|
||||
"dns_query": "DNS-queries",
|
||||
@@ -702,13 +702,13 @@
|
||||
"disable_for_hours": "Voor {{count}} uur",
|
||||
"disable_for_hours_plural": "Voor {{count}} uren",
|
||||
"disable_until_tomorrow": "Tot morgen",
|
||||
"disable_notify_for_seconds": "Beveiliging uitschakelen voor {{count}} seconde",
|
||||
"disable_notify_for_seconds_plural": "Beveiliging uitschakelen voor {{count}} seconden",
|
||||
"disable_notify_for_minutes": "Beveiliging uitschakelen voor {{count}} minuut",
|
||||
"disable_notify_for_minutes_plural": "Beveiliging uitschakelen voor {{count}} minuten",
|
||||
"disable_notify_for_hours": "Beveiliging uitschakelen voor {{count}} uur",
|
||||
"disable_notify_for_hours_plural": "Beveiliging uitschakelen voor {{count}} uren",
|
||||
"disable_notify_until_tomorrow": "Beveiliging uitschakelen tot morgen",
|
||||
"disable_notify_for_seconds": "Bescherming uitschakelen voor {{count}} seconde",
|
||||
"disable_notify_for_seconds_plural": "Bescherming uitschakelen voor {{count}} seconden",
|
||||
"disable_notify_for_minutes": "Bescherming uitschakelen voor {{count}} minuut",
|
||||
"disable_notify_for_minutes_plural": "Bescherming uitschakelen voor {{count}} minuten",
|
||||
"disable_notify_for_hours": "Bescherming uitschakelen voor {{count}} uur",
|
||||
"disable_notify_for_hours_plural": "Bescherming uitschakelen voor {{count}} uren",
|
||||
"disable_notify_until_tomorrow": "Bescherming uitschakelen tot morgen",
|
||||
"enable_protection_timer": "Bescherming wordt ingeschakeld over {{time}}",
|
||||
"custom_retention_input": "Voer retentie in uren in",
|
||||
"custom_rotation_input": "Voer rotatie in uren in",
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
"custom_ip": "Tilpasset IP",
|
||||
"blocking_ipv4": "IPv4-blokkering",
|
||||
"blocking_ipv6": "IPv6-blokkering",
|
||||
"blocked_response_ttl": "Blokkert svar TTL",
|
||||
"blocked_response_ttl": "Blokkerte svars TTL",
|
||||
"dnscrypt": "DNSCrypt",
|
||||
"dns_over_https": "DNS-over-HTTPS",
|
||||
"dns_over_tls": "DNS-over-TLS",
|
||||
|
||||
@@ -78,6 +78,7 @@ class CustomRules extends Component<CustomRulesProps> {
|
||||
<form onSubmit={this.handleSubmit}>
|
||||
<div className="text-edit-container mb-4">
|
||||
<textarea
|
||||
data-testid="custom_rule_textarea"
|
||||
className="form-control font-monospace text-input"
|
||||
value={userRules}
|
||||
onChange={this.handleChange}
|
||||
@@ -91,6 +92,7 @@ class CustomRules extends Component<CustomRulesProps> {
|
||||
|
||||
<div className="card-actions">
|
||||
<button
|
||||
data-testid="apply_custom_rule"
|
||||
className="btn btn-success btn-standard btn-large"
|
||||
type="submit"
|
||||
onClick={this.handleSubmit}>
|
||||
|
||||
@@ -59,7 +59,7 @@ const Header = () => {
|
||||
<div className="header__column">
|
||||
<div className="header__right">
|
||||
{!processingProfile && name && (
|
||||
<a href="control/logout" className="btn btn-sm btn-outline-secondary">
|
||||
<a href="control/logout" className="btn btn-sm btn-outline-secondary" data-testid="sign_out">
|
||||
{t('sign_out')}
|
||||
</a>
|
||||
)}
|
||||
|
||||
@@ -288,7 +288,7 @@ const Row = memo(
|
||||
);
|
||||
|
||||
return (
|
||||
<div style={style} className={className} onClick={onClick} role="row">
|
||||
<div style={style} className={className} onClick={onClick} role="row" data-testid="querylog_cell">
|
||||
<DateCell {...rowProps} />
|
||||
|
||||
<DomainCell {...rowProps} />
|
||||
|
||||
@@ -84,6 +84,7 @@ export const Form = ({ className, setIsLoading }: Props) => {
|
||||
}}>
|
||||
<div className="field__search">
|
||||
<SearchField
|
||||
data-testid="querylog_search"
|
||||
value={searchValue}
|
||||
handleChange={(val) => setValue('search', val)}
|
||||
onKeyDown={onEnterPress}
|
||||
|
||||
@@ -27,12 +27,14 @@ const SETTINGS = {
|
||||
enabled: false,
|
||||
title: i18next.t('use_adguard_browsing_sec'),
|
||||
subtitle: i18next.t('use_adguard_browsing_sec_hint'),
|
||||
testId: 'safebrowsing',
|
||||
[ORDER_KEY]: 0,
|
||||
},
|
||||
parental: {
|
||||
enabled: false,
|
||||
title: i18next.t('use_adguard_parental'),
|
||||
subtitle: i18next.t('use_adguard_parental_hint'),
|
||||
testId: 'parental',
|
||||
[ORDER_KEY]: 1,
|
||||
},
|
||||
};
|
||||
@@ -90,11 +92,12 @@ class Settings extends Component<SettingsProps> {
|
||||
renderSettings = (settings: any) =>
|
||||
getObjectKeysSorted(SETTINGS, ORDER_KEY).map((key: any) => {
|
||||
const setting = settings[key];
|
||||
const { enabled, title, subtitle } = setting;
|
||||
const { enabled, title, subtitle, testId } = setting;
|
||||
|
||||
return (
|
||||
<div key={key} className="form__group form__group--checkbox">
|
||||
<Checkbox
|
||||
data-testid={testId}
|
||||
value={enabled}
|
||||
title={title}
|
||||
subtitle={subtitle}
|
||||
@@ -118,6 +121,7 @@ class Settings extends Component<SettingsProps> {
|
||||
<>
|
||||
<div className="form__group form__group--checkbox">
|
||||
<Checkbox
|
||||
data-testid="safesearch"
|
||||
value={enabled}
|
||||
title={i18next.t('enforce_safe_search')}
|
||||
subtitle={i18next.t('enforce_save_search_hint')}
|
||||
|
||||
@@ -94,14 +94,17 @@ const Footer = () => {
|
||||
auto: {
|
||||
desc: t('theme_auto_desc'),
|
||||
icon: '#auto',
|
||||
testId: 'theme_auto',
|
||||
},
|
||||
dark: {
|
||||
desc: t('theme_dark_desc'),
|
||||
icon: '#dark',
|
||||
testId: 'theme_dark',
|
||||
},
|
||||
light: {
|
||||
desc: t('theme_light_desc'),
|
||||
icon: '#light',
|
||||
testId: 'theme_light',
|
||||
},
|
||||
};
|
||||
|
||||
@@ -113,7 +116,9 @@ const Footer = () => {
|
||||
type="button"
|
||||
className="btn btn-sm btn-secondary footer__theme-button"
|
||||
onClick={() => onThemeChange(theme)}
|
||||
title={content[theme].desc}>
|
||||
title={content[theme].desc}
|
||||
data-testid={content[theme].testId}
|
||||
>
|
||||
<svg className={cn('footer__theme-icon', { 'footer__theme-icon--active': currentValue === theme })}>
|
||||
<use xlinkHref={content[theme].icon} />
|
||||
</svg>
|
||||
|
||||
@@ -28,6 +28,12 @@ export default {
|
||||
"homepage": "https://badmojr.github.io/1Hosts/",
|
||||
"source": "https://adguardteam.github.io/HostlistsRegistry/assets/filter_24.txt"
|
||||
},
|
||||
"1hosts_pro": {
|
||||
"name": "1Hosts (Pro)",
|
||||
"categoryId": "general",
|
||||
"homepage": "https://badmojr.github.io/1Hosts/",
|
||||
"source": "https://adguardteam.github.io/HostlistsRegistry/assets/filter_64.txt"
|
||||
},
|
||||
"CHN_adrules": {
|
||||
"name": "CHN: AdRules DNS List",
|
||||
"categoryId": "regional",
|
||||
|
||||
34
client/tests/e2e/control-panel.spec.ts
Normal file
34
client/tests/e2e/control-panel.spec.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { ADMIN_USERNAME, ADMIN_PASSWORD } from '../constants';
|
||||
|
||||
test.describe('Control Panel', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.goto('/login.html');
|
||||
await page.getByTestId('username').click();
|
||||
await page.getByTestId('username').fill(ADMIN_USERNAME);
|
||||
await page.getByTestId('password').click();
|
||||
await page.getByTestId('password').fill(ADMIN_PASSWORD);
|
||||
await page.keyboard.press('Tab');
|
||||
await page.getByTestId('sign_in').click();
|
||||
await page.waitForURL((url) => !url.href.endsWith('/login.html'));
|
||||
});
|
||||
|
||||
test('should sign out successfully', async ({ page }) => {
|
||||
await page.getByTestId('sign_out').click();
|
||||
|
||||
await page.waitForURL((url) => url.href.endsWith('/login.html'));
|
||||
|
||||
await expect(page.getByTestId('sign_in')).toBeVisible();
|
||||
});
|
||||
|
||||
test('should change theme to dark and then light', async ({ page }) => {
|
||||
await page.getByTestId('theme_dark').click();
|
||||
|
||||
await expect(page.locator('body[data-theme="dark"]')).toBeVisible();
|
||||
|
||||
|
||||
await page.getByTestId('theme_light').click();
|
||||
|
||||
await expect(page.locator('body:not([data-theme="dark"])')).toBeVisible();
|
||||
});
|
||||
});
|
||||
52
client/tests/e2e/dns-settings.spec.ts
Normal file
52
client/tests/e2e/dns-settings.spec.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
import { test, expect, type Page } from '@playwright/test';
|
||||
import { ADMIN_USERNAME, ADMIN_PASSWORD } from '../constants';
|
||||
|
||||
test.describe('DNS Settings', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
// Login before each test
|
||||
await page.goto('/login.html');
|
||||
await page.getByTestId('username').click();
|
||||
await page.getByTestId('username').fill(ADMIN_USERNAME);
|
||||
await page.getByTestId('password').click();
|
||||
await page.getByTestId('password').fill(ADMIN_PASSWORD);
|
||||
await page.keyboard.press('Tab');
|
||||
await page.getByTestId('sign_in').click();
|
||||
await page.waitForURL((url) => !url.href.endsWith('/login.html'));
|
||||
});
|
||||
|
||||
const runDNSSettingsTest = async (page: Page, address: string) => {
|
||||
await page.goto('/#dns');
|
||||
|
||||
const currentDns = await page.getByTestId('upstream_dns').inputValue();
|
||||
|
||||
await page.getByTestId('upstream_dns').fill(address);
|
||||
await page.getByTestId('dns_upstream_test').click();
|
||||
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
await expect(page.getByTestId('upstream_dns')).toHaveValue(address);
|
||||
|
||||
await page.getByTestId('upstream_dns').fill(currentDns);
|
||||
await page.getByTestId('dns_upstream_save').click({ force: true });
|
||||
};
|
||||
|
||||
test('test for Default DNS', async ({ page }) => {
|
||||
await runDNSSettingsTest(page, 'https://dns10.quad9.net/dns-query');
|
||||
});
|
||||
|
||||
test('test for Plain DNS', async ({ page }) => {
|
||||
await runDNSSettingsTest(page, '94.140.14.140');
|
||||
});
|
||||
|
||||
test('test for DNS-over-HTTPS', async ({ page }) => {
|
||||
await runDNSSettingsTest(page, 'https://unfiltered.adguard-dns.com/dns-query');
|
||||
});
|
||||
|
||||
test('test for DNS-over-TLS', async ({ page }) => {
|
||||
await runDNSSettingsTest(page, 'tls://unfiltered.adguard-dns.com');
|
||||
});
|
||||
|
||||
test('test for DNS-over-QUIC', async ({ page }) => {
|
||||
await runDNSSettingsTest(page, 'quic://unfiltered.adguard-dns.com');
|
||||
});
|
||||
});
|
||||
73
client/tests/e2e/filtering.spec.ts
Normal file
73
client/tests/e2e/filtering.spec.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
import { test, expect, type Page } from '@playwright/test';
|
||||
import { execSync } from 'child_process';
|
||||
import { ADMIN_USERNAME, ADMIN_PASSWORD } from '../constants';
|
||||
|
||||
test.describe('Filtering', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
// Login before each test
|
||||
await page.goto('/login.html');
|
||||
await page.getByTestId('username').click();
|
||||
await page.getByTestId('username').fill(ADMIN_USERNAME);
|
||||
await page.getByTestId('password').click();
|
||||
await page.getByTestId('password').fill(ADMIN_PASSWORD);
|
||||
await page.keyboard.press('Tab');
|
||||
await page.getByTestId('sign_in').click();
|
||||
await page.waitForURL((url) => !url.href.endsWith('/login.html'));
|
||||
});
|
||||
|
||||
const runTerminalCommand = (command: string) => {
|
||||
try {
|
||||
console.info(`Executing command: ${command}`);
|
||||
|
||||
const output = execSync(command, { encoding: 'utf-8', stdio: 'pipe' }).trim();
|
||||
|
||||
console.info('Command executed successfully.');
|
||||
console.debug(`Command output:\n${output}`);
|
||||
|
||||
return output;
|
||||
} catch (error: any) {
|
||||
console.error(`Command execution failed with error:\n${error.message}`);
|
||||
throw new Error(`Failed to execute command: ${command}\nError: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
const runCustomRuleTest = async (page: Page, domain_to_block: string) => {
|
||||
await page.goto('/#custom_rules');
|
||||
|
||||
await page.getByTestId('custom_rule_textarea').fill(domain_to_block);
|
||||
await page.getByTestId('apply_custom_rule').click();
|
||||
|
||||
const nslookupBlockedResult = await runTerminalCommand(`nslookup ${domain_to_block} 127.0.0.1`).toString();
|
||||
|
||||
console.info(`nslookup blocked CNAME result: '${nslookupBlockedResult}'`);
|
||||
|
||||
const currentRules = await page.getByTestId('custom_rule_textarea').inputValue();
|
||||
console.debug(`Current rules before removal:\n${currentRules}`);
|
||||
|
||||
if (currentRules.includes(domain_to_block)) {
|
||||
const updatedRules = currentRules
|
||||
.split('\n')
|
||||
.filter((line) => line.trim() !== domain_to_block.trim())
|
||||
.join('\n');
|
||||
|
||||
await page.getByTestId('custom_rule_textarea').fill(updatedRules);
|
||||
console.info(`Rule '${domain_to_block}' removed successfully.`);
|
||||
|
||||
console.info('Applying the updated filtering rules after removal.');
|
||||
await page.getByTestId('apply_custom_rule').click();
|
||||
|
||||
await page.waitForLoadState('domcontentloaded');
|
||||
|
||||
console.info(`Filtering rules successfully updated after removing '${domain_to_block}'.`);
|
||||
} else {
|
||||
console.warn(`Rule '${domain_to_block}' not found. No changes were made.`);
|
||||
}
|
||||
|
||||
const nslookupUnblockedResult = await runTerminalCommand(`nslookup ${domain_to_block} 127.0.0.1`).toString();
|
||||
console.info(`nslookup unblocked CNAME result: '${nslookupUnblockedResult}'`);
|
||||
};
|
||||
|
||||
test('Test blocking rule for apple.com', async ({ page }) => {
|
||||
await runCustomRuleTest(page, 'apple.com');
|
||||
});
|
||||
});
|
||||
89
client/tests/e2e/general-settings.spec.ts
Normal file
89
client/tests/e2e/general-settings.spec.ts
Normal file
@@ -0,0 +1,89 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { execSync } from 'child_process';
|
||||
import { ADMIN_USERNAME, ADMIN_PASSWORD } from '../constants';
|
||||
|
||||
test.describe('General Settings', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.goto('/login.html');
|
||||
await page.getByTestId('username').click();
|
||||
await page.getByTestId('username').fill(ADMIN_USERNAME);
|
||||
await page.getByTestId('password').click();
|
||||
await page.getByTestId('password').fill(ADMIN_PASSWORD);
|
||||
await page.keyboard.press('Tab');
|
||||
await page.getByTestId('sign_in').click();
|
||||
await page.waitForURL((url) => !url.href.endsWith('/login.html'));
|
||||
});
|
||||
|
||||
test('should toggle browsing security feature and verify DNS changes', async ({ page }) => {
|
||||
await page.goto('/#settings');
|
||||
|
||||
const browsingSecurity = await page.getByTestId('safebrowsing');
|
||||
const browsingSecurityLabel = await browsingSecurity.locator('xpath=following-sibling::*[1]');
|
||||
|
||||
const initialState = await browsingSecurity.isChecked();
|
||||
|
||||
if (!initialState) {
|
||||
await browsingSecurityLabel.click();
|
||||
await expect(browsingSecurity).toBeChecked();
|
||||
}
|
||||
|
||||
const resultEnabled = execSync('nslookup totalvirus.com 127.0.0.1').toString();
|
||||
|
||||
await browsingSecurityLabel.click();
|
||||
await expect(browsingSecurity).not.toBeChecked();
|
||||
|
||||
const resultDisabled = execSync('nslookup totalvirus.com 127.0.0.1').toString();
|
||||
|
||||
expect(resultEnabled).not.toEqual(resultDisabled);
|
||||
|
||||
if (initialState) {
|
||||
await browsingSecurityLabel.click();
|
||||
await expect(browsingSecurity).toBeChecked();
|
||||
}
|
||||
});
|
||||
|
||||
test('should toggle parental control feature and verify DNS changes', async ({ page }) => {
|
||||
await page.goto('/#settings');
|
||||
|
||||
const parentalControl = page.getByTestId('parental');
|
||||
const parentalControlLabel = await parentalControl.locator('xpath=following-sibling::*[1]');
|
||||
|
||||
const initialState = await parentalControl.isChecked();
|
||||
|
||||
if (!initialState) {
|
||||
await parentalControlLabel.click();
|
||||
await expect(parentalControl).toBeChecked();
|
||||
}
|
||||
|
||||
const resultEnabled = execSync('nslookup pornhub.com 127.0.0.1').toString();
|
||||
|
||||
await parentalControlLabel.click();
|
||||
await expect(parentalControl).not.toBeChecked();
|
||||
|
||||
const resultDisabled = execSync('nslookup pornhub.com 127.0.0.1').toString();
|
||||
|
||||
expect(resultEnabled).not.toEqual(resultDisabled);
|
||||
|
||||
if (initialState) {
|
||||
await parentalControlLabel.click();
|
||||
await expect(parentalControl).toBeChecked();
|
||||
}
|
||||
});
|
||||
|
||||
test('should toggle safe search feature', async ({ page }) => {
|
||||
await page.goto('/#settings');
|
||||
|
||||
const safeSearch = page.getByTestId('safesearch');
|
||||
const safeSearchLabel = await safeSearch.locator('xpath=following-sibling::*[1]');
|
||||
|
||||
const initialState = await safeSearch.isChecked();
|
||||
|
||||
await safeSearchLabel.click();
|
||||
|
||||
await expect(safeSearch).not.toBeChecked({ checked: initialState });
|
||||
|
||||
await safeSearchLabel.click();
|
||||
|
||||
await expect(safeSearch).toBeChecked({ checked: initialState });
|
||||
});
|
||||
});
|
||||
124
client/tests/e2e/querylog.spec.ts
Normal file
124
client/tests/e2e/querylog.spec.ts
Normal file
@@ -0,0 +1,124 @@
|
||||
import { test, expect } from '@playwright/test';
|
||||
import { ADMIN_USERNAME, ADMIN_PASSWORD } from '../constants';
|
||||
|
||||
test.describe('QueryLog', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.goto('/login.html');
|
||||
await page.getByTestId('username').click();
|
||||
await page.getByTestId('username').fill(ADMIN_USERNAME);
|
||||
await page.getByTestId('password').click();
|
||||
await page.getByTestId('password').fill(ADMIN_PASSWORD);
|
||||
await page.keyboard.press('Tab');
|
||||
await page.getByTestId('sign_in').click();
|
||||
await page.waitForURL((url) => !url.href.endsWith('/login.html'));
|
||||
});
|
||||
|
||||
test('Search of queryLog should work correctly', async ({ page }) => {
|
||||
await page.route('/control/querylog', async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"answer": [
|
||||
{
|
||||
"type": "A",
|
||||
"value": "77.88.44.242",
|
||||
"ttl": 294
|
||||
},
|
||||
{
|
||||
"type": "A",
|
||||
"value": "5.255.255.242",
|
||||
"ttl": 294
|
||||
},
|
||||
{
|
||||
"type": "A",
|
||||
"value": "77.88.55.242",
|
||||
"ttl": 294
|
||||
}
|
||||
],
|
||||
"answer_dnssec": false,
|
||||
"cached": false,
|
||||
"client": "127.0.0.1",
|
||||
"client_info": {
|
||||
"whois": {},
|
||||
"name": "localhost",
|
||||
"disallowed_rule": "127.0.0.1",
|
||||
"disallowed": false
|
||||
},
|
||||
"client_proto": "",
|
||||
"elapsedMs": "78.163167",
|
||||
"question": {
|
||||
"class": "IN",
|
||||
"name": "ya.ru",
|
||||
"type": "A"
|
||||
},
|
||||
"reason": "NotFilteredNotFound",
|
||||
"rules": [],
|
||||
"status": "NOERROR",
|
||||
"time": "2024-07-17T16:02:37.500662+02:00",
|
||||
"upstream": "https://dns10.quad9.net:443/dns-query"
|
||||
},
|
||||
{
|
||||
"answer": [
|
||||
{
|
||||
"type": "A",
|
||||
"value": "77.88.55.242",
|
||||
"ttl": 351
|
||||
},
|
||||
{
|
||||
"type": "A",
|
||||
"value": "77.88.44.242",
|
||||
"ttl": 351
|
||||
},
|
||||
{
|
||||
"type": "A",
|
||||
"value": "5.255.255.242",
|
||||
"ttl": 351
|
||||
}
|
||||
],
|
||||
"answer_dnssec": false,
|
||||
"cached": false,
|
||||
"client": "127.0.0.1",
|
||||
"client_info": {
|
||||
"whois": {},
|
||||
"name": "localhost",
|
||||
"disallowed_rule": "127.0.0.1",
|
||||
"disallowed": false
|
||||
},
|
||||
"client_proto": "",
|
||||
"elapsedMs": "5051.070708",
|
||||
"question": {
|
||||
"class": "IN",
|
||||
"name": "ya.ru",
|
||||
"type": "A"
|
||||
},
|
||||
"reason": "NotFilteredNotFound",
|
||||
"rules": [],
|
||||
"status": "NOERROR",
|
||||
"time": "2024-07-17T16:02:37.4983+02:00",
|
||||
"upstream": "https://dns10.quad9.net:443/dns-query"
|
||||
}
|
||||
],
|
||||
"oldest": "2024-07-17T16:02:37.4983+02:00"
|
||||
}
|
||||
),
|
||||
});
|
||||
});
|
||||
|
||||
await page.goto('/#logs');
|
||||
|
||||
await page.getByTestId('querylog_search').fill('127.0.0.1');
|
||||
|
||||
const [request] = await Promise.all([
|
||||
page.waitForRequest((req) => req.url().includes('/control/querylog')),
|
||||
]);
|
||||
|
||||
if (request) {
|
||||
expect(request.url()).toContain('search=127.0.0.1');
|
||||
expect(await page.getByTestId('querylog_cell').first().isVisible()).toBe(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,12 +1,12 @@
|
||||
# A docker file for scripts/make/build-docker.sh.
|
||||
|
||||
FROM alpine:3.18
|
||||
FROM alpine:3.21
|
||||
|
||||
ARG BUILD_DATE
|
||||
ARG VERSION
|
||||
ARG VCS_REF
|
||||
|
||||
LABEL\
|
||||
LABEL \
|
||||
maintainer="AdGuard Team <devteam@adguard.com>" \
|
||||
org.opencontainers.image.authors="AdGuard Team <devteam@adguard.com>" \
|
||||
org.opencontainers.image.created=$BUILD_DATE \
|
||||
@@ -30,8 +30,8 @@ ARG TARGETARCH
|
||||
ARG TARGETOS
|
||||
ARG TARGETVARIANT
|
||||
|
||||
COPY --chown=nobody:nogroup\
|
||||
./${DIST_DIR}/docker/AdGuardHome_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT}\
|
||||
COPY --chown=nobody:nogroup \
|
||||
./${DIST_DIR}/docker/AdGuardHome_${TARGETOS}_${TARGETARCH}_${TARGETVARIANT} \
|
||||
/opt/adguardhome/AdGuardHome
|
||||
|
||||
RUN setcap 'cap_net_bind_service=+eip' /opt/adguardhome/AdGuardHome
|
||||
@@ -45,8 +45,15 @@ RUN setcap 'cap_net_bind_service=+eip' /opt/adguardhome/AdGuardHome
|
||||
# 3000 : TCP, UDP : HTTP(S) (alt, incl. HTTP/3)
|
||||
# 5443 : TCP, UDP : DNSCrypt (alt)
|
||||
# 6060 : TCP : HTTP (pprof)
|
||||
EXPOSE 53/tcp 53/udp 67/udp 68/udp 80/tcp 443/tcp 443/udp 853/tcp\
|
||||
853/udp 3000/tcp 3000/udp 5443/tcp 5443/udp 6060/tcp
|
||||
EXPOSE 53/tcp 53/udp \
|
||||
67/udp \
|
||||
68/udp \
|
||||
80/tcp \
|
||||
443/tcp 443/udp \
|
||||
853/tcp 853/udp \
|
||||
3000/tcp 3000/udp \
|
||||
5443/tcp 5443/udp \
|
||||
6060/tcp
|
||||
|
||||
WORKDIR /opt/adguardhome/work
|
||||
|
||||
|
||||
26
go.mod
26
go.mod
@@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.75.2
|
||||
github.com/AdguardTeam/golibs v0.32.7
|
||||
github.com/AdguardTeam/dnsproxy v0.75.3
|
||||
github.com/AdguardTeam/golibs v0.32.8
|
||||
github.com/AdguardTeam/urlfilter v0.20.0
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
github.com/ameshkov/dnscrypt/v2 v2.4.0
|
||||
@@ -34,7 +34,7 @@ require (
|
||||
github.com/ti-mo/netfilter v0.5.2
|
||||
go.etcd.io/bbolt v1.4.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
||||
golang.org/x/net v0.39.0
|
||||
golang.org/x/sys v0.32.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
@@ -43,12 +43,12 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.120.0 // indirect
|
||||
cloud.google.com/go/ai v0.10.1 // indirect
|
||||
cloud.google.com/go/auth v0.15.0 // indirect
|
||||
cloud.google.com/go v0.120.1 // indirect
|
||||
cloud.google.com/go/ai v0.10.2 // indirect
|
||||
cloud.google.com/go/auth v0.16.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/longrunning v0.6.6 // indirect
|
||||
cloud.google.com/go/longrunning v0.6.7 // indirect
|
||||
github.com/BurntSushi/toml v1.5.0 // indirect
|
||||
github.com/ameshkov/dnsstamps v1.0.3 // indirect
|
||||
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 // indirect
|
||||
@@ -90,25 +90,25 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/mock v0.5.1 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
||||
golang.org/x/mod v0.24.0 // indirect
|
||||
golang.org/x/oauth2 v0.29.0 // indirect
|
||||
golang.org/x/sync v0.13.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20250406004356-f593adaf3fc1 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 // indirect
|
||||
golang.org/x/term v0.31.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
golang.org/x/tools v0.32.0 // indirect
|
||||
golang.org/x/vuln v1.1.4 // indirect
|
||||
gonum.org/v1/gonum v0.16.0 // indirect
|
||||
google.golang.org/api v0.228.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250407143221-ac9807e6c755 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250407143221-ac9807e6c755 // indirect
|
||||
google.golang.org/api v0.229.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250414145226-207652e42e2e // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e // indirect
|
||||
google.golang.org/grpc v1.71.1 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
honnef.co/go/tools v0.6.1 // indirect
|
||||
mvdan.cc/editorconfig v0.3.0 // indirect
|
||||
mvdan.cc/gofumpt v0.7.0 // indirect
|
||||
mvdan.cc/gofumpt v0.8.0 // indirect
|
||||
mvdan.cc/sh/v3 v3.11.0 // indirect
|
||||
mvdan.cc/unparam v0.0.0-20250301125049-0df0534333a4 // indirect
|
||||
)
|
||||
|
||||
52
go.sum
52
go.sum
@@ -1,19 +1,19 @@
|
||||
cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA=
|
||||
cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q=
|
||||
cloud.google.com/go/ai v0.10.1 h1:EU93KqYmMeOKgaBXAz2DshH2C/BzAT1P+iJORksLIic=
|
||||
cloud.google.com/go/ai v0.10.1/go.mod h1:sWWHZvmJ83BjuxAQtYEiA0SFTpijtbH+SXWFO14ri5A=
|
||||
cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps=
|
||||
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
|
||||
cloud.google.com/go v0.120.1 h1:Z+5V7yd383+9617XDCyszmK5E4wJRJL+tquMfDj9hLM=
|
||||
cloud.google.com/go v0.120.1/go.mod h1:56Vs7sf/i2jYM6ZL9NYlC82r04PThNcPS5YgFmb0rp8=
|
||||
cloud.google.com/go/ai v0.10.2 h1:5NHzmZlRs+3kvlsVdjT0cTnLrjQdROJ/8VOljVfs+8o=
|
||||
cloud.google.com/go/ai v0.10.2/go.mod h1:xZuZuE9d3RgsR132meCnPadiU9XV0qXjpLr+P4J46eE=
|
||||
cloud.google.com/go/auth v0.16.0 h1:Pd8P1s9WkcrBE2n/PhAwKsdrR35V3Sg2II9B+ndM3CU=
|
||||
cloud.google.com/go/auth v0.16.0/go.mod h1:1howDHJ5IETh/LwYs3ZxvlkXF48aSqqJUM+5o02dNOI=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/longrunning v0.6.6 h1:XJNDo5MUfMM05xK3ewpbSdmt7R2Zw+aQEMbdQR65Rbw=
|
||||
cloud.google.com/go/longrunning v0.6.6/go.mod h1:hyeGJUrPHcx0u2Uu1UFSoYZLn4lkMrccJig0t4FI7yw=
|
||||
github.com/AdguardTeam/dnsproxy v0.75.2 h1:bciOkzQh/GG8vcZGdFn6+rS3pu+2Npt9tbA4bNA/rsc=
|
||||
github.com/AdguardTeam/dnsproxy v0.75.2/go.mod h1:U/ouLftmXMIrkTAf8JepqbPuoQzsbXJo0Vxxn+LAdgA=
|
||||
github.com/AdguardTeam/golibs v0.32.7 h1:3dmGlAVgmvquCCwHsvEl58KKcRAK3z1UnjMnwSIeDH4=
|
||||
github.com/AdguardTeam/golibs v0.32.7/go.mod h1:bE8KV1zqTzgZjmjFyBJ9f9O5DEKO717r7e57j1HclJA=
|
||||
cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE=
|
||||
cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY=
|
||||
github.com/AdguardTeam/dnsproxy v0.75.3 h1:pxlMNO+cP1A3px40PY/old6SAE82pkdLPUA2P3KY8u0=
|
||||
github.com/AdguardTeam/dnsproxy v0.75.3/go.mod h1:50OyTHao+uQzUJiXay08hgfvWQ3o2Q2WV99W8u8ypDE=
|
||||
github.com/AdguardTeam/golibs v0.32.8 h1:O3mc3kYcPkW3kbmd+gqzFNgUka13a+iBgFLThwOYSQE=
|
||||
github.com/AdguardTeam/golibs v0.32.8/go.mod h1:McV1QFFlKLElKa306V4OL/T2kr7564PhsayfvTWYBVs=
|
||||
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
|
||||
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
@@ -205,10 +205,10 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw=
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM=
|
||||
golang.org/x/exp/typeparams v0.0.0-20250305212735-054e65f0b394 h1:VI4qDpTkfFaCXEPrbojidLgVQhj2x4nzTccG0hjaLlU=
|
||||
golang.org/x/exp/typeparams v0.0.0-20250305212735-054e65f0b394/go.mod h1:LKZHyeOpPuZcMgxeHjJp4p5yvxrCX1xDvH10zYHhjjQ=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/exp/typeparams v0.0.0-20250408133849-7e4ce0ab07d0 h1:oMe07YcizemJ09rs2kRkFYAp0pt4e1lYLwPWiEGMpXE=
|
||||
golang.org/x/exp/typeparams v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:LKZHyeOpPuZcMgxeHjJp4p5yvxrCX1xDvH10zYHhjjQ=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
@@ -243,8 +243,8 @@ golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/telemetry v0.0.0-20250406004356-f593adaf3fc1 h1:LxyDqgHX2VuimV2UQSNFpQxz+NRUUsh8ulNcP3WvNG0=
|
||||
golang.org/x/telemetry v0.0.0-20250406004356-f593adaf3fc1/go.mod h1:RoaXAWDwS90j6FxVKwJdBV+0HCU+llrKUGgJaxiKl6M=
|
||||
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 h1:RXY2+rSHXvxO2Y+gKrPjYVaEoGOqh3VEXFhnWAt1Irg=
|
||||
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3/go.mod h1:RoaXAWDwS90j6FxVKwJdBV+0HCU+llrKUGgJaxiKl6M=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
|
||||
golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
|
||||
@@ -268,12 +268,12 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs=
|
||||
google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250407143221-ac9807e6c755 h1:AMLTAunltONNuzWgVPZXrjLWtXpsG6A3yLLPEoJ/IjU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250407143221-ac9807e6c755/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250407143221-ac9807e6c755 h1:TwXJCGVREgQ/cl18iY0Z4wJCTL/GmW+Um2oSwZiZPnc=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250407143221-ac9807e6c755/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/api v0.229.0 h1:p98ymMtqeJ5i3lIBMj5MpR9kzIIgzpHHh8vQ+vgAzx8=
|
||||
google.golang.org/api v0.229.0/go.mod h1:wyDfmq5g1wYJWn29O22FDWN48P7Xcz0xz+LBpptYvB0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250414145226-207652e42e2e h1:UdXH7Kzbj+Vzastr5nVfccbmFsmYNygVLSPk1pEfDoY=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250414145226-207652e42e2e/go.mod h1:085qFyf2+XaZlRdCgKNCIZ3afY2p4HHZdoIRpId8F4A=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e h1:ztQaXfzEXTmCBvbtWYRhJxW+0iJcz2qXfd38/e9l7bA=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250414145226-207652e42e2e/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A=
|
||||
google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI=
|
||||
google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
@@ -292,8 +292,8 @@ howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
mvdan.cc/editorconfig v0.3.0 h1:D1D2wLYEYGpawWT5SpM5pRivgEgXjtEXwC9MWhEY0gQ=
|
||||
mvdan.cc/editorconfig v0.3.0/go.mod h1:NcJHuDtNOTEJ6251indKiWuzK6+VcrMuLzGMLKBFupQ=
|
||||
mvdan.cc/gofumpt v0.7.0 h1:bg91ttqXmi9y2xawvkuMXyvAA/1ZGJqYAEGjXuP0JXU=
|
||||
mvdan.cc/gofumpt v0.7.0/go.mod h1:txVFJy/Sc/mvaycET54pV8SW8gWxTlUuGHVEcncmNUo=
|
||||
mvdan.cc/gofumpt v0.8.0 h1:nZUCeC2ViFaerTcYKstMmfysj6uhQrA2vJe+2vwGU6k=
|
||||
mvdan.cc/gofumpt v0.8.0/go.mod h1:vEYnSzyGPmjvFkqJWtXkh79UwPWP9/HMxQdGEXZHjpg=
|
||||
mvdan.cc/sh/v3 v3.11.0 h1:q5h+XMDRfUGUedCqFFsjoFjrhwf2Mvtt1rkMvVz0blw=
|
||||
mvdan.cc/sh/v3 v3.11.0/go.mod h1:LRM+1NjoYCzuq/WZ6y44x14YNAI0NK7FLPeQSaFagGg=
|
||||
mvdan.cc/unparam v0.0.0-20250301125049-0df0534333a4 h1:WjUu4yQoT5BHT1w8Zu56SP8367OuBV5jvo+4Ulppyf8=
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
// Login is the type for web user logins.
|
||||
type Login string
|
||||
|
||||
// NewLogin returns a web user login.
|
||||
// NewLogin returns a web user login. The length of s must not be greater than
|
||||
// [math.MaxUint16].
|
||||
//
|
||||
// TODO(s.chzhen): Add more constraints as needed.
|
||||
func NewLogin(s string) (l Login, err error) {
|
||||
|
||||
35
internal/aghuser/session.go
Normal file
35
internal/aghuser/session.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package aghuser
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionToken is the type for the web user session token.
|
||||
type SessionToken [16]byte
|
||||
|
||||
// NewSessionToken returns a cryptographically secure randomly generated web
|
||||
// user session token. If an error occurs during random generation, it will
|
||||
// cause the program to crash.
|
||||
func NewSessionToken() (t SessionToken) {
|
||||
_, _ = rand.Read(t[:])
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Session represents a web user session.
|
||||
type Session struct {
|
||||
// Expire indicates when the session will expire.
|
||||
Expire time.Time
|
||||
|
||||
// UserLogin is the login of the web user associated with the session.
|
||||
//
|
||||
// TODO(s.chzhen): Remove this field and associate the user by UserID.
|
||||
UserLogin Login
|
||||
|
||||
// Token is the session token.
|
||||
Token SessionToken
|
||||
|
||||
// UserID is the identifier of the web user associated with the session.
|
||||
UserID UserID
|
||||
}
|
||||
453
internal/aghuser/sessionstorage.go
Normal file
453
internal/aghuser/sessionstorage.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package aghuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"go.etcd.io/bbolt"
|
||||
berrors "go.etcd.io/bbolt/errors"
|
||||
)
|
||||
|
||||
// SessionStorage is an interface that defines methods for handling web user
|
||||
// sessions. All methods must be safe for concurrent use.
|
||||
//
|
||||
// TODO(s.chzhen): Add DeleteAll method.
|
||||
type SessionStorage interface {
|
||||
// New creates a new session for the web user.
|
||||
New(ctx context.Context, u *User) (s *Session, err error)
|
||||
|
||||
// FindByToken returns the stored session for the web user based on the session
|
||||
// token.
|
||||
//
|
||||
// TODO(s.chzhen): Consider function signature change to reflect the
|
||||
// in-memory implementation, as it currently always returns nil for error.
|
||||
FindByToken(ctx context.Context, t SessionToken) (s *Session, err error)
|
||||
|
||||
// DeleteByToken removes a stored web user session by the provided token.
|
||||
DeleteByToken(ctx context.Context, t SessionToken) (err error)
|
||||
|
||||
// Close releases the web user sessions database resources.
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
// DefaultSessionStorageConfig represents the web user session storage
|
||||
// configuration structure.
|
||||
type DefaultSessionStorageConfig struct {
|
||||
// Logger is used for logging the operation of the session storage. It must
|
||||
// not be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Clock is used to get the current time. It must not be nil.
|
||||
Clock timeutil.Clock
|
||||
|
||||
// UserDB contains the web user information such as ID, login, and password.
|
||||
// It must not be nil.
|
||||
UserDB DB
|
||||
|
||||
// DBPath is the path to the database file where session data is stored. It
|
||||
// must not be empty.
|
||||
DBPath string
|
||||
|
||||
// SessionTTL is the default Time-To-Live duration for web user sessions.
|
||||
// It specifies how long a session should last and is a required field.
|
||||
SessionTTL time.Duration
|
||||
}
|
||||
|
||||
// DefaultSessionStorage is the default bbolt database implementation of the
|
||||
// [SessionStorage] interface.
|
||||
type DefaultSessionStorage struct {
|
||||
// db is an instance of the bbolt database where web user sessions are
|
||||
// stored by [SessionToken] in the [bucketNameSessions] bucket.
|
||||
db *bbolt.DB
|
||||
|
||||
// logger is used for logging the operation of the session storage.
|
||||
logger *slog.Logger
|
||||
|
||||
// mu protects sessions.
|
||||
mu *sync.Mutex
|
||||
|
||||
// clock is used to get the current time.
|
||||
clock timeutil.Clock
|
||||
|
||||
// userDB contains the web user information such as ID, login, and password.
|
||||
userDB DB
|
||||
|
||||
// sessions maps a session token to a web user session.
|
||||
sessions map[SessionToken]*Session
|
||||
|
||||
// sessionTTL is the default Time-To-Live value for web user sessions.
|
||||
sessionTTL time.Duration
|
||||
}
|
||||
|
||||
// NewDefaultSessionStorage returns the new properly initialized
|
||||
// *DefaultSessionStorage.
|
||||
func NewDefaultSessionStorage(
|
||||
ctx context.Context,
|
||||
conf *DefaultSessionStorageConfig,
|
||||
) (ds *DefaultSessionStorage, err error) {
|
||||
ds = &DefaultSessionStorage{
|
||||
clock: conf.Clock,
|
||||
userDB: conf.UserDB,
|
||||
logger: conf.Logger,
|
||||
mu: &sync.Mutex{},
|
||||
sessions: map[SessionToken]*Session{},
|
||||
sessionTTL: conf.SessionTTL,
|
||||
}
|
||||
|
||||
dbFilename := conf.DBPath
|
||||
// TODO(s.chzhen): Pass logger with options.
|
||||
ds.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
|
||||
if err != nil {
|
||||
ds.logger.ErrorContext(ctx, "opening db %q: %w", dbFilename, err)
|
||||
if errors.Is(err, berrors.ErrInvalid) {
|
||||
const s = "AdGuard Home cannot be initialized due to an incompatible file system.\n" +
|
||||
"Please read the explanation here: https://adguard-dns.io/kb/adguard-home/getting-started/#limitations"
|
||||
slogutil.PrintLines(ctx, ds.logger, slog.LevelError, "", s)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = ds.loadSessions(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading sessions: %w", err)
|
||||
}
|
||||
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
// loadSessions loads web user sessions from the bbolt database.
|
||||
func (ds *DefaultSessionStorage) loadSessions(ctx context.Context) (err error) {
|
||||
tx, err := ds.db.Begin(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting transaction: %w", err)
|
||||
}
|
||||
|
||||
needRollback := true
|
||||
defer func() {
|
||||
if needRollback {
|
||||
err = errors.WithDeferred(err, tx.Rollback())
|
||||
}
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket([]byte(bboltBucketSessions))
|
||||
if bkt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
removed, err := ds.processSessions(ctx, bkt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("processing sessions: %w", err)
|
||||
}
|
||||
|
||||
if removed == 0 {
|
||||
ds.logger.DebugContext(ctx, "loading sessions from db", "stored", len(ds.sessions))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
needRollback = false
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
ds.logger.DebugContext(
|
||||
ctx,
|
||||
"loading sessions from db",
|
||||
"stored", len(ds.sessions),
|
||||
"removed", removed,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processSessions iterates over the sessions bucket and loads or removes
|
||||
// sessions as needed.
|
||||
func (ds *DefaultSessionStorage) processSessions(
|
||||
ctx context.Context,
|
||||
bkt *bbolt.Bucket,
|
||||
) (removed int, err error) {
|
||||
invalidSessions := [][]byte{}
|
||||
|
||||
err = bkt.ForEach(ds.bboltSessionHandler(ctx, &invalidSessions))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("iterating over sessions: %w", err)
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, s := range invalidSessions {
|
||||
if err = bkt.Delete(s); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = errors.Join(errs...); err != nil {
|
||||
return 0, fmt.Errorf("deleting sessions: %w", err)
|
||||
}
|
||||
|
||||
return len(invalidSessions), nil
|
||||
}
|
||||
|
||||
// bboltSessionHandler returns a function for [bbolt.Bucket.ForEach] that
|
||||
// iterates over stored sessions, deserializes them, and logs any errors
|
||||
// encountered. The returned error is always nil, as these errors are
|
||||
// considered non-critical to stop the iteration process.
|
||||
func (ds *DefaultSessionStorage) bboltSessionHandler(
|
||||
ctx context.Context,
|
||||
invalidSessions *[][]byte,
|
||||
) (fn func(k, v []byte) (err error)) {
|
||||
now := ds.clock.Now()
|
||||
|
||||
return func(k, v []byte) (err error) {
|
||||
s, err := bboltDecode(v)
|
||||
if err != nil {
|
||||
*invalidSessions = append(*invalidSessions, k)
|
||||
ds.logger.DebugContext(ctx, "deserializing session", slogutil.KeyError, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if now.After(s.Expire) {
|
||||
*invalidSessions = append(*invalidSessions, k)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
u, err := ds.userDB.ByLogin(ctx, s.UserLogin)
|
||||
if err != nil {
|
||||
// Should not happen, as it currently always returns nil for error.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if u == nil {
|
||||
*invalidSessions = append(*invalidSessions, k)
|
||||
ds.logger.DebugContext(ctx, "no saved user by name", "name", s.UserLogin)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
t := SessionToken(k)
|
||||
s.Token = t
|
||||
s.UserID = u.ID
|
||||
ds.sessions[t] = s
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// bboltBucketSessions is the name of the bucket storing web user sessions in
|
||||
// the bbolt database.
|
||||
const bboltBucketSessions = "sessions-2"
|
||||
|
||||
const (
|
||||
// bboltSessionExpireLen is the length of the expire field in the binary
|
||||
// entry stored in bbolt.
|
||||
bboltSessionExpireLen = 4
|
||||
|
||||
// bboltSessionNameLen is the length of the name field in the binary entry
|
||||
// stored in bbolt.
|
||||
bboltSessionNameLen = 2
|
||||
)
|
||||
|
||||
// bboltDecode deserializes decodes a binary data into a session.
|
||||
func bboltDecode(data []byte) (s *Session, err error) {
|
||||
if len(data) < bboltSessionExpireLen+bboltSessionNameLen {
|
||||
return nil, fmt.Errorf("length of the data is less than expected: got %d", len(data))
|
||||
}
|
||||
|
||||
expireData := data[:bboltSessionExpireLen]
|
||||
nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen]
|
||||
nameData := data[bboltSessionExpireLen+bboltSessionNameLen:]
|
||||
|
||||
nameLen := binary.BigEndian.Uint16(nameLenData)
|
||||
if len(nameData) != int(nameLen) {
|
||||
return nil, fmt.Errorf("login: expected length %d, got %d", nameLen, len(nameData))
|
||||
}
|
||||
|
||||
expire := binary.BigEndian.Uint32(expireData)
|
||||
|
||||
return &Session{
|
||||
Expire: time.Unix(int64(expire), 0),
|
||||
UserLogin: Login(nameData),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// bboltEncode serializes a session properties into a binary data.
|
||||
func bboltEncode(s *Session) (data []byte) {
|
||||
data = make([]byte, bboltSessionExpireLen+bboltSessionNameLen+len(s.UserLogin))
|
||||
|
||||
expireData := data[:bboltSessionExpireLen]
|
||||
nameLenData := data[bboltSessionExpireLen : bboltSessionExpireLen+bboltSessionNameLen]
|
||||
nameData := data[bboltSessionExpireLen+bboltSessionNameLen:]
|
||||
|
||||
expire := uint32(s.Expire.Unix())
|
||||
binary.BigEndian.PutUint32(expireData, expire)
|
||||
binary.BigEndian.PutUint16(nameLenData, uint16(len(s.UserLogin)))
|
||||
copy(nameData, []byte(s.UserLogin))
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ SessionStorage = (*DefaultSessionStorage)(nil)
|
||||
|
||||
// New implements the [SessionStorage] interface for *DefaultSessionStorage.
|
||||
func (ds *DefaultSessionStorage) New(ctx context.Context, u *User) (s *Session, err error) {
|
||||
s = &Session{
|
||||
Token: NewSessionToken(),
|
||||
UserID: u.ID,
|
||||
UserLogin: u.Login,
|
||||
Expire: ds.clock.Now().Add(ds.sessionTTL),
|
||||
}
|
||||
|
||||
err = ds.store(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storing session: %w", err)
|
||||
}
|
||||
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
ds.sessions[s.Token] = s
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// store saves a web user session in the bbolt database.
|
||||
func (ds *DefaultSessionStorage) store(s *Session) (err error) {
|
||||
tx, err := ds.db.Begin(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting transaction: %w", err)
|
||||
}
|
||||
|
||||
needRollback := true
|
||||
defer func() {
|
||||
if needRollback {
|
||||
err = errors.WithDeferred(err, tx.Rollback())
|
||||
}
|
||||
}()
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists([]byte(bboltBucketSessions))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating bucket: %w", err)
|
||||
}
|
||||
|
||||
err = bkt.Put(s.Token[:], bboltEncode(s))
|
||||
if err != nil {
|
||||
return fmt.Errorf("putting data: %w", err)
|
||||
}
|
||||
|
||||
needRollback = false
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByToken implements the [SessionStorage] interface for
|
||||
// *DefaultSessionStorage.
|
||||
func (ds *DefaultSessionStorage) FindByToken(
|
||||
ctx context.Context,
|
||||
t SessionToken,
|
||||
) (s *Session, err error) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
s, ok := ds.sessions[t]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
now := ds.clock.Now()
|
||||
if now.After(s.Expire) {
|
||||
err = ds.deleteByToken(ctx, t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("expired session: %w", err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// DeleteByToken implements the [SessionStorage] interface for
|
||||
// *DefaultSessionStorage.
|
||||
func (ds *DefaultSessionStorage) DeleteByToken(ctx context.Context, t SessionToken) (err error) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return ds.deleteByToken(ctx, t)
|
||||
}
|
||||
|
||||
// deleteByToken removes stored session by token. ds.mu is expected to be
|
||||
// locked.
|
||||
func (ds *DefaultSessionStorage) deleteByToken(ctx context.Context, t SessionToken) (err error) {
|
||||
err = ds.remove(ctx, t)
|
||||
if err != nil {
|
||||
ds.logger.ErrorContext(ctx, "deleting session", slogutil.KeyError, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
delete(ds.sessions, t)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// remove deletes a web user session from the bbolt database.
|
||||
func (ds *DefaultSessionStorage) remove(ctx context.Context, t SessionToken) (err error) {
|
||||
tx, err := ds.db.Begin(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting transaction: %w", err)
|
||||
}
|
||||
|
||||
needRollback := true
|
||||
defer func() {
|
||||
if needRollback {
|
||||
err = errors.WithDeferred(err, tx.Rollback())
|
||||
}
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket([]byte(bboltBucketSessions))
|
||||
if bkt == nil {
|
||||
return errors.Error("no bucket")
|
||||
}
|
||||
|
||||
err = bkt.Delete(t[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing data: %w", err)
|
||||
}
|
||||
|
||||
needRollback = false
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
ds.logger.DebugContext(ctx, "removed session from db")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Close implements the [SessionStorage] interface for *DefaultSessionStorage.
|
||||
func (ds *DefaultSessionStorage) Close() (err error) {
|
||||
err = ds.db.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing db: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
162
internal/aghuser/sessionstorage_test.go
Normal file
162
internal/aghuser/sessionstorage_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package aghuser_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/faketime"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// addSession is a helper function that saves and returns a session for a newly
|
||||
// generated [aghuser.User] by login.
|
||||
func addSession(
|
||||
tb testing.TB,
|
||||
ctx context.Context,
|
||||
ds aghuser.SessionStorage,
|
||||
login aghuser.Login,
|
||||
) (s *aghuser.Session) {
|
||||
tb.Helper()
|
||||
|
||||
s, err := ds.New(ctx, &aghuser.User{
|
||||
ID: aghuser.MustNewUserID(),
|
||||
Login: login,
|
||||
})
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, s)
|
||||
|
||||
var got *aghuser.Session
|
||||
got, err = ds.FindByToken(ctx, s.Token)
|
||||
require.NoError(tb, err)
|
||||
require.NotNil(tb, got)
|
||||
|
||||
assert.Equal(tb, login, got.UserLogin)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestDefaultSessionStorage(t *testing.T) {
|
||||
const (
|
||||
userLoginFirst aghuser.Login = "user_one"
|
||||
userLoginSecond aghuser.Login = "user_two"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
logger = slogutil.NewDiscardLogger()
|
||||
)
|
||||
|
||||
const (
|
||||
sessionTTL = time.Minute
|
||||
timeStep = time.Second
|
||||
)
|
||||
|
||||
// Set up a mock clock to test expired sessions. Each call to [clock.Now]
|
||||
// will return the [date] incremented by [timeStep].
|
||||
date := time.Now()
|
||||
clock := &faketime.Clock{
|
||||
OnNow: func() (now time.Time) {
|
||||
date = date.Add(timeStep)
|
||||
|
||||
return date
|
||||
},
|
||||
}
|
||||
|
||||
dbFile, err := os.CreateTemp(t.TempDir(), "sessions.db")
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, dbFile.Close)
|
||||
|
||||
userDB := aghuser.NewDefaultDB()
|
||||
|
||||
err = userDB.Create(ctx, &aghuser.User{
|
||||
Login: userLoginFirst,
|
||||
ID: aghuser.MustNewUserID(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = userDB.Create(ctx, &aghuser.User{
|
||||
Login: userLoginSecond,
|
||||
ID: aghuser.MustNewUserID(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
ds *aghuser.DefaultSessionStorage
|
||||
|
||||
sessionFirst *aghuser.Session
|
||||
sessionSecond *aghuser.Session
|
||||
)
|
||||
|
||||
require.True(t, t.Run("prepare_session_storage", func(t *testing.T) {
|
||||
ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
|
||||
Clock: clock,
|
||||
UserDB: userDB,
|
||||
Logger: logger,
|
||||
DBPath: dbFile.Name(),
|
||||
SessionTTL: sessionTTL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionFirst = addSession(t, ctx, ds, userLoginFirst)
|
||||
|
||||
// Advance time to ensure the first session expires before creating the
|
||||
// second session.
|
||||
date = date.Add(time.Hour)
|
||||
|
||||
sessionSecond = addSession(t, ctx, ds, userLoginSecond)
|
||||
|
||||
err = ds.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
require.True(t, t.Run("load_sessions", func(t *testing.T) {
|
||||
ds, err = aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
|
||||
Clock: clock,
|
||||
UserDB: userDB,
|
||||
Logger: logger,
|
||||
DBPath: dbFile.Name(),
|
||||
SessionTTL: sessionTTL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var got *aghuser.Session
|
||||
got, err = ds.FindByToken(ctx, sessionFirst.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, got)
|
||||
|
||||
got, err = ds.FindByToken(ctx, sessionSecond.Token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
assert.Equal(t, userLoginSecond, got.UserLogin)
|
||||
|
||||
err = ds.DeleteByToken(ctx, sessionSecond.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err = ds.FindByToken(ctx, sessionSecond.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, got)
|
||||
}))
|
||||
|
||||
require.True(t, t.Run("expired_session", func(t *testing.T) {
|
||||
testutil.CleanupAndRequireSuccess(t, ds.Close)
|
||||
|
||||
sessionFirst = addSession(t, ctx, ds, userLoginFirst)
|
||||
|
||||
date = date.Add(time.Hour)
|
||||
|
||||
var got *aghuser.Session
|
||||
got, err = ds.FindByToken(ctx, sessionFirst.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, got)
|
||||
}))
|
||||
}
|
||||
@@ -32,13 +32,13 @@ func MustNewUserID() (uid UserID) {
|
||||
|
||||
// User represents a web user.
|
||||
type User struct {
|
||||
// ID is the unique identifier for the web user. It must not be empty.
|
||||
ID UserID
|
||||
// Password stores the password information for the web user. It must not
|
||||
// be nil.
|
||||
Password Password
|
||||
|
||||
// Login is the login name of the web user. It must not be empty.
|
||||
Login Login
|
||||
|
||||
// Password stores the password information for the web user. It must not
|
||||
// be nil.
|
||||
Password Password
|
||||
// ID is the unique identifier for the web user. It must not be empty.
|
||||
ID UserID
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
return "", nil
|
||||
}
|
||||
|
||||
hostSrvName := s.conf.ServerName
|
||||
hostSrvName := s.conf.TLSConf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
s.conf.TLSConf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("clientid check: %w", err)
|
||||
|
||||
@@ -121,7 +121,7 @@ func TestServer_HandleBefore_tls(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s, _ := createTestTLS(t, TLSConfig{
|
||||
s, _ := createTestTLS(t, &TLSConfig{
|
||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||
ServerName: tlsServerName,
|
||||
})
|
||||
@@ -259,6 +259,7 @@ func TestServer_HandleBefore_udp(t *testing.T) {
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
AllowedClients: tc.allowedClients,
|
||||
DisallowedClients: tc.disallowedClients,
|
||||
|
||||
@@ -212,13 +212,13 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tlsConf := TLSConfig{
|
||||
tlsConf := &TLSConfig{
|
||||
ServerName: tc.confSrvName,
|
||||
StrictSNICheck: tc.strictSNI,
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
conf: ServerConfig{TLSConf: tlsConf},
|
||||
baseLogger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
@@ -168,43 +167,34 @@ type EDNSClientSubnet struct {
|
||||
UseCustom bool `yaml:"use_custom"`
|
||||
}
|
||||
|
||||
// TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS
|
||||
// TLSConfig contains the TLS configuration settings for DNS-over-HTTPS (DoH),
|
||||
// DNS-over-TLS (DoT), DNS-over-QUIC (DoQ), and Discovery of Designated
|
||||
// Resolvers (DDR).
|
||||
type TLSConfig struct {
|
||||
cert tls.Certificate
|
||||
// Cert is the TLS certificate used for TLS connections. It is nil if
|
||||
// encryption is disabled.
|
||||
Cert *tls.Certificate
|
||||
|
||||
TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"`
|
||||
HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"`
|
||||
// TLSListenAddrs are the addresses to listen on for DoT connections. Each
|
||||
// item in the list must be non-nil if Cert is not nil.
|
||||
TLSListenAddrs []*net.TCPAddr
|
||||
|
||||
// PEM-encoded certificates chain
|
||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"`
|
||||
// PEM-encoded private key
|
||||
PrivateKey string `yaml:"private_key" json:"private_key"`
|
||||
// QUICListenAddrs are the addresses to listen on for DoQ connections. Each
|
||||
// item in the list must be non-nil if Cert is not nil.
|
||||
QUICListenAddrs []*net.UDPAddr
|
||||
|
||||
CertificatePath string `yaml:"certificate_path" json:"certificate_path"`
|
||||
PrivateKeyPath string `yaml:"private_key_path" json:"private_key_path"`
|
||||
|
||||
CertificateChainData []byte `yaml:"-" json:"-"`
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
// HTTPSListenAddrs should be the addresses AdGuard Home is listening on for
|
||||
// DoH connections. These addresses are announced with DDR. Each item in
|
||||
// the list must be non-nil.
|
||||
HTTPSListenAddrs []*net.TCPAddr
|
||||
|
||||
// ServerName is the hostname of the server. Currently, it is only being
|
||||
// used for ClientID checking and Discovery of Designated Resolvers (DDR).
|
||||
ServerName string `yaml:"-" json:"-"`
|
||||
|
||||
// DNS names from certificate (SAN) or CN value from Subject
|
||||
dnsNames []string
|
||||
|
||||
// OverrideTLSCiphers, when set, contains the names of the cipher suites to
|
||||
// use. If the slice is empty, the default safe suites are used.
|
||||
OverrideTLSCiphers []string `yaml:"override_tls_ciphers,omitempty" json:"-"`
|
||||
ServerName string
|
||||
|
||||
// StrictSNICheck controls if the connections with SNI mismatching the
|
||||
// certificate's ones should be rejected.
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
|
||||
// hasIPAddrs is set during the certificate parsing and is true if the
|
||||
// configured certificate contains at least a single IP address.
|
||||
hasIPAddrs bool
|
||||
StrictSNICheck bool
|
||||
}
|
||||
|
||||
// DNSCryptConfig is the DNSCrypt server configuration struct.
|
||||
@@ -239,8 +229,11 @@ type ServerConfig struct {
|
||||
// Remove that.
|
||||
AddrProcConf *client.DefaultAddrProcConfig
|
||||
|
||||
// TLSConf is the TLS configuration for DNS-over-TLS, DNS-over-QUIC, and
|
||||
// HTTPS. It must not be nil.
|
||||
TLSConf *TLSConfig
|
||||
|
||||
Config
|
||||
TLSConfig
|
||||
DNSCryptConfig
|
||||
TLSAllowUnencryptedDoH bool
|
||||
|
||||
@@ -281,6 +274,10 @@ type ServerConfig struct {
|
||||
|
||||
// ServePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||
ServePlainDNS bool
|
||||
|
||||
// PendingRequestsEnabled defines if duplicate requests should be forwarded
|
||||
// to upstreams along with the original one.
|
||||
PendingRequestsEnabled bool
|
||||
}
|
||||
|
||||
// UpstreamMode is a enumeration of upstream mode representations. See
|
||||
@@ -324,6 +321,9 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
UsePrivateRDNS: srvConf.UsePrivateRDNS,
|
||||
PrivateSubnets: s.privateNets,
|
||||
MessageConstructor: s,
|
||||
PendingRequests: &proxy.PendingRequestsConfig{
|
||||
Enabled: srvConf.PendingRequestsEnabled,
|
||||
},
|
||||
}
|
||||
|
||||
if srvConf.EDNSClientSubnet.UseCustom {
|
||||
@@ -608,45 +608,33 @@ func (conf *ServerConfig) ourAddrsSet() (m addrPortSet, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// prepareTLS - prepares TLS configuration for the DNS proxy
|
||||
// prepareTLS sets up the TLS configuration for the DNS proxy.
|
||||
func (s *Server) prepareTLS(proxyConfig *proxy.Config) (err error) {
|
||||
if len(s.conf.CertificateChainData) == 0 || len(s.conf.PrivateKeyData) == 0 {
|
||||
if s.conf.TLSConf.Cert == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.conf.TLSConf.TLSListenAddrs == nil && s.conf.TLSConf.QUICListenAddrs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.conf.TLSListenAddrs == nil && s.conf.QUICListenAddrs == nil {
|
||||
return nil
|
||||
}
|
||||
proxyConfig.TLSListenAddr = s.conf.TLSConf.TLSListenAddrs
|
||||
proxyConfig.QUICListenAddr = s.conf.TLSConf.QUICListenAddrs
|
||||
|
||||
proxyConfig.TLSListenAddr = aghalg.CoalesceSlice(
|
||||
s.conf.TLSListenAddrs,
|
||||
proxyConfig.TLSListenAddr,
|
||||
)
|
||||
|
||||
proxyConfig.QUICListenAddr = aghalg.CoalesceSlice(
|
||||
s.conf.QUICListenAddrs,
|
||||
proxyConfig.QUICListenAddr,
|
||||
)
|
||||
|
||||
s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TLS keypair: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(s.conf.cert.Certificate[0])
|
||||
cert, err := x509.ParseCertificate(s.conf.TLSConf.Cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("x509.ParseCertificate(): %w", err)
|
||||
}
|
||||
|
||||
s.conf.hasIPAddrs = aghtls.CertificateHasIP(cert)
|
||||
s.hasIPAddrs = aghtls.CertificateHasIP(cert)
|
||||
|
||||
if s.conf.StrictSNICheck {
|
||||
if s.conf.TLSConf.StrictSNICheck {
|
||||
if len(cert.DNSNames) != 0 {
|
||||
s.conf.dnsNames = cert.DNSNames
|
||||
s.dnsNames = cert.DNSNames
|
||||
log.Debug("dns: using certificate's SAN as DNS names: %v", cert.DNSNames)
|
||||
slices.Sort(s.conf.dnsNames)
|
||||
slices.Sort(s.dnsNames)
|
||||
} else {
|
||||
s.conf.dnsNames = append(s.conf.dnsNames, cert.Subject.CommonName)
|
||||
s.dnsNames = []string{cert.Subject.CommonName}
|
||||
log.Debug("dns: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
@@ -695,11 +683,11 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||
// Called by 'tls' package when Client Hello is received
|
||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.conf.StrictSNICheck && !anyNameMatches(s.conf.dnsNames, ch.ServerName) {
|
||||
if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) {
|
||||
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||
return nil, fmt.Errorf("invalid SNI")
|
||||
}
|
||||
return &s.conf.cert, nil
|
||||
return s.conf.TLSConf.Cert, nil
|
||||
}
|
||||
|
||||
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
|
||||
|
||||
@@ -296,6 +296,7 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -335,6 +336,7 @@ func TestServer_dns64WithDisabledRDNS(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UseDNS64: true,
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
|
||||
@@ -103,16 +103,26 @@ type SystemResolvers interface {
|
||||
//
|
||||
// The zero Server is empty and ready for use.
|
||||
type Server struct {
|
||||
// dnsProxy is the DNS proxy for forwarding client's DNS requests.
|
||||
dnsProxy *proxy.Proxy
|
||||
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
|
||||
// WHOIS, etc.
|
||||
addrProc client.AddressProcessor
|
||||
|
||||
// dnsFilter is the DNS filter for filtering client's DNS requests and
|
||||
// responses.
|
||||
dnsFilter *filtering.DNSFilter
|
||||
// bootstrap is the resolver for upstreams' hostnames.
|
||||
bootstrap upstream.Resolver
|
||||
|
||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||
// during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// dhcpServer is the DHCP server for accessing lease data.
|
||||
dhcpServer DHCP
|
||||
|
||||
// etcHosts contains the current data from the system's hosts files.
|
||||
etcHosts upstream.Resolver
|
||||
|
||||
// privateNets is the configured set of IP networks considered private.
|
||||
privateNets netutil.SubnetSet
|
||||
|
||||
// queryLog is the query log for client's DNS requests, responses and
|
||||
// filtering results.
|
||||
queryLog querylog.QueryLog
|
||||
@@ -120,37 +130,43 @@ type Server struct {
|
||||
// stats is the statistics collector for client's DNS usage data.
|
||||
stats stats.Interface
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers SystemResolvers
|
||||
|
||||
// access drops disallowed clients.
|
||||
access *accessManager
|
||||
|
||||
// anonymizer masks the client's IP addresses if needed.
|
||||
anonymizer *aghnet.IPMut
|
||||
|
||||
// baseLogger is used to create loggers for other entities. It should not
|
||||
// have a prefix and must not be nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
// must be a valid domain name plus dots on each side.
|
||||
localDomainSuffix string
|
||||
// dnsFilter is the DNS filter for filtering client's DNS requests and
|
||||
// responses.
|
||||
dnsFilter *filtering.DNSFilter
|
||||
|
||||
// dnsProxy is the DNS proxy for forwarding client's DNS requests.
|
||||
dnsProxy *proxy.Proxy
|
||||
|
||||
// internalProxy resolves internal requests from the application itself. It
|
||||
// isn't started and so no listen ports are required.
|
||||
internalProxy *proxy.Proxy
|
||||
|
||||
// ipset processes DNS requests using ipset data. It must not be nil after
|
||||
// initialization. See [newIpsetHandler].
|
||||
ipset *ipsetHandler
|
||||
|
||||
// privateNets is the configured set of IP networks considered private.
|
||||
privateNets netutil.SubnetSet
|
||||
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
|
||||
// part of DNS64 happens inside the [proxy] package, but there still are
|
||||
// some places where response mapping is needed (e.g. DHCP).
|
||||
dns64Pref netip.Prefix
|
||||
|
||||
// addrProc, if not nil, is used to process clients' IP addresses with rDNS,
|
||||
// WHOIS, etc.
|
||||
addrProc client.AddressProcessor
|
||||
|
||||
// sysResolvers used to fetch system resolvers to use by default for private
|
||||
// PTR resolving.
|
||||
sysResolvers SystemResolvers
|
||||
|
||||
// etcHosts contains the current data from the system's hosts files.
|
||||
etcHosts upstream.Resolver
|
||||
|
||||
// bootstrap is the resolver for upstreams' hostnames.
|
||||
bootstrap upstream.Resolver
|
||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||
// must be a valid domain name plus dots on each side.
|
||||
localDomainSuffix string
|
||||
|
||||
// bootResolvers are the resolvers that should be used for
|
||||
// bootstrapping along with [etcHosts].
|
||||
@@ -159,34 +175,26 @@ type Server struct {
|
||||
// [upstream.Resolver] interface.
|
||||
bootResolvers []*upstream.UpstreamResolver
|
||||
|
||||
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
|
||||
// part of DNS64 happens inside the [proxy] package, but there still are
|
||||
// some places where response mapping is needed (e.g. DHCP).
|
||||
dns64Pref netip.Prefix
|
||||
|
||||
// anonymizer masks the client's IP addresses if needed.
|
||||
anonymizer *aghnet.IPMut
|
||||
|
||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||
// during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// internalProxy resolves internal requests from the application itself. It
|
||||
// isn't started and so no listen ports are required.
|
||||
internalProxy *proxy.Proxy
|
||||
|
||||
// isRunning is true if the DNS server is running.
|
||||
isRunning bool
|
||||
|
||||
// protectionUpdateInProgress is used to make sure that only one goroutine
|
||||
// updating the protection configuration after a pause is running at a time.
|
||||
protectionUpdateInProgress atomic.Bool
|
||||
// dnsNames are the DNS names from certificate (SAN) or CN value from
|
||||
// Subject.
|
||||
dnsNames []string
|
||||
|
||||
// conf is the current configuration of the server.
|
||||
conf ServerConfig
|
||||
|
||||
// serverLock protects Server.
|
||||
serverLock sync.RWMutex
|
||||
|
||||
// protectionUpdateInProgress is used to make sure that only one goroutine
|
||||
// updating the protection configuration after a pause is running at a time.
|
||||
protectionUpdateInProgress atomic.Bool
|
||||
|
||||
// isRunning is true if the DNS server is running.
|
||||
isRunning bool
|
||||
|
||||
// hasIPAddrs is set during the certificate parsing and is true if the
|
||||
// configured certificate contains at least a single IP address.
|
||||
hasIPAddrs bool
|
||||
}
|
||||
|
||||
// defaultLocalDomainSuffix is the default suffix used to detect internal hosts
|
||||
|
||||
@@ -213,17 +213,23 @@ func createServerTLSConfig(t *testing.T) (*tls.Config, []byte, []byte) {
|
||||
}, certPem, keyPem
|
||||
}
|
||||
|
||||
func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) {
|
||||
func createTestTLS(t *testing.T, tlsConf *TLSConfig) (s *Server, certPem []byte) {
|
||||
t.Helper()
|
||||
|
||||
var keyPem []byte
|
||||
_, certPem, keyPem = createServerTLSConfig(t)
|
||||
|
||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsConf.Cert = &cert
|
||||
|
||||
s = createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: tlsConf,
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -232,10 +238,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||
s.conf.TLSConfig = tlsConf
|
||||
|
||||
err := s.Prepare(&s.conf)
|
||||
err = s.Prepare(&s.conf)
|
||||
require.NoErrorf(t, err, "failed to prepare server: %s", err)
|
||||
|
||||
return s, certPem
|
||||
@@ -354,6 +357,7 @@ func TestServer(t *testing.T) {
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -395,6 +399,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
t.Run("custom", func(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
UpstreamTimeout: testTimeout,
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -422,6 +427,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.TLSConf = &TLSConfig{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
@@ -436,6 +442,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
|
||||
func TestServer_Prepare_fallbacks(t *testing.T) {
|
||||
srvConf := &ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
FallbackDNS: []string{
|
||||
"#tls://1.1.1.1",
|
||||
@@ -466,6 +473,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -487,7 +495,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoTServer(t *testing.T) {
|
||||
s, certPem := createTestTLS(t, TLSConfig{
|
||||
s, certPem := createTestTLS(t, &TLSConfig{
|
||||
TLSListenAddrs: []*net.TCPAddr{{}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
@@ -511,7 +519,7 @@ func TestDoTServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoQServer(t *testing.T) {
|
||||
s, _ := createTestTLS(t, TLSConfig{
|
||||
s, _ := createTestTLS(t, &TLSConfig{
|
||||
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
|
||||
})
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
@@ -596,6 +604,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -690,6 +699,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -721,6 +731,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -758,6 +769,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
CacheSize: defaultCacheSize,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
@@ -838,6 +850,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -873,6 +886,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -947,6 +961,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -994,6 +1009,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -1064,6 +1080,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
conf := &ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
@@ -1119,6 +1136,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -1172,6 +1190,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
@@ -1235,6 +1254,7 @@ func TestRewrite(t *testing.T) {
|
||||
assert.NoError(t, s.Prepare(&ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53"},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
@@ -1369,6 +1389,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.TLSConf = &TLSConfig{}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
@@ -1457,6 +1478,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
|
||||
s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
|
||||
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
|
||||
s.conf.TLSConf = &TLSConfig{}
|
||||
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{Enabled: false}
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
@@ -1723,6 +1745,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
@@ -1746,6 +1769,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{upsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
|
||||
@@ -37,6 +37,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
srv := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
|
||||
@@ -31,6 +31,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
|
||||
@@ -76,6 +76,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{},
|
||||
TCPListenAddrs: []*net.TCPAddr{},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
FallbackDNS: []string{"9.9.9.10"},
|
||||
@@ -159,6 +160,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{},
|
||||
TCPListenAddrs: []*net.TCPAddr{},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
RatelimitSubnetLenIPv4: 24,
|
||||
@@ -369,6 +371,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
UpstreamTimeout: upsTimeout,
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
|
||||
@@ -246,9 +246,9 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
|
||||
// TODO(e.burkov): Think about storing the FQDN version of the server's
|
||||
// name somewhere.
|
||||
domainName := dns.Fqdn(s.conf.ServerName)
|
||||
domainName := dns.Fqdn(s.conf.TLSConf.ServerName)
|
||||
|
||||
for _, addr := range s.conf.HTTPSListenAddrs {
|
||||
for _, addr := range s.conf.TLSConf.HTTPSListenAddrs {
|
||||
values := []dns.SVCBKeyValue{
|
||||
&dns.SVCBAlpn{Alpn: []string{"h2"}},
|
||||
&dns.SVCBPort{Port: uint16(addr.Port)},
|
||||
@@ -265,7 +265,7 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
resp.Answer = append(resp.Answer, ans)
|
||||
}
|
||||
|
||||
if s.conf.hasIPAddrs {
|
||||
if s.hasIPAddrs {
|
||||
// Only add DNS-over-TLS resolvers in case the certificate contains IP
|
||||
// addresses.
|
||||
//
|
||||
|
||||
@@ -3,6 +3,7 @@ package dnsforward
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
@@ -77,6 +78,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
@@ -177,6 +179,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
AAAADisabled: tc.aaaaDisabled,
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
@@ -316,6 +319,8 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
}}
|
||||
|
||||
_, certPem, keyPem := createServerTLSConfig(t)
|
||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -328,19 +333,18 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
TLSConfig: TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
CertificateChainData: certPem,
|
||||
PrivateKeyData: keyPem,
|
||||
TLSListenAddrs: tc.addrsDoT,
|
||||
HTTPSListenAddrs: tc.addrsDoH,
|
||||
QUICListenAddrs: tc.addrsDoQ,
|
||||
TLSConf: &TLSConfig{
|
||||
ServerName: ddrTestDomainName,
|
||||
Cert: &cert,
|
||||
TLSListenAddrs: tc.addrsDoT,
|
||||
HTTPSListenAddrs: tc.addrsDoH,
|
||||
QUICListenAddrs: tc.addrsDoQ,
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
// TODO(e.burkov): Generate a certificate actually containing the
|
||||
// IP addresses.
|
||||
s.conf.hasIPAddrs = true
|
||||
s.hasIPAddrs = true
|
||||
|
||||
req := createTestMessageWithType(tc.host, tc.qtype)
|
||||
|
||||
@@ -657,6 +661,7 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
||||
// Improve Config declaration for tests.
|
||||
Config: Config{
|
||||
@@ -789,6 +794,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
@@ -818,6 +824,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
|
||||
@@ -1,317 +1,131 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"go.etcd.io/bbolt"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// sessionTokenSize is the length of session token in bytes.
|
||||
const sessionTokenSize = 16
|
||||
// webUser represents a user of the Web UI.
|
||||
type webUser struct {
|
||||
// Name represents the login name of the web user.
|
||||
Name string `yaml:"name"`
|
||||
|
||||
type session struct {
|
||||
userName string
|
||||
// expire is the expiration time, in seconds.
|
||||
expire uint32
|
||||
// PasswordHash is the hashed representation of the web user password.
|
||||
PasswordHash string `yaml:"password"`
|
||||
|
||||
// UserID is the unique identifier of the web user.
|
||||
//
|
||||
// TODO(s.chzhen): !! Use this.
|
||||
UserID aghuser.UserID `yaml:"-"`
|
||||
}
|
||||
|
||||
func (s *session) serialize() []byte {
|
||||
const (
|
||||
expireLen = 4
|
||||
nameLen = 2
|
||||
)
|
||||
data := make([]byte, expireLen+nameLen+len(s.userName))
|
||||
binary.BigEndian.PutUint32(data[0:4], s.expire)
|
||||
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
|
||||
copy(data[6:], []byte(s.userName))
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *session) deserialize(data []byte) bool {
|
||||
if len(data) < 4+2 {
|
||||
return false
|
||||
// toUser returns the new properly initialized *aghuser.User using stored
|
||||
// properties. It panics if there is an error generating the user ID.
|
||||
func (wu *webUser) toUser() (u *aghuser.User) {
|
||||
uid := wu.UserID
|
||||
if uid == (aghuser.UserID{}) {
|
||||
uid = aghuser.MustNewUserID()
|
||||
}
|
||||
s.expire = binary.BigEndian.Uint32(data[0:4])
|
||||
nameLen := binary.BigEndian.Uint16(data[4:6])
|
||||
data = data[6:]
|
||||
|
||||
if len(data) < int(nameLen) {
|
||||
return false
|
||||
return &aghuser.User{
|
||||
Password: aghuser.NewDefaultPassword(wu.PasswordHash),
|
||||
Login: aghuser.Login(wu.Name),
|
||||
ID: uid,
|
||||
}
|
||||
s.userName = string(data)
|
||||
return true
|
||||
}
|
||||
|
||||
// Auth is the global authentication object.
|
||||
type Auth struct {
|
||||
trustedProxies netutil.SubnetSet
|
||||
db *bbolt.DB
|
||||
logger *slog.Logger
|
||||
rateLimiter *authRateLimiter
|
||||
sessions map[string]*session
|
||||
users []webUser
|
||||
lock sync.Mutex
|
||||
sessionTTL uint32
|
||||
sessions aghuser.SessionStorage
|
||||
trustedProxies netutil.SubnetSet
|
||||
users aghuser.DB
|
||||
}
|
||||
|
||||
// webUser represents a user of the Web UI.
|
||||
//
|
||||
// TODO(s.chzhen): Improve naming.
|
||||
type webUser struct {
|
||||
Name string `yaml:"name"`
|
||||
PasswordHash string `yaml:"password"`
|
||||
}
|
||||
|
||||
// InitAuth initializes the global authentication object.
|
||||
// InitAuth initializes the global authentication object. baseLogger,
|
||||
// rateLimiter, trustedProxies must not be nil. dbFilename and sessionTTL
|
||||
// should not be empty.
|
||||
func InitAuth(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
dbFilename string,
|
||||
users []webUser,
|
||||
sessionTTL uint32,
|
||||
sessionTTL time.Duration,
|
||||
rateLimiter *authRateLimiter,
|
||||
trustedProxies netutil.SubnetSet,
|
||||
) (a *Auth) {
|
||||
log.Info("Initializing auth module: %s", dbFilename)
|
||||
|
||||
a = &Auth{
|
||||
sessionTTL: sessionTTL,
|
||||
rateLimiter: rateLimiter,
|
||||
sessions: make(map[string]*session),
|
||||
users: users,
|
||||
trustedProxies: trustedProxies,
|
||||
}
|
||||
var err error
|
||||
|
||||
a.db, err = bbolt.Open(dbFilename, aghos.DefaultPermFile, nil)
|
||||
if err != nil {
|
||||
log.Error("auth: open DB: %s: %s", dbFilename, err)
|
||||
if err.Error() == "invalid argument" {
|
||||
log.Error("AdGuard Home cannot be initialized due to an incompatible file system.\nPlease read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations")
|
||||
) (a *Auth, err error) {
|
||||
userDB := aghuser.NewDefaultDB()
|
||||
for i, u := range users {
|
||||
err = userDB.Create(ctx, u.toUser())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("users: at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
a.loadSessions()
|
||||
log.Info("auth: initialized. users:%d sessions:%d", len(a.users), len(a.sessions))
|
||||
|
||||
return a
|
||||
s, err := aghuser.NewDefaultSessionStorage(ctx, &aghuser.DefaultSessionStorageConfig{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "session_storage"),
|
||||
Clock: timeutil.SystemClock{},
|
||||
UserDB: aghuser.NewDefaultDB(),
|
||||
DBPath: dbFilename,
|
||||
SessionTTL: sessionTTL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating session storage: %w", err)
|
||||
}
|
||||
|
||||
return &Auth{
|
||||
logger: baseLogger.With(slogutil.KeyPrefix, "auth"),
|
||||
rateLimiter: rateLimiter,
|
||||
trustedProxies: trustedProxies,
|
||||
sessions: s,
|
||||
users: userDB,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the authentication database.
|
||||
func (a *Auth) Close() {
|
||||
_ = a.db.Close()
|
||||
}
|
||||
|
||||
func bucketName() []byte {
|
||||
return []byte("sessions-2")
|
||||
}
|
||||
|
||||
// loadSessions loads sessions from the database file and removes expired
|
||||
// sessions.
|
||||
func (a *Auth) loadSessions() {
|
||||
tx, err := a.db.Begin(true)
|
||||
func (a *Auth) Close(ctx context.Context) {
|
||||
err := a.sessions.Close()
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
if tx.Bucket([]byte("sessions")) != nil {
|
||||
_ = tx.DeleteBucket([]byte("sessions"))
|
||||
removed = 1
|
||||
}
|
||||
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
forEach := func(k, v []byte) error {
|
||||
s := session{}
|
||||
if !s.deserialize(v) || s.expire <= now {
|
||||
err = bkt.Delete(k)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Delete: %s", err)
|
||||
} else {
|
||||
removed++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
a.sessions[hex.EncodeToString(k)] = &s
|
||||
return nil
|
||||
}
|
||||
_ = bkt.ForEach(forEach)
|
||||
if removed != 0 {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("bolt.Commit(): %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
|
||||
}
|
||||
|
||||
// addSession adds a new session to the list of sessions and saves it in the
|
||||
// database file.
|
||||
func (a *Auth) addSession(data []byte, s *session) {
|
||||
name := hex.EncodeToString(data)
|
||||
a.lock.Lock()
|
||||
a.sessions[name] = s
|
||||
a.lock.Unlock()
|
||||
if a.storeSession(data, s) {
|
||||
log.Debug("auth: created session %s: expire=%d", name, s.expire)
|
||||
a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// storeSession saves a session in the database file.
|
||||
func (a *Auth) storeSession(data []byte, s *session) bool {
|
||||
tx, err := a.db.Begin(true)
|
||||
// isValidSession returns true if the session is valid.
|
||||
func (a *Auth) isValidSession(ctx context.Context, cookieSess string) (ok bool) {
|
||||
sess, err := hex.DecodeString(cookieSess)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists(bucketName())
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.CreateBucketIfNotExists: %s", err)
|
||||
a.logger.ErrorContext(ctx, "checking session: decoding cookie", slogutil.KeyError, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
err = bkt.Put(data, s.serialize())
|
||||
var t aghuser.SessionToken
|
||||
copy(t[:], sess)
|
||||
|
||||
s, err := a.sessions.FindByToken(ctx, t)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Put: %s", err)
|
||||
a.logger.ErrorContext(ctx, "checking session", slogutil.KeyError, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Commit: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return s != nil
|
||||
}
|
||||
|
||||
// removeSessionFromFile removes a stored session from the DB file on disk.
|
||||
func (a *Auth) removeSessionFromFile(sess []byte) {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Begin: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
log.Error("auth: bbolt.Bucket")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = bkt.Delete(sess)
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Put: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Error("auth: bbolt.Commit: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("auth: removed session from DB")
|
||||
}
|
||||
|
||||
// checkSessionResult is the result of checking a session.
|
||||
type checkSessionResult int
|
||||
|
||||
// checkSessionResult constants.
|
||||
const (
|
||||
checkSessionOK checkSessionResult = 0
|
||||
checkSessionNotFound checkSessionResult = -1
|
||||
checkSessionExpired checkSessionResult = 1
|
||||
)
|
||||
|
||||
// checkSession checks if the session is valid.
|
||||
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
return checkSessionNotFound
|
||||
}
|
||||
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSessionFromFile(key)
|
||||
|
||||
return checkSessionExpired
|
||||
}
|
||||
|
||||
newExpire := now + a.sessionTTL
|
||||
if s.expire/(24*60*60) != newExpire/(24*60*60) {
|
||||
// update expiration time once a day
|
||||
update = true
|
||||
s.expire = newExpire
|
||||
}
|
||||
|
||||
if update {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
if a.storeSession(key, s) {
|
||||
log.Debug("auth: updated session %s: expire=%d", sess, s.expire)
|
||||
}
|
||||
}
|
||||
|
||||
return checkSessionOK
|
||||
}
|
||||
|
||||
// removeSession removes the session from the active sessions and the disk.
|
||||
func (a *Auth) removeSession(sess string) {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.lock.Lock()
|
||||
delete(a.sessions, sess)
|
||||
a.lock.Unlock()
|
||||
a.removeSessionFromFile(key)
|
||||
}
|
||||
|
||||
// addUser adds a new user with the given password.
|
||||
func (a *Auth) addUser(u *webUser, password string) (err error) {
|
||||
// addUser adds a new user with the given password. u must not be nil.
|
||||
func (a *Auth) addUser(ctx context.Context, u *webUser, password string) (err error) {
|
||||
if len(password) == 0 {
|
||||
return errors.Error("empty password")
|
||||
}
|
||||
@@ -323,97 +137,129 @@ func (a *Auth) addUser(u *webUser, password string) (err error) {
|
||||
|
||||
u.PasswordHash = string(hash)
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
err = a.users.Create(ctx, u.toUser())
|
||||
if err != nil {
|
||||
// Should not happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
a.users = append(a.users, *u)
|
||||
|
||||
log.Debug("auth: added user with login %q", u.Name)
|
||||
a.logger.DebugContext(ctx, "added user", "login", u.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findUser returns a user if there is one.
|
||||
func (a *Auth) findUser(login, password string) (u webUser, ok bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
for _, u = range a.users {
|
||||
if u.Name == login &&
|
||||
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
|
||||
return u, true
|
||||
}
|
||||
// findUser returns a user if one exists with the provided login and the
|
||||
// password matches.
|
||||
func (a *Auth) findUser(ctx context.Context, login, password string) (user *aghuser.User) {
|
||||
user, err := a.users.ByLogin(ctx, aghuser.Login(login))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return webUser{}, false
|
||||
ok := user.Password.Authenticate(ctx, password)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
// getCurrentUser returns the current user. It returns an empty User if the
|
||||
// user is not found.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
|
||||
// getCurrentUser searches for a user using a cookie or credentials from basic
|
||||
// authentication.
|
||||
func (a *Auth) getCurrentUser(r *http.Request) (user *aghuser.User) {
|
||||
ctx := r.Context()
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// There's no Cookie, check Basic authentication.
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u, _ = globalContext.auth.findUser(user, pass)
|
||||
|
||||
return u
|
||||
return a.findUser(ctx, user, pass)
|
||||
}
|
||||
|
||||
return webUser{}
|
||||
return nil
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
sess, err := hex.DecodeString(cookie.Value)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(
|
||||
ctx,
|
||||
"searching for user: decoding cookie value",
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
return webUser{}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, u = range a.users {
|
||||
if u.Name == s.userName {
|
||||
return u
|
||||
}
|
||||
var t aghuser.SessionToken
|
||||
copy(t[:], sess)
|
||||
|
||||
s, err := a.sessions.FindByToken(ctx, t)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "searching for user", slogutil.KeyError, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return webUser{}
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &aghuser.User{
|
||||
Login: s.UserLogin,
|
||||
ID: s.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
// removeSession deletes the session from the active sessions and the disk. It
|
||||
// also logs any occurring errors.
|
||||
func (a *Auth) removeSession(ctx context.Context, cookieSess string) {
|
||||
sess, err := hex.DecodeString(cookieSess)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "removing session: decoding cookie", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var t aghuser.SessionToken
|
||||
copy(t[:], sess)
|
||||
|
||||
err = a.sessions.DeleteByToken(ctx, t)
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "removing session by token", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// usersList returns a copy of a users list.
|
||||
func (a *Auth) usersList() (users []webUser) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
func (a *Auth) usersList(ctx context.Context) (webUsers []webUser) {
|
||||
users, err := a.users.All(ctx)
|
||||
if err != nil {
|
||||
// Should not happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
users = make([]webUser, len(a.users))
|
||||
copy(users, a.users)
|
||||
webUsers = make([]webUser, 0, len(users))
|
||||
for _, u := range users {
|
||||
webUsers = append(webUsers, webUser{
|
||||
Name: string(u.Login),
|
||||
PasswordHash: string(u.Password.Hash()),
|
||||
UserID: u.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return users
|
||||
return webUsers
|
||||
}
|
||||
|
||||
// authRequired returns true if a authentication is required.
|
||||
func (a *Auth) authRequired() bool {
|
||||
func (a *Auth) authRequired(ctx context.Context) (ok bool) {
|
||||
if GLMode {
|
||||
return true
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
users, err := a.users.All(ctx)
|
||||
if err != nil {
|
||||
// Should not happen.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return len(a.users) != 0
|
||||
}
|
||||
|
||||
// newSessionToken returns cryptographically secure randomly generated slice of
|
||||
// bytes of sessionTokenSize length.
|
||||
//
|
||||
// TODO(e.burkov): Think about using byte array instead of byte slice.
|
||||
func newSessionToken() (data []byte) {
|
||||
randData := make([]byte, sessionTokenSize)
|
||||
|
||||
// Since Go 1.24, crypto/rand.Read doesn't return an error and crashes
|
||||
// unrecoverably instead.
|
||||
_, _ = rand.Read(randData)
|
||||
|
||||
return randData
|
||||
return len(users) != 0
|
||||
}
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []webUser{{
|
||||
Name: "name",
|
||||
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
|
||||
}}
|
||||
a := InitAuth(fn, nil, 60, nil, nil)
|
||||
s := session{}
|
||||
|
||||
user := webUser{Name: "name"}
|
||||
err := a.addUser(&user, "password")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
|
||||
a.removeSession("notfound")
|
||||
|
||||
sess := newSessionToken()
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
// check expiration
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
s = session{}
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
|
||||
// load saved session
|
||||
a = InitAuth(fn, users, 60, nil, nil)
|
||||
|
||||
// the session is still alive
|
||||
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
|
||||
// reset our expiration time because checkSession() has just updated it
|
||||
s.expire = uint32(time.Now().UTC().Unix() + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u, ok := a.findUser("name", "password")
|
||||
assert.True(t, ok)
|
||||
assert.NotEmpty(t, u.Name)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// load and remove expired sessions
|
||||
a = InitAuth(fn, users, 60, nil, nil)
|
||||
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))
|
||||
|
||||
a.Close()
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -32,10 +33,14 @@ type loginJSON struct {
|
||||
}
|
||||
|
||||
// newCookie creates a new authentication cookie.
|
||||
func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error) {
|
||||
func (a *Auth) newCookie(
|
||||
ctx context.Context,
|
||||
req loginJSON,
|
||||
addr string,
|
||||
) (c *http.Cookie, err error) {
|
||||
rateLimiter := a.rateLimiter
|
||||
u, ok := a.findUser(req.Name, req.Password)
|
||||
if !ok {
|
||||
u := a.findUser(ctx, req.Name, req.Password)
|
||||
if u == nil {
|
||||
if rateLimiter != nil {
|
||||
rateLimiter.inc(addr)
|
||||
}
|
||||
@@ -47,19 +52,16 @@ func (a *Auth) newCookie(req loginJSON, addr string) (c *http.Cookie, err error)
|
||||
rateLimiter.remove(addr)
|
||||
}
|
||||
|
||||
sess := newSessionToken()
|
||||
now := time.Now().UTC()
|
||||
|
||||
a.addSession(sess, &session{
|
||||
userName: u.Name,
|
||||
expire: uint32(now.Unix()) + a.sessionTTL,
|
||||
})
|
||||
s, err := a.sessions.New(ctx, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating session: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: hex.EncodeToString(sess),
|
||||
Value: hex.EncodeToString(s.Token[:]),
|
||||
Path: "/",
|
||||
Expires: now.Add(cookieTTL),
|
||||
Expires: time.Now().Add(cookieTTL),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
@@ -172,7 +174,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
|
||||
}
|
||||
|
||||
cookie, err := globalContext.auth.newCookie(req, remoteIP)
|
||||
cookie, err := globalContext.auth.newCookie(r.Context(), req, remoteIP)
|
||||
if err != nil {
|
||||
logIP := remoteIP
|
||||
if globalContext.auth.trustedProxies.Contains(ip.Unmap()) {
|
||||
@@ -209,7 +211,7 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
globalContext.auth.removeSession(c.Value)
|
||||
globalContext.auth.removeSession(r.Context(), c.Value)
|
||||
|
||||
c = &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
@@ -242,28 +244,7 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (mustAuth bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
// redirect to login page if not authenticated
|
||||
isAuthenticated := false
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
// The only error that is returned from r.Cookie is [http.ErrNoCookie].
|
||||
// Check Basic authentication.
|
||||
user, pass, hasBasic := r.BasicAuth()
|
||||
if hasBasic {
|
||||
_, isAuthenticated = globalContext.auth.findUser(user, pass)
|
||||
if !isAuthenticated {
|
||||
log.Info("%s: invalid basic authorization value", pref)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
res := globalContext.auth.checkSession(cookie.Value)
|
||||
isAuthenticated = res == checkSessionOK
|
||||
if !isAuthenticated {
|
||||
log.Debug("%s: invalid cookie value: %q", pref, cookie)
|
||||
}
|
||||
}
|
||||
|
||||
if isAuthenticated {
|
||||
if u := globalContext.auth.getCurrentUser(r); u != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -289,14 +270,14 @@ func optionalAuth(
|
||||
h func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped func(http.ResponseWriter, *http.Request)) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
p := r.URL.Path
|
||||
authRequired := globalContext.auth != nil && globalContext.auth.authRequired()
|
||||
authRequired := globalContext.auth != nil && globalContext.auth.authRequired(ctx)
|
||||
if p == "/login.html" {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if authRequired && err == nil {
|
||||
// Redirect to the dashboard if already authenticated.
|
||||
res := globalContext.auth.checkSession(cookie.Value)
|
||||
if res == checkSessionOK {
|
||||
if globalContext.auth.isValidSession(ctx, cookie.Value) {
|
||||
http.Redirect(w, r, "", http.StatusFound)
|
||||
|
||||
return
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,13 +35,20 @@ func (w *testResponseWriter) WriteHeader(statusCode int) {
|
||||
}
|
||||
|
||||
func TestAuthHTTP(t *testing.T) {
|
||||
var (
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
logger = slogutil.NewDiscardLogger()
|
||||
err error
|
||||
)
|
||||
|
||||
dir := t.TempDir()
|
||||
fn := filepath.Join(dir, "sessions.db")
|
||||
|
||||
users := []webUser{
|
||||
{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
globalContext.auth = InitAuth(fn, users, 60, nil, nil)
|
||||
globalContext.auth, err = InitAuth(ctx, logger, fn, users, time.Minute, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
handlerCalled := false
|
||||
handler := func(_ http.ResponseWriter, _ *http.Request) {
|
||||
@@ -68,7 +77,11 @@ func TestAuthHTTP(t *testing.T) {
|
||||
assert.True(t, handlerCalled)
|
||||
|
||||
// perform login
|
||||
cookie, err := globalContext.auth.newCookie(loginJSON{Name: "name", Password: "password"}, "")
|
||||
cookie, err := globalContext.auth.newCookie(
|
||||
ctx,
|
||||
loginJSON{Name: "name", Password: "password"},
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cookie)
|
||||
|
||||
@@ -114,7 +127,7 @@ func TestAuthHTTP(t *testing.T) {
|
||||
assert.True(t, handlerCalled)
|
||||
r.Header.Del(httphdr.Cookie)
|
||||
|
||||
globalContext.auth.Close()
|
||||
globalContext.auth.Close(ctx)
|
||||
}
|
||||
|
||||
func TestRealIP(t *testing.T) {
|
||||
|
||||
@@ -2,10 +2,12 @@ package home
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
@@ -23,6 +25,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/renameio/v2/maybe"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -261,30 +264,128 @@ type dnsConfig struct {
|
||||
// HostsFileEnabled defines whether to use information from the system hosts
|
||||
// file to resolve queries.
|
||||
HostsFileEnabled bool `yaml:"hostsfile_enabled"`
|
||||
|
||||
// PendingRequests configures duplicate requests policy.
|
||||
PendingRequests *pendingRequests `yaml:"pending_requests"`
|
||||
}
|
||||
|
||||
type tlsConfigSettings struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // Enabled is the encryption (DoT/DoH/HTTPS) status
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"` // ServerName is the hostname of your HTTPS/TLS server
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"` // ForceHTTPS: if true, forces HTTP->HTTPS redirect
|
||||
PortHTTPS uint16 `yaml:"port_https" json:"port_https,omitempty"` // HTTPS port. If 0, HTTPS will be disabled
|
||||
PortDNSOverTLS uint16 `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"` // DNS-over-TLS port. If 0, DoT will be disabled
|
||||
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"` // DNS-over-QUIC port. If 0, DoQ will be disabled
|
||||
// pendingRequests is a block with pending requests configuration.
|
||||
type pendingRequests struct {
|
||||
// Enabled controls if duplicate requests should be sent to the upstreams
|
||||
// along with the original one.
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// PortDNSCrypt is the port for DNSCrypt requests. If it's zero,
|
||||
// DNSCrypt is disabled.
|
||||
// tlsConfigSettings is the TLS configuration for DNS-over-TLS, DNS-over-QUIC,
|
||||
// and HTTPS. When adding new properties, update the [tlsConfigSettings.clone]
|
||||
// and [tlsConfigSettings.setPrivateFieldsAndCompare] methods as necessary.
|
||||
type tlsConfigSettings struct {
|
||||
// Enabled indicates whether encryption (DoT/DoH/HTTPS) is enabled.
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
|
||||
// ServerName is the hostname of the HTTPS/TLS server.
|
||||
ServerName string `yaml:"server_name" json:"server_name,omitempty"`
|
||||
|
||||
// ForceHTTPS, if true, forces an HTTP to HTTPS redirect.
|
||||
ForceHTTPS bool `yaml:"force_https" json:"force_https"`
|
||||
|
||||
// PortHTTPS is the HTTPS port. If 0, HTTPS will be disabled.
|
||||
PortHTTPS uint16 `yaml:"port_https" json:"port_https,omitempty"`
|
||||
|
||||
// PortDNSOverTLS is the DNS-over-TLS port. If 0, DoT will be disabled.
|
||||
PortDNSOverTLS uint16 `yaml:"port_dns_over_tls" json:"port_dns_over_tls,omitempty"`
|
||||
|
||||
// PortDNSOverQUIC is the DNS-over-QUIC port. If 0, DoQ will be disabled.
|
||||
PortDNSOverQUIC uint16 `yaml:"port_dns_over_quic" json:"port_dns_over_quic,omitempty"`
|
||||
|
||||
// PortDNSCrypt is the port for DNSCrypt requests. If it's zero, DNSCrypt
|
||||
// is disabled.
|
||||
PortDNSCrypt uint16 `yaml:"port_dnscrypt" json:"port_dnscrypt"`
|
||||
// DNSCryptConfigFile is the path to the DNSCrypt config file. Must be
|
||||
// set if PortDNSCrypt is not zero.
|
||||
|
||||
// DNSCryptConfigFile is the path to the DNSCrypt config file. Must be set
|
||||
// if PortDNSCrypt is not zero.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/dnsproxy and
|
||||
// https://github.com/ameshkov/dnscrypt.
|
||||
DNSCryptConfigFile string `yaml:"dnscrypt_config_file" json:"dnscrypt_config_file"`
|
||||
|
||||
// Allow DoH queries via unencrypted HTTP (e.g. for reverse proxying)
|
||||
// AllowUnencryptedDoH allows DoH queries via unencrypted HTTP (e.g. for
|
||||
// reverse proxying).
|
||||
//
|
||||
// TODO(s.chzhen): Add this option into the Web UI.
|
||||
AllowUnencryptedDoH bool `yaml:"allow_unencrypted_doh" json:"allow_unencrypted_doh"`
|
||||
|
||||
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
|
||||
// CertificateChain is the PEM-encoded certificate chain. Must be empty if
|
||||
// [tlsConfigSettings.CertificatePath] is provided.
|
||||
CertificateChain string `yaml:"certificate_chain" json:"certificate_chain"`
|
||||
|
||||
// PrivateKey is the PEM-encoded private key. Must be empty if
|
||||
// [tlsConfigSettings.PrivateKeyPath] is provided.
|
||||
PrivateKey string `yaml:"private_key" json:"private_key"`
|
||||
|
||||
// CertificatePath is the path to the certificate file. Must be empty if
|
||||
// [tlsConfigSettings.CertificateChain] is provided.
|
||||
CertificatePath string `yaml:"certificate_path" json:"certificate_path"`
|
||||
|
||||
// PrivateKeyPath is the path to the private key file. Must be empty if
|
||||
// [tlsConfigSettings.PrivateKey] is provided.
|
||||
PrivateKeyPath string `yaml:"private_key_path" json:"private_key_path"`
|
||||
|
||||
// OverrideTLSCiphers, when set, contains the names of the cipher suites to
|
||||
// use. If the slice is empty, the default safe suites are used.
|
||||
OverrideTLSCiphers []string `yaml:"override_tls_ciphers,omitempty" json:"-"`
|
||||
|
||||
// CertificateChainData is the PEM-encoded byte data for the certificate
|
||||
// chain.
|
||||
CertificateChainData []byte `yaml:"-" json:"-"`
|
||||
|
||||
// PrivateKeyData is the PEM-encoded byte data for the private key.
|
||||
PrivateKeyData []byte `yaml:"-" json:"-"`
|
||||
|
||||
// StrictSNICheck controls if the connections with SNI mismatching the
|
||||
// certificate's ones should be rejected.
|
||||
StrictSNICheck bool `yaml:"strict_sni_check" json:"-"`
|
||||
}
|
||||
|
||||
// clone returns a deep copy of c.
|
||||
func (c *tlsConfigSettings) clone() (clone *tlsConfigSettings) {
|
||||
clone = &tlsConfigSettings{}
|
||||
*clone = *c
|
||||
|
||||
clone.OverrideTLSCiphers = slices.Clone(c.OverrideTLSCiphers)
|
||||
clone.CertificateChainData = slices.Clone(c.CertificateChainData)
|
||||
clone.PrivateKeyData = slices.Clone(c.PrivateKeyData)
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// setPrivateFieldsAndCompare sets any missing properties in conf to match those
|
||||
// in c and returns true if TLS configurations are equal. conf must not be be
|
||||
// nil.
|
||||
// It sets the following properties because these are not accepted from the
|
||||
// frontend:
|
||||
//
|
||||
// [tlsConfigSettings.AllowUnencryptedDoH]
|
||||
// [tlsConfigSettings.DNSCryptConfigFile]
|
||||
// [tlsConfigSettings.OverrideTLSCiphers]
|
||||
// [tlsConfigSettings.PortDNSCrypt]
|
||||
//
|
||||
// The following properties are skipped as they are set by
|
||||
// [tlsManager.loadTLSConfig]:
|
||||
//
|
||||
// [tlsConfigSettings.CertificateChainData]
|
||||
// [tlsConfigSettings.PrivateKeyData]
|
||||
func (c *tlsConfigSettings) setPrivateFieldsAndCompare(conf *tlsConfigSettings) (equal bool) {
|
||||
conf.OverrideTLSCiphers = slices.Clone(c.OverrideTLSCiphers)
|
||||
|
||||
// TODO(s.chzhen): Remove this once the frontend supports it.
|
||||
conf.AllowUnencryptedDoH = c.AllowUnencryptedDoH
|
||||
|
||||
conf.DNSCryptConfigFile = c.DNSCryptConfigFile
|
||||
conf.PortDNSCrypt = c.PortDNSCrypt
|
||||
|
||||
// TODO(a.garipov): Define a custom comparer.
|
||||
return cmp.Equal(c, conf)
|
||||
}
|
||||
|
||||
type queryLogConfig struct {
|
||||
@@ -380,6 +481,9 @@ var config = &configuration{
|
||||
UsePrivateRDNS: true,
|
||||
ServePlainDNS: true,
|
||||
HostsFileEnabled: true,
|
||||
PendingRequests: &pendingRequests{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
TLS: tlsConfigSettings{
|
||||
PortHTTPS: defaultPortHTTPS,
|
||||
@@ -645,13 +749,13 @@ func (c *configuration) write(tlsMgr *tlsManager) (err error) {
|
||||
defer c.Unlock()
|
||||
|
||||
if globalContext.auth != nil {
|
||||
config.Users = globalContext.auth.usersList()
|
||||
// TODO(s.chzhen): Pass context.
|
||||
config.Users = globalContext.auth.usersList(context.TODO())
|
||||
}
|
||||
|
||||
if tlsMgr != nil {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
tlsMgr.WriteDiskConfig(&tlsConf)
|
||||
config.TLS = tlsConf
|
||||
tlsConf := tlsMgr.config()
|
||||
config.TLS = *tlsConf
|
||||
}
|
||||
|
||||
if globalContext.stats != nil {
|
||||
|
||||
@@ -392,6 +392,8 @@ const PasswordMinRunes = 8
|
||||
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
@@ -439,7 +441,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
u := &webUser{
|
||||
Name: req.Username,
|
||||
}
|
||||
err = globalContext.auth.addUser(u, req.Password)
|
||||
err = globalContext.auth.addUser(ctx, u, req.Password)
|
||||
if err != nil {
|
||||
globalContext.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
@@ -452,7 +454,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods(r.Context(), web.baseLogger, web.tlsManager)
|
||||
err = startMods(ctx, web.baseLogger, web.tlsManager)
|
||||
if err != nil {
|
||||
globalContext.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
@@ -488,11 +490,11 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// and with its own context, because it waits until all requests are handled
|
||||
// and will be blocked by it's own caller.
|
||||
go func(timeout time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer slogutil.RecoverAndLog(ctx, web.logger)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
|
||||
defer cancel()
|
||||
|
||||
shutdownSrv(ctx, web.logger, web.httpServer)
|
||||
shutdownSrv(shutdownCtx, web.logger, web.httpServer)
|
||||
}(shutdownTimeout)
|
||||
}
|
||||
|
||||
|
||||
@@ -164,11 +164,8 @@ func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
tlsMgr.WriteDiskConfig(tlsConf)
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsConf) ||
|
||||
if tlsConfUsesPrivilegedPorts(tlsMgr.config()) ||
|
||||
config.HTTPConfig.Address.Port() < 1024 ||
|
||||
config.DNS.Port < 1024 {
|
||||
canUpdate, err = aghnet.CanBindPrivilegedPorts()
|
||||
|
||||
@@ -2,6 +2,7 @@ package home
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -111,9 +112,6 @@ func initDNS(
|
||||
return err
|
||||
}
|
||||
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
tlsMgr.WriteDiskConfig(tlsConf)
|
||||
|
||||
return initDNSServer(
|
||||
globalContext.filters,
|
||||
globalContext.stats,
|
||||
@@ -121,7 +119,7 @@ func initDNS(
|
||||
globalContext.dhcpServer,
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
tlsMgr.config(),
|
||||
tlsMgr,
|
||||
baseLogger,
|
||||
)
|
||||
@@ -255,11 +253,16 @@ func newServerConfig(
|
||||
fwdConf := dnsConf.Config
|
||||
fwdConf.ClientsContainer = clientsContainer
|
||||
|
||||
intTLSConf, err := newDNSTLSConfig(tlsConf, hosts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("constructing tls config: %w", err)
|
||||
}
|
||||
|
||||
newConf = &dnsforward.ServerConfig{
|
||||
UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port),
|
||||
TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port),
|
||||
Config: fwdConf,
|
||||
TLSConfig: newDNSTLSConfig(tlsConf, hosts),
|
||||
TLSConf: intTLSConf,
|
||||
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||
TLSv12Roots: tlsMgr.rootCerts,
|
||||
@@ -272,6 +275,7 @@ func newServerConfig(
|
||||
ServeHTTP3: dnsConf.ServeHTTP3,
|
||||
UseHTTP3Upstreams: dnsConf.UseHTTP3Upstreams,
|
||||
ServePlainDNS: dnsConf.ServePlainDNS,
|
||||
PendingRequestsEnabled: dnsConf.PendingRequests.Enabled,
|
||||
}
|
||||
|
||||
var initialAddresses []netip.Addr
|
||||
@@ -304,14 +308,19 @@ func newServerConfig(
|
||||
}
|
||||
|
||||
// newDNSTLSConfig converts values from the configuration file into the internal
|
||||
// TLS settings for the DNS server. tlsConf must not be nil.
|
||||
func newDNSTLSConfig(conf *tlsConfigSettings, addrs []netip.Addr) (dnsConf dnsforward.TLSConfig) {
|
||||
// TLS settings for the DNS server. conf must not be nil.
|
||||
func newDNSTLSConfig(
|
||||
conf *tlsConfigSettings,
|
||||
addrs []netip.Addr,
|
||||
) (dnsConf *dnsforward.TLSConfig, err error) {
|
||||
if !conf.Enabled {
|
||||
return dnsforward.TLSConfig{}
|
||||
return &dnsforward.TLSConfig{}, nil
|
||||
}
|
||||
|
||||
dnsConf = conf.TLSConfig
|
||||
dnsConf.ServerName = conf.ServerName
|
||||
dnsConf = &dnsforward.TLSConfig{
|
||||
ServerName: conf.ServerName,
|
||||
StrictSNICheck: conf.StrictSNICheck,
|
||||
}
|
||||
|
||||
if conf.PortHTTPS != 0 {
|
||||
dnsConf.HTTPSListenAddrs = ipsToTCPAddrs(addrs, conf.PortHTTPS)
|
||||
@@ -325,7 +334,29 @@ func newDNSTLSConfig(conf *tlsConfigSettings, addrs []netip.Addr) (dnsConf dnsfo
|
||||
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, conf.PortDNSOverQUIC)
|
||||
}
|
||||
|
||||
return dnsConf
|
||||
cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
|
||||
if err != nil {
|
||||
const format = "parsing tls key pair: %w"
|
||||
if conf.AllowUnencryptedDoH {
|
||||
// TODO(s.chzhen): Use [slog.Logger].
|
||||
log.Info("warning: %s: %s", format, err)
|
||||
|
||||
return dnsConf, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(format, err)
|
||||
}
|
||||
|
||||
// Unencrypted DoH is managed by AdGuard Home itself, not by dnsproxy.
|
||||
// Therefore, avoid setting the certificate property to prevent dnsproxy
|
||||
// from starting encrypted listeners. See [dnsforward.Server.prepareTLS].
|
||||
if conf.AllowUnencryptedDoH {
|
||||
return dnsConf, nil
|
||||
}
|
||||
|
||||
dnsConf.Cert = &cert
|
||||
|
||||
return dnsConf, nil
|
||||
}
|
||||
|
||||
// newDNSCryptConfig converts values from the configuration file into the
|
||||
@@ -378,8 +409,7 @@ type dnsEncryption struct {
|
||||
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
|
||||
// listens on. tlsMgr must not be nil.
|
||||
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
tlsMgr.WriteDiskConfig(&tlsConf)
|
||||
tlsConf := tlsMgr.config()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
|
||||
return dnsEncryption{}
|
||||
|
||||
@@ -668,7 +668,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
||||
GLMode = opts.glinetMode
|
||||
|
||||
// Init auth module.
|
||||
globalContext.auth, err = initUsers()
|
||||
globalContext.auth, err = initUsers(ctx, slogLogger)
|
||||
fatalOnError(err)
|
||||
|
||||
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
|
||||
@@ -786,7 +786,8 @@ func checkPermissions(
|
||||
}
|
||||
|
||||
// initUsers initializes context auth module. Clears config users field.
|
||||
func initUsers() (auth *Auth, err error) {
|
||||
// baseLogger must not be nil.
|
||||
func initUsers(ctx context.Context, baseLogger *slog.Logger) (auth *Auth, err error) {
|
||||
sessFilename := filepath.Join(globalContext.getDataDir(), "sessions.db")
|
||||
|
||||
var rateLimiter *authRateLimiter
|
||||
@@ -799,10 +800,17 @@ func initUsers() (auth *Auth, err error) {
|
||||
|
||||
trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies))
|
||||
|
||||
sessionTTL := time.Duration(config.HTTPConfig.SessionTTL).Seconds()
|
||||
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter, trustedProxies)
|
||||
if auth == nil {
|
||||
return nil, errors.Error("initializing auth module failed")
|
||||
auth, err = InitAuth(
|
||||
ctx,
|
||||
baseLogger,
|
||||
sessFilename,
|
||||
config.Users,
|
||||
time.Duration(config.HTTPConfig.SessionTTL),
|
||||
rateLimiter,
|
||||
trustedProxies,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing auth module: %w", err)
|
||||
}
|
||||
|
||||
config.Users = nil
|
||||
@@ -916,7 +924,7 @@ func cleanup(ctx context.Context) {
|
||||
globalContext.web = nil
|
||||
}
|
||||
if globalContext.auth != nil {
|
||||
globalContext.auth.Close()
|
||||
globalContext.auth.Close(ctx)
|
||||
globalContext.auth = nil
|
||||
}
|
||||
|
||||
@@ -991,9 +999,9 @@ func printWebAddrs(proto, addr string, port uint16) {
|
||||
//
|
||||
// TODO(s.chzhen): Implement separate functions for HTTP and HTTPS.
|
||||
func printHTTPAddresses(proto string, tlsMgr *tlsManager) {
|
||||
tlsConf := tlsConfigSettings{}
|
||||
var tlsConf *tlsConfigSettings
|
||||
if tlsMgr != nil {
|
||||
tlsMgr.WriteDiskConfig(&tlsConf)
|
||||
tlsConf = tlsMgr.config()
|
||||
}
|
||||
|
||||
port := config.HTTPConfig.Address.Port()
|
||||
|
||||
@@ -47,7 +47,11 @@ type profileJSON struct {
|
||||
|
||||
// handleGetProfile is the handler for GET /control/profile endpoint.
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
name := ""
|
||||
u := globalContext.auth.getCurrentUser(r)
|
||||
if u != nil {
|
||||
name = string(u.Login)
|
||||
}
|
||||
|
||||
var resp profileJSON
|
||||
func() {
|
||||
@@ -55,7 +59,7 @@ func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
defer config.RUnlock()
|
||||
|
||||
resp = profileJSON{
|
||||
Name: u.Name,
|
||||
Name: name,
|
||||
Language: config.Language,
|
||||
Theme: config.Theme,
|
||||
}
|
||||
|
||||
@@ -24,11 +24,9 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/c2h5oh/datasize"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// tlsManager contains the current configuration and state of AdGuard Home TLS
|
||||
@@ -37,6 +35,9 @@ type tlsManager struct {
|
||||
// logger is used for logging the operation of the TLS Manager.
|
||||
logger *slog.Logger
|
||||
|
||||
// mu protects status, certLastMod, conf, and servePlainDNS.
|
||||
mu *sync.Mutex
|
||||
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
status *tlsConfigStatus
|
||||
|
||||
@@ -52,6 +53,9 @@ type tlsManager struct {
|
||||
// Resolve it.
|
||||
web *webAPI
|
||||
|
||||
// conf contains the TLS configuration settings. It must not be nil.
|
||||
conf *tlsConfigSettings
|
||||
|
||||
// configModified is called when the TLS configuration is changed via an
|
||||
// HTTP request.
|
||||
configModified func()
|
||||
@@ -59,9 +63,6 @@ type tlsManager struct {
|
||||
// customCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
customCipherIDs []uint16
|
||||
|
||||
confLock sync.Mutex
|
||||
conf tlsConfigSettings
|
||||
|
||||
// servePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||
servePlainDNS bool
|
||||
}
|
||||
@@ -91,9 +92,10 @@ type tlsManagerConfig struct {
|
||||
func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
logger: conf.logger,
|
||||
mu: &sync.Mutex{},
|
||||
configModified: conf.configModified,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: conf.tlsSettings,
|
||||
conf: &conf.tlsSettings,
|
||||
servePlainDNS: conf.servePlainDNS,
|
||||
}
|
||||
|
||||
@@ -112,17 +114,22 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
|
||||
m.logger.InfoContext(ctx, "using default ciphers")
|
||||
}
|
||||
|
||||
if m.conf.Enabled {
|
||||
err = m.load(ctx)
|
||||
if err != nil {
|
||||
m.conf.Enabled = false
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m, err
|
||||
}
|
||||
|
||||
m.setCertFileTime(ctx)
|
||||
if !m.conf.Enabled {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
err = m.load(ctx)
|
||||
if err != nil {
|
||||
m.conf.Enabled = false
|
||||
|
||||
return m, err
|
||||
}
|
||||
|
||||
m.setCertFileTime(ctx)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -136,8 +143,9 @@ func (m *tlsManager) setWebAPI(webAPI *webAPI) {
|
||||
}
|
||||
|
||||
// load reloads the TLS configuration from files or data from the config file.
|
||||
// m.mu is expected to be locked.
|
||||
func (m *tlsManager) load(ctx context.Context) (err error) {
|
||||
err = m.loadTLSConf(ctx, &m.conf, m.status)
|
||||
err = m.loadTLSConfig(ctx, m.conf, m.status)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
@@ -145,15 +153,16 @@ func (m *tlsManager) load(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write config
|
||||
func (m *tlsManager) WriteDiskConfig(conf *tlsConfigSettings) {
|
||||
m.confLock.Lock()
|
||||
*conf = m.conf
|
||||
m.confLock.Unlock()
|
||||
// config returns a deep copy of the stored TLS configuration.
|
||||
func (m *tlsManager) config() (conf *tlsConfigSettings) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.conf.clone()
|
||||
}
|
||||
|
||||
// setCertFileTime sets [tlsManager.certLastMod] from the certificate. If there
|
||||
// are errors, setCertFileTime logs them.
|
||||
// are errors, setCertFileTime logs them. m.mu is expected to be locked.
|
||||
func (m *tlsManager) setCertFileTime(ctx context.Context) {
|
||||
if len(m.conf.CertificatePath) == 0 {
|
||||
return
|
||||
@@ -175,21 +184,21 @@ func (m *tlsManager) setCertFileTime(ctx context.Context) {
|
||||
func (m *tlsManager) start(_ context.Context) {
|
||||
m.registerWebHandlers()
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
m.web.tlsConfigChanged(context.Background(), tlsConf)
|
||||
m.web.tlsConfigChanged(context.Background(), m.conf)
|
||||
}
|
||||
|
||||
// reload updates the configuration and restarts the TLS manager.
|
||||
func (m *tlsManager) reload(ctx context.Context) {
|
||||
m.confLock.Lock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tlsConf := m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.CertificatePath) == 0 {
|
||||
return
|
||||
@@ -211,9 +220,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
||||
|
||||
m.logger.InfoContext(ctx, "certificate file is modified")
|
||||
|
||||
m.confLock.Lock()
|
||||
err = m.load(ctx)
|
||||
m.confLock.Unlock()
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "reloading", slogutil.KeyError, err)
|
||||
|
||||
@@ -227,10 +234,6 @@ func (m *tlsManager) reload(ctx context.Context) {
|
||||
m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
m.confLock.Lock()
|
||||
tlsConf = m.conf
|
||||
m.confLock.Unlock()
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
@@ -238,15 +241,12 @@ func (m *tlsManager) reload(ctx context.Context) {
|
||||
}
|
||||
|
||||
// reconfigureDNSServer updates the DNS server configuration using the stored
|
||||
// TLS settings.
|
||||
// TLS settings. m.mu is expected to be locked.
|
||||
func (m *tlsManager) reconfigureDNSServer() (err error) {
|
||||
tlsConf := &tlsConfigSettings{}
|
||||
m.WriteDiskConfig(tlsConf)
|
||||
|
||||
newConf, err := newServerConfig(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
tlsConf,
|
||||
m.conf,
|
||||
m,
|
||||
httpRegister,
|
||||
globalContext.clients.storage,
|
||||
@@ -263,9 +263,11 @@ func (m *tlsManager) reconfigureDNSServer() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTLSConf loads and validates the TLS configuration. The returned error is
|
||||
// also set in status.WarningValidation.
|
||||
func (m *tlsManager) loadTLSConf(
|
||||
// loadTLSConfig loads and validates the TLS configuration. It also sets
|
||||
// [tlsConfigSettings.CertificateChainData] and
|
||||
// [tlsConfigSettings.PrivateKeyData] properties. The returned error is also
|
||||
// set in status.WarningValidation.
|
||||
func (m *tlsManager) loadTLSConfig(
|
||||
ctx context.Context,
|
||||
tlsConf *tlsConfigSettings,
|
||||
status *tlsConfigStatus,
|
||||
@@ -357,10 +359,10 @@ type tlsConfigStatus struct {
|
||||
KeyType string `json:"key_type,omitempty"`
|
||||
|
||||
// NotBefore is the NotBefore field of the first certificate in the chain.
|
||||
NotBefore time.Time `json:"not_before,omitempty"`
|
||||
NotBefore time.Time `json:"not_before"`
|
||||
|
||||
// NotAfter is the NotAfter field of the first certificate in the chain.
|
||||
NotAfter time.Time `json:"not_after,omitempty"`
|
||||
NotAfter time.Time `json:"not_after"`
|
||||
|
||||
// WarningValidation is a validation warning message with the issue
|
||||
// description.
|
||||
@@ -410,15 +412,23 @@ type tlsConfigSettingsExt struct {
|
||||
|
||||
// handleTLSStatus is the handler for the GET /control/tls/status HTTP API.
|
||||
func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
m.confLock.Lock()
|
||||
var tlsConf *tlsConfigSettings
|
||||
var servePlainDNS bool
|
||||
func() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tlsConf = m.conf.clone()
|
||||
servePlainDNS = m.servePlainDNS
|
||||
}()
|
||||
|
||||
data := tlsConfig{
|
||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: m.conf,
|
||||
ServePlainDNS: aghalg.BoolToNullBool(m.servePlainDNS),
|
||||
tlsConfigSettings: *tlsConf,
|
||||
ServePlainDNS: aghalg.BoolToNullBool(servePlainDNS),
|
||||
},
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
m.confLock.Unlock()
|
||||
|
||||
marshalTLS(w, r, data)
|
||||
}
|
||||
@@ -434,6 +444,9 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if setts.PrivateKeySaved {
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
@@ -449,7 +462,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip the error check, since we are only interested in the value of
|
||||
// status.WarningValidation.
|
||||
status := &tlsConfigStatus{}
|
||||
_ = m.loadTLSConf(ctx, &setts.tlsConfigSettings, status)
|
||||
_ = m.loadTLSConfig(ctx, &setts.tlsConfigSettings, status)
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: setts,
|
||||
tlsConfigStatus: status,
|
||||
@@ -458,42 +471,23 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
marshalTLS(w, r, resp)
|
||||
}
|
||||
|
||||
// setConfig updates manager conf with the given one.
|
||||
// setConfig updates manager TLS configuration with the given one. m.mu is
|
||||
// expected to be locked.
|
||||
func (m *tlsManager) setConfig(
|
||||
ctx context.Context,
|
||||
newConf tlsConfigSettings,
|
||||
status *tlsConfigStatus,
|
||||
servePlain aghalg.NullBool,
|
||||
) (restartHTTPS bool) {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
|
||||
// Reset the DNSCrypt data before comparing, since we currently do not
|
||||
// accept these from the frontend.
|
||||
//
|
||||
// TODO(a.garipov): Define a custom comparer for dnsforward.TLSConfig.
|
||||
newConf.DNSCryptConfigFile = m.conf.DNSCryptConfigFile
|
||||
newConf.PortDNSCrypt = m.conf.PortDNSCrypt
|
||||
if !cmp.Equal(m.conf, newConf, cmp.AllowUnexported(dnsforward.TLSConfig{})) {
|
||||
if !m.conf.setPrivateFieldsAndCompare(&newConf) {
|
||||
m.logger.InfoContext(ctx, "config has changed, restarting https server")
|
||||
restartHTTPS = true
|
||||
} else {
|
||||
m.logger.InfoContext(ctx, "config has not changed")
|
||||
}
|
||||
|
||||
// Note: don't do just `t.conf = data` because we must preserve all other members of t.conf
|
||||
m.conf.Enabled = newConf.Enabled
|
||||
m.conf.ServerName = newConf.ServerName
|
||||
m.conf.ForceHTTPS = newConf.ForceHTTPS
|
||||
m.conf.PortHTTPS = newConf.PortHTTPS
|
||||
m.conf.PortDNSOverTLS = newConf.PortDNSOverTLS
|
||||
m.conf.PortDNSOverQUIC = newConf.PortDNSOverQUIC
|
||||
m.conf.CertificateChain = newConf.CertificateChain
|
||||
m.conf.CertificatePath = newConf.CertificatePath
|
||||
m.conf.CertificateChainData = newConf.CertificateChainData
|
||||
m.conf.PrivateKey = newConf.PrivateKey
|
||||
m.conf.PrivateKeyPath = newConf.PrivateKeyPath
|
||||
m.conf.PrivateKeyData = newConf.PrivateKeyData
|
||||
m.conf = &newConf
|
||||
|
||||
m.status = status
|
||||
|
||||
if servePlain != aghalg.NBNull {
|
||||
@@ -515,6 +509,16 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
var restartHTTPS bool
|
||||
defer func() {
|
||||
if restartHTTPS {
|
||||
m.configModified()
|
||||
}
|
||||
}()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
}
|
||||
@@ -526,7 +530,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
status := &tlsConfigStatus{}
|
||||
err = m.loadTLSConf(ctx, &req.tlsConfigSettings, status)
|
||||
err = m.loadTLSConfig(ctx, &req.tlsConfigSettings, status)
|
||||
if err != nil {
|
||||
resp := tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
@@ -538,20 +542,18 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
restartHTTPS := m.setConfig(ctx, req.tlsConfigSettings, status, req.ServePlainDNS)
|
||||
restartHTTPS = m.setConfig(ctx, req.tlsConfigSettings, status, req.ServePlainDNS)
|
||||
m.setCertFileTime(ctx)
|
||||
|
||||
if req.ServePlainDNS != aghalg.NBNull {
|
||||
func() {
|
||||
m.confLock.Lock()
|
||||
defer m.confLock.Unlock()
|
||||
config.Lock()
|
||||
defer config.Unlock()
|
||||
|
||||
config.DNS.ServePlainDNS = req.ServePlainDNS == aghalg.NBTrue
|
||||
}()
|
||||
}
|
||||
|
||||
m.configModified()
|
||||
|
||||
err = m.reconfigureDNSServer()
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err)
|
||||
@@ -567,18 +569,18 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
rc := http.NewResponseController(w)
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "flushing response", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request. It is also should be done in a separate goroutine due to the
|
||||
// request. It is also should be done in a separate goroutine due to the
|
||||
// same reason.
|
||||
if restartHTTPS {
|
||||
go func() {
|
||||
m.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
|
||||
}()
|
||||
go m.web.tlsConfigChanged(context.Background(), &req.tlsConfigSettings)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -239,11 +239,9 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
},
|
||||
Enabled: true,
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
@@ -254,8 +252,7 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
|
||||
m.setWebAPI(web)
|
||||
|
||||
conf := &tlsConfigSettings{}
|
||||
m.WriteDiskConfig(conf)
|
||||
conf := m.config()
|
||||
assertCertSerialNumber(t, conf, snBefore)
|
||||
|
||||
certDER, key = newCertAndKey(t, snAfter)
|
||||
@@ -263,7 +260,7 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
|
||||
m.reload(ctx)
|
||||
|
||||
m.WriteDiskConfig(conf)
|
||||
conf = m.config()
|
||||
assertCertSerialNumber(t, conf, snAfter)
|
||||
}
|
||||
|
||||
@@ -278,11 +275,9 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
Enabled: true,
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
@@ -342,47 +337,49 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
busyUDPPort := udpAddr.Port
|
||||
|
||||
testCases := []struct {
|
||||
setts tlsConfigSettingsExt
|
||||
name string
|
||||
wantErr string
|
||||
setts tlsConfigSettingsExt
|
||||
}{{
|
||||
name: "basic",
|
||||
setts: tlsConfigSettingsExt{},
|
||||
wantErr: "",
|
||||
setts: tlsConfigSettingsExt{},
|
||||
}, {
|
||||
name: "disabled_all",
|
||||
wantErr: "plain DNS is required in case encryption protocols are disabled",
|
||||
setts: tlsConfigSettingsExt{
|
||||
ServePlainDNS: aghalg.NBFalse,
|
||||
},
|
||||
name: "disabled_all",
|
||||
wantErr: "plain DNS is required in case encryption protocols are disabled",
|
||||
}, {
|
||||
name: "busy_https_port",
|
||||
wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort),
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortHTTPS: uint16(busyTCPPort),
|
||||
},
|
||||
},
|
||||
name: "busy_https_port",
|
||||
wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort),
|
||||
}, {
|
||||
name: "busy_dot_port",
|
||||
wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort),
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortDNSOverTLS: uint16(busyTCPPort),
|
||||
},
|
||||
},
|
||||
name: "busy_dot_port",
|
||||
wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort),
|
||||
}, {
|
||||
name: "busy_doq_port",
|
||||
wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort),
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortDNSOverQUIC: uint16(busyUDPPort),
|
||||
},
|
||||
},
|
||||
name: "busy_doq_port",
|
||||
wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort),
|
||||
}, {
|
||||
name: "duplicate_port",
|
||||
wantErr: "validating tcp ports: duplicated values: [4433]",
|
||||
setts: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
@@ -390,8 +387,6 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
PortDNSOverTLS: 4433,
|
||||
},
|
||||
},
|
||||
name: "duplicate_port",
|
||||
wantErr: "validating tcp ports: duplicated values: [4433]",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -417,11 +412,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
Enabled: true,
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
@@ -434,11 +427,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
||||
|
||||
setts := &tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
|
||||
},
|
||||
Enabled: true,
|
||||
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -476,6 +467,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{
|
||||
TLSConf: &dnsforward.TLSConfig{},
|
||||
Config: dnsforward.Config{
|
||||
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false},
|
||||
@@ -511,11 +503,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
logger: logger,
|
||||
configModified: func() {},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
},
|
||||
Enabled: true,
|
||||
CertificatePath: certPath,
|
||||
PrivateKeyPath: keyPath,
|
||||
},
|
||||
servePlainDNS: true,
|
||||
})
|
||||
@@ -526,19 +516,16 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
|
||||
m.setWebAPI(web)
|
||||
|
||||
conf := &tlsConfigSettings{}
|
||||
m.WriteDiskConfig(conf)
|
||||
conf := m.config()
|
||||
assertCertSerialNumber(t, conf, wantSerialNumber)
|
||||
|
||||
// Prepare a request with the new TLS configuration.
|
||||
setts := &tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortHTTPS: 4433,
|
||||
TLSConfig: dnsforward.TLSConfig{
|
||||
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
|
||||
},
|
||||
Enabled: true,
|
||||
PortHTTPS: 4433,
|
||||
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -157,8 +157,8 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
}
|
||||
|
||||
// tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
|
||||
// if necessary.
|
||||
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
|
||||
// if necessary. tlsConf must not be nil.
|
||||
func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf *tlsConfigSettings) {
|
||||
defer slogutil.RecoverAndExit(ctx, web.logger, osutil.ExitCodeFailure)
|
||||
|
||||
web.logger.DebugContext(ctx, "applying new tls configuration")
|
||||
|
||||
@@ -64,7 +64,7 @@ type Entry struct {
|
||||
Domain string
|
||||
|
||||
// UpstreamStats contains the DNS query statistics for both the upstream and
|
||||
// fallback DNS servers.
|
||||
// fallback DNS servers. Don't modify items in the slice.
|
||||
UpstreamStats []*proxy.UpstreamStatistics
|
||||
|
||||
// Result is the result of processing the request.
|
||||
|
||||
@@ -119,4 +119,5 @@ $sudo_cmd docker "$debug_flags" \
|
||||
--build-arg VERSION="$version" \
|
||||
--output "$docker_output" \
|
||||
--platform "$docker_platforms" \
|
||||
--progress 'plain' \
|
||||
$docker_version_tag $docker_channel_tag -f ./docker/Dockerfile .
|
||||
|
||||
@@ -199,6 +199,7 @@ run_linter gocognit --over='10' \
|
||||
./internal/aghhttp/ \
|
||||
./internal/aghrenameio/ \
|
||||
./internal/aghtest/ \
|
||||
./internal/aghuser/ \
|
||||
./internal/arpdb/ \
|
||||
./internal/client/ \
|
||||
./internal/configmigrate/ \
|
||||
@@ -250,6 +251,7 @@ run_linter fieldalignment \
|
||||
./internal/aghrenameio/ \
|
||||
./internal/aghtest/ \
|
||||
./internal/aghtls/ \
|
||||
./internal/aghuser/ \
|
||||
./internal/arpdb/ \
|
||||
./internal/client/ \
|
||||
./internal/configmigrate/ \
|
||||
@@ -280,6 +282,7 @@ run_linter gosec --exclude G115 --quiet \
|
||||
./internal/aghos/ \
|
||||
./internal/aghrenameio/ \
|
||||
./internal/aghtest/ \
|
||||
./internal/aghuser/ \
|
||||
./internal/arpdb/ \
|
||||
./internal/client/ \
|
||||
./internal/configmigrate/ \
|
||||
|
||||
Reference in New Issue
Block a user