From 40740f531b66137e8fe405627dca494492b53251 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Thu, 20 Dec 2018 00:16:32 +0800 Subject: [PATCH] ADD EDNS check feature, TCP server bugfix --- etc/smartdns/smartdns.conf | 8 +- src/dns.c | 175 ++++++++++++++++++++++++++++++++----- src/dns.h | 11 ++- src/dns_client.c | 28 +++++- src/dns_client.h | 1 + src/dns_conf.c | 5 ++ src/dns_server.c | 10 ++- src/smartdns.c | 44 ++++++---- 8 files changed, 230 insertions(+), 52 deletions(-) diff --git a/etc/smartdns/smartdns.conf b/etc/smartdns/smartdns.conf index d0f85c0..9d8d53e 100644 --- a/etc/smartdns/smartdns.conf +++ b/etc/smartdns/smartdns.conf @@ -66,17 +66,17 @@ log-level info # audit-num 2 # remote udp dns server list -# server [IP]:[PORT] [-blacklist-ip] +# server [IP]:[PORT] [-blacklist-ip] [-check-edns] # default port is 53 -# server 8.8.8.8 -blacklist-ip +# server 8.8.8.8 -blacklist-ip -check-edns # remote tcp dns server list -# server-tcp [IP]:[PORT] [-blacklist-ip] +# server-tcp [IP]:[PORT] [-blacklist-ip] [-check-edns] # default port is 53 # server-tcp 8.8.8.8 # remote tls dns server list -# server-tls [IP]:[PORT] [-blacklist-ip] +# server-tls [IP]:[PORT] [-blacklist-ip] [-check-edns] # default port is 853 # server-tls 8.8.8.8 # server-tls 1.0.0.1 diff --git a/src/dns.c b/src/dns.c index 70172d3..7ef23bb 100644 --- a/src/dns.c +++ b/src/dns.c @@ -184,6 +184,9 @@ int dns_rr_add_end(struct dns_packet *packet, int type, dns_type_t rtype, int le count = &head->nrcount; start = &packet->additional; break; + case DNS_RRS_OPT: + count = &packet->optcount; + start = &packet->optional; default: return -1; break; @@ -638,10 +641,23 @@ int dns_get_SOA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, struct return 0; } +int dns_set_OPT_payload_size(struct dns_packet *packet, int payload_size) +{ + if (payload_size < 512) { + payload_size = 512; + } + + packet->payloadsize = payload_size; + return 0; +} + +int dns_get_OPT_payload_size(struct dns_packet *packet) +{ + return packet->payloadsize; +} + int dns_add_OPT_ECS(struct dns_packet *packet, dns_rr_type type, struct dns_opt_ecs *ecs) { - // TODO - unsigned char opt_data[DNS_MAX_OPT_LEN]; struct dns_opt *opt = (struct dns_opt *)opt_data; int len = 0; @@ -655,18 +671,21 @@ int dns_add_OPT_ECS(struct dns_packet *packet, dns_rr_type type, struct dns_opt_ len += (ecs->source_prefix / 8); len += (ecs->source_prefix % 8 > 0) ? 1 : 0; - return dns_add_OPT(packet, type, DNS_OPT_T_ECS, len, opt); + return dns_add_RAW(packet, DNS_RRS_OPT, DNS_OPT_T_ECS, "", 0, opt_data, len); } int dns_get_OPT_ECS(struct dns_rrs *rrs, unsigned short *opt_code, unsigned short *opt_len, struct dns_opt_ecs *ecs) { - // TODO - unsigned char opt_data[DNS_MAX_OPT_LEN]; struct dns_opt *opt = (struct dns_opt *)opt_data; - int len = sizeof(opt_data); + int len = DNS_MAX_OPT_LEN; + int ttl = 0; - if (dns_get_OPT(rrs, opt_code, opt_len, opt, &len) != 0) { + if (dns_get_RAW(rrs, 0, 0, &ttl, opt_data, &len) != 0) { + return -1; + } + + if (len < sizeof(*opt)) { return -1; } @@ -801,6 +820,20 @@ static int _dns_encode_head(struct dns_context *context) return len; } +static int _dns_encode_head_count(struct dns_context *context) +{ + int len = 12; + struct dns_head *head = &context->packet->head; + unsigned char *ptr = context->data; + + ptr += 4; + dns_write_short(&ptr, head->qdcount); + dns_write_short(&ptr, head->ancount); + dns_write_short(&ptr, head->nscount); + dns_write_short(&ptr, head->nrcount); + return len; +} + static int _dns_decode_domain(struct dns_context *context, char *output, int size) { int output_len = 0; @@ -880,7 +913,6 @@ static int _dns_decode_domain(struct dns_context *context, char *output, int siz static int _dns_encode_domain(struct dns_context *context, char *domain) { int num = 0; - int total_len = 0; unsigned char *ptr_num = context->ptr++; /*[len]string[len]string...[0]0 */ @@ -897,15 +929,12 @@ static int _dns_encode_domain(struct dns_context *context, char *domain) num++; context->ptr++; domain++; - total_len++; } *ptr_num = num; - if (total_len > 0) { - /* if domain is '\0', [domain] is '\0' */ - *(context->ptr) = 0; - context->ptr++; - } + /* if domain is '\0', [domain] is '\0' */ + *(context->ptr) = 0; + context->ptr++; return 0; } @@ -1238,8 +1267,97 @@ static int _dns_decode_opt_ecs(struct dns_context *context, struct dns_opt_ecs * int _dns_encode_OPT(struct dns_context *context, struct dns_rrs *rrs) { - // TODO - + int ret; + int opt_code = 0; + int qclass = 0; + char domain[DNS_MAX_CNAME_LEN]; + struct dns_data_context data_context; + int rr_len = 0; + int ttl; + + data_context.data = rrs->data; + data_context.ptr = rrs->data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, DNS_MAX_CNAME_LEN, &opt_code, &qclass, &ttl, &rr_len); + if (ret < 0) { + return -1; + } + + if (_dns_left_len(context) < (4 + rr_len)) { + return -1; + } + + dns_write_short(&context->ptr, opt_code); + dns_write_short(&context->ptr, rr_len); + memcpy(context->ptr, data_context.ptr, rr_len); + context->ptr += rr_len; + + return 0; +} + +int _dns_get_opts_data_len(struct dns_packet *packet, struct dns_rrs *rrs, int count) +{ + int i = 0; + int len = 0; + int opt_code = 0; + int qclass = 0; + int ttl; + int ret; + char domain[DNS_MAX_CNAME_LEN]; + struct dns_data_context data_context; + int rr_len = 0; + + for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + data_context.data = rrs->data; + data_context.ptr = rrs->data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, DNS_MAX_CNAME_LEN, &opt_code, &qclass, &ttl, &rr_len); + if (ret < 0) { + return -1; + } + + len += rr_len; + } + + return len; +} + +int _dns_encode_opts(struct dns_packet *packet, struct dns_context *context, struct dns_rrs *rrs, int count) +{ + int i = 0; + int len = 0; + int ret = 0; + unsigned int rcode = 0; + int rr_len = 0; + int payloadsize = packet->payloadsize; + + rr_len = _dns_get_opts_data_len(packet, rrs, count); + if (rr_len < 0) { + return -1; + } + + if (payloadsize < DNS_DEFAULT_PACKET_SIZE) { + payloadsize = DNS_DEFAULT_PACKET_SIZE; + } + + ret = _dns_encode_rr_head(context, "0", DNS_T_OPT, payloadsize, rcode, rr_len); + if (ret < 0) { + return -1; + } + + if (_dns_left_len(context) < rr_len) { + return -1; + } + + for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + len = _dns_encode_OPT(context, rrs); + if (len < 0) { + return -1; + } + } + return 0; } @@ -1453,6 +1571,8 @@ static int _dns_decode_an(struct dns_context *context, dns_rr_type type) tlog(TLOG_ERROR, "opt length mitchmatch, %s\n", domain); return -1; } + + dns_set_OPT_payload_size(packet, qclass); } break; default: context->ptr += rr_len; @@ -1517,12 +1637,6 @@ static int _dns_encode_an(struct dns_context *context, struct dns_rrs *rrs) return -1; } break; - case DNS_T_OPT: - ret = _dns_encode_OPT(context, rrs); - if (ret < 0) { - return -1; - } - break; default: break; } @@ -1621,6 +1735,15 @@ static int _dns_encode_body(struct dns_context *context) } } + rrs = dns_get_rrs_start(packet, DNS_RRS_OPT, &count); + if (count > 0 || packet->payloadsize > 0) { + len = _dns_encode_opts(packet, context, rrs, count); + if (len < 0) { + return -1; + } + head->nrcount++; + } + return 0; } @@ -1641,6 +1764,9 @@ int dns_packet_init(struct dns_packet *packet, int size, struct dns_head *head) packet->answers = DNS_RR_END; packet->nameservers = DNS_RR_END; packet->additional = DNS_RR_END; + packet->optional = DNS_RR_END; + packet->optcount = 0; + packet->payloadsize = 0; return 0; } @@ -1695,6 +1821,11 @@ int dns_encode(unsigned char *data, int size, struct dns_packet *packet) return -1; } + ret = _dns_encode_head_count(&context); + if (ret < 0) { + return -1; + } + return context.ptr - context.data; } diff --git a/src/dns.h b/src/dns.h index 266f94a..c962096 100644 --- a/src/dns.h +++ b/src/dns.h @@ -12,6 +12,7 @@ #define DNS_MAX_OPT_LEN 256 #define DNS_IN_PACKSIZE (512 * 4) #define DNS_PACKSIZE (512 * 8) +#define DNS_DEFAULT_PACKET_SIZE 512 typedef enum dns_qr { DNS_QR_QUERY = 0, @@ -23,7 +24,8 @@ typedef enum dns_rr_type { DNS_RRS_AN = 1, DNS_RRS_NS = 2, DNS_RRS_NR = 3, - DNS_RRS_END = 4, + DNS_RRS_OPT = 4, + DNS_RRS_END, } dns_rr_type; typedef enum dns_class { DNS_C_IN = 1, DNS_C_ANY = 255 } dns_class_t; @@ -104,6 +106,9 @@ struct dns_packet { unsigned short answers; unsigned short nameservers; unsigned short additional; + unsigned short optcount; + unsigned short optional; + unsigned short payloadsize; int size; int len; unsigned char data[0]; @@ -180,8 +185,8 @@ int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsig int dns_add_SOA(struct dns_packet *packet, dns_rr_type type, char *domain, int ttl, struct dns_soa *soa); int dns_get_SOA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, struct dns_soa *soa); -int dns_add_OPT(struct dns_packet *packet, dns_rr_type type, unsigned short opt_code, unsigned short opt_len, struct dns_opt *opt); -int dns_get_OPT(struct dns_rrs *rrs, unsigned short *opt_code, unsigned short *opt_len, struct dns_opt *opt, int *opt_maxlen); +int dns_set_OPT_payload_size(struct dns_packet *packet, int payload_size); +int dns_get_OPT_payload_size(struct dns_packet *packet); int dns_add_OPT_ECS(struct dns_packet *packet, dns_rr_type type, struct dns_opt_ecs *ecs); int dns_get_OPT_ECS(struct dns_rrs *rrs, unsigned short *opt_code, unsigned short *opt_len, struct dns_opt_ecs *ecs); diff --git a/src/dns_client.c b/src/dns_client.c index 24bdf91..eb1aa87 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -48,7 +48,7 @@ #define DNS_MAX_HOSTNAME 256 #define DNS_MAX_EVENTS 64 #define DNS_HOSTNAME_LEN 128 -#define DNS_TCP_BUFFER (8 * 1024) +#define DNS_TCP_BUFFER (16 * 1024) /* dns client */ struct dns_client { @@ -602,6 +602,7 @@ static int _dns_client_recv(struct dns_server_info *server_info, unsigned char * int ret = 0; struct dns_query_struct *query; int request_num = 0; + int has_opt = 0; packet->head.tc = 0; @@ -620,9 +621,9 @@ static int _dns_client_recv(struct dns_server_info *server_info, unsigned char * return -1; } - tlog(TLOG_DEBUG, "qdcount = %d, ancount = %d, nscount = %d, nrcount = %d, len = %d, id = %d, tc = %d, rd = %d, ra = %d, rcode = %d\n", packet->head.qdcount, + tlog(TLOG_DEBUG, "qdcount = %d, ancount = %d, nscount = %d, nrcount = %d, len = %d, id = %d, tc = %d, rd = %d, ra = %d, rcode = %d, payloadsize = %d\n", packet->head.qdcount, packet->head.ancount, packet->head.nscount, packet->head.nrcount, inpacket_len, packet->head.id, packet->head.tc, packet->head.rd, packet->head.ra, - packet->head.rcode); + packet->head.rcode, dns_get_OPT_payload_size(packet)); /* get question */ rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count); @@ -631,9 +632,13 @@ static int _dns_client_recv(struct dns_server_info *server_info, unsigned char * tlog(TLOG_DEBUG, "domain: %s qtype: %d qclass: %d\n", domain, qtype, qclass); } + if (dns_get_OPT_payload_size(packet) > 0) { + has_opt = 1; + } + /* get query reference */ query = _dns_client_get_request(packet->head.id, domain); - if (query == NULL) { + if (query == NULL || (query && has_opt == 0 && server_info->result_flag & DNSSERVER_FLAG_CHECK_EDNS)) { return 0; } @@ -813,6 +818,10 @@ static int _dns_client_create_socket(struct dns_server_info *server_info) time(&server_info->last_send); time(&server_info->last_recv); + if (server_info->fd > 0) { + return -1; + } + if (server_info->type == DNS_SERVER_UDP) { return _dns_client_create_socket_udp(server_info); } else if (server_info->type == DNS_SERVER_TCP) { @@ -1372,6 +1381,11 @@ static void *_dns_client_work(void *arg) for (i = 0; i < num; i++) { struct epoll_event *event = &events[i]; struct dns_server_info *server_info = (struct dns_server_info *)event->data.ptr; + if (server_info == NULL) { + tlog(TLOG_WARN, "server info is invalid."); + continue; + } + _dns_client_process(server_info, event, now); } } @@ -1398,6 +1412,7 @@ static int _dns_client_send_data_to_buffer(struct dns_server_info *server_info, struct epoll_event event; if (DNS_TCP_BUFFER - server_info->send_buff.len < len) { + errno = ENOMEM; return -1; } @@ -1434,6 +1449,10 @@ static int _dns_client_send_tcp(struct dns_server_info *server_info, void *packe /* save data to buffer, and retry when EPOLLOUT is available */ return _dns_client_send_data_to_buffer(server_info, inpacket, len); } + + if (errno == EPIPE) { + shutdown(server_info->fd, SHUT_RDWR); + } return -1; } else if (send_len < len) { /* save remain data to buffer, and retry when EPOLLOUT is available */ @@ -1552,6 +1571,7 @@ static int _dns_client_send_query(struct dns_query_struct *query, char *doamin) /* add question */ dns_add_domain(packet, doamin, query->qtype, DNS_C_IN); + dns_set_OPT_payload_size(packet, 1024); /* encode packet */ encode_len = dns_encode(inpacket, DNS_IN_PACKSIZE, packet); if (encode_len <= 0) { diff --git a/src/dns_client.h b/src/dns_client.h index d4af9c6..2ef63b1 100644 --- a/src/dns_client.h +++ b/src/dns_client.h @@ -17,6 +17,7 @@ typedef enum dns_result_type { } dns_result_type; #define DNSSERVER_FLAG_BLACKLIST_IP (0x1 << 0) +#define DNSSERVER_FLAG_CHECK_EDNS (0x1 << 1) int dns_client_init(void); diff --git a/src/dns_conf.c b/src/dns_conf.c index 7d29f6b..9e82d47 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -52,6 +52,7 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po /* clang-format off */ static struct option long_options[] = { {"blacklist-ip", 0, 0, 'b'}, + {"check-edns", 0, 0, 'e'}, {0, 0, 0, 0} }; /* clang-format on */ @@ -73,6 +74,10 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po result_flag |= DNSSERVER_FLAG_BLACKLIST_IP; break; } + case 'e': { + result_flag |= DNSSERVER_FLAG_CHECK_EDNS; + break; + } } } diff --git a/src/dns_server.c b/src/dns_server.c index 7105a7f..b2e5742 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -551,6 +551,7 @@ void _dns_server_request_release(struct dns_request *request) free(addr_map); } pthread_mutex_destroy(&request->ip_map_lock); + _dns_server_client_release(request->client); memset(request, 0, sizeof(*request)); free(request); } @@ -903,7 +904,6 @@ static int dns_server_resolve_callback(char *domain, dns_result_type rtype, unsi void *user_ptr) { struct dns_request *request = user_ptr; - struct dns_server_conn *client = request->client; int ip_num = 0; if (request == NULL) { @@ -936,7 +936,6 @@ static int dns_server_resolve_callback(char *domain, dns_result_type rtype, unsi _dns_server_request_remove(request); } _dns_server_request_release(request); - _dns_server_client_release(client); } return 0; @@ -1944,6 +1943,13 @@ errout: void _dns_server_close_socket(void) { + struct dns_server_conn *client, *tmp; + + list_for_each_entry_safe(client, tmp, &server.client_list, list) + { + _dns_server_client_close(client); + } + if (server.udp_server.fd > 0) { close(server.udp_server.fd); server.udp_server.fd = 0; diff --git a/src/smartdns.c b/src/smartdns.c index b61c379..7fcdbb3 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -272,21 +272,36 @@ void smartdns_exit(void) dns_server_load_exit(); } -void sig_handle(int sig) +void sig_exit(int signo) { - switch (sig) { - case SIGINT: - dns_server_stop(); - return; - break; - default: - break; - } - tlog(TLOG_ERROR, "process exit with signal %d\n", sig); + dns_server_stop(); +} + +void sig_error_exit(int signo, siginfo_t *siginfo, void *context) +{ + tlog(TLOG_ERROR, "process exit with signal %d, code = %d, errno = %d, pid = %d, self = %d, addr = %p\n", signo, + siginfo->si_code, siginfo->si_errno, siginfo->si_pid, getpid(), siginfo->si_addr); sleep(1); _exit(0); } +int sig_list[] = {SIGSEGV, SIGABRT, SIGPIPE, SIGBUS, SIGILL, SIGFPE}; + +int sig_num = sizeof(sig_list) / sizeof(int); + +void reg_signal(void) +{ + struct sigaction act, old; + int i = 0; + act.sa_sigaction = sig_error_exit; + sigemptyset(&act.sa_mask); + act.sa_flags = SA_RESTART | SA_SIGINFO; + + for (i = 0; i < sig_num; i++) { + sigaction(sig_list[i], &act, &old); + } +} + int main(int argc, char *argv[]) { int ret; @@ -330,12 +345,7 @@ int main(int argc, char *argv[]) } if (signal_ignore == 0) { - signal(SIGABRT, sig_handle); - signal(SIGPIPE, SIG_IGN); - signal(SIGBUS, sig_handle); - signal(SIGSEGV, sig_handle); - signal(SIGILL, sig_handle); - signal(SIGFPE, sig_handle); + reg_signal(); } if (dns_server_load_conf(config_file) != 0) { @@ -351,7 +361,7 @@ int main(int argc, char *argv[]) goto errout; } - signal(SIGINT, sig_handle); + signal(SIGINT, sig_exit); atexit(smartdns_exit); return smartdns_run();