feature: add option srv-record support SRV.

This commit is contained in:
Nick Peng
2023-12-07 23:15:11 +08:00
parent 61a6e676bc
commit 5df4364809
6 changed files with 335 additions and 2 deletions

View File

@@ -265,6 +265,11 @@ log-level info
# specific cname to domain
# cname /domain/target
# add srv record, support multiple srv record.
# srv-record /domain/[target][,port][,priority][,weight]
# srv-record /_ldap._tcp.example.com/ldapserver.example.com,389
# srv-record /_ldap._tcp.example.com/
# enalbe DNS64 feature
# dns64 [ip/subnet]
# dns64 64:ff9b::/96

View File

@@ -66,6 +66,9 @@ int dns_hosts_record_num;
/* DNS64 */
struct dns_dns64 dns_conf_dns_dns64;
/* SRV-HOST */
struct dns_srv_record_table dns_conf_srv_record_table;
/* server ip/port */
struct dns_bind_ip dns_conf_bind_ip[DNS_MAX_BIND_IP];
int dns_conf_bind_ip_num = 0;
@@ -505,6 +508,26 @@ static void _config_proxy_table_destroy(void)
}
}
static void _config_srv_record_table_destroy(void)
{
struct dns_srv_records *srv_records = NULL;
struct dns_srv_record *srv_record, *tmp1 = NULL;
struct hlist_node *tmp = NULL;
unsigned int i;
hash_for_each_safe(dns_conf_srv_record_table.srv, i, tmp, srv_records, node)
{
list_for_each_entry_safe(srv_record, tmp1, &srv_records->list, list)
{
list_del(&srv_record->list);
free(srv_record);
}
hlist_del_init(&srv_records->node);
free(srv_records);
}
}
static int _config_server(int argc, char *argv[], dns_server_type_t type, int default_port)
{
int index = dns_conf_server_num;
@@ -1860,6 +1883,122 @@ errout:
return 0;
}
struct dns_srv_records *dns_server_get_srv_record(const char *domain)
{
uint32_t key = 0;
key = hash_string(domain);
struct dns_srv_records *srv_records = NULL;
hash_for_each_possible(dns_conf_srv_record_table.srv, srv_records, node, key)
{
if (strncmp(srv_records->domain, domain, DNS_MAX_CONF_CNAME_LEN) == 0) {
return srv_records;
}
}
return NULL;
}
static int _confg_srv_record_add(const char *domain, const char *host, unsigned short priority, unsigned short weight,
unsigned short port)
{
struct dns_srv_records *srv_records = NULL;
struct dns_srv_record *srv_record = NULL;
uint32_t key = 0;
srv_records = dns_server_get_srv_record(domain);
if (srv_records == NULL) {
srv_records = malloc(sizeof(*srv_records));
if (srv_records == NULL) {
goto errout;
}
memset(srv_records, 0, sizeof(*srv_records));
safe_strncpy(srv_records->domain, domain, DNS_MAX_CONF_CNAME_LEN);
INIT_LIST_HEAD(&srv_records->list);
key = hash_string(domain);
hash_add(dns_conf_srv_record_table.srv, &srv_records->node, key);
}
srv_record = malloc(sizeof(*srv_record));
if (srv_record == NULL) {
goto errout;
}
memset(srv_record, 0, sizeof(*srv_record));
safe_strncpy(srv_record->host, host, DNS_MAX_CONF_CNAME_LEN);
srv_record->priority = priority;
srv_record->weight = weight;
srv_record->port = port;
list_add_tail(&srv_record->list, &srv_records->list);
return 0;
errout:
if (srv_record != NULL) {
free(srv_record);
}
return -1;
}
static int _config_srv_record(void *data, int argc, char *argv[])
{
char *value = NULL;
char domain[DNS_MAX_CONF_CNAME_LEN];
char buff[DNS_MAX_CONF_CNAME_LEN];
char *ptr = NULL;
int ret = -1;
char *host_s;
char *priority_s;
char *weight_s;
char *port_s;
unsigned short priority = 0;
unsigned short weight = 0;
unsigned short port = 1;
if (argc < 2) {
goto errout;
}
value = argv[1];
if (_get_domain(value, domain, DNS_MAX_CONF_CNAME_LEN, &value) != 0) {
goto errout;
}
safe_strncpy(buff, value, sizeof(buff));
host_s = strtok_r(buff, ",", &ptr);
if (host_s == NULL) {
host_s = "";
goto out;
}
port_s = strtok_r(NULL, ",", &ptr);
if (port_s != NULL) {
port = atoi(port_s);
}
priority_s = strtok_r(NULL, ",", &ptr);
if (priority_s != NULL) {
priority = atoi(priority_s);
}
weight_s = strtok_r(NULL, ",", &ptr);
if (weight_s != NULL) {
weight = atoi(weight_s);
}
out:
ret = _confg_srv_record_add(domain, host_s, priority, weight, port);
if (ret != 0) {
goto errout;
}
return 0;
errout:
tlog(TLOG_ERROR, "add srv-record %s:%s failed", domain, value);
return -1;
}
static void _config_speed_check_mode_clear(struct dns_domain_check_orders *check_orders)
{
memset(check_orders->orders, 0, sizeof(check_orders->orders));
@@ -4154,6 +4293,7 @@ static struct config_item _config_item[] = {
CONF_YESNO("expand-ptr-from-address", &dns_conf_expand_ptr_from_address),
CONF_CUSTOM("address", _config_address, NULL),
CONF_CUSTOM("cname", _config_cname, NULL),
CONF_CUSTOM("srv-record", _config_srv_record, NULL),
CONF_CUSTOM("proxy-server", _config_proxy_server, NULL),
CONF_YESNO("ipset-timeout", &dns_conf_ipset_timeout_enable),
CONF_CUSTOM("ipset", _config_ipset, NULL),
@@ -4406,6 +4546,7 @@ static int _dns_server_load_conf_init(void)
hash_init(dns_ptr_table.ptr);
hash_init(dns_domain_set_name_table.names);
hash_init(dns_ip_set_name_table.names);
hash_init(dns_conf_srv_record_table.srv);
return 0;
}
@@ -4456,6 +4597,7 @@ void dns_server_load_exit(void)
_config_host_table_destroy(0);
_config_qtype_soa_table_destroy();
_config_proxy_table_destroy();
_config_srv_record_table_destroy();
dns_conf_server_num = 0;
dns_server_bind_destroy();

View File

@@ -488,6 +488,25 @@ struct dns_dns64 {
uint32_t prefix_len;
};
struct dns_srv_record {
struct list_head list;
char host[DNS_MAX_CNAME_LEN];
unsigned short priority;
unsigned short weight;
unsigned short port;
};
struct dns_srv_records {
char domain[DNS_MAX_CNAME_LEN];
struct hlist_node node;
struct list_head list;
};
struct dns_srv_record_table {
DECLARE_HASHTABLE(srv, 4);
};
extern struct dns_srv_record_table dns_conf_srv_record_table;
extern struct dns_dns64 dns_conf_dns_dns64;
extern struct dns_bind_ip dns_conf_bind_ip[DNS_MAX_BIND_IP];
@@ -584,6 +603,8 @@ int dns_server_check_update_hosts(void);
struct dns_proxy_names *dns_server_get_proxy_nams(const char *proxyname);
struct dns_srv_records *dns_server_get_srv_record(const char *domain);
extern int config_additional_file(void *data, int argc, char *argv[]);
const char *dns_conf_get_cache_dir(void);

View File

@@ -276,6 +276,8 @@ struct dns_request {
int has_soa;
int force_soa;
struct dns_srv_records *srv_records;
atomic_t notified;
atomic_t do_callback;
atomic_t adblock;
@@ -949,6 +951,29 @@ static void _dns_server_setup_soa(struct dns_request *request)
soa->minimum = 86400;
}
static int _dns_server_add_srv(struct dns_server_post_context *context)
{
struct dns_request *request = context->request;
struct dns_srv_records *srv_records = request->srv_records;
struct dns_srv_record *srv_record = NULL;
int ret = 0;
if (srv_records == NULL) {
return 0;
}
list_for_each_entry(srv_record, &srv_records->list, list)
{
ret = dns_add_SRV(context->packet, DNS_RRS_AN, request->domain, request->ip_ttl, srv_record->priority,
srv_record->weight, srv_record->port, srv_record->host);
if (ret != 0) {
return -1;
}
}
return 0;
}
static int _dns_add_rrs(struct dns_server_post_context *context)
{
struct dns_request *request = context->request;
@@ -1011,6 +1036,10 @@ static int _dns_add_rrs(struct dns_server_post_context *context)
ret |= dns_add_OPT_ECS(context->packet, &request->ecs);
}
if (request->srv_records != NULL) {
ret |= _dns_server_add_srv(context);
}
if (request->rcode != DNS_RC_NOERROR) {
tlog(TLOG_INFO, "result: %s, qtype: %d, rtcode: %d, id: %d", domain, context->qtype, request->rcode,
request->id);
@@ -4159,6 +4188,28 @@ static int _dns_server_process_DDR(struct dns_request *request)
}
static int _dns_server_process_srv(struct dns_request *request)
{
struct dns_srv_records *srv_records = dns_server_get_srv_record(request->domain);
if (srv_records == NULL) {
return -1;
}
request->rcode = DNS_RC_NOERROR;
request->ip_ttl = _dns_server_get_local_ttl(request);
request->srv_records = srv_records;
struct dns_server_post_context context;
_dns_server_post_context_init(&context, request);
context.do_audit = 1;
context.do_reply = 1;
context.do_cache = 0;
context.do_force_soa = 0;
_dns_request_post(&context);
return 0;
}
static int _dns_server_process_svcb(struct dns_request *request)
{
if (strncmp("_dns.resolver.arpa", request->domain, DNS_MAX_CNAME_LEN) == 0) {
return _dns_server_process_DDR(request);
@@ -5268,7 +5319,7 @@ static int _dns_server_process_special_query(struct dns_request *request)
switch (request->qtype) {
case DNS_T_PTR:
break;
case DNS_T_SVCB:
case DNS_T_SRV:
ret = _dns_server_process_srv(request);
if (ret == 0) {
goto clean_exit;
@@ -5277,6 +5328,15 @@ static int _dns_server_process_special_query(struct dns_request *request)
request->passthrough = 1;
}
break;
case DNS_T_SVCB:
ret = _dns_server_process_svcb(request);
if (ret == 0) {
goto clean_exit;
} else {
/* pass to upstream server */
request->passthrough = 1;
}
break;
case DNS_T_A:
break;
case DNS_T_AAAA:

View File

@@ -1112,7 +1112,7 @@ int main(int argc, char *argv[])
errout:
if (is_run_as_daemon) {
daemon_kickoff(ret, dns_conf_log_console | verbose_screen);
} else {
} else if (dns_conf_log_console == 0 && verbose_screen == 0) {
_smartdns_print_error_tip();
}
smartdns_test_notify(2);

105
test/cases/test-srv.cc Normal file
View File

@@ -0,0 +1,105 @@
/*************************************************************************
*
* Copyright (C) 2018-2023 Ruilin Peng (Nick) <pymumu@gmail.com>.
*
* smartdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* smartdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "client.h"
#include "dns.h"
#include "include/utils.h"
#include "server.h"
#include "util.h"
#include "gtest/gtest.h"
#include <fstream>
class SRV : public ::testing::Test
{
protected:
virtual void SetUp() {}
virtual void TearDown() {}
};
TEST_F(SRV, query)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_SRV) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www.example.com");
dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www1.example.com");
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
speed-check-mode none
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("_ldap._tcp.local.com SRV", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 2);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "_ldap._tcp.local.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 603);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "SRV");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1 1 443 www.example.com.");
}
TEST_F(SRV, match)
{
smartdns::MockServer server_upstream;
smartdns::Server server;
server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) {
if (request->qtype != DNS_T_SRV) {
return smartdns::SERVER_REQUEST_SOA;
}
struct dns_packet *packet = request->response_packet;
dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www.example.com");
dns_add_SRV(packet, DNS_RRS_AN, request->domain.c_str(), 603, 1, 1, 443, "www1.example.com");
return smartdns::SERVER_REQUEST_OK;
});
server.Start(R"""(bind [::]:60053
server 127.0.0.1:61053
log-num 0
log-console yes
log-level debug
srv-record /_ldap._tcp.local.com/www.a.com,443,1,1
srv-record /_ldap._tcp.local.com/www1.a.com,443,1,1
srv-record /_ldap._tcp.local.com/www2.a.com,443,1,1
speed-check-mode none
cache-persist no)""");
smartdns::Client client;
ASSERT_TRUE(client.Query("_ldap._tcp.local.com SRV", 60053));
std::cout << client.GetResult() << std::endl;
ASSERT_EQ(client.GetAnswerNum(), 3);
EXPECT_EQ(client.GetStatus(), "NOERROR");
EXPECT_EQ(client.GetAnswer()[0].GetName(), "_ldap._tcp.local.com");
EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 600);
EXPECT_EQ(client.GetAnswer()[0].GetType(), "SRV");
EXPECT_EQ(client.GetAnswer()[0].GetData(), "1 1 443 www.a.com.");
}