Add support for bootstrapping upstream DNS servers by hostname.
This commit is contained in:
@@ -29,7 +29,7 @@ type Upstream interface {
|
||||
// plain DNS
|
||||
//
|
||||
type plainDNS struct {
|
||||
address string
|
||||
boot bootstrapper
|
||||
preferTCP bool
|
||||
}
|
||||
|
||||
@@ -44,19 +44,25 @@ var defaultTCPClient = dns.Client{
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
|
||||
func (p *plainDNS) Address() string { return p.address }
|
||||
// Address returns the original address that we've put in initially, not resolved one
|
||||
func (p *plainDNS) Address() string { return p.boot.address }
|
||||
|
||||
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
addr, _, err := p.boot.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.preferTCP {
|
||||
reply, _, err := defaultTCPClient.Exchange(m, p.address)
|
||||
reply, _, err := defaultTCPClient.Exchange(m, addr)
|
||||
return reply, err
|
||||
}
|
||||
|
||||
reply, _, err := defaultUDPClient.Exchange(m, p.address)
|
||||
reply, _, err := defaultUDPClient.Exchange(m, addr)
|
||||
if err != nil && reply != nil && reply.Truncated {
|
||||
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
|
||||
reply, _, err = defaultTCPClient.Exchange(m, p.address)
|
||||
reply, _, err = defaultTCPClient.Exchange(m, addr)
|
||||
}
|
||||
|
||||
return reply, err
|
||||
}
|
||||
|
||||
@@ -64,8 +70,8 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
// DNS-over-TLS
|
||||
//
|
||||
type dnsOverTLS struct {
|
||||
address string
|
||||
pool *TLSPool
|
||||
boot bootstrapper
|
||||
pool *TLSPool
|
||||
|
||||
sync.RWMutex // protects pool
|
||||
}
|
||||
@@ -77,7 +83,7 @@ var defaultTLSClient = dns.Client{
|
||||
TLSConfig: &tls.Config{},
|
||||
}
|
||||
|
||||
func (p *dnsOverTLS) Address() string { return p.address }
|
||||
func (p *dnsOverTLS) Address() string { return p.boot.address }
|
||||
|
||||
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
var pool *TLSPool
|
||||
@@ -87,7 +93,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
if pool == nil {
|
||||
p.Lock()
|
||||
// lazy initialize it
|
||||
p.pool = &TLSPool{Address: p.address}
|
||||
p.pool = &TLSPool{boot: &p.boot}
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
@@ -95,19 +101,19 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
poolConn, err := p.pool.Get()
|
||||
p.RUnlock()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address)
|
||||
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address())
|
||||
}
|
||||
c := dns.Conn{Conn: poolConn}
|
||||
err = c.WriteMsg(m)
|
||||
if err != nil {
|
||||
poolConn.Close()
|
||||
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address)
|
||||
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address())
|
||||
}
|
||||
|
||||
reply, err := c.ReadMsg()
|
||||
if err != nil {
|
||||
poolConn.Close()
|
||||
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address)
|
||||
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address())
|
||||
}
|
||||
p.RLock()
|
||||
p.pool.Put(poolConn)
|
||||
@@ -119,7 +125,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
// DNS-over-https
|
||||
//
|
||||
type dnsOverHTTPS struct {
|
||||
address string
|
||||
boot bootstrapper
|
||||
}
|
||||
|
||||
var defaultHTTPSTransport = http.Transport{}
|
||||
@@ -129,35 +135,59 @@ var defaultHTTPSClient = http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
|
||||
func (p *dnsOverHTTPS) Address() string { return p.address }
|
||||
func (p *dnsOverHTTPS) Address() string { return p.boot.address }
|
||||
|
||||
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||
addr, tlsConfig, err := p.boot.get()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address)
|
||||
}
|
||||
|
||||
buf, err := m.Pack()
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't pack request msg")
|
||||
}
|
||||
bb := bytes.NewBuffer(buf)
|
||||
resp, err := http.Post(p.address, "application/dns-message", bb)
|
||||
|
||||
// set up a custom request with custom URL
|
||||
url, err := url.Parse(p.boot.address)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address)
|
||||
}
|
||||
req := http.Request{
|
||||
Method: "POST",
|
||||
URL: url,
|
||||
Body: ioutil.NopCloser(bb),
|
||||
Header: make(http.Header),
|
||||
Host: url.Host,
|
||||
}
|
||||
url.Host = addr
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{TLSClientConfig: tlsConfig},
|
||||
}
|
||||
resp, err := client.Do(&req)
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address)
|
||||
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address)
|
||||
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address)
|
||||
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address)
|
||||
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr)
|
||||
}
|
||||
response := dns.Msg{}
|
||||
err = response.Unpack(body)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body))
|
||||
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body))
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
@@ -178,7 +208,7 @@ func (s *Server) chooseUpstream() Upstream {
|
||||
return upstream
|
||||
}
|
||||
|
||||
func GetUpstream(address string) (Upstream, error) {
|
||||
func AddressToUpstream(address string, bootstrap string) (Upstream, error) {
|
||||
if strings.Contains(address, "://") {
|
||||
url, err := url.Parse(address)
|
||||
if err != nil {
|
||||
@@ -189,25 +219,28 @@ func GetUpstream(address string) (Upstream, error) {
|
||||
if url.Port() == "" {
|
||||
url.Host += ":53"
|
||||
}
|
||||
return &plainDNS{address: url.Host}, nil
|
||||
return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil
|
||||
case "tcp":
|
||||
if url.Port() == "" {
|
||||
url.Host += ":53"
|
||||
}
|
||||
return &plainDNS{address: url.Host, preferTCP: true}, nil
|
||||
return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil
|
||||
case "tls":
|
||||
if url.Port() == "" {
|
||||
url.Host += ":853"
|
||||
}
|
||||
return &dnsOverTLS{address: url.String()}, nil
|
||||
return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||
case "https":
|
||||
return &dnsOverHTTPS{address: url.String()}, nil
|
||||
if url.Port() == "" {
|
||||
url.Host += ":443"
|
||||
}
|
||||
return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||
default:
|
||||
// assume it's plain DNS
|
||||
if url.Port() == "" {
|
||||
url.Host += ":53"
|
||||
}
|
||||
return &plainDNS{address: url.String()}, nil
|
||||
return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,5 +250,5 @@ func GetUpstream(address string) (Upstream, error) {
|
||||
// doesn't have port, default to 53
|
||||
address = net.JoinHostPort(address, "53")
|
||||
}
|
||||
return &plainDNS{address: address}, nil
|
||||
return &plainDNS{boot: toBoot(address, bootstrap)}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user