Fix rule search issue.

This commit is contained in:
Nick Peng
2019-10-19 18:46:22 +08:00
parent c3501923db
commit 64abad4077
3 changed files with 153 additions and 63 deletions

View File

@@ -32,6 +32,8 @@
#include <errno.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
@@ -39,8 +41,6 @@
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#define DNS_MAX_EVENTS 256
#define DNS_SERVER_TMOUT_TTL (5 * 60)
@@ -65,6 +65,12 @@ typedef enum {
DNS_CONN_TYPE_TLS_CLIENT,
} DNS_CONN_TYPE;
struct rule_walk_args {
void *args;
unsigned char *key[DOMAIN_RULE_MAX];
uint32_t key_len[DOMAIN_RULE_MAX];
};
struct dns_conn_buf {
char buf[DNS_CONN_BUFF_SIZE];
int buffsize;
@@ -190,7 +196,7 @@ struct dns_request {
int ip_map_num;
DECLARE_HASHTABLE(ip_map, 4);
struct dns_domain_rule *domain_rule;
struct dns_domain_rule domain_rule;
struct dns_domain_check_order *check_order_list;
};
@@ -256,21 +262,19 @@ static int _dns_server_is_return_soa(struct dns_request *request)
return 1;
}
if (request->domain_rule) {
rule_flag = request->domain_rule->rules[DOMAIN_RULE_FLAGS];
if (rule_flag) {
flags = rule_flag->flags;
if (flags & DOMAIN_FLAG_ADDR_SOA) {
return 1;
}
rule_flag = request->domain_rule.rules[DOMAIN_RULE_FLAGS];
if (rule_flag) {
flags = rule_flag->flags;
if (flags & DOMAIN_FLAG_ADDR_SOA) {
return 1;
}
if ((flags & DOMAIN_FLAG_ADDR_IPV4_SOA) && (request->qtype == DNS_T_A)) {
return 1;
}
if ((flags & DOMAIN_FLAG_ADDR_IPV4_SOA) && (request->qtype == DNS_T_A)) {
return 1;
}
if ((flags & DOMAIN_FLAG_ADDR_IPV6_SOA) && (request->qtype == DNS_T_AAAA)) {
return 1;
}
if ((flags & DOMAIN_FLAG_ADDR_IPV6_SOA) && (request->qtype == DNS_T_AAAA)) {
return 1;
}
}
@@ -618,23 +622,19 @@ static int _dns_setup_ipset(struct dns_request *request)
struct dns_rule_flags *rule_flags = NULL;
int ret = 0;
if (request->domain_rule == NULL) {
return 0;
}
if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_IPSET) == 0) {
return 0;
}
/* check ipset rule */
rule_flags = request->domain_rule->rules[DOMAIN_RULE_FLAGS];
rule_flags = request->domain_rule.rules[DOMAIN_RULE_FLAGS];
if (rule_flags) {
if (rule_flags->flags & DOMAIN_FLAG_IPSET_IGNORE) {
return 0;
}
}
ipset_rule = request->domain_rule->rules[DOMAIN_RULE_IPSET];
ipset_rule = request->domain_rule.rules[DOMAIN_RULE_IPSET];
if (ipset_rule == NULL) {
return 0;
}
@@ -1704,7 +1704,7 @@ errout:
return -1;
}
static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, int rule_key_len)
static void _dns_server_log_rule(const char *domain, enum domain_rule rule_type, unsigned char *rule_key, int rule_key_len)
{
char rule_name[DNS_MAX_CNAME_LEN];
@@ -1714,55 +1714,85 @@ static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, in
reverse_string(rule_name, (char *)rule_key, rule_key_len, 1);
rule_name[rule_key_len] = 0;
tlog(TLOG_INFO, "RULE-MATCH, domain: %s, rule: %s", domain, rule_name);
tlog(TLOG_INFO, "RULE-MATCH, type: %d, domain: %s, rule: %s", rule_type, domain, rule_name);
}
static struct dns_domain_rule *_dns_server_get_domain_rule(const char *domain)
int _dns_server_get_rules(unsigned char *key, uint32_t key_len, void *value, void *arg)
{
struct rule_walk_args *walk_args = arg;
struct dns_request *request = walk_args->args;
struct dns_domain_rule *domain_rule = value;
int i = 0;
if (domain_rule == NULL) {
return 0;
}
for (i = 0; i < DOMAIN_RULE_MAX; i++) {
if (domain_rule->rules[i] == NULL) {
continue;
}
request->domain_rule.rules[i] = domain_rule->rules[i];
walk_args->key[i] = key;
walk_args->key_len[i] = key_len;
}
return 0;
}
void _dns_server_get_domain_rule(struct dns_request *request)
{
int domain_len;
char domain_key[DNS_MAX_CNAME_LEN];
int matched_key_len = DNS_MAX_CNAME_LEN;
unsigned char matched_key[DNS_MAX_CNAME_LEN];
struct dns_domain_rule *domain_rule = NULL;
struct rule_walk_args walk_args;
int i = 0;
memset(&walk_args, 0, sizeof(walk_args));
walk_args.args = request;
/* reverse domain string */
domain_len = strlen(domain);
reverse_string(domain_key, domain, domain_len, 1);
domain_len = strlen(request->domain);
reverse_string(domain_key, request->domain, domain_len, 1);
domain_key[domain_len] = '.';
domain_len++;
domain_key[domain_len] = 0;
/* find domain rule */
if (likely(dns_conf_log_level > TLOG_INFO)) {
return art_substring(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, NULL, NULL);
art_substring_walk(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules, &walk_args);
if (likely(dns_conf_log_level > TLOG_DEBUG)) {
return;
}
domain_rule = art_substring(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, matched_key, &matched_key_len);
if (domain_rule == NULL) {
return NULL;
/* output log rule */
for (i = 0; i < DOMAIN_RULE_MAX; i++) {
if (walk_args.key[i] == NULL) {
continue;
}
matched_key_len = walk_args.key_len[i];
if (walk_args.key_len[i] >= sizeof(matched_key)) {
continue;
}
memcpy(matched_key, walk_args.key[i], walk_args.key_len[i]);
matched_key_len--;
matched_key[matched_key_len] = 0;
_dns_server_log_rule(request->domain, i, matched_key, matched_key_len);
}
if (matched_key_len <= 0) {
return NULL;
}
matched_key_len--;
matched_key[matched_key_len] = 0;
_dns_server_log_rule(domain, matched_key, matched_key_len);
return domain_rule;
return;
}
static int _dns_server_pre_process_rule_flags(struct dns_request *request)
{
struct dns_rule_flags *rule_flag = NULL;
unsigned int flags = 0;
if (request->domain_rule == NULL) {
goto out;
}
/* get domain rule flag */
rule_flag = request->domain_rule->rules[DOMAIN_RULE_FLAGS];
rule_flag = request->domain_rule.rules[DOMAIN_RULE_FLAGS];
if (rule_flag == NULL) {
goto out;
}
@@ -1826,10 +1856,6 @@ static int _dns_server_process_address(struct dns_request *request)
struct dns_address_IPV4 *address_ipv4 = NULL;
struct dns_address_IPV6 *address_ipv6 = NULL;
if (request->domain_rule == NULL) {
goto errout;
}
if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_ADDR) == 0) {
goto errout;
}
@@ -1837,19 +1863,19 @@ static int _dns_server_process_address(struct dns_request *request)
/* address /domain/ rule */
switch (request->qtype) {
case DNS_T_A:
if (request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV4] == NULL) {
if (request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV4] == NULL) {
goto errout;
}
address_ipv4 = request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV4];
address_ipv4 = request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV4];
memcpy(request->ipv4_addr, address_ipv4->ipv4_addr, DNS_RR_A_LEN);
request->ttl_v4 = 600;
request->has_ipv4 = 1;
break;
case DNS_T_AAAA:
if (request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV6] == NULL) {
if (request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV6] == NULL) {
goto errout;
}
address_ipv6 = request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV6];
address_ipv6 = request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV6];
memcpy(request->ipv6_addr, address_ipv6->ipv6_addr, DNS_RR_AAAA_LEN);
request->ttl_v6 = 600;
request->has_ipv6 = 1;
@@ -2032,12 +2058,10 @@ static const char *_dns_server_get_request_groupname(struct dns_request *request
return NULL;
}
if (request->domain_rule) {
/* Get the nameserver rule */
if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) {
struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER];
return nameserver_rule->group_name;
}
/* Get the nameserver rule */
if (request->domain_rule.rules[DOMAIN_RULE_NAMESERVER]) {
struct dns_nameserver_rule *nameserver_rule = request->domain_rule.rules[DOMAIN_RULE_NAMESERVER];
return nameserver_rule->group_name;
}
return NULL;
@@ -2053,10 +2077,11 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain,
dns_group = request->conn->dns_group;
}
/* lookup domain rule */
request->domain_rule = _dns_server_get_domain_rule(domain);
request->qtype = qtype;
safe_strncpy(request->domain, domain, sizeof(request->domain));
request->qtype = qtype;
/* lookup domain rule */
_dns_server_get_domain_rule(request);
group_name = _dns_server_get_request_groupname(request);
if (group_name == NULL) {
group_name = dns_group;

View File

@@ -195,6 +195,17 @@ void* art_search(const art_tree *t, const unsigned char *key, int key_len);
*/
void *art_substring(const art_tree *t, const unsigned char *str, int str_len, unsigned char *key, int *key_len);
/**
* Wakk substring for a value in the ART tree
* @arg t The tree
* @arg str The key
* @arg str_len The length of the key
* @return NULL if the item was not found, otherwise
* the value pointer is returned.
*/
typedef int (*walk_func)(unsigned char *key, uint32_t key_len, void *value, void *arg);
void art_substring_walk(const art_tree *t, const unsigned char *str, int str_len, walk_func func, void *arg);
/**
* Returns the minimum valued leaf
* @return The minimum leaf or NULL

View File

@@ -1073,3 +1073,57 @@ void *art_substring(const art_tree *t, const unsigned char *str, int str_len, un
return found->value;
}
void art_substring_walk(const art_tree *t, const unsigned char *str, int str_len, walk_func func, void *arg)
{
art_node **child;
art_node *n = t->root;
art_node *m;
art_leaf *found = NULL;
int prefix_len, depth = 0;
int stop_search = 0;
while (n && stop_search == 0) {
// Might be a leaf
if (IS_LEAF(n)) {
n = (art_node*)LEAF_RAW(n);
// Check if the expanded path matches
if (!str_prefix_matches((art_leaf*)n, str, str_len)) {
found = (art_leaf*)n;
stop_search = func(found->key, found->key_len, found->value, arg);
}
break;
}
// Check if current is leaf
child = find_child(n, 0);
m = (child) ? *child : NULL;
if (m && IS_LEAF(m)) {
m = (art_node*)LEAF_RAW(m);
// Check if the expanded path matches
if (!str_prefix_matches((art_leaf*)m, str, str_len)) {
found = (art_leaf*)m;
stop_search = func(found->key, found->key_len, found->value, arg);
}
}
// Bail if the prefix does not match
if (n->partial_len) {
prefix_len = check_prefix(n, str, str_len, depth);
if (prefix_len != min(MAX_PREFIX_LEN, n->partial_len))
break;
depth = depth + n->partial_len;
}
// Recursively search
child = find_child(n, str[depth]);
n = (child) ? *child : NULL;
depth++;
}
if (found == NULL) {
return ;
}
return ;
}