diff --git a/etc/smartdns/smartdns.conf b/etc/smartdns/smartdns.conf index 32f1a25..0f9c997 100644 --- a/etc/smartdns/smartdns.conf +++ b/etc/smartdns/smartdns.conf @@ -58,7 +58,8 @@ bind [::]:53 # dns cache size # cache-size [number] # 0: for no cache -cache-size 32768 +# -1: auto set cache size +# cache-size 32768 # enable persist cache when restart # cache-persist no diff --git a/src/dns_cache.h b/src/dns_cache.h index 3a7ce48..7b1d29a 100644 --- a/src/dns_cache.h +++ b/src/dns_cache.h @@ -28,7 +28,7 @@ #include #include -#ifdef __cpluscplus +#ifdef __cplusplus extern "C" { #endif @@ -180,7 +180,7 @@ int dns_cache_load(const char *file); int dns_cache_save(const char *file); -#ifdef __cpluscplus +#ifdef __cplusplus } #endif #endif // !_SMARTDNS_CACHE_H diff --git a/src/dns_client.c b/src/dns_client.c index 5b01afb..9294012 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -558,13 +558,9 @@ static struct dns_server_group *_dns_client_get_dnsserver_group(const char *grou if (group == NULL) { group = client.default_group; - tlog(TLOG_DEBUG, "send query to group %s", DNS_SERVER_GROUP_DEFAULT); } else { if (list_empty(&group->head)) { group = client.default_group; - tlog(TLOG_DEBUG, "send query to group %s", DNS_SERVER_GROUP_DEFAULT); - } else { - tlog(TLOG_DEBUG, "send query to group %s", group_name); } } @@ -3602,7 +3598,7 @@ static int _dns_client_add_hashmap(struct dns_query_struct *query) int is_exists = 0; int loop = 0; - while (loop ++ <= 32) { + while (loop++ <= 32) { if (RAND_bytes((unsigned char *)&query->sid, sizeof(query->sid)) != 1) { query->sid = random(); } @@ -3634,7 +3630,7 @@ static int _dns_client_add_hashmap(struct dns_query_struct *query) pthread_mutex_unlock(&client.domain_map_lock); continue; } - + hash_add(client.domain_map, &query->domain_node, key); pthread_mutex_unlock(&client.domain_map_lock); break; @@ -3708,7 +3704,7 @@ int dns_client_query(const char *domain, int qtype, dns_client_callback callback } pthread_mutex_unlock(&client.domain_map_lock); - tlog(TLOG_INFO, "send request %s, qtype %d, id %d\n", domain, qtype, query->sid); + tlog(TLOG_INFO, "request: %s, qtype: %d, id: %d, group: %s", domain, qtype, query->sid, query->server_group->group_name); _dns_client_query_release(query); return 0; diff --git a/src/dns_client.h b/src/dns_client.h index 4301ad7..6530e7f 100644 --- a/src/dns_client.h +++ b/src/dns_client.h @@ -21,7 +21,7 @@ #include "dns.h" -#ifdef __cpluscplus +#ifdef __cplusplus extern "C" { #endif @@ -143,7 +143,7 @@ int dns_client_remove_group(const char *group_name); int dns_server_num(void); -#ifdef __cpluscplus +#ifdef __cplusplus } #endif #endif diff --git a/src/dns_conf.c b/src/dns_conf.c index 6a32bc3..022c303 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -30,7 +30,6 @@ #include #include -#define DEFAULT_DNS_CACHE_SIZE 512 #define DNS_MAX_REPLY_IP_NUM 8 #define DNS_RESOLV_FILE "/etc/resolv.conf" @@ -84,7 +83,7 @@ static struct config_enum_list dns_conf_response_mode_enum[] = { enum response_mode_type dns_conf_response_mode; /* cache */ -int dns_conf_cachesize = DEFAULT_DNS_CACHE_SIZE; +int dns_conf_cachesize = -1; int dns_conf_prefetch = 0; int dns_conf_serve_expired = 1; int dns_conf_serve_expired_ttl = 24 * 3600; /* 1 day */ @@ -3134,7 +3133,7 @@ static struct config_item _config_item[] = { CONF_CUSTOM("speed-check-mode", _config_speed_check_mode, NULL), CONF_INT("tcp-idle-time", &dns_conf_tcp_idle_time, 0, 3600), CONF_INT("cache-size", &dns_conf_cachesize, 0, CONF_INT_MAX), - CONF_STRING("cache-file", (char *)&dns_conf_cache_file, DNS_MAX_PATH), + CONF_CUSTOM("cache-file", _config_option_parser_filepath, (char *)&dns_conf_cache_file), CONF_YESNO("cache-persist", &dns_conf_cache_persist), CONF_YESNO("prefetch-domain", &dns_conf_prefetch), CONF_YESNO("serve-expired", &dns_conf_serve_expired), @@ -3341,12 +3340,37 @@ errout: return -1; } +static void _dns_conf_auto_set_cache_size(void) +{ + uint64_t memsize = get_system_mem_size(); + if (dns_conf_cachesize >= 0) { + return; + } + + if (memsize <= 16 * 1024 * 1024) { + dns_conf_cachesize = 2048; /* 1MB memory */ + } else if (memsize <= 32 * 1024 * 1024) { + dns_conf_cachesize = 8192; /* 4MB memory*/ + } else if (memsize <= 64 * 1024 * 1024) { + dns_conf_cachesize = 16384; /* 8MB memory*/ + } else if (memsize <= 128 * 1024 * 1024) { + dns_conf_cachesize = 32768; /* 16MB memory*/ + } else if (memsize <= 256 * 1024 * 1024) { + dns_conf_cachesize = 65536; /* 32MB memory*/ + } else if (memsize <= 512 * 1024 * 1024) { + dns_conf_cachesize = 131072; /* 64MB memory*/ + } else { + dns_conf_cachesize = 262144; /* 128MB memory*/ + } +} static int _dns_conf_load_post(void) { _config_setup_smartdns_domain(); _dns_conf_speed_check_mode_verify(); + _dns_conf_auto_set_cache_size(); + if (dns_conf_cachesize == 0 && dns_conf_response_mode == DNS_RESPONSE_MODE_FASTEST_RESPONSE) { dns_conf_response_mode = DNS_RESPONSE_MODE_FASTEST_IP; tlog(TLOG_WARN, "force set response to %s as cache size is 0", diff --git a/src/dns_conf.h b/src/dns_conf.h index d3d9cd7..fd4fb2e 100644 --- a/src/dns_conf.h +++ b/src/dns_conf.h @@ -30,7 +30,7 @@ #include "proxy.h" #include "radix.h" -#ifdef __cpluscplus +#ifdef __cplusplus extern "C" { #endif @@ -503,7 +503,7 @@ int dns_server_check_update_hosts(void); struct dns_proxy_names *dns_server_get_proxy_nams(const char *proxyname); extern int config_additional_file(void *data, int argc, char *argv[]); -#ifdef __cpluscplus +#ifdef __cplusplus } #endif #endif // !_DNS_CONF diff --git a/src/dns_server.c b/src/dns_server.c index 4349540..fd002fd 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -1847,9 +1847,10 @@ static int _dns_request_post(struct dns_server_post_context *context) return -1; } - tlog(TLOG_INFO, "reply %s to %s, qtype: %d, id: %d, group: %s", request->domain, + tlog(TLOG_INFO, "result: %s, client: %s, qtype: %d, id: %d, group: %s, time: %lums", request->domain, get_host_by_addr(clientip, sizeof(clientip), (struct sockaddr *)&request->addr), request->qtype, request->id, - request->dns_group_name[0] != '\0' ? request->dns_group_name : "default"); + request->dns_group_name[0] != '\0' ? request->dns_group_name : "default", + get_tick_count() - request->send_tick); ret = _dns_reply_inpacket(request, context->inpacket, context->inpacket_len); if (ret != 0) { @@ -3287,9 +3288,10 @@ static int _dns_server_reply_passthrough(struct dns_server_post_context *context } _dns_reply_inpacket(request, context->inpacket, context->inpacket_len); - tlog(TLOG_INFO, "reply %s to %s, qtype: %d, id: %d, group: %s", request->domain, + tlog(TLOG_INFO, "result: %s, client: %s, qtype: %d, id: %d, group: %s, time: %lums", request->domain, get_host_by_addr(clientip, sizeof(clientip), (struct sockaddr *)&request->addr), request->qtype, - request->id, request->dns_group_name[0] != '\0' ? request->dns_group_name : "default"); + request->id, request->dns_group_name[0] != '\0' ? request->dns_group_name : "default", + get_tick_count() - request->send_tick); } return _dns_server_reply_all_pending_list(request, context); @@ -5246,7 +5248,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in goto errout; } - tlog(TLOG_INFO, "query %s from %s, qtype: %d, id: %d\n", request->domain, name, request->qtype, request->id); + tlog(TLOG_DEBUG, "query %s from %s, qtype: %d, id: %d\n", request->domain, name, request->qtype, request->id); ret = _dns_server_do_query(request, 1); if (ret != 0) { diff --git a/src/dns_server.h b/src/dns_server.h index 12894db..3c40b5d 100644 --- a/src/dns_server.h +++ b/src/dns_server.h @@ -23,7 +23,7 @@ #include #include "dns_client.h" -#ifdef __cpluscplus +#ifdef __cplusplus extern "C" { #endif @@ -55,7 +55,7 @@ typedef int (*dns_result_callback)(const char *domain, dns_rtcode_t rtcode, dns_ int dns_server_query(const char *domain, int qtype, struct dns_server_query_option *server_query_option, dns_result_callback callback, void *user_ptr); -#ifdef __cpluscplus +#ifdef __cplusplus } #endif #endif diff --git a/src/http_parse.h b/src/http_parse.h index 6d9d40a..d8c87b8 100644 --- a/src/http_parse.h +++ b/src/http_parse.h @@ -19,7 +19,7 @@ #ifndef HTTP_PARSER_H #define HTTP_PARSER_H -#ifdef __cpluscplus +#ifdef __cplusplus extern "C" { #endif @@ -80,7 +80,7 @@ int http_head_parse(struct http_head *http_head, const char *data, int data_len) void http_head_destroy(struct http_head *http_head); -#ifdef __cpluscplus +#ifdef __cplusplus } #endif diff --git a/src/include/nftset.h b/src/include/nftset.h index 18de0da..037d356 100644 --- a/src/include/nftset.h +++ b/src/include/nftset.h @@ -19,7 +19,7 @@ #ifndef _NFTSET_H #define _NFTSET_H -#ifdef __cpluscplus +#ifdef __cplusplus extern "C" { #endif @@ -29,7 +29,7 @@ int nftset_add(const char *familyname, const char *tablename, const char *setnam int nftset_del(const char *familyname, const char *tablename, const char *setname, const unsigned char addr[], int addr_len); -#ifdef __cpluscplus +#ifdef __cplusplus } #endif diff --git a/src/util.c b/src/util.c index 6b0f4f2..6d5582c 100644 --- a/src/util.c +++ b/src/util.c @@ -50,6 +50,7 @@ #include #include #include +#include #define TMP_BUFF_LEN_32 32 @@ -1225,6 +1226,16 @@ void get_compiled_time(struct tm *tm) tm->tm_sec = sec; } +unsigned long get_system_mem_size(void) +{ + struct sysinfo memInfo; + sysinfo (&memInfo); + long long totalMem = memInfo.totalram; + totalMem *= memInfo.mem_unit; + + return totalMem; +} + int is_numeric(const char *str) { while (*str != '\0') { diff --git a/src/util.h b/src/util.h index 9bff444..570b047 100644 --- a/src/util.h +++ b/src/util.h @@ -118,6 +118,8 @@ int parse_tls_header(const char *data, size_t data_len, char *hostname, const ch void get_compiled_time(struct tm *tm); +unsigned long get_system_mem_size(void); + int is_numeric(const char *str); int has_network_raw_cap(void); diff --git a/test/server.cc b/test/server.cc index 4469541..0224546 100644 --- a/test/server.cc +++ b/test/server.cc @@ -17,6 +17,7 @@ */ #include "server.h" +#include "dns_server.h" #include "include/utils.h" #include "util.h" #include @@ -282,10 +283,18 @@ bool MockServer::Start(const std::string &url, ServerRequest callback) return true; } -Server::Server() {} +Server::Server() { + mode_ = Server::CREATE_MODE_FORK; +} + +Server::Server(enum Server::CREATE_MODE mode) +{ + mode_ = mode; +} bool Server::Start(const std::string &conf, enum CONF_TYPE type) { + pid_t pid = 0; int fds[2]; std::string conf_file; @@ -339,19 +348,35 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) return false; } - pid_t pid = fork(); - if (pid == 0) { - std::vector args = { - "smartdns", "-f", "-x", "-c", conf_file, "-p", "-", - }; - char *argv[args.size() + 1]; - for (size_t i = 0; i < args.size(); i++) { - argv[i] = (char *)args[i].c_str(); - } + if (mode_ == CREATE_MODE_FORK) { + pid = fork(); + if (pid == 0) { + std::vector args = { + "smartdns", "-f", "-x", "-c", conf_file, "-p", "-", + }; + char *argv[args.size() + 1]; + for (size_t i = 0; i < args.size(); i++) { + argv[i] = (char *)args[i].c_str(); + } - smartdns_main(args.size(), argv, fds[1]); - _exit(1); - } else if (pid < 0) { + smartdns_main(args.size(), argv, fds[1]); + _exit(1); + } else if (pid < 0) { + return false; + } + } else if (mode_ == CREATE_MODE_THREAD) { + thread_ = std::thread([&]() { + std::vector args = { + "smartdns", "-f", "-x", "-c", conf_file_, "-p", "-", + }; + char *argv[args.size() + 1]; + for (size_t i = 0; i < args.size(); i++) { + argv[i] = (char *)args[i].c_str(); + } + + smartdns_main(args.size(), argv, fds[1]); + }); + } else { return false; } @@ -361,7 +386,13 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) int ret = poll(pfd, 1, 10000); if (ret == 0) { - kill(pid, SIGKILL); + if (thread_.joinable()) { + thread_.join(); + } + + if (pid > 0) { + kill(pid, SIGKILL); + } return false; } @@ -371,6 +402,11 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type) void Server::Stop(bool graceful) { + if (thread_.joinable()) { + dns_server_stop(); + thread_.join(); + } + if (pid_ > 0) { if (graceful) { kill(pid_, SIGTERM); diff --git a/test/server.h b/test/server.h index 9971708..d21f36b 100644 --- a/test/server.h +++ b/test/server.h @@ -36,7 +36,12 @@ class Server CONF_TYPE_STRING, CONF_TYPE_FILE, }; + enum CREATE_MODE { + CREATE_MODE_FORK, + CREATE_MODE_THREAD, + }; Server(); + Server(enum CREATE_MODE mode); virtual ~Server(); bool Start(const std::string &conf, enum CONF_TYPE type = CONF_TYPE_STRING); @@ -45,9 +50,11 @@ class Server private: pid_t pid_; + std::thread thread_; int fd_; std::string conf_file_; bool clean_conf_file_{false}; + enum CREATE_MODE mode_; }; struct ServerRequestContext {