diff --git a/src/dns_cache.c b/src/dns_cache.c index 1cbeb99..4df641f 100644 --- a/src/dns_cache.c +++ b/src/dns_cache.c @@ -77,7 +77,8 @@ static void _dns_cache_delete(struct dns_cache *dns_cache) hash_del(&dns_cache->node); list_del_init(&dns_cache->list); atomic_dec(&dns_cache_head.num); - dns_cache_data_free(dns_cache->cache_data); + dns_cache_data_put(dns_cache->cache_data); + dns_cache->cache_data = NULL; free(dns_cache); } @@ -125,15 +126,6 @@ const char *dns_cache_get_dns_group_name(struct dns_cache *dns_cache) return dns_cache->info.dns_group_name; } -void dns_cache_data_free(struct dns_cache_data *data) -{ - if (data == NULL) { - return; - } - - free(data); -} - struct dns_cache_data *dns_cache_new_data_addr(void) { struct dns_cache_addr *cache_addr = malloc(sizeof(struct dns_cache_addr)); @@ -145,6 +137,7 @@ struct dns_cache_data *dns_cache_new_data_addr(void) cache_addr->head.cache_type = CACHE_TYPE_NONE; cache_addr->head.size = sizeof(struct dns_cache_addr) - sizeof(struct dns_cache_data_head); cache_addr->head.magic = MAGIC_CACHE_DATA; + atomic_set(&cache_addr->head.ref, 1); return (struct dns_cache_data *)cache_addr; } @@ -230,6 +223,7 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len cache_packet->head.cache_type = CACHE_TYPE_PACKET; cache_packet->head.size = packet_len; cache_packet->head.magic = MAGIC_CACHE_DATA; + atomic_set(&cache_packet->head.ref, 1); return (struct dns_cache_data *)cache_packet; } @@ -292,9 +286,6 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee if (cache_data) { old_cache_data = dns_cache->cache_data; dns_cache->cache_data = cache_data; - if (old_cache_data == cache_data) { - old_cache_data = NULL; - } } if (update_time) { time(&dns_cache->info.insert_time); @@ -306,7 +297,7 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee pthread_mutex_unlock(&dns_cache_head.lock); if (old_cache_data) { - dns_cache_data_free(old_cache_data); + dns_cache_data_put(old_cache_data); } dns_cache_release(dns_cache); return 0; @@ -398,9 +389,10 @@ static int _dns_cache_insert(struct dns_cache_info *info, struct dns_cache_data _dns_cache_remove(del_cache); } } - pthread_mutex_unlock(&dns_cache_head.lock); + dns_cache_get(dns_cache); dns_timer_add(&dns_cache->timer); + pthread_mutex_unlock(&dns_cache_head.lock); return 0; errout: @@ -421,7 +413,7 @@ int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int ti } if (dns_cache_head.size <= 0) { - dns_cache_data_free(cache_data); + dns_cache_data_put(cache_data); return 0; } @@ -517,14 +509,20 @@ int dns_cache_get_cname_ttl(struct dns_cache *dns_cache) time(&now); struct dns_cache_addr *cache_addr = (struct dns_cache_addr *)dns_cache_get_data(dns_cache); + if (cache_addr == NULL) { + ttl = 0; + goto out; + } if (cache_addr->head.cache_type != CACHE_TYPE_ADDR) { - return 0; + ttl = 0; + goto out; } ttl = dns_cache->info.insert_time + cache_addr->addr_data.cname_ttl - now; if (ttl < 0) { - return 0; + ttl = 0; + goto out; } int addr_ttl = dns_cache_get_ttl(dns_cache); @@ -533,7 +531,13 @@ int dns_cache_get_cname_ttl(struct dns_cache *dns_cache) } if (ttl < 0) { - return 0; + ttl = 0; + goto out; + } + +out: + if (cache_addr) { + dns_cache_data_put((struct dns_cache_data *)cache_addr); } return ttl; @@ -554,7 +558,35 @@ int dns_cache_is_soa(struct dns_cache *dns_cache) struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache) { - return dns_cache->cache_data; + struct dns_cache_data *cache_data; + pthread_mutex_lock(&dns_cache_head.lock); + dns_cache_data_get(dns_cache->cache_data); + cache_data = dns_cache->cache_data; + pthread_mutex_unlock(&dns_cache_head.lock); + return cache_data; +} + +void dns_cache_data_get(struct dns_cache_data *cache_data) +{ + if (atomic_inc_return(&cache_data->head.ref) == 1) { + tlog(TLOG_ERROR, "BUG: dns_cache data is invalid."); + return; + } + + return; +} + +void dns_cache_data_put(struct dns_cache_data *cache_data) +{ + if (cache_data == NULL) { + return; + } + + if (!atomic_dec_and_test(&cache_data->head.ref)) { + return; + } + + free(cache_data); } int dns_cache_is_visited(struct dns_cache *dns_cache) @@ -625,10 +657,11 @@ static int _dns_cache_read_to_cache(struct dns_cache_record *cache_record, struc goto errout; } + dns_cache_data_get(cache_data); + daemon_keepalive(); - /* keep cache_data */ - return -2; + return 0; errout: return -1; } @@ -676,6 +709,7 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number, dns_cache_read_ } memcpy(&cache_data->head, &data_head, sizeof(data_head)); + atomic_set(&cache_data->head.ref, 1); ret = read(fd, cache_data->data, data_head.size); if (ret != data_head.size) { tlog(TLOG_ERROR, "read cache data failed, %s", strerror(errno)); @@ -687,20 +721,17 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number, dns_cache_read_ cache_record.info.domain[DNS_MAX_CNAME_LEN - 1] = '\0'; cache_record.info.dns_group_name[DNS_GROUP_NAME_LEN - 1] = '\0'; ret = callback(&cache_record, cache_data); - if (ret == -2) { - cache_data = NULL; - } else if (ret != 0) { + dns_cache_data_put(cache_data); + cache_data = NULL; + if (ret != 0) { goto errout; - } else { - free(cache_data); - cache_data = NULL; } } return 0; errout: if (cache_data) { - free(cache_data); + dns_cache_data_put(cache_data); } return -1; } @@ -763,6 +794,8 @@ static int _dns_cache_write_record(int fd, uint32_t *cache_number, struct list_h struct dns_cache *tmp = NULL; struct dns_cache_record cache_record; + memset(&cache_record, 0, sizeof(cache_record)); + pthread_mutex_lock(&dns_cache_head.lock); list_for_each_entry_safe(dns_cache, tmp, head, list) { diff --git a/src/dns_cache.h b/src/dns_cache.h index a74f782..8e0dfb8 100644 --- a/src/dns_cache.h +++ b/src/dns_cache.h @@ -48,6 +48,7 @@ enum CACHE_TYPE { struct dns_cache_data_head { enum CACHE_TYPE cache_type; + atomic_t ref; int is_soa; ssize_t size; uint32_t magic; @@ -130,8 +131,6 @@ uint32_t dns_cache_get_query_flag(struct dns_cache *dns_cache); const char *dns_cache_get_dns_group_name(struct dns_cache *dns_cache); -void dns_cache_data_free(struct dns_cache_data *data); - struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len); typedef int (*dns_cache_callback)(struct dns_cache *dns_cache); @@ -167,6 +166,10 @@ struct dns_cache_data *dns_cache_new_data_addr(void); struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache); +void dns_cache_data_get(struct dns_cache_data *cache_data); + +void dns_cache_data_put(struct dns_cache_data *cache_data); + void dns_cache_set_data_addr(struct dns_cache_data *dns_cache, char *cname, int cname_ttl, unsigned char *addr, int addr_len); diff --git a/src/dns_server.c b/src/dns_server.c index 0964ac9..e9483fd 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -1355,7 +1355,7 @@ static int _dns_server_request_update_cache(struct dns_request *request, dns_typ return 0; errout: if (cache_data) { - dns_cache_data_free(cache_data); + dns_cache_data_put(cache_data); } return -1; } @@ -1499,7 +1499,7 @@ static int _dns_cache_cname_packet(struct dns_server_post_context *context) return 0; errout: if (cache_packet) { - dns_cache_data_free(cache_packet); + dns_cache_data_put((struct dns_cache_data *)cache_packet); } return -1; @@ -1539,7 +1539,7 @@ static int _dns_cache_packet(struct dns_server_post_context *context) return 0; errout: if (cache_packet) { - dns_cache_data_free(cache_packet); + dns_cache_data_put((struct dns_cache_data *)cache_packet); } return -1; @@ -4889,11 +4889,16 @@ errout: static int _dns_server_process_cache_packet(struct dns_request *request, struct dns_cache *dns_cache) { + int ret = -1; struct dns_cache_packet *cache_packet = (struct dns_cache_packet *)dns_cache_get_data(dns_cache); + if (cache_packet == NULL) { + goto out; + } + int do_ipset = (dns_cache_get_ttl(dns_cache) == 0); if (cache_packet->head.cache_type != CACHE_TYPE_PACKET) { - return -1; + goto out; } if (dns_cache_is_visited(dns_cache) == 0) { @@ -4901,7 +4906,7 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct } if (dns_cache->info.qtype != request->qtype) { - return -1; + goto out; } struct dns_server_post_context context; @@ -4912,7 +4917,7 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct if (dns_decode(context.packet, context.packet_maxlen, cache_packet->data, cache_packet->head.size) != 0) { tlog(TLOG_ERROR, "decode cache failed, %d, %d", context.packet_maxlen, context.inpacket_len); - return -1; + goto out; } request->rcode = context.packet->head.rcode; @@ -4921,7 +4926,13 @@ static int _dns_server_process_cache_packet(struct dns_request *request, struct context.do_audit = 1; context.do_reply = 1; context.reply_ttl = _dns_server_get_expired_ttl_reply(dns_cache); - return _dns_server_reply_passthrough(&context); + ret = _dns_server_reply_passthrough(&context); +out: + if (cache_packet) { + dns_cache_data_put((struct dns_cache_data *)cache_packet); + } + + return ret; } static int _dns_server_process_cache_data(struct dns_request *request, struct dns_cache *dns_cache)