diff --git a/src/dns.c b/src/dns.c index da37327..03b88d1 100644 --- a/src/dns.c +++ b/src/dns.c @@ -887,11 +887,12 @@ int dns_add_OPT_ECS(struct dns_packet *packet, struct dns_opt_ecs *ecs) int dns_get_OPT_ECS(struct dns_rrs *rrs, unsigned short *opt_code, unsigned short *opt_len, struct dns_opt_ecs *ecs) { unsigned char opt_data[DNS_MAX_OPT_LEN]; + char domain[DNS_MAX_CNAME_LEN] = {0}; struct dns_opt *opt = (struct dns_opt *)opt_data; int len = DNS_MAX_OPT_LEN; int ttl = 0; - if (_dns_get_RAW(rrs, NULL, 0, &ttl, opt_data, &len) != 0) { + if (_dns_get_RAW(rrs, domain, DNS_MAX_CNAME_LEN, &ttl, opt_data, &len) != 0) { return -1; } diff --git a/src/dns_client.c b/src/dns_client.c index 5314020..3a873c6 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -66,13 +66,7 @@ /* ECS info */ struct dns_client_ecs { int enable; - unsigned int family; - unsigned int bitlen; - union { - unsigned char ipv4_addr[DNS_RR_A_LEN]; - unsigned char ipv6_addr[DNS_RR_AAAA_LEN]; - unsigned char addr[0]; - }; + struct dns_opt_ecs ecs; }; /* TCP/TLS buffer */ @@ -244,6 +238,9 @@ struct dns_query_struct { /* has result */ int has_result; + /* ECS */ + struct dns_client_ecs ecs; + /* replied hash table */ DECLARE_HASHTABLE(replied_map, 4); }; @@ -2873,42 +2870,13 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, return 0; } -static int _dns_client_dns_add_ecs(struct dns_packet *packet, int qtype) +static int _dns_client_dns_add_ecs(struct dns_query_struct *query, struct dns_packet *packet) { - int add_ipv4_ecs = 0; - int add_ipv6_ecs = 0; - - if (qtype == DNS_T_A && client.ecs_ipv4.enable) { - add_ipv4_ecs = 1; - } else if (qtype == DNS_T_AAAA && client.ecs_ipv6.enable) { - add_ipv6_ecs = 1; - } else { - if (client.ecs_ipv4.enable) { - add_ipv4_ecs = 1; - } else if (client.ecs_ipv6.enable) { - add_ipv4_ecs = 1; - } + if (query->ecs.enable == 0) { + return 0; } - if (add_ipv4_ecs) { - struct dns_opt_ecs ecs; - ecs.family = DNS_ADDR_FAMILY_IP; - ecs.source_prefix = client.ecs_ipv4.bitlen; - ecs.scope_prefix = 0; - memcpy(ecs.addr, client.ecs_ipv4.ipv4_addr, DNS_RR_A_LEN); - return dns_add_OPT_ECS(packet, &ecs); - } - - if (add_ipv6_ecs) { - struct dns_opt_ecs ecs; - ecs.family = DNS_ADDR_FAMILY_IPV6; - ecs.source_prefix = client.ecs_ipv6.bitlen; - ecs.scope_prefix = 0; - memcpy(ecs.addr, client.ecs_ipv6.ipv6_addr, DNS_RR_AAAA_LEN); - return dns_add_OPT_ECS(packet, &ecs); - } - - return 0; + return dns_add_OPT_ECS(packet, &query->ecs.ecs); } static int _dns_client_send_query(struct dns_query_struct *query, char *doamin) @@ -2942,7 +2910,7 @@ static int _dns_client_send_query(struct dns_query_struct *query, char *doamin) dns_set_OPT_payload_size(packet, DNS_IN_PACKSIZE); /* dns_add_OPT_TCP_KEEYALIVE(packet, 600); */ - if (_dns_client_dns_add_ecs(packet, query->qtype) != 0) { + if (_dns_client_dns_add_ecs(query, packet) != 0) { tlog(TLOG_ERROR, "add ecs failed."); return -1; } @@ -2964,7 +2932,102 @@ static int _dns_client_send_query(struct dns_query_struct *query, char *doamin) return _dns_client_send_packet(query, inpacket, encode_len); } -int dns_client_query(char *domain, int qtype, dns_client_callback callback, void *user_ptr, const char *group_name) +int _dns_client_query_setup_default_ecs(struct dns_query_struct *query) +{ + int add_ipv4_ecs = 0; + int add_ipv6_ecs = 0; + + if (query->qtype == DNS_T_A && client.ecs_ipv4.enable) { + add_ipv4_ecs = 1; + } else if (query->qtype == DNS_T_AAAA && client.ecs_ipv6.enable) { + add_ipv6_ecs = 1; + } else { + if (client.ecs_ipv4.enable) { + add_ipv4_ecs = 1; + } else if (client.ecs_ipv6.enable) { + add_ipv4_ecs = 1; + } + } + + if (add_ipv4_ecs) { + memcpy(&query->ecs, &client.ecs_ipv4, sizeof(query->ecs)); + return 0; + } + + if (add_ipv6_ecs) { + memcpy(&query->ecs, &client.ecs_ipv6, sizeof(query->ecs)); + return 0; + } + + return 0; +} + +int _dns_client_query_parser_options(struct dns_query_struct *query, struct dns_query_options *options) +{ + if (options == NULL) { + _dns_client_query_setup_default_ecs(query); + return 0; + } + + if (options->enable_flag & DNS_QUEY_OPTION_ECS_IP) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + struct dns_opt_ecs *ecs; + + ecs = &query->ecs.ecs; + getaddr_by_host(options->ecs_ip.ip, (struct sockaddr *)&addr, &addr_len); + + query->ecs.enable = 1; + ecs->source_prefix = options->ecs_ip.subnet; + ecs->scope_prefix = 0; + + switch (addr.ss_family) { + case AF_INET: { + struct sockaddr_in *addr_in; + addr_in = (struct sockaddr_in *)&addr; + ecs->family = DNS_OPT_ECS_FAMILY_IPV4; + memcpy(&ecs->addr, &addr_in->sin_addr.s_addr, 4); + } break; + case AF_INET6: { + struct sockaddr_in6 *addr_in6; + addr_in6 = (struct sockaddr_in6 *)&addr; + if (IN6_IS_ADDR_V4MAPPED(&addr_in6->sin6_addr)) { + memcpy(&ecs->addr, addr_in6->sin6_addr.s6_addr + 12, 4); + ecs->family = DNS_OPT_ECS_FAMILY_IPV4; + } else { + memcpy(&ecs->addr, addr_in6->sin6_addr.s6_addr, 16); + ecs->family = DNS_OPT_ECS_FAMILY_IPV6; + } + } break; + default: + tlog(TLOG_WARN, "ECS set failure."); + break; + } + } + + if (options->enable_flag & DNS_QUEY_OPTION_ECS_DNS) { + struct dns_opt_ecs *ecs = &options->ecs_dns; + if (ecs->family != DNS_OPT_ECS_FAMILY_IPV6 && ecs->family != DNS_OPT_ECS_FAMILY_IPV4) { + return -1; + } + + if (ecs->family == DNS_OPT_ECS_FAMILY_IPV4 && ecs->source_prefix > 32) { + return -1; + } + + if (ecs->family == DNS_OPT_ECS_FAMILY_IPV6 && ecs->source_prefix > 128) { + return -1; + } + + memcpy(&query->ecs.ecs, ecs, sizeof(query->ecs.ecs)); + query->ecs.enable = 1; + } + + return 0; +} + +int dns_client_query(char *domain, int qtype, dns_client_callback callback, void *user_ptr, const char *group_name, + struct dns_query_options *options) { struct dns_query_struct *query = NULL; int ret = 0; @@ -2999,6 +3062,11 @@ int dns_client_query(char *domain, int qtype, dns_client_callback callback, void goto errout; } + if (_dns_client_query_parser_options(query, options) != 0) { + tlog(TLOG_ERROR, "parser options for %s failed.", domain); + goto errout; + } + _dns_client_query_get(query); /* add query to hashtable */ key = hash_string(domain); @@ -3299,20 +3367,25 @@ int dns_client_set_ecs(char *ip, int subnet) case AF_INET: { struct sockaddr_in *addr_in; addr_in = (struct sockaddr_in *)&addr; - memcpy(&client.ecs_ipv4.ipv4_addr, &addr_in->sin_addr.s_addr, 4); - client.ecs_ipv4.bitlen = subnet; + memcpy(&client.ecs_ipv4.ecs.addr, &addr_in->sin_addr.s_addr, 4); + client.ecs_ipv4.ecs.source_prefix = subnet; + client.ecs_ipv4.ecs.scope_prefix = 0; + client.ecs_ipv4.ecs.family = DNS_OPT_ECS_FAMILY_IPV4; client.ecs_ipv4.enable = 1; } break; case AF_INET6: { struct sockaddr_in6 *addr_in6; addr_in6 = (struct sockaddr_in6 *)&addr; if (IN6_IS_ADDR_V4MAPPED(&addr_in6->sin6_addr)) { - memcpy(&client.ecs_ipv4.ipv4_addr, addr_in6->sin6_addr.s6_addr + 12, 4); - client.ecs_ipv4.bitlen = subnet; + client.ecs_ipv4.ecs.source_prefix = subnet; + client.ecs_ipv4.ecs.scope_prefix = 0; + client.ecs_ipv4.ecs.family = DNS_OPT_ECS_FAMILY_IPV4; client.ecs_ipv4.enable = 1; } else { - memcpy(&client.ecs_ipv6.ipv6_addr, addr_in6->sin6_addr.s6_addr, 16); - client.ecs_ipv6.bitlen = subnet; + memcpy(&client.ecs_ipv6.ecs.addr, addr_in6->sin6_addr.s6_addr, 16); + client.ecs_ipv6.ecs.source_prefix = subnet; + client.ecs_ipv6.ecs.scope_prefix = 0; + client.ecs_ipv6.ecs.family = DNS_ADDR_FAMILY_IPV6; client.ecs_ipv6.enable = 1; } } break; diff --git a/src/dns_client.h b/src/dns_client.h index c04a5d8..a41b94c 100644 --- a/src/dns_client.h +++ b/src/dns_client.h @@ -47,6 +47,9 @@ typedef enum dns_result_type { #define DNSSERVER_FLAG_CHECK_EDNS (0x1 << 2) #define DNSSERVER_FLAG_CHECK_TTL (0x1 << 3) +#define DNS_QUEY_OPTION_ECS_DNS (1 << 0) +#define DNS_QUEY_OPTION_ECS_IP (1 << 1) + int dns_client_init(void); int dns_client_set_ecs(char *ip, int subnet); @@ -56,8 +59,20 @@ typedef int (*dns_client_callback)(char *domain, dns_result_type rtype, unsigned struct dns_packet *packet, unsigned char *inpacket, int inpacket_len, void *user_ptr); +struct dns_query_ecs_ip { + char ip[DNS_MAX_CNAME_LEN]; + int subnet; +}; + +struct dns_query_options { + unsigned long long enable_flag; + struct dns_opt_ecs ecs_dns; + struct dns_query_ecs_ip ecs_ip; +}; + /* query domain */ -int dns_client_query(char *domain, int qtype, dns_client_callback callback, void *user_ptr, const char *group_name); +int dns_client_query(char *domain, int qtype, dns_client_callback callback, void *user_ptr, const char *group_name, + struct dns_query_options *options); void dns_client_exit(void); diff --git a/src/dns_server.c b/src/dns_server.c index bd44b4f..64e4e1e 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -184,6 +184,8 @@ struct dns_request { struct sockaddr addr; }; struct sockaddr_storage localaddr; + int has_ecs; + struct dns_opt_ecs ecs; dns_result_callback result_callback; void *user_ptr; @@ -252,7 +254,7 @@ static tlog_log *dns_audit; static int is_ipv6_ready; -static int _dns_server_prefetch_request(char *domain, dns_type_t qtype, uint32_t server_flags); +static int _dns_server_prefetch_request(char *domain, dns_type_t qtype, uint32_t server_flags, struct dns_query_options *options); static int _dns_server_forward_request(unsigned char *inpacket, int inpacket_len) { @@ -685,6 +687,10 @@ static int _dns_add_rrs(struct dns_server_post_context *context) _dns_server_setup_soa(request); ret |= dns_add_SOA(context->packet, DNS_RRS_NS, domain, 0, &request->soa); } + + if (request->has_ecs) { + ret |= dns_add_OPT_ECS(context->packet, &request->ecs); + } return ret; } @@ -3346,10 +3352,17 @@ static int _dns_server_process_cache(struct dns_request *request) out_update_cache: if (dns_cache_get_ttl(dns_cache) == 0) { uint32_t server_flags = request->server_flags; + struct dns_query_options options; if (request->conn == NULL) { server_flags = dns_cache_get_cache_flag(dns_cache->cache_data); } - _dns_server_prefetch_request(request->domain, request->qtype, server_flags); + + options.enable_flag = 0; + if (request->has_ecs) { + options.enable_flag |= DNS_QUEY_OPTION_ECS_DNS; + memcpy(&options.ecs_dns, &request->ecs, sizeof(options.ecs_dns)); + } + _dns_server_prefetch_request(request->domain, request->qtype, server_flags, &options); } else { dns_cache_update(dns_cache); } @@ -3623,19 +3636,29 @@ errout: return -1; } -static int _dns_server_do_query(struct dns_request *request, const char *domain, int qtype) +static int _dns_server_setup_query_option(struct dns_request *request, struct dns_query_options *options) +{ + options->enable_flag = 0; + + if (request->has_ecs) { + memcpy(&options->ecs_dns, &request->ecs, sizeof(options->ecs_dns)); + options->enable_flag |= DNS_QUEY_OPTION_ECS_DNS; + } + + return 0; +} + +static int _dns_server_do_query(struct dns_request *request) { int ret = -1; const char *group_name = NULL; const char *dns_group = NULL; + struct dns_query_options options; if (request->conn) { dns_group = request->conn->dns_group; } - safe_strncpy(request->domain, domain, sizeof(request->domain)); - request->qtype = qtype; - /* lookup domain rule */ _dns_server_get_domain_rule(request); group_name = _dns_server_get_request_groupname(request); @@ -3686,6 +3709,9 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain, goto clean_exit; } + // setup options + _dns_server_setup_query_option(request, &options); + // Get reference for server thread _dns_server_request_get(request); pthread_mutex_lock(&server.request_list_lock); @@ -3694,11 +3720,11 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain, request->send_tick = get_tick_count(); /* When the dual stack ip preference is enabled, both A and AAAA records are requested. */ - if (qtype == DNS_T_AAAA && request->dualstack_selection) { + if (request->qtype == DNS_T_AAAA && request->dualstack_selection) { // Get reference for AAAA query _dns_server_request_get(request); request->request_wait++; - if (dns_client_query(request->domain, DNS_T_A, dns_server_resolve_callback, request, group_name) != 0) { + if (dns_client_query(request->domain, DNS_T_A, dns_server_resolve_callback, request, group_name, &options) != 0) { request->request_wait--; _dns_server_request_release(request); } @@ -3707,7 +3733,7 @@ static int _dns_server_do_query(struct dns_request *request, const char *domain, // Get reference for DNS query request->request_wait++; _dns_server_request_get(request); - if (dns_client_query(request->domain, qtype, dns_server_resolve_callback, request, group_name) != 0) { + if (dns_client_query(request->domain, request->qtype, dns_server_resolve_callback, request, group_name, &options) != 0) { request->request_wait--; _dns_server_request_release(request); tlog(TLOG_ERROR, "send dns request failed."); @@ -3722,6 +3748,60 @@ errout: return ret; } +static int _dns_server_parser_request(struct dns_request *request, struct dns_packet *packet) +{ + struct dns_rrs *rrs; + int rr_count = 0; + int i = 0; + int ret = 0; + int qclass; + int qtype = DNS_T_ALL; + char domain[DNS_MAX_CNAME_LEN]; + + if (packet->head.qr != DNS_QR_QUERY) { + goto errout; + } + + /* get request domain and request qtype */ + rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count); + if (rr_count > 1 || rr_count <= 0) { + goto errout; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + ret = dns_get_domain(rrs, domain, sizeof(domain), &qtype, &qclass); + if (ret != 0) { + goto errout; + } + + // Only support one question. + safe_strncpy(request->domain, domain, sizeof(request->domain)); + request->qtype = qtype; + break; + } + + + /* get request opts */ + rr_count = 0; + rrs = dns_get_rrs_start(packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return 0; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &request->ecs); + if (ret != 0) { + continue; + } + request->has_ecs = 1; + break; + } + + return 0; +errout: + return -1; +} + static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *inpacket, int inpacket_len, struct sockaddr_storage *local, socklen_t local_len, struct sockaddr_storage *from, socklen_t from_len) @@ -3730,14 +3810,9 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in int ret = -1; unsigned char packet_buff[DNS_PACKSIZE]; char name[DNS_MAX_CNAME_LEN]; - char domain[DNS_MAX_CNAME_LEN]; struct dns_packet *packet = (struct dns_packet *)packet_buff; struct dns_request *request = NULL; - struct dns_rrs *rrs; - int rr_count = 0; - int i = 0; - int qclass; - int qtype = DNS_T_ALL; + /* decode packet */ tlog(TLOG_DEBUG, "recv query packet from %s, len = %d", @@ -3754,40 +3829,25 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in packet->head.qdcount, packet->head.ancount, packet->head.nscount, packet->head.nrcount, inpacket_len, packet->head.id, packet->head.tc, packet->head.rd, packet->head.ra, packet->head.rcode); - if (packet->head.qr != DNS_QR_QUERY) { - goto errout; - } - - /* get request domain and request qtype */ - rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &rr_count); - if (rr_count > 1) { - goto errout; - } - - for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { - ret = dns_get_domain(rrs, domain, sizeof(domain), &qtype, &qclass); - if (ret != 0) { - goto errout; - } - - // Only support one question. - break; - } - tlog(TLOG_INFO, "query server %s from %s, qtype = %d\n", domain, name, qtype); - request = _dns_server_new_request(); if (request == NULL) { tlog(TLOG_ERROR, "malloc failed.\n"); goto errout; } + if (_dns_server_parser_request(request, packet) != 0) { + goto errout; + } + + tlog(TLOG_INFO, "query server %s from %s, qtype = %d\n", request->domain, name, request->qtype); + memcpy(&request->localaddr, local, local_len); _dns_server_request_set_client(request, conn); _dns_server_request_set_client_addr(request, from, from_len); _dns_server_request_set_id(request, packet->head.id); - ret = _dns_server_do_query(request, domain, qtype); + ret = _dns_server_do_query(request); if (ret != 0) { - tlog(TLOG_ERROR, "do query %s failed.\n", domain); + tlog(TLOG_ERROR, "do query %s failed.\n", request->domain); goto errout; } _dns_server_request_release_complete(request, 0); @@ -3801,7 +3861,21 @@ errout: return ret; } -static int _dns_server_prefetch_request(char *domain, dns_type_t qtype, uint32_t server_flags) +static int _dns_server_prefetch_setup_options(struct dns_request *request, struct dns_query_options *options) +{ + if (options == NULL) { + return 0; + } + + if (options->enable_flag & DNS_QUEY_OPTION_ECS_DNS) { + request->has_ecs = 1; + memcpy(&request->ecs, &options->ecs_dns, sizeof(request->ecs)); + } + + return 0; +} + +static int _dns_server_prefetch_request(char *domain, dns_type_t qtype, uint32_t server_flags, struct dns_query_options *options) { int ret = -1; struct dns_request *request = NULL; @@ -3812,11 +3886,14 @@ static int _dns_server_prefetch_request(char *domain, dns_type_t qtype, uint32_t goto errout; } + safe_strncpy(request->domain, domain, sizeof(request->domain)); + request->qtype = qtype; request->server_flags = server_flags; + _dns_server_prefetch_setup_options(request, options); _dns_server_request_set_enable_prefetch(request); - ret = _dns_server_do_query(request, domain, qtype); + ret = _dns_server_do_query(request); if (ret != 0) { - tlog(TLOG_ERROR, "do query %s failed.\n", domain); + tlog(TLOG_ERROR, "do query %s failed.\n", request->domain); goto errout; } @@ -3842,8 +3919,10 @@ int dns_server_query(char *domain, int qtype, uint32_t server_flags, dns_result_ } request->server_flags = server_flags; + safe_strncpy(request->domain, domain, sizeof(request->domain)); + request->qtype = qtype; _dns_server_request_set_callback(request, callback, user_ptr); - ret = _dns_server_do_query(request, domain, qtype); + ret = _dns_server_do_query(request); if (ret != 0) { tlog(TLOG_ERROR, "do query %s failed.\n", domain); goto errout; @@ -4239,7 +4318,7 @@ static void _dns_server_prefetch_domain(struct dns_cache *dns_cache) tlog(TLOG_DEBUG, "prefetch by cache %s, qtype %d, ttl %d, hitnum %d", dns_cache->info.domain, dns_cache->info.qtype, dns_cache->info.ttl, hitnum); if (_dns_server_prefetch_request(dns_cache->info.domain, dns_cache->info.qtype, - dns_cache_get_cache_flag(dns_cache->cache_data)) != 0) { + dns_cache_get_cache_flag(dns_cache->cache_data), NULL) != 0) { tlog(TLOG_ERROR, "prefetch domain %s, qtype %d, failed.", dns_cache->info.domain, dns_cache->info.qtype); } }