Added persistent connections cache
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
"time"
|
||||
@@ -8,24 +9,40 @@ import (
|
||||
|
||||
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||
type DnsUpstream struct {
|
||||
nameServer string // IP:port
|
||||
timeout time.Duration // Max read and write timeout
|
||||
endpoint string // IP:port
|
||||
timeout time.Duration // Max read and write timeout
|
||||
proto string // Protocol (tcp, tcp-tls, or udp)
|
||||
transport *Transport // Persistent connections cache
|
||||
}
|
||||
|
||||
// NewDnsUpstream creates a new plain-DNS upstream
|
||||
func NewDnsUpstream(nameServer string) (Upstream, error) {
|
||||
return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil
|
||||
// NewDnsUpstream creates a new DNS upstream
|
||||
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
||||
|
||||
u := &DnsUpstream{
|
||||
endpoint: endpoint,
|
||||
timeout: defaultTimeout,
|
||||
proto: proto,
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
if tlsServerName != "" {
|
||||
tlsConfig = new(tls.Config)
|
||||
tlsConfig.ServerName = tlsServerName
|
||||
}
|
||||
|
||||
// Initialize the connections cache
|
||||
u.transport = NewTransport(endpoint)
|
||||
u.transport.tlsConfig = tlsConfig
|
||||
u.transport.Start()
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
dnsClient := &dns.Client{
|
||||
ReadTimeout: u.timeout,
|
||||
WriteTimeout: u.timeout,
|
||||
}
|
||||
|
||||
resp, _, err := dnsClient.Exchange(query, u.nameServer)
|
||||
resp, err := u.exchange(query)
|
||||
|
||||
if err != nil {
|
||||
resp = &dns.Msg{}
|
||||
@@ -34,3 +51,42 @@ func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, e
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *DnsUpstream) Close() error {
|
||||
|
||||
// Close active connections
|
||||
u.transport.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performs a synchronous query. It sends the message m via the conn
|
||||
// c and waits for a reply. The conn c is not closed.
|
||||
func (u *DnsUpstream) exchange(query *dns.Msg) (r *dns.Msg, err error) {
|
||||
|
||||
// Establish a connection if needed (or reuse cached)
|
||||
conn, err := u.transport.Dial(u.proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the request with a timeout
|
||||
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
||||
if err = conn.WriteMsg(query); err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write response with a timeout
|
||||
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
||||
r, err = conn.ReadMsg()
|
||||
if err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
} else if err == nil && r.Id != query.Id {
|
||||
err = dns.ErrId
|
||||
conn.Close() // Not giving it back
|
||||
}
|
||||
|
||||
u.transport.Yield(conn)
|
||||
return r, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user