From 64abad407737a2842479af19d4caef94c24da093 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Sat, 19 Oct 2019 18:46:22 +0800 Subject: [PATCH] Fix rule search issue. --- src/dns_server.c | 151 +++++++++++++++++++++++++++------------------- src/include/art.h | 11 ++++ src/lib/art.c | 54 +++++++++++++++++ 3 files changed, 153 insertions(+), 63 deletions(-) diff --git a/src/dns_server.c b/src/dns_server.c index 4c9eb2a..ee5a183 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -32,6 +32,8 @@ #include #include #include +#include +#include #include #include #include @@ -39,8 +41,6 @@ #include #include #include -#include -#include #define DNS_MAX_EVENTS 256 #define DNS_SERVER_TMOUT_TTL (5 * 60) @@ -65,6 +65,12 @@ typedef enum { DNS_CONN_TYPE_TLS_CLIENT, } DNS_CONN_TYPE; +struct rule_walk_args { + void *args; + unsigned char *key[DOMAIN_RULE_MAX]; + uint32_t key_len[DOMAIN_RULE_MAX]; +}; + struct dns_conn_buf { char buf[DNS_CONN_BUFF_SIZE]; int buffsize; @@ -190,7 +196,7 @@ struct dns_request { int ip_map_num; DECLARE_HASHTABLE(ip_map, 4); - struct dns_domain_rule *domain_rule; + struct dns_domain_rule domain_rule; struct dns_domain_check_order *check_order_list; }; @@ -256,21 +262,19 @@ static int _dns_server_is_return_soa(struct dns_request *request) return 1; } - if (request->domain_rule) { - rule_flag = request->domain_rule->rules[DOMAIN_RULE_FLAGS]; - if (rule_flag) { - flags = rule_flag->flags; - if (flags & DOMAIN_FLAG_ADDR_SOA) { - return 1; - } + rule_flag = request->domain_rule.rules[DOMAIN_RULE_FLAGS]; + if (rule_flag) { + flags = rule_flag->flags; + if (flags & DOMAIN_FLAG_ADDR_SOA) { + return 1; + } - if ((flags & DOMAIN_FLAG_ADDR_IPV4_SOA) && (request->qtype == DNS_T_A)) { - return 1; - } + if ((flags & DOMAIN_FLAG_ADDR_IPV4_SOA) && (request->qtype == DNS_T_A)) { + return 1; + } - if ((flags & DOMAIN_FLAG_ADDR_IPV6_SOA) && (request->qtype == DNS_T_AAAA)) { - return 1; - } + if ((flags & DOMAIN_FLAG_ADDR_IPV6_SOA) && (request->qtype == DNS_T_AAAA)) { + return 1; } } @@ -618,23 +622,19 @@ static int _dns_setup_ipset(struct dns_request *request) struct dns_rule_flags *rule_flags = NULL; int ret = 0; - if (request->domain_rule == NULL) { - return 0; - } - if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_IPSET) == 0) { return 0; } /* check ipset rule */ - rule_flags = request->domain_rule->rules[DOMAIN_RULE_FLAGS]; + rule_flags = request->domain_rule.rules[DOMAIN_RULE_FLAGS]; if (rule_flags) { if (rule_flags->flags & DOMAIN_FLAG_IPSET_IGNORE) { return 0; } } - ipset_rule = request->domain_rule->rules[DOMAIN_RULE_IPSET]; + ipset_rule = request->domain_rule.rules[DOMAIN_RULE_IPSET]; if (ipset_rule == NULL) { return 0; } @@ -1704,7 +1704,7 @@ errout: return -1; } -static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, int rule_key_len) +static void _dns_server_log_rule(const char *domain, enum domain_rule rule_type, unsigned char *rule_key, int rule_key_len) { char rule_name[DNS_MAX_CNAME_LEN]; @@ -1714,55 +1714,85 @@ static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, in reverse_string(rule_name, (char *)rule_key, rule_key_len, 1); rule_name[rule_key_len] = 0; - tlog(TLOG_INFO, "RULE-MATCH, domain: %s, rule: %s", domain, rule_name); + tlog(TLOG_INFO, "RULE-MATCH, type: %d, domain: %s, rule: %s", rule_type, domain, rule_name); } -static struct dns_domain_rule *_dns_server_get_domain_rule(const char *domain) +int _dns_server_get_rules(unsigned char *key, uint32_t key_len, void *value, void *arg) +{ + struct rule_walk_args *walk_args = arg; + struct dns_request *request = walk_args->args; + struct dns_domain_rule *domain_rule = value; + int i = 0; + if (domain_rule == NULL) { + return 0; + } + + for (i = 0; i < DOMAIN_RULE_MAX; i++) { + if (domain_rule->rules[i] == NULL) { + continue; + } + + request->domain_rule.rules[i] = domain_rule->rules[i]; + walk_args->key[i] = key; + walk_args->key_len[i] = key_len; + } + + return 0; +} + +void _dns_server_get_domain_rule(struct dns_request *request) { int domain_len; char domain_key[DNS_MAX_CNAME_LEN]; int matched_key_len = DNS_MAX_CNAME_LEN; unsigned char matched_key[DNS_MAX_CNAME_LEN]; - struct dns_domain_rule *domain_rule = NULL; + struct rule_walk_args walk_args; + int i = 0; + + memset(&walk_args, 0, sizeof(walk_args)); + walk_args.args = request; /* reverse domain string */ - domain_len = strlen(domain); - reverse_string(domain_key, domain, domain_len, 1); + domain_len = strlen(request->domain); + reverse_string(domain_key, request->domain, domain_len, 1); domain_key[domain_len] = '.'; domain_len++; domain_key[domain_len] = 0; /* find domain rule */ - if (likely(dns_conf_log_level > TLOG_INFO)) { - return art_substring(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, NULL, NULL); + art_substring_walk(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules, &walk_args); + if (likely(dns_conf_log_level > TLOG_DEBUG)) { + return; } - domain_rule = art_substring(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, matched_key, &matched_key_len); - if (domain_rule == NULL) { - return NULL; + /* output log rule */ + for (i = 0; i < DOMAIN_RULE_MAX; i++) { + if (walk_args.key[i] == NULL) { + continue; + } + + matched_key_len = walk_args.key_len[i]; + if (walk_args.key_len[i] >= sizeof(matched_key)) { + continue; + } + + memcpy(matched_key, walk_args.key[i], walk_args.key_len[i]); + + matched_key_len--; + matched_key[matched_key_len] = 0; + _dns_server_log_rule(request->domain, i, matched_key, matched_key_len); } - if (matched_key_len <= 0) { - return NULL; - } - - matched_key_len--; - matched_key[matched_key_len] = 0; - _dns_server_log_rule(domain, matched_key, matched_key_len); - - return domain_rule; + return; } static int _dns_server_pre_process_rule_flags(struct dns_request *request) { struct dns_rule_flags *rule_flag = NULL; unsigned int flags = 0; - if (request->domain_rule == NULL) { - goto out; - } /* get domain rule flag */ - rule_flag = request->domain_rule->rules[DOMAIN_RULE_FLAGS]; + rule_flag = request->domain_rule.rules[DOMAIN_RULE_FLAGS]; if (rule_flag == NULL) { goto out; } @@ -1826,10 +1856,6 @@ static int _dns_server_process_address(struct dns_request *request) struct dns_address_IPV4 *address_ipv4 = NULL; struct dns_address_IPV6 *address_ipv6 = NULL; - if (request->domain_rule == NULL) { - goto errout; - } - if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_ADDR) == 0) { goto errout; } @@ -1837,19 +1863,19 @@ static int _dns_server_process_address(struct dns_request *request) /* address /domain/ rule */ switch (request->qtype) { case DNS_T_A: - if (request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV4] == NULL) { + if (request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV4] == NULL) { goto errout; } - address_ipv4 = request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV4]; + address_ipv4 = request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV4]; memcpy(request->ipv4_addr, address_ipv4->ipv4_addr, DNS_RR_A_LEN); request->ttl_v4 = 600; request->has_ipv4 = 1; break; case DNS_T_AAAA: - if (request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV6] == NULL) { + if (request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV6] == NULL) { goto errout; } - address_ipv6 = request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV6]; + address_ipv6 = request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV6]; memcpy(request->ipv6_addr, address_ipv6->ipv6_addr, DNS_RR_AAAA_LEN); request->ttl_v6 = 600; request->has_ipv6 = 1; @@ -2032,12 +2058,10 @@ static const char *_dns_server_get_request_groupname(struct dns_request *request return NULL; } - if (request->domain_rule) { - /* Get the nameserver rule */ - if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) { - struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]; - return nameserver_rule->group_name; - } + /* Get the nameserver rule */ + if (request->domain_rule.rules[DOMAIN_RULE_NAMESERVER]) { + struct dns_nameserver_rule *nameserver_rule = request->domain_rule.rules[DOMAIN_RULE_NAMESERVER]; + return nameserver_rule->group_name; } return NULL; @@ -2053,10 +2077,11 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain, dns_group = request->conn->dns_group; } - /* lookup domain rule */ - request->domain_rule = _dns_server_get_domain_rule(domain); - request->qtype = qtype; safe_strncpy(request->domain, domain, sizeof(request->domain)); + request->qtype = qtype; + + /* lookup domain rule */ + _dns_server_get_domain_rule(request); group_name = _dns_server_get_request_groupname(request); if (group_name == NULL) { group_name = dns_group; diff --git a/src/include/art.h b/src/include/art.h index e39745f..eb9ccf4 100644 --- a/src/include/art.h +++ b/src/include/art.h @@ -195,6 +195,17 @@ void* art_search(const art_tree *t, const unsigned char *key, int key_len); */ void *art_substring(const art_tree *t, const unsigned char *str, int str_len, unsigned char *key, int *key_len); +/** + * Wakk substring for a value in the ART tree + * @arg t The tree + * @arg str The key + * @arg str_len The length of the key + * @return NULL if the item was not found, otherwise + * the value pointer is returned. + */ +typedef int (*walk_func)(unsigned char *key, uint32_t key_len, void *value, void *arg); +void art_substring_walk(const art_tree *t, const unsigned char *str, int str_len, walk_func func, void *arg); + /** * Returns the minimum valued leaf * @return The minimum leaf or NULL diff --git a/src/lib/art.c b/src/lib/art.c index 93d9cbd..be5a672 100644 --- a/src/lib/art.c +++ b/src/lib/art.c @@ -1073,3 +1073,57 @@ void *art_substring(const art_tree *t, const unsigned char *str, int str_len, un return found->value; } + +void art_substring_walk(const art_tree *t, const unsigned char *str, int str_len, walk_func func, void *arg) +{ + art_node **child; + art_node *n = t->root; + art_node *m; + art_leaf *found = NULL; + int prefix_len, depth = 0; + int stop_search = 0; + + while (n && stop_search == 0) { + // Might be a leaf + if (IS_LEAF(n)) { + n = (art_node*)LEAF_RAW(n); + // Check if the expanded path matches + if (!str_prefix_matches((art_leaf*)n, str, str_len)) { + found = (art_leaf*)n; + stop_search = func(found->key, found->key_len, found->value, arg); + } + break; + } + + // Check if current is leaf + child = find_child(n, 0); + m = (child) ? *child : NULL; + if (m && IS_LEAF(m)) { + m = (art_node*)LEAF_RAW(m); + // Check if the expanded path matches + if (!str_prefix_matches((art_leaf*)m, str, str_len)) { + found = (art_leaf*)m; + stop_search = func(found->key, found->key_len, found->value, arg); + } + } + + // Bail if the prefix does not match + if (n->partial_len) { + prefix_len = check_prefix(n, str, str_len, depth); + if (prefix_len != min(MAX_PREFIX_LEN, n->partial_len)) + break; + depth = depth + n->partial_len; + } + + // Recursively search + child = find_child(n, str[depth]); + n = (child) ? *child : NULL; + depth++; + } + + if (found == NULL) { + return ; + } + + return ; +}