dns_server: support bind tls server

This commit is contained in:
Nick Peng
2023-02-28 23:52:05 +08:00
parent 8405d14625
commit c42f98979c
9 changed files with 751 additions and 17 deletions

View File

@@ -68,6 +68,10 @@ struct dns_dns64 dns_conf_dns_dns64;
struct dns_bind_ip dns_conf_bind_ip[DNS_MAX_BIND_IP];
int dns_conf_bind_ip_num = 0;
int dns_conf_tcp_idle_time = 120;
char dns_conf_bind_ca_file[DNS_MAX_PATH];
char dns_conf_bind_ca_key_file[DNS_MAX_PATH];
char dns_conf_bind_ca_key_pass[DNS_MAX_PATH];
char dns_conf_need_cert = 0;
int dns_conf_max_reply_ip_num = DNS_MAX_REPLY_IP_NUM;
@@ -1357,7 +1361,7 @@ static int _config_nftset_no_speed(void *data, int argc, char *argv[])
goto errout;
}
for (char *tok = strtok(copied_name, ","); tok && nftset_num <=2 ; tok = strtok(NULL, ",")) {
for (char *tok = strtok(copied_name, ","); tok && nftset_num <= 2; tok = strtok(NULL, ",")) {
char *saveptr = NULL;
char *tok_set = NULL;
@@ -1853,6 +1857,14 @@ static int _config_bind_ip(int argc, char *argv[], DNS_BIND_TYPE type)
bind_ip->flags = server_flag;
bind_ip->group = group;
dns_conf_bind_ip_num++;
if (bind_ip->type == DNS_BIND_TYPE_TLS) {
if (bind_ip->ssl_cert_file == NULL || bind_ip->ssl_cert_key_file == NULL) {
bind_ip->ssl_cert_file = dns_conf_bind_ca_file;
bind_ip->ssl_cert_key_file = dns_conf_bind_ca_key_file;
bind_ip->ssl_cert_key_pass = dns_conf_bind_ca_key_pass;
dns_conf_need_cert = 1;
}
}
tlog(TLOG_DEBUG, "bind ip %s, type: %d, flag: %X", ip, type, server_flag);
return 0;
@@ -1871,6 +1883,23 @@ static int _config_bind_ip_tcp(void *data, int argc, char *argv[])
return _config_bind_ip(argc, argv, DNS_BIND_TYPE_TCP);
}
static int _config_bind_ip_tls(void *data, int argc, char *argv[])
{
return _config_bind_ip(argc, argv, DNS_BIND_TYPE_TLS);
}
static int _config_option_parser_filepath(void *data, int argc, char *argv[])
{
if (argc <= 1) {
tlog(TLOG_ERROR, "invalid parameter.");
return -1;
}
conf_get_conf_fullpath(argv[1], data, DNS_MAX_PATH);
return 0;
}
static int _config_server_udp(void *data, int argc, char *argv[])
{
return _config_server(argc, argv, DNS_SERVER_UDP, DEFAULT_DNS_PORT);
@@ -3041,6 +3070,10 @@ static struct config_item _config_item[] = {
CONF_YESNO("resolv-hostname", &dns_conf_resolv_hostname),
CONF_CUSTOM("bind", _config_bind_ip_udp, NULL),
CONF_CUSTOM("bind-tcp", _config_bind_ip_tcp, NULL),
CONF_CUSTOM("bind-tls", _config_bind_ip_tls, NULL),
CONF_CUSTOM("bind-cert-file", _config_option_parser_filepath, &dns_conf_bind_ca_file),
CONF_CUSTOM("bind-cert-key-file", _config_option_parser_filepath, &dns_conf_bind_ca_key_file),
CONF_STRING("bind-cert-key-pass", dns_conf_bind_ca_key_pass, DNS_MAX_PATH),
CONF_CUSTOM("server", _config_server_udp, NULL),
CONF_CUSTOM("server-tcp", _config_server_tcp, NULL),
CONF_CUSTOM("server-tls", _config_server_tls, NULL),
@@ -3071,13 +3104,13 @@ static struct config_item _config_item[] = {
CONF_INT("dualstack-ip-selection-threshold", &dns_conf_dualstack_ip_selection_threshold, 0, 1000),
CONF_CUSTOM("dns64", _config_dns64, NULL),
CONF_CUSTOM("log-level", _config_log_level, NULL),
CONF_STRING("log-file", (char *)dns_conf_log_file, DNS_MAX_PATH),
CONF_CUSTOM("log-file", _config_option_parser_filepath, (char *)dns_conf_log_file),
CONF_SIZE("log-size", &dns_conf_log_size, 0, 1024 * 1024 * 1024),
CONF_INT("log-num", &dns_conf_log_num, 0, 1024),
CONF_INT_BASE("log-file-mode", &dns_conf_log_file_mode, 0, 511, 8),
CONF_YESNO("audit-enable", &dns_conf_audit_enable),
CONF_YESNO("audit-SOA", &dns_conf_audit_log_SOA),
CONF_STRING("audit-file", (char *)&dns_conf_audit_file, DNS_MAX_PATH),
CONF_CUSTOM("audit-file", _config_option_parser_filepath, (char *)&dns_conf_audit_file),
CONF_INT_BASE("audit-file-mode", &dns_conf_audit_file_mode, 0, 511, 8),
CONF_SIZE("audit-size", &dns_conf_audit_size, 0, 1024 * 1024 * 1024),
CONF_INT("audit-num", &dns_conf_audit_num, 0, 1024),
@@ -3264,6 +3297,31 @@ errout:
return -1;
}
static int _check_and_create_cert(void)
{
if (dns_conf_need_cert == 0) {
return 0;
}
if (dns_conf_bind_ca_file[0] != 0 && dns_conf_bind_ca_key_file[0] != 0) {
return -1;
}
conf_get_conf_fullpath("smartdns-cert.pem", dns_conf_bind_ca_file, sizeof(dns_conf_bind_ca_file));
conf_get_conf_fullpath("smartdns-key.pem", dns_conf_bind_ca_key_file, sizeof(dns_conf_bind_ca_key_file));
if (access(dns_conf_bind_ca_file, F_OK) == 0 && access(dns_conf_bind_ca_key_file, F_OK) == 0) {
return 0;
}
tlog(TLOG_INFO, "Generate default ssl cert and key file.");
if (generate_cert_key(dns_conf_bind_ca_key_file, dns_conf_bind_ca_file, NULL, 365 * 3) != 0) {
tlog(TLOG_WARN, "Generate default ssl cert and key file failed.");
return -1;
}
return 0;
}
static int _dns_conf_load_post(void)
{
_config_setup_smartdns_domain();
@@ -3289,16 +3347,26 @@ static int _dns_conf_load_post(void)
_config_domain_set_name_table_destroy();
_check_and_create_cert();
return 0;
}
int dns_server_load_conf(const char *file)
{
int ret = 0;
_dns_conf_load_pre();
ret = _dns_conf_load_pre();
if (ret != 0) {
return ret;
}
openlog("smartdns", LOG_CONS | LOG_NDELAY, LOG_LOCAL1);
ret = load_conf(file, _config_item, _conf_printf);
closelog();
_dns_conf_load_post();
if (ret != 0) {
return ret;
}
ret = _dns_conf_load_post();
return ret;
}

View File

@@ -83,6 +83,7 @@ typedef enum {
DNS_BIND_TYPE_UDP,
DNS_BIND_TYPE_TCP,
DNS_BIND_TYPE_TLS,
DNS_BIND_TYPE_HTTPS,
} DNS_BIND_TYPE;
typedef enum {
@@ -344,6 +345,9 @@ struct dns_bind_ip {
DNS_BIND_TYPE type;
uint32_t flags;
char ip[DNS_MAX_IPLEN];
const char *ssl_cert_file;
const char *ssl_cert_key_file;
const char *ssl_cert_key_pass;
const char *group;
};
@@ -406,6 +410,10 @@ extern struct dns_dns64 dns_conf_dns_dns64;
extern struct dns_bind_ip dns_conf_bind_ip[DNS_MAX_BIND_IP];
extern int dns_conf_bind_ip_num;
extern char dns_conf_bind_ca_file[DNS_MAX_PATH];
extern char dns_conf_bind_ca_key_file[DNS_MAX_PATH];
extern char dns_conf_bind_ca_key_pass[DNS_MAX_PATH];
extern int dns_conf_tcp_idle_time;
extern int dns_conf_cachesize;
extern int dns_conf_prefetch;

View File

@@ -37,6 +37,8 @@
#include <net/if.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
@@ -46,6 +48,10 @@
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/pem.h>
#include <openssl/x509.h>
#define DNS_MAX_EVENTS 256
#define IPV6_READY_CHECK_TIME 180
@@ -76,6 +82,8 @@ typedef enum {
DNS_CONN_TYPE_TCP_CLIENT,
DNS_CONN_TYPE_TLS_SERVER,
DNS_CONN_TYPE_TLS_CLIENT,
DNS_CONN_TYPE_HTTPS_SERVER,
DNS_CONN_TYPE_HTTPS_CLIENT,
} DNS_CONN_TYPE;
typedef enum DNS_CHILD_POST_RESULT {
@@ -132,6 +140,14 @@ struct dns_server_post_context {
int no_release_parent;
};
typedef enum dns_server_client_status {
DNS_SERVER_CLIENT_STATUS_INIT = 0,
DNS_SERVER_CLIENT_STATUS_CONNECTING,
DNS_SERVER_CLIENT_STATUS_CONNECTIONLESS,
DNS_SERVER_CLIENT_STATUS_CONNECTED,
DNS_SERVER_CLIENT_STATUS_DISCONNECTED,
} dns_server_client_status;
struct dns_server_conn_udp {
struct dns_server_conn_head head;
socklen_t addr_len;
@@ -142,6 +158,11 @@ struct dns_server_conn_tcp_server {
struct dns_server_conn_head head;
};
struct dns_server_conn_tls_server {
struct dns_server_conn_head head;
SSL_CTX *ssl_ctx;
};
struct dns_server_conn_tcp_client {
struct dns_server_conn_head head;
struct dns_conn_buf recvbuff;
@@ -151,6 +172,23 @@ struct dns_server_conn_tcp_client {
socklen_t localaddr_len;
struct sockaddr_storage localaddr;
dns_server_client_status status;
};
struct dns_server_conn_tls_client {
struct dns_server_conn_head head;
struct dns_conn_buf recvbuff;
struct dns_conn_buf sndbuff;
socklen_t addr_len;
struct sockaddr_storage addr;
socklen_t localaddr_len;
struct sockaddr_storage localaddr;
dns_server_client_status status;
SSL *ssl;
pthread_mutex_t ssl_lock;
};
/* ip address lists of domain */
@@ -306,6 +344,7 @@ static int _dns_request_post(struct dns_server_post_context *context);
static int _dns_server_reply_all_pending_list(struct dns_request *request, struct dns_server_post_context *context);
static void *_dns_server_get_dns_rule(struct dns_request *request, enum domain_rule rule);
static const char *_dns_server_get_request_groupname(struct dns_request *request);
static int _dns_server_tcp_socket_send(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len);
static void _dns_server_wakeup_thread(void)
{
@@ -460,7 +499,7 @@ static int _dns_server_is_return_soa(struct dns_request *request)
unsigned int flags = 0;
if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_SOA) == 0) {
/* when both has no rule SOA and force AAAA soa, foce AAAA soa has high priority */
/* when both has no rule SOA and force AAAA soa, force AAAA soa has high priority */
if (request->qtype == DNS_T_AAAA && _dns_server_has_bind_flag(request, BIND_FLAG_FORCE_AAAA_SOA) == 0) {
return 1;
}
@@ -961,6 +1000,21 @@ static void _dns_server_conn_release(struct dns_server_conn_head *conn)
conn->fd = -1;
}
if (conn->type == DNS_CONN_TYPE_TLS_CLIENT || conn->type == DNS_CONN_TYPE_HTTPS_CLIENT) {
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)conn;
if (tls_client->ssl != NULL) {
SSL_free(tls_client->ssl);
tls_client->ssl = NULL;
}
pthread_mutex_destroy(&tls_client->ssl_lock);
} else if (conn->type == DNS_CONN_TYPE_TLS_SERVER) {
struct dns_server_conn_tls_server *tls_server = (struct dns_server_conn_tls_server *)conn;
if (tls_server->ssl_ctx != NULL) {
SSL_CTX_free(tls_server->ssl_ctx);
tls_server->ssl_ctx = NULL;
}
}
list_del_init(&conn->list);
free(conn);
}
@@ -1007,7 +1061,7 @@ static int _dns_server_reply_tcp(struct dns_request *request, struct dns_server_
memcpy(inpacket + 2, packet, len);
len += 2;
send_len = send(tcpclient->head.fd, inpacket, len, MSG_NOSIGNAL);
send_len = _dns_server_tcp_socket_send(tcpclient, inpacket, len);
if (send_len < 0) {
if (errno == EAGAIN) {
/* save data to buffer, and retry when EPOLLOUT is available */
@@ -1101,7 +1155,7 @@ static int _dns_reply_inpacket(struct dns_request *request, unsigned char *inpac
} else if (conn->type == DNS_CONN_TYPE_TCP_CLIENT) {
ret = _dns_server_reply_tcp(request, (struct dns_server_conn_tcp_client *)conn, inpacket, inpacket_len);
} else if (conn->type == DNS_CONN_TYPE_TLS_CLIENT) {
ret = -1;
ret = _dns_server_reply_tcp(request, (struct dns_server_conn_tcp_client *)conn, inpacket, inpacket_len);
} else {
ret = -1;
}
@@ -5389,6 +5443,209 @@ errout:
return -1;
}
static ssize_t _ssl_read(struct dns_server_conn_tls_client *conn, void *buff, int num)
{
ssize_t ret = 0;
if (conn == NULL || buff == NULL) {
return SSL_ERROR_SYSCALL;
}
pthread_mutex_lock(&conn->ssl_lock);
ret = SSL_read(conn->ssl, buff, num);
pthread_mutex_unlock(&conn->ssl_lock);
return ret;
}
static ssize_t _ssl_write(struct dns_server_conn_tls_client *conn, const void *buff, int num)
{
ssize_t ret = 0;
if (conn == NULL || buff == NULL || conn->ssl == NULL) {
return SSL_ERROR_SYSCALL;
}
pthread_mutex_lock(&conn->ssl_lock);
ret = SSL_write(conn->ssl, buff, num);
pthread_mutex_unlock(&conn->ssl_lock);
return ret;
}
static int _ssl_get_error(struct dns_server_conn_tls_client *conn, int ret)
{
int err = 0;
if (conn == NULL || conn->ssl == NULL) {
return SSL_ERROR_SYSCALL;
}
pthread_mutex_lock(&conn->ssl_lock);
err = SSL_get_error(conn->ssl, ret);
pthread_mutex_unlock(&conn->ssl_lock);
return err;
}
static int _ssl_do_accept(struct dns_server_conn_tls_client *conn)
{
int err = 0;
if (conn == NULL || conn->ssl == NULL) {
return SSL_ERROR_SYSCALL;
}
pthread_mutex_lock(&conn->ssl_lock);
err = SSL_accept(conn->ssl);
pthread_mutex_unlock(&conn->ssl_lock);
return err;
}
static int _dns_server_socket_ssl_send(struct dns_server_conn_tls_client *tls_client, const void *buf, int num)
{
int ret = 0;
int ssl_ret = 0;
unsigned long ssl_err = 0;
if (tls_client->ssl == NULL) {
errno = EINVAL;
return -1;
}
if (num < 0) {
errno = EINVAL;
return -1;
}
ret = _ssl_write(tls_client, buf, num);
if (ret > 0) {
return ret;
}
ssl_ret = _ssl_get_error(tls_client, ret);
switch (ssl_ret) {
case SSL_ERROR_NONE:
return 0;
break;
case SSL_ERROR_ZERO_RETURN:
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err);
if (ssl_reason == SSL_R_UNINITIALIZED || ssl_reason == SSL_R_PROTOCOL_IS_SHUTDOWN ||
ssl_reason == SSL_R_BAD_LENGTH || ssl_reason == SSL_R_SHUTDOWN_WHILE_IN_INIT ||
ssl_reason == SSL_R_BAD_WRITE_RETRY) {
errno = EAGAIN;
return -1;
}
tlog(TLOG_ERROR, "SSL write fail error no: %s(%d)\n", ERR_reason_error_string(ssl_err), ssl_reason);
errno = EFAULT;
ret = -1;
break;
case SSL_ERROR_SYSCALL:
tlog(TLOG_DEBUG, "SSL syscall failed, %s", strerror(errno));
return ret;
default:
errno = EFAULT;
ret = -1;
break;
}
return ret;
}
static int _dns_server_socket_ssl_recv(struct dns_server_conn_tls_client *tls_client, void *buf, int num)
{
ssize_t ret = 0;
int ssl_ret = 0;
unsigned long ssl_err = 0;
if (tls_client->ssl == NULL) {
errno = EFAULT;
return -1;
}
ret = _ssl_read(tls_client, buf, num);
if (ret >= 0) {
return ret;
}
ssl_ret = _ssl_get_error(tls_client, ret);
switch (ssl_ret) {
case SSL_ERROR_NONE:
case SSL_ERROR_ZERO_RETURN:
return 0;
break;
case SSL_ERROR_WANT_READ:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_WANT_WRITE:
errno = EAGAIN;
ret = -1;
break;
case SSL_ERROR_SSL:
ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err);
if (ssl_reason == SSL_R_UNINITIALIZED) {
errno = EAGAIN;
return -1;
}
if (ssl_reason == SSL_R_SHUTDOWN_WHILE_IN_INIT || ssl_reason == SSL_R_PROTOCOL_IS_SHUTDOWN) {
return 0;
}
tlog(TLOG_INFO, "SSL read fail error no: %s(%lx), len: %d\n", ERR_reason_error_string(ssl_err), ssl_err, num);
errno = EFAULT;
ret = -1;
break;
case SSL_ERROR_SYSCALL:
if (errno == 0) {
return 0;
}
if (errno != ECONNRESET) {
tlog(TLOG_INFO, "SSL syscall failed, %s ", strerror(errno));
}
ret = -1;
return ret;
default:
errno = EFAULT;
ret = -1;
break;
}
return ret;
}
static int _dns_server_tcp_socket_send(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len)
{
if (tcp_client->head.type == DNS_CONN_TYPE_TCP_CLIENT) {
return send(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
int ret = _dns_server_socket_ssl_send((struct dns_server_conn_tls_client *)tcp_client, data, data_len);
return ret;
} else {
return -1;
}
}
static int _dns_server_tcp_socket_recv(struct dns_server_conn_tcp_client *tcp_client, void *data, int data_len)
{
if (tcp_client->head.type == DNS_CONN_TYPE_TCP_CLIENT) {
return recv(tcp_client->head.fd, data, data_len, MSG_NOSIGNAL);
} else if (tcp_client->head.type == DNS_CONN_TYPE_TLS_CLIENT ||
tcp_client->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) {
return _dns_server_socket_ssl_recv((struct dns_server_conn_tls_client *)tcp_client, data, data_len);
} else {
return -1;
}
}
static int _dns_server_tcp_recv(struct dns_server_conn_tcp_client *tcpclient)
{
ssize_t len = 0;
@@ -5399,8 +5656,8 @@ static int _dns_server_tcp_recv(struct dns_server_conn_tcp_client *tcpclient)
return 0;
}
len = recv(tcpclient->head.fd, tcpclient->recvbuff.buf + tcpclient->recvbuff.size,
sizeof(tcpclient->recvbuff.buf) - tcpclient->recvbuff.size, 0);
len = _dns_server_tcp_socket_recv(tcpclient, tcpclient->recvbuff.buf + tcpclient->recvbuff.size,
sizeof(tcpclient->recvbuff.buf) - tcpclient->recvbuff.size);
if (len < 0) {
if (errno == EAGAIN) {
return RECV_ERROR_AGAIN;
@@ -5517,7 +5774,7 @@ static int _dns_server_tcp_send(struct dns_server_conn_tcp_client *tcpclient)
{
int len = 0;
while (tcpclient->sndbuff.size > 0) {
len = send(tcpclient->head.fd, tcpclient->sndbuff.buf, tcpclient->sndbuff.size, MSG_NOSIGNAL);
len = _dns_server_tcp_socket_send(tcpclient, tcpclient->sndbuff.buf, tcpclient->sndbuff.size);
if (len < 0) {
if (errno == EAGAIN) {
return RECV_ERROR_AGAIN;
@@ -5566,6 +5823,137 @@ static int _dns_server_process_tcp(struct dns_server_conn_tcp_client *dnsserver,
return 0;
}
static int _dns_server_tls_accept(struct dns_server_conn_tls_server *tls_server, struct epoll_event *event,
unsigned long now)
{
struct sockaddr_storage addr;
struct dns_server_conn_tls_client *tls_client = NULL;
socklen_t addr_len = sizeof(addr);
int fd = -1;
SSL *ssl = NULL;
fd = accept4(tls_server->head.fd, (struct sockaddr *)&addr, &addr_len, SOCK_NONBLOCK | SOCK_CLOEXEC);
if (fd < 0) {
tlog(TLOG_ERROR, "accept failed, %s", strerror(errno));
return -1;
}
tls_client = malloc(sizeof(*tls_client));
if (tls_client == NULL) {
tlog(TLOG_ERROR, "malloc for tls_client failed.");
goto errout;
}
memset(tls_client, 0, sizeof(*tls_client));
tls_client->head.fd = fd;
tls_client->head.type = DNS_CONN_TYPE_TLS_CLIENT;
tls_client->head.server_flags = tls_server->head.server_flags;
tls_client->head.dns_group = tls_server->head.dns_group;
atomic_set(&tls_client->head.refcnt, 0);
memcpy(&tls_client->addr, &addr, addr_len);
tls_client->addr_len = addr_len;
tls_client->localaddr_len = sizeof(struct sockaddr_storage);
if (_dns_server_epoll_ctl(&tls_client->head, EPOLL_CTL_ADD, EPOLLIN) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed.");
return -1;
}
if (getsocket_inet(tls_client->head.fd, (struct sockaddr *)&tls_client->localaddr, &tls_client->localaddr_len) !=
0) {
tlog(TLOG_ERROR, "get local addr failed, %s", strerror(errno));
goto errout;
}
ssl = SSL_new(tls_server->ssl_ctx);
if (ssl == NULL) {
tlog(TLOG_ERROR, "SSL_new failed.");
goto errout;
}
if (SSL_set_fd(ssl, fd) != 1) {
tlog(TLOG_ERROR, "SSL_set_fd failed.");
goto errout;
}
tls_client->ssl = ssl;
tls_client->status = DNS_SERVER_CLIENT_STATUS_CONNECTING;
pthread_mutex_init(&tls_client->ssl_lock, NULL);
_dns_server_client_touch(&tls_client->head);
list_add(&tls_client->head.list, &server.conn_list);
_dns_server_conn_get(&tls_client->head);
return 0;
errout:
if (fd > 0) {
close(fd);
}
if (ssl) {
SSL_free(ssl);
}
if (tls_client) {
free(tls_client);
}
return -1;
}
static int _dns_server_process_tls(struct dns_server_conn_tls_client *tls_client, struct epoll_event *event,
unsigned long now)
{
int ret = 0;
int ssl_ret = 0;
struct epoll_event fd_event;
if (tls_client->status == DNS_SERVER_CLIENT_STATUS_CONNECTING) {
/* do SSL hand shake */
ret = _ssl_do_accept(tls_client);
if (ret == 0) {
goto errout;
} else if (ret < 0) {
memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = _ssl_get_error(tls_client, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) {
fd_event.events = EPOLLIN;
} else if (ssl_ret == SSL_ERROR_WANT_WRITE) {
fd_event.events = EPOLLOUT | EPOLLIN;
} else if (ssl_ret == SSL_ERROR_SYSCALL) {
goto errout;
} else {
unsigned long ssl_err = ERR_get_error();
int ssl_reason = ERR_GET_REASON(ssl_err);
tlog(TLOG_DEBUG, "Handshake with %s failed, error no: %s(%d, %d, %d)\n", "",
ERR_reason_error_string(ssl_err), ret, ssl_ret, ssl_reason);
ret = 0;
goto errout;
}
fd_event.data.ptr = tls_client;
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->head.fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
return 0;
}
tls_client->status = DNS_SERVER_CLIENT_STATUS_CONNECTED;
memset(&fd_event, 0, sizeof(fd_event));
fd_event.events = EPOLLIN | EPOLLOUT;
fd_event.data.ptr = tls_client;
if (epoll_ctl(server.epoll_fd, EPOLL_CTL_MOD, tls_client->head.fd, &fd_event) != 0) {
tlog(TLOG_ERROR, "epoll ctl failed, %s", strerror(errno));
goto errout;
}
}
return _dns_server_process_tcp((struct dns_server_conn_tcp_client *)tls_client, event, now);
errout:
_dns_server_client_close(&tls_client->head);
return ret;
}
static int _dns_server_process(struct dns_server_conn_head *conn, struct epoll_event *event, unsigned long now)
{
int ret = 0;
@@ -5586,10 +5974,19 @@ static int _dns_server_process(struct dns_server_conn_head *conn, struct epoll_e
get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tcpclient->addr));
}
} else if (conn->type == DNS_CONN_TYPE_TLS_SERVER) {
tlog(TLOG_ERROR, "unsupported dns server type %d", conn->type);
ret = -1;
struct dns_server_conn_tls_server *tls_server = (struct dns_server_conn_tls_server *)conn;
ret = _dns_server_tls_accept(tls_server, event, now);
} else if (conn->type == DNS_CONN_TYPE_TLS_CLIENT) {
struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)conn;
ret = _dns_server_process_tls(tls_client, event, now);
if (ret != 0) {
char name[DNS_MAX_CNAME_LEN];
tlog(TLOG_DEBUG, "process TLS packet from %s failed.",
get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tls_client->addr));
}
} else {
tlog(TLOG_ERROR, "unsupported dns server type %d", conn->type);
_dns_server_client_close(conn);
ret = -1;
}
_dns_server_conn_release(conn);
@@ -5818,9 +6215,18 @@ static void _dns_server_close_socket_server(void)
list_for_each_entry_safe(conn, tmp, &server.conn_list, list)
{
switch (conn->type) {
case DNS_CONN_TYPE_HTTPS_SERVER:
case DNS_CONN_TYPE_TLS_SERVER: {
struct dns_server_conn_tls_server *tls_server = (struct dns_server_conn_tls_server *)conn;
if (tls_server->ssl_ctx) {
SSL_CTX_free(tls_server->ssl_ctx);
tls_server->ssl_ctx = NULL;
}
_dns_server_client_close(conn);
break;
}
case DNS_CONN_TYPE_UDP_SERVER:
case DNS_CONN_TYPE_TCP_SERVER:
case DNS_CONN_TYPE_TLS_SERVER:
_dns_server_client_close(conn);
break;
default:
@@ -6113,6 +6519,7 @@ static int _dns_server_socket_tcp(struct dns_bind_ip *bind_ip)
const char *host_ip = NULL;
struct dns_server_conn_tcp_server *conn = NULL;
int fd = -1;
const int on = 1;
host_ip = bind_ip->ip;
conn = malloc(sizeof(struct dns_server_conn_tcp_server));
@@ -6126,6 +6533,8 @@ static int _dns_server_socket_tcp(struct dns_bind_ip *bind_ip)
goto errout;
}
setsockopt(fd, SOL_TCP, TCP_FASTOPEN, &on, sizeof(on));
conn->head.type = DNS_CONN_TYPE_TCP_SERVER;
conn->head.fd = fd;
_dns_server_set_flags(&conn->head, bind_ip);
@@ -6144,6 +6553,110 @@ errout:
return -1;
}
static int _dns_server_socket_tls_ssl_pass_callback(char *buf, int size, int rwflag, void *userdata)
{
struct dns_bind_ip *bind_ip = userdata;
if (bind_ip->ssl_cert_key_pass == NULL || bind_ip->ssl_cert_key_pass[0] == '\0') {
return 0;
}
safe_strncpy(buf, bind_ip->ssl_cert_key_pass, size);
return strlen(buf);
}
static int _dns_server_socket_tls(struct dns_bind_ip *bind_ip, DNS_CONN_TYPE conn_type)
{
const char *host_ip = NULL;
const char *ssl_cert_file = NULL;
const char *ssl_cert_key_file = NULL;
struct dns_server_conn_tls_server *conn = NULL;
int fd = -1;
const SSL_METHOD *method = NULL;
SSL_CTX *ssl_ctx = NULL;
const int on = 1;
host_ip = bind_ip->ip;
ssl_cert_file = bind_ip->ssl_cert_file;
ssl_cert_key_file = bind_ip->ssl_cert_key_file;
if (ssl_cert_file == NULL || ssl_cert_key_file == NULL) {
tlog(TLOG_WARN, "no cert or cert key file");
goto errout;
}
if (ssl_cert_file[0] == '\0' || ssl_cert_key_file[0] == '\0') {
tlog(TLOG_WARN, "no cert or cert key file");
goto errout;
}
conn = malloc(sizeof(struct dns_server_conn_tls_server));
if (conn == NULL) {
goto errout;
}
INIT_LIST_HEAD(&conn->head.list);
fd = _dns_create_socket(host_ip, SOCK_STREAM);
if (fd <= 0) {
goto errout;
}
setsockopt(fd, SOL_TCP, TCP_FASTOPEN, &on, sizeof(on));
#if (OPENSSL_VERSION_NUMBER >= 0x10100000L)
method = TLS_server_method();
if (method == NULL) {
goto errout;
}
#else
method = SSLv23_server_method();
#endif
ssl_ctx = SSL_CTX_new(method);
if (ssl_ctx == NULL) {
goto errout;
}
SSL_CTX_set_session_cache_mode(ssl_ctx,
SSL_SESS_CACHE_SERVER | SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_NO_AUTO_CLEAR);
SSL_CTX_set_default_passwd_cb(ssl_ctx, _dns_server_socket_tls_ssl_pass_callback);
SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, bind_ip);
/* Set the key and cert */
if (ssl_cert_file[0] != '\0' && SSL_CTX_use_certificate_file(ssl_ctx, ssl_cert_file, SSL_FILETYPE_PEM) <= 0) {
tlog(TLOG_ERROR, "load cert %s failed, %s", ssl_cert_file, ERR_error_string(ERR_get_error(), NULL));
goto errout;
}
if (ssl_cert_key_file[0] != '\0' &&
SSL_CTX_use_PrivateKey_file(ssl_ctx, ssl_cert_key_file, SSL_FILETYPE_PEM) <= 0) {
tlog(TLOG_ERROR, "load cert key %s failed, %s", ssl_cert_key_file, ERR_error_string(ERR_get_error(), NULL));
goto errout;
}
conn->head.type = conn_type;
conn->head.fd = fd;
conn->ssl_ctx = ssl_ctx;
_dns_server_set_flags(&conn->head, bind_ip);
_dns_server_conn_get(&conn->head);
return 0;
errout:
if (ssl_ctx) {
SSL_CTX_free(ssl_ctx);
ssl_ctx = NULL;
}
if (conn) {
free(conn);
conn = NULL;
}
if (fd > 0) {
close(fd);
}
return -1;
}
static int _dns_server_socket(void)
{
int i = 0;
@@ -6161,7 +6674,15 @@ static int _dns_server_socket(void)
goto errout;
}
break;
case DNS_BIND_TYPE_HTTPS:
if (_dns_server_socket_tls(bind_ip, DNS_CONN_TYPE_HTTPS_SERVER) != 0) {
goto errout;
}
break;
case DNS_BIND_TYPE_TLS:
if (_dns_server_socket_tls(bind_ip, DNS_CONN_TYPE_TLS_SERVER) != 0) {
goto errout;
}
break;
default:
break;

View File

@@ -159,7 +159,7 @@ extern int conf_enum(const char *item, void *data, int argc, char *argv[]);
* Example:
* int num = 0;
*
* struct config_item itmes [] = {
* struct config_item items [] = {
* CONF_INT("CONF_NUM", &num, -1, 10),
* CONF_END();
* }

View File

@@ -828,6 +828,123 @@ errout:
return -1;
}
int generate_cert_key(const char *key_path, const char *cert_path, const char *san, int days)
{
int ret = -1;
#if (OPENSSL_VERSION_NUMBER <= 0x30000000L)
RSA *rsa = NULL;
BIGNUM *bn = NULL;
#endif
X509_EXTENSION *cert_ext = NULL;
BIO *cert_file = NULL;
BIO *key_file = NULL;
X509 *cert = NULL;
EVP_PKEY *pkey = NULL;
const int RSA_KEY_LENGTH = 2048;
if (key_path == NULL || cert_path == NULL) {
return ret;
}
key_file = BIO_new_file(key_path, "wb");
cert_file = BIO_new_file(cert_path, "wb");
cert = X509_new();
#if (OPENSSL_VERSION_NUMBER >= 0x30000000L)
pkey = EVP_RSA_gen(RSA_KEY_LENGTH);
#else
bn = BN_new();
rsa = RSA_new();
pkey = EVP_PKEY_new();
if (rsa == NULL || pkey == NULL || bn == NULL) {
goto out;
}
EVP_PKEY_assign(pkey, EVP_PKEY_RSA, rsa);
BN_set_word(bn, RSA_F4);
if (RSA_generate_key_ex(rsa, RSA_KEY_LENGTH, bn, NULL) != 1) {
goto out;
}
#endif
if (key_file == NULL || cert_file == NULL || cert == NULL || pkey == NULL) {
goto out;
}
ASN1_INTEGER_set(X509_get_serialNumber(cert), 1); // serial number
X509_gmtime_adj(X509_get_notBefore(cert), 0); // now
X509_gmtime_adj(X509_get_notAfter(cert), days * 24 * 3600); // accepts secs
X509_set_pubkey(cert, pkey);
X509_NAME *name = X509_get_subject_name(cert);
const unsigned char *country = (unsigned char *)"smartdns";
const unsigned char *company = (unsigned char *)"smartdns";
const unsigned char *common_name = (unsigned char *)"smartdns";
X509_NAME_add_entry_by_txt(name, "C", MBSTRING_ASC, country, -1, -1, 0);
X509_NAME_add_entry_by_txt(name, "O", MBSTRING_ASC, company, -1, -1, 0);
X509_NAME_add_entry_by_txt(name, "CN", MBSTRING_ASC, common_name, -1, -1, 0);
if (san != NULL) {
cert_ext = X509V3_EXT_conf_nid(NULL, NULL, NID_subject_alt_name, san);
if (cert_ext == NULL) {
goto out;
}
X509_add_ext(cert, cert_ext, -1);
}
X509_set_issuer_name(cert, name);
X509_sign(cert, pkey, EVP_sha256());
ret = PEM_write_bio_PrivateKey(key_file, pkey, NULL, NULL, 0, NULL, NULL);
if (ret != 1) {
goto out;
}
ret = PEM_write_bio_X509(cert_file, cert);
if (ret != 1) {
goto out;
}
chmod(key_path, S_IRUSR);
chmod(cert_path, S_IRUSR);
ret = 0;
out:
if (cert_ext) {
X509_EXTENSION_free(cert_ext);
}
if (pkey) {
EVP_PKEY_free(pkey);
}
#if (OPENSSL_VERSION_NUMBER <= 0x30000000L)
if (rsa && pkey == NULL) {
RSA_free(rsa);
}
if (bn) {
BN_free(bn);
}
#endif
if (cert_file) {
BIO_free_all(cert_file);
}
if (key_file) {
BIO_free_all(key_file);
}
if (cert) {
X509_free(cert);
}
return ret;
}
#if OPENSSL_API_COMPAT < 0x10100000
#define THREAD_STACK_SIZE (16 * 1024)
static pthread_mutex_t *lock_cs;

View File

@@ -97,6 +97,8 @@ int SSL_base64_decode(const char *in, unsigned char *out);
int SSL_base64_encode(const void *in, int in_len, char *out);
int generate_cert_key(const char *key_path, const char *cert_path, const char *san, int days);
int create_pid_file(const char *pid_file);
/* Parse a TLS packet for the Server Name Indication extension in the client