diff --git a/.github/workflows/c-cpp.yml b/.github/workflows/c-cpp.yml index 23aa0bf..25cd580 100644 --- a/.github/workflows/c-cpp.yml +++ b/.github/workflows/c-cpp.yml @@ -13,5 +13,14 @@ jobs: steps: - uses: actions/checkout@v2 + - name: prepare + run: | + sudo apt update + sudo apt install libgtest-dev - name: make - run: make + run: | + make all -j4 + make clean + - name: test + run: | + make -C test test -j8 diff --git a/.gitignore b/.gitignore index 04f4115..60da01c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ .vscode *.o +*.pem .DS_Store *.swp. systemd/smartdns.service +test.bin diff --git a/src/dns.h b/src/dns.h index 65173bd..c94dba9 100644 --- a/src/dns.h +++ b/src/dns.h @@ -19,6 +19,10 @@ #ifndef _DNS_HEAD_H #define _DNS_HEAD_H +#ifdef __cplusplus +extern "C" { +#endif /*__cplusplus */ + #define DNS_RR_A_LEN 4 #define DNS_RR_AAAA_LEN 16 #define DNS_MAX_CNAME_LEN 256 @@ -310,4 +314,7 @@ struct dns_update_param { int dns_packet_update(unsigned char *data, int size, struct dns_update_param *param); +#ifdef __cplusplus +} +#endif /*__cplusplus */ #endif diff --git a/src/smartdns.c b/src/smartdns.c index 9f3ba5c..d864f1f 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -646,7 +646,21 @@ static int _smartdns_init_pre(void) return 0; } +#ifdef TEST +#define smartdns_test_notify(retval) smartdns_test_notify_func(fd_notify, retval) +static void smartdns_test_notify_func(int fd_notify, uint64_t retval) { + /* notify parent kickoff */ + if (fd_notify > 0) { + write(fd_notify, &retval, sizeof(retval)); + } +} + +int smartdns_main(int argc, char *argv[], int fd_notify); +int smartdns_main(int argc, char *argv[], int fd_notify) +#else +#define smartdns_test_notify(retval) int main(int argc, char *argv[]) +#endif { int ret = 0; int is_foreground = 0; @@ -732,10 +746,11 @@ int main(int argc, char *argv[]) } atexit(_smartdns_exit); + smartdns_test_notify(1); return _smartdns_run(); errout: - + smartdns_test_notify(2); return 1; } diff --git a/src/util.c b/src/util.c index 035a145..6b0f4f2 100644 --- a/src/util.c +++ b/src/util.c @@ -401,7 +401,7 @@ int check_is_ipaddr(const char *ip) return -1; } -int parse_uri(char *value, char *scheme, char *host, int *port, char *path) +int parse_uri(const char *value, char *scheme, char *host, int *port, char *path) { return parse_uri_ext(value, scheme, NULL, NULL, host, port, path); } @@ -442,16 +442,16 @@ void urldecode(char *dst, const char *src) *dst++ = '\0'; } -int parse_uri_ext(char *value, char *scheme, char *user, char *password, char *host, int *port, char *path) +int parse_uri_ext(const char *value, char *scheme, char *user, char *password, char *host, int *port, char *path) { char *scheme_end = NULL; int field_len = 0; - char *process_ptr = value; + const char *process_ptr = value; char user_pass_host_part[PATH_MAX]; char *user_password = NULL; char *host_part = NULL; - char *host_end = NULL; + const char *host_end = NULL; scheme_end = strstr(value, "://"); if (scheme_end) { diff --git a/src/util.h b/src/util.h index 20663f6..9bff444 100644 --- a/src/util.h +++ b/src/util.h @@ -69,9 +69,9 @@ int parse_ip(const char *value, char *ip, int *port); int check_is_ipaddr(const char *ip); -int parse_uri(char *value, char *scheme, char *host, int *port, char *path); +int parse_uri(const char *value, char *scheme, char *host, int *port, char *path); -int parse_uri_ext(char *value, char *scheme, char *user, char *password, char *host, int *port, char *path); +int parse_uri_ext(const char *value, char *scheme, char *user, char *password, char *host, int *port, char *path); void urldecode(char *dst, const char *src); diff --git a/test/Makefile b/test/Makefile new file mode 100644 index 0000000..8fa1f73 --- /dev/null +++ b/test/Makefile @@ -0,0 +1,46 @@ + +# Copyright (C) 2018-2023 Ruilin Peng (Nick) . +# +# smartdns is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# smartdns is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +BIN=test.bin +CFLAGS += -I../src -I../src/include +CFLAGS += -DTEST +CFLAGS += -g -Wall -Wstrict-prototypes -fno-omit-frame-pointer -Wstrict-aliasing -funwind-tables -Wmissing-prototypes -Wshadow -Wextra -Wno-unused-parameter -Wno-implicit-fallthrough + +CXXFLAGS += -g +CXXFLAGS += -I./ -I../src -I../src/include + +SMARTDNS_OBJS = lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/conf.o lib/nftset.o +SMARTDNS_OBJS += smartdns.o fast_ping.o dns_client.o dns_server.o dns.o util.o tlog.o dns_conf.o dns_cache.o http_parse.o proxy.o +OBJS = $(addprefix ../src/, $(SMARTDNS_OBJS)) + +TEST_SOURCES := $(wildcard *.cc) $(wildcard */*.cc) $(wildcard */*/*.cc) +TEST_OBJECTS := $(patsubst %.cc, %.o, $(TEST_SOURCES)) +OBJS += $(TEST_OBJECTS) + +LDFLAGS += -lssl -lcrypto -lpthread -ldl -lgtest -lstdc++ -lm + +.PHONY: all clean test + +all: $(BIN) + +$(BIN) : $(OBJS) + $(CC) $(OBJS) -o $@ $(LDFLAGS) + +test: $(BIN) + ./$(BIN) + +clean: + $(RM) $(OBJS) $(BIN) diff --git a/test/cases/test-mock-server.cc b/test/cases/test-mock-server.cc new file mode 100644 index 0000000..ccbb438 --- /dev/null +++ b/test/cases/test-mock-server.cc @@ -0,0 +1,17 @@ +#include "client.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" + +TEST(server, mock) +{ + smartdns::MockServer server; + smartdns::Client client; + server.Start("udp://0.0.0.0:7053", [](struct smartdns::ServerRequestContext *request) { + request->response_data_len = 0; + return false; + }); + + ASSERT_TRUE(client.Query("example.com", 7053)); + EXPECT_EQ(client.GetStatus(), "SERVFAIL"); +} diff --git a/test/cases/test-tls-server.cc b/test/cases/test-tls-server.cc new file mode 100644 index 0000000..b42fc34 --- /dev/null +++ b/test/cases/test-tls-server.cc @@ -0,0 +1,34 @@ +#include "client.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" + +TEST(server, TLSServer) +{ + Defer + { + unlink("/tmp/smartdns-cert.pem"); + unlink("/tmp/smartdns-key.pem"); + }; + + smartdns::Server server_wrap; + smartdns::Server server; + + server.Start(R"""(bind [::]:61053 +server-tls 127.0.0.1:60053 -no-check-certificate +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + server_wrap.Start(R"""(bind-tls [::]:60053 +address /example.com/1.2.3.4 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("example.com", 61053)); + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} diff --git a/test/client.cc b/test/client.cc new file mode 100644 index 0000000..fa6cb3c --- /dev/null +++ b/test/client.cc @@ -0,0 +1,301 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace smartdns +{ + +std::vector StringSplit(const std::string &s, const char delim) +{ + std::vector ret; + std::string::size_type lastPos = s.find_first_not_of(delim, 0); + std::string::size_type pos = s.find_first_of(delim, lastPos); + while (std::string::npos != pos || std::string::npos != lastPos) { + ret.push_back(s.substr(lastPos, pos - lastPos)); + lastPos = s.find_first_not_of(delim, pos); + pos = s.find_first_of(delim, lastPos); + } + + return ret; +} + +DNSRecord::DNSRecord() {} + +DNSRecord::~DNSRecord() {} + +bool DNSRecord::Parser(const std::string &line) +{ + std::vector fields = StringSplit(line, '\t'); + if (fields.size() < 3) { + std::cerr << "Invalid DNS record: " << line << ", size: " << fields.size() << std::endl; + return false; + } + + if (fields.size() == 3) { + name_ = fields[0]; + if (name_.size() > 1) { + name_.resize(name_.size() - 1); + } + class_ = fields[1]; + type_ = fields[2]; + return true; + } + + name_ = fields[0]; + if (name_.size() > 1) { + name_.resize(name_.size() - 1); + } + ttl_ = std::stoi(fields[1]); + class_ = fields[2]; + type_ = fields[3]; + data_ = fields[4]; + + for (int i = 5; i < fields.size(); i++) { + data_ += " " + fields[i]; + } + + return true; +} + +std::string DNSRecord::GetName() +{ + return name_; +} + +std::string DNSRecord::GetType() +{ + return type_; +} + +std::string DNSRecord::GetClass() +{ + return class_; +} + +int DNSRecord::GetTTL() +{ + return ttl_; +} + +std::string DNSRecord::GetData() +{ + return data_; +} + +Client::Client() {} + +bool Client::Query(const std::string &dig_cmds, int port, const std::string &ip) +{ + std::string cmd = "dig "; + if (port > 0) { + cmd += "-p " + std::to_string(port); + } + + if (ip.length() > 0) { + cmd += " @" + ip; + } else { + cmd += " @127.0.0.1"; + } + + cmd += " " + dig_cmds; + cmd += " +tries=1"; + FILE *fp = NULL; + + fp = popen(cmd.c_str(), "r"); + if (fp == NULL) { + return false; + } + + std::shared_ptr pipe(fp, pclose); + result_.clear(); + char buffer[4096]; + usleep(10000); + while (fgets(buffer, 4096, pipe.get())) { + result_ += buffer; + } + + if (ParserResult() == false) { + Clear(); + } + + return true; +} + +std::vector Client::GetQuery() +{ + return records_query_; +} + +std::vector Client::GetAnswer() +{ + return records_answer_; +} + +std::vector Client::GetAuthority() +{ + return records_authority_; +} + +std::vector Client::GetAdditional() +{ + return records_additional_; +} + +int Client::GetAnswerNum() +{ + return answer_num_; +} + +std::string Client::GetStatus() +{ + return status_; +} + +std::string Client::GetServer() +{ + return server_; +} + +int Client::GetQueryTime() +{ + return query_time_; +} + +int Client::GetMsgSize() +{ + return msg_size_; +} + +std::string Client::GetFlags() +{ + return flags_; +} + +std::string Client::GetResult() +{ + return result_; +} + +void Client::Clear() +{ + result_.clear(); + answer_num_ = 0; + status_.clear(); + server_.clear(); + query_time_ = 0; + msg_size_ = 0; + flags_.clear(); + records_query_.clear(); + records_answer_.clear(); + records_authority_.clear(); + records_additional_.clear(); +} + +void Client::PrintResult() +{ + std::cout << result_ << std::endl; +} + +bool Client::ParserRecord(const std::string &record_str, std::vector &record) +{ + DNSRecord r; + + if (r.Parser(record_str) == false) { + return false; + } + + record.push_back(r); + return true; +} + +bool Client::ParserResult() +{ + std::smatch match; + + std::regex reg_goanswer(";; Got answer:"); + if (std::regex_search(result_, match, reg_goanswer) == false) { + std::cout << "DIG FAILED:\n" << result_ << std::endl; + return false; + } + + std::regex reg_answer_num(", ANSWER: ([0-9]+),"); + if (std::regex_search(result_, match, reg_answer_num)) { + answer_num_ = std::stoi(match[1]); + } + + std::regex reg_status(", status: ([A-Z]+),"); + if (std::regex_search(result_, match, reg_status)) { + status_ = match[1]; + } + + std::regex reg_server(";; SERVER: ([0-9.]+)#"); + if (std::regex_search(result_, match, reg_server)) { + server_ = match[1]; + } + + std::regex reg_querytime(";; Query time: ([0-9]+) msec"); + if (std::regex_search(result_, match, reg_querytime)) { + query_time_ = std::stoi(match[1]); + } + + std::regex reg_msg_size(";; MSG SIZE rcvd: ([0-9]+)"); + if (std::regex_search(result_, match, reg_msg_size)) { + msg_size_ = std::stoi(match[1]); + } + + std::regex reg_flags(";; flags: ([a-z A-Z]+);"); + if (std::regex_search(result_, match, reg_flags)) { + flags_ = match[1]; + } + + std::regex reg_question(";; QUESTION SECTION:\n.(.*\n)+?\n"); + if (std::regex_search(result_, match, reg_question)) { + if (ParserRecord(match[1], records_query_) == false) { + return false; + } + } + + std::regex reg_answer(";; ANSWER SECTION:\n(.*)\n"); + if (std::regex_search(result_, match, reg_answer)) { + if (ParserRecord(match[1], records_answer_) == false) { + return false; + } + } + + std::regex reg_addition(";; ADDITIONAL SECTION:\n(.*)\n"); + if (std::regex_search(result_, match, reg_answer)) { + if (ParserRecord(match[1], records_additional_) == false) { + return false; + } + } + + return true; +} + +Client::~Client() {} + +} // namespace smartdns \ No newline at end of file diff --git a/test/client.h b/test/client.h new file mode 100644 index 0000000..797ba03 --- /dev/null +++ b/test/client.h @@ -0,0 +1,106 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef _SMARTDNS_CLIENT_ +#define _SMARTDNS_CLIENT_ + +#include +#include +#include + +namespace smartdns +{ + +class DNSRecord +{ + public: + DNSRecord(); + virtual ~DNSRecord(); + + bool Parser(const std::string &line); + + std::string GetName(); + + std::string GetType(); + + std::string GetClass(); + + int GetTTL(); + + std::string GetData(); + + private: + std::string name_; + std::string type_; + std::string class_; + int ttl_; + std::string data_; +}; + +class Client +{ + public: + Client(); + virtual ~Client(); + bool Query(const std::string &dig_cmds, int port = 0, const std::string &ip = ""); + + std::string GetResult(); + + std::vector GetQuery(); + + std::vector GetAnswer(); + + std::vector GetAuthority(); + + std::vector GetAdditional(); + + int GetAnswerNum(); + + std::string GetStatus(); + + std::string GetServer(); + + int GetQueryTime(); + + int GetMsgSize(); + + std::string GetFlags(); + + void Clear(); + + void PrintResult(); + + private: + bool ParserResult(); + bool ParserRecord(const std::string &record_str, std::vector &record); + std::string result_; + int answer_num_{0}; + std::string status_; + std::string server_; + int query_time_{0}; + int msg_size_{0}; + std::string flags_; + + std::vector records_query_; + std::vector records_answer_; + std::vector records_authority_; + std::vector records_additional_; +}; + +} // namespace smartdns +#endif // _SMARTDNS_CLIENT_ \ No newline at end of file diff --git a/test/include/utils.h b/test/include/utils.h new file mode 100644 index 0000000..a69c421 --- /dev/null +++ b/test/include/utils.h @@ -0,0 +1,59 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef _SMARTDNS_TEST_UTILS_ +#define _SMARTDNS_TEST_UTILS_ + +#include + +namespace smartdns +{ + +class DeferGuard +{ + public: + template + + DeferGuard(Callable &&fn) noexcept : fn_(std::forward(fn)) + { + } + DeferGuard(DeferGuard &&other) noexcept + { + fn_ = std::move(other.fn_); + other.fn_ = nullptr; + } + + virtual ~DeferGuard() + { + if (fn_) { + fn_(); + } + }; + DeferGuard(const DeferGuard &) = delete; + void operator=(const DeferGuard &) = delete; + + private: + std::function fn_; +}; + +#define SMARTDNS_CONCAT_(a, b) a##b +#define SMARTDNS_CONCAT(a, b) SMARTDNS_CONCAT_(a, b) +#define Defer ::smartdns::DeferGuard SMARTDNS_CONCAT(__defer__, __LINE__) = [&]() + +} // namespace smartdns +#endif // _SMARTDNS_TEST_UTILS_ diff --git a/test/server.cc b/test/server.cc new file mode 100644 index 0000000..7ac8efb --- /dev/null +++ b/test/server.cc @@ -0,0 +1,355 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "server.h" +#include "include/utils.h" +#include "util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace smartdns +{ + +extern "C" int smartdns_main(int argc, char *argv[], int fd_notify); + +MockServer::MockServer() {} + +MockServer::~MockServer() +{ + Stop(); +} + +bool MockServer::IsRunning() +{ + if (fd_ > 0) { + return true; + } + + return false; +} + +void MockServer::Stop() +{ + if (run_ == true) { + run_ = false; + if (thread_.joinable()) { + thread_.join(); + } + } + + if (fd_ > 0) { + close(fd_); + fd_; + } +} + +void MockServer::Run() +{ + while (run_ == true) { + struct pollfd fds[1]; + fds[0].fd = fd_; + fds[0].events = POLLIN; + fds[0].revents = 0; + int ret = poll(fds, 1, 100); + if (ret == 0) { + continue; + } else if (ret < 0) { + sleep(1); + continue; + } + + if (fds[0].revents & POLLIN) { + struct sockaddr_storage from; + socklen_t addrlen = sizeof(from); + unsigned char in_buff[4096]; + int len = recvfrom(fd_, in_buff, sizeof(in_buff), 0, (struct sockaddr *)&from, &addrlen); + if (len < 0) { + continue; + } + + char packet_buff[4096]; + unsigned char out_buff[4096]; + memset(packet_buff, 0, sizeof(packet_buff)); + struct dns_packet *packet = (struct dns_packet *)packet_buff; + struct ServerRequestContext request; + memset(&request, 0, sizeof(request)); + + int ret = dns_decode(packet, sizeof(packet_buff), in_buff, len); + if (ret == 0) { + request.packet = packet; + if (packet->head.qr == DNS_QR_QUERY) { + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int qtype = 0; + int qclass = 0; + char domain[256]; + + rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count); + for (int i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + ret = dns_get_domain(rrs, domain, sizeof(domain), &qtype, &qclass); + if (ret == 0) { + request.domain = domain; + request.qtype = (dns_type)qtype; + request.qclass = qclass; + break; + } + } + } + } + + request.from = (struct sockaddr_storage *)&from; + request.fromlen = addrlen; + request.request_data = in_buff; + request.request_data_len = len; + request.response_data = out_buff; + request.response_data_len = 0; + request.response_data_max_len = sizeof(out_buff); + + auto callback_ret = callback_(&request); + if (callback_ret == false) { + unsigned char out_packet_buff[4096]; + struct dns_packet *out_packet = (struct dns_packet *)out_packet_buff; + struct dns_head head; + memset(&head, 0, sizeof(head)); + head.id = packet->head.id; + head.qr = DNS_QR_ANSWER; + head.opcode = DNS_OP_QUERY; + head.aa = 0; + head.rd = 1; + head.ra = 0; + head.rcode = DNS_RC_SERVFAIL; + + dns_packet_init(out_packet, sizeof(out_packet_buff), &head); + request.response_data_len = + dns_encode(request.response_data, request.response_data_max_len, out_packet); + } + + sendto(fd_, request.response_data, request.response_data_len, MSG_NOSIGNAL, (struct sockaddr *)&from, + addrlen); + } + } +} + +bool MockServer::GetAddr(const std::string &host, const std::string port, int type, int protocol, + struct sockaddr_storage *addr, socklen_t *addrlen) + +{ + struct addrinfo hints; + struct addrinfo *result = NULL; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = type; + hints.ai_protocol = protocol; + hints.ai_flags = AI_PASSIVE; + if (getaddrinfo(host.c_str(), port.c_str(), &hints, &result) != 0) { + goto errout; + } + + memcpy(addr, result->ai_addr, result->ai_addrlen); + *addrlen = result->ai_addrlen; + return true; +errout: + if (result) { + freeaddrinfo(result); + } + return NULL; +} + +bool MockServer::Start(const std::string &url, ServerRequest callback) +{ + char c_scheme[256]; + char c_host[256]; + int port; + char c_path[256]; + int fd; + struct sockaddr_storage addr; + socklen_t addrlen; + + if (callback == nullptr) { + return false; + } + + if (parse_uri(url.c_str(), c_scheme, c_host, &port, c_path) != 0) { + return false; + } + + std::string scheme(c_scheme); + std::string host(c_host); + std::string path(c_path); + + if (scheme != "udp") { + return false; + } + + if (GetAddr(host, std::to_string(port), SOCK_DGRAM, IPPROTO_UDP, &addr, &addrlen) == false) { + return false; + } + + fd = socket(addr.ss_family, SOCK_DGRAM | SOCK_CLOEXEC, 0); + if (fd < 0) { + return false; + } + + if (bind(fd, (struct sockaddr *)&addr, addrlen) != 0) { + close(fd); + return false; + } + + run_ = true; + thread_ = std::thread(&MockServer::Run, this); + fd_ = fd; + callback_ = callback; + return true; +} + +Server::Server() {} + +bool Server::Start(const std::string &conf, enum CONF_TYPE type) +{ + int fds[2]; + std::string conf_file; + + fds[0] = 0; + fds[1] = 0; + Defer + { + if (fds[0] > 0) { + close(fds[0]); + } + + if (fds[0] > 0) { + close(fds[1]); + } + }; + + if (type == CONF_TYPE_STRING) { + char filename[128]; + strncpy(filename, "/tmp/smartdns_conf.XXXXXX", sizeof(filename)); + int fd = mkstemp(filename); + if (fd < 0) { + return false; + } + Defer { + close(fd); + }; + + std::ofstream ofs(filename); + if (ofs.is_open() == false) { + return false; + } + ofs.write(conf.data(), conf.size()); + ofs.flush(); + ofs.close(); + conf_file = filename; + clean_conf_file_ = true; + } else if (type == CONF_TYPE_FILE) { + conf_file = conf; + } else { + return false; + } + + if (access(conf_file.c_str(), F_OK) != 0) { + return false; + } + + conf_file_ = conf_file; + + if (pipe2(fds, O_CLOEXEC | O_NONBLOCK) != 0) { + return false; + } + + pid_t pid = fork(); + if (pid == 0) { + std::vector args = { + "smartdns", "-f", "-x", "-c", conf_file, "-p", "-", + }; + char *argv[args.size() + 1]; + for (size_t i = 0; i < args.size(); i++) { + argv[i] = (char *)args[i].c_str(); + } + + smartdns_main(args.size(), argv, fds[1]); + _exit(1); + } else if (pid < 0) { + return false; + } + + struct pollfd pfd[1]; + pfd[0].fd = fds[0]; + pfd[0].events = POLLIN; + + int ret = poll(pfd, 1, 10000); + if (ret == 0) { + kill(pid, SIGKILL); + return false; + } + + pid_ = pid; + return pid > 0; +} + +void Server::Stop(bool graceful) +{ + if (pid_ > 0) { + if (graceful) { + kill(pid_, SIGTERM); + } else { + kill(pid_, SIGKILL); + } + } + + waitpid(pid_, NULL, 0); + + pid_ = 0; + if (clean_conf_file_ == true) { + unlink(conf_file_.c_str()); + conf_file_.clear(); + clean_conf_file_ = false; + } +} + +bool Server::IsRunning() +{ + if (pid_ <= 0) { + return false; + } + + if (waitpid(pid_, NULL, WNOHANG) == 0) { + return true; + } + + return kill(pid_, 0) == 0; +} + +Server::~Server() +{ + Stop(false); +} + +} // namespace smartdns \ No newline at end of file diff --git a/test/server.h b/test/server.h new file mode 100644 index 0000000..148df42 --- /dev/null +++ b/test/server.h @@ -0,0 +1,91 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef _SMARTDNS_SERVER_ +#define _SMARTDNS_SERVER_ + +#include "dns.h" +#include +#include +#include +#include +#include + +namespace smartdns +{ + +class Server +{ + public: + enum CONF_TYPE { + CONF_TYPE_STRING, + CONF_TYPE_FILE, + }; + Server(); + virtual ~Server(); + + bool Start(const std::string &conf, enum CONF_TYPE type = CONF_TYPE_STRING); + void Stop(bool graceful = true); + bool IsRunning(); + + private: + pid_t pid_; + int fd_; + std::string conf_file_; + bool clean_conf_file_{false}; +}; + +struct ServerRequestContext { + std::string domain; + dns_type qtype; + int qclass; + struct sockaddr_storage *from; + socklen_t fromlen; + struct dns_packet *packet; + uint8_t *request_data; + int request_data_len; + uint8_t *response_data; + int response_data_max_len; + int response_data_len; +}; + +using ServerRequest = std::function; + +class MockServer +{ + public: + MockServer(); + virtual ~MockServer(); + + bool Start(const std::string &url, ServerRequest callback); + void Stop(); + bool IsRunning(); + + private: + void Run(); + + bool GetAddr(const std::string &host, const std::string port, int type, int protocol, struct sockaddr_storage *addr, + socklen_t *addrlen); + int fd_; + std::thread thread_; + bool run_; + ServerRequest callback_; +}; + +} // namespace smartdns +#endif // _SMARTDNS_SERVER_ \ No newline at end of file diff --git a/test/test.cc b/test/test.cc new file mode 100644 index 0000000..0ca8008 --- /dev/null +++ b/test/test.cc @@ -0,0 +1,25 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "gtest/gtest.h" + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file