From 582cdfb879067b3f3c8827edfb6efe54e2b7aaf9 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Tue, 28 Mar 2023 23:28:28 +0800 Subject: [PATCH] cache: modify cache ver check method, add ipset, nftset after restart. --- src/dns_cache.c | 54 ++++++++++++++++++++---- src/dns_cache.h | 18 ++++++-- src/dns_client.c | 21 ++++++++++ src/dns_client.h | 2 + src/dns_server.c | 33 ++++++++++----- test/cases/test-address.cc | 2 - test/cases/test-cache.cc | 77 ++++++++++++++++++++++++++++++++++ test/cases/test-dns64.cc | 1 - test/cases/test-domain-set.cc | 1 - test/cases/test-dualstack.cc | 1 - test/cases/test-qtype-soa.cc | 2 - test/cases/test-speed-check.cc | 73 ++++++++++++++++++++++++++++++-- 12 files changed, 250 insertions(+), 35 deletions(-) diff --git a/src/dns_cache.c b/src/dns_cache.c index f8a6eb8..caa3138 100644 --- a/src/dns_cache.c +++ b/src/dns_cache.c @@ -147,7 +147,7 @@ void dns_cache_data_free(struct dns_cache_data *data) free(data); } -struct dns_cache_data *dns_cache_new_data(void) +struct dns_cache_data *dns_cache_new_data_addr(void) { struct dns_cache_addr *cache_addr = malloc(sizeof(struct dns_cache_addr)); memset(cache_addr, 0, sizeof(struct dns_cache_addr)); @@ -157,6 +157,7 @@ struct dns_cache_data *dns_cache_new_data(void) cache_addr->head.cache_type = CACHE_TYPE_NONE; cache_addr->head.size = sizeof(struct dns_cache_addr) - sizeof(struct dns_cache_data_head); + cache_addr->head.magic = MAGIC_CACHE_DATA; return (struct dns_cache_data *)cache_addr; } @@ -241,6 +242,7 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len cache_packet->head.cache_type = CACHE_TYPE_PACKET; cache_packet->head.size = packet_len; + cache_packet->head.magic = MAGIC_CACHE_DATA; return (struct dns_cache_data *)cache_packet; } @@ -274,6 +276,7 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee dns_cache->info.ttl = ttl; dns_cache->info.speed = speed; dns_cache->info.no_inactive = no_inactive; + dns_cache->info.is_visited = 1; old_cache_data = dns_cache->cache_data; dns_cache->cache_data = cache_data; list_del_init(&dns_cache->list); @@ -294,12 +297,14 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee return 0; } -int dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data) +int dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, + struct dns_cache_data *cache_data) { return _dns_cache_replace(cache_key, ttl, speed, no_inactive, 0, cache_data); } -int dns_cache_replace_inactive(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data) +int dns_cache_replace_inactive(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, + struct dns_cache_data *cache_data) { return _dns_cache_replace(cache_key, ttl, speed, no_inactive, 1, cache_data); } @@ -391,7 +396,8 @@ errout: return -1; } -int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data) +int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, + struct dns_cache_data *cache_data) { struct dns_cache_info info; @@ -418,6 +424,7 @@ int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no info.hitnum_update_add = DNS_CACHE_HITNUM_STEP; info.speed = speed; info.no_inactive = no_inactive; + info.is_visited = 1; time(&info.insert_time); time(&info.replace_time); @@ -541,6 +548,11 @@ struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache) return dns_cache->cache_data; } +int dns_cache_is_visited(struct dns_cache *dns_cache) +{ + return dns_cache->info.is_visited; +} + void dns_cache_delete(struct dns_cache *dns_cache) { pthread_mutex_lock(&dns_cache_head.lock); @@ -574,6 +586,7 @@ void dns_cache_update(struct dns_cache *dns_cache) if (dns_cache->info.hitnum_update_add < DNS_CACHE_HITNUM_STEP_MAX) { dns_cache->info.hitnum_update_add++; } + dns_cache->info.is_visited = 1; } pthread_mutex_unlock(&dns_cache_head.lock); } @@ -707,15 +720,18 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number) goto errout; } - if (cache_record.magic != MAGIC_CACHE_DATA) { + if (cache_record.magic != MAGIC_RECORD) { tlog(TLOG_ERROR, "magic is invalid."); goto errout; } if (cache_record.type == CACHE_RECORD_TYPE_ACTIVE) { head = &dns_cache_head.cache_list; - } else { + } else if (cache_record.type == CACHE_RECORD_TYPE_INACTIVE) { head = &dns_cache_head.inactive_list; + } else { + tlog(TLOG_ERROR, "read cache record type is invalid."); + goto errout; } ret = read(fd, &data_head, sizeof(data_head)); @@ -724,6 +740,11 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number) goto errout; } + if (data_head.magic != MAGIC_CACHE_DATA) { + tlog(TLOG_ERROR, "data magic is invalid."); + goto errout; + } + if (data_head.size > 1024 * 8) { tlog(TLOG_ERROR, "data may invalid, skip load cache."); goto errout; @@ -742,6 +763,15 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number) goto errout; } + /* set cache unvisited, so that when refreshing ipset/nftset, reload ipset list by restarting smartdns */ + cache_record.info.is_visited = 0; + cache_record.info.domain[DNS_MAX_CNAME_LEN - 1] = '\0'; + cache_record.info.dns_group_name[DNS_GROUP_NAME_LEN - 1] = '\0'; + if (cache_record.type >= CACHE_RECORD_TYPE_END) { + tlog(TLOG_ERROR, "read cache record type is invalid."); + goto errout; + } + if (_dns_cache_insert(&cache_record.info, cache_data, head) != 0) { tlog(TLOG_ERROR, "insert cache data failed."); cache_data = NULL; @@ -786,7 +816,7 @@ int dns_cache_load(const char *file) goto errout; } - if (strncmp(cache_file.version, __TIMESTAMP__, DNS_CACHE_VERSION_LEN - 1) != 0) { + if (strncmp(cache_file.version, dns_cache_file_version(), DNS_CACHE_VERSION_LEN) != 0) { tlog(TLOG_WARN, "cache version is different, skip load cache."); goto errout; } @@ -815,7 +845,7 @@ static int _dns_cache_write_record(int fd, uint32_t *cache_number, enum CACHE_RE pthread_mutex_lock(&dns_cache_head.lock); list_for_each_entry_safe_reverse(dns_cache, tmp, head, list) { - cache_record.magic = MAGIC_CACHE_DATA; + cache_record.magic = MAGIC_RECORD; cache_record.type = type; memcpy(&cache_record.info, &dns_cache->info, sizeof(struct dns_cache_info)); ssize_t ret = write(fd, &cache_record, sizeof(cache_record)); @@ -871,7 +901,7 @@ int dns_cache_save(const char *file) struct dns_cache_file cache_file; memset(&cache_file, 0, sizeof(cache_file)); cache_file.magic = MAGIC_NUMBER; - safe_strncpy(cache_file.version, __TIMESTAMP__, DNS_CACHE_VERSION_LEN); + safe_strncpy(cache_file.version, dns_cache_file_version(), DNS_CACHE_VERSION_LEN); cache_file.cache_number = 0; if (lseek(fd, sizeof(cache_file), SEEK_SET) < 0) { @@ -926,3 +956,9 @@ void dns_cache_destroy(void) pthread_mutex_destroy(&dns_cache_head.lock); } + +const char *dns_cache_file_version(void) +{ + const char *version = "cache ver 1.0"; + return version; +} diff --git a/src/dns_cache.h b/src/dns_cache.h index 7b1d29a..62c7194 100644 --- a/src/dns_cache.h +++ b/src/dns_cache.h @@ -36,7 +36,8 @@ extern "C" { #define DNS_CACHE_VERSION_LEN 32 #define DNS_CACHE_GROUP_NAME_LEN 32 #define MAGIC_NUMBER 0x6548634163536e44 -#define MAGIC_CACHE_DATA 0x44615461 +#define MAGIC_CACHE_DATA 0x61546144 +#define MAGIC_RECORD 0x64526352 enum CACHE_TYPE { CACHE_TYPE_NONE, @@ -47,12 +48,14 @@ enum CACHE_TYPE { enum CACHE_RECORD_TYPE { CACHE_RECORD_TYPE_ACTIVE, CACHE_RECORD_TYPE_INACTIVE, + CACHE_RECORD_TYPE_END, }; struct dns_cache_data_head { enum CACHE_TYPE cache_type; int is_soa; ssize_t size; + uint32_t magic; }; struct dns_cache_data { @@ -89,6 +92,7 @@ struct dns_cache_info { int speed; int no_inactive; int hitnum_update_add; + int is_visited; time_t insert_time; time_t replace_time; }; @@ -136,9 +140,11 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len int dns_cache_init(int size, int enable_inactive, int inactive_list_expired); -int dns_cache_replace(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data); +int dns_cache_replace(struct dns_cache_key *key, int ttl, int speed, int no_inactive, + struct dns_cache_data *cache_data); -int dns_cache_replace_inactive(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data); +int dns_cache_replace_inactive(struct dns_cache_key *key, int ttl, int speed, int no_inactive, + struct dns_cache_data *cache_data); int dns_cache_insert(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data); @@ -152,6 +158,8 @@ void dns_cache_release(struct dns_cache *dns_cache); int dns_cache_hitnum_dec_get(struct dns_cache *dns_cache); +int dns_cache_is_visited(struct dns_cache *dns_cache); + void dns_cache_update(struct dns_cache *dns_cache); typedef void dns_cache_callback(struct dns_cache *dns_cache); @@ -165,7 +173,7 @@ int dns_cache_get_cname_ttl(struct dns_cache *dns_cache); int dns_cache_is_soa(struct dns_cache *dns_cache); -struct dns_cache_data *dns_cache_new_data(void); +struct dns_cache_data *dns_cache_new_data_addr(void); struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache); @@ -180,6 +188,8 @@ int dns_cache_load(const char *file); int dns_cache_save(const char *file); +const char *dns_cache_file_version(void); + #ifdef __cplusplus } #endif diff --git a/src/dns_client.c b/src/dns_client.c index 3447447..d0f7eb9 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -123,6 +123,7 @@ struct dns_server_info { time_t last_recv; unsigned long send_tick; int prohibit; + int is_already_prohibit; /* server addr info */ unsigned short ai_family; @@ -200,6 +201,7 @@ struct dns_client { struct list_head dns_request_list; atomic_t run_period; atomic_t dns_server_num; + atomic_t dns_server_prohibit_num; /* ECS */ struct dns_client_ecs ecs_ipv4; @@ -1413,6 +1415,11 @@ int dns_server_num(void) return atomic_read(&client.dns_server_num); } +int dns_server_alive_num(void) +{ + return atomic_read(&client.dns_server_num) - atomic_read(&client.dns_server_prohibit_num); +} + static void _dns_client_query_get(struct dns_query_struct *query) { if (atomic_inc_return(&query->refcnt) <= 0) { @@ -3338,6 +3345,7 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, query->send_tick = get_tick_count(); /* send query to all dns servers */ + atomic_inc(&query->dns_request_sent); for (i = 0; i < 2; i++) { total_server = 0; pthread_mutex_lock(&client.server_list_lock); @@ -3345,12 +3353,19 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, { server_info = group_member->server; if (server_info->prohibit) { + if (server_info->is_already_prohibit == 0) { + server_info->is_already_prohibit = 1; + atomic_inc(&client.dns_server_prohibit_num); + } + time_t now = 0; time(&now); if ((now - 60 < server_info->last_send) && (now - 5 > server_info->last_recv)) { continue; } server_info->prohibit = 0; + server_info->is_already_prohibit = 0; + atomic_dec(&client.dns_server_prohibit_num); if (now - 60 > server_info->last_send) { _dns_client_close_socket(server_info); } @@ -3428,6 +3443,11 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, } } + int num = atomic_dec_return(&query->dns_request_sent); + if (num == 0) { + _dns_client_query_remove(query); + } + if (send_count <= 0) { tlog(TLOG_WARN, "Send query to upstream server failed, total server number %d", total_server); return -1; @@ -4194,6 +4214,7 @@ int dns_client_init(void) memset(&client, 0, sizeof(client)); pthread_attr_init(&attr); atomic_set(&client.dns_server_num, 0); + atomic_set(&client.dns_server_prohibit_num, 0); atomic_set(&client.run_period, 0); epollfd = epoll_create1(EPOLL_CLOEXEC); diff --git a/src/dns_client.h b/src/dns_client.h index 6530e7f..3586834 100644 --- a/src/dns_client.h +++ b/src/dns_client.h @@ -141,6 +141,8 @@ int dns_client_remove_from_group(const char *group_name, char *server_ip, int po int dns_client_remove_group(const char *group_name); +int dns_server_alive_num(void); + int dns_server_num(void); #ifdef __cplusplus diff --git a/src/dns_server.c b/src/dns_server.c index 9400865..6246e64 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -1576,6 +1576,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context struct dns_request *request = context->request; char name[DNS_MAX_CNAME_LEN] = {0}; int rr_count = 0; + int timeout_value = 0; int i = 0; int j = 0; struct dns_rrs *rrs = NULL; @@ -1642,6 +1643,11 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context return 0; } + timeout_value = request->ip_ttl * 3; + if (timeout_value == 0) { + timeout_value = _dns_server_get_conf_ttl(request, 0) * 3; + } + for (j = 1; j < DNS_RRS_END; j++) { rrs = dns_get_rrs_start(context->packet, j, &rr_count); for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(context->packet, rrs)) { @@ -1659,7 +1665,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context /* add IPV4 to ipset */ tlog(TLOG_DEBUG, "IPSET-MATCH: domain: %s, ipset: %s, IP: %d.%d.%d.%d", request->domain, rule->ipsetname, addr[0], addr[1], addr[2], addr[3]); - ipset_add(rule->ipsetname, addr, DNS_RR_A_LEN, request->ip_ttl * 2); + ipset_add(rule->ipsetname, addr, DNS_RR_A_LEN, timeout_value); } if (nftset_ip != NULL) { @@ -1668,7 +1674,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context nftset_ip->familyname, nftset_ip->nfttablename, nftset_ip->nftsetname, addr[0], addr[1], addr[2], addr[3]); nftset_add(nftset_ip->familyname, nftset_ip->nfttablename, nftset_ip->nftsetname, addr, - DNS_RR_A_LEN, request->ip_ttl * 2); + DNS_RR_A_LEN, timeout_value); } } break; case DNS_T_AAAA: { @@ -1687,7 +1693,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context request->domain, rule->ipsetname, addr[0], addr[1], addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9], addr[10], addr[11], addr[12], addr[13], addr[14], addr[15]); - ipset_add(rule->ipsetname, addr, DNS_RR_AAAA_LEN, request->ip_ttl * 2); + ipset_add(rule->ipsetname, addr, DNS_RR_AAAA_LEN, timeout_value); } if (nftset_ip6 != NULL) { @@ -1699,7 +1705,7 @@ static int _dns_server_setup_ipset_nftset_packet(struct dns_server_post_context addr[0], addr[1], addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9], addr[10], addr[11], addr[12], addr[13], addr[14], addr[15]); nftset_add(nftset_ip6->familyname, nftset_ip6->nfttablename, nftset_ip6->nftsetname, addr, - DNS_RR_AAAA_LEN, request->ip_ttl * 2); + DNS_RR_AAAA_LEN, timeout_value); } } break; default: @@ -2803,7 +2809,7 @@ static int _dns_server_process_answer_A(struct dns_rrs *rrs, struct dns_request /* Ad blocking result */ if (addr[0] == 0 || addr[0] == 127) { /* If half of the servers return the same result, then ignore this address */ - if (atomic_inc_return(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) { + if (atomic_inc_return(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) { request->rcode = DNS_RC_NOERROR; _dns_server_request_release(request); return -1; @@ -2880,7 +2886,7 @@ static int _dns_server_process_answer_AAAA(struct dns_rrs *rrs, struct dns_reque /* Ad blocking result */ if (_dns_server_is_adblock_ipv6(addr) == 0) { /* If half of the servers return the same result, then ignore this address */ - if (atomic_inc_return(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) { + if (atomic_inc_return(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) { request->rcode = DNS_RC_NOERROR; _dns_server_request_release(request); return -1; @@ -2989,7 +2995,8 @@ static int _dns_server_process_answer(struct dns_request *request, const char *d request->soa.refresh, request->soa.retry, request->soa.expire, request->soa.minimum); int soa_num = atomic_inc_return(&request->soa_num); - if ((soa_num >= (dns_server_num() / 3) + 1 || soa_num > 4) && atomic_read(&request->ip_map_num) <= 0) { + if ((soa_num >= (dns_server_alive_num() / 3) + 1 || soa_num > 4) && + atomic_read(&request->ip_map_num) <= 0) { request->ip_ttl = ttl; _dns_server_request_complete(request); } @@ -3072,7 +3079,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const /* Ad blocking result */ if (addr[0] == 0 || addr[0] == 127) { /* If half of the servers return the same result, then ignore this address */ - if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) { + if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) { _dns_server_request_release(request); return 0; } @@ -3116,7 +3123,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, const /* Ad blocking result */ if (_dns_server_is_adblock_ipv6(addr) == 0) { /* If half of the servers return the same result, then ignore this address */ - if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) { + if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) { _dns_server_request_release(request); return 0; } @@ -3384,7 +3391,7 @@ static void _dns_server_passthrough_may_complete(struct dns_request *request) addr = request->ip_addr; if (addr[0] == 0 || addr[0] == 127) { /* If half of the servers return the same result, then ignore this address */ - if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) { + if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) { return; } } @@ -3394,7 +3401,7 @@ static void _dns_server_passthrough_may_complete(struct dns_request *request) addr = request->ip_addr; if (_dns_server_is_adblock_ipv6(addr) == 0) { /* If half of the servers return the same result, then ignore this address */ - if (atomic_read(&request->adblock) <= (dns_server_num() / 2 + dns_server_num() % 2)) { + if (atomic_read(&request->adblock) <= (dns_server_alive_num() / 2 + dns_server_alive_num() % 2)) { return; } } @@ -4551,6 +4558,10 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct return -1; } + if (dns_cache_is_visited(dns_cache) == 0) { + do_ipset = 1; + } + if (dns_cache->info.qtype != request->qtype) { return -1; } diff --git a/test/cases/test-address.cc b/test/cases/test-address.cc index cb55ad2..6f63797 100644 --- a/test/cases/test-address.cc +++ b/test/cases/test-address.cc @@ -35,7 +35,6 @@ TEST_F(Address, soa) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { @@ -122,7 +121,6 @@ TEST_F(Address, ip) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { diff --git a/test/cases/test-cache.cc b/test/cases/test-cache.cc index a564c0e..5188061 100644 --- a/test/cases/test-cache.cc +++ b/test/cases/test-cache.cc @@ -21,7 +21,10 @@ #include "include/utils.h" #include "server.h" #include "gtest/gtest.h" +#include #include +#include +#include /* clang-format off */ #include "dns_cache.h" @@ -286,3 +289,77 @@ dualstack-ip-selection no EXPECT_EQ(head.magic, MAGIC_NUMBER); EXPECT_EQ(head.cache_number, 1); } + +TEST_F(Cache, corrupt_file) +{ + smartdns::MockServer server_upstream; + auto cache_file = "/tmp/smartdns_cache." + smartdns::GenerateRandomString(10); + std::string conf = R"""( +bind [::]:60053@lo +server 127.0.0.1:62053 +log-num 0 +log-console yes +log-level debug +dualstack-ip-selection no +cache-persist yes +)"""; + + conf += "cache-file " + cache_file; + Defer + { + unlink(cache_file.c_str()); + }; + + server_upstream.Start("udp://0.0.0.0:62053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + { + smartdns::Server server; + server.Start(conf); + smartdns::Client client; + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 100); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + server.Stop(); + usleep(200 * 1000); + } + + ASSERT_EQ(access(cache_file.c_str(), F_OK), 0); + + int fd = open(cache_file.c_str(), O_RDWR); + ASSERT_NE(fd, -1); + srandom(time(NULL)); + off_t file_size = lseek(fd, 0, SEEK_END); + off_t offset = random() % (file_size - 300); + std::cout << "try make corrupt at " << offset << ", file size: " << file_size << std::endl; + lseek(fd, offset, SEEK_SET); + for (int i = 0; i < 300; i++) { + unsigned char c = random() % 256; + write(fd, &c, 1); + } + close(fd); + { + smartdns::Server server; + server.Start(conf); + smartdns::Client client; + + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 100); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); + server.Stop(); + usleep(200 * 1000); + } +} \ No newline at end of file diff --git a/test/cases/test-dns64.cc b/test/cases/test-dns64.cc index 8649281..ed1c8f1 100644 --- a/test/cases/test-dns64.cc +++ b/test/cases/test-dns64.cc @@ -35,7 +35,6 @@ TEST_F(DNS64, no_dualstack) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { diff --git a/test/cases/test-domain-set.cc b/test/cases/test-domain-set.cc index 7799536..bc3cb0e 100644 --- a/test/cases/test-domain-set.cc +++ b/test/cases/test-domain-set.cc @@ -35,7 +35,6 @@ TEST_F(DomainSet, set_add) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; smartdns::TempFile file_set; std::vector domain_list; int count = 16; diff --git a/test/cases/test-dualstack.cc b/test/cases/test-dualstack.cc index c7cc0c7..6fb6f68 100644 --- a/test/cases/test-dualstack.cc +++ b/test/cases/test-dualstack.cc @@ -35,7 +35,6 @@ TEST_F(DualStack, ipv4_prefer) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { diff --git a/test/cases/test-qtype-soa.cc b/test/cases/test-qtype-soa.cc index fdc4c6f..82298e2 100644 --- a/test/cases/test-qtype-soa.cc +++ b/test/cases/test-qtype-soa.cc @@ -155,7 +155,6 @@ TEST_F(QtypeSOA, force_AAAA_SOA) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { @@ -199,7 +198,6 @@ TEST_F(QtypeSOA, bind_force_AAAA_SOA) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { diff --git a/test/cases/test-speed-check.cc b/test/cases/test-speed-check.cc index 86b1794..d8e31b7 100644 --- a/test/cases/test-speed-check.cc +++ b/test/cases/test-speed-check.cc @@ -35,7 +35,6 @@ TEST_F(SpeedCheck, response_mode) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { @@ -78,7 +77,6 @@ TEST_F(SpeedCheck, none) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { @@ -120,7 +118,6 @@ TEST_F(SpeedCheck, domain_rules_none) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { @@ -162,7 +159,6 @@ TEST_F(SpeedCheck, only_ping) { smartdns::MockServer server_upstream; smartdns::Server server; - std::map qid_map; server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { if (request->qtype == DNS_T_A) { @@ -190,6 +186,75 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600); } +TEST_F(SpeedCheck, no_ping_fallback_tcp) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 1000); + server.MockPing(PING_TYPE_TCP, "5.6.7.8:80", 60, 100); + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping,tcp:80 +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 500); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8"); +} + + +TEST_F(SpeedCheck, tcp_faster_than_ping) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 300); + server.MockPing(PING_TYPE_TCP, "5.6.7.8:80", 60, 10); + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +speed-check-mode ping,tcp:80 +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_LT(client.GetQueryTime(), 500); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8"); +} + TEST_F(SpeedCheck, fastest_ip) { smartdns::MockServer server_upstream;