config: add option -no-cache and -response-mode for domain-rules and add some test cases.

This commit is contained in:
Nick Peng
2023-03-24 21:58:41 +08:00
parent 4f2867b7f4
commit dd23c5fc31
16 changed files with 1004 additions and 66 deletions

View File

@@ -202,6 +202,9 @@ static void *_new_dns_rule(enum domain_rule domain_rule)
case DOMAIN_RULE_CHECKSPEED:
size = sizeof(struct dns_domain_check_orders);
break;
case DOMAIN_RULE_RESPONSE_MODE:
size = sizeof(struct dns_response_mode_rule);
break;
case DOMAIN_RULE_CNAME:
size = sizeof(struct dns_cname_rule);
break;
@@ -2388,6 +2391,38 @@ errout:
return 0;
}
static int _conf_domain_rule_response_mode(char *domain, const char *mode)
{
enum response_mode_type response_mode_type = DNS_RESPONSE_MODE_FIRST_PING_IP;
struct dns_response_mode_rule *response_mode = NULL;
for (int i = 0; dns_conf_response_mode_enum[i].name != NULL; i++) {
if (strcmp(mode, dns_conf_response_mode_enum[i].name) == 0) {
response_mode_type = dns_conf_response_mode_enum[i].id;
break;
}
}
response_mode = _new_dns_rule(DOMAIN_RULE_RESPONSE_MODE);
if (response_mode == NULL) {
goto errout;
}
response_mode->mode = response_mode_type;
if (_config_domain_rule_add(domain, DOMAIN_RULE_RESPONSE_MODE, response_mode) != 0) {
goto errout;
}
_dns_rule_put(&response_mode->head);
return 0;
errout:
if (response_mode) {
_dns_rule_put(&response_mode->head);
}
return 0;
}
static int _conf_domain_set(void *data, int argc, char *argv[])
{
int opt = 0;
@@ -2527,6 +2562,11 @@ static int _conf_domain_rule_delete(const char *domain)
return _config_domain_rule_delete(domain);
}
static int _conf_domain_rule_no_cache(const char *domain)
{
return _config_domain_rule_flag_set(domain, DOMAIN_FLAG_NO_CACHE, 0);
}
static int _conf_domain_rules(void *data, int argc, char *argv[])
{
int opt = 0;
@@ -2539,6 +2579,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
/* clang-format off */
static struct option long_options[] = {
{"speed-check-mode", required_argument, NULL, 'c'},
{"response-mode", required_argument, NULL, 'r'},
{"address", required_argument, NULL, 'a'},
{"ipset", required_argument, NULL, 'p'},
{"nftset", required_argument, NULL, 't'},
@@ -2550,6 +2591,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
{"rr-ttl-max", required_argument, NULL, 253},
{"no-serve-expired", no_argument, NULL, 254},
{"delete", no_argument, NULL, 255},
{"no-cache", no_argument, NULL, 256},
{NULL, no_argument, NULL, 0}
};
/* clang-format on */
@@ -2566,7 +2608,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
/* process extra options */
optind = 1;
while (1) {
opt = getopt_long_only(argc, argv, "c:a:p:t:n:d:A:", long_options, NULL);
opt = getopt_long_only(argc, argv, "c:a:p:t:n:d:A:r:", long_options, NULL);
if (opt == -1) {
break;
}
@@ -2585,6 +2627,19 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
break;
}
case 'r': {
const char *response_mode = optarg;
if (response_mode == NULL) {
goto errout;
}
if (_conf_domain_rule_response_mode(domain, response_mode) != 0) {
tlog(TLOG_ERROR, "add response-mode rule failed.");
goto errout;
}
break;
}
case 'a': {
const char *address = optarg;
if (address == NULL) {
@@ -2684,6 +2739,14 @@ static int _conf_domain_rules(void *data, int argc, char *argv[])
return 0;
}
case 256: {
if (_conf_domain_rule_no_cache(domain) != 0) {
tlog(TLOG_ERROR, "set no-cache rule failed.");
goto errout;
}
break;
}
default:
break;
}

View File

@@ -74,6 +74,7 @@ enum domain_rule {
DOMAIN_RULE_NFTSET_IP6,
DOMAIN_RULE_NAMESERVER,
DOMAIN_RULE_CHECKSPEED,
DOMAIN_RULE_RESPONSE_MODE,
DOMAIN_RULE_CNAME,
DOMAIN_RULE_TTL,
DOMAIN_RULE_MAX,
@@ -110,6 +111,7 @@ typedef enum {
#define DOMAIN_FLAG_NFTSET_IP6_IGN (1 << 14)
#define DOMAIN_FLAG_NO_SERVE_EXPIRED (1 << 15)
#define DOMAIN_FLAG_CNAME_IGN (1 << 16)
#define DOMAIN_FLAG_NO_CACHE (1 << 17)
#define SERVER_FLAG_EXCLUDE_DEFAULT (1 << 0)
@@ -124,6 +126,12 @@ typedef enum {
#define BIND_FLAG_FORCE_AAAA_SOA (1 << 8)
#define BIND_FLAG_NO_RULE_CNAME (1 << 9)
enum response_mode_type {
DNS_RESPONSE_MODE_FIRST_PING_IP = 0,
DNS_RESPONSE_MODE_FASTEST_IP,
DNS_RESPONSE_MODE_FASTEST_RESPONSE,
};
struct dns_rule {
atomic_t refcnt;
enum domain_rule rule;
@@ -227,6 +235,11 @@ struct dns_domain_check_orders {
struct dns_domain_check_order orders[DOMAIN_CHECK_NUM];
};
struct dns_response_mode_rule {
struct dns_rule head;
enum response_mode_type mode;
};
struct dns_group_table {
DECLARE_HASHTABLE(group, 8);
};
@@ -464,11 +477,6 @@ extern int dns_conf_dualstack_ip_allow_force_AAAA;
extern int dns_conf_dualstack_ip_selection_threshold;
extern int dns_conf_max_reply_ip_num;
enum response_mode_type {
DNS_RESPONSE_MODE_FIRST_PING_IP = 0,
DNS_RESPONSE_MODE_FASTEST_IP,
DNS_RESPONSE_MODE_FASTEST_RESPONSE,
};
extern enum response_mode_type dns_conf_response_mode;
extern int dns_conf_rr_ttl;

View File

@@ -304,10 +304,13 @@ struct dns_request {
struct dns_domain_check_orders *check_order_list;
int check_order;
enum response_mode_type response_mode;
struct dns_request_pending_list *request_pending_list;
int no_select_possible_ip;
int no_cache_cname;
int no_cache;
};
/* dns server data */
@@ -1269,7 +1272,7 @@ static int _dns_cache_cname_packet(struct dns_server_post_context *context)
struct dns_request *request = context->request;
if (request->has_cname == 0 || request->no_cache_cname == 1) {
if (request->has_cname == 0 || request->no_cache_cname == 1 || request->no_cache == 1) {
return 0;
}
@@ -1514,7 +1517,7 @@ static int _dns_cache_reply_packet(struct dns_server_post_context *context)
{
struct dns_request *request = context->request;
int has_soa = request->has_soa;
if (context->do_cache == 0 || _dns_server_has_bind_flag(request, BIND_FLAG_NO_CACHE) == 0) {
if (context->do_cache == 0 || request->no_cache == 1) {
return 0;
}
@@ -2433,6 +2436,7 @@ static struct dns_request *_dns_server_new_request(void)
request->qclass = DNS_C_IN;
request->result_callback = NULL;
request->check_order_list = &dns_conf_check_orders;
request->response_mode = dns_conf_response_mode;
INIT_LIST_HEAD(&request->list);
INIT_LIST_HEAD(&request->pending_list);
INIT_LIST_HEAD(&request->check_list);
@@ -2585,7 +2589,7 @@ out:
}
/* Get first ping result */
if (dns_conf_response_mode == DNS_RESPONSE_MODE_FIRST_PING_IP && last_rtt == -1 && request->ping_time > 0) {
if (request->response_mode == DNS_RESPONSE_MODE_FIRST_PING_IP && last_rtt == -1 && request->ping_time > 0) {
may_complete = 1;
}
@@ -3422,7 +3426,7 @@ static int dns_server_resolve_callback(const char *domain, dns_result_type rtype
return _dns_server_reply_passthrough(&context);
}
if (request->prefetch == 0 && dns_conf_response_mode == DNS_RESPONSE_MODE_FASTEST_RESPONSE &&
if (request->prefetch == 0 && request->response_mode == DNS_RESPONSE_MODE_FASTEST_RESPONSE &&
atomic_read(&request->notified) == 0) {
struct dns_server_post_context context;
int ttl = 0;
@@ -3936,6 +3940,10 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request)
request->no_serve_expired = 1;
}
if ((flags & DOMAIN_FLAG_NO_CACHE) || (_dns_server_has_bind_flag(request, BIND_FLAG_NO_CACHE) == 0)) {
request->no_cache = 1;
}
if (flags & DOMAIN_FLAG_ADDR_IGN) {
/* ignore this domain */
goto out;
@@ -4432,17 +4440,22 @@ static int _dns_server_qtype_soa(struct dns_request *request)
return -1;
}
static void _dns_server_process_speed_check_rule(struct dns_request *request)
static void _dns_server_process_speed_rule(struct dns_request *request)
{
struct dns_domain_check_orders *check_order = NULL;
struct dns_response_mode_rule *response_mode = NULL;
/* get domain rule flag */
/* get speed check mode */
check_order = _dns_server_get_dns_rule(request, DOMAIN_RULE_CHECKSPEED);
if (check_order == NULL) {
return;
if (check_order != NULL) {
request->check_order_list = check_order;
}
request->check_order_list = check_order;
/* get response mode */
response_mode = _dns_server_get_dns_rule(request, DOMAIN_RULE_RESPONSE_MODE);
if (response_mode != NULL) {
request->response_mode = response_mode->mode;
}
}
static int _dns_server_get_expired_ttl_reply(struct dns_cache *dns_cache)
@@ -5073,7 +5086,7 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve
}
/* process speed check rule */
_dns_server_process_speed_check_rule(request);
_dns_server_process_speed_rule(request);
/* check and set passthrough */
_dns_server_check_set_passthrough(request);

165
test/cases/test-address.cc Normal file
View File

@@ -0,0 +1,165 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class Address : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(Address, soa)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /a.com/#4
address /b.com/#6
address /c.com/#
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
EXPECT_EQ(client.GetAuthority()[0].GetData(),
"a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
ASSERT_TRUE(client.Query("b.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("b.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAuthority()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
EXPECT_EQ(client.GetAuthority()[0].GetData(),
"a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400");
ASSERT_TRUE(client.Query("c.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAuthority()[0].GetName(), "c.com");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
EXPECT_EQ(client.GetAuthority()[0].GetData(),
"a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400");
ASSERT_TRUE(client.Query("c.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAuthority()[0].GetName(), "c.com");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
EXPECT_EQ(client.GetAuthority()[0].GetData(),
"a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400");
}
TEST_F(Address, ip)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /a.com/10.10.10.10
address /a.com/64:ff9b::1010:1010
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "10.10.10.10");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::1010:1010");
}

View File

@@ -84,7 +84,7 @@ cache-persist no)""");
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_GT(client.GetAnswer()[0].GetTTL(), 609);
EXPECT_GE(client.GetAnswer()[0].GetTTL(), 609);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
@@ -119,4 +119,45 @@ cache-persist no)""");
EXPECT_LT(client.GetQueryTime(), 100);
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST(Bind, nocache)
{
smartdns::MockServer server_upstream;
smartdns::MockServer server_upstream2;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
unsigned char addr[4] = {1, 2, 3, 4};
usleep(15 * 1000);
dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(
bind [::]:60053 --no-cache
bind-tcp [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("a.com", 60053));
EXPECT_GT(client.GetQueryTime(), 10);
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}

View File

@@ -151,7 +151,7 @@ server 127.0.0.1:61053
log-num 0
cache-size 1
rr-ttl-min 600
rr-ttl-reply-max 5
rr-ttl-reply-max 6
log-console yes
log-level debug
cache-persist no)""");
@@ -169,7 +169,62 @@ cache-persist no)""");
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
EXPECT_GE(client.GetAnswer()[0].GetTTL(), 5);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST_F(Cache, nocache)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
std::string domain = request->domain;
if (request->domain.length() == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
usleep(15000);
if (request->qtype == DNS_T_A) {
unsigned char addr[4] = {1, 2, 3, 4};
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
} else if (request->qtype == DNS_T_AAAA) {
unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr);
} else {
return smartdns::SERVER_REQUEST_ERROR;
}
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
cache-size 100
rr-ttl-min 600
rr-ttl-reply-max 5
log-console yes
log-level debug
domain-rules /a.com/ --no-cache
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("a.com", 60053));
EXPECT_GT(client.GetQueryTime(), 10);
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}

View File

@@ -28,22 +28,13 @@ TEST(server, cname)
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
std::string domain = request->domain;
if (request->domain.length() == 0) {
return smartdns::SERVER_REQUEST_ERROR;
if (request->qtype != DNS_T_A) {
return smartdns::SERVER_REQUEST_SOA;
}
if (request->qtype == DNS_T_A) {
unsigned char addr[4] = {1, 2, 3, 4};
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr);
} else if (request->qtype == DNS_T_AAAA) {
unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr);
} else {
return smartdns::SERVER_REQUEST_ERROR;
}
EXPECT_EQ(domain, "e.com");
unsigned char addr[4] = {1, 2, 3, 4};
dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr);
EXPECT_EQ(request->domain, "e.com");
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;

65
test/cases/test-dns64.cc Normal file
View File

@@ -0,0 +1,65 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class DNS64 : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(DNS64, no_dualstack)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
dns64 64:ff9b::/96
log-console yes
dualstack-ip-selection no
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
}

View File

@@ -0,0 +1,88 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class DomainSet : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(DomainSet, set_add)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
smartdns::TempFile file_set;
std::vector<std::string> domain_list;
int count = 16;
std::string config = "domain-set -name test-set -file " + file_set.GetPath() + "\n";
config += R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level info
cache-persist no
domain-rules /domain-set:test-set/ -c none --dualstack-ip-selection no -a 9.9.9.9
)""";
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
for (int i = 0; i < count; i++) {
auto domain = smartdns::GenerateRandomString(10) + "." + smartdns::GenerateRandomString(3);
file_set.Write(domain);
file_set.Write("\n");
domain_list.emplace_back(domain);
}
std::cout << config << std::endl;
server.Start(config);
smartdns::Client client;
for (auto &domain : domain_list) {
ASSERT_TRUE(client.Query(domain, 60053));
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), domain);
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "9.9.9.9");
}
ASSERT_TRUE(client.Query("a.com", 60053));
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}

View File

@@ -0,0 +1,191 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class SpeedCheck : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(SpeedCheck, response_mode)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8");
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
response-mode first-ping
domain-rules /a.com/ -r fastest-response
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("b.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_GT(client.GetQueryTime(), 100);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 2);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_LT(client.GetQueryTime(), 10);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8");
}
TEST_F(SpeedCheck, none)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8");
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
speed-check-mode none
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("b.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 2);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_LT(client.GetQueryTime(), 20);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 2);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_LT(client.GetQueryTime(), 20);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8");
}
TEST_F(SpeedCheck, domain_rules_none)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8");
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
domain-rules /a.com/ -c none
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("b.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_GT(client.GetQueryTime(), 200);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 2);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_LT(client.GetQueryTime(), 20);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8");
}
TEST_F(SpeedCheck, only_ping)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4");
smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8");
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
speed-check-mode ping
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("b.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_LT(client.GetQueryTime(), 1200);
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
}

View File

@@ -255,6 +255,11 @@ bool Client::ParserResult()
answer_num_ = std::stoi(match[1]);
}
std::regex reg_authority_num(", AUTHORITY: ([0-9]+),");
if (std::regex_search(result_, match, reg_authority_num)) {
records_authority_num_ = std::stoi(match[1]);
}
std::regex reg_status(", status: ([A-Z]+),");
if (std::regex_search(result_, match, reg_status)) {
status_ = match[1];
@@ -301,6 +306,19 @@ bool Client::ParserResult()
}
}
std::regex reg_authority(";; AUTHORITY SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}",
std::regex::ECMAScript | std::regex::optimize);
if (std::regex_search(result_, match, reg_authority)) {
if (ParserRecord(match[1], records_authority_) == false) {
return false;
}
if (records_authority_num_ != records_authority_.size()) {
std::cout << "DIG FAILED: Num Not Match\n" << result_ << std::endl;
return false;
}
}
std::regex reg_addition(";; ADDITIONAL SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}",
std::regex::ECMAScript | std::regex::optimize);
if (std::regex_search(result_, match, reg_answer)) {

View File

@@ -90,6 +90,7 @@ class Client
bool ParserRecord(const std::string &record_str, std::vector<DNSRecord> &record);
std::string result_;
int answer_num_{0};
int records_authority_num_{0};
std::string status_;
std::string server_;
int query_time_{0};

View File

@@ -19,8 +19,10 @@
#ifndef _SMARTDNS_TEST_UTILS_
#define _SMARTDNS_TEST_UTILS_
#include <fstream>
#include <functional>
#include <string>
#include <vector>
namespace smartdns
{
@@ -56,9 +58,54 @@ class DeferGuard
#define SMARTDNS_CONCAT(a, b) SMARTDNS_CONCAT_(a, b)
#define Defer ::smartdns::DeferGuard SMARTDNS_CONCAT(__defer__, __LINE__) = [&]()
class TempFile
{
public:
TempFile();
TempFile(const std::string &line);
virtual ~TempFile();
bool Write(const std::string &line);
std::string GetPath();
void SetPattern(const std::string &pattern);
private:
bool NewTempFile();
std::string path_;
std::ofstream ofs_;
std::string pattern_;
};
class Commander
{
public:
Commander();
virtual ~Commander();
bool Run(const std::vector<std::string> &cmds);
bool Run(const std::string &cmd);
void Kill();
void Terminate();
int ExitCode();
int GetPid();
private:
pid_t pid_{-1};
int exit_code_ = {-1};
};
bool IsCommandExists(const std::string &cmd);
std::string GenerateRandomString(int len);
int ParserArg(const std::string &cmd, std::vector<std::string> &args);
} // namespace smartdns
#endif // _SMARTDNS_TEST_UTILS_

View File

@@ -283,7 +283,8 @@ bool MockServer::Start(const std::string &url, ServerRequest callback)
return true;
}
Server::Server() {
Server::Server()
{
mode_ = Server::CREATE_MODE_FORK;
}
@@ -312,26 +313,9 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type)
};
if (type == CONF_TYPE_STRING) {
char filename[128];
strncpy(filename, "/tmp/smartdns_conf.XXXXXX", sizeof(filename));
int fd = mkstemp(filename);
if (fd < 0) {
return false;
}
Defer
{
close(fd);
};
std::ofstream ofs(filename);
if (ofs.is_open() == false) {
return false;
}
ofs.write(conf.data(), conf.size());
ofs.flush();
ofs.close();
conf_file = filename;
clean_conf_file_ = true;
conf_temp_file_.SetPattern("/tmp/smartdns_conf.XXXXXX");
conf_temp_file_.Write(conf);
conf_file = conf_temp_file_.GetPath();
} else if (type == CONF_TYPE_FILE) {
conf_file = conf;
} else {
@@ -418,11 +402,6 @@ void Server::Stop(bool graceful)
waitpid(pid_, nullptr, 0);
pid_ = 0;
if (clean_conf_file_ == true) {
unlink(conf_file_.c_str());
conf_file_.clear();
clean_conf_file_ = false;
}
}
bool Server::IsRunning()

View File

@@ -20,6 +20,7 @@
#define _SMARTDNS_SERVER_
#include "dns.h"
#include "include/utils.h"
#include <functional>
#include <string>
#include <sys/socket.h>
@@ -53,7 +54,8 @@ class Server
std::thread thread_;
int fd_;
std::string conf_file_;
bool clean_conf_file_{false};
TempFile conf_temp_file_;
enum CREATE_MODE mode_;
};

View File

@@ -1,12 +1,179 @@
#include "include/utils.h"
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/wait.h>
#include <unistd.h>
namespace smartdns
{
TempFile::TempFile()
{
pattern_ = "/tmp/smartdns-test-tmp.XXXXXX";
}
TempFile::TempFile(const std::string &line)
{
pattern_ = "/tmp/smartdns-test-tmp.XXXXXX";
}
TempFile::~TempFile()
{
if (ofs_.is_open()) {
ofs_.close();
ofs_.clear();
}
if (path_.length() > 0) {
unlink(path_.c_str());
}
}
void TempFile::SetPattern(const std::string &pattern)
{
pattern_ = pattern;
}
bool TempFile::Write(const std::string &line)
{
if (ofs_.is_open() == false) {
if (NewTempFile() == false) {
return false;
}
}
ofs_.write(line.data(), line.size());
if (ofs_.fail()) {
return false;
}
ofs_.flush();
return true;
}
bool TempFile::NewTempFile()
{
char filename[128];
strncpy(filename, "/tmp/smartdns-test-tmp.XXXXXX", sizeof(filename));
int fd = mkstemp(filename);
if (fd < 0) {
return false;
}
Defer
{
close(fd);
};
std::ofstream ofs(filename);
if (ofs.is_open() == false) {
return false;
}
ofs_ = std::move(ofs);
path_ = filename;
return true;
}
std::string TempFile::GetPath()
{
if (ofs_.is_open() == false) {
if (NewTempFile() == false) {
return "";
}
}
return path_;
}
Commander::Commander() {}
Commander::~Commander()
{
Kill();
}
bool Commander::Run(const std::string &cmd)
{
std::vector<std::string> args;
if (ParserArg(cmd, args) != 0) {
return false;
}
return Run(args);
}
bool Commander::Run(const std::vector<std::string> &cmds)
{
pid_t pid;
if (pid_ > 0) {
return false;
}
pid = fork();
if (pid < 0) {
return false;
}
if (pid == 0) {
char *argv[cmds.size() + 1];
for (int i = 0; i < cmds.size(); i++) {
argv[i] = (char *)cmds[i].c_str();
}
argv[cmds.size()] = nullptr;
execvp(argv[0], argv);
_exit(1);
}
pid_ = pid;
return true;
}
void Commander::Kill()
{
if (pid_ <= 0) {
return;
}
kill(pid_, SIGKILL);
}
void Commander::Terminate()
{
if (pid_ <= 0) {
return;
}
kill(pid_, SIGTERM);
}
int Commander::ExitCode()
{
int wstatus = 0;
if (exit_code_ >= 0) {
return exit_code_;
}
if (pid_ <= 0) {
return -1;
}
if (waitpid(pid_, &wstatus, 0) == -1) {
return -1;
}
exit_code_ = WEXITSTATUS(wstatus);
return exit_code_;
}
int Commander::GetPid()
{
return pid_;
}
bool IsCommandExists(const std::string &cmd)
{
char *copy_path = nullptr;
@@ -42,17 +209,61 @@ bool IsCommandExists(const std::string &cmd)
std::string GenerateRandomString(int len)
{
std::string result;
std::string result;
static const char alphanum[] = "0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
result.resize(len);
result.resize(len);
for (int i = 0; i < len; ++i) {
result[i] = alphanum[rand() % (sizeof(alphanum) - 1)];
}
for (int i = 0; i < len; ++i) {
result[i] = alphanum[rand() % (sizeof(alphanum) - 1)];
}
return result;
return result;
}
int ParserArg(const std::string &cmd, std::vector<std::string> &args)
{
std::string arg;
char quoteChar = 0;
for (char ch : cmd) {
if (quoteChar == '\\') {
arg.push_back(ch);
quoteChar = 0;
continue;
}
if (quoteChar && ch != quoteChar) {
arg.push_back(ch);
continue;
}
switch (ch) {
case '\'':
case '\"':
case '\\':
quoteChar = quoteChar ? 0 : ch;
break;
case ' ':
case '\t':
case '\n':
if (!arg.empty()) {
args.push_back(arg);
arg.clear();
}
break;
default:
arg.push_back(ch);
break;
}
}
if (!arg.empty()) {
args.push_back(arg);
}
return 0;
}
} // namespace smartdns