Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278d731b6f | ||
|
|
0e97c9f086 | ||
|
|
5f4df7e806 | ||
|
|
d8d7c5b477 | ||
|
|
d3f1785ac9 | ||
|
|
1de95ed53e | ||
|
|
1934c065ec | ||
|
|
301f9af3d4 | ||
|
|
cb0427bfbb | ||
|
|
7456e5907e | ||
|
|
8cab86b924 | ||
|
|
3ec5456e86 | ||
|
|
b51ea5fa07 | ||
|
|
2ac8783eb6 | ||
|
|
5014523ae0 | ||
|
|
dabcc9566c | ||
|
|
c453020349 | ||
|
|
0daaa32fc6 | ||
|
|
5e15fd6dd9 | ||
|
|
76c0f47832 | ||
|
|
70fee14103 | ||
|
|
abd7725fed | ||
|
|
f01b79e625 | ||
|
|
94387450cf | ||
|
|
5723490a6c | ||
|
|
d7506264ad | ||
|
|
245ac46b65 | ||
|
|
107e29ee20 | ||
|
|
5f447d4e31 | ||
|
|
347667a2bd |
118
cmd/root.go
118
cmd/root.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/OpenGFW/analyzer"
|
"github.com/apernet/OpenGFW/analyzer"
|
||||||
"github.com/apernet/OpenGFW/analyzer/tcp"
|
"github.com/apernet/OpenGFW/analyzer/tcp"
|
||||||
@@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/apernet/OpenGFW/modifier"
|
"github.com/apernet/OpenGFW/modifier"
|
||||||
modUDP "github.com/apernet/OpenGFW/modifier/udp"
|
modUDP "github.com/apernet/OpenGFW/modifier/udp"
|
||||||
"github.com/apernet/OpenGFW/ruleset"
|
"github.com/apernet/OpenGFW/ruleset"
|
||||||
|
"github.com/apernet/OpenGFW/ruleset/builtins/geo"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
@@ -41,6 +43,7 @@ var logger *zap.Logger
|
|||||||
// Flags
|
// Flags
|
||||||
var (
|
var (
|
||||||
cfgFile string
|
cfgFile string
|
||||||
|
pcapFile string
|
||||||
logLevel string
|
logLevel string
|
||||||
logFormat string
|
logFormat string
|
||||||
)
|
)
|
||||||
@@ -116,6 +119,7 @@ func init() {
|
|||||||
|
|
||||||
func initFlags() {
|
func initFlags() {
|
||||||
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file")
|
rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&pcapFile, "pcap", "p", "", "pcap file (optional)")
|
||||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", envOrDefaultString(appLogLevelEnv, "info"), "log level")
|
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", envOrDefaultString(appLogLevelEnv, "info"), "log level")
|
||||||
rootCmd.PersistentFlags().StringVarP(&logFormat, "log-format", "f", envOrDefaultString(appLogFormatEnv, "console"), "log format")
|
rootCmd.PersistentFlags().StringVarP(&logFormat, "log-format", "f", envOrDefaultString(appLogFormatEnv, "console"), "log format")
|
||||||
}
|
}
|
||||||
@@ -165,22 +169,33 @@ type cliConfig struct {
|
|||||||
IO cliConfigIO `mapstructure:"io"`
|
IO cliConfigIO `mapstructure:"io"`
|
||||||
Workers cliConfigWorkers `mapstructure:"workers"`
|
Workers cliConfigWorkers `mapstructure:"workers"`
|
||||||
Ruleset cliConfigRuleset `mapstructure:"ruleset"`
|
Ruleset cliConfigRuleset `mapstructure:"ruleset"`
|
||||||
|
Replay cliConfigReplay `mapstructure:"replay"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type cliConfigIO struct {
|
type cliConfigIO struct {
|
||||||
QueueSize uint32 `mapstructure:"queueSize"`
|
QueueSize uint32 `mapstructure:"queueSize"`
|
||||||
ReadBuffer int `mapstructure:"rcvBuf"`
|
QueueNum *uint16 `mapstructure:"queueNum"`
|
||||||
WriteBuffer int `mapstructure:"sndBuf"`
|
Table string `mapstructure:"table"`
|
||||||
Local bool `mapstructure:"local"`
|
ConnMarkAccept uint32 `mapstructure:"connMarkAccept"`
|
||||||
RST bool `mapstructure:"rst"`
|
ConnMarkDrop uint32 `mapstructure:"connMarkDrop"`
|
||||||
|
|
||||||
|
ReadBuffer int `mapstructure:"rcvBuf"`
|
||||||
|
WriteBuffer int `mapstructure:"sndBuf"`
|
||||||
|
Local bool `mapstructure:"local"`
|
||||||
|
RST bool `mapstructure:"rst"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type cliConfigReplay struct {
|
||||||
|
Realtime bool `mapstructure:"realtime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type cliConfigWorkers struct {
|
type cliConfigWorkers struct {
|
||||||
Count int `mapstructure:"count"`
|
Count int `mapstructure:"count"`
|
||||||
QueueSize int `mapstructure:"queueSize"`
|
QueueSize int `mapstructure:"queueSize"`
|
||||||
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"`
|
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"`
|
||||||
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"`
|
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"`
|
||||||
UDPMaxStreams int `mapstructure:"udpMaxStreams"`
|
TCPTimeout time.Duration `mapstructure:"tcpTimeout"`
|
||||||
|
UDPMaxStreams int `mapstructure:"udpMaxStreams"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type cliConfigRuleset struct {
|
type cliConfigRuleset struct {
|
||||||
@@ -194,17 +209,35 @@ func (c *cliConfig) fillLogger(config *engine.Config) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *cliConfig) fillIO(config *engine.Config) error {
|
func (c *cliConfig) fillIO(config *engine.Config) error {
|
||||||
nfio, err := io.NewNFQueuePacketIO(io.NFQueuePacketIOConfig{
|
var ioImpl io.PacketIO
|
||||||
QueueSize: c.IO.QueueSize,
|
var err error
|
||||||
ReadBuffer: c.IO.ReadBuffer,
|
if pcapFile != "" {
|
||||||
WriteBuffer: c.IO.WriteBuffer,
|
// Setup IO for pcap file replay
|
||||||
Local: c.IO.Local,
|
logger.Info("replaying from pcap file", zap.String("pcap file", pcapFile))
|
||||||
RST: c.IO.RST,
|
ioImpl, err = io.NewPcapPacketIO(io.PcapPacketIOConfig{
|
||||||
})
|
PcapFile: pcapFile,
|
||||||
|
Realtime: c.Replay.Realtime,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Setup IO for nfqueue
|
||||||
|
ioImpl, err = io.NewNFQueuePacketIO(io.NFQueuePacketIOConfig{
|
||||||
|
QueueSize: c.IO.QueueSize,
|
||||||
|
QueueNum: c.IO.QueueNum,
|
||||||
|
Table: c.IO.Table,
|
||||||
|
ConnMarkAccept: c.IO.ConnMarkAccept,
|
||||||
|
ConnMarkDrop: c.IO.ConnMarkDrop,
|
||||||
|
|
||||||
|
ReadBuffer: c.IO.ReadBuffer,
|
||||||
|
WriteBuffer: c.IO.WriteBuffer,
|
||||||
|
Local: c.IO.Local,
|
||||||
|
RST: c.IO.RST,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return configError{Field: "io", Err: err}
|
return configError{Field: "io", Err: err}
|
||||||
}
|
}
|
||||||
config.IO = nfio
|
config.IO = ioImpl
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,6 +246,7 @@ func (c *cliConfig) fillWorkers(config *engine.Config) error {
|
|||||||
config.WorkerQueueSize = c.Workers.QueueSize
|
config.WorkerQueueSize = c.Workers.QueueSize
|
||||||
config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal
|
config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal
|
||||||
config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn
|
config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn
|
||||||
|
config.WorkerTCPTimeout = c.Workers.TCPTimeout
|
||||||
config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams
|
config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -256,8 +290,7 @@ func runMain(cmd *cobra.Command, args []string) {
|
|||||||
}
|
}
|
||||||
rsConfig := &ruleset.BuiltinConfig{
|
rsConfig := &ruleset.BuiltinConfig{
|
||||||
Logger: &rulesetLogger{},
|
Logger: &rulesetLogger{},
|
||||||
GeoSiteFilename: config.Ruleset.GeoSite,
|
GeoMatcher: geo.NewGeoMatcher(config.Ruleset.GeoSite, config.Ruleset.GeoIp),
|
||||||
GeoIpFilename: config.Ruleset.GeoIp,
|
|
||||||
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
|
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
|
||||||
}
|
}
|
||||||
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
|
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
|
||||||
@@ -340,12 +373,26 @@ func (l *engineLogger) TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
||||||
logger.Info("TCP stream action",
|
if noMatch {
|
||||||
zap.Int64("id", info.ID),
|
logger.Debug("TCP stream no match",
|
||||||
zap.String("src", info.SrcString()),
|
zap.Int64("id", info.ID),
|
||||||
zap.String("dst", info.DstString()),
|
zap.String("src", info.SrcString()),
|
||||||
zap.String("action", action.String()),
|
zap.String("dst", info.DstString()),
|
||||||
zap.Bool("noMatch", noMatch))
|
zap.String("action", action.String()))
|
||||||
|
} else {
|
||||||
|
logger.Info("TCP stream action",
|
||||||
|
zap.Int64("id", info.ID),
|
||||||
|
zap.String("src", info.SrcString()),
|
||||||
|
zap.String("dst", info.DstString()),
|
||||||
|
zap.String("action", action.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *engineLogger) TCPFlush(workerID, flushed, closed int) {
|
||||||
|
logger.Debug("TCP flush",
|
||||||
|
zap.Int("workerID", workerID),
|
||||||
|
zap.Int("flushed", flushed),
|
||||||
|
zap.Int("closed", closed))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) {
|
func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) {
|
||||||
@@ -366,12 +413,19 @@ func (l *engineLogger) UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
||||||
logger.Info("UDP stream action",
|
if noMatch {
|
||||||
zap.Int64("id", info.ID),
|
logger.Debug("UDP stream no match",
|
||||||
zap.String("src", info.SrcString()),
|
zap.Int64("id", info.ID),
|
||||||
zap.String("dst", info.DstString()),
|
zap.String("src", info.SrcString()),
|
||||||
zap.String("action", action.String()),
|
zap.String("dst", info.DstString()),
|
||||||
zap.Bool("noMatch", noMatch))
|
zap.String("action", action.String()))
|
||||||
|
} else {
|
||||||
|
logger.Info("UDP stream action",
|
||||||
|
zap.Int64("id", info.ID),
|
||||||
|
zap.String("src", info.SrcString()),
|
||||||
|
zap.String("dst", info.DstString()),
|
||||||
|
zap.String("action", action.String()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) {
|
func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) {
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
Ruleset: config.Ruleset,
|
Ruleset: config.Ruleset,
|
||||||
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
||||||
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
||||||
|
TCPTimeout: config.WorkerTCPTimeout,
|
||||||
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -57,12 +58,17 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *engine) Run(ctx context.Context) error {
|
func (e *engine) Run(ctx context.Context) error {
|
||||||
|
workerCtx, workerCancel := context.WithCancel(ctx)
|
||||||
|
defer workerCancel() // Stop workers
|
||||||
|
|
||||||
|
// Register IO shutdown
|
||||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||||
defer ioCancel() // Stop workers & IO
|
e.io.SetCancelFunc(ioCancel)
|
||||||
|
defer ioCancel() // Stop IO
|
||||||
|
|
||||||
// Start workers
|
// Start workers
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
go w.Run(ioCtx)
|
go w.Run(workerCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register IO callback
|
// Register IO callback
|
||||||
@@ -84,6 +90,8 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
|
case <-ioCtx.Done():
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,9 +109,11 @@ func (e *engine) dispatch(p io.Packet) bool {
|
|||||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
// Convert to gopacket.Packet
|
||||||
|
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
||||||
|
packet.Metadata().Timestamp = p.Timestamp()
|
||||||
// Load balance by stream ID
|
// Load balance by stream ID
|
||||||
index := p.StreamID() % uint32(len(e.workers))
|
index := p.StreamID() % uint32(len(e.workers))
|
||||||
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
|
||||||
e.workers[index].Feed(&workerPacket{
|
e.workers[index].Feed(&workerPacket{
|
||||||
StreamID: p.StreamID(),
|
StreamID: p.StreamID(),
|
||||||
Packet: packet,
|
Packet: packet,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/OpenGFW/io"
|
"github.com/apernet/OpenGFW/io"
|
||||||
"github.com/apernet/OpenGFW/ruleset"
|
"github.com/apernet/OpenGFW/ruleset"
|
||||||
@@ -25,6 +26,7 @@ type Config struct {
|
|||||||
WorkerQueueSize int
|
WorkerQueueSize int
|
||||||
WorkerTCPMaxBufferedPagesTotal int
|
WorkerTCPMaxBufferedPagesTotal int
|
||||||
WorkerTCPMaxBufferedPagesPerConn int
|
WorkerTCPMaxBufferedPagesPerConn int
|
||||||
|
WorkerTCPTimeout time.Duration
|
||||||
WorkerUDPMaxStreams int
|
WorkerUDPMaxStreams int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,6 +38,7 @@ type Logger interface {
|
|||||||
TCPStreamNew(workerID int, info ruleset.StreamInfo)
|
TCPStreamNew(workerID int, info ruleset.StreamInfo)
|
||||||
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
||||||
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
|
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
|
||||||
|
TCPFlush(workerID, flushed, closed int)
|
||||||
|
|
||||||
UDPStreamNew(workerID int, info ruleset.StreamInfo)
|
UDPStreamNew(workerID int, info ruleset.StreamInfo)
|
||||||
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/OpenGFW/io"
|
"github.com/apernet/OpenGFW/io"
|
||||||
"github.com/apernet/OpenGFW/ruleset"
|
"github.com/apernet/OpenGFW/ruleset"
|
||||||
@@ -14,9 +15,12 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultChanSize = 64
|
defaultChanSize = 64
|
||||||
defaultTCPMaxBufferedPagesTotal = 4096
|
defaultTCPMaxBufferedPagesTotal = 65536
|
||||||
defaultTCPMaxBufferedPagesPerConnection = 64
|
defaultTCPMaxBufferedPagesPerConnection = 16
|
||||||
|
defaultTCPTimeout = 10 * time.Minute
|
||||||
defaultUDPMaxStreams = 4096
|
defaultUDPMaxStreams = 4096
|
||||||
|
|
||||||
|
tcpFlushInterval = 1 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
type workerPacket struct {
|
type workerPacket struct {
|
||||||
@@ -33,6 +37,7 @@ type worker struct {
|
|||||||
tcpStreamFactory *tcpStreamFactory
|
tcpStreamFactory *tcpStreamFactory
|
||||||
tcpStreamPool *reassembly.StreamPool
|
tcpStreamPool *reassembly.StreamPool
|
||||||
tcpAssembler *reassembly.Assembler
|
tcpAssembler *reassembly.Assembler
|
||||||
|
tcpTimeout time.Duration
|
||||||
|
|
||||||
udpStreamFactory *udpStreamFactory
|
udpStreamFactory *udpStreamFactory
|
||||||
udpStreamManager *udpStreamManager
|
udpStreamManager *udpStreamManager
|
||||||
@@ -47,6 +52,7 @@ type workerConfig struct {
|
|||||||
Ruleset ruleset.Ruleset
|
Ruleset ruleset.Ruleset
|
||||||
TCPMaxBufferedPagesTotal int
|
TCPMaxBufferedPagesTotal int
|
||||||
TCPMaxBufferedPagesPerConn int
|
TCPMaxBufferedPagesPerConn int
|
||||||
|
TCPTimeout time.Duration
|
||||||
UDPMaxStreams int
|
UDPMaxStreams int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,6 +66,9 @@ func (c *workerConfig) fillDefaults() {
|
|||||||
if c.TCPMaxBufferedPagesPerConn <= 0 {
|
if c.TCPMaxBufferedPagesPerConn <= 0 {
|
||||||
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
|
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
|
||||||
}
|
}
|
||||||
|
if c.TCPTimeout <= 0 {
|
||||||
|
c.TCPTimeout = defaultTCPTimeout
|
||||||
|
}
|
||||||
if c.UDPMaxStreams <= 0 {
|
if c.UDPMaxStreams <= 0 {
|
||||||
c.UDPMaxStreams = defaultUDPMaxStreams
|
c.UDPMaxStreams = defaultUDPMaxStreams
|
||||||
}
|
}
|
||||||
@@ -98,6 +107,7 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
tcpStreamFactory: tcpSF,
|
tcpStreamFactory: tcpSF,
|
||||||
tcpStreamPool: tcpStreamPool,
|
tcpStreamPool: tcpStreamPool,
|
||||||
tcpAssembler: tcpAssembler,
|
tcpAssembler: tcpAssembler,
|
||||||
|
tcpTimeout: config.TCPTimeout,
|
||||||
udpStreamFactory: udpSF,
|
udpStreamFactory: udpSF,
|
||||||
udpStreamManager: udpSM,
|
udpStreamManager: udpSM,
|
||||||
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
||||||
@@ -111,6 +121,10 @@ func (w *worker) Feed(p *workerPacket) {
|
|||||||
func (w *worker) Run(ctx context.Context) {
|
func (w *worker) Run(ctx context.Context) {
|
||||||
w.logger.WorkerStart(w.id)
|
w.logger.WorkerStart(w.id)
|
||||||
defer w.logger.WorkerStop(w.id)
|
defer w.logger.WorkerStop(w.id)
|
||||||
|
|
||||||
|
tcpFlushTicker := time.NewTicker(tcpFlushInterval)
|
||||||
|
defer tcpFlushTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -122,6 +136,8 @@ func (w *worker) Run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
|
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
|
||||||
_ = wPkt.SetVerdict(v, b)
|
_ = wPkt.SetVerdict(v, b)
|
||||||
|
case <-tcpFlushTicker.C:
|
||||||
|
w.flushTCP(w.tcpTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -176,6 +192,11 @@ func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata,
|
|||||||
return io.Verdict(ctx.Verdict)
|
return io.Verdict(ctx.Verdict)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *worker) flushTCP(timeout time.Duration) {
|
||||||
|
flushed, closed := w.tcpAssembler.FlushCloseOlderThan(time.Now().Add(-timeout))
|
||||||
|
w.logger.TCPFlush(w.id, flushed, closed)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
|
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
|
||||||
ctx := &udpContext{
|
ctx := &udpContext{
|
||||||
Verdict: udpVerdictAccept,
|
Verdict: udpVerdictAccept,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package io
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Verdict int
|
type Verdict int
|
||||||
@@ -24,6 +25,8 @@ const (
|
|||||||
type Packet interface {
|
type Packet interface {
|
||||||
// StreamID is the ID of the stream the packet belongs to.
|
// StreamID is the ID of the stream the packet belongs to.
|
||||||
StreamID() uint32
|
StreamID() uint32
|
||||||
|
// Timestamp is the time the packet was received.
|
||||||
|
Timestamp() time.Time
|
||||||
// Data is the raw packet data, starting with the IP header.
|
// Data is the raw packet data, starting with the IP header.
|
||||||
Data() []byte
|
Data() []byte
|
||||||
}
|
}
|
||||||
@@ -45,6 +48,9 @@ type PacketIO interface {
|
|||||||
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
|
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
// Close closes the packet IO.
|
// Close closes the packet IO.
|
||||||
Close() error
|
Close() error
|
||||||
|
// SetCancelFunc gives packet IO access to context cancel function, enabling it to
|
||||||
|
// trigger a shutdown
|
||||||
|
SetCancelFunc(cancelFunc context.CancelFunc) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrInvalidPacket struct {
|
type ErrInvalidPacket struct {
|
||||||
|
|||||||
147
io/nfqueue.go
147
io/nfqueue.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/florianl/go-nfqueue"
|
"github.com/florianl/go-nfqueue"
|
||||||
@@ -18,29 +19,28 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
nfqueueNum = 100
|
nfqueueDefaultQueueNum = 100
|
||||||
nfqueueMaxPacketLen = 0xFFFF
|
nfqueueMaxPacketLen = 0xFFFF
|
||||||
nfqueueDefaultQueueSize = 128
|
nfqueueDefaultQueueSize = 128
|
||||||
|
|
||||||
nfqueueConnMarkAccept = 1001
|
nfqueueDefaultConnMarkAccept = 1001
|
||||||
nfqueueConnMarkDrop = 1002
|
|
||||||
|
|
||||||
nftFamily = "inet"
|
nftFamily = "inet"
|
||||||
nftTable = "opengfw"
|
nftDefaultTable = "opengfw"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
func (n *nfqueuePacketIO) generateNftRules() (*nftTableSpec, error) {
|
||||||
if local && rst {
|
if n.local && n.rst {
|
||||||
return nil, errors.New("tcp rst is not supported in local mode")
|
return nil, errors.New("tcp rst is not supported in local mode")
|
||||||
}
|
}
|
||||||
table := &nftTableSpec{
|
table := &nftTableSpec{
|
||||||
Family: nftFamily,
|
Family: nftFamily,
|
||||||
Table: nftTable,
|
Table: n.table,
|
||||||
}
|
}
|
||||||
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
|
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", n.connMarkAccept))
|
||||||
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
|
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", n.connMarkDrop))
|
||||||
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum))
|
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", n.queueNum))
|
||||||
if local {
|
if n.local {
|
||||||
table.Chains = []nftChainSpec{
|
table.Chains = []nftChainSpec{
|
||||||
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
|
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
|
||||||
{Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"},
|
{Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"},
|
||||||
@@ -54,7 +54,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
|||||||
c := &table.Chains[i]
|
c := &table.Chains[i]
|
||||||
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
|
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
|
||||||
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
||||||
if rst {
|
if n.rst {
|
||||||
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
||||||
}
|
}
|
||||||
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
|
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
|
||||||
@@ -63,12 +63,12 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
|||||||
return table, nil
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateIptRules(local, rst bool) ([]iptRule, error) {
|
func (n *nfqueuePacketIO) generateIptRules() ([]iptRule, error) {
|
||||||
if local && rst {
|
if n.local && n.rst {
|
||||||
return nil, errors.New("tcp rst is not supported in local mode")
|
return nil, errors.New("tcp rst is not supported in local mode")
|
||||||
}
|
}
|
||||||
var chains []string
|
var chains []string
|
||||||
if local {
|
if n.local {
|
||||||
chains = []string{"INPUT", "OUTPUT"}
|
chains = []string{"INPUT", "OUTPUT"}
|
||||||
} else {
|
} else {
|
||||||
chains = []string{"FORWARD"}
|
chains = []string{"FORWARD"}
|
||||||
@@ -76,13 +76,13 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
|
|||||||
rules := make([]iptRule, 0, 4*len(chains))
|
rules := make([]iptRule, 0, 4*len(chains))
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
// Bypass protected connections
|
// Bypass protected connections
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(n.connMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(n.connMarkAccept)}})
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(n.connMarkAccept), "-j", "ACCEPT"}})
|
||||||
if rst {
|
if n.rst {
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(n.connMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
||||||
}
|
}
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(n.connMarkDrop), "-j", "DROP"}})
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(n.queueNum), "--queue-bypass"}})
|
||||||
}
|
}
|
||||||
|
|
||||||
return rules, nil
|
return rules, nil
|
||||||
@@ -93,10 +93,14 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
|
|||||||
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
||||||
|
|
||||||
type nfqueuePacketIO struct {
|
type nfqueuePacketIO struct {
|
||||||
n *nfqueue.Nfqueue
|
n *nfqueue.Nfqueue
|
||||||
local bool
|
local bool
|
||||||
rst bool
|
rst bool
|
||||||
rSet bool // whether the nftables/iptables rules have been set
|
rSet bool // whether the nftables/iptables rules have been set
|
||||||
|
queueNum int
|
||||||
|
table string // nftable name
|
||||||
|
connMarkAccept int
|
||||||
|
connMarkDrop int
|
||||||
|
|
||||||
// iptables not nil = use iptables instead of nftables
|
// iptables not nil = use iptables instead of nftables
|
||||||
ipt4 *iptables.IPTables
|
ipt4 *iptables.IPTables
|
||||||
@@ -106,7 +110,12 @@ type nfqueuePacketIO struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NFQueuePacketIOConfig struct {
|
type NFQueuePacketIOConfig struct {
|
||||||
QueueSize uint32
|
QueueSize uint32
|
||||||
|
QueueNum *uint16
|
||||||
|
Table string
|
||||||
|
ConnMarkAccept uint32
|
||||||
|
ConnMarkDrop uint32
|
||||||
|
|
||||||
ReadBuffer int
|
ReadBuffer int
|
||||||
WriteBuffer int
|
WriteBuffer int
|
||||||
Local bool
|
Local bool
|
||||||
@@ -117,6 +126,26 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
if config.QueueSize == 0 {
|
if config.QueueSize == 0 {
|
||||||
config.QueueSize = nfqueueDefaultQueueSize
|
config.QueueSize = nfqueueDefaultQueueSize
|
||||||
}
|
}
|
||||||
|
if config.QueueNum == nil {
|
||||||
|
queueNum := uint16(nfqueueDefaultQueueNum)
|
||||||
|
config.QueueNum = &queueNum
|
||||||
|
}
|
||||||
|
if config.Table == "" {
|
||||||
|
config.Table = nftDefaultTable
|
||||||
|
}
|
||||||
|
if config.ConnMarkAccept == 0 {
|
||||||
|
config.ConnMarkAccept = nfqueueDefaultConnMarkAccept
|
||||||
|
}
|
||||||
|
if config.ConnMarkDrop == 0 {
|
||||||
|
config.ConnMarkDrop = config.ConnMarkAccept + 1
|
||||||
|
if config.ConnMarkDrop == 0 {
|
||||||
|
// Overflow
|
||||||
|
config.ConnMarkDrop = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.ConnMarkAccept == config.ConnMarkDrop {
|
||||||
|
return nil, errors.New("connMarkAccept and connMarkDrop cannot be the same")
|
||||||
|
}
|
||||||
var ipt4, ipt6 *iptables.IPTables
|
var ipt4, ipt6 *iptables.IPTables
|
||||||
var err error
|
var err error
|
||||||
if nftCheck() != nil {
|
if nftCheck() != nil {
|
||||||
@@ -131,7 +160,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
n, err := nfqueue.Open(&nfqueue.Config{
|
n, err := nfqueue.Open(&nfqueue.Config{
|
||||||
NfQueue: nfqueueNum,
|
NfQueue: *config.QueueNum,
|
||||||
MaxPacketLen: nfqueueMaxPacketLen,
|
MaxPacketLen: nfqueueMaxPacketLen,
|
||||||
MaxQueueLen: config.QueueSize,
|
MaxQueueLen: config.QueueSize,
|
||||||
Copymode: nfqueue.NfQnlCopyPacket,
|
Copymode: nfqueue.NfQnlCopyPacket,
|
||||||
@@ -155,16 +184,20 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &nfqueuePacketIO{
|
return &nfqueuePacketIO{
|
||||||
n: n,
|
n: n,
|
||||||
local: config.Local,
|
local: config.Local,
|
||||||
rst: config.RST,
|
rst: config.RST,
|
||||||
ipt4: ipt4,
|
queueNum: int(*config.QueueNum),
|
||||||
ipt6: ipt6,
|
table: config.Table,
|
||||||
|
connMarkAccept: int(config.ConnMarkAccept),
|
||||||
|
connMarkDrop: int(config.ConnMarkDrop),
|
||||||
|
ipt4: ipt4,
|
||||||
|
ipt6: ipt6,
|
||||||
protectedDialer: &net.Dialer{
|
protectedDialer: &net.Dialer{
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
var err error
|
var err error
|
||||||
cErr := c.Control(func(fd uintptr) {
|
cErr := c.Control(func(fd uintptr) {
|
||||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
|
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(config.ConnMarkAccept))
|
||||||
})
|
})
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
return cErr
|
return cErr
|
||||||
@@ -189,6 +222,12 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
|
|||||||
streamID: ctIDFromCtBytes(*a.Ct),
|
streamID: ctIDFromCtBytes(*a.Ct),
|
||||||
data: *a.Payload,
|
data: *a.Payload,
|
||||||
}
|
}
|
||||||
|
// Use timestamp from attribute if available, otherwise use current time as fallback
|
||||||
|
if a.Timestamp != nil {
|
||||||
|
p.timestamp = *a.Timestamp
|
||||||
|
} else {
|
||||||
|
p.timestamp = time.Now()
|
||||||
|
}
|
||||||
return okBoolToInt(cb(p, nil))
|
return okBoolToInt(cb(p, nil))
|
||||||
},
|
},
|
||||||
func(e error) int {
|
func(e error) int {
|
||||||
@@ -205,9 +244,9 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
|
|||||||
}
|
}
|
||||||
if !n.rSet {
|
if !n.rSet {
|
||||||
if n.ipt4 != nil {
|
if n.ipt4 != nil {
|
||||||
err = n.setupIpt(n.local, n.rst, false)
|
err = n.setupIpt(false)
|
||||||
} else {
|
} else {
|
||||||
err = n.setupNft(n.local, n.rst, false)
|
err = n.setupNft(false)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -247,11 +286,11 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
|
|||||||
case VerdictAcceptModify:
|
case VerdictAcceptModify:
|
||||||
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
|
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
|
||||||
case VerdictAcceptStream:
|
case VerdictAcceptStream:
|
||||||
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
|
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, n.connMarkAccept)
|
||||||
case VerdictDrop:
|
case VerdictDrop:
|
||||||
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
|
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
|
||||||
case VerdictDropStream:
|
case VerdictDropStream:
|
||||||
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
|
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, n.connMarkDrop)
|
||||||
default:
|
default:
|
||||||
// Invalid verdict, ignore for now
|
// Invalid verdict, ignore for now
|
||||||
return nil
|
return nil
|
||||||
@@ -265,26 +304,31 @@ func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, add
|
|||||||
func (n *nfqueuePacketIO) Close() error {
|
func (n *nfqueuePacketIO) Close() error {
|
||||||
if n.rSet {
|
if n.rSet {
|
||||||
if n.ipt4 != nil {
|
if n.ipt4 != nil {
|
||||||
_ = n.setupIpt(n.local, n.rst, true)
|
_ = n.setupIpt(true)
|
||||||
} else {
|
} else {
|
||||||
_ = n.setupNft(n.local, n.rst, true)
|
_ = n.setupNft(true)
|
||||||
}
|
}
|
||||||
n.rSet = false
|
n.rSet = false
|
||||||
}
|
}
|
||||||
return n.n.Close()
|
return n.n.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
|
// nfqueue IO does not issue shutdown
|
||||||
rules, err := generateNftRules(local, rst)
|
func (n *nfqueuePacketIO) SetCancelFunc(cancelFunc context.CancelFunc) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *nfqueuePacketIO) setupNft(remove bool) error {
|
||||||
|
rules, err := n.generateNftRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rulesText := rules.String()
|
rulesText := rules.String()
|
||||||
if remove {
|
if remove {
|
||||||
err = nftDelete(nftFamily, nftTable)
|
err = nftDelete(nftFamily, n.table)
|
||||||
} else {
|
} else {
|
||||||
// Delete first to make sure no leftover rules
|
// Delete first to make sure no leftover rules
|
||||||
_ = nftDelete(nftFamily, nftTable)
|
_ = nftDelete(nftFamily, n.table)
|
||||||
err = nftAdd(rulesText)
|
err = nftAdd(rulesText)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -293,8 +337,8 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
|
func (n *nfqueuePacketIO) setupIpt(remove bool) error {
|
||||||
rules, err := generateIptRules(local, rst)
|
rules, err := n.generateIptRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -312,15 +356,20 @@ func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
|
|||||||
var _ Packet = (*nfqueuePacket)(nil)
|
var _ Packet = (*nfqueuePacket)(nil)
|
||||||
|
|
||||||
type nfqueuePacket struct {
|
type nfqueuePacket struct {
|
||||||
id uint32
|
id uint32
|
||||||
streamID uint32
|
streamID uint32
|
||||||
data []byte
|
timestamp time.Time
|
||||||
|
data []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *nfqueuePacket) StreamID() uint32 {
|
func (p *nfqueuePacket) StreamID() uint32 {
|
||||||
return p.streamID
|
return p.streamID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *nfqueuePacket) Timestamp() time.Time {
|
||||||
|
return p.timestamp
|
||||||
|
}
|
||||||
|
|
||||||
func (p *nfqueuePacket) Data() []byte {
|
func (p *nfqueuePacket) Data() []byte {
|
||||||
return p.data
|
return p.data
|
||||||
}
|
}
|
||||||
|
|||||||
136
io/pcap.go
Normal file
136
io/pcap.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package io
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/pcapgo"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ PacketIO = (*pcapPacketIO)(nil)
|
||||||
|
|
||||||
|
type pcapPacketIO struct {
|
||||||
|
pcapFile io.ReadCloser
|
||||||
|
pcap *pcapgo.Reader
|
||||||
|
timeOffset *time.Duration
|
||||||
|
ioCancel context.CancelFunc
|
||||||
|
config PcapPacketIOConfig
|
||||||
|
|
||||||
|
dialer *net.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
type PcapPacketIOConfig struct {
|
||||||
|
PcapFile string
|
||||||
|
Realtime bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPcapPacketIO(config PcapPacketIOConfig) (PacketIO, error) {
|
||||||
|
pcapFile, err := os.Open(config.PcapFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
handle, err := pcapgo.NewReader(pcapFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pcapPacketIO{
|
||||||
|
pcapFile: pcapFile,
|
||||||
|
pcap: handle,
|
||||||
|
timeOffset: nil,
|
||||||
|
ioCancel: nil,
|
||||||
|
config: config,
|
||||||
|
dialer: &net.Dialer{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
||||||
|
go func() {
|
||||||
|
packetSource := gopacket.NewPacketSource(p.pcap, p.pcap.LinkType())
|
||||||
|
for packet := range packetSource.Packets() {
|
||||||
|
p.wait(packet)
|
||||||
|
|
||||||
|
networkLayer := packet.NetworkLayer()
|
||||||
|
if networkLayer != nil {
|
||||||
|
src, dst := networkLayer.NetworkFlow().Endpoints()
|
||||||
|
endpoints := []string{src.String(), dst.String()}
|
||||||
|
sort.Strings(endpoints)
|
||||||
|
id := crc32.Checksum([]byte(strings.Join(endpoints, ",")), crc32.IEEETable)
|
||||||
|
|
||||||
|
cb(&pcapPacket{
|
||||||
|
streamID: id,
|
||||||
|
timestamp: packet.Metadata().Timestamp,
|
||||||
|
data: packet.LinkLayer().LayerPayload(),
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Give the workers a chance to finish everything
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
// Stop the engine when all packets are finished
|
||||||
|
p.ioCancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// A normal dialer is sufficient as pcap IO does not mess up with the networking
|
||||||
|
func (p *pcapPacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return p.dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacketIO) SetVerdict(pkt Packet, v Verdict, newPacket []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacketIO) SetCancelFunc(cancelFunc context.CancelFunc) error {
|
||||||
|
p.ioCancel = cancelFunc
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacketIO) Close() error {
|
||||||
|
return p.pcapFile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intentionally slow down the replay
|
||||||
|
// In realtime mode, this is to match the timestamps in the capture
|
||||||
|
func (p *pcapPacketIO) wait(packet gopacket.Packet) {
|
||||||
|
if !p.config.Realtime {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.timeOffset == nil {
|
||||||
|
offset := time.Since(packet.Metadata().Timestamp)
|
||||||
|
p.timeOffset = &offset
|
||||||
|
} else {
|
||||||
|
t := time.Until(packet.Metadata().Timestamp.Add(*p.timeOffset))
|
||||||
|
time.Sleep(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Packet = (*pcapPacket)(nil)
|
||||||
|
|
||||||
|
type pcapPacket struct {
|
||||||
|
streamID uint32
|
||||||
|
timestamp time.Time
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacket) StreamID() uint32 {
|
||||||
|
return p.streamID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacket) Timestamp() time.Time {
|
||||||
|
return p.timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pcapPacket) Data() []byte {
|
||||||
|
return p.data
|
||||||
|
}
|
||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/apernet/OpenGFW/analyzer"
|
"github.com/apernet/OpenGFW/analyzer"
|
||||||
"github.com/apernet/OpenGFW/modifier"
|
"github.com/apernet/OpenGFW/modifier"
|
||||||
"github.com/apernet/OpenGFW/ruleset/builtins"
|
"github.com/apernet/OpenGFW/ruleset/builtins"
|
||||||
"github.com/apernet/OpenGFW/ruleset/builtins/geo"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExprRule is the external representation of an expression rule.
|
// ExprRule is the external representation of an expression rule.
|
||||||
@@ -302,23 +301,22 @@ type Function struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
|
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
|
||||||
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
|
||||||
return map[string]*Function{
|
return map[string]*Function{
|
||||||
"geoip": {
|
"geoip": {
|
||||||
InitFunc: geoMatcher.LoadGeoIP,
|
InitFunc: config.GeoMatcher.LoadGeoIP,
|
||||||
PatchFunc: nil,
|
PatchFunc: nil,
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
|
return config.GeoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
|
Types: []reflect.Type{reflect.TypeOf(config.GeoMatcher.MatchGeoIp)},
|
||||||
},
|
},
|
||||||
"geosite": {
|
"geosite": {
|
||||||
InitFunc: geoMatcher.LoadGeoSite,
|
InitFunc: config.GeoMatcher.LoadGeoSite,
|
||||||
PatchFunc: nil,
|
PatchFunc: nil,
|
||||||
Func: func(params ...any) (any, error) {
|
Func: func(params ...any) (any, error) {
|
||||||
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
|
return config.GeoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
|
||||||
},
|
},
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
Types: []reflect.Type{reflect.TypeOf(config.GeoMatcher.MatchGeoSite)},
|
||||||
},
|
},
|
||||||
"cidr": {
|
"cidr": {
|
||||||
InitFunc: nil,
|
InitFunc: nil,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/apernet/OpenGFW/analyzer"
|
"github.com/apernet/OpenGFW/analyzer"
|
||||||
"github.com/apernet/OpenGFW/modifier"
|
"github.com/apernet/OpenGFW/modifier"
|
||||||
|
"github.com/apernet/OpenGFW/ruleset/builtins/geo"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Action int
|
type Action int
|
||||||
@@ -102,7 +103,6 @@ type Logger interface {
|
|||||||
|
|
||||||
type BuiltinConfig struct {
|
type BuiltinConfig struct {
|
||||||
Logger Logger
|
Logger Logger
|
||||||
GeoSiteFilename string
|
GeoMatcher *geo.GeoMatcher
|
||||||
GeoIpFilename string
|
|
||||||
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user