From b8cb84fc4a8faf1bd7ade1de3efe33d168fb8065 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Sat, 12 May 2018 00:24:18 +0800 Subject: [PATCH] Update code --- dns.c | 482 ++++++++++++++++++++++++++++++++++++--------------- dns.h | 30 ++-- dns_server.c | 1 + 3 files changed, 364 insertions(+), 149 deletions(-) diff --git a/dns.c b/dns.c index af1dcaf..06c893c 100644 --- a/dns.c +++ b/dns.c @@ -6,217 +6,425 @@ short dns_read_short(unsigned char **buffer) { - unsigned short value; + unsigned short value; - value = *((unsigned short *)(*buffer)); - *buffer += 2; + value = *((unsigned short *)(*buffer)); + *buffer += 2; - return ntohs(value); + return ntohs(value); } void dns_write_char(unsigned char **buffer, unsigned char value) { - **buffer = value; - *buffer += 1; + **buffer = value; + *buffer += 1; } unsigned char dns_read_char(unsigned char **buffer) { - unsigned char value = **buffer; - *buffer += 1; - return value; + unsigned char value = **buffer; + *buffer += 1; + return value; } void dns_write_short(unsigned char **buffer, unsigned short value) { - value = htons(value); - *((unsigned short *)(*buffer)) = value; - *buffer += 2; + value = htons(value); + *((unsigned short *)(*buffer)) = value; + *buffer += 2; } void dns_write_int(unsigned char **buffer, unsigned int value) { - value = htons(value); - *((unsigned int *)(*buffer)) = value; - *buffer += 4; + value = htons(value); + *((unsigned int *)(*buffer)) = value; + *buffer += 4; } unsigned int dns_read_int(unsigned char **buffer) { - unsigned int value; + unsigned int value; - value = *((unsigned int *)(*buffer)); - *buffer += 4; + value = *((unsigned int *)(*buffer)); + *buffer += 4; - return ntohs(value); + return ntohs(value); +} + +struct dns_rrs *dns_rr_get_start(struct dns_packet *packet, int type, int *count) +{ + unsigned short start; + struct dns_head *head = &packet->head; + + switch (type) { + case DNS_RR_QD: + *count = head->qdcount; + start = packet->questions; + break; + case DNS_RR_AN: + *count = head->ancount; + start = packet->answers; + break; + case DNS_RR_NS: + *count = head->nscount; + start = packet->nameservers; + break; + case DNS_RR_NR: + *count = head->nrcount; + start = packet->additional; + break; + default: + return NULL; + break; + } + + return (struct dns_rrs *)(packet->data + start); +} + +struct dns_rrs *dns_rr_get_next(struct dns_packet *packet, struct dns_rrs *rrs) +{ + if (rrs->next == 0) { + return NULL; + } + + return (struct dns_rrs *)(packet->data + rrs->next); +} + +unsigned char *dns_rr_add_start(struct dns_packet *packet, int *maxlen) +{ + struct dns_rrs *rrs; + unsigned char *end = packet->data + packet->len; + rrs = (struct dns_rrs *)end; + *maxlen = packet->size - packet->len - sizeof(*packet); + if (packet->len >= packet->size - sizeof(*packet)) { + return NULL; + } + return rrs->data; +} + +int dns_rr_add_end(struct dns_packet *packet, int type, int len) +{ + struct dns_rrs *rrs; + struct dns_head *head = &packet->head; + unsigned char *end = packet->data + packet->len; + rrs = (struct dns_rrs *)end; + unsigned short *count; + unsigned short *start; + + if (packet->len + len > packet->size - sizeof(*packet)) { + return -1; + } + + switch (type) { + case DNS_RR_QD: + count = &head->qdcount; + start = &packet->questions; + break; + case DNS_RR_AN: + count = &head->ancount; + start = &packet->answers; + break; + case DNS_RR_NS: + count = &head->nscount; + start = &packet->nameservers; + break; + case DNS_RR_NR: + count = &head->nrcount; + start = &packet->additional; + break; + default: + return -1; + break; + } + + *count += 1; + rrs->next = *start; + rrs->len = len; + *start = packet->len; + packet->len += len; + return 0; } int dns_decode_head(struct dns_head *head, unsigned char **data) { - unsigned int fields; - unsigned char *start = *data; - unsigned char *end = start; + unsigned int fields; + unsigned char *start = *data; + unsigned char *end = start; - head->id = dns_read_short(data); - fields = dns_read_short(data); - head->qr = (fields & QR_MASK) >> 15; - head->opcode = (fields & OPCODE_MASK) >> 11; - head->aa = (fields & AA_MASK) >> 10; - head->tc = (fields & TC_MASK) >> 9; - head->rd = (fields & RD_MASK) >> 8; - head->ra = (fields & RA_MASK) >> 7; - head->rcode = (fields & RCODE_MASK) >> 0; - head->qdcount = dns_read_short(data); - head->ancount = dns_read_short(data); - head->nscount = dns_read_short(data); - head->nrcount = dns_read_short(data); + head->id = dns_read_short(data); + fields = dns_read_short(data); + head->qr = (fields & QR_MASK) >> 15; + head->opcode = (fields & OPCODE_MASK) >> 11; + head->aa = (fields & AA_MASK) >> 10; + head->tc = (fields & TC_MASK) >> 9; + head->rd = (fields & RD_MASK) >> 8; + head->ra = (fields & RA_MASK) >> 7; + head->rcode = (fields & RCODE_MASK) >> 0; + head->qdcount = dns_read_short(data); + head->ancount = dns_read_short(data); + head->nscount = dns_read_short(data); + head->nrcount = dns_read_short(data); - end = *data; - return start - end; + end = *data; + return start - end; } int dns_encode_head(unsigned char **data, struct dns_head *head) { - dns_write_short(data, head->id); + dns_write_short(data, head->id); - int fields = 0; - fields |= (head->qr << 15) & QR_MASK; - fields |= (head->rcode << 0) & RCODE_MASK; - dns_write_short(data, fields); + int fields = 0; + fields |= (head->qr << 15) & QR_MASK; + fields |= (head->rcode << 0) & RCODE_MASK; + dns_write_short(data, fields); - dns_write_short(data, head->qdcount); - dns_write_short(data, head->ancount); - dns_write_short(data, head->nscount); - dns_write_short(data, head->nrcount); - return 0; + dns_write_short(data, head->qdcount); + dns_write_short(data, head->ancount); + dns_write_short(data, head->nscount); + dns_write_short(data, head->nrcount); + return 0; } -int dns_get_domain(char *output, int size, unsigned char *data) +int dns_decode_domain(char *output, int size, unsigned char *data) { - int i = 0; - int output_len = 0; - int copy_len = 0; - int total_len = 0; + int i = 0; + int output_len = 0; + int copy_len = 0; + int total_len = 0; - while (data[i]) { - int len = data[i]; + while (data[i]) { + int len = data[i]; - if (i != 0) { - *output = '.'; - output++; - } + if (i != 0) { + *output = '.'; + output++; + } - i++; - total_len++; - if (output_len < size - 1) { - copy_len = (len < size - output_len) ? len : size - 1 - output_len; - memcpy(output, data + i, copy_len); - } - i += len; - output += len; - output_len += len; - total_len += len; - } + i++; + total_len++; + if (output_len < size - 1) { + copy_len = (len < size - output_len) ? len : size - 1 - output_len; + memcpy(output, data + i, copy_len); + } + i += len; + output += len; + output_len += len; + total_len += len; + } - *output = 0; - total_len++; - return total_len; + *output = 0; + total_len++; + return total_len; } int dns_encode_domain(unsigned char *output, int size, char *domain) { - int i = 0; - int num = 0; - int total_len = 0; - unsigned char *ptr_num = output++; - total_len++; - while (i < size && *domain != 0) { - if (*domain == '.') { - *ptr_num = num; - num = 0; - ptr_num = output; - domain++; - output++; - total_len++; - continue; - } - *output = *domain; - num++; - output++; - domain++; - total_len++; - } - *ptr_num = num; - *output = 0; - total_len++; - return total_len; + int i = 0; + int num = 0; + int total_len = 0; + unsigned char *ptr_num = output++; + total_len++; + while (i < size && *domain != 0) { + if (*domain == '.') { + *ptr_num = num; + num = 0; + ptr_num = output; + domain++; + output++; + total_len++; + continue; + } + *output = *domain; + num++; + output++; + domain++; + total_len++; + } + *ptr_num = num; + *output = 0; + total_len++; + return total_len; } int dns_decode_qd(unsigned char *data, int size, char *domain, int domain_size, int *qtype, int *qclass) { - int len = 0; - len = dns_get_domain(domain, domain_size, data); - data += len; - *qtype = dns_read_short(&data); - *qclass = dns_read_short(&data); + int len = 0; + len = dns_decode_domain(domain, domain_size, data); + data += len; + *qtype = dns_read_short(&data); + *qclass = dns_read_short(&data); - return len; + return len; +} + +int dns_add_domain(struct dns_packet *packet, char *domain, int qtype, int qclass) +{ + int maxlen = 0; + int i; + int len = 0; + unsigned char *data = dns_rr_add_start(packet, &maxlen); + unsigned char *data_ptr = data; + if (data == NULL) { + return -1; + } + + for (i = 0; i < maxlen - 4; i++) { + *data = *domain; + if (*domain == '\0') { + data++; + i++; + break; + } + data++; + domain++; + } + len += i; + *((unsigned short *)(data)) = qtype; + data += 2; + len += 2; + + *((unsigned short *)(data)) = qclass; + data += 2; + len += 2; + + return dns_rr_add_end(packet, DNS_RR_QD, len); +} + +int dns_add_A(struct dns_packet *packet, char *domain, int qtype, int qclass) +{ + int maxlen = 0; + int i; + int len = 0; + unsigned char *data = dns_rr_add_start(packet, &maxlen); + unsigned char *data_ptr = data; + if (data == NULL) { + return -1; + } + + + return dns_rr_add_end(packet, DNS_RR_AN, len); +} + +int dns_add_AAAA(struct dns_packet *packet, char *domain, int qtype, int qclass) +{ + int maxlen = 0; + int i; + int len = 0; + unsigned char *data = dns_rr_add_start(packet, &maxlen); + unsigned char *data_ptr = data; + if (data == NULL) { + return -1; + } + + + return dns_rr_add_end(packet, DNS_RR_AN, len); +} + +int dns_get_domain(struct dns_rrs *rrs, char *domain, int maxsize, int *qtype, int *qclass) +{ + int i = 0; + unsigned char *data = rrs->data; + for (i = 0; i < maxsize; i++) { + *domain = *data; + if (*data == '\0') { + domain++; + data++; + break; + } + *domain = *data; + domain++; + data++; + } + + *qtype = *((unsigned short *)(data)); + data += 2; + + *qclass = *((unsigned short *)(data)); + data += 2; + + return 0; } int dns_decode_body(struct dns_packet *packet, unsigned char *data, int size) { - struct dns_head *head = &packet->head; - int i = 0; - int len = 0; - int decode_len = 0; - int qtype = 0; - int qclass = 0; - char name[DNS_MAX_CNAME_LEN]; + struct dns_head *head = &packet->head; + int i = 0; + int len = 0; + int decode_len = 0; + int qtype = 0; + int qclass = 0; + char name[DNS_MAX_CNAME_LEN]; - if (head->nrcount || head->nscount || head->ancount) { - return -1; - } + if (head->nrcount || head->nscount || head->ancount) { + return -1; + } - for (i = 0; i < head->qdcount; i++) { - len = dns_decode_qd(data, size - decode_len, name, DNS_MAX_CNAME_LEN, &qtype, &qclass); - printf("QR: %d, domain: %s, qtype = %d, qclass = %d\n", head->qr, name, qtype, qclass); - decode_len += len; - data += len; - } + for (i = 0; i < head->qdcount; i++) { + len = dns_decode_qd(data, size - decode_len, name, DNS_MAX_CNAME_LEN, &qtype, &qclass); + if (dns_add_domain(packet, name, qtype, qclass) != 0) { + return -1; + } + head->qdcount--; + decode_len += len; + data += len; + } - return 0; + return 0; } int dns_decode(struct dns_packet *packet, unsigned char *data, int size) { - struct dns_head *head = &packet->head; - int decode_len = 0; - int ret = 0; + struct dns_head *head = &packet->head; + int decode_len = 0; + int ret = 0; - decode_len = dns_decode_head(head, &data); - ret = dns_decode_body(packet, data, size - decode_len); - return ret; + decode_len = dns_decode_head(head, &data); + ret = dns_decode_body(packet, data, size - decode_len); + + struct dns_rrs *rrs; + int count = 0; + int i = 0; + + rrs = dns_rr_get_start(packet, DNS_RR_QD, &count); + for (i = 0; i < count && rrs; i++, rrs = dns_rr_get_next(packet, rrs)) { + char name[128]; + int qclass = 0; + int qtype = 0; + + dns_get_domain(rrs, name, 128, &qtype, &qclass); + + printf("QR: %d, domain: %s, qtype = %d, qclass = %d\n", head->qr, name, qtype, qclass); + } + + return ret; } +int dns_packet_init(struct dns_packet *packet, int size) +{ + memset(packet, 0, size); + packet->size = size; + + return 0; +} int dns_encode(unsigned char *data, int size, struct dns_packet *packet) { - int rc; - int len = 0; + int rc; + int len = 0; - len = dns_encode_head(&data, &packet->head); + len = dns_encode_head(&data, &packet->head); - while (1) { - len = dns_encode_domain(data, size, "www.baidu.com"); - data += len; - dns_write_short(&data, /*qType*/12); - dns_write_short(&data, /*qClass*/ 1); - } + while (1) { + len = dns_encode_domain(data, size, "www.baidu.com"); + data += len; + dns_write_short(&data, /*qType*/ 12); + dns_write_short(&data, /*qClass*/ 1); + } - /* + /* rc |= dns_encode_resource_records(packet->answers, data); rc |= dns_encode_resource_records(packet->nameservers, data); rc |= dns_encode_resource_records(packet->additional, data); */ - return rc; + return rc; } \ No newline at end of file diff --git a/dns.h b/dns.h index 24625e5..2f0c5b5 100644 --- a/dns.h +++ b/dns.h @@ -79,14 +79,6 @@ struct dns_qds { unsigned short classes; }; -struct dns_rrs { - unsigned short type; - unsigned short classes; - unsigned int ttl; - unsigned short rd_length; - char rd_data[0]; -}; - typedef uint32_t TTL; typedef struct dns_question_t /* RFC-1035 */ @@ -148,12 +140,26 @@ typedef union dns_answer_t { dns_aaaa_t aaaa; } dns_answer_t; +#define DNS_RR_QD 0 +#define DNS_RR_AN 1 +#define DNS_RR_NS 2 +#define DNS_RR_NR 3 + +struct dns_rrs { + unsigned short next; + unsigned short len; + unsigned char data[0]; +}; + struct dns_packet { struct dns_head head; - dns_question_t *questions; - dns_answer_t *answers; - dns_answer_t *nameservers; - dns_answer_t *additional; + unsigned short questions; + unsigned short answers; + unsigned short nameservers; + unsigned short additional; + int size; + int len; + unsigned char data[0]; }; int dns_decode(struct dns_packet *packet, unsigned char *data, int size); diff --git a/dns_server.c b/dns_server.c index e248c29..fa8dbea 100644 --- a/dns_server.c +++ b/dns_server.c @@ -63,6 +63,7 @@ static int _dns_server_process(struct timeval *now) goto errout; } + dns_packet_init(packet, sizeof(rsppacket)); dns_decode(packet, inpacket, len); printf("head.id = %d\n", packet->head.id);