dns_client: make DNS query ID random

This commit is contained in:
Nick Peng
2023-01-31 22:51:08 +08:00
parent 1e29f1fa63
commit 26d16eb9dc
2 changed files with 18 additions and 7 deletions

View File

@@ -47,6 +47,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <sys/epoll.h> #include <sys/epoll.h>
#include <sys/random.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
@@ -254,7 +255,6 @@ struct dns_query_struct {
}; };
static struct dns_client client; static struct dns_client client;
static atomic_t dns_client_sid = ATOMIC_INIT(0);
static LIST_HEAD(pending_servers); static LIST_HEAD(pending_servers);
static pthread_mutex_t pending_server_mutex = PTHREAD_MUTEX_INITIALIZER; static pthread_mutex_t pending_server_mutex = PTHREAD_MUTEX_INITIALIZER;
static int dns_client_has_bootstrap_dns = 0; static int dns_client_has_bootstrap_dns = 0;
@@ -1525,7 +1525,7 @@ static void _dns_client_check_tcp(void)
pthread_mutex_unlock(&client.server_list_lock); pthread_mutex_unlock(&client.server_list_lock);
} }
static struct dns_query_struct *_dns_client_get_request(unsigned short sid, char *domain) static struct dns_query_struct *_dns_client_get_request(char *domain, int qtype, unsigned short sid)
{ {
struct dns_query_struct *query = NULL; struct dns_query_struct *query = NULL;
struct dns_query_struct *query_result = NULL; struct dns_query_struct *query_result = NULL;
@@ -1535,6 +1535,7 @@ static struct dns_query_struct *_dns_client_get_request(unsigned short sid, char
/* get query by hash key : id + domain */ /* get query by hash key : id + domain */
key = hash_string(domain); key = hash_string(domain);
key = jhash(&sid, sizeof(sid), key); key = jhash(&sid, sizeof(sid), key);
key = jhash(&qtype, sizeof(qtype), key);
pthread_mutex_lock(&client.domain_map_lock); pthread_mutex_lock(&client.domain_map_lock);
hash_for_each_possible_safe(client.domain_map, query, tmp, domain_node, key) hash_for_each_possible_safe(client.domain_map, query, tmp, domain_node, key)
{ {
@@ -1542,6 +1543,10 @@ static struct dns_query_struct *_dns_client_get_request(unsigned short sid, char
continue; continue;
} }
if (qtype != query->qtype) {
continue;
}
if (strncmp(query->domain, domain, DNS_MAX_CNAME_LEN) != 0) { if (strncmp(query->domain, domain, DNS_MAX_CNAME_LEN) != 0) {
continue; continue;
} }
@@ -1643,7 +1648,7 @@ static int _dns_client_recv(struct dns_server_info *server_info, unsigned char *
} }
/* get query reference */ /* get query reference */
query = _dns_client_get_request(packet->head.id, domain); query = _dns_client_get_request(domain, qtype, packet->head.id);
if (query == NULL) { if (query == NULL) {
return 0; return 0;
} }
@@ -3495,6 +3500,7 @@ int dns_client_query(const char *domain, int qtype, dns_client_callback callback
struct dns_query_struct *query = NULL; struct dns_query_struct *query = NULL;
int ret = 0; int ret = 0;
uint32_t key = 0; uint32_t key = 0;
int unused __attribute__((unused));
if (domain == NULL) { if (domain == NULL) {
goto errout; goto errout;
@@ -3518,7 +3524,9 @@ int dns_client_query(const char *domain, int qtype, dns_client_callback callback
query->qtype = qtype; query->qtype = qtype;
query->send_tick = 0; query->send_tick = 0;
query->has_result = 0; query->has_result = 0;
query->sid = atomic_inc_return(&dns_client_sid); if (getrandom(&query->sid, sizeof(query->sid), GRND_NONBLOCK) != sizeof(query->sid)) {
query->sid = random();
}
query->server_group = _dns_client_get_dnsserver_group(group_name); query->server_group = _dns_client_get_dnsserver_group(group_name);
if (query->server_group == NULL) { if (query->server_group == NULL) {
tlog(TLOG_ERROR, "get dns server group %s failed.", group_name); tlog(TLOG_ERROR, "get dns server group %s failed.", group_name);
@@ -3534,6 +3542,7 @@ int dns_client_query(const char *domain, int qtype, dns_client_callback callback
/* add query to hashtable */ /* add query to hashtable */
key = hash_string(domain); key = hash_string(domain);
key = jhash(&query->sid, sizeof(query->sid), key); key = jhash(&query->sid, sizeof(query->sid), key);
key = jhash(&query->qtype, sizeof(query->qtype), key);
pthread_mutex_lock(&client.domain_map_lock); pthread_mutex_lock(&client.domain_map_lock);
hash_add(client.domain_map, &query->domain_node, key); hash_add(client.domain_map, &query->domain_node, key);
pthread_mutex_unlock(&client.domain_map_lock); pthread_mutex_unlock(&client.domain_map_lock);
@@ -3946,6 +3955,8 @@ int dns_client_init(void)
return -1; return -1;
} }
srandom(time(NULL));
memset(&client, 0, sizeof(client)); memset(&client, 0, sizeof(client));
pthread_attr_init(&attr); pthread_attr_init(&attr);
atomic_set(&client.dns_server_num, 0); atomic_set(&client.dns_server_num, 0);

View File

@@ -569,7 +569,7 @@ static int _fast_ping_sendping_v4(struct ping_host_struct *ping_host)
len = sendto(ping.fd_icmp, packet, sizeof(struct fast_ping_packet), 0, &ping_host->addr, ping_host->addr_len); len = sendto(ping.fd_icmp, packet, sizeof(struct fast_ping_packet), 0, &ping_host->addr, ping_host->addr_len);
if (len < 0 || len != sizeof(struct fast_ping_packet)) { if (len < 0 || len != sizeof(struct fast_ping_packet)) {
int err = errno; int err = errno;
if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL) { if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL || errno == EPERM || errno == EACCES) {
goto errout; goto errout;
} }
char ping_host_name[PING_MAX_HOSTLEN]; char ping_host_name[PING_MAX_HOSTLEN];
@@ -621,7 +621,7 @@ static int _fast_ping_sendping_udp(struct ping_host_struct *ping_host)
len = sendto(fd, &dns_head, sizeof(dns_head), 0, &ping_host->addr, ping_host->addr_len); len = sendto(fd, &dns_head, sizeof(dns_head), 0, &ping_host->addr, ping_host->addr_len);
if (len < 0 || len != sizeof(dns_head)) { if (len < 0 || len != sizeof(dns_head)) {
int err = errno; int err = errno;
if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL) { if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL || errno == EPERM || errno == EACCES) {
goto errout; goto errout;
} }
char ping_host_name[PING_MAX_HOSTLEN]; char ping_host_name[PING_MAX_HOSTLEN];
@@ -672,7 +672,7 @@ static int _fast_ping_sendping_tcp(struct ping_host_struct *ping_host)
goto errout; goto errout;
} }
if (errno == EACCES) { if (errno == EACCES || errno == EPERM) {
if (bool_print_log == 0) { if (bool_print_log == 0) {
goto errout; goto errout;
} }