diff --git a/dns.c b/dns.c index 4bf568a..7f430cb 100644 --- a/dns.c +++ b/dns.c @@ -6,256 +6,359 @@ short dns_read_short(unsigned char **buffer) { - unsigned short value; + unsigned short value; - value = *((unsigned short *)(*buffer)); - *buffer += 2; + value = *((unsigned short *)(*buffer)); + *buffer += 2; - return ntohs(value); + return ntohs(value); } void dns_write_char(unsigned char **buffer, unsigned char value) { - **buffer = value; - *buffer += 1; + **buffer = value; + *buffer += 1; } unsigned char dns_read_char(unsigned char **buffer) { - unsigned char value = **buffer; - *buffer += 1; - return value; + unsigned char value = **buffer; + *buffer += 1; + return value; } void dns_write_short(unsigned char **buffer, unsigned short value) { - value = htons(value); - *((unsigned short *)(*buffer)) = value; - *buffer += 2; + value = htons(value); + *((unsigned short *)(*buffer)) = value; + *buffer += 2; } void dns_write_int(unsigned char **buffer, unsigned int value) { - value = htons(value); - *((unsigned int *)(*buffer)) = value; - *buffer += 4; + value = htons(value); + *((unsigned int *)(*buffer)) = value; + *buffer += 4; } unsigned int dns_read_int(unsigned char **buffer) { - unsigned int value; + unsigned int value; - value = *((unsigned int *)(*buffer)); - *buffer += 4; + value = *((unsigned int *)(*buffer)); + *buffer += 4; - return ntohs(value); + return ntohs(value); } -struct dns_rrs *dns_rr_get_start(struct dns_packet *packet, int type, int *count) +struct dns_rrs *dns_get_rrs_start(struct dns_packet *packet, int type, int *count) { - unsigned short start; - struct dns_head *head = &packet->head; + unsigned short start; + struct dns_head *head = &packet->head; - switch (type) { - case DNS_RR_QD: - *count = head->qdcount; - start = packet->questions; - break; - case DNS_RR_AN: - *count = head->ancount; - start = packet->answers; - break; - case DNS_RR_NS: - *count = head->nscount; - start = packet->nameservers; - break; - case DNS_RR_NR: - *count = head->nrcount; - start = packet->additional; - break; - default: - return NULL; - break; - } + switch (type) { + case DNS_RRS_QD: + *count = head->qdcount; + start = packet->questions; + break; + case DNS_RRS_AN: + *count = head->ancount; + start = packet->answers; + break; + case DNS_RRS_NS: + *count = head->nscount; + start = packet->nameservers; + break; + case DNS_RRS_NR: + *count = head->nrcount; + start = packet->additional; + break; + default: + return NULL; + break; + } - return (struct dns_rrs *)(packet->data + start); + return (struct dns_rrs *)(packet->data + start); } -struct dns_rrs *dns_rr_get_next(struct dns_packet *packet, struct dns_rrs *rrs) +struct dns_rrs *dns_get_rrs_next(struct dns_packet *packet, struct dns_rrs *rrs) { - if (rrs->next == 0) { - return NULL; - } + if (rrs->next == DNS_RR_END) { + return NULL; + } - return (struct dns_rrs *)(packet->data + rrs->next); + return (struct dns_rrs *)(packet->data + rrs->next); } -unsigned char *dns_rr_add_start(struct dns_packet *packet, int *maxlen) +unsigned char *_dns_add_rrs_start(struct dns_packet *packet, int *maxlen) { - struct dns_rrs *rrs; - unsigned char *end = packet->data + packet->len; - rrs = (struct dns_rrs *)end; - *maxlen = packet->size - packet->len - sizeof(*packet); - if (packet->len >= packet->size - sizeof(*packet)) { - return NULL; - } - return rrs->data; + struct dns_rrs *rrs; + unsigned char *end = packet->data + packet->len; + rrs = (struct dns_rrs *)end; + *maxlen = packet->size - packet->len - sizeof(*packet); + if (packet->len >= packet->size - sizeof(*packet)) { + return NULL; + } + return rrs->data; } int dns_rr_add_end(struct dns_packet *packet, int type, dns_type_t rrtype, int len) { - struct dns_rrs *rrs; - struct dns_head *head = &packet->head; - unsigned char *end = packet->data + packet->len; - rrs = (struct dns_rrs *)end; - unsigned short *count; - unsigned short *start; + struct dns_rrs *rrs; + struct dns_head *head = &packet->head; + unsigned char *end = packet->data + packet->len; + rrs = (struct dns_rrs *)end; + unsigned short *count; + unsigned short *start; - if (packet->len + len > packet->size - sizeof(*packet)) { - return -1; - } + len += sizeof(*rrs); + if (packet->len + len > packet->size - sizeof(*packet)) { + return -1; + } - switch (type) { - case DNS_RR_QD: - count = &head->qdcount; - start = &packet->questions; - break; - case DNS_RR_AN: - count = &head->ancount; - start = &packet->answers; - break; - case DNS_RR_NS: - count = &head->nscount; - start = &packet->nameservers; - break; - case DNS_RR_NR: - count = &head->nrcount; - start = &packet->additional; - break; - default: - return -1; - break; - } + switch (type) { + case DNS_RRS_QD: + count = &head->qdcount; + start = &packet->questions; + break; + case DNS_RRS_AN: + count = &head->ancount; + start = &packet->answers; + break; + case DNS_RRS_NS: + count = &head->nscount; + start = &packet->nameservers; + break; + case DNS_RRS_NR: + count = &head->nrcount; + start = &packet->additional; + break; + default: + return -1; + break; + } - *count += 1; - rrs->next = *start; - rrs->len = len; + *count += 1; + rrs->next = *start; + rrs->len = len; rrs->type = rrtype; *start = packet->len; - packet->len += len; - return 0; + packet->len += len; + return sizeof(*rrs) + len; } -int dns_decode_head(struct dns_head *head, unsigned char **data) +int _dns_add_qr_head(unsigned char *data, int maxlen, char *domain, int qtype, int qclass) { - unsigned int fields; - unsigned char *start = *data; - unsigned char *end = start; + int i; + int len = 0; - head->id = dns_read_short(data); - fields = dns_read_short(data); - head->qr = (fields & QR_MASK) >> 15; - head->opcode = (fields & OPCODE_MASK) >> 11; - head->aa = (fields & AA_MASK) >> 10; - head->tc = (fields & TC_MASK) >> 9; - head->rd = (fields & RD_MASK) >> 8; - head->ra = (fields & RA_MASK) >> 7; - head->rcode = (fields & RCODE_MASK) >> 0; - head->qdcount = dns_read_short(data); - head->ancount = dns_read_short(data); - head->nscount = dns_read_short(data); - head->nrcount = dns_read_short(data); + for (i = 0; i < maxlen; i++) { + *data = *domain; + if (*domain == '\0') { + data++; + i++; + break; + } + data++; + domain++; + } + len += i; - end = *data; - return start - end; + if (maxlen - len < 4) { + return -1; + } + + *((unsigned short *)(data)) = qtype; + data += 2; + len += 2; + + *((unsigned short *)(data)) = qclass; + data += 2; + len += 2; + + return len; } -int dns_encode_head(unsigned char **data, struct dns_head *head) +int _dns_get_qr_head(unsigned char *data, char *domain, int maxsize, int *qtype, int *qclass) { - dns_write_short(data, head->id); + int i; + int len = 0; + for (i = 0; i < maxsize; i++) { + *domain = *data; + if (*data == '\0') { + domain++; + data++; + i++; + break; + } + *domain = *data; + domain++; + data++; + } + len += i; + if (len >= maxsize) { + return -1; + } - int fields = 0; - fields |= (head->qr << 15) & QR_MASK; - fields |= (head->rcode << 0) & RCODE_MASK; - dns_write_short(data, fields); + *qtype = *((unsigned short *)(data)); + data += 2; + len += 2; - dns_write_short(data, head->qdcount); - dns_write_short(data, head->ancount); - dns_write_short(data, head->nscount); - dns_write_short(data, head->nrcount); - return 0; + *qclass = *((unsigned short *)(data)); + data += 2; + len += 2; + + return len; } -int dns_decode_domain(char *output, int size, unsigned char *data) +int _dns_add_rr_head(unsigned char *data, int maxlen, char *domain, int qtype, int qclass, int ttl, int rr_len) { - int i = 0; - int output_len = 0; - int copy_len = 0; - int total_len = 0; + int len = 0; - while (data[i]) { - int len = data[i]; + len = _dns_add_qr_head(data, maxlen, domain, qtype, qclass); + if (len < 0) { + return -1; + } + data += len; + if (maxlen - len < 6) { + return -1; + } - if (i != 0) { - *output = '.'; - output++; - } + *((unsigned int *)(data)) = ttl; + data += 4; + len += 4; - i++; - total_len++; - if (output_len < size - 1) { - copy_len = (len < size - output_len) ? len : size - 1 - output_len; - memcpy(output, data + i, copy_len); - } - i += len; - output += len; - output_len += len; - total_len += len; - } + *((unsigned short *)(data)) = rr_len; + data += 2; + len += 2; - *output = 0; - total_len++; - return total_len; + return len; } -int dns_encode_domain(unsigned char *output, int size, char *domain) +int _dns_get_rr_head(unsigned char *data, char *domain, int maxsize, int *qtype, int *qclass, int *ttl, int *rr_len) { - int i = 0; - int num = 0; - int total_len = 0; - unsigned char *ptr_num = output++; - total_len++; - while (i < size && *domain != 0) { - if (*domain == '.') { - *ptr_num = num; - num = 0; - ptr_num = output; - domain++; - output++; - total_len++; - continue; - } - *output = *domain; - num++; - output++; - domain++; - total_len++; - } - *ptr_num = num; - *output = 0; - total_len++; - return total_len; + int len = 0; + + len = _dns_get_qr_head(data, domain, maxsize, qtype, qclass); + data += len; + + *ttl = *((unsigned int *)(data)); + data += 4; + len += 4; + + *rr_len = *((unsigned short *)(data)); + data += 2; + len += 2; + + return len; } -int dns_decode_qd(unsigned char *data, int size, char *domain, int domain_size, int *qtype, int *qclass) +int dns_add_A(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[4]) { - int len = 0; - len = dns_decode_domain(domain, domain_size, data); - data += len; - *qtype = dns_read_short(&data); - *qclass = dns_read_short(&data); + int maxlen = 0; + int len = 0; - return len; + unsigned char *data = _dns_add_rrs_start(packet, &maxlen); + if (data == NULL) { + return -1; + } + + len = _dns_add_rr_head(data, maxlen, domain, DNS_T_A, DNS_C_IN, ttl, DNS_RR_A_LEN); + if (len < 0) { + return -1; + } + data += len; + + memcpy(data, addr, DNS_RR_A_LEN); + data += DNS_RR_A_LEN; + len += DNS_RR_A_LEN; + + return dns_rr_add_end(packet, DNS_RRS_AN, DNS_T_A, len); +} + +int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[4]) +{ + int qtype = 0; + int qclass = 0; + int rr_len = 0; + int len = 0; + int total_len = 0; + + unsigned char *data = rrs->data; + + if (rrs->type != DNS_T_A) { + return -1; + } + + len = _dns_get_rr_head(data, domain, maxsize, &qtype, &qclass, ttl, &rr_len); + if (len <= 0) { + return -1; + } + data += len; + total_len += len; + + if (qtype != DNS_T_A || rr_len != DNS_RR_A_LEN) { + return -1; + } + + memcpy(addr, data, DNS_RR_A_LEN); + total_len += rr_len; + data += rr_len; + + return total_len; +} + +int dns_add_AAAA(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[16]) +{ + int maxlen = 0; + int len = 0; + unsigned char *data = _dns_add_rrs_start(packet, &maxlen); + if (data == NULL) { + return -1; + } + + len = _dns_add_rr_head(data, maxlen, domain, DNS_T_AAAA, DNS_C_IN, ttl, DNS_RR_AAAA_LEN); + if (len < 0) { + return -1; + } + data += len; + + memcpy(data, addr, DNS_RR_AAAA_LEN); + data += DNS_RR_AAAA_LEN; + len += DNS_RR_AAAA_LEN; + + return dns_rr_add_end(packet, DNS_RRS_AN, DNS_T_AAAA, len); +} + +int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[16]) +{ + int qtype = 0; + int qclass = 0; + int rr_len = 0; + int len = 0; + int total_len = 0; + + if (rrs->type != DNS_T_AAAA) { + return -1; + } + + unsigned char *data = rrs->data; + + len = _dns_get_rr_head(data, domain, maxsize, &qtype, &qclass, ttl, &rr_len); + if (len <= 0) { + return -1; + } + data += len; + total_len += len; + + if (qtype != DNS_T_AAAA || rr_len != DNS_RR_AAAA_LEN) { + return -1; + } + + memcpy(addr, rrs->data, DNS_RR_AAAA_LEN); + total_len += DNS_RR_AAAA_LEN; + + return total_len; } /* @@ -264,175 +367,572 @@ int dns_decode_qd(unsigned char *data, int size, char *domain, int domain_size, */ int dns_add_domain(struct dns_packet *packet, char *domain, int qtype, int qclass) { - int maxlen = 0; - int i; - int len = 0; - unsigned char *data = dns_rr_add_start(packet, &maxlen); + int len = 0; + int maxlen = 0; + unsigned char *data = _dns_add_rrs_start(packet, &maxlen); - if (data == NULL) { - return -1; - } + if (data == NULL) { + return -1; + } - for (i = 0; i < maxlen - 4; i++) { - *data = *domain; - if (*domain == '\0') { - data++; - i++; - break; - } - data++; - domain++; - } - len += i; - *((unsigned short *)(data)) = qtype; - data += 2; - len += 2; + len = _dns_add_qr_head(data, maxlen, domain, qtype, qclass); + if (len < 0) { + return -1; + } - *((unsigned short *)(data)) = qclass; - data += 2; - len += 2; - - return dns_rr_add_end(packet, DNS_RR_QD, DNS_T_CNAME, len); -} - -int dns_add_A(struct dns_packet *packet, unsigned char addr[4]) -{ - int maxlen = 0; - int len = 0; - unsigned char *data = dns_rr_add_start(packet, &maxlen); - unsigned char *data_ptr = data; - if (data == NULL) { - return -1; - } - - memcpy(data, addr, 4); - data += 4; - len += 4; - - return dns_rr_add_end(packet, DNS_RR_AN, DNS_T_A, len); -} - -int dns_add_AAAA(struct dns_packet *packet, unsigned char addr[16]) -{ - int maxlen = 0; - int len = 0; - unsigned char *data = dns_rr_add_start(packet, &maxlen); - if (data == NULL) { - return -1; - } - - memcpy(data, addr, 4); - data += 4; - len += 4; - - return dns_rr_add_end(packet, DNS_RR_AN, DNS_T_AAAA, len); + return dns_rr_add_end(packet, DNS_RRS_QD, DNS_T_CNAME, len); } int dns_get_domain(struct dns_rrs *rrs, char *domain, int maxsize, int *qtype, int *qclass) { - int i = 0; - unsigned char *data = rrs->data; - for (i = 0; i < maxsize; i++) { - *domain = *data; - if (*data == '\0') { - domain++; - data++; - break; - } - *domain = *data; - domain++; - data++; - } + if (rrs->type != DNS_T_CNAME) { + return -1; + } - *qtype = *((unsigned short *)(data)); - data += 2; - - *qclass = *((unsigned short *)(data)); - data += 2; - - return 0; + return _dns_get_qr_head(rrs->data, domain, maxsize, qtype, qclass); } -int dns_decode_body(struct dns_packet *packet, unsigned char *data, int size) +int _dns_decode_head(struct dns_head *head, unsigned char *data) { - struct dns_head *head = &packet->head; - int i = 0; - int len = 0; - int decode_len = 0; - int qtype = 0; - int qclass = 0; - char name[DNS_MAX_CNAME_LEN]; + unsigned int fields; + unsigned char *start = data; + unsigned char *end = data; - if (head->nrcount || head->nscount || head->ancount) { - return -1; - } + head->id = dns_read_short(&data); + fields = dns_read_short(&data); + head->qr = (fields & QR_MASK) >> 15; + head->opcode = (fields & OPCODE_MASK) >> 11; + head->aa = (fields & AA_MASK) >> 10; + head->tc = (fields & TC_MASK) >> 9; + head->rd = (fields & RD_MASK) >> 8; + head->ra = (fields & RA_MASK) >> 7; + head->rcode = (fields & RCODE_MASK) >> 0; + head->qdcount = dns_read_short(&data); + head->ancount = dns_read_short(&data); + head->nscount = dns_read_short(&data); + head->nrcount = dns_read_short(&data); - for (i = 0; i < head->qdcount; i++) { - len = dns_decode_qd(data, size - decode_len, name, DNS_MAX_CNAME_LEN, &qtype, &qclass); - if (dns_add_domain(packet, name, qtype, qclass) != 0) { - return -1; - } - head->qdcount--; - decode_len += len; - data += len; - } - - return 0; + end = data; + return end - start; } -int dns_decode(struct dns_packet *packet, unsigned char *data, int size) +int _dns_encode_head(unsigned char *data, int size, struct dns_head *head) { - struct dns_head *head = &packet->head; - int decode_len = 0; - int ret = 0; + int len = 12; - decode_len = dns_decode_head(head, &data); - ret = dns_decode_body(packet, data, size - decode_len); + if (size < len) { + return -1; + } - struct dns_rrs *rrs; - int count = 0; - int i = 0; + dns_write_short(&data, head->id); - rrs = dns_rr_get_start(packet, DNS_RR_QD, &count); - for (i = 0; i < count && rrs; i++, rrs = dns_rr_get_next(packet, rrs)) { - char name[128]; - int qclass = 0; - int qtype = 0; + int fields = 0; + fields |= (head->qr << 15) & QR_MASK; + fields |= (head->rcode << 0) & RCODE_MASK; + dns_write_short(&data, fields); - dns_get_domain(rrs, name, 128, &qtype, &qclass); - - printf("QR: %d, domain: %s, qtype = %d, qclass = %d\n", head->qr, name, qtype, qclass); - } - - return ret; + dns_write_short(&data, head->qdcount); + dns_write_short(&data, head->ancount); + dns_write_short(&data, head->nscount); + dns_write_short(&data, head->nrcount); + return len; } -int dns_packet_init(struct dns_packet *packet, int size) +int _dns_decode_domain(char *output, int size, unsigned char *data) { - memset(packet, 0, size); - packet->size = size; + int i = 0; + int output_len = 0; + int copy_len = 0; + int total_len = 0; - return 0; + while (data[i]) { + int len = data[i]; + + if (i != 0) { + *output = '.'; + output++; + } + + i++; + total_len++; + if (output_len < size - 1) { + copy_len = (len < size - output_len) ? len : size - 1 - output_len; + memcpy(output, data + i, copy_len); + } + i += len; + output += len; + output_len += len; + total_len += len; + } + + *output = 0; + total_len++; + return total_len; +} + +int _dns_encode_domain(unsigned char *output, int size, char *domain) +{ + int i = 0; + int num = 0; + int total_len = 0; + unsigned char *ptr_num = output++; + total_len++; + while (i < size && *domain != 0) { + if (*domain == '.') { + *ptr_num = num; + num = 0; + ptr_num = output; + domain++; + output++; + total_len++; + continue; + } + *output = *domain; + num++; + output++; + domain++; + total_len++; + } + *ptr_num = num; + *output = 0; + total_len++; + return total_len; +} + +int _dns_decode_qr_head(unsigned char *data, int size, char *domain, int domain_size, int *qtype, int *qclass) +{ + int len = 0; + len = _dns_decode_domain(domain, domain_size, data); + if (len <= 0) { + return -1; + } + + data += len; + *qtype = dns_read_short(&data); + len += 2; + *qclass = dns_read_short(&data); + len += 2; + + return len; +} + +int _dns_encode_qr_head(unsigned char *data, int size, char *domain, int qtype, int qclass) +{ + int len = 0; + len = _dns_encode_domain(data, size, domain); + if (len <= 0) { + return -1; + } + data += len; + + if (size - len < 4) { + return -1; + } + + dns_write_short(&data, qtype); + len += 2; + dns_write_short(&data, qclass); + len += 2; + + return len; +} + +int _dns_decode_rr_head(unsigned char *data, int size, char *domain, int domain_size, int *qtype, int *qclass, int *ttl, int *rr_len) +{ + int len = 0; + int total_len = 0; + + len = _dns_decode_qr_head(data, size, domain, domain_size, qtype, qclass); + if (len <= 0) { + return -1; + } + + data += len; + total_len += len; + + *ttl = dns_read_int(&data); + len += 4; + total_len += 4; + + *rr_len = dns_read_short(&data); + len += 2; + total_len += 2; + + return total_len; +} + +int _dns_encode_rr_head(unsigned char *data, int size, char *domain, int qtype, int qclass, int ttl, int rr_len) +{ + int len = 0; + int total_len = 0; + len = _dns_encode_qr_head(data, size, domain, qtype, qclass); + if (len <= 0) { + return -1; + } + + data += len; + total_len += len; + + if (size - len < 6) { + return -1; + } + + dns_write_int(&data, ttl); + len += 4; + total_len += 4; + + dns_write_short(&data, rr_len); + len += 2; + total_len += 2; + + return total_len; +} + +int _dns_decode_A(unsigned char addr[4], unsigned char *data) +{ + memcpy(addr, data, DNS_RR_A_LEN); + return DNS_RR_A_LEN; +} + +int _dns_encode_A(unsigned char *output, int size, struct dns_rrs *rrs) +{ + int len; + int len_rrs; + int total_len = 0; + int qtype = 0; + int qclass = 0; + int ttl = 0; + char domain[DNS_MAX_CNAME_LEN]; + unsigned char *data_rrs; + int rr_len; + + data_rrs = rrs->data; + len_rrs = _dns_get_rr_head(data_rrs, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (len_rrs <= 0) { + return -1; + } + data_rrs += len_rrs; + + if (rr_len != DNS_RR_A_LEN) { + return -1; + } + + len = _dns_encode_rr_head(output, size, domain, qtype, qclass, ttl, DNS_RR_A_LEN); + if (len <= 0) { + return -1; + } + output += len; + total_len += len; + + if (size - total_len < rr_len + DNS_RR_A_LEN) { + return -1; + } + + memcpy(output, data_rrs, DNS_RR_A_LEN); + output += DNS_RR_A_LEN; + data_rrs += DNS_RR_A_LEN; + total_len += DNS_RR_A_LEN; + + return total_len; +} + +int _dns_decode_AAAA(unsigned char addr[DNS_RR_AAAA_LEN], unsigned char *data) +{ + memcpy(addr, data, DNS_RR_AAAA_LEN); + return DNS_RR_AAAA_LEN; +} + +int _dns_encode_AAAA(unsigned char *output, int size, struct dns_rrs *rrs) +{ + int len; + int len_rrs; + int total_len = 0; + int qtype = 0; + int qclass = 0; + int ttl = 0; + char domain[DNS_MAX_CNAME_LEN]; + unsigned char *data_rrs; + int rr_len; + + data_rrs = rrs->data; + len_rrs = _dns_get_rr_head(data_rrs, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (len_rrs <= 0) { + return -1; + } + data_rrs += len_rrs; + + if (rr_len != DNS_RR_AAAA_LEN) { + return -1; + } + + len = _dns_encode_rr_head(output, size, domain, qtype, qclass, ttl, DNS_RR_AAAA_LEN); + if (len <= 0) { + return -1; + } + output += len; + total_len += len; + + if (size - total_len < rr_len + DNS_RR_AAAA_LEN) { + return -1; + } + + memcpy(output, data_rrs, DNS_RR_AAAA_LEN); + output += DNS_RR_AAAA_LEN; + data_rrs += DNS_RR_AAAA_LEN; + total_len += DNS_RR_AAAA_LEN; + + return total_len; +} + +int _dns_decode_qd(struct dns_packet *packet, unsigned char *data, int size) +{ + int len; + int decode_len = 0; + int qtype = 0; + int qclass = 0; + char domain[DNS_MAX_CNAME_LEN]; + + int ttl; + int rr_len; + len = _dns_decode_qr_head(data, size, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass); + if (len <= 0) { + return -1; + } + decode_len += len; + + len = dns_add_domain(packet, domain, qtype, qclass); + if ( len <= 0 ) { + return -1; + } + + return decode_len; +} + +int _dns_decode_an(struct dns_packet *packet, unsigned char *data, int size) +{ + int len; + int qtype = 0; + int qclass = 0; + int ttl; + int rr_len = 0; + char domain[DNS_MAX_CNAME_LEN]; + int decode_len = 0; + + len = _dns_decode_rr_head(data, size, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (len <= 0) { + return -1; + } + + data += len; + decode_len += len; + switch (qtype) { + case DNS_T_A: { + unsigned char addr[DNS_RR_A_LEN]; + len = _dns_decode_A(addr, data); + if (len < 0) { + return -1; + } + data += len; + decode_len += len; + len = dns_add_A(packet, domain, ttl, addr); + if (len < 0) { + return -1; + } + } break; + case DNS_T_AAAA: { + unsigned char addr[DNS_RR_AAAA_LEN]; + len = _dns_decode_AAAA(addr, data); + if (len < 0) { + return -1; + } + data += len; + decode_len += len; + len = dns_add_AAAA(packet, domain, ttl, addr); + if (len < 0) { + return -1; + } + } break; + default: + break; + } + + return decode_len; +} + +int _dns_encode_qd(unsigned char *data, int size, struct dns_rrs *rrs) +{ + int len; + int len_rrs; + int qtype = 0; + int qclass = 0; + int total_len = 0; + char domain[DNS_MAX_CNAME_LEN]; + unsigned char *data_rrs = rrs->data; + len_rrs = _dns_get_qr_head(data_rrs, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass); + if (len_rrs <= 0) { + return -1; + } + + len = _dns_encode_qr_head(data, size, domain, qtype, qclass); + if (len <= 0) { + return -1; + } + total_len += len; + + return total_len; +} + +int _dns_encode_an(unsigned char *data, int size, struct dns_rrs *rrs) +{ + int len; + int total_len = 0; + switch (rrs->type) { + case DNS_T_A: { + len = _dns_encode_A(data, size, rrs); + if (len < 0) { + return -1; + } + total_len += len; + } break; + case DNS_T_AAAA: + len = _dns_encode_AAAA(data, size, rrs); + if (len < 0) { + return -1; + } + total_len += len; + break; + default: + break; + } + + return total_len; +} + +int _dns_decode_body(struct dns_packet *packet, unsigned char *data, int size) +{ + struct dns_head *head = &packet->head; + int i = 0; + int len = 0; + int decode_len = 0; + + for (i = 0; i < head->qdcount; i++) { + len = _dns_decode_qd(packet, data, size - decode_len); + if (len <= 0) { + return -1; + } + head->qdcount--; + decode_len += len; + data += len; + } + + for (i = 0; i < head->ancount; i++) { + len = _dns_decode_an(packet, data, size - decode_len); + if (len <= 0) { + return -1; + } + head->ancount--; + decode_len += len; + data += len; + } + + return decode_len; +} + +int _dns_encode_body(unsigned char *data, int size, struct dns_packet *packet) +{ + struct dns_head *head = &packet->head; + int i = 0; + int len = 0; + int encode_len = 0; + struct dns_rrs *rrs; + int count; + + rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &count); + head->qdcount = count; + for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + len = _dns_encode_qd(data, size - encode_len, rrs); + if (len <= 0) { + return -1; + } + encode_len += len; + data += len; + } + + rrs = dns_get_rrs_start(packet, DNS_RRS_AN, &count); + head->ancount = count; + for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + len = _dns_encode_an(data, size - encode_len, rrs); + if (len <= 0) { + return -1; + } + encode_len += len; + data += len; + } + + return encode_len; +} + +int dns_packet_init(struct dns_packet *packet, int size, struct dns_head *head) +{ + struct dns_head *init_head = &packet->head; + memset(packet, 0, size); + packet->size = size; + init_head->id = head->id; + init_head->qr = head->qr; + init_head->opcode = head->opcode; + init_head->aa = head->aa; + init_head->tc = 0; + init_head->rd = head->rd; + init_head->ra = head->ra; + init_head->rcode = head->rcode; + packet->questions = DNS_RR_END; + packet->answers = DNS_RR_END; + packet->nameservers = DNS_RR_END; + packet->additional = DNS_RR_END; + + return 0; +} + +int dns_decode(struct dns_packet *packet, int maxsize, unsigned char *data, int size) +{ + struct dns_head *head = &packet->head; + int decode_len = 0; + int len = 0; + + memset(packet, 0, sizeof(*packet)); + dns_packet_init(packet, maxsize, head); + len = _dns_decode_head(head, data); + if (len < 0) { + return -1; + } + data += len; + decode_len += len; + + len = _dns_decode_body(packet, data, size - decode_len); + if (len < 0) { + return -1; + } + decode_len += len; + + return decode_len; } int dns_encode(unsigned char *data, int size, struct dns_packet *packet) { - int rc; - int len = 0; + int len = 0; + int total_len = 0; - len = dns_encode_head(&data, &packet->head); + len = _dns_encode_head(data, size, &packet->head); + if (len <= 0) { + return -1; + } + data += len; + total_len += len; - while (1) { - len = dns_encode_domain(data, size, "www.baidu.com"); - data += len; - dns_write_short(&data, /*qType*/ 12); - dns_write_short(&data, /*qClass*/ 1); - } - - /* - rc |= dns_encode_resource_records(packet->answers, data); - rc |= dns_encode_resource_records(packet->nameservers, data); - rc |= dns_encode_resource_records(packet->additional, data); - */ - return rc; -} \ No newline at end of file + len = _dns_encode_body(data, size - len, packet); + if (len <= 0) { + return -1; + } + total_len += len; + return total_len; +} diff --git a/dns.h b/dns.h index f03df77..04cb04c 100644 --- a/dns.h +++ b/dns.h @@ -14,7 +14,15 @@ #define RA_MASK 0x8000 #define RCODE_MASK 0x000F -typedef enum dns_section { DNS_S_QD = 0x01, DNS_S_AN = 0x02, DNS_S_NS = 0x04, DNS_S_AR = 0x08, DNS_S_ALL = 0x0f } dns_section_t; +#define DNS_RR_A_LEN 4 +#define DNS_RR_AAAA_LEN 16 + +#define DNS_RRS_QD 0 +#define DNS_RRS_AN 1 +#define DNS_RRS_NS 2 +#define DNS_RRS_NR 3 + +#define DNS_RR_END (-1) typedef enum dns_class { DNS_C_IN = 1, DNS_C_ANY = 255 } dns_class_t; @@ -63,10 +71,10 @@ struct dns_head { unsigned short id; // identification number unsigned short qr; /* Query/Response Flag */ unsigned short opcode; /* Operation Code */ - unsigned short aa; /* Authoritative Answer Flag */ - unsigned short tc; /* Truncation Flag */ - unsigned short rd; /* Recursion Desired */ - unsigned short ra; /* Recursion Available */ + unsigned char aa; /* Authoritative Answer Flag */ + unsigned char tc; /* Truncation Flag */ + unsigned char rd; /* Recursion Desired */ + unsigned char ra; /* Recursion Available */ unsigned short rcode; /* Response Code */ unsigned short qdcount; // number of question entries unsigned short ancount; // number of answer entries @@ -74,77 +82,6 @@ struct dns_head { unsigned short nrcount; // number of addititional resource entries } __attribute__((packed)); -struct dns_qds { - unsigned short type; - unsigned short classes; -}; - -typedef uint32_t TTL; - -typedef struct dns_question_t /* RFC-1035 */ -{ - const char *name; - dns_type_t type; - dns_class_t class; -} dns_question_t; - -typedef struct dns_generic_t /* RFC-1035 */ -{ - const char *name; - dns_type_t type; - dns_class_t class; - TTL ttl; -} dns_generic_t; - -typedef struct dns_a_t /* RFC-1035 */ -{ - const char *name; - dns_type_t type; - dns_class_t class; - TTL ttl; - in_addr_t address; -} dns_a_t; - -typedef struct dns_aaaa_t /* RFC-1886 */ -{ - const char *name; - dns_type_t type; - dns_class_t class; - TTL ttl; - struct in6_addr address; -} dns_aaaa_t; - -typedef struct dns_cname_t /* RFC-1035 */ -{ - const char *name; - dns_type_t type; - dns_class_t class; - TTL ttl; - const char *cname; -} dns_cname_t; - -typedef struct dns_ptr_t /* RFC-1035 */ -{ - const char *name; - dns_type_t type; - dns_class_t class; - TTL ttl; - const char *ptr; -} dns_ptr_t; - -typedef union dns_answer_t { - dns_generic_t generic; - dns_a_t a; - dns_cname_t cname; - dns_ptr_t ptr; - dns_aaaa_t aaaa; -} dns_answer_t; - -#define DNS_RR_QD 0 -#define DNS_RR_AN 1 -#define DNS_RR_NS 2 -#define DNS_RR_NR 3 - struct dns_rrs { unsigned short next; unsigned short len; @@ -163,18 +100,26 @@ struct dns_packet { unsigned char data[0]; }; -int dns_decode(struct dns_packet *packet, unsigned char *data, int size); +struct dns_rrs *dns_get_rrs_next(struct dns_packet *packet, struct dns_rrs *rrs); -int dns_encode(unsigned char *data, int size, struct dns_packet *packet); +struct dns_rrs *dns_get_rrs_start(struct dns_packet *packet, int type, int *count); -int dns_packet_init(struct dns_packet *packet, int size); +int dns_add_A(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[4]); + +int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[4]); + +int dns_add_AAAA(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[16]); + +int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[16]); int dns_get_domain(struct dns_rrs *rrs, char *domain, int maxsize, int *qtype, int *qclass); int dns_add_domain(struct dns_packet *packet, char *domain, int qtype, int qclass); -struct dns_rrs *dns_rr_get_next(struct dns_packet *packet, struct dns_rrs *rrs); +int dns_decode(struct dns_packet *packet, int maxsize, unsigned char *data, int size); -struct dns_rrs *dns_rr_get_start(struct dns_packet *packet, int type, int *count); +int dns_encode(unsigned char *data, int size, struct dns_packet *packet); + +int dns_packet_init(struct dns_packet *packet, int size, struct dns_head *head); #endif \ No newline at end of file diff --git a/dns_server.c b/dns_server.c index fa8dbea..b4d65b3 100644 --- a/dns_server.c +++ b/dns_server.c @@ -1,6 +1,6 @@ #include "dns_server.h" -#include "hashtable.h" #include "dns.h" +#include "hashtable.h" #include #include #include @@ -24,130 +24,174 @@ #define DNS_INPACKET_SIZE 512 struct dns_server { - int run; - int epoll_fd; + int run; + int epoll_fd; - int fd; + int fd; - pthread_mutex_t map_lock; - DECLARE_HASHTABLE(hostmap, 6); + pthread_mutex_t map_lock; + DECLARE_HASHTABLE(hostmap, 6); }; static struct dns_server server; static void tv_sub(struct timeval *out, struct timeval *in) { - if ((out->tv_usec -= in->tv_usec) < 0) { /* out -= in */ - --out->tv_sec; - out->tv_usec += 1000000; - } - out->tv_sec -= in->tv_sec; + if ((out->tv_usec -= in->tv_usec) < 0) { /* out -= in */ + --out->tv_sec; + out->tv_usec += 1000000; + } + out->tv_sec -= in->tv_sec; } void _dns_server_period_run() { + unsigned char packet_data[DNS_INPACKET_SIZE]; + unsigned char data[DNS_INPACKET_SIZE]; + + struct dns_packet *packet = (struct dns_packet *)packet_data; + + struct dns_head head; + memset(&head, 0, sizeof(head)); + head.rcode = 0; + head.qr = 0; + head.ra = 1; + head.id = 1; + + int len; + struct sockaddr_in to; + socklen_t to_len = sizeof(to); + + dns_packet_init(packet, DNS_INPACKET_SIZE, &head); + dns_add_domain(packet, "www.baidu.com", 1, 1); + len = dns_encode(data, DNS_INPACKET_SIZE, packet); + + memset(&to, 0, sizeof(to)); + to.sin_addr.s_addr = inet_addr("192.168.1.1"); + to.sin_port = htons(53); + + len = sendto(server.fd, data, len, 0, (struct sockaddr *)&to, to_len); + if (len < 0) { + printf("send failed."); + } + + printf("send %d\n", len); } static int _dns_server_process(struct timeval *now) { - int len; - unsigned char inpacket[DNS_INPACKET_SIZE]; - unsigned char rsppacket[DNS_INPACKET_SIZE]; + int len; + unsigned char inpacket[DNS_INPACKET_SIZE]; + unsigned char rsppacket[DNS_INPACKET_SIZE]; struct dns_packet *packet = (struct dns_packet *)rsppacket; struct sockaddr_storage from; socklen_t from_len = sizeof(from); - len = recvfrom(server.fd, inpacket, sizeof(inpacket), 0, (struct sockaddr *)&from, (socklen_t *)&from_len); - if (len < 0) { - fprintf(stderr, "recvfrom failed, %s\n", strerror(errno)); - goto errout; - } + len = recvfrom(server.fd, inpacket, sizeof(inpacket), 0, (struct sockaddr *)&from, (socklen_t *)&from_len); + if (len < 0) { + fprintf(stderr, "recvfrom failed, %s\n", strerror(errno)); + goto errout; + } - dns_packet_init(packet, sizeof(rsppacket)); - dns_decode(packet, inpacket, len); + dns_decode(packet, DNS_INPACKET_SIZE, inpacket, len); - printf("head.id = %d\n", packet->head.id); - printf("head.an_count = %d\n", packet->head.ancount); - printf("head.qd_count = %d\n", packet->head.qdcount); + int count; + struct dns_rrs *rrs; + char name[128]; + int i = 0; + int ttl; + + rrs = dns_get_rrs_start(packet, DNS_RRS_AN, &count); + for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + switch (rrs->type) { + case DNS_T_A: { + unsigned char addr[4]; + dns_get_A(rrs, name, 128, &ttl, addr); + printf("%s %d : %d.%d.%d.%d\n", name, ttl, addr[0], addr[1], addr[2], addr[3]); + } break; + default: + break; + } + } return 0; errout: - return -1; + return -1; } int dns_server_run(void) { - struct epoll_event events[DNS_MAX_EVENTS + 1]; - int num; - int i; - struct timeval last = { 0 }; - struct timeval now = { 0 }; - struct timeval diff = { 0 }; - uint millisec = 0; + struct epoll_event events[DNS_MAX_EVENTS + 1]; + int num; + int i; + struct timeval last = {0}; + struct timeval now = {0}; + struct timeval diff = {0}; + uint millisec = 0; - while (server.run) { - diff = now; - tv_sub(&diff, &last); - millisec = diff.tv_sec * 1000 + diff.tv_usec / 1000; - if (millisec >= 100) { - _dns_server_period_run(); - last = now; - } + while (server.run) { + diff = now; + tv_sub(&diff, &last); + millisec = diff.tv_sec * 1000 + diff.tv_usec / 1000; + if (millisec >= 100) { + _dns_server_period_run(); + last = now; + } - num = epoll_wait(server.epoll_fd, events, DNS_MAX_EVENTS, 100); - if (num < 0) { - gettimeofday(&now, 0); - usleep(100000); - continue; - } + num = epoll_wait(server.epoll_fd, events, DNS_MAX_EVENTS, 100); + if (num < 0) { + gettimeofday(&now, 0); + usleep(100000); + continue; + } - if (num == 0) { - gettimeofday(&now, 0); - continue; - } + if (num == 0) { + gettimeofday(&now, 0); + continue; + } - gettimeofday(&now, 0); - for (i = 0; i < num; i++) { - struct epoll_event *event = &events[i]; - if (event->data.fd != server.fd) { + gettimeofday(&now, 0); + for (i = 0; i < num; i++) { + struct epoll_event *event = &events[i]; + if (event->data.fd != server.fd) { fprintf(stderr, "invalid fd\n"); continue; } _dns_server_process(&now); - } - } + } + } - close(server.epoll_fd); - server.epoll_fd = -1; + close(server.epoll_fd); + server.epoll_fd = -1; - return 0; + return 0; } static struct addrinfo *_dns_server_getaddr(const char *host, const char *port, int type, int protocol) { - struct addrinfo hints; - struct addrinfo *result = NULL; + struct addrinfo hints; + struct addrinfo *result = NULL; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = type; - hints.ai_protocol = protocol; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = type; + hints.ai_protocol = protocol; hints.ai_flags = AI_PASSIVE; if (getaddrinfo(host, port, &hints, &result) != 0) { fprintf(stderr, "get addr info failed. %s\n", strerror(errno)); - goto errout; - } + goto errout; + } - return result; + return result; errout: - if (result) { - freeaddrinfo(result); - } - return NULL; + if (result) { + freeaddrinfo(result); + } + return NULL; } -int dns_server_start(void) +int dns_server_start(void) { struct epoll_event event; event.events = EPOLLIN; @@ -162,102 +206,102 @@ int dns_server_start(void) int dns_server_socket(void) { - int fd = -1; - struct addrinfo *gai = NULL; + int fd = -1; + struct addrinfo *gai = NULL; - gai = _dns_server_getaddr(NULL, "53", SOCK_DGRAM, 0); - if (gai == NULL) { + gai = _dns_server_getaddr(NULL, "53", SOCK_DGRAM, 0); + if (gai == NULL) { fprintf(stderr, "get address failed.\n"); goto errout; } fd = socket(gai->ai_family, gai->ai_socktype, gai->ai_protocol); - if (fd < 0) { - fprintf(stderr, "create socket failed.\n"); - goto errout; - } + if (fd < 0) { + fprintf(stderr, "create socket failed.\n"); + goto errout; + } - if (bind(fd, gai->ai_addr, gai->ai_addrlen) != 0) { - fprintf(stderr, "bind failed.\n"); - goto errout; - } + if (bind(fd, gai->ai_addr, gai->ai_addrlen) != 0) { + fprintf(stderr, "bind failed.\n"); + goto errout; + } - server.fd = fd; - freeaddrinfo(gai); + server.fd = fd; + freeaddrinfo(gai); - return fd; + return fd; errout: - if (fd > 0) { - close(fd); - } + if (fd > 0) { + close(fd); + } - if (gai) { - freeaddrinfo(gai); - } - return -1; + if (gai) { + freeaddrinfo(gai); + } + return -1; } int dns_server_init(void) { - pthread_attr_t attr; - int epollfd = -1; - int fd = -1; + pthread_attr_t attr; + int epollfd = -1; + int fd = -1; - if (server.epoll_fd > 0) { - return -1; - } + if (server.epoll_fd > 0) { + return -1; + } - memset(&server, 0, sizeof(server)); - pthread_attr_init(&attr); + memset(&server, 0, sizeof(server)); + pthread_attr_init(&attr); - epollfd = epoll_create1(EPOLL_CLOEXEC); - if (epollfd < 0) { - fprintf(stderr, "create epoll failed, %s\n", strerror(errno)); - goto errout; - } + epollfd = epoll_create1(EPOLL_CLOEXEC); + if (epollfd < 0) { + fprintf(stderr, "create epoll failed, %s\n", strerror(errno)); + goto errout; + } fd = dns_server_socket(); - if (fd < 0) { + if (fd < 0) { fprintf(stderr, "create server socket failed.\n"); goto errout; } pthread_mutex_init(&server.map_lock, 0); - hash_init(server.hostmap); - server.epoll_fd = epollfd; - server.fd = fd; - server.run = 1; + hash_init(server.hostmap); + server.epoll_fd = epollfd; + server.fd = fd; + server.run = 1; - if (dns_server_start() != 0) { + if (dns_server_start() != 0) { fprintf(stderr, "start service failed.\n"); goto errout; } return 0; errout: - server.run = 0; + server.run = 0; - if (fd > 0) { - close(fd); - } + if (fd > 0) { + close(fd); + } - if (epollfd) { - close(epollfd); - } + if (epollfd) { + close(epollfd); + } - pthread_mutex_destroy(&server.map_lock); + pthread_mutex_destroy(&server.map_lock); - return -1; + return -1; } void dns_server_exit(void) { - server.run = 0; + server.run = 0; - if (server.fd > 0) { - close(server.fd); - server.fd = -1; - } + if (server.fd > 0) { + close(server.fd); + server.fd = -1; + } - pthread_mutex_destroy(&server.map_lock); -} \ No newline at end of file + pthread_mutex_destroy(&server.map_lock); +}