diff --git a/src/dns_client.c b/src/dns_client.c index f73bae0..37fdddb 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -253,6 +253,9 @@ static int dns_client_has_bootstrap_dns = 0; int _ssl_read(struct dns_server_info *server, void *buff, int num) { int ret = 0; + if (server == NULL || buff == NULL) { + return SSL_ERROR_SYSCALL; + } pthread_mutex_lock(&server->lock); ret = SSL_read(server->ssl, buff, num); pthread_mutex_unlock(&server->lock); @@ -262,6 +265,10 @@ int _ssl_read(struct dns_server_info *server, void *buff, int num) int _ssl_write(struct dns_server_info *server, const void *buff, int num) { int ret = 0; + if (server == NULL || buff == NULL || server->ssl == NULL) { + return SSL_ERROR_SYSCALL; + } + pthread_mutex_lock(&server->lock); ret = SSL_write(server->ssl, buff, num); pthread_mutex_unlock(&server->lock); @@ -271,6 +278,10 @@ int _ssl_write(struct dns_server_info *server, const void *buff, int num) int _ssl_shutdown(struct dns_server_info *server) { int ret = 0; + if (server == NULL || server->ssl == NULL) { + return SSL_ERROR_SYSCALL; + } + pthread_mutex_lock(&server->lock); ret = SSL_shutdown(server->ssl); pthread_mutex_unlock(&server->lock); @@ -280,6 +291,10 @@ int _ssl_shutdown(struct dns_server_info *server) int _ssl_get_error(struct dns_server_info *server, int ret) { int err = 0; + if (server == NULL || server->ssl == NULL) { + return SSL_ERROR_SYSCALL; + } + pthread_mutex_lock(&server->lock); err = SSL_get_error(server->ssl, ret); pthread_mutex_unlock(&server->lock); @@ -289,6 +304,10 @@ int _ssl_get_error(struct dns_server_info *server, int ret) int _ssl_do_handshake(struct dns_server_info *server) { int err = 0; + if (server == NULL || server->ssl == NULL) { + return SSL_ERROR_SYSCALL; + } + pthread_mutex_lock(&server->lock); err = SSL_do_handshake(server->ssl); pthread_mutex_unlock(&server->lock); @@ -298,6 +317,10 @@ int _ssl_do_handshake(struct dns_server_info *server) int _ssl_session_reused(struct dns_server_info *server) { int err = 0; + if (server == NULL || server->ssl == NULL) { + return SSL_ERROR_SYSCALL; + } + pthread_mutex_lock(&server->lock); err = SSL_session_reused(server->ssl); pthread_mutex_unlock(&server->lock); @@ -307,6 +330,10 @@ int _ssl_session_reused(struct dns_server_info *server) SSL_SESSION *_ssl_get1_session(struct dns_server_info *server) { SSL_SESSION *ret = 0; + if (server == NULL || server->ssl == NULL) { + return NULL; + } + pthread_mutex_lock(&server->lock); ret = SSL_get1_session(server->ssl); pthread_mutex_unlock(&server->lock); @@ -407,6 +434,10 @@ static struct dns_server_info *_dns_client_get_server(char *server_ip, int port, struct dns_server_info *server_info, *tmp; struct dns_server_info *server_info_return = NULL; + if (server_ip == NULL) { + return NULL; + } + pthread_mutex_lock(&client.server_list_lock); list_for_each_entry_safe(server_info, tmp, &client.dns_server_list, list) { @@ -509,6 +540,10 @@ static int _dns_client_add_to_pending_group(char *group_name, char *server_ip, i struct dns_server_pending *pending = NULL; struct dns_server_pending_group *group = NULL; + if (group_name == NULL || server_ip == NULL) { + goto errout; + } + pthread_mutex_lock(&pending_server_mutex); list_for_each_entry_safe(item, tmp, &pending_servers, list) { @@ -550,6 +585,10 @@ static int _dns_client_add_to_group_pending(char *group_name, char *server_ip, i { struct dns_server_info *server_info = NULL; + if (group_name == NULL || server_ip == NULL) { + return -1; + } + server_info = _dns_client_get_server(server_ip, port, server_type); if (server_info == NULL) { if (ispending == 0) { @@ -630,6 +669,10 @@ int dns_client_add_group(char *group_name) unsigned long key; struct dns_server_group *group = NULL; + if (group_name == NULL) { + return -1; + } + if (_dns_client_get_group(group_name) != NULL) { tlog(TLOG_ERROR, "add group %s failed, group already exists", group_name); return -1; @@ -661,6 +704,10 @@ static int _dns_client_remove_group(struct dns_server_group *group) struct dns_server_group_member *group_member; struct dns_server_group_member *tmp; + if (group == NULL) { + return 0; + } + list_for_each_entry_safe(group_member, tmp, &group->head, list) { _dns_client_remove_member(group_member); @@ -678,6 +725,10 @@ int dns_client_remove_group(char *group_name) struct dns_server_group *group = NULL; struct hlist_node *tmp = NULL; + if (group_name == NULL) { + return -1; + } + key = hash_string(group_name); hash_for_each_possible_safe(client.group, group, tmp, node, key) { @@ -786,6 +837,10 @@ static int _dns_client_set_trusted_cert(SSL_CTX *ssl_ctx) char *capath = NULL; int cert_path_set = 0; + if (ssl_ctx == NULL) { + return -1; + } + if (dns_conf_ca_file[0]) { cafile = dns_conf_ca_file; } @@ -795,7 +850,7 @@ static int _dns_client_set_trusted_cert(SSL_CTX *ssl_ctx) } if (cafile == NULL && capath == NULL) { - if (SSL_CTX_set_default_verify_paths(ssl_ctx)) { + if (SSL_CTX_set_default_verify_paths(ssl_ctx) == 0) { cafile = "/etc/ssl/certs/ca-certificates.crt"; capath = "/etc/ssl/certs"; } else { @@ -804,7 +859,7 @@ static int _dns_client_set_trusted_cert(SSL_CTX *ssl_ctx) } if (cert_path_set == 0) { - if (!SSL_CTX_load_verify_locations(ssl_ctx, cafile, capath)) { + if (SSL_CTX_load_verify_locations(ssl_ctx, cafile, capath) == 0) { tlog(TLOG_WARN, "load certificate from %s:%s failed.", cafile, capath); return -1; } @@ -2855,6 +2910,10 @@ int dns_client_query(char *domain, int qtype, dns_client_callback callback, void int ret = 0; uint32_t key = 0; + if (domain == NULL) { + goto errout; + } + query = malloc(sizeof(*query)); if (query == NULL) { goto errout; diff --git a/src/dns_conf.c b/src/dns_conf.c index 83b4cc8..be2113a 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -1454,8 +1454,14 @@ int config_addtional_file(void *data, int argc, char *argv[]) if (conf_file[0] != '/') { safe_strncpy(file_path_dir, conf_get_conf_file(), DNS_MAX_PATH); dirname(file_path_dir); - if (snprintf(file_path, DNS_MAX_PATH, "%s/%s", file_path_dir, conf_file) < 0) { - return -1; + if (strncmp(file_path_dir, conf_get_conf_file(), sizeof(file_path_dir)) == 0) { + if (snprintf(file_path, DNS_MAX_PATH, "%s", conf_file) < 0) { + return -1; + } + } else { + if (snprintf(file_path, DNS_MAX_PATH, "%s/%s", file_path_dir, conf_file) < 0) { + return -1; + } } } else { safe_strncpy(file_path, conf_file, DNS_MAX_PATH);