Compare commits

..

3 Commits

Author SHA1 Message Date
Ainar Garipov
ce9bb588ed all: sync with master 2024-03-19 16:20:32 +03:00
Ainar Garipov
55fb914537 all: sync rc fix with master 2024-03-13 16:25:51 +03:00
Ainar Garipov
6f7bfd6c9c all: sync with master 2024-03-12 18:15:58 +03:00
114 changed files with 2980 additions and 1340 deletions

View File

@@ -14,15 +14,25 @@ and this project adheres to
<!--
## [v0.108.0] - TBA
## [v0.107.46] - 2024-03-13 (APPROX.)
## [v0.107.47] - 2024-04-03 (APPROX.)
See also the [v0.107.46 GitHub milestone][ms-v0.107.46].
See also the [v0.107.47 GitHub milestone][ms-v0.107.47].
[ms-v0.107.46]: https://github.com/AdguardTeam/AdGuardHome/milestone/81?closed=1
[ms-v0.107.47]: https://github.com/AdguardTeam/AdGuardHome/milestone/82?closed=1
NOTE: Add new changes BELOW THIS COMMENT.
-->
<!--
NOTE: Add new changes ABOVE THIS COMMENT.
-->
## [v0.107.46] - 2024-03-20
See also the [v0.107.46 GitHub milestone][ms-v0.107.46].
### Added
- Ability to disable the use of system hosts file information for query
@@ -30,17 +40,29 @@ NOTE: Add new changes BELOW THIS COMMENT.
- Ability to define custom directories for storage of query log files and
statistics ([#5992]).
### Changed
- Private RDNS resolution (`dns.use_private_ptr_resolvers` in YAML
configuration) now requires a valid "Private reverse DNS servers", when
enabled ([#6820]).
**NOTE:** Disabling private RDNS resolution behaves effectively the same as if
no private reverse DNS servers provided by user and by the OS.
### Fixed
- Statistics for 7 days displayed by day on the dashboard graph ([#6712]).
- Missing "served from cache" label on long DNS server strings ([#6740]).
- Incorrect tracking of the system hosts file's changes ([#6711]).
[#5992]: https://github.com/AdguardTeam/AdGuardHome/issues/5992
[#6610]: https://github.com/AdguardTeam/AdGuardHome/issues/6610
[#6711]: https://github.com/AdguardTeam/AdGuardHome/issues/6711
[#6712]: https://github.com/AdguardTeam/AdGuardHome/issues/6712
[#6740]: https://github.com/AdguardTeam/AdGuardHome/issues/6740
[#6820]: https://github.com/AdguardTeam/AdGuardHome/issues/6820
<!--
NOTE: Add new changes ABOVE THIS COMMENT.
-->
[ms-v0.107.46]: https://github.com/AdguardTeam/AdGuardHome/milestone/81?closed=1
@@ -2831,11 +2853,12 @@ See also the [v0.104.2 GitHub milestone][ms-v0.104.2].
<!--
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.46...HEAD
[v0.107.46]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.45...v0.107.46
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.47...HEAD
[v0.107.47]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.46...v0.107.46
-->
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.45...HEAD
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.46...HEAD
[v0.107.46]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.45...v0.107.46
[v0.107.45]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.44...v0.107.45
[v0.107.44]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.43...v0.107.44
[v0.107.43]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.42...v0.107.43

View File

@@ -473,6 +473,9 @@ bug or implementing the feature.
[@kongfl888](https://github.com/kongfl888) (originally by
[@rufengsuixing](https://github.com/rufengsuixing)).
* [AdGuardHome sync](https://github.com/bakito/adguardhome-sync) by
[@bakito](https://github.com/bakito).
* [Terminal-based, real-time traffic monitoring and statistics for your AdGuard Home
instance](https://github.com/Lissy93/AdGuardian-Term) by
[@Lissy93](https://github.com/Lissy93)

View File

@@ -40,6 +40,8 @@
'jobs':
- 'Publish to GitHub Releases'
# TODO(e.burkov): In jobs below find out why the explicit checkout is
# performed.
'Build frontend':
'docker':
'image': '${bamboo.dockerGo}'

View File

@@ -68,9 +68,6 @@
set -e -f -u -x
# Explicitly checkout the revision that we need.
git checkout "${bamboo.repository.revision.number}"
make\
ARCH="amd64"\
OS="windows darwin linux"\
@@ -122,8 +119,6 @@
# from the release branch and are used to build the release candidate
# images.
- '^rc-v[0-9]+\.[0-9]+\.[0-9]+':
# Build betas on release branches manually.
'triggers': []
# Set the default release channel on the release branch to beta, as we
# may need to build a few of these.
'variables':

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Použít dříve uložený klíče",
"parental_control": "Rodičovská ochrana",
"safe_browsing": "Bezpečné prohlížení",
"served_from_cache": "{{value}} <i>(převzato z mezipaměti)</i>",
"served_from_cache_label": "Převzato z mezipaměti",
"form_error_password_length": "Heslo musí obsahovat od {{min}} do {{max}} znaků",
"anonymizer_notification": "<0>Poznámka:</0> Anonymizace IP je zapnuta. Můžete ji vypnout v <1>Obecných nastaveních</1>.",
"confirm_dns_cache_clear": "Opravdu chcete vymazat mezipaměť DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Brug den tidligere gemte nøgle",
"parental_control": "Forældrekontrol",
"safe_browsing": "Sikker Browsing",
"served_from_cache": "{{value}} <i>(leveret fra cache)</i>",
"served_from_cache_label": "Leveret fra cache",
"form_error_password_length": "Adgangskode skal udgøre fra {{min}} til {{max}} tegn",
"anonymizer_notification": "<0>Bemærk:</0> IP-anonymisering er aktiveret. Det kan deaktiveres via <1>Generelle indstillinger</1>.",
"confirm_dns_cache_clear": "Sikker på, at DNS-cache skal ryddes?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Zuvor gespeicherten Schlüssel verwenden",
"parental_control": "Kindersicherung",
"safe_browsing": "Internetsicherheit",
"served_from_cache": "{{value}} <i>(aus dem Cache abgerufen)</i>",
"served_from_cache_label": "Aus dem Cache abgerufen",
"form_error_password_length": "Das Passwort muss zwischen {{min}} und {{max}} Zeichen enthalten",
"anonymizer_notification": "<0>Hinweis:</0> Die IP-Anonymisierung ist aktiviert. Sie können sie in den <1>Allgemeinen Einstellungen</1> deaktivieren.",
"confirm_dns_cache_clear": "Möchten Sie den DNS-Cache wirklich leeren?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Use the previously saved key",
"parental_control": "Parental Control",
"safe_browsing": "Safe Browsing",
"served_from_cache": "{{value}} <i>(served from cache)</i>",
"served_from_cache_label": "Served from cache",
"form_error_password_length": "Password must be {{min}} to {{max}} characters long",
"anonymizer_notification": "<0>Note:</0> IP anonymization is enabled. You can disable it in <1>General settings</1>.",
"confirm_dns_cache_clear": "Are you sure you want to clear DNS cache?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Usar la clave guardada previamente",
"parental_control": "Control parental",
"safe_browsing": "Navegación segura",
"served_from_cache": "{{value}} <i>(servido desde la caché)</i>",
"served_from_cache_label": "Servido desde la caché",
"form_error_password_length": "La contraseña debe tener entre {{min}} y {{max}} caracteres",
"anonymizer_notification": "<0>Nota:</0> La anonimización de IP está habilitada. Puedes deshabilitarla en <1>Configuración general</1>.",
"confirm_dns_cache_clear": "¿Estás seguro de que deseas borrar la caché DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Utiliser la clef précédemment enregistrée",
"parental_control": "Contrôle parental",
"safe_browsing": "Navigation sécurisée",
"served_from_cache": "{{value}} <i>(depuis le cache)</i>",
"served_from_cache_label": "Servi depuis le cache",
"form_error_password_length": "Le mot de passe doit comporter entre {{min}} et {{max}}  caractères",
"anonymizer_notification": "<0>Note :</0> L'anonymisation IP est activée. Vous pouvez la désactiver dans les <1>paramètres généraux</1>.",
"confirm_dns_cache_clear": "Voulez-vous vraiment vider le cache DNS ?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Utilizza la chiave salvata in precedenza",
"parental_control": "Controllo Parentale",
"safe_browsing": "Navigazione Sicura",
"served_from_cache": "{{value}} <i>(fornito dalla cache)</i>",
"served_from_cache_label": "Servito dalla cache",
"form_error_password_length": "La password deve contenere da {{min}} a {{max}} caratteri",
"anonymizer_notification": "<0>Attenzione:</0> L'anonimizzazione dell'IP è abilitata. Puoi disabilitarla in <1>Impostazioni generali</1>.",
"confirm_dns_cache_clear": "Sei sicuro di voler cancellare la cache DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "以前に保存したキーを使用する",
"parental_control": "ペアレンタルコントロール",
"safe_browsing": "セーフブラウジング",
"served_from_cache": "{{value}} <i>(キャッシュから応答)</i>",
"served_from_cache_label": "キャッシュからの配信:",
"form_error_password_length": "パスワードの長さは{{min}}〜{{max}}文字にしてください。",
"anonymizer_notification": "【<0>注意</0>】IPの匿名化が有効になっています。 <1>一般設定</1>で無効にできます。",
"confirm_dns_cache_clear": "DNS キャッシュをクリアしてもよろしいですか?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "이전에 저장했던 키 사용하기",
"parental_control": "자녀 보호",
"safe_browsing": "세이프 브라우징",
"served_from_cache": "{{value}} <i>(캐시에서 제공)</i>",
"served_from_cache_label": "캐시에서 가져옴",
"form_error_password_length": "비밀번호는 {{min}}~{{max}}자 길이여야 합니다.",
"anonymizer_notification": "<0>참고:</0> IP 익명화가 활성화되었습니다. <1>일반 설정</1>에서 비활성화할 수 있습니다.",
"confirm_dns_cache_clear": "정말로 DNS 캐시를 지우시겠습니까?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "De eerder opgeslagen sleutel gebruiken",
"parental_control": "Ouderlijk toezicht",
"safe_browsing": "Veilig browsen",
"served_from_cache": "{{value}} <i>(geleverd vanuit cache)</i>",
"served_from_cache_label": "Geleverd vanuit cache",
"form_error_password_length": "Wachtwoord moet {{min}} tot {{max}} tekens lang zijn",
"anonymizer_notification": "<0>Opmerking:</0> IP-anonimisering is ingeschakeld. Je kunt het uitschakelen in <1>Algemene instellingen</1>.",
"confirm_dns_cache_clear": "Weet je zeker dat je de DNS-cache wilt wissen?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Use a chave salva anteriormente",
"parental_control": "Controle parental",
"safe_browsing": "Navegação segura",
"served_from_cache": "{{value}} <i>(servido do cache)</i>",
"served_from_cache_label": "Servido a partir do cache",
"form_error_password_length": "A senha deve ter entre {{min}} e {{max}} caracteres",
"anonymizer_notification": "<0>Observação:</0> A anonimização de IP está ativada. Você pode desativá-lo em <1>Configurações gerais</1>.",
"confirm_dns_cache_clear": "Tem certeza de que deseja limpar o cache DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Use a chave guardada anteriormente",
"parental_control": "Controlo parental",
"safe_browsing": "Navegação segura",
"served_from_cache": "{{value}} <i>(servido do cache)</i>",
"served_from_cache_label": "Servido a partir do cache",
"form_error_password_length": "A palavra-passe deve ter {{min}} a {{max}} caracteres",
"anonymizer_notification": "<0>Observação:</0> A anonimização de IP está ativada. Você pode desativá-la em <1>Definições gerais</1>.",
"confirm_dns_cache_clear": "Tem certeza de que quer limpar a cache DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Использовать сохранённый ранее ключ",
"parental_control": "Родительский контроль",
"safe_browsing": "Безопасный интернет",
"served_from_cache": "{{value}} <i>(получено из кеша)</i>",
"served_from_cache_label": "Получено из кеша",
"form_error_password_length": "Пароль должен содержать от {{min}} до {{max}} символов",
"anonymizer_notification": "<0>Внимание:</0> включена анонимизация IP-адресов. Вы можете отключить её в разделе <1>Основные настройки</1>.",
"confirm_dns_cache_clear": "Вы уверены, что хотите очистить кеш DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Použiť predtým uložený kľúč",
"parental_control": "Rodičovská kontrola",
"safe_browsing": "Bezpečné prehliadanie",
"served_from_cache": "{{value}} <i>(prevzatá z cache pamäte)</i>",
"served_from_cache_label": "Prevzaté z cache pamäte",
"form_error_password_length": "Heslo musí mať od {{min}} do {{max}} znakov",
"anonymizer_notification": "<0>Poznámka:</0> Anonymizácia IP je zapnutá. Môžete ju vypnúť vo <1>Všeobecných nastaveniach</1>.",
"confirm_dns_cache_clear": "Naozaj chcete vymazať vyrovnávaciu pamäť DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Uporabi prej shranjeni ključ",
"parental_control": "Starševski nadzor",
"safe_browsing": "Varno brskanje",
"served_from_cache": "{{value}} <i>(postreženo iz predpomnilnika)</i>",
"served_from_cache_label": "Dostavljeno iz predpomnilnika",
"form_error_password_length": "Geslo mora vsebovati od {{min}} do {{max}} znakov",
"anonymizer_notification": "<0>Opomba:</0> Anonimizacija IP je omogočena. Onemogočite ga lahko v <1>Splošnih nastavitvah</1>.",
"confirm_dns_cache_clear": "Ali ste prepričani, da želite počistiti predpomnilnik DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Önceden kaydedilmiş anahtarı kullan",
"parental_control": "Ebeveyn Denetimi",
"safe_browsing": "Güvenli Gezinti",
"served_from_cache": "{{value}} <i>(önbellekten kullanıldı)</i>",
"served_from_cache_label": "Önbellekten kullanıldı",
"form_error_password_length": "Parola {{min}} ila {{max}} karakter uzunluğunda olmalıdır",
"anonymizer_notification": "<0>Not:</0> IP anonimleştirme etkinleştirildi. Bunu <1>Genel ayarlardan</1> devre dışı bırakabilirsiniz.",
"confirm_dns_cache_clear": "DNS önbelleğini temizlemek istediğinizden emin misiniz?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "Використати раніше збережений ключ",
"parental_control": "Батьківський контроль",
"safe_browsing": "Безпечний перегляд",
"served_from_cache": "{{value}} <i>(отримано з кешу)</i>",
"served_from_cache_label": "Отримано з кешу",
"form_error_password_length": "Пароль має містити від {{min}} до {{max}} символів",
"anonymizer_notification": "<0>Примітка:</0> IP-анонімізацію ввімкнено. Ви можете вимкнути його в <1>Загальні налаштування</1> .",
"confirm_dns_cache_clear": "Ви впевнені, що бажаєте очистити кеш DNS?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "使用之前保存的密钥",
"parental_control": "家长控制",
"safe_browsing": "安全浏览",
"served_from_cache": "{{value}}<i>(由缓存提供)</i>",
"served_from_cache_label": "从缓存中",
"form_error_password_length": "密码长度必须为 {{min}} 到 {{max}} 个字符",
"anonymizer_notification": "<0>注意:</0> IP 匿名化已启用。您可以在<1>常规设置</1>中禁用它。",
"confirm_dns_cache_clear": "您确定要清除 DNS 缓存吗?",

View File

@@ -678,7 +678,7 @@
"use_saved_key": "使用該先前已儲存的金鑰",
"parental_control": "家長控制",
"safe_browsing": "安全瀏覽",
"served_from_cache": "{{value}} <i>(由快取提供)</i>",
"served_from_cache_label": "從快取中",
"form_error_password_length": "密碼長度必須為 {{min}} 到 {{max}} 個字符",
"anonymizer_notification": "<0>注意:</0>IP 匿名化被啟用。您可在<1>一般設定</1>中禁用它。",
"confirm_dns_cache_clear": "您確定您想要清除 DNS 快取嗎?",

View File

@@ -55,6 +55,12 @@ const Dashboard = ({
return t('stats_disabled_short');
}
const msIn7Days = 604800000;
if (stats.timeUnits === TIME_UNITS.HOURS && stats.interval === msIn7Days) {
return t('for_last_days', { count: msToDays(stats.interval) });
}
return stats.timeUnits === TIME_UNITS.HOURS
? t('for_last_hours', { count: msToHours(stats.interval) })
: t('for_last_days', { count: msToDays(stats.interval) });

View File

@@ -38,9 +38,6 @@ const ResponseCell = ({
const statusLabel = t(isBlockedByResponse ? 'blocked_by_cname_or_ip' : FILTERED_STATUS_TO_META_MAP[reason]?.LABEL || reason);
const boldStatusLabel = <span className="font-weight-bold">{statusLabel}</span>;
const upstreamString = cached
? t('served_from_cache', { value: upstream, i: <i /> })
: upstream;
const renderResponses = (responseArr) => {
if (!responseArr || responseArr.length === 0) {
@@ -58,7 +55,16 @@ const ResponseCell = ({
const COMMON_CONTENT = {
encryption_status: boldStatusLabel,
install_settings_dns: upstreamString,
install_settings_dns: upstream,
...(cached
&& {
served_from_cache_label: (
<svg className="icons icon--20 icon--green mb-1">
<use xlinkHref="#check" />
</svg>
),
}
),
elapsed: formattedElapsedMs,
response_code: status,
...(service_name && services.allServices

View File

@@ -118,9 +118,6 @@ const Row = memo(({
const blockingForClientKey = isFiltered ? 'unblock_for_this_client_only' : 'block_for_this_client_only';
const clientNameBlockingFor = getBlockingClientName(clients, client);
const upstreamString = cached
? t('served_from_cache', { value: upstream, i: <i /> })
: upstream;
const onBlockingForClientClick = () => {
dispatch(toggleBlockingForClient(buttonType, domain, clientNameBlockingFor));
@@ -192,7 +189,16 @@ const Row = memo(({
className="link--green">{sourceData.name}
</a>,
response_details: 'title',
install_settings_dns: upstreamString,
install_settings_dns: upstream,
...(cached
&& {
served_from_cache_label: (
<svg className="icons icon--20 icon--green">
<use xlinkHref="#check" />
</svg>
),
}
),
elapsed: formattedElapsedMs,
...(rules.length > 0
&& { rule_label: getRulesToFilterList(rules, filters, whitelistFilters) }

View File

@@ -245,6 +245,10 @@ const Icons = () => (
<path fillRule="evenodd" clipRule="evenodd" d="M12 13.5C11.1716 13.5 10.5 12.8284 10.5 12C10.5 11.1716 11.1716 10.5 12 10.5C12.8284 10.5 13.5 11.1716 13.5 12C13.5 12.8284 12.8284 13.5 12 13.5Z" fill="currentColor" />
<path fillRule="evenodd" clipRule="evenodd" d="M12 20C11.1716 20 10.5 19.3284 10.5 18.5C10.5 17.6716 11.1716 17 12 17C12.8284 17 13.5 17.6716 13.5 18.5C13.5 19.3284 12.8284 20 12 20Z" fill="currentColor" />
</symbol>
<symbol id="check" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round">
<path d="M5 11.7665L10.5878 17L19 8" />
</symbol>
</svg>
);

26
go.mod
View File

@@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
go 1.21.8
require (
github.com/AdguardTeam/dnsproxy v0.65.2
github.com/AdguardTeam/golibs v0.20.1
github.com/AdguardTeam/dnsproxy v0.66.0
github.com/AdguardTeam/golibs v0.20.2
github.com/AdguardTeam/urlfilter v0.18.0
github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.7
@@ -18,7 +18,7 @@ require (
github.com/google/gopacket v1.1.19
github.com/google/renameio/v2 v2.0.0
github.com/google/uuid v1.6.0
github.com/insomniacslk/dhcp v0.0.0-20240204152450-ca2dc33955c1
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
github.com/kardianos/service v1.2.2
github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118
@@ -31,11 +31,11 @@ require (
github.com/quic-go/quic-go v0.41.0
github.com/stretchr/testify v1.8.4
github.com/ti-mo/netfilter v0.5.1
go.etcd.io/bbolt v1.3.8
golang.org/x/crypto v0.19.0
golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3
golang.org/x/net v0.21.0
golang.org/x/sys v0.17.0
go.etcd.io/bbolt v1.3.9
golang.org/x/crypto v0.21.0
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225
golang.org/x/net v0.22.0
golang.org/x/sys v0.18.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
howett.net/plist v1.0.1
@@ -48,19 +48,19 @@ require (
github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/google/pprof v0.0.0-20240207164012-fb44976bdcd5 // indirect
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
github.com/onsi/ginkgo/v2 v2.15.0 // indirect
github.com/onsi/ginkgo/v2 v2.16.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/u-root/uio v0.0.0-20240207234124-abbebccef0fd // indirect
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/mod v0.15.0 // indirect
golang.org/x/mod v0.16.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.18.0 // indirect
golang.org/x/tools v0.19.0 // indirect
gonum.org/v1/gonum v0.14.0 // indirect
)

56
go.sum
View File

@@ -1,7 +1,7 @@
github.com/AdguardTeam/dnsproxy v0.65.2 h1:D+BMw0Vu2lbQrYpoPctG2Xr+24KdfhgkzZb6QgPZheM=
github.com/AdguardTeam/dnsproxy v0.65.2/go.mod h1:8NQTTNZY+qR9O1Fzgz3WQv30knfSgms68SRlzSnX74A=
github.com/AdguardTeam/golibs v0.20.1 h1:ol8qLjWGZhU9paMMwN+OLWVTUigGsXa29iVTyd62VKY=
github.com/AdguardTeam/golibs v0.20.1/go.mod h1:bgcMgRviCKyU6mkrX+RtT/OsKPFzyppelfRsksMG3KU=
github.com/AdguardTeam/dnsproxy v0.66.0 h1:RyUbyDxRSXBFjVG1l2/4HV3I98DtfIgpnZkgXkgHKnc=
github.com/AdguardTeam/dnsproxy v0.66.0/go.mod h1:ZThEXbMUlP1RxfwtNW30ItPAHE6OF4YFygK8qjU/cvY=
github.com/AdguardTeam/golibs v0.20.2 h1:9gThBFyuELf2ohRnUNeQGQsVBYI7YslaRLUFwVaUj8E=
github.com/AdguardTeam/golibs v0.20.2/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s=
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
@@ -29,8 +29,8 @@ github.com/dimfeld/httptreemux/v5 v5.5.0 h1:p8jkiMrCuZ0CmhwYLcbNbl7DDo21fozhKHQ2
github.com/dimfeld/httptreemux/v5 v5.5.0/go.mod h1:QeEylH57C0v3VO0tkKraVz9oD3Uu93CKPnTLbsidvSw=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ping/ping v1.1.0 h1:3MCGhVX4fyEUuhsfwPrsEdQw6xspHkv5zHsiSoDFZYw=
@@ -46,8 +46,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/pprof v0.0.0-20240207164012-fb44976bdcd5 h1:E/LAvt58di64hlYjx7AsNS6C/ysHWYo+2qPCZKTQhRo=
github.com/google/pprof v0.0.0-20240207164012-fb44976bdcd5/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 h1:y3N7Bm7Y9/CtpiVkw/ZWj6lSlDF3F74SfKwfTCer72Q=
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -55,8 +55,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis=
github.com/insomniacslk/dhcp v0.0.0-20240204152450-ca2dc33955c1 h1:L3pm9Kf2G6gJVYawz2SrI5QnV1wzHYbqmKnSHHXJAb8=
github.com/insomniacslk/dhcp v0.0.0-20240204152450-ca2dc33955c1/go.mod h1:izxuNQZeFrbx2nK2fAyN5iNUB34Fe9j0nK4PwLzAkKw=
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8 h1:V3plQrMHRWOB5zMm3yNqvBxDQVW1+/wHBSok5uPdmVs=
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8/go.mod h1:izxuNQZeFrbx2nK2fAyN5iNUB34Fe9j0nK4PwLzAkKw=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
@@ -84,8 +84,8 @@ github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.15.0 h1:79HwNRBAZHOEwrczrgSOPy+eFTTlIGELKy5as+ClttY=
github.com/onsi/ginkgo/v2 v2.15.0/go.mod h1:HlxMHtYF57y6Dpf+mc5529KKmSq9h2FpCF+/ZkwUxKM=
github.com/onsi/ginkgo/v2 v2.16.0 h1:7q1w9frJDzninhXxjZd+Y/x54XNjG/UlRLIYPZafsPM=
github.com/onsi/ginkgo/v2 v2.16.0/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs=
github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8=
github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
@@ -121,32 +121,32 @@ github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+Kd
github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI=
github.com/tklauser/numcpus v0.6.0 h1:kebhY2Qt+3U6RNK7UqpYNA+tJ23IBEGKkB7JQBfDYms=
github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4=
github.com/u-root/uio v0.0.0-20240207234124-abbebccef0fd h1:BQJh5fdHsPa/YuMVrbcSxQKuowGCHYh0GD7hvLaHBK0=
github.com/u-root/uio v0.0.0-20240207234124-abbebccef0fd/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA=
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM=
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.etcd.io/bbolt v1.3.8 h1:xs88BrvEv273UsB79e0hcVrlUWmS0a8upikMFhSyAtA=
go.etcd.io/bbolt v1.3.8/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw=
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo=
golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
@@ -161,8 +161,8 @@ golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -170,8 +170,8 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=

View File

@@ -5,9 +5,9 @@ package aghalg
import (
"fmt"
"slices"
"golang.org/x/exp/constraints"
"golang.org/x/exp/slices"
)
// Coalesce returns the first non-zero value. It is named after function

View File

@@ -1,11 +1,11 @@
package aghalg_test
import (
"slices"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
)
// elements is a helper function that returns n elements of the buffer.

View File

@@ -0,0 +1,86 @@
package aghalg
import (
"slices"
)
// SortedMap is a map that keeps elements in order with internal sorting
// function. Must be initialised by the [NewSortedMap].
type SortedMap[K comparable, V any] struct {
vals map[K]V
cmp func(a, b K) (res int)
keys []K
}
// NewSortedMap initializes the new instance of sorted map. cmp is a sort
// function to keep elements in order.
//
// TODO(s.chzhen): Use cmp.Compare in Go 1.21.
func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] {
return SortedMap[K, V]{
vals: map[K]V{},
cmp: cmp,
}
}
// Set adds val with key to the sorted map. It panics if the m is nil.
func (m *SortedMap[K, V]) Set(key K, val V) {
m.vals[key] = val
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
if has {
m.keys[i] = key
} else {
m.keys = slices.Insert(m.keys, i, key)
}
}
// Get returns val by key from the sorted map.
func (m *SortedMap[K, V]) Get(key K) (val V, ok bool) {
if m == nil {
return
}
val, ok = m.vals[key]
return val, ok
}
// Del removes the value by key from the sorted map.
func (m *SortedMap[K, V]) Del(key K) {
if m == nil {
return
}
if _, has := m.vals[key]; !has {
return
}
delete(m.vals, key)
i, _ := slices.BinarySearchFunc(m.keys, key, m.cmp)
m.keys = slices.Delete(m.keys, i, i+1)
}
// Clear removes all elements from the sorted map.
func (m *SortedMap[K, V]) Clear() {
if m == nil {
return
}
m.keys = nil
clear(m.vals)
}
// Range calls cb for each element of the map, sorted by m.cmp. If cb returns
// false it stops.
func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) {
if m == nil {
return
}
for _, k := range m.keys {
if !cb(k, m.vals[k]) {
return
}
}
}

View File

@@ -0,0 +1,95 @@
package aghalg
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewSortedMap(t *testing.T) {
var m SortedMap[string, int]
letters := []string{}
for i := 0; i < 10; i++ {
r := string('a' + rune(i))
letters = append(letters, r)
}
t.Run("create_and_fill", func(t *testing.T) {
m = NewSortedMap[string, int](strings.Compare)
nums := []int{}
for i, r := range letters {
m.Set(r, i)
nums = append(nums, i)
}
gotLetters := []string{}
gotNums := []int{}
m.Range(func(k string, v int) bool {
gotLetters = append(gotLetters, k)
gotNums = append(gotNums, v)
return true
})
assert.Equal(t, letters, gotLetters)
assert.Equal(t, nums, gotNums)
n, ok := m.Get(letters[0])
assert.True(t, ok)
assert.Equal(t, nums[0], n)
})
t.Run("clear", func(t *testing.T) {
lastLetter := letters[len(letters)-1]
m.Del(lastLetter)
_, ok := m.Get(lastLetter)
assert.False(t, ok)
m.Clear()
gotLetters := []string{}
m.Range(func(k string, _ int) bool {
gotLetters = append(gotLetters, k)
return true
})
assert.Len(t, gotLetters, 0)
})
}
func TestNewSortedMap_nil(t *testing.T) {
const (
key = "key"
val = "val"
)
var m SortedMap[string, string]
assert.Panics(t, func() {
m.Set(key, val)
})
assert.NotPanics(t, func() {
_, ok := m.Get(key)
assert.False(t, ok)
})
assert.NotPanics(t, func() {
m.Range(func(_, _ string) (cont bool) {
return true
})
})
assert.NotPanics(t, func() {
m.Del(key)
})
assert.NotPanics(t, func() {
m.Clear()
})
}

View File

@@ -154,8 +154,8 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error)
}
// handleEvents concurrently handles the file system events. It closes the
// update channel of HostsContainer when finishes. It's used to be called
// within a separate goroutine.
// update channel of HostsContainer when finishes. It is intended to be used as
// a goroutine.
func (hc *HostsContainer) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))

View File

@@ -67,6 +67,7 @@ func TestNewHostsContainer(t *testing.T) {
}
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: onEvents,
OnAdd: onAdd,
OnClose: func() (err error) { return nil },
@@ -93,6 +94,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() {
_, _ = aghnet.NewHostsContainer(nil, &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
// Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
@@ -111,6 +113,7 @@ func TestNewHostsContainer(t *testing.T) {
const errOnAdd errors.Error = "error"
errWatcher := &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
OnAdd: func(name string) (err error) { return errOnAdd },
OnClose: func() (err error) { return nil },
@@ -155,6 +158,7 @@ func TestHostsContainer_refresh(t *testing.T) {
t.Cleanup(func() { close(eventsCh) })
w := &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: func() (e <-chan event) { return eventsCh },
OnAdd: func(name string) (err error) {
assert.Equal(t, "dir", name)

View File

@@ -1,11 +1,11 @@
package aghnet
import (
"slices"
"strings"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/filterlist"
"golang.org/x/exp/slices"
)
// IgnoreEngine contains the list of rules for ignoring hostnames and matches

View File

@@ -17,6 +17,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
)
// DialContextFunc is the semantic alias for dialing functions, such as
@@ -32,7 +33,7 @@ var (
netInterfaceAddrs = net.InterfaceAddrs
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = aghos.RootDirFS()
rootDirFS = osutil.RootDirFS()
)
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about

View File

@@ -8,6 +8,8 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/fsnotify/fsnotify"
)
@@ -18,31 +20,38 @@ type event = struct{}
// FSWatcher tracks all the fyle system events and notifies about those.
//
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
//
// TODO(e.burkov): Add tests.
type FSWatcher interface {
// Start starts watching the added files.
Start() (err error)
// Close stops watching the files and closes an update channel.
io.Closer
// Events should return a read-only channel which notifies about events.
// Events returns the channel to notify about the file system events.
Events() (e <-chan event)
// Add should check if the file named name is accessible and starts tracking
// it.
// Add starts tracking the file. It returns an error if the file can't be
// tracked. It must not be called after Start.
Add(name string) (err error)
}
// osWatcher tracks the file system provided by the OS.
type osWatcher struct {
// w is the actual notifier that is handled by osWatcher.
w *fsnotify.Watcher
// watcher is the actual notifier that is handled by osWatcher.
watcher *fsnotify.Watcher
// events is the channel to notify.
events chan event
// files is the set of tracked files.
files *stringutil.Set
}
const (
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
// methods.
osWatcherPref = "os watcher"
)
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
// methods.
const osWatcherPref = "os watcher"
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
// OS and notifies only about writing events.
@@ -55,25 +64,27 @@ func NewOSWritesWatcher() (w FSWatcher, err error) {
return nil, fmt.Errorf("creating watcher: %w", err)
}
fsw := &osWatcher{
w: watcher,
events: make(chan event, 1),
}
go fsw.handleErrors()
go fsw.handleEvents()
return fsw, nil
return &osWatcher{
watcher: watcher,
events: make(chan event, 1),
files: stringutil.NewSet(),
}, nil
}
// handleErrors handles accompanying errors. It used to be called in a separate
// goroutine.
func (w *osWatcher) handleErrors() {
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
// type check
var _ FSWatcher = (*osWatcher)(nil)
for err := range w.w.Errors {
log.Error("%s: %s", osWatcherPref, err)
}
// Start implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Start() (err error) {
go w.handleErrors()
go w.handleEvents()
return nil
}
// Close implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Close() (err error) {
return w.watcher.Close()
}
// Events implements the FSWatcher interface for *osWatcher.
@@ -81,34 +92,42 @@ func (w *osWatcher) Events() (e <-chan event) {
return w.events
}
// Add implements the FSWatcher interface for *osWatcher.
// Add implements the [FSWatcher] interface for *osWatcher.
//
// TODO(e.burkov): Make it accept non-existing files to detect it's creating.
func (w *osWatcher) Add(name string) (err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
if _, err = fs.Stat(RootDirFS(), name); err != nil {
fi, err := fs.Stat(osutil.RootDirFS(), name)
if err != nil {
return fmt.Errorf("checking file %q: %w", name, err)
}
return w.w.Add(filepath.Join("/", name))
}
name = filepath.Join("/", name)
w.files.Add(name)
// Close implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Close() (err error) {
return w.w.Close()
// Watch the directory and filter the events by the file name, since the
// common recomendation to the fsnotify package is to watch the directory
// instead of the file itself.
//
// See https://pkg.go.dev/github.com/fsnotify/fsnotify@v1.7.0#readme-watching-a-file-doesn-t-work-well.
if !fi.IsDir() {
name = filepath.Dir(name)
}
return w.watcher.Add(name)
}
// handleEvents notifies about the received file system's event if needed. It
// used to be called in a separate goroutine.
// is intended to be used as a goroutine.
func (w *osWatcher) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
defer close(w.events)
ch := w.w.Events
ch := w.watcher.Events
for e := range ch {
if e.Op&fsnotify.Write == 0 {
if e.Op&fsnotify.Write == 0 || !w.files.Has(e.Name) {
continue
}
@@ -131,3 +150,13 @@ func (w *osWatcher) handleEvents() {
}
}
}
// handleErrors handles accompanying errors. It used to be called in a separate
// goroutine.
func (w *osWatcher) handleErrors() {
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
for err := range w.watcher.Errors {
log.Error("%s: %s", osWatcherPref, err)
}
}

View File

@@ -7,17 +7,16 @@ import (
"bufio"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"path"
"runtime"
"slices"
"strconv"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices"
)
// UnsupportedError is returned by functions and methods when a particular
@@ -155,13 +154,6 @@ func IsOpenWrt() (ok bool) {
return isOpenWrt()
}
// RootDirFS returns the [fs.FS] rooted at the operating system's root. On
// Windows it returns the fs.FS rooted at the volume of the system directory
// (usually, C:).
func RootDirFS() (fsys fs.FS) {
return rootDirFS()
}
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c)

View File

@@ -7,6 +7,7 @@ import (
"os"
"syscall"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/stringutil"
)
@@ -40,7 +41,7 @@ func isOpenWrt() (ok bool) {
}
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
}).Walk(RootDirFS(), etcReleasePattern)
}).Walk(osutil.RootDirFS(), etcReleasePattern)
return err == nil && ok
}

View File

@@ -3,17 +3,12 @@
package aghos
import (
"io/fs"
"os"
"os/signal"
"golang.org/x/sys/unix"
)
func rootDirFS() (fsys fs.FS) {
return os.DirFS("/")
}
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP)
}

View File

@@ -3,29 +3,13 @@
package aghos
import (
"io/fs"
"os"
"os/signal"
"path/filepath"
"syscall"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/windows"
)
func rootDirFS() (fsys fs.FS) {
// TODO(a.garipov): Use a better way if golang/go#44279 is ever resolved.
sysDir, err := windows.GetSystemDirectory()
if err != nil {
log.Error("aghos: getting root filesystem: %s; using C:", err)
// Assume that C: is the safe default.
return os.DirFS("C:")
}
return os.DirFS(filepath.VolumeName(sysDir))
}
func setRlimit(val uint64) (err error) {
return Unsupported("setrlimit")
}

View File

@@ -9,8 +9,13 @@ import (
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require"
)
@@ -71,3 +76,49 @@ func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) {
return srv.Client(), u
}
// testTimeout is a timeout for tests.
//
// TODO(e.burkov): Move into agdctest.
const testTimeout = 1 * time.Second
// StartLocalhostUpstream is a test helper that starts a DNS server on
// localhost.
func StartLocalhostUpstream(t *testing.T, h dns.Handler) (addr *url.URL) {
t.Helper()
startCh := make(chan netip.AddrPort)
defer close(startCh)
errCh := make(chan error)
srv := &dns.Server{
Addr: "127.0.0.1:0",
Net: string(proxy.ProtoTCP),
Handler: h,
ReadTimeout: testTimeout,
WriteTimeout: testTimeout,
}
srv.NotifyStartedFunc = func() {
addrPort := srv.Listener.Addr()
startCh <- netutil.NetAddrToAddrPort(addrPort)
}
go func() { errCh <- srv.ListenAndServe() }()
select {
case addrPort := <-startCh:
addr = &url.URL{
Scheme: string(proxy.ProtoTCP),
Host: addrPort.String(),
}
testutil.CleanupAndRequireSuccess(t, func() (err error) { return <-errCh })
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)
case err := <-errCh:
require.NoError(t, err)
case <-time.After(testTimeout):
require.FailNow(t, "timeout exceeded")
}
return addr
}

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
@@ -26,14 +25,25 @@ import (
// FSWatcher is a fake [aghos.FSWatcher] implementation for tests.
type FSWatcher struct {
OnStart func() (err error)
OnClose func() (err error)
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// type check
var _ aghos.FSWatcher = (*FSWatcher)(nil)
// Start implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Start() (err error) {
return w.OnStart()
}
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
@@ -44,11 +54,6 @@ func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}
// Package agh
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
@@ -88,9 +93,6 @@ type AddressProcessor struct {
OnClose func() (err error)
}
// type check
var _ client.AddressProcessor = (*AddressProcessor)(nil)
// Process implements the [client.AddressProcessor] interface for
// *AddressProcessor.
func (p *AddressProcessor) Process(ip netip.Addr) {
@@ -108,9 +110,6 @@ type AddressUpdater struct {
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
}
// type check
var _ client.AddressUpdater = (*AddressUpdater)(nil)
// UpdateAddress implements the [client.AddressUpdater] interface for
// *AddressUpdater.
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {

View File

@@ -2,6 +2,7 @@ package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
@@ -13,3 +14,13 @@ var _ filtering.Resolver = (*aghtest.Resolver)(nil)
// type check
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
// type check
//
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
var _ client.AddressProcessor = (*aghtest.AddressProcessor)(nil)
// type check
//
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
var _ client.AddressUpdater = (*aghtest.AddressUpdater)(nil)

View File

@@ -7,13 +7,14 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
"github.com/AdguardTeam/golibs/osutil"
)
// Variables and functions to substitute in tests.
@@ -22,7 +23,7 @@ var (
aghosRunCommand = aghos.RunCommand
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = aghos.RootDirFS()
rootDirFS = osutil.RootDirFS()
)
// Interface stores and refreshes the network neighborhood reported by ARP

249
internal/client/index.go Normal file
View File

@@ -0,0 +1,249 @@
package client
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
)
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
type macKey any
// macToKey converts mac into key of type macKey, which is used as the key of
// the [clientIndex.macToUID]. mac must be valid MAC address.
func macToKey(mac net.HardwareAddr) (key macKey) {
switch len(mac) {
case 6:
return [6]byte(mac)
case 8:
return [8]byte(mac)
case 20:
return [20]byte(mac)
default:
panic(fmt.Errorf("invalid mac address %#v", mac))
}
}
// Index stores all information about persistent clients.
type Index struct {
// clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID
// ipToUID maps IP address to UID.
ipToUID map[netip.Addr]UID
// macToUID maps MAC address to UID.
macToUID map[macKey]UID
// uidToClient maps UID to the persistent client.
uidToClient map[UID]*Persistent
// subnetToUID maps subnet to UID.
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{},
uidToClient: map[UID]*Persistent{},
}
}
// Add stores information about a persistent client in the index. c must be
// non-nil and contain UID.
func (ci *Index) Add(c *Persistent) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
}
for _, ip := range c.IPs {
ci.ipToUID[ip] = c.UID
}
for _, pref := range c.Subnets {
ci.subnetToUID.Set(pref, c.UID)
}
for _, mac := range c.MACs {
k := macToKey(mac)
ci.macToUID[k] = c.UID
}
ci.uidToClient[c.UID] = c
}
// Clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
p := ci.uidToClient[existing]
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
}
}
p, ip := ci.clashesIP(c)
if p != nil {
return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip)
}
p, s := ci.clashesSubnet(c)
if p != nil {
return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s)
}
p, mac := ci.clashesMAC(c)
if p != nil {
return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac)
}
return nil
}
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
return ci.uidToClient[existing], ip
}
}
return nil, netip.Addr{}
}
// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) {
if s == p {
existing = uid
ok = true
return false
}
return true
})
if ok && existing != c.UID {
return ci.uidToClient[existing], s
}
}
return nil, netip.Prefix{}
}
// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
if ok && existing != c.UID {
return ci.uidToClient[existing], mac
}
}
return nil, nil
}
// Find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
}
ip, err := netip.ParseAddr(id)
if err == nil {
// MAC addresses can be successfully parsed as IP addresses.
c, found = ci.findByIP(ip)
if found {
return c, true
}
}
mac, err := net.ParseMAC(id)
if err == nil {
return ci.findByMAC(mac)
}
return nil, false
}
// find finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
}
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) {
uid, found = id, true
return false
}
return true
})
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// find finds persistent client by MAC.
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// Delete removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}
for _, ip := range c.IPs {
delete(ci.ipToUID, ip)
}
for _, pref := range c.Subnets {
ci.subnetToUID.Del(pref)
}
for _, mac := range c.MACs {
k := macToKey(mac)
delete(ci.macToUID, k)
}
delete(ci.uidToClient, c.UID)
}

View File

@@ -0,0 +1,223 @@
package client
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*Persistent) (ci *Index) {
ci = NewIndex()
for _, c := range m {
c.UID = MustNewUID()
ci.Add(c)
}
return ci
}
func TestClientIndex(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
cliIPv6 = "1:2:3::4"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
)
clients := []*Persistent{{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}, {
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients)
testCases := []struct {
want *Persistent
name string
ids []string
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clients[0],
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clients[1],
}, {
name: "mac",
ids: []string{cliMAC},
want: clients[2],
}, {
name: "client_id",
ids: []string{cliID},
want: clients[3],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.Find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
}
})
}
t.Run("not_found", func(t *testing.T) {
_, ok := ci.Find(cliIPNone)
assert.False(t, ok)
})
}
func TestClientIndex_Clashes(t *testing.T) {
const (
cliIP1 = "1.1.1.1"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
)
clients := []*Persistent{{
Name: "client_with_ip",
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}, {
Name: "client_with_subnet",
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients)
testCases := []struct {
client *Persistent
name string
}{{
name: "ipv4",
client: clients[0],
}, {
name: "subnet",
client: clients[1],
}, {
name: "mac",
client: clients[2],
}, {
name: "client_id",
client: clients[3],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
clone := tc.client.ShallowClone()
clone.UID = MustNewUID()
err := ci.Clashes(clone)
require.Error(t, err)
ci.Delete(tc.client)
err = ci.Clashes(clone)
require.NoError(t, err)
})
}
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestMACToKey(t *testing.T) {
testCases := []struct {
want any
name string
in string
}{{
name: "column6",
in: "00:00:5e:00:53:01",
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
}, {
name: "column8",
in: "02:00:5e:10:00:00:00:01",
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
}, {
name: "column20",
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
}, {
name: "hyphen6",
in: "00-00-5e-00-53-01",
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
}, {
name: "hyphen8",
in: "02-00-5e-10-00-00-00-01",
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
}, {
name: "hyphen20",
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
}, {
name: "dot6",
in: "0000.5e00.5301",
want: [6]byte(mustParseMAC("0000.5e00.5301")),
}, {
name: "dot8",
in: "0200.5e10.0000.0001",
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
}, {
name: "dot20",
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mac := mustParseMAC(tc.in)
key := macToKey(mac)
assert.Equal(t, tc.want, key)
})
}
assert.Panics(t, func() {
mac := net.HardwareAddr([]byte{1, 2, 3})
_ = macToKey(mac)
})
}

View File

@@ -1,22 +1,22 @@
package home
package client
import (
"encoding"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/uuid"
"golang.org/x/exp/slices"
)
// UID is the type for the unique IDs of persistent clients.
@@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) {
return UID(uuidv7), err
}
// MustNewUID is a wrapper around [NewUID] that panics if there is an error.
func MustNewUID() (uid UID) {
uid, err := NewUID()
if err != nil {
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
}
return uid
}
// type check
var _ encoding.TextMarshaler = UID{}
@@ -46,16 +56,16 @@ func (uid *UID) UnmarshalText(data []byte) error {
return (*uuid.UUID)(uid).UnmarshalText(data)
}
// persistentClient contains information about persistent clients.
type persistentClient struct {
// upstreamConfig is the custom upstream configuration for this client. If
// Persistent contains information about persistent clients.
type Persistent struct {
// UpstreamConfig is the custom upstream configuration for this client. If
// it's nil, it has not been initialized yet. If it's non-nil and empty,
// there are no valid upstreams. If it's non-nil and non-empty, these
// upstream must be used.
upstreamConfig *proxy.CustomUpstreamConfig
UpstreamConfig *proxy.CustomUpstreamConfig
// TODO(d.kolyshev): Make safeSearchConf a pointer.
safeSearchConf filtering.SafeSearchConfig
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
@@ -87,8 +97,8 @@ type persistentClient struct {
IgnoreStatistics bool
}
// setTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *stringutil.Set) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)
@@ -102,9 +112,9 @@ func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
slices.Sort(c.Tags)
}
// setIDs parses a list of strings into typed fields and returns an error if
// SetIDs parses a list of strings into typed fields and returns an error if
// there is one.
func (c *persistentClient) setIDs(ids []string) (err error) {
func (c *Persistent) SetIDs(ids []string) (err error) {
for _, id := range ids {
err = c.setID(id)
if err != nil {
@@ -144,7 +154,7 @@ func subnetCompare(x, y netip.Prefix) (cmp int) {
}
// setID parses id into typed field if there is no error.
func (c *persistentClient) setID(id string) (err error) {
func (c *Persistent) setID(id string) (err error) {
if id == "" {
return errors.Error("clientid is empty")
}
@@ -170,7 +180,7 @@ func (c *persistentClient) setID(id string) (err error) {
return nil
}
err = dnsforward.ValidateClientID(id)
err = ValidateClientID(id)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
@@ -181,9 +191,23 @@ func (c *persistentClient) setID(id string) (err error) {
return nil
}
// ids returns a list of client ids containing at least one element.
func (c *persistentClient) ids() (ids []string) {
ids = make([]string, 0, c.idsLen())
// ValidateClientID returns an error if id is not a valid ClientID.
//
// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
// avoid the import cycle. Remove it.
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {
// Replace the domain name label wrapper with our own.
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
}
return nil
}
// IDs returns a list of client IDs containing at least one element.
func (c *Persistent) IDs() (ids []string) {
ids = make([]string, 0, c.IDsLen())
for _, ip := range c.IPs {
ids = append(ids, ip.String())
@@ -200,24 +224,24 @@ func (c *persistentClient) ids() (ids []string) {
return append(ids, c.ClientIDs...)
}
// idsLen returns a length of client ids.
func (c *persistentClient) idsLen() (n int) {
// IDsLen returns a length of client ids.
func (c *Persistent) IDsLen() (n int) {
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
}
// equalIDs returns true if the ids of the current and previous clients are the
// EqualIDs returns true if the ids of the current and previous clients are the
// same.
func (c *persistentClient) equalIDs(prev *persistentClient) (equal bool) {
func (c *Persistent) EqualIDs(prev *Persistent) (equal bool) {
return slices.Equal(c.IPs, prev.IPs) &&
slices.Equal(c.Subnets, prev.Subnets) &&
slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) &&
slices.Equal(c.ClientIDs, prev.ClientIDs)
}
// shallowClone returns a deep copy of the client, except upstreamConfig,
// ShallowClone returns a deep copy of the client, except upstreamConfig,
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
func (c *persistentClient) shallowClone() (clone *persistentClient) {
clone = &persistentClient{}
func (c *Persistent) ShallowClone() (clone *Persistent) {
clone = &Persistent{}
*clone = *c
clone.BlockedServices = c.BlockedServices.Clone()
@@ -232,10 +256,10 @@ func (c *persistentClient) shallowClone() (clone *persistentClient) {
return clone
}
// closeUpstreams closes the client-specific upstream config of c if any.
func (c *persistentClient) closeUpstreams() (err error) {
if c.upstreamConfig != nil {
if err = c.upstreamConfig.Close(); err != nil {
// CloseUpstreams closes the client-specific upstream config of c if any.
func (c *Persistent) CloseUpstreams() (err error) {
if c.UpstreamConfig != nil {
if err = c.UpstreamConfig.Close(); err != nil {
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
}
}
@@ -243,8 +267,8 @@ func (c *persistentClient) closeUpstreams() (err error) {
return nil
}
// setSafeSearch initializes and sets the safe search filter for this client.
func (c *persistentClient) setSafeSearch(
// SetSafeSearch initializes and sets the safe search filter for this client.
func (c *Persistent) SetSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,

View File

@@ -1,4 +1,4 @@
package home
package client
import (
"testing"
@@ -27,10 +27,10 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
)
testCases := []struct {
want assert.BoolAssertionFunc
name string
ids []string
prevIDs []string
want assert.BoolAssertionFunc
}{{
name: "single_ip",
ids: []string{ip1},
@@ -110,15 +110,15 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := &persistentClient{}
err := c.setIDs(tc.ids)
c := &Persistent{}
err := c.SetIDs(tc.ids)
require.NoError(t, err)
prev := &persistentClient{}
err = prev.setIDs(tc.prevIDs)
prev := &Persistent{}
err = prev.SetIDs(tc.prevIDs)
require.NoError(t, err)
tc.want(t, c.equalIDs(prev))
tc.want(t, c.EqualIDs(prev))
})
}
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/netip"
"os"
"slices"
"strings"
"time"
@@ -15,7 +16,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/v2/maybe"
"golang.org/x/exp/slices"
)
const (

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"path/filepath"
"slices"
"testing"
"time"
@@ -13,7 +14,6 @@ import (
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestMain(m *testing.M) {

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"net/netip"
"os"
"slices"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
@@ -19,7 +20,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
)
type v4ServerConfJSON struct {
@@ -592,7 +592,7 @@ func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
}
// parseLease parses a lease from r. If there is no error returns DHCPServer
// and *Lease. r must be non-nil.
// and *Lease. r must be non-nil.
func (s *server) parseLease(r io.Reader) (srv DHCPServer, lease *dhcpsvc.Lease, err error) {
l := &leaseStatic{}
err = json.NewDecoder(r).Decode(l)

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
@@ -20,7 +21,6 @@ import (
"github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"golang.org/x/exp/slices"
)
// v4Server is a DHCPv4 server.

View File

@@ -2,11 +2,11 @@ package dhcpsvc
import (
"fmt"
"slices"
"time"
"github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// Config is the configuration for the DHCP service.
@@ -19,6 +19,8 @@ type Config struct {
// clients' hostnames.
LocalDomainName string
// TODO(e.burkov): Add DB path.
// ICMPTimeout is the timeout for checking another DHCP server's presence.
ICMPTimeout time.Duration
@@ -68,12 +70,6 @@ func (conf *Config) Validate() (err error) {
return nil
}
// newMustErr returns an error that indicates that valName must be as must
// describes.
func newMustErr(valName, must string, val fmt.Stringer) (err error) {
return fmt.Errorf("%s %s must %s", valName, val, must)
}
// validate returns an error in ic, if any.
func (ic *InterfaceConfig) validate() (err error) {
if ic == nil {

View File

@@ -7,48 +7,16 @@ import (
"context"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"golang.org/x/exp/slices"
)
// Lease is a DHCP lease.
// Interface is a DHCP service.
//
// TODO(e.burkov): Consider moving it to [agh], since it also may be needed in
// [websvc].
type Lease struct {
// IP is the IP address leased to the client.
IP netip.Addr
// Expiry is the expiration time of the lease.
Expiry time.Time
// Hostname of the client.
Hostname string
// HWAddr is the physical hardware address (MAC address).
HWAddr net.HardwareAddr
// IsStatic defines if the lease is static.
IsStatic bool
}
// Clone returns a deep copy of l.
func (l *Lease) Clone() (clone *Lease) {
if l == nil {
return nil
}
return &Lease{
Expiry: l.Expiry,
Hostname: l.Hostname,
HWAddr: slices.Clone(l.HWAddr),
IP: l.IP,
IsStatic: l.IsStatic,
}
}
// TODO(e.burkov): Separate HostByIP, MACByIP, IPByHost into a separate
// interface. This is also applicable to Enabled method.
//
// TODO(e.burkov): Reconsider the requirements for the leases validity.
type Interface interface {
agh.ServiceWithConfig[*Config]
@@ -63,6 +31,8 @@ type Interface interface {
// MACByIP returns the MAC address for the given IP address leased. It
// returns nil if there is no such client, due to an assumption that a DHCP
// client must always have a MAC address.
//
// TODO(e.burkov): Think of a contract for the returned value.
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
// IPByHost returns the IP address of the DHCP client with the given
@@ -71,26 +41,29 @@ type Interface interface {
// hostname, either set or generated.
IPByHost(host string) (ip netip.Addr)
// Leases returns all the active DHCP leases.
// Leases returns all the active DHCP leases. The returned slice should be
// a clone.
//
// TODO(e.burkov): Consider implementing iterating methods with appropriate
// signatures instead of cloning the whole list.
Leases() (ls []*Lease)
// AddLease adds a new DHCP lease. It returns an error if the lease is
// invalid or already exists.
// AddLease adds a new DHCP lease. l must be valid. It returns an error if
// l already exists.
AddLease(l *Lease) (err error)
// UpdateStaticLease changes an existing DHCP lease. It returns an error if
// there is no lease with such hardware addressor if new values are invalid
// or already exist.
// UpdateStaticLease replaces an existing static DHCP lease. l must be
// valid. It returns an error if the lease with the given hardware address
// doesn't exist or if other values match another existing lease.
UpdateStaticLease(l *Lease) (err error)
// RemoveLease removes an existing DHCP lease. It returns an error if there
// is no lease equal to l.
// RemoveLease removes an existing DHCP lease. l must be valid. It returns
// an error if there is no lease equal to l.
RemoveLease(l *Lease) (err error)
// Reset removes all the DHCP leases.
//
// TODO(e.burkov): If it's really needed?
Reset() (err error)
}

View File

@@ -1,6 +1,10 @@
package dhcpsvc
import "github.com/AdguardTeam/golibs/errors"
import (
"fmt"
"github.com/AdguardTeam/golibs/errors"
)
const (
// errNilConfig is returned when a nil config met.
@@ -9,3 +13,9 @@ const (
// errNoInterfaces is returned when no interfaces found in configuration.
errNoInterfaces errors.Error = "no interfaces specified"
)
// newMustErr returns an error that indicates that valName must be as must
// describes.
func newMustErr(valName, must string, val fmt.Stringer) (err error) {
return fmt.Errorf("%s %s must %s", valName, val, must)
}

View File

@@ -0,0 +1,66 @@
package dhcpsvc
import (
"fmt"
"slices"
"time"
)
// netInterface is a common part of any network interface within the DHCP
// server.
//
// TODO(e.burkov): Add other methods as [DHCPServer] evolves.
type netInterface struct {
// name is the name of the network interface.
name string
// leases is a set of leases sorted by hardware address.
leases []*Lease
// leaseTTL is the default Time-To-Live value for leases.
leaseTTL time.Duration
}
// reset clears all the slices in iface for reuse.
func (iface *netInterface) reset() {
iface.leases = iface.leases[:0]
}
// insertLease inserts the given lease into iface. It returns an error if the
// lease can't be inserted.
func (iface *netInterface) insertLease(l *Lease) (err error) {
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
if found {
return fmt.Errorf("lease for mac %s already exists", l.HWAddr)
}
iface.leases = slices.Insert(iface.leases, i, l)
return nil
}
// updateLease replaces an existing lease within iface with the given one. It
// returns an error if there is no lease with such hardware address.
func (iface *netInterface) updateLease(l *Lease) (prev *Lease, err error) {
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
if !found {
return nil, fmt.Errorf("no lease for mac %s", l.HWAddr)
}
prev, iface.leases[i] = iface.leases[i], l
return prev, nil
}
// removeLease removes an existing lease from iface. It returns an error if
// there is no lease equal to l.
func (iface *netInterface) removeLease(l *Lease) (err error) {
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
if !found {
return fmt.Errorf("no lease for mac %s", l.HWAddr)
}
iface.leases = slices.Delete(iface.leases, i, i+1)
return nil
}

52
internal/dhcpsvc/lease.go Normal file
View File

@@ -0,0 +1,52 @@
package dhcpsvc
import (
"bytes"
"net"
"net/netip"
"slices"
"time"
)
// Lease is a DHCP lease.
//
// TODO(e.burkov): Consider moving it to [agh], since it also may be needed in
// [websvc].
//
// TODO(e.burkov): Add validation method.
type Lease struct {
// IP is the IP address leased to the client.
IP netip.Addr
// Expiry is the expiration time of the lease.
Expiry time.Time
// Hostname of the client.
Hostname string
// HWAddr is the physical hardware address (MAC address).
HWAddr net.HardwareAddr
// IsStatic defines if the lease is static.
IsStatic bool
}
// Clone returns a deep copy of l.
func (l *Lease) Clone() (clone *Lease) {
if l == nil {
return nil
}
return &Lease{
Expiry: l.Expiry,
Hostname: l.Hostname,
HWAddr: slices.Clone(l.HWAddr),
IP: l.IP,
IsStatic: l.IsStatic,
}
}
// compareLeaseMAC compares two [Lease]s by hardware address.
func compareLeaseMAC(a, b *Lease) (res int) {
return bytes.Compare(a.HWAddr, b.HWAddr)
}

View File

@@ -0,0 +1,126 @@
package dhcpsvc
import (
"fmt"
"net/netip"
"slices"
"strings"
)
// leaseIndex is the set of leases indexed by their identifiers for quick
// lookup.
type leaseIndex struct {
// byAddr is a lookup shortcut for leases by their IP addresses.
byAddr map[netip.Addr]*Lease
// byName is a lookup shortcut for leases by their hostnames.
//
// TODO(e.burkov): Use a slice of leases with the same hostname?
byName map[string]*Lease
}
// newLeaseIndex returns a new index for [Lease]s.
func newLeaseIndex() *leaseIndex {
return &leaseIndex{
byAddr: map[netip.Addr]*Lease{},
byName: map[string]*Lease{},
}
}
// leaseByAddr returns a lease by its IP address.
func (idx *leaseIndex) leaseByAddr(addr netip.Addr) (l *Lease, ok bool) {
l, ok = idx.byAddr[addr]
return l, ok
}
// leaseByName returns a lease by its hostname.
func (idx *leaseIndex) leaseByName(name string) (l *Lease, ok bool) {
// TODO(e.burkov): Probably, use a case-insensitive comparison and store in
// slice. This would require a benchmark.
l, ok = idx.byName[strings.ToLower(name)]
return l, ok
}
// clear removes all leases from idx.
func (idx *leaseIndex) clear() {
clear(idx.byAddr)
clear(idx.byName)
}
// add adds l into idx and into iface. l must be valid, iface should be
// responsible for l's IP. It returns an error if l duplicates at least a
// single value of another lease.
func (idx *leaseIndex) add(l *Lease, iface *netInterface) (err error) {
loweredName := strings.ToLower(l.Hostname)
if _, ok := idx.byAddr[l.IP]; ok {
return fmt.Errorf("lease for ip %s already exists", l.IP)
} else if _, ok = idx.byName[loweredName]; ok {
return fmt.Errorf("lease for hostname %s already exists", l.Hostname)
}
err = iface.insertLease(l)
if err != nil {
return err
}
idx.byAddr[l.IP] = l
idx.byName[loweredName] = l
return nil
}
// remove removes l from idx and from iface. l must be valid, iface should
// contain the same lease or the lease itself. It returns an error if the lease
// not found.
func (idx *leaseIndex) remove(l *Lease, iface *netInterface) (err error) {
loweredName := strings.ToLower(l.Hostname)
if _, ok := idx.byAddr[l.IP]; !ok {
return fmt.Errorf("no lease for ip %s", l.IP)
} else if _, ok = idx.byName[loweredName]; !ok {
return fmt.Errorf("no lease for hostname %s", l.Hostname)
}
err = iface.removeLease(l)
if err != nil {
return err
}
delete(idx.byAddr, l.IP)
delete(idx.byName, loweredName)
return nil
}
// update updates l in idx and in iface. l must be valid, iface should be
// responsible for l's IP. It returns an error if l duplicates at least a
// single value of another lease, except for the updated lease itself.
func (idx *leaseIndex) update(l *Lease, iface *netInterface) (err error) {
loweredName := strings.ToLower(l.Hostname)
existing, ok := idx.byAddr[l.IP]
if ok && !slices.Equal(l.HWAddr, existing.HWAddr) {
return fmt.Errorf("lease for ip %s already exists", l.IP)
}
existing, ok = idx.byName[loweredName]
if ok && !slices.Equal(l.HWAddr, existing.HWAddr) {
return fmt.Errorf("lease for hostname %s already exists", l.Hostname)
}
prev, err := iface.updateLease(l)
if err != nil {
return err
}
delete(idx.byAddr, prev.IP)
delete(idx.byName, strings.ToLower(prev.Hostname))
idx.byAddr[l.IP] = l
idx.byName[loweredName] = l
return nil
}

View File

@@ -2,11 +2,15 @@ package dhcpsvc
import (
"fmt"
"net"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
@@ -15,18 +19,21 @@ type DHCPServer struct {
// information about its clients.
enabled *atomic.Bool
// localTLD is the top-level domain name to use for resolving DHCP
// clients' hostnames.
// localTLD is the top-level domain name to use for resolving DHCP clients'
// hostnames.
localTLD string
// leasesMu protects the leases index as well as leases in the interfaces.
leasesMu *sync.RWMutex
// leases stores the DHCP leases for quick lookups.
leases *leaseIndex
// interfaces4 is the set of IPv4 interfaces sorted by interface name.
interfaces4 []*iface4
interfaces4 netInterfacesV4
// interfaces6 is the set of IPv6 interfaces sorted by interface name.
interfaces6 []*iface6
// leases is the set of active DHCP leases.
leases []*Lease
interfaces6 netInterfacesV6
// icmpTimeout is the timeout for checking another DHCP server's presence.
icmpTimeout time.Duration
@@ -42,26 +49,27 @@ func New(conf *Config) (srv *DHCPServer, err error) {
return nil, nil
}
ifaces4 := make([]*iface4, len(conf.Interfaces))
ifaces6 := make([]*iface6, len(conf.Interfaces))
// TODO(e.burkov): Add validations scoped to the network interfaces set.
ifaces4 := make(netInterfacesV4, 0, len(conf.Interfaces))
ifaces6 := make(netInterfacesV6, 0, len(conf.Interfaces))
ifaceNames := maps.Keys(conf.Interfaces)
slices.Sort(ifaceNames)
var i4 *iface4
var i6 *iface6
var i4 *netInterfaceV4
var i6 *netInterfaceV6
for _, ifaceName := range ifaceNames {
iface := conf.Interfaces[ifaceName]
i4, err = newIface4(ifaceName, iface.IPv4)
i4, err = newNetInterfaceV4(ifaceName, iface.IPv4)
if err != nil {
return nil, fmt.Errorf("interface %q: ipv4: %w", ifaceName, err)
} else if i4 != nil {
ifaces4 = append(ifaces4, i4)
}
i6 = newIface6(ifaceName, iface.IPv6)
i6 = newNetInterfaceV6(ifaceName, iface.IPv6)
if i6 != nil {
ifaces6 = append(ifaces6, i6)
}
@@ -70,13 +78,19 @@ func New(conf *Config) (srv *DHCPServer, err error) {
enabled := &atomic.Bool{}
enabled.Store(conf.Enabled)
return &DHCPServer{
srv = &DHCPServer{
enabled: enabled,
localTLD: conf.LocalDomainName,
leasesMu: &sync.RWMutex{},
leases: newLeaseIndex(),
interfaces4: ifaces4,
interfaces6: ifaces6,
localTLD: conf.LocalDomainName,
icmpTimeout: conf.ICMPTimeout,
}, nil
}
// TODO(e.burkov): Load leases.
return srv, nil
}
// type check
@@ -91,10 +105,140 @@ func (srv *DHCPServer) Enabled() (ok bool) {
// Leases implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Leases() (leases []*Lease) {
leases = make([]*Lease, 0, len(srv.leases))
for _, lease := range srv.leases {
leases = append(leases, lease.Clone())
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
for _, iface := range srv.interfaces4 {
for _, lease := range iface.leases {
leases = append(leases, lease.Clone())
}
}
for _, iface := range srv.interfaces6 {
for _, lease := range iface.leases {
leases = append(leases, lease.Clone())
}
}
return leases
}
// HostByIP implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) HostByIP(ip netip.Addr) (host string) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
if l, ok := srv.leases.leaseByAddr(ip); ok {
return l.Hostname
}
return ""
}
// MACByIP implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) MACByIP(ip netip.Addr) (mac net.HardwareAddr) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
if l, ok := srv.leases.leaseByAddr(ip); ok {
return l.HWAddr
}
return nil
}
// IPByHost implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) IPByHost(host string) (ip netip.Addr) {
srv.leasesMu.RLock()
defer srv.leasesMu.RUnlock()
if l, ok := srv.leases.leaseByName(host); ok {
return l.IP
}
return netip.Addr{}
}
// Reset implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) Reset() (err error) {
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
for _, iface := range srv.interfaces4 {
iface.reset()
}
for _, iface := range srv.interfaces6 {
iface.reset()
}
srv.leases.clear()
return nil
}
// AddLease implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) AddLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "adding lease: %w") }()
addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
return srv.leases.add(l, iface)
}
// UpdateStaticLease implements the [Interface] interface for *DHCPServer.
//
// TODO(e.burkov): Support moving leases between interfaces.
func (srv *DHCPServer) UpdateStaticLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "updating static lease: %w") }()
addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
return srv.leases.update(l, iface)
}
// RemoveLease implements the [Interface] interface for *DHCPServer.
func (srv *DHCPServer) RemoveLease(l *Lease) (err error) {
defer func() { err = errors.Annotate(err, "removing lease: %w") }()
addr := l.IP
iface, err := srv.ifaceForAddr(addr)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
}
srv.leasesMu.Lock()
defer srv.leasesMu.Unlock()
return srv.leases.remove(l, iface)
}
// ifaceForAddr returns the handled network interface for the given IP address,
// or an error if no such interface exists.
func (srv *DHCPServer) ifaceForAddr(addr netip.Addr) (iface *netInterface, err error) {
var ok bool
if addr.Is4() {
iface, ok = srv.interfaces4.find(addr)
} else {
iface, ok = srv.interfaces6.find(addr)
}
if !ok {
return nil, fmt.Errorf("no interface for ip %s", addr)
}
return iface, nil
}

View File

@@ -1,17 +1,67 @@
package dhcpsvc_test
import (
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testLocalTLD is a common local TLD for tests.
const testLocalTLD = "local"
// testInterfaceConf is a common set of interface configurations for tests.
var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
"eth0": {
IPv4: &dhcpsvc.IPv4Config{
Enabled: true,
GatewayIP: netip.MustParseAddr("192.168.0.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("192.168.0.2"),
RangeEnd: netip.MustParseAddr("192.168.0.254"),
LeaseDuration: 1 * time.Hour,
},
IPv6: &dhcpsvc.IPv6Config{
Enabled: true,
RangeStart: netip.MustParseAddr("2001:db8::1"),
LeaseDuration: 1 * time.Hour,
RAAllowSLAAC: true,
RASLAACOnly: true,
},
},
"eth1": {
IPv4: &dhcpsvc.IPv4Config{
Enabled: true,
GatewayIP: netip.MustParseAddr("172.16.0.1"),
SubnetMask: netip.MustParseAddr("255.255.255.0"),
RangeStart: netip.MustParseAddr("172.16.0.2"),
RangeEnd: netip.MustParseAddr("172.16.0.255"),
LeaseDuration: 1 * time.Hour,
},
IPv6: &dhcpsvc.IPv6Config{
Enabled: true,
RangeStart: netip.MustParseAddr("2001:db9::1"),
LeaseDuration: 1 * time.Hour,
RAAllowSLAAC: true,
RASLAACOnly: true,
},
},
}
// mustParseMAC parses a hardware address from s and requires no errors.
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
require.NoError(t, err)
return mac
}
func TestNew(t *testing.T) {
validIPv4Conf := &dhcpsvc.IPv4Config{
Enabled: true,
@@ -113,3 +163,433 @@ func TestNew(t *testing.T) {
})
}
}
func TestDHCPServer_AddLease(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
require.NoError(t, err)
const (
host1 = "host1"
host2 = "host2"
host3 = "host3"
)
ip1 := netip.MustParseAddr("192.168.0.2")
ip2 := netip.MustParseAddr("192.168.0.3")
ip3 := netip.MustParseAddr("2001:db8::2")
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
require.NoError(t, srv.AddLease(&dhcpsvc.Lease{
Hostname: host1,
IP: ip1,
HWAddr: mac1,
IsStatic: true,
}))
testCases := []struct {
name string
lease *dhcpsvc.Lease
wantErrMsg string
}{{
name: "outside_range",
lease: &dhcpsvc.Lease{
Hostname: host2,
IP: netip.MustParseAddr("1.2.3.4"),
HWAddr: mac2,
},
wantErrMsg: "adding lease: no interface for ip 1.2.3.4",
}, {
name: "duplicate_ip",
lease: &dhcpsvc.Lease{
Hostname: host2,
IP: ip1,
HWAddr: mac2,
},
wantErrMsg: "adding lease: lease for ip " + ip1.String() +
" already exists",
}, {
name: "duplicate_hostname",
lease: &dhcpsvc.Lease{
Hostname: host1,
IP: ip2,
HWAddr: mac2,
},
wantErrMsg: "adding lease: lease for hostname " + host1 +
" already exists",
}, {
name: "duplicate_hostname_case",
lease: &dhcpsvc.Lease{
Hostname: strings.ToUpper(host1),
IP: ip2,
HWAddr: mac2,
},
wantErrMsg: "adding lease: lease for hostname " +
strings.ToUpper(host1) + " already exists",
}, {
name: "duplicate_mac",
lease: &dhcpsvc.Lease{
Hostname: host2,
IP: ip2,
HWAddr: mac1,
},
wantErrMsg: "adding lease: lease for mac " + mac1.String() +
" already exists",
}, {
name: "valid",
lease: &dhcpsvc.Lease{
Hostname: host2,
IP: ip2,
HWAddr: mac2,
},
wantErrMsg: "",
}, {
name: "valid_v6",
lease: &dhcpsvc.Lease{
Hostname: host3,
IP: ip3,
HWAddr: mac3,
},
wantErrMsg: "",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.AddLease(tc.lease))
})
}
}
func TestDHCPServer_index(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
require.NoError(t, err)
const (
host1 = "host1"
host2 = "host2"
host3 = "host3"
host4 = "host4"
host5 = "host5"
)
ip1 := netip.MustParseAddr("192.168.0.2")
ip2 := netip.MustParseAddr("192.168.0.3")
ip3 := netip.MustParseAddr("172.16.0.3")
ip4 := netip.MustParseAddr("172.16.0.4")
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
leases := []*dhcpsvc.Lease{{
Hostname: host1,
IP: ip1,
HWAddr: mac1,
IsStatic: true,
}, {
Hostname: host2,
IP: ip2,
HWAddr: mac2,
IsStatic: true,
}, {
Hostname: host3,
IP: ip3,
HWAddr: mac3,
IsStatic: true,
}, {
Hostname: host4,
IP: ip4,
HWAddr: mac1,
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
}
t.Run("ip_idx", func(t *testing.T) {
assert.Equal(t, ip1, srv.IPByHost(host1))
assert.Equal(t, ip2, srv.IPByHost(host2))
assert.Equal(t, ip3, srv.IPByHost(host3))
assert.Equal(t, ip4, srv.IPByHost(host4))
assert.Equal(t, netip.Addr{}, srv.IPByHost(host5))
})
t.Run("name_idx", func(t *testing.T) {
assert.Equal(t, host1, srv.HostByIP(ip1))
assert.Equal(t, host2, srv.HostByIP(ip2))
assert.Equal(t, host3, srv.HostByIP(ip3))
assert.Equal(t, host4, srv.HostByIP(ip4))
assert.Equal(t, "", srv.HostByIP(netip.Addr{}))
})
t.Run("mac_idx", func(t *testing.T) {
assert.Equal(t, mac1, srv.MACByIP(ip1))
assert.Equal(t, mac2, srv.MACByIP(ip2))
assert.Equal(t, mac3, srv.MACByIP(ip3))
assert.Equal(t, mac1, srv.MACByIP(ip4))
assert.Nil(t, srv.MACByIP(netip.Addr{}))
})
}
func TestDHCPServer_UpdateStaticLease(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
require.NoError(t, err)
const (
host1 = "host1"
host2 = "host2"
host3 = "host3"
host4 = "host4"
host5 = "host5"
host6 = "host6"
)
ip1 := netip.MustParseAddr("192.168.0.2")
ip2 := netip.MustParseAddr("192.168.0.3")
ip3 := netip.MustParseAddr("192.168.0.4")
ip4 := netip.MustParseAddr("2001:db8::2")
ip5 := netip.MustParseAddr("2001:db8::3")
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
mac2 := mustParseMAC(t, "01:02:03:04:05:07")
mac3 := mustParseMAC(t, "06:05:04:03:02:01")
mac4 := mustParseMAC(t, "06:05:04:03:02:02")
leases := []*dhcpsvc.Lease{{
Hostname: host1,
IP: ip1,
HWAddr: mac1,
IsStatic: true,
}, {
Hostname: host2,
IP: ip2,
HWAddr: mac2,
IsStatic: true,
}, {
Hostname: host4,
IP: ip4,
HWAddr: mac4,
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
}
testCases := []struct {
name string
lease *dhcpsvc.Lease
wantErrMsg string
}{{
name: "outside_range",
lease: &dhcpsvc.Lease{
Hostname: host1,
IP: netip.MustParseAddr("1.2.3.4"),
HWAddr: mac1,
},
wantErrMsg: "updating static lease: no interface for ip 1.2.3.4",
}, {
name: "not_found",
lease: &dhcpsvc.Lease{
Hostname: host3,
IP: ip3,
HWAddr: mac3,
},
wantErrMsg: "updating static lease: no lease for mac " + mac3.String(),
}, {
name: "duplicate_ip",
lease: &dhcpsvc.Lease{
Hostname: host1,
IP: ip2,
HWAddr: mac1,
},
wantErrMsg: "updating static lease: lease for ip " + ip2.String() +
" already exists",
}, {
name: "duplicate_hostname",
lease: &dhcpsvc.Lease{
Hostname: host2,
IP: ip1,
HWAddr: mac1,
},
wantErrMsg: "updating static lease: lease for hostname " + host2 +
" already exists",
}, {
name: "duplicate_hostname_case",
lease: &dhcpsvc.Lease{
Hostname: strings.ToUpper(host2),
IP: ip1,
HWAddr: mac1,
},
wantErrMsg: "updating static lease: lease for hostname " +
strings.ToUpper(host2) + " already exists",
}, {
name: "valid",
lease: &dhcpsvc.Lease{
Hostname: host3,
IP: ip3,
HWAddr: mac1,
},
wantErrMsg: "",
}, {
name: "valid_v6",
lease: &dhcpsvc.Lease{
Hostname: host6,
IP: ip5,
HWAddr: mac4,
},
wantErrMsg: "",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.UpdateStaticLease(tc.lease))
})
}
}
func TestDHCPServer_RemoveLease(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
require.NoError(t, err)
const (
host1 = "host1"
host2 = "host2"
host3 = "host3"
)
ip1 := netip.MustParseAddr("192.168.0.2")
ip2 := netip.MustParseAddr("192.168.0.3")
ip3 := netip.MustParseAddr("2001:db8::2")
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
mac2 := mustParseMAC(t, "02:03:04:05:06:07")
mac3 := mustParseMAC(t, "06:05:04:03:02:01")
leases := []*dhcpsvc.Lease{{
Hostname: host1,
IP: ip1,
HWAddr: mac1,
IsStatic: true,
}, {
Hostname: host3,
IP: ip3,
HWAddr: mac3,
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
}
testCases := []struct {
name string
lease *dhcpsvc.Lease
wantErrMsg string
}{{
name: "not_found_mac",
lease: &dhcpsvc.Lease{
Hostname: host1,
IP: ip1,
HWAddr: mac2,
},
wantErrMsg: "removing lease: no lease for mac " + mac2.String(),
}, {
name: "not_found_ip",
lease: &dhcpsvc.Lease{
Hostname: host1,
IP: ip2,
HWAddr: mac1,
},
wantErrMsg: "removing lease: no lease for ip " + ip2.String(),
}, {
name: "not_found_host",
lease: &dhcpsvc.Lease{
Hostname: host2,
IP: ip1,
HWAddr: mac1,
},
wantErrMsg: "removing lease: no lease for hostname " + host2,
}, {
name: "valid",
lease: &dhcpsvc.Lease{
Hostname: host1,
IP: ip1,
HWAddr: mac1,
},
wantErrMsg: "",
}, {
name: "valid_v6",
lease: &dhcpsvc.Lease{
Hostname: host3,
IP: ip3,
HWAddr: mac3,
},
wantErrMsg: "",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.RemoveLease(tc.lease))
})
}
assert.Empty(t, srv.Leases())
}
func TestDHCPServer_Reset(t *testing.T) {
srv, err := dhcpsvc.New(&dhcpsvc.Config{
Enabled: true,
LocalDomainName: testLocalTLD,
Interfaces: testInterfaceConf,
})
require.NoError(t, err)
leases := []*dhcpsvc.Lease{{
Hostname: "host1",
IP: netip.MustParseAddr("192.168.0.2"),
HWAddr: mustParseMAC(t, "01:02:03:04:05:06"),
IsStatic: true,
}, {
Hostname: "host2",
IP: netip.MustParseAddr("192.168.0.3"),
HWAddr: mustParseMAC(t, "06:05:04:03:02:01"),
IsStatic: true,
}, {
Hostname: "host3",
IP: netip.MustParseAddr("2001:db8::2"),
HWAddr: mustParseMAC(t, "02:03:04:05:06:07"),
IsStatic: true,
}, {
Hostname: "host4",
IP: netip.MustParseAddr("2001:db8::3"),
HWAddr: mustParseMAC(t, "06:05:04:03:02:02"),
IsStatic: true,
}}
for _, l := range leases {
require.NoError(t, srv.AddLease(l))
}
require.Len(t, srv.Leases(), len(leases))
require.NoError(t, srv.Reset())
assert.Empty(t, srv.Leases())
}

View File

@@ -4,12 +4,12 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/gopacket/layers"
"golang.org/x/exp/slices"
)
// IPv4Config is the interface-specific configuration for DHCPv4.
@@ -64,69 +64,6 @@ func (conf *IPv4Config) validate() (err error) {
}
}
// iface4 is a DHCP interface for IPv4 address family.
type iface4 struct {
// gateway is the IP address of the network gateway.
gateway netip.Addr
// subnet is the network subnet.
subnet netip.Prefix
// addrSpace is the IPv4 address space allocated for leasing.
addrSpace ipRange
// name is the name of the interface.
name string
// implicitOpts are the options listed in Appendix A of RFC 2131 and
// initialized with default values. It must not have intersections with
// explicitOpts.
implicitOpts layers.DHCPOptions
// explicitOpts are the user-configured options. It must not have
// intersections with implicitOpts.
explicitOpts layers.DHCPOptions
// leaseTTL is the time-to-live of dynamic leases on this interface.
leaseTTL time.Duration
}
// newIface4 creates a new DHCP interface for IPv4 address family with the given
// configuration. It returns an error if the given configuration can't be used.
func newIface4(name string, conf *IPv4Config) (i *iface4, err error) {
if !conf.Enabled {
return nil, nil
}
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
switch {
case !subnet.Contains(conf.RangeStart):
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
case !subnet.Contains(conf.RangeEnd):
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
}
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
if err != nil {
return nil, err
} else if addrSpace.contains(conf.GatewayIP) {
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
}
i = &iface4{
name: name,
gateway: conf.GatewayIP,
subnet: subnet,
addrSpace: addrSpace,
leaseTTL: conf.LeaseDuration,
}
i.implicitOpts, i.explicitOpts = conf.options()
return i, nil
}
// options returns the implicit and explicit options for the interface. The two
// lists are disjoint and the implicit options are initialized with default
// values.
@@ -318,3 +255,83 @@ func (conf *IPv4Config) options() (implicit, explicit layers.DHCPOptions) {
func compareV4OptionCodes(a, b layers.DHCPOption) (res int) {
return int(a.Type) - int(b.Type)
}
// netInterfaceV4 is a DHCP interface for IPv4 address family.
type netInterfaceV4 struct {
// gateway is the IP address of the network gateway.
gateway netip.Addr
// subnet is the network subnet.
subnet netip.Prefix
// addrSpace is the IPv4 address space allocated for leasing.
addrSpace ipRange
// implicitOpts are the options listed in Appendix A of RFC 2131 and
// initialized with default values. It must not have intersections with
// explicitOpts.
implicitOpts layers.DHCPOptions
// explicitOpts are the user-configured options. It must not have
// intersections with implicitOpts.
explicitOpts layers.DHCPOptions
// netInterface is embedded here to provide some common network interface
// logic.
netInterface
}
// newNetInterfaceV4 creates a new DHCP interface for IPv4 address family with
// the given configuration. It returns an error if the given configuration
// can't be used.
func newNetInterfaceV4(name string, conf *IPv4Config) (i *netInterfaceV4, err error) {
if !conf.Enabled {
return nil, nil
}
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
switch {
case !subnet.Contains(conf.RangeStart):
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
case !subnet.Contains(conf.RangeEnd):
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
}
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
if err != nil {
return nil, err
} else if addrSpace.contains(conf.GatewayIP) {
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
}
i = &netInterfaceV4{
gateway: conf.GatewayIP,
subnet: subnet,
addrSpace: addrSpace,
netInterface: netInterface{
name: name,
leaseTTL: conf.LeaseDuration,
},
}
i.implicitOpts, i.explicitOpts = conf.options()
return i, nil
}
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
type netInterfacesV4 []*netInterfaceV4
// find returns the first network interface within ifaces containing ip. It
// returns false if there is no such interface.
func (ifaces netInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV4) (contains bool) {
return iface.subnet.Contains(ip)
})
if i < 0 {
return nil, false
}
return &ifaces[i].netInterface, true
}

View File

@@ -3,11 +3,12 @@ package dhcpsvc
import (
"fmt"
"net/netip"
"slices"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/google/gopacket/layers"
"golang.org/x/exp/slices"
)
// IPv6Config is the interface-specific configuration for DHCPv6.
@@ -52,57 +53,6 @@ func (conf *IPv6Config) validate() (err error) {
}
}
// iface6 is a DHCP interface for IPv6 address family.
//
// TODO(e.burkov): Add options.
type iface6 struct {
// rangeStart is the first IP address in the range.
rangeStart netip.Addr
// name is the name of the interface.
name string
// implicitOpts are the DHCPv6 options listed in RFC 8415 (and others) and
// initialized with default values. It must not have intersections with
// explicitOpts.
implicitOpts layers.DHCPv6Options
// explicitOpts are the user-configured options. It must not have
// intersections with implicitOpts.
explicitOpts layers.DHCPv6Options
// leaseTTL is the time-to-live of dynamic leases on this interface.
leaseTTL time.Duration
// raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO
// flags.
raSLAACOnly bool
// raAllowSLAAC defines if DHCP should send ICMPv6.RA packets with MO flags.
raAllowSLAAC bool
}
// newIface6 creates a new DHCP interface for IPv6 address family with the given
// configuration.
//
// TODO(e.burkov): Validate properly.
func newIface6(name string, conf *IPv6Config) (i *iface6) {
if !conf.Enabled {
return nil
}
i = &iface6{
name: name,
rangeStart: conf.RangeStart,
leaseTTL: conf.LeaseDuration,
raSLAACOnly: conf.RASLAACOnly,
raAllowSLAAC: conf.RAAllowSLAAC,
}
i.implicitOpts, i.explicitOpts = conf.options()
return i
}
// options returns the implicit and explicit options for the interface. The two
// lists are disjoint and the implicit options are initialized with default
// values.
@@ -133,3 +83,79 @@ func (conf *IPv6Config) options() (implicit, explicit layers.DHCPv6Options) {
func compareV6OptionCodes(a, b layers.DHCPv6Option) (res int) {
return int(a.Code) - int(b.Code)
}
// netInterfaceV6 is a DHCP interface for IPv6 address family.
//
// TODO(e.burkov): Add options.
type netInterfaceV6 struct {
// rangeStart is the first IP address in the range.
rangeStart netip.Addr
// implicitOpts are the DHCPv6 options listed in RFC 8415 (and others) and
// initialized with default values. It must not have intersections with
// explicitOpts.
implicitOpts layers.DHCPv6Options
// explicitOpts are the user-configured options. It must not have
// intersections with implicitOpts.
explicitOpts layers.DHCPv6Options
// netInterface is embedded here to provide some common network interface
// logic.
netInterface
// raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO
// flags.
raSLAACOnly bool
// raAllowSLAAC defines if DHCP should send ICMPv6.RA packets with MO flags.
raAllowSLAAC bool
}
// newNetInterfaceV6 creates a new DHCP interface for IPv6 address family with
// the given configuration.
//
// TODO(e.burkov): Validate properly.
func newNetInterfaceV6(name string, conf *IPv6Config) (i *netInterfaceV6) {
if !conf.Enabled {
return nil
}
i = &netInterfaceV6{
rangeStart: conf.RangeStart,
netInterface: netInterface{
name: name,
leaseTTL: conf.LeaseDuration,
},
raSLAACOnly: conf.RASLAACOnly,
raAllowSLAAC: conf.RAAllowSLAAC,
}
i.implicitOpts, i.explicitOpts = conf.options()
return i
}
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
type netInterfacesV6 []*netInterfaceV6
// find returns the first network interface within ifaces containing ip. It
// returns false if there is no such interface.
func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool) {
// prefLen is the length of prefix to match ip against.
//
// TODO(e.burkov): DHCPv6 inherits the weird behavior of legacy
// implementation where the allocated range constrained by the first address
// and the first address with last byte set to 0xff. Proper prefixes should
// be used instead.
const prefLen = netutil.IPv6BitLen - 8
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV6) (contains bool) {
return !ip.Less(iface.rangeStart) &&
netip.PrefixFrom(iface.rangeStart, prefLen).Contains(ip)
})
if i < 0 {
return nil, false
}
return &ifaces[i].netInterface, true
}

View File

@@ -14,6 +14,8 @@ import (
)
// ValidateClientID returns an error if id is not a valid ClientID.
//
// Keep in sync with [client.ValidateClientID].
func ValidateClientID(id string) (err error) {
err = netutil.ValidateHostnameLabel(id)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/netip"
"os"
"slices"
"strings"
"time"
@@ -24,7 +25,6 @@ import (
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/ameshkov/dnscrypt/v2"
"golang.org/x/exp/slices"
)
// ClientsContainer provides information about preconfigured DNS clients.
@@ -40,7 +40,7 @@ type ClientsContainer interface {
) (conf *proxy.CustomUpstreamConfig, err error)
}
// Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config represents the DNS filtering configuration of AdGuard Home. The zero
// Config is empty and ready for use.
type Config struct {
// Callbacks for other modules
@@ -357,10 +357,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
conf.DNSCryptResolverCert = c.ResolverCert
}
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
return nil, errors.Error("no default upstream servers configured")
}
conf, err = prepareCacheConfig(conf,
srvConf.CacheSize,
srvConf.CacheMinTTL,

View File

@@ -1,10 +1,10 @@
package dnsforward
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
)
func TestAnyNameMatches(t *testing.T) {

View File

@@ -2,54 +2,56 @@ package dnsforward
import (
"fmt"
"strings"
"sync"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the
// actual DNS availability of each upstream.
// upstreamConfigValidator parses each section of an upstream configuration into
// a corresponding [*proxy.UpstreamConfig] and checks the actual DNS
// availability of each upstream.
type upstreamConfigValidator struct {
// general is the general upstream configuration.
general []*upstreamResult
// generalUpstreamResults contains upstream results of a general section.
generalUpstreamResults map[string]*upstreamResult
// fallback is the fallback upstream configuration.
fallback []*upstreamResult
// fallbackUpstreamResults contains upstream results of a fallback section.
fallbackUpstreamResults map[string]*upstreamResult
// private is the private upstream configuration.
private []*upstreamResult
// privateUpstreamResults contains upstream results of a private section.
privateUpstreamResults map[string]*upstreamResult
// generalParseResults contains parsing results of a general section.
generalParseResults []*parseResult
// fallbackParseResults contains parsing results of a fallback section.
fallbackParseResults []*parseResult
// privateParseResults contains parsing results of a private section.
privateParseResults []*parseResult
}
// upstreamResult is a result of validation of an [upstream.Upstream] within an
// upstreamResult is a result of parsing of an [upstream.Upstream] within an
// [proxy.UpstreamConfig].
type upstreamResult struct {
// server is the parsed upstream. It is nil when there was an error during
// parsing.
// server is the parsed upstream.
server upstream.Upstream
// err is the error either from parsing or from checking the upstream.
// err is the upstream check error.
err error
// original is the piece of configuration that have either been turned to an
// upstream or caused an error.
original string
// isSpecific is true if the upstream is domain-specific.
isSpecific bool
}
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
// if ur should be sorted before other, and 1 otherwise.
//
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
// the end.
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
return strings.Compare(ur.original, other.original)
// parseResult contains a original piece of upstream configuration and a
// corresponding error.
type parseResult struct {
err *proxy.ParseError
original string
}
// newUpstreamConfigValidator parses the upstream configuration and returns a
@@ -61,97 +63,99 @@ func newUpstreamConfigValidator(
private []string,
opts *upstream.Options,
) (cv *upstreamConfigValidator) {
cv = &upstreamConfigValidator{}
cv = &upstreamConfigValidator{
generalUpstreamResults: map[string]*upstreamResult{},
fallbackUpstreamResults: map[string]*upstreamResult{},
privateUpstreamResults: map[string]*upstreamResult{},
}
for _, line := range general {
cv.general = cv.insertLineResults(cv.general, line, opts)
}
for _, line := range fallback {
cv.fallback = cv.insertLineResults(cv.fallback, line, opts)
}
for _, line := range private {
cv.private = cv.insertLineResults(cv.private, line, opts)
}
conf, err := proxy.ParseUpstreamsConfig(general, opts)
cv.generalParseResults = collectErrResults(general, err)
insertConfResults(conf, cv.generalUpstreamResults)
conf, err = proxy.ParseUpstreamsConfig(fallback, opts)
cv.fallbackParseResults = collectErrResults(fallback, err)
insertConfResults(conf, cv.fallbackUpstreamResults)
conf, err = proxy.ParseUpstreamsConfig(private, opts)
cv.privateParseResults = collectErrResults(private, err)
insertConfResults(conf, cv.privateUpstreamResults)
return cv
}
// insertLineResults parses line and inserts the result into s. It can insert
// multiple results as well as none.
func (cv *upstreamConfigValidator) insertLineResults(
s []*upstreamResult,
line string,
opts *upstream.Options,
) (result []*upstreamResult) {
upstreams, isSpecific, err := splitUpstreamLine(line)
if err != nil {
return cv.insert(s, &upstreamResult{
err: err,
original: line,
})
// collectErrResults parses err and returns parsing results containing the
// original upstream configuration line and the corresponding error. err can be
// nil.
func collectErrResults(lines []string, err error) (results []*parseResult) {
if err == nil {
return nil
}
for _, upstreamAddr := range upstreams {
var res *upstreamResult
if upstreamAddr != "#" {
res = cv.parseUpstream(upstreamAddr, opts)
} else if !isSpecific {
res = &upstreamResult{
err: errNotDomainSpecific,
original: upstreamAddr,
}
} else {
// limit is a maximum length for upstream configuration lines.
const limit = 80
wrapper, ok := err.(errors.WrapperSlice)
if !ok {
log.Debug("dnsforward: configvalidator: unwrapping: %s", err)
return nil
}
errs := wrapper.Unwrap()
results = make([]*parseResult, 0, len(errs))
for i, e := range errs {
var parseErr *proxy.ParseError
if !errors.As(e, &parseErr) {
log.Debug("dnsforward: configvalidator: inserting unexpected error %d: %s", i, err)
continue
}
res.isSpecific = isSpecific
s = cv.insert(s, res)
}
return s
}
// insert inserts r into slice in a sorted order, except duplicates. slice must
// not be nil.
func (cv *upstreamConfigValidator) insert(
s []*upstreamResult,
r *upstreamResult,
) (result []*upstreamResult) {
i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare)
if has {
log.Debug("dnsforward: duplicate configuration %q", r.original)
return s
}
return slices.Insert(s, i, r)
}
// parseUpstream parses addr and returns the result of parsing. It returns nil
// if the specified server points at the default upstream server which is
// validated separately.
func (cv *upstreamConfigValidator) parseUpstream(
addr string,
opts *upstream.Options,
) (r *upstreamResult) {
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
if proto, _, ok := strings.Cut(addr, "://"); ok {
if !slices.Contains(protocols, proto) {
return &upstreamResult{
err: fmt.Errorf("bad protocol %q", proto),
original: addr,
}
idx := parseErr.Idx
line := []rune(lines[idx])
if len(line) > limit {
line = line[:limit]
line[limit-1] = '…'
}
results = append(results, &parseResult{
original: string(line),
err: parseErr,
})
}
ups, err := upstream.AddressToUpstream(addr, opts)
return results
}
return &upstreamResult{
server: ups,
err: err,
original: addr,
// insertConfResults parses conf and inserts the upstream result into results.
// It can insert multiple results as well as none.
func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) {
insertListResults(conf.Upstreams, results, false)
for _, ups := range conf.DomainReservedUpstreams {
insertListResults(ups, results, true)
}
for _, ups := range conf.SpecifiedDomainUpstreams {
insertListResults(ups, results, true)
}
}
// insertListResults constructs upstream results from the upstream list and
// inserts them into results. It can insert multiple results as well as none.
func insertListResults(ups []upstream.Upstream, results map[string]*upstreamResult, specific bool) {
for _, u := range ups {
addr := u.Address()
_, ok := results[addr]
if ok {
continue
}
results[addr] = &upstreamResult{
server: u,
isSpecific: specific,
}
}
}
@@ -187,35 +191,30 @@ func (cv *upstreamConfigValidator) check() {
}
wg := &sync.WaitGroup{}
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
wg.Add(len(cv.generalUpstreamResults) +
len(cv.fallbackUpstreamResults) +
len(cv.privateUpstreamResults))
for _, res := range cv.general {
go cv.checkSrv(res, wg, commonChecker)
for _, res := range cv.generalUpstreamResults {
go checkSrv(res, wg, commonChecker)
}
for _, res := range cv.fallback {
go cv.checkSrv(res, wg, commonChecker)
for _, res := range cv.fallbackUpstreamResults {
go checkSrv(res, wg, commonChecker)
}
for _, res := range cv.private {
go cv.checkSrv(res, wg, arpaChecker)
for _, res := range cv.privateUpstreamResults {
go checkSrv(res, wg, arpaChecker)
}
wg.Wait()
}
// checkSrv runs hc on the server from res, if any, and stores any occurred
// error in res. wg is always marked done in the end. It used to be called in
// a separate goroutine.
func (cv *upstreamConfigValidator) checkSrv(
res *upstreamResult,
wg *sync.WaitGroup,
hc *healthchecker,
) {
// error in res. wg is always marked done in the end. It is intended to be
// used as a goroutine.
func checkSrv(res *upstreamResult, wg *sync.WaitGroup, hc *healthchecker) {
defer log.OnPanic(fmt.Sprintf("dnsforward: checking upstream %s", res.server.Address()))
defer wg.Done()
if res.server == nil {
return
}
res.err = hc.check(res.server)
if res.err != nil && res.isSpecific {
res.err = domainSpecificTestError{Err: res.err}
@@ -225,65 +224,126 @@ func (cv *upstreamConfigValidator) checkSrv(
// close closes all the upstreams that were successfully parsed. It enriches
// the results with deferred closing errors.
func (cv *upstreamConfigValidator) close() {
for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} {
for _, r := range slice {
if r.server != nil {
r.err = errors.WithDeferred(r.err, r.server.Close())
}
all := []map[string]*upstreamResult{
cv.generalUpstreamResults,
cv.fallbackUpstreamResults,
cv.privateUpstreamResults,
}
for _, m := range all {
for _, r := range m {
r.err = errors.WithDeferred(r.err, r.server.Close())
}
}
}
// sections of the upstream configuration according to the text label of the
// localization.
//
// Keep in sync with client/src/__locales/en.json.
//
// TODO(s.chzhen): Refactor.
const (
generalTextLabel = "upstream_dns"
fallbackTextLabel = "fallback_dns_title"
privateTextLabel = "local_ptr_title"
)
// status returns all the data collected during parsing, healthcheck, and
// closing of the upstreams. The returned map is keyed by the original upstream
// configuration piece and contains the corresponding error or "OK" if there was
// no error.
func (cv *upstreamConfigValidator) status() (results map[string]string) {
result := map[string]string{}
// Names of the upstream configuration sections for logging.
const (
generalSection = "general"
fallbackSection = "fallback"
privateSection = "private"
)
for _, res := range cv.general {
resultToStatus("general", res, result)
results = map[string]string{}
for original, res := range cv.generalUpstreamResults {
upstreamResultToStatus(generalSection, string(original), res, results)
}
for _, res := range cv.fallback {
resultToStatus("fallback", res, result)
for original, res := range cv.fallbackUpstreamResults {
upstreamResultToStatus(fallbackSection, string(original), res, results)
}
for _, res := range cv.private {
resultToStatus("private", res, result)
for original, res := range cv.privateUpstreamResults {
upstreamResultToStatus(privateSection, string(original), res, results)
}
return result
parseResultToStatus(generalTextLabel, generalSection, cv.generalParseResults, results)
parseResultToStatus(fallbackTextLabel, fallbackSection, cv.fallbackParseResults, results)
parseResultToStatus(privateTextLabel, privateSection, cv.privateParseResults, results)
return results
}
// resultToStatus puts "OK" or an error message from res into resMap. section
// is the name of the upstream configuration section, i.e. "general",
// upstreamResultToStatus puts "OK" or an error message from res into resMap.
// section is the name of the upstream configuration section, i.e. "general",
// "fallback", or "private", and only used for logging.
//
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
// put together in a single map, which may lead to collisions, see AG-27539.
// Improve the results compilation.
func resultToStatus(section string, res *upstreamResult, resMap map[string]string) {
func upstreamResultToStatus(
section string,
original string,
res *upstreamResult,
resMap map[string]string,
) {
val := "OK"
if res.err != nil {
val = res.err.Error()
}
prevVal := resMap[res.original]
prevVal := resMap[original]
switch prevVal {
case "":
resMap[res.original] = val
resMap[original] = val
case val:
log.Debug("dnsforward: duplicating %s config line %q", section, res.original)
log.Debug("dnsforward: duplicating %s config line %q", section, original)
default:
log.Debug(
"dnsforward: warning: %s config line %q (%v) had different result %v",
section,
val,
res.original,
original,
prevVal,
)
}
}
// parseResultToStatus puts parsing error messages from results into resMap.
// section is the name of the upstream configuration section, i.e. "general",
// "fallback", or "private", and only used for logging.
//
// Parsing error message has the following format:
//
// sectionTextLabel line: parsing error
//
// Where sectionTextLabel is a section text label of a localization and line is
// a line number.
func parseResultToStatus(
textLabel string,
section string,
results []*parseResult,
resMap map[string]string,
) {
for _, res := range results {
original := res.original
_, ok := resMap[original]
if ok {
log.Debug("dnsforward: duplicating %s parsing error %q", section, original)
continue
}
resMap[original] = fmt.Sprintf("%s %d: parsing error", textLabel, res.err.Idx+1)
}
}
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
// the tested upstream domain-specific and therefore consider its errors
// non-critical.
@@ -342,7 +402,7 @@ func (h *healthchecker) check(u upstream.Upstream) (err error) {
if err != nil {
return fmt.Errorf("couldn't communicate with upstream: %w", err)
} else if h.ansEmpty && len(reply.Answer) > 0 {
return errWrongResponse
return errors.Error("wrong response")
}
return nil

View File

@@ -8,7 +8,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@@ -101,21 +100,6 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
type answerMap = map[uint16][sectionsNum][]dns.RR
pt := testutil.PanicT{}
newUps := func(answers answerMap) (u upstream.Upstream) {
return aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
q := req.Question[0]
require.Contains(pt, answers, q.Qtype)
answer := answers[q.Qtype]
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = answer[sectionAnswer]
resp.Ns = answer[sectionAuthority]
resp.Extra = answer[sectionAdditional]
return resp, nil
})
}
testCases := []struct {
name string
@@ -265,13 +249,16 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
}}
localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
localUps := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
require.Equal(pt, req.Question[0].Name, ptr64Domain)
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = []dns.RR{localRR}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Len(pt, m.Question, 1)
require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp := (&dns.Msg{
Answer: []dns.RR{localRR},
}).SetReply(m)
return resp, nil
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{
Net: "tcp",
@@ -279,25 +266,44 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
}
for _, tc := range testCases {
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be reused
// right after stop, due to a data race in [proxy.Proxy.Init] method
// when setting an OOB size. As a temporary workaround, recreate the
// whole server for each test case.
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, localUps)
tc := tc
t.Run(tc.name, func(t *testing.T) {
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newUps(tc.upsAns)}
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{
Answer: answer[sectionAnswer],
Ns: answer[sectionAuthority],
Extra: answer[sectionAdditional],
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
upsAddr := aghtest.StartLocalhostUpstream(t, upsHdlr).String()
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be
// reused right after stop, due to a data race in [proxy.Proxy.Init]
// method when setting an OOB size. As a temporary workaround,
// recreate the whole server for each test case.
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/netip"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
@@ -30,7 +31,6 @@ import (
"github.com/AdguardTeam/golibs/netutil/sysresolv"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// DefaultTimeout is the default upstream timeout
@@ -464,7 +464,8 @@ func (s *Server) Start() error {
// startLocked starts the DNS server without locking. s.serverLock is expected
// to be locked.
func (s *Server) startLocked() error {
err := s.dnsProxy.Start()
// TODO(e.burkov): Use context properly.
err := s.dnsProxy.Start(context.Background())
if err == nil {
s.isRunning = true
}
@@ -517,35 +518,56 @@ func (s *Server) prepareLocalResolvers(
return uc, nil
}
// LocalResolversError is an error type for errors during local resolvers setup.
// This is only needed to distinguish these errors from errors returned by
// creating the proxy.
type LocalResolversError struct {
Err error
}
// type check
var _ error = (*LocalResolversError)(nil)
// Error implements the error interface for *LocalResolversError.
func (err *LocalResolversError) Error() (s string) {
return fmt.Sprintf("creating local resolvers: %s", err.Err)
}
// type check
var _ errors.Wrapper = (*LocalResolversError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
func (err *LocalResolversError) Unwrap() error {
return err.Err
}
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running.
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
uc, err := s.prepareLocalResolvers(boot)
// It assumes s.serverLock is locked or s not running. It returns the upstream
// configuration used for private PTR resolving, or nil if it's disabled. Note,
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return err
return nil, err
}
s.localResolvers = &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: uc,
},
}
err = s.localResolvers.Init()
localResolvers, err := proxy.New(&proxy.Config{
UpstreamConfig: uc,
})
if err != nil {
return fmt.Errorf("initializing proxy: %w", err)
return nil, &LocalResolversError{Err: err}
}
s.localResolvers = localResolvers
// TODO(e.burkov): Should we also consider the DNS64 usage?
if s.conf.UsePrivateRDNS &&
// Only set the upstream config if there are any upstreams. It's safe
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
}
return nil
return uc, nil
}
// Prepare initializes parameters of s using data from conf. conf must not be
@@ -586,21 +608,24 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err)
}
// Set the proxy here because [setupLocalResolvers] sets its values.
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
s.dnsProxy = &proxy.Proxy{Config: *proxyConfig}
err = s.setupLocalResolvers(boot)
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
err = s.setupFallbackDNS()
proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err)
}
dnsProxy, err := proxy.New(proxyConfig)
if err != nil {
return fmt.Errorf("creating proxy: %w", err)
}
s.dnsProxy = dnsProxy
s.recDetector.clear()
s.setupAddrProc()
@@ -643,26 +668,25 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
}
// setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (err error) {
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
fallbacks := s.conf.FallbackDNS
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
if len(fallbacks) == 0 {
return nil
return nil, nil
}
uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
uc, err = proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
// TODO(s.chzhen): Investigate if other options are needed.
Timeout: s.conf.UpstreamTimeout,
PreferIPv6: s.conf.BootstrapPreferIPv6,
// TODO(e.burkov): Use bootstrap.
})
if err != nil {
// Do not wrap the error because it's informative enough as is.
return err
return nil, err
}
s.dnsProxy.Fallbacks = uc
return nil
return uc, nil
}
// setupAddrProc initializes the address processor. It assumes s.serverLock is
@@ -730,19 +754,9 @@ func (s *Server) prepareInternalProxy() (err error) {
return fmt.Errorf("invalid upstream mode: %w", err)
}
// TODO(a.garipov): Make a proper constructor for proxy.Proxy.
p := &proxy.Proxy{
Config: *conf,
}
s.internalProxy, err = proxy.New(conf)
err = p.Init()
if err != nil {
return err
}
s.internalProxy = p
return nil
return err
}
// Stop stops the DNS server.
@@ -761,14 +775,17 @@ func (s *Server) stopLocked() (err error) {
// [upstream.Upstream] implementations.
if s.dnsProxy != nil {
err = s.dnsProxy.Stop()
// TODO(e.burkov): Use context properly.
err = s.dnsProxy.Shutdown(context.Background())
if err != nil {
log.Error("dnsforward: closing primary resolvers: %s", err)
}
}
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
@@ -841,6 +858,8 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
}
}
// TODO(e.burkov): It seems an error here brings the server down, which is
// not reliable enough.
err = s.Prepare(conf)
if err != nil {
return fmt.Errorf("could not reconfigure the server: %w", err)

View File

@@ -5,9 +5,11 @@ import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/pem"
"fmt"
"math/big"
@@ -63,8 +65,7 @@ func startDeferStop(t *testing.T, s *Server) {
t.Helper()
err := s.Start()
require.NoErrorf(t, err, "failed to start server: %s", err)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, s.Stop)
}
@@ -72,7 +73,6 @@ func createTestServer(
t *testing.T,
filterConf *filtering.Config,
forwardConf ServerConfig,
localUps upstream.Upstream,
) (s *Server) {
t.Helper()
@@ -82,7 +82,8 @@ func createTestServer(
@@||whitelist.example.org^
||127.0.0.255`
filters := []filtering.Filter{{
ID: 0, Data: []byte(rules),
ID: 0,
Data: []byte(rules),
}}
f, err := filtering.New(filterConf, filters)
@@ -105,19 +106,6 @@ func createTestServer(
err = s.Prepare(&forwardConf)
require.NoError(t, err)
s.serverLock.Lock()
defer s.serverLock.Unlock()
// TODO(e.burkov): Try to move it higher.
if localUps != nil {
ups := []upstream.Upstream{localUps}
s.localResolvers.UpstreamConfig.Upstreams = ups
s.conf.UsePrivateRDNS = true
s.dnsProxy.PrivateRDNSUpstreamConfig = &proxy.UpstreamConfig{
Upstreams: ups,
}
}
return s
}
@@ -181,7 +169,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
})
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf
@@ -310,7 +298,7 @@ func TestServer(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
@@ -410,7 +398,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
@@ -490,7 +478,7 @@ func TestServerRace(t *testing.T) {
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s := createTestServer(t, filterConf, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s)
@@ -545,7 +533,7 @@ func TestSafeSearch(t *testing.T) {
},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@@ -628,7 +616,7 @@ func TestInvalidRequest(t *testing.T) {
},
},
ServePlainDNS: true,
}, nil)
})
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@@ -662,7 +650,7 @@ func TestBlockedRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
}, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -698,7 +686,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
}, forwardConf)
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1)
@@ -773,7 +761,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
},
},
ServePlainDNS: true,
}, nil)
})
testUpstm := &aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
@@ -811,7 +799,7 @@ func TestBlockCNAME(t *testing.T) {
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
}, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
@@ -886,7 +874,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
}, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
@@ -933,7 +921,7 @@ func TestNullBlockedRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeNullIP,
}, forwardConf, nil)
}, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -1054,7 +1042,7 @@ func TestBlockedByHosts(t *testing.T) {
s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
}, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -1102,7 +1090,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@@ -1330,6 +1318,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
@@ -1481,6 +1470,8 @@ func TestServer_Exchange(t *testing.T) {
onesIP = netip.MustParseAddr("1.1.1.1")
twosIP = netip.MustParseAddr("2.2.2.2")
localIP = netip.MustParseAddr("192.168.1.1")
pt = testutil.PanicT{}
)
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
@@ -1489,72 +1480,73 @@ func TestServer_Exchange(t *testing.T) {
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
}
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
require.NoError(pt, w.WriteMsg(resp))
})
upsAddr := aghtest.StartLocalhostUpstream(t, extUpsHdlr).String()
revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, localDomainHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
}
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
errUpstream := aghtest.NewErrorUpstream()
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
refusingUpstream := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return new(dns.Msg).SetRcode(req, dns.RcodeRefused), nil
require.NoError(pt, w.WriteMsg(resp))
})
zeroTTLUps := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "zero.ttl.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = new(dns.Msg).SetReply(req)
hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
}
resp.Answer = []dns.RR{&dns.PTR{
Hdr: hdr,
Ptr: localDomainHost,
}}
return resp, nil
},
}
errUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeServerFailure)))
})
srv := &Server{
recDetector: newRecursionDetector(0, 1),
internalProxy: &proxy.Proxy{
Config: proxy.Config{
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{extUpstream},
nonPtrHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
hash := sha256.Sum256([]byte("some-host"))
resp := (&dns.Msg{
Answer: []dns.RR{&dns.TXT{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 60,
},
},
},
}
srv.conf.UsePrivateRDNS = true
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
require.NoError(t, srv.internalProxy.Init())
Txt: []string{hex.EncodeToString(hash[:])},
}},
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
refusingHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)))
})
zeroTTLHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{
Answer: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
Ptr: dns.Fqdn(localDomainHost),
}},
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
testCases := []struct {
req netip.Addr
wantErr error
locUpstream upstream.Upstream
locUpstream dns.Handler
name string
want string
wantTTL time.Duration
@@ -1569,35 +1561,35 @@ func TestServer_Exchange(t *testing.T) {
name: "local_good",
want: localDomainHost,
wantErr: nil,
locUpstream: locUpstream,
locUpstream: locUpsHdlr,
req: localIP,
wantTTL: defaultTTL,
}, {
name: "upstream_error",
want: "",
wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream,
wantErr: ErrRDNSFailed,
locUpstream: errUpsHdlr,
req: localIP,
wantTTL: 0,
}, {
name: "empty_answer_error",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: locUpstream,
locUpstream: locUpsHdlr,
req: netip.MustParseAddr("192.168.1.2"),
wantTTL: 0,
}, {
name: "invalid_answer",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: nonPtrUpstream,
locUpstream: nonPtrHdlr,
req: localIP,
wantTTL: 0,
}, {
name: "refused",
want: "",
wantErr: ErrRDNSFailed,
locUpstream: refusingUpstream,
locUpstream: refusingHdlr,
req: localIP,
wantTTL: 0,
}, {
@@ -1611,23 +1603,28 @@ func TestServer_Exchange(t *testing.T) {
name: "zero_ttl",
want: localDomainHost,
wantErr: nil,
locUpstream: zeroTTLUps,
locUpstream: zeroTTLHdlr,
req: localIP,
wantTTL: 0,
}}
for _, tc := range testCases {
pcfg := proxy.Config{
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{tc.locUpstream},
},
}
srv.localResolvers = &proxy.Proxy{
Config: pcfg,
}
require.NoError(t, srv.localResolvers.Init())
localUpsAddr := aghtest.StartLocalhostUpstream(t, tc.locUpstream).String()
t.Run(tc.name, func(t *testing.T) {
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
LocalPTRResolvers: []string{localUpsAddr},
UsePrivateRDNS: true,
ServePlainDNS: true,
})
host, ttl, eerr := srv.Exchange(tc.req)
require.ErrorIs(t, eerr, tc.wantErr)
@@ -1637,8 +1634,17 @@ func TestServer_Exchange(t *testing.T) {
}
t.Run("resolving_disabled", func(t *testing.T) {
srv.conf.UsePrivateRDNS = false
t.Cleanup(func() { srv.conf.UsePrivateRDNS = true })
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
LocalPTRResolvers: []string{},
ServePlainDNS: true,
})
host, _, eerr := srv.Exchange(localIP)

View File

@@ -42,7 +42,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
})
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
return &dns.Msg{

View File

@@ -4,6 +4,7 @@ import (
"encoding/binary"
"fmt"
"net"
"slices"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@@ -12,7 +13,6 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// beforeRequestHandler is the handler that is called before any other

View File

@@ -6,16 +6,17 @@ import (
"io"
"net/http"
"net/netip"
"slices"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
)
// jsonDNSConfig is the JSON representation of the DNS server configuration.
@@ -294,7 +295,7 @@ func (req *jsonDNSConfig) checkFallbacks() (err error) {
return nil
}
err = ValidateUpstreams(*req.Fallbacks)
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
@@ -344,7 +345,7 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
// validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil {
err = ValidateUpstreams(*req.Upstreams)
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("upstream servers: %w", err)
}
@@ -580,9 +581,6 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
return
}
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
opts := &upstream.Options{

View File

@@ -83,7 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s := createTestServer(t, filterConf, forwardConf)
s.sysResolvers = &emptySysResolvers{}
require.NoError(t, s.Start())
@@ -164,7 +164,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
ConfigModified: func() {},
ServePlainDNS: true,
}
s := createTestServer(t, filterConf, forwardConf, nil)
s := createTestServer(t, filterConf, forwardConf)
s.sysResolvers = &emptySysResolvers{}
defaultConf := s.conf
@@ -223,8 +223,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "",
}, {
name: "upstream_dns_bad",
wantSet: `validating dns config: ` +
`upstream servers: validating upstream "!!!": not an ip:port`,
wantSet: `validating dns config: upstream servers: parsing error at index 0: ` +
`cannot prepare the upstream: invalid address !!!: bad hostname "!!!": ` +
`bad top-level domain name label "!!!": bad top-level domain name label rune '!'`,
}, {
name: "bootstraps_bad",
wantSet: `validating dns config: checking bootstrap a: not a bootstrap: ParseAddr("a"): ` +
@@ -313,98 +314,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
}
}
func TestValidateUpstreams(t *testing.T) {
const sdnsStamp = `sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_J` +
`S3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczE` +
`uYWRndWFyZC5jb20`
testCases := []struct {
name string
wantErr string
set []string
}{{
name: "empty",
wantErr: ``,
set: nil,
}, {
name: "comment",
wantErr: ``,
set: []string{"# comment"},
}, {
name: "no_default",
wantErr: `no default upstreams specified`,
set: []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]" + sdnsStamp,
},
}, {
name: "with_default",
wantErr: ``,
set: []string{
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]" + sdnsStamp,
"8.8.8.8",
},
}, {
name: "invalid",
wantErr: `validating upstream "dhcp://fake.dns": bad protocol "dhcp"`,
set: []string{"dhcp://fake.dns"},
}, {
name: "invalid",
wantErr: `validating upstream "1.2.3.4.5": not an ip:port`,
set: []string{"1.2.3.4.5"},
}, {
name: "invalid",
wantErr: `validating upstream "123.3.7m": not an ip:port`,
set: []string{"123.3.7m"},
}, {
name: "invalid",
wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` +
`missing separator`,
set: []string{"[/host.com]tls://dns.adguard.com"},
}, {
name: "invalid",
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
set: []string{"[host.ru]#"},
}, {
name: "valid_default",
wantErr: ``,
set: []string{
"1.1.1.1",
"tls://1.1.1.1",
"https://dns.adguard.com/dns-query",
sdnsStamp,
"udp://dns.google",
"udp://8.8.8.8",
"[/host.com/]1.1.1.1",
"[//]tls://1.1.1.1",
"[/www.host.com/]#",
"[/host.com/google.com/]8.8.8.8",
"[/host/]" + sdnsStamp,
"[/пример.рф/]8.8.8.8",
},
}, {
name: "bad_domain",
wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` +
`bad domain name "!": bad top-level domain name label "!": ` +
`bad top-level domain name label rune '!'`,
set: []string{"[/!/]8.8.8.8"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreams(tc.set)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func TestValidateUpstreamsPrivate(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
@@ -509,6 +418,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
},
},
&aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(_ string) (err error) { return nil },
OnClose: func() (err error) { return nil },
@@ -529,7 +439,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
})
srv.etcHosts = upstream.NewHostsResolver(hc)
startDeferStop(t, srv)

View File

@@ -2,13 +2,13 @@ package dnsforward
import (
"net/netip"
"slices"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// makeResponse creates a DNS response by req and sets necessary flags. It also

View File

@@ -9,7 +9,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
@@ -87,7 +86,7 @@ func TestServer_ProcessInitial(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, c, nil)
}, c)
var gotAddr netip.Addr
s.addrProc = &aghtest.AddressProcessor{
@@ -188,7 +187,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, c, nil)
}, c)
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
dctx := &dnsContext{
@@ -248,9 +247,9 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host string
want []*dns.SVCB
wantRes resultCode
portDoH int
portDoT int
portDoQ int
addrsDoH []*net.TCPAddr
addrsDoT []*net.TCPAddr
addrsDoQ []*net.UDPAddr
qtype uint16
ddrEnabled bool
}{{
@@ -259,14 +258,14 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: testQuestionTarget,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8043,
addrsDoH: []*net.TCPAddr{{Port: 8043}},
}, {
name: "pass_qtype",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeA,
ddrEnabled: true,
portDoH: 8043,
addrsDoH: []*net.TCPAddr{{Port: 8043}},
}, {
name: "pass_disabled_tls",
wantRes: resultCodeFinish,
@@ -279,7 +278,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: false,
portDoH: 8043,
addrsDoH: []*net.TCPAddr{{Port: 8043}},
}, {
name: "dot",
wantRes: resultCodeFinish,
@@ -287,7 +286,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
addrsDoT: []*net.TCPAddr{{Port: 8043}},
}, {
name: "doh",
wantRes: resultCodeFinish,
@@ -295,7 +294,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8044,
addrsDoH: []*net.TCPAddr{{Port: 8044}},
}, {
name: "doq",
wantRes: resultCodeFinish,
@@ -303,7 +302,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoQ: 8042,
addrsDoQ: []*net.UDPAddr{{Port: 8042}},
}, {
name: "dot_doh",
wantRes: resultCodeFinish,
@@ -311,13 +310,35 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
portDoH: 8044,
addrsDoT: []*net.TCPAddr{{Port: 8043}},
addrsDoH: []*net.TCPAddr{{Port: 8044}},
}}
_, certPem, keyPem := createServerTLSConfig(t)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
HandleDDR: tc.ddrEnabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
CertificateChainData: certPem,
PrivateKeyData: keyPem,
TLSListenAddrs: tc.addrsDoT,
HTTPSListenAddrs: tc.addrsDoH,
QUICListenAddrs: tc.addrsDoQ,
},
ServePlainDNS: true,
})
// TODO(e.burkov): Generate a certificate actually containing the
// IP addresses.
s.conf.hasIPAddrs = true
req := createTestMessageWithType(tc.host, tc.qtype)
@@ -358,41 +379,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f
}
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
t.Helper()
s = &Server{
dnsFilter: createTestDNSFilter(t),
dnsProxy: &proxy.Proxy{
Config: proxy.Config{},
},
conf: ServerConfig{
Config: Config{
HandleDDR: ddrEnabled,
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
},
ServePlainDNS: true,
},
}
if portDoT > 0 {
s.dnsProxy.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
s.conf.hasIPAddrs = true
}
if portDoQ > 0 {
s.dnsProxy.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
}
if portDoH > 0 {
s.conf.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
}
return s
}
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
@@ -680,13 +666,16 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
intPTRAnswer = "some.local-client."
)
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
)
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
@@ -696,12 +685,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
// Improve Config declaration for tests.
Config: Config{
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, ups)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
testCases := []struct {
@@ -764,6 +755,16 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer(
t,
&filtering.Config{
@@ -776,14 +777,10 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
}),
)
var proxyCtx *proxy.DNSContext

View File

@@ -21,7 +21,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, nil)
})
req := &dns.Msg{
Question: []dns.Question{{

View File

@@ -2,10 +2,9 @@ package dnsforward
import (
"fmt"
"net"
"net/netip"
"os"
"strings"
"slices"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
@@ -16,29 +15,6 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
const (
// errNotDomainSpecific is returned when the upstream should be
// domain-specific, but isn't.
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
// errMissingSeparator is returned when the domain-specific part of the
// upstream configuration line isn't closed.
errMissingSeparator errors.Error = "missing separator"
// errDupSeparator is returned when the domain-specific part of the upstream
// configuration line contains more than one ending separator.
errDupSeparator errors.Error = "duplicated separator"
// errNoDefaultUpstreams is returned when there are no default upstreams
// specified in the upstream configuration.
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
// errWrongResponse is returned when the checked upstream replies in an
// unexpected way.
errWrongResponse errors.Error = "wrong response"
)
// loadUpstreams parses upstream DNS servers from the configured file or from
@@ -199,84 +175,12 @@ func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#'
}
// newUpstreamConfig validates upstreams and returns an appropriate upstream
// configuration or nil if it can't be built.
//
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
// slice already so that this function may be considered useless.
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
// No need to validate comments and empty lines.
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
if len(upstreams) == 0 {
// Consider this case valid since it means the default server should be
// used.
return nil, nil
}
err = validateUpstreamConfig(upstreams)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
conf, err = proxy.ParseUpstreamsConfig(
upstreams,
&upstream.Options{
Bootstrap: net.DefaultResolver,
Timeout: DefaultTimeout,
},
)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
} else if len(conf.Upstreams) == 0 {
return nil, errNoDefaultUpstreams
}
return conf, nil
}
// validateUpstreamConfig validates each upstream from the upstream
// configuration and returns an error if any upstream is invalid.
//
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
func validateUpstreamConfig(conf []string) (err error) {
for _, u := range conf {
var ups []string
var isSpecific bool
ups, isSpecific, err = splitUpstreamLine(u)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
for _, addr := range ups {
_, err = validateUpstream(addr, isSpecific)
if err != nil {
return fmt.Errorf("validating upstream %q: %w", addr, err)
}
}
}
return nil
}
// ValidateUpstreams validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified.
//
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
func ValidateUpstreams(upstreams []string) (err error) {
_, err = newUpstreamConfig(upstreams)
return err
}
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := newUpstreamConfig(upstreams)
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
@@ -308,66 +212,3 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}
// protocols are the supported URL schemes for upstreams.
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
// validateUpstream returns an error if u alongside with domains is not a valid
// upstream configuration. useDefault is true if the upstream is
// domain-specific and is configured to point at the default upstream server
// which is validated separately. The upstream is considered domain-specific
// only if domains is at least not nil.
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
// The special server address '#' means that default server must be used.
if useDefault = u == "#" && isSpecific; useDefault {
return useDefault, nil
}
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
if proto, _, ok := strings.Cut(u, "://"); ok {
if !slices.Contains(protocols, proto) {
return false, fmt.Errorf("bad protocol %q", proto)
}
} else if _, err = netip.ParseAddr(u); err == nil {
return false, nil
} else if _, err = netip.ParseAddrPort(u); err == nil {
return false, nil
}
return false, err
}
// splitUpstreamLine returns the upstreams and the specified domains. domains
// is nil when the upstream is not domains-specific. Otherwise it may also be
// empty.
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return []string{upstreamStr}, false, nil
}
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
if !found {
return nil, false, errMissingSeparator
} else if strings.Contains(ups, "/]") {
return nil, false, errDupSeparator
}
for i, host := range strings.Split(doms, "/") {
if host == "" {
continue
}
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
if err != nil {
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
}
isSpecific = true
}
return strings.Fields(ups), isSpecific, nil
}

View File

@@ -100,8 +100,7 @@ func TestUpstreamConfigValidator(t *testing.T) {
name: "bad_specification",
general: []string{"[/domain.example/]/]1.2.3.4"},
want: map[string]string{
"[/domain.example/]/]1.2.3.4": `splitting upstream line ` +
`"[/domain.example/]/]1.2.3.4": duplicated separator`,
"[/domain.example/]/]1.2.3.4": generalTextLabel + " 1: parsing error",
},
}, {
name: "all_different",
@@ -120,23 +119,9 @@ func TestUpstreamConfigValidator(t *testing.T) {
fallback: []string{"[/example/" + goodUps},
private: []string{"[/example//bad.123/]" + goodUps},
want: map[string]string{
`[/example/]/]` + goodUps: `splitting upstream line ` +
`"[/example/]/]` + goodUps + `": duplicated separator`,
`[/example/` + goodUps: `splitting upstream line ` +
`"[/example/` + goodUps + `": missing separator`,
`[/example//bad.123/]` + goodUps: `splitting upstream line ` +
`"[/example//bad.123/]` + goodUps + `": domain at index 2: ` +
`bad domain name "bad.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
},
}, {
name: "non-specific_default",
general: []string{
"#",
"[/example/]#",
},
want: map[string]string{
"#": "not a domain-specific upstream",
"[/example/]/]" + goodUps: generalTextLabel + " 1: parsing error",
"[/example/" + goodUps: fallbackTextLabel + " 1: parsing error",
"[/example//bad.123/]" + goodUps: privateTextLabel + " 1: parsing error",
},
}, {
name: "bad_proto",
@@ -144,7 +129,15 @@ func TestUpstreamConfigValidator(t *testing.T) {
"bad://1.2.3.4",
},
want: map[string]string{
"bad://1.2.3.4": `bad protocol "bad"`,
"bad://1.2.3.4": generalTextLabel + " 1: parsing error",
},
}, {
name: "truncated_line",
general: []string{
"This is a very long line. It will cause a parsing error and will be truncated here.",
},
want: map[string]string{
"This is a very long line. It will cause a parsing error and will be truncated …": "upstream_dns 1: parsing error",
},
}}

View File

@@ -4,13 +4,13 @@ import (
"encoding/json"
"fmt"
"net/http"
"slices"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"golang.org/x/exp/slices"
)
// serviceRules maps a service ID to its filtering rules.

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
@@ -15,7 +16,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
)
// filterDir is the subdirectory of a data directory to store downloaded

View File

@@ -12,6 +12,7 @@ import (
"path/filepath"
"runtime"
"runtime/debug"
"slices"
"strings"
"sync"
"sync/atomic"
@@ -29,7 +30,6 @@ import (
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// The IDs of built-in filter lists.

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256"
"encoding/hex"
"fmt"
"slices"
"strings"
"time"
@@ -14,7 +15,6 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
"golang.org/x/net/publicsuffix"
)

View File

@@ -3,6 +3,7 @@ package hashprefix
import (
"crypto/sha256"
"encoding/hex"
"slices"
"strings"
"testing"
"time"
@@ -12,7 +13,6 @@ import (
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
const (

View File

@@ -40,6 +40,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
},
}
watcher := &aghtest.FSWatcher{
OnStart: func() (_ error) { panic("not implemented") },
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
OnClose: func() (err error) { return nil },

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"os"
"path/filepath"
"slices"
"sync"
"time"
@@ -15,7 +16,6 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// validateFilterURL validates the filter list URL or file name.

View File

@@ -3,6 +3,7 @@ package rewrite
import (
"fmt"
"slices"
"strings"
"sync"
@@ -12,7 +13,6 @@ import (
"github.com/AdguardTeam/urlfilter/filterlist"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// Storage is a storage for rewrite rules.

View File

@@ -3,10 +3,10 @@ package filtering
import (
"encoding/json"
"net/http"
"slices"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/exp/slices"
)
// TODO(d.kolyshev): Use [rewrite.Item] instead.

View File

@@ -3,12 +3,12 @@ package filtering
import (
"fmt"
"net/netip"
"slices"
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)
// Legacy DNS rewrites

View File

@@ -6,9 +6,9 @@ import (
"fmt"
"hash/crc32"
"io"
"slices"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/slices"
)
// Parser is a filtering-rule parser that collects data, such as the checksum

View File

@@ -1698,6 +1698,14 @@ var blockedServices = []blockedService{{
Rules: []string{
"||kik.com^",
},
}, {
ID: "kook",
Name: "KOOK",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"18 18 220 220\"><path d=\"M32 87c25.86-.19 51.72-.33 77.58-.41 12-.05 24.01-.1 36.02-.2 10.46-.07 20.93-.12 31.4-.14 5.53-.01 11.07-.04 16.62-.1 5.21-.05 10.43-.07 15.64-.06a387.2 387.2 0 0 0 5.75-.04c14.35-.22 14.35-.22 18.22 2.55A17.33 17.33 0 0 1 236 95l.88 3.6c.12 3.5-.2 6.37-.94 9.78l-.78 3.69-.87 3.92-.87 4.07a1749.43 1749.43 0 0 1-1.84 8.47c-.8 3.63-1.58 7.26-2.36 10.89l-2.24 10.38-.44 2.01c-.4 1.86-.8 3.73-1.22 5.59l-.7 3.21C224 163 224 163 223 164a233.75 233.75 0 0 1-8.63.1h-2.62l-8.31-.04A6730.23 6730.23 0 0 1 184 164l-3 16h-11l-8-16H25c-6.84-12.54-6.84-12.54-5-20l1.22-6.25 1.5-7.1.8-3.8 2.11-9.9 2.15-10.13L32 87Zm7 8c-1.35 3.37-2.33 6-3.05 9.44l-.55 2.55-.56 2.73-.6 2.83-1.24 5.92c-.62 3.02-1.26 6.03-1.9 9.05l-1.2 5.76-.58 2.73-.53 2.55-.46 2.24c-.45 2.65-.45 2.65-.33 7.2h17l3-12 .97 2.62c1.7 3.74 3.14 6.44 6.03 9.38 5.17 1.54 10.74.68 16 0 1.18-4.5 1.38-6.94-.94-11.04a187.05 187.05 0 0 0-2.62-4.15c-.9-1.38-1.77-2.77-2.63-4.17l-1.2-1.83c-.61-1.81-.61-1.81.16-3.88a71.73 71.73 0 0 1 5.42-7.18l2.42-2.94a157.62 157.62 0 0 1 8.08-8.76c1.85-2.9 1.99-5.68 2.31-9.05-10.17-1.5-10.17-1.5-19.72 1.08-3.04 2.88-5.15 6.35-7.28 9.92h-1l2-11H39Zm44 7c-1.15 2.9-1.8 5.8-2.37 8.86l-.52 2.5c-.36 1.72-.7 3.45-1.04 5.18a558.08 558.08 0 0 1-1.6 7.94l-1.03 5.06-.49 2.4c-1.1 5.76-.68 8.96 2.05 14.06 5.18.15 10.35.26 15.53.33l5.29.12c2.53.08 5.06.1 7.6.14l2.37.09c4.28 0 5.75-.28 9.1-3.19 2.58-3.7 3.5-7.6 4.35-11.96l.51-2.42c.36-1.68.7-3.36 1.04-5.04a665.71 665.71 0 0 1 1.61-7.68l1.03-4.92.48-2.3c1.57-7.33 1.57-7.33.09-14.5-3.04-2.53-6.18-2.3-9.96-2.3l-2.36-.04a398.83 398.83 0 0 0-4.96-.02c-2.52 0-5.03-.06-7.54-.12l-4.83-.01-2.26-.08c-6.16.11-9.06 2.62-12.09 7.9Zm48.96-2.85c-1.75 3.37-2.51 6.83-3.23 10.54l-.53 2.48-1.05 5.19a606 606 0 0 1-1.64 7.9l-1.03 5.05-.5 2.37c-1.13 5.82-1.4 10.19 2.02 15.32a61 61 0 0 0 7.86.5l2.37.03c1.66.02 3.31.03 4.97.03 2.53 0 5.05.06 7.57.11l4.84.02 2.27.07c3.99-.06 6.47-.57 9.5-3.26 3.18-3.67 4.1-7.5 5-12.19l.51-2.49c.36-1.72.7-3.45 1.04-5.18.5-2.64 1.05-5.27 1.6-7.9l1.03-5.05.49-2.37c1.1-5.81 1.36-10.2-2.05-15.32a66.9 66.9 0 0 0-8.11-.5l-2.45-.03a558.66 558.66 0 0 0-5.12-.03c-2.61 0-5.22-.06-7.83-.11l-4.98-.02-2.35-.07c-4.75.06-7.42.93-10.2 4.91ZM181 95a1692.08 1692.08 0 0 0-8.88 39.44l-.65 3.08-.59 2.85-.52 2.5c-.44 2.26-.44 2.26-.36 5.13h17l3-12c5.87 9.75 5.87 9.75 7 12l7.94.06 2.28.03c1.93 0 3.85-.04 5.78-.09 1.34-1.14 1.34-1.14 1.32-4.12-.37-4.5-1.96-7.32-4.38-11-3.93-5.99-3.93-5.99-3.98-9.47 1.39-3.22 3.3-5.26 5.73-7.79l2.72-2.91a409.78 409.78 0 0 1 4.55-4.7l1.6-1.7 1.38-1.38c1.72-3.14 1.73-6.39 2.06-9.93-9.68-1.42-9.68-1.42-18.79 1.08-3.29 2.87-5.8 6.28-8.21 9.92l1-11h-17Z\"/><path d=\"M100 112h7c-1.88 16.88-1.88 16.88-4 19-2 .04-4 .04-6 0 .46-6.55 1.21-12.68 3-19Zm48 0h7c-1 6.46-2.02 12.77-4 19h-7c.47-2.8.95-5.58 1.44-8.38l.4-2.4c1.05-6 1.05-6 2.16-8.22Z\"/></svg>"),
Rules: []string{
"||kaiheila.cn^",
"||kookapp.cn^",
},
}, {
ID: "lazada",
Name: "Lazada",
@@ -2864,6 +2872,13 @@ var blockedServices = []blockedService{{
"||yt.be^",
"||ytimg.com^",
},
}, {
ID: "yy",
Name: "YY",
IconSVG: []byte("<svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"currentColor\" viewBox=\"33 41 194 194\"><path d=\"M36.7 100.1c-2.3 1.3-3.9 5.6-3.2 8.1.4 1 8.2 10.9 17.5 22l16.9 20.1-2.2 2.7c-12.1 14.7-12.6 16.2-8.2 20.5 2.6 2.7 7.1 3.3 9.7 1.2 2.3-1.8 53.3-62.9 54.7-65.4 2.5-4.9-1.4-10.3-7.6-10.3-2.9 0-4.5 1.6-19.4 19.2-8.9 10.6-16.5 19.3-16.9 19.3-.5 0-7.2-7.7-15-17-7.8-9.4-15.2-18-16.6-19.3-2.7-2.3-6.7-2.8-9.7-1.1zm6.2 3.9c1 .5 8.5 8.8 16.6 18.4 8 9.6 15.5 17.8 16.6 18.1 1 .4 2.8.4 3.8 0 1.1-.3 8.6-8.5 16.7-18.1 8.1-9.6 15.6-17.9 16.6-18.5 2.3-1.2 4.8.4 4.8 3.1 0 .9-3.6 5.9-7.9 11.1-4.4 5.2-15.6 18.6-25.1 29.9-18.1 21.7-20.8 24.4-23.4 23.4-2.9-1.1-1.7-5 3.4-11.4 2.8-3.5 5.3-7.3 5.6-8.6.7-2.9-.1-4.1-18.2-25.6C44.5 116.4 38 107.9 38 107c0-1.4 1.8-4 2.8-4 .2 0 1.1.4 2.1 1zm96.3-2.7c-1.2 1.3-2.2 3.5-2.2 4.9 0 3 .1 3.2 19.2 25.9 8.2 9.7 14.8 18 14.8 18.4 0 .4-2.7 3.8-6 7.7-4.8 5.5-6 7.6-6 10.3 0 5.6 5.1 8.9 10.4 6.9 1.3-.5 46.3-53.3 54.4-63.8 1.2-1.6 2.2-4.1 2.2-5.6 0-3.5-3.8-7-7.6-7-3.6 0-6.1 2.4-23.8 23.7-7 8.4-12.8 15.3-13 15.3-.2 0-7.3-8.2-15.7-18.3-8.3-10-15.9-18.8-16.7-19.5-2.6-2-7.6-1.5-10 1.1zm11.4 6.9c2.8 2.9 9.6 10.7 15.1 17.3 12.1 14.4 13.5 15.8 16.4 15.3 2.3-.3 9.1-7.5 25.4-27.1 8.6-10.4 11-12 13.2-9.1.7 1 1.3 2.1 1.3 2.6 0 .7-20 24.9-41.8 50.6-11.2 13.2-12.8 14.6-15.5 13.1-3-1.7-2.7-2.4 6.1-13.9 1.7-2.4 3.2-5.5 3.2-7 0-2-4.1-7.5-15.9-21.6-8.7-10.4-16.1-19.9-16.5-21.1-.5-1.5-.2-2.5 1.1-3.2 2.6-1.5 2.4-1.6 7.9 4.1z\"/></svg>"),
Rules: []string{
"||yy.com^",
},
}, {
ID: "zhihu",
Name: "Zhihu",

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
@@ -23,7 +24,6 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
@@ -47,8 +47,9 @@ type DHCP interface {
type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different
// types (string, netip.Addr, and so on).
list map[string]*persistentClient // name -> client
idIndex map[string]*persistentClient // ID -> client
list map[string]*client.Persistent // name -> client
clientIndex *client.Index
// ipToRC maps IP addresses to runtime client information.
ipToRC map[netip.Addr]*client.Runtime
@@ -102,10 +103,11 @@ func (clients *clientsContainer) Init(
log.Fatal("clients.list != nil")
}
clients.list = map[string]*persistentClient{}
clients.idIndex = map[string]*persistentClient{}
clients.list = map[string]*client.Persistent{}
clients.ipToRC = map[netip.Addr]*client.Runtime{}
clients.clientIndex = client.NewIndex()
clients.allTags = stringutil.NewSet(clientTags...)
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
@@ -140,8 +142,7 @@ func (clients *clientsContainer) Init(
}
// handleHostsUpdates receives the updates from the hosts container and adds
// them to the clients container. It's used to be called in a separate
// goroutine.
// them to the clients container. It is intended to be used as a goroutine.
func (clients *clientsContainer) handleHostsUpdates() {
for upd := range clients.etcHosts.Upd() {
clients.addFromHostsFile(upd)
@@ -189,7 +190,7 @@ type clientObject struct {
Upstreams []string `yaml:"upstreams"`
// UID is the unique identifier of the persistent client.
UID UID `yaml:"uid"`
UID client.UID `yaml:"uid"`
// UpstreamsCacheSize is the DNS cache size (in bytes).
//
@@ -213,8 +214,8 @@ type clientObject struct {
func (o *clientObject) toPersistent(
filteringConf *filtering.Config,
allTags *stringutil.Set,
) (cli *persistentClient, err error) {
cli = &persistentClient{
) (cli *client.Persistent, err error) {
cli = &client.Persistent{
Name: o.Name,
Upstreams: o.Upstreams,
@@ -224,7 +225,7 @@ func (o *clientObject) toPersistent(
UseOwnSettings: !o.UseGlobalSettings,
FilteringEnabled: o.FilteringEnabled,
ParentalEnabled: o.ParentalEnabled,
safeSearchConf: o.SafeSearchConf,
SafeSearchConf: o.SafeSearchConf,
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
IgnoreQueryLog: o.IgnoreQueryLog,
@@ -233,13 +234,13 @@ func (o *clientObject) toPersistent(
UpstreamsCacheSize: o.UpstreamsCacheSize,
}
err = cli.setIDs(o.IDs)
err = cli.SetIDs(o.IDs)
if err != nil {
return nil, fmt.Errorf("parsing ids: %w", err)
}
if (cli.UID == UID{}) {
cli.UID, err = NewUID()
if (cli.UID == client.UID{}) {
cli.UID, err = client.NewUID()
if err != nil {
return nil, fmt.Errorf("generating uid: %w", err)
}
@@ -248,7 +249,7 @@ func (o *clientObject) toPersistent(
if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err = cli.setSafeSearch(
err = cli.SetSafeSearch(
o.SafeSearchConf,
filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime),
@@ -265,7 +266,7 @@ func (o *clientObject) toPersistent(
cli.BlockedServices = o.BlockedServices.Clone()
cli.setTags(o.Tags, allTags)
cli.SetTags(o.Tags, allTags)
return cli, nil
}
@@ -277,7 +278,7 @@ func (clients *clientsContainer) addFromConfig(
filteringConf *filtering.Config,
) (err error) {
for i, o := range objects {
var cli *persistentClient
var cli *client.Persistent
cli, err = o.toPersistent(filteringConf, clients.allTags)
if err != nil {
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
@@ -305,7 +306,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
BlockedServices: cli.BlockedServices.Clone(),
IDs: cli.ids(),
IDs: cli.IDs(),
Tags: stringutil.CloneSlice(cli.Tags),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
@@ -314,7 +315,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled,
ParentalEnabled: cli.ParentalEnabled,
SafeSearchConf: cli.safeSearchConf,
SafeSearchConf: cli.SafeSearchConf,
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
IgnoreQueryLog: cli.IgnoreQueryLog,
@@ -435,7 +436,7 @@ func (clients *clientsContainer) clientOrArtificial(
}
// find returns a shallow copy of the client if there is one found.
func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) {
func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
@@ -444,7 +445,7 @@ func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool)
return nil, false
}
return c.shallowClone(), true
return c.ShallowClone(), true
}
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
@@ -480,8 +481,8 @@ func (clients *clientsContainer) UpstreamConfigByID(
c, ok := clients.findLocked(id)
if !ok {
return nil, nil
} else if c.upstreamConfig != nil {
return c.upstreamConfig, nil
} else if c.UpstreamConfig != nil {
return c.UpstreamConfig, nil
}
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
@@ -510,15 +511,15 @@ func (clients *clientsContainer) UpstreamConfigByID(
int(c.UpstreamsCacheSize),
config.DNS.EDNSClientSubnet.Enabled,
)
c.upstreamConfig = conf
c.UpstreamConfig = conf
return conf, nil
}
// findLocked searches for a client by its ID. clients.lock is expected to be
// locked.
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
c, ok = clients.idIndex[id]
func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) {
c, ok = clients.clientIndex.Find(id)
if ok {
return c, true
}
@@ -528,21 +529,13 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok
return nil, false
}
for _, c = range clients.list {
for _, subnet := range c.Subnets {
if subnet.Contains(ip) {
return c, true
}
}
}
// TODO(e.burkov): Iterate through clients.list only once.
return clients.findDHCP(ip)
}
// findDHCP searches for a client by its MAC, if the DHCP server is active and
// there is such client. clients.lock is expected to be locked.
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, ok bool) {
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) {
foundMAC := clients.dhcp.MACByIP(ip)
if foundMAC == nil {
return nil, false
@@ -592,13 +585,13 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru
}
// check validates the client. It also sorts the client tags.
func (clients *clientsContainer) check(c *persistentClient) (err error) {
func (clients *clientsContainer) check(c *client.Persistent) (err error) {
switch {
case c == nil:
return errors.Error("client is nil")
case c.Name == "":
return errors.Error("invalid name")
case c.idsLen() == 0:
case c.IDsLen() == 0:
return errors.Error("id required")
default:
// Go on.
@@ -613,7 +606,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) {
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
err = dnsforward.ValidateUpstreams(c.Upstreams)
_, err = proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
@@ -623,7 +616,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) {
// add adds a new client object. ok is false if such client already exists or
// if an error occurred.
func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
err = clients.check(c)
if err != nil {
return false, err
@@ -639,31 +632,26 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
}
// check ID index
ids := c.ids()
for _, id := range ids {
var c2 *persistentClient
c2, ok = clients.idIndex[id]
if ok {
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
err = clients.clientIndex.Clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return false, err
}
clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list))
return true, nil
}
// addLocked c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) addLocked(c *persistentClient) {
func (clients *clientsContainer) addLocked(c *client.Persistent) {
// update Name index
clients.list[c.Name] = c
// update ID index
for _, id := range c.ids() {
clients.idIndex[id] = c
}
clients.clientIndex.Add(c)
}
// remove removes a client. ok is false if there is no such client.
@@ -671,7 +659,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()
var c *persistentClient
var c *client.Persistent
c, ok = clients.list[name]
if !ok {
return false
@@ -684,8 +672,8 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
// removeLocked removes c from the indexes. clients.lock is expected to be
// locked.
func (clients *clientsContainer) removeLocked(c *persistentClient) {
if err := c.closeUpstreams(); err != nil {
func (clients *clientsContainer) removeLocked(c *client.Persistent) {
if err := c.CloseUpstreams(); err != nil {
log.Error("client container: removing client %s: %s", c.Name, err)
}
@@ -693,13 +681,11 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) {
delete(clients.list, c.Name)
// Update the ID index.
for _, id := range c.ids() {
delete(clients.idIndex, id)
}
clients.clientIndex.Delete(c)
}
// update updates a client by its name.
func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) {
err = clients.check(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
@@ -717,7 +703,7 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
}
}
if c.equalIDs(prev) {
if c.EqualIDs(prev) {
clients.removeLocked(prev)
clients.addLocked(c)
@@ -725,11 +711,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
}
// Check the ID index.
for _, id := range c.ids() {
existing, ok := clients.idIndex[id]
if ok && existing != prev {
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
}
err = clients.clientIndex.Clashes(c)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
clients.removeLocked(prev)
@@ -906,14 +891,14 @@ func (clients *clientsContainer) addFromSystemARP() {
// the persistent clients.
func (clients *clientsContainer) close() (err error) {
persistent := maps.Values(clients.list)
slices.SortFunc(persistent, func(a, b *persistentClient) (res int) {
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
return strings.Compare(a.Name, b.Name)
})
var errs []error
for _, cli := range persistent {
if err = cli.closeUpstreams(); err != nil {
if err = cli.CloseUpstreams(); err != nil {
errs = append(errs, err)
}
}

View File

@@ -66,8 +66,9 @@ func TestClients(t *testing.T) {
cliIPv6 = netip.MustParseAddr("1:2:3::4")
)
c := &persistentClient{
c := &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli1IP, cliIPv6},
}
@@ -76,8 +77,9 @@ func TestClients(t *testing.T) {
assert.True(t, ok)
c = &persistentClient{
c = &client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{cli2IP},
}
@@ -109,8 +111,9 @@ func TestClients(t *testing.T) {
})
t.Run("add_fail_name", func(t *testing.T) {
ok, err := clients.add(&persistentClient{
ok, err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
})
require.NoError(t, err)
@@ -118,16 +121,18 @@ func TestClients(t *testing.T) {
})
t.Run("add_fail_ip", func(t *testing.T) {
ok, err := clients.add(&persistentClient{
ok, err := clients.add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
})
require.Error(t, err)
assert.False(t, ok)
})
t.Run("update_fail_ip", func(t *testing.T) {
err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{
err := clients.update(&client.Persistent{Name: "client1"}, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
})
assert.Error(t, err)
})
@@ -143,8 +148,9 @@ func TestClients(t *testing.T) {
prev, ok := clients.list["client1"]
require.True(t, ok)
err := clients.update(prev, &persistentClient{
err := clients.update(prev, &client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{cliNewIP},
})
require.NoError(t, err)
@@ -157,8 +163,9 @@ func TestClients(t *testing.T) {
prev, ok = clients.list["client1"]
require.True(t, ok)
err = clients.update(prev, &persistentClient{
err = clients.update(prev, &client.Persistent{
Name: "client1-renamed",
UID: client.MustNewUID(),
IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true,
})
@@ -175,7 +182,7 @@ func TestClients(t *testing.T) {
assert.Nil(t, nilCli)
require.Len(t, c.ids(), 1)
require.Len(t, c.IDs(), 1)
assert.Equal(t, cliNewIP, c.IPs[0])
})
@@ -258,8 +265,9 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2")
ok, err := clients.add(&persistentClient{
ok, err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
})
require.NoError(t, err)
@@ -280,8 +288,9 @@ func TestClientsAddExisting(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1")
// Add a client.
ok, err := clients.add(&persistentClient{
ok, err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
@@ -330,16 +339,18 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC.
ok, err := clients.add(&persistentClient{
ok, err := clients.add(&client.Persistent{
Name: "client2",
UID: client.MustNewUID(),
IPs: []netip.Addr{ip},
})
require.NoError(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP range.
ok, err = clients.add(&persistentClient{
ok, err = clients.add(&client.Persistent{
Name: "client3",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
})
require.NoError(t, err)
@@ -351,8 +362,9 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t)
// Add client with upstreams.
ok, err := clients.add(&persistentClient{
ok, err := clients.add(&client.Persistent{
Name: "client1",
UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{
"1.1.1.1",

View File

@@ -131,9 +131,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
// initPrev initializes the persistent client with the default or previous
// client properties.
func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err error) {
func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err error) {
var (
uid UID
uid client.UID
ignoreQueryLog bool
ignoreStatistics bool
upsCacheEnabled bool
@@ -166,14 +166,14 @@ func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err e
return nil, fmt.Errorf("invalid blocked services: %w", err)
}
if (uid == UID{}) {
uid, err = NewUID()
if (uid == client.UID{}) {
uid, err = client.NewUID()
if err != nil {
return nil, fmt.Errorf("generating uid: %w", err)
}
}
return &persistentClient{
return &client.Persistent{
BlockedServices: svcs,
UID: uid,
IgnoreQueryLog: ignoreQueryLog,
@@ -187,21 +187,21 @@ func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err e
// errors.
func (clients *clientsContainer) jsonToClient(
cj clientJSON,
prev *persistentClient,
) (c *persistentClient, err error) {
prev *client.Persistent,
) (c *client.Persistent, err error) {
c, err = initPrev(cj, prev)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
err = c.setIDs(cj.IDs)
err = c.SetIDs(cj.IDs)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil, err
}
c.safeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
c.SafeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
c.Name = cj.Name
c.Tags = cj.Tags
c.Upstreams = cj.Upstreams
@@ -211,9 +211,9 @@ func (clients *clientsContainer) jsonToClient(
c.SafeBrowsingEnabled = cj.SafeBrowsingEnabled
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
if c.safeSearchConf.Enabled {
err = c.setSafeSearch(
c.safeSearchConf,
if c.SafeSearchConf.Enabled {
err = c.SetSafeSearch(
c.SafeSearchConf,
clients.safeSearchCacheSize,
clients.safeSearchCacheTTL,
)
@@ -258,7 +258,7 @@ func copySafeSearch(
func copyBlockedServices(
sch *schedule.Weekly,
svcStrs []string,
prev *persistentClient,
prev *client.Persistent,
) (svcs *filtering.BlockedServices, err error) {
var weekly *schedule.Weekly
if sch != nil {
@@ -283,15 +283,15 @@ func copyBlockedServices(
}
// clientToJSON converts persistent client object to JSON object.
func clientToJSON(c *persistentClient) (cj *clientJSON) {
func clientToJSON(c *client.Persistent) (cj *clientJSON) {
// TODO(d.kolyshev): Remove after cleaning the deprecated
// [clientJSON.SafeSearchEnabled] field.
cloneVal := c.safeSearchConf
cloneVal := c.SafeSearchConf
safeSearchConf := &cloneVal
return &clientJSON{
Name: c.Name,
IDs: c.ids(),
IDs: c.IDs(),
Tags: c.Tags,
UseGlobalSettings: !c.UseOwnSettings,
FilteringEnabled: c.FilteringEnabled,
@@ -397,7 +397,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return
}
var prev *persistentClient
var prev *client.Persistent
var ok bool
func() {

View File

@@ -232,6 +232,10 @@ type dnsConfig struct {
// ServePlainDNS defines if plain DNS is allowed for incoming requests.
ServePlainDNS bool `yaml:"serve_plain_dns"`
// HostsFileEnabled defines whether to use information from the system hosts
// file to resolve queries.
HostsFileEnabled bool `yaml:"hostsfile_enabled"`
}
type tlsConfigSettings struct {
@@ -259,6 +263,10 @@ type tlsConfigSettings struct {
}
type queryLogConfig struct {
// DirPath is the custom directory for logs. If it's empty the default
// directory will be used. See [homeContext.getDataDir].
DirPath string `yaml:"dir_path"`
// Ignored is the list of host names, which should not be written to log.
// "." is considered to be the root domain.
Ignored []string `yaml:"ignored"`
@@ -278,6 +286,10 @@ type queryLogConfig struct {
}
type statsConfig struct {
// DirPath is the custom directory for statistics. If it's empty the
// default directory is used. See [homeContext.getDataDir].
DirPath string `yaml:"dir_path"`
// Ignored is the list of host names, which should not be counted.
Ignored []string `yaml:"ignored"`
@@ -341,9 +353,10 @@ var config = &configuration{
// was later increased to 300 due to https://github.com/AdguardTeam/AdGuardHome/issues/2257
MaxGoroutines: 300,
},
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
ServePlainDNS: true,
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
UsePrivateRDNS: true,
ServePlainDNS: true,
HostsFileEnabled: true,
},
TLS: tlsConfigSettings{
PortHTTPS: defaultPortHTTPS,
@@ -443,20 +456,25 @@ var config = &configuration{
Theme: ThemeAuto,
}
// getConfigFilename returns path to the current config file
func (c *configuration) getConfigFilename() string {
configFile, err := filepath.EvalSymlinks(Context.configFilename)
// configFilePath returns the absolute path to the symlink-evaluated path to the
// current config file.
func configFilePath() (confPath string) {
confPath, err := filepath.EvalSymlinks(Context.confFilePath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.Error("unexpected error while config file path evaluation: %s", err)
confPath = Context.confFilePath
logFunc := log.Error
if errors.Is(err, os.ErrNotExist) {
logFunc = log.Debug
}
configFile = Context.configFilename
}
if !filepath.IsAbs(configFile) {
configFile = filepath.Join(Context.workDir, configFile)
logFunc("evaluating config path: %s; using %q", err, confPath)
}
return configFile
if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, confPath)
}
return confPath
}
// validateBindHosts returns error if any of binding hosts from configuration is
@@ -497,7 +515,10 @@ func parseConfig() (err error) {
// Don't wrap the error, because it's informative enough as is.
return err
} else if upgraded {
err = maybe.WriteFile(config.getConfigFilename(), config.fileData, 0o644)
confPath := configFilePath()
log.Debug("writing config file %q after config upgrade", confPath)
err = maybe.WriteFile(confPath, config.fileData, 0o644)
if err != nil {
return fmt.Errorf("writing new config: %w", err)
}
@@ -518,12 +539,8 @@ func parseConfig() (err error) {
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
}
err = setContextTLSCipherIDs()
if err != nil {
return err
}
return nil
// Do not wrap the error because it's informative enough as is.
return setContextTLSCipherIDs()
}
// validateConfig returns error if the configuration is invalid.
@@ -587,11 +604,11 @@ func readConfigFile() (fileData []byte, err error) {
return config.fileData, nil
}
name := config.getConfigFilename()
log.Debug("reading config file: %s", name)
confPath := configFilePath()
log.Debug("reading config file %q", confPath)
// Do not wrap the error because it's informative enough as is.
return os.ReadFile(name)
return os.ReadFile(confPath)
}
// Saves configuration to the YAML file and also saves the user filter contents to a file
@@ -655,8 +672,8 @@ func (c *configuration) write() (err error) {
config.Clients.Persistent = Context.clients.forConfig()
configFile := config.getConfigFilename()
log.Debug("writing config file %q", configFile)
confPath := configFilePath()
log.Debug("writing config file %q", confPath)
buf := &bytes.Buffer{}
enc := yaml.NewEncoder(buf)
@@ -667,7 +684,7 @@ func (c *configuration) write() (err error) {
return fmt.Errorf("generating config file: %w", err)
}
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
err = maybe.WriteFile(confPath, buf.Bytes(), 0o644)
if err != nil {
return fmt.Errorf("writing config file: %w", err)
}

View File

@@ -144,10 +144,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
// Make sure that we don't send negative numbers to the frontend,
// since enough time might have passed to make the difference less
// than zero.
protectionDisabledDuration = max(
0,
time.Until(*protectionDisabledUntil).Milliseconds(),
)
protectionDisabledDuration = max(0, time.Until(*protectionDisabledUntil).Milliseconds())
}
resp = statusResponse{

View File

@@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
@@ -46,12 +47,15 @@ func onConfigModified() {
// server and initializes it at last. It also must not be called unless
// [config] and [Context] are initialized.
func initDNS() (err error) {
baseDir := Context.getDataDir()
anonymizer := config.anonymizer()
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
if err != nil {
return err
}
statsConf := stats.Config{
Filename: filepath.Join(baseDir, "stats.db"),
Filename: filepath.Join(statsDir, "stats.db"),
Limit: config.Stats.Interval.Duration,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
@@ -75,7 +79,7 @@ func initDNS() (err error) {
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
FindClient: Context.clients.findMultiple,
BaseDir: baseDir,
BaseDir: querylogDir,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
RotationIvl: config.QueryLog.Interval.Duration,
MemSize: config.QueryLog.MemSize,
@@ -154,6 +158,17 @@ func initDNSServer(
}
err = Context.dnsServer.Prepare(dnsConf)
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
// should go away once the private RDNS resolution is moved to the proxy.
var locResErr *dnsforward.LocalResolversError
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
log.Info("WARNING: no local resolvers configured while private RDNS " +
"resolution enabled, trying to disable")
dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf)
}
if err != nil {
return fmt.Errorf("dnsServer.Prepare: %w", err)
}
@@ -424,7 +439,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
}
setts.FilteringEnabled = c.FilteringEnabled
setts.SafeSearchEnabled = c.safeSearchConf.Enabled
setts.SafeSearchEnabled = c.SafeSearchConf.Enabled
setts.ClientSafeSearch = c.SafeSearch
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
setts.ParentalEnabled = c.ParentalEnabled
@@ -545,3 +560,50 @@ func (r safeSearchResolver) LookupIP(
return ips, nil
}
// checkStatsAndQuerylogDirs checks and returns directory paths to store
// statistics and query log.
func checkStatsAndQuerylogDirs(
ctx *homeContext,
conf *configuration,
) (statsDir, querylogDir string, err error) {
baseDir := ctx.getDataDir()
statsDir = conf.Stats.DirPath
if statsDir == "" {
statsDir = baseDir
} else {
err = checkDir(statsDir)
if err != nil {
return "", "", fmt.Errorf("statistics: custom directory: %w", err)
}
}
querylogDir = conf.QueryLog.DirPath
if querylogDir == "" {
querylogDir = baseDir
} else {
err = checkDir(querylogDir)
if err != nil {
return "", "", fmt.Errorf("querylog: custom directory: %w", err)
}
}
return statsDir, querylogDir, nil
}
// checkDir checks if the path is a directory. It's used to check for
// misconfiguration at startup.
func checkDir(path string) (err error) {
var fi os.FileInfo
if fi, err = os.Stat(path); err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
}
if !fi.IsDir() {
return fmt.Errorf("%q is not a directory", path)
}
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/stretchr/testify/assert"
@@ -12,6 +13,19 @@ import (
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*client.Persistent) (ci *client.Index) {
ci = client.NewIndex()
for _, c := range m {
c.UID = client.MustNewUID()
ci.Add(c)
}
return ci
}
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
@@ -22,29 +36,28 @@ func TestApplyAdditionalFiltering(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*persistentClient{
"default": {
UseOwnSettings: false,
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
FilteringEnabled: false,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
"custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
},
"partial_custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
}
Context.clients.clientIndex = newIDIndex([]*client.Persistent{{
ClientIDs: []string{"default"},
UseOwnSettings: false,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: false},
FilteringEnabled: false,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
}, {
ClientIDs: []string{"custom_filtering"},
UseOwnSettings: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}, {
ClientIDs: []string{"partial_custom_filtering"},
UseOwnSettings: true,
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
}})
testCases := []struct {
name string
@@ -108,38 +121,37 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*persistentClient{
"default": {
UseOwnBlockedServices: false,
Context.clients.clientIndex = newIDIndex([]*client.Persistent{{
ClientIDs: []string{"default"},
UseOwnBlockedServices: false,
}, {
ClientIDs: []string{"no_services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
"no_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
UseOwnBlockedServices: true,
UseOwnBlockedServices: true,
}, {
ClientIDs: []string{"services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: clientBlockedServices,
},
"services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
UseOwnBlockedServices: true,
}, {
ClientIDs: []string{"invalid_services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: invalidBlockedServices,
},
"invalid_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: invalidBlockedServices,
},
UseOwnBlockedServices: true,
UseOwnBlockedServices: true,
}, {
ClientIDs: []string{"allow_all"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.FullWeekly(),
IDs: clientBlockedServices,
},
"allow_all": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.FullWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
},
}
UseOwnBlockedServices: true,
}})
testCases := []struct {
name string

View File

@@ -14,6 +14,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"sync"
"syscall"
"time"
@@ -39,8 +40,6 @@ import (
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/slices"
)
// Global context
@@ -68,11 +67,14 @@ type homeContext struct {
// Runtime properties
// --
configFilename string // Config filename (can be overridden via the command line arguments)
workDir string // Location of our directory, used to protect against CWD being somewhere else
pidFileName string // PID file name. Empty if no PID file was created.
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
// confFilePath is the configuration file path as set by default or from the
// command-line options.
confFilePath string
workDir string // Location of our directory, used to protect against CWD being somewhere else
pidFileName string // PID file name. Empty if no PID file was created.
controlLock sync.Mutex
tlsRoots *x509.CertPool // list of root CAs for TLSv1.2
// tlsCipherIDs are the ID of the cipher suites that AdGuard Home must use.
tlsCipherIDs []uint16
@@ -250,7 +252,7 @@ func setupHostsContainer() (err error) {
return errors.Join(fmt.Errorf("initializing hosts container: %w", err), closeErr)
}
return nil
return hostsWatcher.Start()
}
// setupOpts sets up command-line options.
@@ -361,7 +363,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.EtcHosts = Context.etcHosts
// TODO(s.chzhen): Use empty interface.
if Context.etcHosts == nil {
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled {
conf.EtcHosts = nil
}
@@ -492,7 +494,14 @@ func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webA
}
}
disableUpdate := opts.disableUpdate || version.Channel() == version.ChannelDevelopment
disableUpdate := opts.disableUpdate
switch version.Channel() {
case
version.ChannelDevelopment,
version.ChannelCandidate:
disableUpdate = true
}
if disableUpdate {
log.Info("AdGuard Home updates are disabled")
}
@@ -575,6 +584,9 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
Path: path.Join("adguardhome", version.Channel(), "version.json"),
}
confPath := configFilePath()
log.Debug("using config path %q for updater", confPath)
upd := updater.NewUpdater(&updater.Config{
Client: config.Filtering.HTTPClient,
Version: version.Version(),
@@ -584,7 +596,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: Context.workDir,
ConfName: config.getConfigFilename(),
ConfName: confPath,
ExecPath: execPath,
VersionCheckURL: u.String(),
})
@@ -748,7 +760,16 @@ func writePIDFile(fn string) bool {
// initConfigFilename sets up context config file path. This file path can be
// overridden by command-line arguments, or is set to default.
func initConfigFilename(opts options) {
Context.configFilename = stringutil.Coalesce(opts.confFilename, "AdGuardHome.yaml")
confPath := opts.confFilename
if confPath == "" {
Context.confFilePath = "AdGuardHome.yaml"
return
}
log.Debug("config path overridden to %q from cmdline", confPath)
Context.confFilePath = confPath
}
// initWorkingDir initializes the workDir. If no command-line arguments are
@@ -906,16 +927,23 @@ func printHTTPAddresses(proto string) {
}
}
// -------------------
// first run / install
// -------------------
func detectFirstRun() bool {
configfile := Context.configFilename
if !filepath.IsAbs(configfile) {
configfile = filepath.Join(Context.workDir, Context.configFilename)
// detectFirstRun returns true if this is the first run of AdGuard Home.
func detectFirstRun() (ok bool) {
confPath := Context.confFilePath
if !filepath.IsAbs(confPath) {
confPath = filepath.Join(Context.workDir, Context.confFilePath)
}
_, err := os.Stat(configfile)
return errors.Is(err, os.ErrNotExist)
_, err := os.Stat(confPath)
if err == nil {
return false
} else if errors.Is(err, os.ErrNotExist) {
return true
}
log.Error("detecting first run: %s; considering first run", err)
return true
}
// jsonError is a generic JSON error response.

Some files were not shown because too many files have changed in this diff Show More