7 Commits

Author SHA1 Message Date
Toby
5f447d4e31 Merge pull request #124 from apernet/wip-tcp-flush
feat: TCP timeout flush
2024-04-09 11:07:20 -07:00
Toby
347667a2bd feat: TCP timeout flush 2024-04-08 11:54:35 -07:00
Toby
393c29bd2d Merge pull request #123 from apernet/wip-lookup
feat: dns lookup function
2024-04-07 17:49:33 -07:00
Toby
9c0893c512 feat: added protected dial support, removed multi-IO support for simplicity 2024-04-06 14:42:45 -07:00
Toby
ae34b4856a feat: dns lookup function 2024-04-03 20:02:57 -07:00
Toby
d7737e9211 Merge pull request #119 from apernet/update-readme
docs: update readme feature list
2024-04-01 21:50:07 -07:00
Toby
dd9ecc3dd7 docs: update readme feature list 2024-04-01 21:49:06 -07:00
14 changed files with 255 additions and 135 deletions

View File

@@ -18,8 +18,8 @@ Telegram グループ: https://t.me/OpGFW
## 特徴 ## 特徴
- フル IP/TCP 再アセンブル、各種プロトコルアナライザー - フル IP/TCP 再アセンブル、各種プロトコルアナライザー
- HTTP、TLS、QUIC、DNS、SSH、SOCKS4/5、WireGuard、その他多数 - HTTP、TLS、QUIC、DNS、SSH、SOCKS4/5、WireGuard、OpenVPN、その他多数
- Shadowsocks の「完全に暗号化されたトラフィック」の検出など (https://gfw.report/publications/usenixsecurity23/en/) - Shadowsocks、VMess の「完全に暗号化されたトラフィック」の検出など (https://gfw.report/publications/usenixsecurity23/en/)
- Trojan プロキシプロトコルの検出 - Trojan プロキシプロトコルの検出
- [WIP] 機械学習に基づくトラフィック分類 - [WIP] 機械学習に基づくトラフィック分類
- IPv4 と IPv6 をフルサポート - IPv4 と IPv6 をフルサポート

View File

@@ -21,8 +21,8 @@ Telegram group: https://t.me/OpGFW
## Features ## Features
- Full IP/TCP reassembly, various protocol analyzers - Full IP/TCP reassembly, various protocol analyzers
- HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, and many more to come - HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, OpenVPN, and many more to come
- "Fully encrypted traffic" detection for Shadowsocks, - "Fully encrypted traffic" detection for Shadowsocks, VMess,
etc. (https://gfw.report/publications/usenixsecurity23/en/) etc. (https://gfw.report/publications/usenixsecurity23/en/)
- Trojan (proxy protocol) detection - Trojan (proxy protocol) detection
- [WIP] Machine learning based traffic classification - [WIP] Machine learning based traffic classification

View File

@@ -18,8 +18,8 @@ Telegram 群组: https://t.me/OpGFW
## 功能 ## 功能
- 完整的 IP/TCP 重组,各种协议解析器 - 完整的 IP/TCP 重组,各种协议解析器
- HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, 更多协议正在开发中 - HTTP, TLS, QUIC, DNS, SSH, SOCKS4/5, WireGuard, OpenVPN, 更多协议正在开发中
- Shadowsocks 等 "全加密流量" 检测 (https://gfw.report/publications/usenixsecurity23/zh/) - Shadowsocks, VMess 等 "全加密流量" 检测 (https://gfw.report/publications/usenixsecurity23/zh/)
- Trojan 协议检测 - Trojan 协议检测
- [开发中] 基于机器学习的流量分类 - [开发中] 基于机器学习的流量分类
- 同等支持 IPv4 和 IPv6 - 同等支持 IPv4 和 IPv6

View File

@@ -7,6 +7,7 @@ import (
"os/signal" "os/signal"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/apernet/OpenGFW/analyzer" "github.com/apernet/OpenGFW/analyzer"
"github.com/apernet/OpenGFW/analyzer/tcp" "github.com/apernet/OpenGFW/analyzer/tcp"
@@ -176,11 +177,12 @@ type cliConfigIO struct {
} }
type cliConfigWorkers struct { type cliConfigWorkers struct {
Count int `mapstructure:"count"` Count int `mapstructure:"count"`
QueueSize int `mapstructure:"queueSize"` QueueSize int `mapstructure:"queueSize"`
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"` TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"`
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"` TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"`
UDPMaxStreams int `mapstructure:"udpMaxStreams"` TCPTimeout time.Duration `mapstructure:"tcpTimeout"`
UDPMaxStreams int `mapstructure:"udpMaxStreams"`
} }
type cliConfigRuleset struct { type cliConfigRuleset struct {
@@ -204,7 +206,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
if err != nil { if err != nil {
return configError{Field: "io", Err: err} return configError{Field: "io", Err: err}
} }
config.IOs = []io.PacketIO{nfio} config.IO = nfio
return nil return nil
} }
@@ -213,6 +215,7 @@ func (c *cliConfig) fillWorkers(config *engine.Config) error {
config.WorkerQueueSize = c.Workers.QueueSize config.WorkerQueueSize = c.Workers.QueueSize
config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal
config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn
config.WorkerTCPTimeout = c.Workers.TCPTimeout
config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams
return nil return nil
} }
@@ -247,12 +250,7 @@ func runMain(cmd *cobra.Command, args []string) {
if err != nil { if err != nil {
logger.Fatal("failed to parse config", zap.Error(err)) logger.Fatal("failed to parse config", zap.Error(err))
} }
defer func() { defer engineConfig.IO.Close() // Make sure to close IO on exit
// Make sure to close all IOs on exit
for _, i := range engineConfig.IOs {
_ = i.Close()
}
}()
// Ruleset // Ruleset
rawRs, err := ruleset.ExprRulesFromYAML(args[0]) rawRs, err := ruleset.ExprRulesFromYAML(args[0])
@@ -260,9 +258,10 @@ func runMain(cmd *cobra.Command, args []string) {
logger.Fatal("failed to load rules", zap.Error(err)) logger.Fatal("failed to load rules", zap.Error(err))
} }
rsConfig := &ruleset.BuiltinConfig{ rsConfig := &ruleset.BuiltinConfig{
Logger: &rulesetLogger{}, Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite, GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp, GeoIpFilename: config.Ruleset.GeoIp,
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
} }
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig) rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
if err != nil { if err != nil {
@@ -344,12 +343,26 @@ func (l *engineLogger) TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
} }
func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) { func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
logger.Info("TCP stream action", if noMatch {
zap.Int64("id", info.ID), logger.Debug("TCP stream no match",
zap.String("src", info.SrcString()), zap.Int64("id", info.ID),
zap.String("dst", info.DstString()), zap.String("src", info.SrcString()),
zap.String("action", action.String()), zap.String("dst", info.DstString()),
zap.Bool("noMatch", noMatch)) zap.String("action", action.String()))
} else {
logger.Info("TCP stream action",
zap.Int64("id", info.ID),
zap.String("src", info.SrcString()),
zap.String("dst", info.DstString()),
zap.String("action", action.String()))
}
}
func (l *engineLogger) TCPFlush(workerID, flushed, closed int) {
logger.Debug("TCP flush",
zap.Int("workerID", workerID),
zap.Int("flushed", flushed),
zap.Int("closed", closed))
} }
func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) { func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) {
@@ -370,12 +383,19 @@ func (l *engineLogger) UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
} }
func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) { func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
logger.Info("UDP stream action", if noMatch {
zap.Int64("id", info.ID), logger.Debug("UDP stream no match",
zap.String("src", info.SrcString()), zap.Int64("id", info.ID),
zap.String("dst", info.DstString()), zap.String("src", info.SrcString()),
zap.String("action", action.String()), zap.String("dst", info.DstString()),
zap.Bool("noMatch", noMatch)) zap.String("action", action.String()))
} else {
logger.Info("UDP stream action",
zap.Int64("id", info.ID),
zap.String("src", info.SrcString()),
zap.String("dst", info.DstString()),
zap.String("action", action.String()))
}
} }
func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) { func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) {

View File

@@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
type engine struct { type engine struct {
logger Logger logger Logger
ioList []io.PacketIO io io.PacketIO
workers []*worker workers []*worker
} }
@@ -34,6 +34,7 @@ func NewEngine(config Config) (Engine, error) {
Ruleset: config.Ruleset, Ruleset: config.Ruleset,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
TCPTimeout: config.WorkerTCPTimeout,
UDPMaxStreams: config.WorkerUDPMaxStreams, UDPMaxStreams: config.WorkerUDPMaxStreams,
}) })
if err != nil { if err != nil {
@@ -42,7 +43,7 @@ func NewEngine(config Config) (Engine, error) {
} }
return &engine{ return &engine{
logger: config.Logger, logger: config.Logger,
ioList: config.IOs, io: config.IO,
workers: workers, workers: workers,
}, nil }, nil
} }
@@ -58,27 +59,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
func (e *engine) Run(ctx context.Context) error { func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx) ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IOs defer ioCancel() // Stop workers & IO
// Start workers // Start workers
for _, w := range e.workers { for _, w := range e.workers {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
// Register callbacks // Register IO callback
errChan := make(chan error, len(e.ioList)) errChan := make(chan error, 1)
for _, i := range e.ioList { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
ioEntry := i // Make sure dispatch() uses the correct ioEntry
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(ioEntry, p)
})
if err != nil { if err != nil {
return err errChan <- err
return false
} }
return e.dispatch(p)
})
if err != nil {
return err
} }
// Block until IO errors or context is cancelled // Block until IO errors or context is cancelled
@@ -91,8 +89,7 @@ func (e *engine) Run(ctx context.Context) error {
} }
// dispatch dispatches a packet to a worker. // dispatch dispatches a packet to a worker.
// This must be safe for concurrent use, as it may be called from multiple IOs. func (e *engine) dispatch(p io.Packet) bool {
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
data := p.Data() data := p.Data()
ipVersion := data[0] >> 4 ipVersion := data[0] >> 4
var layerType gopacket.LayerType var layerType gopacket.LayerType
@@ -102,7 +99,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
layerType = layers.LayerTypeIPv6 layerType = layers.LayerTypeIPv6
} else { } else {
// Unsupported network layer // Unsupported network layer
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil) _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true return true
} }
// Load balance by stream ID // Load balance by stream ID
@@ -112,7 +109,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
StreamID: p.StreamID(), StreamID: p.StreamID(),
Packet: packet, Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error { SetVerdict: func(v io.Verdict, b []byte) error {
return ioEntry.SetVerdict(p, v, b) return e.io.SetVerdict(p, v, b)
}, },
}) })
return true return true

View File

@@ -2,6 +2,7 @@ package engine
import ( import (
"context" "context"
"time"
"github.com/apernet/OpenGFW/io" "github.com/apernet/OpenGFW/io"
"github.com/apernet/OpenGFW/ruleset" "github.com/apernet/OpenGFW/ruleset"
@@ -18,13 +19,14 @@ type Engine interface {
// Config is the configuration for the engine. // Config is the configuration for the engine.
type Config struct { type Config struct {
Logger Logger Logger Logger
IOs []io.PacketIO IO io.PacketIO
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
Workers int // Number of workers. Zero or negative means auto (number of CPU cores). Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
WorkerQueueSize int WorkerQueueSize int
WorkerTCPMaxBufferedPagesTotal int WorkerTCPMaxBufferedPagesTotal int
WorkerTCPMaxBufferedPagesPerConn int WorkerTCPMaxBufferedPagesPerConn int
WorkerTCPTimeout time.Duration
WorkerUDPMaxStreams int WorkerUDPMaxStreams int
} }
@@ -36,6 +38,7 @@ type Logger interface {
TCPStreamNew(workerID int, info ruleset.StreamInfo) TCPStreamNew(workerID int, info ruleset.StreamInfo)
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool) TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
TCPFlush(workerID, flushed, closed int)
UDPStreamNew(workerID int, info ruleset.StreamInfo) UDPStreamNew(workerID int, info ruleset.StreamInfo)
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool) UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)

View File

@@ -2,6 +2,7 @@ package engine
import ( import (
"context" "context"
"time"
"github.com/apernet/OpenGFW/io" "github.com/apernet/OpenGFW/io"
"github.com/apernet/OpenGFW/ruleset" "github.com/apernet/OpenGFW/ruleset"
@@ -14,9 +15,12 @@ import (
const ( const (
defaultChanSize = 64 defaultChanSize = 64
defaultTCPMaxBufferedPagesTotal = 4096 defaultTCPMaxBufferedPagesTotal = 65536
defaultTCPMaxBufferedPagesPerConnection = 64 defaultTCPMaxBufferedPagesPerConnection = 16
defaultTCPTimeout = 10 * time.Minute
defaultUDPMaxStreams = 4096 defaultUDPMaxStreams = 4096
tcpFlushInterval = 1 * time.Minute
) )
type workerPacket struct { type workerPacket struct {
@@ -33,6 +37,7 @@ type worker struct {
tcpStreamFactory *tcpStreamFactory tcpStreamFactory *tcpStreamFactory
tcpStreamPool *reassembly.StreamPool tcpStreamPool *reassembly.StreamPool
tcpAssembler *reassembly.Assembler tcpAssembler *reassembly.Assembler
tcpTimeout time.Duration
udpStreamFactory *udpStreamFactory udpStreamFactory *udpStreamFactory
udpStreamManager *udpStreamManager udpStreamManager *udpStreamManager
@@ -47,6 +52,7 @@ type workerConfig struct {
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
TCPMaxBufferedPagesTotal int TCPMaxBufferedPagesTotal int
TCPMaxBufferedPagesPerConn int TCPMaxBufferedPagesPerConn int
TCPTimeout time.Duration
UDPMaxStreams int UDPMaxStreams int
} }
@@ -60,6 +66,9 @@ func (c *workerConfig) fillDefaults() {
if c.TCPMaxBufferedPagesPerConn <= 0 { if c.TCPMaxBufferedPagesPerConn <= 0 {
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
} }
if c.TCPTimeout <= 0 {
c.TCPTimeout = defaultTCPTimeout
}
if c.UDPMaxStreams <= 0 { if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = defaultUDPMaxStreams c.UDPMaxStreams = defaultUDPMaxStreams
} }
@@ -98,6 +107,7 @@ func newWorker(config workerConfig) (*worker, error) {
tcpStreamFactory: tcpSF, tcpStreamFactory: tcpSF,
tcpStreamPool: tcpStreamPool, tcpStreamPool: tcpStreamPool,
tcpAssembler: tcpAssembler, tcpAssembler: tcpAssembler,
tcpTimeout: config.TCPTimeout,
udpStreamFactory: udpSF, udpStreamFactory: udpSF,
udpStreamManager: udpSM, udpStreamManager: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(), modSerializeBuffer: gopacket.NewSerializeBuffer(),
@@ -111,6 +121,10 @@ func (w *worker) Feed(p *workerPacket) {
func (w *worker) Run(ctx context.Context) { func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id) w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id) defer w.logger.WorkerStop(w.id)
tcpFlushTicker := time.NewTicker(tcpFlushInterval)
defer tcpFlushTicker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -122,6 +136,8 @@ func (w *worker) Run(ctx context.Context) {
} }
v, b := w.handle(wPkt.StreamID, wPkt.Packet) v, b := w.handle(wPkt.StreamID, wPkt.Packet)
_ = wPkt.SetVerdict(v, b) _ = wPkt.SetVerdict(v, b)
case <-tcpFlushTicker.C:
w.flushTCP(w.tcpTimeout)
} }
} }
} }
@@ -176,6 +192,11 @@ func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata,
return io.Verdict(ctx.Verdict) return io.Verdict(ctx.Verdict)
} }
func (w *worker) flushTCP(timeout time.Duration) {
flushed, closed := w.tcpAssembler.FlushCloseOlderThan(time.Now().Add(-timeout))
w.logger.TCPFlush(w.id, flushed, closed)
}
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) { func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
ctx := &udpContext{ ctx := &udpContext{
Verdict: udpVerdictAccept, Verdict: udpVerdictAccept,

2
go.mod
View File

@@ -5,7 +5,7 @@ go 1.21
require ( require (
github.com/bwmarrin/snowflake v0.3.0 github.com/bwmarrin/snowflake v0.3.0
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/expr-lang/expr v1.15.7 github.com/expr-lang/expr v1.16.3
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866 github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7

4
go.sum
View File

@@ -7,8 +7,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo= github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to=
github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ= github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow= github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4= github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=

View File

@@ -2,6 +2,7 @@ package io
import ( import (
"context" "context"
"net"
) )
type Verdict int type Verdict int
@@ -29,7 +30,6 @@ type Packet interface {
// PacketCallback is called for each packet received. // PacketCallback is called for each packet received.
// Return false to "unregister" and stop receiving packets. // Return false to "unregister" and stop receiving packets.
// It must be safe for concurrent use.
type PacketCallback func(Packet, error) bool type PacketCallback func(Packet, error) bool
type PacketIO interface { type PacketIO interface {
@@ -39,6 +39,10 @@ type PacketIO interface {
Register(context.Context, PacketCallback) error Register(context.Context, PacketCallback) error
// SetVerdict sets the verdict for a packet. // SetVerdict sets the verdict for a packet.
SetVerdict(Packet, Verdict, []byte) error SetVerdict(Packet, Verdict, []byte) error
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
// in the sense that the packets sent/received through the connection must bypass
// the packet IO and not be processed by the callback.
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
// Close closes the packet IO. // Close closes the packet IO.
Close() error Close() error
} }

View File

@@ -5,9 +5,11 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
"syscall"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue" "github.com/florianl/go-nfqueue"
@@ -50,6 +52,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
} }
for i := range table.Chains { for i := range table.Chains {
c := &table.Chains[i] c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst { if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
@@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
} }
rules := make([]iptRule, 0, 4*len(chains)) rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains { for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst { if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
@@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
// iptables not nil = use iptables instead of nftables // iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables ipt4 *iptables.IPTables
ipt6 *iptables.IPTables ipt6 *iptables.IPTables
protectedDialer *net.Dialer
} }
type NFQueuePacketIOConfig struct { type NFQueuePacketIOConfig struct {
@@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
rst: config.RST, rst: config.RST,
ipt4: ipt4, ipt4: ipt4,
ipt6: ipt6, ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil }, nil
} }
@@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
} }
} }
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error { func (n *nfqueuePacketIO) Close() error {
if n.rSet { if n.rSet {
if n.ipt4 != nil { if n.ipt4 != nil {

View File

@@ -14,14 +14,12 @@ type GeoMatcher struct {
ipMatcherLock sync.Mutex ipMatcherLock sync.Mutex
} }
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) { func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)
return &GeoMatcher{ return &GeoMatcher{
geoLoader: geoLoader, geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher), geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher), geoIpMatcher: make(map[string]hostMatcher),
}, nil }
} }
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {

View File

@@ -1,11 +1,15 @@
package ruleset package ruleset
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"time"
"github.com/expr-lang/expr/builtin"
"github.com/expr-lang/expr" "github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/ast"
@@ -55,10 +59,9 @@ type compiledExprRule struct {
var _ Ruleset = (*exprRuleset)(nil) var _ Ruleset = (*exprRuleset)(nil)
type exprRuleset struct { type exprRuleset struct {
Rules []compiledExprRule Rules []compiledExprRule
Ans []analyzer.Analyzer Ans []analyzer.Analyzer
Logger Logger Logger Logger
GeoMatcher *geo.GeoMatcher
} }
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
@@ -100,10 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
fullAnMap := analyzersToMap(ans) fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods) fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer) depAnMap := make(map[string]analyzer.Analyzer)
geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) funcMap := buildFunctionMap(config)
if err != nil {
return nil, err
}
// Compile all rules and build a map of analyzers that are used by the rules. // Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules { for _, rule := range rules {
if rule.Action == "" && !rule.Log { if rule.Action == "" && !rule.Log {
@@ -118,13 +118,19 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
action = &a action = &a
} }
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)} visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
patcher := &idPatcher{} patcher := &idPatcher{FuncMap: funcMap}
program, err := expr.Compile(rule.Expr, program, err := expr.Compile(rule.Expr,
func(c *conf.Config) { func(c *conf.Config) {
c.Strict = false c.Strict = false
c.Expect = reflect.Bool c.Expect = reflect.Bool
c.Visitors = append(c.Visitors, visitor, patcher) c.Visitors = append(c.Visitors, visitor, patcher)
registerBuiltinFunctions(c.Functions, geoMatcher) for name, f := range funcMap {
c.Functions[name] = &builtin.Function{
Name: name,
Func: f.Func,
Types: f.Types,
}
}
}, },
) )
if err != nil { if err != nil {
@@ -138,24 +144,15 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
if isBuiltInAnalyzer(name) || visitor.Variables[name] { if isBuiltInAnalyzer(name) || visitor.Variables[name] {
continue continue
} }
// Check if it's one of the built-in functions, and if so, if f, ok := funcMap[name]; ok {
// skip it as an analyzer & do initialization if necessary. // Built-in function, initialize if necessary
switch name { if f.InitFunc != nil {
case "geoip": if err := f.InitFunc(); err != nil {
if err := geoMatcher.LoadGeoIP(); err != nil { return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err)
return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err) }
}
case "geosite":
if err := geoMatcher.LoadGeoSite(); err != nil {
return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err)
}
case "cidr":
// No initialization needed for CIDR.
default:
a, ok := fullAnMap[name]
if !ok {
return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name)
} }
} else if a, ok := fullAnMap[name]; ok {
// Analyzer, add to dependency map
depAnMap[name] = a depAnMap[name] = a
} }
} }
@@ -184,37 +181,12 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
depAns = append(depAns, a) depAns = append(depAns, a)
} }
return &exprRuleset{ return &exprRuleset{
Rules: compiledRules, Rules: compiledRules,
Ans: depAns, Ans: depAns,
Logger: config.Logger, Logger: config.Logger,
GeoMatcher: geoMatcher,
}, nil }, nil
} }
func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo.GeoMatcher) {
funcMap["geoip"] = &ast.Function{
Name: "geoip",
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
}
funcMap["geosite"] = &ast.Function{
Name: "geosite",
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
}
funcMap["cidr"] = &ast.Function{
Name: "cidr",
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
}
}
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
m := map[string]interface{}{ m := map[string]interface{}{
"id": info.ID, "id": info.ID,
@@ -299,29 +271,109 @@ func (v *idVisitor) Visit(node *ast.Node) {
// idPatcher patches the AST during expr compilation, replacing certain values with // idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance. // their internal representations for better runtime performance.
type idPatcher struct { type idPatcher struct {
Err error FuncMap map[string]*Function
Err error
} }
func (p *idPatcher) Visit(node *ast.Node) { func (p *idPatcher) Visit(node *ast.Node) {
switch (*node).(type) { switch (*node).(type) {
case *ast.CallNode: case *ast.CallNode:
callNode := (*node).(*ast.CallNode) callNode := (*node).(*ast.CallNode)
if callNode.Func == nil { if callNode.Callee == nil {
// Ignore invalid call nodes // Ignore invalid call nodes
return return
} }
switch callNode.Func.Name { if f, ok := p.FuncMap[callNode.Callee.String()]; ok {
case "cidr": if f.PatchFunc != nil {
cidrStringNode, ok := callNode.Arguments[1].(*ast.StringNode) if err := f.PatchFunc(&callNode.Arguments); err != nil {
if !ok { p.Err = err
return return
}
} }
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
if err != nil {
p.Err = err
return
}
callNode.Arguments[1] = &ast.ConstantNode{Value: cidr}
} }
} }
} }
type Function struct {
InitFunc func() error
PatchFunc func(args *[]ast.Node) error
Func func(params ...any) (any, error)
Types []reflect.Type
}
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
return map[string]*Function{
"geoip": {
InitFunc: geoMatcher.LoadGeoIP,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
},
"geosite": {
InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
},
"cidr": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
cidrStringNode, ok := (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("cidr: invalid argument type")
}
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
if err != nil {
return err
}
(*args)[1] = &ast.ConstantNode{Value: cidr}
return nil
},
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
},
"lookup": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode
if len(*args) > 1 {
// Has the optional server argument
var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("lookup: invalid argument type")
}
}
r := &net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
if serverStr != nil {
address = serverStr.Value
}
return config.ProtectedDialContext(ctx, network, address)
},
}
if len(*args) > 1 {
(*args)[1] = &ast.ConstantNode{Value: r}
} else {
*args = append(*args, &ast.ConstantNode{Value: r})
}
return nil
},
Func: func(params ...any) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
},
Types: []reflect.Type{
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
},
},
}
}

View File

@@ -1,6 +1,7 @@
package ruleset package ruleset
import ( import (
"context"
"net" "net"
"strconv" "strconv"
@@ -100,7 +101,8 @@ type Logger interface {
} }
type BuiltinConfig struct { type BuiltinConfig struct {
Logger Logger Logger Logger
GeoSiteFilename string GeoSiteFilename string
GeoIpFilename string GeoIpFilename string
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
} }