diff --git a/src/dns_client.c b/src/dns_client.c index 9294012..3447447 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -3704,7 +3704,8 @@ int dns_client_query(const char *domain, int qtype, dns_client_callback callback } pthread_mutex_unlock(&client.domain_map_lock); - tlog(TLOG_INFO, "request: %s, qtype: %d, id: %d, group: %s", domain, qtype, query->sid, query->server_group->group_name); + tlog(TLOG_INFO, "request: %s, qtype: %d, id: %d, group: %s", domain, qtype, query->sid, + query->server_group->group_name); _dns_client_query_release(query); return 0; @@ -3829,13 +3830,18 @@ static void _dns_client_remove_all_pending_servers(void) static void _dns_client_add_pending_servers(void) { +#ifdef TEST + const int delay_value = 1; +#else + const int delay_value = 3; +#endif struct dns_server_pending *pending = NULL; struct dns_server_pending *tmp = NULL; - static int delay = 0; + static int delay = delay_value; LIST_HEAD(retry_list); /* add pending server after 3 seconds */ - if (++delay < 3) { + if (++delay < delay_value) { return; } delay = 0; diff --git a/src/dns_conf.c b/src/dns_conf.c index 07fffc9..b823c90 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -3336,6 +3336,8 @@ void dns_server_load_exit(void) _config_host_table_destroy(); _config_qtype_soa_table_destroy(); _config_proxy_table_destroy(); + + dns_conf_server_num = 0; } static int _dns_conf_speed_check_mode_verify(void) diff --git a/src/dns_server.c b/src/dns_server.c index d99c38f..0a36bfc 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -844,9 +844,12 @@ static int _dns_rrs_add_all_best_ip(struct dns_server_post_context *context) continue; } - int ttl_range = request->ping_time + request->ping_time / 10; - if ((ttl_range < addr_map->ping_time) && addr_map->ping_time >= 100 && ignore_speed == 0) { - continue; + /* if ping time is larger than 5ms, check again. */ + if (addr_map->ping_time - request->ping_time >= 50) { + int ttl_range = request->ping_time + request->ping_time / 10 + 5; + if ((ttl_range < addr_map->ping_time) && addr_map->ping_time >= 100 && ignore_speed == 0) { + continue; + } } context->ip_num++; @@ -4700,7 +4703,7 @@ out: return ret; } -static void _dns_server_check_ipv6_ready(void) +void dns_server_check_ipv6_ready(void) { static int do_get_conf = 0; static int is_icmp_check_set; @@ -6231,7 +6234,7 @@ static void _dns_server_period_run_second(void) _dns_server_check_need_exit(); if (sec % IPV6_READY_CHECK_TIME == 0 && is_ipv6_ready == 0) { - _dns_server_check_ipv6_ready(); + dns_server_check_ipv6_ready(); } if (sec % 60 == 0) { @@ -6931,7 +6934,7 @@ int dns_server_init(void) goto errout; } - _dns_server_check_ipv6_ready(); + dns_server_check_ipv6_ready(); tlog(TLOG_INFO, "%s", (is_ipv6_ready) ? "IPV6 is ready, enable IPV6 features" : "IPV6 is not ready, disable IPV6 features"); diff --git a/src/dns_server.h b/src/dns_server.h index 3c40b5d..70cebdc 100644 --- a/src/dns_server.h +++ b/src/dns_server.h @@ -37,6 +37,8 @@ struct dns_server_query_option { int dns_is_ipv6_ready(void); +void dns_server_check_ipv6_ready(void); + int dns_server_init(void); int dns_server_run(void); diff --git a/src/fast_ping.c b/src/fast_ping.c index c9a657c..5850ef5 100644 --- a/src/fast_ping.c +++ b/src/fast_ping.c @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -95,6 +96,18 @@ struct fast_ping_packet { struct fast_ping_packet_msg msg; }; +struct fast_ping_fake_ip { + struct hlist_node node; + atomic_t ref; + PING_TYPE type; + FAST_PING_TYPE ping_type; + char host[PING_MAX_HOSTLEN]; + int ttl; + float time; + struct sockaddr_storage addr; + int addr_len; +}; + struct ping_host_struct { atomic_t ref; atomic_t notified; @@ -127,6 +140,9 @@ struct ping_host_struct { }; socklen_t addr_len; struct fast_ping_packet packet; + + struct fast_ping_fake_ip *fake; + int fake_time_fd; }; struct fast_ping_notify_event { @@ -163,6 +179,8 @@ struct fast_ping_struct { pthread_mutex_t map_lock; DECLARE_HASHTABLE(addrmap, 6); + DECLARE_HASHTABLE(fake, 6); + int fake_ip_num; }; static struct fast_ping_struct ping; @@ -170,6 +188,8 @@ static atomic_t ping_sid = ATOMIC_INIT(0); static int bool_print_log = 1; static void _fast_ping_host_put(struct ping_host_struct *ping_host); +static int _fast_ping_get_addr_by_type(PING_TYPE type, const char *ip_str, int port, struct addrinfo **out_gai, + FAST_PING_TYPE *out_ping_type); static void _fast_ping_wakeup_thread(void) { @@ -376,6 +396,179 @@ errout: return -1; } +static void _fast_ping_fake_put(struct fast_ping_fake_ip *fake) +{ + int ref_cnt = atomic_dec_and_test(&fake->ref); + if (!ref_cnt) { + if (ref_cnt < 0) { + tlog(TLOG_ERROR, "invalid refcount of fake ping %s", fake->host); + abort(); + } + return; + } + + pthread_mutex_lock(&ping.map_lock); + if (hash_hashed(&fake->node)) { + hash_del(&fake->node); + } + pthread_mutex_unlock(&ping.map_lock); + + free(fake); +} + +static void _fast_ping_fake_remove(struct fast_ping_fake_ip *fake) +{ + pthread_mutex_lock(&ping.map_lock); + if (hash_hashed(&fake->node)) { + hash_del(&fake->node); + } + pthread_mutex_unlock(&ping.map_lock); + + _fast_ping_fake_put(fake); +} + +static void _fast_ping_fake_get(struct fast_ping_fake_ip *fake) +{ + atomic_inc(&fake->ref); +} + +static struct fast_ping_fake_ip *_fast_ping_fake_find(FAST_PING_TYPE ping_type, struct sockaddr *addr, int addr_len) +{ + struct fast_ping_fake_ip *fake = NULL; + struct fast_ping_fake_ip *ret = NULL; + uint32_t key = 0; + + if (ping.fake_ip_num == 0) { + return NULL; + } + + key = jhash(addr, addr_len, 0); + key = jhash(&ping_type, sizeof(ping_type), key); + pthread_mutex_lock(&ping.map_lock); + hash_for_each_possible(ping.fake, fake, node, key) + { + if (fake->ping_type != ping_type) { + continue; + } + + if (fake->addr_len != addr_len) { + continue; + } + + if (memcmp(&fake->addr, addr, fake->addr_len) != 0) { + continue; + } + + ret = fake; + _fast_ping_fake_get(fake); + break; + } + pthread_mutex_unlock(&ping.map_lock); + return ret; +} + +int fast_ping_fake_ip_add(PING_TYPE type, const char *host, int ttl, float time) +{ + struct fast_ping_fake_ip *fake = NULL; + struct fast_ping_fake_ip *fake_old = NULL; + char ip_str[PING_MAX_HOSTLEN]; + int port = -1; + FAST_PING_TYPE ping_type = FAST_PING_END; + uint32_t key = 0; + int ret = -1; + struct addrinfo *gai = NULL; + + if (parse_ip(host, ip_str, &port) != 0) { + goto errout; + } + + ret = _fast_ping_get_addr_by_type(type, ip_str, port, &gai, &ping_type); + if (ret != 0) { + goto errout; + } + + fake_old = _fast_ping_fake_find(ping_type, gai->ai_addr, gai->ai_addrlen); + fake = malloc(sizeof(*fake)); + if (fake == NULL) { + goto errout; + } + memset(fake, 0, sizeof(*fake)); + + safe_strncpy(fake->host, ip_str, PING_MAX_HOSTLEN); + fake->ttl = ttl; + fake->time = time; + fake->type = type; + fake->ping_type = ping_type; + memcpy(&fake->addr, gai->ai_addr, gai->ai_addrlen); + fake->addr_len = gai->ai_addrlen; + INIT_HLIST_NODE(&fake->node); + atomic_set(&fake->ref, 1); + + key = jhash(&fake->addr, fake->addr_len, 0); + key = jhash(&ping_type, sizeof(ping_type), key); + pthread_mutex_lock(&ping.map_lock); + hash_add(ping.fake, &fake->node, key); + pthread_mutex_unlock(&ping.map_lock); + ping.fake_ip_num++; + + if (fake_old != NULL) { + _fast_ping_fake_put(fake_old); + _fast_ping_fake_remove(fake_old); + } + + freeaddrinfo(gai); + return 0; +errout: + if (fake != NULL) { + free(fake); + } + + if (fake_old != NULL) { + _fast_ping_fake_put(fake_old); + } + + if (gai != NULL) { + freeaddrinfo(gai); + } + + return -1; +} + +int fast_ping_fake_ip_remove(PING_TYPE type, const char *host) +{ + struct fast_ping_fake_ip *fake = NULL; + char ip_str[PING_MAX_HOSTLEN]; + int port = -1; + int ret = -1; + FAST_PING_TYPE ping_type = FAST_PING_END; + struct addrinfo *gai = NULL; + + if (parse_ip(host, ip_str, &port) != 0) { + return -1; + } + + ret = _fast_ping_get_addr_by_type(type, ip_str, port, &gai, &ping_type); + if (ret != 0) { + goto errout; + } + + fake = _fast_ping_fake_find(ping_type, gai->ai_addr, gai->ai_addrlen); + if (fake == NULL) { + goto errout; + } + + _fast_ping_fake_remove(fake); + _fast_ping_fake_put(fake); + ping.fake_ip_num--; + freeaddrinfo(gai); + return 0; +errout: + if (gai != NULL) { + freeaddrinfo(gai); + } + return -1; +} + static void _fast_ping_host_get(struct ping_host_struct *ping_host) { if (atomic_inc_return(&ping_host->ref) <= 0) { @@ -386,6 +579,15 @@ static void _fast_ping_host_get(struct ping_host_struct *ping_host) static void _fast_ping_close_host_sock(struct ping_host_struct *ping_host) { + if (ping_host->fake_time_fd > 0) { + struct epoll_event *event = NULL; + event = (struct epoll_event *)1; + epoll_ctl(ping.epoll_fd, EPOLL_CTL_DEL, ping_host->fake_time_fd, event); + + close(ping_host->fake_time_fd); + ping_host->fake_time_fd = -1; + } + if (ping_host->fd < 0) { return; } @@ -455,6 +657,10 @@ static void _fast_ping_host_put(struct ping_host_struct *ping_host) } _fast_ping_close_host_sock(ping_host); + if (ping_host->fake != NULL) { + _fast_ping_fake_put(ping_host->fake); + ping_host->fake = NULL; + } pthread_mutex_lock(&ping.map_lock); hash_del(&ping_host->addr_node); @@ -546,6 +752,38 @@ errout: return -1; } +static int _fast_ping_send_fake(struct ping_host_struct *ping_host, struct fast_ping_fake_ip *fake) +{ + struct itimerspec its; + int sec = fake->time / 1000; + int cent_usec = ((long)(fake->time * 10)) % 10000; + its.it_value.tv_sec = sec; + its.it_value.tv_nsec = cent_usec * 1000 * 100; + its.it_interval.tv_sec = 0; + its.it_interval.tv_nsec = 0; + + if (timerfd_settime(ping_host->fake_time_fd, 0, &its, NULL) < 0) { + tlog(TLOG_ERROR, "timerfd_settime failed, %s", strerror(errno)); + goto errout; + } + + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.ptr = ping_host; + if (epoll_ctl(ping.epoll_fd, EPOLL_CTL_ADD, ping_host->fake_time_fd, &ev) == -1) { + if (errno != EEXIST) { + goto errout; + } + } + + ping_host->seq++; + + return 0; + +errout: + return -1; +} + static int _fast_ping_sendping_v4(struct ping_host_struct *ping_host) { struct fast_ping_packet *packet = &ping_host->packet; @@ -710,8 +948,16 @@ errout: static int _fast_ping_sendping(struct ping_host_struct *ping_host) { int ret = -1; + struct fast_ping_fake_ip *fake = NULL; gettimeofday(&ping_host->last, NULL); + fake = _fast_ping_fake_find(ping_host->type, &ping_host->addr, ping_host->addr_len); + if (fake) { + ret = _fast_ping_send_fake(ping_host, fake); + _fast_ping_fake_put(fake); + return ret; + } + if (ping_host->type == FAST_PING_ICMP) { ret = _fast_ping_sendping_v4(ping_host); } else if (ping_host->type == FAST_PING_ICMP6) { @@ -1010,13 +1256,18 @@ static int _fast_ping_get_addr_by_icmp(const char *ip_str, int port, struct addr goto errout; } - gai = _fast_ping_getaddr(ip_str, service, socktype, sockproto); - if (gai == NULL) { - goto errout; + if (out_gai != NULL) { + gai = _fast_ping_getaddr(ip_str, service, socktype, sockproto); + if (gai == NULL) { + goto errout; + } + + *out_gai = gai; } - *out_gai = gai; - *out_ping_type = ping_type; + if (out_ping_type != NULL) { + *out_ping_type = ping_type; + } return 0; errout: @@ -1150,6 +1401,8 @@ struct ping_host_struct *fast_ping_start(PING_TYPE type, const char *host, int c FAST_PING_TYPE ping_type = FAST_PING_END; unsigned int seed = 0; int ret = 0; + struct fast_ping_fake_ip *fake = NULL; + int fake_time_fd = -1; if (parse_ip(host, ip_str, &port) != 0) { goto errout; @@ -1194,6 +1447,19 @@ struct ping_host_struct *fast_ping_start(PING_TYPE type, const char *host, int c tlog(TLOG_DEBUG, "ping %s, id = %d", host, ping_host->sid); + fake = _fast_ping_fake_find(ping_host->type, gai->ai_addr, gai->ai_addrlen); + if (fake) { + fake_time_fd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC); + if (fake_time_fd < 0) { + tlog(TLOG_ERROR, "timerfd_create failed, %s", strerror(errno)); + goto errout; + } + /* already take ownership by find. */ + ping_host->fake = fake; + ping_host->fake_time_fd = fake_time_fd; + fake = NULL; + } + addrkey = _fast_ping_hash_key(ping_host->sid, &ping_host->addr); _fast_ping_host_get(ping_host); @@ -1229,6 +1495,14 @@ errout: free(ping_host); } + if (fake_time_fd > 0) { + close(fake_time_fd); + } + + if (fake) { + _fast_ping_fake_put(fake); + } + return NULL; } @@ -1365,6 +1639,33 @@ errout: return NULL; } +static int _fast_ping_process_fake(struct ping_host_struct *ping_host, struct timeval *now) +{ + struct timeval tvresult = *now; + struct timeval *tvsend = &ping_host->last; + uint64_t exp; + int ret; + + ret = read(ping_host->fake_time_fd, &exp, sizeof(uint64_t)); + if (ret < 0) { + return -1; + } + + ping_host->ttl = ping_host->fake->ttl; + tv_sub(&tvresult, tvsend); + if (ping_host->ping_callback) { + _fast_ping_send_notify_event(ping_host, PING_RESULT_RESPONSE, ping_host->seq, ping_host->ttl, &tvresult); + } + + ping_host->send = 0; + + if (ping_host->count == 1) { + _fast_ping_host_remove(ping_host); + } + + return 0; +} + static int _fast_ping_process_icmp(struct ping_host_struct *ping_host, struct timeval *now) { int len = 0; @@ -1592,6 +1893,11 @@ static int _fast_ping_process(struct ping_host_struct *ping_host, struct epoll_e { int ret = -1; + if (ping_host->fake != NULL) { + ret = _fast_ping_process_fake(ping_host, now); + return ret; + } + switch (ping_host->type) { case FAST_PING_ICMP6: case FAST_PING_ICMP: @@ -1635,6 +1941,18 @@ static void _fast_ping_remove_all(void) } } +static void _fast_ping_remove_all_fake_ip(void) +{ + struct fast_ping_fake_ip *fake = NULL; + struct hlist_node *tmp = NULL; + unsigned long i = 0; + + hash_for_each_safe(ping.fake, i, tmp, fake, node) + { + _fast_ping_fake_put(fake); + } +} + static void _fast_ping_period_run(void) { struct ping_host_struct *ping_host = NULL; @@ -1890,6 +2208,7 @@ int fast_ping_init(void) INIT_LIST_HEAD(&ping.notify_event_list); hash_init(ping.addrmap); + hash_init(ping.fake); ping.no_unprivileged_ping = !has_unprivileged_ping(); ping.ident = (getpid() & 0XFFFF); atomic_set(&ping.run, 1); @@ -1998,6 +2317,7 @@ void fast_ping_exit(void) _fast_ping_close_fds(); _fast_ping_remove_all(); + _fast_ping_remove_all_fake_ip(); _fast_ping_remove_all_notify_event(); pthread_cond_destroy(&ping.notify_cond); diff --git a/src/fast_ping.h b/src/fast_ping.h index 8c2237a..33231f8 100644 --- a/src/fast_ping.h +++ b/src/fast_ping.h @@ -47,6 +47,10 @@ typedef void (*fast_ping_result)(struct ping_host_struct *ping_host, const char struct ping_host_struct *fast_ping_start(PING_TYPE type, const char *host, int count, int interval, int timeout, fast_ping_result ping_callback, void *userptr); +int fast_ping_fake_ip_add(PING_TYPE type, const char *host, int ttl, float time); + +int fast_ping_fake_ip_remove(PING_TYPE type, const char *host); + /* stop ping */ int fast_ping_stop(struct ping_host_struct *ping_host); diff --git a/src/smartdns.c b/src/smartdns.c index 837578b..dfcd173 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -17,6 +17,7 @@ */ #define _GNU_SOURCE +#include "smartdns.h" #include "art.h" #include "atomic.h" #include "dns_client.h" @@ -304,8 +305,7 @@ static int _smartdns_add_servers(void) } if (_smartdns_prepare_server_flags(&flags, server) != 0) { - tlog(TLOG_ERROR, "prepare server flags failed, %s:%d", server->server, - server->port); + tlog(TLOG_ERROR, "prepare server flags failed, %s:%d", server->server, server->port); return -1; } @@ -647,16 +647,31 @@ static int _smartdns_init_pre(void) } #ifdef TEST + +static smartdns_post_func _smartdns_post = NULL; +static void *_smartdns_post_arg = NULL; + +int smartdns_reg_post_func(smartdns_post_func func, void *arg) +{ + _smartdns_post = func; + _smartdns_post_arg = arg; + return 0; +} + #define smartdns_test_notify(retval) smartdns_test_notify_func(fd_notify, retval) -static void smartdns_test_notify_func(int fd_notify, uint64_t retval) { +static void smartdns_test_notify_func(int fd_notify, uint64_t retval) +{ /* notify parent kickoff */ if (fd_notify > 0) { write(fd_notify, &retval, sizeof(retval)); } + + if (_smartdns_post != NULL) { + _smartdns_post(_smartdns_post_arg); + } } -int smartdns_main(int argc, char *argv[], int fd_notify); -int smartdns_main(int argc, char *argv[], int fd_notify) +int smartdns_main(int argc, char *argv[], int fd_notify) #else #define smartdns_test_notify(retval) int main(int argc, char *argv[]) diff --git a/src/smartdns.h b/src/smartdns.h new file mode 100644 index 0000000..519d646 --- /dev/null +++ b/src/smartdns.h @@ -0,0 +1,39 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#ifndef SMART_DNS_H +#define SMART_DNS_H + +#ifdef __cplusplus +extern "C" { +#endif /*__cplusplus */ + +#ifdef TEST + +typedef void (*smartdns_post_func)(void *arg); + +int smartdns_reg_post_func(smartdns_post_func func, void *arg); + +int smartdns_main(int argc, char *argv[], int fd_notify); + +#endif + +#ifdef __cplusplus +} +#endif /*__cplusplus */ +#endif diff --git a/src/util.c b/src/util.c index 6d5582c..d022ad9 100644 --- a/src/util.c +++ b/src/util.c @@ -310,7 +310,7 @@ int parse_ip(const char *value, char *ip, int *port) return 0; } -static int _check_is_ipv4(const char *ip) +int check_is_ipv4(const char *ip) { const char *ptr = ip; char c = 0; @@ -344,7 +344,8 @@ static int _check_is_ipv4(const char *ip) return 0; } -static int _check_is_ipv6(const char *ip) + +int check_is_ipv6(const char *ip) { const char *ptr = ip; char c = 0; @@ -394,10 +395,10 @@ int check_is_ipaddr(const char *ip) { if (strstr(ip, ".")) { /* IPV4 */ - return _check_is_ipv4(ip); + return check_is_ipv4(ip); } else if (strstr(ip, ":")) { /* IPV6 */ - return _check_is_ipv6(ip); + return check_is_ipv6(ip); } return -1; } diff --git a/src/util.h b/src/util.h index 570b047..b7d9522 100644 --- a/src/util.h +++ b/src/util.h @@ -69,6 +69,10 @@ int parse_ip(const char *value, char *ip, int *port); int check_is_ipaddr(const char *ip); +int check_is_ipv4(const char *ip); + +int check_is_ipv6(const char *ip); + int parse_uri(const char *value, char *scheme, char *host, int *port, char *path); int parse_uri_ext(const char *value, char *scheme, char *user, char *password, char *host, int *port, char *path); diff --git a/test/Makefile b/test/Makefile index 8fa1f73..d89847e 100644 --- a/test/Makefile +++ b/test/Makefile @@ -20,6 +20,7 @@ CFLAGS += -DTEST CFLAGS += -g -Wall -Wstrict-prototypes -fno-omit-frame-pointer -Wstrict-aliasing -funwind-tables -Wmissing-prototypes -Wshadow -Wextra -Wno-unused-parameter -Wno-implicit-fallthrough CXXFLAGS += -g +CXXFLAGS += -DTEST CXXFLAGS += -I./ -I../src -I../src/include SMARTDNS_OBJS = lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/conf.o lib/nftset.o diff --git a/test/cases/test-address.cc b/test/cases/test-address.cc index 733f6e6..15f51fd 100644 --- a/test/cases/test-address.cc +++ b/test/cases/test-address.cc @@ -61,7 +61,7 @@ cache-persist no)"""); smartdns::Client client; ASSERT_TRUE(client.Query("a.com A", 60053)); std::cout << client.GetResult() << std::endl; - ASSERT_EQ(client.GetAnswerNum(), 0); + ASSERT_EQ(client.GetAuthorityNum(), 1); EXPECT_EQ(client.GetStatus(), "NOERROR"); EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); diff --git a/test/cases/test-bind.cc b/test/cases/test-bind.cc index a5b35fd..92ea314 100644 --- a/test/cases/test-bind.cc +++ b/test/cases/test-bind.cc @@ -160,4 +160,35 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); -} \ No newline at end of file +} + +TEST(Bind, device) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""( +bind [::]:60053@lo +server 127.0.0.1:62053 +log-num 0 +log-console yes +log-level info +cache-persist no)"""); + smartdns::Client client; + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 100); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} diff --git a/test/cases/test-bootstrap.cc b/test/cases/test-bootstrap.cc new file mode 100644 index 0000000..bf3310b --- /dev/null +++ b/test/cases/test-bootstrap.cc @@ -0,0 +1,71 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" + +class BootStrap : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(BootStrap, bootstrap) +{ + smartdns::MockServer server_upstream; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611); + return smartdns::SERVER_REQUEST_OK; + }); + + server_upstream2.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "127.0.0.1", 611); + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server udp://127.0.0.1:62053 -bootstrap-dns +server udp://example.com:61053 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + usleep(2500000); + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} diff --git a/test/cases/test-cache.cc b/test/cases/test-cache.cc index a70cde6..a564c0e 100644 --- a/test/cases/test-cache.cc +++ b/test/cases/test-cache.cc @@ -21,6 +21,11 @@ #include "include/utils.h" #include "server.h" #include "gtest/gtest.h" +#include + +/* clang-format off */ +#include "dns_cache.h" +/* clang-format on */ class Cache : public ::testing::Test { @@ -228,3 +233,56 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } + +TEST_F(Cache, save_file) +{ + smartdns::MockServer server_upstream; + auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10); + std::string conf = R"""( +bind [::]:60053@lo +server 127.0.0.1:62053 +log-num 0 +log-console yes +log-level debug +cache-persist yes +dualstack-ip-selection no +)"""; + + conf += "cache-file " + cache_file; + Defer + { + unlink(cache_file.c_str()); + }; + + server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + { + smartdns::Server server; + server.Start(conf); + smartdns::Client client; + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 100); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + server.Stop(); + usleep(200 * 1000); + } + + ASSERT_EQ(access(cache_file.c_str(), F_OK), 0); + + std::fstream fs(cache_file, std::ios::in); + struct dns_cache_file head; + memset(&head, 0, sizeof(head)); + fs.read((char *)&head, sizeof(head)); + EXPECT_EQ(head.magic, MAGIC_NUMBER); + EXPECT_EQ(head.cache_number, 1); +} diff --git a/test/cases/test-cname.cc b/test/cases/test-cname.cc index b7aadda..5407407 100644 --- a/test/cases/test-cname.cc +++ b/test/cases/test-cname.cc @@ -22,7 +22,14 @@ #include "server.h" #include "gtest/gtest.h" -TEST(server, cname) +class Cname : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(Cname, cname) { smartdns::MockServer server_upstream; smartdns::Server server; @@ -32,11 +39,8 @@ TEST(server, cname) return smartdns::SERVER_REQUEST_SOA; } - unsigned char addr[4] = {1, 2, 3, 4}; - dns_add_A(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 611, addr); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611); EXPECT_EQ(request->domain, "e.com"); - - request->response_packet->head.rcode = DNS_RC_NOERROR; return smartdns::SERVER_REQUEST_OK; }); diff --git a/test/cases/test-dualstack.cc b/test/cases/test-dualstack.cc new file mode 100644 index 0000000..c7cc0c7 --- /dev/null +++ b/test/cases/test-dualstack.cc @@ -0,0 +1,187 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class DualStack : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(DualStack, ipv4_prefer) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + std::map qid_map; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100); + server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110); + server.MockPing(PING_TYPE_ICMP, "2001:db8::1", 60, 150); + server.MockPing(PING_TYPE_ICMP, "2001:db8::2", 60, 150); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAuthorityNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + + usleep(220 * 1000); + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_GT(client.GetAnswer()[0].GetTTL(), 597); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8"); +} + +TEST_F(DualStack, ipv6_prefer_allow_force_AAAA) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100); + server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110); + server.MockPing(PING_TYPE_ICMP, "2001:db8::1", 60, 70); + server.MockPing(PING_TYPE_ICMP, "2001:db8::2", 60, 75); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping +dualstack-ip-allow-force-AAAA yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAuthorityNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + + usleep(220 * 1000); + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_GT(client.GetAnswer()[0].GetTTL(), 597); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::1"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "2001:db8::2"); +} + +TEST_F(DualStack, ipv6_prefer_must_exist_ipv4) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100); + server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110); + server.MockPing(PING_TYPE_ICMP, "2001:db8::1", 60, 70); + server.MockPing(PING_TYPE_ICMP, "2001:db8::2", 60, 100); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping +dualstack-ip-allow-force-AAAA yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAuthorityNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + + usleep(220 * 1000); + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_GT(client.GetAnswer()[0].GetTTL(), 597); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::1"); +} diff --git a/test/cases/test-nameserver.cc b/test/cases/test-nameserver.cc new file mode 100644 index 0000000..9f12138 --- /dev/null +++ b/test/cases/test-nameserver.cc @@ -0,0 +1,97 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" + +class NameServer : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(NameServer, cname) +{ + smartdns::MockServer server_upstream; + smartdns::MockServer server_upstream1; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "9.10.11.12", 611); + return smartdns::SERVER_REQUEST_OK; + }); + + server_upstream1.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611); + return smartdns::SERVER_REQUEST_OK; + }); + + server_upstream2.Start("udp://0.0.0.0:63053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8", 611); + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +server 127.0.0.1:62053 -group g1 -exclude-default-group +server 127.0.0.1:63053 -group g2 -exclude-default-group +nameserver /a.com/g1 +nameserver /b.com/g2 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8"); + + ASSERT_TRUE(client.Query("c.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "c.com"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "9.10.11.12"); +} diff --git a/test/cases/test-ping.cc b/test/cases/test-ping.cc index 278955f..e2a3d17 100644 --- a/test/cases/test-ping.cc +++ b/test/cases/test-ping.cc @@ -71,3 +71,27 @@ TEST_F(Ping, tcp) fast_ping_stop(ping_host); EXPECT_EQ(count, 1); } + +void fake_ping_result_callback(struct ping_host_struct *ping_host, const char *host, FAST_PING_RESULT result, + struct sockaddr *addr, socklen_t addr_len, int seqno, int ttl, struct timeval *tv, + int error, void *userptr) +{ + if (result == PING_RESULT_RESPONSE) { + int *count = (int *)userptr; + double rtt = tv->tv_sec * 1000.0 + tv->tv_usec / 1000.0; + tlog(TLOG_INFO, "from %15s: seq=%d ttl=%d time=%.3f\n", host, seqno, ttl, rtt); + *count = (int)rtt; + } +} + +TEST_F(Ping, fake_icmp) +{ + struct ping_host_struct *ping_host; + int count = 0; + fast_ping_fake_ip_add(PING_TYPE_ICMP, "1.2.3.4", 60, 5); + ping_host = fast_ping_start(PING_TYPE_ICMP, "1.2.3.4", 1, 1000, 200, fake_ping_result_callback, &count); + ASSERT_NE(ping_host, nullptr); + usleep(100000); + fast_ping_stop(ping_host); + EXPECT_GE(count, 5); +} diff --git a/test/cases/test-qtype-soa.cc b/test/cases/test-qtype-soa.cc new file mode 100644 index 0000000..bd14793 --- /dev/null +++ b/test/cases/test-qtype-soa.cc @@ -0,0 +1,78 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" + +class QtypeSOA : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(QtypeSOA, AAAA_HTTPS) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } else if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +log-level debug +force-qtype-SOA 28 65 +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAuthorityNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + + ASSERT_TRUE(client.Query("a.com -t HTTPS", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAuthorityNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30); + EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA"); + + ASSERT_TRUE(client.Query("a.com A", 60053)); + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); +} diff --git a/test/cases/test-speed-check.cc b/test/cases/test-speed-check.cc index ee10e75..86b1794 100644 --- a/test/cases/test-speed-check.cc +++ b/test/cases/test-speed-check.cc @@ -189,3 +189,146 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); } + +TEST_F(SpeedCheck, fastest_ip) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100); + server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping +dualstack-ip-selection no +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 200); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + + usleep(220 * 1000); + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 2); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 20); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_GT(client.GetAnswer()[0].GetTTL(), 597); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + EXPECT_EQ(client.GetAnswer()[1].GetData(), "5.6.7.8"); +} + +TEST_F(SpeedCheck, unreach_best_ipv4) +{ + smartdns::MockServer server_upstream; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server_upstream2.Start("udp://0.0.0.0:62053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "9.10.11.12"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10000); + server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 10000); + server.MockPing(PING_TYPE_ICMP, "9.10.11.12", 60, 10000); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +server 127.0.0.1:62053 +log-num 0 +log-console yes +speed-check-mode ping +dualstack-ip-selection no +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 1200); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_GT(client.GetAnswer()[0].GetTTL(), 597); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(SpeedCheck, unreach_best_ipv6) +{ + smartdns::MockServer server_upstream; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server_upstream2.Start("udp://0.0.0.0:62053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_AAAA) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::3"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "2001:db8::1", 60, 10000); + server.MockPing(PING_TYPE_ICMP, "2001:db8::2", 60, 10000); + server.MockPing(PING_TYPE_ICMP, "2001:db8::3", 60, 10000); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +server 127.0.0.1:62053 +log-num 0 +log-console yes +speed-check-mode ping +dualstack-ip-selection no +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 1200); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_GT(client.GetAnswer()[0].GetTTL(), 597); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::2"); +} diff --git a/test/client.cc b/test/client.cc index 9a9e952..8aff929 100644 --- a/test/client.cc +++ b/test/client.cc @@ -173,6 +173,11 @@ int Client::GetAnswerNum() return answer_num_; } +int Client::GetAuthorityNum() +{ + return authority_num_; +} + std::string Client::GetStatus() { return status_; @@ -257,7 +262,7 @@ bool Client::ParserResult() std::regex reg_authority_num(", AUTHORITY: ([0-9]+),"); if (std::regex_search(result_, match, reg_authority_num)) { - records_authority_num_ = std::stoi(match[1]); + authority_num_ = std::stoi(match[1]); } std::regex reg_status(", status: ([A-Z]+),"); @@ -313,7 +318,7 @@ bool Client::ParserResult() return false; } - if (records_authority_num_ != records_authority_.size()) { + if (authority_num_ != records_authority_.size()) { std::cout << "DIG FAILED: Num Not Match\n" << result_ << std::endl; return false; } diff --git a/test/client.h b/test/client.h index ae2ca60..2c59de9 100644 --- a/test/client.h +++ b/test/client.h @@ -71,6 +71,8 @@ class Client int GetAnswerNum(); + int GetAuthorityNum(); + std::string GetStatus(); std::string GetServer(); @@ -90,7 +92,7 @@ class Client bool ParserRecord(const std::string &record_str, std::vector &record); std::string result_; int answer_num_{0}; - int records_authority_num_{0}; + int authority_num_{0}; std::string status_; std::string server_; int query_time_{0}; diff --git a/test/server.cc b/test/server.cc index 67b7995..af1829c 100644 --- a/test/server.cc +++ b/test/server.cc @@ -18,7 +18,9 @@ #include "server.h" #include "dns_server.h" +#include "fast_ping.h" #include "include/utils.h" +#include "smartdns.h" #include "util.h" #include #include @@ -36,8 +38,6 @@ namespace smartdns { -extern "C" int smartdns_main(int argc, char *argv[], int fd_notify); - MockServer::MockServer() {} MockServer::~MockServer() @@ -293,6 +293,34 @@ Server::Server(enum Server::CREATE_MODE mode) mode_ = mode; } +void Server::MockPing(PING_TYPE type, const std::string &host, int ttl, float time) +{ + struct MockPingIP ping_ip; + ping_ip.type = type; + ping_ip.host = host; + ping_ip.ttl = ttl; + ping_ip.time = time; + mock_ping_ips_.push_back(ping_ip); +} + +void Server::StartPost(void *arg) +{ + Server *server = (Server *)arg; + bool has_ipv6 = false; + for (auto &it : server->mock_ping_ips_) { + if (has_ipv6 == false && check_is_ipv6(it.host.c_str()) == 0) { + has_ipv6 = true; + } + + fast_ping_fake_ip_add(it.type, it.host.c_str(), it.ttl, it.time); + } + + if (has_ipv6 == true) { + fast_ping_fake_ip_add(PING_TYPE_ICMP, "2001::", 64, 10); + dns_server_check_ipv6_ready(); + } +} + bool Server::Start(const std::string &conf, enum CONF_TYPE type) { pid_t pid = 0; @@ -343,6 +371,7 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) argv[i] = (char *)args[i].c_str(); } + smartdns_reg_post_func(Server::StartPost, this); smartdns_main(args.size(), argv, fds[1]); _exit(1); } else if (pid < 0) { @@ -358,7 +387,9 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) argv[i] = (char *)args[i].c_str(); } + smartdns_reg_post_func(Server::StartPost, this); smartdns_main(args.size(), argv, fds[1]); + smartdns_reg_post_func(nullptr, nullptr); }); } else { return false; diff --git a/test/server.h b/test/server.h index e29a58d..3809721 100644 --- a/test/server.h +++ b/test/server.h @@ -20,12 +20,14 @@ #define _SMARTDNS_SERVER_ #include "dns.h" +#include "fast_ping.h" #include "include/utils.h" #include #include #include #include #include +#include namespace smartdns { @@ -33,6 +35,12 @@ namespace smartdns class Server { public: + struct MockPingIP { + PING_TYPE type; + std::string host; + int ttl; + float time; + }; enum CONF_TYPE { CONF_TYPE_STRING, CONF_TYPE_FILE, @@ -45,17 +53,19 @@ class Server Server(enum CREATE_MODE mode); virtual ~Server(); + void MockPing(PING_TYPE type, const std::string &host, int ttl, float time); bool Start(const std::string &conf, enum CONF_TYPE type = CONF_TYPE_STRING); void Stop(bool graceful = true); bool IsRunning(); private: + static void StartPost(void *arg); pid_t pid_; std::thread thread_; int fd_; std::string conf_file_; TempFile conf_temp_file_; - + std::vector mock_ping_ips_; enum CREATE_MODE mode_; }; @@ -100,10 +110,10 @@ class MockServer static bool GetAddr(const std::string &host, const std::string port, int type, int protocol, struct sockaddr_storage *addr, socklen_t *addrlen); - int fd_; + int fd_{0}; std::thread thread_; - bool run_; - ServerRequest callback_; + bool run_{false}; + ServerRequest callback_{nullptr}; }; } // namespace smartdns