Add source IP filtering/allow list feature
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"flag"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/handlers"
|
||||
)
|
||||
@@ -19,10 +20,38 @@ func invalidHandler(httpW http.ResponseWriter, httpR *http.Request) {
|
||||
httpW.Write([]byte("Invalid Request\n"))
|
||||
}
|
||||
|
||||
// Access handler, check to see if client IP in allowed IPs, continue if it is, send to invalidHandler if not
|
||||
func accessHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(httpW http.ResponseWriter, httpR *http.Request) {
|
||||
|
||||
// setting.allowedIPs will always have at least one element because of how it's defined
|
||||
if setting.allowedIPs[0] == "" {
|
||||
next.ServeHTTP(httpW, httpR)
|
||||
}
|
||||
|
||||
IPPort := httpR.RemoteAddr
|
||||
|
||||
// Remove port from IP and remove brackets that are around IPv6 addresses
|
||||
requestIp := IPPort[0:strings.LastIndex(IPPort, ":")]
|
||||
requestIp = strings.Replace(requestIp, "[", "", -1)
|
||||
requestIp = strings.Replace(requestIp, "]", "", -1)
|
||||
|
||||
for _, allowedIP := range setting.allowedIPs {
|
||||
if requestIp == allowedIP {
|
||||
next.ServeHTTP(httpW, httpR)
|
||||
}
|
||||
}
|
||||
|
||||
invalidHandler(httpW, httpR)
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
type settingType struct {
|
||||
birdSocket string
|
||||
bird6Socket string
|
||||
listen string
|
||||
allowedIPs []string
|
||||
}
|
||||
|
||||
var setting settingType
|
||||
@@ -34,6 +63,7 @@ func main() {
|
||||
"/var/run/bird/bird.ctl",
|
||||
"/var/run/bird/bird6.ctl",
|
||||
":8000",
|
||||
[]string{""},
|
||||
}
|
||||
|
||||
if birdSocketEnv := os.Getenv("BIRD_SOCKET"); birdSocketEnv != "" {
|
||||
@@ -45,16 +75,21 @@ func main() {
|
||||
if listenEnv := os.Getenv("BIRDLG_LISTEN"); listenEnv != "" {
|
||||
settingDefault.listen = listenEnv
|
||||
}
|
||||
if AllowedIPsEnv := os.Getenv("ALLOWED_IPS"); AllowedIPsEnv != "" {
|
||||
settingDefault.allowedIPs = strings.Split(AllowedIPsEnv, ",")
|
||||
}
|
||||
|
||||
// Allow parameters to override environment variables
|
||||
birdParam := flag.String("bird", settingDefault.birdSocket, "socket file for bird, set either in parameter or environment variable BIRD_SOCKET")
|
||||
bird6Param := flag.String("bird6", settingDefault.bird6Socket, "socket file for bird6, set either in parameter or environment variable BIRD6_SOCKET")
|
||||
listenParam := flag.String("listen", settingDefault.listen, "listen address, set either in parameter or environment variable BIRDLG_LISTEN")
|
||||
AllowedIPsParam := flag.String("allowed", strings.Join(settingDefault.allowedIPs, ","), "IPs allowed to access this proxy, separated by commas. Don't set to allow all IPs.")
|
||||
flag.Parse()
|
||||
|
||||
setting.birdSocket = *birdParam
|
||||
setting.bird6Socket = *bird6Param
|
||||
setting.listen = *listenParam
|
||||
setting.allowedIPs = strings.Split(*AllowedIPsParam, ",")
|
||||
|
||||
// Start HTTP server
|
||||
http.HandleFunc("/", invalidHandler)
|
||||
@@ -62,5 +97,6 @@ func main() {
|
||||
http.HandleFunc("/bird6", bird6Handler)
|
||||
http.HandleFunc("/traceroute", tracerouteIPv4Wrapper)
|
||||
http.HandleFunc("/traceroute6", tracerouteIPv6Wrapper)
|
||||
http.ListenAndServe(*listenParam, handlers.LoggingHandler(os.Stdout, http.DefaultServeMux))
|
||||
http.ListenAndServe(*listenParam, handlers.LoggingHandler(os.Stdout, accessHandler(http.DefaultServeMux)))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user