Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f447d4e31 | ||
|
|
347667a2bd | ||
|
|
393c29bd2d | ||
|
|
9c0893c512 | ||
|
|
ae34b4856a | ||
|
|
d7737e9211 | ||
|
|
dd9ecc3dd7 |
@@ -18,8 +18,8 @@ Telegram グループ: https://t.me/OpGFW
|
|||||||
## 特徴
|
## 特徴
|
||||||
|
|
||||||
- フル IP/TCP 再アセンブル、各種プロトコルアナライザー
|
- フル IP/TCP 再アセンブル、各種プロトコルアナライザー
|
||||||
- HTTP、TLS、QUIC、DNS、SSH、SOCKS4/5、WireGuard、その他多数
|
- HTTP、TLS、QUIC、DNS、SSH、SOCKS4/5、WireGuard、OpenVPN、その他多数
|
||||||
- Shadowsocks の「完全に暗号化されたトラフィック」の検出など (https://gfw.report/publications/usenixsecurity23/en/)
|
- Shadowsocks、VMess の「完全に暗号化されたトラフィック」の検出など (https://gfw.report/publications/usenixsecurity23/en/)
|
||||||
- Trojan プロキシプロトコルの検出
|
- Trojan プロキシプロトコルの検出
|
||||||
- [WIP] 機械学習に基づくトラフィック分類
|
- [WIP] 機械学習に基づくトラフィック分類
|
||||||
- IPv4 と IPv6 をフルサポート
|
- IPv4 と IPv6 をフルサポート
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ Telegram group: https://t.me/OpGFW
|
|||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Full IP/TCP reassembly, various protocol analyzers
|
- Full IP/TCP reassembly, various protocol analyzers
|
||||||
- HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, and many more to come
|
- HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, OpenVPN, and many more to come
|
||||||
- "Fully encrypted traffic" detection for Shadowsocks,
|
- "Fully encrypted traffic" detection for Shadowsocks, VMess,
|
||||||
etc. (https://gfw.report/publications/usenixsecurity23/en/)
|
etc. (https://gfw.report/publications/usenixsecurity23/en/)
|
||||||
- Trojan (proxy protocol) detection
|
- Trojan (proxy protocol) detection
|
||||||
- [WIP] Machine learning based traffic classification
|
- [WIP] Machine learning based traffic classification
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ Telegram 群组: https://t.me/OpGFW
|
|||||||
## 功能
|
## 功能
|
||||||
|
|
||||||
- 完整的 IP/TCP 重组,各种协议解析器
|
- 完整的 IP/TCP 重组,各种协议解析器
|
||||||
- HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, 更多协议正在开发中
|
- HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, OpenVPN, 更多协议正在开发中
|
||||||
- Shadowsocks 等 "全加密流量" 检测 (https://gfw.report/publications/usenixsecurity23/zh/)
|
- Shadowsocks, VMess 等 "全加密流量" 检测 (https://gfw.report/publications/usenixsecurity23/zh/)
|
||||||
- Trojan 协议检测
|
- Trojan 协议检测
|
||||||
- [开发中] 基于机器学习的流量分类
|
- [开发中] 基于机器学习的流量分类
|
||||||
- 同等支持 IPv4 和 IPv6
|
- 同等支持 IPv4 和 IPv6
|
||||||
|
|||||||
74
cmd/root.go
74
cmd/root.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/OpenGFW/analyzer"
|
"github.com/apernet/OpenGFW/analyzer"
|
||||||
"github.com/apernet/OpenGFW/analyzer/tcp"
|
"github.com/apernet/OpenGFW/analyzer/tcp"
|
||||||
@@ -176,11 +177,12 @@ type cliConfigIO struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type cliConfigWorkers struct {
|
type cliConfigWorkers struct {
|
||||||
Count int `mapstructure:"count"`
|
Count int `mapstructure:"count"`
|
||||||
QueueSize int `mapstructure:"queueSize"`
|
QueueSize int `mapstructure:"queueSize"`
|
||||||
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"`
|
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"`
|
||||||
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"`
|
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"`
|
||||||
UDPMaxStreams int `mapstructure:"udpMaxStreams"`
|
TCPTimeout time.Duration `mapstructure:"tcpTimeout"`
|
||||||
|
UDPMaxStreams int `mapstructure:"udpMaxStreams"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type cliConfigRuleset struct {
|
type cliConfigRuleset struct {
|
||||||
@@ -204,7 +206,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return configError{Field: "io", Err: err}
|
return configError{Field: "io", Err: err}
|
||||||
}
|
}
|
||||||
config.IOs = []io.PacketIO{nfio}
|
config.IO = nfio
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,6 +215,7 @@ func (c *cliConfig) fillWorkers(config *engine.Config) error {
|
|||||||
config.WorkerQueueSize = c.Workers.QueueSize
|
config.WorkerQueueSize = c.Workers.QueueSize
|
||||||
config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal
|
config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal
|
||||||
config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn
|
config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn
|
||||||
|
config.WorkerTCPTimeout = c.Workers.TCPTimeout
|
||||||
config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams
|
config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -247,12 +250,7 @@ func runMain(cmd *cobra.Command, args []string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("failed to parse config", zap.Error(err))
|
logger.Fatal("failed to parse config", zap.Error(err))
|
||||||
}
|
}
|
||||||
defer func() {
|
defer engineConfig.IO.Close() // Make sure to close IO on exit
|
||||||
// Make sure to close all IOs on exit
|
|
||||||
for _, i := range engineConfig.IOs {
|
|
||||||
_ = i.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Ruleset
|
// Ruleset
|
||||||
rawRs, err := ruleset.ExprRulesFromYAML(args[0])
|
rawRs, err := ruleset.ExprRulesFromYAML(args[0])
|
||||||
@@ -260,9 +258,10 @@ func runMain(cmd *cobra.Command, args []string) {
|
|||||||
logger.Fatal("failed to load rules", zap.Error(err))
|
logger.Fatal("failed to load rules", zap.Error(err))
|
||||||
}
|
}
|
||||||
rsConfig := &ruleset.BuiltinConfig{
|
rsConfig := &ruleset.BuiltinConfig{
|
||||||
Logger: &rulesetLogger{},
|
Logger: &rulesetLogger{},
|
||||||
GeoSiteFilename: config.Ruleset.GeoSite,
|
GeoSiteFilename: config.Ruleset.GeoSite,
|
||||||
GeoIpFilename: config.Ruleset.GeoIp,
|
GeoIpFilename: config.Ruleset.GeoIp,
|
||||||
|
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
|
||||||
}
|
}
|
||||||
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
|
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -344,12 +343,26 @@ func (l *engineLogger) TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
||||||
logger.Info("TCP stream action",
|
if noMatch {
|
||||||
zap.Int64("id", info.ID),
|
logger.Debug("TCP stream no match",
|
||||||
zap.String("src", info.SrcString()),
|
zap.Int64("id", info.ID),
|
||||||
zap.String("dst", info.DstString()),
|
zap.String("src", info.SrcString()),
|
||||||
zap.String("action", action.String()),
|
zap.String("dst", info.DstString()),
|
||||||
zap.Bool("noMatch", noMatch))
|
zap.String("action", action.String()))
|
||||||
|
} else {
|
||||||
|
logger.Info("TCP stream action",
|
||||||
|
zap.Int64("id", info.ID),
|
||||||
|
zap.String("src", info.SrcString()),
|
||||||
|
zap.String("dst", info.DstString()),
|
||||||
|
zap.String("action", action.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *engineLogger) TCPFlush(workerID, flushed, closed int) {
|
||||||
|
logger.Debug("TCP flush",
|
||||||
|
zap.Int("workerID", workerID),
|
||||||
|
zap.Int("flushed", flushed),
|
||||||
|
zap.Int("closed", closed))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) {
|
func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) {
|
||||||
@@ -370,12 +383,19 @@ func (l *engineLogger) UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
|
||||||
logger.Info("UDP stream action",
|
if noMatch {
|
||||||
zap.Int64("id", info.ID),
|
logger.Debug("UDP stream no match",
|
||||||
zap.String("src", info.SrcString()),
|
zap.Int64("id", info.ID),
|
||||||
zap.String("dst", info.DstString()),
|
zap.String("src", info.SrcString()),
|
||||||
zap.String("action", action.String()),
|
zap.String("dst", info.DstString()),
|
||||||
zap.Bool("noMatch", noMatch))
|
zap.String("action", action.String()))
|
||||||
|
} else {
|
||||||
|
logger.Info("UDP stream action",
|
||||||
|
zap.Int64("id", info.ID),
|
||||||
|
zap.String("src", info.SrcString()),
|
||||||
|
zap.String("dst", info.DstString()),
|
||||||
|
zap.String("action", action.String()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) {
|
func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
|
|||||||
|
|
||||||
type engine struct {
|
type engine struct {
|
||||||
logger Logger
|
logger Logger
|
||||||
ioList []io.PacketIO
|
io io.PacketIO
|
||||||
workers []*worker
|
workers []*worker
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,6 +34,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
Ruleset: config.Ruleset,
|
Ruleset: config.Ruleset,
|
||||||
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
||||||
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
||||||
|
TCPTimeout: config.WorkerTCPTimeout,
|
||||||
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -42,7 +43,7 @@ func NewEngine(config Config) (Engine, error) {
|
|||||||
}
|
}
|
||||||
return &engine{
|
return &engine{
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
ioList: config.IOs,
|
io: config.IO,
|
||||||
workers: workers,
|
workers: workers,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -58,27 +59,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
|||||||
|
|
||||||
func (e *engine) Run(ctx context.Context) error {
|
func (e *engine) Run(ctx context.Context) error {
|
||||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||||
defer ioCancel() // Stop workers & IOs
|
defer ioCancel() // Stop workers & IO
|
||||||
|
|
||||||
// Start workers
|
// Start workers
|
||||||
for _, w := range e.workers {
|
for _, w := range e.workers {
|
||||||
go w.Run(ioCtx)
|
go w.Run(ioCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register callbacks
|
// Register IO callback
|
||||||
errChan := make(chan error, len(e.ioList))
|
errChan := make(chan error, 1)
|
||||||
for _, i := range e.ioList {
|
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||||
ioEntry := i // Make sure dispatch() uses the correct ioEntry
|
|
||||||
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
|
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return e.dispatch(ioEntry, p)
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
errChan <- err
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
return e.dispatch(p)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Block until IO errors or context is cancelled
|
// Block until IO errors or context is cancelled
|
||||||
@@ -91,8 +89,7 @@ func (e *engine) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dispatch dispatches a packet to a worker.
|
// dispatch dispatches a packet to a worker.
|
||||||
// This must be safe for concurrent use, as it may be called from multiple IOs.
|
func (e *engine) dispatch(p io.Packet) bool {
|
||||||
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
|
||||||
data := p.Data()
|
data := p.Data()
|
||||||
ipVersion := data[0] >> 4
|
ipVersion := data[0] >> 4
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
@@ -102,7 +99,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
|||||||
layerType = layers.LayerTypeIPv6
|
layerType = layers.LayerTypeIPv6
|
||||||
} else {
|
} else {
|
||||||
// Unsupported network layer
|
// Unsupported network layer
|
||||||
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
|
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Load balance by stream ID
|
// Load balance by stream ID
|
||||||
@@ -112,7 +109,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
|||||||
StreamID: p.StreamID(),
|
StreamID: p.StreamID(),
|
||||||
Packet: packet,
|
Packet: packet,
|
||||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||||
return ioEntry.SetVerdict(p, v, b)
|
return e.io.SetVerdict(p, v, b)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/OpenGFW/io"
|
"github.com/apernet/OpenGFW/io"
|
||||||
"github.com/apernet/OpenGFW/ruleset"
|
"github.com/apernet/OpenGFW/ruleset"
|
||||||
@@ -18,13 +19,14 @@ type Engine interface {
|
|||||||
// Config is the configuration for the engine.
|
// Config is the configuration for the engine.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Logger Logger
|
Logger Logger
|
||||||
IOs []io.PacketIO
|
IO io.PacketIO
|
||||||
Ruleset ruleset.Ruleset
|
Ruleset ruleset.Ruleset
|
||||||
|
|
||||||
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
|
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
|
||||||
WorkerQueueSize int
|
WorkerQueueSize int
|
||||||
WorkerTCPMaxBufferedPagesTotal int
|
WorkerTCPMaxBufferedPagesTotal int
|
||||||
WorkerTCPMaxBufferedPagesPerConn int
|
WorkerTCPMaxBufferedPagesPerConn int
|
||||||
|
WorkerTCPTimeout time.Duration
|
||||||
WorkerUDPMaxStreams int
|
WorkerUDPMaxStreams int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,6 +38,7 @@ type Logger interface {
|
|||||||
TCPStreamNew(workerID int, info ruleset.StreamInfo)
|
TCPStreamNew(workerID int, info ruleset.StreamInfo)
|
||||||
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
||||||
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
|
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
|
||||||
|
TCPFlush(workerID, flushed, closed int)
|
||||||
|
|
||||||
UDPStreamNew(workerID int, info ruleset.StreamInfo)
|
UDPStreamNew(workerID int, info ruleset.StreamInfo)
|
||||||
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package engine
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apernet/OpenGFW/io"
|
"github.com/apernet/OpenGFW/io"
|
||||||
"github.com/apernet/OpenGFW/ruleset"
|
"github.com/apernet/OpenGFW/ruleset"
|
||||||
@@ -14,9 +15,12 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultChanSize = 64
|
defaultChanSize = 64
|
||||||
defaultTCPMaxBufferedPagesTotal = 4096
|
defaultTCPMaxBufferedPagesTotal = 65536
|
||||||
defaultTCPMaxBufferedPagesPerConnection = 64
|
defaultTCPMaxBufferedPagesPerConnection = 16
|
||||||
|
defaultTCPTimeout = 10 * time.Minute
|
||||||
defaultUDPMaxStreams = 4096
|
defaultUDPMaxStreams = 4096
|
||||||
|
|
||||||
|
tcpFlushInterval = 1 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
type workerPacket struct {
|
type workerPacket struct {
|
||||||
@@ -33,6 +37,7 @@ type worker struct {
|
|||||||
tcpStreamFactory *tcpStreamFactory
|
tcpStreamFactory *tcpStreamFactory
|
||||||
tcpStreamPool *reassembly.StreamPool
|
tcpStreamPool *reassembly.StreamPool
|
||||||
tcpAssembler *reassembly.Assembler
|
tcpAssembler *reassembly.Assembler
|
||||||
|
tcpTimeout time.Duration
|
||||||
|
|
||||||
udpStreamFactory *udpStreamFactory
|
udpStreamFactory *udpStreamFactory
|
||||||
udpStreamManager *udpStreamManager
|
udpStreamManager *udpStreamManager
|
||||||
@@ -47,6 +52,7 @@ type workerConfig struct {
|
|||||||
Ruleset ruleset.Ruleset
|
Ruleset ruleset.Ruleset
|
||||||
TCPMaxBufferedPagesTotal int
|
TCPMaxBufferedPagesTotal int
|
||||||
TCPMaxBufferedPagesPerConn int
|
TCPMaxBufferedPagesPerConn int
|
||||||
|
TCPTimeout time.Duration
|
||||||
UDPMaxStreams int
|
UDPMaxStreams int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,6 +66,9 @@ func (c *workerConfig) fillDefaults() {
|
|||||||
if c.TCPMaxBufferedPagesPerConn <= 0 {
|
if c.TCPMaxBufferedPagesPerConn <= 0 {
|
||||||
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
|
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
|
||||||
}
|
}
|
||||||
|
if c.TCPTimeout <= 0 {
|
||||||
|
c.TCPTimeout = defaultTCPTimeout
|
||||||
|
}
|
||||||
if c.UDPMaxStreams <= 0 {
|
if c.UDPMaxStreams <= 0 {
|
||||||
c.UDPMaxStreams = defaultUDPMaxStreams
|
c.UDPMaxStreams = defaultUDPMaxStreams
|
||||||
}
|
}
|
||||||
@@ -98,6 +107,7 @@ func newWorker(config workerConfig) (*worker, error) {
|
|||||||
tcpStreamFactory: tcpSF,
|
tcpStreamFactory: tcpSF,
|
||||||
tcpStreamPool: tcpStreamPool,
|
tcpStreamPool: tcpStreamPool,
|
||||||
tcpAssembler: tcpAssembler,
|
tcpAssembler: tcpAssembler,
|
||||||
|
tcpTimeout: config.TCPTimeout,
|
||||||
udpStreamFactory: udpSF,
|
udpStreamFactory: udpSF,
|
||||||
udpStreamManager: udpSM,
|
udpStreamManager: udpSM,
|
||||||
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
modSerializeBuffer: gopacket.NewSerializeBuffer(),
|
||||||
@@ -111,6 +121,10 @@ func (w *worker) Feed(p *workerPacket) {
|
|||||||
func (w *worker) Run(ctx context.Context) {
|
func (w *worker) Run(ctx context.Context) {
|
||||||
w.logger.WorkerStart(w.id)
|
w.logger.WorkerStart(w.id)
|
||||||
defer w.logger.WorkerStop(w.id)
|
defer w.logger.WorkerStop(w.id)
|
||||||
|
|
||||||
|
tcpFlushTicker := time.NewTicker(tcpFlushInterval)
|
||||||
|
defer tcpFlushTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -122,6 +136,8 @@ func (w *worker) Run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
|
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
|
||||||
_ = wPkt.SetVerdict(v, b)
|
_ = wPkt.SetVerdict(v, b)
|
||||||
|
case <-tcpFlushTicker.C:
|
||||||
|
w.flushTCP(w.tcpTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -176,6 +192,11 @@ func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata,
|
|||||||
return io.Verdict(ctx.Verdict)
|
return io.Verdict(ctx.Verdict)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *worker) flushTCP(timeout time.Duration) {
|
||||||
|
flushed, closed := w.tcpAssembler.FlushCloseOlderThan(time.Now().Add(-timeout))
|
||||||
|
w.logger.TCPFlush(w.id, flushed, closed)
|
||||||
|
}
|
||||||
|
|
||||||
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
|
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
|
||||||
ctx := &udpContext{
|
ctx := &udpContext{
|
||||||
Verdict: udpVerdictAccept,
|
Verdict: udpVerdictAccept,
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -5,7 +5,7 @@ go 1.21
|
|||||||
require (
|
require (
|
||||||
github.com/bwmarrin/snowflake v0.3.0
|
github.com/bwmarrin/snowflake v0.3.0
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/expr-lang/expr v1.15.7
|
github.com/expr-lang/expr v1.16.3
|
||||||
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
|
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
|
||||||
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
|
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -7,8 +7,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo=
|
github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to=
|
||||||
github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
|
github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
|
||||||
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
|
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
|
||||||
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
|
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
|
||||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package io
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Verdict int
|
type Verdict int
|
||||||
@@ -29,7 +30,6 @@ type Packet interface {
|
|||||||
|
|
||||||
// PacketCallback is called for each packet received.
|
// PacketCallback is called for each packet received.
|
||||||
// Return false to "unregister" and stop receiving packets.
|
// Return false to "unregister" and stop receiving packets.
|
||||||
// It must be safe for concurrent use.
|
|
||||||
type PacketCallback func(Packet, error) bool
|
type PacketCallback func(Packet, error) bool
|
||||||
|
|
||||||
type PacketIO interface {
|
type PacketIO interface {
|
||||||
@@ -39,6 +39,10 @@ type PacketIO interface {
|
|||||||
Register(context.Context, PacketCallback) error
|
Register(context.Context, PacketCallback) error
|
||||||
// SetVerdict sets the verdict for a packet.
|
// SetVerdict sets the verdict for a packet.
|
||||||
SetVerdict(Packet, Verdict, []byte) error
|
SetVerdict(Packet, Verdict, []byte) error
|
||||||
|
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
|
||||||
|
// in the sense that the packets sent/received through the connection must bypass
|
||||||
|
// the packet IO and not be processed by the callback.
|
||||||
|
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
// Close closes the packet IO.
|
// Close closes the packet IO.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/florianl/go-nfqueue"
|
"github.com/florianl/go-nfqueue"
|
||||||
@@ -50,6 +52,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
|||||||
}
|
}
|
||||||
for i := range table.Chains {
|
for i := range table.Chains {
|
||||||
c := &table.Chains[i]
|
c := &table.Chains[i]
|
||||||
|
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
|
||||||
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
||||||
if rst {
|
if rst {
|
||||||
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
||||||
@@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
|
|||||||
}
|
}
|
||||||
rules := make([]iptRule, 0, 4*len(chains))
|
rules := make([]iptRule, 0, 4*len(chains))
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
|
// Bypass protected connections
|
||||||
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
||||||
if rst {
|
if rst {
|
||||||
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
||||||
@@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
|
|||||||
// iptables not nil = use iptables instead of nftables
|
// iptables not nil = use iptables instead of nftables
|
||||||
ipt4 *iptables.IPTables
|
ipt4 *iptables.IPTables
|
||||||
ipt6 *iptables.IPTables
|
ipt6 *iptables.IPTables
|
||||||
|
|
||||||
|
protectedDialer *net.Dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
type NFQueuePacketIOConfig struct {
|
type NFQueuePacketIOConfig struct {
|
||||||
@@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
rst: config.RST,
|
rst: config.RST,
|
||||||
ipt4: ipt4,
|
ipt4: ipt4,
|
||||||
ipt6: ipt6,
|
ipt6: ipt6,
|
||||||
|
protectedDialer: &net.Dialer{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
cErr := c.Control(func(fd uintptr) {
|
||||||
|
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
|
||||||
|
})
|
||||||
|
if cErr != nil {
|
||||||
|
return cErr
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return n.protectedDialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) Close() error {
|
func (n *nfqueuePacketIO) Close() error {
|
||||||
if n.rSet {
|
if n.rSet {
|
||||||
if n.ipt4 != nil {
|
if n.ipt4 != nil {
|
||||||
|
|||||||
@@ -14,14 +14,12 @@ type GeoMatcher struct {
|
|||||||
ipMatcherLock sync.Mutex
|
ipMatcherLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) {
|
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
|
||||||
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)
|
|
||||||
|
|
||||||
return &GeoMatcher{
|
return &GeoMatcher{
|
||||||
geoLoader: geoLoader,
|
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
|
||||||
geoSiteMatcher: make(map[string]hostMatcher),
|
geoSiteMatcher: make(map[string]hostMatcher),
|
||||||
geoIpMatcher: make(map[string]hostMatcher),
|
geoIpMatcher: make(map[string]hostMatcher),
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
|
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
|
||||||
|
|||||||
188
ruleset/expr.go
188
ruleset/expr.go
@@ -1,11 +1,15 @@
|
|||||||
package ruleset
|
package ruleset
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/expr-lang/expr/builtin"
|
||||||
|
|
||||||
"github.com/expr-lang/expr"
|
"github.com/expr-lang/expr"
|
||||||
"github.com/expr-lang/expr/ast"
|
"github.com/expr-lang/expr/ast"
|
||||||
@@ -55,10 +59,9 @@ type compiledExprRule struct {
|
|||||||
var _ Ruleset = (*exprRuleset)(nil)
|
var _ Ruleset = (*exprRuleset)(nil)
|
||||||
|
|
||||||
type exprRuleset struct {
|
type exprRuleset struct {
|
||||||
Rules []compiledExprRule
|
Rules []compiledExprRule
|
||||||
Ans []analyzer.Analyzer
|
Ans []analyzer.Analyzer
|
||||||
Logger Logger
|
Logger Logger
|
||||||
GeoMatcher *geo.GeoMatcher
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
||||||
@@ -100,10 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
fullAnMap := analyzersToMap(ans)
|
fullAnMap := analyzersToMap(ans)
|
||||||
fullModMap := modifiersToMap(mods)
|
fullModMap := modifiersToMap(mods)
|
||||||
depAnMap := make(map[string]analyzer.Analyzer)
|
depAnMap := make(map[string]analyzer.Analyzer)
|
||||||
geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
funcMap := buildFunctionMap(config)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Compile all rules and build a map of analyzers that are used by the rules.
|
// Compile all rules and build a map of analyzers that are used by the rules.
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.Action == "" && !rule.Log {
|
if rule.Action == "" && !rule.Log {
|
||||||
@@ -118,13 +118,19 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
action = &a
|
action = &a
|
||||||
}
|
}
|
||||||
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
|
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
|
||||||
patcher := &idPatcher{}
|
patcher := &idPatcher{FuncMap: funcMap}
|
||||||
program, err := expr.Compile(rule.Expr,
|
program, err := expr.Compile(rule.Expr,
|
||||||
func(c *conf.Config) {
|
func(c *conf.Config) {
|
||||||
c.Strict = false
|
c.Strict = false
|
||||||
c.Expect = reflect.Bool
|
c.Expect = reflect.Bool
|
||||||
c.Visitors = append(c.Visitors, visitor, patcher)
|
c.Visitors = append(c.Visitors, visitor, patcher)
|
||||||
registerBuiltinFunctions(c.Functions, geoMatcher)
|
for name, f := range funcMap {
|
||||||
|
c.Functions[name] = &builtin.Function{
|
||||||
|
Name: name,
|
||||||
|
Func: f.Func,
|
||||||
|
Types: f.Types,
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -138,24 +144,15 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
if isBuiltInAnalyzer(name) || visitor.Variables[name] {
|
if isBuiltInAnalyzer(name) || visitor.Variables[name] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Check if it's one of the built-in functions, and if so,
|
if f, ok := funcMap[name]; ok {
|
||||||
// skip it as an analyzer & do initialization if necessary.
|
// Built-in function, initialize if necessary
|
||||||
switch name {
|
if f.InitFunc != nil {
|
||||||
case "geoip":
|
if err := f.InitFunc(); err != nil {
|
||||||
if err := geoMatcher.LoadGeoIP(); err != nil {
|
return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err)
|
||||||
return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err)
|
}
|
||||||
}
|
|
||||||
case "geosite":
|
|
||||||
if err := geoMatcher.LoadGeoSite(); err != nil {
|
|
||||||
return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err)
|
|
||||||
}
|
|
||||||
case "cidr":
|
|
||||||
// No initialization needed for CIDR.
|
|
||||||
default:
|
|
||||||
a, ok := fullAnMap[name]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name)
|
|
||||||
}
|
}
|
||||||
|
} else if a, ok := fullAnMap[name]; ok {
|
||||||
|
// Analyzer, add to dependency map
|
||||||
depAnMap[name] = a
|
depAnMap[name] = a
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,37 +181,12 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
|||||||
depAns = append(depAns, a)
|
depAns = append(depAns, a)
|
||||||
}
|
}
|
||||||
return &exprRuleset{
|
return &exprRuleset{
|
||||||
Rules: compiledRules,
|
Rules: compiledRules,
|
||||||
Ans: depAns,
|
Ans: depAns,
|
||||||
Logger: config.Logger,
|
Logger: config.Logger,
|
||||||
GeoMatcher: geoMatcher,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo.GeoMatcher) {
|
|
||||||
funcMap["geoip"] = &ast.Function{
|
|
||||||
Name: "geoip",
|
|
||||||
Func: func(params ...any) (any, error) {
|
|
||||||
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
|
|
||||||
},
|
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
|
|
||||||
}
|
|
||||||
funcMap["geosite"] = &ast.Function{
|
|
||||||
Name: "geosite",
|
|
||||||
Func: func(params ...any) (any, error) {
|
|
||||||
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
|
|
||||||
},
|
|
||||||
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
|
||||||
}
|
|
||||||
funcMap["cidr"] = &ast.Function{
|
|
||||||
Name: "cidr",
|
|
||||||
Func: func(params ...any) (any, error) {
|
|
||||||
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
|
|
||||||
},
|
|
||||||
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
|
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
|
||||||
m := map[string]interface{}{
|
m := map[string]interface{}{
|
||||||
"id": info.ID,
|
"id": info.ID,
|
||||||
@@ -299,29 +271,109 @@ func (v *idVisitor) Visit(node *ast.Node) {
|
|||||||
// idPatcher patches the AST during expr compilation, replacing certain values with
|
// idPatcher patches the AST during expr compilation, replacing certain values with
|
||||||
// their internal representations for better runtime performance.
|
// their internal representations for better runtime performance.
|
||||||
type idPatcher struct {
|
type idPatcher struct {
|
||||||
Err error
|
FuncMap map[string]*Function
|
||||||
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *idPatcher) Visit(node *ast.Node) {
|
func (p *idPatcher) Visit(node *ast.Node) {
|
||||||
switch (*node).(type) {
|
switch (*node).(type) {
|
||||||
case *ast.CallNode:
|
case *ast.CallNode:
|
||||||
callNode := (*node).(*ast.CallNode)
|
callNode := (*node).(*ast.CallNode)
|
||||||
if callNode.Func == nil {
|
if callNode.Callee == nil {
|
||||||
// Ignore invalid call nodes
|
// Ignore invalid call nodes
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch callNode.Func.Name {
|
if f, ok := p.FuncMap[callNode.Callee.String()]; ok {
|
||||||
case "cidr":
|
if f.PatchFunc != nil {
|
||||||
cidrStringNode, ok := callNode.Arguments[1].(*ast.StringNode)
|
if err := f.PatchFunc(&callNode.Arguments); err != nil {
|
||||||
if !ok {
|
p.Err = err
|
||||||
return
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
|
|
||||||
if err != nil {
|
|
||||||
p.Err = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
callNode.Arguments[1] = &ast.ConstantNode{Value: cidr}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Function struct {
|
||||||
|
InitFunc func() error
|
||||||
|
PatchFunc func(args *[]ast.Node) error
|
||||||
|
Func func(params ...any) (any, error)
|
||||||
|
Types []reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
|
||||||
|
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
||||||
|
return map[string]*Function{
|
||||||
|
"geoip": {
|
||||||
|
InitFunc: geoMatcher.LoadGeoIP,
|
||||||
|
PatchFunc: nil,
|
||||||
|
Func: func(params ...any) (any, error) {
|
||||||
|
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
|
||||||
|
},
|
||||||
|
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
|
||||||
|
},
|
||||||
|
"geosite": {
|
||||||
|
InitFunc: geoMatcher.LoadGeoSite,
|
||||||
|
PatchFunc: nil,
|
||||||
|
Func: func(params ...any) (any, error) {
|
||||||
|
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
|
||||||
|
},
|
||||||
|
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
|
||||||
|
},
|
||||||
|
"cidr": {
|
||||||
|
InitFunc: nil,
|
||||||
|
PatchFunc: func(args *[]ast.Node) error {
|
||||||
|
cidrStringNode, ok := (*args)[1].(*ast.StringNode)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("cidr: invalid argument type")
|
||||||
|
}
|
||||||
|
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
(*args)[1] = &ast.ConstantNode{Value: cidr}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Func: func(params ...any) (any, error) {
|
||||||
|
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
|
||||||
|
},
|
||||||
|
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
|
||||||
|
},
|
||||||
|
"lookup": {
|
||||||
|
InitFunc: nil,
|
||||||
|
PatchFunc: func(args *[]ast.Node) error {
|
||||||
|
var serverStr *ast.StringNode
|
||||||
|
if len(*args) > 1 {
|
||||||
|
// Has the optional server argument
|
||||||
|
var ok bool
|
||||||
|
serverStr, ok = (*args)[1].(*ast.StringNode)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("lookup: invalid argument type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := &net.Resolver{
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if serverStr != nil {
|
||||||
|
address = serverStr.Value
|
||||||
|
}
|
||||||
|
return config.ProtectedDialContext(ctx, network, address)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if len(*args) > 1 {
|
||||||
|
(*args)[1] = &ast.ConstantNode{Value: r}
|
||||||
|
} else {
|
||||||
|
*args = append(*args, &ast.ConstantNode{Value: r})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Func: func(params ...any) (any, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
|
||||||
|
},
|
||||||
|
Types: []reflect.Type{
|
||||||
|
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package ruleset
|
package ruleset
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -100,7 +101,8 @@ type Logger interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BuiltinConfig struct {
|
type BuiltinConfig struct {
|
||||||
Logger Logger
|
Logger Logger
|
||||||
GeoSiteFilename string
|
GeoSiteFilename string
|
||||||
GeoIpFilename string
|
GeoIpFilename string
|
||||||
|
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user