_dns_server_recv function refactoring

This commit is contained in:
Nick Peng
2019-07-08 23:30:14 +08:00
parent 301a60f6ed
commit 44c854dfec
3 changed files with 297 additions and 235 deletions

View File

@@ -65,6 +65,9 @@ struct dns_server_conn {
socklen_t addr_len; socklen_t addr_len;
struct sockaddr_storage addr; struct sockaddr_storage addr;
socklen_t localaddr_len;
struct sockaddr_storage localaddr;
time_t last_request_time; time_t last_request_time;
}; };
@@ -106,7 +109,6 @@ struct dns_request {
/* dns query */ /* dns query */
char domain[DNS_MAX_CNAME_LEN]; char domain[DNS_MAX_CNAME_LEN];
struct dns_head head;
unsigned long send_tick; unsigned long send_tick;
unsigned short qtype; unsigned short qtype;
unsigned short id; unsigned short id;
@@ -203,25 +205,6 @@ static void _dns_server_audit_log(struct dns_request *request)
tlog_printf(dns_audit, "%s %s query %s, type %d, result %s\n", req_time, req_host, request->domain, request->qtype, req_result); tlog_printf(dns_audit, "%s %s query %s, type %d, result %s\n", req_time, req_host, request->domain, request->qtype, req_result);
} }
static int _dns_recv_addr(struct dns_request *request, struct sockaddr_storage *from, socklen_t from_len)
{
switch (from->ss_family) {
case AF_INET:
memcpy(&request->in, from, from_len);
request->addr_len = from_len;
break;
case AF_INET6:
memcpy(&request->in6, from, from_len);
request->addr_len = from_len;
break;
default:
return -1;
break;
}
return 0;
}
static int _dns_add_rrs(struct dns_packet *packet, struct dns_request *request) static int _dns_add_rrs(struct dns_packet *packet, struct dns_request *request)
{ {
int ret = 0; int ret = 0;
@@ -458,7 +441,7 @@ static int _dns_reply(struct dns_request *request)
return _dns_reply_inpacket(request, inpacket, encode_len); return _dns_reply_inpacket(request, inpacket, encode_len);
} }
static int _dns_server_reply_SOA(int rcode, struct dns_request *request, struct dns_packet *packet) static int _dns_server_reply_SOA(int rcode, struct dns_request *request)
{ {
struct dns_soa *soa; struct dns_soa *soa;
@@ -639,7 +622,7 @@ static int _dns_server_request_complete(struct dns_request *request)
dns_cache_insert(request->domain, cname, cname_ttl, request->ttl_v4, DNS_T_A, request->ipv4_addr, DNS_RR_A_LEN, request->ping_ttl_v4); dns_cache_insert(request->domain, cname, cname_ttl, request->ttl_v4, DNS_T_A, request->ipv4_addr, DNS_RR_A_LEN, request->ping_ttl_v4);
} }
return _dns_server_reply_SOA(DNS_RC_NOERROR, request, NULL); return _dns_server_reply_SOA(DNS_RC_NOERROR, request);
} }
} }
@@ -743,6 +726,45 @@ static void _dns_server_select_possible_ipaddress(struct dns_request *request)
} }
} }
static struct dns_request *_dns_server_new_request(void)
{
struct dns_request *request = NULL;
request = malloc(sizeof(*request));
if (request == NULL) {
tlog(TLOG_ERROR, "malloc failed.\n");
goto errout;
}
memset(request, 0, sizeof(*request));
pthread_mutex_init(&request->ip_map_lock, NULL);
atomic_set(&request->adblock, 0);
atomic_set(&request->soa_num, 0);
atomic_set(&request->refcnt, 0);
request->ping_ttl_v4 = -1;
request->ping_ttl_v6 = -1;
request->prefetch = 0;
request->rcode = DNS_RC_SERVFAIL;
request->client = NULL;
request->result_callback = NULL;
INIT_LIST_HEAD(&request->list);
hash_init(request->ip_map);
return request;
errout:
return NULL;
}
static void _dns_server_delete_request(struct dns_request *request)
{
if (request->client) {
_dns_server_client_release(request->client);
}
pthread_mutex_destroy(&request->ip_map_lock);
memset(request, 0, sizeof(*request));
free(request);
}
static void _dns_server_request_release(struct dns_request *request) static void _dns_server_request_release(struct dns_request *request)
{ {
struct dns_ip_address *addr_map; struct dns_ip_address *addr_map;
@@ -771,10 +793,8 @@ static void _dns_server_request_release(struct dns_request *request)
hash_del(&addr_map->node); hash_del(&addr_map->node);
free(addr_map); free(addr_map);
} }
pthread_mutex_destroy(&request->ip_map_lock);
_dns_server_client_release(request->client); _dns_server_delete_request(request);
memset(request, 0, sizeof(*request));
free(request);
} }
static void _dns_server_request_get(struct dns_request *request) static void _dns_server_request_get(struct dns_request *request)
@@ -1292,7 +1312,7 @@ static int dns_server_resolve_callback(char *domain, dns_result_type rtype, unsi
return 0; return 0;
} }
static int _dns_server_process_ptr(struct dns_request *request, struct dns_packet *packet) static int _dns_server_process_ptr(struct dns_request *request)
{ {
struct ifaddrs *ifaddr = NULL; struct ifaddrs *ifaddr = NULL;
struct ifaddrs *ifa = NULL; struct ifaddrs *ifa = NULL;
@@ -1367,7 +1387,7 @@ errout:
return -1; return -1;
} }
static void _dns_server_log_rule(char *domain, unsigned char *rule_key, int rule_key_len) static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, int rule_key_len)
{ {
char rule_name[DNS_MAX_CNAME_LEN]; char rule_name[DNS_MAX_CNAME_LEN];
@@ -1380,7 +1400,7 @@ static void _dns_server_log_rule(char *domain, unsigned char *rule_key, int rule
tlog(TLOG_INFO, "RULE-MATCH, domain: %s, rule: %s", domain, rule_name); tlog(TLOG_INFO, "RULE-MATCH, domain: %s, rule: %s", domain, rule_name);
} }
static struct dns_domain_rule *_dns_server_get_domain_rule(char *domain) static struct dns_domain_rule *_dns_server_get_domain_rule(const char *domain)
{ {
int domain_len; int domain_len;
char domain_key[DNS_MAX_CNAME_LEN]; char domain_key[DNS_MAX_CNAME_LEN];
@@ -1416,7 +1436,7 @@ static struct dns_domain_rule *_dns_server_get_domain_rule(char *domain)
return domain_rule; return domain_rule;
} }
static int _dns_server_pre_process_rule_flags(struct dns_request *request, struct dns_packet *packet) static int _dns_server_pre_process_rule_flags(struct dns_request *request)
{ {
struct dns_rule_flags *rule_flag = NULL; struct dns_rule_flags *rule_flag = NULL;
unsigned int flags = 0; unsigned int flags = 0;
@@ -1438,7 +1458,7 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request, struc
if (flags & DOMAIN_FLAG_ADDR_SOA) { if (flags & DOMAIN_FLAG_ADDR_SOA) {
/* return SOA */ /* return SOA */
_dns_server_reply_SOA(DNS_RC_NOERROR, request, packet); _dns_server_reply_SOA(DNS_RC_NOERROR, request);
return 0; return 0;
} }
@@ -1452,7 +1472,7 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request, struc
if (flags & DOMAIN_FLAG_ADDR_IPV4_SOA) { if (flags & DOMAIN_FLAG_ADDR_IPV4_SOA) {
/* return SOA for A request */ /* return SOA for A request */
_dns_server_reply_SOA(DNS_RC_NOERROR, request, packet); _dns_server_reply_SOA(DNS_RC_NOERROR, request);
return 0; return 0;
} }
break; break;
@@ -1464,7 +1484,7 @@ static int _dns_server_pre_process_rule_flags(struct dns_request *request, struc
if (flags & DOMAIN_FLAG_ADDR_IPV6_SOA) { if (flags & DOMAIN_FLAG_ADDR_IPV6_SOA) {
/* return SOA for A request */ /* return SOA for A request */
_dns_server_reply_SOA(DNS_RC_NOERROR, request, packet); _dns_server_reply_SOA(DNS_RC_NOERROR, request);
return 0; return 0;
} }
break; break;
@@ -1477,7 +1497,7 @@ errout:
return -1; return -1;
} }
static int _dns_server_process_address(struct dns_request *request, struct dns_packet *packet) static int _dns_server_process_address(struct dns_request *request)
{ {
struct dns_address_IPV4 *address_ipv4 = NULL; struct dns_address_IPV4 *address_ipv4 = NULL;
struct dns_address_IPV6 *address_ipv6 = NULL; struct dns_address_IPV6 *address_ipv6 = NULL;
@@ -1519,7 +1539,7 @@ errout:
return -1; return -1;
} }
static int _dns_server_process_cache(struct dns_request *request, struct dns_packet *packet) static int _dns_server_process_cache(struct dns_request *request)
{ {
struct dns_cache *dns_cache = NULL; struct dns_cache *dns_cache = NULL;
struct dns_cache *dns_cache_A = NULL; struct dns_cache *dns_cache_A = NULL;
@@ -1540,7 +1560,7 @@ static int _dns_server_process_cache(struct dns_request *request, struct dns_pac
tlog(TLOG_DEBUG, "Force IPV4 perfered."); tlog(TLOG_DEBUG, "Force IPV4 perfered.");
dns_cache_release(dns_cache_A); dns_cache_release(dns_cache_A);
dns_cache_release(dns_cache); dns_cache_release(dns_cache);
return _dns_server_reply_SOA(DNS_RC_NOERROR, request, NULL); return _dns_server_reply_SOA(DNS_RC_NOERROR, request);
} }
} }
} }
@@ -1596,91 +1616,55 @@ errout:
return -1; return -1;
} }
static int _dns_server_recv(struct dns_server_conn *client, unsigned char *inpacket, int inpacket_len, struct sockaddr_storage *from, socklen_t from_len) static void _dns_server_request_set_client(struct dns_request *request, struct dns_server_conn *client)
{ {
int decode_len;
int ret = -1;
unsigned char packet_buff[DNS_PACKSIZE];
char name[DNS_MAX_CNAME_LEN];
struct dns_packet *packet = (struct dns_packet *)packet_buff;
struct dns_request *request = NULL;
struct dns_rrs *rrs;
const char *group_name = NULL;
int rr_count = 0;
int i = 0;
int qclass;
int qtype = DNS_T_ALL;
_dns_server_client_get(client);
/* decode packet */
tlog(TLOG_DEBUG, "recv query packet from %s, len = %d", gethost_by_addr(name, sizeof(name), (struct sockaddr *)from), inpacket_len);
decode_len = dns_decode(packet, DNS_PACKSIZE, inpacket, inpacket_len);
if (decode_len < 0) {
tlog(TLOG_ERROR, "decode failed.\n");
goto errout;
}
tlog(TLOG_DEBUG, "request qdcount = %d, ancount = %d, nscount = %d, nrcount = %d, len = %d, id = %d, tc = %d, rd = %d, ra = %d, rcode = %d\n",
packet->head.qdcount, packet->head.ancount, packet->head.nscount, packet->head.nrcount, inpacket_len, packet->head.id, packet->head.tc,
packet->head.rd, packet->head.ra, packet->head.rcode);
if (packet->head.qr != DNS_QR_QUERY) {
goto errout;
}
request = malloc(sizeof(*request));
if (request == NULL) {
tlog(TLOG_ERROR, "malloc failed.\n");
goto errout;
}
memset(request, 0, sizeof(*request));
pthread_mutex_init(&request->ip_map_lock, NULL);
atomic_set(&request->adblock, 0);
atomic_set(&request->soa_num, 0);
atomic_set(&request->refcnt, 0);
request->ping_ttl_v4 = -1;
request->ping_ttl_v6 = -1;
request->prefetch = 0;
request->rcode = DNS_RC_SERVFAIL;
request->client = client; request->client = client;
INIT_LIST_HEAD(&request->list); _dns_server_client_get(client);
}
/* get client request address type */ static void _dns_server_request_set_id(struct dns_request *request, unsigned short id)
if (_dns_recv_addr(request, from, from_len) != 0) { {
tlog(TLOG_ERROR, "get client address failed."); request->id = id;
goto errout; }
static void _dns_server_request_set_enable_prefetch(struct dns_request *request)
{
request->prefetch = 1;
}
static int _dns_server_request_set_client_addr(struct dns_request *request, struct sockaddr_storage *from, socklen_t from_len)
{
switch (from->ss_family) {
case AF_INET:
memcpy(&request->in, from, from_len);
request->addr_len = from_len;
break;
case AF_INET6:
memcpy(&request->in6, from, from_len);
request->addr_len = from_len;
break;
default:
return -1;
break;
} }
request->id = packet->head.id; return 0;
memcpy(&request->head, &packet->head, sizeof(struct dns_head)); }
hash_init(request->ip_map);
/* get request domain and request qtype */ static void _dns_server_request_set_callback(struct dns_request *request, dns_result_callback callback, void *user_ptr)
rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count); {
if (rr_count > 1) { request->result_callback = callback;
goto errout; request->user_ptr = user_ptr;
} }
for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { static int _dns_server_process_special_query(struct dns_request *request)
ret = dns_get_domain(rrs, request->domain, sizeof(request->domain), &qtype, &qclass); {
if (ret != 0) { int ret = 0;
goto errout;
}
request->qtype = qtype; switch (request->qtype) {
}
tlog(TLOG_INFO, "query server %s from %s, qtype = %d\n", request->domain, name, qtype);
/* lookup domain rule */
request->domain_rule = _dns_server_get_domain_rule(request->domain);
switch (qtype) {
case DNS_T_PTR: case DNS_T_PTR:
/* return PTR record */ /* return PTR record */
ret = _dns_server_process_ptr(request, packet); ret = _dns_server_process_ptr(request);
if (ret == 0) { if (ret == 0) {
goto clean_exit; goto clean_exit;
} else { } else {
@@ -1693,38 +1677,65 @@ static int _dns_server_recv(struct dns_server_conn *client, unsigned char *inpac
case DNS_T_AAAA: case DNS_T_AAAA:
/* force return SOA */ /* force return SOA */
if (dns_conf_force_AAAA_SOA == 1) { if (dns_conf_force_AAAA_SOA == 1) {
_dns_server_reply_SOA(DNS_RC_NOERROR, request, packet); _dns_server_reply_SOA(DNS_RC_NOERROR, request);
goto clean_exit; goto clean_exit;
} }
break; break;
default: default:
tlog(TLOG_DEBUG, "unsupport qtype: %d, domain: %s", qtype, request->domain); tlog(TLOG_DEBUG, "unsupport qtype: %d, domain: %s", request->qtype, request->domain);
request->passthrough = 1; request->passthrough = 1;
/* pass request to upstream server */ /* pass request to upstream server */
break; break;
} }
/* process domain flag */ return -1;
if (_dns_server_pre_process_rule_flags(request, packet) == 0) { clean_exit:
goto clean_exit; return 0;
} }
/* process domain address */
if (_dns_server_process_address(request, packet) == 0) {
goto clean_exit;
}
/* process cache */
if (_dns_server_process_cache(request, packet) == 0) {
goto clean_exit;
}
static const char *_dns_server_get_request_groupname(struct dns_request *request)
{
if (request->domain_rule) { if (request->domain_rule) {
/* Get the nameserver rule */ /* Get the nameserver rule */
if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) { if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) {
struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]; struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER];
group_name = nameserver_rule->group_name; return nameserver_rule->group_name;
}
}
return NULL;
}
static int _dns_server_do_query(struct dns_request *request, const char *domain, int qtype)
{
int ret = -1;
const char *group_name = NULL;
/* lookup domain rule */
request->domain_rule = _dns_server_get_domain_rule(domain);
request->qtype = qtype;
safe_strncpy(request->domain, domain, sizeof(request->domain));
group_name = _dns_server_get_request_groupname(request);
if (_dns_server_process_special_query(request) == 0) {
goto clean_exit;
}
/* process domain flag */
if (_dns_server_pre_process_rule_flags(request) == 0) {
goto clean_exit;
}
/* process domain address */
if (_dns_server_process_address(request) == 0) {
goto clean_exit;
}
/* process cache */
if (request->prefetch == 0) {
if (_dns_server_process_cache(request) == 0) {
goto clean_exit;
} }
} }
@@ -1749,8 +1760,6 @@ static int _dns_server_recv(struct dns_server_conn *client, unsigned char *inpac
request->request_wait++; request->request_wait++;
if (dns_client_query(request->domain, qtype, dns_server_resolve_callback, request, group_name) != 0) { if (dns_client_query(request->domain, qtype, dns_server_resolve_callback, request, group_name) != 0) {
_dns_server_request_release(request); _dns_server_request_release(request);
_dns_server_request_remove(request);
request = NULL;
tlog(TLOG_ERROR, "send dns request failed."); tlog(TLOG_ERROR, "send dns request failed.");
goto errout; goto errout;
} }
@@ -1758,16 +1767,87 @@ static int _dns_server_recv(struct dns_server_conn *client, unsigned char *inpac
return 0; return 0;
clean_exit: clean_exit:
if (request) { if (request) {
free(request); _dns_server_delete_request(request);
} }
_dns_server_client_release(client);
return 0; return 0;
errout: errout:
_dns_server_request_remove(request);
request = NULL;
return ret;
}
static int _dns_server_recv(struct dns_server_conn *client, unsigned char *inpacket, int inpacket_len, struct sockaddr_storage *from, socklen_t from_len)
{
int decode_len;
int ret = -1;
unsigned char packet_buff[DNS_PACKSIZE];
char name[DNS_MAX_CNAME_LEN];
char domain[DNS_MAX_CNAME_LEN];
struct dns_packet *packet = (struct dns_packet *)packet_buff;
struct dns_request *request = NULL;
struct dns_rrs *rrs;
int rr_count = 0;
int i = 0;
int qclass;
int qtype = DNS_T_ALL;
_dns_server_client_get(client);
/* decode packet */
tlog(TLOG_DEBUG, "recv query packet from %s, len = %d", gethost_by_addr(name, sizeof(name), (struct sockaddr *)from), inpacket_len);
decode_len = dns_decode(packet, DNS_PACKSIZE, inpacket, inpacket_len);
if (decode_len < 0) {
tlog(TLOG_ERROR, "decode failed.\n");
goto errout;
}
tlog(TLOG_DEBUG, "request qdcount = %d, ancount = %d, nscount = %d, nrcount = %d, len = %d, id = %d, tc = %d, rd = %d, ra = %d, rcode = %d\n",
packet->head.qdcount, packet->head.ancount, packet->head.nscount, packet->head.nrcount, inpacket_len, packet->head.id, packet->head.tc,
packet->head.rd, packet->head.ra, packet->head.rcode);
if (packet->head.qr != DNS_QR_QUERY) {
goto errout;
}
/* get request domain and request qtype */
rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count);
if (rr_count > 1) {
goto errout;
}
for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) {
ret = dns_get_domain(rrs, domain, sizeof(domain), &qtype, &qclass);
if (ret != 0) {
goto errout;
}
// Only support one question.
break;
}
tlog(TLOG_INFO, "query server %s from %s, qtype = %d\n", domain, name, qtype);
request = _dns_server_new_request();
if (request == NULL) {
tlog(TLOG_ERROR, "malloc failed.\n");
goto errout;
}
_dns_server_request_set_client(request, client);
_dns_server_request_set_client_addr(request, from, from_len);
_dns_server_request_set_id(request, packet->head.id);
ret = _dns_server_do_query(request, domain, qtype);
if (ret != 0) {
tlog(TLOG_ERROR, "do query %s failed.\n", domain);
goto errout;
}
_dns_server_client_release(client);
return ret;
errout:
if (request) { if (request) {
ret = _dns_server_forward_request(inpacket, inpacket_len); ret = _dns_server_forward_request(inpacket, inpacket_len);
free(request); _dns_server_delete_request(request);
} }
_dns_server_client_release(client); _dns_server_client_release(client);
@@ -1775,56 +1855,35 @@ errout:
} }
static int _dns_server_prefetch_request(char *domain, dns_type_t qtype) static int _dns_server_prefetch_request(char *domain, dns_type_t qtype)
{
return dns_server_query(domain, qtype, NULL, NULL);
}
int dns_server_query(char *domain, int qtype, dns_result_callback callback, void *user_ptr)
{ {
int ret = -1; int ret = -1;
struct dns_request *request = NULL; struct dns_request *request = NULL;
const char *group_name = NULL;
request = malloc(sizeof(*request)); request = _dns_server_new_request();
if (request == NULL) { if (request == NULL) {
tlog(TLOG_ERROR, "malloc failed.\n"); tlog(TLOG_ERROR, "malloc failed.\n");
goto errout; goto errout;
} }
memset(request, 0, sizeof(*request));
pthread_mutex_init(&request->ip_map_lock, NULL);
atomic_set(&request->adblock, 0);
request->ping_ttl_v4 = -1;
request->ping_ttl_v6 = -1;
request->prefetch = 1;
request->qtype = qtype;
request->rcode = DNS_RC_SERVFAIL;
request->id = 0; _dns_server_request_set_callback(request, callback, user_ptr);
hash_init(request->ip_map); _dns_server_request_set_enable_prefetch(request);
safe_strncpy(request->domain, domain, DNS_MAX_CNAME_LEN); ret = _dns_server_do_query(request, domain, qtype);
if (ret != 0) {
/* lookup domain rule */ tlog(TLOG_ERROR, "do query %s failed.\n", domain);
request->domain_rule = _dns_server_get_domain_rule(request->domain); goto errout;
tlog(TLOG_INFO, "prefetch domain %s, qtype = %d\n", request->domain, qtype);
_dns_server_request_get(request);
pthread_mutex_lock(&server.request_list_lock);
list_add_tail(&request->list, &server.request_list);
pthread_mutex_unlock(&server.request_list_lock);
_dns_server_request_get(request);
request->send_tick = get_tick_count();
request->request_wait++;
if (request->domain_rule) {
/* get nameserver rule */
if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) {
struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER];
group_name = nameserver_rule->group_name;
}
} }
/* send request */ return ret;
dns_client_query(request->domain, qtype, dns_server_resolve_callback, request, group_name);
return 0;
errout: errout:
if (request) {
_dns_server_delete_request(request);
}
return ret; return ret;
} }
@@ -1834,76 +1893,41 @@ static int _dns_server_process_udp(struct dns_server_conn *dnsserver, struct epo
unsigned char inpacket[DNS_IN_PACKSIZE]; unsigned char inpacket[DNS_IN_PACKSIZE];
struct sockaddr_storage from; struct sockaddr_storage from;
socklen_t from_len = sizeof(from); socklen_t from_len = sizeof(from);
struct msghdr msg;
struct iovec iov;
char ans_data[4096];
struct cmsghdr *cmsg;
len = recvfrom(dnsserver->fd, inpacket, sizeof(inpacket), 0, (struct sockaddr *)&from, (socklen_t *)&from_len); memset(&msg, 0, sizeof(msg));
iov.iov_base = (char *)inpacket;
iov.iov_len = sizeof(inpacket);
msg.msg_name = &from;
msg.msg_namelen = sizeof(from);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = ans_data;
msg.msg_controllen = sizeof(ans_data);
len = recvmsg(dnsserver->fd, &msg, MSG_DONTWAIT);
if (len < 0) { if (len < 0) {
tlog(TLOG_ERROR, "recvfrom failed, %s\n", strerror(errno)); tlog(TLOG_ERROR, "recvfrom failed, %s\n", strerror(errno));
return -1; return -1;
} }
from_len = msg.msg_namelen;
return _dns_server_recv(dnsserver, inpacket, len, &from, from_len); for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
} if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {
const struct in_pktinfo *pktinfo = (struct in_pktinfo *)CMSG_DATA(cmsg);
int dns_server_query(char *domain, int qtype, dns_result_callback callback, void *user_ptr) unsigned char *addr = (unsigned char *)&pktinfo->ipi_addr.s_addr;
{ fill_sockaddr_by_ip(addr, sizeof(in_addr_t), 0, (struct sockaddr *)&dnsserver->localaddr, &dnsserver->localaddr_len);
int ret = -1; } else if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) {
struct dns_request *request = NULL; const struct in6_pktinfo *pktinfo = (struct in6_pktinfo *)CMSG_DATA(cmsg);
const char *group_name = NULL; unsigned char *addr = (unsigned char *)pktinfo->ipi6_addr.s6_addr;
fill_sockaddr_by_ip(addr, sizeof(struct in6_addr), 0, (struct sockaddr *)&dnsserver->localaddr, &dnsserver->localaddr_len);
request = malloc(sizeof(*request));
if (request == NULL) {
tlog(TLOG_ERROR, "malloc failed.\n");
goto errout;
}
memset(request, 0, sizeof(*request));
pthread_mutex_init(&request->ip_map_lock, NULL);
atomic_set(&request->adblock, 0);
request->ping_ttl_v4 = -1;
request->ping_ttl_v6 = -1;
request->prefetch = 1;
request->qtype = qtype;
request->rcode = DNS_RC_SERVFAIL;
request->result_callback = callback;
request->user_ptr = user_ptr;
request->id = 0;
hash_init(request->ip_map);
safe_strncpy(request->domain, domain, DNS_MAX_CNAME_LEN);
/* lookup domain rule */
request->domain_rule = _dns_server_get_domain_rule(request->domain);
tlog(TLOG_INFO, "query domain %s, qtype = %d\n", request->domain, qtype);
/* process cache */
if (_dns_server_process_cache(request, NULL) == 0) {
ret = 0;
goto clean_exit;
}
_dns_server_request_get(request);
pthread_mutex_lock(&server.request_list_lock);
list_add_tail(&request->list, &server.request_list);
pthread_mutex_unlock(&server.request_list_lock);
_dns_server_request_get(request);
request->send_tick = get_tick_count();
request->request_wait++;
if (request->domain_rule) {
/* get nameserver rule */
if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) {
struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER];
group_name = nameserver_rule->group_name;
} }
} }
/* send request */ return _dns_server_recv(dnsserver, inpacket, len, &from, from_len);
ret = dns_client_query(request->domain, qtype, dns_server_resolve_callback, request, group_name);
clean_exit:
return ret;
errout:
return ret;
} }
static void _dns_server_client_touch(struct dns_server_conn *client) static void _dns_server_client_touch(struct dns_server_conn *client)
@@ -1960,6 +1984,12 @@ static int _dns_server_accept(struct dns_server_conn *dnsserver, struct epoll_ev
atomic_set(&client->refcnt, 0); atomic_set(&client->refcnt, 0);
memcpy(&client->addr, &addr, addr_len); memcpy(&client->addr, &addr, addr_len);
client->addr_len = addr_len; client->addr_len = addr_len;
client->localaddr_len = sizeof(struct sockaddr_storage);
if (getsockname(client->fd, (struct sockaddr *)&client->localaddr, &client->localaddr_len) != 0) {
tlog(TLOG_ERROR, "get local addr failed, %s", strerror(errno));
goto errout;
}
_dns_server_client_touch(client); _dns_server_client_touch(client);
@@ -2479,6 +2509,9 @@ static int _dns_create_socket(const char *host_ip, int type)
tlog(TLOG_ERROR, "set socket opt failed."); tlog(TLOG_ERROR, "set socket opt failed.");
goto errout; goto errout;
} }
} else {
setsockopt(fd, IPPROTO_IP, IP_PKTINFO, &optval, sizeof(optval));
setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &optval, sizeof(optval));
} }
if (bind(fd, gai->ai_addr, gai->ai_addrlen) != 0) { if (bind(fd, gai->ai_addr, gai->ai_addrlen) != 0) {

View File

@@ -138,6 +138,33 @@ errout:
return -1; return -1;
} }
int fill_sockaddr_by_ip(unsigned char *ip, int ip_len, int port, struct sockaddr *addr, socklen_t *addr_len)
{
if (ip == NULL || addr == NULL || addr_len == NULL) {
return -1;
}
if (ip_len == IPV4_ADDR_LEN) {
struct sockaddr_in *addr_in = NULL;
addr->sa_family = AF_INET;
addr_in = (struct sockaddr_in *)addr;
addr_in->sin_port = htons(port);
addr_in->sin_family = AF_INET;
memcpy(&addr_in->sin_addr.s_addr, ip, ip_len);
*addr_len = 16;
} else if (ip_len == IPV6_ADDR_LEN) {
struct sockaddr_in6 *addr_in6 = NULL;
addr->sa_family = AF_INET6;
addr_in6 = (struct sockaddr_in6 *)addr;
addr_in6->sin6_port = htons(port);
addr_in6->sin6_family = AF_INET6;
memcpy(addr_in6->sin6_addr.s6_addr, ip, ip_len);
*addr_len = 28;
}
return -1;
}
int parse_ip(const char *value, char *ip, int *port) int parse_ip(const char *value, char *ip, int *port)
{ {
int offset = 0; int offset = 0;
@@ -350,7 +377,7 @@ int set_fd_nonblock(int fd, int nonblock)
return 0; return 0;
} }
char *reverse_string(char *output, char *input, int len, int to_lower_case) char *reverse_string(char *output, const char *input, int len, int to_lower_case)
{ {
char *begin = output; char *begin = output;
if (len <= 0) { if (len <= 0) {

View File

@@ -16,6 +16,8 @@ char *gethost_by_addr(char *host, int maxsize, struct sockaddr *addr);
int getaddr_by_host(char *host, struct sockaddr *addr, socklen_t *addr_len); int getaddr_by_host(char *host, struct sockaddr *addr, socklen_t *addr_len);
int fill_sockaddr_by_ip(unsigned char *ip, int ip_len, int port, struct sockaddr *addr, socklen_t *addr_len);
int parse_ip(const char *value, char *ip, int *port); int parse_ip(const char *value, char *ip, int *port);
int check_is_ipaddr(const char *ip); int check_is_ipaddr(const char *ip);
@@ -24,7 +26,7 @@ int parse_uri(char *value, char *scheme, char *host, int *port, char *path);
int set_fd_nonblock(int fd, int nonblock); int set_fd_nonblock(int fd, int nonblock);
char *reverse_string(char *output, char *input, int len, int to_lower_case); char *reverse_string(char *output, const char *input, int len, int to_lower_case);
void print_stack(void); void print_stack(void);