From 22e13b40db1345d7dc85b4d47ab6b7bf0f8ab11c Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Tue, 26 Apr 2022 20:43:11 +0800 Subject: [PATCH] dns_server: fix passthrouth ipset issue --- src/dns.c | 3 ++ src/dns_server.c | 94 +++++++++++++++++++++++++++--------------------- 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/src/dns.c b/src/dns.c index 959271e..4fe2c59 100644 --- a/src/dns.c +++ b/src/dns.c @@ -2087,6 +2087,9 @@ static int _dns_update_an(struct dns_context *context, dns_rr_type type, struct break; default: { unsigned char *ttl_ptr = start - sizeof(int) - sizeof(short); + if (param->ip_ttl < 0) { + break; + } _dns_write_int(&ttl_ptr, param->ip_ttl); } break; } diff --git a/src/dns_server.c b/src/dns_server.c index 3d6ed27..805abdd 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -102,6 +102,8 @@ struct dns_server_post_context { int qtype; int do_cache; int do_reply; + int do_ipset; + int reply_ttl; int do_audit; int do_force_soa; int skip_notify_count; @@ -367,7 +369,7 @@ static void _dns_server_post_context_init_from(struct dns_server_post_context *c struct dns_packet *packet, unsigned char *inpacket, int inpacket_len) { memset(context, 0, sizeof(*context)); - context->packet = (struct dns_packet *)(context->packet_buff); + context->packet = packet; context->packet_maxlen = sizeof(context->packet_buff); context->inpacket = inpacket; context->inpacket_len = inpacket_len; @@ -958,6 +960,10 @@ static int _dns_server_setup_ipset_packet(struct dns_server_post_context *contex return 0; } + if (context->do_ipset == 0) { + return 0; + } + if (context->ip_num <= 0) { return 0; } @@ -1178,6 +1184,7 @@ static void _dns_server_dualstack_selection_cache_A(struct dns_request *request) _dns_server_post_context_init(&context, request); context.qtype = DNS_T_A; context.do_cache = 1; + context.do_ipset = 1; context.skip_notify_count = 1; _dns_request_post(&context); } @@ -1293,6 +1300,7 @@ out: struct dns_server_post_context context; _dns_server_post_context_init(&context, request); context.do_cache = 1; + context.do_ipset = 1; context.do_force_soa = force_A; context.do_audit = 1; context.do_reply = 1; @@ -1472,6 +1480,7 @@ static void _dns_server_complete_with_multi_ipaddress(struct dns_request *reques _dns_server_post_context_init(&context, request); context.do_cache = 1; + context.do_ipset = 1; context.do_reply = do_reply; context.select_all_best_ip = 1; context.skip_notify_count = 1; @@ -1481,6 +1490,7 @@ static void _dns_server_complete_with_multi_ipaddress(struct dns_request *reques _dns_server_post_context_init(&context, request); context.qtype = DNS_T_A; context.do_cache = 1; + context.do_ipset = 1; context.select_all_best_ip = 1; context.skip_notify_count = 1; _dns_request_post(&context); @@ -1527,7 +1537,8 @@ static void _dns_server_request_release_complete(struct dns_request *request, in static void _dns_server_request_release(struct dns_request *request) { - _dns_server_request_release_complete(request, 1); + + _dns_server_request_release_complete(request, request->passthrough == 0); } static void _dns_server_request_get(struct dns_request *request) @@ -2173,7 +2184,7 @@ static int _dns_server_passthrough_rule_check(struct dns_request *request, char return -1; } -static int _dns_server_get_answer(struct dns_request *request, struct dns_packet *packet) +static int _dns_server_get_answer(struct dns_server_post_context *context) { int i = 0; int j = 0; @@ -2181,6 +2192,8 @@ static int _dns_server_get_answer(struct dns_request *request, struct dns_packet struct dns_rrs *rrs = NULL; int rr_count = 0; char name[DNS_MAX_CNAME_LEN] = {0}; + struct dns_request *request = context->request; + struct dns_packet *packet = context->packet; for (j = 1; j < DNS_RRS_END; j++) { rrs = dns_get_rrs_start(packet, j, &rr_count); @@ -2200,6 +2213,7 @@ static int _dns_server_get_answer(struct dns_request *request, struct dns_packet request->ttl_v4 = _dns_server_get_conf_ttl(ttl); request->has_ipv4 = 1; request->rcode = packet->head.rcode; + context->ip_num++; } break; case DNS_T_AAAA: { unsigned char addr[16]; @@ -2214,6 +2228,7 @@ static int _dns_server_get_answer(struct dns_request *request, struct dns_packet request->ttl_v6 = _dns_server_get_conf_ttl(ttl); request->has_ipv6 = 1; request->rcode = packet->head.rcode; + context->ip_num++; } break; case DNS_T_NS: { char cname[DNS_MAX_CNAME_LEN]; @@ -2260,18 +2275,17 @@ static int _dns_server_reply_passthrouth(struct dns_server_post_context *context return 0; } - _dns_server_get_answer(request, context->packet); + _dns_server_get_answer(context); if (request->result_callback) { _dns_result_callback(request); } - if (request->conn) { - /* When passthrough, modify the id to be the id of the client request. */ - dns_server_update_reply_packet_id(request, context->inpacket, context->inpacket_len); - ret = _dns_reply_inpacket(request, context->inpacket, context->inpacket_len); - } - if (context->packet->head.rcode != DNS_RC_NOERROR && context->packet->head.rcode != DNS_RC_NXDOMAIN) { + if (request->conn && context->do_reply == 1) { + /* When passthrough, modify the id to be the id of the client request. */ + dns_server_update_reply_packet_id(request, context->inpacket, context->inpacket_len); + ret = _dns_reply_inpacket(request, context->inpacket, context->inpacket_len); + } return ret; } @@ -2283,7 +2297,20 @@ static int _dns_server_reply_passthrouth(struct dns_server_post_context *context _dns_server_audit_log(context); - return 0; + if (request->conn && context->do_reply == 1) { + /* When passthrough, modify the id to be the id of the client request. */ + dns_server_update_reply_packet_id(request, context->inpacket, context->inpacket_len); + struct dns_update_param param; + param.id = request->id; + param.ip_ttl = context->reply_ttl; + if (dns_packet_update(context->inpacket, context->inpacket_len, ¶m) != 0) { + tlog(TLOG_ERROR, "update cache info failed."); + return -1; + } + ret = _dns_reply_inpacket(request, context->inpacket, context->inpacket_len); + } + + return _dns_server_reply_all_pending_list(request, context); } static int dns_server_resolve_callback(char *domain, dns_result_type rtype, unsigned int result_flag, @@ -2308,6 +2335,11 @@ static int dns_server_resolve_callback(char *domain, dns_result_type rtype, unsi } _dns_server_post_context_init_from(&context, request, packet, inpacket, inpacket_len); + context.do_cache = 1; + context.do_audit = 1; + context.do_reply = 1; + context.do_ipset = 1; + context.reply_ttl = -1; return _dns_server_reply_passthrouth(&context); } _dns_server_process_answer(request, domain, packet, result_flag); @@ -2767,47 +2799,29 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct struct dns_cache_packet *cache_packet = (struct dns_cache_packet *)dns_cache_get_data(dns_cache); if (cache_packet->head.cache_type != CACHE_TYPE_PACKET) { - goto errout; + return -1; } if (dns_cache->info.qtype != request->qtype) { - goto errout; - } - - if (atomic_inc_return(&request->notified) != 1) { - return 0; + return -1; } struct dns_server_post_context context; _dns_server_post_context_init(&context, request); + context.inpacket = cache_packet->data; + context.inpacket_len = cache_packet->head.size; - context.do_audit = 1; if (dns_decode(context.packet, context.packet_maxlen, cache_packet->data, cache_packet->head.size) != 0) { - goto errout; + return -1; } - _dns_server_audit_log(&context); - if (request->result_callback) { - _dns_result_callback(request); - } + context.do_cache = 0; + context.do_ipset = 0; + context.do_audit = 1; + context.do_reply = 1; + context.reply_ttl = _dns_server_get_expired_ttl_reply(dns_cache); - if (request->conn == NULL) { - return 0; - } - - /* When passthrough, modify the id to be the id of the client request. */ - struct dns_update_param param; - param.id = request->id; - param.cname_ttl = _dns_server_get_expired_ttl_reply(dns_cache); - param.ip_ttl = _dns_server_get_expired_ttl_reply(dns_cache); - if (dns_packet_update(cache_packet->data, cache_packet->head.size, ¶m) != 0) { - tlog(TLOG_ERROR, "update cache info failed."); - goto errout; - } - - return _dns_reply_inpacket(request, cache_packet->data, cache_packet->head.size); -errout: - return -1; + return _dns_server_reply_passthrouth(&context); } static int _dns_server_process_cache_data(struct dns_request *request, struct dns_cache *dns_cache)