Add support for bootstrapping upstream DNS servers by hostname.
This commit is contained in:
@@ -2,9 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/joomcode/errorx"
|
||||
@@ -27,51 +25,29 @@ import (
|
||||
// log.Println(r)
|
||||
// pool.Put(c.Conn)
|
||||
type TLSPool struct {
|
||||
Address string
|
||||
parsedAddress *url.URL
|
||||
parsedAddressMutex sync.RWMutex
|
||||
boot *bootstrapper
|
||||
|
||||
// connections
|
||||
conns []net.Conn
|
||||
sync.Mutex // protects conns
|
||||
}
|
||||
|
||||
func (n *TLSPool) getHost() (string, error) {
|
||||
n.parsedAddressMutex.RLock()
|
||||
if n.parsedAddress != nil {
|
||||
n.parsedAddressMutex.RUnlock()
|
||||
return n.parsedAddress.Host, nil
|
||||
}
|
||||
n.parsedAddressMutex.RUnlock()
|
||||
|
||||
n.parsedAddressMutex.Lock()
|
||||
defer n.parsedAddressMutex.Unlock()
|
||||
url, err := url.Parse(n.Address)
|
||||
if err != nil {
|
||||
return "", errorx.Decorate(err, "Failed to parse %s", n.Address)
|
||||
}
|
||||
if url.Scheme != "tls" {
|
||||
return "", fmt.Errorf("TLSPool only supports TLS")
|
||||
}
|
||||
n.parsedAddress = url
|
||||
return n.parsedAddress.Host, nil
|
||||
connsMutex sync.Mutex // protects conns
|
||||
}
|
||||
|
||||
func (n *TLSPool) Get() (net.Conn, error) {
|
||||
host, err := n.getHost()
|
||||
address, tlsConfig, err := n.boot.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get the connection from the slice inside the lock
|
||||
var c net.Conn
|
||||
n.Lock()
|
||||
n.connsMutex.Lock()
|
||||
num := len(n.conns)
|
||||
if num > 0 {
|
||||
last := num - 1
|
||||
c = n.conns[last]
|
||||
n.conns = n.conns[:last]
|
||||
}
|
||||
n.Unlock()
|
||||
n.connsMutex.Unlock()
|
||||
|
||||
// if we got connection from the slice, return it
|
||||
if c != nil {
|
||||
@@ -80,10 +56,10 @@ func (n *TLSPool) Get() (net.Conn, error) {
|
||||
}
|
||||
|
||||
// we'll need a new connection, dial now
|
||||
// log.Printf("Dialing to %s", host)
|
||||
conn, err := tls.Dial("tcp", host, nil)
|
||||
// log.Printf("Dialing to %s", address)
|
||||
conn, err := tls.Dial("tcp", address, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, errorx.Decorate(err, "Failed to connect to %s", host)
|
||||
return nil, errorx.Decorate(err, "Failed to connect to %s", address)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
@@ -92,7 +68,7 @@ func (n *TLSPool) Put(c net.Conn) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
n.Lock()
|
||||
n.connsMutex.Lock()
|
||||
n.conns = append(n.conns, c)
|
||||
n.Unlock()
|
||||
n.connsMutex.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user