From dd23c5fc3165a8f1cd582c3e1387c98320592bfb Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Fri, 24 Mar 2023 21:58:41 +0800 Subject: [PATCH] config: add option -no-cache and -response-mode for domain-rules and add some test cases. --- src/dns_conf.c | 65 +++++++++- src/dns_conf.h | 18 ++- src/dns_server.c | 33 +++-- test/cases/test-address.cc | 165 ++++++++++++++++++++++++ test/cases/test-bind.cc | 43 ++++++- test/cases/test-cache.cc | 59 ++++++++- test/cases/test-cname.cc | 19 +-- test/cases/test-dns64.cc | 65 ++++++++++ test/cases/test-domain-set.cc | 88 +++++++++++++ test/cases/test-speed-check.cc | 191 ++++++++++++++++++++++++++++ test/client.cc | 18 +++ test/client.h | 1 + test/include/utils.h | 47 +++++++ test/server.cc | 31 +---- test/server.h | 4 +- test/utils.cc | 223 ++++++++++++++++++++++++++++++++- 16 files changed, 1004 insertions(+), 66 deletions(-) create mode 100644 test/cases/test-address.cc create mode 100644 test/cases/test-dns64.cc create mode 100644 test/cases/test-domain-set.cc create mode 100644 test/cases/test-speed-check.cc diff --git a/src/dns_conf.c b/src/dns_conf.c index 022c303..60920ad 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -202,6 +202,9 @@ static void *_new_dns_rule(enum domain_rule domain_rule) case DOMAIN_RULE_CHECKSPEED: size = sizeof(struct dns_domain_check_orders); break; + case DOMAIN_RULE_RESPONSE_MODE: + size = sizeof(struct dns_response_mode_rule); + break; case DOMAIN_RULE_CNAME: size = sizeof(struct dns_cname_rule); break; @@ -2388,6 +2391,38 @@ errout: return 0; } +static int _conf_domain_rule_response_mode(char *domain, const char *mode) +{ + enum response_mode_type response_mode_type = DNS_RESPONSE_MODE_FIRST_PING_IP; + struct dns_response_mode_rule *response_mode = NULL; + + for (int i = 0; dns_conf_response_mode_enum[i].name != NULL; i++) { + if (strcmp(mode, dns_conf_response_mode_enum[i].name) == 0) { + response_mode_type = dns_conf_response_mode_enum[i].id; + break; + } + } + + response_mode = _new_dns_rule(DOMAIN_RULE_RESPONSE_MODE); + if (response_mode == NULL) { + goto errout; + } + response_mode->mode = response_mode_type; + + if (_config_domain_rule_add(domain, DOMAIN_RULE_RESPONSE_MODE, response_mode) != 0) { + goto errout; + } + + _dns_rule_put(&response_mode->head); + return 0; +errout: + if (response_mode) { + _dns_rule_put(&response_mode->head); + } + + return 0; +} + static int _conf_domain_set(void *data, int argc, char *argv[]) { int opt = 0; @@ -2527,6 +2562,11 @@ static int _conf_domain_rule_delete(const char *domain) return _config_domain_rule_delete(domain); } +static int _conf_domain_rule_no_cache(const char *domain) +{ + return _config_domain_rule_flag_set(domain, DOMAIN_FLAG_NO_CACHE, 0); +} + static int _conf_domain_rules(void *data, int argc, char *argv[]) { int opt = 0; @@ -2539,6 +2579,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[]) /* clang-format off */ static struct option long_options[] = { {"speed-check-mode", required_argument, NULL, 'c'}, + {"response-mode", required_argument, NULL, 'r'}, {"address", required_argument, NULL, 'a'}, {"ipset", required_argument, NULL, 'p'}, {"nftset", required_argument, NULL, 't'}, @@ -2550,6 +2591,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[]) {"rr-ttl-max", required_argument, NULL, 253}, {"no-serve-expired", no_argument, NULL, 254}, {"delete", no_argument, NULL, 255}, + {"no-cache", no_argument, NULL, 256}, {NULL, no_argument, NULL, 0} }; /* clang-format on */ @@ -2566,7 +2608,7 @@ static int _conf_domain_rules(void *data, int argc, char *argv[]) /* process extra options */ optind = 1; while (1) { - opt = getopt_long_only(argc, argv, "c:a:p:t:n:d:A:", long_options, NULL); + opt = getopt_long_only(argc, argv, "c:a:p:t:n:d:A:r:", long_options, NULL); if (opt == -1) { break; } @@ -2585,6 +2627,19 @@ static int _conf_domain_rules(void *data, int argc, char *argv[]) break; } + case 'r': { + const char *response_mode = optarg; + if (response_mode == NULL) { + goto errout; + } + + if (_conf_domain_rule_response_mode(domain, response_mode) != 0) { + tlog(TLOG_ERROR, "add response-mode rule failed."); + goto errout; + } + + break; + } case 'a': { const char *address = optarg; if (address == NULL) { @@ -2684,6 +2739,14 @@ static int _conf_domain_rules(void *data, int argc, char *argv[]) return 0; } + case 256: { + if (_conf_domain_rule_no_cache(domain) != 0) { + tlog(TLOG_ERROR, "set no-cache rule failed."); + goto errout; + } + + break; + } default: break; } diff --git a/src/dns_conf.h b/src/dns_conf.h index fd4fb2e..3abb71b 100644 --- a/src/dns_conf.h +++ b/src/dns_conf.h @@ -74,6 +74,7 @@ enum domain_rule { DOMAIN_RULE_NFTSET_IP6, DOMAIN_RULE_NAMESERVER, DOMAIN_RULE_CHECKSPEED, + DOMAIN_RULE_RESPONSE_MODE, DOMAIN_RULE_CNAME, DOMAIN_RULE_TTL, DOMAIN_RULE_MAX, @@ -110,6 +111,7 @@ typedef enum { #define DOMAIN_FLAG_NFTSET_IP6_IGN (1 << 14) #define DOMAIN_FLAG_NO_SERVE_EXPIRED (1 << 15) #define DOMAIN_FLAG_CNAME_IGN (1 << 16) +#define DOMAIN_FLAG_NO_CACHE (1 << 17) #define SERVER_FLAG_EXCLUDE_DEFAULT (1 << 0) @@ -124,6 +126,12 @@ typedef enum { #define BIND_FLAG_FORCE_AAAA_SOA (1 << 8) #define BIND_FLAG_NO_RULE_CNAME (1 << 9) +enum response_mode_type { + DNS_RESPONSE_MODE_FIRST_PING_IP = 0, + DNS_RESPONSE_MODE_FASTEST_IP, + DNS_RESPONSE_MODE_FASTEST_RESPONSE, +}; + struct dns_rule { atomic_t refcnt; enum domain_rule rule; @@ -227,6 +235,11 @@ struct dns_domain_check_orders { struct dns_domain_check_order orders[DOMAIN_CHECK_NUM]; }; +struct dns_response_mode_rule { + struct dns_rule head; + enum response_mode_type mode; +}; + struct dns_group_table { DECLARE_HASHTABLE(group, 8); }; @@ -464,11 +477,6 @@ extern int dns_conf_dualstack_ip_allow_force_AAAA; extern int dns_conf_dualstack_ip_selection_threshold; extern int dns_conf_max_reply_ip_num; -enum response_mode_type { - DNS_RESPONSE_MODE_FIRST_PING_IP = 0, - DNS_RESPONSE_MODE_FASTEST_IP, - DNS_RESPONSE_MODE_FASTEST_RESPONSE, -}; extern enum response_mode_type dns_conf_response_mode; extern int dns_conf_rr_ttl; diff --git a/src/dns_server.c b/src/dns_server.c index fd002fd..d99c38f 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -304,10 +304,13 @@ struct dns_request { struct dns_domain_check_orders *check_order_list; int check_order; + enum response_mode_type response_mode; + struct dns_request_pending_list *request_pending_list; int no_select_possible_ip; int no_cache_cname; + int no_cache; }; /* dns server data */ @@ -1269,7 +1272,7 @@ static int _dns_cache_cname_packet(struct dns_server_post_context *context) struct dns_request *request = context->request; - if (request->has_cname == 0 || request->no_cache_cname == 1) { + if (request->has_cname == 0 || request->no_cache_cname == 1 || request->no_cache == 1) { return 0; } @@ -1514,7 +1517,7 @@ static int _dns_cache_reply_packet(struct dns_server_post_context *context) { struct dns_request *request = context->request; int has_soa = request->has_soa; - if (context->do_cache == 0 || _dns_server_has_bind_flag(request, BIND_FLAG_NO_CACHE) == 0) { + if (context->do_cache == 0 || request->no_cache == 1) { return 0; } @@ -2433,6 +2436,7 @@ static struct dns_request *_dns_server_new_request(void) request->qclass = DNS_C_IN; request->result_callback = NULL; request->check_order_list = &dns_conf_check_orders; + request->response_mode = dns_conf_response_mode; INIT_LIST_HEAD(&request->list); INIT_LIST_HEAD(&request->pending_list); INIT_LIST_HEAD(&request->check_list); @@ -2585,7 +2589,7 @@ out: } /* Get first ping result */ - if (dns_conf_response_mode == DNS_RESPONSE_MODE_FIRST_PING_IP && last_rtt == -1 && request->ping_time > 0) { + if (request->response_mode == DNS_RESPONSE_MODE_FIRST_PING_IP && last_rtt == -1 && request->ping_time > 0) { may_complete = 1; } @@ -3422,7 +3426,7 @@ static int dns_server_resolve_callback(const char *domain, dns_result_type rtype return _dns_server_reply_passthrough(&context); } - if (request->prefetch == 0 && dns_conf_response_mode == DNS_RESPONSE_MODE_FASTEST_RESPONSE && + if (request->prefetch == 0 && request->response_mode == DNS_RESPONSE_MODE_FASTEST_RESPONSE && atomic_read(&request->notified) == 0) { struct dns_server_post_context context; int ttl = 0; @@ -3936,6 +3940,10 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request) request->no_serve_expired = 1; } + if ((flags & DOMAIN_FLAG_NO_CACHE) || (_dns_server_has_bind_flag(request, BIND_FLAG_NO_CACHE) == 0)) { + request->no_cache = 1; + } + if (flags & DOMAIN_FLAG_ADDR_IGN) { /* ignore this domain */ goto out; @@ -4432,17 +4440,22 @@ static int _dns_server_qtype_soa(struct dns_request *request) return -1; } -static void _dns_server_process_speed_check_rule(struct dns_request *request) +static void _dns_server_process_speed_rule(struct dns_request *request) { struct dns_domain_check_orders *check_order = NULL; + struct dns_response_mode_rule *response_mode = NULL; - /* get domain rule flag */ + /* get speed check mode */ check_order = _dns_server_get_dns_rule(request, DOMAIN_RULE_CHECKSPEED); - if (check_order == NULL) { - return; + if (check_order != NULL) { + request->check_order_list = check_order; } - request->check_order_list = check_order; + /* get response mode */ + response_mode = _dns_server_get_dns_rule(request, DOMAIN_RULE_RESPONSE_MODE); + if (response_mode != NULL) { + request->response_mode = response_mode->mode; + } } static int _dns_server_get_expired_ttl_reply(struct dns_cache *dns_cache) @@ -5073,7 +5086,7 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve } /* process speed check rule */ - _dns_server_process_speed_check_rule(request); + _dns_server_process_speed_rule(request); /* check and set passthrough */ _dns_server_check_set_passthrough(request); diff --git a/test/cases/test-address.cc b/test/cases/test-address.cc new file mode 100644 index 0000000..733f6e6 --- /dev/null +++ b/test/cases/test-address.cc @@ -0,0 +1,165 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class Address : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(Address, soa) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +speed-check-mode none +address /a.com/#4 +address /b.com/#6 +address /c.com/# +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + EXPECT_EQ(client.GetAuthority()[0].GetData(), + "a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400"); + + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304"); + + ASSERT_TRUE(client.Query("b.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + + ASSERT_TRUE(client.Query("b.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + EXPECT_EQ(client.GetAuthority()[0].GetData(), + "a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400"); + + ASSERT_TRUE(client.Query("c.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "c.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + EXPECT_EQ(client.GetAuthority()[0].GetData(), + "a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400"); + + ASSERT_TRUE(client.Query("c.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 0); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "c.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + EXPECT_EQ(client.GetAuthority()[0].GetData(), + "a.gtld-servers.net. nstld.verisign-grs.com. 1800 1800 900 604800 86400"); +} + +TEST_F(Address, ip) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +speed-check-mode none +address /a.com/10.10.10.10 +address /a.com/64:ff9b::1010:1010 +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "10.10.10.10"); + + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::1010:1010"); +} diff --git a/test/cases/test-bind.cc b/test/cases/test-bind.cc index a9cc0ea..a5b35fd 100644 --- a/test/cases/test-bind.cc +++ b/test/cases/test-bind.cc @@ -84,7 +84,7 @@ cache-persist no)"""); std::cout << client.GetResult() << std::endl; ASSERT_EQ(client.GetAnswerNum(), 1); EXPECT_EQ(client.GetStatus(), "NOERROR"); - EXPECT_GT(client.GetAnswer()[0].GetTTL(), 609); + EXPECT_GE(client.GetAnswer()[0].GetTTL(), 609); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } @@ -119,4 +119,45 @@ cache-persist no)"""); EXPECT_LT(client.GetQueryTime(), 100); EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST(Bind, nocache) +{ + smartdns::MockServer server_upstream; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + unsigned char addr[4] = {1, 2, 3, 4}; + usleep(15 * 1000); + dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr); + request->response_packet->head.rcode = DNS_RC_NOERROR; + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""( +bind [::]:60053 --no-cache +bind-tcp [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + + ASSERT_TRUE(client.Query("a.com", 60053)); + EXPECT_GT(client.GetQueryTime(), 10); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } \ No newline at end of file diff --git a/test/cases/test-cache.cc b/test/cases/test-cache.cc index 5d13e59..a70cde6 100644 --- a/test/cases/test-cache.cc +++ b/test/cases/test-cache.cc @@ -151,7 +151,7 @@ server 127.0.0.1:61053 log-num 0 cache-size 1 rr-ttl-min 600 -rr-ttl-reply-max 5 +rr-ttl-reply-max 6 log-console yes log-level debug cache-persist no)"""); @@ -169,7 +169,62 @@ cache-persist no)"""); ASSERT_EQ(client.GetAnswerNum(), 1); EXPECT_EQ(client.GetStatus(), "NOERROR"); EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); - EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5); + EXPECT_GE(client.GetAnswer()[0].GetTTL(), 5); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } +TEST_F(Cache, nocache) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + std::string domain = request->domain; + if (request->domain.length() == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + usleep(15000); + + if (request->qtype == DNS_T_A) { + unsigned char addr[4] = {1, 2, 3, 4}; + dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr); + } else if (request->qtype == DNS_T_AAAA) { + unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 0, addr); + } else { + return smartdns::SERVER_REQUEST_ERROR; + } + + request->response_packet->head.rcode = DNS_RC_NOERROR; + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +cache-size 100 +rr-ttl-min 600 +rr-ttl-reply-max 5 +log-console yes +log-level debug +domain-rules /a.com/ --no-cache +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + + ASSERT_TRUE(client.Query("a.com", 60053)); + EXPECT_GT(client.GetQueryTime(), 10); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} diff --git a/test/cases/test-cname.cc b/test/cases/test-cname.cc index b6df617..b7aadda 100644 --- a/test/cases/test-cname.cc +++ b/test/cases/test-cname.cc @@ -28,22 +28,13 @@ TEST(server, cname) smartdns::Server server; server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { - std::string domain = request->domain; - if (request->domain.length() == 0) { - return smartdns::SERVER_REQUEST_ERROR; + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; } - if (request->qtype == DNS_T_A) { - unsigned char addr[4] = {1, 2, 3, 4}; - dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr); - } else if (request->qtype == DNS_T_AAAA) { - unsigned char addr[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr); - } else { - return smartdns::SERVER_REQUEST_ERROR; - } - - EXPECT_EQ(domain, "e.com"); + unsigned char addr[4] = {1, 2, 3, 4}; + dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr); + EXPECT_EQ(request->domain, "e.com"); request->response_packet->head.rcode = DNS_RC_NOERROR; return smartdns::SERVER_REQUEST_OK; diff --git a/test/cases/test-dns64.cc b/test/cases/test-dns64.cc new file mode 100644 index 0000000..8649281 --- /dev/null +++ b/test/cases/test-dns64.cc @@ -0,0 +1,65 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class DNS64 : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(DNS64, no_dualstack) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +dns64 64:ff9b::/96 +log-console yes +dualstack-ip-selection no +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304"); +} diff --git a/test/cases/test-domain-set.cc b/test/cases/test-domain-set.cc new file mode 100644 index 0000000..7799536 --- /dev/null +++ b/test/cases/test-domain-set.cc @@ -0,0 +1,88 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class DomainSet : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(DomainSet, set_add) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + smartdns::TempFile file_set; + std::vector domain_list; + int count = 16; + std::string config = "domain-set -name test-set -file " + file_set.GetPath() + "\n"; + config += R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level info +cache-persist no +domain-rules /domain-set:test-set/ -c none --dualstack-ip-selection no -a 9.9.9.9 +)"""; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + for (int i = 0; i < count; i++) { + auto domain = smartdns::GenerateRandomString(10) + "." + smartdns::GenerateRandomString(3); + file_set.Write(domain); + file_set.Write("\n"); + domain_list.emplace_back(domain); + } + + std::cout << config << std::endl; + server.Start(config); + smartdns::Client client; + + for (auto &domain : domain_list) { + ASSERT_TRUE(client.Query(domain, 60053)); + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), domain); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "9.9.9.9"); + } + + ASSERT_TRUE(client.Query("a.com", 60053)); + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} diff --git a/test/cases/test-speed-check.cc b/test/cases/test-speed-check.cc new file mode 100644 index 0000000..ee10e75 --- /dev/null +++ b/test/cases/test-speed-check.cc @@ -0,0 +1,191 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class SpeedCheck : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(SpeedCheck, response_mode) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +response-mode first-ping +domain-rules /a.com/ -r fastest-response +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_GT(client.GetQueryTime(), 100); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 10); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8"); +} + +TEST_F(SpeedCheck, none) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode none +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8"); +} + +TEST_F(SpeedCheck, domain_rules_none) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +domain-rules /a.com/ -c none +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_GT(client.GetQueryTime(), 200); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8"); +} + +TEST_F(SpeedCheck, only_ping) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 1200); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); +} diff --git a/test/client.cc b/test/client.cc index ec43412..9a9e952 100644 --- a/test/client.cc +++ b/test/client.cc @@ -255,6 +255,11 @@ bool Client::ParserResult() answer_num_ = std::stoi(match[1]); } + std::regex reg_authority_num(", AUTHORITY: ([0-9]+),"); + if (std::regex_search(result_, match, reg_authority_num)) { + records_authority_num_ = std::stoi(match[1]); + } + std::regex reg_status(", status: ([A-Z]+),"); if (std::regex_search(result_, match, reg_status)) { status_ = match[1]; @@ -301,6 +306,19 @@ bool Client::ParserResult() } } + std::regex reg_authority(";; AUTHORITY SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}", + std::regex::ECMAScript | std::regex::optimize); + if (std::regex_search(result_, match, reg_authority)) { + if (ParserRecord(match[1], records_authority_) == false) { + return false; + } + + if (records_authority_num_ != records_authority_.size()) { + std::cout << "DIG FAILED: Num Not Match\n" << result_ << std::endl; + return false; + } + } + std::regex reg_addition(";; ADDITIONAL SECTION:\\n((?:.|\\n|\\r\\n)+?)\\n{2,}", std::regex::ECMAScript | std::regex::optimize); if (std::regex_search(result_, match, reg_answer)) { diff --git a/test/client.h b/test/client.h index 797ba03..ae2ca60 100644 --- a/test/client.h +++ b/test/client.h @@ -90,6 +90,7 @@ class Client bool ParserRecord(const std::string &record_str, std::vector &record); std::string result_; int answer_num_{0}; + int records_authority_num_{0}; std::string status_; std::string server_; int query_time_{0}; diff --git a/test/include/utils.h b/test/include/utils.h index d932fa5..ef228df 100644 --- a/test/include/utils.h +++ b/test/include/utils.h @@ -19,8 +19,10 @@ #ifndef _SMARTDNS_TEST_UTILS_ #define _SMARTDNS_TEST_UTILS_ +#include #include #include +#include namespace smartdns { @@ -56,9 +58,54 @@ class DeferGuard #define SMARTDNS_CONCAT(a, b) SMARTDNS_CONCAT_(a, b) #define Defer ::smartdns::DeferGuard SMARTDNS_CONCAT(__defer__, __LINE__) = [&]() +class TempFile +{ + public: + TempFile(); + TempFile(const std::string &line); + virtual ~TempFile(); + + bool Write(const std::string &line); + + std::string GetPath(); + + void SetPattern(const std::string &pattern); + + private: + bool NewTempFile(); + std::string path_; + std::ofstream ofs_; + std::string pattern_; +}; + +class Commander +{ + public: + Commander(); + virtual ~Commander(); + + bool Run(const std::vector &cmds); + + bool Run(const std::string &cmd); + + void Kill(); + + void Terminate(); + + int ExitCode(); + + int GetPid(); + + private: + pid_t pid_{-1}; + int exit_code_ = {-1}; +}; + bool IsCommandExists(const std::string &cmd); std::string GenerateRandomString(int len); +int ParserArg(const std::string &cmd, std::vector &args); + } // namespace smartdns #endif // _SMARTDNS_TEST_UTILS_ diff --git a/test/server.cc b/test/server.cc index 0224546..67b7995 100644 --- a/test/server.cc +++ b/test/server.cc @@ -283,7 +283,8 @@ bool MockServer::Start(const std::string &url, ServerRequest callback) return true; } -Server::Server() { +Server::Server() +{ mode_ = Server::CREATE_MODE_FORK; } @@ -312,26 +313,9 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) }; if (type == CONF_TYPE_STRING) { - char filename[128]; - strncpy(filename, "/tmp/smartdns_conf.XXXXXX", sizeof(filename)); - int fd = mkstemp(filename); - if (fd < 0) { - return false; - } - Defer - { - close(fd); - }; - - std::ofstream ofs(filename); - if (ofs.is_open() == false) { - return false; - } - ofs.write(conf.data(), conf.size()); - ofs.flush(); - ofs.close(); - conf_file = filename; - clean_conf_file_ = true; + conf_temp_file_.SetPattern("/tmp/smartdns_conf.XXXXXX"); + conf_temp_file_.Write(conf); + conf_file = conf_temp_file_.GetPath(); } else if (type == CONF_TYPE_FILE) { conf_file = conf; } else { @@ -418,11 +402,6 @@ void Server::Stop(bool graceful) waitpid(pid_, nullptr, 0); pid_ = 0; - if (clean_conf_file_ == true) { - unlink(conf_file_.c_str()); - conf_file_.clear(); - clean_conf_file_ = false; - } } bool Server::IsRunning() diff --git a/test/server.h b/test/server.h index d21f36b..e29a58d 100644 --- a/test/server.h +++ b/test/server.h @@ -20,6 +20,7 @@ #define _SMARTDNS_SERVER_ #include "dns.h" +#include "include/utils.h" #include #include #include @@ -53,7 +54,8 @@ class Server std::thread thread_; int fd_; std::string conf_file_; - bool clean_conf_file_{false}; + TempFile conf_temp_file_; + enum CREATE_MODE mode_; }; diff --git a/test/utils.cc b/test/utils.cc index d77c6c9..7409beb 100644 --- a/test/utils.cc +++ b/test/utils.cc @@ -1,12 +1,179 @@ #include "include/utils.h" +#include #include #include #include +#include #include namespace smartdns { +TempFile::TempFile() +{ + pattern_ = "/tmp/smartdns-test-tmp.XXXXXX"; +} + +TempFile::TempFile(const std::string &line) +{ + pattern_ = "/tmp/smartdns-test-tmp.XXXXXX"; +} + +TempFile::~TempFile() +{ + if (ofs_.is_open()) { + ofs_.close(); + ofs_.clear(); + } + + if (path_.length() > 0) { + unlink(path_.c_str()); + } +} + +void TempFile::SetPattern(const std::string &pattern) +{ + pattern_ = pattern; +} + +bool TempFile::Write(const std::string &line) +{ + if (ofs_.is_open() == false) { + if (NewTempFile() == false) { + return false; + } + } + + ofs_.write(line.data(), line.size()); + if (ofs_.fail()) { + return false; + } + ofs_.flush(); + + return true; +} + +bool TempFile::NewTempFile() +{ + char filename[128]; + strncpy(filename, "/tmp/smartdns-test-tmp.XXXXXX", sizeof(filename)); + int fd = mkstemp(filename); + if (fd < 0) { + return false; + } + Defer + { + close(fd); + }; + + std::ofstream ofs(filename); + if (ofs.is_open() == false) { + return false; + } + ofs_ = std::move(ofs); + path_ = filename; + + return true; +} + +std::string TempFile::GetPath() +{ + if (ofs_.is_open() == false) { + if (NewTempFile() == false) { + return ""; + } + } + + return path_; +} + +Commander::Commander() {} +Commander::~Commander() +{ + Kill(); +} + +bool Commander::Run(const std::string &cmd) +{ + std::vector args; + if (ParserArg(cmd, args) != 0) { + return false; + } + + return Run(args); +} + +bool Commander::Run(const std::vector &cmds) +{ + pid_t pid; + + if (pid_ > 0) { + return false; + } + + pid = fork(); + if (pid < 0) { + return false; + } + + if (pid == 0) { + char *argv[cmds.size() + 1]; + for (int i = 0; i < cmds.size(); i++) { + argv[i] = (char *)cmds[i].c_str(); + } + argv[cmds.size()] = nullptr; + execvp(argv[0], argv); + _exit(1); + } + + pid_ = pid; + + return true; +} + +void Commander::Kill() +{ + if (pid_ <= 0) { + return; + } + + kill(pid_, SIGKILL); +} + +void Commander::Terminate() +{ + if (pid_ <= 0) { + return; + } + + kill(pid_, SIGTERM); +} + +int Commander::ExitCode() +{ + int wstatus = 0; + if (exit_code_ >= 0) { + return exit_code_; + } + + if (pid_ <= 0) { + return -1; + } + + if (waitpid(pid_, &wstatus, 0) == -1) { + return -1; + } + + exit_code_ = WEXITSTATUS(wstatus); + + return exit_code_; +} + +int Commander::GetPid() +{ + return pid_; +} + bool IsCommandExists(const std::string &cmd) { char *copy_path = nullptr; @@ -42,17 +209,61 @@ bool IsCommandExists(const std::string &cmd) std::string GenerateRandomString(int len) { - std::string result; + std::string result; static const char alphanum[] = "0123456789" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz"; - result.resize(len); + result.resize(len); - for (int i = 0; i < len; ++i) { - result[i] = alphanum[rand() % (sizeof(alphanum) - 1)]; - } + for (int i = 0; i < len; ++i) { + result[i] = alphanum[rand() % (sizeof(alphanum) - 1)]; + } - return result; + return result; +} + +int ParserArg(const std::string &cmd, std::vector &args) +{ + std::string arg; + char quoteChar = 0; + + for (char ch : cmd) { + if (quoteChar == '\\') { + arg.push_back(ch); + quoteChar = 0; + continue; + } + + if (quoteChar && ch != quoteChar) { + arg.push_back(ch); + continue; + } + + switch (ch) { + case '\'': + case '\"': + case '\\': + quoteChar = quoteChar ? 0 : ch; + break; + case ' ': + case '\t': + case '\n': + if (!arg.empty()) { + args.push_back(arg); + arg.clear(); + } + break; + default: + arg.push_back(ch); + break; + } + } + + if (!arg.empty()) { + args.push_back(arg); + } + + return 0; } } // namespace smartdns \ No newline at end of file