tls: refactor tls send recv code.
This commit is contained in:
108
src/dns_server.c
108
src/dns_server.c
@@ -188,6 +188,7 @@ struct dns_server_conn_tcp_client {
|
||||
struct dns_server_conn_tls_client {
|
||||
struct dns_server_conn_tcp_client tcp;
|
||||
SSL *ssl;
|
||||
int ssl_want_write;
|
||||
pthread_mutex_t ssl_lock;
|
||||
};
|
||||
|
||||
@@ -6135,16 +6136,16 @@ static int _dns_server_socket_ssl_send(struct dns_server_conn_tls_client *tls_cl
|
||||
ssl_ret = _ssl_get_error(tls_client, ret);
|
||||
switch (ssl_ret) {
|
||||
case SSL_ERROR_NONE:
|
||||
case SSL_ERROR_ZERO_RETURN:
|
||||
return 0;
|
||||
break;
|
||||
case SSL_ERROR_ZERO_RETURN:
|
||||
case SSL_ERROR_WANT_READ:
|
||||
errno = EAGAIN;
|
||||
ret = -1;
|
||||
ret = -SSL_ERROR_WANT_READ;
|
||||
break;
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
errno = EAGAIN;
|
||||
ret = -1;
|
||||
ret = -SSL_ERROR_WANT_WRITE;
|
||||
break;
|
||||
case SSL_ERROR_SSL:
|
||||
ssl_err = ERR_get_error();
|
||||
@@ -6184,7 +6185,7 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
|
||||
}
|
||||
|
||||
ret = _ssl_read(tls_client, buf, num);
|
||||
if (ret >= 0) {
|
||||
if (ret > 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -6196,11 +6197,11 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
|
||||
break;
|
||||
case SSL_ERROR_WANT_READ:
|
||||
errno = EAGAIN;
|
||||
ret = -1;
|
||||
ret = -SSL_ERROR_WANT_READ;
|
||||
break;
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
errno = EAGAIN;
|
||||
ret = -1;
|
||||
ret = -SSL_ERROR_WANT_WRITE;
|
||||
break;
|
||||
case SSL_ERROR_SSL:
|
||||
ssl_err = ERR_get_error();
|
||||
@@ -6214,7 +6215,14 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
|
||||
return 0;
|
||||
}
|
||||
|
||||
tlog(TLOG_INFO, "SSL read fail error no: %s(%lx), len: %d\n", ERR_reason_error_string(ssl_err), ssl_err, num);
|
||||
#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING
|
||||
if (ssl_reason == SSL_R_UNEXPECTED_EOF_WHILE_READING) {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
tlog(TLOG_DEBUG, "SSL read fail error no: %s(%lx), reason: %d\n", ERR_reason_error_string(ssl_err), ssl_err,
|
||||
ssl_reason);
|
||||
errno = EFAULT;
|
||||
ret = -1;
|
||||
break;
|
||||
@@ -6223,9 +6231,6 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (errno != ECONNRESET) {
|
||||
tlog(TLOG_INFO, "SSL syscall failed, %s ", strerror(errno));
|
||||
}
|
||||
ret = -1;
|
||||
return ret;
|
||||
default:
|
||||
@@ -6237,13 +6242,46 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int _dns_server_ssl_poll_event(struct dns_server_conn_tls_client *tls_client, int ssl_ret)
|
||||
{
|
||||
struct epoll_event fd_event;
|
||||
|
||||
memset(&fd_event, 0, sizeof(fd_event));
|
||||
|
||||
if (ssl_ret == SSL_ERROR_WANT_READ) {
|
||||
fd_event.events = EPOLLIN;
|
||||
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
|
||||
fd_event.events = EPOLLOUT | EPOLLIN;
|
||||
} else {
|
||||
goto errout;
|
||||
}
|
||||
|
||||
fd_event.data.ptr = tls_client;
|
||||
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->tcp.head.fd, &fd_event) != 0) {
|
||||
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
|
||||
goto errout;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
errout:
|
||||
return -1;
|
||||
}
|
||||
|
||||
static int _dns_server_tcp_socket_send(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len)
|
||||
{
|
||||
if (tcp_client->head.type == DNS_CONN_TYPE_TCP_CLIENT) {
|
||||
return send(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
|
||||
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
|
||||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
|
||||
int ret = _dns_server_socket_ssl_send((struct dns_server_conn_tls_client *)tcp_client, data, data_len);
|
||||
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcp_client;
|
||||
tls_client->ssl_want_write = 0;
|
||||
int ret = _dns_server_socket_ssl_send(tls_client, data, data_len);
|
||||
if (ret < 0 && errno == EAGAIN) {
|
||||
if (_dns_server_ssl_poll_event(tls_client, SSL_ERROR_WANT_WRITE) == 0) {
|
||||
errno = EAGAIN;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
return -1;
|
||||
@@ -6256,7 +6294,16 @@ static int _dns_server_tcp_socket_recv(struct dns_server_conn_tcp_client *tcp_cl
|
||||
return recv(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
|
||||
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
|
||||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
|
||||
return _dns_server_socket_ssl_recv((struct dns_server_conn_tls_client *)tcp_client, data, data_len);
|
||||
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcp_client;
|
||||
int ret = _dns_server_socket_ssl_recv(tls_client, data, data_len);
|
||||
if (ret == -SSL_ERROR_WANT_WRITE && errno == EAGAIN) {
|
||||
if (_dns_server_ssl_poll_event(tls_client, SSL_ERROR_WANT_WRITE) == 0) {
|
||||
errno = EAGAIN;
|
||||
tls_client->ssl_want_write = 1;
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
@@ -6287,7 +6334,7 @@ static int _dns_server_tcp_recv(struct dns_server_conn_tcp_client *tcpclient)
|
||||
return RECV_ERROR_CLOSE;
|
||||
}
|
||||
|
||||
tlog(TLOG_ERROR, "recv failed, %s\n", strerror(errno));
|
||||
tlog(TLOG_DEBUG, "recv failed, %s\n", strerror(errno));
|
||||
return RECV_ERROR_FAIL;
|
||||
} else if (len == 0) {
|
||||
return RECV_ERROR_CLOSE;
|
||||
@@ -6451,10 +6498,22 @@ static int _dns_server_tcp_process_requests(struct dns_server_conn_tcp_client *t
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int _dns_server_tls_want_write(struct dns_server_conn_tcp_client *tcpclient)
|
||||
{
|
||||
if (tcpclient->head.type == DNS_CONN_TYPE_TLS_CLIENT || tcpclient->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
|
||||
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcpclient;
|
||||
if (tls_client->ssl_want_write == 1) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int _dns_server_tcp_send(struct dns_server_conn_tcp_client *tcpclient)
|
||||
{
|
||||
int len = 0;
|
||||
while (tcpclient->sndbuff.size > 0) {
|
||||
while (tcpclient->sndbuff.size > 0 || _dns_server_tls_want_write(tcpclient) == 1) {
|
||||
len = _dns_server_tcp_socket_send(tcpclient, tcpclient->sndbuff.buf, tcpclient->sndbuff.size);
|
||||
if (len < 0) {
|
||||
if (errno == EAGAIN) {
|
||||
@@ -6605,13 +6664,11 @@ static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client
|
||||
if (ret <= 0) {
|
||||
memset(&fd_event, 0, sizeof(fd_event));
|
||||
ssl_ret = _ssl_get_error(tls_client, ret);
|
||||
if (ssl_ret == SSL_ERROR_WANT_READ) {
|
||||
fd_event.events = EPOLLIN;
|
||||
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
|
||||
fd_event.events = EPOLLOUT | EPOLLIN;
|
||||
} else if (ssl_ret == SSL_ERROR_SYSCALL) {
|
||||
goto errout;
|
||||
} else {
|
||||
if (_dns_server_ssl_poll_event(tls_client, ssl_ret) == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (ssl_ret != SSL_ERROR_SYSCALL) {
|
||||
unsigned long ssl_err = ERR_get_error();
|
||||
int ssl_reason = ERR_GET_REASON(ssl_err);
|
||||
char name[DNS_MAX_CNAME_LEN];
|
||||
@@ -6619,16 +6676,9 @@ static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client
|
||||
get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tls_client->tcp.addr),
|
||||
ERR_reason_error_string(ssl_err), ret, ssl_ret, ssl_reason);
|
||||
ret = 0;
|
||||
goto errout;
|
||||
}
|
||||
|
||||
fd_event.data.ptr = tls_client;
|
||||
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->tcp.head.fd, &fd_event) != 0) {
|
||||
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
|
||||
goto errout;
|
||||
}
|
||||
|
||||
return 0;
|
||||
goto errout;
|
||||
}
|
||||
|
||||
tls_client->tcp.status = DNS_SERVER_CLIENT_STATUS_CONNECTED;
|
||||
|
||||
Reference in New Issue
Block a user