Fix rule search issue.
This commit is contained in:
151
src/dns_server.c
151
src/dns_server.c
@@ -32,6 +32,8 @@
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <ifaddrs.h>
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <pthread.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
@@ -39,8 +41,6 @@
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/tcp.h>
|
||||
|
||||
#define DNS_MAX_EVENTS 256
|
||||
#define DNS_SERVER_TMOUT_TTL (5 * 60)
|
||||
@@ -65,6 +65,12 @@ typedef enum {
|
||||
DNS_CONN_TYPE_TLS_CLIENT,
|
||||
} DNS_CONN_TYPE;
|
||||
|
||||
struct rule_walk_args {
|
||||
void *args;
|
||||
unsigned char *key[DOMAIN_RULE_MAX];
|
||||
uint32_t key_len[DOMAIN_RULE_MAX];
|
||||
};
|
||||
|
||||
struct dns_conn_buf {
|
||||
char buf[DNS_CONN_BUFF_SIZE];
|
||||
int buffsize;
|
||||
@@ -190,7 +196,7 @@ struct dns_request {
|
||||
int ip_map_num;
|
||||
DECLARE_HASHTABLE(ip_map, 4);
|
||||
|
||||
struct dns_domain_rule *domain_rule;
|
||||
struct dns_domain_rule domain_rule;
|
||||
struct dns_domain_check_order *check_order_list;
|
||||
};
|
||||
|
||||
@@ -256,21 +262,19 @@ static int _dns_server_is_return_soa(struct dns_request *request)
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (request->domain_rule) {
|
||||
rule_flag = request->domain_rule->rules[DOMAIN_RULE_FLAGS];
|
||||
if (rule_flag) {
|
||||
flags = rule_flag->flags;
|
||||
if (flags & DOMAIN_FLAG_ADDR_SOA) {
|
||||
return 1;
|
||||
}
|
||||
rule_flag = request->domain_rule.rules[DOMAIN_RULE_FLAGS];
|
||||
if (rule_flag) {
|
||||
flags = rule_flag->flags;
|
||||
if (flags & DOMAIN_FLAG_ADDR_SOA) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if ((flags & DOMAIN_FLAG_ADDR_IPV4_SOA) && (request->qtype == DNS_T_A)) {
|
||||
return 1;
|
||||
}
|
||||
if ((flags & DOMAIN_FLAG_ADDR_IPV4_SOA) && (request->qtype == DNS_T_A)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if ((flags & DOMAIN_FLAG_ADDR_IPV6_SOA) && (request->qtype == DNS_T_AAAA)) {
|
||||
return 1;
|
||||
}
|
||||
if ((flags & DOMAIN_FLAG_ADDR_IPV6_SOA) && (request->qtype == DNS_T_AAAA)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,23 +622,19 @@ static int _dns_setup_ipset(struct dns_request *request)
|
||||
struct dns_rule_flags *rule_flags = NULL;
|
||||
int ret = 0;
|
||||
|
||||
if (request->domain_rule == NULL) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_IPSET) == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* check ipset rule */
|
||||
rule_flags = request->domain_rule->rules[DOMAIN_RULE_FLAGS];
|
||||
rule_flags = request->domain_rule.rules[DOMAIN_RULE_FLAGS];
|
||||
if (rule_flags) {
|
||||
if (rule_flags->flags & DOMAIN_FLAG_IPSET_IGNORE) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
ipset_rule = request->domain_rule->rules[DOMAIN_RULE_IPSET];
|
||||
ipset_rule = request->domain_rule.rules[DOMAIN_RULE_IPSET];
|
||||
if (ipset_rule == NULL) {
|
||||
return 0;
|
||||
}
|
||||
@@ -1704,7 +1704,7 @@ errout:
|
||||
return -1;
|
||||
}
|
||||
|
||||
static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, int rule_key_len)
|
||||
static void _dns_server_log_rule(const char *domain, enum domain_rule rule_type, unsigned char *rule_key, int rule_key_len)
|
||||
{
|
||||
char rule_name[DNS_MAX_CNAME_LEN];
|
||||
|
||||
@@ -1714,55 +1714,85 @@ static void _dns_server_log_rule(const char *domain, unsigned char *rule_key, in
|
||||
|
||||
reverse_string(rule_name, (char *)rule_key, rule_key_len, 1);
|
||||
rule_name[rule_key_len] = 0;
|
||||
tlog(TLOG_INFO, "RULE-MATCH, domain: %s, rule: %s", domain, rule_name);
|
||||
tlog(TLOG_INFO, "RULE-MATCH, type: %d, domain: %s, rule: %s", rule_type, domain, rule_name);
|
||||
}
|
||||
|
||||
static struct dns_domain_rule *_dns_server_get_domain_rule(const char *domain)
|
||||
int _dns_server_get_rules(unsigned char *key, uint32_t key_len, void *value, void *arg)
|
||||
{
|
||||
struct rule_walk_args *walk_args = arg;
|
||||
struct dns_request *request = walk_args->args;
|
||||
struct dns_domain_rule *domain_rule = value;
|
||||
int i = 0;
|
||||
if (domain_rule == NULL) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
for (i = 0; i < DOMAIN_RULE_MAX; i++) {
|
||||
if (domain_rule->rules[i] == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
request->domain_rule.rules[i] = domain_rule->rules[i];
|
||||
walk_args->key[i] = key;
|
||||
walk_args->key_len[i] = key_len;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void _dns_server_get_domain_rule(struct dns_request *request)
|
||||
{
|
||||
int domain_len;
|
||||
char domain_key[DNS_MAX_CNAME_LEN];
|
||||
int matched_key_len = DNS_MAX_CNAME_LEN;
|
||||
unsigned char matched_key[DNS_MAX_CNAME_LEN];
|
||||
struct dns_domain_rule *domain_rule = NULL;
|
||||
struct rule_walk_args walk_args;
|
||||
int i = 0;
|
||||
|
||||
memset(&walk_args, 0, sizeof(walk_args));
|
||||
walk_args.args = request;
|
||||
|
||||
/* reverse domain string */
|
||||
domain_len = strlen(domain);
|
||||
reverse_string(domain_key, domain, domain_len, 1);
|
||||
domain_len = strlen(request->domain);
|
||||
reverse_string(domain_key, request->domain, domain_len, 1);
|
||||
domain_key[domain_len] = '.';
|
||||
domain_len++;
|
||||
domain_key[domain_len] = 0;
|
||||
|
||||
/* find domain rule */
|
||||
if (likely(dns_conf_log_level > TLOG_INFO)) {
|
||||
return art_substring(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, NULL, NULL);
|
||||
art_substring_walk(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules, &walk_args);
|
||||
if (likely(dns_conf_log_level > TLOG_DEBUG)) {
|
||||
return;
|
||||
}
|
||||
|
||||
domain_rule = art_substring(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, matched_key, &matched_key_len);
|
||||
if (domain_rule == NULL) {
|
||||
return NULL;
|
||||
/* output log rule */
|
||||
for (i = 0; i < DOMAIN_RULE_MAX; i++) {
|
||||
if (walk_args.key[i] == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
matched_key_len = walk_args.key_len[i];
|
||||
if (walk_args.key_len[i] >= sizeof(matched_key)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
memcpy(matched_key, walk_args.key[i], walk_args.key_len[i]);
|
||||
|
||||
matched_key_len--;
|
||||
matched_key[matched_key_len] = 0;
|
||||
_dns_server_log_rule(request->domain, i, matched_key, matched_key_len);
|
||||
}
|
||||
|
||||
if (matched_key_len <= 0) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
matched_key_len--;
|
||||
matched_key[matched_key_len] = 0;
|
||||
_dns_server_log_rule(domain, matched_key, matched_key_len);
|
||||
|
||||
return domain_rule;
|
||||
return;
|
||||
}
|
||||
|
||||
static int _dns_server_pre_process_rule_flags(struct dns_request *request)
|
||||
{
|
||||
struct dns_rule_flags *rule_flag = NULL;
|
||||
unsigned int flags = 0;
|
||||
if (request->domain_rule == NULL) {
|
||||
goto out;
|
||||
}
|
||||
|
||||
/* get domain rule flag */
|
||||
rule_flag = request->domain_rule->rules[DOMAIN_RULE_FLAGS];
|
||||
rule_flag = request->domain_rule.rules[DOMAIN_RULE_FLAGS];
|
||||
if (rule_flag == NULL) {
|
||||
goto out;
|
||||
}
|
||||
@@ -1826,10 +1856,6 @@ static int _dns_server_process_address(struct dns_request *request)
|
||||
struct dns_address_IPV4 *address_ipv4 = NULL;
|
||||
struct dns_address_IPV6 *address_ipv6 = NULL;
|
||||
|
||||
if (request->domain_rule == NULL) {
|
||||
goto errout;
|
||||
}
|
||||
|
||||
if (_dns_server_has_bind_flag(request, BIND_FLAG_NO_RULE_ADDR) == 0) {
|
||||
goto errout;
|
||||
}
|
||||
@@ -1837,19 +1863,19 @@ static int _dns_server_process_address(struct dns_request *request)
|
||||
/* address /domain/ rule */
|
||||
switch (request->qtype) {
|
||||
case DNS_T_A:
|
||||
if (request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV4] == NULL) {
|
||||
if (request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV4] == NULL) {
|
||||
goto errout;
|
||||
}
|
||||
address_ipv4 = request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV4];
|
||||
address_ipv4 = request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV4];
|
||||
memcpy(request->ipv4_addr, address_ipv4->ipv4_addr, DNS_RR_A_LEN);
|
||||
request->ttl_v4 = 600;
|
||||
request->has_ipv4 = 1;
|
||||
break;
|
||||
case DNS_T_AAAA:
|
||||
if (request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV6] == NULL) {
|
||||
if (request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV6] == NULL) {
|
||||
goto errout;
|
||||
}
|
||||
address_ipv6 = request->domain_rule->rules[DOMAIN_RULE_ADDRESS_IPV6];
|
||||
address_ipv6 = request->domain_rule.rules[DOMAIN_RULE_ADDRESS_IPV6];
|
||||
memcpy(request->ipv6_addr, address_ipv6->ipv6_addr, DNS_RR_AAAA_LEN);
|
||||
request->ttl_v6 = 600;
|
||||
request->has_ipv6 = 1;
|
||||
@@ -2032,12 +2058,10 @@ static const char *_dns_server_get_request_groupname(struct dns_request *request
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (request->domain_rule) {
|
||||
/* Get the nameserver rule */
|
||||
if (request->domain_rule->rules[DOMAIN_RULE_NAMESERVER]) {
|
||||
struct dns_nameserver_rule *nameserver_rule = request->domain_rule->rules[DOMAIN_RULE_NAMESERVER];
|
||||
return nameserver_rule->group_name;
|
||||
}
|
||||
/* Get the nameserver rule */
|
||||
if (request->domain_rule.rules[DOMAIN_RULE_NAMESERVER]) {
|
||||
struct dns_nameserver_rule *nameserver_rule = request->domain_rule.rules[DOMAIN_RULE_NAMESERVER];
|
||||
return nameserver_rule->group_name;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
@@ -2053,10 +2077,11 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain,
|
||||
dns_group = request->conn->dns_group;
|
||||
}
|
||||
|
||||
/* lookup domain rule */
|
||||
request->domain_rule = _dns_server_get_domain_rule(domain);
|
||||
request->qtype = qtype;
|
||||
safe_strncpy(request->domain, domain, sizeof(request->domain));
|
||||
request->qtype = qtype;
|
||||
|
||||
/* lookup domain rule */
|
||||
_dns_server_get_domain_rule(request);
|
||||
group_name = _dns_server_get_request_groupname(request);
|
||||
if (group_name == NULL) {
|
||||
group_name = dns_group;
|
||||
|
||||
@@ -195,6 +195,17 @@ void* art_search(const art_tree *t, const unsigned char *key, int key_len);
|
||||
*/
|
||||
void *art_substring(const art_tree *t, const unsigned char *str, int str_len, unsigned char *key, int *key_len);
|
||||
|
||||
/**
|
||||
* Wakk substring for a value in the ART tree
|
||||
* @arg t The tree
|
||||
* @arg str The key
|
||||
* @arg str_len The length of the key
|
||||
* @return NULL if the item was not found, otherwise
|
||||
* the value pointer is returned.
|
||||
*/
|
||||
typedef int (*walk_func)(unsigned char *key, uint32_t key_len, void *value, void *arg);
|
||||
void art_substring_walk(const art_tree *t, const unsigned char *str, int str_len, walk_func func, void *arg);
|
||||
|
||||
/**
|
||||
* Returns the minimum valued leaf
|
||||
* @return The minimum leaf or NULL
|
||||
|
||||
@@ -1073,3 +1073,57 @@ void *art_substring(const art_tree *t, const unsigned char *str, int str_len, un
|
||||
|
||||
return found->value;
|
||||
}
|
||||
|
||||
void art_substring_walk(const art_tree *t, const unsigned char *str, int str_len, walk_func func, void *arg)
|
||||
{
|
||||
art_node **child;
|
||||
art_node *n = t->root;
|
||||
art_node *m;
|
||||
art_leaf *found = NULL;
|
||||
int prefix_len, depth = 0;
|
||||
int stop_search = 0;
|
||||
|
||||
while (n && stop_search == 0) {
|
||||
// Might be a leaf
|
||||
if (IS_LEAF(n)) {
|
||||
n = (art_node*)LEAF_RAW(n);
|
||||
// Check if the expanded path matches
|
||||
if (!str_prefix_matches((art_leaf*)n, str, str_len)) {
|
||||
found = (art_leaf*)n;
|
||||
stop_search = func(found->key, found->key_len, found->value, arg);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Check if current is leaf
|
||||
child = find_child(n, 0);
|
||||
m = (child) ? *child : NULL;
|
||||
if (m && IS_LEAF(m)) {
|
||||
m = (art_node*)LEAF_RAW(m);
|
||||
// Check if the expanded path matches
|
||||
if (!str_prefix_matches((art_leaf*)m, str, str_len)) {
|
||||
found = (art_leaf*)m;
|
||||
stop_search = func(found->key, found->key_len, found->value, arg);
|
||||
}
|
||||
}
|
||||
|
||||
// Bail if the prefix does not match
|
||||
if (n->partial_len) {
|
||||
prefix_len = check_prefix(n, str, str_len, depth);
|
||||
if (prefix_len != min(MAX_PREFIX_LEN, n->partial_len))
|
||||
break;
|
||||
depth = depth + n->partial_len;
|
||||
}
|
||||
|
||||
// Recursively search
|
||||
child = find_child(n, str[depth]);
|
||||
n = (child) ? *child : NULL;
|
||||
depth++;
|
||||
}
|
||||
|
||||
if (found == NULL) {
|
||||
return ;
|
||||
}
|
||||
|
||||
return ;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user