diff --git a/src/dns_conf.c b/src/dns_conf.c index 70945b0..09db101 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -864,20 +864,63 @@ static int _config_domain_rule_add_callback(const char *domain, void *priv) return _config_domain_rule_add(domain, args->type, args->rule); } +static int _config_setup_domain_key(const char *domain, char *domain_key, int domain_key_max_len, int *domain_key_len, + int *root_rule_only, int *sub_rule_only) +{ + int tmp_root_rule_only = 0; + int tmp_sub_rule_only = 0; + + int len = strlen(domain); + if (len >= domain_key_max_len - 2) { + tlog(TLOG_ERROR, "domain %s too long", domain); + return -1; + } + + reverse_string(domain_key, domain, len, 1); + if (domain[0] == '*') { + /* prefix wildcard */ + len--; + if (domain[1] == '.') { + tmp_sub_rule_only = 1; + } else if ((domain[1] == '-') && (domain[2] == '.')) { + len--; + tmp_sub_rule_only = 1; + tmp_root_rule_only = 1; + } + } else if (domain[0] == '-') { + /* root match only */ + len--; + if (domain[1] == '.') { + tmp_root_rule_only = 1; + } + } else { + /* suffix match */ + domain_key[len] = '.'; + len++; + } + domain_key[len] = 0; + + *domain_key_len = len; + if (root_rule_only) { + *root_rule_only = tmp_root_rule_only; + } + + if (sub_rule_only) { + *sub_rule_only = tmp_sub_rule_only; + } + + return 0; +} + static struct dns_domain_rule *_config_domain_rule_get(const char *domain) { char domain_key[DNS_MAX_CONF_CNAME_LEN]; int len = 0; - len = strlen(domain); - if (len >= (int)sizeof(domain_key) - 2) { + if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, NULL, NULL) != 0) { return NULL; } - reverse_string(domain_key, domain, len, 1); - domain_key[len] = '.'; - len++; - domain_key[len] = 0; return art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len); } @@ -892,18 +935,6 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo int sub_rule_only = 0; int root_rule_only = 0; - /* Reverse string, for suffix match */ - len = strlen(domain); - if (len >= (int)sizeof(domain_key) - 2) { - tlog(TLOG_ERROR, "domain name %s too long", domain); - goto errout; - } - - if (len <= 0) { - tlog(TLOG_ERROR, "domain name %s too short", domain); - goto errout; - } - if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) { struct dns_set_rule_add_callback_args args; args.type = type; @@ -912,29 +943,10 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo &args); } - reverse_string(domain_key, domain, len, 1); - if (domain[0] == '*') { - /* prefix wildcard */ - len--; - if (domain[1] == '.') { - sub_rule_only = 1; - } else if ((domain[1] == '-') && (domain[2] == '.')) { - len--; - sub_rule_only = 1; - root_rule_only = 1; - } - } else if (domain[0] == '-') { - /* root match only */ - len--; - if (domain[1] == '.') { - root_rule_only = 1; - } - } else { - /* suffix match */ - domain_key[len] = '.'; - len++; + /* Reverse string, for suffix match */ + if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, &root_rule_only, &sub_rule_only) != 0) { + goto errout; } - domain_key[len] = 0; if (type >= DOMAIN_RULE_MAX) { goto errout; @@ -991,22 +1003,15 @@ static int _config_domain_rule_delete(const char *domain) char domain_key[DNS_MAX_CONF_CNAME_LEN]; int len = 0; - /* Reverse string, for suffix match */ - len = strlen(domain); - if (len >= (int)sizeof(domain_key)) { - tlog(TLOG_ERROR, "domain name %s too long", domain); - goto errout; - } - if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) { return _config_domain_rule_set_each(domain + sizeof("domain-set:") - 1, _config_domain_rule_delete_callback, NULL); } + /* Reverse string, for suffix match */ - reverse_string(domain_key, domain, len, 1); - domain_key[len] = '.'; - len++; - domain_key[len] = 0; + if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, NULL, NULL) != 0) { + goto errout; + } /* delete existing rules */ void *rule = art_delete(&dns_conf_domain_rule, (unsigned char *)domain_key, len); @@ -1036,6 +1041,8 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u char domain_key[DNS_MAX_CONF_CNAME_LEN]; int len = 0; + int sub_rule_only = 0; + int root_rule_only = 0; if (strncmp(domain, "domain-set:", sizeof("domain-set:") - 1) == 0) { struct dns_set_rule_flags_callback_args args; @@ -1045,15 +1052,9 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u &args); } - len = strlen(domain); - if (len >= (int)sizeof(domain_key)) { - tlog(TLOG_ERROR, "domain %s too long", domain); - return -1; + if (_config_setup_domain_key(domain, domain_key, sizeof(domain_key), &len, &root_rule_only, &sub_rule_only) != 0) { + goto errout; } - reverse_string(domain_key, domain, len, 1); - domain_key[len] = '.'; - len++; - domain_key[len] = 0; /* Get existing or create domain rule */ domain_rule = art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len); @@ -1073,6 +1074,9 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u domain_rule->rules[DOMAIN_RULE_FLAGS] = (struct dns_rule *)rule_flags; } + domain_rule->sub_rule_only = sub_rule_only; + domain_rule->root_rule_only = root_rule_only; + rule_flags = (struct dns_rule_flags *)domain_rule->rules[DOMAIN_RULE_FLAGS]; if (is_clear == false) { rule_flags->flags |= flag; diff --git a/test/cases/test-rule.cc b/test/cases/test-rule.cc index 0a50080..ececea2 100644 --- a/test/cases/test-rule.cc +++ b/test/cases/test-rule.cc @@ -259,4 +259,59 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); -} \ No newline at end of file +} + +TEST_F(Rule, AAAA_SOA) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +speed-check-mode none +address /-.a.com/#6 +address /*.b.com/#6 +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + + ASSERT_TRUE(client.Query("a.a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304"); + + ASSERT_TRUE(client.Query("a.b.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + + ASSERT_TRUE(client.Query("b.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304"); +}