From cf52eeacc98a1e7ec3804a86d593a16bfac37cb5 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Tue, 14 Mar 2023 23:52:10 +0800 Subject: [PATCH] test: add test case for cname --- src/dns.c | 4 --- src/dns_server.c | 20 ++++++++++++++ src/smartdns.c | 7 ++--- test/cases/test-cname.cc | 52 +++++++++++++++++++++++++++++++++++ test/cases/test-tls-server.cc | 16 +++++++++++ test/client.cc | 25 +++++++++++++---- test/server.cc | 43 +++++++++++++++-------------- test/server.h | 1 + 8 files changed, 134 insertions(+), 34 deletions(-) create mode 100644 test/cases/test-cname.cc diff --git a/src/dns.c b/src/dns.c index 7eedf89..b88051d 100644 --- a/src/dns.c +++ b/src/dns.c @@ -1244,10 +1244,6 @@ int dns_get_domain(struct dns_rrs *rrs, char *domain, int maxsize, int *qtype, i { struct dns_context context; - if (rrs->type != DNS_T_CNAME) { - return -1; - } - _dns_init_context_by_rrs(rrs, &context); return _dns_get_qr_head(&context, domain, maxsize, qtype, qclass); } diff --git a/src/dns_server.c b/src/dns_server.c index dc9e484..6f58847 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -6110,6 +6110,23 @@ static void _dns_server_tcp_idle_check(void) } } +#ifdef TEST +static void _dns_server_check_need_exit(void) +{ + static int parent_pid = 0; + if (parent_pid == 0) { + parent_pid = getppid(); + } + + if (parent_pid != getppid()) { + tlog(TLOG_WARN, "parent process exit, exit too."); + dns_server_stop(); + } +} +#else +#define _dns_server_check_need_exit() +#endif + static void _dns_server_period_run_second(void) { static unsigned int sec = 0; @@ -6152,6 +6169,7 @@ static void _dns_server_period_run_second(void) } _dns_server_tcp_idle_check(); + _dns_server_check_need_exit(); if (sec % IPV6_READY_CHECK_TIME == 0 && is_ipv6_ready == 0) { _dns_server_check_ipv6_ready(); @@ -6812,6 +6830,8 @@ int dns_server_init(void) int epollfd = -1; int ret = -1; + _dns_server_check_need_exit(); + if (server.epoll_fd > 0) { return -1; } diff --git a/src/smartdns.c b/src/smartdns.c index d864f1f..837578b 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -745,11 +745,10 @@ int main(int argc, char *argv[]) goto errout; } - atexit(_smartdns_exit); smartdns_test_notify(1); - - return _smartdns_run(); - + ret = _smartdns_run(); + _smartdns_exit(); + return ret; errout: smartdns_test_notify(2); return 1; diff --git a/test/cases/test-cname.cc b/test/cases/test-cname.cc new file mode 100644 index 0000000..ec3d100 --- /dev/null +++ b/test/cases/test-cname.cc @@ -0,0 +1,52 @@ +#include "client.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" +#include "dns.h" + +TEST(server, cname) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + std::string domain = request->domain; + if (request->domain.length() == 0) { + return false; + } + + if (request->qtype == DNS_T_A) { + unsigned char addr[4] = {1, 2, 3, 4}; + dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr); + } else if (request->qtype == DNS_T_AAAA) { + unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr); + } else { + return false; + } + + EXPECT_EQ(domain, "e.com"); + + request->response_packet->head.rcode = DNS_RC_NOERROR; + return true; + }); + + server.Start(R"""(bind [::]:60053 +cname /a.com/b.com +cname /b.com/c.com +cname /c.com/d.com +cname /d.com/e.com +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "b.com."); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "1.2.3.4"); +} diff --git a/test/cases/test-tls-server.cc b/test/cases/test-tls-server.cc index b42fc34..8bb59b8 100644 --- a/test/cases/test-tls-server.cc +++ b/test/cases/test-tls-server.cc @@ -32,3 +32,19 @@ cache-persist no)"""); EXPECT_EQ(client.GetStatus(), "NOERROR"); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } + +TEST(server, DISABLED_TLSCN) +{ + smartdns::Server server; + + server.Start(R"""(bind [::]:60053 +server tls://1.0.0.1:853 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("www.example.com", 60053)); + EXPECT_EQ(client.GetStatus(), "NOERROR"); +} + diff --git a/test/client.cc b/test/client.cc index fa6cb3c..afb08f3 100644 --- a/test/client.cc +++ b/test/client.cc @@ -224,11 +224,16 @@ bool Client::ParserRecord(const std::string &record_str, std::vector { DNSRecord r; - if (r.Parser(record_str) == false) { - return false; + std::vector lines = StringSplit(record_str, '\n'); + + for (auto &line : lines) { + if (r.Parser(line) == false) { + return false; + } + + record.push_back(r); } - record.push_back(r); return true; } @@ -272,21 +277,29 @@ bool Client::ParserResult() flags_ = match[1]; } - std::regex reg_question(";; QUESTION SECTION:\n.(.*\n)+?\n"); + std::regex reg_question(";; QUESTION SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}", + std::regex::ECMAScript | std::regex::optimize); if (std::regex_search(result_, match, reg_question)) { if (ParserRecord(match[1], records_query_) == false) { return false; } } - std::regex reg_answer(";; ANSWER SECTION:\n(.*)\n"); + std::regex reg_answer(";; ANSWER SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}", + std::regex::ECMAScript | std::regex::optimize); if (std::regex_search(result_, match, reg_answer)) { if (ParserRecord(match[1], records_answer_) == false) { return false; } + + if (answer_num_ != records_answer_.size()) { + std::cout << "DIG FAILED: Num Not Match\n" << result_ << std::endl; + return false; + } } - std::regex reg_addition(";; ADDITIONAL SECTION:\n(.*)\n"); + std::regex reg_addition(";; ADDITIONAL SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}", + std::regex::ECMAScript | std::regex::optimize); if (std::regex_search(result_, match, reg_answer)) { if (ParserRecord(match[1], records_additional_) == false) { return false; diff --git a/test/server.cc b/test/server.cc index 7ac8efb..9400bd8 100644 --- a/test/server.cc +++ b/test/server.cc @@ -21,6 +21,7 @@ #include "util.h" #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include #include -#include namespace smartdns { @@ -87,13 +87,15 @@ void MockServer::Run() struct sockaddr_storage from; socklen_t addrlen = sizeof(from); unsigned char in_buff[4096]; + int query_id = 0; int len = recvfrom(fd_, in_buff, sizeof(in_buff), 0, (struct sockaddr *)&from, &addrlen); if (len < 0) { continue; } char packet_buff[4096]; - unsigned char out_buff[4096]; + unsigned char response_data_buff[4096]; + unsigned char response_packet_buff[4096]; memset(packet_buff, 0, sizeof(packet_buff)); struct dns_packet *packet = (struct dns_packet *)packet_buff; struct ServerRequestContext request; @@ -102,6 +104,7 @@ void MockServer::Run() int ret = dns_decode(packet, sizeof(packet_buff), in_buff, len); if (ret == 0) { request.packet = packet; + query_id = packet->head.id; if (packet->head.qr == DNS_QR_QUERY) { struct dns_rrs *rrs = NULL; int rr_count = 0; @@ -126,27 +129,26 @@ void MockServer::Run() request.fromlen = addrlen; request.request_data = in_buff; request.request_data_len = len; - request.response_data = out_buff; + request.response_packet = (struct dns_packet *)response_packet_buff; + request.response_data = response_data_buff; request.response_data_len = 0; - request.response_data_max_len = sizeof(out_buff); + request.response_data_max_len = sizeof(response_data_buff); + + struct dns_head head; + memset(&head, 0, sizeof(head)); + head.id = query_id; + head.qr = DNS_QR_ANSWER; + head.opcode = DNS_OP_QUERY; + head.aa = 0; + head.rd = 0; + head.ra = 1; + head.rcode = DNS_RC_SERVFAIL; + dns_packet_init(request.response_packet, sizeof(response_packet_buff), &head); auto callback_ret = callback_(&request); - if (callback_ret == false) { - unsigned char out_packet_buff[4096]; - struct dns_packet *out_packet = (struct dns_packet *)out_packet_buff; - struct dns_head head; - memset(&head, 0, sizeof(head)); - head.id = packet->head.id; - head.qr = DNS_QR_ANSWER; - head.opcode = DNS_OP_QUERY; - head.aa = 0; - head.rd = 1; - head.ra = 0; - head.rcode = DNS_RC_SERVFAIL; - - dns_packet_init(out_packet, sizeof(out_packet_buff), &head); + if (callback_ret == false || request.response_data_len == 0) { request.response_data_len = - dns_encode(request.response_data, request.response_data_max_len, out_packet); + dns_encode(request.response_data, request.response_data_max_len, request.response_packet); } sendto(fd_, request.response_data, request.response_data_len, MSG_NOSIGNAL, (struct sockaddr *)&from, @@ -255,7 +257,8 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) if (fd < 0) { return false; } - Defer { + Defer + { close(fd); }; diff --git a/test/server.h b/test/server.h index 148df42..9ee8d7c 100644 --- a/test/server.h +++ b/test/server.h @@ -60,6 +60,7 @@ struct ServerRequestContext { uint8_t *request_data; int request_data_len; uint8_t *response_data; + struct dns_packet *response_packet; int response_data_max_len; int response_data_len; };