Support TLS query.

This commit is contained in:
Nick Peng
2018-10-14 22:52:45 +08:00
parent f68e4eda1c
commit 4e92267f24
9 changed files with 478 additions and 16 deletions

View File

@@ -42,6 +42,8 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#define DNS_MAX_HOSTNAME 256
#define DNS_MAX_EVENTS 64
@@ -73,6 +75,14 @@ struct dns_server_buff {
unsigned short len;
};
typedef enum dns_server_status {
DNS_SERVER_STATUS_INIT = 0,
DNS_SERVER_STATUS_CONNECTING,
DNS_SERVER_STATUS_CONNECTIONLESS,
DNS_SERVER_STATUS_CONNECTED,
DNS_SERVER_STATUS_DISCONNECTED,
} dns_server_status;
/* dns server information */
struct dns_server_info {
struct list_head list;
@@ -84,6 +94,9 @@ struct dns_server_info {
/* client socket */
int fd;
SSL *ssl;
SSL_CTX *ssl_ctx;
dns_server_status status;
struct dns_server_buff send_buff;
struct dns_server_buff recv_buff;
@@ -142,6 +155,7 @@ struct dns_query_struct {
static struct dns_client client;
static atomic_t dns_client_sid = ATOMIC_INIT(0);
/* get addr info */
static struct addrinfo *_dns_client_getaddr(const char *host, char *port, int type, int protocol)
{
@@ -214,6 +228,7 @@ int _dns_client_server_add(char *server_ip, struct addrinfo *gai, dns_server_typ
server_info->ai_addrlen = gai->ai_addrlen;
server_info->type = server_type;
server_info->fd = 0;
server_info->status = DNS_SERVER_STATUS_INIT;
if (gai->ai_addrlen > sizeof(server_info->in6)) {
tlog(TLOG_ERROR, "addr len invalid, %d, %zd, %d", gai->ai_addrlen, sizeof(server_info->addr), server_info->ai_family);
@@ -253,9 +268,22 @@ static void _dns_client_close_socket(struct dns_server_info *server_info)
if (server_info->fd <= 0) {
return;
}
if (server_info->ssl) {
SSL_shutdown(server_info->ssl);
SSL_free(server_info->ssl);
server_info->ssl = NULL;
}
if (server_info->ssl_ctx) {
SSL_CTX_free(server_info->ssl_ctx);
server_info->ssl_ctx = NULL;
}
epoll_ctl(client.epoll_fd, EPOLL_CTL_DEL, server_info->fd, NULL);
close(server_info->fd);
server_info->fd = -1;
server_info->status = DNS_SERVER_STATUS_DISCONNECTED;
}
/* remove all servers information */
@@ -320,10 +348,17 @@ int _dns_client_server_operate(char *server_ip, int port, dns_server_type_t serv
return -1;
}
if (server_type == DNS_SERVER_UDP) {
switch (server_type) {
case DNS_SERVER_UDP:
sock_type = SOCK_DGRAM;
} else {
break;
case DNS_SERVER_TLS:
case DNS_SERVER_TCP:
sock_type = SOCK_STREAM;
break;
default:
return -1;
break;
}
/* get addr info */
@@ -641,6 +676,7 @@ static int _dns_client_create_socket_udp(struct dns_server_info *server_info)
}
server_info->fd = fd;
server_info->status = DNS_SERVER_STATUS_CONNECTIONLESS;
return 0;
errout:
@@ -683,6 +719,7 @@ static int _DNS_client_create_socket_tcp(struct dns_server_info *server_info)
}
server_info->fd = fd;
server_info->status = DNS_SERVER_STATUS_CONNECTING;
return 0;
errout:
@@ -693,12 +730,91 @@ errout:
return -1;
}
static int _DNS_client_create_socket_tls(struct dns_server_info *server_info)
{
int fd = 0;
struct epoll_event event;
SSL_CTX *ctx = NULL;
SSL *ssl = NULL;
ctx = SSL_CTX_new(SSLv23_client_method());
if (ctx == NULL) {
tlog(TLOG_ERROR, "create ssl ctx failed.");
goto errout;
}
ssl = SSL_new(ctx);
if (ssl == NULL) {
tlog(TLOG_ERROR, "new ssl failed.");
goto errout;
}
fd = socket(server_info->ai_family, SOCK_STREAM, 0);
if (fd < 0) {
tlog(TLOG_ERROR, "create socket failed.");
goto errout;
}
if (set_fd_nonblock(fd, 1) != 0) {
tlog(TLOG_ERROR, "set socket non block failed, %s", strerror(errno));
goto errout;
}
if (connect(fd, (struct sockaddr *)&server_info->addr, server_info->ai_addrlen) != 0) {
if (errno != EINPROGRESS) {
tlog(TLOG_ERROR, "connect failed.");
goto errout;
}
}
if(SSL_set_fd(ssl, fd) == 0) {
tlog(TLOG_ERROR, "ssl set fd failed.");
goto errout;
}
memset(&event, 0, sizeof(event));
event.events = EPOLLIN | EPOLLOUT;
event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_ADD, fd, &event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed.");
goto errout;
}
server_info->fd = fd;
server_info->ssl = ssl;
server_info->ssl_ctx = ctx;
server_info->status = DNS_SERVER_STATUS_CONNECTING;
tlog(TLOG_DEBUG, "TLS server connecting.\n");
return 0;
errout:
if (fd > 0) {
close(fd);
}
if (ssl) {
SSL_free(ssl);
}
if (ctx) {
SSL_CTX_free(ctx);
}
return -1;
}
static int _dns_client_create_socket(struct dns_server_info *server_info)
{
time(&server_info->last_send);
time(&server_info->last_recv);
if (server_info->type == DNS_SERVER_UDP) {
return _dns_client_create_socket_udp(server_info);
} else if (server_info->type == DNS_SERVER_TCP) {
return _DNS_client_create_socket_tcp(server_info);
} else if (server_info->type == DNS_SERVER_TLS) {
return _DNS_client_create_socket_tls(server_info);
} else {
return -1;
}
@@ -741,6 +857,10 @@ static int _dns_client_process_tcp(struct dns_server_info *server_info, struct e
/* when connected */
if (event->events & EPOLLOUT) {
struct epoll_event event;
if (server_info->status != DNS_SERVER_STATUS_CONNECTED) {
server_info->status = DNS_SERVER_STATUS_DISCONNECTED;
}
pthread_mutex_lock(&client.server_list_lock);
if (server_info->send_buff.len > 0) {
/* send data in send_buffer */
@@ -861,6 +981,279 @@ errout:
return -1;
}
static int _dns_client_socket_send(SSL *ssl, const void *buf, int num)
{
int ret = 0;
int ssl_ret = 0;
unsigned long ssl_err = 0;
if (ssl == NULL) {
return -1;
}
ret = SSL_write(ssl, buf, num);
if (ret > 0) {
return ret;
}
ssl_ret = SSL_get_error(ssl, ret);
switch (ssl_ret) {
case SSL_ERROR_NONE:
errno = EAGAIN;
return -1;
break;
case SSL_ERROR_ZERO_RETURN:
return 0;
break;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
if (ERR_GET_REASON(ssl_err) == SSL_R_UNINITIALIZED) {
errno = EAGAIN;
return -1;
}
tlog(TLOG_ERROR, "SSL write fail error no: %s(%ld)\n", ERR_reason_error_string(ssl_err), ssl_err);
errno = EFAULT;
ret = -1;
break;
case SSL_ERROR_SYSCALL:
tlog(TLOG_ERROR, "SSL syscall failed, %s", strerror(errno));
return ret;
default:
errno = EFAULT;
ret = -1;
break;
}
return ret;
}
static int _dns_client_socket_recv(SSL *ssl, void *buf, int num)
{
int ret = 0;
int ssl_ret = 0;
unsigned long ssl_err = 0;
ret = SSL_read(ssl, buf, num);
if (ret > 0) {
return ret;
}
ssl_ret = SSL_get_error(ssl, ret);
switch (ssl_ret) {
case SSL_ERROR_NONE:
errno = EAGAIN;
return -1;
break;
case SSL_ERROR_ZERO_RETURN:
return 0;
break;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
if (ERR_GET_REASON(ssl_err) == SSL_R_UNINITIALIZED) {
errno = EAGAIN;
return -1;
}
tlog(TLOG_ERROR, "SSL read fail error no: %s(%ld)\n", ERR_reason_error_string(ssl_err), ssl_err);
errno = EFAULT;
ret = -1;
break;
case SSL_ERROR_SYSCALL:
tlog(TLOG_ERROR, "SSL syscall failed, %s", strerror(errno));
return ret;
default:
errno = EFAULT;
ret = -1;
break;
}
return ret;
}
static int _dns_client_process_tls(struct dns_server_info *server_info, struct epoll_event *event, unsigned long now)
{
int len;
int ret = -1;
unsigned char *inpacket_data = server_info->recv_buff.data;
char from_host[DNS_MAX_CNAME_LEN];
struct epoll_event fd_event;
int ssl_ret;
if (server_info->status == DNS_SERVER_STATUS_CONNECTING) {
ret = SSL_connect(server_info->ssl);
if (ret == 0) {
goto errout;
} else if (ret < 0) {
memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = SSL_get_error(server_info->ssl, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT;
} else {
goto errout;
}
fd_event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed.");
goto errout;
}
return 0;
}
tlog(TLOG_DEBUG, "TLS server connected.\n");
server_info->status = DNS_SERVER_STATUS_CONNECTED;
memset(&fd_event, 0, sizeof(fd_event));
fd_event.events = EPOLLIN | EPOLLOUT;
fd_event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed.");
goto errout;
}
}
/* when connected */
if (event->events & EPOLLOUT) {
pthread_mutex_lock(&client.server_list_lock);
if (server_info->send_buff.len > 0) {
/* send data in send_buffer */
len = _dns_client_socket_send(server_info->ssl, server_info->send_buff.data, server_info->send_buff.len);
if (len < 0) {
pthread_mutex_unlock(&client.server_list_lock);
goto errout;
}
server_info->send_buff.len -= len;
if (server_info->send_buff.len > 0) {
memmove(server_info->send_buff.data, server_info->send_buff.data + len, server_info->send_buff.len);
}
}
pthread_mutex_unlock(&client.server_list_lock);
/* still remain data, retry */
if (server_info->send_buff.len > 0) {
return 0;
}
/* clear epllout event */
memset(&fd_event, 0, sizeof(fd_event));
fd_event.events = EPOLLIN;
fd_event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed.");
return -1;
}
return 0;
}
/* receive from tcp */
len = _dns_client_socket_recv(server_info->ssl, server_info->recv_buff.data + server_info->recv_buff.len, DNS_TCP_BUFFER - server_info->recv_buff.len);
if (len < 0) {
/* no data to recv, try again */
if (errno == EAGAIN) {
return 0;
}
/* FOR GFW */
if (errno == ECONNRESET) {
goto errout;
}
tlog(TLOG_ERROR, "recv failed, %s, %d\n", strerror(errno), errno);
goto errout;
}
/* peer server close */
if (len == 0) {
pthread_mutex_lock(&client.server_list_lock);
_dns_client_close_socket(server_info);
server_info->recv_buff.len = 0;
if (server_info->send_buff.len > 0) {
/* still remain request data, reconnect and send*/
ret = _dns_client_create_socket(server_info);
} else {
ret = 0;
}
pthread_mutex_unlock(&client.server_list_lock);
tlog(TLOG_DEBUG, "peer close, left = %d", server_info->send_buff.len);
return ret;
}
time(&server_info->last_recv);
server_info->recv_buff.len += len;
if (server_info->recv_buff.len < 2) {
/* wait and recv */
return 0;
}
while (1) {
/* tcp result format
* | len (short) | dns query result |
*/
inpacket_data = server_info->recv_buff.data;
len = ntohs(*((unsigned short *)(inpacket_data)));
if (len <= 0 || len >= DNS_IN_PACKSIZE) {
/* data len is invalid */
goto errout;
}
if (len > server_info->recv_buff.len - 2) {
/* len is not expceded, wait and recv */
break;
}
inpacket_data = server_info->recv_buff.data + 2;
tlog(TLOG_DEBUG, "recv tcp from %s, len = %d", gethost_by_addr(from_host, (struct sockaddr *)&server_info->addr, server_info->ai_addrlen), len);
/* process result */
if (_dns_client_recv(inpacket_data, len, &server_info->addr, server_info->ai_addrlen) != 0) {
goto errout;
}
len += 2;
server_info->recv_buff.len -= len;
/* move to next result */
if (server_info->recv_buff.len > 0) {
memmove(server_info->recv_buff.data, server_info->recv_buff.data + len, server_info->recv_buff.len);
} else {
break;
}
}
return 0;
errout:
pthread_mutex_lock(&client.server_list_lock);
server_info->recv_buff.len = 0;
server_info->send_buff.len = 0;
_dns_client_close_socket(server_info);
pthread_mutex_unlock(&client.server_list_lock);
return -1;
}
static int _dns_client_process(struct dns_server_info *server_info, struct epoll_event *event, unsigned long now)
{
if (server_info->type == DNS_SERVER_UDP) {
@@ -869,6 +1262,9 @@ static int _dns_client_process(struct dns_server_info *server_info, struct epoll
} else if (server_info->type == DNS_SERVER_TCP) {
/* receive from tcp */
return _dns_client_process_tcp(server_info, event, now);
} else if (server_info->type == DNS_SERVER_TLS) {
/* recive from tls */
return _dns_client_process_tls(server_info, event, now);
} else {
return -1;
}
@@ -981,6 +1377,34 @@ static int _dns_client_send_tcp(struct dns_server_info *server_info, void *packe
return 0;
}
static int _dns_client_send_tls(struct dns_server_info *server_info, void *packet, unsigned short len)
{
int send_len = 0;
unsigned char inpacket_data[DNS_IN_PACKSIZE];
unsigned char *inpacket = inpacket_data;
/* TCP query format
* | len (short) | dns query data |
*/
*((unsigned short *)(inpacket)) = htons(len);
memcpy(inpacket + 2, packet, len);
len += 2;
send_len = _dns_client_socket_send(server_info->ssl, inpacket, len);
if (send_len < 0) {
if (errno == EAGAIN || server_info->ssl == NULL) {
/* save data to buffer, and retry when EPOLLOUT is available */
return _dns_client_send_data_to_buffer(server_info, inpacket, len);
}
return -1;
} else if (send_len < len) {
/* save remain data to buffer, and retry when EPOLLOUT is available */
return _dns_client_send_data_to_buffer(server_info, inpacket + send_len, len - send_len);
}
return 0;
}
static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, int len)
{
struct dns_server_info *server_info, *tmp;
@@ -1012,6 +1436,11 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet,
ret = _dns_client_send_tcp(server_info, packet, len);
send_err = errno;
break;
case DNS_SERVER_TLS:
/* tls query */
ret = _dns_client_send_tls(server_info, packet, len);
send_err = errno;
break;
default:
/* unsupport query type */
ret = -1;