Compare commits

..

54 Commits

Author SHA1 Message Date
Nick Peng
27c1aedd3b http-svcb: add https svcb RR suppport 2023-10-23 02:43:28 +00:00
zxlhhyccc
62171f2a4d add choose update time 2023-10-22 21:58:52 +08:00
Nick Peng
37a87e864e dns_conf: fix wildcard match issue 2023-10-06 22:51:36 +08:00
Nick Peng
a1d067f9eb dns_server: Fix address sub-rule issue. 2023-10-05 17:10:14 +08:00
Nick Peng
96d37332e4 dns_cache: Fix possible crash issue 2023-10-05 16:32:20 +08:00
Nick Peng
3916ea570a dns_cache: fix cache timeout issue 2023-09-27 23:19:45 +08:00
Nick Peng
51c81513ab dns_conf: add short option for server options 2023-09-24 22:55:10 +08:00
Nick Peng
1dd01ff4bd dns_cache: optimize timer wheel for DNS cache 2023-09-24 13:16:33 +08:00
Nick Peng
bfacad33ae dns_cache: Replace cache timeout mechanism with time wheel algorithm to reduce CPU usage 2023-09-23 23:31:19 +08:00
Nick Peng
b7fb501be9 dns_cache: fix insert issue. 2023-09-19 22:15:21 +08:00
Nick Peng
28139d2020 smartdns: Fixe coredump issue caused by running smartdns --help 2023-09-14 21:08:10 +08:00
Nick Peng
f7ede1b7d0 lint: clear lint warnings 2023-09-13 23:39:54 +08:00
Nick Peng
875100f5c1 dns_cache: optimize insertion performance 2023-09-12 21:49:14 +08:00
Frand Ren
1a492f7dc0 add domain rule "root or sub" 2023-09-12 09:22:59 +08:00
Nick Peng
1ff7829b49 dns: simple make DDR request SOA. 2023-09-09 17:20:42 +08:00
Nick Peng
8befd9d5d2 optware: fix init script restart dnsmasq failure. 2023-09-07 23:12:16 +08:00
Nick Peng
5658d72b3b dns_conf: update smartdns.conf and add -no-ip-alias for bind 2023-09-07 23:04:15 +08:00
Nick Peng
1b12709451 feature: add ip-rules and ip-set options 2023-09-06 23:25:13 +08:00
Nick Peng
c39a7b9b41 dns_cache: optimize dns cache. 2023-08-31 22:32:35 +08:00
Nick Peng
901baf80c0 ip-alias: add option -no-ipalias for domain-rules 2023-08-30 00:34:27 +08:00
Nick Peng
45e3455932 dns_cache: reduce cpu usage. 2023-08-29 23:55:04 +08:00
Nick Peng
887ef7b20e dns_cache: some cpu usage optimize for inactive cache 2023-08-24 23:42:39 +08:00
Nick Peng
9307855f7c dns_conf: fix ip-alias issue 2023-08-23 00:13:00 +08:00
Nick Peng
fb7b747f9f cache: Optimize cache memory usage 2023-08-22 23:18:31 +08:00
Nick Peng
7eb9d5d42f action: add docker build CI and update openssl for docker image. 2023-08-22 23:08:58 +08:00
Nick Peng
1054229efb feature: add ip-alias option. 2023-08-16 22:47:47 +08:00
Brainos
c19a39a447 Add nss-lookup.target as dependency for service 2023-08-14 23:27:51 +08:00
MoetaYuko
1ba6ee7cb9 openwrt: add missing EOF to custom.conf 2023-08-09 23:31:15 +08:00
Nick Peng
601ebd590e ssl: output error message when handshake failed. 2023-08-06 21:15:04 +08:00
Nick Peng
b133ce408a dns_conf: fix memory corruption issue when ip number greater than 8. 2023-07-28 22:42:36 +08:00
Nick Peng
8d3a62c568 dns_server: fix bogus-nxdomain issue. 2023-07-26 22:40:26 +08:00
Nick Peng
93a8b87c17 dns_server: fix memory corrupt bug. 2023-07-17 21:47:14 +08:00
Nick Peng
ffc331af21 dns-client: fix bootstrap retry failure issue when os startup. 2023-07-15 21:04:27 +08:00
Nick Peng
89e958abfa dns_client: avoid false re-creation of udp sockets causing retries. 2023-07-14 20:44:10 +08:00
Nick Peng
2576fdb02f dns_client: fix bootstrap dns retry issue. 2023-07-12 22:37:22 +08:00
Nick Peng
7ff6ae3ea0 dns_server: fix edns subnet not working issue. 2023-07-12 19:28:37 +08:00
Nick Peng
c2b072b523 conf: add ddns-domain options 2023-07-12 19:28:32 +08:00
Nick Peng
1df9d624b4 conf: add host-ip option for server. 2023-07-12 19:13:11 +08:00
Nick Peng
6b021946aa conf: support prefix wildcard match. 2023-07-05 00:08:29 +08:00
Nick Peng
087c9f5df2 conf: fix address issue when configuring multiple IPs 2023-07-01 09:20:42 +08:00
Nick Peng
e66928f27f ecs: Optimize ecs-subnet configuration method 2023-06-28 14:23:27 +08:00
Nick Peng
8a9a11d6d9 log: enable output log to console when run as daemon. 2023-06-16 21:57:39 +08:00
Nick Peng
a6e5ceb675 conf: trim prefix space for multiline option 2023-06-15 21:18:08 +08:00
Nick Peng
08567c458b address: support multiple ip addresses 2023-06-14 22:41:53 +08:00
Nick Peng
234c721011 test: fix test case failure issue 2023-06-14 22:40:12 +08:00
Chongyun Lee
45346705d8 tlog: fix declaration of tlog_set_permission 2023-06-11 07:02:37 +08:00
Nick Peng
9b7b2ad12d openwrt: fix adblock not working issue 2023-06-07 23:54:26 +08:00
Nick Peng
f072ff3412 dns_server: optimize result callback and update tlog. 2023-06-07 21:15:59 +08:00
Nick Peng
ad43c796cf force-qtype-SOA: support qtype range. 2023-06-05 22:55:47 +08:00
Nick Peng
f5c8d3ce57 dns_server: improve code readability 2023-06-02 22:52:46 +08:00
Nick Peng
f621b424e2 lint: add clang-tidy linter 2023-05-30 23:26:05 +08:00
Nick Peng
d59c148a28 smartdns: follow sysv daemon initialize steps 2023-05-30 23:25:25 +08:00
Nick Peng
8ea34ab176 dns_conf: A little bit of performance optimization 2023-05-27 22:51:00 +08:00
Nick Peng
0340d272c3 dns_server: fix max ttl reply issue. 2023-05-09 23:22:08 +08:00
49 changed files with 5026 additions and 1122 deletions

43
.clang-tidy Normal file
View File

@@ -0,0 +1,43 @@
Checks: >
-*,
modernize-*,
bugprone-*,
concurrency-*,
misc-*,
readability-*,
performance-*,
portability-*,
google-*,
linuxkernel-*,
-bugprone-narrowing-conversions,
-bugprone-branch-clone,
-bugprone-reserved-identifier,
-bugprone-easily-swappable-parameters,
-bugprone-sizeof-expression,
-bugprone-implicit-widening-of-multiplication-result,
-bugprone-suspicious-memory-comparison,
-bugprone-not-null-terminated-result,
-bugprone-signal-handler,
-concurrency-mt-unsafe,
-misc-unused-parameters,
-misc-misplaced-widening-cast,
-misc-no-recursion,
-readability-magic-numbers,
-readability-use-anyofallof,
-readability-identifier-length,
-readability-function-cognitive-complexity,
-readability-named-parameter,
-readability-isolate-declaration,
-readability-else-after-return,
-readability-redundant-control-flow,
-readability-suspicious-call-argument,
-google-readability-casting,
-google-readability-todo,
-performance-no-int-to-ptr,
# clang-analyzer-*,
# clang-analyzer-deadcode.DeadStores,
# clang-analyzer-optin.performance.Padding,
# -clang-analyzer-security.insecureAPI.*
# Turn all the warnings from the checks above into errors.
FormatStyle: file

35
.github/workflows/docker.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Publish Docker Image
on:
workflow_dispatch:
inputs:
version:
description: 'new image tag(e.g. v1.1.0)'
required: true
default: 'latest'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v3
with:
platforms: linux/amd64,linux/arm64
push: true
tags: ${{vars.DOCKERHUB_REPO}}:${{ github.event.inputs.version }}

View File

@@ -2,7 +2,7 @@ FROM ubuntu:latest as smartdns-builder
LABEL previous-stage=smartdns-builder
# prepare builder
ARG OPENSSL_VER=1.1.1f
ARG OPENSSL_VER=3.0.10
RUN apt update && \
apt install -y perl curl make musl-tools musl-dev && \
ln -s /usr/include/linux /usr/include/$(uname -m)-linux-musl && \
@@ -27,7 +27,7 @@ COPY . /build/smartdns/
RUN cd /build/smartdns && \
export CC=musl-gcc && \
export CFLAGS="-I /opt/build/include" && \
export LDFLAGS="-L /opt/build/lib" && \
export LDFLAGS="-L /opt/build/lib -L /opt/build/lib64" && \
sh ./package/build-pkg.sh --platform linux --arch `dpkg --print-architecture` --static && \
\
( cd package && tar -xvf *.tar.gz && chmod a+x smartdns/etc/init.d/smartdns ) && \
@@ -37,9 +37,9 @@ RUN cd /build/smartdns && \
cp package/smartdns/usr /release/ -a && \
cd / && rm -rf /build
FROM busybox:latest
FROM busybox:stable-musl
COPY --from=smartdns-builder /release/ /
EXPOSE 53/udp
VOLUME "/etc/smartdns/"
VOLUME ["/etc/smartdns/"]
CMD ["/usr/sbin/smartdns", "-f", "-x"]

View File

@@ -39,6 +39,7 @@
# -no-cache: skip cache.
# -no-rule-soa: Skip address SOA(#) rules.
# -no-dualstack-selection: Disable dualstack ip selection.
# -no-ip-alias: ignore ip alias.
# -force-aaaa-soa: force AAAA query return SOA.
# -ipset ipsetname: use ipset rule.
# -nftset nftsetname: use nftset rule.
@@ -100,6 +101,10 @@ bind [::]:53
# List of IPs that will be ignored
# ignore-ip [ip/subnet]
# alias of IPs
# ip-alias [ip/subnet] [ip1[,ip2]...]
# ip-alias 192.168.0.1/24 10.9.0.1,10.9.0.2
# speed check mode
# speed-check-mode [ping|tcp:port|none|,]
# example:
@@ -112,7 +117,7 @@ bind [::]:53
# force specific qtype return soa
# force-qtype-SOA [qtypeid |...]
# force-qtype-SOA [qtypeid,...]
# force-qtype-SOA [qtypeid|start_id-end_id|,...]
# force-qtype-SOA 65 28
# force-qtype-SOA 65,28
force-qtype-SOA 65
@@ -188,12 +193,13 @@ log-level info
# -blacklist-ip: filter result with blacklist ip
# -whitelist-ip: filter result with whitelist ip, result in whitelist-ip will be accepted.
# -check-edns: result must exist edns RR, or discard result.
# -group [group]: set server to group, use with nameserver /domain/group.
# -exclude-default-group: exclude this server from default group.
# -proxy [proxy-name]: use proxy to connect to server.
# g|-group [group]: set server to group, use with nameserver /domain/group.
# e|-exclude-default-group: exclude this server from default group.
# p|-proxy [proxy-name]: use proxy to connect to server.
# -bootstrap-dns: set as bootstrap dns server.
# -set-mark: set mark on packets.
# -subnet [ip/subnet]: set edns client subnet.
# -host-ip [ip]: set dns server host ip.
# server 8.8.8.8 -blacklist-ip -check-edns -group g1 -group g2
# server tls://dns.google:853
# server https://dns.google/dns-query
@@ -208,8 +214,8 @@ log-level info
# -spki-pin: TLS spki pin to verify.
# -tls-host-verify: cert hostname to verify.
# -host-name: TLS sni hostname.
# -no-check-certificate: no check certificate.
# -proxy [proxy-name]: use proxy to connect to server.
# k|-no-check-certificate: no check certificate.
# p|-proxy [proxy-name]: use proxy to connect to server.
# -bootstrap-dns: set as bootstrap dns server.
# Get SPKI with this command:
# echo | openssl s_client -connect '[ip]:853' | openssl x509 -pubkey -noout | openssl pkey -pubin -outform der | openssl dgst -sha256 -binary | openssl enc -base64
@@ -223,8 +229,8 @@ log-level info
# -tls-host-verify: cert hostname to verify.
# -host-name: TLS sni hostname.
# -http-host: http host.
# -no-check-certificate: no check certificate.
# -proxy [proxy-name]: use proxy to connect to server.
# k|-no-check-certificate: no check certificate.
# p|-proxy [proxy-name]: use proxy to connect to server.
# -bootstrap-dns: set as bootstrap dns server.
# default port is 443
# server-https https://cloudflare-dns.com/dns-query
@@ -247,8 +253,9 @@ log-level info
# expand-ptr-from-address yes
# specific address to domain
# address /domain/[ip|-|-4|-6|#|#4|#6]
# address /domain/[ip1,ip2|-|-4|-6|#|#4|#6]
# address /www.example.com/1.2.3.4, return ip 1.2.3.4 to client
# address /www.example.com/1.2.3.4,5.6.7.8, return multiple ip addresses
# address /www.example.com/-, ignore address, query from upstream, suffix 4, for ipv4, 6 for ipv6, none for all
# address /www.example.com/#, return SOA to client, suffix 4, for ipv4, 6 for ipv6, none for all
@@ -289,6 +296,9 @@ log-level info
# nftset /www.example.com/-, ignore this domain
# nftset /www.example.com/#6:-, ignore ipv6
# set ddns domain
# ddns-domain domain
# set domain rules
# domain-rules /domain/ [-speed-check-mode [...]]
# rules:
@@ -301,6 +311,8 @@ log-level info
# [-d] -dualstack-ip-selection [yes|no]: same as dualstack-ip-selection option
# -no-serve-expired: ignore expired domain
# -delete: delete domain rule
# -no-ip-alias: ignore ip alias
# -no-cache: ignore cache
# collection of domains
# the domain-set can be used with /domain/ for address, nameserver, ipset, etc.
@@ -315,3 +327,24 @@ log-level info
# nameserver /domain-set:domain-list/server-group
# ipset /domain-set:domain-list/ipset
# domain-rules /domain-set:domain-list/ -speed-check-mode ping
# set ip rules
# ip-rules ip-cidrs [-ip-alias [...]]
# rules:
# [-c] -ip-alias [ip1,ip2]: same as ip-alias option
# [-a] -whitelist-ip: same as whitelist-ip option
# [-n] -blacklist-ip: same as blacklist-ip option
# [-p] -bogus-nxdomain: same as bogus-nxdomain option
# [-t] -ignore-ip: same as ignore-ip option
# collection of IPs
# the ip-set can be used with /ip-cidr/ for ip-alias, ignore-ip, etc.
# ip-set -name [set-name] -type list -file [/path/to/file]
# [-n] -name [set name]: ip set name
# [-t] -type [list]: ip set type, list only now
# [-f] -file [path/to/set]: file path of ip set
#
# example:
# ip-set -name ip-list -file /etc/smartdns/ip-list.conf
# bogus-nxdomain ip-set:ip-list
# ip-alias ip-set:ip-list 1.2.3.4

View File

@@ -1,2 +1,2 @@
# Add custom settings here.
# please read https://pymumu.github.io/smartdns/config/basic-config/
# please read https://pymumu.github.io/smartdns/config/basic-config/

View File

@@ -260,13 +260,13 @@ disable_auto_update()
enable_auto_update()
{
grep "0 5 * * * /etc/init.d/smartdns updatefiles" /etc/crontabs/root 2>/dev/null
grep "0 $auto_update_day_time * * * /etc/init.d/smartdns updatefiles" /etc/crontabs/root 2>/dev/null
if [ $? -eq 0 ]; then
return
fi
disable_auto_update 1
echo "0 5 * * * /etc/init.d/smartdns updatefiles" >> /etc/crontabs/root
echo "0 $auto_update_day_time * * * /etc/init.d/smartdns updatefiles" >> /etc/crontabs/root
restart_crond
}
@@ -308,7 +308,7 @@ load_domain_rules()
config_get block_domain_set_file "$section" "block_domain_set_file"
[ ! -z "$block_domain_set_file" ] && {
conf_append "domain-set" "-name ${domain_set_name}-block-file -file '$block_domain_set_file'"
conf_append "domain-rules" "/domain-set:${domain_set_name}-block-file/ -group block"
conf_append "domain-rules" "/domain-set:${domain_set_name}-block-file/ --address #"
}
conf_append "domain-set" "-name ${domain_set_name}-block-list -file /etc/smartdns/domain-block.list"
@@ -498,6 +498,8 @@ load_service()
config_get tcp_server "$section" "tcp_server" "1"
config_get server_flags "$section" "server_flags" ""
config_get auto_update_day_time "$section" "auto_update_day_time" "5"
config_get speed_check_mode "$section" "speed_check_mode" ""
[ ! -z "$speed_check_mode" ] && conf_append "speed-check-mode" "$speed_check_mode"

View File

@@ -121,8 +121,7 @@ restart_dnsmasq()
PID2="$(echo "$CMD" | awk 'NR==2{print $1}')"
PID2_PPID="$(grep 'PPid:' /proc/$PID2/status | awk '{print $2}' 2>/dev/null)"
if [ "$PID2_PPID" != "$PID1" ]; then
echo "find multiple dnsmasq, but not started by the same process"
return 1
kill -9 "$PID2"
fi
PID=$PID1
else

View File

@@ -15,8 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
BIN=smartdns
OBJS_LIB=lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/conf.o lib/nftset.o
OBJS=smartdns.o fast_ping.o dns_client.o dns_server.o dns.o util.o tlog.o dns_conf.o dns_cache.o http_parse.o proxy.o $(OBJS_LIB)
OBJS_LIB=lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/timer_wheel.o
OBJS_MAIN=smartdns.o fast_ping.o dns_client.o dns_server.o dns.o util.o tlog.o dns_conf.o dns_cache.o http_parse.o proxy.o timer.o lib/conf.o lib/nftset.o
OBJS=$(OBJS_MAIN) $(OBJS_LIB)
# cflags
ifndef CFLAGS
@@ -51,5 +52,8 @@ all: $(BIN)
$(BIN) : $(OBJS)
$(CC) $(OBJS) -o $@ $(LDFLAGS)
clang-tidy:
clang-tidy -p=. $(OBJS_MAIN:.o=.c) -- $(CFLAGS)
clean:
$(RM) $(OBJS) $(BIN)

View File

@@ -759,8 +759,6 @@ static int _dns_get_opt_RAW(struct dns_rrs *rrs, char *domain, int maxsize, int
static int __attribute__((unused)) _dns_add_OPT(struct dns_packet *packet, dns_rr_type type, unsigned short opt_code,
unsigned short opt_len, struct dns_opt *opt)
{
// TODO
int ret = 0;
int len = 0;
struct dns_context context;
@@ -806,8 +804,6 @@ static int __attribute__((unused)) _dns_add_OPT(struct dns_packet *packet, dns_r
static int __attribute__((unused)) _dns_get_OPT(struct dns_rrs *rrs, unsigned short *opt_code, unsigned short *opt_len,
struct dns_opt *opt, int *opt_maxlen)
{
// TODO
int qtype = 0;
int qclass = 0;
int rr_len = 0;
@@ -875,6 +871,27 @@ int dns_get_PTR(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *
return _dns_get_RAW(rrs, domain, maxsize, ttl, cname, &len);
}
int dns_add_TXT(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *text)
{
int rr_len = strnlen(text, DNS_MAX_CNAME_LEN);
char data[DNS_MAX_CNAME_LEN];
if (rr_len > DNS_MAX_CNAME_LEN - 2) {
return -1;
}
data[0] = rr_len;
rr_len++;
memcpy(data + 1, text, rr_len);
data[rr_len] = 0;
return _dns_add_RAW(packet, type, DNS_T_TXT, domain, ttl, data, rr_len);
}
int dns_get_TXT(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *text, int txt_size)
{
return -1;
}
int dns_add_NS(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *cname)
{
int rr_len = strnlen(cname, DNS_MAX_CNAME_LEN) + 1;
@@ -1687,8 +1704,6 @@ static int _dns_encode_SOA(struct dns_context *context, struct dns_rrs *rrs)
static int _dns_decode_opt_ecs(struct dns_context *context, struct dns_opt_ecs *ecs, int opt_len)
{
// TODO
int len = 0;
if (opt_len < 4) {
return -1;
@@ -1716,7 +1731,6 @@ static int _dns_decode_opt_ecs(struct dns_context *context, struct dns_opt_ecs *
static int _dns_decode_opt_cookie(struct dns_context *context, struct dns_opt_cookie *cookie, int opt_len)
{
// TODO
if (opt_len < (int)member_size(struct dns_opt_cookie, client_cookie)) {
return -1;
}

View File

@@ -72,6 +72,7 @@ typedef enum dns_type {
DNS_T_SRV = 33,
DNS_T_OPT = 41,
DNS_T_SSHFP = 44,
DNS_T_SVCB = 64,
DNS_T_HTTPS = 65,
DNS_T_SPF = 99,
DNS_T_AXFR = 252,
@@ -262,6 +263,9 @@ int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned
int dns_add_PTR(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *cname);
int dns_get_PTR(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *cname, int cname_size);
int dns_add_TXT(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *text);
int dns_get_TXT(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *text, int txt_size);
int dns_add_AAAA(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl,
unsigned char addr[DNS_RR_AAAA_LEN]);
int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[DNS_RR_AAAA_LEN]);
@@ -302,7 +306,7 @@ int dns_add_HTTPS_end(struct dns_rr_nested *svcparam);
int dns_get_HTTPS_svcparm_start(struct dns_rrs *rrs, struct dns_https_param **https_param, char *domain, int maxsize,
int *ttl, int *priority, char *target, int target_size);
struct dns_https_param *dns_get_HTTPS_svcparm_next(struct dns_rrs *rrs, struct dns_https_param *parm);
struct dns_https_param *dns_get_HTTPS_svcparm_next(struct dns_rrs *rrs, struct dns_https_param *param);
/*
* Packet operation

View File

@@ -18,66 +18,57 @@
#include "dns_cache.h"
#include "stringutil.h"
#include "timer.h"
#include "tlog.h"
#include "util.h"
#include <errno.h>
#include <fcntl.h>
#include <pthread.h>
#include <stdio.h>
#include <string.h>
#include <sys/types.h>
#define DNS_CACHE_MAX_HITNUM 5000
#define DNS_CACHE_HITNUM_STEP 2
#define DNS_CACHE_MAX_HITNUM 6000
#define DNS_CACHE_HITNUM_STEP 3
#define DNS_CACHE_HITNUM_STEP_MAX 6
#define DNS_CACHE_READ_TIMEOUT 60
struct dns_cache_head {
DECLARE_HASHTABLE(cache_hash, 16);
struct hash_table cache_hash;
struct list_head cache_list;
struct list_head inactive_list;
atomic_t num;
int size;
int enable_inactive;
int inactive_list_expired;
pthread_mutex_t lock;
dns_cache_callback timeout_callback;
};
typedef int (*dns_cache_read_callback)(struct dns_cache_record *cache_record, struct dns_cache_data *cache_data);
static struct dns_cache_head dns_cache_head;
int dns_cache_init(int size, int enable_inactive, int inactive_list_expired)
int dns_cache_init(int size, dns_cache_callback timeout_callback)
{
int bits = 0;
INIT_LIST_HEAD(&dns_cache_head.cache_list);
INIT_LIST_HEAD(&dns_cache_head.inactive_list);
hash_init(dns_cache_head.cache_hash);
bits = ilog2(size) - 1;
if (bits >= 20) {
bits = 20;
} else if (bits < 12) {
bits = 12;
}
hash_table_init(dns_cache_head.cache_hash, bits, malloc);
atomic_set(&dns_cache_head.num, 0);
dns_cache_head.size = size;
dns_cache_head.enable_inactive = enable_inactive;
dns_cache_head.inactive_list_expired = inactive_list_expired;
dns_cache_head.timeout_callback = timeout_callback;
pthread_mutex_init(&dns_cache_head.lock, NULL);
return 0;
}
static __attribute__((unused)) struct dns_cache *_dns_cache_last(void)
static struct dns_cache *_dns_cache_first(void)
{
struct dns_cache *dns_cache = NULL;
dns_cache = list_last_entry(&dns_cache_head.inactive_list, struct dns_cache, list);
if (dns_cache) {
return dns_cache;
}
return list_last_entry(&dns_cache_head.cache_list, struct dns_cache, list);
}
static struct dns_cache *_dns_inactive_cache_first(void)
{
struct dns_cache *dns_cache = NULL;
dns_cache = list_first_entry_or_null(&dns_cache_head.inactive_list, struct dns_cache, list);
if (dns_cache) {
return dns_cache;
}
return list_first_entry_or_null(&dns_cache_head.cache_list, struct dns_cache, list);
}
@@ -86,7 +77,8 @@ static void _dns_cache_delete(struct dns_cache *dns_cache)
hash_del(&dns_cache->node);
list_del_init(&dns_cache->list);
atomic_dec(&dns_cache_head.num);
dns_cache_data_free(dns_cache->cache_data);
dns_cache_data_put(dns_cache->cache_data);
dns_cache->cache_data = NULL;
free(dns_cache);
}
@@ -103,6 +95,7 @@ void dns_cache_release(struct dns_cache *dns_cache)
if (dns_cache == NULL) {
return;
}
if (!atomic_dec_and_test(&dns_cache->ref)) {
return;
}
@@ -114,15 +107,10 @@ static void _dns_cache_remove(struct dns_cache *dns_cache)
{
hash_del(&dns_cache->node);
list_del_init(&dns_cache->list);
dns_timer_del(&dns_cache->timer);
dns_cache_release(dns_cache);
}
static void _dns_cache_move_inactive(struct dns_cache *dns_cache)
{
list_del_init(&dns_cache->list);
list_add_tail(&dns_cache->list, &dns_cache_head.inactive_list);
}
enum CACHE_TYPE dns_cache_data_type(struct dns_cache_data *cache_data)
{
return cache_data->head.cache_type;
@@ -138,15 +126,6 @@ const char *dns_cache_get_dns_group_name(struct dns_cache *dns_cache)
return dns_cache->info.dns_group_name;
}
void dns_cache_data_free(struct dns_cache_data *data)
{
if (data == NULL) {
return;
}
free(data);
}
struct dns_cache_data *dns_cache_new_data_addr(void)
{
struct dns_cache_addr *cache_addr = malloc(sizeof(struct dns_cache_addr));
@@ -158,6 +137,7 @@ struct dns_cache_data *dns_cache_new_data_addr(void)
cache_addr->head.cache_type = CACHE_TYPE_NONE;
cache_addr->head.size = sizeof(struct dns_cache_addr) - sizeof(struct dns_cache_data_head);
cache_addr->head.magic = MAGIC_CACHE_DATA;
atomic_set(&cache_addr->head.ref, 1);
return (struct dns_cache_data *)cache_addr;
}
@@ -243,11 +223,38 @@ struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len
cache_packet->head.cache_type = CACHE_TYPE_PACKET;
cache_packet->head.size = packet_len;
cache_packet->head.magic = MAGIC_CACHE_DATA;
atomic_set(&cache_packet->head.ref, 1);
return (struct dns_cache_data *)cache_packet;
}
static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive, int inactive,
static void dns_cache_timer_relase(struct tw_timer_list *timer, void *data)
{
struct dns_cache *dns_cache = data;
dns_cache_release(dns_cache);
}
static void dns_cache_expired(struct tw_timer_list *timer, void *data, unsigned long timestamp)
{
struct dns_cache *dns_cache = data;
if (dns_cache->del_pending == 1) {
dns_cache_release(dns_cache);
return;
}
if (dns_cache_head.timeout_callback) {
if (dns_cache_head.timeout_callback(dns_cache) != 0) {
dns_cache_release(dns_cache);
return;
}
}
dns_cache->del_pending = 1;
dns_timer_mod(&dns_cache->timer, 5);
}
static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int timeout, int update_time,
struct dns_cache_data *cache_data)
{
struct dns_cache *dns_cache = NULL;
@@ -260,7 +267,7 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee
/* lookup existing cache */
dns_cache = dns_cache_lookup(cache_key);
if (dns_cache == NULL) {
return dns_cache_insert(cache_key, ttl, speed, no_inactive, cache_data);
return dns_cache_insert(cache_key, ttl, speed, timeout, cache_data);
}
if (ttl < DNS_CACHE_TTL_MIN) {
@@ -270,43 +277,36 @@ static int _dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int spee
/* update cache data */
pthread_mutex_lock(&dns_cache_head.lock);
dns_cache->del_pending = 0;
dns_cache->info.ttl = ttl;
dns_cache->info.qtype = cache_key->qtype;
dns_cache->info.query_flag = cache_key->query_flag;
dns_cache->info.ttl = ttl;
dns_cache->info.speed = speed;
dns_cache->info.no_inactive = no_inactive;
dns_cache->info.timeout = timeout;
dns_cache->info.is_visited = 1;
old_cache_data = dns_cache->cache_data;
dns_cache->cache_data = cache_data;
list_del_init(&dns_cache->list);
if (inactive == 0) {
time(&dns_cache->info.insert_time);
time(&dns_cache->info.replace_time);
list_add_tail(&dns_cache->list, &dns_cache_head.cache_list);
} else {
time(&dns_cache->info.replace_time);
list_add_tail(&dns_cache->list, &dns_cache_head.inactive_list);
if (cache_data) {
old_cache_data = dns_cache->cache_data;
dns_cache->cache_data = cache_data;
}
if (update_time) {
time(&dns_cache->info.insert_time);
}
time(&dns_cache->info.replace_time);
list_del(&dns_cache->list);
list_add_tail(&dns_cache->list, &dns_cache_head.cache_list);
dns_timer_mod(&dns_cache->timer, timeout);
pthread_mutex_unlock(&dns_cache_head.lock);
dns_cache_data_free(old_cache_data);
if (old_cache_data) {
dns_cache_data_put(old_cache_data);
}
dns_cache_release(dns_cache);
return 0;
}
int dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive,
int dns_cache_replace(struct dns_cache_key *cache_key, int ttl, int speed, int timeout, int update_time,
struct dns_cache_data *cache_data)
{
return _dns_cache_replace(cache_key, ttl, speed, no_inactive, 0, cache_data);
}
int dns_cache_replace_inactive(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive,
struct dns_cache_data *cache_data)
{
return _dns_cache_replace(cache_key, ttl, speed, no_inactive, 1, cache_data);
return _dns_cache_replace(cache_key, ttl, speed, timeout, update_time, cache_data);
}
static void _dns_cache_remove_by_domain(struct dns_cache_key *cache_key)
@@ -320,7 +320,7 @@ static void _dns_cache_remove_by_domain(struct dns_cache_key *cache_key)
key = jhash(&cache_key->query_flag, sizeof(cache_key->query_flag), key);
pthread_mutex_lock(&dns_cache_head.lock);
hash_for_each_possible(dns_cache_head.cache_hash, dns_cache, node, key)
hash_table_for_each_possible(dns_cache_head.cache_hash, dns_cache, node, key)
{
if (dns_cache->info.qtype != cache_key->qtype) {
continue;
@@ -372,19 +372,26 @@ static int _dns_cache_insert(struct dns_cache_info *info, struct dns_cache_data
memcpy(&dns_cache->info, info, sizeof(*info));
dns_cache->del_pending = 0;
dns_cache->cache_data = cache_data;
dns_cache->timer.function = dns_cache_expired;
dns_cache->timer.del_function = dns_cache_timer_relase;
dns_cache->timer.expires = info->timeout;
dns_cache->timer.data = dns_cache;
pthread_mutex_lock(&dns_cache_head.lock);
hash_add(dns_cache_head.cache_hash, &dns_cache->node, key);
hash_table_add(dns_cache_head.cache_hash, &dns_cache->node, key);
list_add_tail(&dns_cache->list, head);
INIT_LIST_HEAD(&dns_cache->check_list);
/* Release extra cache, remove oldest cache record */
if (atomic_inc_return(&dns_cache_head.num) > dns_cache_head.size) {
struct dns_cache *del_cache = NULL;
del_cache = _dns_inactive_cache_first();
del_cache = _dns_cache_first();
if (del_cache) {
_dns_cache_remove(del_cache);
}
}
dns_cache_get(dns_cache);
dns_timer_add(&dns_cache->timer);
pthread_mutex_unlock(&dns_cache_head.lock);
return 0;
@@ -396,7 +403,7 @@ errout:
return -1;
}
int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no_inactive,
int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int timeout,
struct dns_cache_data *cache_data)
{
struct dns_cache_info info;
@@ -406,7 +413,7 @@ int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no
}
if (dns_cache_head.size <= 0) {
dns_cache_data_free(cache_data);
dns_cache_data_put(cache_data);
return 0;
}
@@ -423,7 +430,7 @@ int dns_cache_insert(struct dns_cache_key *cache_key, int ttl, int speed, int no
info.ttl = ttl;
info.hitnum_update_add = DNS_CACHE_HITNUM_STEP;
info.speed = speed;
info.no_inactive = no_inactive;
info.timeout = timeout;
info.is_visited = 1;
time(&info.insert_time);
time(&info.replace_time);
@@ -450,7 +457,7 @@ struct dns_cache *dns_cache_lookup(struct dns_cache_key *cache_key)
time(&now);
/* find cache */
pthread_mutex_lock(&dns_cache_head.lock);
hash_for_each_possible(dns_cache_head.cache_hash, dns_cache, node, key)
hash_table_for_each_possible(dns_cache_head.cache_hash, dns_cache, node, key)
{
if (dns_cache->info.qtype != cache_key->qtype) {
continue;
@@ -473,13 +480,7 @@ struct dns_cache *dns_cache_lookup(struct dns_cache_key *cache_key)
}
if (dns_cache_ret) {
/* Return NULL if the cache times out */
if (dns_cache_head.enable_inactive == 0 && (now - dns_cache_ret->info.insert_time > dns_cache_ret->info.ttl)) {
_dns_cache_remove(dns_cache_ret);
dns_cache_ret = NULL;
} else {
dns_cache_get(dns_cache_ret);
}
dns_cache_get(dns_cache_ret);
}
pthread_mutex_unlock(&dns_cache_head.lock);
@@ -508,14 +509,20 @@ int dns_cache_get_cname_ttl(struct dns_cache *dns_cache)
time(&now);
struct dns_cache_addr *cache_addr = (struct dns_cache_addr *)dns_cache_get_data(dns_cache);
if (cache_addr == NULL) {
ttl = 0;
goto out;
}
if (cache_addr->head.cache_type != CACHE_TYPE_ADDR) {
return 0;
ttl = 0;
goto out;
}
ttl = dns_cache->info.insert_time + cache_addr->addr_data.cname_ttl - now;
if (ttl < 0) {
return 0;
ttl = 0;
goto out;
}
int addr_ttl = dns_cache_get_ttl(dns_cache);
@@ -524,7 +531,13 @@ int dns_cache_get_cname_ttl(struct dns_cache *dns_cache)
}
if (ttl < 0) {
return 0;
ttl = 0;
goto out;
}
out:
if (cache_addr) {
dns_cache_data_put((struct dns_cache_data *)cache_addr);
}
return ttl;
@@ -545,7 +558,35 @@ int dns_cache_is_soa(struct dns_cache *dns_cache)
struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache)
{
return dns_cache->cache_data;
struct dns_cache_data *cache_data;
pthread_mutex_lock(&dns_cache_head.lock);
dns_cache_data_get(dns_cache->cache_data);
cache_data = dns_cache->cache_data;
pthread_mutex_unlock(&dns_cache_head.lock);
return cache_data;
}
void dns_cache_data_get(struct dns_cache_data *cache_data)
{
if (atomic_inc_return(&cache_data->head.ref) == 1) {
tlog(TLOG_ERROR, "BUG: dns_cache data is invalid.");
return;
}
return;
}
void dns_cache_data_put(struct dns_cache_data *cache_data)
{
if (cache_data == NULL) {
return;
}
if (!atomic_dec_and_test(&cache_data->head.ref)) {
return;
}
free(cache_data);
}
int dns_cache_is_visited(struct dns_cache *dns_cache)
@@ -591,127 +632,47 @@ void dns_cache_update(struct dns_cache *dns_cache)
pthread_mutex_unlock(&dns_cache_head.lock);
}
static void _dns_cache_remove_expired_ttl(dns_cache_callback inactive_precallback, int ttl_inactive_pre,
unsigned int max_callback_num, const time_t *now)
static int _dns_cache_read_to_cache(struct dns_cache_record *cache_record, struct dns_cache_data *cache_data)
{
struct dns_cache *dns_cache = NULL;
struct dns_cache *tmp = NULL;
unsigned int callback_num = 0;
int ttl = 0;
LIST_HEAD(checklist);
struct list_head *head = NULL;
head = &dns_cache_head.cache_list;
struct dns_cache_info *info = &cache_record->info;
pthread_mutex_lock(&dns_cache_head.lock);
list_for_each_entry_safe(dns_cache, tmp, &dns_cache_head.inactive_list, list)
{
ttl = dns_cache->info.insert_time + dns_cache->info.ttl - *now;
if (ttl > 0) {
continue;
}
if (dns_cache_head.inactive_list_expired + ttl < 0) {
_dns_cache_remove(dns_cache);
continue;
}
ttl = *now - dns_cache->info.replace_time;
if (ttl < ttl_inactive_pre || inactive_precallback == NULL) {
continue;
}
if (callback_num >= max_callback_num) {
continue;
}
if (dns_cache->del_pending == 1) {
continue;
}
/* If the TTL time is in the pre-timeout range, call callback function */
dns_cache_get(dns_cache);
list_add_tail(&dns_cache->check_list, &checklist);
dns_cache->del_pending = 1;
callback_num++;
time_t now = time(NULL);
unsigned int seed_tmp = now;
int passed_time = now - info->replace_time;
int timeout = info->timeout - passed_time;
if (timeout < DNS_CACHE_READ_TIMEOUT * 2) {
timeout = DNS_CACHE_READ_TIMEOUT + (rand_r(&seed_tmp) % DNS_CACHE_READ_TIMEOUT);
}
pthread_mutex_unlock(&dns_cache_head.lock);
list_for_each_entry_safe(dns_cache, tmp, &checklist, check_list)
{
/* run inactive_precallback */
if (inactive_precallback) {
inactive_precallback(dns_cache);
}
dns_cache_release(dns_cache);
if (timeout > dns_conf_serve_expired_ttl && dns_conf_serve_expired_ttl >= 0) {
timeout = dns_conf_serve_expired_ttl;
}
info->timeout = timeout;
if (_dns_cache_insert(&cache_record->info, cache_data, head) != 0) {
tlog(TLOG_ERROR, "insert cache data failed.");
cache_data = NULL;
goto errout;
}
dns_cache_data_get(cache_data);
daemon_keepalive();
return 0;
errout:
return -1;
}
void dns_cache_invalidate(dns_cache_callback precallback, int ttl_pre, unsigned int max_callback_num,
dns_cache_callback inactive_precallback, int ttl_inactive_pre)
static int _dns_cache_read_record(int fd, uint32_t cache_number, dns_cache_read_callback callback)
{
struct dns_cache *dns_cache = NULL;
struct dns_cache *tmp = NULL;
time_t now = 0;
int ttl = 0;
LIST_HEAD(checklist);
unsigned int callback_num = 0;
if (max_callback_num <= 0) {
max_callback_num = -1;
}
if (dns_cache_head.size <= 0) {
return;
}
time(&now);
pthread_mutex_lock(&dns_cache_head.lock);
list_for_each_entry_safe(dns_cache, tmp, &dns_cache_head.cache_list, list)
{
ttl = dns_cache->info.insert_time + dns_cache->info.ttl - now;
if (ttl > 0 && ttl < ttl_pre) {
/* If the TTL time is in the pre-timeout range, call callback function */
if (precallback && dns_cache->del_pending == 0 && callback_num < max_callback_num) {
list_add_tail(&dns_cache->check_list, &checklist);
dns_cache_get(dns_cache);
dns_cache->del_pending = 1;
callback_num++;
continue;
}
}
if (ttl < 0) {
if (dns_cache_head.enable_inactive && dns_cache->info.no_inactive == 0) {
_dns_cache_move_inactive(dns_cache);
} else {
_dns_cache_remove(dns_cache);
}
}
}
pthread_mutex_unlock(&dns_cache_head.lock);
if (dns_cache_head.enable_inactive && dns_cache_head.inactive_list_expired != 0) {
_dns_cache_remove_expired_ttl(inactive_precallback, ttl_inactive_pre, max_callback_num, &now);
}
list_for_each_entry_safe(dns_cache, tmp, &checklist, check_list)
{
/* run callback */
if (precallback) {
precallback(dns_cache);
}
list_del(&dns_cache->check_list);
dns_cache_release(dns_cache);
}
}
static int _dns_cache_read_record(int fd, uint32_t cache_number)
{
unsigned int i = 0;
ssize_t ret = 0;
struct dns_cache_record cache_record;
struct dns_cache_data_head data_head;
struct dns_cache_data *cache_data = NULL;
struct list_head *head = NULL;
for (i = 0; i < cache_number; i++) {
ret = read(fd, &cache_record, sizeof(cache_record));
@@ -725,15 +686,6 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number)
goto errout;
}
if (cache_record.type == CACHE_RECORD_TYPE_ACTIVE) {
head = &dns_cache_head.cache_list;
} else if (cache_record.type == CACHE_RECORD_TYPE_INACTIVE) {
head = &dns_cache_head.inactive_list;
} else {
tlog(TLOG_ERROR, "read cache record type is invalid.");
goto errout;
}
ret = read(fd, &data_head, sizeof(data_head));
if (ret != sizeof(data_head)) {
tlog(TLOG_ERROR, "read data head failed, %s", strerror(errno));
@@ -757,6 +709,7 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number)
}
memcpy(&cache_data->head, &data_head, sizeof(data_head));
atomic_set(&cache_data->head.ref, 1);
ret = read(fd, cache_data->data, data_head.size);
if (ret != data_head.size) {
tlog(TLOG_ERROR, "read cache data failed, %s", strerror(errno));
@@ -767,29 +720,23 @@ static int _dns_cache_read_record(int fd, uint32_t cache_number)
cache_record.info.is_visited = 0;
cache_record.info.domain[DNS_MAX_CNAME_LEN - 1] = '\0';
cache_record.info.dns_group_name[DNS_GROUP_NAME_LEN - 1] = '\0';
if (cache_record.type >= CACHE_RECORD_TYPE_END) {
tlog(TLOG_ERROR, "read cache record type is invalid.");
goto errout;
}
if (_dns_cache_insert(&cache_record.info, cache_data, head) != 0) {
tlog(TLOG_ERROR, "insert cache data failed.");
cache_data = NULL;
goto errout;
}
ret = callback(&cache_record, cache_data);
dns_cache_data_put(cache_data);
cache_data = NULL;
if (ret != 0) {
goto errout;
}
}
return 0;
errout:
if (cache_data) {
free(cache_data);
dns_cache_data_put(cache_data);
}
return -1;
}
int dns_cache_load(const char *file)
static int _dns_cache_file_read(const char *file, dns_cache_read_callback callback)
{
int fd = -1;
ssize_t ret = 0;
@@ -822,7 +769,7 @@ int dns_cache_load(const char *file)
}
tlog(TLOG_INFO, "load cache file %s, total %d records", file, cache_file.cache_number);
if (_dns_cache_read_record(fd, cache_file.cache_number) != 0) {
if (_dns_cache_read_record(fd, cache_file.cache_number, callback) != 0) {
goto errout;
}
@@ -836,17 +783,23 @@ errout:
return -1;
}
static int _dns_cache_write_record(int fd, uint32_t *cache_number, enum CACHE_RECORD_TYPE type, struct list_head *head)
int dns_cache_load(const char *file)
{
return _dns_cache_file_read(file, _dns_cache_read_to_cache);
}
static int _dns_cache_write_record(int fd, uint32_t *cache_number, struct list_head *head)
{
struct dns_cache *dns_cache = NULL;
struct dns_cache *tmp = NULL;
struct dns_cache_record cache_record;
memset(&cache_record, 0, sizeof(cache_record));
pthread_mutex_lock(&dns_cache_head.lock);
list_for_each_entry_safe_reverse(dns_cache, tmp, head, list)
list_for_each_entry_safe(dns_cache, tmp, head, list)
{
cache_record.magic = MAGIC_RECORD;
cache_record.type = type;
memcpy(&cache_record.info, &dns_cache->info, sizeof(struct dns_cache_info));
ssize_t ret = write(fd, &cache_record, sizeof(cache_record));
if (ret != sizeof(cache_record)) {
@@ -874,12 +827,7 @@ errout:
static int _dns_cache_write_records(int fd, uint32_t *cache_number)
{
if (_dns_cache_write_record(fd, cache_number, CACHE_RECORD_TYPE_ACTIVE, &dns_cache_head.cache_list) != 0) {
return -1;
}
if (_dns_cache_write_record(fd, cache_number, CACHE_RECORD_TYPE_INACTIVE, &dns_cache_head.inactive_list) != 0) {
if (_dns_cache_write_record(fd, cache_number, &dns_cache_head.cache_list) != 0) {
return -1;
}
@@ -945,17 +893,29 @@ errout:
return -1;
}
static int _dns_cache_print(struct dns_cache_record *cache_record, struct dns_cache_data *cache_data)
{
printf("domain: %s, qtype: %d, ttl: %d, speed: %.1fms\n", cache_record->info.domain, cache_record->info.qtype,
cache_record->info.ttl, (float)cache_record->info.speed / 10);
return 0;
}
int dns_cache_print(const char *file)
{
if (access(file, F_OK) != 0) {
tlog(TLOG_ERROR, "cache file %s not exist.", file);
return -1;
}
return _dns_cache_file_read(file, _dns_cache_print);
}
void dns_cache_destroy(void)
{
struct dns_cache *dns_cache = NULL;
struct dns_cache *tmp = NULL;
pthread_mutex_lock(&dns_cache_head.lock);
list_for_each_entry_safe(dns_cache, tmp, &dns_cache_head.inactive_list, list)
{
_dns_cache_delete(dns_cache);
}
list_for_each_entry_safe(dns_cache, tmp, &dns_cache_head.cache_list, list)
{
_dns_cache_delete(dns_cache);
@@ -963,10 +923,11 @@ void dns_cache_destroy(void)
pthread_mutex_unlock(&dns_cache_head.lock);
pthread_mutex_destroy(&dns_cache_head.lock);
hash_table_free(dns_cache_head.cache_hash, free);
}
const char *dns_cache_file_version(void)
{
const char *version = "cache ver 1.0";
const char *version = "cache ver 1.2";
return version;
}

View File

@@ -25,6 +25,7 @@
#include "hash.h"
#include "hashtable.h"
#include "list.h"
#include "timer.h"
#include <stdlib.h>
#include <time.h>
@@ -45,14 +46,9 @@ enum CACHE_TYPE {
CACHE_TYPE_PACKET,
};
enum CACHE_RECORD_TYPE {
CACHE_RECORD_TYPE_ACTIVE,
CACHE_RECORD_TYPE_INACTIVE,
CACHE_RECORD_TYPE_END,
};
struct dns_cache_data_head {
enum CACHE_TYPE cache_type;
atomic_t ref;
int is_soa;
ssize_t size;
uint32_t magic;
@@ -90,7 +86,7 @@ struct dns_cache_info {
int ttl;
int hitnum;
int speed;
int no_inactive;
int timeout;
int hitnum_update_add;
int is_visited;
time_t insert_time;
@@ -99,7 +95,6 @@ struct dns_cache_info {
struct dns_cache_record {
uint32_t magic;
enum CACHE_RECORD_TYPE type;
struct dns_cache_info info;
};
@@ -113,6 +108,8 @@ struct dns_cache {
struct dns_cache_info info;
struct dns_cache_data *cache_data;
struct tw_timer_list timer;
};
struct dns_cache_file {
@@ -134,19 +131,16 @@ uint32_t dns_cache_get_query_flag(struct dns_cache *dns_cache);
const char *dns_cache_get_dns_group_name(struct dns_cache *dns_cache);
void dns_cache_data_free(struct dns_cache_data *data);
struct dns_cache_data *dns_cache_new_data_packet(void *packet, size_t packet_len);
int dns_cache_init(int size, int enable_inactive, int inactive_list_expired);
typedef int (*dns_cache_callback)(struct dns_cache *dns_cache);
int dns_cache_replace(struct dns_cache_key *key, int ttl, int speed, int no_inactive,
int dns_cache_init(int size, dns_cache_callback timeout_callback);
int dns_cache_replace(struct dns_cache_key *key, int ttl, int speed, int tiemout, int update_time,
struct dns_cache_data *cache_data);
int dns_cache_replace_inactive(struct dns_cache_key *key, int ttl, int speed, int no_inactive,
struct dns_cache_data *cache_data);
int dns_cache_insert(struct dns_cache_key *key, int ttl, int speed, int no_inactive, struct dns_cache_data *cache_data);
int dns_cache_insert(struct dns_cache_key *key, int ttl, int speed, int timeout, struct dns_cache_data *cache_data);
struct dns_cache *dns_cache_lookup(struct dns_cache_key *key);
@@ -162,11 +156,6 @@ int dns_cache_is_visited(struct dns_cache *dns_cache);
void dns_cache_update(struct dns_cache *dns_cache);
typedef void dns_cache_callback(struct dns_cache *dns_cache);
void dns_cache_invalidate(dns_cache_callback precallback, int ttl_pre, unsigned int max_callback_num,
dns_cache_callback inactive_precallback, int ttl_inactive_pre);
int dns_cache_get_ttl(struct dns_cache *dns_cache);
int dns_cache_get_cname_ttl(struct dns_cache *dns_cache);
@@ -177,6 +166,10 @@ struct dns_cache_data *dns_cache_new_data_addr(void);
struct dns_cache_data *dns_cache_get_data(struct dns_cache *dns_cache);
void dns_cache_data_get(struct dns_cache_data *cache_data);
void dns_cache_data_put(struct dns_cache_data *cache_data);
void dns_cache_set_data_addr(struct dns_cache_data *dns_cache, char *cname, int cname_ttl, unsigned char *addr,
int addr_len);
@@ -188,6 +181,8 @@ int dns_cache_load(const char *file);
int dns_cache_save(const char *file, int check_lock);
int dns_cache_print(const char *file);
const char *dns_cache_file_version(void);
#ifdef __cplusplus

View File

@@ -63,7 +63,7 @@
#define DNS_TCP_CONNECT_TIMEOUT (5)
#define DNS_QUERY_TIMEOUT (500)
#define DNS_QUERY_RETRY (4)
#define DNS_PENDING_SERVER_RETRY 40
#define DNS_PENDING_SERVER_RETRY 60
#define SOCKET_PRIORITY (6)
#define SOCKET_IP_TOS (IPTOS_LOWDELAY | IPTOS_RELIABILITY)
@@ -161,7 +161,8 @@ struct dns_server_pending {
unsigned int has_v6;
unsigned int query_v4;
unsigned int query_v6;
unsigned int has_soa;
unsigned int has_soa_v4;
unsigned int has_soa_v6;
/* server type */
dns_server_type_t type;
@@ -442,8 +443,8 @@ static struct addrinfo *_dns_client_getaddr(const char *host, char *port, int ty
ret = getaddrinfo(host, port, &hints, &result);
if (ret != 0) {
tlog(TLOG_ERROR, "get addr info failed. %s\n", gai_strerror(ret));
tlog(TLOG_ERROR, "host = %s, port = %s, type = %d, protocol = %d", host, port, type, protocol);
tlog(TLOG_WARN, "get addr info failed. %s\n", gai_strerror(ret));
tlog(TLOG_WARN, "host = %s, port = %s, type = %d, protocol = %d", host, port, type, protocol);
goto errout;
}
@@ -1074,7 +1075,7 @@ static int _dns_client_server_add(char *server_ip, char *server_host, int port,
return 0;
}
snprintf(port_s, 8, "%d", port);
snprintf(port_s, sizeof(port_s), "%d", port);
gai = _dns_client_getaddr(server_ip, port_s, sock_type, 0);
if (gai == NULL) {
tlog(TLOG_DEBUG, "get address failed, %s:%d", server_ip, port);
@@ -1402,6 +1403,8 @@ static int _dns_client_add_server_pending(char *server_ip, char *server_host, in
struct client_dns_server_flags *flags, int is_pending)
{
int ret = 0;
struct addrinfo *gai = NULL;
char server_ip_tmp[DNS_HOSTNAME_LEN] = {0};
if (server_type >= DNS_SERVER_TYPE_END) {
tlog(TLOG_ERROR, "server type is invalid.");
@@ -1414,6 +1417,22 @@ static int _dns_client_add_server_pending(char *server_ip, char *server_host, in
tlog(TLOG_INFO, "add pending server %s", server_ip);
return 0;
}
} else if (check_is_ipaddr(server_ip) && is_pending == 0) {
gai = _dns_client_getaddr(server_ip, NULL, SOCK_STREAM, 0);
if (gai == NULL) {
return -1;
}
if (get_host_by_addr(server_ip_tmp, sizeof(server_ip_tmp), gai->ai_addr) != NULL) {
tlog(TLOG_INFO, "resolve %s to %s.", server_ip, server_ip_tmp);
server_ip = server_ip_tmp;
} else {
tlog(TLOG_INFO, "resolve %s failed.", server_ip);
freeaddrinfo(gai);
return -1;
}
freeaddrinfo(gai);
}
/* add server */
@@ -1783,13 +1802,8 @@ static int _dns_client_create_socket_udp_proxy(struct dns_server_info *server_in
ret = proxy_conn_connect(proxy);
if (ret != 0) {
if (errno == ENETUNREACH || errno == EHOSTUNREACH || errno == EPERM || errno == EACCES) {
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
if (errno != EINPROGRESS) {
tlog(TLOG_ERROR, "connect %s failed, %s", server_info->ip, strerror(errno));
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
}
@@ -1843,14 +1857,8 @@ static int _dns_client_create_socket_udp(struct dns_server_info *server_info)
server_info->status = DNS_SERVER_STATUS_CONNECTIONLESS;
if (connect(fd, &server_info->addr, server_info->ai_addrlen) != 0) {
if (errno == ENETUNREACH || errno == EHOSTUNREACH || errno == ECONNREFUSED || errno == EPERM ||
errno == EACCES) {
tlog(TLOG_INFO, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
if (errno != EINPROGRESS) {
tlog(TLOG_ERROR, "connect %s failed, %s", server_info->ip, strerror(errno));
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
}
@@ -1950,13 +1958,8 @@ static int _DNS_client_create_socket_tcp(struct dns_server_info *server_info)
}
if (ret != 0) {
if (errno == ENETUNREACH || errno == EHOSTUNREACH || errno == ECONNREFUSED) {
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
if (errno != EINPROGRESS) {
tlog(TLOG_ERROR, "connect %s failed, %s", server_info->ip, strerror(errno));
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
}
@@ -2064,13 +2067,8 @@ static int _DNS_client_create_socket_tls(struct dns_server_info *server_info, ch
}
if (ret != 0) {
if (errno == ENETUNREACH || errno == EHOSTUNREACH || errno == ECONNREFUSED) {
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
if (errno != EINPROGRESS) {
tlog(TLOG_ERROR, "connect %s failed, %s", server_info->ip, strerror(errno));
tlog(TLOG_DEBUG, "connect %s failed, %s", server_info->ip, strerror(errno));
goto errout;
}
}
@@ -2239,13 +2237,14 @@ static int _dns_client_process_udp_proxy(struct dns_server_info *server_info, st
}
int latency = get_tick_count() - server_info->send_tick;
tlog(TLOG_DEBUG, "recv udp packet from %s, len: %d, latency: %d",
get_host_by_addr(from_host, sizeof(from_host), (struct sockaddr *)&from), len, latency);
if (latency < server_info->drop_packet_latency_ms) {
tlog(TLOG_DEBUG, "drop packet from %s, latency: %d", from_host, latency);
return 0;
}
tlog(TLOG_DEBUG, "recv udp packet from %s, len: %d",
get_host_by_addr(from_host, sizeof(from_host), (struct sockaddr *)&from), len);
/* update recv time */
time(&server_info->last_recv);
@@ -2322,14 +2321,15 @@ static int _dns_client_process_udp(struct dns_server_info *server_info, struct e
}
}
tlog(TLOG_DEBUG, "recv udp packet from %s, len: %d, ttl: %d",
get_host_by_addr(from_host, sizeof(from_host), (struct sockaddr *)&from), len, ttl);
int latency = get_tick_count() - server_info->send_tick;
tlog(TLOG_DEBUG, "recv udp packet from %s, len: %d, ttl: %d, latency: %d",
get_host_by_addr(from_host, sizeof(from_host), (struct sockaddr *)&from), len, ttl, latency);
/* update recv time */
time(&server_info->last_recv);
int latency = get_tick_count() - server_info->send_tick;
if (latency < server_info->drop_packet_latency_ms) {
tlog(TLOG_DEBUG, "drop packet from %s, latency: %d", from_host, latency);
return 0;
}
@@ -2818,7 +2818,7 @@ static int _dns_client_verify_common_name(struct dns_server_info *server_info, X
tlog(TLOG_DEBUG, "peer SAN: %s", dns->data);
if (_dns_client_tls_matchName(tls_host_verify, (char *)dns->data, dns->length) == 0) {
tlog(TLOG_INFO, "peer SAN match: %s", dns->data);
tlog(TLOG_DEBUG, "peer SAN match: %s", dns->data);
return 0;
}
} break;
@@ -2969,9 +2969,7 @@ static int _dns_client_process_tls(struct dns_server_info *server_info, struct e
if (server_info->status == DNS_SERVER_STATUS_CONNECTING) {
/* do SSL hand shake */
ret = _ssl_do_handshake(server_info);
if (ret == 0) {
goto errout;
} else if (ret < 0) {
if (ret <= 0) {
memset(&fd_event, 0, sizeof(fd_event));
ssl_ret = _ssl_get_error(server_info, ret);
if (ssl_ret == SSL_ERROR_WANT_READ) {
@@ -3370,21 +3368,22 @@ static int _dns_client_setup_server_packet(struct dns_server_info *server_info,
struct dns_packet *packet = (struct dns_packet *)packet_buff;
struct dns_head head;
int encode_len = 0;
int repack = 0;
int hitchhiking = 0;
*packet_data = default_packet;
*packet_data_len = default_packet_len;
if (query->qtype != DNS_T_AAAA && query->qtype != DNS_T_A) {
/* no need to encode packet */
return 0;
if (server_info->ecs_ipv4.enable == true || server_info->ecs_ipv6.enable == true) {
repack = 1;
}
if (server_info->ecs_ipv4.enable == false && query->qtype == DNS_T_A) {
/* no need to encode packet */
return 0;
if ((server_info->flags.server_flag & SERVER_FLAG_HITCHHIKING) != 0) {
hitchhiking = 1;
repack = 1;
}
if (server_info->ecs_ipv6.enable == false && query->qtype == DNS_T_AAAA) {
if (repack == 0) {
/* no need to encode packet */
return 0;
}
@@ -3404,6 +3403,11 @@ static int _dns_client_setup_server_packet(struct dns_server_info *server_info,
return -1;
}
if (hitchhiking != 0 && dns_add_domain(packet, "-", query->qtype, DNS_C_IN) != 0) {
tlog(TLOG_ERROR, "add domain to packet failed.");
return -1;
}
/* add question */
if (dns_add_domain(packet, query->domain, query->qtype, DNS_C_IN) != 0) {
tlog(TLOG_ERROR, "add domain to packet failed.");
@@ -3412,10 +3416,16 @@ static int _dns_client_setup_server_packet(struct dns_server_info *server_info,
dns_set_OPT_payload_size(packet, DNS_IN_PACKSIZE);
/* dns_add_OPT_TCP_KEEPALIVE(packet, 600); */
if (query->qtype == DNS_T_A && server_info->ecs_ipv4.enable) {
if ((query->qtype == DNS_T_A && server_info->ecs_ipv4.enable)) {
dns_add_OPT_ECS(packet, &server_info->ecs_ipv4.ecs);
} else if (query->qtype == DNS_T_AAAA && server_info->ecs_ipv6.enable) {
} else if ((query->qtype == DNS_T_AAAA && server_info->ecs_ipv6.enable)) {
dns_add_OPT_ECS(packet, &server_info->ecs_ipv6.ecs);
} else {
if (server_info->ecs_ipv6.enable) {
dns_add_OPT_ECS(packet, &server_info->ecs_ipv6.ecs);
} else if (server_info->ecs_ipv4.enable) {
dns_add_OPT_ECS(packet, &server_info->ecs_ipv4.ecs);
}
}
/* encode packet */
@@ -3496,6 +3506,8 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet,
atomic_inc(&query->dns_request_sent);
send_count++;
errno = 0;
server_info->send_tick = get_tick_count();
switch (server_info->type) {
case DNS_SERVER_UDP:
/* udp query */
@@ -3548,7 +3560,6 @@ static int _dns_client_send_packet(struct dns_query_struct *query, void *packet,
continue;
}
time(&server_info->last_send);
server_info->send_tick = get_tick_count();
}
pthread_mutex_unlock(&client.server_list_lock);
@@ -3651,20 +3662,20 @@ static int _dns_client_query_setup_default_ecs(struct dns_query_struct *query)
if (client.ecs_ipv4.enable) {
add_ipv4_ecs = 1;
} else if (client.ecs_ipv6.enable) {
add_ipv4_ecs = 1;
add_ipv6_ecs = 1;
}
}
if (add_ipv4_ecs) {
memcpy(&query->ecs, &client.ecs_ipv4, sizeof(query->ecs));
return 0;
}
if (add_ipv6_ecs) {
memcpy(&query->ecs, &client.ecs_ipv6, sizeof(query->ecs));
return 0;
}
if (add_ipv4_ecs) {
memcpy(&query->ecs, &client.ecs_ipv4, sizeof(query->ecs));
return 0;
}
return 0;
}
@@ -3889,31 +3900,40 @@ static void _dns_client_check_servers(void)
pthread_mutex_unlock(&client.server_list_lock);
}
static int _dns_client_pending_server_resolve(const char *domain, dns_rtcode_t rtcode, dns_type_t addr_type, char *ip,
unsigned int ping_time, void *user_ptr)
static int _dns_client_pending_server_resolve(const struct dns_result *result, void *user_ptr)
{
struct dns_server_pending *pending = user_ptr;
int ret = 0;
int has_soa = 0;
if (rtcode == DNS_RC_NXDOMAIN) {
pending->has_soa = 1;
if (result->rtcode == DNS_RC_NXDOMAIN || result->has_soa == 1 || result->rtcode == DNS_RC_REFUSED ||
(result->rtcode == DNS_RC_NOERROR && result->ip_num == 0)) {
has_soa = 1;
}
if (addr_type == DNS_T_A) {
if (result->addr_type == DNS_T_A) {
pending->ping_time_v4 = -1;
if (rtcode == DNS_RC_NOERROR) {
if (result->rtcode == DNS_RC_NOERROR && result->ip_num > 0) {
pending->has_v4 = 1;
pending->ping_time_v4 = ping_time;
pending->has_soa = 0;
safe_strncpy(pending->ipv4, ip, DNS_HOSTNAME_LEN);
pending->ping_time_v4 = result->ping_time;
pending->has_soa_v4 = 0;
safe_strncpy(pending->ipv4, result->ip, DNS_HOSTNAME_LEN);
} else if (has_soa) {
pending->has_v4 = 0;
pending->ping_time_v4 = -1;
pending->has_soa_v4 = 1;
}
} else if (addr_type == DNS_T_AAAA) {
} else if (result->addr_type == DNS_T_AAAA) {
pending->ping_time_v6 = -1;
if (rtcode == DNS_RC_NOERROR) {
if (result->rtcode == DNS_RC_NOERROR && result->ip_num > 0) {
pending->has_v6 = 1;
pending->ping_time_v6 = ping_time;
pending->has_soa = 0;
safe_strncpy(pending->ipv6, ip, DNS_HOSTNAME_LEN);
pending->ping_time_v6 = result->ping_time;
pending->has_soa_v6 = 0;
safe_strncpy(pending->ipv6, result->ip, DNS_HOSTNAME_LEN);
} else if (has_soa) {
pending->has_v6 = 0;
pending->ping_time_v6 = -1;
pending->has_soa_v6 = 1;
}
} else {
ret = -1;
@@ -4007,10 +4027,29 @@ static void _dns_client_add_pending_servers(void)
int add_success = 0;
char *dnsserver_ip = NULL;
/* if has no bootstrap DNS, just call getaddrinfo to get address */
if (dns_client_has_bootstrap_dns == 0) {
list_del_init(&pending->retry_list);
_dns_client_server_pending_release(pending);
pending->retry_cnt++;
if (_dns_client_add_pendings(pending, pending->host) != 0) {
pthread_mutex_unlock(&pending_server_mutex);
tlog(TLOG_INFO, "add pending DNS server %s from resolv.conf failed, retry %d...", pending->host,
pending->retry_cnt - 1);
if (pending->retry_cnt - 1 > DNS_PENDING_SERVER_RETRY) {
tlog(TLOG_WARN, "add pending DNS server %s from resolv.conf failed, exit...", pending->host);
exit(1);
}
continue;
}
_dns_client_server_pending_release(pending);
continue;
}
if (pending->query_v4 == 0) {
pending->query_v4 = 1;
_dns_client_server_pending_get(pending);
if (dns_server_query(pending->host, DNS_T_A, 0, _dns_client_pending_server_resolve, pending) != 0) {
if (dns_server_query(pending->host, DNS_T_A, NULL, _dns_client_pending_server_resolve, pending) != 0) {
_dns_client_server_pending_release(pending);
pending->query_v4 = 0;
}
@@ -4019,9 +4058,9 @@ static void _dns_client_add_pending_servers(void)
if (pending->query_v6 == 0) {
pending->query_v6 = 1;
_dns_client_server_pending_get(pending);
if (dns_server_query(pending->host, DNS_T_AAAA, 0, _dns_client_pending_server_resolve, pending) != 0) {
if (dns_server_query(pending->host, DNS_T_AAAA, NULL, _dns_client_pending_server_resolve, pending) != 0) {
_dns_client_server_pending_release(pending);
pending->query_v4 = 0;
pending->query_v6 = 0;
}
}
@@ -4052,7 +4091,7 @@ static void _dns_client_add_pending_servers(void)
continue;
}
if (pending->has_soa && dnsserver_ip == NULL) {
if (dnsserver_ip == NULL && pending->has_soa_v4 && pending->has_soa_v6) {
tlog(TLOG_WARN, "add pending DNS server %s failed, no such host.", pending->host);
_dns_client_server_pending_remove(pending);
continue;
@@ -4068,22 +4107,6 @@ static void _dns_client_add_pending_servers(void)
pending->query_v4 = 0;
pending->query_v6 = 0;
}
/* if has no bootstrap DNS, just call getaddrinfo to get address */
if (dns_client_has_bootstrap_dns == 0) {
if (_dns_client_add_pendings(pending, pending->host) != 0) {
pthread_mutex_unlock(&pending_server_mutex);
tlog(TLOG_INFO, "add pending DNS server %s from resolv.conf failed, retry %d...", pending->host,
pending->retry_cnt - 1);
if (pending->retry_cnt - 1 > DNS_PENDING_SERVER_RETRY) {
tlog(TLOG_WARN, "add pending DNS server %s from resolv.conf failed, exit...", pending->host);
exit(1);
}
continue;
}
_dns_client_server_pending_release(pending);
}
}
}
@@ -4117,14 +4140,19 @@ static void _dns_client_period_run(unsigned int msec)
{
/* free timed out query, and notify caller */
list_del_init(&query->period_list);
_dns_client_check_udp_nat(query);
/* check udp nat after retrying. */
if (atomic_read(&query->retry_count) == 1) {
_dns_client_check_udp_nat(query);
}
if (atomic_dec_and_test(&query->retry_count) || (query->has_result != 0)) {
_dns_client_query_remove(query);
if (query->has_result == 0) {
tlog(TLOG_INFO, "retry query %s, type: %d, id: %d failed", query->domain, query->qtype, query->sid);
tlog(TLOG_DEBUG, "retry query %s, type: %d, id: %d failed", query->domain, query->qtype, query->sid);
}
} else {
tlog(TLOG_INFO, "retry query %s, type: %d, id: %d", query->domain, query->qtype, query->sid);
tlog(TLOG_DEBUG, "retry query %s, type: %d, id: %d", query->domain, query->qtype, query->sid);
_dns_client_send_query(query);
}
_dns_client_query_release(query);

File diff suppressed because it is too large Load Diff

View File

@@ -57,6 +57,9 @@ extern "C" {
#define DEFAULT_DNS_TLS_PORT 853
#define DEFAULT_DNS_HTTPS_PORT 443
#define DNS_MAX_CONF_CNAME_LEN 256
#define MAX_QTYPE_NUM 65535
#define DNS_MAX_REPLY_IP_NUM 8
#define SMARTDNS_CONF_FILE "/etc/smartdns/smartdns.conf"
#define SMARTDNS_LOG_FILE "/var/log/smartdns/smartdns.log"
#define SMARTDNS_AUDIT_FILE "/var/log/smartdns/smartdns-audit.log"
@@ -82,6 +85,12 @@ enum domain_rule {
DOMAIN_RULE_MAX,
};
enum ip_rule {
IP_RULE_FLAGS = 0,
IP_RULE_ALIAS = 1,
IP_RULE_MAX,
};
typedef enum {
DNS_BIND_TYPE_UDP,
DNS_BIND_TYPE_TCP,
@@ -114,8 +123,15 @@ typedef enum {
#define DOMAIN_FLAG_NO_SERVE_EXPIRED (1 << 15)
#define DOMAIN_FLAG_CNAME_IGN (1 << 16)
#define DOMAIN_FLAG_NO_CACHE (1 << 17)
#define DOMAIN_FLAG_NO_IPALIAS (1 << 18)
#define IP_RULE_FLAG_BLACKLIST (1 << 0)
#define IP_RULE_FLAG_WHITELIST (1 << 1)
#define IP_RULE_FLAG_BOGUS (1 << 2)
#define IP_RULE_FLAG_IP_IGNORE (1 << 3)
#define SERVER_FLAG_EXCLUDE_DEFAULT (1 << 0)
#define SERVER_FLAG_HITCHHIKING (1 << 1)
#define BIND_FLAG_NO_RULE_ADDR (1 << 0)
#define BIND_FLAG_NO_RULE_NAMESERVER (1 << 1)
@@ -128,6 +144,7 @@ typedef enum {
#define BIND_FLAG_FORCE_AAAA_SOA (1 << 8)
#define BIND_FLAG_NO_RULE_CNAME (1 << 9)
#define BIND_FLAG_NO_RULE_NFTSET (1 << 10)
#define BIND_FLAG_NO_IP_ALIAS (1 << 11)
enum response_mode_type {
DNS_RESPONSE_MODE_FIRST_PING_IP = 0,
@@ -148,12 +165,14 @@ struct dns_rule_flags {
struct dns_rule_address_IPV4 {
struct dns_rule head;
unsigned char ipv4_addr[DNS_RR_A_LEN];
char addr_num;
unsigned char ipv4_addr[][DNS_RR_A_LEN];
};
struct dns_rule_address_IPV6 {
struct dns_rule head;
unsigned char ipv6_addr[DNS_RR_AAAA_LEN];
char addr_num;
unsigned char ipv6_addr[][DNS_RR_AAAA_LEN];
};
struct dns_ipset_name {
@@ -212,8 +231,9 @@ extern struct dns_nftset_names dns_conf_nftset_no_speed;
struct dns_domain_rule {
struct dns_rule head;
unsigned char sub_rule_only : 1;
unsigned char root_rule_only : 1;
struct dns_rule *rules[DOMAIN_RULE_MAX];
int is_sub_rule[DOMAIN_RULE_MAX];
};
struct dns_nameserver_rule {
@@ -343,18 +363,18 @@ struct dns_bogus_ip_address {
};
};
enum address_rule {
ADDRESS_RULE_BLACKLIST = 1,
ADDRESS_RULE_WHITELIST = 2,
ADDRESS_RULE_BOGUS = 3,
ADDRESS_RULE_IP_IGNORE = 4,
struct dns_iplist_ip_address {
int addr_len;
union {
unsigned char ipv4_addr[DNS_RR_A_LEN];
unsigned char ipv6_addr[DNS_RR_AAAA_LEN];
unsigned char addr[0];
};
};
struct dns_ip_address_rule {
unsigned int blacklist : 1;
unsigned int whitelist : 1;
unsigned int bogus : 1;
unsigned int ip_ignore : 1;
struct dns_iplist_ip_addresses {
int ipaddr_num;
struct dns_iplist_ip_address *ipaddr;
};
struct dns_conf_address_rule {
@@ -381,15 +401,7 @@ struct dns_bind_ip {
struct nftset_ipset_rules nftset_ipset_rule;
};
struct dns_qtype_soa_list {
struct hlist_node node;
uint32_t qtypeid;
};
struct dns_qtype_soa_table {
DECLARE_HASHTABLE(qtype, 8);
};
extern struct dns_qtype_soa_table dns_qtype_soa_table;
extern uint8_t *dns_qtype_soa_table;
struct dns_domain_set_rule {
struct list_head list;
@@ -420,8 +432,48 @@ struct dns_domain_set_name_table {
};
extern struct dns_domain_set_name_table dns_domain_set_name_table;
struct dns_ip_rule {
atomic_t refcnt;
enum ip_rule rule;
};
enum dns_ip_set_type {
DNS_IP_SET_LIST = 0,
};
struct dns_ip_rules {
struct dns_ip_rule *rules[IP_RULE_MAX];
};
struct ip_rule_flags {
struct dns_ip_rule head;
unsigned int flags;
unsigned int is_flag_set;
};
struct ip_rule_alias {
struct dns_ip_rule head;
struct dns_iplist_ip_addresses ip_alias;
};
struct dns_ip_set_name {
struct list_head list;
enum dns_ip_set_type type;
char file[DNS_MAX_PATH];
};
struct dns_ip_set_name_list {
struct hlist_node node;
char name[DNS_MAX_CNAME_LEN];
struct list_head set_name_list;
};
struct dns_ip_set_name_table {
DECLARE_HASHTABLE(names, 4);
};
extern struct dns_ip_set_name_table dns_ip_set_name_table;
struct dns_set_rule_add_callback_args {
enum domain_rule type;
int type;
void *rule;
};
@@ -520,6 +572,9 @@ extern int dns_save_fail_packet;
extern char dns_save_fail_packet_dir[DNS_MAX_PATH];
extern char dns_resolv_file[DNS_MAX_PATH];
extern int dns_no_pidfile;
extern int dns_no_daemon;
void dns_server_load_exit(void);
int dns_server_load_conf(const char *file);

File diff suppressed because it is too large Load Diff

View File

@@ -49,9 +49,21 @@ void dns_server_stop(void);
void dns_server_exit(void);
#define MAX_IP_NUM 16
struct dns_result {
const char *domain;
dns_rtcode_t rtcode;
dns_type_t addr_type;
const char *ip;
const unsigned char *ip_addr[MAX_IP_NUM];
int ip_num;
int has_soa;
unsigned int ping_time;
};
/* query result notify function */
typedef int (*dns_result_callback)(const char *domain, dns_rtcode_t rtcode, dns_type_t addr_type, char *ip,
unsigned int ping_time, void *user_ptr);
typedef int (*dns_result_callback)(const struct dns_result *result, void *user_ptr);
/* query domain */
int dns_server_query(const char *domain, int qtype, struct dns_server_query_option *server_query_option,

View File

@@ -726,7 +726,7 @@ static int _fast_ping_sendping_v6(struct ping_host_struct *ping_host)
len = sendto(ping.fd_icmp6, &ping_host->packet, sizeof(struct fast_ping_packet), 0, &ping_host->addr,
ping_host->addr_len);
if (len < 0 || len != sizeof(struct fast_ping_packet)) {
if (len != sizeof(struct fast_ping_packet)) {
int err = errno;
if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL || errno == EHOSTUNREACH) {
goto errout;
@@ -806,7 +806,7 @@ static int _fast_ping_sendping_v4(struct ping_host_struct *ping_host)
icmp->icmp_cksum = _fast_ping_checksum((void *)packet, sizeof(struct fast_ping_packet));
len = sendto(ping.fd_icmp, packet, sizeof(struct fast_ping_packet), 0, &ping_host->addr, ping_host->addr_len);
if (len < 0 || len != sizeof(struct fast_ping_packet)) {
if (len != sizeof(struct fast_ping_packet)) {
int err = errno;
if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL || errno == EPERM || errno == EACCES) {
goto errout;
@@ -858,7 +858,7 @@ static int _fast_ping_sendping_udp(struct ping_host_struct *ping_host)
gettimeofday(&ping_host->last, NULL);
len = sendto(fd, &dns_head, sizeof(dns_head), 0, &ping_host->addr, ping_host->addr_len);
if (len < 0 || len != sizeof(dns_head)) {
if (len != sizeof(dns_head)) {
int err = errno;
if (errno == ENETUNREACH || errno == EINVAL || errno == EADDRNOTAVAIL || errno == EPERM || errno == EACCES) {
goto errout;

View File

@@ -189,6 +189,8 @@ int load_conf(const char *file, struct config_item items[], conf_error_handler h
void load_exit(void);
int conf_get_current_lineno(void);
const char *conf_get_conf_file(void);
const char *conf_get_conf_fullpath(const char *path, char *fullpath, size_t path_len);

View File

@@ -29,6 +29,11 @@
#define __must_be_array(a) BUILD_BUG_ON_ZERO(__same_type((a), &(a)[0]))
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]) + __must_be_array(arr))
struct hash_table {
struct hlist_head *table;
unsigned int size;
};
#define DEFINE_HASHTABLE(name, bits) \
struct hlist_head name[1 << (bits)] = \
{ [0 ... ((1 << (bits)) - 1)] = HLIST_HEAD_INIT }
@@ -38,6 +43,8 @@
#define HASH_SIZE(name) (ARRAY_SIZE(name))
#define HASH_BITS(name) ilog2(HASH_SIZE(name))
#define HASH_TABLE_SIZE(name) (1 << ((name).size))
#define HASH_TABLE_BITS(name) ((name).size)
/* Use hash_32 when possible to allow for fast 32bit hashing in 64bit kernels. */
#define hash_min(val, bits) \
@@ -63,6 +70,16 @@ static inline void __hash_init(struct hlist_head *ht, unsigned int sz)
*/
#define hash_init(hashtable) __hash_init(hashtable, HASH_SIZE(hashtable))
#define hash_table_init(hashtable, bits, malloc_func) \
(hashtable).size = bits; \
(hashtable).table = malloc_func(sizeof(struct hlist_head) * HASH_TABLE_SIZE((hashtable))); \
__hash_init((hashtable).table, HASH_TABLE_SIZE((hashtable)))
#define hash_table_free(hashtable, free_func) \
free_func((hashtable).table); \
(hashtable).table = NULL; \
(hashtable).size = 0;
/**
* hash_add - add an object to a hashtable
* @hashtable: hashtable to add to
@@ -72,6 +89,9 @@ static inline void __hash_init(struct hlist_head *ht, unsigned int sz)
#define hash_add(hashtable, node, key) \
hlist_add_head(node, &hashtable[hash_min(key, HASH_BITS(hashtable))])
#define hash_table_add(hashtable, node, key) \
hlist_add_head(node, &(hashtable).table[hash_min(key, HASH_TABLE_BITS(hashtable))])
/**
* hash_hashed - check whether an object is in any hashtable
* @node: the &struct hlist_node of the object to be checked
@@ -101,6 +121,8 @@ static inline bool __hash_empty(struct hlist_head *ht, unsigned int sz)
*/
#define hash_empty(hashtable) __hash_empty(hashtable, HASH_SIZE(hashtable))
#define hash_table_empty(hashtable) __hash_empty((hashtable).table, HASH_TABLE_SIZE(hashtable))
/**
* hash_del - remove an object from a hashtable
* @node: &struct hlist_node of the object to remove
@@ -122,6 +144,11 @@ static inline void hash_del(struct hlist_node *node)
(bkt)++)\
hlist_for_each_entry(obj, &name[bkt], member)
#define hash_table_for_each(name, bkt, obj, member) \
for ((bkt) = 0, obj = NULL; obj == NULL && (bkt) < (HASH_TABLE_SIZE(name));\
(bkt)++)\
hlist_for_each_entry(obj, &((name).table)[bkt], member)
/**
* hash_for_each_safe - iterate over a hashtable safe against removal of
* hash entry
@@ -136,6 +163,11 @@ static inline void hash_del(struct hlist_node *node)
(bkt)++)\
hlist_for_each_entry_safe(obj, tmp, &name[bkt], member)
#define hash_table_for_each_safe(name, bkt, tmp, obj, member) \
for ((bkt) = 0, obj = NULL; obj == NULL && (bkt) < (HASH_TABLE_SIZE(name));\
(bkt)++)\
hlist_for_each_entry_safe(obj, tmp, &((name).table)[bkt], member)
/**
* hash_for_each_possible - iterate over all possible objects hashing to the
* same bucket
@@ -147,6 +179,9 @@ static inline void hash_del(struct hlist_node *node)
#define hash_for_each_possible(name, obj, member, key) \
hlist_for_each_entry(obj, &name[hash_min(key, HASH_BITS(name))], member)
#define hash_table_for_each_possible(name, obj, member, key) \
hlist_for_each_entry(obj, &((name).table)[hash_min(key, HASH_TABLE_BITS(name))], member)
/**
* hash_for_each_possible_safe - iterate over all possible objects hashing to the
* same bucket safe against removals
@@ -160,4 +195,8 @@ static inline void hash_del(struct hlist_node *node)
hlist_for_each_entry_safe(obj, tmp,\
&name[hash_min(key, HASH_BITS(name))], member)
#define hash_table_for_each_possible_safe(name, obj, tmp, member, key) \
hlist_for_each_entry_safe(obj, tmp,\
&((name).table)[hash_min(key, HASH_TABLE_BITS(name))], member)
#endif

50
src/include/timer_wheel.h Normal file
View File

@@ -0,0 +1,50 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __TIMER_WHEEL_H
#define __TIMER_WHEEL_H
#include "list.h"
struct tw_base;
struct tw_timer_list;
typedef void (*tw_func)(struct tw_timer_list *, void *, unsigned long);
typedef void (*tw_del_func)(struct tw_timer_list *, void *);
struct tw_timer_list {
void *data;
unsigned long expires;
tw_func function;
tw_del_func del_function;
struct list_head entry;
};
struct tw_base *tw_init_timers(void);
int tw_cleanup_timers(struct tw_base *);
void tw_add_timer(struct tw_base *, struct tw_timer_list *);
int tw_del_timer(struct tw_base *, struct tw_timer_list *);
int tw_mod_timer_pending(struct tw_base *, struct tw_timer_list *, unsigned long);
int tw_mod_timer(struct tw_base *, struct tw_timer_list *, unsigned long);
#endif

View File

@@ -26,12 +26,18 @@
#include <unistd.h>
static const char *current_conf_file = NULL;
static int current_conf_lineno = 0;
const char *conf_get_conf_file(void)
{
return current_conf_file;
}
int conf_get_current_lineno(void)
{
return current_conf_lineno;
}
static char *get_dir_name(char *path)
{
if (strstr(path, "/") == NULL) {
@@ -347,6 +353,7 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
char value[MAX_LINE_LEN];
int filed_num = 0;
int i = 0;
int last_item_index = -1;
int argc = 0;
char *argv[1024];
int ret = 0;
@@ -354,6 +361,9 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
int line_no = 0;
int line_len = 0;
int read_len = 0;
int is_last_line_wrap = 0;
int current_line_wrap = 0;
int is_func_found = 0;
const char *last_file = NULL;
if (handler == NULL) {
@@ -367,16 +377,46 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
line_no = 0;
while (fgets(line + line_len, MAX_LINE_LEN - line_len, fp)) {
current_line_wrap = 0;
line_no++;
read_len = strnlen(line + line_len, sizeof(line));
if (read_len >= 2 && *(line + line_len + read_len - 2) == '\\') {
line_len += read_len - 2;
line[line_len] = '\0';
continue;
read_len -= 1;
current_line_wrap = 1;
}
line_len = 0;
filed_num = sscanf(line, "%63s %8192[^\r\n]s", key, value);
/* comment in wrap line, skip */
if (is_last_line_wrap && read_len > 0) {
if (*(line + line_len) == '#') {
continue;
}
}
/* trim prefix spaces in wrap line */
if ((current_line_wrap == 1 || is_last_line_wrap == 1) && read_len > 0) {
is_last_line_wrap = current_line_wrap;
read_len -= 1;
for (i = 0; i < read_len; i++) {
char *ptr = line + line_len + i;
if (*ptr == ' ' || *ptr == '\t') {
continue;
}
memmove(line + line_len, ptr, read_len - i + 1);
line_len += read_len - i;
break;
}
line[line_len] = '\0';
if (current_line_wrap) {
continue;
}
}
line_len = 0;
is_last_line_wrap = 0;
filed_num = sscanf(line, "%63s %8191[^\r\n]s", key, value);
if (filed_num <= 0) {
continue;
}
@@ -392,13 +432,23 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
goto errout;
}
for (i = 0;; i++) {
is_func_found = 0;
for (i = last_item_index;; i++) {
if (i < 0) {
continue;
}
if (items[i].item == NULL) {
handler(file, line_no, CONF_RET_NOENT);
break;
}
if (strncmp(items[i].item, key, MAX_KEY_LEN) != 0) {
if (last_item_index >= 0) {
i = -1;
last_item_index = -1;
}
continue;
}
@@ -410,6 +460,7 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
/* call item function */
last_file = current_conf_file;
current_conf_file = file;
current_conf_lineno = line_no;
call_ret = items[i].item_func(items[i].item, items[i].data, argc, argv);
ret = handler(file, line_no, call_ret);
if (ret != 0) {
@@ -422,8 +473,14 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro
current_conf_file = last_file;
}
last_item_index = i;
is_func_found = 1;
break;
}
if (is_func_found == 0) {
handler(file, line_no, CONF_RET_NOENT);
}
}
fclose(fp);

419
src/lib/timer_wheel.c Normal file
View File

@@ -0,0 +1,419 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "bitops.h"
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/select.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>
#include "timer_wheel.h"
#define TVR_BITS 10
#define TVN_BITS 6
#define TVR_SIZE (1 << TVR_BITS)
#define TVN_SIZE (1 << TVN_BITS)
#define TVR_MASK (TVR_SIZE - 1)
#define TVN_MASK (TVN_SIZE - 1)
#define INDEX(N) ((base->jiffies >> (TVR_BITS + N * TVN_BITS)) & TVN_MASK)
struct tvec {
struct list_head vec[TVN_SIZE];
};
struct tvec_root {
struct list_head vec[TVR_SIZE];
};
struct tw_base {
pthread_spinlock_t lock;
pthread_t runner;
unsigned long jiffies;
struct tvec_root tv1;
struct tvec tv2;
struct tvec tv3;
struct tvec tv4;
struct tvec tv5;
};
static inline void _tw_add_timer(struct tw_base *base, struct tw_timer_list *timer)
{
int i;
unsigned long idx;
unsigned long expires;
struct list_head *vec;
expires = timer->expires;
idx = expires - base->jiffies;
if (idx < TVR_SIZE) {
i = expires & TVR_MASK;
vec = base->tv1.vec + i;
} else if (idx < 1 << (TVR_BITS + TVN_BITS)) {
i = (expires >> TVR_BITS) & TVN_MASK;
vec = base->tv2.vec + i;
} else if (idx < 1 << (TVR_BITS + 2 * TVN_BITS)) {
i = (expires >> (TVR_BITS + TVN_BITS)) & TVN_MASK;
vec = base->tv3.vec + i;
} else if (idx < 1 << (TVR_BITS + 3 * TVN_BITS)) {
i = (expires >> (TVR_BITS + 2 * TVN_BITS)) & TVN_MASK;
vec = base->tv4.vec + i;
} else if ((long)idx < 0) {
vec = base->tv1.vec + (base->jiffies & TVR_MASK);
} else {
i = (expires >> (TVR_BITS + 3 * TVN_BITS)) & TVN_MASK;
vec = base->tv5.vec + i;
}
list_add_tail(&timer->entry, vec);
}
static inline unsigned long _apply_slack(struct tw_base *base, struct tw_timer_list *timer)
{
long delta;
unsigned long mask, expires, expires_limit;
expires = timer->expires;
delta = expires - base->jiffies;
if (delta < 256) {
return expires;
}
expires_limit = expires + delta / 256;
mask = expires ^ expires_limit;
if (mask == 0) {
return expires;
}
int bit = fls_long(mask);
mask = (1UL << bit) - 1;
expires_limit = expires_limit & ~(mask);
return expires_limit;
}
static inline void _tw_detach_timer(struct tw_timer_list *timer)
{
struct list_head *entry = &timer->entry;
list_del(entry);
entry->next = NULL;
}
static inline int _tw_cascade(struct tw_base *base, struct tvec *tv, int index)
{
struct tw_timer_list *timer, *tmp;
struct list_head tv_list;
list_replace_init(tv->vec + index, &tv_list);
list_for_each_entry_safe(timer, tmp, &tv_list, entry)
{
_tw_add_timer(base, timer);
}
return index;
}
static inline int timer_pending(struct tw_timer_list *timer)
{
struct list_head *entry = &timer->entry;
return (entry->next != NULL);
}
static inline int __detach_if_pending(struct tw_timer_list *timer)
{
if (!timer_pending(timer)) {
return 0;
}
_tw_detach_timer(timer);
return 1;
}
static inline int __mod_timer(struct tw_base *base, struct tw_timer_list *timer, int pending_only)
{
int ret = 0;
ret = __detach_if_pending(timer);
if (!ret && pending_only) {
goto done;
}
ret = 1;
_tw_add_timer(base, timer);
done:
return ret;
}
void tw_add_timer(struct tw_base *base, struct tw_timer_list *timer)
{
if (timer->function == NULL) {
return;
}
pthread_spin_lock(&base->lock);
{
timer->expires += base->jiffies;
timer->expires = _apply_slack(base, timer);
_tw_add_timer(base, timer);
}
pthread_spin_unlock(&base->lock);
}
int tw_del_timer(struct tw_base *base, struct tw_timer_list *timer)
{
int ret = 0;
pthread_spin_lock(&base->lock);
{
if (timer_pending(timer)) {
ret = 1;
_tw_detach_timer(timer);
}
}
pthread_spin_unlock(&base->lock);
if (ret == 1 && timer->del_function) {
timer->del_function(timer, timer->data);
}
return ret;
}
int tw_mod_timer_pending(struct tw_base *base, struct tw_timer_list *timer, unsigned long expires)
{
int ret = 1;
pthread_spin_lock(&base->lock);
{
timer->expires = expires + base->jiffies;
timer->expires = _apply_slack(base, timer);
ret = __mod_timer(base, timer, 1);
}
pthread_spin_unlock(&base->lock);
return ret;
}
int tw_mod_timer(struct tw_base *base, struct tw_timer_list *timer, unsigned long expires)
{
int ret = 1;
pthread_spin_lock(&base->lock);
{
if (timer_pending(timer) && timer->expires == expires) {
goto unblock;
}
timer->expires = expires + base->jiffies;
timer->expires = _apply_slack(base, timer);
ret = __mod_timer(base, timer, 0);
}
unblock:
pthread_spin_unlock(&base->lock);
return ret;
}
int tw_cleanup_timers(struct tw_base *base)
{
int ret = 0;
void *res = NULL;
ret = pthread_cancel(base->runner);
if (ret != 0) {
goto errout;
}
ret = pthread_join(base->runner, &res);
if (ret != 0) {
goto errout;
}
if (res != PTHREAD_CANCELED) {
goto errout;
}
ret = pthread_spin_destroy(&base->lock);
if (ret != 0) {
goto errout;
}
free(base);
return 0;
errout:
return -1;
}
static inline void run_timers(struct tw_base *base)
{
unsigned long index, call_time;
struct tw_timer_list *timer;
struct list_head work_list;
struct list_head *head = &work_list;
pthread_spin_lock(&base->lock);
{
index = base->jiffies & TVR_MASK;
if (!index && (!_tw_cascade(base, &base->tv2, INDEX(0))) && (!_tw_cascade(base, &base->tv3, INDEX(1))) &&
(!_tw_cascade(base, &base->tv4, INDEX(2))))
_tw_cascade(base, &base->tv5, INDEX(3));
call_time = base->jiffies++;
list_replace_init(base->tv1.vec + index, head);
while (!list_empty(head)) {
tw_func fn;
void *data;
timer = list_first_entry(head, struct tw_timer_list, entry);
fn = timer->function;
data = timer->data;
_tw_detach_timer(timer);
pthread_spin_unlock(&base->lock);
{
fn(timer, data, call_time);
}
pthread_spin_lock(&base->lock);
if ( (timer_pending(timer) == 0 && timer->del_function) ) {
pthread_spin_unlock(&base->lock);
timer->del_function(timer, timer->data);
pthread_spin_lock(&base->lock);
}
}
}
pthread_spin_unlock(&base->lock);
}
static unsigned long _tw_tick_count(void)
{
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return (ts.tv_sec * 1000 + ts.tv_nsec / 1000000);
}
static void *timer_work(void *arg)
{
struct tw_base *base = arg;
int sleep = 1000;
int sleep_time = 0;
unsigned long now = {0};
unsigned long last = {0};
unsigned long expect_time = 0;
sleep_time = sleep;
now = _tw_tick_count() - sleep;
last = now;
expect_time = now + sleep;
while (1) {
run_timers(base);
now = _tw_tick_count();
if (sleep_time > 0) {
sleep_time -= now - last;
if (sleep_time <= 0) {
sleep_time = 0;
}
int cnt = sleep_time / sleep;
expect_time -= cnt * sleep;
sleep_time -= cnt * sleep;
}
if (now >= expect_time) {
sleep_time = sleep - (now - expect_time);
if (sleep_time < 0) {
sleep_time = 0;
expect_time = now;
}
expect_time += sleep;
}
last = now;
usleep(sleep_time * 1000);
}
return NULL;
}
struct tw_base *tw_init_timers(void)
{
int j = 0;
int ret = 0;
struct timeval tv = {
0,
};
struct tw_base *base = NULL;
base = malloc(sizeof(*base));
if (!base) {
goto errout;
}
ret = pthread_spin_init(&base->lock, 0);
if (ret != 0) {
goto errout2;
}
for (j = 0; j < TVN_SIZE; j++) {
INIT_LIST_HEAD(base->tv5.vec + j);
INIT_LIST_HEAD(base->tv4.vec + j);
INIT_LIST_HEAD(base->tv3.vec + j);
INIT_LIST_HEAD(base->tv2.vec + j);
}
for (j = 0; j < TVR_SIZE; j++) {
INIT_LIST_HEAD(base->tv1.vec + j);
}
ret = gettimeofday(&tv, 0);
if (ret < 0) {
goto errout1;
}
base->jiffies = tv.tv_sec;
ret = pthread_create(&base->runner, NULL, timer_work, base);
if (ret != 0) {
goto errout1;
}
return base;
errout1:
(void)pthread_spin_destroy(&base->lock);
errout2:
free(base);
errout:
return NULL;
}

View File

@@ -96,7 +96,7 @@ struct proxy_struct {
static struct proxy_struct proxy;
const char *proxy_socks5_status_code[] = {
static const char *proxy_socks5_status_code[] = {
"success",
"general SOCKS server failure",
"connection not allowed by ruleset",
@@ -234,7 +234,7 @@ int proxy_remove(const char *proxy_name)
static void _proxy_remove_all(void)
{
struct proxy_server_info *server_info;
struct proxy_server_info *server_info = NULL;
struct hlist_node *tmp = NULL;
unsigned int i = 0;
@@ -957,7 +957,7 @@ int proxy_conn_recvfrom(struct proxy_conn *proxy_conn, void *buf, size_t len, in
return -1;
}
ret = recvfrom(proxy_conn->udp_fd, buffer, sizeof(buffer), MSG_NOSIGNAL, NULL, 0);
ret = recvfrom(proxy_conn->udp_fd, buffer, sizeof(buffer), MSG_NOSIGNAL, NULL, NULL);
if (ret <= 0) {
return -1;
}
@@ -1043,14 +1043,14 @@ int proxy_conn_is_udp(struct proxy_conn *proxy_conn)
return proxy_conn->is_udp;
}
int proxy_init()
int proxy_init(void)
{
memset(&proxy, 0, sizeof(proxy));
hash_init(proxy.proxy_server);
return 0;
}
int proxy_exit()
int proxy_exit(void)
{
_proxy_remove_all();
return 0;

View File

@@ -20,6 +20,7 @@
#include "smartdns.h"
#include "art.h"
#include "atomic.h"
#include "dns_cache.h"
#include "dns_client.h"
#include "dns_conf.h"
#include "dns_server.h"
@@ -29,8 +30,10 @@
#include "rbtree.h"
#include "tlog.h"
#include "util.h"
#include "timer.h"
#include <errno.h>
#include <fcntl.h>
#include <getopt.h>
#include <libgen.h>
#include <linux/capability.h>
#include <openssl/err.h>
@@ -261,7 +264,6 @@ static int _smartdns_prepare_server_flags(struct client_dns_server_flags *flags,
safe_strncpy(flag_tls->hostname, server->hostname, sizeof(flag_tls->hostname));
safe_strncpy(flag_tls->tls_host_verify, server->tls_host_verify, sizeof(flag_tls->tls_host_verify));
flag_tls->skip_check_cert = server->skip_check_cert;
} break;
case DNS_SERVER_TCP:
break;
@@ -459,14 +461,21 @@ static int _smartdns_init(void)
int ret = 0;
const char *logfile = _smartdns_log_path();
int i = 0;
char logdir[PATH_MAX] = {0};
int logbuffersize = 0;
ret = tlog_init(logfile, dns_conf_log_size, dns_conf_log_num, 0, 0);
if (get_system_mem_size() > 1024 * 1024 * 1024) {
logbuffersize = 1024 * 1024;
}
ret = tlog_init(logfile, dns_conf_log_size, dns_conf_log_num, logbuffersize, TLOG_NONBLOCK);
if (ret != 0) {
tlog(TLOG_ERROR, "start tlog failed.\n");
goto errout;
}
if (verbose_screen != 0 || dns_conf_log_console != 0) {
safe_strncpy(logdir, _smartdns_log_path(), PATH_MAX);
if (verbose_screen != 0 || dns_conf_log_console != 0 || access(dir_name(logdir), W_OK) != 0) {
tlog_setlogscreen(1);
}
@@ -478,6 +487,11 @@ static int _smartdns_init(void)
tlog(TLOG_NOTICE, "smartdns starting...(Copyright (C) Nick Peng <pymumu@gmail.com>, build: %s %s)", __DATE__,
__TIME__);
if (dns_timer_init() != 0) {
tlog(TLOG_ERROR, "init timer failed.");
goto errout;
}
if (_smartdns_init_ssl() != 0) {
tlog(TLOG_ERROR, "init ssl failed.");
goto errout;
@@ -565,6 +579,7 @@ static void _smartdns_exit(void)
fast_ping_exit();
dns_server_exit();
_smartdns_destroy_ssl();
dns_timer_destroy();
tlog_exit();
dns_server_load_exit();
}
@@ -736,14 +751,20 @@ static void smartdns_test_notify_func(int fd_notify, uint64_t retval)
}
}
int smartdns_main(int argc, char *argv[], int fd_notify)
#define smartdns_close_allfds() \
if (no_close_allfds == 0) { \
close_all_fd(fd_notify); \
}
int smartdns_main(int argc, char *argv[], int fd_notify, int no_close_allfds)
#else
#define smartdns_test_notify(retval)
#define smartdns_close_allfds() close_all_fd(-1)
int main(int argc, char *argv[])
#endif
{
int ret = 0;
int is_foreground = 0;
int is_run_as_daemon = 1;
int opt = 0;
char config_file[MAX_LINE_LEN];
char pid_file[MAX_LINE_LEN];
@@ -751,6 +772,9 @@ int main(int argc, char *argv[])
sigset_t empty_sigblock;
struct stat sb;
static struct option long_options[] = {
{"cache-print", required_argument, 0, 256}, {"help", no_argument, 0, 'h'}, {NULL, 0, 0, 0}};
safe_strncpy(config_file, SMARTDNS_CONF_FILE, MAX_LINE_LEN);
if (stat("/run", &sb) == 0 && S_ISDIR(sb.st_mode)) {
@@ -762,17 +786,22 @@ int main(int argc, char *argv[])
/* patch for Asus router: unblock all signal*/
sigemptyset(&empty_sigblock);
sigprocmask(SIG_SETMASK, &empty_sigblock, NULL);
smartdns_close_allfds();
while ((opt = getopt(argc, argv, "fhc:p:SvxN:")) != -1) {
while ((opt = getopt_long(argc, argv, "fhc:p:SvxN:", long_options, 0)) != -1) {
switch (opt) {
case 'f':
is_foreground = 1;
is_run_as_daemon = 0;
break;
case 'c':
snprintf(config_file, sizeof(config_file), "%s", optarg);
if (full_path(config_file, sizeof(config_file), optarg) != 0) {
snprintf(config_file, sizeof(config_file), "%s", optarg);
}
break;
case 'p':
snprintf(pid_file, sizeof(pid_file), "%s", optarg);
if (strncmp(optarg, "-", 2) == 0 || full_path(pid_file, sizeof(pid_file), optarg) != 0) {
snprintf(pid_file, sizeof(pid_file), "%s", optarg);
}
break;
case 'S':
signal_ignore = 1;
@@ -790,19 +819,37 @@ int main(int argc, char *argv[])
#endif
case 'h':
_help();
return 0;
case 256:
return dns_cache_print(optarg);
break;
default:
fprintf(stderr, "unknown option, please run %s -h for help.\n", argv[0]);
return 1;
}
}
if (dns_server_load_conf(config_file) != 0) {
ret = dns_server_load_conf(config_file);
if (ret != 0) {
fprintf(stderr, "load config failed.\n");
goto errout;
}
if (is_foreground == 0) {
if (daemon(0, 0) < 0) {
fprintf(stderr, "run daemon process failed, %s\n", strerror(errno));
return 1;
if (dns_no_daemon) {
is_run_as_daemon = 0;
}
if (is_run_as_daemon) {
int daemon_ret = daemon_run();
if (daemon_ret != -2) {
char buff[4096];
char *log_path = realpath(_smartdns_log_path(), buff);
if (log_path != NULL && access(log_path, F_OK) == 0 && daemon_ret != -3 && daemon_ret != 0) {
fprintf(stderr, "run daemon failed, please check log at %s\n", log_path);
}
return daemon_ret;
}
}
@@ -810,7 +857,8 @@ int main(int argc, char *argv[])
_reg_signal();
}
if (strncmp(pid_file, "-", 2) != 0 && create_pid_file(pid_file) != 0) {
if (strncmp(pid_file, "-", 2) != 0 && dns_no_pidfile == 0 && create_pid_file(pid_file) != 0) {
ret = -3;
goto errout;
}
@@ -818,9 +866,10 @@ int main(int argc, char *argv[])
signal(SIGINT, _sig_exit);
signal(SIGTERM, _sig_exit);
if (_smartdns_init_pre() != 0) {
ret = _smartdns_init_pre();
if (ret != 0) {
fprintf(stderr, "init failed.\n");
return 1;
goto errout;
}
drop_root_privilege();
@@ -831,11 +880,21 @@ int main(int argc, char *argv[])
goto errout;
}
if (is_run_as_daemon) {
ret = daemon_kickoff(0, dns_conf_log_console | verbose_screen);
if (ret != 0) {
goto errout;
}
}
smartdns_test_notify(1);
ret = _smartdns_run();
_smartdns_exit();
return ret;
errout:
if (is_run_as_daemon) {
daemon_kickoff(ret, dns_conf_log_console | verbose_screen);
}
smartdns_test_notify(2);
return 1;
}

View File

@@ -29,7 +29,7 @@ typedef void (*smartdns_post_func)(void *arg);
int smartdns_reg_post_func(smartdns_post_func func, void *arg);
int smartdns_main(int argc, char *argv[], int fd_notify);
int smartdns_main(int argc, char *argv[], int fd_notify, int no_close_allfds);
#endif

70
src/timer.c Normal file
View File

@@ -0,0 +1,70 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "timer.h"
#include "timer_wheel.h"
static struct tw_base *dns_timer_base = NULL;
int dns_timer_init(void)
{
struct tw_base *tw = tw_init_timers();
if (tw == NULL) {
return -1;
}
dns_timer_base = tw;
return 0;
}
void dns_timer_destroy(void)
{
if (dns_timer_base != NULL) {
tw_cleanup_timers(dns_timer_base);
dns_timer_base = NULL;
}
}
void dns_timer_add(struct tw_timer_list *timer)
{
if (dns_timer_base == NULL) {
return;
}
tw_add_timer(dns_timer_base, timer);
}
int dns_timer_del(struct tw_timer_list *timer)
{
if (dns_timer_base == NULL) {
return 0;
}
return tw_del_timer(dns_timer_base, timer);
}
int dns_timer_mod(struct tw_timer_list *timer, unsigned long expires)
{
if (dns_timer_base == NULL) {
return 0;
}
return tw_mod_timer(dns_timer_base, timer, expires);
}

41
src/timer.h Normal file
View File

@@ -0,0 +1,41 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef SMART_DNS_TIMER_H
#define SMART_DNS_TIMER_H
#include "timer_wheel.h"
#ifdef __cplusplus
extern "C" {
#endif /*__cplusplus */
int dns_timer_init(void);
void dns_timer_add(struct tw_timer_list *timer);
int dns_timer_del(struct tw_timer_list *timer);
int dns_timer_mod(struct tw_timer_list *timer, unsigned long expires);
void dns_timer_destroy(void);
#ifdef __cplusplus
}
#endif /*__cplusplus */
#endif

View File

@@ -187,7 +187,7 @@ static int _tlog_mkdir(const char *path)
return 0;
}
while (*path == ' ' && *path != '\0') {
while (*path == ' ') {
path++;
}
@@ -330,7 +330,7 @@ void tlog_logcount(struct tlog_log *log, int count)
log->logcount = count;
}
void tlog_set_permission(struct tlog_log *log, unsigned int file, unsigned int archive)
void tlog_set_permission(struct tlog_log *log, mode_t file, mode_t archive)
{
log->file_perm = file;
log->archive_perm = archive;
@@ -748,7 +748,7 @@ static int _tlog_list_dir(const char *path, list_callback callback, void *userpt
dir = opendir(path);
if (dir == NULL) {
fprintf(stderr, "open directory failed, %s\n", strerror(errno));
fprintf(stderr, "tlog: open directory failed, %s\n", strerror(errno));
goto errout;
}
@@ -859,7 +859,7 @@ static int _tlog_remove_oldlog(struct tlog_log *log)
/* get total log file number */
if (_tlog_list_dir(log->logdir, _tlog_count_log_callback, &count_log) != 0) {
fprintf(stderr, "get log file count failed.\n");
fprintf(stderr, "tlog: get log file count failed.\n");
return -1;
}
@@ -896,7 +896,7 @@ static int _tlog_log_lock(struct tlog_log *log)
snprintf(lock_file, sizeof(lock_file), "%s/%s.lock", log->logdir, log->logname);
fd = open(lock_file, O_RDWR | O_CREAT | O_CLOEXEC, S_IRUSR | S_IWUSR);
if (fd == -1) {
fprintf(stderr, "create pid file failed, %s", strerror(errno));
fprintf(stderr, "tlog: create lock file failed, %s", strerror(errno));
return -1;
}
@@ -1061,8 +1061,14 @@ static int _tlog_archive_log_compressed(struct tlog_log *log)
if (pid == 0) {
_tlog_close_all_fd();
execl(tlog.gzip_cmd, tlog.gzip_cmd, "-1", pending_file, NULL);
fprintf(stderr, "tlog: execl gzip failed, no compress\n");
log->nocompress = 1;
_exit(1);
} else if (pid < 0) {
if (errno == EPERM || errno == EACCES) {
fprintf(stderr, "tlog: vfork failed, errno: %d, no compress\n", errno);
log->nocompress = 1;
}
goto errout;
}
log->zip_pid = pid;
@@ -1195,9 +1201,9 @@ static int _tlog_write(struct tlog_log *log, const char *buff, int bufflen)
return -1;
}
log->print_errmsg = 0;
fprintf(stderr, "create log dir %s failed, %s\n", log->logdir, strerror(errno));
fprintf(stderr, "tlog: create log dir %s failed, %s\n", log->logdir, strerror(errno));
if (errno == EACCES && log->logscreen == 0) {
fprintf(stderr, "no permission to write log file, output log to console\n");
fprintf(stderr, "tlog: no permission to write log file, output log to console\n");
tlog_logscreen(log, 1);
tlog_logcount(log, 0);
}
@@ -1211,7 +1217,7 @@ static int _tlog_write(struct tlog_log *log, const char *buff, int bufflen)
return -1;
}
fprintf(stderr, "open log file %s failed, %s\n", logfile, strerror(errno));
fprintf(stderr, "tlog: open log file %s failed, %s\n", logfile, strerror(errno));
log->print_errmsg = 0;
return -1;
}
@@ -1752,13 +1758,13 @@ tlog_log *tlog_open(const char *logfile, int maxlogsize, int maxlogcount, int bu
struct tlog_log *log = NULL;
if (tlog.run == 0) {
fprintf(stderr, "tlog is not initialized.");
fprintf(stderr, "tlog: tlog is not initialized.\n");
return NULL;
}
log = (struct tlog_log *)malloc(sizeof(*log));
if (log == NULL) {
fprintf(stderr, "malloc log failed.");
fprintf(stderr, "tlog: malloc log failed.\n");
return NULL;
}
@@ -1800,7 +1806,7 @@ tlog_log *tlog_open(const char *logfile, int maxlogsize, int maxlogcount, int bu
log->buff = (char *)malloc(log->buffsize);
if (log->buff == NULL) {
fprintf(stderr, "malloc log buffer failed, %s\n", strerror(errno));
fprintf(stderr, "tlog: malloc log buffer failed, %s\n", strerror(errno));
goto errout;
}
@@ -1888,7 +1894,7 @@ static void tlog_fork_child(void)
pthread_attr_init(&attr);
int ret = pthread_create(&tlog.tid, &attr, _tlog_work, NULL);
if (ret != 0) {
fprintf(stderr, "create tlog work thread failed, %s\n", strerror(errno));
fprintf(stderr, "tlog: create tlog work thread failed, %s\n", strerror(errno));
goto errout;
}
@@ -1910,12 +1916,12 @@ int tlog_init(const char *logfile, int maxlogsize, int maxlogcount, int buffsize
struct tlog_log *log = NULL;
if (tlog_format != NULL) {
fprintf(stderr, "tlog already initialized.\n");
fprintf(stderr, "tlog: already initialized.\n");
return -1;
}
if (buffsize > 0 && buffsize < TLOG_MAX_LINE_SIZE_SET * 2) {
fprintf(stderr, "buffer size is invalid.\n");
fprintf(stderr, "tlog: buffer size is invalid.\n");
return -1;
}
@@ -1932,19 +1938,19 @@ int tlog_init(const char *logfile, int maxlogsize, int maxlogcount, int buffsize
log = tlog_open(logfile, maxlogsize, maxlogcount, buffsize, flag);
if (log == NULL) {
fprintf(stderr, "init tlog root failed.\n");
fprintf(stderr, "tlog: init tlog root failed.\n");
goto errout;
}
tlog_reg_output_func(log, _tlog_root_write_log);
if ((flag & TLOG_NOCOMPRESS) == 0 && tlog.gzip_cmd[0] == '\0') {
fprintf(stderr, "can not find gzip command, disable compress.\n");
fprintf(stderr, "tlog: can not find gzip command, disable compress.\n");
}
tlog.root = log;
ret = pthread_create(&tlog.tid, &attr, _tlog_work, NULL);
if (ret != 0) {
fprintf(stderr, "create tlog work thread failed, %s\n", strerror(errno));
fprintf(stderr, "tlog: create tlog work thread failed, %s\n", strerror(errno));
goto errout;
}

View File

@@ -1,6 +1,6 @@
/*
* tinylog
* Copyright (C) 2018-2021 Ruilin Peng (Nick) <pymumu@gmail.com>
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>
* https://github.com/pymumu/tinylog
*/
@@ -139,7 +139,7 @@ steps:
read _tlog_format for example.
*/
typedef int (*tlog_format_func)(char *buff, int maxlen, struct tlog_loginfo *info, void *userptr, const char *format, va_list ap);
extern int tlog_reg_format_func(tlog_format_func func);
extern int tlog_reg_format_func(tlog_format_func callback);
/* register log output callback
Note: info is invalid when flag TLOG_SEGMENT is not set.

View File

@@ -25,6 +25,7 @@
#include "util.h"
#include <arpa/inet.h>
#include <ctype.h>
#include <dirent.h>
#include <dlfcn.h>
#include <errno.h>
#include <fcntl.h>
@@ -38,11 +39,13 @@
#include <openssl/crypto.h>
#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include <poll.h>
#include <pthread.h>
#include <signal.h>
#include <stdlib.h>
#include <string.h>
#include <sys/prctl.h>
#include <sys/resource.h>
#include <sys/stat.h>
#include <sys/statvfs.h>
#include <sys/sysinfo.h>
@@ -100,8 +103,20 @@ struct ipset_netlink_msg {
__be16 res_id;
};
enum daemon_msg_type {
DAEMON_MSG_KICKOFF,
DAEMON_MSG_KEEPALIVE,
DAEMON_MSG_DAEMON_PID,
};
struct daemon_msg {
enum daemon_msg_type type;
int value;
};
static int ipset_fd;
static int pidfile_fd;
static int daemon_fd;
unsigned long get_tick_count(void)
{
@@ -153,6 +168,55 @@ errout:
return NULL;
}
int generate_random_addr(unsigned char *addr, int addr_len, int mask)
{
if (mask / 8 > addr_len) {
return -1;
}
int offset = mask / 8;
int bit = 0;
for (int i = offset; i < addr_len; i++) {
bit = 0xFF;
if (i == offset) {
bit = ~(0xFF << (8 - mask % 8)) & 0xFF;
}
addr[i] = jhash(&addr[i], 1, 0) & bit;
}
return 0;
}
int generate_addr_map(const unsigned char *addr_from, const unsigned char *addr_to, unsigned char *addr_out, int addr_len, int mask)
{
if ((mask / 8) >= addr_len) {
if (mask % 8 != 0) {
return -1;
}
}
int offset = mask / 8;
int bit = mask % 8;
for (int i = 0; i < offset; i++) {
addr_out[i] = addr_to[i];
}
if (bit != 0) {
int mask1 = 0xFF >> bit;
int mask2 = (0xFF << (8 - bit)) & 0xFF;
addr_out[offset] = addr_from[offset] & mask1;
addr_out[offset] |= addr_to[offset] & mask2;
offset = offset + 1;
}
for (int i = offset; i < addr_len; i++) {
addr_out[i] = addr_from[i];
}
return 0;
}
int getaddr_by_host(const char *host, struct sockaddr *addr, socklen_t *addr_len)
{
struct addrinfo hints;
@@ -806,7 +870,11 @@ int create_pid_file(const char *pid_file)
}
if (lockf(fd, F_TLOCK, 0) < 0) {
fprintf(stderr, "Server is already running.\n");
memset(buff, 0, TMP_BUFF_LEN_32);
if (read(fd, buff, TMP_BUFF_LEN_32) <= 0) {
buff[0] = '\0';
}
fprintf(stderr, "Server is already running, pid is %s", buff);
goto errout;
}
@@ -831,6 +899,27 @@ errout:
return -1;
}
int full_path(char *normalized_path, int normalized_path_len, const char *path)
{
const char *p = path;
if (path == NULL || normalized_path == NULL) {
return -1;
}
while (*p == ' ') {
p++;
}
if (*p == '\0' || *p == '/') {
return -1;
}
char buf[PATH_MAX];
snprintf(normalized_path, normalized_path_len, "%s/%s", getcwd(buf, sizeof(buf)), path);
return 0;
}
int generate_cert_key(const char *key_path, const char *cert_path, const char *san, int days)
{
int ret = -1;
@@ -1479,6 +1568,210 @@ out:
return ret;
}
static void _close_all_fd_by_res(void)
{
struct rlimit lim;
int maxfd = 0;
int i = 0;
getrlimit(RLIMIT_NOFILE, &lim);
maxfd = lim.rlim_cur;
if (maxfd > 4096) {
maxfd = 4096;
}
for (i = 3; i < maxfd; i++) {
close(i);
}
}
void close_all_fd(int keepfd)
{
DIR *dirp;
int dir_fd = -1;
struct dirent *dentp;
dirp = opendir("/proc/self/fd");
if (dirp == NULL) {
goto errout;
}
dir_fd = dirfd(dirp);
while ((dentp = readdir(dirp)) != NULL) {
int fd = atol(dentp->d_name);
if (fd < 0) {
continue;
}
if (fd == dir_fd || fd == STDIN_FILENO || fd == STDOUT_FILENO || fd == STDERR_FILENO || fd == keepfd) {
continue;
}
close(fd);
}
closedir(dirp);
return;
errout:
if (dirp) {
closedir(dirp);
}
_close_all_fd_by_res();
return;
}
int daemon_kickoff(int status, int no_close)
{
struct daemon_msg msg;
if (daemon_fd <= 0) {
return -1;
}
msg.type = DAEMON_MSG_KICKOFF;
msg.value = status;
int ret = write(daemon_fd, &msg, sizeof(msg));
if (ret != sizeof(msg)) {
fprintf(stderr, "notify parent process failed, %s\n", strerror(errno));
return -1;
}
if (no_close == 0) {
int fd_null = open("/dev/null", O_RDWR);
if (fd_null < 0) {
fprintf(stderr, "open /dev/null failed, %s\n", strerror(errno));
return -1;
}
dup2(fd_null, STDIN_FILENO);
dup2(fd_null, STDOUT_FILENO);
dup2(fd_null, STDERR_FILENO);
if (fd_null > 2) {
close(fd_null);
}
}
close(daemon_fd);
daemon_fd = -1;
return 0;
}
int daemon_keepalive(void)
{
struct daemon_msg msg;
static time_t last = 0;
time_t now = time(NULL);
if (daemon_fd <= 0) {
return -1;
}
if (now == last) {
return 0;
}
last = now;
msg.type = DAEMON_MSG_KEEPALIVE;
msg.value = 0;
int ret = write(daemon_fd, &msg, sizeof(msg));
if (ret != sizeof(msg)) {
return -1;
}
return 0;
}
int daemon_run(void)
{
pid_t pid = 0;
int fds[2] = {0};
if (pipe(fds) != 0) {
fprintf(stderr, "run daemon process failed, pipe failed, %s\n", strerror(errno));
return -1;
}
pid = fork();
if (pid < 0) {
fprintf(stderr, "run daemon process failed, fork failed, %s\n", strerror(errno));
close(fds[0]);
close(fds[1]);
return -1;
} else if (pid > 0) {
struct pollfd pfd;
int ret = 0;
close(fds[1]);
pfd.fd = fds[0];
pfd.events = POLLIN;
pfd.revents = 0;
do {
ret = poll(&pfd, 1, 3000);
if (ret <= 0) {
fprintf(stderr, "run daemon process failed, wait child timeout, kill child.\n");
goto errout;
}
if (!(pfd.revents & POLLIN)) {
goto errout;
}
struct daemon_msg msg;
ret = read(fds[0], &msg, sizeof(msg));
if (ret != sizeof(msg)) {
goto errout;
}
if (msg.type == DAEMON_MSG_KEEPALIVE) {
continue;
} else if (msg.type == DAEMON_MSG_DAEMON_PID) {
pid = msg.value;
continue;
} else if (msg.type == DAEMON_MSG_KICKOFF) {
return msg.value;
} else {
goto errout;
}
} while (true);
}
setsid();
pid = fork();
if (pid < 0) {
fprintf(stderr, "double fork failed, %s\n", strerror(errno));
_exit(1);
} else if (pid > 0) {
struct daemon_msg msg;
int unused __attribute__((unused));
msg.type = DAEMON_MSG_DAEMON_PID;
msg.value = pid;
unused = write(fds[1], &msg, sizeof(msg));
_exit(0);
}
umask(0);
if (chdir("/") != 0) {
goto errout;
}
close(fds[0]);
daemon_fd = fds[1];
return -2;
errout:
kill(pid, SIGKILL);
return -1;
}
#ifdef DEBUG
struct _dns_read_packet_info {
int data_len;
@@ -1604,7 +1897,7 @@ static int _dns_debug_display(struct dns_packet *packet)
int ret = 0;
ret = dns_get_HTTPS_svcparm_start(rrs, &p, name, DNS_MAX_CNAME_LEN, &ttl, &priority, target,
DNS_MAX_CNAME_LEN);
DNS_MAX_CNAME_LEN);
if (ret != 0) {
printf("get HTTPS svcparm failed\n");
break;

View File

@@ -59,6 +59,11 @@ char *dir_name(char *path);
char *get_host_by_addr(char *host, int maxsize, struct sockaddr *addr);
int generate_random_addr(unsigned char *addr, int addr_len, int mask);
int generate_addr_map(const unsigned char *addr_from, const unsigned char *addr_to, unsigned char *addr_out,
int addr_len, int mask);
int getaddr_by_host(const char *host, struct sockaddr *addr, socklen_t *addr_len);
int getsocket_inet(int fd, struct sockaddr *addr, socklen_t *addr_len);
@@ -105,6 +110,8 @@ int generate_cert_key(const char *key_path, const char *cert_path, const char *s
int create_pid_file(const char *pid_file);
int full_path(char *normalized_path, int normalized_path_len, const char *path);
/* Parse a TLS packet for the Server Name Indication extension in the client
* hello handshake, returning the first server name found (pointer to static
* array)
@@ -138,6 +145,14 @@ uint64_t get_free_space(const char *path);
void print_stack(void);
void close_all_fd(int keepfd);
int daemon_run(void);
int daemon_kickoff(int status, int no_close);
int daemon_keepalive(void);
int write_file(const char *filename, void *data, int data_len);
int dns_packet_save(const char *dir, const char *type, const char *from, const void *packet, int packet_len);

View File

@@ -1,6 +1,9 @@
[Unit]
Description=SmartDNS Server
After=network.target
After=network.target
Before=network-online.target
Before=nss-lookup.target
Wants=nss-lookup.target
StartLimitBurst=0
StartLimitIntervalSec=60
@@ -8,7 +11,7 @@ StartLimitIntervalSec=60
Type=forking
PIDFile=@RUNSTATEDIR@/smartdns.pid
EnvironmentFile=@SYSCONFDIR@/default/smartdns
ExecStart=@SBINDIR@/smartdns -p @RUNSTATEDIR@/smartdns.pid $SMART_DNS_OPTS
ExecStart=@SBINDIR@/smartdns -p @RUNSTATEDIR@/smartdns.pid $SMART_DNS_OPTS
Restart=always
RestartSec=2
TimeoutStopSec=15

View File

@@ -23,8 +23,8 @@ CXXFLAGS += -g
CXXFLAGS += -DTEST
CXXFLAGS += -I./ -I../src -I../src/include
SMARTDNS_OBJS = lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/conf.o lib/nftset.o
SMARTDNS_OBJS += smartdns.o fast_ping.o dns_client.o dns_server.o dns.o util.o tlog.o dns_conf.o dns_cache.o http_parse.o proxy.o
SMARTDNS_OBJS = lib/rbtree.o lib/art.o lib/bitops.o lib/radix.o lib/conf.o lib/nftset.o lib/timer_wheel.o
SMARTDNS_OBJS += smartdns.o fast_ping.o dns_client.o dns_server.o dns.o util.o tlog.o dns_conf.o dns_cache.o http_parse.o proxy.o timer.o
OBJS = $(addprefix ../src/, $(SMARTDNS_OBJS))
TEST_SOURCES := $(wildcard *.cc) $(wildcard */*.cc) $(wildcard */*/*.cc)

View File

@@ -161,3 +161,140 @@ cache-persist no)""");
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::1010:1010");
}
TEST_F(Address, multiaddress)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /a.com/10.10.10.10,11.11.11.11,22.22.22.22
address /a.com/64:ff9b::1010:1010,64:ff9b::1111:1111,64:ff9b::2222:2222
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
std::map<std::string, smartdns::DNSRecord *> result;
ASSERT_EQ(client.GetAnswerNum(), 3);
EXPECT_EQ(client.GetStatus(), "NOERROR");
auto answers = client.GetAnswer();
for (int i = 0; i < client.GetAnswerNum(); i++) {
result[client.GetAnswer()[i].GetData()] = &answers[i];
}
ASSERT_NE(result.find("10.10.10.10"), result.end());
auto check_result = result["10.10.10.10"];
EXPECT_EQ(check_result->GetName(), "a.com");
EXPECT_EQ(check_result->GetTTL(), 600);
EXPECT_EQ(check_result->GetType(), "A");
EXPECT_EQ(check_result->GetData(), "10.10.10.10");
ASSERT_NE(result.find("11.11.11.11"), result.end());
check_result = result["11.11.11.11"];
EXPECT_EQ(check_result->GetName(), "a.com");
EXPECT_EQ(check_result->GetTTL(), 600);
EXPECT_EQ(check_result->GetType(), "A");
EXPECT_EQ(check_result->GetData(), "11.11.11.11");
ASSERT_NE(result.find("22.22.22.22"), result.end());
check_result = result["22.22.22.22"];
EXPECT_EQ(check_result->GetName(), "a.com");
EXPECT_EQ(check_result->GetTTL(), 600);
EXPECT_EQ(check_result->GetType(), "A");
EXPECT_EQ(check_result->GetData(), "22.22.22.22");
result.clear();
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 3);
EXPECT_EQ(client.GetStatus(), "NOERROR");
answers = client.GetAnswer();
for (int i = 0; i < client.GetAnswerNum(); i++) {
result[client.GetAnswer()[i].GetData()] = &answers[i];
}
ASSERT_NE(result.find("64:ff9b::1010:1010"), result.end());
check_result = result["64:ff9b::1010:1010"];
EXPECT_EQ(check_result->GetName(), "a.com");
EXPECT_EQ(check_result->GetTTL(), 600);
EXPECT_EQ(check_result->GetType(), "AAAA");
EXPECT_EQ(check_result->GetData(), "64:ff9b::1010:1010");
ASSERT_NE(result.find("64:ff9b::1111:1111"), result.end());
check_result = result["64:ff9b::1111:1111"];
EXPECT_EQ(check_result->GetName(), "a.com");
EXPECT_EQ(check_result->GetTTL(), 600);
EXPECT_EQ(check_result->GetType(), "AAAA");
EXPECT_EQ(check_result->GetData(), "64:ff9b::1111:1111");
ASSERT_NE(result.find("64:ff9b::2222:2222"), result.end());
check_result = result["64:ff9b::2222:2222"];
EXPECT_EQ(check_result->GetName(), "a.com");
EXPECT_EQ(check_result->GetTTL(), 600);
EXPECT_EQ(check_result->GetType(), "AAAA");
EXPECT_EQ(check_result->GetData(), "64:ff9b::2222:2222");
}
TEST_F(Address, soa_sub_ip)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /a.com/192.168.1.1
address /com/#
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "192.168.1.1");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
ASSERT_TRUE(client.Query("b.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
}

View File

@@ -127,6 +127,16 @@ cache-persist no)""");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
sleep(1);
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 5);
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST_F(Cache, max_reply_ttl_expired)

View File

@@ -77,7 +77,6 @@ TEST_F(Cname, subdomain1)
if (request->domain == "s.a.com") {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 700);
return smartdns::SERVER_REQUEST_OK;
}
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
@@ -114,7 +113,6 @@ TEST_F(Cname, subdomain2)
if (request->domain == "a.s.a.com") {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 700);
return smartdns::SERVER_REQUEST_OK;
}
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
@@ -139,7 +137,6 @@ cache-persist no)""");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "4.5.6.7");
}
TEST_F(Cname, loop)
{
smartdns::MockServer server_upstream;
@@ -153,7 +150,6 @@ TEST_F(Cname, loop)
if (request->domain == "s.a.com") {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 700);
return smartdns::SERVER_REQUEST_OK;
}
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);

124
test/cases/test-ddns.cc Normal file
View File

@@ -0,0 +1,124 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class DDNS : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(DDNS, smartdns)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
dualstack-ip-selection no
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("smartdns A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "smartdns");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "127.0.0.1");
ASSERT_TRUE(client.Query("smartdns AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "smartdns");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "::1");
}
TEST_F(DDNS, ddns)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
ddns-domain test.ddns.com
ddns-domain test.ddns.org
log-console yes
dualstack-ip-selection no
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("test.ddns.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "test.ddns.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "127.0.0.1");
ASSERT_TRUE(client.Query("test.ddns.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "test.ddns.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "::1");
ASSERT_TRUE(client.Query("test.ddns.org A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "test.ddns.org");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "127.0.0.1");
ASSERT_TRUE(client.Query("test.ddns.org AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "test.ddns.org");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "::1");
}

View File

@@ -39,17 +39,17 @@ TEST_F(DomainRule, bogus_nxdomain)
if (request->qtype != DNS_T_A) {
return smartdns::SERVER_REQUEST_SOA;
}
if (request->domain == "a.com") {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "10.11.12.13", 611);
return smartdns::SERVER_REQUEST_OK;
}
if (request->domain == "a.com") {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "10.11.12.13", 611);
return smartdns::SERVER_REQUEST_OK;
}
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
return smartdns::SERVER_REQUEST_OK;
});
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
server.Start(R"""(bind [::]:60053
server udp://127.0.0.1:61053 -blacklist-ip
@@ -64,7 +64,7 @@ cache-persist no)""");
ASSERT_EQ(client.GetAuthorityNum(), 1);
EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
ASSERT_TRUE(client.Query("b.com", 60053));
ASSERT_TRUE(client.Query("b.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");

266
test/cases/test-https.cc Normal file
View File

@@ -0,0 +1,266 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class HTTPS : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(HTTPS, ipv4_speed_prefer)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_HTTPS) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
struct dns_rr_nested svcparam_buffer;
dns_add_HTTPS_start(&svcparam_buffer, packet, DNS_RRS_AN, request->domain.c_str(), 3, 1, "b.com");
const char alph[] = "\x02h2\x05h3-19";
int alph_len = sizeof(alph) - 1;
dns_HTTPS_add_alpn(&svcparam_buffer, alph, alph_len);
dns_HTTPS_add_port(&svcparam_buffer, 443);
unsigned char add_v4[] = {1, 2, 3, 4};
unsigned char *addr[1] = {add_v4};
dns_HTTPS_add_ipv4hint(&svcparam_buffer, addr, 1);
unsigned char ech[] = {0x00, 0x45, 0xfe, 0x0d, 0x00};
dns_HTTPS_add_ech(&svcparam_buffer, (void *)ech, sizeof(ech));
unsigned char add_v6[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
addr[0] = add_v6;
dns_HTTPS_add_ipv6hint(&svcparam_buffer, addr, 1);
dns_add_HTTPS_end(&svcparam_buffer);
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-console yes
dualstack-ip-selection no
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com HTTPS", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "HTTPS");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1 b.com. alpn=\"h2,h3-19\" port=443 ipv4hint=1.2.3.4 ech=AEX+DQA=");
}
TEST_F(HTTPS, ipv6_speed_prefer)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_HTTPS) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
struct dns_rr_nested svcparam_buffer;
dns_add_HTTPS_start(&svcparam_buffer, packet, DNS_RRS_AN, request->domain.c_str(), 3, 1, "b.com");
const char alph[] = "\x02h2\x05h3-19";
int alph_len = sizeof(alph) - 1;
dns_HTTPS_add_alpn(&svcparam_buffer, alph, alph_len);
dns_HTTPS_add_port(&svcparam_buffer, 443);
unsigned char add_v4[] = {1, 2, 3, 4};
unsigned char *addr[1] = {add_v4};
dns_HTTPS_add_ipv4hint(&svcparam_buffer, addr, 1);
unsigned char ech[] = {0x00, 0x45, 0xfe, 0x0d, 0x00};
dns_HTTPS_add_ech(&svcparam_buffer, (void *)ech, sizeof(ech));
unsigned char add_v6[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
addr[0] = add_v6;
dns_HTTPS_add_ipv6hint(&svcparam_buffer, addr, 1);
dns_add_HTTPS_end(&svcparam_buffer);
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "102:304:506:708:90a:b0c:d0e:f10", 60, 10);
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-console yes
dualstack-ip-selection no
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com HTTPS", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "HTTPS");
EXPECT_EQ(client.GetAnswer()[0].GetData(),
"1 b.com. alpn=\"h2,h3-19\" port=443 ech=AEX+DQA= ipv6hint=102:304:506:708:90a:b0c:d0e:f10");
}
TEST_F(HTTPS, ipv4_SOA)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_HTTPS) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
struct dns_rr_nested svcparam_buffer;
dns_add_HTTPS_start(&svcparam_buffer, packet, DNS_RRS_AN, request->domain.c_str(), 3, 1, "a.com");
const char alph[] = "\x02h2\x05h3-19";
int alph_len = sizeof(alph) - 1;
dns_HTTPS_add_alpn(&svcparam_buffer, alph, alph_len);
dns_HTTPS_add_port(&svcparam_buffer, 443);
unsigned char add_v4[] = {1, 2, 3, 4};
unsigned char *addr[1] = {add_v4};
dns_HTTPS_add_ipv4hint(&svcparam_buffer, addr, 1);
unsigned char ech[] = {0x00, 0x45, 0xfe, 0x0d, 0x00};
dns_HTTPS_add_ech(&svcparam_buffer, (void *)ech, sizeof(ech));
unsigned char add_v6[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
addr[0] = add_v6;
dns_HTTPS_add_ipv6hint(&svcparam_buffer, addr, 1);
dns_add_HTTPS_end(&svcparam_buffer);
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-console yes
dualstack-ip-selection no
address /a.com/#4
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com HTTPS", 61053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
auto result_check = client.GetAnswer()[0].GetData();
ASSERT_TRUE(client.Query("a.com HTTPS", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "HTTPS");
EXPECT_EQ(client.GetAnswer()[0].GetData(),
"1 a.com. alpn=\"h2,h3-19\" port=443 ech=AEX+DQA= ipv6hint=102:304:506:708:90a:b0c:d0e:f10");
}
TEST_F(HTTPS, ipv6_SOA)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_HTTPS) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
struct dns_rr_nested svcparam_buffer;
dns_add_HTTPS_start(&svcparam_buffer, packet, DNS_RRS_AN, request->domain.c_str(), 3, 1, "a.com");
const char alph[] = "\x02h2\x05h3-19";
int alph_len = sizeof(alph) - 1;
dns_HTTPS_add_alpn(&svcparam_buffer, alph, alph_len);
dns_HTTPS_add_port(&svcparam_buffer, 443);
unsigned char add_v4[] = {1, 2, 3, 4};
unsigned char *addr[1] = {add_v4};
dns_HTTPS_add_ipv4hint(&svcparam_buffer, addr, 1);
unsigned char ech[] = {0x00, 0x45, 0xfe, 0x0d, 0x00};
dns_HTTPS_add_ech(&svcparam_buffer, (void *)ech, sizeof(ech));
unsigned char add_v6[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
addr[0] = add_v6;
dns_HTTPS_add_ipv6hint(&svcparam_buffer, addr, 1);
dns_add_HTTPS_end(&svcparam_buffer);
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-console yes
dualstack-ip-selection no
address /a.com/#6
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com HTTPS", 61053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
auto result_check = client.GetAnswer()[0].GetData();
ASSERT_TRUE(client.Query("a.com HTTPS", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 3);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "HTTPS");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1 a.com. alpn=\"h2,h3-19\" port=443 ipv4hint=1.2.3.4 ech=AEX+DQA=");
}
TEST_F(HTTPS, SOA)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-console yes
dualstack-ip-selection no
address /a.com/#6
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com HTTPS", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAuthorityNum(), 1);
EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 60);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
}

318
test/cases/test-ip-alias.cc Normal file
View File

@@ -0,0 +1,318 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "gtest/gtest.h"
#include <fstream>
class IPAlias : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST(IPAlias, map_multiip_nospeed_check)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
std::string domain = request->domain;
if (request->domain.length() == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (request->qtype == DNS_T_A) {
unsigned char addr[][4] = {{1, 2, 3, 1}, {1, 2, 3, 2}, {1, 2, 3, 3}};
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else if (request->qtype == DNS_T_AAAA) {
unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3}};
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else {
return smartdns::SERVER_REQUEST_ERROR;
}
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110);
server.MockPing(PING_TYPE_ICMP, "9.10.11.12", 60, 140);
server.MockPing(PING_TYPE_ICMP, "10.10.10.10", 60, 120);
server.MockPing(PING_TYPE_ICMP, "11.11.11.11", 60, 150);
server.MockPing(PING_TYPE_ICMP, "0102:0304:0500::", 60, 100);
server.MockPing(PING_TYPE_ICMP, "0506:0708:0900::", 60, 110);
server.MockPing(PING_TYPE_ICMP, "0a0b:0c0d:0e00::", 60, 140);
server.MockPing(PING_TYPE_ICMP, "ffff::1", 60, 120);
server.MockPing(PING_TYPE_ICMP, "ffff::2", 60, 150);
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
dualstack-ip-selection no
speed-check-mode none
ip-alias 1.2.3.0/24 10.10.10.10,12.12.12.12,13.13.13.13,15.15.15.15
ip-alias 0102::/16 FFFF::0001,FFFF::0002,FFFF::0003,FFFF::0004
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 4);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "10.10.10.10");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "12.12.12.12");
EXPECT_EQ(client.GetAnswer()[2].GetData(), "15.15.15.15");
EXPECT_EQ(client.GetAnswer()[3].GetData(), "13.13.13.13");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 4);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "ffff::1");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "ffff::3");
EXPECT_EQ(client.GetAnswer()[2].GetData(), "ffff::2");
EXPECT_EQ(client.GetAnswer()[3].GetData(), "ffff::4");
}
TEST(IPAlias, map_single_ip_nospeed_check)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
std::string domain = request->domain;
if (request->domain.length() == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (request->qtype == DNS_T_A) {
unsigned char addr[][4] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else if (request->qtype == DNS_T_AAAA) {
unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{10, 11, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else {
return smartdns::SERVER_REQUEST_ERROR;
}
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110);
server.MockPing(PING_TYPE_ICMP, "9.10.11.12", 60, 140);
server.MockPing(PING_TYPE_ICMP, "10.10.10.10", 60, 120);
server.MockPing(PING_TYPE_ICMP, "11.11.11.11", 60, 150);
server.MockPing(PING_TYPE_ICMP, "0102:0304:0500::", 60, 100);
server.MockPing(PING_TYPE_ICMP, "0506:0708:0900::", 60, 110);
server.MockPing(PING_TYPE_ICMP, "0a0b:0c0d:0e00::", 60, 140);
server.MockPing(PING_TYPE_ICMP, "ffff::1", 60, 120);
server.MockPing(PING_TYPE_ICMP, "ffff::2", 60, 150);
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
dualstack-ip-selection no
speed-check-mode none
ip-alias 1.2.3.4 10.10.10.10
ip-alias 5.6.7.8/32 11.11.11.11
ip-alias 0102:0304:0500:: ffff::1
ip-alias 0506:0708:0900:: ffff::2
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 3);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "10.10.10.10");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "11.11.11.11");
EXPECT_EQ(client.GetAnswer()[2].GetData(), "9.10.11.12");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 3);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "ffff::1");
EXPECT_EQ(client.GetAnswer()[1].GetData(), "a0b:c0d:e00::");
EXPECT_EQ(client.GetAnswer()[2].GetData(), "ffff::2");
}
TEST(IPAlias, mapip_withspeed_check)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
std::string domain = request->domain;
if (request->domain.length() == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (request->qtype == DNS_T_A) {
unsigned char addr[][4] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else if (request->qtype == DNS_T_AAAA) {
unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{10, 11, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else {
return smartdns::SERVER_REQUEST_ERROR;
}
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110);
server.MockPing(PING_TYPE_ICMP, "9.10.11.12", 60, 140);
server.MockPing(PING_TYPE_ICMP, "10.10.10.10", 60, 120);
server.MockPing(PING_TYPE_ICMP, "11.11.11.11", 60, 150);
server.MockPing(PING_TYPE_ICMP, "0102:0304:0500::", 60, 100);
server.MockPing(PING_TYPE_ICMP, "0506:0708:0900::", 60, 110);
server.MockPing(PING_TYPE_ICMP, "0a0b:0c0d:0e00::", 60, 140);
server.MockPing(PING_TYPE_ICMP, "ffff::1", 60, 120);
server.MockPing(PING_TYPE_ICMP, "ffff::2", 60, 150);
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
dualstack-ip-selection no
ip-alias 1.2.3.4 10.10.10.10
ip-alias 5.6.7.8/32 11.11.11.11
ip-alias 0102::/16 ffff::1
ip-alias 0506::/16 ffff::2
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "10.10.10.10");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "ffff::1");
}
TEST(IPAlias, no_ip_alias)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
std::string domain = request->domain;
if (request->domain.length() == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (request->qtype == DNS_T_A) {
unsigned char addr[][4] = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_A(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else if (request->qtype == DNS_T_AAAA) {
unsigned char addr[][16] = {{1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{10, 11, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[0]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[1]);
dns_add_AAAA(request->response_packet, DNS_RRS_AN, domain.c_str(), 61, addr[2]);
} else {
return smartdns::SERVER_REQUEST_ERROR;
}
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 100);
server.MockPing(PING_TYPE_ICMP, "5.6.7.8", 60, 110);
server.MockPing(PING_TYPE_ICMP, "9.10.11.12", 60, 140);
server.MockPing(PING_TYPE_ICMP, "10.10.10.10", 60, 120);
server.MockPing(PING_TYPE_ICMP, "11.11.11.11", 60, 150);
server.MockPing(PING_TYPE_ICMP, "0102:0304:0500::", 60, 100);
server.MockPing(PING_TYPE_ICMP, "0506:0708:0900::", 60, 110);
server.MockPing(PING_TYPE_ICMP, "0a0b:0c0d:0e00::", 60, 140);
server.MockPing(PING_TYPE_ICMP, "ffff::1", 60, 120);
server.MockPing(PING_TYPE_ICMP, "ffff::2", 60, 150);
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
dualstack-ip-selection no
ip-alias 1.2.3.4 10.10.10.10
ip-alias 5.6.7.8/32 11.11.11.11
ip-alias 0102::/16 ffff::1
ip-alias 0506::/16 ffff::2
domain-rules /a.com/ -no-ip-alias
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "102:304:500::");
}

View File

@@ -53,8 +53,8 @@ TEST_F(IPRule, white_list)
return smartdns::SERVER_REQUEST_OK;
});
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
server.Start(R"""(bind [::]:60053
server udp://127.0.0.1:61053 -whitelist-ip
@@ -97,8 +97,8 @@ TEST_F(IPRule, black_list)
return smartdns::SERVER_REQUEST_OK;
});
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 10);
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 10);
server.Start(R"""(bind [::]:60053
server udp://127.0.0.1:61053 -blacklist-ip
@@ -134,10 +134,10 @@ TEST_F(IPRule, ignore_ip)
return smartdns::SERVER_REQUEST_OK;
});
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 90);
server.MockPing(PING_TYPE_ICMP, "7.8.9.10", 60, 40);
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 90);
server.MockPing(PING_TYPE_ICMP, "7.8.9.10", 60, 40);
server.Start(R"""(bind [::]:60053
server udp://127.0.0.1:61053 -blacklist-ip
@@ -154,3 +154,59 @@ cache-persist no)""");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "7.8.9.10");
}
TEST_F(IPRule, ignore_ip_set)
{
smartdns::MockServer server_upstream;
smartdns::MockServer server_upstream2;
smartdns::Server server;
std::string file = "/tmp/smartdns_test_ip_set.list" + smartdns::GenerateRandomString(5);
std::ofstream ofs(file);
ASSERT_TRUE(ofs.is_open());
Defer
{
ofs.close();
unlink(file.c_str());
};
server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_A) {
return smartdns::SERVER_REQUEST_SOA;
}
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611);
smartdns::MockServer::AddIP(request, request->domain.c_str(), "4.5.6.7", 611);
smartdns::MockServer::AddIP(request, request->domain.c_str(), "7.8.9.10", 611);
return smartdns::SERVER_REQUEST_OK;
});
/* this ip will be discard, but is reachable */
server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10);
server.MockPing(PING_TYPE_ICMP, "4.5.6.7", 60, 90);
server.MockPing(PING_TYPE_ICMP, "7.8.9.10", 60, 40);
std::string ipset_list = R"""(
1.2.3.0/24
4.5.6.0/24
)""";
ofs.write(ipset_list.c_str(), ipset_list.length());
ofs.flush();
server.Start(R"""(bind [::]:60053
server udp://127.0.0.1:61053 -blacklist-ip
ip-set -name ip-list -file )""" +
file + R"""(
ignore-ip ip-set:ip-list
log-num 0
speed-check-mode none
log-console yes
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "7.8.9.10");
}

View File

@@ -42,11 +42,11 @@ TEST_F(Ptr, query)
return smartdns::SERVER_REQUEST_OK;
}
if (request->qtype == DNS_T_PTR) {
dns_add_PTR(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 30, "my-hostname");
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
}
if (request->qtype == DNS_T_PTR) {
dns_add_PTR(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 30, "my-hostname");
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
@@ -127,11 +127,11 @@ TEST_F(Ptr, smartdns)
return smartdns::SERVER_REQUEST_OK;
}
if (request->qtype == DNS_T_PTR) {
dns_add_PTR(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 30, "my-hostname");
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
}
if (request->qtype == DNS_T_PTR) {
dns_add_PTR(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 30, "my-hostname");
request->response_packet->head.rcode = DNS_RC_NOERROR;
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});

View File

@@ -255,3 +255,55 @@ cache-persist no)""");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
}
TEST_F(QtypeSOA, HTTPS_SOA)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
std::map<int, int> qid_map;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_HTTPS) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
struct dns_rr_nested svcparam_buffer;
dns_add_HTTPS_start(&svcparam_buffer, packet, DNS_RRS_AN, request->domain.c_str(), 3, 1, "a.com");
const char alph[] = "\x02h2\x05h3-19";
int alph_len = sizeof(alph) - 1;
dns_HTTPS_add_alpn(&svcparam_buffer, alph, alph_len);
dns_HTTPS_add_port(&svcparam_buffer, 443);
unsigned char add_v4[] = {1, 2, 3, 4};
unsigned char *addr[1] = {add_v4};
dns_HTTPS_add_ipv4hint(&svcparam_buffer, addr, 1);
unsigned char ech[] = {0x00, 0x45, 0xfe, 0x0d, 0x00};
dns_HTTPS_add_ech(&svcparam_buffer, (void *)ech, sizeof(ech));
unsigned char add_v6[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
addr[0] = add_v6;
dns_HTTPS_add_ipv6hint(&svcparam_buffer, addr, 1);
dns_add_HTTPS_end(&svcparam_buffer);
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-console yes
dualstack-ip-selection no
speed-check-mode none
address /a.com/#
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com HTTPS", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAuthorityNum(), 1);
EXPECT_EQ(client.GetStatus(), "NXDOMAIN");
EXPECT_EQ(client.GetAuthority()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAuthority()[0].GetTTL(), 30);
EXPECT_EQ(client.GetAuthority()[0].GetType(), "SOA");
}

317
test/cases/test-rule.cc Normal file
View File

@@ -0,0 +1,317 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class Rule : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(Rule, Match)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /a.com/5.6.7.8
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("a.a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("aa.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "aa.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST_F(Rule, PrefixWildcardMatch)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /*a.com/5.6.7.8
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("a.a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("aa.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "aa.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("ab.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "ab.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST_F(Rule, SubDomainMatchOnly)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /*.a.com/5.6.7.8
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("a.a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("aa.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "aa.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST_F(Rule, RootDomainMatchOnly)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /-.a.com/5.6.7.8
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
ASSERT_TRUE(client.Query("a.a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("b.a.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
ASSERT_TRUE(client.Query("ba.com A", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "ba.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "A");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4");
}
TEST_F(Rule, AAAA_SOA)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype == DNS_T_A) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 700);
return smartdns::SERVER_REQUEST_OK;
} else if (request->qtype == DNS_T_AAAA) {
smartdns::MockServer::AddIP(request, request->domain.c_str(), "64:ff9b::102:304", 700);
return smartdns::SERVER_REQUEST_OK;
}
return smartdns::SERVER_REQUEST_SOA;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
address /-.a.com/#6
address /*.b.com/#6
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NOERROR");
ASSERT_TRUE(client.Query("a.a.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
ASSERT_TRUE(client.Query("a.b.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetStatus(), "NOERROR");
ASSERT_TRUE(client.Query("b.com AAAA", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 700);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "AAAA");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "64:ff9b::102:304");
}

View File

@@ -44,7 +44,7 @@ TEST_F(Server, all_unreach)
return smartdns::SERVER_REQUEST_OK;
});
server.MockPing(PING_TYPE_ICMP, "2001::", 128, 10000);
server.MockPing(PING_TYPE_ICMP, "2001::", 128, 10000);
server.Start(R"""(bind [::]:60053
bind-tcp [::]:60053
server tls://255.255.255.255
@@ -58,11 +58,11 @@ cache-persist no)""");
ASSERT_TRUE(client.Query("a.com", 60053));
std::cout << client.GetResult() << std::endl;
EXPECT_EQ(client.GetStatus(), "SERVFAIL");
EXPECT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetAnswerNum(), 0);
/* server should not crash */
ASSERT_TRUE(client.Query("a.com +tcp", 60053));
/* server should not crash */
ASSERT_TRUE(client.Query("a.com +tcp", 60053));
std::cout << client.GetResult() << std::endl;
EXPECT_EQ(client.GetStatus(), "SERVFAIL");
EXPECT_EQ(client.GetAnswerNum(), 0);
EXPECT_EQ(client.GetAnswerNum(), 0);
}

View File

@@ -220,7 +220,6 @@ cache-persist no)""");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "5.6.7.8");
}
TEST_F(SpeedCheck, tcp_faster_than_ping)
{
smartdns::MockServer server_upstream;

View File

@@ -245,6 +245,148 @@ cache-persist no)""");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::1");
}
TEST_F(SubNet, v4_server_subnet_txt)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_TXT) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_opt_ecs ecs;
struct dns_rrs *rrs = NULL;
int rr_count = 0;
int i = 0;
int ret = 0;
int has_ecs = 0;
rr_count = 0;
rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count);
if (rr_count <= 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) {
memset(&ecs, 0, sizeof(ecs));
ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs);
if (ret != 0) {
continue;
}
has_ecs = 1;
break;
}
if (has_ecs == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (ecs.family != DNS_OPT_ECS_FAMILY_IPV4) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (memcmp(ecs.addr, "\x08\x08\x08\x00", 4) != 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (ecs.source_prefix != 24) {
return smartdns::SERVER_REQUEST_ERROR;
}
dns_add_TXT(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 6, "hello world");
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053 -subnet 8.8.8.8/24
log-num 0
log-console yes
dualstack-ip-selection no
log-level debug
rr-ttl-min 0
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com TXT", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 6);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "TXT");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "\"hello world\"");
}
TEST_F(SubNet, v6_default_subnet_txt)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_TXT) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_opt_ecs ecs;
struct dns_rrs *rrs = NULL;
int rr_count = 0;
int i = 0;
int ret = 0;
int has_ecs = 0;
rr_count = 0;
rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count);
if (rr_count <= 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) {
memset(&ecs, 0, sizeof(ecs));
ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs);
if (ret != 0) {
continue;
}
has_ecs = 1;
break;
}
if (has_ecs == 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (ecs.family != DNS_OPT_ECS_FAMILY_IPV6) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (memcmp(ecs.addr, "\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00", 16) != 0) {
return smartdns::SERVER_REQUEST_ERROR;
}
if (ecs.source_prefix != 64) {
return smartdns::SERVER_REQUEST_ERROR;
}
dns_add_TXT(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 6, "hello world");
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
dualstack-ip-selection no
rr-ttl-min 0
edns-client-subnet ffff:ffff:ffff:ffff:ffff::/64
log-level debug
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("a.com TXT", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 1);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 6);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "TXT");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "\"hello world\"");
}
TEST_F(SubNet, per_server)
{
smartdns::MockServer server_upstream1;

View File

@@ -376,23 +376,21 @@ bool Server::Start(const std::string &conf, enum CONF_TYPE type)
}
smartdns_reg_post_func(Server::StartPost, this);
smartdns_main(args.size(), argv, fds[1]);
smartdns_main(args.size(), argv, fds[1], 0);
_exit(1);
} else if (pid < 0) {
return false;
}
} else if (mode_ == CREATE_MODE_THREAD) {
thread_ = std::thread([&]() {
std::vector<std::string> args = {
"smartdns", "-f", "-x", "-c", conf_file_, "-p", "-",
};
std::vector<std::string> args = {"smartdns", "-f", "-x", "-c", conf_file_, "-p", "-", "-S"};
char *argv[args.size() + 1];
for (size_t i = 0; i < args.size(); i++) {
argv[i] = (char *)args[i].c_str();
}
smartdns_reg_post_func(Server::StartPost, this);
smartdns_main(args.size(), argv, fds[1]);
smartdns_main(args.size(), argv, fds[1], 1);
smartdns_reg_post_func(nullptr, nullptr);
});
} else {