diff --git a/go.mod b/go.mod index 74068e28..f3bfe7ac 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.2 require ( github.com/AdguardTeam/dnsproxy v0.75.4 - github.com/AdguardTeam/golibs v0.32.8 + github.com/AdguardTeam/golibs v0.32.9 github.com/AdguardTeam/urlfilter v0.20.0 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.4.0 diff --git a/go.sum b/go.sum index ab50edbd..3ba97fc9 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,8 @@ cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFs cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY= github.com/AdguardTeam/dnsproxy v0.75.4 h1:hTnHh9HoTYKKhKqePpIxCzfecl7dAXykZTw2gcj0I5U= github.com/AdguardTeam/dnsproxy v0.75.4/go.mod h1:50OyTHao+uQzUJiXay08hgfvWQ3o2Q2WV99W8u8ypDE= -github.com/AdguardTeam/golibs v0.32.8 h1:O3mc3kYcPkW3kbmd+gqzFNgUka13a+iBgFLThwOYSQE= -github.com/AdguardTeam/golibs v0.32.8/go.mod h1:McV1QFFlKLElKa306V4OL/T2kr7564PhsayfvTWYBVs= +github.com/AdguardTeam/golibs v0.32.9 h1:/6luT0aMOn05/s9eh1yA4lbcHgl0d1iEEvEBbIMMUk0= +github.com/AdguardTeam/golibs v0.32.9/go.mod h1:McV1QFFlKLElKa306V4OL/T2kr7564PhsayfvTWYBVs= github.com/AdguardTeam/urlfilter v0.20.0 h1:X32qiuVCVd8WDYCEsbdZKfXMzwdVqrdulamtUi4rmzs= github.com/AdguardTeam/urlfilter v0.20.0/go.mod h1:gjrywLTxfJh6JOkwi9SU+frhP7kVVEZ5exFGkR99qpk= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= diff --git a/internal/client/storage.go b/internal/client/storage.go index ef7d5209..734394d7 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -7,7 +7,6 @@ import ( "net" "net/netip" "slices" - "strings" "sync" "time" @@ -478,7 +477,7 @@ const ErrBadIdentifier errors.Error = "bad client identifier" func (p *FindParams) Set(id string) (err error) { *p = FindParams{} - isClientID := true + isFound := false if netutil.IsValidIPString(id) { // It is safe to use [netip.MustParseAddr] because it has already been @@ -488,24 +487,27 @@ func (p *FindParams) Set(id string) (err error) { // Even if id can be parsed as an IP address, it may be a MAC address. // So do not return prematurely, continue parsing. - isClientID = false + isFound = true } - if canBeValidIPPrefixString(id) { - p.Subnet, err = netip.ParsePrefix(id) - if err == nil { - isClientID = false - } - } - - if canBeMACString(id) { + if netutil.IsValidMACString(id) { p.MAC, err = net.ParseMAC(id) - if err == nil { - isClientID = false + if err != nil { + panic(fmt.Errorf("parsing mac from %q: %w", id, err)) } + + isFound = true } - if !isClientID { + if isFound { + return nil + } + + if netutil.IsValidIPPrefixString(id) { + // It is safe to use [netip.MustParsePrefix] because it has already been + // validated that id contains the string representation of IP prefix. + p.Subnet = netip.MustParsePrefix(id) + return nil } @@ -518,57 +520,6 @@ func (p *FindParams) Set(id string) (err error) { return nil } -// canBeValidIPPrefixString is a best-effort check to determine if s is a valid -// CIDR before using [netip.ParsePrefix], aimed at reducing allocations. -// -// TODO(s.chzhen): Replace this implementation with the more robust version -// from golibs. -func canBeValidIPPrefixString(s string) (ok bool) { - ipStr, bitStr, ok := strings.Cut(s, "/") - if !ok { - return false - } - - if bitStr == "" || len(bitStr) > 3 { - return false - } - - bits := 0 - for _, c := range bitStr { - if c < '0' || c > '9' { - return false - } - - bits = bits*10 + int(c-'0') - } - - if bits > 128 { - return false - } - - return netutil.IsValidIPString(ipStr) -} - -// canBeMACString is a best-effort check to determine if s is a valid MAC -// address before using [net.ParseMAC], aimed at reducing allocations. -// -// TODO(s.chzhen): Replace this implementation with the more robust version -// from golibs. -func canBeMACString(s string) (ok bool) { - switch len(s) { - case - len("0000.0000.0000"), - len("00:00:00:00:00:00"), - len("0000.0000.0000.0000"), - len("00:00:00:00:00:00:00:00"), - len("0000.0000.0000.0000.0000.0000.0000.0000.0000.0000"), - len("00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"): - return true - default: - return false - } -} - // Find represents the parameters for searching a client. params must not be // nil and must have at least one non-empty field. func (s *Storage) Find(params *FindParams) (p *Persistent, ok bool) {