diff --git a/dnsforward/ipset.go b/dnsforward/ipset.go index c6b3de5a..f5dadac0 100644 --- a/dnsforward/ipset.go +++ b/dnsforward/ipset.go @@ -89,6 +89,31 @@ func (c *ipsetCtx) getIP(rr dns.RR) net.IP { } } +// Find the ipsets for a given host (accounting for subdomain wildcards) +func (c *ipsetCtx) getIpsetNames(host string) ([]string, bool) { + var ipsetNames []string + var found bool + + // search for matching ipset hosts starting with most specific subdomain + i := 0 + for i != -1 { + host = host[i:] + + ipsetNames, found = c.ipsetList[host] + if found { + break + } + + // move slice up to the parent domain + i = strings.Index(host, ".") + if i != -1 { + i++ + } + } + + return ipsetNames, found +} + func addToIpset(host string, ipsetName string, ipStr string) { code, out, err := util.RunCommand("ipset", "add", ipsetName, ipStr) if err != nil { @@ -115,7 +140,7 @@ func (c *ipsetCtx) processMembers(ctx *dnsContext, addMember func(string, string host := req.Question[0].Name host = strings.TrimSuffix(host, ".") host = strings.ToLower(host) - ipsetNames, found := c.ipsetList[host] + ipsetNames, found := c.getIpsetNames(host) if !found { return resultDone } diff --git a/dnsforward/ipset_test.go b/dnsforward/ipset_test.go index c6794790..bf8f362d 100644 --- a/dnsforward/ipset_test.go +++ b/dnsforward/ipset_test.go @@ -158,6 +158,22 @@ func TestIpsetSubdomainOverride(t *testing.T) { assert.Equal(t, 1, len(b)) } +func TestIpsetSubdomainWildcard(t *testing.T) { + setup() + + ctx.proxyCtx.Req = makeReqA("sub.host.com.") + ctx.proxyCtx.Res = &dns.Msg{ + Answer: []dns.RR{ + makeA("sub.host.com.", net.IPv4(127, 0, 0, 1)), + }, + } + + doProcess(t) + + assert.Equal(t, 1, b[Binding{"sub.host.com", "name", "127.0.0.1"}]) + assert.Equal(t, 1, len(b)) +} + func TestIpsetCnameThirdParty(t *testing.T) { setup()