Pull request: * all: move internal Go packages to internal/
Merge in DNS/adguard-home from 2234-move-to-internal to master Squashed commit of the following: commit d26a288cabeac86f9483fab307677b1027c78524 Author: Eugene Burkov <e.burkov@adguard.com> Date: Fri Oct 30 12:44:18 2020 +0300 * all: move internal Go packages to internal/ Closes #2234.
This commit is contained in:
64
internal/dhcpd/README.md
Normal file
64
internal/dhcpd/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# DHCP server
|
||||
|
||||
Contents:
|
||||
* [Test setup with Virtual Box](#vbox)
|
||||
|
||||
<a id="vbox"></a>
|
||||
## Test setup with Virtual Box
|
||||
|
||||
To set up a test environment for DHCP server you need:
|
||||
|
||||
* Linux host machine
|
||||
* Virtual Box
|
||||
* Virtual machine (guest OS doesn't matter)
|
||||
|
||||
### Configure client
|
||||
|
||||
1. Install Virtual Box and run the following command to create a Host-Only network:
|
||||
|
||||
$ VBoxManage hostonlyif create
|
||||
|
||||
You can check its status by `ip a` command.
|
||||
|
||||
You can also set up Host-Only network using Virtual Box menu:
|
||||
|
||||
File -> Host Network Manager...
|
||||
|
||||
2. Create your virtual machine and set up its network:
|
||||
|
||||
VM Settings -> Network -> Host-only Adapter
|
||||
|
||||
3. Start your VM, install an OS. Configure your network interface to use DHCP and the OS should ask for a IP address from our DHCP server.
|
||||
|
||||
4. To see the current IP address on client OS you can use `ip a` command on Linux or `ipconfig` on Windows.
|
||||
|
||||
5. To force the client OS to request an IP from DHCP server again, you can use `dhclient` on Linux or `ipconfig /release` on Windows.
|
||||
|
||||
### Configure server
|
||||
|
||||
1. Edit server configuration file 'AdGuardHome.yaml', for example:
|
||||
|
||||
dhcp:
|
||||
enabled: true
|
||||
interface_name: vboxnet0
|
||||
dhcpv4:
|
||||
gateway_ip: 192.168.56.1
|
||||
subnet_mask: 255.255.255.0
|
||||
range_start: 192.168.56.2
|
||||
range_end: 192.168.56.2
|
||||
lease_duration: 86400
|
||||
icmp_timeout_msec: 1000
|
||||
options: []
|
||||
dhcpv6:
|
||||
range_start: 2001::1
|
||||
lease_duration: 86400
|
||||
ra_slaac_only: false
|
||||
ra_allow_slaac: false
|
||||
|
||||
2. Start the server
|
||||
|
||||
./AdGuardHome
|
||||
|
||||
There should be a message in log which shows that DHCP server is ready:
|
||||
|
||||
[info] DHCP: listening on 0.0.0.0:67
|
||||
215
internal/dhcpd/check_other_dhcp.go
Normal file
215
internal/dhcpd/check_other_dhcp.go
Normal file
@@ -0,0 +1,215 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd/nclient4"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/nclient6"
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
)
|
||||
|
||||
// CheckIfOtherDHCPServersPresentV4 sends a DHCP request to the specified network interface,
|
||||
// and waits for a response for a period defined by defaultDiscoverTime
|
||||
func CheckIfOtherDHCPServersPresentV4(ifaceName string) (bool, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't find interface by name %s", ifaceName)
|
||||
}
|
||||
|
||||
// get ipv4 address of an interface
|
||||
ifaceIPNet := getIfaceIPv4(*iface)
|
||||
if len(ifaceIPNet) == 0 {
|
||||
return false, fmt.Errorf("couldn't find IPv4 address of interface %s %+v", ifaceName, iface)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return false, fmt.Errorf("can't find DHCP server: not supported on macOS")
|
||||
}
|
||||
|
||||
srcIP := ifaceIPNet[0]
|
||||
src := net.JoinHostPort(srcIP.String(), "68")
|
||||
dst := "255.255.255.255:67"
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
req, err := dhcpv4.NewDiscovery(iface.HardwareAddr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("dhcpv4.NewDiscovery: %s", err)
|
||||
}
|
||||
req.Options.Update(dhcpv4.OptClientIdentifier(iface.HardwareAddr))
|
||||
req.Options.Update(dhcpv4.OptHostName(hostname))
|
||||
|
||||
// resolve 0.0.0.0:68
|
||||
udpAddr, err := net.ResolveUDPAddr("udp4", src)
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't resolve UDP address %s", src)
|
||||
}
|
||||
|
||||
if !udpAddr.IP.To4().Equal(srcIP) {
|
||||
return false, wrapErrPrint(err, "Resolved UDP address is not %s", src)
|
||||
}
|
||||
|
||||
// resolve 255.255.255.255:67
|
||||
dstAddr, err := net.ResolveUDPAddr("udp4", dst)
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't resolve UDP address %s", dst)
|
||||
}
|
||||
|
||||
// bind to 0.0.0.0:68
|
||||
log.Tracef("Listening to udp4 %+v", udpAddr)
|
||||
c, err := nclient4.NewRawUDPConn(ifaceName, 68)
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't listen on :68")
|
||||
}
|
||||
if c != nil {
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
// send to 255.255.255.255:67
|
||||
_, err = c.WriteTo(req.ToBytes(), dstAddr)
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't send a packet to %s", dst)
|
||||
}
|
||||
|
||||
for {
|
||||
// wait for answer
|
||||
log.Tracef("Waiting %v for an answer", defaultDiscoverTime)
|
||||
// TODO: replicate dhclient's behaviour of retrying several times with progressively bigger timeouts
|
||||
b := make([]byte, 1500)
|
||||
_ = c.SetReadDeadline(time.Now().Add(defaultDiscoverTime))
|
||||
n, _, err := c.ReadFrom(b)
|
||||
if isTimeout(err) {
|
||||
// timed out -- no DHCP servers
|
||||
log.Debug("DHCPv4: didn't receive DHCP response")
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't receive packet")
|
||||
}
|
||||
|
||||
log.Tracef("Received packet (%v bytes)", n)
|
||||
|
||||
response, err := dhcpv4.FromBytes(b[:n])
|
||||
if err != nil {
|
||||
log.Debug("DHCPv4: dhcpv4.FromBytes: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("DHCPv4: received message from server: %s", response.Summary())
|
||||
|
||||
if !(response.OpCode == dhcpv4.OpcodeBootReply &&
|
||||
response.HWType == iana.HWTypeEthernet &&
|
||||
bytes.Equal(response.ClientHWAddr, iface.HardwareAddr) &&
|
||||
bytes.Equal(response.TransactionID[:], req.TransactionID[:]) &&
|
||||
response.Options.Has(dhcpv4.OptionDHCPMessageType)) {
|
||||
log.Debug("DHCPv4: received message from server doesn't match our request")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Tracef("The packet is from an active DHCP server")
|
||||
// that's a DHCP server there
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CheckIfOtherDHCPServersPresentV6 sends a DHCP request to the specified network interface,
|
||||
// and waits for a response for a period defined by defaultDiscoverTime
|
||||
func CheckIfOtherDHCPServersPresentV6(ifaceName string) (bool, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("DHCPv6: net.InterfaceByName: %s: %s", ifaceName, err)
|
||||
}
|
||||
|
||||
ifaceIPNet := getIfaceIPv6(*iface)
|
||||
if len(ifaceIPNet) == 0 {
|
||||
return false, fmt.Errorf("DHCPv6: couldn't find IPv6 address of interface %s %+v", ifaceName, iface)
|
||||
}
|
||||
|
||||
srcIP := ifaceIPNet[0]
|
||||
src := net.JoinHostPort(srcIP.String(), "546")
|
||||
dst := "[ff02::1:2]:547"
|
||||
|
||||
req, err := dhcpv6.NewSolicit(iface.HardwareAddr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("DHCPv6: dhcpv6.NewSolicit: %s", err)
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp6", src)
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "DHCPv6: Couldn't resolve UDP address %s", src)
|
||||
}
|
||||
|
||||
if !udpAddr.IP.To16().Equal(srcIP) {
|
||||
return false, wrapErrPrint(err, "DHCPv6: Resolved UDP address is not %s", src)
|
||||
}
|
||||
|
||||
dstAddr, err := net.ResolveUDPAddr("udp6", dst)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("DHCPv6: Couldn't resolve UDP address %s: %s", dst, err)
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: Listening to udp6 %+v", udpAddr)
|
||||
c, err := nclient6.NewIPv6UDPConn(ifaceName, dhcpv6.DefaultClientPort)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("DHCPv6: Couldn't listen on :546: %s", err)
|
||||
}
|
||||
if c != nil {
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
_, err = c.WriteTo(req.ToBytes(), dstAddr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("DHCPv6: Couldn't send a packet to %s: %s", dst, err)
|
||||
}
|
||||
|
||||
for {
|
||||
log.Debug("DHCPv6: Waiting %v for an answer", defaultDiscoverTime)
|
||||
b := make([]byte, 4096)
|
||||
_ = c.SetReadDeadline(time.Now().Add(defaultDiscoverTime))
|
||||
n, _, err := c.ReadFrom(b)
|
||||
if isTimeout(err) {
|
||||
log.Debug("DHCPv6: didn't receive DHCP response")
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, wrapErrPrint(err, "Couldn't receive packet")
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: Received packet (%v bytes)", n)
|
||||
|
||||
resp, err := dhcpv6.FromBytes(b[:n])
|
||||
if err != nil {
|
||||
log.Debug("DHCPv6: dhcpv6.FromBytes: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: received message from server: %s", resp.Summary())
|
||||
|
||||
cid := req.Options.ClientID()
|
||||
msg, err := resp.GetInnerMessage()
|
||||
if err != nil {
|
||||
log.Debug("DHCPv6: resp.GetInnerMessage: %s", err)
|
||||
continue
|
||||
}
|
||||
rcid := msg.Options.ClientID()
|
||||
if resp.Type() == dhcpv6.MessageTypeAdvertise &&
|
||||
msg.TransactionID == req.TransactionID &&
|
||||
rcid != nil &&
|
||||
cid.Equal(*rcid) {
|
||||
log.Debug("DHCPv6: The packet is from an active DHCP server")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: received message from server doesn't match our request")
|
||||
}
|
||||
}
|
||||
11
internal/dhcpd/check_other_dhcp_windows.go
Normal file
11
internal/dhcpd/check_other_dhcp_windows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package dhcpd
|
||||
|
||||
import "fmt"
|
||||
|
||||
func CheckIfOtherDHCPServersPresentV4(ifaceName string) (bool, error) {
|
||||
return false, fmt.Errorf("not supported")
|
||||
}
|
||||
|
||||
func CheckIfOtherDHCPServersPresentV6(ifaceName string) (bool, error) {
|
||||
return false, fmt.Errorf("not supported")
|
||||
}
|
||||
173
internal/dhcpd/db.go
Normal file
173
internal/dhcpd/db.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// On-disk database for lease table
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
const dbFilename = "leases.db"
|
||||
|
||||
type leaseJSON struct {
|
||||
HWAddr []byte `json:"mac"`
|
||||
IP []byte `json:"ip"`
|
||||
Hostname string `json:"host"`
|
||||
Expiry int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func normalizeIP(ip net.IP) net.IP {
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// Load lease table from DB
|
||||
func (s *Server) dbLoad() {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
v6StaticLeases := []*Lease{}
|
||||
v6DynLeases := []*Lease{}
|
||||
|
||||
data, err := ioutil.ReadFile(s.conf.DBFilePath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Error("DHCP: can't read file %s: %v", s.conf.DBFilePath, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
obj := []leaseJSON{}
|
||||
err = json.Unmarshal(data, &obj)
|
||||
if err != nil {
|
||||
log.Error("DHCP: invalid DB: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
numLeases := len(obj)
|
||||
for i := range obj {
|
||||
obj[i].IP = normalizeIP(obj[i].IP)
|
||||
|
||||
if !(len(obj[i].IP) == 4 || len(obj[i].IP) == 16) {
|
||||
log.Info("DHCP: invalid IP: %s", obj[i].IP)
|
||||
continue
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
HWAddr: obj[i].HWAddr,
|
||||
IP: obj[i].IP,
|
||||
Hostname: obj[i].Hostname,
|
||||
Expiry: time.Unix(obj[i].Expiry, 0),
|
||||
}
|
||||
|
||||
if len(obj[i].IP) == 16 {
|
||||
if obj[i].Expiry == leaseExpireStatic {
|
||||
v6StaticLeases = append(v6StaticLeases, &lease)
|
||||
} else {
|
||||
v6DynLeases = append(v6DynLeases, &lease)
|
||||
}
|
||||
|
||||
} else {
|
||||
if obj[i].Expiry == leaseExpireStatic {
|
||||
staticLeases = append(staticLeases, &lease)
|
||||
} else {
|
||||
dynLeases = append(dynLeases, &lease)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
leases4 := normalizeLeases(staticLeases, dynLeases)
|
||||
s.srv4.ResetLeases(leases4)
|
||||
|
||||
leases6 := normalizeLeases(v6StaticLeases, v6DynLeases)
|
||||
if s.srv6 != nil {
|
||||
s.srv6.ResetLeases(leases6)
|
||||
}
|
||||
|
||||
log.Info("DHCP: loaded leases v4:%d v6:%d total-read:%d from DB",
|
||||
len(leases4), len(leases6), numLeases)
|
||||
}
|
||||
|
||||
// Skip duplicate leases
|
||||
// Static leases have a priority over dynamic leases
|
||||
func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
|
||||
leases := []*Lease{}
|
||||
index := map[string]int{}
|
||||
|
||||
for i, lease := range staticLeases {
|
||||
_, ok := index[lease.HWAddr.String()]
|
||||
if ok {
|
||||
continue // skip the lease with the same HW address
|
||||
}
|
||||
index[lease.HWAddr.String()] = i
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
|
||||
for i, lease := range dynLeases {
|
||||
_, ok := index[lease.HWAddr.String()]
|
||||
if ok {
|
||||
continue // skip the lease with the same HW address
|
||||
}
|
||||
index[lease.HWAddr.String()] = i
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
|
||||
return leases
|
||||
}
|
||||
|
||||
// Store lease table in DB
|
||||
func (s *Server) dbStore() {
|
||||
var leases []leaseJSON
|
||||
|
||||
leases4 := s.srv4.GetLeasesRef()
|
||||
for _, l := range leases4 {
|
||||
if l.Expiry.Unix() == 0 {
|
||||
continue
|
||||
}
|
||||
lease := leaseJSON{
|
||||
HWAddr: l.HWAddr,
|
||||
IP: l.IP,
|
||||
Hostname: l.Hostname,
|
||||
Expiry: l.Expiry.Unix(),
|
||||
}
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
|
||||
if s.srv6 != nil {
|
||||
leases6 := s.srv6.GetLeasesRef()
|
||||
for _, l := range leases6 {
|
||||
if l.Expiry.Unix() == 0 {
|
||||
continue
|
||||
}
|
||||
lease := leaseJSON{
|
||||
HWAddr: l.HWAddr,
|
||||
IP: l.IP,
|
||||
Hostname: l.Hostname,
|
||||
Expiry: l.Expiry.Unix(),
|
||||
}
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(leases)
|
||||
if err != nil {
|
||||
log.Error("json.Marshal: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = file.SafeWrite(s.conf.DBFilePath, data)
|
||||
if err != nil {
|
||||
log.Error("DHCP: can't store lease table on disk: %v filename: %s",
|
||||
err, s.conf.DBFilePath)
|
||||
return
|
||||
}
|
||||
log.Info("DHCP: stored %d leases in DB", len(leases))
|
||||
}
|
||||
511
internal/dhcpd/dhcp_http.go
Normal file
511
internal/dhcpd/dhcp_http.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/jsonutil"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Info("DHCP: %s %s: %s", r.Method, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// []Lease -> JSON
|
||||
func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string {
|
||||
leases := []map[string]string{}
|
||||
for _, l := range inputLeases {
|
||||
lease := map[string]string{
|
||||
"mac": l.HWAddr.String(),
|
||||
"ip": l.IP.String(),
|
||||
"hostname": l.Hostname,
|
||||
}
|
||||
|
||||
if includeExpires {
|
||||
lease["expires"] = l.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
leases = append(leases, lease)
|
||||
}
|
||||
return leases
|
||||
}
|
||||
|
||||
type v4ServerConfJSON struct {
|
||||
GatewayIP string `json:"gateway_ip"`
|
||||
SubnetMask string `json:"subnet_mask"`
|
||||
RangeStart string `json:"range_start"`
|
||||
RangeEnd string `json:"range_end"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON {
|
||||
return v4ServerConfJSON{
|
||||
GatewayIP: c.GatewayIP,
|
||||
SubnetMask: c.SubnetMask,
|
||||
RangeStart: c.RangeStart,
|
||||
RangeEnd: c.RangeEnd,
|
||||
LeaseDuration: c.LeaseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
|
||||
return V4ServerConf{
|
||||
GatewayIP: j.GatewayIP,
|
||||
SubnetMask: j.SubnetMask,
|
||||
RangeStart: j.RangeStart,
|
||||
RangeEnd: j.RangeEnd,
|
||||
LeaseDuration: j.LeaseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
type v6ServerConfJSON struct {
|
||||
RangeStart string `json:"range_start"`
|
||||
LeaseDuration uint32 `json:"lease_duration"`
|
||||
}
|
||||
|
||||
func v6ServerConfToJSON(c V6ServerConf) v6ServerConfJSON {
|
||||
return v6ServerConfJSON{
|
||||
RangeStart: c.RangeStart,
|
||||
LeaseDuration: c.LeaseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func v6JSONToServerConf(j v6ServerConfJSON) V6ServerConf {
|
||||
return V6ServerConf{
|
||||
RangeStart: j.RangeStart,
|
||||
LeaseDuration: j.LeaseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
leases := convertLeases(s.Leases(LeasesDynamic), true)
|
||||
staticLeases := convertLeases(s.Leases(LeasesStatic), false)
|
||||
|
||||
v4conf := V4ServerConf{}
|
||||
s.srv4.WriteDiskConfig4(&v4conf)
|
||||
|
||||
v6conf := V6ServerConf{}
|
||||
s.srv6.WriteDiskConfig6(&v6conf)
|
||||
|
||||
status := map[string]interface{}{
|
||||
"enabled": s.conf.Enabled,
|
||||
"interface_name": s.conf.InterfaceName,
|
||||
"v4": v4ServerConfToJSON(v4conf),
|
||||
"v6": v6ServerConfToJSON(v6conf),
|
||||
"leases": leases,
|
||||
"static_leases": staticLeases,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(status)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Unable to marshal DHCP status json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type staticLeaseJSON struct {
|
||||
HWAddr string `json:"mac"`
|
||||
IP string `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
type dhcpServerConfigJSON struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
InterfaceName string `json:"interface_name"`
|
||||
V4 v4ServerConfJSON `json:"v4"`
|
||||
V6 v6ServerConfJSON `json:"v6"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
newconfig := dhcpServerConfigJSON{}
|
||||
newconfig.Enabled = s.conf.Enabled
|
||||
newconfig.InterfaceName = s.conf.InterfaceName
|
||||
|
||||
js, err := jsonutil.DecodeObject(&newconfig, r.Body)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Failed to parse new DHCP config json: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var s4 DHCPServer
|
||||
var s6 DHCPServer
|
||||
v4Enabled := false
|
||||
v6Enabled := false
|
||||
|
||||
if js.Exists("v4") {
|
||||
v4conf := v4JSONToServerConf(newconfig.V4)
|
||||
v4conf.Enabled = newconfig.Enabled
|
||||
if len(v4conf.RangeStart) == 0 {
|
||||
v4conf.Enabled = false
|
||||
}
|
||||
v4Enabled = v4conf.Enabled
|
||||
v4conf.InterfaceName = newconfig.InterfaceName
|
||||
|
||||
c4 := V4ServerConf{}
|
||||
s.srv4.WriteDiskConfig4(&c4)
|
||||
v4conf.notify = c4.notify
|
||||
v4conf.ICMPTimeout = c4.ICMPTimeout
|
||||
|
||||
s4, err = v4Create(v4conf)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv4 configuration: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if js.Exists("v6") {
|
||||
v6conf := v6JSONToServerConf(newconfig.V6)
|
||||
v6conf.Enabled = newconfig.Enabled
|
||||
if len(v6conf.RangeStart) == 0 {
|
||||
v6conf.Enabled = false
|
||||
}
|
||||
v6Enabled = v6conf.Enabled
|
||||
v6conf.InterfaceName = newconfig.InterfaceName
|
||||
v6conf.notify = s.onNotify
|
||||
s6, err = v6Create(v6conf)
|
||||
if s6 == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Invalid DHCPv6 configuration: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if newconfig.Enabled && !v4Enabled && !v6Enabled {
|
||||
httpError(r, w, http.StatusBadRequest, "DHCPv4 or DHCPv6 configuration must be complete")
|
||||
return
|
||||
}
|
||||
|
||||
s.Stop()
|
||||
|
||||
if js.Exists("enabled") {
|
||||
s.conf.Enabled = newconfig.Enabled
|
||||
}
|
||||
|
||||
if js.Exists("interface_name") {
|
||||
s.conf.InterfaceName = newconfig.InterfaceName
|
||||
}
|
||||
|
||||
if s4 != nil {
|
||||
s.srv4 = s4
|
||||
}
|
||||
if s6 != nil {
|
||||
s.srv6 = s6
|
||||
}
|
||||
s.conf.ConfigModified()
|
||||
s.dbLoad()
|
||||
|
||||
if s.conf.Enabled {
|
||||
staticIP, err := HasStaticIP(newconfig.InterfaceName)
|
||||
if !staticIP && err == nil {
|
||||
err = SetStaticIP(newconfig.InterfaceName)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to configure static IP: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type netInterfaceJSON struct {
|
||||
Name string `json:"name"`
|
||||
GatewayIP string `json:"gateway_ip"`
|
||||
HardwareAddr string `json:"hardware_address"`
|
||||
Addrs4 []string `json:"ipv4_addresses"`
|
||||
Addrs6 []string `json:"ipv6_addresses"`
|
||||
Flags string `json:"flags"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]interface{}{}
|
||||
|
||||
ifaces, err := util.GetValidNetInterfaces()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
// it's a loopback, skip it
|
||||
continue
|
||||
}
|
||||
if iface.Flags&net.FlagBroadcast == 0 {
|
||||
// this interface doesn't support broadcast, skip it
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
jsonIface := netInterfaceJSON{
|
||||
Name: iface.Name,
|
||||
HardwareAddr: iface.HardwareAddr.String(),
|
||||
}
|
||||
|
||||
if iface.Flags != 0 {
|
||||
jsonIface.Flags = iface.Flags.String()
|
||||
}
|
||||
// we don't want link-local addresses in json, so skip them
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
// not an IPNet, should not happen
|
||||
httpError(r, w, http.StatusInternalServerError, "SHOULD NOT HAPPEN: got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||
return
|
||||
}
|
||||
// ignore link-local
|
||||
if ipnet.IP.IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String())
|
||||
} else {
|
||||
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String())
|
||||
}
|
||||
}
|
||||
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
|
||||
jsonIface.GatewayIP = getGatewayIP(iface.Name)
|
||||
response[iface.Name] = jsonIface
|
||||
}
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to marshal json with available interfaces: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Perform the following tasks:
|
||||
// . Search for another DHCP server running
|
||||
// . Check if a static IP is configured for the network interface
|
||||
// Respond with results
|
||||
func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errorText := fmt.Sprintf("failed to read request body: %s", err)
|
||||
log.Error(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
interfaceName := strings.TrimSpace(string(body))
|
||||
if interfaceName == "" {
|
||||
errorText := fmt.Sprintf("empty interface name specified")
|
||||
log.Error(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
found4, err4 := CheckIfOtherDHCPServersPresentV4(interfaceName)
|
||||
|
||||
staticIP := map[string]interface{}{}
|
||||
isStaticIP, err := HasStaticIP(interfaceName)
|
||||
staticIPStatus := "yes"
|
||||
if err != nil {
|
||||
staticIPStatus = "error"
|
||||
staticIP["error"] = err.Error()
|
||||
} else if !isStaticIP {
|
||||
staticIPStatus = "no"
|
||||
staticIP["ip"] = util.GetSubnet(interfaceName)
|
||||
}
|
||||
staticIP["static"] = staticIPStatus
|
||||
|
||||
v4 := map[string]interface{}{}
|
||||
othSrv := map[string]interface{}{}
|
||||
foundVal := "no"
|
||||
if found4 {
|
||||
foundVal = "yes"
|
||||
} else if err != nil {
|
||||
foundVal = "error"
|
||||
othSrv["error"] = err4.Error()
|
||||
}
|
||||
othSrv["found"] = foundVal
|
||||
v4["other_server"] = othSrv
|
||||
v4["static_ip"] = staticIP
|
||||
|
||||
found6, err6 := CheckIfOtherDHCPServersPresentV6(interfaceName)
|
||||
|
||||
v6 := map[string]interface{}{}
|
||||
othSrv = map[string]interface{}{}
|
||||
foundVal = "no"
|
||||
if found6 {
|
||||
foundVal = "yes"
|
||||
} else if err6 != nil {
|
||||
foundVal = "error"
|
||||
othSrv["error"] = err6.Error()
|
||||
}
|
||||
othSrv["found"] = foundVal
|
||||
v6["other_server"] = othSrv
|
||||
|
||||
result := map[string]interface{}{}
|
||||
result["v4"] = v4
|
||||
result["v6"] = v6
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(result)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "Failed to marshal DHCP found json: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
lj := staticLeaseJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP(lj.IP)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
mac, err := net.ParseMAC(lj.HWAddr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid MAC")
|
||||
return
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
}
|
||||
|
||||
err = s.srv6.AddStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ip, _ = parseIPv4(lj.IP)
|
||||
if ip == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
return
|
||||
}
|
||||
|
||||
mac, err := net.ParseMAC(lj.HWAddr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid MAC")
|
||||
return
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = s.srv4.AddStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
lj := staticLeaseJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&lj)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP(lj.IP)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
mac, err := net.ParseMAC(lj.HWAddr)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid MAC")
|
||||
return
|
||||
}
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
}
|
||||
|
||||
err = s.srv6.RemoveStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ip, _ = parseIPv4(lj.IP)
|
||||
if ip == nil {
|
||||
httpError(r, w, http.StatusBadRequest, "invalid IP")
|
||||
return
|
||||
}
|
||||
|
||||
mac, _ := net.ParseMAC(lj.HWAddr)
|
||||
|
||||
lease := Lease{
|
||||
IP: ip,
|
||||
HWAddr: mac,
|
||||
Hostname: lj.Hostname,
|
||||
}
|
||||
err = s.srv4.RemoveStaticLease(lease)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
s.Stop()
|
||||
|
||||
err := os.Remove(s.conf.DBFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Error("DHCP: os.Remove: %s: %s", s.conf.DBFilePath, err)
|
||||
}
|
||||
|
||||
oldconf := s.conf
|
||||
s.conf = ServerConfig{}
|
||||
s.conf.WorkDir = oldconf.WorkDir
|
||||
s.conf.HTTPRegister = oldconf.HTTPRegister
|
||||
s.conf.ConfigModified = oldconf.ConfigModified
|
||||
s.conf.DBFilePath = oldconf.DBFilePath
|
||||
|
||||
v4conf := V4ServerConf{}
|
||||
v4conf.ICMPTimeout = 1000
|
||||
v4conf.notify = s.onNotify
|
||||
s.srv4, _ = v4Create(v4conf)
|
||||
|
||||
v6conf := V6ServerConf{}
|
||||
v6conf.notify = s.onNotify
|
||||
s.srv6, _ = v6Create(v6conf)
|
||||
|
||||
s.conf.ConfigModified()
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
s.conf.HTTPRegister("GET", "/control/dhcp/status", s.handleDHCPStatus)
|
||||
s.conf.HTTPRegister("GET", "/control/dhcp/interfaces", s.handleDHCPInterfaces)
|
||||
s.conf.HTTPRegister("POST", "/control/dhcp/set_config", s.handleDHCPSetConfig)
|
||||
s.conf.HTTPRegister("POST", "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
|
||||
s.conf.HTTPRegister("POST", "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
|
||||
s.conf.HTTPRegister("POST", "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
|
||||
s.conf.HTTPRegister("POST", "/control/dhcp/reset", s.handleReset)
|
||||
}
|
||||
259
internal/dhcpd/dhcpd.go
Normal file
259
internal/dhcpd/dhcpd.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
const defaultDiscoverTime = time.Second * 3
|
||||
const leaseExpireStatic = 1
|
||||
|
||||
var webHandlersRegistered = false
|
||||
|
||||
// Lease contains the necessary information about a DHCP lease
|
||||
type Lease struct {
|
||||
HWAddr net.HardwareAddr `json:"mac"`
|
||||
IP net.IP `json:"ip"`
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// Lease expiration time
|
||||
// 1: static lease
|
||||
Expiry time.Time `json:"expires"`
|
||||
}
|
||||
|
||||
// ServerConfig - DHCP server configuration
|
||||
// field ordering is important -- yaml fields will mirror ordering from here
|
||||
type ServerConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
Conf4 V4ServerConf `yaml:"dhcpv4"`
|
||||
Conf6 V6ServerConf `yaml:"dhcpv6"`
|
||||
|
||||
WorkDir string `yaml:"-"`
|
||||
DBFilePath string `yaml:"-"` // path to DB file
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`
|
||||
}
|
||||
|
||||
type OnLeaseChangedT func(flags int)
|
||||
|
||||
// flags for onLeaseChanged()
|
||||
const (
|
||||
LeaseChangedAdded = iota
|
||||
LeaseChangedAddedStatic
|
||||
LeaseChangedRemovedStatic
|
||||
|
||||
LeaseChangedDBStore
|
||||
)
|
||||
|
||||
// Server - the current state of the DHCP server
|
||||
type Server struct {
|
||||
srv4 DHCPServer
|
||||
srv6 DHCPServer
|
||||
|
||||
conf ServerConfig
|
||||
|
||||
// Called when the leases DB is modified
|
||||
onLeaseChanged []OnLeaseChangedT
|
||||
}
|
||||
|
||||
type ServerInterface interface {
|
||||
Leases(flags int) []Lease
|
||||
SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT)
|
||||
}
|
||||
|
||||
// CheckConfig checks the configuration
|
||||
func (s *Server) CheckConfig(config ServerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create - create object
|
||||
func Create(config ServerConfig) *Server {
|
||||
s := Server{}
|
||||
s.conf.Enabled = config.Enabled
|
||||
s.conf.InterfaceName = config.InterfaceName
|
||||
s.conf.HTTPRegister = config.HTTPRegister
|
||||
s.conf.ConfigModified = config.ConfigModified
|
||||
s.conf.DBFilePath = filepath.Join(config.WorkDir, dbFilename)
|
||||
|
||||
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
|
||||
webHandlersRegistered = true
|
||||
s.registerHandlers()
|
||||
}
|
||||
|
||||
var err4, err6 error
|
||||
v4conf := config.Conf4
|
||||
v4conf.Enabled = s.conf.Enabled
|
||||
if len(v4conf.RangeStart) == 0 {
|
||||
v4conf.Enabled = false
|
||||
}
|
||||
v4conf.InterfaceName = s.conf.InterfaceName
|
||||
v4conf.notify = s.onNotify
|
||||
s.srv4, err4 = v4Create(v4conf)
|
||||
|
||||
v6conf := config.Conf6
|
||||
v6conf.Enabled = s.conf.Enabled
|
||||
if len(v6conf.RangeStart) == 0 {
|
||||
v6conf.Enabled = false
|
||||
}
|
||||
v6conf.InterfaceName = s.conf.InterfaceName
|
||||
v6conf.notify = s.onNotify
|
||||
s.srv6, err6 = v6Create(v6conf)
|
||||
|
||||
if err4 != nil {
|
||||
log.Error("%s", err4)
|
||||
return nil
|
||||
}
|
||||
if err6 != nil {
|
||||
log.Error("%s", err6)
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.conf.Enabled && !v4conf.Enabled && !v6conf.Enabled {
|
||||
log.Error("Can't enable DHCP server because neither DHCPv4 nor DHCPv6 servers are configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
// we can't delay database loading until DHCP server is started,
|
||||
// because we need static leases functionality available beforehand
|
||||
s.dbLoad()
|
||||
return &s
|
||||
}
|
||||
|
||||
// server calls this function after DB is updated
|
||||
func (s *Server) onNotify(flags uint32) {
|
||||
if flags == LeaseChangedDBStore {
|
||||
s.dbStore()
|
||||
return
|
||||
}
|
||||
|
||||
s.notify(int(flags))
|
||||
}
|
||||
|
||||
// SetOnLeaseChanged - set callback
|
||||
func (s *Server) SetOnLeaseChanged(onLeaseChanged OnLeaseChangedT) {
|
||||
s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged)
|
||||
}
|
||||
|
||||
func (s *Server) notify(flags int) {
|
||||
if len(s.onLeaseChanged) == 0 {
|
||||
return
|
||||
}
|
||||
for _, f := range s.onLeaseChanged {
|
||||
f(flags)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
func (s *Server) WriteDiskConfig(c *ServerConfig) {
|
||||
c.Enabled = s.conf.Enabled
|
||||
c.InterfaceName = s.conf.InterfaceName
|
||||
s.srv4.WriteDiskConfig4(&c.Conf4)
|
||||
s.srv6.WriteDiskConfig6(&c.Conf6)
|
||||
}
|
||||
|
||||
// Start will listen on port 67 and serve DHCP requests.
|
||||
func (s *Server) Start() error {
|
||||
err := s.srv4.Start()
|
||||
if err != nil {
|
||||
log.Error("DHCPv4: start: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.srv6.Start()
|
||||
if err != nil {
|
||||
log.Error("DHCPv6: start: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the listening UDP socket
|
||||
func (s *Server) Stop() {
|
||||
s.srv4.Stop()
|
||||
s.srv6.Stop()
|
||||
}
|
||||
|
||||
// flags for Leases() function
|
||||
const (
|
||||
LeasesDynamic = 1
|
||||
LeasesStatic = 2
|
||||
LeasesAll = LeasesDynamic | LeasesStatic
|
||||
)
|
||||
|
||||
// Leases returns the list of current DHCP leases (thread-safe)
|
||||
func (s *Server) Leases(flags int) []Lease {
|
||||
result := s.srv4.GetLeases(flags)
|
||||
|
||||
v6leases := s.srv6.GetLeases(flags)
|
||||
result = append(result, v6leases...)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
func (s *Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
if ip.To4() != nil {
|
||||
return s.srv4.FindMACbyIP(ip)
|
||||
}
|
||||
return s.srv6.FindMACbyIP(ip)
|
||||
}
|
||||
|
||||
// AddStaticLease - add static v4 lease
|
||||
func (s *Server) AddStaticLease(lease Lease) error {
|
||||
return s.srv4.AddStaticLease(lease)
|
||||
}
|
||||
|
||||
// Parse option string
|
||||
// Format:
|
||||
// CODE TYPE VALUE
|
||||
func parseOptionString(s string) (uint8, []byte) {
|
||||
s = strings.TrimSpace(s)
|
||||
scode := util.SplitNext(&s, ' ')
|
||||
t := util.SplitNext(&s, ' ')
|
||||
sval := util.SplitNext(&s, ' ')
|
||||
|
||||
code, err := strconv.Atoi(scode)
|
||||
if err != nil || code <= 0 || code > 255 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var val []byte
|
||||
|
||||
switch t {
|
||||
case "hex":
|
||||
val, err = hex.DecodeString(sval)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
case "ip":
|
||||
ip := net.ParseIP(sval)
|
||||
if ip == nil {
|
||||
return 0, nil
|
||||
}
|
||||
val = ip
|
||||
if ip.To4() != nil {
|
||||
val = ip.To4()
|
||||
}
|
||||
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return uint8(code), val
|
||||
}
|
||||
132
internal/dhcpd/dhcpd_test.go
Normal file
132
internal/dhcpd/dhcpd_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func check(t *testing.T, result bool, msg string) {
|
||||
if !result {
|
||||
t.Fatal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func testNotify(flags uint32) {
|
||||
}
|
||||
|
||||
// Leases database store/load
|
||||
func TestDB(t *testing.T) {
|
||||
var err error
|
||||
s := Server{}
|
||||
s.conf.DBFilePath = dbFilename
|
||||
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
notify: testNotify,
|
||||
}
|
||||
s.srv4, err = v4Create(conf)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
s.srv6, err = v6Create(V6ServerConf{})
|
||||
assert.True(t, err == nil)
|
||||
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.100").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
exp1 := time.Now().Add(time.Hour)
|
||||
l.Expiry = exp1
|
||||
s.srv4.(*v4Server).addLease(&l)
|
||||
|
||||
l2 := Lease{}
|
||||
l2.IP = net.ParseIP("192.168.10.101").To4()
|
||||
l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb")
|
||||
s.srv4.AddStaticLease(l2)
|
||||
|
||||
_ = os.Remove("leases.db")
|
||||
s.dbStore()
|
||||
s.srv4.ResetLeases(nil)
|
||||
|
||||
s.dbLoad()
|
||||
|
||||
ll := s.srv4.GetLeases(LeasesAll)
|
||||
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
|
||||
assert.Equal(t, "192.168.10.101", ll[0].IP.String())
|
||||
assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix())
|
||||
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
|
||||
assert.Equal(t, "192.168.10.100", ll[1].IP.String())
|
||||
assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix())
|
||||
|
||||
_ = os.Remove("leases.db")
|
||||
}
|
||||
|
||||
func TestIsValidSubnetMask(t *testing.T) {
|
||||
assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0}))
|
||||
assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0}))
|
||||
assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0}))
|
||||
assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0}))
|
||||
assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1}))
|
||||
}
|
||||
|
||||
func TestNormalizeLeases(t *testing.T) {
|
||||
dynLeases := []*Lease{}
|
||||
staticLeases := []*Lease{}
|
||||
leases := []*Lease{}
|
||||
|
||||
lease := &Lease{}
|
||||
lease.HWAddr = []byte{1, 2, 3, 4}
|
||||
dynLeases = append(dynLeases, lease)
|
||||
lease = new(Lease)
|
||||
lease.HWAddr = []byte{1, 2, 3, 5}
|
||||
dynLeases = append(dynLeases, lease)
|
||||
|
||||
lease = new(Lease)
|
||||
lease.HWAddr = []byte{1, 2, 3, 4}
|
||||
lease.IP = []byte{0, 2, 3, 4}
|
||||
staticLeases = append(staticLeases, lease)
|
||||
lease = new(Lease)
|
||||
lease.HWAddr = []byte{2, 2, 3, 4}
|
||||
staticLeases = append(staticLeases, lease)
|
||||
|
||||
leases = normalizeLeases(staticLeases, dynLeases)
|
||||
|
||||
assert.True(t, len(leases) == 3)
|
||||
assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4}))
|
||||
assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4}))
|
||||
assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4}))
|
||||
assert.True(t, bytes.Equal(leases[2].HWAddr, []byte{1, 2, 3, 5}))
|
||||
}
|
||||
|
||||
func TestOptions(t *testing.T) {
|
||||
code, val := parseOptionString(" 12 hex abcdef ")
|
||||
assert.Equal(t, uint8(12), code)
|
||||
assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val))
|
||||
|
||||
code, _ = parseOptionString(" 12 hex abcdef1 ")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
|
||||
code, val = parseOptionString("123 ip 1.2.3.4")
|
||||
assert.Equal(t, uint8(123), code)
|
||||
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
|
||||
|
||||
code, _ = parseOptionString("256 ip 1.1.1.1")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
code, _ = parseOptionString("-1 ip 1.1.1.1")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
code, _ = parseOptionString("12 ip 1.1.1.1x")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
code, _ = parseOptionString("12 x 1.1.1.1")
|
||||
assert.Equal(t, uint8(0), code)
|
||||
}
|
||||
76
internal/dhcpd/helpers.go
Normal file
76
internal/dhcpd/helpers.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/joomcode/errorx"
|
||||
)
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
operr, ok := err.(*net.OpError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return operr.Timeout()
|
||||
}
|
||||
|
||||
// Get IPv4 address list
|
||||
func getIfaceIPv4(iface net.Interface) []net.IP {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var res []net.IP
|
||||
for _, a := range addrs {
|
||||
ipnet, ok := a.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() != nil {
|
||||
res = append(res, ipnet.IP.To4())
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func wrapErrPrint(err error, message string, args ...interface{}) error {
|
||||
var errx error
|
||||
if err == nil {
|
||||
errx = fmt.Errorf(message, args...)
|
||||
} else {
|
||||
errx = errorx.Decorate(err, message, args...)
|
||||
}
|
||||
log.Println(errx.Error())
|
||||
return errx
|
||||
}
|
||||
|
||||
func parseIPv4(text string) (net.IP, error) {
|
||||
result := net.ParseIP(text)
|
||||
if result == nil {
|
||||
return nil, fmt.Errorf("%s is not an IP address", text)
|
||||
}
|
||||
if result.To4() == nil {
|
||||
return nil, fmt.Errorf("%s is not an IPv4 address", text)
|
||||
}
|
||||
return result.To4(), nil
|
||||
}
|
||||
|
||||
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0)
|
||||
func isValidSubnetMask(mask net.IP) bool {
|
||||
var n uint32
|
||||
n = binary.BigEndian.Uint32(mask)
|
||||
for i := 0; i != 32; i++ {
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
if (n & 0x80000000) == 0 {
|
||||
return false
|
||||
}
|
||||
n <<= 1
|
||||
}
|
||||
return true
|
||||
}
|
||||
586
internal/dhcpd/nclient4/client.go
Normal file
586
internal/dhcpd/nclient4/client.go
Normal file
@@ -0,0 +1,586 @@
|
||||
// Copyright 2018 the u-root Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
// +build go1.12
|
||||
|
||||
// Package nclient4 is a small, minimum-functionality client for DHCPv4.
|
||||
//
|
||||
// It only supports the 4-way DHCPv4 Discover-Offer-Request-Ack handshake as
|
||||
// well as the Request-Ack renewal process.
|
||||
// Originally from here: github.com/insomniacslk/dhcp/dhcpv4/nclient4
|
||||
// with the difference that this package can be built on UNIX (not just Linux),
|
||||
// because github.com/mdlayher/raw package supports it.
|
||||
package nclient4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBufferCap = 5
|
||||
|
||||
// DefaultTimeout is the default value for read-timeout if option WithTimeout is not set
|
||||
DefaultTimeout = 5 * time.Second
|
||||
|
||||
// DefaultRetries is amount of retries will be done if no answer was received within read-timeout amount of time
|
||||
DefaultRetries = 3
|
||||
|
||||
// MaxMessageSize is the value to be used for DHCP option "MaxMessageSize".
|
||||
MaxMessageSize = 1500
|
||||
|
||||
// ClientPort is the port that DHCP clients listen on.
|
||||
ClientPort = 68
|
||||
|
||||
// ServerPort is the port that DHCP servers and relay agents listen on.
|
||||
ServerPort = 67
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultServers is the address of all link-local DHCP servers and
|
||||
// relay agents.
|
||||
DefaultServers = &net.UDPAddr{
|
||||
IP: net.IPv4bcast,
|
||||
Port: ServerPort,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoResponse is returned when no response packet is received.
|
||||
ErrNoResponse = errors.New("no matching response packet received")
|
||||
|
||||
// ErrNoConn is returned when NewWithConn is called with nil-value as conn.
|
||||
ErrNoConn = errors.New("conn is nil")
|
||||
|
||||
// ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr
|
||||
ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil")
|
||||
)
|
||||
|
||||
// pendingCh is a channel associated with a pending TransactionID.
|
||||
type pendingCh struct {
|
||||
// SendAndRead closes done to indicate that it wishes for no more
|
||||
// messages for this particular XID.
|
||||
done <-chan struct{}
|
||||
|
||||
// ch is used by the receive loop to distribute DHCP messages.
|
||||
ch chan<- *dhcpv4.DHCPv4
|
||||
}
|
||||
|
||||
// Logger is a handler which will be used to output logging messages
|
||||
type Logger interface {
|
||||
// PrintMessage print _all_ DHCP messages
|
||||
PrintMessage(prefix string, message *dhcpv4.DHCPv4)
|
||||
|
||||
// Printf is use to print the rest debugging information
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// EmptyLogger prints nothing
|
||||
type EmptyLogger struct{}
|
||||
|
||||
// Printf is just a dummy function that does nothing
|
||||
func (e EmptyLogger) Printf(format string, v ...interface{}) {}
|
||||
|
||||
// PrintMessage is just a dummy function that does nothing
|
||||
func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {}
|
||||
|
||||
// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer.
|
||||
type Printfer interface {
|
||||
// Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf.
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger.
|
||||
// DHCP messages are printed in the short format.
|
||||
type ShortSummaryLogger struct {
|
||||
// Printfer is used for actual output of the logger
|
||||
Printfer
|
||||
}
|
||||
|
||||
// Printf prints a log message as-is via predefined Printfer
|
||||
func (s ShortSummaryLogger) Printf(format string, v ...interface{}) {
|
||||
s.Printfer.Printf(format, v...)
|
||||
}
|
||||
|
||||
// PrintMessage prints a DHCP message in the short format via predefined Printfer
|
||||
func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
|
||||
s.Printf("%s: %s", prefix, message)
|
||||
}
|
||||
|
||||
// DebugLogger is a wrapper for Printfer to implement interface Logger.
|
||||
// DHCP messages are printed in the long format.
|
||||
type DebugLogger struct {
|
||||
// Printfer is used for actual output of the logger
|
||||
Printfer
|
||||
}
|
||||
|
||||
// Printf prints a log message as-is via predefined Printfer
|
||||
func (d DebugLogger) Printf(format string, v ...interface{}) {
|
||||
d.Printfer.Printf(format, v...)
|
||||
}
|
||||
|
||||
// PrintMessage prints a DHCP message in the long format via predefined Printfer
|
||||
func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {
|
||||
d.Printf("%s: %s", prefix, message.Summary())
|
||||
}
|
||||
|
||||
// Client is an IPv4 DHCP client.
|
||||
type Client struct {
|
||||
ifaceHWAddr net.HardwareAddr
|
||||
conn net.PacketConn
|
||||
timeout time.Duration
|
||||
retry int
|
||||
logger Logger
|
||||
|
||||
// bufferCap is the channel capacity for each TransactionID.
|
||||
bufferCap int
|
||||
|
||||
// serverAddr is the UDP address to send all packets to.
|
||||
//
|
||||
// This may be an actual broadcast address, or a unicast address.
|
||||
serverAddr *net.UDPAddr
|
||||
|
||||
// closed is an atomic bool set to 1 when done is closed.
|
||||
closed uint32
|
||||
|
||||
// done is closed to unblock the receive loop.
|
||||
done chan struct{}
|
||||
|
||||
// wg protects any spawned goroutines, namely the receiveLoop.
|
||||
wg sync.WaitGroup
|
||||
|
||||
pendingMu sync.Mutex
|
||||
// pending stores the distribution channels for each pending
|
||||
// TransactionID. receiveLoop uses this map to determine which channel
|
||||
// to send a new DHCP message to.
|
||||
pending map[dhcpv4.TransactionID]*pendingCh
|
||||
}
|
||||
|
||||
// New returns a client usable with an unconfigured interface.
|
||||
func New(iface string, opts ...ClientOpt) (*Client, error) {
|
||||
return new(iface, nil, nil, opts...)
|
||||
}
|
||||
|
||||
// NewWithConn creates a new DHCP client that sends and receives packets on the
|
||||
// given interface.
|
||||
func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
|
||||
return new(``, conn, ifaceHWAddr, opts...)
|
||||
}
|
||||
|
||||
func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) {
|
||||
c := &Client{
|
||||
ifaceHWAddr: ifaceHWAddr,
|
||||
timeout: DefaultTimeout,
|
||||
retry: DefaultRetries,
|
||||
serverAddr: DefaultServers,
|
||||
bufferCap: defaultBufferCap,
|
||||
conn: conn,
|
||||
logger: EmptyLogger{},
|
||||
|
||||
done: make(chan struct{}),
|
||||
pending: make(map[dhcpv4.TransactionID]*pendingCh),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to apply option: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.ifaceHWAddr == nil {
|
||||
if iface == `` {
|
||||
return nil, ErrNoIfaceHWAddr
|
||||
}
|
||||
|
||||
i, err := net.InterfaceByName(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get interface information: %w", err)
|
||||
}
|
||||
|
||||
c.ifaceHWAddr = i.HardwareAddr
|
||||
}
|
||||
|
||||
if c.conn == nil {
|
||||
var err error
|
||||
if iface == `` {
|
||||
return nil, ErrNoConn
|
||||
}
|
||||
c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err)
|
||||
}
|
||||
}
|
||||
c.wg.Add(1)
|
||||
go c.receiveLoop()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying connection.
|
||||
func (c *Client) Close() error {
|
||||
// Make sure not to close done twice.
|
||||
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.conn.Close()
|
||||
|
||||
// Closing c.done sets off a chain reaction:
|
||||
//
|
||||
// Any SendAndRead unblocks trying to receive more messages, which
|
||||
// means rem() gets called.
|
||||
//
|
||||
// rem() should be unblocking receiveLoop if it is blocked.
|
||||
//
|
||||
// receiveLoop should then exit gracefully.
|
||||
close(c.done)
|
||||
|
||||
// Wait for receiveLoop to stop.
|
||||
c.wg.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) isClosed() bool {
|
||||
return atomic.LoadUint32(&c.closed) != 0
|
||||
}
|
||||
|
||||
func (c *Client) receiveLoop() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
// TODO: Clients can send a "max packet size" option in their
|
||||
// packets, IIRC. Choose a reasonable size and set it.
|
||||
b := make([]byte, MaxMessageSize)
|
||||
n, _, err := c.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
if !c.isClosed() {
|
||||
c.logger.Printf("error reading from UDP connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := dhcpv4.FromBytes(b[:n])
|
||||
if err != nil {
|
||||
// Not a valid DHCP packet; keep listening.
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.OpCode != dhcpv4.OpcodeBootReply {
|
||||
// Not a response message.
|
||||
continue
|
||||
}
|
||||
|
||||
// This is a somewhat non-standard check, by the looks
|
||||
// of RFC 2131. It should work as long as the DHCP
|
||||
// server is spec-compliant for the HWAddr field.
|
||||
if c.ifaceHWAddr != nil && !bytes.Equal(c.ifaceHWAddr, msg.ClientHWAddr) {
|
||||
// Not for us.
|
||||
continue
|
||||
}
|
||||
|
||||
c.pendingMu.Lock()
|
||||
p, ok := c.pending[msg.TransactionID]
|
||||
if ok {
|
||||
select {
|
||||
case <-p.done:
|
||||
close(p.ch)
|
||||
delete(c.pending, msg.TransactionID)
|
||||
|
||||
// This send may block.
|
||||
case p.ch <- msg:
|
||||
}
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientOpt is a function that configures the Client.
|
||||
type ClientOpt func(c *Client) error
|
||||
|
||||
// WithTimeout configures the retransmission timeout.
|
||||
//
|
||||
// Default is 5 seconds.
|
||||
func WithTimeout(d time.Duration) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.timeout = d
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithSummaryLogger logs one-line DHCPv4 message summaries when sent & received.
|
||||
func WithSummaryLogger() ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.logger = ShortSummaryLogger{
|
||||
Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithDebugLogger logs multi-line full DHCPv4 messages when sent & received.
|
||||
func WithDebugLogger() ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.logger = DebugLogger{
|
||||
Printfer: log.New(os.Stderr, "[dhcpv4] ", log.LstdFlags),
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger set the logger (see interface Logger).
|
||||
func WithLogger(newLogger Logger) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.logger = newLogger
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnicast forces client to send messages as unicast frames.
|
||||
// By default client sends messages as broadcast frames even if server address is defined.
|
||||
//
|
||||
// srcAddr is both:
|
||||
// * The source address of outgoing frames.
|
||||
// * The address to be listened for incoming frames.
|
||||
func WithUnicast(srcAddr *net.UDPAddr) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
if srcAddr == nil {
|
||||
srcAddr = &net.UDPAddr{Port: ServerPort}
|
||||
}
|
||||
c.conn, err = net.ListenUDP("udp4", srcAddr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to start listening UDP port: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithHWAddr tells to the Client to receive messages destinated to selected
|
||||
// hardware address
|
||||
func WithHWAddr(hwAddr net.HardwareAddr) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.ifaceHWAddr = hwAddr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// nolint
|
||||
func withBufferCap(n int) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.bufferCap = n
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetry configures the number of retransmissions to attempt.
|
||||
//
|
||||
// Default is 3.
|
||||
func WithRetry(r int) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.retry = r
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerAddr configures the address to send messages to.
|
||||
func WithServerAddr(n *net.UDPAddr) ClientOpt {
|
||||
return func(c *Client) (err error) {
|
||||
c.serverAddr = n
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Matcher matches DHCP packets.
|
||||
type Matcher func(*dhcpv4.DHCPv4) bool
|
||||
|
||||
// IsMessageType returns a matcher that checks for the message type.
|
||||
//
|
||||
// If t is MessageTypeNone, all packets are matched.
|
||||
func IsMessageType(t dhcpv4.MessageType) Matcher {
|
||||
return func(p *dhcpv4.DHCPv4) bool {
|
||||
return p.MessageType() == t || t == dhcpv4.MessageTypeNone
|
||||
}
|
||||
}
|
||||
|
||||
// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer
|
||||
// received.
|
||||
func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) {
|
||||
// RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should
|
||||
// contain.
|
||||
discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers,
|
||||
dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to create a discovery request: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("got an error while the discovery request: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Request completes the 4-way Discover-Offer-Request-Ack handshake.
|
||||
//
|
||||
// Note that modifiers will be applied *both* to Discover and Request packets.
|
||||
func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) {
|
||||
offer, err = c.DiscoverOffer(ctx, modifiers...)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to receive an offer: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(chrisko): should this be unicast to the server?
|
||||
request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers,
|
||||
dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to create a request: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
ack, err = c.SendAndRead(ctx, c.serverAddr, request, nil)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("got an error while processing the request: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ErrTransactionIDInUse is returned if there were an attempt to send a message
|
||||
// with the same TransactionID as we are already waiting an answer for.
|
||||
type ErrTransactionIDInUse struct {
|
||||
// TransactionID is the transaction ID of the message which the error is related to.
|
||||
TransactionID dhcpv4.TransactionID
|
||||
}
|
||||
|
||||
// Error is just the method to comply interface "error".
|
||||
func (err *ErrTransactionIDInUse) Error() string {
|
||||
return fmt.Sprintf("transaction ID %s already in use", err.TransactionID)
|
||||
}
|
||||
|
||||
// send sends p to destination and returns a response channel.
|
||||
//
|
||||
// Responses will be matched by transaction ID and ClientHWAddr.
|
||||
//
|
||||
// The returned lambda function must be called after all desired responses have
|
||||
// been received in order to return the Transaction ID to the usable pool.
|
||||
func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv4.DHCPv4, cancel func(), err error) {
|
||||
c.pendingMu.Lock()
|
||||
if _, ok := c.pending[msg.TransactionID]; ok {
|
||||
c.pendingMu.Unlock()
|
||||
return nil, nil, &ErrTransactionIDInUse{msg.TransactionID}
|
||||
}
|
||||
|
||||
ch := make(chan *dhcpv4.DHCPv4, c.bufferCap)
|
||||
done := make(chan struct{})
|
||||
c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
cancel = func() {
|
||||
// Why can't we just close ch here?
|
||||
//
|
||||
// Because receiveLoop may potentially be blocked trying to
|
||||
// send on ch. We gotta unblock it first, and then we can take
|
||||
// the lock and remove the XID from the pending transaction
|
||||
// map.
|
||||
close(done)
|
||||
|
||||
c.pendingMu.Lock()
|
||||
if p, ok := c.pending[msg.TransactionID]; ok {
|
||||
close(p.ch)
|
||||
delete(c.pending, msg.TransactionID)
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil {
|
||||
cancel()
|
||||
return nil, nil, fmt.Errorf("error writing packet to connection: %w", err)
|
||||
}
|
||||
return ch, cancel, nil
|
||||
}
|
||||
|
||||
// This error should never be visible to users.
|
||||
// It is used only to increase the timeout in retryFn.
|
||||
var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
|
||||
|
||||
// SendAndRead sends a packet p to a destination dest and waits for the first
|
||||
// response matching `match` as well as its Transaction ID and ClientHWAddr.
|
||||
//
|
||||
// If match is nil, the first packet matching the Transaction ID and
|
||||
// ClientHWAddr is returned.
|
||||
func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, p *dhcpv4.DHCPv4, match Matcher) (*dhcpv4.DHCPv4, error) {
|
||||
var response *dhcpv4.DHCPv4
|
||||
err := c.retryFn(func(timeout time.Duration) error {
|
||||
ch, rem, err := c.send(dest, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.PrintMessage("sent message", p)
|
||||
defer rem()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return ErrNoResponse
|
||||
|
||||
case <-time.After(timeout):
|
||||
return errDeadlineExceeded
|
||||
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
case packet := <-ch:
|
||||
if match == nil || match(packet) {
|
||||
c.logger.PrintMessage("received message", packet)
|
||||
response = packet
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
if err == errDeadlineExceeded {
|
||||
return nil, ErrNoResponse
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *Client) retryFn(fn func(timeout time.Duration) error) error {
|
||||
timeout := c.timeout
|
||||
|
||||
// Each retry takes the amount of timeout at worst.
|
||||
for i := 0; i < c.retry || c.retry < 0; i++ { // TODO: why is this called "retry" if this is "tries" ("retries"+1)?
|
||||
switch err := fn(timeout); err {
|
||||
case nil:
|
||||
// Got it!
|
||||
return nil
|
||||
|
||||
case errDeadlineExceeded:
|
||||
// Double timeout, then retry.
|
||||
timeout *= 2
|
||||
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return errDeadlineExceeded
|
||||
}
|
||||
333
internal/dhcpd/nclient4/client_test.go
Normal file
333
internal/dhcpd/nclient4/client_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
// Copyright 2018 the u-root Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build linux
|
||||
// github.com/hugelgupf/socketpair is Linux-only
|
||||
// +build go1.12
|
||||
|
||||
package nclient4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hugelgupf/socketpair"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
mu sync.Mutex
|
||||
received []*dhcpv4.DHCPv4
|
||||
|
||||
// Each received packet can have more than one response (in theory,
|
||||
// from different servers sending different Advertise, for example).
|
||||
responses [][]*dhcpv4.DHCPv4
|
||||
}
|
||||
|
||||
func (h *handler) handle(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.received = append(h.received, m)
|
||||
|
||||
if len(h.responses) > 0 {
|
||||
for _, resp := range h.responses[0] {
|
||||
_, _ = conn.WriteTo(resp.ToBytes(), peer)
|
||||
}
|
||||
h.responses = h.responses[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...ClientOpt) (*Client, net.PacketConn) {
|
||||
// Fake PacketConn connection.
|
||||
clientRawConn, serverRawConn, err := socketpair.PacketSocketPair()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{Port: ClientPort})
|
||||
serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort})
|
||||
|
||||
o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)}
|
||||
o = append(o, opts...)
|
||||
mc, err := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
h := &handler{responses: responses}
|
||||
s, err := server4.NewServer("", nil, h.handle, server4.WithConn(serverConn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
_ = s.Serve()
|
||||
}()
|
||||
|
||||
return mc, serverConn
|
||||
}
|
||||
|
||||
func ComparePacket(got *dhcpv4.DHCPv4, want *dhcpv4.DHCPv4) error {
|
||||
if got == nil && got == want {
|
||||
return nil
|
||||
}
|
||||
if (want == nil || got == nil) && (got != want) {
|
||||
return fmt.Errorf("packet got %v, want %v", got, want)
|
||||
}
|
||||
if !bytes.Equal(got.ToBytes(), want.ToBytes()) {
|
||||
return fmt.Errorf("packet got %v, want %v", got, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func pktsExpected(got []*dhcpv4.DHCPv4, want []*dhcpv4.DHCPv4) error {
|
||||
if len(got) != len(want) {
|
||||
return fmt.Errorf("got %d packets, want %d packets", len(got), len(want))
|
||||
}
|
||||
|
||||
for i := range got {
|
||||
if err := ComparePacket(got[i], want[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newPacketWeirdHWAddr(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 {
|
||||
p, err := dhcpv4.New()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("newpacket: %v", err))
|
||||
}
|
||||
p.OpCode = op
|
||||
p.TransactionID = xid
|
||||
p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 1, 2, 3, 4, 5, 6}
|
||||
return p
|
||||
}
|
||||
|
||||
func newPacket(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 {
|
||||
p, err := dhcpv4.New()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("newpacket: %v", err))
|
||||
}
|
||||
p.OpCode = op
|
||||
p.TransactionID = xid
|
||||
p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestSendAndRead(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
desc string
|
||||
send *dhcpv4.DHCPv4
|
||||
server []*dhcpv4.DHCPv4
|
||||
|
||||
// If want is nil, we assume server[0] contains what is wanted.
|
||||
want *dhcpv4.DHCPv4
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
desc: "two response packets",
|
||||
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
server: []*dhcpv4.DHCPv4{
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
{
|
||||
desc: "one response packet",
|
||||
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
server: []*dhcpv4.DHCPv4{
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
{
|
||||
desc: "one response packet, one invalid XID, one invalid opcode, one invalid hwaddr",
|
||||
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
server: []*dhcpv4.DHCPv4{
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x77, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacketWeirdHWAddr(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
{
|
||||
desc: "discard wrong XID",
|
||||
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
server: []*dhcpv4.DHCPv4{
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0, 0, 0, 0}),
|
||||
},
|
||||
want: nil, // Explicitly empty.
|
||||
wantErr: ErrNoResponse,
|
||||
},
|
||||
{
|
||||
desc: "no response, timeout",
|
||||
send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
wantErr: ErrNoResponse,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
// Both server and client only get 2 seconds.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{tt.server},
|
||||
// Use an unbuffered channel to make sure we
|
||||
// have no deadlocks.
|
||||
withBufferCap(0))
|
||||
defer mc.Close()
|
||||
|
||||
rcvd, err := mc.SendAndRead(context.Background(), DefaultServers, tt.send, nil)
|
||||
if err != tt.wantErr {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := ComparePacket(rcvd, tt.want); err != nil {
|
||||
t.Errorf("got unexpected packets: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelSendAndRead(t *testing.T) {
|
||||
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
|
||||
|
||||
// Both the server and client only get 2 seconds.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{},
|
||||
WithTimeout(10*time.Second),
|
||||
// Use an unbuffered channel to make sure nothing blocks.
|
||||
withBufferCap(0))
|
||||
defer mc.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
|
||||
t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
time.Sleep(4 * time.Second)
|
||||
|
||||
if err := mc.Close(); err != nil {
|
||||
t.Errorf("closing failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestReuseXID(t *testing.T) {
|
||||
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
|
||||
|
||||
// Both the server and client only get 2 seconds.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{})
|
||||
defer mc.Close()
|
||||
|
||||
if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
|
||||
t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
|
||||
}
|
||||
|
||||
if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse {
|
||||
t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
|
||||
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})
|
||||
|
||||
responses := newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33})
|
||||
|
||||
// Both the server and client only get 2 seconds.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mc, udpConn := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{{responses}})
|
||||
defer mc.Close()
|
||||
|
||||
// Too short for valid DHCPv4 packet.
|
||||
_, _ = udpConn.WriteTo([]byte{0x01}, nil)
|
||||
_, _ = udpConn.WriteTo([]byte{0x01, 0x2}, nil)
|
||||
|
||||
rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil)
|
||||
if err != nil {
|
||||
t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err)
|
||||
}
|
||||
|
||||
if err := ComparePacket(rcvd, responses); err != nil {
|
||||
t.Errorf("got unexpected packets: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSendAndRead(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
desc string
|
||||
send []*dhcpv4.DHCPv4
|
||||
server [][]*dhcpv4.DHCPv4
|
||||
wantErr []error
|
||||
}{
|
||||
{
|
||||
desc: "two requests, two responses",
|
||||
send: []*dhcpv4.DHCPv4{
|
||||
newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x44, 0x44, 0x44, 0x44}),
|
||||
},
|
||||
server: [][]*dhcpv4.DHCPv4{
|
||||
[]*dhcpv4.DHCPv4{ // Response for first packet.
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}),
|
||||
},
|
||||
[]*dhcpv4.DHCPv4{ // Response for second packet.
|
||||
newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x44, 0x44, 0x44, 0x44}),
|
||||
},
|
||||
},
|
||||
wantErr: []error{
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
} {
|
||||
// Both server and client only get 2 seconds.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mc, _ := serveAndClient(ctx, tt.server)
|
||||
defer mc.Close()
|
||||
|
||||
for i, send := range tt.send {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
rcvd, err := mc.SendAndRead(ctx, DefaultServers, send, nil)
|
||||
|
||||
if wantErr := tt.wantErr[i]; err != wantErr {
|
||||
t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr)
|
||||
}
|
||||
if err := pktsExpected([]*dhcpv4.DHCPv4{rcvd}, tt.server[i]); err != nil {
|
||||
t.Errorf("got unexpected packets: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
144
internal/dhcpd/nclient4/conn_unix.go
Normal file
144
internal/dhcpd/nclient4/conn_unix.go
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright 2018 the u-root Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
// +build go1.12
|
||||
|
||||
package nclient4
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/mdlayher/ethernet"
|
||||
"github.com/mdlayher/raw"
|
||||
"github.com/u-root/u-root/pkg/uio"
|
||||
)
|
||||
|
||||
var (
|
||||
// BroadcastMac is the broadcast MAC address.
|
||||
//
|
||||
// Any UDP packet sent to this address is broadcast on the subnet.
|
||||
BroadcastMac = net.HardwareAddr([]byte{255, 255, 255, 255, 255, 255})
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr".
|
||||
ErrUDPAddrIsRequired = errors.New("must supply UDPAddr")
|
||||
)
|
||||
|
||||
// NewRawUDPConn returns a UDP connection bound to the interface and port
|
||||
// given based on a raw packet socket. All packets are broadcasted.
|
||||
//
|
||||
// The interface can be completely unconfigured.
|
||||
func NewRawUDPConn(iface string, port int) (net.PacketConn, error) {
|
||||
ifc, err := net.InterfaceByName(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawConn, err := raw.ListenPacket(ifc, uint16(ethernet.EtherTypeIPv4), &raw.Config{LinuxSockDGRAM: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewBroadcastUDPConn(rawConn, &net.UDPAddr{Port: port}), nil
|
||||
}
|
||||
|
||||
// BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast
|
||||
// MAC address.
|
||||
type BroadcastRawUDPConn struct {
|
||||
// PacketConn is a raw DGRAM socket.
|
||||
net.PacketConn
|
||||
|
||||
// boundAddr is the address this RawUDPConn is "bound" to.
|
||||
//
|
||||
// Calls to ReadFrom will only return packets destined to this address.
|
||||
boundAddr *net.UDPAddr
|
||||
}
|
||||
|
||||
// NewBroadcastUDPConn returns a PacketConn that marshals and unmarshals UDP
|
||||
// packets, sending them to the broadcast MAC at on rawPacketConn.
|
||||
//
|
||||
// Calls to ReadFrom will only return packets destined to boundAddr.
|
||||
func NewBroadcastUDPConn(rawPacketConn net.PacketConn, boundAddr *net.UDPAddr) net.PacketConn {
|
||||
return &BroadcastRawUDPConn{
|
||||
PacketConn: rawPacketConn,
|
||||
boundAddr: boundAddr,
|
||||
}
|
||||
}
|
||||
|
||||
func udpMatch(addr *net.UDPAddr, bound *net.UDPAddr) bool {
|
||||
if bound == nil {
|
||||
return true
|
||||
}
|
||||
if bound.IP != nil && !bound.IP.Equal(addr.IP) {
|
||||
return false
|
||||
}
|
||||
return bound.Port == addr.Port
|
||||
}
|
||||
|
||||
// ReadFrom implements net.PacketConn.ReadFrom.
|
||||
//
|
||||
// ReadFrom reads raw IP packets and will try to match them against
|
||||
// upc.boundAddr. Any matching packets are returned via the given buffer.
|
||||
func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
ipHdrMaxLen := IPv4MaximumHeaderSize
|
||||
udpHdrLen := UDPMinimumSize
|
||||
|
||||
for {
|
||||
pkt := make([]byte, ipHdrMaxLen+udpHdrLen+len(b))
|
||||
n, _, err := upc.PacketConn.ReadFrom(pkt)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
pkt = pkt[:n]
|
||||
buf := uio.NewBigEndianBuffer(pkt)
|
||||
|
||||
// To read the header length, access data directly.
|
||||
ipHdr := IPv4(buf.Data())
|
||||
ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength())))
|
||||
|
||||
if ipHdr.TransportProtocol() != UDPProtocolNumber {
|
||||
continue
|
||||
}
|
||||
udpHdr := UDP(buf.Consume(udpHdrLen))
|
||||
|
||||
addr := &net.UDPAddr{
|
||||
IP: ipHdr.DestinationAddress(),
|
||||
Port: int(udpHdr.DestinationPort()),
|
||||
}
|
||||
if !udpMatch(addr, upc.boundAddr) {
|
||||
continue
|
||||
}
|
||||
srcAddr := &net.UDPAddr{
|
||||
IP: ipHdr.SourceAddress(),
|
||||
Port: int(udpHdr.SourcePort()),
|
||||
}
|
||||
// Extra padding after end of IP packet should be ignored,
|
||||
// if not dhcp option parsing will fail.
|
||||
dhcpLen := int(ipHdr.PayloadLength()) - udpHdrLen
|
||||
return copy(b, buf.Consume(dhcpLen)), srcAddr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo implements net.PacketConn.WriteTo and broadcasts all packets at the
|
||||
// raw socket level.
|
||||
//
|
||||
// WriteTo wraps the given packet in the appropriate UDP and IP header before
|
||||
// sending it on the packet conn.
|
||||
func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, ErrUDPAddrIsRequired
|
||||
}
|
||||
|
||||
// Using the boundAddr is not quite right here, but it works.
|
||||
packet := udp4pkt(b, udpAddr, upc.boundAddr)
|
||||
|
||||
// Broadcasting is not always right, but hell, what the ARP do I know.
|
||||
return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: BroadcastMac})
|
||||
}
|
||||
377
internal/dhcpd/nclient4/ipv4.go
Normal file
377
internal/dhcpd/nclient4/ipv4.go
Normal file
@@ -0,0 +1,377 @@
|
||||
// Copyright 2018 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// This file contains code taken from gVisor.
|
||||
|
||||
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
// +build go1.12
|
||||
|
||||
package nclient4
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
"github.com/u-root/u-root/pkg/uio"
|
||||
)
|
||||
|
||||
const (
|
||||
versIHL = 0
|
||||
tos = 1
|
||||
totalLen = 2
|
||||
id = 4
|
||||
flagsFO = 6
|
||||
ttl = 8
|
||||
protocol = 9
|
||||
checksum = 10
|
||||
srcAddr = 12
|
||||
dstAddr = 16
|
||||
)
|
||||
|
||||
// TransportProtocolNumber is the number of a transport protocol.
|
||||
type TransportProtocolNumber uint32
|
||||
|
||||
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
|
||||
// fields of a packet that needs to be encoded.
|
||||
type IPv4Fields struct {
|
||||
// IHL is the "internet header length" field of an IPv4 packet.
|
||||
IHL uint8
|
||||
|
||||
// TOS is the "type of service" field of an IPv4 packet.
|
||||
TOS uint8
|
||||
|
||||
// TotalLength is the "total length" field of an IPv4 packet.
|
||||
TotalLength uint16
|
||||
|
||||
// ID is the "identification" field of an IPv4 packet.
|
||||
ID uint16
|
||||
|
||||
// Flags is the "flags" field of an IPv4 packet.
|
||||
Flags uint8
|
||||
|
||||
// FragmentOffset is the "fragment offset" field of an IPv4 packet.
|
||||
FragmentOffset uint16
|
||||
|
||||
// TTL is the "time to live" field of an IPv4 packet.
|
||||
TTL uint8
|
||||
|
||||
// Protocol is the "protocol" field of an IPv4 packet.
|
||||
Protocol uint8
|
||||
|
||||
// Checksum is the "checksum" field of an IPv4 packet.
|
||||
Checksum uint16
|
||||
|
||||
// SrcAddr is the "source ip address" of an IPv4 packet.
|
||||
SrcAddr net.IP
|
||||
|
||||
// DstAddr is the "destination ip address" of an IPv4 packet.
|
||||
DstAddr net.IP
|
||||
}
|
||||
|
||||
// IPv4 represents an ipv4 header stored in a byte array.
|
||||
// Most of the methods of IPv4 access to the underlying slice without
|
||||
// checking the boundaries and could panic because of 'index out of range'.
|
||||
// Always call IsValid() to validate an instance of IPv4 before using other methods.
|
||||
type IPv4 []byte
|
||||
|
||||
const (
|
||||
// IPv4MinimumSize is the minimum size of a valid IPv4 packet.
|
||||
IPv4MinimumSize = 20
|
||||
|
||||
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
|
||||
// that there are only 4 bits to represents the header length in 32-bit
|
||||
// units, the header cannot exceed 15*4 = 60 bytes.
|
||||
IPv4MaximumHeaderSize = 60
|
||||
|
||||
// IPv4AddressSize is the size, in bytes, of an IPv4 address.
|
||||
IPv4AddressSize = 4
|
||||
|
||||
// IPv4Version is the version of the ipv4 protocol.
|
||||
IPv4Version = 4
|
||||
)
|
||||
|
||||
var (
|
||||
// IPv4Broadcast is the broadcast address of the IPv4 protocol.
|
||||
IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff}
|
||||
|
||||
// IPv4Any is the non-routable IPv4 "any" meta address.
|
||||
IPv4Any = net.IP{0, 0, 0, 0}
|
||||
)
|
||||
|
||||
// Flags that may be set in an IPv4 packet.
|
||||
const (
|
||||
IPv4FlagMoreFragments = 1 << iota
|
||||
IPv4FlagDontFragment
|
||||
)
|
||||
|
||||
// HeaderLength returns the value of the "header length" field of the ipv4
|
||||
// header.
|
||||
func (b IPv4) HeaderLength() uint8 {
|
||||
return (b[versIHL] & 0xf) * 4
|
||||
}
|
||||
|
||||
// Protocol returns the value of the protocol field of the ipv4 header.
|
||||
func (b IPv4) Protocol() uint8 {
|
||||
return b[protocol]
|
||||
}
|
||||
|
||||
// SourceAddress returns the "source address" field of the ipv4 header.
|
||||
func (b IPv4) SourceAddress() net.IP {
|
||||
return net.IP(b[srcAddr : srcAddr+IPv4AddressSize])
|
||||
}
|
||||
|
||||
// DestinationAddress returns the "destination address" field of the ipv4
|
||||
// header.
|
||||
func (b IPv4) DestinationAddress() net.IP {
|
||||
return net.IP(b[dstAddr : dstAddr+IPv4AddressSize])
|
||||
}
|
||||
|
||||
// TransportProtocol implements Network.TransportProtocol.
|
||||
func (b IPv4) TransportProtocol() TransportProtocolNumber {
|
||||
return TransportProtocolNumber(b.Protocol())
|
||||
}
|
||||
|
||||
// Payload implements Network.Payload.
|
||||
func (b IPv4) Payload() []byte {
|
||||
return b[b.HeaderLength():][:b.PayloadLength()]
|
||||
}
|
||||
|
||||
// PayloadLength returns the length of the payload portion of the ipv4 packet.
|
||||
func (b IPv4) PayloadLength() uint16 {
|
||||
return b.TotalLength() - uint16(b.HeaderLength())
|
||||
}
|
||||
|
||||
// TotalLength returns the "total length" field of the ipv4 header.
|
||||
func (b IPv4) TotalLength() uint16 {
|
||||
return binary.BigEndian.Uint16(b[totalLen:])
|
||||
}
|
||||
|
||||
// SetTotalLength sets the "total length" field of the ipv4 header.
|
||||
func (b IPv4) SetTotalLength(totalLength uint16) {
|
||||
binary.BigEndian.PutUint16(b[totalLen:], totalLength)
|
||||
}
|
||||
|
||||
// SetChecksum sets the checksum field of the ipv4 header.
|
||||
func (b IPv4) SetChecksum(v uint16) {
|
||||
binary.BigEndian.PutUint16(b[checksum:], v)
|
||||
}
|
||||
|
||||
// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
|
||||
// ipv4 header.
|
||||
func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
|
||||
v := (uint16(flags) << 13) | (offset >> 3)
|
||||
binary.BigEndian.PutUint16(b[flagsFO:], v)
|
||||
}
|
||||
|
||||
// SetSourceAddress sets the "source address" field of the ipv4 header.
|
||||
func (b IPv4) SetSourceAddress(addr net.IP) {
|
||||
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4())
|
||||
}
|
||||
|
||||
// SetDestinationAddress sets the "destination address" field of the ipv4
|
||||
// header.
|
||||
func (b IPv4) SetDestinationAddress(addr net.IP) {
|
||||
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4())
|
||||
}
|
||||
|
||||
// CalculateChecksum calculates the checksum of the ipv4 header.
|
||||
func (b IPv4) CalculateChecksum() uint16 {
|
||||
return Checksum(b[:b.HeaderLength()], 0)
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the ipv4 header.
|
||||
func (b IPv4) Encode(i *IPv4Fields) {
|
||||
b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
|
||||
b[tos] = i.TOS
|
||||
b.SetTotalLength(i.TotalLength)
|
||||
binary.BigEndian.PutUint16(b[id:], i.ID)
|
||||
b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
|
||||
b[ttl] = i.TTL
|
||||
b[protocol] = i.Protocol
|
||||
b.SetChecksum(i.Checksum)
|
||||
copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
|
||||
copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
|
||||
}
|
||||
|
||||
const (
|
||||
udpSrcPort = 0
|
||||
udpDstPort = 2
|
||||
udpLength = 4
|
||||
udpChecksum = 6
|
||||
)
|
||||
|
||||
// UDPFields contains the fields of a UDP packet. It is used to describe the
|
||||
// fields of a packet that needs to be encoded.
|
||||
type UDPFields struct {
|
||||
// SrcPort is the "source port" field of a UDP packet.
|
||||
SrcPort uint16
|
||||
|
||||
// DstPort is the "destination port" field of a UDP packet.
|
||||
DstPort uint16
|
||||
|
||||
// Length is the "length" field of a UDP packet.
|
||||
Length uint16
|
||||
|
||||
// Checksum is the "checksum" field of a UDP packet.
|
||||
Checksum uint16
|
||||
}
|
||||
|
||||
// UDP represents a UDP header stored in a byte array.
|
||||
type UDP []byte
|
||||
|
||||
const (
|
||||
// UDPMinimumSize is the minimum size of a valid UDP packet.
|
||||
UDPMinimumSize = 8
|
||||
|
||||
// UDPProtocolNumber is UDP's transport protocol number.
|
||||
UDPProtocolNumber TransportProtocolNumber = 17
|
||||
)
|
||||
|
||||
// SourcePort returns the "source port" field of the udp header.
|
||||
func (b UDP) SourcePort() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpSrcPort:])
|
||||
}
|
||||
|
||||
// DestinationPort returns the "destination port" field of the udp header.
|
||||
func (b UDP) DestinationPort() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpDstPort:])
|
||||
}
|
||||
|
||||
// Length returns the "length" field of the udp header.
|
||||
func (b UDP) Length() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpLength:])
|
||||
}
|
||||
|
||||
// SetSourcePort sets the "source port" field of the udp header.
|
||||
func (b UDP) SetSourcePort(port uint16) {
|
||||
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
|
||||
}
|
||||
|
||||
// SetDestinationPort sets the "destination port" field of the udp header.
|
||||
func (b UDP) SetDestinationPort(port uint16) {
|
||||
binary.BigEndian.PutUint16(b[udpDstPort:], port)
|
||||
}
|
||||
|
||||
// SetChecksum sets the "checksum" field of the udp header.
|
||||
func (b UDP) SetChecksum(checksum uint16) {
|
||||
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
|
||||
}
|
||||
|
||||
// Payload returns the data contained in the UDP datagram.
|
||||
func (b UDP) Payload() []byte {
|
||||
return b[UDPMinimumSize:]
|
||||
}
|
||||
|
||||
// Checksum returns the "checksum" field of the udp header.
|
||||
func (b UDP) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpChecksum:])
|
||||
}
|
||||
|
||||
// CalculateChecksum calculates the checksum of the udp packet, given the total
|
||||
// length of the packet and the checksum of the network-layer pseudo-header
|
||||
// (excluding the total length) and the checksum of the payload.
|
||||
func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
|
||||
// Add the length portion of the checksum to the pseudo-checksum.
|
||||
tmp := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||
checksum := Checksum(tmp, partialChecksum)
|
||||
|
||||
// Calculate the rest of the checksum.
|
||||
return Checksum(b[:UDPMinimumSize], checksum)
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the udp header.
|
||||
func (b UDP) Encode(u *UDPFields) {
|
||||
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
|
||||
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
|
||||
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
|
||||
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
|
||||
}
|
||||
|
||||
func calculateChecksum(buf []byte, initial uint32) uint16 {
|
||||
v := initial
|
||||
|
||||
l := len(buf)
|
||||
if l&1 != 0 {
|
||||
l--
|
||||
v += uint32(buf[l]) << 8
|
||||
}
|
||||
|
||||
for i := 0; i < l; i += 2 {
|
||||
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
|
||||
}
|
||||
|
||||
return ChecksumCombine(uint16(v), uint16(v>>16))
|
||||
}
|
||||
|
||||
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
|
||||
// given byte array.
|
||||
//
|
||||
// The initial checksum must have been computed on an even number of bytes.
|
||||
func Checksum(buf []byte, initial uint16) uint16 {
|
||||
return calculateChecksum(buf, uint32(initial))
|
||||
}
|
||||
|
||||
// ChecksumCombine combines the two uint16 to form their checksum. This is done
|
||||
// by adding them and the carry.
|
||||
//
|
||||
// Note that checksum a must have been computed on an even number of bytes.
|
||||
func ChecksumCombine(a, b uint16) uint16 {
|
||||
v := uint32(a) + uint32(b)
|
||||
return uint16(v + v>>16)
|
||||
}
|
||||
|
||||
// PseudoHeaderChecksum calculates the pseudo-header checksum for the
|
||||
// given destination protocol and network address, ignoring the length
|
||||
// field. Pseudo-headers are needed by transport layers when calculating
|
||||
// their own checksum.
|
||||
func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr net.IP, dstAddr net.IP) uint16 {
|
||||
xsum := Checksum([]byte(srcAddr), 0)
|
||||
xsum = Checksum([]byte(dstAddr), xsum)
|
||||
return Checksum([]byte{0, uint8(protocol)}, xsum)
|
||||
}
|
||||
|
||||
func udp4pkt(packet []byte, dest *net.UDPAddr, src *net.UDPAddr) []byte {
|
||||
ipLen := IPv4MinimumSize
|
||||
udpLen := UDPMinimumSize
|
||||
|
||||
h := make([]byte, 0, ipLen+udpLen+len(packet))
|
||||
hdr := uio.NewBigEndianBuffer(h)
|
||||
|
||||
ipv4fields := &IPv4Fields{
|
||||
IHL: IPv4MinimumSize,
|
||||
TotalLength: uint16(ipLen + udpLen + len(packet)),
|
||||
TTL: 64, // Per RFC 1700's recommendation for IP time to live
|
||||
Protocol: uint8(UDPProtocolNumber),
|
||||
SrcAddr: src.IP.To4(),
|
||||
DstAddr: dest.IP.To4(),
|
||||
}
|
||||
ipv4hdr := IPv4(hdr.WriteN(ipLen))
|
||||
ipv4hdr.Encode(ipv4fields)
|
||||
ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum())
|
||||
|
||||
udphdr := UDP(hdr.WriteN(udpLen))
|
||||
udphdr.Encode(&UDPFields{
|
||||
SrcPort: uint16(src.Port),
|
||||
DstPort: uint16(dest.Port),
|
||||
Length: uint16(udpLen + len(packet)),
|
||||
})
|
||||
|
||||
xsum := Checksum(packet, PseudoHeaderChecksum(
|
||||
ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr))
|
||||
udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length()))
|
||||
|
||||
hdr.WriteBytes(packet)
|
||||
return hdr.Data()
|
||||
}
|
||||
312
internal/dhcpd/network_utils.go
Normal file
312
internal/dhcpd/network_utils.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/util"
|
||||
|
||||
"github.com/AdguardTeam/golibs/file"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Check if network interface has a static IP configured
|
||||
// Supports: Raspbian.
|
||||
func HasStaticIP(ifaceName string) (bool, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return hasStaticIPDhcpcdConf(string(body), ifaceName), nil
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return hasStaticIPDarwin(ifaceName)
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("cannot check if IP is static: not supported on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// Set a static IP for the specified network interface
|
||||
func SetStaticIP(ifaceName string) error {
|
||||
if runtime.GOOS == "linux" {
|
||||
return setStaticIPDhcpdConf(ifaceName)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return setStaticIPDarwin(ifaceName)
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot set static IP on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// for dhcpcd.conf
|
||||
func hasStaticIPDhcpcdConf(dhcpConf, ifaceName string) bool {
|
||||
lines := strings.Split(dhcpConf, "\n")
|
||||
nameLine := fmt.Sprintf("interface %s", ifaceName)
|
||||
withinInterfaceCtx := false
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if withinInterfaceCtx && len(line) == 0 {
|
||||
// an empty line resets our state
|
||||
withinInterfaceCtx = false
|
||||
}
|
||||
|
||||
if len(line) == 0 || line[0] == '#' {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if !withinInterfaceCtx {
|
||||
if line == nameLine {
|
||||
// we found our interface
|
||||
withinInterfaceCtx = true
|
||||
}
|
||||
|
||||
} else {
|
||||
if strings.HasPrefix(line, "interface ") {
|
||||
// we found another interface - reset our state
|
||||
withinInterfaceCtx = false
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "static ip_address=") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Get gateway IP address
|
||||
func getGatewayIP(ifaceName string) string {
|
||||
cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
|
||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
||||
d, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
fields := strings.Fields(string(d))
|
||||
if len(fields) < 3 || fields[0] != "default" {
|
||||
return ""
|
||||
}
|
||||
|
||||
ip := net.ParseIP(fields[2])
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fields[2]
|
||||
}
|
||||
|
||||
// setStaticIPDhcpdConf - updates /etc/dhcpd.conf and sets the current IP address to be static
|
||||
func setStaticIPDhcpdConf(ifaceName string) error {
|
||||
ip := util.GetSubnet(ifaceName)
|
||||
if len(ip) == 0 {
|
||||
return errors.New("can't get IP address")
|
||||
}
|
||||
|
||||
ip4, _, err := net.ParseCIDR(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gatewayIP := getGatewayIP(ifaceName)
|
||||
add := updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String())
|
||||
|
||||
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body = append(body, []byte(add)...)
|
||||
err = file.SafeWrite("/etc/dhcpcd.conf", body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updates dhcpd.conf content -- sets static IP address there
|
||||
// for dhcpcd.conf
|
||||
func updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string {
|
||||
var body []byte
|
||||
|
||||
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
|
||||
ifaceName, ip)
|
||||
body = append(body, []byte(add)...)
|
||||
|
||||
if len(gatewayIP) != 0 {
|
||||
add = fmt.Sprintf("static routers=%s\n",
|
||||
gatewayIP)
|
||||
body = append(body, []byte(add)...)
|
||||
}
|
||||
|
||||
add = fmt.Sprintf("static domain_name_servers=%s\n\n",
|
||||
dnsIP)
|
||||
body = append(body, []byte(add)...)
|
||||
|
||||
return string(body)
|
||||
}
|
||||
|
||||
// Check if network interface has a static IP configured
|
||||
// Supports: MacOS.
|
||||
func hasStaticIPDarwin(ifaceName string) (bool, error) {
|
||||
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return portInfo.static, nil
|
||||
}
|
||||
|
||||
// setStaticIPDarwin - uses networksetup util to set the current IP address to be static
|
||||
// Additionally it configures the current DNS servers as well
|
||||
func setStaticIPDarwin(ifaceName string) error {
|
||||
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if portInfo.static {
|
||||
return errors.New("IP address is already static")
|
||||
}
|
||||
|
||||
dnsAddrs, err := getEtcResolvConfServers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args := make([]string, 0)
|
||||
args = append(args, "-setdnsservers", portInfo.name)
|
||||
args = append(args, dnsAddrs...)
|
||||
|
||||
// Setting DNS servers is necessary when configuring a static IP
|
||||
code, _, err := util.RunCommand("networksetup", args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code != 0 {
|
||||
return fmt.Errorf("failed to set DNS servers, code=%d", code)
|
||||
}
|
||||
|
||||
// Actually configures hardware port to have static IP
|
||||
code, _, err = util.RunCommand("networksetup", "-setmanual",
|
||||
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code != 0 {
|
||||
return fmt.Errorf("failed to set DNS servers, code=%d", code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCurrentHardwarePortInfo gets information the specified network interface
|
||||
func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
||||
// First of all we should find hardware port name
|
||||
m := getNetworkSetupHardwareReports()
|
||||
hardwarePort, ok := m[ifaceName]
|
||||
if !ok {
|
||||
return hardwarePortInfo{}, fmt.Errorf("could not find hardware port for %s", ifaceName)
|
||||
}
|
||||
|
||||
return getHardwarePortInfo(hardwarePort)
|
||||
}
|
||||
|
||||
// getNetworkSetupHardwareReports parses the output of the `networksetup -listallhardwareports` command
|
||||
// it returns a map where the key is the interface name, and the value is the "hardware port"
|
||||
// returns nil if it fails to parse the output
|
||||
func getNetworkSetupHardwareReports() map[string]string {
|
||||
_, out, err := util.RunCommand("networksetup", "-listallhardwareports")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m := make(map[string]string, 0)
|
||||
|
||||
matches := re.FindAllStringSubmatch(out, -1)
|
||||
for i := range matches {
|
||||
port := matches[i][1]
|
||||
device := matches[i][2]
|
||||
m[device] = port
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// hardwarePortInfo - information obtained using MacOS networksetup
|
||||
// about the current state of the internet connection
|
||||
type hardwarePortInfo struct {
|
||||
name string
|
||||
ip string
|
||||
subnet string
|
||||
gatewayIP string
|
||||
static bool
|
||||
}
|
||||
|
||||
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
||||
h := hardwarePortInfo{}
|
||||
|
||||
_, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort)
|
||||
if err != nil {
|
||||
return h, err
|
||||
}
|
||||
|
||||
re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
|
||||
|
||||
match := re.FindStringSubmatch(out)
|
||||
if len(match) == 0 {
|
||||
return h, errors.New("could not find hardware port info")
|
||||
}
|
||||
|
||||
h.name = hardwarePort
|
||||
h.ip = match[1]
|
||||
h.subnet = match[2]
|
||||
h.gatewayIP = match[3]
|
||||
|
||||
if strings.Index(out, "Manual Configuration") == 0 {
|
||||
h.static = true
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Gets a list of nameservers currently configured in the /etc/resolv.conf
|
||||
func getEtcResolvConfServers() ([]string, error) {
|
||||
body, err := ioutil.ReadFile("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
|
||||
|
||||
matches := re.FindAllStringSubmatch(string(body), -1)
|
||||
if len(matches) == 0 {
|
||||
return nil, errors.New("found no DNS servers in /etc/resolv.conf")
|
||||
}
|
||||
|
||||
addrs := make([]string, 0)
|
||||
for i := range matches {
|
||||
addrs = append(addrs, matches[i][1])
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
63
internal/dhcpd/network_utils_test.go
Normal file
63
internal/dhcpd/network_utils_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHasStaticIPDhcpcdConf(t *testing.T) {
|
||||
dhcpdConf := `#comment
|
||||
# comment
|
||||
|
||||
interface eth0
|
||||
static ip_address=192.168.0.1/24
|
||||
|
||||
# interface wlan0
|
||||
static ip_address=192.168.1.1/24
|
||||
|
||||
# comment
|
||||
`
|
||||
assert.True(t, !hasStaticIPDhcpcdConf(dhcpdConf, "wlan0"))
|
||||
|
||||
dhcpdConf = `#comment
|
||||
# comment
|
||||
|
||||
interface eth0
|
||||
static ip_address=192.168.0.1/24
|
||||
|
||||
# interface wlan0
|
||||
static ip_address=192.168.1.1/24
|
||||
|
||||
# comment
|
||||
|
||||
interface wlan0
|
||||
# comment
|
||||
static ip_address=192.168.2.1/24
|
||||
`
|
||||
assert.True(t, hasStaticIPDhcpcdConf(dhcpdConf, "wlan0"))
|
||||
}
|
||||
|
||||
func TestSetStaticIPDhcpcdConf(t *testing.T) {
|
||||
dhcpcdConf := `
|
||||
interface wlan0
|
||||
static ip_address=192.168.0.2/24
|
||||
static routers=192.168.0.1
|
||||
static domain_name_servers=192.168.0.2
|
||||
|
||||
`
|
||||
s := updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2")
|
||||
assert.Equal(t, dhcpcdConf, s)
|
||||
|
||||
// without gateway
|
||||
dhcpcdConf = `
|
||||
interface wlan0
|
||||
static ip_address=192.168.0.2/24
|
||||
static domain_name_servers=192.168.0.2
|
||||
|
||||
`
|
||||
s = updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2")
|
||||
assert.Equal(t, dhcpcdConf, s)
|
||||
}
|
||||
13
internal/dhcpd/os_windows.go
Normal file
13
internal/dhcpd/os_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
// Create a socket for receiving broadcast packets
|
||||
func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) {
|
||||
return nil, errors.New("newBroadcastPacketConn(): not supported on Windows")
|
||||
}
|
||||
239
internal/dhcpd/router_adv.go
Normal file
239
internal/dhcpd/router_adv.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type raCtx struct {
|
||||
raAllowSlaac bool // send RA packets without MO flags
|
||||
raSlaacOnly bool // send RA packets with MO flags
|
||||
ipAddr net.IP // source IP address (link-local-unicast)
|
||||
dnsIPAddr net.IP // IP address for DNS Server option
|
||||
prefixIPAddr net.IP // IP address for Prefix option
|
||||
ifaceName string
|
||||
iface *net.Interface
|
||||
packetSendPeriod time.Duration // how often RA packets are sent
|
||||
|
||||
conn *icmp.PacketConn // ICMPv6 socket
|
||||
stop atomic.Value // stop the packet sending loop
|
||||
}
|
||||
|
||||
type icmpv6RA struct {
|
||||
managedAddressConfiguration bool
|
||||
otherConfiguration bool
|
||||
prefix net.IP
|
||||
prefixLen int
|
||||
sourceLinkLayerAddress net.HardwareAddr
|
||||
recursiveDNSServer net.IP
|
||||
mtu uint32
|
||||
}
|
||||
|
||||
// Create an ICMPv6.RouterAdvertisement packet with all necessary options.
|
||||
//
|
||||
// ICMPv6:
|
||||
// type[1]
|
||||
// code[1]
|
||||
// chksum[2]
|
||||
// body (RouterAdvertisement):
|
||||
// Cur Hop Limit[1]
|
||||
// Flags[1]: MO......
|
||||
// Router Lifetime[2]
|
||||
// Reachable Time[4]
|
||||
// Retrans Timer[4]
|
||||
// Option=Prefix Information(3):
|
||||
// Type[1]
|
||||
// Length * 8bytes[1]
|
||||
// Prefix Length[1]
|
||||
// Flags[1]: LA......
|
||||
// Valid Lifetime[4]
|
||||
// Preferred Lifetime[4]
|
||||
// Reserved[4]
|
||||
// Prefix[16]
|
||||
// Option=MTU(5):
|
||||
// Type[1]
|
||||
// Length * 8bytes[1]
|
||||
// Reserved[2]
|
||||
// MTU[4]
|
||||
// Option=Source link-layer address(1):
|
||||
// Link-Layer Address[6]
|
||||
// Option=Recursive DNS Server(25):
|
||||
// Type[1]
|
||||
// Length * 8bytes[1]
|
||||
// Reserved[2]
|
||||
// Lifetime[4]
|
||||
// Addresses of IPv6 Recursive DNS Servers[16]
|
||||
func createICMPv6RAPacket(params icmpv6RA) []byte {
|
||||
data := make([]byte, 88)
|
||||
i := 0
|
||||
|
||||
// ICMPv6:
|
||||
|
||||
data[i] = 134 // type
|
||||
data[i+1] = 0 // code
|
||||
data[i+2] = 0 // chksum
|
||||
data[i+3] = 0
|
||||
i += 4
|
||||
|
||||
// RouterAdvertisement:
|
||||
|
||||
data[i] = 64 // Cur Hop Limit[1]
|
||||
i++
|
||||
|
||||
data[i] = 0 // Flags[1]: MO......
|
||||
if params.managedAddressConfiguration {
|
||||
data[i] |= 0x80
|
||||
}
|
||||
if params.otherConfiguration {
|
||||
data[i] |= 0x40
|
||||
}
|
||||
i++
|
||||
|
||||
binary.BigEndian.PutUint16(data[i:], 1800) // Router Lifetime[2]
|
||||
i += 2
|
||||
binary.BigEndian.PutUint32(data[i:], 0) // Reachable Time[4]
|
||||
i += 4
|
||||
binary.BigEndian.PutUint32(data[i:], 0) // Retrans Timer[4]
|
||||
i += 4
|
||||
|
||||
// Option=Prefix Information:
|
||||
|
||||
data[i] = 3 // Type
|
||||
data[i+1] = 4 // Length
|
||||
i += 2
|
||||
data[i] = byte(params.prefixLen) // Prefix Length[1]
|
||||
i++
|
||||
data[i] = 0xc0 // Flags[1]
|
||||
i++
|
||||
binary.BigEndian.PutUint32(data[i:], 3600) // Valid Lifetime[4]
|
||||
i += 4
|
||||
binary.BigEndian.PutUint32(data[i:], 3600) // Preferred Lifetime[4]
|
||||
i += 4
|
||||
binary.BigEndian.PutUint32(data[i:], 0) // Reserved[4]
|
||||
i += 4
|
||||
copy(data[i:], params.prefix[:8]) // Prefix[16]
|
||||
binary.BigEndian.PutUint32(data[i+8:], 0)
|
||||
binary.BigEndian.PutUint32(data[i+12:], 0)
|
||||
i += 16
|
||||
|
||||
// Option=MTU:
|
||||
|
||||
data[i] = 5 // Type
|
||||
data[i+1] = 1 // Length
|
||||
i += 2
|
||||
binary.BigEndian.PutUint16(data[i:], 0) // Reserved[2]
|
||||
i += 2
|
||||
binary.BigEndian.PutUint32(data[i:], params.mtu) // MTU[4]
|
||||
i += 4
|
||||
|
||||
// Option=Source link-layer address:
|
||||
|
||||
data[i] = 1 // Type
|
||||
data[i+1] = 1 // Length
|
||||
i += 2
|
||||
copy(data[i:], params.sourceLinkLayerAddress) // Link-Layer Address[6]
|
||||
i += 6
|
||||
|
||||
// Option=Recursive DNS Server:
|
||||
|
||||
data[i] = 25 // Type
|
||||
data[i+1] = 3 // Length
|
||||
i += 2
|
||||
binary.BigEndian.PutUint16(data[i:], 0) // Reserved[2]
|
||||
i += 2
|
||||
binary.BigEndian.PutUint32(data[i:], 3600) // Lifetime[4]
|
||||
i += 4
|
||||
copy(data[i:], params.recursiveDNSServer) // Addresses of IPv6 Recursive DNS Servers[16]
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// Init - initialize RA module
|
||||
func (ra *raCtx) Init() error {
|
||||
ra.stop.Store(0)
|
||||
ra.conn = nil
|
||||
if !(ra.raAllowSlaac || ra.raSlaacOnly) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6 RA: source IP address: %s DNS IP address: %s",
|
||||
ra.ipAddr, ra.dnsIPAddr)
|
||||
|
||||
params := icmpv6RA{
|
||||
managedAddressConfiguration: !ra.raSlaacOnly,
|
||||
otherConfiguration: !ra.raSlaacOnly,
|
||||
mtu: uint32(ra.iface.MTU),
|
||||
prefixLen: 64,
|
||||
recursiveDNSServer: ra.dnsIPAddr,
|
||||
sourceLinkLayerAddress: ra.iface.HardwareAddr,
|
||||
}
|
||||
params.prefix = make([]byte, 16)
|
||||
copy(params.prefix, ra.prefixIPAddr[:8]) // /64
|
||||
|
||||
data := createICMPv6RAPacket(params)
|
||||
|
||||
var err error
|
||||
ipAndScope := ra.ipAddr.String() + "%" + ra.ifaceName
|
||||
ra.conn, err = icmp.ListenPacket("ip6:ipv6-icmp", ipAndScope)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DHCPv6 RA: icmp.ListenPacket: %s", err)
|
||||
}
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
ra.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
con6 := ra.conn.IPv6PacketConn()
|
||||
|
||||
if err := con6.SetHopLimit(255); err != nil {
|
||||
return fmt.Errorf("DHCPv6 RA: SetHopLimit: %s", err)
|
||||
}
|
||||
|
||||
if err := con6.SetMulticastHopLimit(255); err != nil {
|
||||
return fmt.Errorf("DHCPv6 RA: SetMulticastHopLimit: %s", err)
|
||||
}
|
||||
|
||||
msg := &ipv6.ControlMessage{
|
||||
HopLimit: 255,
|
||||
Src: ra.ipAddr,
|
||||
IfIndex: ra.iface.Index,
|
||||
}
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP("ff02::1"),
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Debug("DHCPv6 RA: starting to send periodic RouterAdvertisement packets")
|
||||
for ra.stop.Load() == 0 {
|
||||
_, err = con6.WriteTo(data, msg, addr)
|
||||
if err != nil {
|
||||
log.Error("DHCPv6 RA: WriteTo: %s", err)
|
||||
}
|
||||
time.Sleep(ra.packetSendPeriod)
|
||||
}
|
||||
log.Debug("DHCPv6 RA: loop exit")
|
||||
}()
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close - close module
|
||||
func (ra *raCtx) Close() {
|
||||
log.Debug("DHCPv6 RA: closing")
|
||||
|
||||
ra.stop.Store(1)
|
||||
|
||||
if ra.conn != nil {
|
||||
ra.conn.Close()
|
||||
}
|
||||
}
|
||||
31
internal/dhcpd/router_adv_test.go
Normal file
31
internal/dhcpd/router_adv_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRA(t *testing.T) {
|
||||
ra := icmpv6RA{
|
||||
managedAddressConfiguration: false,
|
||||
otherConfiguration: true,
|
||||
mtu: 1500,
|
||||
prefix: net.ParseIP("1234::"),
|
||||
prefixLen: 64,
|
||||
recursiveDNSServer: net.ParseIP("fe80::800:27ff:fe00:0"),
|
||||
sourceLinkLayerAddress: []byte{0x0a, 0x00, 0x27, 0x00, 0x00, 0x00},
|
||||
}
|
||||
data := createICMPv6RAPacket(ra)
|
||||
dataCorrect := []byte{
|
||||
0x86, 0x00, 0x00, 0x00, 0x40, 0x40, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x00, 0x00, 0x00,
|
||||
0x12, 0x34, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc, 0x01, 0x01, 0x0a, 0x00, 0x27, 0x00, 0x00, 0x00,
|
||||
0x19, 0x03, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x10, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x08, 0x00, 0x27, 0xff, 0xfe, 0x00, 0x00, 0x00,
|
||||
}
|
||||
assert.True(t, bytes.Equal(data, dataCorrect))
|
||||
}
|
||||
100
internal/dhcpd/server.go
Normal file
100
internal/dhcpd/server.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DHCPServer - DHCP server interface
|
||||
type DHCPServer interface {
|
||||
// ResetLeases - reset leases
|
||||
ResetLeases(leases []*Lease)
|
||||
// GetLeases - get leases
|
||||
GetLeases(flags int) []Lease
|
||||
// GetLeasesRef - get reference to leases array
|
||||
GetLeasesRef() []*Lease
|
||||
// AddStaticLease - add a static lease
|
||||
AddStaticLease(lease Lease) error
|
||||
// RemoveStaticLease - remove a static lease
|
||||
RemoveStaticLease(l Lease) error
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
FindMACbyIP(ip net.IP) net.HardwareAddr
|
||||
|
||||
// WriteDiskConfig4 - copy disk configuration
|
||||
WriteDiskConfig4(c *V4ServerConf)
|
||||
// WriteDiskConfig6 - copy disk configuration
|
||||
WriteDiskConfig6(c *V6ServerConf)
|
||||
|
||||
// Start - start server
|
||||
Start() error
|
||||
// Stop - stop server
|
||||
Stop()
|
||||
}
|
||||
|
||||
// V4ServerConf - server configuration
|
||||
type V4ServerConf struct {
|
||||
Enabled bool `yaml:"-"`
|
||||
InterfaceName string `yaml:"-"`
|
||||
|
||||
GatewayIP string `yaml:"gateway_ip"`
|
||||
SubnetMask string `yaml:"subnet_mask"`
|
||||
|
||||
// The first & the last IP address for dynamic leases
|
||||
// Bytes [0..2] of the last allowed IP address must match the first IP
|
||||
RangeStart string `yaml:"range_start"`
|
||||
RangeEnd string `yaml:"range_end"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
|
||||
|
||||
// IP conflict detector: time (ms) to wait for ICMP reply
|
||||
// 0: disable
|
||||
ICMPTimeout uint32 `yaml:"icmp_timeout_msec"`
|
||||
|
||||
// Custom Options.
|
||||
//
|
||||
// Option with arbitrary hexadecimal data:
|
||||
// DEC_CODE hex HEX_DATA
|
||||
// where DEC_CODE is a decimal DHCPv4 option code in range [1..255]
|
||||
//
|
||||
// Option with IP data (only 1 IP is supported):
|
||||
// DEC_CODE ip IP_ADDR
|
||||
Options []string `yaml:"options"`
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
ipEnd net.IP // ending IP address for dynamic leases
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses
|
||||
routerIP net.IP // value for Option Router
|
||||
subnetMask net.IPMask // value for Option SubnetMask
|
||||
options []dhcpOption
|
||||
|
||||
// Server calls this function when leases data changes
|
||||
notify func(uint32)
|
||||
}
|
||||
|
||||
// V6ServerConf - server configuration
|
||||
type V6ServerConf struct {
|
||||
Enabled bool `yaml:"-"`
|
||||
InterfaceName string `yaml:"-"`
|
||||
|
||||
// The first IP address for dynamic leases
|
||||
// The last allowed IP address ends with 0xff byte
|
||||
RangeStart string `yaml:"range_start"`
|
||||
|
||||
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds
|
||||
|
||||
RaSlaacOnly bool `yaml:"ra_slaac_only"` // send ICMPv6.RA packets without MO flags
|
||||
RaAllowSlaac bool `yaml:"ra_allow_slaac"` // send ICMPv6.RA packets with MO flags
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []net.IP // IPv6 addresses to return to DHCP clients as DNS server addresses
|
||||
|
||||
// Server calls this function when leases data changes
|
||||
notify func(uint32)
|
||||
}
|
||||
|
||||
type dhcpOption struct {
|
||||
code uint8
|
||||
val []byte
|
||||
}
|
||||
673
internal/dhcpd/v4.go
Normal file
673
internal/dhcpd/v4.go
Normal file
@@ -0,0 +1,673 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/sparrc/go-ping"
|
||||
)
|
||||
|
||||
// v4Server - DHCPv4 server
|
||||
type v4Server struct {
|
||||
srv *server4.Server
|
||||
leasesLock sync.Mutex
|
||||
leases []*Lease
|
||||
ipAddrs [256]byte
|
||||
|
||||
conf V4ServerConf
|
||||
}
|
||||
|
||||
// WriteDiskConfig4 - write configuration
|
||||
func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) {
|
||||
*c = s.conf
|
||||
}
|
||||
|
||||
// WriteDiskConfig6 - write configuration
|
||||
func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
|
||||
}
|
||||
|
||||
// Return TRUE if IP address is within range [start..stop]
|
||||
func ip4InRange(start net.IP, stop net.IP, ip net.IP) bool {
|
||||
if len(start) != 4 || len(stop) != 4 {
|
||||
return false
|
||||
}
|
||||
from := binary.BigEndian.Uint32(start)
|
||||
to := binary.BigEndian.Uint32(stop)
|
||||
check := binary.BigEndian.Uint32(ip)
|
||||
return from <= check && check <= to
|
||||
}
|
||||
|
||||
// ResetLeases - reset leases
|
||||
func (s *v4Server) ResetLeases(leases []*Lease) {
|
||||
s.leases = nil
|
||||
|
||||
for _, l := range leases {
|
||||
|
||||
if l.Expiry.Unix() != leaseExpireStatic &&
|
||||
!ip4InRange(s.conf.ipStart, s.conf.ipEnd, l.IP) {
|
||||
|
||||
log.Debug("DHCPv4: skipping a lease with IP %v: not within current IP range", l.IP)
|
||||
continue
|
||||
}
|
||||
|
||||
s.addLease(l)
|
||||
}
|
||||
}
|
||||
|
||||
// GetLeasesRef - get leases
|
||||
func (s *v4Server) GetLeasesRef() []*Lease {
|
||||
return s.leases
|
||||
}
|
||||
|
||||
// Return TRUE if this lease holds a blacklisted IP
|
||||
func (s *v4Server) blacklisted(l *Lease) bool {
|
||||
return l.HWAddr.String() == "00:00:00:00:00:00"
|
||||
}
|
||||
|
||||
// GetLeases returns the list of current DHCP leases (thread-safe)
|
||||
func (s *v4Server) GetLeases(flags int) []Lease {
|
||||
var result []Lease
|
||||
now := time.Now().Unix()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
for _, lease := range s.leases {
|
||||
if ((flags&LeasesDynamic) != 0 && lease.Expiry.Unix() > now && !s.blacklisted(lease)) ||
|
||||
((flags&LeasesStatic) != 0 && lease.Expiry.Unix() == leaseExpireStatic) {
|
||||
result = append(result, *lease)
|
||||
}
|
||||
}
|
||||
s.leasesLock.Unlock()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
func (s *v4Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
now := time.Now().Unix()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, l := range s.leases {
|
||||
if l.IP.Equal(ip4) {
|
||||
unix := l.Expiry.Unix()
|
||||
if unix > now || unix == leaseExpireStatic {
|
||||
return l.HWAddr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add the specified IP to the black list for a time period
|
||||
func (s *v4Server) blacklistLease(lease *Lease) {
|
||||
hw := make(net.HardwareAddr, 6)
|
||||
lease.HWAddr = hw
|
||||
lease.Hostname = ""
|
||||
lease.Expiry = time.Now().Add(s.conf.leaseTime)
|
||||
}
|
||||
|
||||
// Remove (swap) lease by index
|
||||
func (s *v4Server) leaseRemoveSwapByIndex(i int) {
|
||||
s.ipAddrs[s.leases[i].IP[3]] = 0
|
||||
log.Debug("DHCPv4: removed lease %s", s.leases[i].HWAddr)
|
||||
|
||||
n := len(s.leases)
|
||||
if i != n-1 {
|
||||
s.leases[i] = s.leases[n-1] // swap with the last element
|
||||
}
|
||||
s.leases = s.leases[:n-1]
|
||||
}
|
||||
|
||||
// Remove a dynamic lease with the same properties
|
||||
// Return error if a static lease is found
|
||||
func (s *v4Server) rmDynamicLease(lease Lease) error {
|
||||
for i := 0; i < len(s.leases); i++ {
|
||||
l := s.leases[i]
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) {
|
||||
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
if i == len(s.leases) {
|
||||
break
|
||||
}
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a lease
|
||||
func (s *v4Server) addLease(l *Lease) {
|
||||
s.leases = append(s.leases, l)
|
||||
s.ipAddrs[l.IP[3]] = 1
|
||||
log.Debug("DHCPv4: added lease %s <-> %s", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
// Remove a lease with the same properties
|
||||
func (s *v4Server) rmLease(lease Lease) error {
|
||||
for i, l := range s.leases {
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
|
||||
l.Hostname != lease.Hostname {
|
||||
|
||||
return fmt.Errorf("Lease not found")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("lease not found")
|
||||
}
|
||||
|
||||
// AddStaticLease adds a static lease (thread-safe)
|
||||
func (s *v4Server) AddStaticLease(lease Lease) error {
|
||||
if len(lease.IP) != 4 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
if len(lease.HWAddr) != 6 {
|
||||
return fmt.Errorf("invalid MAC")
|
||||
}
|
||||
lease.Expiry = time.Unix(leaseExpireStatic, 0)
|
||||
|
||||
s.leasesLock.Lock()
|
||||
err := s.rmDynamicLease(lease)
|
||||
if err != nil {
|
||||
s.leasesLock.Unlock()
|
||||
return err
|
||||
}
|
||||
s.addLease(&lease)
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.leasesLock.Unlock()
|
||||
|
||||
s.conf.notify(LeaseChangedAddedStatic)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveStaticLease removes a static lease (thread-safe)
|
||||
func (s *v4Server) RemoveStaticLease(l Lease) error {
|
||||
if len(l.IP) != 4 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
if len(l.HWAddr) != 6 {
|
||||
return fmt.Errorf("invalid MAC")
|
||||
}
|
||||
|
||||
s.leasesLock.Lock()
|
||||
err := s.rmLease(l)
|
||||
if err != nil {
|
||||
s.leasesLock.Unlock()
|
||||
return err
|
||||
}
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.leasesLock.Unlock()
|
||||
|
||||
s.conf.notify(LeaseChangedRemovedStatic)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send ICMP to the specified machine
|
||||
// Return TRUE if it doesn't reply, which probably means that the IP is available
|
||||
func (s *v4Server) addrAvailable(target net.IP) bool {
|
||||
|
||||
if s.conf.ICMPTimeout == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
pinger, err := ping.NewPinger(target.String())
|
||||
if err != nil {
|
||||
log.Error("ping.NewPinger(): %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
pinger.SetPrivileged(true)
|
||||
pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond
|
||||
pinger.Count = 1
|
||||
reply := false
|
||||
pinger.OnRecv = func(pkt *ping.Packet) {
|
||||
reply = true
|
||||
}
|
||||
log.Debug("DHCPv4: Sending ICMP Echo to %v", target)
|
||||
pinger.Run()
|
||||
|
||||
if reply {
|
||||
log.Info("DHCPv4: IP conflict: %v is already used by another device", target)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debug("DHCPv4: ICMP procedure is complete: %v", target)
|
||||
return true
|
||||
}
|
||||
|
||||
// Find lease by MAC
|
||||
func (s *v4Server) findLease(mac net.HardwareAddr) *Lease {
|
||||
for i := range s.leases {
|
||||
if bytes.Equal(mac, s.leases[i].HWAddr) {
|
||||
return s.leases[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get next free IP
|
||||
func (s *v4Server) findFreeIP() net.IP {
|
||||
for i := s.conf.ipStart[3]; ; i++ {
|
||||
if s.ipAddrs[i] == 0 {
|
||||
ip := make([]byte, 4)
|
||||
copy(ip, s.conf.ipStart)
|
||||
ip[3] = i
|
||||
return ip
|
||||
}
|
||||
if i == s.conf.ipEnd[3] {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find an expired lease and return its index or -1
|
||||
func (s *v4Server) findExpiredLease() int {
|
||||
now := time.Now().Unix()
|
||||
for i, lease := range s.leases {
|
||||
if lease.Expiry.Unix() != leaseExpireStatic &&
|
||||
lease.Expiry.Unix() <= now {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Reserve lease for MAC
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) *Lease {
|
||||
l := Lease{}
|
||||
l.HWAddr = make([]byte, 6)
|
||||
copy(l.HWAddr, mac)
|
||||
|
||||
l.IP = s.findFreeIP()
|
||||
if l.IP == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil
|
||||
}
|
||||
copy(s.leases[i].HWAddr, mac)
|
||||
return s.leases[i]
|
||||
}
|
||||
|
||||
s.addLease(&l)
|
||||
return &l
|
||||
}
|
||||
|
||||
func (s *v4Server) commitLease(l *Lease) {
|
||||
l.Expiry = time.Now().Add(s.conf.leaseTime)
|
||||
|
||||
s.leasesLock.Lock()
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.leasesLock.Unlock()
|
||||
|
||||
s.conf.notify(LeaseChangedAdded)
|
||||
}
|
||||
|
||||
// Process Discover request and return lease
|
||||
func (s *v4Server) processDiscover(req *dhcpv4.DHCPv4, resp *dhcpv4.DHCPv4) *Lease {
|
||||
mac := req.ClientHWAddr
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
lease := s.findLease(mac)
|
||||
if lease == nil {
|
||||
toStore := false
|
||||
for lease == nil {
|
||||
lease = s.reserveLease(mac)
|
||||
if lease == nil {
|
||||
log.Debug("DHCPv4: No more IP addresses")
|
||||
if toStore {
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
toStore = true
|
||||
|
||||
if !s.addrAvailable(lease.IP) {
|
||||
s.blacklistLease(lease)
|
||||
lease = nil
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
|
||||
} else {
|
||||
reqIP := req.Options.Get(dhcpv4.OptionRequestedIPAddress)
|
||||
if len(reqIP) != 0 &&
|
||||
!bytes.Equal(reqIP, lease.IP) {
|
||||
log.Debug("DHCPv4: different RequestedIP: %v != %v", reqIP, lease.IP)
|
||||
}
|
||||
}
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
|
||||
return lease
|
||||
}
|
||||
|
||||
type optFQDN struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (o *optFQDN) String() string {
|
||||
return "optFQDN"
|
||||
}
|
||||
|
||||
// flags[1]
|
||||
// A-RR[1]
|
||||
// PTR-RR[1]
|
||||
// name[]
|
||||
func (o *optFQDN) ToBytes() []byte {
|
||||
b := make([]byte, 3+len(o.name))
|
||||
i := 0
|
||||
|
||||
b[i] = 0x03 // f_server_overrides | f_server
|
||||
i++
|
||||
|
||||
b[i] = 255 // A-RR
|
||||
i++
|
||||
|
||||
b[i] = 255 // PTR-RR
|
||||
i++
|
||||
|
||||
copy(b[i:], []byte(o.name))
|
||||
return b
|
||||
|
||||
}
|
||||
|
||||
// Process Request request and return lease
|
||||
// Return false if we don't need to reply
|
||||
func (s *v4Server) processRequest(req *dhcpv4.DHCPv4, resp *dhcpv4.DHCPv4) (*Lease, bool) {
|
||||
var lease *Lease
|
||||
mac := req.ClientHWAddr
|
||||
hostname := req.Options.Get(dhcpv4.OptionHostName)
|
||||
reqIP := req.Options.Get(dhcpv4.OptionRequestedIPAddress)
|
||||
if reqIP == nil {
|
||||
reqIP = req.ClientIPAddr
|
||||
}
|
||||
|
||||
sid := req.Options.Get(dhcpv4.OptionServerIdentifier)
|
||||
if len(sid) != 0 &&
|
||||
!bytes.Equal(sid, s.conf.dnsIPAddrs[0]) {
|
||||
log.Debug("DHCPv4: Bad OptionServerIdentifier in Request message for %s", mac)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(reqIP) != 4 {
|
||||
log.Debug("DHCPv4: Bad OptionRequestedIPAddress in Request message for %s", mac)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
s.leasesLock.Lock()
|
||||
for _, l := range s.leases {
|
||||
if bytes.Equal(l.HWAddr, mac) {
|
||||
if !bytes.Equal(l.IP, reqIP) {
|
||||
s.leasesLock.Unlock()
|
||||
log.Debug("DHCPv4: Mismatched OptionRequestedIPAddress in Request message for %s", mac)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
lease = l
|
||||
break
|
||||
}
|
||||
}
|
||||
s.leasesLock.Unlock()
|
||||
|
||||
if lease == nil {
|
||||
log.Debug("DHCPv4: No lease for %s", mac)
|
||||
return nil, true
|
||||
}
|
||||
|
||||
if lease.Expiry.Unix() != leaseExpireStatic {
|
||||
lease.Hostname = string(hostname)
|
||||
s.commitLease(lease)
|
||||
} else if len(lease.Hostname) != 0 {
|
||||
o := &optFQDN{
|
||||
name: lease.Hostname,
|
||||
}
|
||||
fqdn := dhcpv4.Option{
|
||||
Code: dhcpv4.OptionFQDN,
|
||||
Value: o,
|
||||
}
|
||||
resp.UpdateOption(fqdn)
|
||||
}
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeAck))
|
||||
return lease, true
|
||||
}
|
||||
|
||||
// Find a lease associated with MAC and prepare response
|
||||
// Return 1: OK
|
||||
// Return 0: error; reply with Nak
|
||||
// Return -1: error; don't reply
|
||||
func (s *v4Server) process(req *dhcpv4.DHCPv4, resp *dhcpv4.DHCPv4) int {
|
||||
|
||||
var lease *Lease
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptServerIdentifier(s.conf.dnsIPAddrs[0]))
|
||||
|
||||
switch req.MessageType() {
|
||||
|
||||
case dhcpv4.MessageTypeDiscover:
|
||||
lease = s.processDiscover(req, resp)
|
||||
if lease == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
case dhcpv4.MessageTypeRequest:
|
||||
var toReply bool
|
||||
lease, toReply = s.processRequest(req, resp)
|
||||
if lease == nil {
|
||||
if toReply {
|
||||
return 0
|
||||
}
|
||||
return -1 // drop packet
|
||||
}
|
||||
}
|
||||
|
||||
resp.YourIPAddr = make([]byte, 4)
|
||||
copy(resp.YourIPAddr, lease.IP)
|
||||
|
||||
resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime))
|
||||
resp.UpdateOption(dhcpv4.OptRouter(s.conf.routerIP))
|
||||
resp.UpdateOption(dhcpv4.OptSubnetMask(s.conf.subnetMask))
|
||||
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
|
||||
|
||||
for _, opt := range s.conf.options {
|
||||
resp.Options[opt.code] = opt.val
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// client(0.0.0.0:68) -> (Request:ClientMAC,Type=Discover,ClientID,ReqIP,HostName) -> server(255.255.255.255:67)
|
||||
// client(255.255.255.255:68) <- (Reply:YourIP,ClientMAC,Type=Offer,ServerID,SubnetMask,LeaseTime) <- server(<IP>:67)
|
||||
// client(0.0.0.0:68) -> (Request:ClientMAC,Type=Request,ClientID,ReqIP||ClientIP,HostName,ServerID,ParamReqList) -> server(255.255.255.255:67)
|
||||
// client(255.255.255.255:68) <- (Reply:YourIP,ClientMAC,Type=ACK,ServerID,SubnetMask,LeaseTime) <- server(<IP>:67)
|
||||
func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4.DHCPv4) {
|
||||
log.Debug("DHCPv4: received message: %s", req.Summary())
|
||||
|
||||
switch req.MessageType() {
|
||||
case dhcpv4.MessageTypeDiscover,
|
||||
dhcpv4.MessageTypeRequest:
|
||||
//
|
||||
|
||||
default:
|
||||
log.Debug("DHCPv4: unsupported message type %d", req.MessageType())
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := dhcpv4.NewReplyFromRequest(req)
|
||||
if err != nil {
|
||||
log.Debug("DHCPv4: dhcpv4.New: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.ClientHWAddr) != 6 {
|
||||
log.Debug("DHCPv4: Invalid ClientHWAddr")
|
||||
return
|
||||
}
|
||||
|
||||
r := s.process(req, resp)
|
||||
if r < 0 {
|
||||
return
|
||||
} else if r == 0 {
|
||||
resp.Options.Update(dhcpv4.OptMessageType(dhcpv4.MessageTypeNak))
|
||||
}
|
||||
|
||||
log.Debug("DHCPv4: sending: %s", resp.Summary())
|
||||
|
||||
_, err = conn.WriteTo(resp.ToBytes(), peer)
|
||||
if err != nil {
|
||||
log.Error("DHCPv4: conn.Write to %s failed: %s", peer, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Start - start server
|
||||
func (s *v4Server) Start() error {
|
||||
if !s.conf.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(s.conf.InterfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DHCPv4: Couldn't find interface by name %s: %s", s.conf.InterfaceName, err)
|
||||
}
|
||||
|
||||
log.Debug("DHCPv4: starting...")
|
||||
s.conf.dnsIPAddrs = getIfaceIPv4(*iface)
|
||||
if len(s.conf.dnsIPAddrs) == 0 {
|
||||
log.Debug("DHCPv4: no IPv6 address for interface %s", iface.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
laddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Port: dhcpv4.ServerPort,
|
||||
}
|
||||
s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("DHCPv4: listening")
|
||||
|
||||
go func() {
|
||||
err = s.srv.Serve()
|
||||
log.Debug("DHCPv4: srv.Serve: %s", err)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop - stop server
|
||||
func (s *v4Server) Stop() {
|
||||
if s.srv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("DHCPv4: stopping")
|
||||
err := s.srv.Close()
|
||||
if err != nil {
|
||||
log.Error("DHCPv4: srv.Close: %s", err)
|
||||
}
|
||||
// now s.srv.Serve() will return
|
||||
s.srv = nil
|
||||
}
|
||||
|
||||
// Create DHCPv4 server
|
||||
func v4Create(conf V4ServerConf) (DHCPServer, error) {
|
||||
s := &v4Server{}
|
||||
s.conf = conf
|
||||
|
||||
if !conf.Enabled {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP)
|
||||
if err != nil {
|
||||
return s, fmt.Errorf("DHCPv4: %s", err)
|
||||
}
|
||||
|
||||
subnet, err := parseIPv4(s.conf.SubnetMask)
|
||||
if err != nil || !isValidSubnetMask(subnet) {
|
||||
return s, fmt.Errorf("DHCPv4: invalid subnet mask: %s", s.conf.SubnetMask)
|
||||
}
|
||||
s.conf.subnetMask = make([]byte, 4)
|
||||
copy(s.conf.subnetMask, subnet)
|
||||
|
||||
s.conf.ipStart, err = parseIPv4(conf.RangeStart)
|
||||
if s.conf.ipStart == nil {
|
||||
return s, fmt.Errorf("DHCPv4: %s", err)
|
||||
}
|
||||
if s.conf.ipStart[0] == 0 {
|
||||
return s, fmt.Errorf("DHCPv4: invalid range start IP")
|
||||
}
|
||||
|
||||
s.conf.ipEnd, err = parseIPv4(conf.RangeEnd)
|
||||
if s.conf.ipEnd == nil {
|
||||
return s, fmt.Errorf("DHCPv4: %s", err)
|
||||
}
|
||||
if !net.IP.Equal(s.conf.ipStart[:3], s.conf.ipEnd[:3]) ||
|
||||
s.conf.ipStart[3] > s.conf.ipEnd[3] {
|
||||
return s, fmt.Errorf("DHCPv4: range end IP should match range start IP")
|
||||
}
|
||||
|
||||
if conf.LeaseDuration == 0 {
|
||||
s.conf.leaseTime = time.Hour * 24
|
||||
s.conf.LeaseDuration = uint32(s.conf.leaseTime.Seconds())
|
||||
} else {
|
||||
s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration)
|
||||
}
|
||||
|
||||
for _, o := range conf.Options {
|
||||
code, val := parseOptionString(o)
|
||||
if code == 0 {
|
||||
log.Debug("DHCPv4: bad option string: %s", o)
|
||||
continue
|
||||
}
|
||||
|
||||
opt := dhcpOption{
|
||||
code: code,
|
||||
val: val,
|
||||
}
|
||||
s.conf.options = append(s.conf.options, opt)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
47
internal/dhcpd/v46_windows.go
Normal file
47
internal/dhcpd/v46_windows.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package dhcpd
|
||||
|
||||
// 'u-root/u-root' package, a dependency of 'insomniacslk/dhcp' package, doesn't build on Windows
|
||||
|
||||
import "net"
|
||||
|
||||
type winServer struct {
|
||||
}
|
||||
|
||||
func (s *winServer) ResetLeases(leases []*Lease) {
|
||||
}
|
||||
func (s *winServer) GetLeases(flags int) []Lease {
|
||||
return nil
|
||||
}
|
||||
func (s *winServer) GetLeasesRef() []*Lease {
|
||||
return nil
|
||||
}
|
||||
func (s *winServer) AddStaticLease(lease Lease) error {
|
||||
return nil
|
||||
}
|
||||
func (s *winServer) RemoveStaticLease(l Lease) error {
|
||||
return nil
|
||||
}
|
||||
func (s *winServer) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *winServer) WriteDiskConfig4(c *V4ServerConf) {
|
||||
}
|
||||
func (s *winServer) WriteDiskConfig6(c *V6ServerConf) {
|
||||
}
|
||||
|
||||
func (s *winServer) Start() error {
|
||||
return nil
|
||||
}
|
||||
func (s *winServer) Stop() {
|
||||
}
|
||||
func (s *winServer) Reset() {
|
||||
}
|
||||
|
||||
func v4Create(conf V4ServerConf) (DHCPServer, error) {
|
||||
return &winServer{}, nil
|
||||
}
|
||||
|
||||
func v6Create(conf V6ServerConf) (DHCPServer, error) {
|
||||
return &winServer{}, nil
|
||||
}
|
||||
238
internal/dhcpd/v4_test.go
Normal file
238
internal/dhcpd/v4_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func notify4(flags uint32) {
|
||||
}
|
||||
|
||||
func TestV4StaticLeaseAddRemove(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
notify: notify4,
|
||||
}
|
||||
s, err := v4Create(conf)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
|
||||
// add static lease
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// try to add the same static lease - fail
|
||||
assert.True(t, s.AddStaticLease(l) != nil)
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
// try to remove static lease - fail
|
||||
l.IP = net.ParseIP("192.168.10.110").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) != nil)
|
||||
|
||||
// remove static lease
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) == nil)
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
}
|
||||
|
||||
func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
notify: notify4,
|
||||
}
|
||||
sIface, err := v4Create(conf)
|
||||
s := sIface.(*v4Server)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
// add dynamic lease
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("192.168.10.150").To4()
|
||||
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
|
||||
// add dynamic lease
|
||||
{
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("192.168.10.151").To4()
|
||||
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
}
|
||||
|
||||
// add static lease with the same IP
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// add static lease with the same MAC
|
||||
l = Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.152").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// check
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 2, len(ls))
|
||||
|
||||
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
|
||||
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
assert.Equal(t, "192.168.10.152", ls[1].IP.String())
|
||||
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
|
||||
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
|
||||
}
|
||||
|
||||
func TestV4StaticLeaseGet(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
notify: notify4,
|
||||
}
|
||||
sIface, err := v4Create(conf)
|
||||
s := sIface.(*v4Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
|
||||
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("192.168.10.150").To4()
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// "Discover"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv4.NewDiscovery(mac)
|
||||
resp, _ := dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
|
||||
// check "Offer"
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv4.NewRequestFromOffer(resp)
|
||||
resp, _ = dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
|
||||
// check "Ack"
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
|
||||
dnsAddrs := resp.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "192.168.10.150", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
}
|
||||
|
||||
func TestV4DynamicLeaseGet(t *testing.T) {
|
||||
conf := V4ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "192.168.10.100",
|
||||
RangeEnd: "192.168.10.200",
|
||||
GatewayIP: "192.168.10.1",
|
||||
SubnetMask: "255.255.255.0",
|
||||
notify: notify4,
|
||||
Options: []string{
|
||||
"81 hex 303132",
|
||||
"82 ip 1.2.3.4",
|
||||
},
|
||||
}
|
||||
sIface, err := v4Create(conf)
|
||||
s := sIface.(*v4Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()}
|
||||
|
||||
// "Discover"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv4.NewDiscovery(mac)
|
||||
resp, _ := dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
|
||||
// check "Offer"
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
|
||||
assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String())
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv4.NewRequestFromOffer(resp)
|
||||
resp, _ = dhcpv4.NewReplyFromRequest(req)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
|
||||
// check "Ack"
|
||||
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
|
||||
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String())
|
||||
assert.Equal(t, "192.168.10.1", resp.Router()[0].String())
|
||||
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String())
|
||||
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
|
||||
dnsAddrs := resp.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "192.168.10.100", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
|
||||
start := net.ParseIP("192.168.10.100").To4()
|
||||
stop := net.ParseIP("192.168.10.200").To4()
|
||||
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4()))
|
||||
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4()))
|
||||
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4()))
|
||||
assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4()))
|
||||
}
|
||||
680
internal/dhcpd/v6.go
Normal file
680
internal/dhcpd/v6.go
Normal file
@@ -0,0 +1,680 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/server6"
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
)
|
||||
|
||||
const valueIAID = "ADGH" // value for IANA.ID
|
||||
|
||||
// v6Server - DHCPv6 server
|
||||
type v6Server struct {
|
||||
srv *server6.Server
|
||||
leasesLock sync.Mutex
|
||||
leases []*Lease
|
||||
ipAddrs [256]byte
|
||||
sid dhcpv6.Duid
|
||||
|
||||
ra raCtx // RA module
|
||||
|
||||
conf V6ServerConf
|
||||
}
|
||||
|
||||
// WriteDiskConfig4 - write configuration
|
||||
func (s *v6Server) WriteDiskConfig4(c *V4ServerConf) {
|
||||
}
|
||||
|
||||
// WriteDiskConfig6 - write configuration
|
||||
func (s *v6Server) WriteDiskConfig6(c *V6ServerConf) {
|
||||
*c = s.conf
|
||||
}
|
||||
|
||||
// Return TRUE if IP address is within range [start..0xff]
|
||||
// nolint(staticcheck)
|
||||
func ip6InRange(start net.IP, ip net.IP) bool {
|
||||
if len(start) != 16 {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(start[:15], ip[:15]) {
|
||||
return false
|
||||
}
|
||||
return start[15] <= ip[15]
|
||||
}
|
||||
|
||||
// ResetLeases - reset leases
|
||||
func (s *v6Server) ResetLeases(ll []*Lease) {
|
||||
s.leases = nil
|
||||
for _, l := range ll {
|
||||
|
||||
if l.Expiry.Unix() != leaseExpireStatic &&
|
||||
!ip6InRange(s.conf.ipStart, l.IP) {
|
||||
|
||||
log.Debug("DHCPv6: skipping a lease with IP %v: not within current IP range", l.IP)
|
||||
continue
|
||||
}
|
||||
|
||||
s.addLease(l)
|
||||
}
|
||||
}
|
||||
|
||||
// GetLeases - get current leases
|
||||
func (s *v6Server) GetLeases(flags int) []Lease {
|
||||
var result []Lease
|
||||
s.leasesLock.Lock()
|
||||
for _, lease := range s.leases {
|
||||
|
||||
if lease.Expiry.Unix() == leaseExpireStatic {
|
||||
if (flags & LeasesStatic) != 0 {
|
||||
result = append(result, *lease)
|
||||
}
|
||||
|
||||
} else {
|
||||
if (flags & LeasesDynamic) != 0 {
|
||||
result = append(result, *lease)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.leasesLock.Unlock()
|
||||
return result
|
||||
}
|
||||
|
||||
// GetLeasesRef - get leases
|
||||
func (s *v6Server) GetLeasesRef() []*Lease {
|
||||
return s.leases
|
||||
}
|
||||
|
||||
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
|
||||
func (s *v6Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
|
||||
now := time.Now().Unix()
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
for _, l := range s.leases {
|
||||
if l.IP.Equal(ip) {
|
||||
unix := l.Expiry.Unix()
|
||||
if unix > now || unix == leaseExpireStatic {
|
||||
return l.HWAddr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove (swap) lease by index
|
||||
func (s *v6Server) leaseRemoveSwapByIndex(i int) {
|
||||
s.ipAddrs[s.leases[i].IP[15]] = 0
|
||||
log.Debug("DHCPv6: removed lease %s", s.leases[i].HWAddr)
|
||||
|
||||
n := len(s.leases)
|
||||
if i != n-1 {
|
||||
s.leases[i] = s.leases[n-1] // swap with the last element
|
||||
}
|
||||
s.leases = s.leases[:n-1]
|
||||
}
|
||||
|
||||
// Remove a dynamic lease with the same properties
|
||||
// Return error if a static lease is found
|
||||
func (s *v6Server) rmDynamicLease(lease Lease) error {
|
||||
for i := 0; i < len(s.leases); i++ {
|
||||
l := s.leases[i]
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) {
|
||||
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
if i == len(s.leases) {
|
||||
break
|
||||
}
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddStaticLease - add a static lease
|
||||
func (s *v6Server) AddStaticLease(l Lease) error {
|
||||
if len(l.IP) != 16 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
if len(l.HWAddr) != 6 {
|
||||
return fmt.Errorf("invalid MAC")
|
||||
}
|
||||
|
||||
l.Expiry = time.Unix(leaseExpireStatic, 0)
|
||||
|
||||
s.leasesLock.Lock()
|
||||
err := s.rmDynamicLease(l)
|
||||
if err != nil {
|
||||
s.leasesLock.Unlock()
|
||||
return err
|
||||
}
|
||||
s.addLease(&l)
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.leasesLock.Unlock()
|
||||
|
||||
s.conf.notify(LeaseChangedAddedStatic)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveStaticLease - remove a static lease
|
||||
func (s *v6Server) RemoveStaticLease(l Lease) error {
|
||||
if len(l.IP) != 16 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
}
|
||||
if len(l.HWAddr) != 6 {
|
||||
return fmt.Errorf("invalid MAC")
|
||||
}
|
||||
|
||||
s.leasesLock.Lock()
|
||||
err := s.rmLease(l)
|
||||
if err != nil {
|
||||
s.leasesLock.Unlock()
|
||||
return err
|
||||
}
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.leasesLock.Unlock()
|
||||
s.conf.notify(LeaseChangedRemovedStatic)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a lease
|
||||
func (s *v6Server) addLease(l *Lease) {
|
||||
s.leases = append(s.leases, l)
|
||||
s.ipAddrs[l.IP[15]] = 1
|
||||
log.Debug("DHCPv6: added lease %s <-> %s", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
// Remove a lease with the same properties
|
||||
func (s *v6Server) rmLease(lease Lease) error {
|
||||
for i, l := range s.leases {
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
|
||||
l.Hostname != lease.Hostname {
|
||||
|
||||
return fmt.Errorf("Lease not found")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("lease not found")
|
||||
}
|
||||
|
||||
// Find lease by MAC
|
||||
func (s *v6Server) findLease(mac net.HardwareAddr) *Lease {
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
for i := range s.leases {
|
||||
if bytes.Equal(mac, s.leases[i].HWAddr) {
|
||||
return s.leases[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find an expired lease and return its index or -1
|
||||
func (s *v6Server) findExpiredLease() int {
|
||||
now := time.Now().Unix()
|
||||
for i, lease := range s.leases {
|
||||
if lease.Expiry.Unix() != leaseExpireStatic &&
|
||||
lease.Expiry.Unix() <= now {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Get next free IP
|
||||
func (s *v6Server) findFreeIP() net.IP {
|
||||
for i := s.conf.ipStart[15]; ; i++ {
|
||||
if s.ipAddrs[i] == 0 {
|
||||
ip := make([]byte, 16)
|
||||
copy(ip, s.conf.ipStart)
|
||||
ip[15] = i
|
||||
return ip
|
||||
}
|
||||
if i == 0xff {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reserve lease for MAC
|
||||
func (s *v6Server) reserveLease(mac net.HardwareAddr) *Lease {
|
||||
l := Lease{}
|
||||
l.HWAddr = make([]byte, 6)
|
||||
copy(l.HWAddr, mac)
|
||||
|
||||
s.leasesLock.Lock()
|
||||
defer s.leasesLock.Unlock()
|
||||
|
||||
copy(l.IP, s.conf.ipStart)
|
||||
l.IP = s.findFreeIP()
|
||||
if l.IP == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil
|
||||
}
|
||||
copy(s.leases[i].HWAddr, mac)
|
||||
return s.leases[i]
|
||||
}
|
||||
|
||||
s.addLease(&l)
|
||||
return &l
|
||||
}
|
||||
|
||||
func (s *v6Server) commitDynamicLease(l *Lease) {
|
||||
l.Expiry = time.Now().Add(s.conf.leaseTime)
|
||||
|
||||
s.leasesLock.Lock()
|
||||
s.conf.notify(LeaseChangedDBStore)
|
||||
s.leasesLock.Unlock()
|
||||
s.conf.notify(LeaseChangedAdded)
|
||||
}
|
||||
|
||||
// Check Client ID
|
||||
func (s *v6Server) checkCID(msg *dhcpv6.Message) error {
|
||||
if msg.Options.ClientID() == nil {
|
||||
return fmt.Errorf("DHCPv6: no ClientID option in request")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check ServerID policy
|
||||
func (s *v6Server) checkSID(msg *dhcpv6.Message) error {
|
||||
sid := msg.Options.ServerID()
|
||||
|
||||
switch msg.Type() {
|
||||
case dhcpv6.MessageTypeSolicit,
|
||||
dhcpv6.MessageTypeConfirm,
|
||||
dhcpv6.MessageTypeRebind:
|
||||
|
||||
if sid != nil {
|
||||
return fmt.Errorf("DHCPv6: drop packet: ServerID option in message %s", msg.Type().String())
|
||||
}
|
||||
|
||||
case dhcpv6.MessageTypeRequest,
|
||||
dhcpv6.MessageTypeRenew,
|
||||
dhcpv6.MessageTypeRelease,
|
||||
dhcpv6.MessageTypeDecline:
|
||||
|
||||
if sid == nil {
|
||||
return fmt.Errorf("DHCPv6: drop packet: no ServerID option in message %s", msg.Type().String())
|
||||
}
|
||||
if !sid.Equal(s.sid) {
|
||||
return fmt.Errorf("DHCPv6: drop packet: mismatched ServerID option in message %s: %s",
|
||||
msg.Type().String(), sid.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// . IAAddress must be equal to the lease's IP
|
||||
func (s *v6Server) checkIA(msg *dhcpv6.Message, lease *Lease) error {
|
||||
switch msg.Type() {
|
||||
case dhcpv6.MessageTypeRequest,
|
||||
dhcpv6.MessageTypeConfirm,
|
||||
dhcpv6.MessageTypeRenew,
|
||||
dhcpv6.MessageTypeRebind:
|
||||
|
||||
oia := msg.Options.OneIANA()
|
||||
if oia == nil {
|
||||
return fmt.Errorf("no IANA option in %s", msg.Type().String())
|
||||
}
|
||||
|
||||
oiaAddr := oia.Options.OneAddress()
|
||||
if oiaAddr == nil {
|
||||
return fmt.Errorf("no IANA.Addr option in %s", msg.Type().String())
|
||||
}
|
||||
|
||||
if !oiaAddr.IPv6Addr.Equal(lease.IP) {
|
||||
return fmt.Errorf("invalid IANA.Addr option in %s", msg.Type().String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store lease in DB (if necessary) and return lease life time
|
||||
func (s *v6Server) commitLease(msg *dhcpv6.Message, lease *Lease) time.Duration {
|
||||
lifetime := s.conf.leaseTime
|
||||
|
||||
switch msg.Type() {
|
||||
case dhcpv6.MessageTypeSolicit:
|
||||
//
|
||||
|
||||
case dhcpv6.MessageTypeConfirm:
|
||||
lifetime = lease.Expiry.Sub(time.Now())
|
||||
|
||||
case dhcpv6.MessageTypeRequest,
|
||||
dhcpv6.MessageTypeRenew,
|
||||
dhcpv6.MessageTypeRebind:
|
||||
|
||||
if lease.Expiry.Unix() != leaseExpireStatic {
|
||||
s.commitDynamicLease(lease)
|
||||
}
|
||||
}
|
||||
return lifetime
|
||||
}
|
||||
|
||||
// Find a lease associated with MAC and prepare response
|
||||
func (s *v6Server) process(msg *dhcpv6.Message, req dhcpv6.DHCPv6, resp dhcpv6.DHCPv6) bool {
|
||||
switch msg.Type() {
|
||||
case dhcpv6.MessageTypeSolicit,
|
||||
dhcpv6.MessageTypeRequest,
|
||||
dhcpv6.MessageTypeConfirm,
|
||||
dhcpv6.MessageTypeRenew,
|
||||
dhcpv6.MessageTypeRebind:
|
||||
// continue
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
mac, err := dhcpv6.ExtractMAC(req)
|
||||
if err != nil {
|
||||
log.Debug("DHCPv6: dhcpv6.ExtractMAC: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
lease := s.findLease(mac)
|
||||
if lease == nil {
|
||||
log.Debug("DHCPv6: no lease for: %s", mac)
|
||||
|
||||
switch msg.Type() {
|
||||
|
||||
case dhcpv6.MessageTypeSolicit:
|
||||
lease = s.reserveLease(mac)
|
||||
if lease == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
err = s.checkIA(msg, lease)
|
||||
if err != nil {
|
||||
log.Debug("DHCPv6: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
lifetime := s.commitLease(msg, lease)
|
||||
|
||||
oia := &dhcpv6.OptIANA{
|
||||
T1: lifetime / 2,
|
||||
T2: time.Duration(float32(lifetime) / 1.5),
|
||||
}
|
||||
roia := msg.Options.OneIANA()
|
||||
if roia != nil {
|
||||
copy(oia.IaId[:], roia.IaId[:])
|
||||
} else {
|
||||
copy(oia.IaId[:], []byte(valueIAID))
|
||||
}
|
||||
oiaAddr := &dhcpv6.OptIAAddress{
|
||||
IPv6Addr: lease.IP,
|
||||
PreferredLifetime: lifetime,
|
||||
ValidLifetime: lifetime,
|
||||
}
|
||||
oia.Options = dhcpv6.IdentityOptions{
|
||||
Options: []dhcpv6.Option{oiaAddr},
|
||||
}
|
||||
resp.AddOption(oia)
|
||||
|
||||
if msg.IsOptionRequested(dhcpv6.OptionDNSRecursiveNameServer) {
|
||||
resp.UpdateOption(dhcpv6.OptDNS(s.conf.dnsIPAddrs...))
|
||||
}
|
||||
|
||||
fqdn := msg.GetOneOption(dhcpv6.OptionFQDN)
|
||||
if fqdn != nil {
|
||||
resp.AddOption(fqdn)
|
||||
}
|
||||
|
||||
resp.AddOption(&dhcpv6.OptStatusCode{
|
||||
StatusCode: iana.StatusSuccess,
|
||||
StatusMessage: "success",
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// 1.
|
||||
// fe80::* (client) --(Solicit + ClientID+IANA())-> ff02::1:2
|
||||
// server -(Advertise + ClientID+ServerID+IANA(IAAddress)> fe80::*
|
||||
// fe80::* --(Request + ClientID+ServerID+IANA(IAAddress))-> ff02::1:2
|
||||
// server -(Reply + ClientID+ServerID+IANA(IAAddress)+DNS)> fe80::*
|
||||
//
|
||||
// 2.
|
||||
// fe80::* --(Confirm|Renew|Rebind + ClientID+IANA(IAAddress))-> ff02::1:2
|
||||
// server -(Reply + ClientID+ServerID+IANA(IAAddress)+DNS)> fe80::*
|
||||
//
|
||||
// 3.
|
||||
// fe80::* --(Release + ClientID+ServerID+IANA(IAAddress))-> ff02::1:2
|
||||
func (s *v6Server) packetHandler(conn net.PacketConn, peer net.Addr, req dhcpv6.DHCPv6) {
|
||||
msg, err := req.GetInnerMessage()
|
||||
if err != nil {
|
||||
log.Error("DHCPv6: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: received: %s", req.Summary())
|
||||
|
||||
err = s.checkCID(msg)
|
||||
if err != nil {
|
||||
log.Debug("%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = s.checkSID(msg)
|
||||
if err != nil {
|
||||
log.Debug("%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var resp dhcpv6.DHCPv6
|
||||
|
||||
switch msg.Type() {
|
||||
case dhcpv6.MessageTypeSolicit:
|
||||
if msg.GetOneOption(dhcpv6.OptionRapidCommit) == nil {
|
||||
resp, err = dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case dhcpv6.MessageTypeRequest,
|
||||
dhcpv6.MessageTypeConfirm,
|
||||
dhcpv6.MessageTypeRenew,
|
||||
dhcpv6.MessageTypeRebind,
|
||||
dhcpv6.MessageTypeRelease,
|
||||
dhcpv6.MessageTypeInformationRequest:
|
||||
resp, err = dhcpv6.NewReplyFromMessage(msg)
|
||||
|
||||
default:
|
||||
log.Error("DHCPv6: message type %d not supported", msg.Type())
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Error("DHCPv6: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp.AddOption(dhcpv6.OptServerID(s.sid))
|
||||
|
||||
_ = s.process(msg, req, resp)
|
||||
|
||||
log.Debug("DHCPv6: sending: %s", resp.Summary())
|
||||
|
||||
_, err = conn.WriteTo(resp.ToBytes(), peer)
|
||||
if err != nil {
|
||||
log.Error("DHCPv6: conn.Write to %s failed: %s", peer, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get IPv6 address list
|
||||
func getIfaceIPv6(iface net.Interface) []net.IP {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var res []net.IP
|
||||
for _, a := range addrs {
|
||||
ipnet, ok := a.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if ipnet.IP.To4() == nil {
|
||||
res = append(res, ipnet.IP)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// initialize RA module
|
||||
func (s *v6Server) initRA(iface *net.Interface) error {
|
||||
// choose the source IP address - should be link-local-unicast
|
||||
s.ra.ipAddr = s.conf.dnsIPAddrs[0]
|
||||
for _, ip := range s.conf.dnsIPAddrs {
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
s.ra.ipAddr = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.ra.raAllowSlaac = s.conf.RaAllowSlaac
|
||||
s.ra.raSlaacOnly = s.conf.RaSlaacOnly
|
||||
s.ra.dnsIPAddr = s.ra.ipAddr
|
||||
s.ra.prefixIPAddr = s.conf.ipStart
|
||||
s.ra.ifaceName = s.conf.InterfaceName
|
||||
s.ra.iface = iface
|
||||
s.ra.packetSendPeriod = 1 * time.Second
|
||||
return s.ra.Init()
|
||||
}
|
||||
|
||||
// Start - start server
|
||||
func (s *v6Server) Start() error {
|
||||
if !s.conf.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
iface, err := net.InterfaceByName(s.conf.InterfaceName)
|
||||
if err != nil {
|
||||
return wrapErrPrint(err, "Couldn't find interface by name %s", s.conf.InterfaceName)
|
||||
}
|
||||
|
||||
s.conf.dnsIPAddrs = getIfaceIPv6(*iface)
|
||||
if len(s.conf.dnsIPAddrs) == 0 {
|
||||
log.Debug("DHCPv6: no IPv6 address for interface %s", iface.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = s.initRA(iface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// don't initialize DHCPv6 server if we must force the clients to use SLAAC
|
||||
if s.conf.RaSlaacOnly {
|
||||
log.Debug("DHCPv6: not starting DHCPv6 server due to ra_slaac_only=true")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: starting...")
|
||||
|
||||
if len(iface.HardwareAddr) != 6 {
|
||||
return fmt.Errorf("DHCPv6: invalid MAC %s", iface.HardwareAddr)
|
||||
}
|
||||
s.sid = dhcpv6.Duid{
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
LinkLayerAddr: iface.HardwareAddr,
|
||||
Time: dhcpv6.GetTime(),
|
||||
}
|
||||
|
||||
laddr := &net.UDPAddr{
|
||||
IP: net.ParseIP("::"),
|
||||
Port: dhcpv6.DefaultServerPort,
|
||||
}
|
||||
s.srv, err = server6.NewServer(iface.Name, laddr, s.packetHandler, server6.WithDebugLogger())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = s.srv.Serve()
|
||||
log.Debug("DHCPv6: srv.Serve: %s", err)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop - stop server
|
||||
func (s *v6Server) Stop() {
|
||||
s.ra.Close()
|
||||
|
||||
// DHCPv6 server may not be initialized if ra_slaac_only=true
|
||||
if s.srv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("DHCPv6: stopping")
|
||||
err := s.srv.Close()
|
||||
if err != nil {
|
||||
log.Error("DHCPv6: srv.Close: %s", err)
|
||||
}
|
||||
// now server.Serve() will return
|
||||
s.srv = nil
|
||||
}
|
||||
|
||||
// Create DHCPv6 server
|
||||
func v6Create(conf V6ServerConf) (DHCPServer, error) {
|
||||
s := &v6Server{}
|
||||
s.conf = conf
|
||||
|
||||
if !conf.Enabled {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
s.conf.ipStart = net.ParseIP(conf.RangeStart)
|
||||
if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil {
|
||||
return s, fmt.Errorf("DHCPv6: invalid range-start IP: %s", conf.RangeStart)
|
||||
}
|
||||
|
||||
if conf.LeaseDuration == 0 {
|
||||
s.conf.leaseTime = time.Hour * 24
|
||||
s.conf.LeaseDuration = uint32(s.conf.leaseTime.Seconds())
|
||||
} else {
|
||||
s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
225
internal/dhcpd/v6_test.go
Normal file
225
internal/dhcpd/v6_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/iana"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func notify6(flags uint32) {
|
||||
}
|
||||
|
||||
func TestV6StaticLeaseAddRemove(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::1",
|
||||
notify: notify6,
|
||||
}
|
||||
s, err := v6Create(conf)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
|
||||
// add static lease
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// try to add static lease - fail
|
||||
assert.True(t, s.AddStaticLease(l) != nil)
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "2001::1", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
// try to remove static lease - fail
|
||||
l.IP = net.ParseIP("2001::2")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) != nil)
|
||||
|
||||
// remove static lease
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.RemoveStaticLease(l) == nil)
|
||||
|
||||
// check
|
||||
ls = s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 0, len(ls))
|
||||
}
|
||||
|
||||
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::1",
|
||||
notify: notify6,
|
||||
}
|
||||
sIface, err := v6Create(conf)
|
||||
s := sIface.(*v6Server)
|
||||
assert.True(t, err == nil)
|
||||
|
||||
// add dynamic lease
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("2001::1")
|
||||
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
|
||||
// add dynamic lease
|
||||
{
|
||||
ld := Lease{}
|
||||
ld.IP = net.ParseIP("2001::2")
|
||||
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
s.addLease(&ld)
|
||||
}
|
||||
|
||||
// add static lease with the same IP
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// add static lease with the same MAC
|
||||
l = Lease{}
|
||||
l.IP = net.ParseIP("2001::3")
|
||||
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// check
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 2, len(ls))
|
||||
|
||||
assert.Equal(t, "2001::1", ls[0].IP.String())
|
||||
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic)
|
||||
|
||||
assert.Equal(t, "2001::3", ls[1].IP.String())
|
||||
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
|
||||
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic)
|
||||
}
|
||||
|
||||
func TestV6GetLease(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::1",
|
||||
notify: notify6,
|
||||
}
|
||||
sIface, err := v6Create(conf)
|
||||
s := sIface.(*v6Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
|
||||
s.sid = dhcpv6.Duid{
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
}
|
||||
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
|
||||
l := Lease{}
|
||||
l.IP = net.ParseIP("2001::1")
|
||||
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
assert.True(t, s.AddStaticLease(l) == nil)
|
||||
|
||||
// "Solicit"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv6.NewSolicit(mac)
|
||||
msg, _ := req.GetInnerMessage()
|
||||
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
resp.AddOption(dhcpv6.OptServerID(s.sid))
|
||||
|
||||
// check "Advertise"
|
||||
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
|
||||
oia := resp.Options.OneIANA()
|
||||
oiaAddr := oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
|
||||
msg, _ = req.GetInnerMessage()
|
||||
resp, _ = dhcpv6.NewReplyFromMessage(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
|
||||
// check "Reply"
|
||||
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::1", oiaAddr.IPv6Addr.String())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
|
||||
|
||||
dnsAddrs := resp.Options.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "2000::1", dnsAddrs[0].String())
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesStatic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "2001::1", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
}
|
||||
|
||||
func TestV6GetDynamicLease(t *testing.T) {
|
||||
conf := V6ServerConf{
|
||||
Enabled: true,
|
||||
RangeStart: "2001::2",
|
||||
notify: notify6,
|
||||
}
|
||||
sIface, err := v6Create(conf)
|
||||
s := sIface.(*v6Server)
|
||||
assert.True(t, err == nil)
|
||||
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
|
||||
s.sid = dhcpv6.Duid{
|
||||
Type: dhcpv6.DUID_LLT,
|
||||
HwType: iana.HWTypeEthernet,
|
||||
}
|
||||
s.sid.LinkLayerAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
|
||||
// "Solicit"
|
||||
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
|
||||
req, _ := dhcpv6.NewSolicit(mac)
|
||||
msg, _ := req.GetInnerMessage()
|
||||
resp, _ := dhcpv6.NewAdvertiseFromSolicit(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
resp.AddOption(dhcpv6.OptServerID(s.sid))
|
||||
|
||||
// check "Advertise"
|
||||
assert.Equal(t, dhcpv6.MessageTypeAdvertise, resp.Type())
|
||||
oia := resp.Options.OneIANA()
|
||||
oiaAddr := oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
|
||||
|
||||
// "Request"
|
||||
req, _ = dhcpv6.NewRequestFromAdvertise(resp)
|
||||
msg, _ = req.GetInnerMessage()
|
||||
resp, _ = dhcpv6.NewReplyFromMessage(msg)
|
||||
assert.True(t, s.process(msg, req, resp))
|
||||
|
||||
// check "Reply"
|
||||
assert.Equal(t, dhcpv6.MessageTypeReply, resp.Type())
|
||||
oia = resp.Options.OneIANA()
|
||||
oiaAddr = oia.Options.OneAddress()
|
||||
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
|
||||
|
||||
dnsAddrs := resp.Options.DNS()
|
||||
assert.Equal(t, 1, len(dnsAddrs))
|
||||
assert.Equal(t, "2000::1", dnsAddrs[0].String())
|
||||
|
||||
// check lease
|
||||
ls := s.GetLeases(LeasesDynamic)
|
||||
assert.Equal(t, 1, len(ls))
|
||||
assert.Equal(t, "2001::2", ls[0].IP.String())
|
||||
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
|
||||
|
||||
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
|
||||
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
|
||||
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2")))
|
||||
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3")))
|
||||
}
|
||||
Reference in New Issue
Block a user