Pull request: * all: move internal Go packages to internal/

Merge in DNS/adguard-home from 2234-move-to-internal to master

Squashed commit of the following:

commit d26a288cabeac86f9483fab307677b1027c78524
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Oct 30 12:44:18 2020 +0300

    * all: move internal Go packages to internal/

    Closes #2234.
This commit is contained in:
Ainar Garipov
2020-10-30 13:32:02 +03:00
parent df3fa595a2
commit ae8de95d89
125 changed files with 85 additions and 85 deletions

64
internal/dhcpd/README.md Normal file
View File

@@ -0,0 +1,64 @@
# DHCP server
Contents:
* [Test setup with Virtual Box](#vbox)
<a id="vbox"></a>
## Test setup with Virtual Box
To set up a test environment for DHCP server you need:
* Linux host machine
* Virtual Box
* Virtual machine (guest OS doesn't matter)
### Configure client
1. Install Virtual Box and run the following command to create a Host-Only network:
$ VBoxManage hostonlyif create
You can check its status by `ip a` command.
You can also set up Host-Only network using Virtual Box menu:
File -> Host Network Manager...
2. Create your virtual machine and set up its network:
VM Settings -> Network -> Host-only Adapter
3. Start your VM, install an OS. Configure your network interface to use DHCP and the OS should ask for a IP address from our DHCP server.
4. To see the current IP address on client OS you can use `ip a` command on Linux or `ipconfig` on Windows.
5. To force the client OS to request an IP from DHCP server again, you can use `dhclient` on Linux or `ipconfig /release` on Windows.
### Configure server
1. Edit server configuration file 'AdGuardHome.yaml', for example:
dhcp:
enabled: true
interface_name: vboxnet0
dhcpv4:
gateway_ip: 192.168.56.1
subnet_mask: 255.255.255.0
range_start: 192.168.56.2
range_end: 192.168.56.2
lease_duration: 86400
icmp_timeout_msec: 1000
options: []
dhcpv6:
range_start: 2001::1
lease_duration: 86400
ra_slaac_only: false
ra_allow_slaac: false
2. Start the server
./AdGuardHome
There should be a message in log which shows that DHCP server is ready:
[info] DHCP: listening on 0.0.0.0:67

View File

@@ -0,0 +1,215 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"bytes"
"fmt"
"net"
"os"
"runtime"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd/nclient4"
"github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/dhcpv6/nclient6"
"github.com/insomniacslk/dhcp/iana"
)
// CheckIfOtherDHCPServersPresentV4 sends a DHCP request to the specified network interface,
// and waits for a response for a period defined by defaultDiscoverTime
func CheckIfOtherDHCPServersPresentV4(ifaceName string) (bool, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
return false, wrapErrPrint(err, "Couldn't find interface by name %s", ifaceName)
}
// get ipv4 address of an interface
ifaceIPNet := getIfaceIPv4(*iface)
if len(ifaceIPNet) == 0 {
return false, fmt.Errorf("couldn't find IPv4 address of interface %s %+v", ifaceName, iface)
}
if runtime.GOOS == "darwin" {
return false, fmt.Errorf("can't find DHCP server: not supported on macOS")
}
srcIP := ifaceIPNet[0]
src := net.JoinHostPort(srcIP.String(), "68")
dst := "255.255.255.255:67"
hostname, _ := os.Hostname()
req, err := dhcpv4.NewDiscovery(iface.HardwareAddr)
if err != nil {
return false, fmt.Errorf("dhcpv4.NewDiscovery: %s", err)
}
req.Options.Update(dhcpv4.OptClientIdentifier(iface.HardwareAddr))
req.Options.Update(dhcpv4.OptHostName(hostname))
// resolve 0.0.0.0:68
udpAddr, err := net.ResolveUDPAddr("udp4", src)
if err != nil {
return false, wrapErrPrint(err, "Couldn't resolve UDP address %s", src)
}
if !udpAddr.IP.To4().Equal(srcIP) {
return false, wrapErrPrint(err, "Resolved UDP address is not %s", src)
}
// resolve 255.255.255.255:67
dstAddr, err := net.ResolveUDPAddr("udp4", dst)
if err != nil {
return false, wrapErrPrint(err, "Couldn't resolve UDP address %s", dst)
}
// bind to 0.0.0.0:68
log.Tracef("Listening to udp4 %+v", udpAddr)
c, err := nclient4.NewRawUDPConn(ifaceName, 68)
if err != nil {
return false, wrapErrPrint(err, "Couldn't listen on :68")
}
if c != nil {
defer c.Close()
}
// send to 255.255.255.255:67
_, err = c.WriteTo(req.ToBytes(), dstAddr)
if err != nil {
return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst)
}
for {
// wait for answer
log.Tracef("Waiting %v for an answer", defaultDiscoverTime)
// TODO: replicate dhclient's behaviour of retrying several times with progressively bigger timeouts
b := make([]byte, 1500)
_ = c.SetReadDeadline(time.Now().Add(defaultDiscoverTime))
n, _, err := c.ReadFrom(b)
if isTimeout(err) {
// timed out -- no DHCP servers
log.Debug("DHCPv4: didn't receive DHCP response")
return false, nil
}
if err != nil {
return false, wrapErrPrint(err, "Couldn't receive packet")
}
log.Tracef("Received packet (%v bytes)", n)
response, err := dhcpv4.FromBytes(b[:n])
if err != nil {
log.Debug("DHCPv4: dhcpv4.FromBytes: %s", err)
continue
}
log.Debug("DHCPv4: received message from server: %s", response.Summary())
if !(response.OpCode == dhcpv4.OpcodeBootReply &&
response.HWType == iana.HWTypeEthernet &&
bytes.Equal(response.ClientHWAddr, iface.HardwareAddr) &&
bytes.Equal(response.TransactionID[:], req.TransactionID[:]) &&
response.Options.Has(dhcpv4.OptionDHCPMessageType)) {
log.Debug("DHCPv4: received message from server doesn't match our request")
continue
}
log.Tracef("The packet is from an active DHCP server")
// that's a DHCP server there
return true, nil
}
}
// CheckIfOtherDHCPServersPresentV6 sends a DHCP request to the specified network interface,
// and waits for a response for a period defined by defaultDiscoverTime
func CheckIfOtherDHCPServersPresentV6(ifaceName string) (bool, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
return false, fmt.Errorf("DHCPv6: net.InterfaceByName: %s: %s", ifaceName, err)
}
ifaceIPNet := getIfaceIPv6(*iface)
if len(ifaceIPNet) == 0 {
return false, fmt.Errorf("DHCPv6: couldn't find IPv6 address of interface %s %+v", ifaceName, iface)
}
srcIP := ifaceIPNet[0]
src := net.JoinHostPort(srcIP.String(), "546")
dst := "[ff02::1:2]:547"
req, err := dhcpv6.NewSolicit(iface.HardwareAddr)
if err != nil {
return false, fmt.Errorf("DHCPv6: dhcpv6.NewSolicit: %s", err)
}
udpAddr, err := net.ResolveUDPAddr("udp6", src)
if err != nil {
return false, wrapErrPrint(err, "DHCPv6: Couldn't resolve UDP address %s", src)
}
if !udpAddr.IP.To16().Equal(srcIP) {
return false, wrapErrPrint(err, "DHCPv6: Resolved UDP address is not %s", src)
}
dstAddr, err := net.ResolveUDPAddr("udp6", dst)
if err != nil {
return false, fmt.Errorf("DHCPv6: Couldn't resolve UDP address %s: %s", dst, err)
}
log.Debug("DHCPv6: Listening to udp6 %+v", udpAddr)
c, err := nclient6.NewIPv6UDPConn(ifaceName, dhcpv6.DefaultClientPort)
if err != nil {
return false, fmt.Errorf("DHCPv6: Couldn't listen on :546: %s", err)
}
if c != nil {
defer c.Close()
}
_, err = c.WriteTo(req.ToBytes(), dstAddr)
if err != nil {
return false, fmt.Errorf("DHCPv6: Couldn't send a packet to %s: %s", dst, err)
}
for {
log.Debug("DHCPv6: Waiting %v for an answer", defaultDiscoverTime)
b := make([]byte, 4096)
_ = c.SetReadDeadline(time.Now().Add(defaultDiscoverTime))
n, _, err := c.ReadFrom(b)
if isTimeout(err) {
log.Debug("DHCPv6: didn't receive DHCP response")
return false, nil
}
if err != nil {
return false, wrapErrPrint(err, "Couldn't receive packet")
}
log.Debug("DHCPv6: Received packet (%v bytes)", n)
resp, err := dhcpv6.FromBytes(b[:n])
if err != nil {
log.Debug("DHCPv6: dhcpv6.FromBytes: %s", err)
continue
}
log.Debug("DHCPv6: received message from server: %s", resp.Summary())
cid := req.Options.ClientID()
msg, err := resp.GetInnerMessage()
if err != nil {
log.Debug("DHCPv6: resp.GetInnerMessage: %s", err)
continue
}
rcid := msg.Options.ClientID()
if resp.Type() == dhcpv6.MessageTypeAdvertise &&
msg.TransactionID == req.TransactionID &&
rcid != nil &&
cid.Equal(*rcid) {
log.Debug("DHCPv6: The packet is from an active DHCP server")
return true, nil
}
log.Debug("DHCPv6: received message from server doesn't match our request")
}
}

View File

@@ -0,0 +1,11 @@
package dhcpd
import "fmt"
func CheckIfOtherDHCPServersPresentV4(ifaceName string) (bool, error) {
return false, fmt.Errorf("not supported")
}
func CheckIfOtherDHCPServersPresentV6(ifaceName string) (bool, error) {
return false, fmt.Errorf("not supported")
}

173
internal/dhcpd/db.go Normal file
View File

@@ -0,0 +1,173 @@
// On-disk database for lease table
package dhcpd
import (
"encoding/json"
"io/ioutil"
"net"
"os"
"time"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
)
const dbFilename = "leases.db"
type leaseJSON struct {
HWAddr []byte `json:"mac"`
IP []byte `json:"ip"`
Hostname string `json:"host"`
Expiry int64 `json:"exp"`
}
func normalizeIP(ip net.IP) net.IP {
ip4 := ip.To4()
if ip4 != nil {
return ip4
}
return ip
}
// Load lease table from DB
func (s *Server) dbLoad() {
dynLeases := []*Lease{}
staticLeases := []*Lease{}
v6StaticLeases := []*Lease{}
v6DynLeases := []*Lease{}
data, err := ioutil.ReadFile(s.conf.DBFilePath)
if err != nil {
if !os.IsNotExist(err) {
log.Error("DHCP: can't read file %s: %v", s.conf.DBFilePath, err)
}
return
}
obj := []leaseJSON{}
err = json.Unmarshal(data, &obj)
if err != nil {
log.Error("DHCP: invalid DB: %v", err)
return
}
numLeases := len(obj)
for i := range obj {
obj[i].IP = normalizeIP(obj[i].IP)
if !(len(obj[i].IP) == 4 || len(obj[i].IP) == 16) {
log.Info("DHCP: invalid IP: %s", obj[i].IP)
continue
}
lease := Lease{
HWAddr: obj[i].HWAddr,
IP: obj[i].IP,
Hostname: obj[i].Hostname,
Expiry: time.Unix(obj[i].Expiry, 0),
}
if len(obj[i].IP) == 16 {
if obj[i].Expiry == leaseExpireStatic {
v6StaticLeases = append(v6StaticLeases, &lease)
} else {
v6DynLeases = append(v6DynLeases, &lease)
}
} else {
if obj[i].Expiry == leaseExpireStatic {
staticLeases = append(staticLeases, &lease)
} else {
dynLeases = append(dynLeases, &lease)
}
}
}
leases4 := normalizeLeases(staticLeases, dynLeases)
s.srv4.ResetLeases(leases4)
leases6 := normalizeLeases(v6StaticLeases, v6DynLeases)
if s.srv6 != nil {
s.srv6.ResetLeases(leases6)
}
log.Info("DHCP: loaded leases v4:%d v6:%d total-read:%d from DB",
len(leases4), len(leases6), numLeases)
}
// Skip duplicate leases
// Static leases have a priority over dynamic leases
func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
leases := []*Lease{}
index := map[string]int{}
for i, lease := range staticLeases {
_, ok := index[lease.HWAddr.String()]
if ok {
continue // skip the lease with the same HW address
}
index[lease.HWAddr.String()] = i
leases = append(leases, lease)
}
for i, lease := range dynLeases {
_, ok := index[lease.HWAddr.String()]
if ok {
continue // skip the lease with the same HW address
}
index[lease.HWAddr.String()] = i
leases = append(leases, lease)
}
return leases
}
// Store lease table in DB
func (s *Server) dbStore() {
var leases []leaseJSON
leases4 := s.srv4.GetLeasesRef()
for _, l := range leases4 {
if l.Expiry.Unix() == 0 {
continue
}
lease := leaseJSON{
HWAddr: l.HWAddr,
IP: l.IP,
Hostname: l.Hostname,
Expiry: l.Expiry.Unix(),
}
leases = append(leases, lease)
}
if s.srv6 != nil {
leases6 := s.srv6.GetLeasesRef()
for _, l := range leases6 {
if l.Expiry.Unix() == 0 {
continue
}
lease := leaseJSON{
HWAddr: l.HWAddr,
IP: l.IP,
Hostname: l.Hostname,
Expiry: l.Expiry.Unix(),
}
leases = append(leases, lease)
}
}
data, err := json.Marshal(leases)
if err != nil {
log.Error("json.Marshal: %v", err)
return
}
err = file.SafeWrite(s.conf.DBFilePath, data)
if err != nil {
log.Error("DHCP: can't store lease table on disk: %v filename: %s",
err, s.conf.DBFilePath)
return
}
log.Info("DHCP: stored %d leases in DB", len(leases))
}

511
internal/dhcpd/dhcp_http.go Normal file
View File

@@ -0,0 +1,511 @@
package dhcpd
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/jsonutil"
"github.com/AdguardTeam/golibs/log"
)
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("DHCP: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
// []Lease -> JSON
func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string {
leases := []map[string]string{}
for _, l := range inputLeases {
lease := map[string]string{
"mac": l.HWAddr.String(),
"ip": l.IP.String(),
"hostname": l.Hostname,
}
if includeExpires {
lease["expires"] = l.Expiry.Format(time.RFC3339)
}
leases = append(leases, lease)
}
return leases
}
type v4ServerConfJSON struct {
GatewayIP string `json:"gateway_ip"`
SubnetMask string `json:"subnet_mask"`
RangeStart string `json:"range_start"`
RangeEnd string `json:"range_end"`
LeaseDuration uint32 `json:"lease_duration"`
}
func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON {
return v4ServerConfJSON{
GatewayIP: c.GatewayIP,
SubnetMask: c.SubnetMask,
RangeStart: c.RangeStart,
RangeEnd: c.RangeEnd,
LeaseDuration: c.LeaseDuration,
}
}
func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
return V4ServerConf{
GatewayIP: j.GatewayIP,
SubnetMask: j.SubnetMask,
RangeStart: j.RangeStart,
RangeEnd: j.RangeEnd,
LeaseDuration: j.LeaseDuration,
}
}
type v6ServerConfJSON struct {
RangeStart string `json:"range_start"`
LeaseDuration uint32 `json:"lease_duration"`
}
func v6ServerConfToJSON(c V6ServerConf) v6ServerConfJSON {
return v6ServerConfJSON{
RangeStart: c.RangeStart,
LeaseDuration: c.LeaseDuration,
}
}
func v6JSONToServerConf(j v6ServerConfJSON) V6ServerConf {
return V6ServerConf{
RangeStart: j.RangeStart,
LeaseDuration: j.LeaseDuration,
}
}
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
leases := convertLeases(s.Leases(LeasesDynamic), true)
staticLeases := convertLeases(s.Leases(LeasesStatic), false)
v4conf := V4ServerConf{}
s.srv4.WriteDiskConfig4(&v4conf)
v6conf := V6ServerConf{}
s.srv6.WriteDiskConfig6(&v6conf)
status := map[string]interface{}{
"enabled": s.conf.Enabled,
"interface_name": s.conf.InterfaceName,
"v4": v4ServerConfToJSON(v4conf),
"v6": v6ServerConfToJSON(v6conf),
"leases": leases,
"static_leases": staticLeases,
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(status)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err)
return
}
}
type staticLeaseJSON struct {
HWAddr string `json:"mac"`
IP string `json:"ip"`
Hostname string `json:"hostname"`
}
type dhcpServerConfigJSON struct {
Enabled bool `json:"enabled"`
InterfaceName string `json:"interface_name"`
V4 v4ServerConfJSON `json:"v4"`
V6 v6ServerConfJSON `json:"v6"`
}
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
newconfig := dhcpServerConfigJSON{}
newconfig.Enabled = s.conf.Enabled
newconfig.InterfaceName = s.conf.InterfaceName
js, err := jsonutil.DecodeObject(&newconfig, r.Body)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
return
}
var s4 DHCPServer
var s6 DHCPServer
v4Enabled := false
v6Enabled := false
if js.Exists("v4") {
v4conf := v4JSONToServerConf(newconfig.V4)
v4conf.Enabled = newconfig.Enabled
if len(v4conf.RangeStart) == 0 {
v4conf.Enabled = false
}
v4Enabled = v4conf.Enabled
v4conf.InterfaceName = newconfig.InterfaceName
c4 := V4ServerConf{}
s.srv4.WriteDiskConfig4(&c4)
v4conf.notify = c4.notify
v4conf.ICMPTimeout = c4.ICMPTimeout
s4, err = v4Create(v4conf)
if err != nil {
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv4 configuration: %s", err)
return
}
}
if js.Exists("v6") {
v6conf := v6JSONToServerConf(newconfig.V6)
v6conf.Enabled = newconfig.Enabled
if len(v6conf.RangeStart) == 0 {
v6conf.Enabled = false
}
v6Enabled = v6conf.Enabled
v6conf.InterfaceName = newconfig.InterfaceName
v6conf.notify = s.onNotify
s6, err = v6Create(v6conf)
if s6 == nil {
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv6 configuration: %s", err)
return
}
}
if newconfig.Enabled && !v4Enabled && !v6Enabled {
httpError(r, w, http.StatusBadRequest, "DHCPv4 or DHCPv6 configuration must be complete")
return
}
s.Stop()
if js.Exists("enabled") {
s.conf.Enabled = newconfig.Enabled
}
if js.Exists("interface_name") {
s.conf.InterfaceName = newconfig.InterfaceName
}
if s4 != nil {
s.srv4 = s4
}
if s6 != nil {
s.srv6 = s6
}
s.conf.ConfigModified()
s.dbLoad()
if s.conf.Enabled {
staticIP, err := HasStaticIP(newconfig.InterfaceName)
if !staticIP && err == nil {
err = SetStaticIP(newconfig.InterfaceName)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to configure static IP: %s", err)
return
}
}
err = s.Start()
if err != nil {
httpError(r, w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
return
}
}
}
type netInterfaceJSON struct {
Name string `json:"name"`
GatewayIP string `json:"gateway_ip"`
HardwareAddr string `json:"hardware_address"`
Addrs4 []string `json:"ipv4_addresses"`
Addrs6 []string `json:"ipv6_addresses"`
Flags string `json:"flags"`
}
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{}
ifaces, err := util.GetValidNetInterfaces()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
return
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 {
// it's a loopback, skip it
continue
}
if iface.Flags&net.FlagBroadcast == 0 {
// this interface doesn't support broadcast, skip it
continue
}
addrs, err := iface.Addrs()
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
return
}
jsonIface := netInterfaceJSON{
Name: iface.Name,
HardwareAddr: iface.HardwareAddr.String(),
}
if iface.Flags != 0 {
jsonIface.Flags = iface.Flags.String()
}
// we don't want link-local addresses in json, so skip them
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
// not an IPNet, should not happen
httpError(r, w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
return
}
// ignore link-local
if ipnet.IP.IsLinkLocalUnicast() {
continue
}
if ipnet.IP.To4() != nil {
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String())
} else {
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String())
}
}
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
jsonIface.GatewayIP = getGatewayIP(iface.Name)
response[iface.Name] = jsonIface
}
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err)
return
}
}
// Perform the following tasks:
// . Search for another DHCP server running
// . Check if a static IP is configured for the network interface
// Respond with results
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("failed to read request body: %s", err)
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
interfaceName := strings.TrimSpace(string(body))
if interfaceName == "" {
errorText := fmt.Sprintf("empty interface name specified")
log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
found4, err4 := CheckIfOtherDHCPServersPresentV4(interfaceName)
staticIP := map[string]interface{}{}
isStaticIP, err := HasStaticIP(interfaceName)
staticIPStatus := "yes"
if err != nil {
staticIPStatus = "error"
staticIP["error"] = err.Error()
} else if !isStaticIP {
staticIPStatus = "no"
staticIP["ip"] = util.GetSubnet(interfaceName)
}
staticIP["static"] = staticIPStatus
v4 := map[string]interface{}{}
othSrv := map[string]interface{}{}
foundVal := "no"
if found4 {
foundVal = "yes"
} else if err != nil {
foundVal = "error"
othSrv["error"] = err4.Error()
}
othSrv["found"] = foundVal
v4["other_server"] = othSrv
v4["static_ip"] = staticIP
found6, err6 := CheckIfOtherDHCPServersPresentV6(interfaceName)
v6 := map[string]interface{}{}
othSrv = map[string]interface{}{}
foundVal = "no"
if found6 {
foundVal = "yes"
} else if err6 != nil {
foundVal = "error"
othSrv["error"] = err6.Error()
}
othSrv["found"] = foundVal
v6["other_server"] = othSrv
result := map[string]interface{}{}
result["v4"] = v4
result["v6"] = v6
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(result)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err)
return
}
}
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
lj := staticLeaseJSON{}
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := net.ParseIP(lj.IP)
if ip != nil && ip.To4() == nil {
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
}
err = s.srv6.AddStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
return
}
ip, _ = parseIPv4(lj.IP)
if ip == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = s.srv4.AddStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
}
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
lj := staticLeaseJSON{}
err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
ip := net.ParseIP(lj.IP)
if ip != nil && ip.To4() == nil {
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
}
err = s.srv6.RemoveStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
return
}
ip, _ = parseIPv4(lj.IP)
if ip == nil {
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, _ := net.ParseMAC(lj.HWAddr)
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = s.srv4.RemoveStaticLease(lease)
if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err)
return
}
}
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
s.Stop()
err := os.Remove(s.conf.DBFilePath)
if err != nil && !os.IsNotExist(err) {
log.Error("DHCP: os.Remove: %s: %s", s.conf.DBFilePath, err)
}
oldconf := s.conf
s.conf = ServerConfig{}
s.conf.WorkDir = oldconf.WorkDir
s.conf.HTTPRegister = oldconf.HTTPRegister
s.conf.ConfigModified = oldconf.ConfigModified
s.conf.DBFilePath = oldconf.DBFilePath
v4conf := V4ServerConf{}
v4conf.ICMPTimeout = 1000
v4conf.notify = s.onNotify
s.srv4, _ = v4Create(v4conf)
v6conf := V6ServerConf{}
v6conf.notify = s.onNotify
s.srv6, _ = v6Create(v6conf)
s.conf.ConfigModified()
}
func (s *Server) registerHandlers() {
s.conf.HTTPRegister("GET", "/control/dhcp/status", s.handleDHCPStatus)
s.conf.HTTPRegister("GET", "/control/dhcp/interfaces", s.handleDHCPInterfaces)
s.conf.HTTPRegister("POST", "/control/dhcp/set_config", s.handleDHCPSetConfig)
s.conf.HTTPRegister("POST", "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
s.conf.HTTPRegister("POST", "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
s.conf.HTTPRegister("POST", "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
s.conf.HTTPRegister("POST", "/control/dhcp/reset", s.handleReset)
}

259
internal/dhcpd/dhcpd.go Normal file
View File

@@ -0,0 +1,259 @@
package dhcpd
import (
"encoding/hex"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/log"
)
const defaultDiscoverTime = time.Second * 3
const leaseExpireStatic = 1
var webHandlersRegistered = false
// Lease contains the necessary information about a DHCP lease
type Lease struct {
HWAddr net.HardwareAddr `json:"mac"`
IP net.IP `json:"ip"`
Hostname string `json:"hostname"`
// Lease expiration time
// 1: static lease
Expiry time.Time `json:"expires"`
}
// ServerConfig - DHCP server configuration
// field ordering is important -- yaml fields will mirror ordering from here
type ServerConfig struct {
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`
Conf4 V4ServerConf `yaml:"dhcpv4"`
Conf6 V6ServerConf `yaml:"dhcpv6"`
WorkDir string `yaml:"-"`
DBFilePath string `yaml:"-"` // path to DB file
// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
}
type OnLeaseChangedT func(flags int)
// flags for onLeaseChanged()
const (
LeaseChangedAdded = iota
LeaseChangedAddedStatic
LeaseChangedRemovedStatic
LeaseChangedDBStore
)
// Server - the current state of the DHCP server
type Server struct {
srv4 DHCPServer
srv6 DHCPServer
conf ServerConfig
// Called when the leases DB is modified
onLeaseChanged []OnLeaseChangedT
}
type ServerInterface interface {
Leases(flags int) []Lease
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
}
// CheckConfig checks the configuration
func (s *Server) CheckConfig(config ServerConfig) error {
return nil
}
// Create - create object
func Create(config ServerConfig) *Server {
s := Server{}
s.conf.Enabled = config.Enabled
s.conf.InterfaceName = config.InterfaceName
s.conf.HTTPRegister = config.HTTPRegister
s.conf.ConfigModified = config.ConfigModified
s.conf.DBFilePath = filepath.Join(config.WorkDir, dbFilename)
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
webHandlersRegistered = true
s.registerHandlers()
}
var err4, err6 error
v4conf := config.Conf4
v4conf.Enabled = s.conf.Enabled
if len(v4conf.RangeStart) == 0 {
v4conf.Enabled = false
}
v4conf.InterfaceName = s.conf.InterfaceName
v4conf.notify = s.onNotify
s.srv4, err4 = v4Create(v4conf)
v6conf := config.Conf6
v6conf.Enabled = s.conf.Enabled
if len(v6conf.RangeStart) == 0 {
v6conf.Enabled = false
}
v6conf.InterfaceName = s.conf.InterfaceName
v6conf.notify = s.onNotify
s.srv6, err6 = v6Create(v6conf)
if err4 != nil {
log.Error("%s", err4)
return nil
}
if err6 != nil {
log.Error("%s", err6)
return nil
}
if s.conf.Enabled && !v4conf.Enabled && !v6conf.Enabled {
log.Error("Can't enable DHCP server because neither DHCPv4 nor DHCPv6 servers are configured")
return nil
}
// we can't delay database loading until DHCP server is started,
// because we need static leases functionality available beforehand
s.dbLoad()
return &s
}
// server calls this function after DB is updated
func (s *Server) onNotify(flags uint32) {
if flags == LeaseChangedDBStore {
s.dbStore()
return
}
s.notify(int(flags))
}
// SetOnLeaseChanged - set callback
func (s *Server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
}
func (s *Server) notify(flags int) {
if len(s.onLeaseChanged) == 0 {
return
}
for _, f := range s.onLeaseChanged {
f(flags)
}
}
// WriteDiskConfig - write configuration
func (s *Server) WriteDiskConfig(c *ServerConfig) {
c.Enabled = s.conf.Enabled
c.InterfaceName = s.conf.InterfaceName
s.srv4.WriteDiskConfig4(&c.Conf4)
s.srv6.WriteDiskConfig6(&c.Conf6)
}
// Start will listen on port 67 and serve DHCP requests.
func (s *Server) Start() error {
err := s.srv4.Start()
if err != nil {
log.Error("DHCPv4: start: %s", err)
return err
}
err = s.srv6.Start()
if err != nil {
log.Error("DHCPv6: start: %s", err)
return err
}
return nil
}
// Stop closes the listening UDP socket
func (s *Server) Stop() {
s.srv4.Stop()
s.srv6.Stop()
}
// flags for Leases() function
const (
LeasesDynamic = 1
LeasesStatic = 2
LeasesAll = LeasesDynamic | LeasesStatic
)
// Leases returns the list of current DHCP leases (thread-safe)
func (s *Server) Leases(flags int) []Lease {
result := s.srv4.GetLeases(flags)
v6leases := s.srv6.GetLeases(flags)
result = append(result, v6leases...)
return result
}
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
if ip.To4() != nil {
return s.srv4.FindMACbyIP(ip)
}
return s.srv6.FindMACbyIP(ip)
}
// AddStaticLease - add static v4 lease
func (s *Server) AddStaticLease(lease Lease) error {
return s.srv4.AddStaticLease(lease)
}
// Parse option string
// Format:
// CODE TYPE VALUE
func parseOptionString(s string) (uint8, []byte) {
s = strings.TrimSpace(s)
scode := util.SplitNext(&s, ' ')
t := util.SplitNext(&s, ' ')
sval := util.SplitNext(&s, ' ')
code, err := strconv.Atoi(scode)
if err != nil || code <= 0 || code > 255 {
return 0, nil
}
var val []byte
switch t {
case "hex":
val, err = hex.DecodeString(sval)
if err != nil {
return 0, nil
}
case "ip":
ip := net.ParseIP(sval)
if ip == nil {
return 0, nil
}
val = ip
if ip.To4() != nil {
val = ip.To4()
}
default:
return 0, nil
}
return uint8(code), val
}

View File

@@ -0,0 +1,132 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"bytes"
"net"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func check(t *testing.T, result bool, msg string) {
if !result {
t.Fatal(msg)
}
}
func testNotify(flags uint32) {
}
// Leases database store/load
func TestDB(t *testing.T) {
var err error
s := Server{}
s.conf.DBFilePath = dbFilename
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
notify: testNotify,
}
s.srv4, err = v4Create(conf)
assert.True(t, err == nil)
s.srv6, err = v6Create(V6ServerConf{})
assert.True(t, err == nil)
l := Lease{}
l.IP = net.ParseIP("192.168.10.100").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
exp1 := time.Now().Add(time.Hour)
l.Expiry = exp1
s.srv4.(*v4Server).addLease(&l)
l2 := Lease{}
l2.IP = net.ParseIP("192.168.10.101").To4()
l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb")
s.srv4.AddStaticLease(l2)
_ = os.Remove("leases.db")
s.dbStore()
s.srv4.ResetLeases(nil)
s.dbLoad()
ll := s.srv4.GetLeases(LeasesAll)
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
assert.Equal(t, "192.168.10.101", ll[0].IP.String())
assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
assert.Equal(t, "192.168.10.100", ll[1].IP.String())
assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix())
_ = os.Remove("leases.db")
}
func TestIsValidSubnetMask(t *testing.T) {
assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1}))
}
func TestNormalizeLeases(t *testing.T) {
dynLeases := []*Lease{}
staticLeases := []*Lease{}
leases := []*Lease{}
lease := &Lease{}
lease.HWAddr = []byte{1, 2, 3, 4}
dynLeases = append(dynLeases, lease)
lease = new(Lease)
lease.HWAddr = []byte{1, 2, 3, 5}
dynLeases = append(dynLeases, lease)
lease = new(Lease)
lease.HWAddr = []byte{1, 2, 3, 4}
lease.IP = []byte{0, 2, 3, 4}
staticLeases = append(staticLeases, lease)
lease = new(Lease)
lease.HWAddr = []byte{2, 2, 3, 4}
staticLeases = append(staticLeases, lease)
leases = normalizeLeases(staticLeases, dynLeases)
assert.True(t, len(leases) == 3)
assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[2].HWAddr, []byte{1, 2, 3, 5}))
}
func TestOptions(t *testing.T) {
code, val := parseOptionString(" 12 hex abcdef ")
assert.Equal(t, uint8(12), code)
assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val))
code, _ = parseOptionString(" 12 hex abcdef1 ")
assert.Equal(t, uint8(0), code)
code, val = parseOptionString("123 ip 1.2.3.4")
assert.Equal(t, uint8(123), code)
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
code, _ = parseOptionString("256 ip 1.1.1.1")
assert.Equal(t, uint8(0), code)
code, _ = parseOptionString("-1 ip 1.1.1.1")
assert.Equal(t, uint8(0), code)
code, _ = parseOptionString("12 ip 1.1.1.1x")
assert.Equal(t, uint8(0), code)
code, _ = parseOptionString("12 x 1.1.1.1")
assert.Equal(t, uint8(0), code)
}

76
internal/dhcpd/helpers.go Normal file
View File

@@ -0,0 +1,76 @@
package dhcpd
import (
"encoding/binary"
"fmt"
"net"
"github.com/AdguardTeam/golibs/log"
"github.com/joomcode/errorx"
)
func isTimeout(err error) bool {
operr, ok := err.(*net.OpError)
if !ok {
return false
}
return operr.Timeout()
}
// Get IPv4 address list
func getIfaceIPv4(iface net.Interface) []net.IP {
addrs, err := iface.Addrs()
if err != nil {
return nil
}
var res []net.IP
for _, a := range addrs {
ipnet, ok := a.(*net.IPNet)
if !ok {
continue
}
if ipnet.IP.To4() != nil {
res = append(res, ipnet.IP.To4())
}
}
return res
}
func wrapErrPrint(err error, message string, args ...interface{}) error {
var errx error
if err == nil {
errx = fmt.Errorf(message, args...)
} else {
errx = errorx.Decorate(err, message, args...)
}
log.Println(errx.Error())
return errx
}
func parseIPv4(text string) (net.IP, error) {
result := net.ParseIP(text)
if result == nil {
return nil, fmt.Errorf("%s is not an IP address", text)
}
if result.To4() == nil {
return nil, fmt.Errorf("%s is not an IPv4 address", text)
}
return result.To4(), nil
}
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
func isValidSubnetMask(mask net.IP) bool {
var n uint32
n = binary.BigEndian.Uint32(mask)
for i := 0; i != 32; i++ {
if n == 0 {
break
}
if (n & 0x80000000) == 0 {
return false
}
n <<= 1
}
return true
}

View File

@@ -0,0 +1,586 @@
// Copyright 2018 the u-root Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
// +build go1.12
// Package nclient4 is a small, minimum-functionality client for DHCPv4.
//
// It only supports the 4-way DHCPv4 Discover-Offer-Request-Ack handshake as
// well as the Request-Ack renewal process.
// Originally from here: github.com/insomniacslk/dhcp/dhcpv4/nclient4
// with the difference that this package can be built on UNIX (not just Linux),
// because github.com/mdlayher/raw package supports it.
package nclient4
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"net"
"os"
"sync"
"sync/atomic"
"time"
"github.com/insomniacslk/dhcp/dhcpv4"
)
const (
defaultBufferCap = 5
// DefaultTimeout is the default value for read-timeout if option WithTimeout is not set
DefaultTimeout = 5 * time.Second
// DefaultRetries is amount of retries will be done if no answer was received within read-timeout amount of time
DefaultRetries = 3
// MaxMessageSize is the value to be used for DHCP option "MaxMessageSize".
MaxMessageSize = 1500
// ClientPort is the port that DHCP clients listen on.
ClientPort = 68
// ServerPort is the port that DHCP servers and relay agents listen on.
ServerPort = 67
)
var (
// DefaultServers is the address of all link-local DHCP servers and
// relay agents.
DefaultServers = &net.UDPAddr{
IP: net.IPv4bcast,
Port: ServerPort,
}
)
var (
// ErrNoResponse is returned when no response packet is received.
ErrNoResponse = errors.New("no matching response packet received")
// ErrNoConn is returned when NewWithConn is called with nil-value as conn.
ErrNoConn = errors.New("conn is nil")
// ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr
ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil")
)
// pendingCh is a channel associated with a pending TransactionID.
type pendingCh struct {
// SendAndRead closes done to indicate that it wishes for no more
// messages for this particular XID.
done <-chan struct{}
// ch is used by the receive loop to distribute DHCP messages.
ch chan<- *dhcpv4.DHCPv4
}
// Logger is a handler which will be used to output logging messages
type Logger interface {
// PrintMessage print _all_ DHCP messages
PrintMessage(prefix string, message *dhcpv4.DHCPv4)
// Printf is use to print the rest debugging information
Printf(format string, v ...interface{})
}
// EmptyLogger prints nothing
type EmptyLogger struct{}
// Printf is just a dummy function that does nothing
func (e EmptyLogger) Printf(format string, v ...interface{}) {}
// PrintMessage is just a dummy function that does nothing
func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {}
// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer.
type Printfer interface {
// Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf.
Printf(format string, v ...interface{})
}
// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger.
// DHCP messages are printed in the short format.
type ShortSummaryLogger struct {
// Printfer is used for actual output of the logger
Printfer
}
// Printf prints a log message as-is via predefined Printfer
func (s ShortSummaryLogger) Printf(format string, v ...interface{}) {
s.Printfer.Printf(format, v...)
}
// PrintMessage prints a DHCP message in the short format via predefined Printfer
func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
s.Printf("%s: %s", prefix, message)
}
// DebugLogger is a wrapper for Printfer to implement interface Logger.
// DHCP messages are printed in the long format.
type DebugLogger struct {
// Printfer is used for actual output of the logger
Printfer
}
// Printf prints a log message as-is via predefined Printfer
func (d DebugLogger) Printf(format string, v ...interface{}) {
d.Printfer.Printf(format, v...)
}
// PrintMessage prints a DHCP message in the long format via predefined Printfer
func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
d.Printf("%s: %s", prefix, message.Summary())
}
// Client is an IPv4 DHCP client.
type Client struct {
ifaceHWAddr net.HardwareAddr
conn net.PacketConn
timeout time.Duration
retry int
logger Logger
// bufferCap is the channel capacity for each TransactionID.
bufferCap int
// serverAddr is the UDP address to send all packets to.
//
// This may be an actual broadcast address, or a unicast address.
serverAddr *net.UDPAddr
// closed is an atomic bool set to 1 when done is closed.
closed uint32
// done is closed to unblock the receive loop.
done chan struct{}
// wg protects any spawned goroutines, namely the receiveLoop.
wg sync.WaitGroup
pendingMu sync.Mutex
// pending stores the distribution channels for each pending
// TransactionID. receiveLoop uses this map to determine which channel
// to send a new DHCP message to.
pending map[dhcpv4.TransactionID]*pendingCh
}
// New returns a client usable with an unconfigured interface.
func New(iface string, opts ...ClientOpt) (*Client, error) {
return new(iface, nil, nil, opts...)
}
// NewWithConn creates a new DHCP client that sends and receives packets on the
// given interface.
func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
return new(``, conn, ifaceHWAddr, opts...)
}
func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
c := &Client{
ifaceHWAddr: ifaceHWAddr,
timeout: DefaultTimeout,
retry: DefaultRetries,
serverAddr: DefaultServers,
bufferCap: defaultBufferCap,
conn: conn,
logger: EmptyLogger{},
done: make(chan struct{}),
pending: make(map[dhcpv4.TransactionID]*pendingCh),
}
for _, opt := range opts {
err := opt(c)
if err != nil {
return nil, fmt.Errorf("unable to apply option: %w", err)
}
}
if c.ifaceHWAddr == nil {
if iface == `` {
return nil, ErrNoIfaceHWAddr
}
i, err := net.InterfaceByName(iface)
if err != nil {
return nil, fmt.Errorf("unable to get interface information: %w", err)
}
c.ifaceHWAddr = i.HardwareAddr
}
if c.conn == nil {
var err error
if iface == `` {
return nil, ErrNoConn
}
c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast
if err != nil {
return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err)
}
}
c.wg.Add(1)
go c.receiveLoop()
return c, nil
}
// Close closes the underlying connection.
func (c *Client) Close() error {
// Make sure not to close done twice.
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return nil
}
err := c.conn.Close()
// Closing c.done sets off a chain reaction:
//
// Any SendAndRead unblocks trying to receive more messages, which
// means rem() gets called.
//
// rem() should be unblocking receiveLoop if it is blocked.
//
// receiveLoop should then exit gracefully.
close(c.done)
// Wait for receiveLoop to stop.
c.wg.Wait()
return err
}
func (c *Client) isClosed() bool {
return atomic.LoadUint32(&c.closed) != 0
}
func (c *Client) receiveLoop() {
defer c.wg.Done()
for {
// TODO: Clients can send a "max packet size" option in their
// packets, IIRC. Choose a reasonable size and set it.
b := make([]byte, MaxMessageSize)
n, _, err := c.conn.ReadFrom(b)
if err != nil {
if !c.isClosed() {
c.logger.Printf("error reading from UDP connection: %v", err)
}
return
}
msg, err := dhcpv4.FromBytes(b[:n])
if err != nil {
// Not a valid DHCP packet; keep listening.
continue
}
if msg.OpCode != dhcpv4.OpcodeBootReply {
// Not a response message.
continue
}
// This is a somewhat non-standard check, by the looks
// of RFC 2131. It should work as long as the DHCP
// server is spec-compliant for the HWAddr field.
if c.ifaceHWAddr != nil && !bytes.Equal(c.ifaceHWAddr, msg.ClientHWAddr) {
// Not for us.
continue
}
c.pendingMu.Lock()
p, ok := c.pending[msg.TransactionID]
if ok {
select {
case <-p.done:
close(p.ch)
delete(c.pending, msg.TransactionID)
// This send may block.
case p.ch <- msg:
}
}
c.pendingMu.Unlock()
}
}
// ClientOpt is a function that configures the Client.
type ClientOpt func(c *Client) error
// WithTimeout configures the retransmission timeout.
//
// Default is 5 seconds.
func WithTimeout(d time.Duration) ClientOpt {
return func(c *Client) (err error) {
c.timeout = d
return
}
}
// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received.
func WithSummaryLogger() ClientOpt {
return func(c *Client) (err error) {
c.logger = ShortSummaryLogger{
Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
}
return
}
}
// WithDebugLogger logs multi-line full DHCPv4 messages when sent & received.
func WithDebugLogger() ClientOpt {
return func(c *Client) (err error) {
c.logger = DebugLogger{
Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
}
return
}
}
// WithLogger set the logger (see interface Logger).
func WithLogger(newLogger Logger) ClientOpt {
return func(c *Client) (err error) {
c.logger = newLogger
return
}
}
// WithUnicast forces client to send messages as unicast frames.
// By default client sends messages as broadcast frames even if server address is defined.
//
// srcAddr is both:
// * The source address of outgoing frames.
// * The address to be listened for incoming frames.
func WithUnicast(srcAddr *net.UDPAddr) ClientOpt {
return func(c *Client) (err error) {
if srcAddr == nil {
srcAddr = &net.UDPAddr{Port: ServerPort}
}
c.conn, err = net.ListenUDP("udp4", srcAddr)
if err != nil {
err = fmt.Errorf("unable to start listening UDP port: %w", err)
}
return
}
}
// WithHWAddr tells to the Client to receive messages destinated to selected
// hardware address
func WithHWAddr(hwAddr net.HardwareAddr) ClientOpt {
return func(c *Client) (err error) {
c.ifaceHWAddr = hwAddr
return
}
}
// nolint
func withBufferCap(n int) ClientOpt {
return func(c *Client) (err error) {
c.bufferCap = n
return
}
}
// WithRetry configures the number of retransmissions to attempt.
//
// Default is 3.
func WithRetry(r int) ClientOpt {
return func(c *Client) (err error) {
c.retry = r
return
}
}
// WithServerAddr configures the address to send messages to.
func WithServerAddr(n *net.UDPAddr) ClientOpt {
return func(c *Client) (err error) {
c.serverAddr = n
return
}
}
// Matcher matches DHCP packets.
type Matcher func(*dhcpv4.DHCPv4) bool
// IsMessageType returns a matcher that checks for the message type.
//
// If t is MessageTypeNone, all packets are matched.
func IsMessageType(t dhcpv4.MessageType) Matcher {
return func(p *dhcpv4.DHCPv4) bool {
return p.MessageType() == t || t == dhcpv4.MessageTypeNone
}
}
// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer
// received.
func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) {
// RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should
// contain.
discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers,
dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
if err != nil {
err = fmt.Errorf("unable to create a discovery request: %w", err)
return
}
offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
if err != nil {
err = fmt.Errorf("got an error while the discovery request: %w", err)
return
}
return
}
// Request completes the 4-way Discover-Offer-Request-Ack handshake.
//
// Note that modifiers will be applied *both* to Discover and Request packets.
func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) {
offer, err = c.DiscoverOffer(ctx, modifiers...)
if err != nil {
err = fmt.Errorf("unable to receive an offer: %w", err)
return
}
// TODO(chrisko): should this be unicast to the server?
request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
if err != nil {
err = fmt.Errorf("unable to create a request: %w", err)
return
}
ack, err = c.SendAndRead(ctx, c.serverAddr, request, nil)
if err != nil {
err = fmt.Errorf("got an error while processing the request: %w", err)
return
}
return
}
// ErrTransactionIDInUse is returned if there were an attempt to send a message
// with the same TransactionID as we are already waiting an answer for.
type ErrTransactionIDInUse struct {
// TransactionID is the transaction ID of the message which the error is related to.
TransactionID dhcpv4.TransactionID
}
// Error is just the method to comply interface "error".
func (err *ErrTransactionIDInUse) Error() string {
return fmt.Sprintf("transaction ID %s already in use", err.TransactionID)
}
// send sends p to destination and returns a response channel.
//
// Responses will be matched by transaction ID and ClientHWAddr.
//
// The returned lambda function must be called after all desired responses have
// been received in order to return the Transaction ID to the usable pool.
func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv4.DHCPv4, cancel func(), err error) {
c.pendingMu.Lock()
if _, ok := c.pending[msg.TransactionID]; ok {
c.pendingMu.Unlock()
return nil, nil, &ErrTransactionIDInUse{msg.TransactionID}
}
ch := make(chan *dhcpv4.DHCPv4, c.bufferCap)
done := make(chan struct{})
c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
c.pendingMu.Unlock()
cancel = func() {
// Why can't we just close ch here?
//
// Because receiveLoop may potentially be blocked trying to
// send on ch. We gotta unblock it first, and then we can take
// the lock and remove the XID from the pending transaction
// map.
close(done)
c.pendingMu.Lock()
if p, ok := c.pending[msg.TransactionID]; ok {
close(p.ch)
delete(c.pending, msg.TransactionID)
}
c.pendingMu.Unlock()
}
if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
cancel()
return nil, nil, fmt.Errorf("error writing packet to connection: %w", err)
}
return ch, cancel, nil
}
// This error should never be visible to users.
// It is used only to increase the timeout in retryFn.
var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
// SendAndRead sends a packet p to a destination dest and waits for the first
// response matching `match` as well as its Transaction ID and ClientHWAddr.
//
// If match is nil, the first packet matching the Transaction ID and
// ClientHWAddr is returned.
func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, p *dhcpv4.DHCPv4, match Matcher) (*dhcpv4.DHCPv4, error) {
var response *dhcpv4.DHCPv4
err := c.retryFn(func(timeout time.Duration) error {
ch, rem, err := c.send(dest, p)
if err != nil {
return err
}
c.logger.PrintMessage("sent message", p)
defer rem()
for {
select {
case <-c.done:
return ErrNoResponse
case <-time.After(timeout):
return errDeadlineExceeded
case <-ctx.Done():
return ctx.Err()
case packet := <-ch:
if match == nil || match(packet) {
c.logger.PrintMessage("received message", packet)
response = packet
return nil
}
}
}
})
if err == errDeadlineExceeded {
return nil, ErrNoResponse
}
if err != nil {
return nil, err
}
return response, nil
}
func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
timeout := c.timeout
// Each retry takes the amount of timeout at worst.
for i := 0; i < c.retry || c.retry < 0; i++ { // TODO: why is this called "retry" if this is "tries" ("retries"+1)?
switch err := fn(timeout); err {
case nil:
// Got it!
return nil
case errDeadlineExceeded:
// Double timeout, then retry.
timeout *= 2
default:
return err
}
}
return errDeadlineExceeded
}

View File

@@ -0,0 +1,333 @@
// Copyright 2018 the u-root Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux
// github.com/hugelgupf/socketpair is Linux-only
// +build go1.12
package nclient4
import (
"bytes"
"context"
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/hugelgupf/socketpair"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
)
type handler struct {
mu sync.Mutex
received []*dhcpv4.DHCPv4
// Each received packet can have more than one response (in theory,
// from different servers sending different Advertise, for example).
responses [][]*dhcpv4.DHCPv4
}
func (h *handler) handle(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
h.mu.Lock()
defer h.mu.Unlock()
h.received = append(h.received, m)
if len(h.responses) > 0 {
for _, resp := range h.responses[0] {
_, _ = conn.WriteTo(resp.ToBytes(), peer)
}
h.responses = h.responses[1:]
}
}
func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...ClientOpt) (*Client, net.PacketConn) {
// Fake PacketConn connection.
clientRawConn, serverRawConn, err := socketpair.PacketSocketPair()
if err != nil {
panic(err)
}
clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{Port: ClientPort})
serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort})
o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)}
o = append(o, opts...)
mc, err := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...)
if err != nil {
panic(err)
}
h := &handler{responses: responses}
s, err := server4.NewServer("", nil, h.handle, server4.WithConn(serverConn))
if err != nil {
panic(err)
}
go func() {
_ = s.Serve()
}()
return mc, serverConn
}
func ComparePacket(got *dhcpv4.DHCPv4, want *dhcpv4.DHCPv4) error {
if got == nil && got == want {
return nil
}
if (want == nil || got == nil) && (got != want) {
return fmt.Errorf("packet got %v, want %v", got, want)
}
if !bytes.Equal(got.ToBytes(), want.ToBytes()) {
return fmt.Errorf("packet got %v, want %v", got, want)
}
return nil
}
func pktsExpected(got []*dhcpv4.DHCPv4, want []*dhcpv4.DHCPv4) error {
if len(got) != len(want) {
return fmt.Errorf("got %d packets, want %d packets", len(got), len(want))
}
for i := range got {
if err := ComparePacket(got[i], want[i]); err != nil {
return err
}
}
return nil
}
func newPacketWeirdHWAddr(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 {
p, err := dhcpv4.New()
if err != nil {
panic(fmt.Sprintf("newpacket: %v", err))
}
p.OpCode = op
p.TransactionID = xid
p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 1, 2, 3, 4, 5, 6}
return p
}
func newPacket(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 {
p, err := dhcpv4.New()
if err != nil {
panic(fmt.Sprintf("newpacket: %v", err))
}
p.OpCode = op
p.TransactionID = xid
p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
return p
}
func TestSendAndRead(t *testing.T) {
for _, tt := range []struct {
desc string
send *dhcpv4.DHCPv4
server []*dhcpv4.DHCPv4
// If want is nil, we assume server[0] contains what is wanted.
want *dhcpv4.DHCPv4
wantErr error
}{
{
desc: "two response packets",
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
server: []*dhcpv4.DHCPv4{
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
{
desc: "one response packet",
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
server: []*dhcpv4.DHCPv4{
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
{
desc: "one response packet, one invalid XID, one invalid opcode, one invalid hwaddr",
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
server: []*dhcpv4.DHCPv4{
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x77, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacketWeirdHWAddr(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
{
desc: "discard wrong XID",
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
server: []*dhcpv4.DHCPv4{
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0, 0, 0, 0}),
},
want: nil, // Explicitly empty.
wantErr: ErrNoResponse,
},
{
desc: "no response, timeout",
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
wantErr: ErrNoResponse,
},
} {
t.Run(tt.desc, func(t *testing.T) {
// Both server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{tt.server},
// Use an unbuffered channel to make sure we
// have no deadlocks.
withBufferCap(0))
defer mc.Close()
rcvd, err := mc.SendAndRead(context.Background(), DefaultServers, tt.send, nil)
if err != tt.wantErr {
t.Error(err)
}
if err := ComparePacket(rcvd, tt.want); err != nil {
t.Errorf("got unexpected packets: %v", err)
}
})
}
}
func TestParallelSendAndRead(t *testing.T) {
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
// Both the server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{},
WithTimeout(10*time.Second),
// Use an unbuffered channel to make sure nothing blocks.
withBufferCap(0))
defer mc.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(4 * time.Second)
if err := mc.Close(); err != nil {
t.Errorf("closing failed: %v", err)
}
}()
wg.Wait()
}
func TestReuseXID(t *testing.T) {
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
// Both the server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{})
defer mc.Close()
if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
}
if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
}
}
func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
responses := newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33})
// Both the server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
mc, udpConn := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{{responses}})
defer mc.Close()
// Too short for valid DHCPv4 packet.
_, _ = udpConn.WriteTo([]byte{0x01}, nil)
_, _ = udpConn.WriteTo([]byte{0x01, 0x2}, nil)
rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil)
if err != nil {
t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err)
}
if err := ComparePacket(rcvd, responses); err != nil {
t.Errorf("got unexpected packets: %v", err)
}
}
func TestMultipleSendAndRead(t *testing.T) {
for _, tt := range []struct {
desc string
send []*dhcpv4.DHCPv4
server [][]*dhcpv4.DHCPv4
wantErr []error
}{
{
desc: "two requests, two responses",
send: []*dhcpv4.DHCPv4{
newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x44, 0x44, 0x44, 0x44}),
},
server: [][]*dhcpv4.DHCPv4{
[]*dhcpv4.DHCPv4{ // Response for first packet.
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
},
[]*dhcpv4.DHCPv4{ // Response for second packet.
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x44, 0x44, 0x44, 0x44}),
},
},
wantErr: []error{
nil,
nil,
},
},
} {
// Both server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
mc, _ := serveAndClient(ctx, tt.server)
defer mc.Close()
for i, send := range tt.send {
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
rcvd, err := mc.SendAndRead(ctx, DefaultServers, send, nil)
if wantErr := tt.wantErr[i]; err != wantErr {
t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr)
}
if err := pktsExpected([]*dhcpv4.DHCPv4{rcvd}, tt.server[i]); err != nil {
t.Errorf("got unexpected packets: %v", err)
}
}
}
}

View File

@@ -0,0 +1,144 @@
// Copyright 2018 the u-root Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
// +build go1.12
package nclient4
import (
"errors"
"io"
"net"
"github.com/mdlayher/ethernet"
"github.com/mdlayher/raw"
"github.com/u-root/u-root/pkg/uio"
)
var (
// BroadcastMac is the broadcast MAC address.
//
// Any UDP packet sent to this address is broadcast on the subnet.
BroadcastMac = net.HardwareAddr([]byte{255, 255, 255, 255, 255, 255})
)
var (
// ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr".
ErrUDPAddrIsRequired = errors.New("must supply UDPAddr")
)
// NewRawUDPConn returns a UDP connection bound to the interface and port
// given based on a raw packet socket. All packets are broadcasted.
//
// The interface can be completely unconfigured.
func NewRawUDPConn(iface string, port int) (net.PacketConn, error) {
ifc, err := net.InterfaceByName(iface)
if err != nil {
return nil, err
}
rawConn, err := raw.ListenPacket(ifc, uint16(ethernet.EtherTypeIPv4), &raw.Config{LinuxSockDGRAM: true})
if err != nil {
return nil, err
}
return NewBroadcastUDPConn(rawConn, &net.UDPAddr{Port: port}), nil
}
// BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast
// MAC address.
type BroadcastRawUDPConn struct {
// PacketConn is a raw DGRAM socket.
net.PacketConn
// boundAddr is the address this RawUDPConn is "bound" to.
//
// Calls to ReadFrom will only return packets destined to this address.
boundAddr *net.UDPAddr
}
// NewBroadcastUDPConn returns a PacketConn that marshals and unmarshals UDP
// packets, sending them to the broadcast MAC at on rawPacketConn.
//
// Calls to ReadFrom will only return packets destined to boundAddr.
func NewBroadcastUDPConn(rawPacketConn net.PacketConn, boundAddr *net.UDPAddr) net.PacketConn {
return &BroadcastRawUDPConn{
PacketConn: rawPacketConn,
boundAddr: boundAddr,
}
}
func udpMatch(addr *net.UDPAddr, bound *net.UDPAddr) bool {
if bound == nil {
return true
}
if bound.IP != nil && !bound.IP.Equal(addr.IP) {
return false
}
return bound.Port == addr.Port
}
// ReadFrom implements net.PacketConn.ReadFrom.
//
// ReadFrom reads raw IP packets and will try to match them against
// upc.boundAddr. Any matching packets are returned via the given buffer.
func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
ipHdrMaxLen := IPv4MaximumHeaderSize
udpHdrLen := UDPMinimumSize
for {
pkt := make([]byte, ipHdrMaxLen+udpHdrLen+len(b))
n, _, err := upc.PacketConn.ReadFrom(pkt)
if err != nil {
return 0, nil, err
}
if n == 0 {
return 0, nil, io.EOF
}
pkt = pkt[:n]
buf := uio.NewBigEndianBuffer(pkt)
// To read the header length, access data directly.
ipHdr := IPv4(buf.Data())
ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength())))
if ipHdr.TransportProtocol() != UDPProtocolNumber {
continue
}
udpHdr := UDP(buf.Consume(udpHdrLen))
addr := &net.UDPAddr{
IP: ipHdr.DestinationAddress(),
Port: int(udpHdr.DestinationPort()),
}
if !udpMatch(addr, upc.boundAddr) {
continue
}
srcAddr := &net.UDPAddr{
IP: ipHdr.SourceAddress(),
Port: int(udpHdr.SourcePort()),
}
// Extra padding after end of IP packet should be ignored,
// if not dhcp option parsing will fail.
dhcpLen := int(ipHdr.PayloadLength()) - udpHdrLen
return copy(b, buf.Consume(dhcpLen)), srcAddr, nil
}
}
// WriteTo implements net.PacketConn.WriteTo and broadcasts all packets at the
// raw socket level.
//
// WriteTo wraps the given packet in the appropriate UDP and IP header before
// sending it on the packet conn.
func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, ErrUDPAddrIsRequired
}
// Using the boundAddr is not quite right here, but it works.
packet := udp4pkt(b, udpAddr, upc.boundAddr)
// Broadcasting is not always right, but hell, what the ARP do I know.
return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: BroadcastMac})
}

View File

@@ -0,0 +1,377 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file contains code taken from gVisor.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
// +build go1.12
package nclient4
import (
"encoding/binary"
"net"
"github.com/u-root/u-root/pkg/uio"
)
const (
versIHL = 0
tos = 1
totalLen = 2
id = 4
flagsFO = 6
ttl = 8
protocol = 9
checksum = 10
srcAddr = 12
dstAddr = 16
)
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv4Fields struct {
// IHL is the "internet header length" field of an IPv4 packet.
IHL uint8
// TOS is the "type of service" field of an IPv4 packet.
TOS uint8
// TotalLength is the "total length" field of an IPv4 packet.
TotalLength uint16
// ID is the "identification" field of an IPv4 packet.
ID uint16
// Flags is the "flags" field of an IPv4 packet.
Flags uint8
// FragmentOffset is the "fragment offset" field of an IPv4 packet.
FragmentOffset uint16
// TTL is the "time to live" field of an IPv4 packet.
TTL uint8
// Protocol is the "protocol" field of an IPv4 packet.
Protocol uint8
// Checksum is the "checksum" field of an IPv4 packet.
Checksum uint16
// SrcAddr is the "source ip address" of an IPv4 packet.
SrcAddr net.IP
// DstAddr is the "destination ip address" of an IPv4 packet.
DstAddr net.IP
}
// IPv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv4 before using other methods.
type IPv4 []byte
const (
// IPv4MinimumSize is the minimum size of a valid IPv4 packet.
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
// that there are only 4 bits to represents the header length in 32-bit
// units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
// IPv4AddressSize is the size, in bytes, of an IPv4 address.
IPv4AddressSize = 4
// IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
)
var (
// IPv4Broadcast is the broadcast address of the IPv4 protocol.
IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff}
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any = net.IP{0, 0, 0, 0}
)
// Flags that may be set in an IPv4 packet.
const (
IPv4FlagMoreFragments = 1 << iota
IPv4FlagDontFragment
)
// HeaderLength returns the value of the "header length" field of the ipv4
// header.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & 0xf) * 4
}
// Protocol returns the value of the protocol field of the ipv4 header.
func (b IPv4) Protocol() uint8 {
return b[protocol]
}
// SourceAddress returns the "source address" field of the ipv4 header.
func (b IPv4) SourceAddress() net.IP {
return net.IP(b[srcAddr : srcAddr+IPv4AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv4
// header.
func (b IPv4) DestinationAddress() net.IP {
return net.IP(b[dstAddr : dstAddr+IPv4AddressSize])
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv4) TransportProtocol() TransportProtocolNumber {
return TransportProtocolNumber(b.Protocol())
}
// Payload implements Network.Payload.
func (b IPv4) Payload() []byte {
return b[b.HeaderLength():][:b.PayloadLength()]
}
// PayloadLength returns the length of the payload portion of the ipv4 packet.
func (b IPv4) PayloadLength() uint16 {
return b.TotalLength() - uint16(b.HeaderLength())
}
// TotalLength returns the "total length" field of the ipv4 header.
func (b IPv4) TotalLength() uint16 {
return binary.BigEndian.Uint16(b[totalLen:])
}
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[totalLen:], totalLength)
}
// SetChecksum sets the checksum field of the ipv4 header.
func (b IPv4) SetChecksum(v uint16) {
binary.BigEndian.PutUint16(b[checksum:], v)
}
// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
// ipv4 header.
func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
binary.BigEndian.PutUint16(b[flagsFO:], v)
}
// SetSourceAddress sets the "source address" field of the ipv4 header.
func (b IPv4) SetSourceAddress(addr net.IP) {
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4())
}
// SetDestinationAddress sets the "destination address" field of the ipv4
// header.
func (b IPv4) SetDestinationAddress(addr net.IP) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4())
}
// CalculateChecksum calculates the checksum of the ipv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return Checksum(b[:b.HeaderLength()], 0)
}
// Encode encodes all the fields of the ipv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
b[ttl] = i.TTL
b[protocol] = i.Protocol
b.SetChecksum(i.Checksum)
copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
}
const (
udpSrcPort = 0
udpDstPort = 2
udpLength = 4
udpChecksum = 6
)
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
// SrcPort is the "source port" field of a UDP packet.
SrcPort uint16
// DstPort is the "destination port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
Length uint16
// Checksum is the "checksum" field of a UDP packet.
Checksum uint16
}
// UDP represents a UDP header stored in a byte array.
type UDP []byte
const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber TransportProtocolNumber = 17
)
// SourcePort returns the "source port" field of the udp header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
// DestinationPort returns the "destination port" field of the udp header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
// Length returns the "length" field of the udp header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// SetSourcePort sets the "source port" field of the udp header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
// SetDestinationPort sets the "destination port" field of the udp header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
// SetChecksum sets the "checksum" field of the udp header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
// Checksum returns the "checksum" field of the udp header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
// CalculateChecksum calculates the checksum of the udp packet, given the total
// length of the packet and the checksum of the network-layer pseudo-header
// (excluding the total length) and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
// Add the length portion of the checksum to the pseudo-checksum.
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
checksum := Checksum(tmp, partialChecksum)
// Calculate the rest of the checksum.
return Checksum(b[:UDPMinimumSize], checksum)
}
// Encode encodes all the fields of the udp header.
func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
func calculateChecksum(buf []byte, initial uint32) uint16 {
v := initial
l := len(buf)
if l&1 != 0 {
l--
v += uint32(buf[l]) << 8
}
for i := 0; i < l; i += 2 {
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
}
return ChecksumCombine(uint16(v), uint16(v>>16))
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
return calculateChecksum(buf, uint32(initial))
}
// ChecksumCombine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//
// Note that checksum a must have been computed on an even number of bytes.
func ChecksumCombine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}
// PseudoHeaderChecksum calculates the pseudo-header checksum for the
// given destination protocol and network address, ignoring the length
// field. Pseudo-headers are needed by transport layers when calculating
// their own checksum.
func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 {
xsum := Checksum([]byte(srcAddr), 0)
xsum = Checksum([]byte(dstAddr), xsum)
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte {
ipLen := IPv4MinimumSize
udpLen := UDPMinimumSize
h := make([]byte, 0, ipLen+udpLen+len(packet))
hdr := uio.NewBigEndianBuffer(h)
ipv4fields := &IPv4Fields{
IHL: IPv4MinimumSize,
TotalLength: uint16(ipLen + udpLen + len(packet)),
TTL: 64, // Per RFC 1700's recommendation for IP time to live
Protocol: uint8(UDPProtocolNumber),
SrcAddr: src.IP.To4(),
DstAddr: dest.IP.To4(),
}
ipv4hdr := IPv4(hdr.WriteN(ipLen))
ipv4hdr.Encode(ipv4fields)
ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum())
udphdr := UDP(hdr.WriteN(udpLen))
udphdr.Encode(&UDPFields{
SrcPort: uint16(src.Port),
DstPort: uint16(dest.Port),
Length: uint16(udpLen + len(packet)),
})
xsum := Checksum(packet, PseudoHeaderChecksum(
ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr))
udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length()))
hdr.WriteBytes(packet)
return hdr.Data()
}

View File

@@ -0,0 +1,312 @@
package dhcpd
import (
"errors"
"fmt"
"io/ioutil"
"net"
"os/exec"
"regexp"
"runtime"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/util"
"github.com/AdguardTeam/golibs/file"
"github.com/AdguardTeam/golibs/log"
)
// Check if network interface has a static IP configured
// Supports: Raspbian.
func HasStaticIP(ifaceName string) (bool, error) {
if runtime.GOOS == "linux" {
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil {
return false, err
}
return hasStaticIPDhcpcdConf(string(body), ifaceName), nil
}
if runtime.GOOS == "darwin" {
return hasStaticIPDarwin(ifaceName)
}
return false, fmt.Errorf("cannot check if IP is static: not supported on %s", runtime.GOOS)
}
// Set a static IP for the specified network interface
func SetStaticIP(ifaceName string) error {
if runtime.GOOS == "linux" {
return setStaticIPDhcpdConf(ifaceName)
}
if runtime.GOOS == "darwin" {
return setStaticIPDarwin(ifaceName)
}
return fmt.Errorf("cannot set static IP on %s", runtime.GOOS)
}
// for dhcpcd.conf
func hasStaticIPDhcpcdConf(dhcpConf, ifaceName string) bool {
lines := strings.Split(dhcpConf, "\n")
nameLine := fmt.Sprintf("interface %s", ifaceName)
withinInterfaceCtx := false
for _, line := range lines {
line = strings.TrimSpace(line)
if withinInterfaceCtx && len(line) == 0 {
// an empty line resets our state
withinInterfaceCtx = false
}
if len(line) == 0 || line[0] == '#' {
continue
}
line = strings.TrimSpace(line)
if !withinInterfaceCtx {
if line == nameLine {
// we found our interface
withinInterfaceCtx = true
}
} else {
if strings.HasPrefix(line, "interface ") {
// we found another interface - reset our state
withinInterfaceCtx = false
continue
}
if strings.HasPrefix(line, "static ip_address=") {
return true
}
}
}
return false
}
// Get gateway IP address
func getGatewayIP(ifaceName string) string {
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
d, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
return ""
}
fields := strings.Fields(string(d))
if len(fields) < 3 || fields[0] != "default" {
return ""
}
ip := net.ParseIP(fields[2])
if ip == nil {
return ""
}
return fields[2]
}
// setStaticIPDhcpdConf - updates /etc/dhcpd.conf and sets the current IP address to be static
func setStaticIPDhcpdConf(ifaceName string) error {
ip := util.GetSubnet(ifaceName)
if len(ip) == 0 {
return errors.New("can't get IP address")
}
ip4, _, err := net.ParseCIDR(ip)
if err != nil {
return err
}
gatewayIP := getGatewayIP(ifaceName)
add := updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String())
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil {
return err
}
body = append(body, []byte(add)...)
err = file.SafeWrite("/etc/dhcpcd.conf", body)
if err != nil {
return err
}
return nil
}
// updates dhcpd.conf content -- sets static IP address there
// for dhcpcd.conf
func updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string {
var body []byte
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
ifaceName, ip)
body = append(body, []byte(add)...)
if len(gatewayIP) != 0 {
add = fmt.Sprintf("static routers=%s\n",
gatewayIP)
body = append(body, []byte(add)...)
}
add = fmt.Sprintf("static domain_name_servers=%s\n\n",
dnsIP)
body = append(body, []byte(add)...)
return string(body)
}
// Check if network interface has a static IP configured
// Supports: MacOS.
func hasStaticIPDarwin(ifaceName string) (bool, error) {
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
if err != nil {
return false, err
}
return portInfo.static, nil
}
// setStaticIPDarwin - uses networksetup util to set the current IP address to be static
// Additionally it configures the current DNS servers as well
func setStaticIPDarwin(ifaceName string) error {
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
if err != nil {
return err
}
if portInfo.static {
return errors.New("IP address is already static")
}
dnsAddrs, err := getEtcResolvConfServers()
if err != nil {
return err
}
args := make([]string, 0)
args = append(args, "-setdnsservers", portInfo.name)
args = append(args, dnsAddrs...)
// Setting DNS servers is necessary when configuring a static IP
code, _, err := util.RunCommand("networksetup", args...)
if err != nil {
return err
}
if code != 0 {
return fmt.Errorf("failed to set DNS servers, code=%d", code)
}
// Actually configures hardware port to have static IP
code, _, err = util.RunCommand("networksetup", "-setmanual",
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
if err != nil {
return err
}
if code != 0 {
return fmt.Errorf("failed to set DNS servers, code=%d", code)
}
return nil
}
// getCurrentHardwarePortInfo gets information the specified network interface
func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
// First of all we should find hardware port name
m := getNetworkSetupHardwareReports()
hardwarePort, ok := m[ifaceName]
if !ok {
return hardwarePortInfo{}, fmt.Errorf("could not find hardware port for %s", ifaceName)
}
return getHardwarePortInfo(hardwarePort)
}
// getNetworkSetupHardwareReports parses the output of the `networksetup -listallhardwareports` command
// it returns a map where the key is the interface name, and the value is the "hardware port"
// returns nil if it fails to parse the output
func getNetworkSetupHardwareReports() map[string]string {
_, out, err := util.RunCommand("networksetup", "-listallhardwareports")
if err != nil {
return nil
}
re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n")
if err != nil {
return nil
}
m := make(map[string]string, 0)
matches := re.FindAllStringSubmatch(out, -1)
for i := range matches {
port := matches[i][1]
device := matches[i][2]
m[device] = port
}
return m
}
// hardwarePortInfo - information obtained using MacOS networksetup
// about the current state of the internet connection
type hardwarePortInfo struct {
name string
ip string
subnet string
gatewayIP string
static bool
}
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
h := hardwarePortInfo{}
_, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort)
if err != nil {
return h, err
}
re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
match := re.FindStringSubmatch(out)
if len(match) == 0 {
return h, errors.New("could not find hardware port info")
}
h.name = hardwarePort
h.ip = match[1]
h.subnet = match[2]
h.gatewayIP = match[3]
if strings.Index(out, "Manual Configuration") == 0 {
h.static = true
}
return h, nil
}
// Gets a list of nameservers currently configured in the /etc/resolv.conf
func getEtcResolvConfServers() ([]string, error) {
body, err := ioutil.ReadFile("/etc/resolv.conf")
if err != nil {
return nil, err
}
re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
matches := re.FindAllStringSubmatch(string(body), -1)
if len(matches) == 0 {
return nil, errors.New("found no DNS servers in /etc/resolv.conf")
}
addrs := make([]string, 0)
for i := range matches {
addrs = append(addrs, matches[i][1])
}
return addrs, nil
}

View File

@@ -0,0 +1,63 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHasStaticIPDhcpcdConf(t *testing.T) {
dhcpdConf := `#comment
# comment
interface eth0
static ip_address=192.168.0.1/24
# interface wlan0
static ip_address=192.168.1.1/24
# comment
`
assert.True(t, !hasStaticIPDhcpcdConf(dhcpdConf, "wlan0"))
dhcpdConf = `#comment
# comment
interface eth0
static ip_address=192.168.0.1/24
# interface wlan0
static ip_address=192.168.1.1/24
# comment
interface wlan0
# comment
static ip_address=192.168.2.1/24
`
assert.True(t, hasStaticIPDhcpcdConf(dhcpdConf, "wlan0"))
}
func TestSetStaticIPDhcpcdConf(t *testing.T) {
dhcpcdConf := `
interface wlan0
static ip_address=192.168.0.2/24
static routers=192.168.0.1
static domain_name_servers=192.168.0.2
`
s := updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2")
assert.Equal(t, dhcpcdConf, s)
// without gateway
dhcpcdConf = `
interface wlan0
static ip_address=192.168.0.2/24
static domain_name_servers=192.168.0.2
`
s = updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2")
assert.Equal(t, dhcpcdConf, s)
}

View File

@@ -0,0 +1,13 @@
package dhcpd
import (
"errors"
"net"
"golang.org/x/net/ipv4"
)
// Create a socket for receiving broadcast packets
func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) {
return nil, errors.New("newBroadcastPacketConn(): not supported on Windows")
}

View File

@@ -0,0 +1,239 @@
package dhcpd
import (
"encoding/binary"
"fmt"
"net"
"sync/atomic"
"time"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv6"
)
type raCtx struct {
raAllowSlaac bool // send RA packets without MO flags
raSlaacOnly bool // send RA packets with MO flags
ipAddr net.IP // source IP address (link-local-unicast)
dnsIPAddr net.IP // IP address for DNS Server option
prefixIPAddr net.IP // IP address for Prefix option
ifaceName string
iface *net.Interface
packetSendPeriod time.Duration // how often RA packets are sent
conn *icmp.PacketConn // ICMPv6 socket
stop atomic.Value // stop the packet sending loop
}
type icmpv6RA struct {
managedAddressConfiguration bool
otherConfiguration bool
prefix net.IP
prefixLen int
sourceLinkLayerAddress net.HardwareAddr
recursiveDNSServer net.IP
mtu uint32
}
// Create an ICMPv6.RouterAdvertisement packet with all necessary options.
//
// ICMPv6:
// type[1]
// code[1]
// chksum[2]
// body (RouterAdvertisement):
// Cur Hop Limit[1]
// Flags[1]: MO......
// Router Lifetime[2]
// Reachable Time[4]
// Retrans Timer[4]
// Option=Prefix Information(3):
// Type[1]
// Length * 8bytes[1]
// Prefix Length[1]
// Flags[1]: LA......
// Valid Lifetime[4]
// Preferred Lifetime[4]
// Reserved[4]
// Prefix[16]
// Option=MTU(5):
// Type[1]
// Length * 8bytes[1]
// Reserved[2]
// MTU[4]
// Option=Source link-layer address(1):
// Link-Layer Address[6]
// Option=Recursive DNS Server(25):
// Type[1]
// Length * 8bytes[1]
// Reserved[2]
// Lifetime[4]
// Addresses of IPv6 Recursive DNS Servers[16]
func createICMPv6RAPacket(params icmpv6RA) []byte {
data := make([]byte, 88)
i := 0
// ICMPv6:
data[i] = 134 // type
data[i+1] = 0 // code
data[i+2] = 0 // chksum
data[i+3] = 0
i += 4
// RouterAdvertisement:
data[i] = 64 // Cur Hop Limit[1]
i++
data[i] = 0 // Flags[1]: MO......
if params.managedAddressConfiguration {
data[i] |= 0x80
}
if params.otherConfiguration {
data[i] |= 0x40
}
i++
binary.BigEndian.PutUint16(data[i:], 1800) // Router Lifetime[2]
i += 2
binary.BigEndian.PutUint32(data[i:], 0) // Reachable Time[4]
i += 4
binary.BigEndian.PutUint32(data[i:], 0) // Retrans Timer[4]
i += 4
// Option=Prefix Information:
data[i] = 3 // Type
data[i+1] = 4 // Length
i += 2
data[i] = byte(params.prefixLen) // Prefix Length[1]
i++
data[i] = 0xc0 // Flags[1]
i++
binary.BigEndian.PutUint32(data[i:], 3600) // Valid Lifetime[4]
i += 4
binary.BigEndian.PutUint32(data[i:], 3600) // Preferred Lifetime[4]
i += 4
binary.BigEndian.PutUint32(data[i:], 0) // Reserved[4]
i += 4
copy(data[i:], params.prefix[:8]) // Prefix[16]
binary.BigEndian.PutUint32(data[i+8:], 0)
binary.BigEndian.PutUint32(data[i+12:], 0)
i += 16
// Option=MTU:
data[i] = 5 // Type
data[i+1] = 1 // Length
i += 2
binary.BigEndian.PutUint16(data[i:], 0) // Reserved[2]
i += 2
binary.BigEndian.PutUint32(data[i:], params.mtu) // MTU[4]
i += 4
// Option=Source link-layer address:
data[i] = 1 // Type
data[i+1] = 1 // Length
i += 2
copy(data[i:], params.sourceLinkLayerAddress) // Link-Layer Address[6]
i += 6
// Option=Recursive DNS Server:
data[i] = 25 // Type
data[i+1] = 3 // Length
i += 2
binary.BigEndian.PutUint16(data[i:], 0) // Reserved[2]
i += 2
binary.BigEndian.PutUint32(data[i:], 3600) // Lifetime[4]
i += 4
copy(data[i:], params.recursiveDNSServer) // Addresses of IPv6 Recursive DNS Servers[16]
return data
}
// Init - initialize RA module
func (ra *raCtx) Init() error {
ra.stop.Store(0)
ra.conn = nil
if !(ra.raAllowSlaac || ra.raSlaacOnly) {
return nil
}
log.Debug("DHCPv6 RA: source IP address: %s DNS IP address: %s",
ra.ipAddr, ra.dnsIPAddr)
params := icmpv6RA{
managedAddressConfiguration: !ra.raSlaacOnly,
otherConfiguration: !ra.raSlaacOnly,
mtu: uint32(ra.iface.MTU),
prefixLen: 64,
recursiveDNSServer: ra.dnsIPAddr,
sourceLinkLayerAddress: ra.iface.HardwareAddr,
}
params.prefix = make([]byte, 16)
copy(params.prefix, ra.prefixIPAddr[:8]) // /64
data := createICMPv6RAPacket(params)
var err error
ipAndScope := ra.ipAddr.String() + "%" + ra.ifaceName
ra.conn, err = icmp.ListenPacket("ip6:ipv6-icmp", ipAndScope)
if err != nil {
return fmt.Errorf("DHCPv6 RA: icmp.ListenPacket: %s", err)
}
success := false
defer func() {
if !success {
ra.Close()
}
}()
con6 := ra.conn.IPv6PacketConn()
if err := con6.SetHopLimit(255); err != nil {
return fmt.Errorf("DHCPv6 RA: SetHopLimit: %s", err)
}
if err := con6.SetMulticastHopLimit(255); err != nil {
return fmt.Errorf("DHCPv6 RA: SetMulticastHopLimit: %s", err)
}
msg := &ipv6.ControlMessage{
HopLimit: 255,
Src: ra.ipAddr,
IfIndex: ra.iface.Index,
}
addr := &net.UDPAddr{
IP: net.ParseIP("ff02::1"),
}
go func() {
log.Debug("DHCPv6 RA: starting to send periodic RouterAdvertisement packets")
for ra.stop.Load() == 0 {
_, err = con6.WriteTo(data, msg, addr)
if err != nil {
log.Error("DHCPv6 RA: WriteTo: %s", err)
}
time.Sleep(ra.packetSendPeriod)
}
log.Debug("DHCPv6 RA: loop exit")
}()
success = true
return nil
}
// Close - close module
func (ra *raCtx) Close() {
log.Debug("DHCPv6 RA: closing")
ra.stop.Store(1)
if ra.conn != nil {
ra.conn.Close()
}
}

View File

@@ -0,0 +1,31 @@
package dhcpd
import (
"bytes"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRA(t *testing.T) {
ra := icmpv6RA{
managedAddressConfiguration: false,
otherConfiguration: true,
mtu: 1500,
prefix: net.ParseIP("1234::"),
prefixLen: 64,
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
}
data := createICMPv6RAPacket(ra)
dataCorrect := []byte{
0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
0x12, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc, 0x01, 0x01, 0x0a, 0x00, 0x27, 0x00, 0x00, 0x00,
0x19, 0x03, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x08, 0x00, 0x27, 0xff, 0xfe, 0x00, 0x00, 0x00,
}
assert.True(t, bytes.Equal(data, dataCorrect))
}

100
internal/dhcpd/server.go Normal file
View File

@@ -0,0 +1,100 @@
package dhcpd
import (
"net"
"time"
)
// DHCPServer - DHCP server interface
type DHCPServer interface {
// ResetLeases - reset leases
ResetLeases(leases []*Lease)
// GetLeases - get leases
GetLeases(flags int) []Lease
// GetLeasesRef - get reference to leases array
GetLeasesRef() []*Lease
// AddStaticLease - add a static lease
AddStaticLease(lease Lease) error
// RemoveStaticLease - remove a static lease
RemoveStaticLease(l Lease) error
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
FindMACbyIP(ip net.IP) net.HardwareAddr
// WriteDiskConfig4 - copy disk configuration
WriteDiskConfig4(c *V4ServerConf)
// WriteDiskConfig6 - copy disk configuration
WriteDiskConfig6(c *V6ServerConf)
// Start - start server
Start() error
// Stop - stop server
Stop()
}
// V4ServerConf - server configuration
type V4ServerConf struct {
Enabled bool `yaml:"-"`
InterfaceName string `yaml:"-"`
GatewayIP string `yaml:"gateway_ip"`
SubnetMask string `yaml:"subnet_mask"`
// The first & the last IP address for dynamic leases
// Bytes [0..2] of the last allowed IP address must match the first IP
RangeStart string `yaml:"range_start"`
RangeEnd string `yaml:"range_end"`
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
// IP conflict detector: time (ms) to wait for ICMP reply
// 0: disable
ICMPTimeout uint32 `yaml:"icmp_timeout_msec"`
// Custom Options.
//
// Option with arbitrary hexadecimal data:
// DEC_CODE hex HEX_DATA
// where DEC_CODE is a decimal DHCPv4 option code in range [1..255]
//
// Option with IP data (only 1 IP is supported):
// DEC_CODE ip IP_ADDR
Options []string `yaml:"options"`
ipStart net.IP // starting IP address for dynamic leases
ipEnd net.IP // ending IP address for dynamic leases
leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses
routerIP net.IP // value for Option Router
subnetMask net.IPMask // value for Option SubnetMask
options []dhcpOption
// Server calls this function when leases data changes
notify func(uint32)
}
// V6ServerConf - server configuration
type V6ServerConf struct {
Enabled bool `yaml:"-"`
InterfaceName string `yaml:"-"`
// The first IP address for dynamic leases
// The last allowed IP address ends with 0xff byte
RangeStart string `yaml:"range_start"`
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
RaSlaacOnly bool `yaml:"ra_slaac_only"` // send ICMPv6.RA packets without MO flags
RaAllowSlaac bool `yaml:"ra_allow_slaac"` // send ICMPv6.RA packets with MO flags
ipStart net.IP // starting IP address for dynamic leases
leaseTime time.Duration // the time during which a dynamic lease is considered valid
dnsIPAddrs []net.IP // IPv6 addresses to return to DHCP clients as DNS server addresses
// Server calls this function when leases data changes
notify func(uint32)
}
type dhcpOption struct {
code uint8
val []byte
}

673
internal/dhcpd/v4.go Normal file
View File

@@ -0,0 +1,673 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/sparrc/go-ping"
)
// v4Server - DHCPv4 server
type v4Server struct {
srv *server4.Server
leasesLock sync.Mutex
leases []*Lease
ipAddrs [256]byte
conf V4ServerConf
}
// WriteDiskConfig4 - write configuration
func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) {
*c = s.conf
}
// WriteDiskConfig6 - write configuration
func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
}
// Return TRUE if IP address is within range [start..stop]
func ip4InRange(start net.IP, stop net.IP, ip net.IP) bool {
if len(start) != 4 || len(stop) != 4 {
return false
}
from := binary.BigEndian.Uint32(start)
to := binary.BigEndian.Uint32(stop)
check := binary.BigEndian.Uint32(ip)
return from <= check && check <= to
}
// ResetLeases - reset leases
func (s *v4Server) ResetLeases(leases []*Lease) {
s.leases = nil
for _, l := range leases {
if l.Expiry.Unix() != leaseExpireStatic &&
!ip4InRange(s.conf.ipStart, s.conf.ipEnd, l.IP) {
log.Debug("DHCPv4: skipping a lease with IP %v: not within current IP range", l.IP)
continue
}
s.addLease(l)
}
}
// GetLeasesRef - get leases
func (s *v4Server) GetLeasesRef() []*Lease {
return s.leases
}
// Return TRUE if this lease holds a blacklisted IP
func (s *v4Server) blacklisted(l *Lease) bool {
return l.HWAddr.String() == "00:00:00:00:00:00"
}
// GetLeases returns the list of current DHCP leases (thread-safe)
func (s *v4Server) GetLeases(flags int) []Lease {
var result []Lease
now := time.Now().Unix()
s.leasesLock.Lock()
for _, lease := range s.leases {
if ((flags&LeasesDynamic) != 0 && lease.Expiry.Unix() > now && !s.blacklisted(lease)) ||
((flags&LeasesStatic) != 0 && lease.Expiry.Unix() == leaseExpireStatic) {
result = append(result, *lease)
}
}
s.leasesLock.Unlock()
return result
}
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
func (s *v4Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
now := time.Now().Unix()
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
ip4 := ip.To4()
if ip4 == nil {
return nil
}
for _, l := range s.leases {
if l.IP.Equal(ip4) {
unix := l.Expiry.Unix()
if unix > now || unix == leaseExpireStatic {
return l.HWAddr
}
}
}
return nil
}
// Add the specified IP to the black list for a time period
func (s *v4Server) blacklistLease(lease *Lease) {
hw := make(net.HardwareAddr, 6)
lease.HWAddr = hw
lease.Hostname = ""
lease.Expiry = time.Now().Add(s.conf.leaseTime)
}
// Remove (swap) lease by index
func (s *v4Server) leaseRemoveSwapByIndex(i int) {
s.ipAddrs[s.leases[i].IP[3]] = 0
log.Debug("DHCPv4: removed lease %s", s.leases[i].HWAddr)
n := len(s.leases)
if i != n-1 {
s.leases[i] = s.leases[n-1] // swap with the last element
}
s.leases = s.leases[:n-1]
}
// Remove a dynamic lease with the same properties
// Return error if a static lease is found
func (s *v4Server) rmDynamicLease(lease Lease) error {
for i := 0; i < len(s.leases); i++ {
l := s.leases[i]
if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists")
}
s.leaseRemoveSwapByIndex(i)
if i == len(s.leases) {
break
}
l = s.leases[i]
}
if net.IP.Equal(l.IP, lease.IP) {
if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists")
}
s.leaseRemoveSwapByIndex(i)
}
}
return nil
}
// Add a lease
func (s *v4Server) addLease(l *Lease) {
s.leases = append(s.leases, l)
s.ipAddrs[l.IP[3]] = 1
log.Debug("DHCPv4: added lease %s <-> %s", l.IP, l.HWAddr)
}
// Remove a lease with the same properties
func (s *v4Server) rmLease(lease Lease) error {
for i, l := range s.leases {
if net.IP.Equal(l.IP, lease.IP) {
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
l.Hostname != lease.Hostname {
return fmt.Errorf("Lease not found")
}
s.leaseRemoveSwapByIndex(i)
return nil
}
}
return fmt.Errorf("lease not found")
}
// AddStaticLease adds a static lease (thread-safe)
func (s *v4Server) AddStaticLease(lease Lease) error {
if len(lease.IP) != 4 {
return fmt.Errorf("invalid IP")
}
if len(lease.HWAddr) != 6 {
return fmt.Errorf("invalid MAC")
}
lease.Expiry = time.Unix(leaseExpireStatic, 0)
s.leasesLock.Lock()
err := s.rmDynamicLease(lease)
if err != nil {
s.leasesLock.Unlock()
return err
}
s.addLease(&lease)
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedAddedStatic)
return nil
}
// RemoveStaticLease removes a static lease (thread-safe)
func (s *v4Server) RemoveStaticLease(l Lease) error {
if len(l.IP) != 4 {
return fmt.Errorf("invalid IP")
}
if len(l.HWAddr) != 6 {
return fmt.Errorf("invalid MAC")
}
s.leasesLock.Lock()
err := s.rmLease(l)
if err != nil {
s.leasesLock.Unlock()
return err
}
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedRemovedStatic)
return nil
}
// Send ICMP to the specified machine
// Return TRUE if it doesn't reply, which probably means that the IP is available
func (s *v4Server) addrAvailable(target net.IP) bool {
if s.conf.ICMPTimeout == 0 {
return true
}
pinger, err := ping.NewPinger(target.String())
if err != nil {
log.Error("ping.NewPinger(): %v", err)
return true
}
pinger.SetPrivileged(true)
pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond
pinger.Count = 1
reply := false
pinger.OnRecv = func(pkt *ping.Packet) {
reply = true
}
log.Debug("DHCPv4: Sending ICMP Echo to %v", target)
pinger.Run()
if reply {
log.Info("DHCPv4: IP conflict: %v is already used by another device", target)
return false
}
log.Debug("DHCPv4: ICMP procedure is complete: %v", target)
return true
}
// Find lease by MAC
func (s *v4Server) findLease(mac net.HardwareAddr) *Lease {
for i := range s.leases {
if bytes.Equal(mac, s.leases[i].HWAddr) {
return s.leases[i]
}
}
return nil
}
// Get next free IP
func (s *v4Server) findFreeIP() net.IP {
for i := s.conf.ipStart[3]; ; i++ {
if s.ipAddrs[i] == 0 {
ip := make([]byte, 4)
copy(ip, s.conf.ipStart)
ip[3] = i
return ip
}
if i == s.conf.ipEnd[3] {
break
}
}
return nil
}
// Find an expired lease and return its index or -1
func (s *v4Server) findExpiredLease() int {
now := time.Now().Unix()
for i, lease := range s.leases {
if lease.Expiry.Unix() != leaseExpireStatic &&
lease.Expiry.Unix() <= now {
return i
}
}
return -1
}
// Reserve lease for MAC
func (s *v4Server) reserveLease(mac net.HardwareAddr) *Lease {
l := Lease{}
l.HWAddr = make([]byte, 6)
copy(l.HWAddr, mac)
l.IP = s.findFreeIP()
if l.IP == nil {
i := s.findExpiredLease()
if i < 0 {
return nil
}
copy(s.leases[i].HWAddr, mac)
return s.leases[i]
}
s.addLease(&l)
return &l
}
func (s *v4Server) commitLease(l *Lease) {
l.Expiry = time.Now().Add(s.conf.leaseTime)
s.leasesLock.Lock()
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedAdded)
}
// Process Discover request and return lease
func (s *v4Server) processDiscover(req *dhcpv4.DHCPv4, resp *dhcpv4.DHCPv4) *Lease {
mac := req.ClientHWAddr
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
lease := s.findLease(mac)
if lease == nil {
toStore := false
for lease == nil {
lease = s.reserveLease(mac)
if lease == nil {
log.Debug("DHCPv4: No more IP addresses")
if toStore {
s.conf.notify(LeaseChangedDBStore)
}
return nil
}
toStore = true
if !s.addrAvailable(lease.IP) {
s.blacklistLease(lease)
lease = nil
continue
}
break
}
s.conf.notify(LeaseChangedDBStore)
} else {
reqIP := req.Options.Get(dhcpv4.OptionRequestedIPAddress)
if len(reqIP) != 0 &&
!bytes.Equal(reqIP, lease.IP) {
log.Debug("DHCPv4: different RequestedIP: %v != %v", reqIP, lease.IP)
}
}
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
return lease
}
type optFQDN struct {
name string
}
func (o *optFQDN) String() string {
return "optFQDN"
}
// flags[1]
// A-RR[1]
// PTR-RR[1]
// name[]
func (o *optFQDN) ToBytes() []byte {
b := make([]byte, 3+len(o.name))
i := 0
b[i] = 0x03 // f_server_overrides | f_server
i++
b[i] = 255 // A-RR
i++
b[i] = 255 // PTR-RR
i++
copy(b[i:], []byte(o.name))
return b
}
// Process Request request and return lease
// Return false if we don't need to reply
func (s *v4Server) processRequest(req *dhcpv4.DHCPv4, resp *dhcpv4.DHCPv4) (*Lease, bool) {
var lease *Lease
mac := req.ClientHWAddr
hostname := req.Options.Get(dhcpv4.OptionHostName)
reqIP := req.Options.Get(dhcpv4.OptionRequestedIPAddress)
if reqIP == nil {
reqIP = req.ClientIPAddr
}
sid := req.Options.Get(dhcpv4.OptionServerIdentifier)
if len(sid) != 0 &&
!bytes.Equal(sid, s.conf.dnsIPAddrs[0]) {
log.Debug("DHCPv4: Bad OptionServerIdentifier in Request message for %s", mac)
return nil, false
}
if len(reqIP) != 4 {
log.Debug("DHCPv4: Bad OptionRequestedIPAddress in Request message for %s", mac)
return nil, false
}
s.leasesLock.Lock()
for _, l := range s.leases {
if bytes.Equal(l.HWAddr, mac) {
if !bytes.Equal(l.IP, reqIP) {
s.leasesLock.Unlock()
log.Debug("DHCPv4: Mismatched OptionRequestedIPAddress in Request message for %s", mac)
return nil, true
}
lease = l
break
}
}
s.leasesLock.Unlock()
if lease == nil {
log.Debug("DHCPv4: No lease for %s", mac)
return nil, true
}
if lease.Expiry.Unix() != leaseExpireStatic {
lease.Hostname = string(hostname)
s.commitLease(lease)
} else if len(lease.Hostname) != 0 {
o := &optFQDN{
name: lease.Hostname,
}
fqdn := dhcpv4.Option{
Code: dhcpv4.OptionFQDN,
Value: o,
}
resp.UpdateOption(fqdn)
}
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
return lease, true
}
// Find a lease associated with MAC and prepare response
// Return 1: OK
// Return 0: error; reply with Nak
// Return -1: error; don't reply
func (s *v4Server) process(req *dhcpv4.DHCPv4, resp *dhcpv4.DHCPv4) int {
var lease *Lease
resp.UpdateOption(dhcpv4.OptServerIdentifier(s.conf.dnsIPAddrs[0]))
switch req.MessageType() {
case dhcpv4.MessageTypeDiscover:
lease = s.processDiscover(req, resp)
if lease == nil {
return 0
}
case dhcpv4.MessageTypeRequest:
var toReply bool
lease, toReply = s.processRequest(req, resp)
if lease == nil {
if toReply {
return 0
}
return -1 // drop packet
}
}
resp.YourIPAddr = make([]byte, 4)
copy(resp.YourIPAddr, lease.IP)
resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime))
resp.UpdateOption(dhcpv4.OptRouter(s.conf.routerIP))
resp.UpdateOption(dhcpv4.OptSubnetMask(s.conf.subnetMask))
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
for _, opt := range s.conf.options {
resp.Options[opt.code] = opt.val
}
return 1
}
// client(0.0.0.0:68) -> (Request:ClientMAC,Type=Discover,ClientID,ReqIP,HostName) -> server(255.255.255.255:67)
// client(255.255.255.255:68) <- (Reply:YourIP,ClientMAC,Type=Offer,ServerID,SubnetMask,LeaseTime) <- server(<IP>:67)
// client(0.0.0.0:68) -> (Request:ClientMAC,Type=Request,ClientID,ReqIP||ClientIP,HostName,ServerID,ParamReqList) -> server(255.255.255.255:67)
// client(255.255.255.255:68) <- (Reply:YourIP,ClientMAC,Type=ACK,ServerID,SubnetMask,LeaseTime) <- server(<IP>:67)
func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4.DHCPv4) {
log.Debug("DHCPv4: received message: %s", req.Summary())
switch req.MessageType() {
case dhcpv4.MessageTypeDiscover,
dhcpv4.MessageTypeRequest:
//
default:
log.Debug("DHCPv4: unsupported message type %d", req.MessageType())
return
}
resp, err := dhcpv4.NewReplyFromRequest(req)
if err != nil {
log.Debug("DHCPv4: dhcpv4.New: %s", err)
return
}
if len(req.ClientHWAddr) != 6 {
log.Debug("DHCPv4: Invalid ClientHWAddr")
return
}
r := s.process(req, resp)
if r < 0 {
return
} else if r == 0 {
resp.Options.Update(dhcpv4.OptMessageType(dhcpv4.MessageTypeNak))
}
log.Debug("DHCPv4: sending: %s", resp.Summary())
_, err = conn.WriteTo(resp.ToBytes(), peer)
if err != nil {
log.Error("DHCPv4: conn.Write to %s failed: %s", peer, err)
return
}
}
// Start - start server
func (s *v4Server) Start() error {
if !s.conf.Enabled {
return nil
}
iface, err := net.InterfaceByName(s.conf.InterfaceName)
if err != nil {
return fmt.Errorf("DHCPv4: Couldn't find interface by name %s: %s", s.conf.InterfaceName, err)
}
log.Debug("DHCPv4: starting...")
s.conf.dnsIPAddrs = getIfaceIPv4(*iface)
if len(s.conf.dnsIPAddrs) == 0 {
log.Debug("DHCPv4: no IPv6 address for interface %s", iface.Name)
return nil
}
laddr := &net.UDPAddr{
IP: net.ParseIP("0.0.0.0"),
Port: dhcpv4.ServerPort,
}
s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger())
if err != nil {
return err
}
log.Info("DHCPv4: listening")
go func() {
err = s.srv.Serve()
log.Debug("DHCPv4: srv.Serve: %s", err)
}()
return nil
}
// Stop - stop server
func (s *v4Server) Stop() {
if s.srv == nil {
return
}
log.Debug("DHCPv4: stopping")
err := s.srv.Close()
if err != nil {
log.Error("DHCPv4: srv.Close: %s", err)
}
// now s.srv.Serve() will return
s.srv = nil
}
// Create DHCPv4 server
func v4Create(conf V4ServerConf) (DHCPServer, error) {
s := &v4Server{}
s.conf = conf
if !conf.Enabled {
return s, nil
}
var err error
s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP)
if err != nil {
return s, fmt.Errorf("DHCPv4: %s", err)
}
subnet, err := parseIPv4(s.conf.SubnetMask)
if err != nil || !isValidSubnetMask(subnet) {
return s, fmt.Errorf("DHCPv4: invalid subnet mask: %s", s.conf.SubnetMask)
}
s.conf.subnetMask = make([]byte, 4)
copy(s.conf.subnetMask, subnet)
s.conf.ipStart, err = parseIPv4(conf.RangeStart)
if s.conf.ipStart == nil {
return s, fmt.Errorf("DHCPv4: %s", err)
}
if s.conf.ipStart[0] == 0 {
return s, fmt.Errorf("DHCPv4: invalid range start IP")
}
s.conf.ipEnd, err = parseIPv4(conf.RangeEnd)
if s.conf.ipEnd == nil {
return s, fmt.Errorf("DHCPv4: %s", err)
}
if !net.IP.Equal(s.conf.ipStart[:3], s.conf.ipEnd[:3]) ||
s.conf.ipStart[3] > s.conf.ipEnd[3] {
return s, fmt.Errorf("DHCPv4: range end IP should match range start IP")
}
if conf.LeaseDuration == 0 {
s.conf.leaseTime = time.Hour * 24
s.conf.LeaseDuration = uint32(s.conf.leaseTime.Seconds())
} else {
s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration)
}
for _, o := range conf.Options {
code, val := parseOptionString(o)
if code == 0 {
log.Debug("DHCPv4: bad option string: %s", o)
continue
}
opt := dhcpOption{
code: code,
val: val,
}
s.conf.options = append(s.conf.options, opt)
}
return s, nil
}

View File

@@ -0,0 +1,47 @@
package dhcpd
// 'u-root/u-root' package, a dependency of 'insomniacslk/dhcp' package, doesn't build on Windows
import "net"
type winServer struct {
}
func (s *winServer) ResetLeases(leases []*Lease) {
}
func (s *winServer) GetLeases(flags int) []Lease {
return nil
}
func (s *winServer) GetLeasesRef() []*Lease {
return nil
}
func (s *winServer) AddStaticLease(lease Lease) error {
return nil
}
func (s *winServer) RemoveStaticLease(l Lease) error {
return nil
}
func (s *winServer) FindMACbyIP(ip net.IP) net.HardwareAddr {
return nil
}
func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {
}
func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {
}
func (s *winServer) Start() error {
return nil
}
func (s *winServer) Stop() {
}
func (s *winServer) Reset() {
}
func v4Create(conf V4ServerConf) (DHCPServer, error) {
return &winServer{}, nil
}
func v6Create(conf V6ServerConf) (DHCPServer, error) {
return &winServer{}, nil
}

238
internal/dhcpd/v4_test.go Normal file
View File

@@ -0,0 +1,238 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"net"
"testing"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/stretchr/testify/assert"
)
func notify4(flags uint32) {
}
func TestV4StaticLeaseAddRemove(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
notify: notify4,
}
s, err := v4Create(conf)
assert.True(t, err == nil)
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
// add static lease
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// try to add the same static lease - fail
assert.True(t, s.AddStaticLease(l) != nil)
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
// try to remove static lease - fail
l.IP = net.ParseIP("192.168.10.110").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil)
// remove static lease
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil)
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
}
func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
notify: notify4,
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
// add dynamic lease
ld := Lease{}
ld.IP = net.ParseIP("192.168.10.150").To4()
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
s.addLease(&ld)
// add dynamic lease
{
ld := Lease{}
ld.IP = net.ParseIP("192.168.10.151").To4()
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
s.addLease(&ld)
}
// add static lease with the same IP
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// add static lease with the same MAC
l = Lease{}
l.IP = net.ParseIP("192.168.10.152").To4()
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// check
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls))
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.Equal(t, "192.168.10.152", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
}
func TestV4StaticLeaseGet(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
notify: notify4,
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4()
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv4.NewDiscovery(mac)
resp, _ := dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
// check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
// "Request"
req, _ = dhcpv4.NewRequestFromOffer(resp)
resp, _ = dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
// check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
}
func TestV4DynamicLeaseGet(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: "192.168.10.100",
RangeEnd: "192.168.10.200",
GatewayIP: "192.168.10.1",
SubnetMask: "255.255.255.0",
notify: notify4,
Options: []string{
"81 hex 303132",
"82 ip 1.2.3.4",
},
}
sIface, err := v4Create(conf)
s := sIface.(*v4Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
// "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv4.NewDiscovery(mac)
resp, _ := dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
// check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String())
// "Request"
req, _ = dhcpv4.NewRequestFromOffer(resp)
resp, _ = dhcpv4.NewReplyFromRequest(req)
assert.Equal(t, 1, s.process(req, resp))
// check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "192.168.10.100", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
start := net.ParseIP("192.168.10.100").To4()
stop := net.ParseIP("192.168.10.200").To4()
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4()))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4()))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4()))
assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4()))
}

680
internal/dhcpd/v6.go Normal file
View File

@@ -0,0 +1,680 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"bytes"
"fmt"
"net"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/dhcpv6/server6"
"github.com/insomniacslk/dhcp/iana"
)
const valueIAID = "ADGH" // value for IANA.ID
// v6Server - DHCPv6 server
type v6Server struct {
srv *server6.Server
leasesLock sync.Mutex
leases []*Lease
ipAddrs [256]byte
sid dhcpv6.Duid
ra raCtx // RA module
conf V6ServerConf
}
// WriteDiskConfig4 - write configuration
func (s *v6Server) WriteDiskConfig4(c *V4ServerConf) {
}
// WriteDiskConfig6 - write configuration
func (s *v6Server) WriteDiskConfig6(c *V6ServerConf) {
*c = s.conf
}
// Return TRUE if IP address is within range [start..0xff]
// nolint(staticcheck)
func ip6InRange(start net.IP, ip net.IP) bool {
if len(start) != 16 {
return false
}
if !bytes.Equal(start[:15], ip[:15]) {
return false
}
return start[15] <= ip[15]
}
// ResetLeases - reset leases
func (s *v6Server) ResetLeases(ll []*Lease) {
s.leases = nil
for _, l := range ll {
if l.Expiry.Unix() != leaseExpireStatic &&
!ip6InRange(s.conf.ipStart, l.IP) {
log.Debug("DHCPv6: skipping a lease with IP %v: not within current IP range", l.IP)
continue
}
s.addLease(l)
}
}
// GetLeases - get current leases
func (s *v6Server) GetLeases(flags int) []Lease {
var result []Lease
s.leasesLock.Lock()
for _, lease := range s.leases {
if lease.Expiry.Unix() == leaseExpireStatic {
if (flags & LeasesStatic) != 0 {
result = append(result, *lease)
}
} else {
if (flags & LeasesDynamic) != 0 {
result = append(result, *lease)
}
}
}
s.leasesLock.Unlock()
return result
}
// GetLeasesRef - get leases
func (s *v6Server) GetLeasesRef() []*Lease {
return s.leases
}
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
func (s *v6Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
now := time.Now().Unix()
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
for _, l := range s.leases {
if l.IP.Equal(ip) {
unix := l.Expiry.Unix()
if unix > now || unix == leaseExpireStatic {
return l.HWAddr
}
}
}
return nil
}
// Remove (swap) lease by index
func (s *v6Server) leaseRemoveSwapByIndex(i int) {
s.ipAddrs[s.leases[i].IP[15]] = 0
log.Debug("DHCPv6: removed lease %s", s.leases[i].HWAddr)
n := len(s.leases)
if i != n-1 {
s.leases[i] = s.leases[n-1] // swap with the last element
}
s.leases = s.leases[:n-1]
}
// Remove a dynamic lease with the same properties
// Return error if a static lease is found
func (s *v6Server) rmDynamicLease(lease Lease) error {
for i := 0; i < len(s.leases); i++ {
l := s.leases[i]
if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists")
}
s.leaseRemoveSwapByIndex(i)
if i == len(s.leases) {
break
}
l = s.leases[i]
}
if net.IP.Equal(l.IP, lease.IP) {
if l.Expiry.Unix() == leaseExpireStatic {
return fmt.Errorf("static lease already exists")
}
s.leaseRemoveSwapByIndex(i)
}
}
return nil
}
// AddStaticLease - add a static lease
func (s *v6Server) AddStaticLease(l Lease) error {
if len(l.IP) != 16 {
return fmt.Errorf("invalid IP")
}
if len(l.HWAddr) != 6 {
return fmt.Errorf("invalid MAC")
}
l.Expiry = time.Unix(leaseExpireStatic, 0)
s.leasesLock.Lock()
err := s.rmDynamicLease(l)
if err != nil {
s.leasesLock.Unlock()
return err
}
s.addLease(&l)
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedAddedStatic)
return nil
}
// RemoveStaticLease - remove a static lease
func (s *v6Server) RemoveStaticLease(l Lease) error {
if len(l.IP) != 16 {
return fmt.Errorf("invalid IP")
}
if len(l.HWAddr) != 6 {
return fmt.Errorf("invalid MAC")
}
s.leasesLock.Lock()
err := s.rmLease(l)
if err != nil {
s.leasesLock.Unlock()
return err
}
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedRemovedStatic)
return nil
}
// Add a lease
func (s *v6Server) addLease(l *Lease) {
s.leases = append(s.leases, l)
s.ipAddrs[l.IP[15]] = 1
log.Debug("DHCPv6: added lease %s <-> %s", l.IP, l.HWAddr)
}
// Remove a lease with the same properties
func (s *v6Server) rmLease(lease Lease) error {
for i, l := range s.leases {
if net.IP.Equal(l.IP, lease.IP) {
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
l.Hostname != lease.Hostname {
return fmt.Errorf("Lease not found")
}
s.leaseRemoveSwapByIndex(i)
return nil
}
}
return fmt.Errorf("lease not found")
}
// Find lease by MAC
func (s *v6Server) findLease(mac net.HardwareAddr) *Lease {
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
for i := range s.leases {
if bytes.Equal(mac, s.leases[i].HWAddr) {
return s.leases[i]
}
}
return nil
}
// Find an expired lease and return its index or -1
func (s *v6Server) findExpiredLease() int {
now := time.Now().Unix()
for i, lease := range s.leases {
if lease.Expiry.Unix() != leaseExpireStatic &&
lease.Expiry.Unix() <= now {
return i
}
}
return -1
}
// Get next free IP
func (s *v6Server) findFreeIP() net.IP {
for i := s.conf.ipStart[15]; ; i++ {
if s.ipAddrs[i] == 0 {
ip := make([]byte, 16)
copy(ip, s.conf.ipStart)
ip[15] = i
return ip
}
if i == 0xff {
break
}
}
return nil
}
// Reserve lease for MAC
func (s *v6Server) reserveLease(mac net.HardwareAddr) *Lease {
l := Lease{}
l.HWAddr = make([]byte, 6)
copy(l.HWAddr, mac)
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
copy(l.IP, s.conf.ipStart)
l.IP = s.findFreeIP()
if l.IP == nil {
i := s.findExpiredLease()
if i < 0 {
return nil
}
copy(s.leases[i].HWAddr, mac)
return s.leases[i]
}
s.addLease(&l)
return &l
}
func (s *v6Server) commitDynamicLease(l *Lease) {
l.Expiry = time.Now().Add(s.conf.leaseTime)
s.leasesLock.Lock()
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedAdded)
}
// Check Client ID
func (s *v6Server) checkCID(msg *dhcpv6.Message) error {
if msg.Options.ClientID() == nil {
return fmt.Errorf("DHCPv6: no ClientID option in request")
}
return nil
}
// Check ServerID policy
func (s *v6Server) checkSID(msg *dhcpv6.Message) error {
sid := msg.Options.ServerID()
switch msg.Type() {
case dhcpv6.MessageTypeSolicit,
dhcpv6.MessageTypeConfirm,
dhcpv6.MessageTypeRebind:
if sid != nil {
return fmt.Errorf("DHCPv6: drop packet: ServerID option in message %s", msg.Type().String())
}
case dhcpv6.MessageTypeRequest,
dhcpv6.MessageTypeRenew,
dhcpv6.MessageTypeRelease,
dhcpv6.MessageTypeDecline:
if sid == nil {
return fmt.Errorf("DHCPv6: drop packet: no ServerID option in message %s", msg.Type().String())
}
if !sid.Equal(s.sid) {
return fmt.Errorf("DHCPv6: drop packet: mismatched ServerID option in message %s: %s",
msg.Type().String(), sid.String())
}
}
return nil
}
// . IAAddress must be equal to the lease's IP
func (s *v6Server) checkIA(msg *dhcpv6.Message, lease *Lease) error {
switch msg.Type() {
case dhcpv6.MessageTypeRequest,
dhcpv6.MessageTypeConfirm,
dhcpv6.MessageTypeRenew,
dhcpv6.MessageTypeRebind:
oia := msg.Options.OneIANA()
if oia == nil {
return fmt.Errorf("no IANA option in %s", msg.Type().String())
}
oiaAddr := oia.Options.OneAddress()
if oiaAddr == nil {
return fmt.Errorf("no IANA.Addr option in %s", msg.Type().String())
}
if !oiaAddr.IPv6Addr.Equal(lease.IP) {
return fmt.Errorf("invalid IANA.Addr option in %s", msg.Type().String())
}
}
return nil
}
// Store lease in DB (if necessary) and return lease life time
func (s *v6Server) commitLease(msg *dhcpv6.Message, lease *Lease) time.Duration {
lifetime := s.conf.leaseTime
switch msg.Type() {
case dhcpv6.MessageTypeSolicit:
//
case dhcpv6.MessageTypeConfirm:
lifetime = lease.Expiry.Sub(time.Now())
case dhcpv6.MessageTypeRequest,
dhcpv6.MessageTypeRenew,
dhcpv6.MessageTypeRebind:
if lease.Expiry.Unix() != leaseExpireStatic {
s.commitDynamicLease(lease)
}
}
return lifetime
}
// Find a lease associated with MAC and prepare response
func (s *v6Server) process(msg *dhcpv6.Message, req dhcpv6.DHCPv6, resp dhcpv6.DHCPv6) bool {
switch msg.Type() {
case dhcpv6.MessageTypeSolicit,
dhcpv6.MessageTypeRequest,
dhcpv6.MessageTypeConfirm,
dhcpv6.MessageTypeRenew,
dhcpv6.MessageTypeRebind:
// continue
default:
return false
}
mac, err := dhcpv6.ExtractMAC(req)
if err != nil {
log.Debug("DHCPv6: dhcpv6.ExtractMAC: %s", err)
return false
}
lease := s.findLease(mac)
if lease == nil {
log.Debug("DHCPv6: no lease for: %s", mac)
switch msg.Type() {
case dhcpv6.MessageTypeSolicit:
lease = s.reserveLease(mac)
if lease == nil {
return false
}
default:
return false
}
}
err = s.checkIA(msg, lease)
if err != nil {
log.Debug("DHCPv6: %s", err)
return false
}
lifetime := s.commitLease(msg, lease)
oia := &dhcpv6.OptIANA{
T1: lifetime / 2,
T2: time.Duration(float32(lifetime) / 1.5),
}
roia := msg.Options.OneIANA()
if roia != nil {
copy(oia.IaId[:], roia.IaId[:])
} else {
copy(oia.IaId[:], []byte(valueIAID))
}
oiaAddr := &dhcpv6.OptIAAddress{
IPv6Addr: lease.IP,
PreferredLifetime: lifetime,
ValidLifetime: lifetime,
}
oia.Options = dhcpv6.IdentityOptions{
Options: []dhcpv6.Option{oiaAddr},
}
resp.AddOption(oia)
if msg.IsOptionRequested(dhcpv6.OptionDNSRecursiveNameServer) {
resp.UpdateOption(dhcpv6.OptDNS(s.conf.dnsIPAddrs...))
}
fqdn := msg.GetOneOption(dhcpv6.OptionFQDN)
if fqdn != nil {
resp.AddOption(fqdn)
}
resp.AddOption(&dhcpv6.OptStatusCode{
StatusCode: iana.StatusSuccess,
StatusMessage: "success",
})
return true
}
// 1.
// fe80::* (client) --(Solicit + ClientID+IANA())-> ff02::1:2
// server -(Advertise + ClientID+ServerID+IANA(IAAddress)> fe80::*
// fe80::* --(Request + ClientID+ServerID+IANA(IAAddress))-> ff02::1:2
// server -(Reply + ClientID+ServerID+IANA(IAAddress)+DNS)> fe80::*
//
// 2.
// fe80::* --(Confirm|Renew|Rebind + ClientID+IANA(IAAddress))-> ff02::1:2
// server -(Reply + ClientID+ServerID+IANA(IAAddress)+DNS)> fe80::*
//
// 3.
// fe80::* --(Release + ClientID+ServerID+IANA(IAAddress))-> ff02::1:2
func (s *v6Server) packetHandler(conn net.PacketConn, peer net.Addr, req dhcpv6.DHCPv6) {
msg, err := req.GetInnerMessage()
if err != nil {
log.Error("DHCPv6: %s", err)
return
}
log.Debug("DHCPv6: received: %s", req.Summary())
err = s.checkCID(msg)
if err != nil {
log.Debug("%s", err)
return
}
err = s.checkSID(msg)
if err != nil {
log.Debug("%s", err)
return
}
var resp dhcpv6.DHCPv6
switch msg.Type() {
case dhcpv6.MessageTypeSolicit:
if msg.GetOneOption(dhcpv6.OptionRapidCommit) == nil {
resp, err = dhcpv6.NewAdvertiseFromSolicit(msg)
break
}
fallthrough
case dhcpv6.MessageTypeRequest,
dhcpv6.MessageTypeConfirm,
dhcpv6.MessageTypeRenew,
dhcpv6.MessageTypeRebind,
dhcpv6.MessageTypeRelease,
dhcpv6.MessageTypeInformationRequest:
resp, err = dhcpv6.NewReplyFromMessage(msg)
default:
log.Error("DHCPv6: message type %d not supported", msg.Type())
return
}
if err != nil {
log.Error("DHCPv6: %s", err)
return
}
resp.AddOption(dhcpv6.OptServerID(s.sid))
_ = s.process(msg, req, resp)
log.Debug("DHCPv6: sending: %s", resp.Summary())
_, err = conn.WriteTo(resp.ToBytes(), peer)
if err != nil {
log.Error("DHCPv6: conn.Write to %s failed: %s", peer, err)
return
}
}
// Get IPv6 address list
func getIfaceIPv6(iface net.Interface) []net.IP {
addrs, err := iface.Addrs()
if err != nil {
return nil
}
var res []net.IP
for _, a := range addrs {
ipnet, ok := a.(*net.IPNet)
if !ok {
continue
}
if ipnet.IP.To4() == nil {
res = append(res, ipnet.IP)
}
}
return res
}
// initialize RA module
func (s *v6Server) initRA(iface *net.Interface) error {
// choose the source IP address - should be link-local-unicast
s.ra.ipAddr = s.conf.dnsIPAddrs[0]
for _, ip := range s.conf.dnsIPAddrs {
if ip.IsLinkLocalUnicast() {
s.ra.ipAddr = ip
break
}
}
s.ra.raAllowSlaac = s.conf.RaAllowSlaac
s.ra.raSlaacOnly = s.conf.RaSlaacOnly
s.ra.dnsIPAddr = s.ra.ipAddr
s.ra.prefixIPAddr = s.conf.ipStart
s.ra.ifaceName = s.conf.InterfaceName
s.ra.iface = iface
s.ra.packetSendPeriod = 1 * time.Second
return s.ra.Init()
}
// Start - start server
func (s *v6Server) Start() error {
if !s.conf.Enabled {
return nil
}
iface, err := net.InterfaceByName(s.conf.InterfaceName)
if err != nil {
return wrapErrPrint(err, "Couldn't find interface by name %s", s.conf.InterfaceName)
}
s.conf.dnsIPAddrs = getIfaceIPv6(*iface)
if len(s.conf.dnsIPAddrs) == 0 {
log.Debug("DHCPv6: no IPv6 address for interface %s", iface.Name)
return nil
}
err = s.initRA(iface)
if err != nil {
return err
}
// don't initialize DHCPv6 server if we must force the clients to use SLAAC
if s.conf.RaSlaacOnly {
log.Debug("DHCPv6: not starting DHCPv6 server due to ra_slaac_only=true")
return nil
}
log.Debug("DHCPv6: starting...")
if len(iface.HardwareAddr) != 6 {
return fmt.Errorf("DHCPv6: invalid MAC %s", iface.HardwareAddr)
}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
LinkLayerAddr: iface.HardwareAddr,
Time: dhcpv6.GetTime(),
}
laddr := &net.UDPAddr{
IP: net.ParseIP("::"),
Port: dhcpv6.DefaultServerPort,
}
s.srv, err = server6.NewServer(iface.Name, laddr, s.packetHandler, server6.WithDebugLogger())
if err != nil {
return err
}
go func() {
err = s.srv.Serve()
log.Debug("DHCPv6: srv.Serve: %s", err)
}()
return nil
}
// Stop - stop server
func (s *v6Server) Stop() {
s.ra.Close()
// DHCPv6 server may not be initialized if ra_slaac_only=true
if s.srv == nil {
return
}
log.Debug("DHCPv6: stopping")
err := s.srv.Close()
if err != nil {
log.Error("DHCPv6: srv.Close: %s", err)
}
// now server.Serve() will return
s.srv = nil
}
// Create DHCPv6 server
func v6Create(conf V6ServerConf) (DHCPServer, error) {
s := &v6Server{}
s.conf = conf
if !conf.Enabled {
return s, nil
}
s.conf.ipStart = net.ParseIP(conf.RangeStart)
if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil {
return s, fmt.Errorf("DHCPv6: invalid range-start IP: %s", conf.RangeStart)
}
if conf.LeaseDuration == 0 {
s.conf.leaseTime = time.Hour * 24
s.conf.LeaseDuration = uint32(s.conf.leaseTime.Seconds())
} else {
s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration)
}
return s, nil
}

225
internal/dhcpd/v6_test.go Normal file
View File

@@ -0,0 +1,225 @@
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package dhcpd
import (
"net"
"testing"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/iana"
"github.com/stretchr/testify/assert"
)
func notify6(flags uint32) {
}
func TestV6StaticLeaseAddRemove(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
notify: notify6,
}
s, err := v6Create(conf)
assert.True(t, err == nil)
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
// add static lease
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// try to add static lease - fail
assert.True(t, s.AddStaticLease(l) != nil)
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
// try to remove static lease - fail
l.IP = net.ParseIP("2001::2")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil)
// remove static lease
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil)
// check
ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls))
}
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
notify: notify6,
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
// add dynamic lease
ld := Lease{}
ld.IP = net.ParseIP("2001::1")
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
s.addLease(&ld)
// add dynamic lease
{
ld := Lease{}
ld.IP = net.ParseIP("2001::2")
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
s.addLease(&ld)
}
// add static lease with the same IP
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// add static lease with the same MAC
l = Lease{}
l.IP = net.ParseIP("2001::3")
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// check
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls))
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
assert.Equal(t, "2001::3", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
}
func TestV6GetLease(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::1",
notify: notify6,
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
}
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
l := Lease{}
l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil)
// "Solicit"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv6.NewSolicit(mac)
msg, _ := req.GetInnerMessage()
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
assert.True(t, s.process(msg, req, resp))
resp.AddOption(dhcpv6.OptServerID(s.sid))
// check "Advertise"
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
oia := resp.Options.OneIANA()
oiaAddr := oia.Options.OneAddress()
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
// "Request"
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
msg, _ = req.GetInnerMessage()
resp, _ = dhcpv6.NewReplyFromMessage(msg)
assert.True(t, s.process(msg, req, resp))
// check "Reply"
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "2000::1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
}
func TestV6GetDynamicLease(t *testing.T) {
conf := V6ServerConf{
Enabled: true,
RangeStart: "2001::2",
notify: notify6,
}
sIface, err := v6Create(conf)
s := sIface.(*v6Server)
assert.True(t, err == nil)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT,
HwType: iana.HWTypeEthernet,
}
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
// "Solicit"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
req, _ := dhcpv6.NewSolicit(mac)
msg, _ := req.GetInnerMessage()
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
assert.True(t, s.process(msg, req, resp))
resp.AddOption(dhcpv6.OptServerID(s.sid))
// check "Advertise"
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
oia := resp.Options.OneIANA()
oiaAddr := oia.Options.OneAddress()
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
// "Request"
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
msg, _ = req.GetInnerMessage()
resp, _ = dhcpv6.NewReplyFromMessage(msg)
assert.True(t, s.process(msg, req, resp))
// check "Reply"
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
oia = resp.Options.OneIANA()
oiaAddr = oia.Options.OneAddress()
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs))
assert.Equal(t, "2000::1", dnsAddrs[0].String())
// check lease
ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls))
assert.Equal(t, "2001::2", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3")))
}