diff --git a/dns.c b/dns.c index 7f430cb..3e9310c 100644 --- a/dns.c +++ b/dns.c @@ -144,111 +144,112 @@ int dns_rr_add_end(struct dns_packet *packet, int type, dns_type_t rrtype, int l rrs->type = rrtype; *start = packet->len; packet->len += len; - return sizeof(*rrs) + len; + return 0; } -int _dns_add_qr_head(unsigned char *data, int maxlen, char *domain, int qtype, int qclass) +static inline int _dns_data_left_len(struct dns_data_context *data_context) { - int i; - int len = 0; + return data_context->maxsize - (data_context->ptr - data_context->data); +} - for (i = 0; i < maxlen; i++) { - *data = *domain; +int _dns_add_qr_head(struct dns_data_context *data_context, char *domain, int qtype, int qclass) +{ + while (1) { + if (_dns_data_left_len(data_context) < 1) { + return -1; + } + *data_context->ptr = *domain; if (*domain == '\0') { - data++; - i++; + data_context->ptr++; break; } - data++; + data_context->ptr++; domain++; } - len += i; - if (maxlen - len < 4) { + if (_dns_data_left_len(data_context) < 4) { return -1; } - *((unsigned short *)(data)) = qtype; - data += 2; - len += 2; + *((unsigned short *)(data_context->ptr)) = qtype; + data_context->ptr += 2; - *((unsigned short *)(data)) = qclass; - data += 2; - len += 2; + *((unsigned short *)(data_context->ptr)) = qclass; + data_context->ptr += 2; - return len; + return 0; } -int _dns_get_qr_head(unsigned char *data, char *domain, int maxsize, int *qtype, int *qclass) +int _dns_get_qr_head(struct dns_data_context *data_context, char *domain, int maxsize, int *qtype, int *qclass) { int i; - int len = 0; + for (i = 0; i < maxsize; i++) { - *domain = *data; - if (*data == '\0') { + if (_dns_data_left_len(data_context) < 1) { + return -1; + } + *domain = *data_context->ptr; + if (*data_context->ptr == '\0') { domain++; - data++; + data_context->ptr++; i++; break; } - *domain = *data; + *domain = *data_context->ptr; domain++; - data++; + data_context->ptr++; } - len += i; - if (len >= maxsize) { + + if (_dns_data_left_len(data_context) < 4) { return -1; } - *qtype = *((unsigned short *)(data)); - data += 2; - len += 2; + *qtype = *((unsigned short *)(data_context->ptr)); + data_context->ptr += 2; - *qclass = *((unsigned short *)(data)); - data += 2; - len += 2; + *qclass = *((unsigned short *)(data_context->ptr)); + data_context->ptr += 2; - return len; + return 0; } -int _dns_add_rr_head(unsigned char *data, int maxlen, char *domain, int qtype, int qclass, int ttl, int rr_len) +int _dns_add_rr_head(struct dns_data_context *data_context, char *domain, int qtype, int qclass, int ttl, int rr_len) { int len = 0; - len = _dns_add_qr_head(data, maxlen, domain, qtype, qclass); + len = _dns_add_qr_head(data_context, domain, qtype, qclass); if (len < 0) { return -1; } - data += len; - if (maxlen - len < 6) { + + if (_dns_data_left_len(data_context) < 6) { return -1; } - *((unsigned int *)(data)) = ttl; - data += 4; - len += 4; + *((unsigned int *)(data_context->ptr)) = ttl; + data_context->ptr += 4; - *((unsigned short *)(data)) = rr_len; - data += 2; - len += 2; + *((unsigned short *)(data_context->ptr)) = rr_len; + data_context->ptr += 2; - return len; + return 0; } -int _dns_get_rr_head(unsigned char *data, char *domain, int maxsize, int *qtype, int *qclass, int *ttl, int *rr_len) +int _dns_get_rr_head(struct dns_data_context *data_context, char *domain, int maxsize, int *qtype, int *qclass, int *ttl, int *rr_len) { int len = 0; - len = _dns_get_qr_head(data, domain, maxsize, qtype, qclass); - data += len; + len = _dns_get_qr_head(data_context, domain, maxsize, qtype, qclass); - *ttl = *((unsigned int *)(data)); - data += 4; - len += 4; + if (_dns_data_left_len(data_context) < 6) { + return -1; + } - *rr_len = *((unsigned short *)(data)); - data += 2; - len += 2; + *ttl = *((unsigned int *)(data_context->ptr)); + data_context->ptr += 4; + + *rr_len = *((unsigned short *)(data_context->ptr)); + data_context->ptr += 2; return len; } @@ -257,21 +258,25 @@ int dns_add_A(struct dns_packet *packet, char *domain, int ttl, unsigned char ad { int maxlen = 0; int len = 0; + struct dns_data_context data_context; unsigned char *data = _dns_add_rrs_start(packet, &maxlen); if (data == NULL) { return -1; } - len = _dns_add_rr_head(data, maxlen, domain, DNS_T_A, DNS_C_IN, ttl, DNS_RR_A_LEN); + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = maxlen; + + len = _dns_add_rr_head(&data_context, domain, DNS_T_A, DNS_C_IN, ttl, DNS_RR_A_LEN); if (len < 0) { return -1; } - data += len; - memcpy(data, addr, DNS_RR_A_LEN); - data += DNS_RR_A_LEN; - len += DNS_RR_A_LEN; + memcpy(data_context.ptr, addr, DNS_RR_A_LEN); + data_context.ptr += DNS_RR_A_LEN; + len = data_context.ptr - data_context.data; return dns_rr_add_end(packet, DNS_RRS_AN, DNS_T_A, len); } @@ -281,8 +286,8 @@ int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned int qtype = 0; int qclass = 0; int rr_len = 0; - int len = 0; - int total_len = 0; + int ret = 0; + struct dns_data_context data_context; unsigned char *data = rrs->data; @@ -290,42 +295,47 @@ int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned return -1; } - len = _dns_get_rr_head(data, domain, maxsize, &qtype, &qclass, ttl, &rr_len); - if (len <= 0) { + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, maxsize, &qtype, &qclass, ttl, &rr_len); + if (ret < 0) { return -1; } - data += len; - total_len += len; if (qtype != DNS_T_A || rr_len != DNS_RR_A_LEN) { return -1; } - memcpy(addr, data, DNS_RR_A_LEN); - total_len += rr_len; - data += rr_len; + memcpy(addr, rrs->data, DNS_RR_A_LEN); - return total_len; + return 0; } int dns_add_AAAA(struct dns_packet *packet, char *domain, int ttl, unsigned char addr[16]) { int maxlen = 0; int len = 0; + struct dns_data_context data_context; + unsigned char *data = _dns_add_rrs_start(packet, &maxlen); if (data == NULL) { return -1; } - len = _dns_add_rr_head(data, maxlen, domain, DNS_T_AAAA, DNS_C_IN, ttl, DNS_RR_AAAA_LEN); + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = maxlen; + + len = _dns_add_rr_head(&data_context, domain, DNS_T_AAAA, DNS_C_IN, ttl, DNS_RR_AAAA_LEN); if (len < 0) { return -1; } - data += len; - memcpy(data, addr, DNS_RR_AAAA_LEN); - data += DNS_RR_AAAA_LEN; - len += DNS_RR_AAAA_LEN; + memcpy(data_context.ptr, addr, DNS_RR_AAAA_LEN); + data_context.ptr += DNS_RR_AAAA_LEN; + len = data_context.ptr - data_context.data; return dns_rr_add_end(packet, DNS_RRS_AN, DNS_T_AAAA, len); } @@ -335,8 +345,8 @@ int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsig int qtype = 0; int qclass = 0; int rr_len = 0; - int len = 0; - int total_len = 0; + int ret = 0; + struct dns_data_context data_context; if (rrs->type != DNS_T_AAAA) { return -1; @@ -344,21 +354,26 @@ int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsig unsigned char *data = rrs->data; - len = _dns_get_rr_head(data, domain, maxsize, &qtype, &qclass, ttl, &rr_len); - if (len <= 0) { + if (rrs->type != DNS_T_AAAA) { + return -1; + } + + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, maxsize, &qtype, &qclass, ttl, &rr_len); + if (ret <= 0) { return -1; } - data += len; - total_len += len; if (qtype != DNS_T_AAAA || rr_len != DNS_RR_AAAA_LEN) { return -1; } memcpy(addr, rrs->data, DNS_RR_AAAA_LEN); - total_len += DNS_RR_AAAA_LEN; - return total_len; + return 0; } /* @@ -370,36 +385,59 @@ int dns_add_domain(struct dns_packet *packet, char *domain, int qtype, int qclas int len = 0; int maxlen = 0; unsigned char *data = _dns_add_rrs_start(packet, &maxlen); + struct dns_data_context data_context; if (data == NULL) { return -1; } - len = _dns_add_qr_head(data, maxlen, domain, qtype, qclass); + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = maxlen; + + len = _dns_add_qr_head(&data_context, domain, qtype, qclass); if (len < 0) { return -1; } + len = data_context.ptr - data_context.data; + return dns_rr_add_end(packet, DNS_RRS_QD, DNS_T_CNAME, len); } int dns_get_domain(struct dns_rrs *rrs, char *domain, int maxsize, int *qtype, int *qclass) { + struct dns_data_context data_context; + unsigned char *data = rrs->data; + if (rrs->type != DNS_T_CNAME) { return -1; } - return _dns_get_qr_head(rrs->data, domain, maxsize, qtype, qclass); + data_context.data = data; + data_context.ptr = data; + data_context.maxsize = rrs->len; + + return _dns_get_qr_head(&data_context, domain, maxsize, qtype, qclass); } -int _dns_decode_head(struct dns_head *head, unsigned char *data) +static inline int _dns_left_len(struct dns_context *context) +{ + return context->maxsize - (context->ptr - context->data); +} + +int _dns_decode_head(struct dns_context *context) { unsigned int fields; - unsigned char *start = data; - unsigned char *end = data; + int len = 12; + struct dns_head *head = &context->packet->head; - head->id = dns_read_short(&data); - fields = dns_read_short(&data); + if (_dns_left_len(context) < len) { + return -1; + } + + head->id = dns_read_short(&context->ptr); + fields = dns_read_short(&context->ptr); head->qr = (fields & QR_MASK) >> 15; head->opcode = (fields & OPCODE_MASK) >> 11; head->aa = (fields & AA_MASK) >> 10; @@ -407,348 +445,330 @@ int _dns_decode_head(struct dns_head *head, unsigned char *data) head->rd = (fields & RD_MASK) >> 8; head->ra = (fields & RA_MASK) >> 7; head->rcode = (fields & RCODE_MASK) >> 0; - head->qdcount = dns_read_short(&data); - head->ancount = dns_read_short(&data); - head->nscount = dns_read_short(&data); - head->nrcount = dns_read_short(&data); + head->qdcount = dns_read_short(&context->ptr); + head->ancount = dns_read_short(&context->ptr); + head->nscount = dns_read_short(&context->ptr); + head->nrcount = dns_read_short(&context->ptr); - end = data; - return end - start; + return 0; } -int _dns_encode_head(unsigned char *data, int size, struct dns_head *head) +int _dns_encode_head(struct dns_context *context) { int len = 12; + struct dns_head *head = &context->packet->head; - if (size < len) { + if (_dns_left_len(context) < len) { return -1; } - dns_write_short(&data, head->id); + dns_write_short(&context->ptr, head->id); int fields = 0; fields |= (head->qr << 15) & QR_MASK; fields |= (head->rcode << 0) & RCODE_MASK; - dns_write_short(&data, fields); + dns_write_short(&context->ptr, fields); - dns_write_short(&data, head->qdcount); - dns_write_short(&data, head->ancount); - dns_write_short(&data, head->nscount); - dns_write_short(&data, head->nrcount); + dns_write_short(&context->ptr, head->qdcount); + dns_write_short(&context->ptr, head->ancount); + dns_write_short(&context->ptr, head->nscount); + dns_write_short(&context->ptr, head->nrcount); return len; } -int _dns_decode_domain(char *output, int size, unsigned char *data) +int _dns_decode_domain(struct dns_context *context, char *output, int size) { - int i = 0; int output_len = 0; int copy_len = 0; - int total_len = 0; + int len = *(context->ptr); - while (data[i]) { - int len = data[i]; - - if (i != 0) { - *output = '.'; - output++; + while (*(context->ptr)) { + if (_dns_left_len(context) < 1) { + return -1; } - i++; - total_len++; + context->ptr++; if (output_len < size - 1) { copy_len = (len < size - output_len) ? len : size - 1 - output_len; - memcpy(output, data + i, copy_len); + if (_dns_left_len(context) < copy_len) { + return -1; + } + memcpy(output, context->ptr, copy_len); } - i += len; + + context->ptr += len; output += len; output_len += len; - total_len += len; + + len = *(context->ptr); + if (len == 0) { + break; + } + *output = '.'; + output++; + } *output = 0; - total_len++; - return total_len; + context->ptr++; + + return 0; } -int _dns_encode_domain(unsigned char *output, int size, char *domain) +int _dns_encode_domain(struct dns_context *context, char *domain) { - int i = 0; int num = 0; - int total_len = 0; - unsigned char *ptr_num = output++; - total_len++; - while (i < size && *domain != 0) { + unsigned char *ptr_num = context->ptr++; + + while (_dns_left_len(context) > 1 && *domain != 0) { if (*domain == '.') { *ptr_num = num; num = 0; - ptr_num = output; + ptr_num = context->ptr; domain++; - output++; - total_len++; + context->ptr++; continue; } - *output = *domain; + *context->ptr = *domain; num++; - output++; + context->ptr++; domain++; - total_len++; } + *ptr_num = num; - *output = 0; - total_len++; - return total_len; + *context->ptr = 0; + return 0; } -int _dns_decode_qr_head(unsigned char *data, int size, char *domain, int domain_size, int *qtype, int *qclass) +int _dns_decode_qr_head(struct dns_context *context, char *domain, int domain_size, int *qtype, int *qclass) +{ + int ret = 0; + ret = _dns_decode_domain(context, domain, domain_size); + if (ret < 0) { + return -1; + } + + if (_dns_left_len(context) < 4) { + return -1; + } + + *qtype = dns_read_short(&context->ptr); + *qclass = dns_read_short(&context->ptr); + + return 0; +} + +int _dns_encode_qr_head(struct dns_context *context, char *domain, int qtype, int qclass) +{ + int ret = 0; + ret = _dns_encode_domain(context, domain); + if (ret <= 0) { + return -1; + } + + if (_dns_left_len(context) < 4) { + return -1; + } + + dns_write_short(&context->ptr, qtype); + dns_write_short(&context->ptr, qclass); + + return 0; +} + +int _dns_decode_rr_head(struct dns_context *context, char *domain, int domain_size, int *qtype, int *qclass, int *ttl, int *rr_len) { int len = 0; - len = _dns_decode_domain(domain, domain_size, data); + + len = _dns_decode_qr_head(context, domain, domain_size, qtype, qclass); if (len <= 0) { return -1; } - data += len; - *qtype = dns_read_short(&data); - len += 2; - *qclass = dns_read_short(&data); - len += 2; - - return len; -} - -int _dns_encode_qr_head(unsigned char *data, int size, char *domain, int qtype, int qclass) -{ - int len = 0; - len = _dns_encode_domain(data, size, domain); - if (len <= 0) { - return -1; - } - data += len; - - if (size - len < 4) { + if (_dns_left_len(context) < 6) { return -1; } - dns_write_short(&data, qtype); - len += 2; - dns_write_short(&data, qclass); - len += 2; + *ttl = dns_read_int(&context->ptr); + *rr_len = dns_read_short(&context->ptr); - return len; + return 0; } -int _dns_decode_rr_head(unsigned char *data, int size, char *domain, int domain_size, int *qtype, int *qclass, int *ttl, int *rr_len) +int _dns_encode_rr_head(struct dns_context *context, char *domain, int qtype, int qclass, int ttl, int rr_len) { - int len = 0; - int total_len = 0; - - len = _dns_decode_qr_head(data, size, domain, domain_size, qtype, qclass); - if (len <= 0) { + int ret = 0; + ret = _dns_encode_qr_head(context, domain, qtype, qclass); + if (ret <= 0) { return -1; } - data += len; - total_len += len; - - *ttl = dns_read_int(&data); - len += 4; - total_len += 4; - - *rr_len = dns_read_short(&data); - len += 2; - total_len += 2; - - return total_len; -} - -int _dns_encode_rr_head(unsigned char *data, int size, char *domain, int qtype, int qclass, int ttl, int rr_len) -{ - int len = 0; - int total_len = 0; - len = _dns_encode_qr_head(data, size, domain, qtype, qclass); - if (len <= 0) { + if (_dns_left_len(context) < 6) { return -1; } - data += len; - total_len += len; + dns_write_int(&context->ptr, ttl); + dns_write_short(&context->ptr, rr_len); - if (size - len < 6) { + return 0; +} + +int _dns_decode_A(struct dns_context *context, unsigned char addr[4]) +{ + if (_dns_left_len(context) < DNS_RR_A_LEN) { return -1; } - dns_write_int(&data, ttl); - len += 4; - total_len += 4; - - dns_write_short(&data, rr_len); - len += 2; - total_len += 2; - - return total_len; + memcpy(addr, context->ptr, DNS_RR_A_LEN); + context->ptr += DNS_RR_A_LEN; + return 0; } -int _dns_decode_A(unsigned char addr[4], unsigned char *data) +int _dns_encode_A(struct dns_context *context, struct dns_rrs *rrs) { - memcpy(addr, data, DNS_RR_A_LEN); - return DNS_RR_A_LEN; -} - -int _dns_encode_A(unsigned char *output, int size, struct dns_rrs *rrs) -{ - int len; - int len_rrs; - int total_len = 0; + int ret; int qtype = 0; int qclass = 0; int ttl = 0; char domain[DNS_MAX_CNAME_LEN]; - unsigned char *data_rrs; int rr_len; + struct dns_data_context data_context; - data_rrs = rrs->data; - len_rrs = _dns_get_rr_head(data_rrs, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); - if (len_rrs <= 0) { + data_context.data = rrs->data; + data_context.ptr = rrs->data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (ret < 0) { return -1; } - data_rrs += len_rrs; if (rr_len != DNS_RR_A_LEN) { return -1; } - len = _dns_encode_rr_head(output, size, domain, qtype, qclass, ttl, DNS_RR_A_LEN); - if (len <= 0) { - return -1; - } - output += len; - total_len += len; - - if (size - total_len < rr_len + DNS_RR_A_LEN) { + ret = _dns_encode_rr_head(context, domain, qtype, qclass, ttl, DNS_RR_A_LEN); + if (ret <= 0) { return -1; } - memcpy(output, data_rrs, DNS_RR_A_LEN); - output += DNS_RR_A_LEN; - data_rrs += DNS_RR_A_LEN; - total_len += DNS_RR_A_LEN; + if (_dns_left_len(context) < DNS_RR_A_LEN) { + return -1; + } - return total_len; + memcpy(context->ptr, rrs->data, DNS_RR_A_LEN); + context->ptr += DNS_RR_A_LEN; + + return 0; } -int _dns_decode_AAAA(unsigned char addr[DNS_RR_AAAA_LEN], unsigned char *data) +int _dns_decode_AAAA(struct dns_context *context, unsigned char addr[DNS_RR_AAAA_LEN]) { - memcpy(addr, data, DNS_RR_AAAA_LEN); - return DNS_RR_AAAA_LEN; + if (_dns_left_len(context) < DNS_RR_AAAA_LEN) { + return -1; + } + + memcpy(addr, context->ptr, DNS_RR_AAAA_LEN); + context->ptr += DNS_RR_AAAA_LEN; + return 0; } -int _dns_encode_AAAA(unsigned char *output, int size, struct dns_rrs *rrs) +int _dns_encode_AAAA(struct dns_context *context, struct dns_rrs *rrs) { - int len; - int len_rrs; - int total_len = 0; + int ret; int qtype = 0; int qclass = 0; int ttl = 0; char domain[DNS_MAX_CNAME_LEN]; - unsigned char *data_rrs; int rr_len; + struct dns_data_context data_context; - data_rrs = rrs->data; - len_rrs = _dns_get_rr_head(data_rrs, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); - if (len_rrs <= 0) { + data_context.data = rrs->data; + data_context.ptr = rrs->data; + data_context.maxsize = rrs->len; + + ret = _dns_get_rr_head(&data_context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (ret <= 0) { return -1; } - data_rrs += len_rrs; if (rr_len != DNS_RR_AAAA_LEN) { return -1; } - len = _dns_encode_rr_head(output, size, domain, qtype, qclass, ttl, DNS_RR_AAAA_LEN); - if (len <= 0) { - return -1; - } - output += len; - total_len += len; - - if (size - total_len < rr_len + DNS_RR_AAAA_LEN) { + ret = _dns_encode_rr_head(context, domain, qtype, qclass, ttl, DNS_RR_AAAA_LEN); + if (ret <= 0) { return -1; } - memcpy(output, data_rrs, DNS_RR_AAAA_LEN); - output += DNS_RR_AAAA_LEN; - data_rrs += DNS_RR_AAAA_LEN; - total_len += DNS_RR_AAAA_LEN; + if (_dns_left_len(context) < DNS_RR_AAAA_LEN) { + return -1; + } - return total_len; + memcpy(context->ptr, rrs->data, DNS_RR_AAAA_LEN); + context->ptr += DNS_RR_AAAA_LEN; + + return 0; } -int _dns_decode_qd(struct dns_packet *packet, unsigned char *data, int size) +int _dns_decode_qd(struct dns_context *context) { + struct dns_packet *packet = context->packet; int len; - int decode_len = 0; int qtype = 0; int qclass = 0; char domain[DNS_MAX_CNAME_LEN]; - int ttl; - int rr_len; - len = _dns_decode_qr_head(data, size, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass); + len = _dns_decode_qr_head(context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass); + if (len < 0) { + return -1; + } + + len = dns_add_domain(packet, domain, qtype, qclass); if (len <= 0) { return -1; } - decode_len += len; - len = dns_add_domain(packet, domain, qtype, qclass); - if ( len <= 0 ) { - return -1; - } - - return decode_len; + return 0; } -int _dns_decode_an(struct dns_packet *packet, unsigned char *data, int size) +int _dns_decode_an(struct dns_context *context) { - int len; + int ret; int qtype = 0; int qclass = 0; int ttl; int rr_len = 0; char domain[DNS_MAX_CNAME_LEN]; - int decode_len = 0; + struct dns_packet *packet = context->packet; - len = _dns_decode_rr_head(data, size, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); - if (len <= 0) { + ret = _dns_decode_rr_head(context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass, &ttl, &rr_len); + if (ret < 0) { return -1; } - data += len; - decode_len += len; switch (qtype) { case DNS_T_A: { unsigned char addr[DNS_RR_A_LEN]; - len = _dns_decode_A(addr, data); - if (len < 0) { + ret = _dns_decode_A(context, addr); + if (ret < 0) { return -1; } - data += len; - decode_len += len; - len = dns_add_A(packet, domain, ttl, addr); - if (len < 0) { + + ret = dns_add_A(packet, domain, ttl, addr); + if (ret < 0) { return -1; } } break; case DNS_T_AAAA: { unsigned char addr[DNS_RR_AAAA_LEN]; - len = _dns_decode_AAAA(addr, data); - if (len < 0) { + ret = _dns_decode_AAAA(context, addr); + if (ret < 0) { return -1; } - data += len; - decode_len += len; - len = dns_add_AAAA(packet, domain, ttl, addr); - if (len < 0) { + + ret = dns_add_AAAA(packet, domain, ttl, addr); + if (ret < 0) { return -1; } } break; @@ -756,120 +776,111 @@ int _dns_decode_an(struct dns_packet *packet, unsigned char *data, int size) break; } - return decode_len; + return 0; } -int _dns_encode_qd(unsigned char *data, int size, struct dns_rrs *rrs) +int _dns_encode_qd(struct dns_context *context, struct dns_rrs *rrs) { - int len; - int len_rrs; + int ret; int qtype = 0; int qclass = 0; - int total_len = 0; char domain[DNS_MAX_CNAME_LEN]; - unsigned char *data_rrs = rrs->data; - len_rrs = _dns_get_qr_head(data_rrs, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass); - if (len_rrs <= 0) { + struct dns_data_context data_context; + + data_context.data = rrs->data; + data_context.ptr = rrs->data; + data_context.maxsize = rrs->len; + + ret = _dns_get_qr_head(&data_context, domain, DNS_MAX_CNAME_LEN, &qtype, &qclass); + if (ret <= 0) { return -1; } - len = _dns_encode_qr_head(data, size, domain, qtype, qclass); - if (len <= 0) { + ret = _dns_encode_qr_head(context, domain, qtype, qclass); + if (ret <= 0) { return -1; } - total_len += len; - return total_len; + return 0; } -int _dns_encode_an(unsigned char *data, int size, struct dns_rrs *rrs) +int _dns_encode_an(struct dns_context *context, struct dns_rrs *rrs) { - int len; - int total_len = 0; + int ret; switch (rrs->type) { case DNS_T_A: { - len = _dns_encode_A(data, size, rrs); - if (len < 0) { + ret = _dns_encode_A(context, rrs); + if (ret < 0) { return -1; } - total_len += len; } break; case DNS_T_AAAA: - len = _dns_encode_AAAA(data, size, rrs); - if (len < 0) { + ret = _dns_encode_AAAA(context, rrs); + if (ret < 0) { return -1; } - total_len += len; break; default: break; } - return total_len; + return 0; } -int _dns_decode_body(struct dns_packet *packet, unsigned char *data, int size) +int _dns_decode_body(struct dns_context *context) { + struct dns_packet *packet = context->packet; struct dns_head *head = &packet->head; int i = 0; - int len = 0; - int decode_len = 0; + int ret = 0; for (i = 0; i < head->qdcount; i++) { - len = _dns_decode_qd(packet, data, size - decode_len); - if (len <= 0) { + ret = _dns_decode_qd(context); + if (ret <= 0) { return -1; } head->qdcount--; - decode_len += len; - data += len; } for (i = 0; i < head->ancount; i++) { - len = _dns_decode_an(packet, data, size - decode_len); - if (len <= 0) { + ret = _dns_decode_an(context); + if (ret <= 0) { return -1; } head->ancount--; - decode_len += len; - data += len; } - return decode_len; + return 0; } -int _dns_encode_body(unsigned char *data, int size, struct dns_packet *packet) +int _dns_encode_body(struct dns_context *context) { + struct dns_packet *packet = context->packet; struct dns_head *head = &packet->head; int i = 0; int len = 0; - int encode_len = 0; struct dns_rrs *rrs; int count; rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &count); head->qdcount = count; for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { - len = _dns_encode_qd(data, size - encode_len, rrs); + len = _dns_encode_qd(context, rrs); if (len <= 0) { return -1; } - encode_len += len; - data += len; } rrs = dns_get_rrs_start(packet, DNS_RRS_AN, &count); head->ancount = count; for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { - len = _dns_encode_an(data, size - encode_len, rrs); + len = _dns_encode_an(context, rrs); if (len <= 0) { return -1; } - encode_len += len; - data += len; } - return encode_len; + return 0; } int dns_packet_init(struct dns_packet *packet, int size, struct dns_head *head) @@ -896,43 +907,51 @@ int dns_packet_init(struct dns_packet *packet, int size, struct dns_head *head) int dns_decode(struct dns_packet *packet, int maxsize, unsigned char *data, int size) { struct dns_head *head = &packet->head; - int decode_len = 0; - int len = 0; + struct dns_context context; + int ret = 0; + memset(&context, 0, sizeof(context)); memset(packet, 0, sizeof(*packet)); + + context.data = data; + context.packet = packet; + context.ptr = data; + context.maxsize = size; + dns_packet_init(packet, maxsize, head); - len = _dns_decode_head(head, data); - if (len < 0) { + ret = _dns_decode_head(&context); + if (ret < 0) { return -1; } - data += len; - decode_len += len; - len = _dns_decode_body(packet, data, size - decode_len); - if (len < 0) { + ret = _dns_decode_body(&context); + if (ret < 0) { return -1; } - decode_len += len; - return decode_len; + return 0; } int dns_encode(unsigned char *data, int size, struct dns_packet *packet) { - int len = 0; - int total_len = 0; + int ret = 0; + struct dns_context context; - len = _dns_encode_head(data, size, &packet->head); - if (len <= 0) { + memset(&context, 0, sizeof(context)); + context.data = data; + context.packet = packet; + context.ptr = data; + context.maxsize = size; + + ret = _dns_encode_head(&context); + if (ret <= 0) { return -1; } - data += len; - total_len += len; - len = _dns_encode_body(data, size - len, packet); - if (len <= 0) { + ret = _dns_encode_body(&context); + if (ret <= 0) { return -1; } - total_len += len; - return total_len; + + return context.ptr - context.data; } diff --git a/dns.h b/dns.h index 04cb04c..b1de1ea 100644 --- a/dns.h +++ b/dns.h @@ -71,10 +71,10 @@ struct dns_head { unsigned short id; // identification number unsigned short qr; /* Query/Response Flag */ unsigned short opcode; /* Operation Code */ - unsigned char aa; /* Authoritative Answer Flag */ - unsigned char tc; /* Truncation Flag */ - unsigned char rd; /* Recursion Desired */ - unsigned char ra; /* Recursion Available */ + unsigned char aa; /* Authoritative Answer Flag */ + unsigned char tc; /* Truncation Flag */ + unsigned char rd; /* Recursion Desired */ + unsigned char ra; /* Recursion Available */ unsigned short rcode; /* Response Code */ unsigned short qdcount; // number of question entries unsigned short ancount; // number of answer entries @@ -100,6 +100,20 @@ struct dns_packet { unsigned char data[0]; }; +struct dns_data_context +{ + unsigned char *data; + unsigned char *ptr; + unsigned int maxsize; +}; + +struct dns_context { + struct dns_packet *packet; + unsigned char *data; + unsigned int maxsize; + unsigned char *ptr; +}; + struct dns_rrs *dns_get_rrs_next(struct dns_packet *packet, struct dns_rrs *rrs); struct dns_rrs *dns_get_rrs_start(struct dns_packet *packet, int type, int *count); diff --git a/dns_server.c b/dns_server.c index b4d65b3..8e4e6d4 100644 --- a/dns_server.c +++ b/dns_server.c @@ -46,6 +46,7 @@ static void tv_sub(struct timeval *out, struct timeval *in) void _dns_server_period_run() { + return; unsigned char packet_data[DNS_INPACKET_SIZE]; unsigned char data[DNS_INPACKET_SIZE]; @@ -93,13 +94,19 @@ static int _dns_server_process(struct timeval *now) goto errout; } - dns_decode(packet, DNS_INPACKET_SIZE, inpacket, len); + len = dns_decode(packet, DNS_INPACKET_SIZE, inpacket, len); + if (len) { + printf("decode failed.\n"); + goto errout; + } int count; struct dns_rrs *rrs; char name[128]; int i = 0; int ttl; + int qtype; + int qclass; rrs = dns_get_rrs_start(packet, DNS_RRS_AN, &count); for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { @@ -114,6 +121,18 @@ static int _dns_server_process(struct timeval *now) } } + rrs = dns_get_rrs_start(packet, DNS_RRS_QD, &count); + for (i = 0; i < count && rrs; i++, rrs = dns_get_rrs_next(packet, rrs)) { + switch (rrs->type) { + case DNS_T_CNAME: { + dns_get_domain(rrs, name, 128, &qtype, &qclass); + printf("domain: %s qtype: %d qclass: %d\n", name, qtype, qclass); + } break; + default: + break; + } + } + return 0; errout: return -1;