tls: refactor tls send recv code.

This commit is contained in:
Nick Peng
2023-12-08 23:12:29 +08:00
parent 7b1ea2c43d
commit c4bffbb1dd
2 changed files with 145 additions and 62 deletions

View File

@@ -107,6 +107,7 @@ struct dns_server_info {
int ttl_range;
SSL *ssl;
int ssl_write_len;
int ssl_want_write;
SSL_CTX *ssl_ctx;
SSL_SESSION *ssl_session;
@@ -2374,16 +2375,16 @@ static int _dns_client_socket_ssl_send(struct dns_server_info *server, const voi
ssl_ret = _ssl_get_error(server, ret);
switch (ssl_ret) {
case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
return 0;
break;
case SSL_ERROR_ZERO_RETURN:
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_READ;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_WRITE;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
@@ -2423,7 +2424,7 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
}
ret = _ssl_read(server, buf, num);
if (ret >= 0) {
if (ret > 0) {
return ret;
}
@@ -2435,11 +2436,11 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
break;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_READ;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_WRITE;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
@@ -2453,7 +2454,13 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
return 0;
}
tlog(TLOG_INFO, "SSL read fail error no: %s(%lx), len: %d\n", ERR_reason_error_string(ssl_err), ssl_err, num);
#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING
if (ssl_reason == SSL_R_UNEXPECTED_EOF_WHILE_READING) {
return 0;
}
#endif
tlog(TLOG_WARN, "SSL read fail error no: %s(%lx), reason: %d\n", ERR_reason_error_string(ssl_err), ssl_err, ssl_reason);
errno = EFAULT;
ret = -1;
break;
@@ -2462,9 +2469,6 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
return 0;
}
if (errno != ECONNRESET) {
tlog(TLOG_INFO, "SSL syscall failed, %s ", strerror(errno));
}
ret = -1;
return ret;
default:
@@ -2476,6 +2480,32 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
return ret;
}
static int _dns_client_ssl_poll_event(struct dns_server_info *server_info, int ssl_ret)
{
struct epoll_event fd_event;
memset(&fd_event, 0, sizeof(fd_event));
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else {
goto errout;
}
fd_event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
errout:
return -1;
}
static int _dns_client_socket_send(struct dns_server_info *server_info)
{
if (server_info->type == DNS_SERVER_UDP) {
@@ -2488,10 +2518,13 @@ static int _dns_client_socket_send(struct dns_server_info *server_info)
write_len = server_info->ssl_write_len;
server_info->ssl_write_len = -1;
}
server_info->ssl_want_write = 0;
int ret = _dns_client_socket_ssl_send(server_info, server_info->send_buff.data, write_len);
if (ret != 0) {
if (errno == EAGAIN) {
server_info->ssl_write_len = write_len;
if (ret < 0 && errno == EAGAIN) {
server_info->ssl_write_len = write_len;
if (_dns_client_ssl_poll_event(server_info, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
}
}
return ret;
@@ -2508,8 +2541,16 @@ static int _dns_client_socket_recv(struct dns_server_info *server_info)
return recv(server_info->fd, server_info->recv_buff.data + server_info->recv_buff.len,
DNS_TCP_BUFFER - server_info->recv_buff.len, 0);
} else if (server_info->type == DNS_SERVER_TLS || server_info->type == DNS_SERVER_HTTPS) {
return _dns_client_socket_ssl_recv(server_info, server_info->recv_buff.data + server_info->recv_buff.len,
DNS_TCP_BUFFER - server_info->recv_buff.len);
int ret = _dns_client_socket_ssl_recv(server_info, server_info->recv_buff.data + server_info->recv_buff.len,
DNS_TCP_BUFFER - server_info->recv_buff.len);
if (ret == -SSL_ERROR_WANT_WRITE && errno == EAGAIN) {
if (_dns_client_ssl_poll_event(server_info, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
server_info->ssl_want_write = 1;
}
}
return ret;
} else {
return -1;
}
@@ -2632,7 +2673,7 @@ static int _dns_client_process_tcp(struct dns_server_info *server_info, struct e
goto errout;
}
tlog(TLOG_ERROR, "recv failed, server %s:%d, %s\n", server_info->ip, server_info->port, strerror(errno));
tlog(TLOG_WARN, "recv failed, server %s:%d, %s\n", server_info->ip, server_info->port, strerror(errno));
goto errout;
}
@@ -2674,7 +2715,7 @@ static int _dns_client_process_tcp(struct dns_server_info *server_info, struct e
server_info->status = DNS_SERVER_STATUS_DISCONNECTED;
}
if (server_info->send_buff.len > 0) {
if (server_info->send_buff.len > 0 || server_info->ssl_want_write == 1) {
/* send existing send_buffer data */
len = _dns_client_socket_send(server_info);
if (len < 0) {
@@ -2972,16 +3013,11 @@ static int _dns_client_process_tls(struct dns_server_info *server_info, struct e
if (ret <= 0) {
memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = _ssl_get_error(server_info, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else if (ssl_ret == SSL_ERROR_SYSCALL) {
if (errno != ENETUNREACH) {
tlog(TLOG_WARN, "Handshake with %s failed, %s", server_info->ip, strerror(errno));
}
goto errout;
} else {
if (_dns_client_ssl_poll_event(server_info, ssl_ret) == 0) {
return 0;
}
if (ssl_ret != SSL_ERROR_SYSCALL) {
unsigned long ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err);
tlog(TLOG_WARN, "Handshake with %s failed, error no: %s(%d, %d, %d)\n", server_info->ip,
@@ -2989,13 +3025,10 @@ static int _dns_client_process_tls(struct dns_server_info *server_info, struct e
goto errout;
}
fd_event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
if (errno != ENETUNREACH) {
tlog(TLOG_WARN, "Handshake with %s failed, %s", server_info->ip, strerror(errno));
}
return 0;
goto errout;
}
tlog(TLOG_DEBUG, "tls server %s connected.\n", server_info->ip);

View File

@@ -188,6 +188,7 @@ struct dns_server_conn_tcp_client {
struct dns_server_conn_tls_client {
struct dns_server_conn_tcp_client tcp;
SSL *ssl;
int ssl_want_write;
pthread_mutex_t ssl_lock;
};
@@ -6135,16 +6136,16 @@ static int _dns_server_socket_ssl_send(struct dns_server_conn_tls_client *tls_cl
ssl_ret = _ssl_get_error(tls_client, ret);
switch (ssl_ret) {
case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
return 0;
break;
case SSL_ERROR_ZERO_RETURN:
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_READ;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_WRITE;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
@@ -6184,7 +6185,7 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
}
ret = _ssl_read(tls_client, buf, num);
if (ret >= 0) {
if (ret > 0) {
return ret;
}
@@ -6196,11 +6197,11 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
break;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_READ;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
ret = -SSL_ERROR_WANT_WRITE;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
@@ -6214,7 +6215,14 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
return 0;
}
tlog(TLOG_INFO, "SSL read fail error no: %s(%lx), len: %d\n", ERR_reason_error_string(ssl_err), ssl_err, num);
#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING
if (ssl_reason == SSL_R_UNEXPECTED_EOF_WHILE_READING) {
return 0;
}
#endif
tlog(TLOG_DEBUG, "SSL read fail error no: %s(%lx), reason: %d\n", ERR_reason_error_string(ssl_err), ssl_err,
ssl_reason);
errno = EFAULT;
ret = -1;
break;
@@ -6223,9 +6231,6 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
return 0;
}
if (errno != ECONNRESET) {
tlog(TLOG_INFO, "SSL syscall failed, %s ", strerror(errno));
}
ret = -1;
return ret;
default:
@@ -6237,13 +6242,46 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
return ret;
}
static int _dns_server_ssl_poll_event(struct dns_server_conn_tls_client *tls_client, int ssl_ret)
{
struct epoll_event fd_event;
memset(&fd_event, 0, sizeof(fd_event));
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else {
goto errout;
}
fd_event.data.ptr = tls_client;
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->tcp.head.fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
errout:
return -1;
}
static int _dns_server_tcp_socket_send(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len)
{
if (tcp_client->head.type == DNS_CONN_TYPE_TCP_CLIENT) {
return send(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
int ret = _dns_server_socket_ssl_send((struct dns_server_conn_tls_client *)tcp_client, data, data_len);
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcp_client;
tls_client->ssl_want_write = 0;
int ret = _dns_server_socket_ssl_send(tls_client, data, data_len);
if (ret < 0 && errno == EAGAIN) {
if (_dns_server_ssl_poll_event(tls_client, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
}
}
return ret;
} else {
return -1;
@@ -6256,7 +6294,16 @@ static int _dns_server_tcp_socket_recv(struct dns_server_conn_tcp_client *tcp_cl
return recv(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
return _dns_server_socket_ssl_recv((struct dns_server_conn_tls_client *)tcp_client, data, data_len);
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcp_client;
int ret = _dns_server_socket_ssl_recv(tls_client, data, data_len);
if (ret == -SSL_ERROR_WANT_WRITE && errno == EAGAIN) {
if (_dns_server_ssl_poll_event(tls_client, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
tls_client->ssl_want_write = 1;
}
}
return ret;
} else {
return -1;
}
@@ -6287,7 +6334,7 @@ static int _dns_server_tcp_recv(struct dns_server_conn_tcp_client *tcpclient)
return RECV_ERROR_CLOSE;
}
tlog(TLOG_ERROR, "recv failed, %s\n", strerror(errno));
tlog(TLOG_DEBUG, "recv failed, %s\n", strerror(errno));
return RECV_ERROR_FAIL;
} else if (len == 0) {
return RECV_ERROR_CLOSE;
@@ -6451,10 +6498,22 @@ static int _dns_server_tcp_process_requests(struct dns_server_conn_tcp_client *t
return 0;
}
static int _dns_server_tls_want_write(struct dns_server_conn_tcp_client *tcpclient)
{
if (tcpclient->head.type == DNS_CONN_TYPE_TLS_CLIENT || tcpclient->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcpclient;
if (tls_client->ssl_want_write == 1) {
return 1;
}
}
return 0;
}
static int _dns_server_tcp_send(struct dns_server_conn_tcp_client *tcpclient)
{
int len = 0;
while (tcpclient->sndbuff.size > 0) {
while (tcpclient->sndbuff.size > 0 || _dns_server_tls_want_write(tcpclient) == 1) {
len = _dns_server_tcp_socket_send(tcpclient, tcpclient->sndbuff.buf, tcpclient->sndbuff.size);
if (len < 0) {
if (errno == EAGAIN) {
@@ -6605,13 +6664,11 @@ static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client
if (ret <= 0) {
memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = _ssl_get_error(tls_client, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else if (ssl_ret == SSL_ERROR_SYSCALL) {
goto errout;
} else {
if (_dns_server_ssl_poll_event(tls_client, ssl_ret) == 0) {
return 0;
}
if (ssl_ret != SSL_ERROR_SYSCALL) {
unsigned long ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err);
char name[DNS_MAX_CNAME_LEN];
@@ -6619,16 +6676,9 @@ static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client
get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tls_client->tcp.addr),
ERR_reason_error_string(ssl_err), ret, ssl_ret, ssl_reason);
ret = 0;
goto errout;
}
fd_event.data.ptr = tls_client;
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->tcp.head.fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
goto errout;
}
tls_client->tcp.status = DNS_SERVER_CLIENT_STATUS_CONNECTED;