diff --git a/etc/smartdns/smartdns.conf b/etc/smartdns/smartdns.conf index 8fe4cee..d31abb0 100644 --- a/etc/smartdns/smartdns.conf +++ b/etc/smartdns/smartdns.conf @@ -187,6 +187,7 @@ log-level info # -proxy [proxy-name]: use proxy to connect to server. # -bootstrap-dns: set as bootstrap dns server. # -set-mark: set mark on packets. +# -subnet [ip/subnet]: set edns client subnet. # server 8.8.8.8 -blacklist-ip -check-edns -group g1 -group g2 # server tls://dns.google:853 # server https://dns.google/dns-query diff --git a/src/dns.h b/src/dns.h index 49a7ed8..ad24a13 100644 --- a/src/dns.h +++ b/src/dns.h @@ -280,17 +280,14 @@ int dns_add_OPT_TCP_KEEPALIVE(struct dns_packet *packet, unsigned short timeout) int dns_get_OPT_TCP_KEEPALIVE(struct dns_rrs *rrs, unsigned short *opt_code, unsigned short *opt_len, unsigned short *timeout); -int dns_add_HTTPS_start(struct dns_rr_nested *svcparam_buffer, struct dns_packet *packet, - dns_rr_type type, const char *domain, int ttl, int priority, - const char *target); +int dns_add_HTTPS_start(struct dns_rr_nested *svcparam_buffer, struct dns_packet *packet, dns_rr_type type, + const char *domain, int ttl, int priority, const char *target); int dns_HTTPS_add_raw(struct dns_rr_nested *svcparam, unsigned short key, unsigned char *value, unsigned short len); int dns_HTTPS_add_port(struct dns_rr_nested *svcparam, unsigned short port); int dns_HTTPS_add_alpn(struct dns_rr_nested *svcparam, const char *alpn); int dns_HTTPS_add_no_default_alpn(struct dns_rr_nested *svcparam); -int dns_HTTPS_add_ipv4hint(struct dns_rr_nested *svcparam, unsigned char addr[][DNS_RR_A_LEN], - int addr_num); -int dns_HTTPS_add_ipv6hint(struct dns_rr_nested *svcparam, unsigned char addr[][DNS_RR_AAAA_LEN], - int addr_num); +int dns_HTTPS_add_ipv4hint(struct dns_rr_nested *svcparam, unsigned char addr[][DNS_RR_A_LEN], int addr_num); +int dns_HTTPS_add_ipv6hint(struct dns_rr_nested *svcparam, unsigned char addr[][DNS_RR_AAAA_LEN], int addr_num); int dns_HTTPS_add_ech(struct dns_rr_nested *svcparam, void *ech, int ech_len); int dns_add_HTTPS_end(struct dns_rr_nested *svcparam); diff --git a/src/dns_client.c b/src/dns_client.c index 48f84a2..68e5b92 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -136,6 +136,10 @@ struct dns_server_info { }; struct client_dns_server_flags flags; + + /* ECS */ + struct dns_client_ecs ecs_ipv4; + struct dns_client_ecs ecs_ipv6; }; struct dns_server_pending_group { @@ -271,6 +275,8 @@ static int dns_client_has_bootstrap_dns = 0; static int _dns_client_send_udp(struct dns_server_info *server_info, void *packet, int len); static void _dns_client_clear_wakeup_event(void); static void _dns_client_do_wakeup_event(void); +static int _dns_client_setup_ecs(char *ip, int subnet, struct dns_client_ecs *ecs_ipv4, + struct dns_client_ecs *ecs_ipv6); static ssize_t _ssl_read(struct dns_server_info *server, void *buff, int num) { @@ -988,6 +994,25 @@ errout: return NULL; } +static int _dns_client_server_add_ecs(struct dns_server_info *server_info, struct client_dns_server_flags *flags) +{ + int ret = 0; + + if (flags == NULL) { + return 0; + } + + if (flags->ipv4_ecs.enable) { + ret = _dns_client_setup_ecs(flags->ipv4_ecs.ip, flags->ipv4_ecs.subnet, &server_info->ecs_ipv4, NULL); + } + + if (flags->ipv6_ecs.enable) { + ret |= _dns_client_setup_ecs(flags->ipv6_ecs.ip, flags->ipv6_ecs.subnet, NULL, &server_info->ecs_ipv6); + } + + return ret; +} + /* add dns server information */ static int _dns_client_server_add(char *server_ip, char *server_host, int port, dns_server_type_t server_type, struct client_dns_server_flags *flags) @@ -1083,10 +1108,15 @@ static int _dns_client_server_add(char *server_ip, char *server_host, int port, pthread_mutex_init(&server_info->lock, NULL); memcpy(&server_info->flags, flags, sizeof(server_info->flags)); + if (_dns_client_server_add_ecs(server_info, flags) != 0) { + tlog(TLOG_ERROR, "add %s ecs failed.", server_ip); + goto errout; + } + /* exclude this server from default group */ if ((server_info->flags.server_flag & SERVER_FLAG_EXCLUDE_DEFAULT) == 0) { if (_dns_client_add_to_group(DNS_SERVER_GROUP_DEFAULT, server_info) != 0) { - tlog(TLOG_ERROR, "add server to default group failed."); + tlog(TLOG_ERROR, "add server %s to default group failed.", server_ip); goto errout; } } @@ -3331,6 +3361,80 @@ static int _dns_client_send_https(struct dns_server_info *server_info, void *pac return 0; } +static int _dns_client_setup_server_packet(struct dns_server_info *server_info, struct dns_query_struct *query, + void *default_packet, int default_packet_len, + unsigned char *packet_data_buffer, void **packet_data, int *packet_data_len) +{ + unsigned char packet_buff[DNS_PACKSIZE]; + struct dns_packet *packet = (struct dns_packet *)packet_buff; + struct dns_head head; + int encode_len = 0; + + *packet_data = default_packet; + *packet_data_len = default_packet_len; + + if (query->qtype != DNS_T_AAAA && query->qtype != DNS_T_A) { + /* no need to encode packet */ + return 0; + } + + if (server_info->ecs_ipv4.enable == false && query->qtype == DNS_T_A) { + /* no need to encode packet */ + return 0; + } + + if (server_info->ecs_ipv6.enable == false && query->qtype == DNS_T_AAAA) { + /* no need to encode packet */ + return 0; + } + + /* init dns packet head */ + memset(&head, 0, sizeof(head)); + head.id = query->sid; + head.qr = DNS_QR_QUERY; + head.opcode = DNS_OP_QUERY; + head.aa = 0; + head.rd = 1; + head.ra = 0; + head.rcode = 0; + + if (dns_packet_init(packet, DNS_PACKSIZE, &head) != 0) { + tlog(TLOG_ERROR, "init packet failed."); + return -1; + } + + /* add question */ + if (dns_add_domain(packet, query->domain, query->qtype, DNS_C_IN) != 0) { + tlog(TLOG_ERROR, "add domain to packet failed."); + return -1; + } + + dns_set_OPT_payload_size(packet, DNS_IN_PACKSIZE); + /* dns_add_OPT_TCP_KEEPALIVE(packet, 600); */ + if (query->qtype == DNS_T_A && server_info->ecs_ipv4.enable) { + dns_add_OPT_ECS(packet, &server_info->ecs_ipv4.ecs); + } else if (query->qtype == DNS_T_AAAA && server_info->ecs_ipv6.enable) { + dns_add_OPT_ECS(packet, &server_info->ecs_ipv6.ecs); + } + + /* encode packet */ + encode_len = dns_encode(packet_data_buffer, DNS_IN_PACKSIZE, packet); + if (encode_len <= 0) { + tlog(TLOG_ERROR, "encode query failed."); + return -1; + } + + if (encode_len > DNS_IN_PACKSIZE) { + BUG("size is invalid."); + return -1; + } + + *packet_data = packet_data_buffer; + *packet_data_len = encode_len; + + return 0; +} + static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, int len) { struct dns_server_info *server_info = NULL; @@ -3341,6 +3445,9 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, int i = 0; int total_server = 0; int send_count = 0; + void *packet_data = NULL; + int packet_data_len = 0; + unsigned char packet_data_buffer[DNS_IN_PACKSIZE]; query->send_tick = get_tick_count(); @@ -3357,7 +3464,7 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, server_info->is_already_prohibit = 1; atomic_inc(&client.dns_server_prohibit_num); } - + time_t now = 0; time(&now); if ((now - 60 < server_info->last_send) && (now - 5 > server_info->last_recv)) { @@ -3380,28 +3487,33 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, } } + if (_dns_client_setup_server_packet(server_info, query, packet, len, packet_data_buffer, &packet_data, + &packet_data_len) != 0) { + continue; + } + atomic_inc(&query->dns_request_sent); send_count++; errno = 0; switch (server_info->type) { case DNS_SERVER_UDP: /* udp query */ - ret = _dns_client_send_udp(server_info, packet, len); + ret = _dns_client_send_udp(server_info, packet_data, packet_data_len); send_err = errno; break; case DNS_SERVER_TCP: /* tcp query */ - ret = _dns_client_send_tcp(server_info, packet, len); + ret = _dns_client_send_tcp(server_info, packet_data, packet_data_len); send_err = errno; break; case DNS_SERVER_TLS: /* tls query */ - ret = _dns_client_send_tls(server_info, packet, len); + ret = _dns_client_send_tls(server_info, packet_data, packet_data_len); send_err = errno; break; case DNS_SERVER_HTTPS: /* https query */ - ret = _dns_client_send_https(server_info, packet, len); + ret = _dns_client_send_https(server_info, packet_data, packet_data_len); send_err = errno; break; default: @@ -3444,7 +3556,7 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet, } } - int num = atomic_dec_return(&query->dns_request_sent); + int num = atomic_dec_return(&query->dns_request_sent); if (num == 0 && send_count > 0) { _dns_client_query_remove(query); } @@ -3466,7 +3578,7 @@ static int _dns_client_dns_add_ecs(struct dns_query_struct *query, struct dns_pa return dns_add_OPT_ECS(packet, &query->ecs.ecs); } -static int _dns_client_send_query(struct dns_query_struct *query, const char *domain) +static int _dns_client_send_query(struct dns_query_struct *query) { unsigned char packet_buff[DNS_PACKSIZE]; unsigned char inpacket[DNS_IN_PACKSIZE]; @@ -3490,7 +3602,7 @@ static int _dns_client_send_query(struct dns_query_struct *query, const char *do } /* add question */ - if (dns_add_domain(packet, domain, query->qtype, DNS_C_IN) != 0) { + if (dns_add_domain(packet, query->domain, query->qtype, DNS_C_IN) != 0) { tlog(TLOG_ERROR, "add domain to packet failed."); return -1; } @@ -3709,7 +3821,7 @@ int dns_client_query(const char *domain, int qtype, dns_client_callback callback /* send query */ _dns_client_query_get(query); - ret = _dns_client_send_query(query, domain); + ret = _dns_client_send_query(query); if (ret != 0) { _dns_client_query_release(query); goto errout_del_list; @@ -4005,7 +4117,7 @@ static void _dns_client_period_run(unsigned int msec) } } else { tlog(TLOG_INFO, "retry query %s, type: %d, id: %d", query->domain, query->qtype, query->sid); - _dns_client_send_query(query, query->domain); + _dns_client_send_query(query); } _dns_client_query_release(query); } @@ -4101,36 +4213,38 @@ static void *_dns_client_work(void *arg) return NULL; } -int dns_client_set_ecs(char *ip, int subnet) +static int _dns_client_setup_ecs(char *ip, int subnet, struct dns_client_ecs *ecs_ipv4, struct dns_client_ecs *ecs_ipv6) { struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); - getaddr_by_host(ip, (struct sockaddr *)&addr, &addr_len); + if (getaddr_by_host(ip, (struct sockaddr *)&addr, &addr_len) != 0) { + return -1; + } switch (addr.ss_family) { case AF_INET: { struct sockaddr_in *addr_in = NULL; addr_in = (struct sockaddr_in *)&addr; - memcpy(&client.ecs_ipv4.ecs.addr, &addr_in->sin_addr.s_addr, 4); - client.ecs_ipv4.ecs.source_prefix = subnet; - client.ecs_ipv4.ecs.scope_prefix = 0; - client.ecs_ipv4.ecs.family = DNS_OPT_ECS_FAMILY_IPV4; - client.ecs_ipv4.enable = 1; + memcpy(&ecs_ipv4->ecs.addr, &addr_in->sin_addr.s_addr, 4); + ecs_ipv4->ecs.source_prefix = subnet; + ecs_ipv4->ecs.scope_prefix = 0; + ecs_ipv4->ecs.family = DNS_OPT_ECS_FAMILY_IPV4; + ecs_ipv4->enable = 1; } break; case AF_INET6: { struct sockaddr_in6 *addr_in6 = NULL; addr_in6 = (struct sockaddr_in6 *)&addr; if (IN6_IS_ADDR_V4MAPPED(&addr_in6->sin6_addr)) { - client.ecs_ipv4.ecs.source_prefix = subnet; - client.ecs_ipv4.ecs.scope_prefix = 0; - client.ecs_ipv4.ecs.family = DNS_OPT_ECS_FAMILY_IPV4; - client.ecs_ipv4.enable = 1; + ecs_ipv4->ecs.source_prefix = subnet; + ecs_ipv4->ecs.scope_prefix = 0; + ecs_ipv4->ecs.family = DNS_OPT_ECS_FAMILY_IPV4; + ecs_ipv4->enable = 1; } else { - memcpy(&client.ecs_ipv6.ecs.addr, addr_in6->sin6_addr.s6_addr, 16); - client.ecs_ipv6.ecs.source_prefix = subnet; - client.ecs_ipv6.ecs.scope_prefix = 0; - client.ecs_ipv6.ecs.family = DNS_ADDR_FAMILY_IPV6; - client.ecs_ipv6.enable = 1; + memcpy(&ecs_ipv6->ecs.addr, addr_in6->sin6_addr.s6_addr, 16); + ecs_ipv6->ecs.source_prefix = subnet; + ecs_ipv6->ecs.scope_prefix = 0; + ecs_ipv6->ecs.family = DNS_ADDR_FAMILY_IPV6; + ecs_ipv6->enable = 1; } } break; default: @@ -4139,6 +4253,11 @@ int dns_client_set_ecs(char *ip, int subnet) return 0; } +int dns_client_set_ecs(char *ip, int subnet) +{ + return _dns_client_setup_ecs(ip, subnet, &client.ecs_ipv4, &client.ecs_ipv6); +} + static int _dns_client_create_wakeup_event(void) { int fd_wakeup = -1; diff --git a/src/dns_client.h b/src/dns_client.h index 3586834..f3a2232 100644 --- a/src/dns_client.h +++ b/src/dns_client.h @@ -108,6 +108,12 @@ struct client_dns_server_flag_https { char skip_check_cert; }; +struct client_dns_server_flag_ecs { + int enable; + char ip[DNS_MAX_CNAME_LEN]; + int subnet; +}; + struct client_dns_server_flags { dns_server_type_t type; unsigned int server_flag; @@ -115,6 +121,9 @@ struct client_dns_server_flags { long long set_mark; int drop_packet_latency_ms; char proxyname[DNS_MAX_CNAME_LEN]; + struct client_dns_server_flag_ecs ipv4_ecs; + struct client_dns_server_flag_ecs ipv6_ecs; + union { struct client_dns_server_flag_udp udp; struct client_dns_server_flag_tls tls; diff --git a/src/dns_conf.c b/src/dns_conf.c index dbd9911..9a9bb92 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -170,6 +170,8 @@ char dns_conf_sni_proxy_ip[DNS_MAX_IPLEN]; static int _conf_domain_rule_nameserver(char *domain, const char *group_name); static int _conf_ptr_add(const char *hostname, const char *ip, int is_dynamic); +static int _conf_client_subnet(char *subnet, struct dns_edns_client_subnet *ipv4_ecs, + struct dns_edns_client_subnet *ipv6_ecs); static void *_new_dns_rule(enum domain_rule domain_rule) { @@ -499,6 +501,7 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de {"exclude-default-group", no_argument, NULL, 'E'}, /* exclude this from default group */ {"set-mark", required_argument, NULL, 254}, /* set mark */ {"bootstrap-dns", no_argument, NULL, 255}, /* set as bootstrap dns */ + {"subnet", required_argument, NULL, 256}, /* set subnet */ {NULL, no_argument, NULL, 0} }; /* clang-format on */ @@ -634,6 +637,10 @@ static int _config_server(int argc, char *argv[], dns_server_type_t type, int de is_bootstrap_dns = 1; break; } + case 256: { + _conf_client_subnet(optarg, &server->ipv4_ecs, &server->ipv6_ecs); + break; + } default: break; } @@ -2327,48 +2334,52 @@ static int _conf_whitelist_ip(void *data, int argc, char *argv[]) return _config_iplist_rule(argv[1], ADDRESS_RULE_WHITELIST); } -static int _conf_edns_client_subnet(void *data, int argc, char *argv[]) +static int _conf_client_subnet(char *subnet, struct dns_edns_client_subnet *ipv4_ecs, + struct dns_edns_client_subnet *ipv6_ecs) { char *slash = NULL; - char *value = NULL; - int subnet = 0; + int subnet_len = 0; struct dns_edns_client_subnet *ecs = NULL; struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); + char str_subnet[128]; - if (argc <= 1) { + if (subnet == NULL) { return -1; } - value = argv[1]; - - slash = strstr(value, "/"); + safe_strncpy(str_subnet, subnet, sizeof(str_subnet)); + slash = strstr(str_subnet, "/"); if (slash) { *slash = 0; slash++; - subnet = atoi(slash); - if (subnet < 0 || subnet > 128) { + subnet_len = atoi(slash); + if (subnet_len < 0 || subnet_len > 128) { return -1; } } - if (getaddr_by_host(value, (struct sockaddr *)&addr, &addr_len) != 0) { + if (getaddr_by_host(str_subnet, (struct sockaddr *)&addr, &addr_len) != 0) { goto errout; } switch (addr.ss_family) { case AF_INET: - ecs = &dns_conf_ipv4_ecs; + ecs = ipv4_ecs; break; case AF_INET6: - ecs = &dns_conf_ipv6_ecs; + ecs = ipv6_ecs; break; default: goto errout; } - safe_strncpy(ecs->ip, value, DNS_MAX_IPLEN); - ecs->subnet = subnet; + if (ecs == NULL) { + return 0; + } + + safe_strncpy(ecs->ip, str_subnet, DNS_MAX_IPLEN); + ecs->subnet = subnet_len; ecs->enable = 1; return 0; @@ -2377,6 +2388,16 @@ errout: return -1; } +static int _conf_edns_client_subnet(void *data, int argc, char *argv[]) +{ + + if (argc <= 1) { + return -1; + } + + return _conf_client_subnet(argv[1], &dns_conf_ipv4_ecs, &dns_conf_ipv6_ecs); +} + static int _conf_domain_rule_speed_check(char *domain, const char *mode) { struct dns_domain_check_orders *check_orders = NULL; diff --git a/src/dns_conf.h b/src/dns_conf.h index 658a1c4..eb8b69c 100644 --- a/src/dns_conf.h +++ b/src/dns_conf.h @@ -293,6 +293,12 @@ struct dns_proxy_table { }; extern struct dns_proxy_table dns_proxy_table; +struct dns_edns_client_subnet { + int enable; + char ip[DNS_MAX_IPLEN]; + int subnet; +}; + struct dns_servers { char server[DNS_MAX_IPLEN]; unsigned short port; @@ -309,6 +315,8 @@ struct dns_servers { char tls_host_verify[DNS_MAX_CNAME_LEN]; char path[DNS_MAX_URL_LEN]; char proxyname[PROXY_NAME_LEN]; + struct dns_edns_client_subnet ipv4_ecs; + struct dns_edns_client_subnet ipv6_ecs; }; struct dns_proxy_servers { @@ -346,12 +354,6 @@ struct dns_ip_address_rule { unsigned int ip_ignore : 1; }; -struct dns_edns_client_subnet { - int enable; - char ip[DNS_MAX_IPLEN]; - int subnet; -}; - struct dns_conf_address_rule { radix_tree_t *ipv4; radix_tree_t *ipv6; diff --git a/src/dns_server.c b/src/dns_server.c index 6246e64..77557cf 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -5152,7 +5152,7 @@ static int _dns_server_do_query(struct dns_request *request, int skip_notify_eve 0) { request->request_wait--; _dns_server_request_release(request); - tlog(TLOG_WARN, "send dns request failed."); + tlog(TLOG_DEBUG, "send dns request failed."); goto errout; } @@ -5345,7 +5345,7 @@ static int _dns_server_prefetch_request(char *domain, dns_type_t qtype, int expi _dns_server_request_set_enable_prefetch(request, expired_domain); ret = _dns_server_do_query(request, 0); if (ret != 0) { - tlog(TLOG_WARN, "do query %s failed.\n", request->domain); + tlog(TLOG_DEBUG, "prefetch do query %s failed.\n", request->domain); goto errout; } diff --git a/src/dns_server.h b/src/dns_server.h index 70cebdc..a45c866 100644 --- a/src/dns_server.h +++ b/src/dns_server.h @@ -20,8 +20,8 @@ #define _SMART_DNS_SERVER_H #include "dns.h" -#include #include "dns_client.h" +#include #ifdef __cplusplus extern "C" { diff --git a/src/http_parse.c b/src/http_parse.c index 89f5210..850f1ad 100644 --- a/src/http_parse.c +++ b/src/http_parse.c @@ -19,9 +19,9 @@ #include "http_parse.h" #include "hash.h" #include "hashtable.h" -#include "util.h" #include "jhash.h" #include "list.h" +#include "util.h" #include #include #include @@ -230,7 +230,7 @@ static int _http_head_parse_response(struct http_head *http_head, char *key, cha if (*tmp_ptr != ' ') { continue; } - + *tmp_ptr = '\0'; ret_code = field_start; ret_msg = tmp_ptr + 1; diff --git a/src/proxy.c b/src/proxy.c index 5eb1b51..84c606f 100644 --- a/src/proxy.c +++ b/src/proxy.c @@ -802,7 +802,7 @@ static int _proxy_handshake_http(struct proxy_conn *proxy_conn) if (errno == EAGAIN || errno == EWOULDBLOCK) { return PROXY_HANDSHAKE_WANT_READ; } - + if (len == 0) { tlog(TLOG_ERROR, "remote server %s closed.", proxy_conn->server_info->proxy_name); } else { diff --git a/src/smartdns.c b/src/smartdns.c index dfcd173..b52efda 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -263,6 +263,18 @@ static int _smartdns_prepare_server_flags(struct client_dns_server_flags *flags, flags->set_mark = server->set_mark; flags->drop_packet_latency_ms = server->drop_packet_latency_ms; safe_strncpy(flags->proxyname, server->proxyname, sizeof(flags->proxyname)); + if (server->ipv4_ecs.enable) { + flags->ipv4_ecs.enable = 1; + safe_strncpy(flags->ipv4_ecs.ip, server->ipv4_ecs.ip, sizeof(flags->ipv4_ecs.ip)); + flags->ipv4_ecs.subnet = server->ipv4_ecs.subnet; + } + + if (server->ipv6_ecs.enable) { + flags->ipv6_ecs.enable = 1; + safe_strncpy(flags->ipv6_ecs.ip, server->ipv6_ecs.ip, sizeof(flags->ipv6_ecs.ip)); + flags->ipv6_ecs.subnet = server->ipv6_ecs.subnet; + } + return 0; } diff --git a/src/util.c b/src/util.c index d022ad9..d7546e8 100644 --- a/src/util.c +++ b/src/util.c @@ -45,12 +45,12 @@ #include #include #include +#include #include #include #include #include #include -#include #define TMP_BUFF_LEN_32 32 @@ -1230,9 +1230,9 @@ void get_compiled_time(struct tm *tm) unsigned long get_system_mem_size(void) { struct sysinfo memInfo; - sysinfo (&memInfo); - long long totalMem = memInfo.totalram; - totalMem *= memInfo.mem_unit; + sysinfo(&memInfo); + long long totalMem = memInfo.totalram; + totalMem *= memInfo.mem_unit; return totalMem; } diff --git a/test/cases/test-subnet.cc b/test/cases/test-subnet.cc new file mode 100644 index 0000000..77b5abd --- /dev/null +++ b/test/cases/test-subnet.cc @@ -0,0 +1,447 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "util.h" +#include "gtest/gtest.h" +#include + +class SubNet : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(SubNet, pass_subnet) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.family != 1) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\x08\x08\x08\x00", 4) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 24) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +dualstack-ip-selection no +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A +subnet=8.8.8.8/24", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(SubNet, conf) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.family != DNS_OPT_ECS_FAMILY_IPV4) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\x08\x08\x08\x00", 4) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 24) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +dualstack-ip-selection no +edns-client-subnet 8.8.8.8/24 +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + +TEST_F(SubNet, conf_v6) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_AAAA) { + return smartdns::SERVER_REQUEST_SOA; + } + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.family != DNS_OPT_ECS_FAMILY_IPV6) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00", 16) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 64) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + return smartdns::SERVER_REQUEST_OK; + }); + + server.MockPing(PING_TYPE_ICMP, "2001:db8::1", 60, 70); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +dualstack-ip-selection no +edns-client-subnet ffff:ffff:ffff:ffff:ffff::/64 +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::1"); +} + +TEST_F(SubNet, per_server) +{ + smartdns::MockServer server_upstream1; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream1.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 1) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + + if (request->qtype == DNS_T_AAAA) { + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 1) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server_upstream2.Start("udp://0.0.0.0:62053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype == DNS_T_A) { + + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 0) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4"); + return smartdns::SERVER_REQUEST_OK; + } + + if (ecs.family != DNS_OPT_ECS_FAMILY_IPV4) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\x08\x08\x08\x00", 4) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 24) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "5.6.7.8"); + return smartdns::SERVER_REQUEST_OK; + } + + if (request->qtype = DNS_T_AAAA) { + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 0) { + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::1"); + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.family != DNS_OPT_ECS_FAMILY_IPV6) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00", 16) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 64) { + return smartdns::SERVER_REQUEST_ERROR; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "2001:db8::2"); + return smartdns::SERVER_REQUEST_OK; + } + + return smartdns::SERVER_REQUEST_SOA; + }); + + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100); + server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 10); + server.MockPing(PING_TYPE_ICMP, "2001:db8::1", 60, 100); + server.MockPing(PING_TYPE_ICMP, "2001:db8::2", 60, 10); + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:62053 -subnet=8.8.8.8/24 -subnet=ffff:ffff:ffff:ffff:ffff::/64 +server 127.0.0.1:61053 +log-num 0 +log-console yes +dualstack-ip-selection no +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com A", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "A"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8"); + + ASSERT_TRUE(client.Query("a.com AAAA", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::2"); +}