tls: refactor tls send recv code.

This commit is contained in:
Nick Peng
2023-12-08 23:12:29 +08:00
parent 7b1ea2c43d
commit c4bffbb1dd
2 changed files with 145 additions and 62 deletions

View File

@@ -107,6 +107,7 @@ struct dns_server_info {
int ttl_range; int ttl_range;
SSL *ssl; SSL *ssl;
int ssl_write_len; int ssl_write_len;
int ssl_want_write;
SSL_CTX *ssl_ctx; SSL_CTX *ssl_ctx;
SSL_SESSION *ssl_session; SSL_SESSION *ssl_session;
@@ -2374,16 +2375,16 @@ static int _dns_client_socket_ssl_send(struct dns_server_info *server, const voi
ssl_ret = _ssl_get_error(server, ret); ssl_ret = _ssl_get_error(server, ret);
switch (ssl_ret) { switch (ssl_ret) {
case SSL_ERROR_NONE: case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
return 0; return 0;
break; break;
case SSL_ERROR_ZERO_RETURN:
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_READ;
break; break;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_WRITE;
break; break;
case SSL_ERROR_SSL: case SSL_ERROR_SSL:
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
@@ -2423,7 +2424,7 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
} }
ret = _ssl_read(server, buf, num); ret = _ssl_read(server, buf, num);
if (ret >= 0) { if (ret > 0) {
return ret; return ret;
} }
@@ -2435,11 +2436,11 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
break; break;
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_READ;
break; break;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_WRITE;
break; break;
case SSL_ERROR_SSL: case SSL_ERROR_SSL:
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
@@ -2453,7 +2454,13 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
return 0; return 0;
} }
tlog(TLOG_INFO, "SSL read fail error no: %s(%lx), len: %d\n", ERR_reason_error_string(ssl_err), ssl_err, num); #ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING
if (ssl_reason == SSL_R_UNEXPECTED_EOF_WHILE_READING) {
return 0;
}
#endif
tlog(TLOG_WARN, "SSL read fail error no: %s(%lx), reason: %d\n", ERR_reason_error_string(ssl_err), ssl_err, ssl_reason);
errno = EFAULT; errno = EFAULT;
ret = -1; ret = -1;
break; break;
@@ -2462,9 +2469,6 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
return 0; return 0;
} }
if (errno != ECONNRESET) {
tlog(TLOG_INFO, "SSL syscall failed, %s ", strerror(errno));
}
ret = -1; ret = -1;
return ret; return ret;
default: default:
@@ -2476,6 +2480,32 @@ static int _dns_client_socket_ssl_recv(struct dns_server_info *server, void *buf
return ret; return ret;
} }
static int _dns_client_ssl_poll_event(struct dns_server_info *server_info, int ssl_ret)
{
struct epoll_event fd_event;
memset(&fd_event, 0, sizeof(fd_event));
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else {
goto errout;
}
fd_event.data.ptr = server_info;
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
errout:
return -1;
}
static int _dns_client_socket_send(struct dns_server_info *server_info) static int _dns_client_socket_send(struct dns_server_info *server_info)
{ {
if (server_info->type == DNS_SERVER_UDP) { if (server_info->type == DNS_SERVER_UDP) {
@@ -2488,10 +2518,13 @@ static int _dns_client_socket_send(struct dns_server_info *server_info)
write_len = server_info->ssl_write_len; write_len = server_info->ssl_write_len;
server_info->ssl_write_len = -1; server_info->ssl_write_len = -1;
} }
server_info->ssl_want_write = 0;
int ret = _dns_client_socket_ssl_send(server_info, server_info->send_buff.data, write_len); int ret = _dns_client_socket_ssl_send(server_info, server_info->send_buff.data, write_len);
if (ret != 0) { if (ret < 0 && errno == EAGAIN) {
if (errno == EAGAIN) { server_info->ssl_write_len = write_len;
server_info->ssl_write_len = write_len; if (_dns_client_ssl_poll_event(server_info, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
} }
} }
return ret; return ret;
@@ -2508,8 +2541,16 @@ static int _dns_client_socket_recv(struct dns_server_info *server_info)
return recv(server_info->fd, server_info->recv_buff.data + server_info->recv_buff.len, return recv(server_info->fd, server_info->recv_buff.data + server_info->recv_buff.len,
DNS_TCP_BUFFER - server_info->recv_buff.len, 0); DNS_TCP_BUFFER - server_info->recv_buff.len, 0);
} else if (server_info->type == DNS_SERVER_TLS || server_info->type == DNS_SERVER_HTTPS) { } else if (server_info->type == DNS_SERVER_TLS || server_info->type == DNS_SERVER_HTTPS) {
return _dns_client_socket_ssl_recv(server_info, server_info->recv_buff.data + server_info->recv_buff.len, int ret = _dns_client_socket_ssl_recv(server_info, server_info->recv_buff.data + server_info->recv_buff.len,
DNS_TCP_BUFFER - server_info->recv_buff.len); DNS_TCP_BUFFER - server_info->recv_buff.len);
if (ret == -SSL_ERROR_WANT_WRITE && errno == EAGAIN) {
if (_dns_client_ssl_poll_event(server_info, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
server_info->ssl_want_write = 1;
}
}
return ret;
} else { } else {
return -1; return -1;
} }
@@ -2632,7 +2673,7 @@ static int _dns_client_process_tcp(struct dns_server_info *server_info, struct e
goto errout; goto errout;
} }
tlog(TLOG_ERROR, "recv failed, server %s:%d, %s\n", server_info->ip, server_info->port, strerror(errno)); tlog(TLOG_WARN, "recv failed, server %s:%d, %s\n", server_info->ip, server_info->port, strerror(errno));
goto errout; goto errout;
} }
@@ -2674,7 +2715,7 @@ static int _dns_client_process_tcp(struct dns_server_info *server_info, struct e
server_info->status = DNS_SERVER_STATUS_DISCONNECTED; server_info->status = DNS_SERVER_STATUS_DISCONNECTED;
} }
if (server_info->send_buff.len > 0) { if (server_info->send_buff.len > 0 || server_info->ssl_want_write == 1) {
/* send existing send_buffer data */ /* send existing send_buffer data */
len = _dns_client_socket_send(server_info); len = _dns_client_socket_send(server_info);
if (len < 0) { if (len < 0) {
@@ -2972,16 +3013,11 @@ static int _dns_client_process_tls(struct dns_server_info *server_info, struct e
if (ret <= 0) { if (ret <= 0) {
memset(&fd_event, 0, sizeof(fd_event)); memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = _ssl_get_error(server_info, ret); ssl_ret = _ssl_get_error(server_info, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) { if (_dns_client_ssl_poll_event(server_info, ssl_ret) == 0) {
fd_event.events = EPOLLIN; return 0;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) { }
fd_event.events = EPOLLOUT | EPOLLIN;
} else if (ssl_ret == SSL_ERROR_SYSCALL) { if (ssl_ret != SSL_ERROR_SYSCALL) {
if (errno != ENETUNREACH) {
tlog(TLOG_WARN, "Handshake with %s failed, %s", server_info->ip, strerror(errno));
}
goto errout;
} else {
unsigned long ssl_err = ERR_get_error(); unsigned long ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err); int ssl_reason = ERR_GET_REASON(ssl_err);
tlog(TLOG_WARN, "Handshake with %s failed, error no: %s(%d, %d, %d)\n", server_info->ip, tlog(TLOG_WARN, "Handshake with %s failed, error no: %s(%d, %d, %d)\n", server_info->ip,
@@ -2989,13 +3025,10 @@ static int _dns_client_process_tls(struct dns_server_info *server_info, struct e
goto errout; goto errout;
} }
fd_event.data.ptr = server_info; if (errno != ENETUNREACH) {
if (epoll_ctl(client.epoll_fd, EPOLL_CTL_MOD, server_info->fd, &fd_event) != 0) { tlog(TLOG_WARN, "Handshake with %s failed, %s", server_info->ip, strerror(errno));
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
} }
goto errout;
return 0;
} }
tlog(TLOG_DEBUG, "tls server %s connected.\n", server_info->ip); tlog(TLOG_DEBUG, "tls server %s connected.\n", server_info->ip);

View File

@@ -188,6 +188,7 @@ struct dns_server_conn_tcp_client {
struct dns_server_conn_tls_client { struct dns_server_conn_tls_client {
struct dns_server_conn_tcp_client tcp; struct dns_server_conn_tcp_client tcp;
SSL *ssl; SSL *ssl;
int ssl_want_write;
pthread_mutex_t ssl_lock; pthread_mutex_t ssl_lock;
}; };
@@ -6135,16 +6136,16 @@ static int _dns_server_socket_ssl_send(struct dns_server_conn_tls_client *tls_cl
ssl_ret = _ssl_get_error(tls_client, ret); ssl_ret = _ssl_get_error(tls_client, ret);
switch (ssl_ret) { switch (ssl_ret) {
case SSL_ERROR_NONE: case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
return 0; return 0;
break; break;
case SSL_ERROR_ZERO_RETURN:
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_READ;
break; break;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_WRITE;
break; break;
case SSL_ERROR_SSL: case SSL_ERROR_SSL:
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
@@ -6184,7 +6185,7 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
} }
ret = _ssl_read(tls_client, buf, num); ret = _ssl_read(tls_client, buf, num);
if (ret >= 0) { if (ret > 0) {
return ret; return ret;
} }
@@ -6196,11 +6197,11 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
break; break;
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_READ;
break; break;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
errno = EAGAIN; errno = EAGAIN;
ret = -1; ret = -SSL_ERROR_WANT_WRITE;
break; break;
case SSL_ERROR_SSL: case SSL_ERROR_SSL:
ssl_err = ERR_get_error(); ssl_err = ERR_get_error();
@@ -6214,7 +6215,14 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
return 0; return 0;
} }
tlog(TLOG_INFO, "SSL read fail error no: %s(%lx), len: %d\n", ERR_reason_error_string(ssl_err), ssl_err, num); #ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING
if (ssl_reason == SSL_R_UNEXPECTED_EOF_WHILE_READING) {
return 0;
}
#endif
tlog(TLOG_DEBUG, "SSL read fail error no: %s(%lx), reason: %d\n", ERR_reason_error_string(ssl_err), ssl_err,
ssl_reason);
errno = EFAULT; errno = EFAULT;
ret = -1; ret = -1;
break; break;
@@ -6223,9 +6231,6 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
return 0; return 0;
} }
if (errno != ECONNRESET) {
tlog(TLOG_INFO, "SSL syscall failed, %s ", strerror(errno));
}
ret = -1; ret = -1;
return ret; return ret;
default: default:
@@ -6237,13 +6242,46 @@ static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_cl
return ret; return ret;
} }
static int _dns_server_ssl_poll_event(struct dns_server_conn_tls_client *tls_client, int ssl_ret)
{
struct epoll_event fd_event;
memset(&fd_event, 0, sizeof(fd_event));
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else {
goto errout;
}
fd_event.data.ptr = tls_client;
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->tcp.head.fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
errout:
return -1;
}
static int _dns_server_tcp_socket_send(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len) static int _dns_server_tcp_socket_send(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len)
{ {
if (tcp_client->head.type == DNS_CONN_TYPE_TCP_CLIENT) { if (tcp_client->head.type == DNS_CONN_TYPE_TCP_CLIENT) {
return send(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL); return send(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT || } else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) { tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
int ret = _dns_server_socket_ssl_send((struct dns_server_conn_tls_client *)tcp_client, data, data_len); struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcp_client;
tls_client->ssl_want_write = 0;
int ret = _dns_server_socket_ssl_send(tls_client, data, data_len);
if (ret < 0 && errno == EAGAIN) {
if (_dns_server_ssl_poll_event(tls_client, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
}
}
return ret; return ret;
} else { } else {
return -1; return -1;
@@ -6256,7 +6294,16 @@ static int _dns_server_tcp_socket_recv(struct dns_server_conn_tcp_client *tcp_cl
return recv(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL); return recv(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT || } else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) { tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
return _dns_server_socket_ssl_recv((struct dns_server_conn_tls_client *)tcp_client, data, data_len); struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcp_client;
int ret = _dns_server_socket_ssl_recv(tls_client, data, data_len);
if (ret == -SSL_ERROR_WANT_WRITE && errno == EAGAIN) {
if (_dns_server_ssl_poll_event(tls_client, SSL_ERROR_WANT_WRITE) == 0) {
errno = EAGAIN;
tls_client->ssl_want_write = 1;
}
}
return ret;
} else { } else {
return -1; return -1;
} }
@@ -6287,7 +6334,7 @@ static int _dns_server_tcp_recv(struct dns_server_conn_tcp_client *tcpclient)
return RECV_ERROR_CLOSE; return RECV_ERROR_CLOSE;
} }
tlog(TLOG_ERROR, "recv failed, %s\n", strerror(errno)); tlog(TLOG_DEBUG, "recv failed, %s\n", strerror(errno));
return RECV_ERROR_FAIL; return RECV_ERROR_FAIL;
} else if (len == 0) { } else if (len == 0) {
return RECV_ERROR_CLOSE; return RECV_ERROR_CLOSE;
@@ -6451,10 +6498,22 @@ static int _dns_server_tcp_process_requests(struct dns_server_conn_tcp_client *t
return 0; return 0;
} }
static int _dns_server_tls_want_write(struct dns_server_conn_tcp_client *tcpclient)
{
if (tcpclient->head.type == DNS_CONN_TYPE_TLS_CLIENT || tcpclient->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)tcpclient;
if (tls_client->ssl_want_write == 1) {
return 1;
}
}
return 0;
}
static int _dns_server_tcp_send(struct dns_server_conn_tcp_client *tcpclient) static int _dns_server_tcp_send(struct dns_server_conn_tcp_client *tcpclient)
{ {
int len = 0; int len = 0;
while (tcpclient->sndbuff.size > 0) { while (tcpclient->sndbuff.size > 0 || _dns_server_tls_want_write(tcpclient) == 1) {
len = _dns_server_tcp_socket_send(tcpclient, tcpclient->sndbuff.buf, tcpclient->sndbuff.size); len = _dns_server_tcp_socket_send(tcpclient, tcpclient->sndbuff.buf, tcpclient->sndbuff.size);
if (len < 0) { if (len < 0) {
if (errno == EAGAIN) { if (errno == EAGAIN) {
@@ -6605,13 +6664,11 @@ static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client
if (ret <= 0) { if (ret <= 0) {
memset(&fd_event, 0, sizeof(fd_event)); memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = _ssl_get_error(tls_client, ret); ssl_ret = _ssl_get_error(tls_client, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) { if (_dns_server_ssl_poll_event(tls_client, ssl_ret) == 0) {
fd_event.events = EPOLLIN; return 0;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) { }
fd_event.events = EPOLLOUT | EPOLLIN;
} else if (ssl_ret == SSL_ERROR_SYSCALL) { if (ssl_ret != SSL_ERROR_SYSCALL) {
goto errout;
} else {
unsigned long ssl_err = ERR_get_error(); unsigned long ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err); int ssl_reason = ERR_GET_REASON(ssl_err);
char name[DNS_MAX_CNAME_LEN]; char name[DNS_MAX_CNAME_LEN];
@@ -6619,16 +6676,9 @@ static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client
get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tls_client->tcp.addr), get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tls_client->tcp.addr),
ERR_reason_error_string(ssl_err), ret, ssl_ret, ssl_reason); ERR_reason_error_string(ssl_err), ret, ssl_ret, ssl_reason);
ret = 0; ret = 0;
goto errout;
} }
fd_event.data.ptr = tls_client; goto errout;
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->tcp.head.fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
} }
tls_client->tcp.status = DNS_SERVER_CLIENT_STATUS_CONNECTED; tls_client->tcp.status = DNS_SERVER_CLIENT_STATUS_CONNECTED;