From ef806ecc9c16a39c178d86c33b6fe893e53717b2 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Sat, 11 Nov 2023 09:58:11 +0800 Subject: [PATCH] feature: simple support DOH server --- etc/smartdns/smartdns.conf | 2 + src/dns_client.c | 2 +- src/dns_conf.c | 8 +- src/dns_server.c | 184 +++++++++++++++++++++++++++++++------ src/http_parse.c | 6 +- src/include/hash.h | 18 ++++ test/cases/test-bind.cc | 30 ++++++ 7 files changed, 219 insertions(+), 31 deletions(-) diff --git a/etc/smartdns/smartdns.conf b/etc/smartdns/smartdns.conf index 7f9f740..d8a2845 100644 --- a/etc/smartdns/smartdns.conf +++ b/etc/smartdns/smartdns.conf @@ -30,6 +30,8 @@ # tls cert file # bind-cert-key-pass [password] # tls private key password +# bind-https server +# bind-https [IP]:[port][@device] [-group [group]] [-no-rule-addr] [-no-rule-nameserver] [-no-rule-ipset] [-no-speed-check] [-no-cache] [-no-rule-soa] [-no-dualstack-selection] # option: # -group: set domain request to use the appropriate server group. # -no-rule-addr: skip address rule. diff --git a/src/dns_client.c b/src/dns_client.c index 6334581..19a1303 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -3328,7 +3328,7 @@ static int _dns_client_send_https(struct dns_server_info *server_info, void *pac "POST %s HTTP/1.1\r\n" "Host: %s\r\n" "User-Agent: smartdns\r\n" - "content-type: application/dns-message\r\n" + "Content-Type: application/dns-message\r\n" "Content-Length: %d\r\n" "\r\n", https_flag->path, https_flag->httphost, len); diff --git a/src/dns_conf.c b/src/dns_conf.c index 613c6cb..d92287d 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -2218,7 +2218,7 @@ static int _config_bind_ip(int argc, char *argv[], DNS_BIND_TYPE type) bind_ip->flags = server_flag; bind_ip->group = group; dns_conf_bind_ip_num++; - if (bind_ip->type == DNS_BIND_TYPE_TLS) { + if (bind_ip->type == DNS_BIND_TYPE_TLS || bind_ip->type == DNS_BIND_TYPE_HTTPS) { if (bind_ip->ssl_cert_file == NULL || bind_ip->ssl_cert_key_file == NULL) { bind_ip->ssl_cert_file = dns_conf_bind_ca_file; bind_ip->ssl_cert_key_file = dns_conf_bind_ca_key_file; @@ -2249,6 +2249,11 @@ static int _config_bind_ip_tls(void *data, int argc, char *argv[]) return _config_bind_ip(argc, argv, DNS_BIND_TYPE_TLS); } +static int _config_bind_ip_https(void *data, int argc, char *argv[]) +{ + return _config_bind_ip(argc, argv, DNS_BIND_TYPE_HTTPS); +} + static int _config_option_parser_filepath(void *data, int argc, char *argv[]) { if (argc <= 1) { @@ -4098,6 +4103,7 @@ static struct config_item _config_item[] = { CONF_CUSTOM("bind", _config_bind_ip_udp, NULL), CONF_CUSTOM("bind-tcp", _config_bind_ip_tcp, NULL), CONF_CUSTOM("bind-tls", _config_bind_ip_tls, NULL), + CONF_CUSTOM("bind-https", _config_bind_ip_https, NULL), CONF_CUSTOM("bind-cert-file", _config_option_parser_filepath, &dns_conf_bind_ca_file), CONF_CUSTOM("bind-cert-key-file", _config_option_parser_filepath, &dns_conf_bind_ca_key_file), CONF_STRING("bind-cert-key-pass", dns_conf_bind_ca_key_pass, DNS_MAX_PATH), diff --git a/src/dns_server.c b/src/dns_server.c index aad04c1..351c5d2 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -27,6 +27,7 @@ #include "dns_conf.h" #include "fast_ping.h" #include "hashtable.h" +#include "http_parse.h" #include "list.h" #include "nftset.h" #include "tlog.h" @@ -1101,7 +1102,7 @@ static void _dns_server_conn_release(struct dns_server_conn_head *conn) tls_client->ssl = NULL; } pthread_mutex_destroy(&tls_client->ssl_lock); - } else if (conn->type == DNS_CONN_TYPE_TLS_SERVER) { + } else if (conn->type == DNS_CONN_TYPE_TLS_SERVER || conn->type == DNS_CONN_TYPE_HTTPS_SERVER) { struct dns_server_conn_tls_server *tls_server = (struct dns_server_conn_tls_server *)conn; if (tls_server->ssl_ctx != NULL) { SSL_CTX_free(tls_server->ssl_ctx); @@ -1141,6 +1142,71 @@ static int _dns_server_reply_tcp_to_buffer(struct dns_server_conn_tcp_client *tc return 0; } +static int _dns_server_reply_http_error(struct dns_server_conn_tcp_client *tcpclient, int code, const char *code_msg, + const char *message) +{ + int send_len = 0; + int http_len = 0; + unsigned char data[DNS_IN_PACKSIZE]; + + http_len = snprintf((char *)data, DNS_IN_PACKSIZE, + "HTTP/1.1 %d %s\r\n" + "\r\n" + "%s", + code, code_msg, message); + + send_len = _dns_server_tcp_socket_send(tcpclient, data, http_len); + if (send_len < 0) { + if (errno == EAGAIN) { + /* save data to buffer, and retry when EPOLLOUT is available */ + return _dns_server_reply_tcp_to_buffer(tcpclient, data, http_len); + } + return -1; + } else if (send_len < http_len) { + /* save remain data to buffer, and retry when EPOLLOUT is available */ + return _dns_server_reply_tcp_to_buffer(tcpclient, data + send_len, http_len - send_len); + } + + return 0; +} + +static int _dns_server_reply_https(struct dns_request *request, struct dns_server_conn_tcp_client *tcpclient, + void *packet, unsigned short len) +{ + int send_len = 0; + int http_len = 0; + unsigned char inpacket_data[DNS_IN_PACKSIZE]; + unsigned char *inpacket = inpacket_data; + + if (len > sizeof(inpacket_data)) { + tlog(TLOG_ERROR, "packet size is invalid."); + return -1; + } + + http_len = snprintf((char *)inpacket, DNS_IN_PACKSIZE, + "HTTP/1.1 200 OK\r\n" + "content-type: application/dns-message\r\n" + "Content-Length: %d\r\n" + "\r\n", + len); + memcpy(inpacket + http_len, packet, len); + http_len += len; + + send_len = _dns_server_tcp_socket_send(tcpclient, inpacket, http_len); + if (send_len < 0) { + if (errno == EAGAIN) { + /* save data to buffer, and retry when EPOLLOUT is available */ + return _dns_server_reply_tcp_to_buffer(tcpclient, inpacket, http_len); + } + return -1; + } else if (send_len < http_len) { + /* save remain data to buffer, and retry when EPOLLOUT is available */ + return _dns_server_reply_tcp_to_buffer(tcpclient, inpacket + send_len, http_len - send_len); + } + + return 0; +} + static int _dns_server_reply_tcp(struct dns_request *request, struct dns_server_conn_tcp_client *tcpclient, void *packet, unsigned short len) { @@ -1255,6 +1321,8 @@ static int _dns_reply_inpacket(struct dns_request *request, unsigned char *inpac ret = _dns_server_reply_tcp(request, (struct dns_server_conn_tcp_client *)conn, inpacket, inpacket_len); } else if (conn->type == DNS_CONN_TYPE_TLS_CLIENT) { ret = _dns_server_reply_tcp(request, (struct dns_server_conn_tcp_client *)conn, inpacket, inpacket_len); + } else if (conn->type == DNS_CONN_TYPE_HTTPS_CLIENT) { + ret = _dns_server_reply_https(request, (struct dns_server_conn_tcp_client *)conn, inpacket, inpacket_len); } else { ret = -1; } @@ -6112,47 +6180,104 @@ static int _dns_server_tcp_process_one_request(struct dns_server_conn_tcp_client int total_len = tcpclient->recvbuff.size; int proceed_len = 0; unsigned char *request_data = NULL; - int ret = 0; + int ret = RECV_ERROR_FAIL; + int len = 0; + struct http_head *http_head = NULL; /* Handling multiple requests */ for (;;) { - if ((total_len - proceed_len) <= (int)sizeof(unsigned short)) { - ret = RECV_ERROR_AGAIN; - break; + ret = RECV_ERROR_FAIL; + if (tcpclient->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) { + if ((total_len - proceed_len) <= 0) { + ret = RECV_ERROR_AGAIN; + goto out; + } + + http_head = http_head_init(4096); + if (http_head == NULL) { + goto out; + } + + len = http_head_parse(http_head, (char *)tcpclient->recvbuff.buf, tcpclient->recvbuff.size); + if (len < 0) { + if (len == -1) { + ret = 0; + goto out; + } + + tlog(TLOG_DEBUG, "remote server not supported."); + goto errout; + } + + if (http_head_get_method(http_head) != HTTP_METHOD_POST) { + tlog(TLOG_DEBUG, "remote server not supported."); + goto errout; + } + + const char *content_type = http_head_get_fields_value(http_head, "Content-Type"); + if (content_type == NULL || + strncmp(content_type, "application/dns-message", sizeof("application/dns-message")) != 0) { + tlog(TLOG_DEBUG, "content type not supported, %s", content_type); + goto errout; + } + + request_len = http_head_get_data_len(http_head); + if (request_len >= len) { + tlog(TLOG_DEBUG, "request length is invalid."); + + goto errout; + } + request_data = (unsigned char *)http_head_get_data(http_head); + proceed_len += len; + } else { + if ((total_len - proceed_len) <= (int)sizeof(unsigned short)) { + ret = RECV_ERROR_AGAIN; + break; + } + + /* Get record length */ + request_data = (unsigned char *)(tcpclient->recvbuff.buf + proceed_len); + request_len = ntohs(*((unsigned short *)(request_data))); + + if (request_len >= sizeof(tcpclient->recvbuff.buf)) { + tlog(TLOG_DEBUG, "request length is invalid."); + return RECV_ERROR_FAIL; + } + + if (request_len > (total_len - proceed_len - sizeof(unsigned short))) { + ret = RECV_ERROR_AGAIN; + break; + } + + request_data = (unsigned char *)(tcpclient->recvbuff.buf + proceed_len + sizeof(unsigned short)); + proceed_len += sizeof(unsigned short) + request_len; } - /* Get record length */ - request_data = (unsigned char *)(tcpclient->recvbuff.buf + proceed_len); - request_len = ntohs(*((unsigned short *)(request_data))); - - if (request_len >= sizeof(tcpclient->recvbuff.buf)) { - tlog(TLOG_DEBUG, "request length is invalid."); - return RECV_ERROR_FAIL; - } - - if (request_len > (total_len - proceed_len - sizeof(unsigned short))) { - ret = RECV_ERROR_AGAIN; - break; - } - - request_data = (unsigned char *)(tcpclient->recvbuff.buf + proceed_len + sizeof(unsigned short)); - /* process one record */ ret = _dns_server_recv(&tcpclient->head, request_data, request_len, &tcpclient->localaddr, tcpclient->localaddr_len, &tcpclient->addr, tcpclient->addr_len); if (ret != 0) { return ret; } - - proceed_len += sizeof(unsigned short) + request_len; } +out: + if (total_len > proceed_len && proceed_len > 0) { memmove(tcpclient->recvbuff.buf, tcpclient->recvbuff.buf + proceed_len, total_len - proceed_len); } tcpclient->recvbuff.size -= proceed_len; +errout: + if (http_head) { + http_head_destroy(http_head); + } + + if (ret == RECV_ERROR_FAIL && tcpclient->head.type == DNS_CONN_TYPE_HTTPS_CLIENT) { + _dns_server_reply_http_error(tcpclient, 400, "Bad Request", "Bad Request"); + } + return ret; } @@ -6273,7 +6398,14 @@ static int _dns_server_tls_accept(struct dns_server_conn_tls_server *tls_server, memset(tls_client, 0, sizeof(*tls_client)); tls_client->head.fd = fd; - tls_client->head.type = DNS_CONN_TYPE_TLS_CLIENT; + if (tls_server->head.type == DNS_CONN_TYPE_TLS_SERVER) { + tls_client->head.type = DNS_CONN_TYPE_TLS_CLIENT; + } else if (tls_server->head.type == DNS_CONN_TYPE_HTTPS_SERVER) { + tls_client->head.type = DNS_CONN_TYPE_HTTPS_CLIENT; + } else { + tlog(TLOG_ERROR, "invalid http server type."); + goto errout; + } tls_client->head.server_flags = tls_server->head.server_flags; tls_client->head.dns_group = tls_server->head.dns_group; tls_client->head.ipset_nftset_rule = tls_server->head.ipset_nftset_rule; @@ -6402,10 +6534,10 @@ static int _dns_server_process(struct dns_server_conn_head *conn, struct epoll_e tlog(TLOG_DEBUG, "process TCP packet from %s failed.", get_host_by_addr(name, sizeof(name), (struct sockaddr *)&tcpclient->addr)); } - } else if (conn->type == DNS_CONN_TYPE_TLS_SERVER) { + } else if (conn->type == DNS_CONN_TYPE_TLS_SERVER || conn->type == DNS_CONN_TYPE_HTTPS_SERVER) { struct dns_server_conn_tls_server *tls_server = (struct dns_server_conn_tls_server *)conn; ret = _dns_server_tls_accept(tls_server, event, now); - } else if (conn->type == DNS_CONN_TYPE_TLS_CLIENT) { + } else if (conn->type == DNS_CONN_TYPE_TLS_CLIENT || conn->type == DNS_CONN_TYPE_HTTPS_CLIENT) { struct dns_server_conn_tls_client *tls_client = (struct dns_server_conn_tls_client *)conn; ret = _dns_server_process_tls(tls_client, event, now); if (ret != 0) { diff --git a/src/http_parse.c b/src/http_parse.c index 850f1ad..afd8217 100644 --- a/src/http_parse.c +++ b/src/http_parse.c @@ -111,10 +111,10 @@ const char *http_head_get_fields_value(struct http_head *http_head, const char * uint32_t key; struct http_head_fields *filed; - key = hash_string(name); + key = hash_string_case(name); hash_for_each_possible(http_head->field_map, filed, node, key) { - if (strncmp(filed->name, name, 128) == 0) { + if (strncasecmp(filed->name, name, 128) == 0) { return filed->value; } } @@ -205,7 +205,7 @@ static int _http_head_add_fields(struct http_head *http_head, char *name, char * fields->value = value; list_add_tail(&fields->list, &http_head->field_head.list); - key = hash_string(name); + key = hash_string_case(name); hash_add(http_head->field_map, &fields->node, key); return 0; diff --git a/src/include/hash.h b/src/include/hash.h index 76d4081..cf280b3 100644 --- a/src/include/hash.h +++ b/src/include/hash.h @@ -21,6 +21,7 @@ #include "bitmap.h" #include "jhash.h" +#include /* Fast hashing routine for ints, longs and pointers. (C) 2002 Nadia Yvette Chambers, IBM */ @@ -223,11 +224,28 @@ static inline uint32_t hash_string_initval(const char *s, uint32_t initval) return h; } +static inline uint32_t hash_string_case_initval(const char *s, uint32_t initval) +{ + uint32_t h = initval; + + while (*s) { + h = h * 31 + tolower(*s); + s++; + } + + return h; +} + static inline uint32_t hash_string(const char *s) { return hash_string_initval(s, 0); } +static inline uint32_t hash_string_case(const char *s) +{ + return hash_string_case_initval(s, 0); +} + static inline uint32_t hash_string_array(const char **a) { uint32_t h = 0; diff --git a/test/cases/test-bind.cc b/test/cases/test-bind.cc index 64f1dfa..df4be5f 100644 --- a/test/cases/test-bind.cc +++ b/test/cases/test-bind.cc @@ -56,6 +56,36 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); } +TEST(Bind, https) +{ + Defer + { + unlink("/tmp/smartdns-cert.pem"); + unlink("/tmp/smartdns-key.pem"); + }; + + smartdns::Server server_wrap; + smartdns::Server server; + + server.Start(R"""(bind [::]:61053 +server https://127.0.0.1:60053 -no-check-certificate +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + server_wrap.Start(R"""(bind-https [::]:60053 +address /example.com/1.2.3.4 +log-num 0 +log-console yes +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("example.com", 61053)); + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +} + TEST(Bind, udp_tcp) { smartdns::MockServer server_upstream;