diff --git a/dhcpd/db.go b/dhcpd/db.go index 36893afa..08a924f8 100644 --- a/dhcpd/db.go +++ b/dhcpd/db.go @@ -103,10 +103,10 @@ func (s *Server) dbLoad() { s.reserveIP(lease.IP, lease.HWAddr) } - s.v6Leases = normalizeLeases(v6StaticLeases, []*Lease{}) + s.srv6.leases = normalizeLeases(v6StaticLeases, []*Lease{}) log.Info("DHCP: loaded leases v4:%d v6:%d total-read:%d from DB", - len(s.leases), len(s.v6Leases), numLeases) + len(s.leases), len(s.srv6.leases), numLeases) } // Skip duplicate leases @@ -153,7 +153,7 @@ func (s *Server) dbStore() { leases = append(leases, lease) } - for _, l := range s.v6Leases { + for _, l := range s.srv6.leases { if l.Expiry.Unix() == 0 { continue } diff --git a/dhcpd/dhcp_http.go b/dhcpd/dhcp_http.go index 2a1ad40e..d6926e08 100644 --- a/dhcpd/dhcp_http.go +++ b/dhcpd/dhcp_http.go @@ -259,7 +259,7 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request HWAddr: mac, } - err = s.v6AddStaticLease(lease) + err = s.srv6.AddStaticLease(lease) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) return @@ -313,7 +313,7 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ HWAddr: mac, } - err = s.v6RemoveStaticLease(lease) + err = s.srv6.RemoveStaticLease(lease) if err != nil { httpError(r, w, http.StatusBadRequest, "%s", err) return diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go index d3204a7c..7aee7113 100644 --- a/dhcpd/dhcpd.go +++ b/dhcpd/dhcpd.go @@ -47,7 +47,7 @@ type ServerConfig struct { // 0: disable ICMPTimeout uint32 `json:"icmp_timeout_msec" yaml:"icmp_timeout_msec"` - EnableV6 bool `yaml:"enable_v6"` + Conf6 V6ServerConf `yaml:"dhcpv6"` WorkDir string `json:"-" yaml:"-"` DBFilePath string `json:"-" yaml:"-"` // path to DB file @@ -91,8 +91,7 @@ type Server struct { // IP address pool -- if entry is in the pool, then it's attached to a lease IPpool map[[4]byte]net.HardwareAddr - v6Leases []*Lease - v6LeasesLock sync.RWMutex + srv6 *V6Server conf ServerConfig @@ -134,6 +133,8 @@ func Create(config ServerConfig) *Server { s.registerHandlers() } + s.srv6 = v6Create(config.Conf6) + // we can't delay database loading until DHCP server is started, // because we need static leases functionality available beforehand s.dbLoad() @@ -262,8 +263,8 @@ func (s *Server) Start() error { }() } - if s.conf.EnableV6 { - err := s.v6Start(*iface) + if s.conf.Conf6.Enabled { + err := s.srv6.Start(*iface) if err != nil { return err } @@ -751,7 +752,7 @@ func (s *Server) Leases(flags int) []Lease { } s.leasesLock.RUnlock() - v6leases := s.v6GetLeases(flags) + v6leases := s.srv6.GetLeases(flags) result = append(result, v6leases...) return result diff --git a/dhcpd/v6.go b/dhcpd/v6.go index 745bff8b..299f0760 100644 --- a/dhcpd/v6.go +++ b/dhcpd/v6.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net" + "sync" "time" "github.com/AdguardTeam/golibs/log" @@ -14,19 +15,39 @@ import ( const valIAID = "ADGH" -func (s *Server) v6GetLeases(flags int) []Lease { +// V6Server - DHCPv6 server +type V6Server struct { + s4 *Server // for dbStore() + srv *server6.Server + leases []*Lease + leasesLock sync.Mutex + + conf V6ServerConf +} + +// V6ServerConf - server configuration +type V6ServerConf struct { + Enabled bool `yaml:"enabled"` + RangeStart string `yaml:"range_start"` + LeaseDuration uint32 `yaml:"lease_duration"` // in seconds + leaseTime time.Duration +} + +// GetLeases - get current leases +func (s *V6Server) GetLeases(flags int) []Lease { var result []Lease - s.v6LeasesLock.Lock() - for _, lease := range s.v6Leases { + s.leasesLock.Lock() + for _, lease := range s.leases { if (flags&LeasesStatic) != 0 && lease.Expiry.Unix() == leaseExpireStatic { result = append(result, *lease) } } - s.v6LeasesLock.Unlock() + s.leasesLock.Unlock() return result } -func (s *Server) v6AddStaticLease(l Lease) error { +// AddStaticLease - add a static lease +func (s *V6Server) AddStaticLease(l Lease) error { if len(l.IP) != 16 { return fmt.Errorf("invalid IP") } @@ -36,18 +57,49 @@ func (s *Server) v6AddStaticLease(l Lease) error { l.Expiry = time.Unix(leaseExpireStatic, 0) - s.v6LeasesLock.Lock() - s.v6Leases = append(s.v6Leases, &l) - s.dbStore() - s.v6LeasesLock.Unlock() + s.leasesLock.Lock() + err := s.addLease(l) + if err != nil { + s.leasesLock.Unlock() + return err + } + s.s4.dbStore() + s.leasesLock.Unlock() // s.notify(LeaseChangedAddedStatic) return nil } +// RemoveStaticLease - remove a static lease +func (s *V6Server) RemoveStaticLease(l Lease) error { + if len(l.IP) != 16 { + return fmt.Errorf("invalid IP") + } + if len(l.HWAddr) != 6 { + return fmt.Errorf("invalid MAC") + } + + s.leasesLock.Lock() + err := s.rmLease(l) + if err != nil { + s.leasesLock.Unlock() + return err + } + s.s4.dbStore() + s.leasesLock.Unlock() + // s.notify(LeaseChangedRemovedStatic) + return nil +} + +// Add a lease +func (s *V6Server) addLease(l Lease) error { + s.leases = append(s.leases, &l) + return nil +} + // Remove a lease -func (s *Server) v6RmLease(l Lease) error { +func (s *V6Server) rmLease(l Lease) error { var newLeases []*Lease - for _, lease := range s.v6Leases { + for _, lease := range s.leases { if net.IP.Equal(lease.IP, l.IP) { if !bytes.Equal(lease.HWAddr, l.HWAddr) { return fmt.Errorf("Lease not found") @@ -57,54 +109,34 @@ func (s *Server) v6RmLease(l Lease) error { newLeases = append(newLeases, lease) } - if len(newLeases) == len(s.v6Leases) { + if len(newLeases) == len(s.leases) { return fmt.Errorf("Lease not found: %s", l.IP) } - s.v6Leases = newLeases + s.leases = newLeases return nil } -func (s *Server) v6RemoveStaticLease(l Lease) error { - if len(l.IP) != 16 { - return fmt.Errorf("invalid IP") - } - if len(l.HWAddr) != 6 { - return fmt.Errorf("invalid MAC") - } +func (s *V6Server) findLease(mac net.HardwareAddr) *Lease { + s.leasesLock.Lock() + defer s.leasesLock.Unlock() - s.v6LeasesLock.Lock() - err := s.v6RmLease(l) - if err != nil { - s.v6LeasesLock.Unlock() - return err - } - s.dbStore() - s.v6LeasesLock.Unlock() - // s.notify(LeaseChangedRemovedStatic) - return nil -} - -func (s *Server) v6FindLease(mac net.HardwareAddr) *Lease { - s.v6LeasesLock.Lock() - defer s.v6LeasesLock.Unlock() - - for i := range s.v6Leases { - if bytes.Equal(mac, s.v6Leases[i].HWAddr) { - return s.v6Leases[i] + for i := range s.leases { + if bytes.Equal(mac, s.leases[i].HWAddr) { + return s.leases[i] } } return nil } -func (s *Server) v6Process(req dhcpv6.DHCPv6, resp dhcpv6.DHCPv6) { +func (s *V6Server) v6Process(req dhcpv6.DHCPv6, resp dhcpv6.DHCPv6) { mac, err := dhcpv6.ExtractMAC(req) if err != nil { log.Debug("DHCPv6: dhcpv6.ExtractMAC: %s", err) return } - lease := s.v6FindLease(mac) + lease := s.findLease(mac) if lease == nil { log.Debug("DHCPv6: no lease for: %s", mac) return @@ -122,14 +154,14 @@ func (s *Server) v6Process(req dhcpv6.DHCPv6, resp dhcpv6.DHCPv6) { oia.Options = dhcpv6.IdentityOptions{Options: []dhcpv6.Option{ &dhcpv6.OptIAAddress{ IPv6Addr: lease.IP, - PreferredLifetime: s.leaseTime, - ValidLifetime: s.leaseTime, + PreferredLifetime: s.conf.leaseTime, + ValidLifetime: s.conf.leaseTime, }, }} resp.AddOption(oia) } -func (s *Server) v6PacketHandler(conn net.PacketConn, peer net.Addr, req dhcpv6.DHCPv6) { +func (s *V6Server) packetHandler(conn net.PacketConn, peer net.Addr, req dhcpv6.DHCPv6) { msg, err := req.GetInnerMessage() if err != nil { log.Error("DHCPv6: %s", err) @@ -181,12 +213,13 @@ func (s *Server) v6PacketHandler(conn net.PacketConn, peer net.Addr, req dhcpv6. } } -func (s *Server) v6Start(iface net.Interface) error { +// Start - start server +func (s *V6Server) Start(iface net.Interface) error { laddr := &net.UDPAddr{ IP: net.ParseIP("::"), Port: dhcpv6.DefaultServerPort, } - server, err := server6.NewServer(iface.Name, laddr, s.v6PacketHandler, server6.WithDebugLogger()) + server, err := server6.NewServer(iface.Name, laddr, s.packetHandler, server6.WithDebugLogger()) if err != nil { return err } @@ -197,3 +230,16 @@ func (s *Server) v6Start(iface net.Interface) error { }() return nil } + +func v6Create(conf V6ServerConf) *V6Server { + s := &V6Server{} + s.conf = conf + + if conf.LeaseDuration == 0 { + s.conf.leaseTime = time.Hour * 2 + } else { + s.conf.leaseTime = time.Second * time.Duration(conf.LeaseDuration) + } + + return s +}