Compare commits
2 Commits
master
...
AGDNS-2743
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6109e3575f | ||
|
|
88706e9cf2 |
@@ -20,14 +20,10 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
|||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
- Command line option `--update` when the `dns.serve_plain_dns` configuration property was disabled ([7801]).
|
|
||||||
|
|
||||||
- DNS cache not working for custom upstream configurations.
|
- DNS cache not working for custom upstream configurations.
|
||||||
|
|
||||||
- Validation process for the DNS-over-TLS, DNS-over-QUIC, and HTTPS ports on the *Encryption Settings* page.
|
- Validation process for the DNS-over-TLS, DNS-over-QUIC, and HTTPS ports on the *Encryption Settings* page.
|
||||||
|
|
||||||
[#7801]: https://github.com/AdguardTeam/AdGuardHome/issues/7801
|
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||||
-->
|
-->
|
||||||
|
|||||||
12
go.mod
12
go.mod
@@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
|
|||||||
go 1.24.2
|
go 1.24.2
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.75.5
|
github.com/AdguardTeam/dnsproxy v0.75.4
|
||||||
github.com/AdguardTeam/golibs v0.32.9
|
github.com/AdguardTeam/golibs v0.32.8
|
||||||
github.com/AdguardTeam/urlfilter v0.20.0
|
github.com/AdguardTeam/urlfilter v0.20.0
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.4.0
|
github.com/ameshkov/dnscrypt/v2 v2.4.0
|
||||||
@@ -36,7 +36,7 @@ require (
|
|||||||
golang.org/x/crypto v0.37.0
|
golang.org/x/crypto v0.37.0
|
||||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
||||||
golang.org/x/net v0.39.0
|
golang.org/x/net v0.39.0
|
||||||
golang.org/x/sys v0.33.0
|
golang.org/x/sys v0.32.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
howett.net/plist v1.0.1
|
howett.net/plist v1.0.1
|
||||||
@@ -61,7 +61,7 @@ require (
|
|||||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||||
github.com/golangci/misspell v0.6.0 // indirect
|
github.com/golangci/misspell v0.6.0 // indirect
|
||||||
github.com/google/generative-ai-go v0.19.0 // indirect
|
github.com/google/generative-ai-go v0.19.0 // indirect
|
||||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect
|
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
|
||||||
github.com/google/s2a-go v0.1.9 // indirect
|
github.com/google/s2a-go v0.1.9 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||||
@@ -89,11 +89,11 @@ require (
|
|||||||
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
go.opentelemetry.io/otel/metric v1.35.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||||
go.uber.org/mock v0.5.2 // indirect
|
go.uber.org/mock v0.5.1 // indirect
|
||||||
golang.org/x/exp/typeparams v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
golang.org/x/exp/typeparams v0.0.0-20250408133849-7e4ce0ab07d0 // indirect
|
||||||
golang.org/x/mod v0.24.0 // indirect
|
golang.org/x/mod v0.24.0 // indirect
|
||||||
golang.org/x/oauth2 v0.29.0 // indirect
|
golang.org/x/oauth2 v0.29.0 // indirect
|
||||||
golang.org/x/sync v0.14.0 // indirect
|
golang.org/x/sync v0.13.0 // indirect
|
||||||
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 // indirect
|
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 // indirect
|
||||||
golang.org/x/term v0.31.0 // indirect
|
golang.org/x/term v0.31.0 // indirect
|
||||||
golang.org/x/text v0.24.0 // indirect
|
golang.org/x/text v0.24.0 // indirect
|
||||||
|
|||||||
24
go.sum
24
go.sum
@@ -10,10 +10,10 @@ cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4
|
|||||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||||
cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE=
|
cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE=
|
||||||
cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY=
|
cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY=
|
||||||
github.com/AdguardTeam/dnsproxy v0.75.5 h1:/P7+Ku4bjl+sVC/FW3PbT7pabgCjKTcrAOHqsZe2e60=
|
github.com/AdguardTeam/dnsproxy v0.75.4 h1:hTnHh9HoTYKKhKqePpIxCzfecl7dAXykZTw2gcj0I5U=
|
||||||
github.com/AdguardTeam/dnsproxy v0.75.5/go.mod h1:fdwtHhrDkTueDagDCasYKZbXdppkkBXW7RGPBNH+pis=
|
github.com/AdguardTeam/dnsproxy v0.75.4/go.mod h1:50OyTHao+uQzUJiXay08hgfvWQ3o2Q2WV99W8u8ypDE=
|
||||||
github.com/AdguardTeam/golibs v0.32.9 h1:/6luT0aMOn05/s9eh1yA4lbcHgl0d1iEEvEBbIMMUk0=
|
github.com/AdguardTeam/golibs v0.32.8 h1:O3mc3kYcPkW3kbmd+gqzFNgUka13a+iBgFLThwOYSQE=
|
||||||
github.com/AdguardTeam/golibs v0.32.9/go.mod h1:McV1QFFlKLElKa306V4OL/T2kr7564PhsayfvTWYBVs=
|
github.com/AdguardTeam/golibs v0.32.8/go.mod h1:McV1QFFlKLElKa306V4OL/T2kr7564PhsayfvTWYBVs=
|
||||||
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
|
github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs=
|
||||||
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
|
github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk=
|
||||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||||
@@ -72,8 +72,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
|||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a h1:rDA3FfmxwXR+BVKKdz55WwMJ1pD2hJQNW31d+l3mPk4=
|
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
|
||||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA=
|
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
|
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
|
||||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||||
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
|
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
|
||||||
@@ -199,8 +199,8 @@ go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt
|
|||||||
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
|
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
|
||||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
go.uber.org/mock v0.5.1 h1:ASgazW/qBmR+A32MYFDB6E2POoTgOwT509VP0CT/fjs=
|
||||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
go.uber.org/mock v0.5.1/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||||
@@ -227,8 +227,8 @@ golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
|
|||||||
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -241,8 +241,8 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
|
||||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 h1:RXY2+rSHXvxO2Y+gKrPjYVaEoGOqh3VEXFhnWAt1Irg=
|
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3 h1:RXY2+rSHXvxO2Y+gKrPjYVaEoGOqh3VEXFhnWAt1Irg=
|
||||||
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3/go.mod h1:RoaXAWDwS90j6FxVKwJdBV+0HCU+llrKUGgJaxiKl6M=
|
golang.org/x/telemetry v0.0.0-20250417124945-06ef541f3fa3/go.mod h1:RoaXAWDwS90j6FxVKwJdBV+0HCU+llrKUGgJaxiKl6M=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -477,7 +478,7 @@ const ErrBadIdentifier errors.Error = "bad client identifier"
|
|||||||
func (p *FindParams) Set(id string) (err error) {
|
func (p *FindParams) Set(id string) (err error) {
|
||||||
*p = FindParams{}
|
*p = FindParams{}
|
||||||
|
|
||||||
isFound := false
|
isClientID := true
|
||||||
|
|
||||||
if netutil.IsValidIPString(id) {
|
if netutil.IsValidIPString(id) {
|
||||||
// It is safe to use [netip.MustParseAddr] because it has already been
|
// It is safe to use [netip.MustParseAddr] because it has already been
|
||||||
@@ -487,27 +488,24 @@ func (p *FindParams) Set(id string) (err error) {
|
|||||||
|
|
||||||
// Even if id can be parsed as an IP address, it may be a MAC address.
|
// Even if id can be parsed as an IP address, it may be a MAC address.
|
||||||
// So do not return prematurely, continue parsing.
|
// So do not return prematurely, continue parsing.
|
||||||
isFound = true
|
isClientID = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if netutil.IsValidMACString(id) {
|
if canBeValidIPPrefixString(id) {
|
||||||
p.MAC, err = net.ParseMAC(id)
|
p.Subnet, err = netip.ParsePrefix(id)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
panic(fmt.Errorf("parsing mac from %q: %w", id, err))
|
isClientID = false
|
||||||
}
|
}
|
||||||
|
|
||||||
isFound = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isFound {
|
if canBeMACString(id) {
|
||||||
return nil
|
p.MAC, err = net.ParseMAC(id)
|
||||||
|
if err == nil {
|
||||||
|
isClientID = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if netutil.IsValidIPPrefixString(id) {
|
if !isClientID {
|
||||||
// It is safe to use [netip.MustParsePrefix] because it has already been
|
|
||||||
// validated that id contains the string representation of IP prefix.
|
|
||||||
p.Subnet = netip.MustParsePrefix(id)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -520,6 +518,57 @@ func (p *FindParams) Set(id string) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// canBeValidIPPrefixString is a best-effort check to determine if s is a valid
|
||||||
|
// CIDR before using [netip.ParsePrefix], aimed at reducing allocations.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Replace this implementation with the more robust version
|
||||||
|
// from golibs.
|
||||||
|
func canBeValidIPPrefixString(s string) (ok bool) {
|
||||||
|
ipStr, bitStr, ok := strings.Cut(s, "/")
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if bitStr == "" || len(bitStr) > 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bits := 0
|
||||||
|
for _, c := range bitStr {
|
||||||
|
if c < '0' || c > '9' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bits = bits*10 + int(c-'0')
|
||||||
|
}
|
||||||
|
|
||||||
|
if bits > 128 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return netutil.IsValidIPString(ipStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// canBeMACString is a best-effort check to determine if s is a valid MAC
|
||||||
|
// address before using [net.ParseMAC], aimed at reducing allocations.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Replace this implementation with the more robust version
|
||||||
|
// from golibs.
|
||||||
|
func canBeMACString(s string) (ok bool) {
|
||||||
|
switch len(s) {
|
||||||
|
case
|
||||||
|
len("0000.0000.0000"),
|
||||||
|
len("00:00:00:00:00:00"),
|
||||||
|
len("0000.0000.0000.0000"),
|
||||||
|
len("00:00:00:00:00:00:00:00"),
|
||||||
|
len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"),
|
||||||
|
len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"):
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Find represents the parameters for searching a client. params must not be
|
// Find represents the parameters for searching a client. params must not be
|
||||||
// nil and must have at least one non-empty field.
|
// nil and must have at least one non-empty field.
|
||||||
func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {
|
func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {
|
||||||
|
|||||||
@@ -1,19 +1,355 @@
|
|||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/golibs/httphdr"
|
"github.com/AdguardTeam/golibs/httphdr"
|
||||||
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(s.chzhen): !! Add more tests.
|
||||||
|
func TestAuth_ServeHTTP_first_run(t *testing.T) {
|
||||||
|
storeGlobals(t)
|
||||||
|
|
||||||
|
globalContext.firstRun = true
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
globalContext.mux = mux
|
||||||
|
|
||||||
|
var (
|
||||||
|
logger = slogutil.NewDiscardLogger()
|
||||||
|
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
web, err := initWeb(ctx, options{}, nil, nil, logger, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
globalContext.web = web
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
url string
|
||||||
|
method string
|
||||||
|
code int
|
||||||
|
}{{
|
||||||
|
url: "/",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/apple/doh.mobileconfig",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/apple/dot.mobileconfig",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/change_language",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/current_language",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/install/check_config",
|
||||||
|
method: http.MethodPost,
|
||||||
|
code: http.StatusBadRequest,
|
||||||
|
}, {
|
||||||
|
url: "/control/install/configure",
|
||||||
|
method: http.MethodPost,
|
||||||
|
code: http.StatusBadRequest,
|
||||||
|
}, {
|
||||||
|
url: "/control/install/get_addresses",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/login",
|
||||||
|
method: http.MethodPost,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/logout",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile/update",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/status",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/update",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/version.json",
|
||||||
|
method: http.MethodGet,
|
||||||
|
code: http.StatusFound,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.url, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(tc.method, tc.url, nil)
|
||||||
|
|
||||||
|
h, pattern := mux.Handler(r)
|
||||||
|
require.NotEmpty(t, pattern)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.code, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuth_ServeHTTP(t *testing.T) {
|
||||||
|
storeGlobals(t)
|
||||||
|
|
||||||
|
const (
|
||||||
|
authNone = iota
|
||||||
|
authBasic
|
||||||
|
authCookie
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testTTL = 60
|
||||||
|
userName = "name"
|
||||||
|
userPassword = "password"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
logger = slogutil.NewDiscardLogger()
|
||||||
|
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sessionsDB := filepath.Join(t.TempDir(), "sessions.db")
|
||||||
|
|
||||||
|
users := []webUser{{
|
||||||
|
Name: userName,
|
||||||
|
PasswordHash: string(passwordHash),
|
||||||
|
}}
|
||||||
|
auth := InitAuth(sessionsDB, users, testTTL, nil, nil)
|
||||||
|
globalContext.auth = auth
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
globalContext.mux = mux
|
||||||
|
|
||||||
|
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||||
|
logger: logger,
|
||||||
|
configModified: func() {},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
web, err := initWeb(ctx, options{}, nil, nil, logger, tlsMgr, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
globalContext.web = web
|
||||||
|
|
||||||
|
creds, err := json.Marshal(&loginJSON{Name: userName, Password: userPassword})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/control/login", bytes.NewReader(creds))
|
||||||
|
r.Header.Set(httphdr.ContentType, aghhttp.HdrValApplicationJSON)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
mux.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
var loginCookie *http.Cookie
|
||||||
|
for _, c := range w.Result().Cookies() {
|
||||||
|
if c.Name == sessionCookieName {
|
||||||
|
loginCookie = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, loginCookie)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
url string
|
||||||
|
method string
|
||||||
|
authMethod int
|
||||||
|
wantCode int
|
||||||
|
}{{
|
||||||
|
url: "/",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/change_language",
|
||||||
|
method: http.MethodPost,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/change_language",
|
||||||
|
method: http.MethodPost,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusInternalServerError,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/change_language",
|
||||||
|
method: http.MethodPost,
|
||||||
|
authMethod: authCookie,
|
||||||
|
wantCode: http.StatusInternalServerError,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/current_language",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/current_language",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/i18n/current_language",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authCookie,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/logout",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/logout",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authCookie,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile/update",
|
||||||
|
method: http.MethodPut,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile/update",
|
||||||
|
method: http.MethodPut,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusBadRequest,
|
||||||
|
}, {
|
||||||
|
url: "/control/profile/update",
|
||||||
|
method: http.MethodPut,
|
||||||
|
authMethod: authCookie,
|
||||||
|
wantCode: http.StatusBadRequest,
|
||||||
|
}, {
|
||||||
|
url: "/control/status",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/status",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/status",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authCookie,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/update",
|
||||||
|
method: http.MethodPost,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/version.json",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authNone,
|
||||||
|
wantCode: http.StatusForbidden,
|
||||||
|
}, {
|
||||||
|
url: "/control/version.json",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authBasic,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}, {
|
||||||
|
url: "/control/version.json",
|
||||||
|
method: http.MethodGet,
|
||||||
|
authMethod: authCookie,
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.url, func(t *testing.T) {
|
||||||
|
r = httptest.NewRequest(tc.method, tc.url, nil)
|
||||||
|
switch tc.authMethod {
|
||||||
|
case authNone:
|
||||||
|
// Go on.
|
||||||
|
case authBasic:
|
||||||
|
r.SetBasicAuth(userName, userPassword)
|
||||||
|
case authCookie:
|
||||||
|
r.AddCookie(loginCookie)
|
||||||
|
default:
|
||||||
|
panic("unrecognized auth method")
|
||||||
|
}
|
||||||
|
|
||||||
|
h, pattern := mux.Handler(r)
|
||||||
|
require.NotEmpty(t, pattern)
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantCode, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("logout", func(t *testing.T) {
|
||||||
|
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
|
||||||
|
r.AddCookie(loginCookie)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
|
||||||
|
mux.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
r = httptest.NewRequest(http.MethodGet, "/control/logout", nil)
|
||||||
|
r.AddCookie(loginCookie)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
|
||||||
|
mux.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, http.StatusFound, w.Code)
|
||||||
|
|
||||||
|
r = httptest.NewRequest(http.MethodGet, "/control/status", nil)
|
||||||
|
r.AddCookie(loginCookie)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
|
||||||
|
mux.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// implements http.ResponseWriter
|
// implements http.ResponseWriter
|
||||||
type testResponseWriter struct {
|
type testResponseWriter struct {
|
||||||
hdr http.Header
|
hdr http.Header
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func (web *webAPI) requestVersionInfo(
|
|||||||
) (err error) {
|
) (err error) {
|
||||||
updater := web.conf.updater
|
updater := web.conf.updater
|
||||||
for range 3 {
|
for range 3 {
|
||||||
resp.VersionInfo, err = updater.VersionInfo(ctx, recheck)
|
resp.VersionInfo, err = updater.VersionInfo(recheck)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -133,7 +133,7 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = updater.Update(r.Context(), false)
|
err = updater.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||||
|
|
||||||
|
|||||||
@@ -119,15 +119,16 @@ func initDNS(
|
|||||||
globalContext.dhcpServer,
|
globalContext.dhcpServer,
|
||||||
anonymizer,
|
anonymizer,
|
||||||
httpRegister,
|
httpRegister,
|
||||||
|
tlsMgr.config(),
|
||||||
tlsMgr,
|
tlsMgr,
|
||||||
baseLogger,
|
baseLogger,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
// initDNSServer initializes the [context.dnsServer]. To only use the internal
|
||||||
// proxy, none of the arguments are required, but tlsMgr and l still must not be
|
// proxy, none of the arguments are required, but tlsConf, tlsMgr and l still
|
||||||
// nil, in other cases all the arguments also must not be nil. It also must not
|
// must not be nil, in other cases all the arguments also must not be nil. It
|
||||||
// be called unless [config] and [globalContext] are initialized.
|
// also must not be called unless [config] and [globalContext] are initialized.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||||
func initDNSServer(
|
func initDNSServer(
|
||||||
@@ -137,6 +138,7 @@ func initDNSServer(
|
|||||||
dhcpSrv dnsforward.DHCP,
|
dhcpSrv dnsforward.DHCP,
|
||||||
anonymizer *aghnet.IPMut,
|
anonymizer *aghnet.IPMut,
|
||||||
httpReg aghhttp.RegisterFunc,
|
httpReg aghhttp.RegisterFunc,
|
||||||
|
tlsConf *tlsConfigSettings,
|
||||||
tlsMgr *tlsManager,
|
tlsMgr *tlsManager,
|
||||||
l *slog.Logger,
|
l *slog.Logger,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
@@ -165,7 +167,7 @@ func initDNSServer(
|
|||||||
dnsConf, err := newServerConfig(
|
dnsConf, err := newServerConfig(
|
||||||
&config.DNS,
|
&config.DNS,
|
||||||
config.Clients.Sources,
|
config.Clients.Sources,
|
||||||
tlsMgr.config(),
|
tlsConf,
|
||||||
tlsMgr,
|
tlsMgr,
|
||||||
httpReg,
|
httpReg,
|
||||||
globalContext.clients.storage,
|
globalContext.clients.storage,
|
||||||
|
|||||||
@@ -487,14 +487,9 @@ func checkPorts() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isUpdateEnabled returns true if the update is enabled for current
|
// isUpdateEnabled returns true if the update is enabled for current
|
||||||
// configuration. It also logs the decision. isCustomURL should be true if the
|
// configuration. It also logs the decision. customURL should be true if the
|
||||||
// updater is using a custom URL.
|
// updater is using a custom URL.
|
||||||
func isUpdateEnabled(
|
func isUpdateEnabled(ctx context.Context, l *slog.Logger, opts *options, customURL bool) (ok bool) {
|
||||||
ctx context.Context,
|
|
||||||
l *slog.Logger,
|
|
||||||
opts *options,
|
|
||||||
isCustomURL bool,
|
|
||||||
) (ok bool) {
|
|
||||||
if opts.disableUpdate {
|
if opts.disableUpdate {
|
||||||
l.DebugContext(ctx, "updates are disabled by command-line option")
|
l.DebugContext(ctx, "updates are disabled by command-line option")
|
||||||
|
|
||||||
@@ -505,13 +500,13 @@ func isUpdateEnabled(
|
|||||||
case
|
case
|
||||||
version.ChannelDevelopment,
|
version.ChannelDevelopment,
|
||||||
version.ChannelCandidate:
|
version.ChannelCandidate:
|
||||||
if isCustomURL {
|
if customURL {
|
||||||
l.DebugContext(ctx, "updates are enabled because custom url is used")
|
l.DebugContext(ctx, "updates are enabled because custom url is used")
|
||||||
} else {
|
} else {
|
||||||
l.DebugContext(ctx, "updates are disabled for development and candidate builds")
|
l.DebugContext(ctx, "updates are disabled for development and candidate builds")
|
||||||
}
|
}
|
||||||
|
|
||||||
return isCustomURL
|
return customURL
|
||||||
default:
|
default:
|
||||||
l.DebugContext(ctx, "updates are enabled")
|
l.DebugContext(ctx, "updates are enabled")
|
||||||
|
|
||||||
@@ -519,7 +514,7 @@ func isUpdateEnabled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
|
// initWeb initializes the web module. upd, baseLogger, and tlsMgr must not be
|
||||||
// nil.
|
// nil.
|
||||||
func initWeb(
|
func initWeb(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -528,7 +523,7 @@ func initWeb(
|
|||||||
upd *updater.Updater,
|
upd *updater.Updater,
|
||||||
baseLogger *slog.Logger,
|
baseLogger *slog.Logger,
|
||||||
tlsMgr *tlsManager,
|
tlsMgr *tlsManager,
|
||||||
isCustomUpdURL bool,
|
customURL bool,
|
||||||
) (web *webAPI, err error) {
|
) (web *webAPI, err error) {
|
||||||
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
|
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
|
||||||
|
|
||||||
@@ -544,7 +539,7 @@ func initWeb(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL)
|
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, customURL)
|
||||||
|
|
||||||
webConf := &webConfig{
|
webConf := &webConfig{
|
||||||
updater: upd,
|
updater: upd,
|
||||||
@@ -650,12 +645,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
|||||||
|
|
||||||
confPath := configFilePath()
|
confPath := configFilePath()
|
||||||
|
|
||||||
updLogger := slogLogger.With(slogutil.KeyPrefix, "updater")
|
upd, customURL := newUpdater(ctx, slogLogger, globalContext.workDir, confPath, execPath, config)
|
||||||
upd, isCustomURL := newUpdater(ctx, updLogger, config, globalContext.workDir, confPath, execPath)
|
|
||||||
|
|
||||||
// TODO(e.burkov): This could be made earlier, probably as the option's
|
// TODO(e.burkov): This could be made earlier, probably as the option's
|
||||||
// effect.
|
// effect.
|
||||||
cmdlineUpdate(ctx, updLogger, opts, upd, tlsMgr)
|
cmdlineUpdate(ctx, slogLogger, opts, upd, tlsMgr)
|
||||||
|
|
||||||
if !globalContext.firstRun {
|
if !globalContext.firstRun {
|
||||||
// Save the updated config.
|
// Save the updated config.
|
||||||
@@ -677,7 +671,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
|||||||
globalContext.auth, err = initUsers()
|
globalContext.auth, err = initUsers()
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, isCustomURL)
|
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, customURL)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
globalContext.web = web
|
globalContext.web = web
|
||||||
@@ -720,17 +714,16 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
|
|||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
// newUpdater creates a new AdGuard Home updater. l and conf must not be nil.
|
// newUpdater creates a new AdGuard Home updater. customURL is true if the user
|
||||||
// workDir, confPath, and execPath must not be empty. isCustomURL is true if
|
// has specified a custom version announcement URL.
|
||||||
// the user has specified a custom version announcement URL.
|
|
||||||
func newUpdater(
|
func newUpdater(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
l *slog.Logger,
|
l *slog.Logger,
|
||||||
conf *configuration,
|
|
||||||
workDir string,
|
workDir string,
|
||||||
confPath string,
|
confPath string,
|
||||||
execPath string,
|
execPath string,
|
||||||
) (upd *updater.Updater, isCustomURL bool) {
|
config *configuration,
|
||||||
|
) (upd *updater.Updater, customURL bool) {
|
||||||
// envName is the name of the environment variable that can be used to
|
// envName is the name of the environment variable that can be used to
|
||||||
// override the default version check URL.
|
// override the default version check URL.
|
||||||
const envName = "ADGUARD_HOME_TEST_UPDATE_VERSION_URL"
|
const envName = "ADGUARD_HOME_TEST_UPDATE_VERSION_URL"
|
||||||
@@ -742,14 +735,14 @@ func newUpdater(
|
|||||||
case version.Channel() == version.ChannelRelease:
|
case version.Channel() == version.ChannelRelease:
|
||||||
// Only enable custom version URL for development builds.
|
// Only enable custom version URL for development builds.
|
||||||
l.DebugContext(ctx, "custom version url is disabled for release builds")
|
l.DebugContext(ctx, "custom version url is disabled for release builds")
|
||||||
case !conf.UnsafeUseCustomUpdateIndexURL:
|
case !config.UnsafeUseCustomUpdateIndexURL:
|
||||||
l.DebugContext(ctx, "custom version url is disabled in config")
|
l.DebugContext(ctx, "custom version url is disabled in config")
|
||||||
default:
|
default:
|
||||||
versionURL, _ = url.Parse(customURLStr)
|
versionURL, _ = url.Parse(customURLStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := urlutil.ValidateHTTPURL(versionURL)
|
err := urlutil.ValidateHTTPURL(versionURL)
|
||||||
if isCustomURL = err == nil; !isCustomURL {
|
if customURL = err == nil; !customURL {
|
||||||
l.DebugContext(ctx, "parsing custom version url", slogutil.KeyError, err)
|
l.DebugContext(ctx, "parsing custom version url", slogutil.KeyError, err)
|
||||||
|
|
||||||
versionURL = updater.DefaultVersionURL()
|
versionURL = updater.DefaultVersionURL()
|
||||||
@@ -758,8 +751,7 @@ func newUpdater(
|
|||||||
l.DebugContext(ctx, "creating updater", "config_path", confPath)
|
l.DebugContext(ctx, "creating updater", "config_path", confPath)
|
||||||
|
|
||||||
return updater.NewUpdater(&updater.Config{
|
return updater.NewUpdater(&updater.Config{
|
||||||
Client: conf.Filtering.HTTPClient,
|
Client: config.Filtering.HTTPClient,
|
||||||
Logger: l,
|
|
||||||
Version: version.Version(),
|
Version: version.Version(),
|
||||||
Channel: version.Channel(),
|
Channel: version.Channel(),
|
||||||
GOARCH: runtime.GOARCH,
|
GOARCH: runtime.GOARCH,
|
||||||
@@ -770,7 +762,7 @@ func newUpdater(
|
|||||||
ConfName: confPath,
|
ConfName: confPath,
|
||||||
ExecPath: execPath,
|
ExecPath: execPath,
|
||||||
VersionCheckURL: versionURL,
|
VersionCheckURL: versionURL,
|
||||||
}), isCustomURL
|
}), customURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkPermissions checks and migrates permissions of the files and directories
|
// checkPermissions checks and migrates permissions of the files and directories
|
||||||
@@ -1086,12 +1078,12 @@ func cmdlineUpdate(
|
|||||||
//
|
//
|
||||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||||
// separately.
|
// separately.
|
||||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, tlsMgr, l)
|
err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, tlsMgr, l)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
l.InfoContext(ctx, "performing update via cli")
|
l.InfoContext(ctx, "performing update via cli")
|
||||||
|
|
||||||
info, err := upd.VersionInfo(ctx, true)
|
info, err := upd.VersionInfo(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.ErrorContext(ctx, "getting version info", slogutil.KeyError, err)
|
l.ErrorContext(ctx, "getting version info", slogutil.KeyError, err)
|
||||||
|
|
||||||
@@ -1104,7 +1096,7 @@ func cmdlineUpdate(
|
|||||||
os.Exit(osutil.ExitCodeSuccess)
|
os.Exit(osutil.ExitCodeSuccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = upd.Update(ctx, globalContext.firstRun)
|
err = upd.Update(globalContext.firstRun)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
err = restartService()
|
err = restartService()
|
||||||
|
|||||||
@@ -193,10 +193,7 @@ func (m *tlsManager) start(_ context.Context) {
|
|||||||
m.web.tlsConfigChanged(context.Background(), m.conf)
|
m.web.tlsConfigChanged(context.Background(), m.conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reload updates the configuration and restarts the TLS manager. It logs any
|
// reload updates the configuration and restarts the TLS manager.
|
||||||
// encountered errors.
|
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Consider returning an error.
|
|
||||||
func (m *tlsManager) reload(ctx context.Context) {
|
func (m *tlsManager) reload(ctx context.Context) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|||||||
@@ -113,10 +113,13 @@ func TestValidateCertificates(t *testing.T) {
|
|||||||
// restores them once the test is complete.
|
// restores them once the test is complete.
|
||||||
//
|
//
|
||||||
// The global variables are:
|
// The global variables are:
|
||||||
// - [configuration.dns]
|
// - [configuration]
|
||||||
|
// - [homeContext.auth]
|
||||||
// - [homeContext.clients.storage]
|
// - [homeContext.clients.storage]
|
||||||
// - [homeContext.dnsServer]
|
// - [homeContext.dnsServer]
|
||||||
|
// - [homeContext.firstRun]
|
||||||
// - [homeContext.mux]
|
// - [homeContext.mux]
|
||||||
|
// - [homeContext.web]
|
||||||
//
|
//
|
||||||
// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global
|
// TODO(s.chzhen): Remove this once the TLS manager no longer accesses global
|
||||||
// variables. Make tests that use this helper concurrent.
|
// variables. Make tests that use this helper concurrent.
|
||||||
@@ -124,15 +127,21 @@ func storeGlobals(tb testing.TB) {
|
|||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
prevConfig := config
|
prevConfig := config
|
||||||
|
auth := globalContext.auth
|
||||||
storage := globalContext.clients.storage
|
storage := globalContext.clients.storage
|
||||||
dnsServer := globalContext.dnsServer
|
dnsServer := globalContext.dnsServer
|
||||||
|
firstRun := globalContext.firstRun
|
||||||
mux := globalContext.mux
|
mux := globalContext.mux
|
||||||
|
web := globalContext.web
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
config = prevConfig
|
config = prevConfig
|
||||||
|
globalContext.auth = auth
|
||||||
globalContext.clients.storage = storage
|
globalContext.clients.storage = storage
|
||||||
globalContext.dnsServer = dnsServer
|
globalContext.dnsServer = dnsServer
|
||||||
|
globalContext.firstRun = firstRun
|
||||||
globalContext.mux = mux
|
globalContext.mux = mux
|
||||||
|
globalContext.web = web
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package updater
|
package updater
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -13,6 +12,7 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/ioutil"
|
"github.com/AdguardTeam/golibs/ioutil"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/c2h5oh/datasize"
|
"github.com/c2h5oh/datasize"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ const maxVersionRespSize datasize.ByteSize = 64 * datasize.KB
|
|||||||
|
|
||||||
// VersionInfo downloads the latest version information. If forceRecheck is
|
// VersionInfo downloads the latest version information. If forceRecheck is
|
||||||
// false and there are cached results, those results are returned.
|
// false and there are cached results, those results are returned.
|
||||||
func (u *Updater) VersionInfo(ctx context.Context, forceRecheck bool) (vi VersionInfo, err error) {
|
func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||||
u.mu.Lock()
|
u.mu.Lock()
|
||||||
defer u.mu.Unlock()
|
defer u.mu.Unlock()
|
||||||
|
|
||||||
@@ -45,17 +45,11 @@ func (u *Updater) VersionInfo(ctx context.Context, forceRecheck bool) (vi Versio
|
|||||||
return u.prevCheckResult, u.prevCheckError
|
return u.prevCheckResult, u.prevCheckError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
vcu := u.versionCheckURL
|
vcu := u.versionCheckURL
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, vcu, nil)
|
resp, err = u.client.Get(vcu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return VersionInfo{}, fmt.Errorf("constructing request to %s: %w", vcu, err)
|
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||||
}
|
|
||||||
|
|
||||||
u.logger.DebugContext(ctx, "requesting version data", "url", vcu)
|
|
||||||
|
|
||||||
resp, err := u.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return VersionInfo{}, fmt.Errorf("requesting %s: %w", vcu, err)
|
|
||||||
}
|
}
|
||||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||||
|
|
||||||
@@ -65,16 +59,16 @@ func (u *Updater) VersionInfo(ctx context.Context, forceRecheck bool) (vi Versio
|
|||||||
// ReadCloser.
|
// ReadCloser.
|
||||||
body, err := io.ReadAll(r)
|
body, err := io.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return VersionInfo{}, fmt.Errorf("reading response from %s: %w", vcu, err)
|
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.prevCheckTime = now
|
u.prevCheckTime = now
|
||||||
u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(ctx, body)
|
u.prevCheckResult, u.prevCheckError = u.parseVersionResponse(body)
|
||||||
|
|
||||||
return u.prevCheckResult, u.prevCheckError
|
return u.prevCheckResult, u.prevCheckError
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) parseVersionResponse(ctx context.Context, data []byte) (VersionInfo, error) {
|
func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) {
|
||||||
info := VersionInfo{
|
info := VersionInfo{
|
||||||
CanAutoUpdate: aghalg.NBFalse,
|
CanAutoUpdate: aghalg.NBFalse,
|
||||||
}
|
}
|
||||||
@@ -98,7 +92,7 @@ func (u *Updater) parseVersionResponse(ctx context.Context, data []byte) (Versio
|
|||||||
info.Announcement = versionJSON["announcement"]
|
info.Announcement = versionJSON["announcement"]
|
||||||
info.AnnouncementURL = versionJSON["announcement_url"]
|
info.AnnouncementURL = versionJSON["announcement_url"]
|
||||||
|
|
||||||
packageURL, key, found := u.downloadURL(ctx, versionJSON)
|
packageURL, key, found := u.downloadURL(versionJSON)
|
||||||
if !found {
|
if !found {
|
||||||
return info, fmt.Errorf("version.json: no package URL: key %q not found in object", key)
|
return info, fmt.Errorf("version.json: no package URL: key %q not found in object", key)
|
||||||
}
|
}
|
||||||
@@ -114,10 +108,7 @@ func (u *Updater) parseVersionResponse(ctx context.Context, data []byte) (Versio
|
|||||||
// downloadURL returns the download URL for current build as well as its key in
|
// downloadURL returns the download URL for current build as well as its key in
|
||||||
// versionObj. If the key is not found, it additionally prints an informative
|
// versionObj. If the key is not found, it additionally prints an informative
|
||||||
// log message.
|
// log message.
|
||||||
func (u *Updater) downloadURL(
|
func (u *Updater) downloadURL(versionObj map[string]string) (dlURL, key string, ok bool) {
|
||||||
ctx context.Context,
|
|
||||||
versionObj map[string]string,
|
|
||||||
) (dlURL, key string, ok bool) {
|
|
||||||
if u.goarch == "arm" && u.goarm != "" {
|
if u.goarch == "arm" && u.goarm != "" {
|
||||||
key = fmt.Sprintf("download_%s_%sv%s", u.goos, u.goarch, u.goarm)
|
key = fmt.Sprintf("download_%s_%sv%s", u.goos, u.goarch, u.goarm)
|
||||||
} else if isMIPS(u.goarch) && u.gomips != "" {
|
} else if isMIPS(u.goarch) && u.gomips != "" {
|
||||||
@@ -133,7 +124,7 @@ func (u *Updater) downloadURL(
|
|||||||
|
|
||||||
keys := slices.Sorted(maps.Keys(versionObj))
|
keys := slices.Sorted(maps.Keys(versionObj))
|
||||||
|
|
||||||
u.logger.ErrorContext(ctx, "key not found", "missing", key, "got", keys)
|
log.Error("updater: key %q not found; got keys %q", key, keys)
|
||||||
|
|
||||||
return "", key, false
|
return "", key, false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -59,7 +58,6 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
|||||||
|
|
||||||
u := updater.NewUpdater(&updater.Config{
|
u := updater.NewUpdater(&updater.Config{
|
||||||
Client: srv.Client(),
|
Client: srv.Client(),
|
||||||
Logger: testLogger,
|
|
||||||
Version: "v0.103.0-beta.1",
|
Version: "v0.103.0-beta.1",
|
||||||
Channel: version.ChannelBeta,
|
Channel: version.ChannelBeta,
|
||||||
GOARCH: "arm",
|
GOARCH: "arm",
|
||||||
@@ -67,8 +65,7 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
|||||||
VersionCheckURL: fakeURL,
|
VersionCheckURL: fakeURL,
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
info, err := u.VersionInfo(false)
|
||||||
info, err := u.VersionInfo(ctx, false)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, counter, 1)
|
assert.Equal(t, counter, 1)
|
||||||
@@ -78,14 +75,14 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
|||||||
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
assert.Equal(t, aghalg.NBTrue, info.CanAutoUpdate)
|
||||||
|
|
||||||
t.Run("cache_check", func(t *testing.T) {
|
t.Run("cache_check", func(t *testing.T) {
|
||||||
_, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), false)
|
_, err = u.VersionInfo(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, counter, 1)
|
assert.Equal(t, counter, 1)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("force_check", func(t *testing.T) {
|
t.Run("force_check", func(t *testing.T) {
|
||||||
_, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), true)
|
_, err = u.VersionInfo(true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, counter, 2)
|
assert.Equal(t, counter, 2)
|
||||||
@@ -94,7 +91,7 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
|||||||
t.Run("api_fail", func(t *testing.T) {
|
t.Run("api_fail", func(t *testing.T) {
|
||||||
srv.Close()
|
srv.Close()
|
||||||
|
|
||||||
_, err = u.VersionInfo(testutil.ContextWithTimeout(t, testTimeout), true)
|
_, err = u.VersionInfo(true)
|
||||||
var urlErr *url.Error
|
var urlErr *url.Error
|
||||||
assert.ErrorAs(t, err, &urlErr)
|
assert.ErrorAs(t, err, &urlErr)
|
||||||
})
|
})
|
||||||
@@ -133,7 +130,6 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
u := updater.NewUpdater(&updater.Config{
|
u := updater.NewUpdater(&updater.Config{
|
||||||
Client: fakeClient,
|
Client: fakeClient,
|
||||||
Logger: testLogger,
|
|
||||||
Version: "v0.103.0-beta.1",
|
Version: "v0.103.0-beta.1",
|
||||||
Channel: version.ChannelBeta,
|
Channel: version.ChannelBeta,
|
||||||
GOOS: "linux",
|
GOOS: "linux",
|
||||||
@@ -143,8 +139,7 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
|||||||
VersionCheckURL: fakeURL,
|
VersionCheckURL: fakeURL,
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
info, err := u.VersionInfo(false)
|
||||||
info, err := u.VersionInfo(ctx, false)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
assert.Equal(t, "v0.103.0-beta.2", info.NewVersion)
|
||||||
|
|||||||
@@ -5,11 +5,9 @@ import (
|
|||||||
"archive/tar"
|
"archive/tar"
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -24,14 +22,13 @@ import (
|
|||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/ioutil"
|
"github.com/AdguardTeam/golibs/ioutil"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Updater is the AdGuard Home updater.
|
// Updater is the AdGuard Home updater.
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
logger *slog.Logger
|
|
||||||
|
|
||||||
version string
|
version string
|
||||||
channel string
|
channel string
|
||||||
@@ -78,48 +75,27 @@ func DefaultVersionURL() *url.URL {
|
|||||||
|
|
||||||
// Config is the AdGuard Home updater configuration.
|
// Config is the AdGuard Home updater configuration.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// Client is used to perform HTTP requests. It must not be nil.
|
|
||||||
Client *http.Client
|
Client *http.Client
|
||||||
|
|
||||||
// Logger is used for logging the update process. It must not be nil.
|
|
||||||
Logger *slog.Logger
|
|
||||||
|
|
||||||
// VersionCheckURL is URL to the latest version announcement. It must not
|
// VersionCheckURL is URL to the latest version announcement. It must not
|
||||||
// be nil, see [DefaultVersionURL].
|
// be nil, see [DefaultVersionURL].
|
||||||
VersionCheckURL *url.URL
|
VersionCheckURL *url.URL
|
||||||
|
|
||||||
// Version is the current AdGuard Home version. It must not be empty.
|
|
||||||
Version string
|
Version string
|
||||||
|
|
||||||
// Channel is the current AdGuard Home update channel. It must be a valid
|
|
||||||
// channel, see [version.ChannelBeta] and the related constants.
|
|
||||||
Channel string
|
Channel string
|
||||||
|
GOARCH string
|
||||||
|
GOOS string
|
||||||
|
GOARM string
|
||||||
|
GOMIPS string
|
||||||
|
|
||||||
// GOARCH is the current CPU architecture. It must not be empty and must be
|
// ConfName is the name of the current configuration file. Typically,
|
||||||
// one of the supported architectures.
|
// "AdGuardHome.yaml".
|
||||||
GOARCH string
|
|
||||||
|
|
||||||
// GOOS is the current operating system. It must not be empty and must be
|
|
||||||
// one of the supported OSs.
|
|
||||||
GOOS string
|
|
||||||
|
|
||||||
// GOARM is the current ARM variant, if any. It must either be empty or be
|
|
||||||
// a valid and supported GOARM value.
|
|
||||||
GOARM string
|
|
||||||
|
|
||||||
// GOMIPS is the current MIPS variant, if any. It must either be empty or
|
|
||||||
// be a valid and supported GOMIPS value.
|
|
||||||
GOMIPS string
|
|
||||||
|
|
||||||
// ConfName is the name of the current configuration file. It must not be
|
|
||||||
// empty.
|
|
||||||
ConfName string
|
ConfName string
|
||||||
|
|
||||||
// WorkDir is the working directory that is used for temporary files. It
|
// WorkDir is the working directory that is used for temporary files.
|
||||||
// must not be empty.
|
|
||||||
WorkDir string
|
WorkDir string
|
||||||
|
|
||||||
// ExecPath is path to the executable file. It must not be empty.
|
// ExecPath is path to the executable file.
|
||||||
ExecPath string
|
ExecPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,7 +103,6 @@ type Config struct {
|
|||||||
func NewUpdater(conf *Config) *Updater {
|
func NewUpdater(conf *Config) *Updater {
|
||||||
return &Updater{
|
return &Updater{
|
||||||
client: conf.Client,
|
client: conf.Client,
|
||||||
logger: conf.Logger,
|
|
||||||
|
|
||||||
version: conf.Version,
|
version: conf.Version,
|
||||||
channel: conf.Channel,
|
channel: conf.Channel,
|
||||||
@@ -147,49 +122,49 @@ func NewUpdater(conf *Config) *Updater {
|
|||||||
|
|
||||||
// Update performs the auto-update. It returns an error if the update failed.
|
// Update performs the auto-update. It returns an error if the update failed.
|
||||||
// If firstRun is true, it assumes the configuration file doesn't exist.
|
// If firstRun is true, it assumes the configuration file doesn't exist.
|
||||||
func (u *Updater) Update(ctx context.Context, firstRun bool) (err error) {
|
func (u *Updater) Update(firstRun bool) (err error) {
|
||||||
u.mu.Lock()
|
u.mu.Lock()
|
||||||
defer u.mu.Unlock()
|
defer u.mu.Unlock()
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "staring update", "first_run", firstRun)
|
log.Info("updater: updating")
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.logger.ErrorContext(ctx, "update failed", slogutil.KeyError, err)
|
log.Info("updater: failed")
|
||||||
} else {
|
} else {
|
||||||
u.logger.InfoContext(ctx, "update finished")
|
log.Info("updater: finished successfully")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = u.prepare(ctx)
|
err = u.prepare()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("preparing: %w", err)
|
return fmt.Errorf("preparing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer u.clean(ctx)
|
defer u.clean()
|
||||||
|
|
||||||
err = u.downloadPackageFile(ctx)
|
err = u.downloadPackageFile()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("downloading package file: %w", err)
|
return fmt.Errorf("downloading package file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = u.unpack(ctx)
|
err = u.unpack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unpacking: %w", err)
|
return fmt.Errorf("unpacking: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !firstRun {
|
if !firstRun {
|
||||||
err = u.check(ctx)
|
err = u.check()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("checking config: %w", err)
|
return fmt.Errorf("checking config: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = u.backup(ctx, firstRun)
|
err = u.backup(firstRun)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("making backup: %w", err)
|
return fmt.Errorf("making backup: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = u.replace(ctx)
|
err = u.replace()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("replacing: %w", err)
|
return fmt.Errorf("replacing: %w", err)
|
||||||
}
|
}
|
||||||
@@ -206,7 +181,7 @@ func (u *Updater) NewVersion() (nv string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepare fills all necessary fields in Updater object.
|
// prepare fills all necessary fields in Updater object.
|
||||||
func (u *Updater) prepare(ctx context.Context) (err error) {
|
func (u *Updater) prepare() (err error) {
|
||||||
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion))
|
||||||
|
|
||||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||||
@@ -225,12 +200,11 @@ func (u *Updater) prepare(ctx context.Context) (err error) {
|
|||||||
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(u.execPath))
|
u.backupExeName = filepath.Join(u.backupDir, filepath.Base(u.execPath))
|
||||||
u.updateExeName = filepath.Join(u.updateDir, updateExeName)
|
u.updateExeName = filepath.Join(u.updateDir, updateExeName)
|
||||||
|
|
||||||
u.logger.InfoContext(
|
log.Debug(
|
||||||
ctx,
|
"updater: updating from %s to %s using url: %s",
|
||||||
"updating",
|
version.Version(),
|
||||||
"from", version.Version(),
|
u.newVersion,
|
||||||
"to", u.newVersion,
|
u.packageURL,
|
||||||
"package_url", u.packageURL,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
u.currentExeName = u.execPath
|
u.currentExeName = u.execPath
|
||||||
@@ -243,20 +217,23 @@ func (u *Updater) prepare(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unpack extracts the files from the downloaded archive.
|
// unpack extracts the files from the downloaded archive.
|
||||||
func (u *Updater) unpack(ctx context.Context) (err error) {
|
func (u *Updater) unpack() error {
|
||||||
|
var err error
|
||||||
_, pkgNameOnly := filepath.Split(u.packageURL)
|
_, pkgNameOnly := filepath.Split(u.packageURL)
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "unpacking package", "package_name", pkgNameOnly)
|
log.Debug("updater: unpacking package")
|
||||||
if strings.HasSuffix(pkgNameOnly, ".zip") {
|
if strings.HasSuffix(pkgNameOnly, ".zip") {
|
||||||
u.unpackedFiles, err = u.unpackZip(ctx, u.packageName, u.updateDir)
|
u.unpackedFiles, err = zipFileUnpack(u.packageName, u.updateDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(".zip unpack failed: %w", err)
|
return fmt.Errorf(".zip unpack failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if strings.HasSuffix(pkgNameOnly, ".tar.gz") {
|
} else if strings.HasSuffix(pkgNameOnly, ".tar.gz") {
|
||||||
u.unpackedFiles, err = u.unpackTarGz(ctx, u.packageName, u.updateDir)
|
u.unpackedFiles, err = tarGzFileUnpack(u.packageName, u.updateDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(".tar.gz unpack failed: %w", err)
|
return fmt.Errorf(".tar.gz unpack failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("unknown package extension")
|
return fmt.Errorf("unknown package extension")
|
||||||
}
|
}
|
||||||
@@ -266,8 +243,8 @@ func (u *Updater) unpack(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
// check returns an error if the configuration file couldn't be used with the
|
// check returns an error if the configuration file couldn't be used with the
|
||||||
// version of AdGuard Home just downloaded.
|
// version of AdGuard Home just downloaded.
|
||||||
func (u *Updater) check(ctx context.Context) (err error) {
|
func (u *Updater) check() (err error) {
|
||||||
u.logger.InfoContext(ctx, "checking configuration")
|
log.Debug("updater: checking configuration")
|
||||||
|
|
||||||
err = copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
err = copyFile(u.confName, filepath.Join(u.updateDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -291,9 +268,8 @@ func (u *Updater) check(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
// backup makes a backup of the current configuration and supporting files. It
|
// backup makes a backup of the current configuration and supporting files. It
|
||||||
// ignores the configuration file if firstRun is true.
|
// ignores the configuration file if firstRun is true.
|
||||||
func (u *Updater) backup(ctx context.Context, firstRun bool) (err error) {
|
func (u *Updater) backup(firstRun bool) (err error) {
|
||||||
u.logger.InfoContext(ctx, "backing up current configuration")
|
log.Debug("updater: backing up current configuration")
|
||||||
|
|
||||||
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir)
|
||||||
if !firstRun {
|
if !firstRun {
|
||||||
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"), aghos.DefaultPermFile)
|
||||||
@@ -303,7 +279,7 @@ func (u *Updater) backup(ctx context.Context, firstRun bool) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wd := u.workDir
|
wd := u.workDir
|
||||||
err = u.copySupportingFiles(ctx, u.unpackedFiles, wd, u.backupDir)
|
err = copySupportingFiles(u.unpackedFiles, wd, u.backupDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", wd, u.backupDir, err)
|
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", wd, u.backupDir, err)
|
||||||
}
|
}
|
||||||
@@ -313,18 +289,13 @@ func (u *Updater) backup(ctx context.Context, firstRun bool) (err error) {
|
|||||||
|
|
||||||
// replace moves the current executable with the updated one and also copies the
|
// replace moves the current executable with the updated one and also copies the
|
||||||
// supporting files.
|
// supporting files.
|
||||||
func (u *Updater) replace(ctx context.Context) (err error) {
|
func (u *Updater) replace() error {
|
||||||
err = u.copySupportingFiles(ctx, u.unpackedFiles, u.updateDir, u.workDir)
|
err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.workDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", u.updateDir, u.workDir, err)
|
return fmt.Errorf("copySupportingFiles(%s, %s) failed: %w", u.updateDir, u.workDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(
|
log.Debug("updater: renaming: %s to %s", u.currentExeName, u.backupExeName)
|
||||||
ctx,
|
|
||||||
"backing up current executable",
|
|
||||||
"from", u.currentExeName,
|
|
||||||
"to", u.backupExeName,
|
|
||||||
)
|
|
||||||
err = os.Rename(u.currentExeName, u.backupExeName)
|
err = os.Rename(u.currentExeName, u.backupExeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -340,22 +311,14 @@ func (u *Updater) replace(ctx context.Context) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(
|
log.Debug("updater: renamed: %s to %s", u.updateExeName, u.currentExeName)
|
||||||
ctx,
|
|
||||||
"replacing current executable",
|
|
||||||
"from", u.updateExeName,
|
|
||||||
"to", u.currentExeName,
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean removes the temporary directory itself and all it's contents.
|
// clean removes the temporary directory itself and all it's contents.
|
||||||
func (u *Updater) clean(ctx context.Context) {
|
func (u *Updater) clean() {
|
||||||
err := os.RemoveAll(u.updateDir)
|
_ = os.RemoveAll(u.updateDir)
|
||||||
if err != nil {
|
|
||||||
u.logger.WarnContext(ctx, "removing update dir", slogutil.KeyError, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxPackageFileSize is a maximum package file length in bytes. The largest
|
// MaxPackageFileSize is a maximum package file length in bytes. The largest
|
||||||
@@ -364,52 +327,34 @@ func (u *Updater) clean(ctx context.Context) {
|
|||||||
const MaxPackageFileSize = 32 * 1024 * 1024
|
const MaxPackageFileSize = 32 * 1024 * 1024
|
||||||
|
|
||||||
// Download package file and save it to disk
|
// Download package file and save it to disk
|
||||||
func (u *Updater) downloadPackageFile(ctx context.Context) (err error) {
|
func (u *Updater) downloadPackageFile() (err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.packageURL, nil)
|
var resp *http.Response
|
||||||
|
resp, err = u.client.Get(u.packageURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("constructing package request: %w", err)
|
return fmt.Errorf("http request failed: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := u.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("requesting package: %w", err)
|
|
||||||
}
|
}
|
||||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||||
|
|
||||||
r := ioutil.LimitReader(resp.Body, MaxPackageFileSize)
|
r := ioutil.LimitReader(resp.Body, MaxPackageFileSize)
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "reading http body")
|
log.Debug("updater: reading http body")
|
||||||
|
|
||||||
// This use of ReadAll is now safe, because we limited body's Reader.
|
// This use of ReadAll is now safe, because we limited body's Reader.
|
||||||
body, err := io.ReadAll(r)
|
body, err := io.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("io.ReadAll() failed: %w", err)
|
return fmt.Errorf("io.ReadAll() failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
_ = os.Mkdir(u.updateDir, aghos.DefaultPermDir)
|
||||||
if err != nil {
|
|
||||||
// TODO(a.garipov): Consider returning this error.
|
|
||||||
u.logger.WarnContext(ctx, "creating update dir", slogutil.KeyError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "saving package", "to", u.packageName)
|
|
||||||
|
|
||||||
|
log.Debug("updater: saving package to file")
|
||||||
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing package file: %w", err)
|
return fmt.Errorf("writing package file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unpackTarGzFile unpacks one file from a .tar.gz archive into outDir. All
|
func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) {
|
||||||
// arguments must not be empty.
|
|
||||||
func (u *Updater) unpackTarGzFile(
|
|
||||||
ctx context.Context,
|
|
||||||
outDir string,
|
|
||||||
tr *tar.Reader,
|
|
||||||
hdr *tar.Header,
|
|
||||||
) (name string, err error) {
|
|
||||||
name = filepath.Base(hdr.Name)
|
name = filepath.Base(hdr.Name)
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
@@ -432,18 +377,13 @@ func (u *Updater) unpackTarGzFile(
|
|||||||
return "", fmt.Errorf("creating directory %q: %w", outName, err)
|
return "", fmt.Errorf("creating directory %q: %w", outName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "created directory", "name", outName)
|
log.Debug("updater: created directory %q", outName)
|
||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if hdr.Typeflag != tar.TypeReg {
|
if hdr.Typeflag != tar.TypeReg {
|
||||||
u.logger.WarnContext(
|
log.Info("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
|
||||||
ctx,
|
|
||||||
"unknown file type; skipping",
|
|
||||||
"file_name", name,
|
|
||||||
"type", hdr.Typeflag,
|
|
||||||
)
|
|
||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@@ -460,19 +400,16 @@ func (u *Updater) unpackTarGzFile(
|
|||||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "created file", "name", outName)
|
log.Debug("updater: created file %q", outName)
|
||||||
|
|
||||||
return name, nil
|
return name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unpackTarGz unpack all files from a .tar.gz archive to outDir. Existing
|
// Unpack all files from .tar.gz file to the specified directory
|
||||||
// files are overwritten. All files are created inside outDir. files are the
|
// Existing files are overwritten
|
||||||
// list of created files.
|
// All files are created inside outDir, subdirectories are not created
|
||||||
func (u *Updater) unpackTarGz(
|
// Return the list of files (not directories) written
|
||||||
ctx context.Context,
|
func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
|
||||||
tarfile string,
|
|
||||||
outDir string,
|
|
||||||
) (files []string, err error) {
|
|
||||||
f, err := os.Open(tarfile)
|
f, err := os.Open(tarfile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("os.Open(): %w", err)
|
return nil, fmt.Errorf("os.Open(): %w", err)
|
||||||
@@ -500,7 +437,7 @@ func (u *Updater) unpackTarGz(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
name, err = u.unpackTarGzFile(ctx, outDir, tarReader, hdr)
|
name, err = tarGzFileUnpackOne(outDir, tarReader, hdr)
|
||||||
|
|
||||||
if name != "" {
|
if name != "" {
|
||||||
files = append(files, name)
|
files = append(files, name)
|
||||||
@@ -510,13 +447,7 @@ func (u *Updater) unpackTarGz(
|
|||||||
return files, err
|
return files, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// unpackZipFile unpacks one file from a .zip archive into outDir. All
|
func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||||
// arguments must not be empty.
|
|
||||||
func (u *Updater) unpackZipFile(
|
|
||||||
ctx context.Context,
|
|
||||||
outDir string,
|
|
||||||
zf *zip.File,
|
|
||||||
) (name string, err error) {
|
|
||||||
var rc io.ReadCloser
|
var rc io.ReadCloser
|
||||||
rc, err = zf.Open()
|
rc, err = zf.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -535,8 +466,7 @@ func (u *Updater) unpackZipFile(
|
|||||||
if name == "AdGuardHome" {
|
if name == "AdGuardHome" {
|
||||||
// Top-level AdGuardHome/. Skip it.
|
// Top-level AdGuardHome/. Skip it.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): See the similar TODO in
|
// TODO(a.garipov): See the similar todo in tarGzFileUnpack.
|
||||||
// [Updater.unpackTarGzFile].
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -545,7 +475,7 @@ func (u *Updater) unpackZipFile(
|
|||||||
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
return "", fmt.Errorf("creating directory %q: %w", outputName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "created directory", "name", outputName)
|
log.Debug("updater: created directory %q", outputName)
|
||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@@ -562,19 +492,16 @@ func (u *Updater) unpackZipFile(
|
|||||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "created file", "name", outputName)
|
log.Debug("updater: created file %q", outputName)
|
||||||
|
|
||||||
return name, nil
|
return name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unpackZip unpack all files from a .zip archive to outDir. Existing files are
|
// Unpack all files from .zip file to the specified directory
|
||||||
// overwritten. All files are created inside outDir. files are the list of
|
// Existing files are overwritten
|
||||||
// created files.
|
// All files are created inside 'outDir', subdirectories are not created
|
||||||
func (u *Updater) unpackZip(
|
// Return the list of files (not directories) written
|
||||||
ctx context.Context,
|
func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
|
||||||
zipfile string,
|
|
||||||
outDir string,
|
|
||||||
) (files []string, err error) {
|
|
||||||
zrc, err := zip.OpenReader(zipfile)
|
zrc, err := zip.OpenReader(zipfile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("zip.OpenReader(): %w", err)
|
return nil, fmt.Errorf("zip.OpenReader(): %w", err)
|
||||||
@@ -583,7 +510,7 @@ func (u *Updater) unpackZip(
|
|||||||
|
|
||||||
for _, zf := range zrc.File {
|
for _, zf := range zrc.File {
|
||||||
var name string
|
var name string
|
||||||
name, err = u.unpackZipFile(ctx, outDir, zf)
|
name, err = zipFileUnpackOne(outDir, zf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -616,12 +543,7 @@ func copyFile(src, dst string, perm fs.FileMode) (err error) {
|
|||||||
// copySupportingFiles copies each file specified in files from srcdir to
|
// copySupportingFiles copies each file specified in files from srcdir to
|
||||||
// dstdir. If a file specified as a path, only the name of the file is used.
|
// dstdir. If a file specified as a path, only the name of the file is used.
|
||||||
// It skips AdGuardHome, AdGuardHome.exe, and AdGuardHome.yaml.
|
// It skips AdGuardHome, AdGuardHome.exe, and AdGuardHome.yaml.
|
||||||
func (u *Updater) copySupportingFiles(
|
func copySupportingFiles(files []string, srcdir, dstdir string) error {
|
||||||
ctx context.Context,
|
|
||||||
files []string,
|
|
||||||
srcdir string,
|
|
||||||
dstdir string,
|
|
||||||
) (err error) {
|
|
||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
_, name := filepath.Split(f)
|
_, name := filepath.Split(f)
|
||||||
if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" {
|
if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" {
|
||||||
@@ -631,12 +553,12 @@ func (u *Updater) copySupportingFiles(
|
|||||||
src := filepath.Join(srcdir, name)
|
src := filepath.Join(srcdir, name)
|
||||||
dst := filepath.Join(dstdir, name)
|
dst := filepath.Join(dstdir, name)
|
||||||
|
|
||||||
err = copyFile(src, dst, aghos.DefaultPermFile)
|
err := copyFile(src, dst, aghos.DefaultPermFile)
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.logger.InfoContext(ctx, "copied", "from", src, "to", dst)
|
log.Debug("updater: copied: %q to %q", src, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
package updater
|
package updater
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -59,7 +55,6 @@ func TestUpdater_internal(t *testing.T) {
|
|||||||
|
|
||||||
u := NewUpdater(&Config{
|
u := NewUpdater(&Config{
|
||||||
Client: fakeClient,
|
Client: fakeClient,
|
||||||
Logger: slogutil.NewDiscardLogger(),
|
|
||||||
GOOS: tc.os,
|
GOOS: tc.os,
|
||||||
Version: "v0.103.0",
|
Version: "v0.103.0",
|
||||||
ExecPath: exePath,
|
ExecPath: exePath,
|
||||||
@@ -73,13 +68,13 @@ func TestUpdater_internal(t *testing.T) {
|
|||||||
u.newVersion = "v0.103.1"
|
u.newVersion = "v0.103.1"
|
||||||
u.packageURL = fakeURL.String()
|
u.packageURL = fakeURL.String()
|
||||||
|
|
||||||
require.NoError(t, u.prepare(newCtx(t)))
|
require.NoError(t, u.prepare())
|
||||||
require.NoError(t, u.downloadPackageFile(newCtx(t)))
|
require.NoError(t, u.downloadPackageFile())
|
||||||
require.NoError(t, u.unpack(newCtx(t)))
|
require.NoError(t, u.unpack())
|
||||||
require.NoError(t, u.backup(newCtx(t), false))
|
require.NoError(t, u.backup(false))
|
||||||
require.NoError(t, u.replace(newCtx(t)))
|
require.NoError(t, u.replace())
|
||||||
|
|
||||||
u.clean(newCtx(t))
|
u.clean()
|
||||||
|
|
||||||
require.True(t, t.Run("backup", func(t *testing.T) {
|
require.True(t, t.Run("backup", func(t *testing.T) {
|
||||||
var d []byte
|
var d []byte
|
||||||
@@ -118,8 +113,3 @@ func TestUpdater_internal(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newCtx is a helper that returns a new context with a timeout.
|
|
||||||
func newCtx(tb testing.TB) (ctx context.Context) {
|
|
||||||
return testutil.ContextWithTimeout(tb, 1*time.Second)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,21 +10,17 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testTimeout is the common timeout for tests.
|
func TestMain(m *testing.M) {
|
||||||
const testTimeout = 1 * time.Second
|
testutil.DiscardLogOutput(m)
|
||||||
|
}
|
||||||
// testLogger is the common logger for tests.
|
|
||||||
var testLogger = slogutil.NewDiscardLogger()
|
|
||||||
|
|
||||||
func TestUpdater_Update(t *testing.T) {
|
func TestUpdater_Update(t *testing.T) {
|
||||||
const jsonData = `{
|
const jsonData = `{
|
||||||
@@ -77,7 +73,6 @@ func TestUpdater_Update(t *testing.T) {
|
|||||||
|
|
||||||
u := updater.NewUpdater(&updater.Config{
|
u := updater.NewUpdater(&updater.Config{
|
||||||
Client: srv.Client(),
|
Client: srv.Client(),
|
||||||
Logger: testLogger,
|
|
||||||
GOARCH: "amd64",
|
GOARCH: "amd64",
|
||||||
GOOS: "linux",
|
GOOS: "linux",
|
||||||
Version: "v0.103.0",
|
Version: "v0.103.0",
|
||||||
@@ -87,12 +82,10 @@ func TestUpdater_Update(t *testing.T) {
|
|||||||
VersionCheckURL: versionCheckURL,
|
VersionCheckURL: versionCheckURL,
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
_, err = u.VersionInfo(false)
|
||||||
_, err = u.VersionInfo(ctx, false)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
err = u.Update(true)
|
||||||
err = u.Update(ctx, true)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check backup files
|
// check backup files
|
||||||
@@ -131,15 +124,14 @@ func TestUpdater_Update(t *testing.T) {
|
|||||||
t.Skip("skipping config check test on windows")
|
t.Skip("skipping config check test on windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = u.Update(testutil.ContextWithTimeout(t, testTimeout), false)
|
err = u.Update(false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("api_fail", func(t *testing.T) {
|
t.Run("api_fail", func(t *testing.T) {
|
||||||
srv.Close()
|
srv.Close()
|
||||||
|
|
||||||
err = u.Update(testutil.ContextWithTimeout(t, testTimeout), true)
|
err = u.Update(true)
|
||||||
|
|
||||||
var urlErr *url.Error
|
var urlErr *url.Error
|
||||||
assert.ErrorAs(t, err, &urlErr)
|
assert.ErrorAs(t, err, &urlErr)
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user