Support TLS SPKI verify

This commit is contained in:
Nick Peng
2019-02-22 00:40:30 +08:00
parent 4465ce798a
commit 85b0eed3a2
10 changed files with 159 additions and 29 deletions

View File

@@ -93,6 +93,9 @@ struct dns_server_info {
/* server type */
dns_server_type_t type;
unsigned char *spki;
int spki_len;
/* client socket */
int fd;
int ttl;
@@ -535,9 +538,25 @@ void _dns_client_group_remove_all(void)
}
/* add dns server information */
int _dns_client_server_add(char *server_ip, struct addrinfo *gai, dns_server_type_t server_type, unsigned int server_flag, unsigned int result_flag, int ttl)
int _dns_client_server_add(char *server_ip, struct addrinfo *gai, dns_server_type_t server_type, unsigned int server_flag, unsigned int result_flag, int ttl, char *spki)
{
struct dns_server_info *server_info = NULL;
unsigned char *spki_data = NULL;
int spki_data_len = 0;
/* read SPKI */
if (spki && (strlen(spki) < DNS_MAX_SPKI_LEN)) {
spki_data = malloc(DNS_MAX_SPKI_LEN);
if (spki_data) {
memset(spki_data, 0, DNS_MAX_SPKI_LEN);
spki_data_len = SSL_base64_decode(spki, spki_data);
if (spki_data_len != SHA256_DIGEST_LENGTH) {
free(spki_data);
spki_data = NULL;
spki_data_len = 0;
}
}
}
if (_dns_client_server_exist(gai, server_type) == 0) {
return 0;
@@ -562,6 +581,8 @@ int _dns_client_server_add(char *server_ip, struct addrinfo *gai, dns_server_typ
server_info->result_flag = result_flag;
server_info->ttl = ttl;
server_info->ttl_range = 0;
server_info->spki = spki_data;
server_info->spki_len = spki_data_len;
if ((server_flag & SERVER_FLAG_EXCLUDE_DEFAULT) == 0) {
if (_dns_client_add_to_group(DNS_SERVER_GROUP_DEFAULT, server_info) != 0) {
@@ -605,6 +626,10 @@ int _dns_client_server_add(char *server_ip, struct addrinfo *gai, dns_server_typ
atomic_inc(&client.dns_server_num);
return 0;
errout:
if (spki_data) {
free(spki_data);
}
if (server_info) {
if (server_info->ssl_ctx) {
SSL_CTX_free(server_info->ssl_ctx);
@@ -673,6 +698,10 @@ void _dns_client_server_remove_all(void)
{
list_del(&server_info->list);
_dns_client_server_close(server_info);
if (server_info->spki) {
free(server_info->spki);
server_info->spki = NULL;
}
free(server_info);
}
pthread_mutex_unlock(&client.server_list_lock);
@@ -707,7 +736,7 @@ int _dns_client_server_remove(char *server_ip, struct addrinfo *gai, dns_server_
return -1;
}
int _dns_client_server_operate(char *server_ip, int port, dns_server_type_t server_type, unsigned int server_flag, unsigned int result_flag, int ttl, int operate)
int _dns_client_server_operate(char *server_ip, int port, dns_server_type_t server_type, unsigned int server_flag, unsigned int result_flag, int ttl, char *spki, int operate)
{
char port_s[8];
int sock_type;
@@ -741,7 +770,7 @@ int _dns_client_server_operate(char *server_ip, int port, dns_server_type_t serv
}
if (operate == 0) {
ret = _dns_client_server_add(server_ip, gai, server_type, server_flag, result_flag, ttl);
ret = _dns_client_server_add(server_ip, gai, server_type, server_flag, result_flag, ttl, spki);
if (ret != 0) {
goto errout;
}
@@ -760,14 +789,14 @@ errout:
return -1;
}
int dns_client_add_server(char *server_ip, int port, dns_server_type_t server_type, unsigned server_flag, unsigned int result_flag, int ttl)
int dns_client_add_server(char *server_ip, int port, dns_server_type_t server_type, unsigned server_flag, unsigned int result_flag, int ttl, char *spki)
{
return _dns_client_server_operate(server_ip, port, server_type, server_flag, result_flag, ttl, 0);
return _dns_client_server_operate(server_ip, port, server_type, server_flag, result_flag, ttl, spki, 0);
}
int dns_client_remove_server(char *server_ip, int port, dns_server_type_t server_type)
{
return _dns_client_server_operate(server_ip, port, server_type, 0, 0, 0, 1);
return _dns_client_server_operate(server_ip, port, server_type, 0, 0, 0, NULL, 1);
}
int dns_server_num(void)
@@ -1653,11 +1682,12 @@ static int _dns_client_tls_verify(struct dns_server_info *server_info)
{
X509 *cert = NULL;
char peer_CN[256];
const EVP_MD *digest;
unsigned char md[EVP_MAX_MD_SIZE];
unsigned int n;
char cert_fingerprint[256];
int i = 0;
int key_len = 0;
unsigned char *key_data = NULL;
unsigned char *key_data_tmp = NULL;
unsigned char *key_sha256 = NULL;
cert = SSL_get_peer_certificate(server_info->ssl);
if (cert == NULL) {
@@ -1666,28 +1696,66 @@ static int _dns_client_tls_verify(struct dns_server_info *server_info)
}
X509_NAME_get_text_by_NID(X509_get_subject_name(cert), NID_commonName, peer_CN, 256);
tlog(TLOG_DEBUG, "peer CN: %s", peer_CN);
digest = EVP_get_digestbyname("sha256");
X509_digest(cert, digest, md, &n);
/* get spki pin */
key_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
if (key_len <= 0 ) {
tlog(TLOG_ERROR, "get x509 public key failed.");
goto errout;
}
key_data = OPENSSL_malloc(key_len);
key_data_tmp = key_data;
if (key_data == NULL) {
tlog(TLOG_ERROR, "malloc memory failed.");
goto errout;
}
i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &key_data_tmp);
key_sha256 = SSL_SHA256(key_data, key_len, 0);
if (key_sha256 == NULL) {
tlog(TLOG_ERROR, "get sha256 failed.");
goto errout;
}
char *ptr = cert_fingerprint;
for (i = 0; i < 32; i++) {
*ptr = _dns_client_to_hex(md[i] >> 4 & 0xF);
for (i = 0; i < SHA256_DIGEST_LENGTH; i++) {
*ptr = _dns_client_to_hex(key_sha256[i] >> 4 & 0xF);
ptr++;
*ptr = _dns_client_to_hex(md[i] & 0xF);
*ptr = _dns_client_to_hex(key_sha256[i] & 0xF);
ptr++;
*ptr = ':';
ptr++;
}
ptr--;
*ptr = 0;
tlog(TLOG_DEBUG, "cert fingerprint(%s): %s", "sha256", cert_fingerprint);
tlog(TLOG_DEBUG, "cert SPKI pin(%s): %s", "sha256", cert_fingerprint);
if (server_info->spki) {
if (memcmp(server_info->spki, key_sha256, server_info->spki_len) != 0) {
tlog(TLOG_INFO, "server %s cert spki is invalid", server_info->ip);
goto errout;
} else {
tlog(TLOG_DEBUG, "server %s cert spki verify succeed", server_info->ip);
}
}
OPENSSL_free(key_data);
X509_free(cert);
return 0;
errout:
if (key_data) {
OPENSSL_free(key_data);
}
if (cert) {
X509_free(cert);
}
return -1;
}
static int _dns_client_process_tls(struct dns_server_info *server_info, struct epoll_event *event, unsigned long now)

View File

@@ -36,7 +36,7 @@ int dns_client_query(char *domain, int qtype, dns_client_callback callback, void
void dns_client_exit(void);
/* add remote dns server */
int dns_client_add_server(char *server_ip, int port, dns_server_type_t server_type, unsigned int server_flag, unsigned int result_flag, int ttl);
int dns_client_add_server(char *server_ip, int port, dns_server_type_t server_type, unsigned int server_flag, unsigned int result_flag, int ttl, char *spki);
/* remove remote dns server */
int dns_client_remove_server(char *server_ip, int port, dns_server_type_t server_type);

View File

@@ -142,11 +142,14 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po
int opt = 0;
unsigned int result_flag = 0;
unsigned int server_flag = 0;
unsigned char *spki = NULL;
int ttl = 0;
/* clang-format off */
static struct option long_options[] = {
{"blacklist-ip", 0, 0, 'b'},
{"check-edns", 0, 0, 'e'},
{"spki-pin", required_argument, 0, 'p'},
{"check-ttl", required_argument, 0, 't'},
{"group", required_argument, 0, 'g'},
{"exclude-default-group", 0, 0, 'E'},
@@ -164,6 +167,7 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po
}
server = &dns_conf_servers[index];
server->spki[0] = '\0';
ip = argv[1];
/* parse ip, port from ip */
@@ -200,7 +204,7 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po
ttl = atoi(optarg);
if (ttl < -255 || ttl > 255) {
tlog(TLOG_ERROR, "ttl value is invalid.");
return -1;
goto errout;
}
result_flag |= DNSSERVER_FLAG_CHECK_TTL;
break;
@@ -212,10 +216,14 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po
case 'g': {
if (dns_conf_get_group_set(optarg, server) != 0) {
tlog(TLOG_ERROR, "add group failed.");
return -1;
goto errout;
}
break;
}
case 'p': {
strncpy(server->spki, optarg, DNS_MAX_SPKI_LEN);
break;
}
default:
break;
}
@@ -230,6 +238,13 @@ int config_server(int argc, char *argv[], dns_server_type_t type, int default_po
tlog(TLOG_DEBUG, "add server %s, flag: %X, ttl: %d", ip, result_flag, ttl);
return 0;
errout:
if (spki) {
free(spki);
}
return -1;
}
int config_domain_iter_cb(void *data, const unsigned char *key, uint32_t key_len, void *value)

View File

@@ -15,6 +15,7 @@
#define DNS_GROUP_NAME_LEN 32
#define DNS_NAX_GROUP_NUMBER 16
#define DNS_MAX_IPLEN 64
#define DNS_MAX_SPKI_LEN 64
#define DNS_MAX_PATH 1024
#define DEFAULT_DNS_PORT 53
#define DEFAULT_DNS_TLS_PORT 853
@@ -91,6 +92,7 @@ struct dns_servers {
unsigned int server_flag;
int ttl;
dns_server_type_t type;
char spki[DNS_MAX_SPKI_LEN];
};
/* ip address lists of domain */

View File

@@ -130,7 +130,7 @@ int smartdns_add_servers(void)
for (i = 0; i < dns_conf_server_num; i++) {
ret = dns_client_add_server(dns_conf_servers[i].server, dns_conf_servers[i].port, dns_conf_servers[i].type, dns_conf_servers[i].server_flag, dns_conf_servers[i].result_flag,
dns_conf_servers[i].ttl);
dns_conf_servers[i].ttl, dns_conf_servers[i].spki);
if (ret != 0) {
tlog(TLOG_ERROR, "add server failed, %s:%d", dns_conf_servers[i].server, dns_conf_servers[i].port);
return -1;

View File

@@ -353,6 +353,44 @@ int ipset_del(const char *ipsetname, const unsigned char addr[], int addr_len)
return _ipset_operate(ipsetname, addr, addr_len, 0, IPSET_DEL);
}
unsigned char *SSL_SHA256(const unsigned char *d, size_t n, unsigned char *md)
{
SHA256_CTX c;
static unsigned char m[SHA256_DIGEST_LENGTH];
if (md == NULL)
md = m;
SHA256_Init(&c);
SHA256_Update(&c, d, n);
SHA256_Final(md, &c);
OPENSSL_cleanse(&c, sizeof(c));
return (md);
}
int SSL_base64_decode(const char *in, unsigned char *out)
{
size_t inlen = strlen(in);
int outlen;
if (inlen == 0) {
return 0;
}
outlen = EVP_DecodeBlock(out, (unsigned char *)in, inlen);
if (outlen < 0) {
goto errout;
}
/* Subtract padding bytes from |outlen| */
while (in[--inlen] == '=') {
--outlen;
}
return outlen;
errout:
return -1;
}
#define THREAD_STACK_SIZE (16*1024)
static pthread_mutex_t *lock_cs;
static long *lock_count;

View File

@@ -30,4 +30,8 @@ void SSL_CRYPTO_thread_setup(void);
void SSL_CRYPTO_thread_cleanup(void);
unsigned char *SSL_SHA256(const unsigned char *d, size_t n, unsigned char *md);
int SSL_base64_decode(const char *in, unsigned char *out);
#endif