diff --git a/dns.c b/dns.c index 7a2921b..576f478 100644 --- a/dns.c +++ b/dns.c @@ -78,6 +78,10 @@ struct dns_rrs *dns_get_rrs_start(struct dns_packet *packet, int type, int *coun break; } + if (start == DNS_RR_END) { + return NULL; + } + return (struct dns_rrs *)(packet->data + start); } @@ -111,7 +115,6 @@ int dns_rr_add_end(struct dns_packet *packet, int type, dns_type_t rrtype, int l unsigned short *count; unsigned short *start; - len += sizeof(*rrs); if (packet->len + len > packet->size - sizeof(*packet)) { return -1; } @@ -143,7 +146,7 @@ int dns_rr_add_end(struct dns_packet *packet, int type, dns_type_t rrtype, int l rrs->len = len; rrs->type = rrtype; *start = packet->len; - packet->len += len; + packet->len += len + sizeof(*rrs); return 0; } @@ -254,6 +257,71 @@ int _dns_get_rr_head(struct dns_data_context *data_context, char *domain, int ma return len; } +int dns_add_CNAME(struct dns_packet *packet, char *domain, int ttl, char *cname) +{ + int maxlen = 0; + int len = 0; + struct dns_data_context data_context; + int rr_len = 0; + + rr_len = strnlen(cname, maxlen - 1) + 1; + unsigned char *data = _dns_add_rrs_start(packet, &maxlen); + if (data == NULL) { + return -1; + } + + if (rr_len >= maxlen) { + return -1; + } + + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = maxlen; + + len = _dns_add_rr_head(&data_context, domain, DNS_T_CNAME, DNS_C_IN, ttl, rr_len); + if (len < 0) { + return -1; + } + + memcpy(data_context.ptr, cname, rr_len); + data_context.ptr += rr_len; + len = data_context.ptr - data_context.data; + + return dns_rr_add_end(packet, DNS_RRS_AN, DNS_T_CNAME, len); +} + +int dns_get_CNAME(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *cname, int cname_size) +{ + int qtype = 0; + int qclass = 0; + int rr_len = 0; + int ret = 0; + struct dns_data_context data_context; + + unsigned char *data = rrs->data; + + if (rrs->type != DNS_T_CNAME) { + return -1; + } + + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, maxsize, &qtype, &qclass, ttl, &rr_len); + if (ret < 0) { + return -1; + } + + if (qtype != DNS_T_CNAME || rr_len > cname_size) { + return -1; + } + + memcpy(cname, data_context.ptr, rr_len); + + return 0; +} + int dns_add_A(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[4]) { int maxlen = 0; @@ -308,7 +376,7 @@ int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned return -1; } - memcpy(addr, rrs->data, DNS_RR_A_LEN); + memcpy(addr, data_context.ptr, DNS_RR_A_LEN); return 0; } @@ -371,7 +439,7 @@ int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsig return -1; } - memcpy(addr, rrs->data, DNS_RR_AAAA_LEN); + memcpy(addr, data_context.ptr, DNS_RR_AAAA_LEN); return 0; } @@ -476,41 +544,108 @@ int _dns_encode_head(struct dns_context *context) return len; } +int dns_parse_domain(char *dns, int offset, char *domain, int *space) +{ + unsigned char val, *pval; + unsigned short len; + + int sp = 0; + int domain_len = 0; + int org = 1; + + while (1) { + pval = (unsigned char *)(dns + offset); + val = *pval; + + if (val == 0) { + domain[domain_len - 1] = 0; + domain_len--; + if (org) + sp++; + break; + } else if (val <= 63) { + memcpy(domain + domain_len, dns + offset + 1, val); + domain_len += val; + domain[domain_len] = '.'; + domain_len++; + + offset += (val + 1); + + if (org) + sp += (val + 1); + } else { + len = *(unsigned short *)(dns + offset); + len = ntohs(len); + len = len & (~0xc000); + + if (org) + sp += 2; + + org = 0; + offset = len; + } + } + + *space = sp; + + return 0; +} + int _dns_decode_domain(struct dns_context *context, char *output, int size) { int output_len = 0; int copy_len = 0; int len = *(context->ptr); + unsigned char *ptr = context->ptr; + int is_compressed = 0; - while (*(context->ptr)) { - if (_dns_left_len(context) < 1) { + + int sp = 0; + dns_parse_domain(context->data, context->ptr - context->data, output, &sp); + context->ptr += sp; + return 0; + + + while (1) { + len = *ptr; + if (len == 0) { + *output = 0; + ptr++; + break; + } + + if (len >= 0xC0) { + ptr = context->data + (dns_read_short(&ptr) & 0x3FFF); + is_compressed = 1; + context->ptr += 2; + continue; + } + + if (context->maxsize - (ptr - context->data) < 1) { return -1; } - context->ptr++; + ptr++; if (output_len < size - 1) { copy_len = (len < size - output_len) ? len : size - 1 - output_len; - if (_dns_left_len(context) < copy_len) { + if (context->maxsize - (ptr - context->data) < 1) { return -1; } - memcpy(output, context->ptr, copy_len); + memcpy(output, ptr, copy_len); } - context->ptr += len; + ptr += len; output += len; output_len += len; - - len = *(context->ptr); - if (len == 0) { - break; - } *output = '.'; output++; - } - *output = 0; - context->ptr++; + if (is_compressed == 0) { + context->ptr = ptr; + } + + printf("--%p\n", context->ptr); return 0; } @@ -536,13 +671,15 @@ int _dns_encode_domain(struct dns_context *context, char *domain) } *ptr_num = num; - *context->ptr = 0; + *(context->ptr) = 0; + context->ptr++; return 0; } int _dns_decode_qr_head(struct dns_context *context, char *domain, int domain_size, int *qtype, int *qclass) { int ret = 0; + ret = _dns_decode_domain(context, domain, domain_size); if (ret < 0) { return -1; @@ -613,6 +750,53 @@ int _dns_encode_rr_head(struct dns_context *context, char *domain, int qtype, in return 0; } +int _dns_decode_CNAME(struct dns_context *context, char *cname, int cname_size) +{ + int ret = 0; + ret = _dns_decode_domain(context, cname, cname_size); + if (ret < 0) { + return -1; + } + + return 0; +} + +int _dns_encode_CNAME(struct dns_context *context, struct dns_rrs *rrs) +{ + int ret; + int qtype = 0; + int qclass = 0; + int ttl = 0; + char domain[DNS_MAX_CNAME_LEN]; + int rr_len; + struct dns_data_context data_context; + + data_context.data = rrs->data; + data_context.ptr = rrs->data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (ret < 0) { + return -1; + } + + if (rr_len > rrs->len) { + return -1; + } + + ret = _dns_encode_rr_head(context, domain, qtype, qclass, ttl, rr_len); + if (ret < 0) { + return -1; + } + + ret = _dns_encode_domain(context, (char *)rrs->data); + if (ret < 0) { + return -1; + } + + return 0; +} + int _dns_decode_A(struct dns_context *context, unsigned char addr[4]) { if (_dns_left_len(context) < DNS_RR_A_LEN) { @@ -748,6 +932,18 @@ int _dns_decode_an(struct dns_context *context) } switch (qtype) { + case DNS_T_CNAME: { + char cname[DNS_MAX_CNAME_LEN]; + ret = _dns_decode_CNAME(context, cname, DNS_MAX_CNAME_LEN); + if (ret < 0) { + return -1; + } + + ret = dns_add_CNAME(packet, domain, ttl, cname); + if (ret < 0) { + return -1; + } + } break; case DNS_T_A: { unsigned char addr[DNS_RR_A_LEN]; ret = _dns_decode_A(context, addr); @@ -773,6 +969,7 @@ int _dns_decode_an(struct dns_context *context) } } break; default: + context->ptr += rr_len; break; } diff --git a/dns.h b/dns.h index b1de1ea..64662e3 100644 --- a/dns.h +++ b/dns.h @@ -22,7 +22,7 @@ #define DNS_RRS_NS 2 #define DNS_RRS_NR 3 -#define DNS_RR_END (-1) +#define DNS_RR_END (0XFFFF) typedef enum dns_class { DNS_C_IN = 1, DNS_C_ANY = 255 } dns_class_t; @@ -118,6 +118,10 @@ struct dns_rrs *dns_get_rrs_next(struct dns_packet *packet, struct dns_rrs *rrs) struct dns_rrs *dns_get_rrs_start(struct dns_packet *packet, int type, int *count); +int dns_add_CNAME(struct dns_packet *packet, char *domain, int ttl, char *cname); + +int dns_get_CNAME(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *cname, int cname_size); + int dns_add_A(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[4]); int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[4]); diff --git a/dns_server.c b/dns_server.c index 8e4e6d4..245938f 100644 --- a/dns_server.c +++ b/dns_server.c @@ -46,7 +46,7 @@ static void tv_sub(struct timeval *out, struct timeval *in) void _dns_server_period_run() { - return; + unsigned char packet_data[DNS_INPACKET_SIZE]; unsigned char data[DNS_INPACKET_SIZE]; @@ -64,7 +64,7 @@ void _dns_server_period_run() socklen_t to_len = sizeof(to); dns_packet_init(packet, DNS_INPACKET_SIZE, &head); - dns_add_domain(packet, "www.baidu.com", 1, 1); + dns_add_domain(packet, "www.huawei.com", 1, 1); len = dns_encode(data, DNS_INPACKET_SIZE, packet); memset(&to, 0, sizeof(to)); @@ -75,8 +75,6 @@ void _dns_server_period_run() if (len < 0) { printf("send failed."); } - - printf("send %d\n", len); } static int _dns_server_process(struct timeval *now) @@ -97,6 +95,7 @@ static int _dns_server_process(struct timeval *now) len = dns_decode(packet, DNS_INPACKET_SIZE, inpacket, len); if (len) { printf("decode failed.\n"); + return 0; goto errout; } @@ -116,6 +115,11 @@ static int _dns_server_process(struct timeval *now) dns_get_A(rrs, name, 128, &ttl, addr); printf("%s %d : %d.%d.%d.%d\n", name, ttl, addr[0], addr[1], addr[2], addr[3]); } break; + case DNS_T_CNAME: { + char cname[128]; + dns_get_CNAME(rrs, name, 128, &ttl, cname, 128); + printf("%s %d : %s\n", name, ttl, cname); + } break; default: break; } @@ -133,6 +137,7 @@ static int _dns_server_process(struct timeval *now) } } + printf("\n"); return 0; errout: return -1; @@ -157,7 +162,7 @@ int dns_server_run(void) last = now; } - num = epoll_wait(server.epoll_fd, events, DNS_MAX_EVENTS, 100); + num = epoll_wait(server.epoll_fd, events, DNS_MAX_EVENTS, 1000); if (num < 0) { gettimeofday(&now, 0); usleep(100000);