Compare commits

..

52 Commits

Author SHA1 Message Date
刘祥超
17af07cce0 edge-node gc命令增加耗时和gc pause时长 2023-12-25 09:24:20 +08:00
刘祥超
cfa57fac66 优化计数器相关测试用例 2023-12-24 16:08:57 +08:00
刘祥超
47523eaa73 优化计数器性能 2023-12-24 15:11:09 +08:00
刘祥超
27a24c6a8a 版本号修改为1.3.2 2023-12-24 11:14:45 +08:00
刘祥超
9bc2b1a651 WAF参数中增加“请求来源” 2023-12-24 10:03:24 +08:00
刘祥超
4f24b7f39c 增加Websocket连接数统计 2023-12-20 11:43:00 +08:00
刘祥超
4607a1f4e7 版本号修改为1.3.1.2 2023-12-18 08:51:22 +08:00
刘祥超
0f2068b161 优化TCP源站错误提示 2023-12-15 18:38:09 +08:00
刘祥超
c039691a71 缓存设置中可以设置缓存主域名,用来复用多域名下的缓存 2023-12-13 18:41:51 +08:00
刘祥超
930ee44065 根据系统环境调整WebP转换线程数 2023-12-12 09:55:18 +08:00
刘祥超
8a9aac7d72 优化代码 2023-12-11 20:35:48 +08:00
刘祥超
e50bbb962d WebP策略变化时只更新相关配置 2023-12-11 11:09:12 +08:00
刘祥超
9ff936d0c1 WebP转换质量转移到WebP策略配置 2023-12-11 10:17:17 +08:00
刘祥超
f53727b09c WebP转换限制为单线程,防止占用系统资源过高 2023-12-11 09:33:04 +08:00
刘祥超
525ce1f923 优化WAF XSS检测,减少对图片内容的误判 2023-12-10 19:40:29 +08:00
刘祥超
16e7cd800c WAF SQL注入检测和XSS注入检测自动进行URL解码 2023-12-10 16:52:54 +08:00
刘祥超
3f34bfc0b0 节点进程停止时,自动保存WAF临时白名单,并在进程重新启动后恢复 2023-12-10 15:41:31 +08:00
刘祥超
548cd1002b 增加WAF相关测试用例 2023-12-10 09:27:29 +08:00
刘祥超
3423865868 优化测试用例 2023-12-10 08:54:39 +08:00
刘祥超
037bc8e0de 优化WAF单词匹配性能 2023-12-09 19:19:29 +08:00
刘祥超
e03292de28 WAF规则模板中XSS注入检测规则使用“包含XSS注入”操作符替代以往的正则表达式 2023-12-09 17:00:21 +08:00
刘祥超
ee2565905e 优化WAF动作“显示网页”显示 2023-12-09 15:55:40 +08:00
刘祥超
05881b457d WAF规则模板中SQL注入规则使用“包含SQL注入”操作符替代以往的正则表达式 2023-12-09 15:28:07 +08:00
刘祥超
b116effc6c WAF SQL注入和XSS检测增加缓存/优化部分WAF相关测试用例 2023-12-09 11:46:50 +08:00
刘祥超
536efeeb9c 提升单词匹配性能 2023-12-09 10:06:07 +08:00
刘祥超
e8638e4bec WAF检查项增加“所有报头名称” 2023-12-08 15:39:23 +08:00
刘祥超
c9db722129 WAF增加“包含XSS注入”操作符 2023-12-08 10:15:18 +08:00
刘祥超
90de472bd5 增加测试用例 2023-12-07 20:47:25 +08:00
刘祥超
50c6c60abf WAF SQL注入检测时支持 (http|https):// 开头的URL 2023-12-07 20:38:06 +08:00
刘祥超
cc10372fe1 WAF增加“包含SQL注入”操作符 2023-12-07 20:25:35 +08:00
刘祥超
05c98a0656 修复一处单词错误 2023-12-07 12:14:04 +08:00
刘祥超
1a790fe391 优化代码 2023-12-07 12:07:06 +08:00
刘祥超
7dbd73cb59 优化WAF中前缀和后缀相关操作符性能 2023-12-07 12:05:08 +08:00
刘祥超
4dfa571547 WAF操作符增加包含任一单词、包含所有单词、不包含任一单词 2023-12-07 11:42:59 +08:00
刘祥超
9f77f62308 WAF checkpoint返回值支持[][]byte 2023-12-05 17:18:53 +08:00
刘祥超
facea1ed96 优化代码 2023-12-05 16:28:10 +08:00
刘祥超
e367814db3 内容压缩级别允许为0 2023-12-05 10:48:17 +08:00
刘祥超
3a15408c98 修复缓存命中率统计测试用例 2023-12-03 14:55:09 +08:00
刘祥超
c504b37118 WAF相关跳转不计入统计 2023-12-03 14:41:11 +08:00
刘祥超
74708dc02f 默认不启用内存分片管理 2023-12-03 14:26:51 +08:00
刘祥超
0c097498bb 优化链表相关代码 2023-12-03 11:27:47 +08:00
刘祥超
981c063eff 优化验证码性能 2023-11-30 17:25:41 +08:00
刘祥超
5e35c50113 页面优化增加例外URL和限制URL 2023-11-30 15:48:50 +08:00
刘祥超
e6c2869ff2 增加“极验-行为验”验证码集成支持 2023-11-29 17:00:06 +08:00
刘祥超
358bec2e9b WAF验证码验证后返回时判断是否已通过验证 2023-11-28 20:39:42 +08:00
刘祥超
1cd644f2eb 优化验证码加载方式,减少不必要的图片生成 2023-11-28 18:07:27 +08:00
刘祥超
f783e5c331 将版本号修改为1.3.1 2023-11-23 17:19:41 +08:00
刘祥超
c39b1c794f 修复清空文件索引Map时产生并发异常 2023-11-23 17:14:50 +08:00
刘祥超
2633d43897 增加最大内存用量 2023-11-22 17:03:42 +08:00
刘祥超
88dca006c4 优化日志 2023-11-22 16:44:06 +08:00
刘祥超
98feb26b79 优化brotli压缩和解压缩性能 2023-11-21 20:18:37 +08:00
刘祥超
ac6683e79d GRPC增加Keepalive参数 2023-11-20 09:56:50 +08:00
87 changed files with 17417 additions and 1070 deletions

View File

@@ -10,7 +10,7 @@ function build() {
# for macOS users: precompiled gcc can be downloaded from https://github.com/messense/homebrew-macos-cross-toolchains
GCC_X86_64_DIR="/usr/local/gcc/x86_64-unknown-linux-gnu/bin"
GCC_ARM64_DIR="//usr/local/gcc/aarch64-unknown-linux-gnu/bin"
GCC_ARM64_DIR="/usr/local/gcc/aarch64-unknown-linux-gnu/bin"
OS=${1}
ARCH=${2}
@@ -123,8 +123,8 @@ function build() {
# libpcap
if [ "$OS" == "linux" ] && [[ "$ARCH" == "amd64" || "$ARCH" == "arm64" ]] && [ "$TAG" == "plus" ]; then
CGO_LDFLAGS="-L${SRCDIR}/libs/libpcap/${ARCH} -lpcap"
CGO_CFLAGS="-I${SRCDIR}/libs/libpcap/src/libpcap -I${SRCDIR}/libs/libpcap/src/libpcap/pcap"
CGO_LDFLAGS="-L${SRCDIR}/libs/libpcap/${ARCH} -lpcap -L${SRCDIR}/libs/libbrotli/${ARCH} -lbrotlienc -lbrotlidec -lbrotlicommon"
CGO_CFLAGS="-I${SRCDIR}/libs/libpcap/src/libpcap -I${SRCDIR}/libs/libpcap/src/libpcap/pcap -I${SRCDIR}/libs/libbrotli/src/brotli/c/include"
fi
if [ ! -z $CC_PATH ]; then

View File

@@ -228,11 +228,18 @@ func main() {
})
app.On("gc", func() {
var sock = gosock.NewTmpSock(teaconst.ProcessName)
_, err := sock.Send(&gosock.Command{Code: "gc"})
reply, err := sock.Send(&gosock.Command{Code: "gc"})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
fmt.Println("ok")
if reply == nil {
fmt.Println("ok")
} else {
var paramMap = maps.NewMap(reply.Params)
var pauseMS = paramMap.GetFloat64("pauseMS")
var costMS = paramMap.GetFloat64("costMS")
fmt.Printf("ok, cost: %.4fms, pause: %.4fms", costMS, pauseMS)
}
}
})
app.On("ip.drop", func() {

View File

@@ -133,7 +133,10 @@ func (this *FileListHashMap) Clean() {
this.lockers[i].Lock()
}
this.m = make([]map[uint64]zero.Zero, HashMapSharding)
// 这里不能简单清空 this.m ,避免导致别的数据无法写入 map 而产生 panic
for i := 0; i < HashMapSharding; i++ {
this.m[i] = map[uint64]zero.Zero{}
}
for i := HashMapSharding - 1; i >= 0; i-- {
this.lockers[i].Unlock()

View File

@@ -125,6 +125,13 @@ func TestFileListHashMap_Delete(t *testing.T) {
a.IsTrue(m.Len() == 0)
}
func TestFileListHashMap_Clean(t *testing.T) {
var m = caches.NewFileListHashMap()
m.SetIsAvailable(true)
m.Clean()
m.Add("a")
}
func Benchmark_BigInt(b *testing.B) {
var hash = stringutil.Md5("123456")
b.ResetTimer()

View File

@@ -15,6 +15,7 @@ import (
)
const (
enableFragmentPool = false
minMemoryFragmentPoolItemSize = 8 << 10
maxMemoryFragmentPoolItemSize = 128 << 20
maxItemsInMemoryFragmentPoolBucket = 1024

View File

@@ -517,7 +517,7 @@ func (this *MemoryStorage) flushItem(key string) {
_ = this.Delete(key)
// 重用内存,前提是确保内存不再被引用
if ok && item.IsDone && !item.isReferring && len(item.BodyValue) > 0 {
if enableFragmentPool && ok && item.IsDone && !item.isReferring && len(item.BodyValue) > 0 {
SharedFragmentMemoryPool.Put(item.BodyValue)
}
}()

View File

@@ -32,7 +32,9 @@ func NewMemoryWriter(memoryStorage *MemoryStorage, key string, expiredAt int64,
ModifiedAt: fasttime.Now().Unix(),
Status: status,
}
if expectedBodySize > 0 && expectedBodySize <= maxMemoryFragmentPoolItemSize {
if enableFragmentPool &&
expectedBodySize > 0 &&
expectedBodySize <= maxMemoryFragmentPoolItemSize {
bodyBytes, ok := SharedFragmentMemoryPool.Get(expectedBodySize) // try to reuse memory
if ok {
valueItem.BodyValue = bodyBytes
@@ -168,7 +170,8 @@ func (this *MemoryWriter) Discard() error {
this.storage.locker.Lock()
delete(this.storage.valuesMap, this.hash)
if this.item != nil &&
if enableFragmentPool &&
this.item != nil &&
!this.item.isReferring &&
cap(this.item.BodyValue) >= minMemoryFragmentPoolItemSize {
SharedFragmentMemoryPool.Put(this.item.BodyValue)

View File

@@ -1,4 +1,5 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus || !linux
package compressions

View File

@@ -1,4 +1,5 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus || !linux
package compressions
@@ -27,7 +28,7 @@ func newBrotliWriter(writer io.Writer, level int) (*BrotliWriter, error) {
return &BrotliWriter{
writer: brotli.NewWriterOptions(writer, brotli.WriterOptions{
Quality: level,
LGWin: 13, // TODO 在全局设置里可以设置此值
LGWin: 14, // TODO 在全局设置里可以设置此值
}),
level: level,
}, nil

View File

@@ -19,6 +19,10 @@ func NewZSTDWriter(writer io.Writer, level int) (Writer, error) {
}
func newZSTDWriter(writer io.Writer, level int) (Writer, error) {
if level < 0 {
level = 0
}
var zstdLevel = zstd.EncoderLevelFromZstd(level)
zstdWriter, err := zstd.NewWriter(writer, zstd.WithEncoderLevel(zstdLevel))

View File

@@ -9,6 +9,24 @@ import (
"testing"
)
func TestNewZSTDWriter_Level0(t *testing.T) {
var buf = &bytes.Buffer{}
writer, err := compressions.NewZSTDWriter(buf, 0)
if err != nil {
t.Fatal(err)
}
var originData = []byte(strings.Repeat("Hello", 1024))
_, err = writer.Write(originData)
if err != nil {
t.Fatal(err)
}
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log("origin data:", len(originData), "result:", buf.Len())
}
func TestNewZSTDWriter(t *testing.T) {
var buf = &bytes.Buffer{}
writer, err := compressions.NewZSTDWriter(buf, 10)

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "1.3.0"
Version = "1.3.2"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -90,13 +90,13 @@ func (this *APIStream) loop() error {
break
}
message, err := nodeStream.Recv()
if err != nil {
message, streamErr := nodeStream.Recv()
if streamErr != nil {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
return nil
}
return err
return streamErr
}
// 处理消息

View File

@@ -85,6 +85,8 @@ type HTTPRequest struct {
isAttack bool // 是否是攻击请求
requestBodyData []byte // 读取的Body内容
isWebsocketResponse bool // 是否为Websocket响应非请求
// WAF相关
firewallPolicyId int64
firewallRuleGroupId int64
@@ -410,6 +412,8 @@ func (this *HTTPRequest) doEnd() {
var countAttacks int64 = 0
var attackBytes int64 = 0
var countWebsocketConnections int64 = 0
if this.isCached {
countCached = 1
cachedBytes = totalBytes
@@ -421,8 +425,11 @@ func (this *HTTPRequest) doEnd() {
attackBytes = totalBytes
}
}
if this.isWebsocketResponse {
countWebsocketConnections = 1
}
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, countWebsocketConnections, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
// 指标
if metrics.SharedManager.HasHTTPMetrics() {

View File

@@ -3,6 +3,7 @@ package nodes
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -130,7 +131,22 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var tags = []string{}
// 检查是否有缓存
var key = this.Format(this.cacheRef.Key)
var key string
if this.web.Cache.Key != nil && this.web.Cache.Key.IsOn && len(this.web.Cache.Key.Host) > 0 {
key = configutils.ParseVariables(this.cacheRef.Key, func(varName string) (value string) {
switch varName {
case "scheme":
return this.web.Cache.Key.Scheme
case "host":
return this.web.Cache.Key.Host
default:
return this.Format("${" + varName + "}")
}
})
} else {
key = this.Format(this.cacheRef.Key)
}
if len(key) == 0 {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, empty key"

View File

@@ -434,7 +434,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// Page optimization
if this.web.Optimization != nil && resp.Body != nil && this.cacheRef != nil /** must under cache **/ {
err := this.web.Optimization.FilterResponse(resp)
err := this.web.Optimization.FilterResponse(this.URL(), resp)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Page Optimization: Fail to read content from origin", "内容优化:从源站读取内容失败", false)
return

View File

@@ -61,6 +61,9 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}
}
// 标记
this.isWebsocketResponse = true
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
var newRequestOrigin = this.web.Websocket.RequestOrigin
@@ -77,7 +80,6 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}
// 连接源站
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil {
if isLastRetry {

View File

@@ -11,7 +11,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
@@ -34,22 +33,19 @@ import (
"net/textproto"
"os"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
)
var webpMaxBufferSize int64 = 1_000_000_000
var webpTotalBufferSize int64 = 0
var webpIgnoreURLSet = setutils.NewFixedSet(131072)
var webPThreads int32
var webPMaxThreads int32 = 1
var webPIgnoreURLSet = setutils.NewFixedSet(131072)
func init() {
if !teaconst.IsMain {
return
}
var systemMemory = utils.SystemMemoryGB() / 8
if systemMemory > 0 {
webpMaxBufferSize = int64(systemMemory) << 30
webPMaxThreads = int32(runtime.NumCPU() / 4)
if webPMaxThreads < 1 {
webPMaxThreads = 1
}
}
@@ -80,6 +76,7 @@ type HTTPWriter struct {
// WebP
webpIsEncoding bool
webpOriginContentType string
webpQuality int
// Compression
compressionConfig *serverconfigs.HTTPCompressionConfig
@@ -483,8 +480,8 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
contentTypeWritten = true
}
err := cacheWriter.WriteAt(start, data)
if err != nil {
writeErr := cacheWriter.WriteAt(start, data)
if writeErr != nil {
hasError = true
this.cacheIsFinished = false
}
@@ -531,6 +528,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
if policy.RequireCache && this.req.cacheRef == nil {
return
}
this.webpQuality = policy.Quality
// 限制最小和最大尺寸
// TODO 需要将reader修改为LimitReader
@@ -550,7 +548,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
this.req.web.WebP.MatchResponse(contentType, size, filepath.Ext(this.req.Path()), this.req.Format) &&
this.req.web.WebP.MatchAccept(this.req.requestHeader("Accept")) {
// 检查是否已经因为尺寸过大而忽略
if webpIgnoreURLSet.Has(this.req.URL()) {
if webPIgnoreURLSet.Has(this.req.URL()) {
return
}
@@ -560,8 +558,8 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
return
}
// 检查内存
if atomic.LoadInt64(&webpTotalBufferSize) >= webpMaxBufferSize {
// 检查当前是否正在转换
if atomic.LoadInt32(&webPThreads) >= webPMaxThreads {
return
}
@@ -622,7 +620,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
return
}
if this.compressionConfig.Level <= 0 {
if this.compressionConfig.Level < 0 {
return
}
@@ -1020,6 +1018,11 @@ func (this *HTTPWriter) calculateStaleLife() int {
func (this *HTTPWriter) finishWebP() {
// 处理WebP
if this.webpIsEncoding {
atomic.AddInt32(&webPThreads, 1)
defer func() {
atomic.AddInt32(&webPThreads, -1)
}()
var webpCacheWriter caches.Writer
// 准备WebP Cache
@@ -1080,7 +1083,7 @@ func (this *HTTPWriter) finishWebP() {
if isGif {
gifImage, err = gif.DecodeAll(reader)
if gifImage != nil && (gifImage.Config.Width > gowebp.WebPMaxDimension || gifImage.Config.Height > gowebp.WebPMaxDimension) {
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
} else {
@@ -1088,7 +1091,7 @@ func (this *HTTPWriter) finishWebP() {
if imageData != nil {
var bound = imageData.Bounds()
if bound.Max.X > gowebp.WebPMaxDimension || bound.Max.Y > gowebp.WebPMaxDimension {
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
}
@@ -1096,19 +1099,21 @@ func (this *HTTPWriter) finishWebP() {
if err != nil {
// 发生了错误终止处理
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
var totalBytes = reader.TotalBytes()
atomic.AddInt64(&webpTotalBufferSize, totalBytes)
defer func() {
atomic.AddInt64(&webpTotalBufferSize, -totalBytes)
}()
var f = types.Float32(this.req.web.WebP.Quality)
if f > 100 {
f = 100
var f = types.Float32(this.webpQuality)
if f <= 0 || f > 100 {
if this.size > (8<<20) || this.size <= 0 {
f = 30
} else if this.size > (1 << 20) {
f = 50
} else if this.size > (128 << 10) {
f = 60
} else {
f = 75
}
}
if imageData != nil {

View File

@@ -47,9 +47,13 @@ func (this *TCPListener) Serve() error {
atomic.AddInt64(&this.countActiveConnections, 1)
go func(conn net.Conn) {
err = this.handleConn(conn)
var server = this.Group.FirstServer()
if server == nil {
return
}
err = this.handleConn(server, conn)
if err != nil {
remotelogs.Error("TCP_LISTENER", err.Error())
remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil)
}
atomic.AddInt64(&this.countActiveConnections, -1)
}(conn)
@@ -63,8 +67,7 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Reset()
}
func (this *TCPListener) handleConn(conn net.Conn) error {
var server = this.Group.FirstServer()
func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error {
if server == nil {
return errors.New("no server available")
}
@@ -132,14 +135,14 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 {
// 统计
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true
}
}
// 统计
if !recordStat {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
@@ -194,7 +197,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 记录流量
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
}
if err != nil {

View File

@@ -370,7 +370,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
// 统计
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
// 处理ControlMessage
@@ -401,7 +401,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
// 记录流量和带宽
if server != nil {
// 流量
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
// 带宽
var userPlanId int64

View File

@@ -777,9 +777,19 @@ func (this *Node) listenSock() error {
_ = cmd.ReplyOk()
}
case "gc":
var before = time.Now()
runtime.GC()
debug.FreeOSMemory()
_ = cmd.ReplyOk()
var costSeconds = time.Since(before).Seconds()
var gcStats = &debug.GCStats{}
debug.ReadGCStats(gcStats)
_ = cmd.Reply(&gosock.Command{
Params: map[string]any{
"pauseMS": gcStats.PauseTotal.Seconds() * 1000,
"costMS": costSeconds * 1000,
},
})
case "reload":
err := this.syncConfig(0)
if err != nil {
@@ -1039,7 +1049,7 @@ func (this *Node) reloadServer() {
for serverId, serverConfig := range updatingServerMap {
if serverConfig != nil {
if countUpdatingServers < maxPrintServers {
remotelogs.Debug("NODE", "load server '"+types.String(serverId)+"'")
remotelogs.Debug("NODE", "reload server '"+types.String(serverId)+"'")
}
newNodeConfig.AddServer(serverConfig)
} else {

View File

@@ -100,6 +100,8 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
err = this.execTOAChangedTask()
case "networkSecurityPolicyChanged":
err = this.execNetworkSecurityPolicyChangedTask(rpcClient)
case "webPPolicyChanged":
err = this.execWebPPolicyChangedTask(rpcClient)
default:
// 特殊任务
if strings.HasPrefix(task.Type, "ipListDeleted") { // 删除IP名单
@@ -325,6 +327,34 @@ func (this *Node) execDeleteIPList(taskType string) error {
return nil
}
// WebP策略变更
func (this *Node) execWebPPolicyChangedTask(rpcClient *rpc.RPCClient) error {
remotelogs.Println("NODE", "updating webp policies ...")
resp, err := rpcClient.NodeRPC.FindNodeWebPPolicies(rpcClient.Context(), &pb.FindNodeWebPPoliciesRequest{})
if err != nil {
return err
}
var webPPolicyMap = map[int64]*nodeconfigs.WebPImagePolicy{}
for _, policy := range resp.WebPPolicies {
if len(policy.WebPPolicyJSON) > 0 {
var webPPolicy = nodeconfigs.NewWebPImagePolicy()
err = json.Unmarshal(policy.WebPPolicyJSON, webPPolicy)
if err != nil {
remotelogs.Error("NODE", "decode webp policy failed: "+err.Error())
continue
}
err = webPPolicy.Init()
if err != nil {
remotelogs.Error("NODE", "initialize webp policy failed: "+err.Error())
continue
}
webPPolicyMap[policy.NodeClusterId] = webPPolicy
}
}
sharedNodeConfig.UpdateWebPImagePolicies(webPPolicyMap)
return nil
}
// 标记任务完成
func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (success bool) {
if taskId <= 0 {

View File

@@ -4,7 +4,7 @@ package re
type RuneMap map[rune]*RuneTree
func (this *RuneMap) Lookup(s string, caseInsensitive bool) bool {
func (this RuneMap) Lookup(s string, caseInsensitive bool) bool {
return this.lookup([]rune(s), caseInsensitive, 0)
}

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"net/url"
"sync"
@@ -240,12 +241,15 @@ func (this *RPCClient) init() error {
grpc.MaxCallSendMsgSize(512<<20),
grpc.UseCompressor(gzip.Name),
)
var keepaliveParams = grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
})
if u.Scheme == "http" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions)
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions, keepaliveParams)
} else if u.Scheme == "https" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), callOptions)
})), callOptions, keepaliveParams)
} else {
return errors.New("parse endpoint failed: invalid scheme '" + u.Scheme + "'")
}

View File

@@ -57,12 +57,13 @@ type BandwidthStat struct {
MaxBytes int64 `json:"maxBytes"`
TotalBytes int64 `json:"totalBytes"`
CachedBytes int64 `json:"cachedBytes"`
AttackBytes int64 `json:"attackBytes"`
CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"`
UserPlanId int64 `json:"userPlanId"`
CachedBytes int64 `json:"cachedBytes"`
AttackBytes int64 `json:"attackBytes"`
CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"`
CountWebsocketConnections int64 `json:"countWebsocketConnections"`
UserPlanId int64 `json:"userPlanId"`
}
// BandwidthStatManager 服务带宽统计
@@ -142,20 +143,21 @@ func (this *BandwidthStatManager) Loop() error {
}
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
CountWebsocketConnections: stat.CountWebsocketConnections,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId,
})
delete(this.m, key)
}
@@ -231,7 +233,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, s
}
// AddTraffic 添加请求数据
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64) {
var now = fasttime.Now()
var day = now.Ymd()
var timeAt = now.Round5Hi()
@@ -245,6 +247,7 @@ func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64,
stat.CountCachedRequests += countCachedRequests
stat.CountAttackRequests += countAttacks
stat.AttackBytes += attackBytes
stat.CountWebsocketConnections += countWebsocketConnections
}
this.locker.Unlock()
}

View File

@@ -53,19 +53,20 @@ func BenchmarkBandwidthStatManager_Slice(b *testing.B) {
for j := 0; j < 100; j++ {
var stat = &stats.BandwidthStat{}
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / 2,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
NodeRegionId: 1,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / 2,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
CountWebsocketConnections: stat.CountWebsocketConnections,
NodeRegionId: 1,
})
}
_ = pbStats

View File

@@ -106,13 +106,13 @@ func (this *TrafficStatManager) Start() {
}
// Add 添加流量
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) {
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64, checkingTrafficLimit bool, planId int64) {
if serverId == 0 {
return
}
// 添加到带宽
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes)
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes, countWebsocketConnections)
if bytes == 0 && countRequests == 0 {
return

View File

@@ -11,7 +11,7 @@ import (
func TestTrafficStatManager_Add(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, 0, false, 0)
}
t.Log(manager.itemMap)
}
@@ -19,7 +19,7 @@ func TestTrafficStatManager_Add(t *testing.T) {
func TestTrafficStatManager_Upload(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, 0, false, 0)
}
err := manager.Upload()
if err != nil {
@@ -36,7 +36,7 @@ func BenchmarkTrafficStatManager_Add(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, 0, false, 0)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/Tea"
"sync"
"sync/atomic"
"time"
@@ -138,7 +139,7 @@ func (this *Stat) IsGood(category string) bool {
return true
}
if item.countCached > countSamples && item.timestamp < fasttime.Now().Unix()-600 /** 10 minutes ago **/ {
if item.countCached > countSamples && (Tea.IsTesting() || item.timestamp < fasttime.Now().Unix()-600) /** 10 minutes ago **/ {
var isGood = item.countHits*100/item.countCached >= this.goodRatio
if isGood {
item.isGood = true

View File

@@ -58,10 +58,10 @@ func TestNewStat(t *testing.T) {
{
var stat = cachehits.NewStat(5)
for i := 0; i < 10001; i++ {
for i := 0; i < 100001; i++ {
stat.IncreaseCached("a")
}
for i := 0; i < 499; i++ {
for i := 0; i < 4999; i++ {
stat.IncreaseHit("a")
}

View File

@@ -54,6 +54,7 @@ func TestCounter_GC(t *testing.T) {
time.Sleep(1 * time.Second)
counter.Increase(1, 20)
counter.GC()
t.Log(counter.Get(1))
}
func TestCounter_GC2(t *testing.T) {
@@ -62,7 +63,7 @@ func TestCounter_GC2(t *testing.T) {
}
var counter = counters.NewCounter[uint32]().WithGC()
for i := 0; i < 1e5; i++ {
for i := 0; i < 100_000; i++ {
counter.Increase(uint64(i), rands.Int(10, 300))
}
@@ -90,9 +91,22 @@ func TestCounterMemory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
t.Log((stat1.TotalAlloc-stat.TotalAlloc)/(1<<20), "MB")
t.Log((stat1.HeapInuse-stat.HeapInuse)/(1<<20), "MB")
t.Log(counter.TotalItems())
var gcPause = func() {
var before = time.Now()
runtime.GC()
var costSeconds = time.Since(before).Seconds()
var stats = &debug.GCStats{}
debug.ReadGCStats(stats)
t.Log("GC pause:", stats.PauseTotal.Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
}
gcPause()
_ = counter.TotalItems()
}
func BenchmarkCounter_Increase(b *testing.B) {

View File

@@ -7,9 +7,10 @@ import (
)
const spanMaxValue = 10_000_000
const maxSpans = 10
type Item[T SupportedUIntType] struct {
spans []T
spans [maxSpans + 1]T
lastUpdateTime int64
lifeSeconds int64
spanSeconds int64
@@ -19,16 +20,16 @@ func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] {
if lifeSeconds <= 0 {
lifeSeconds = 60
}
var spanSeconds = lifeSeconds / 10
var spanSeconds = lifeSeconds / maxSpans
if spanSeconds < 1 {
spanSeconds = 1
} else if lifeSeconds > maxSpans && lifeSeconds%maxSpans != 0 {
spanSeconds++
}
var countSpans = lifeSeconds/spanSeconds + 1 /** prevent index out of bounds **/
return &Item[T]{
lifeSeconds: int64(lifeSeconds),
spanSeconds: int64(spanSeconds),
spans: make([]T, countSpans),
lastUpdateTime: fasttime.Now().Unix(),
}
}
@@ -119,5 +120,9 @@ func (this *Item[T]) IsExpired(currentTime int64) bool {
}
func (this *Item[T]) calculateSpanIndex(timestamp int64) int {
return int(timestamp % this.lifeSeconds / this.spanSeconds)
var index = int(timestamp % this.lifeSeconds / this.spanSeconds)
if index > maxSpans-1 {
return maxSpans - 1
}
return index
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"runtime"
"testing"
@@ -41,9 +42,9 @@ func TestItem_Increase2(t *testing.T) {
var a = assert.NewAssertion(t)
var item = counters.NewItem[uint32](20)
var item = counters.NewItem[uint32](23)
for i := 0; i < 100; i++ {
t.Log(item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
t.Log("round "+types.String(i)+":", item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
time.Sleep(2 * time.Second)
}
@@ -56,14 +57,14 @@ func TestItem_IsExpired(t *testing.T) {
return
}
var currentTime = time.Now().Unix()
var item = counters.NewItem[uint32](10)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(10 * time.Second)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(2 * time.Second)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(2 * time.Second)
t.Log(item.IsExpired(time.Now().Unix()))
}
func BenchmarkItem_Increase(b *testing.B) {

View File

@@ -2,7 +2,7 @@
package linkedlist
type List[T any] struct {
type List[T any] struct {
head *Item[T]
end *Item[T]
count int
@@ -36,6 +36,15 @@ func (this *List[T]) Push(item *Item[T]) {
this.add(item)
}
func (this *List[T]) Shift() *Item[T] {
if this.head != nil {
var old = this.head
this.Remove(this.head)
return old
}
return nil
}
func (this *List[T]) Remove(item *Item[T]) {
if item == nil {
return
@@ -71,6 +80,15 @@ func (this *List[T]) Range(f func(item *Item[T]) (goNext bool)) {
}
}
func (this *List[T]) RangeReverse(f func(item *Item[T]) (goNext bool)) {
for e := this.end; e != nil; e = e.prev {
goNext := f(e)
if !goNext {
break
}
}
}
func (this *List[T]) Reset() {
this.head = nil
this.end = nil

View File

@@ -4,6 +4,7 @@ package linkedlist_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
"github.com/iwind/TeaGo/types"
"runtime"
"strconv"
"testing"
@@ -95,6 +96,48 @@ func TestList_Push(t *testing.T) {
})
}
func TestList_Shift(t *testing.T) {
var list = linkedlist.NewList[int]()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
list.Push(linkedlist.NewItem(3))
list.Push(linkedlist.NewItem(4))
for i := 0; i < 10; i++ {
t.Log("=== before shift " + types.String(i) + " ===")
list.Range(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
t.Logf("shift: %+v", list.Shift())
t.Log("=== after shift " + types.String(i) + " ===")
list.Range(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
}
}
func TestList_RangeReverse(t *testing.T) {
var list = linkedlist.NewList[int]()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
var item3 = linkedlist.NewItem(3)
list.Push(item3)
list.Push(linkedlist.NewItem(4))
//list.Push(item3)
//list.Remove(item3)
list.RangeReverse(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
}
func BenchmarkList_Add(b *testing.B) {
var list = linkedlist.NewList[int]()
for i := 0; i < b.N; i++ {

View File

@@ -0,0 +1,170 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package runes
// ContainsAnyWordRunes 直接使用rune检查字符串是否包含任一单词
func ContainsAnyWordRunes(s string, words [][]rune, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
var lastRune rune // last searching rune in s
var lastIndex = -2 // -2: not started, -1: not found, >=0: rune index
for _, wordRunes := range words {
if len(wordRunes) == 0 {
continue
}
if lastIndex > -2 && lastRune == wordRunes[0] {
if lastIndex >= 0 {
result, _ := ContainsWordRunes(allRunes[lastIndex:], wordRunes, isCaseInsensitive)
if result {
return true
}
}
continue
} else {
result, firstIndex := ContainsWordRunes(allRunes, wordRunes, isCaseInsensitive)
lastIndex = firstIndex
if result {
return true
}
}
lastRune = wordRunes[0]
}
return false
}
// ContainsAnyWord 检查字符串是否包含任一单词
func ContainsAnyWord(s string, words []string, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
var lastRune rune // last searching rune in s
var lastIndex = -2 // -2: not started, -1: not found, >=0: rune index
for _, word := range words {
var wordRunes = []rune(word)
if len(wordRunes) == 0 {
continue
}
if lastIndex > -2 && lastRune == wordRunes[0] {
if lastIndex >= 0 {
result, _ := ContainsWordRunes(allRunes[lastIndex:], wordRunes, isCaseInsensitive)
if result {
return true
}
}
continue
} else {
result, firstIndex := ContainsWordRunes(allRunes, wordRunes, isCaseInsensitive)
lastIndex = firstIndex
if result {
return true
}
}
lastRune = wordRunes[0]
}
return false
}
// ContainsAllWords 检查字符串是否包含所有单词
func ContainsAllWords(s string, words []string, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
for _, word := range words {
if result, _ := ContainsWordRunes(allRunes, []rune(word), isCaseInsensitive); !result {
return false
}
}
return true
}
// ContainsWordRunes 检查字符列表是否包含某个单词子字符列表
func ContainsWordRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) (result bool, firstIndex int) {
firstIndex = -1
var l = len(subRunes)
if l == 0 {
return false, 0
}
var al = len(allRunes)
for index, r := range allRunes {
if EqualRune(r, subRunes[0], isCaseInsensitive) && (index == 0 || !isChar(allRunes[index-1]) /**boundary check **/) {
if firstIndex < 0 {
firstIndex = index
}
var found = true
if l > 1 {
for i := 1; i < l; i++ {
var subIndex = index + i
if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) {
found = false
break
}
}
}
// check after charset
if found && (al <= index+l || !isChar(allRunes[index+l]) /**boundary check **/) {
return true, firstIndex
}
}
}
return false, firstIndex
}
// ContainsSubRunes 检查字符列表是否包含某个子子字符列表
// 与 ContainsWordRunes 不同,这里不需要检查边界符号
func ContainsSubRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) bool {
var l = len(subRunes)
if l == 0 {
return false
}
var al = len(allRunes)
for index, r := range allRunes {
if EqualRune(r, subRunes[0], isCaseInsensitive) {
var found = true
if l > 1 {
for i := 1; i < l; i++ {
var subIndex = index + i
if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) {
found = false
break
}
}
}
// check after charset
if found {
return true
}
}
}
return false
}
// EqualRune 判断两个rune是否相同
func EqualRune(r1 rune, r2 rune, isCaseInsensitive bool) bool {
const d = 'a' - 'A'
return r1 == r2 ||
(isCaseInsensitive && r1 >= 'a' && r1 <= 'z' && r1-r2 == d) ||
(isCaseInsensitive && r1 >= 'A' && r1 <= 'Z' && r1-r2 == -d)
}
func isChar(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9'
}

View File

@@ -0,0 +1,172 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package runes_test
import (
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/TeaOSLab/EdgeNode/internal/utils/runes"
"github.com/iwind/TeaGo/assert"
"regexp"
"runtime"
"sort"
"strings"
"testing"
)
func TestContainsAllWords(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAllWords("How are you?", []string{"are", "you"}, false))
a.IsFalse(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, false))
a.IsTrue(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, true))
}
func TestContainsAnyWord(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you"}, false))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you", "ok"}, false))
a.IsFalse(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, false))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how"}, true))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, true))
a.IsTrue(runes.ContainsAnyWord("How-are you?", []string{"how", "ok"}, true))
}
func TestContainsAnyWord_Sort(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"abc", "ant", "arm", "Hit", "Hi", "Pet", "pie", "are"}, false))
}
func TestContainsWordRunes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(runes.ContainsWordRunes([]rune(""), []rune("How"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune(""), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("How"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("you"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("are"), false))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("re"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you w?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("w How are you?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are w you?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are how you?"), []rune("how"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("ARE"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("you"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("YOU"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU YOU YOU YOU YOU YOU YOU"), true))
}
func TestContainsSubRunes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(runes.ContainsSubRunes([]rune(""), []rune("How"), true))
a.IsFalse(runes.ContainsSubRunes([]rune("How are you?"), []rune(""), true))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("YOU"), true))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("ow"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("H"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("How"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("oi"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("g"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("ing"), false))
a.IsFalse(runes.ContainsSubRunes([]rune("How are you doing"), []rune("int"), false))
}
func TestEqualRune(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.EqualRune('a', 'a', false))
a.IsTrue(runes.EqualRune('a', 'a', true))
a.IsFalse(runes.EqualRune('a', 'A', false))
a.IsTrue(runes.EqualRune('a', 'A', true))
a.IsFalse(runes.EqualRune('c', 'C', false))
a.IsTrue(runes.EqualRune('c', 'C', true))
a.IsTrue(runes.EqualRune('C', 'C', true))
a.IsTrue(runes.EqualRune('C', 'c', true))
a.IsTrue(runes.EqualRune('Z', 'z', true))
a.IsTrue(runes.EqualRune('z', 'Z', true))
a.IsFalse(runes.EqualRune('z', 'z'+('a'-'A'), true))
}
func BenchmarkContainsWordRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true)
}
})
}
func BenchmarkContainsAnyWord(b *testing.B) {
runtime.GOMAXPROCS(4)
var words = strings.Split("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n")
sort.Strings(words)
var wordRunes = [][]rune{}
for _, word := range words {
wordRunes = append(wordRunes, []rune(word))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsAnyWord("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0", words, true)
}
})
}
func BenchmarkContainsAnyWordRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
var words = strings.Split("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n")
sort.Strings(words)
var wordRunes = [][]rune{}
for _, word := range words {
wordRunes = append(wordRunes, []rune(word))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsAnyWordRunes("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0", wordRunes, true)
}
})
}
func BenchmarkContainsAnyWord_Regexp(b *testing.B) {
runtime.GOMAXPROCS(4)
var reg = regexp.MustCompile("(?i)" + strings.ReplaceAll("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n", "|"))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = reg.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0")
}
})
}
func BenchmarkContainsAnyWord_Re(b *testing.B) {
runtime.GOMAXPROCS(4)
var reg = re.MustCompile("(?i)" + strings.ReplaceAll("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n", "|"))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = reg.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0")
}
})
}
func BenchmarkContainsSubRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsSubRunes([]rune("How are you"), []rune("YOU"), true)
}
})
}

View File

@@ -13,6 +13,6 @@ func setMaxMemory(memoryGB int) {
memoryGB = 1
}
var maxMemoryBytes = (int64(memoryGB) << 30) * 75 / 100 // 默认 75%
var maxMemoryBytes = (int64(memoryGB) << 30) * 80 / 100 // 默认 80%
debug.SetMemoryLimit(maxMemoryBytes)
}

View File

@@ -55,6 +55,8 @@ type CaptchaAction struct {
SlideUIFooter string `yaml:"slideUIFooter" json:"slideUIFooter"` // 页脚
SlideUIBody string `yaml:"slideUIBody" json:"slideUIBody"` // 内容轮廓
GeeTestConfig *firewallconfigs.GeeTestConfig `yaml:"geeTestConfig" json:"geeTestConfig"` // 极验设置 MUST be struct
Lang string `yaml:"lang" json:"lang"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
Scope string `yaml:"scope" json:"scope"`
@@ -157,6 +159,7 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
// 占用一次失败次数
CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit)
req.DisableStat()
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)

View File

@@ -67,6 +67,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
return true, false
}
request.DisableStat()
request.ProcessResponseHeaders(writer.Header(), http.StatusFound)
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)

View File

@@ -36,10 +36,30 @@ func (this *PageAction) WillChange() bool {
// Perform the action
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
if writer == nil {
return
}
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(this.Status)
_, _ = writer.Write([]byte(request.Format(this.Body)))
var body = this.Body
if len(body) == 0 {
body = `<!DOCTYPE html>
<html lang="en">
<title>403 Forbidden</title>
<style>
address { line-height: 1.8; }
</style>
<body>
<h1>403 Forbidden By WAF</h1>
<address>Connection: ${remoteAddr} (Client) -&gt; ${serverAddr} (Server)</address>
<address>Request ID: ${requestId}</address>
</body>
</html>`
}
_, _ = writer.Write([]byte(request.Format(body)))
return false, false
}

View File

@@ -92,6 +92,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
Value: info,
})
request.DisableStat()
request.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)

View File

@@ -1,6 +1,7 @@
package waf
package waf_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
@@ -11,22 +12,22 @@ import (
func TestFindActionInstance(t *testing.T) {
a := assert.NewAssertion(t)
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil))
t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil))
t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil))
t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil))
t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", waf.FindActionInstance(waf.ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
a.IsTrue(waf.FindActionInstance(waf.ActionGoSet, nil) != waf.FindActionInstance(waf.ActionGoSet, nil))
}
func TestFindActionInstance_Options(t *testing.T) {
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
logs.PrintAsJSON(waf.FindActionInstance(waf.ActionBlock, maps.Map{
"timeout": 3600,
}), t)
}
@@ -34,6 +35,6 @@ func TestFindActionInstance_Options(t *testing.T) {
func BenchmarkFindActionInstance(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {
FindActionInstance(ActionGoSet, nil)
waf.FindActionInstance(waf.ActionGoSet, nil)
}
}

View File

@@ -15,6 +15,7 @@ type CaptchaPageCode = string
const (
CaptchaPageCodeInit CaptchaPageCode = "init"
CaptchaPageCodeShow CaptchaPageCode = "show"
CaptchaPageCodeImage CaptchaPageCode = "image"
CaptchaPageCodeSubmit CaptchaPageCode = "submit"
)
@@ -39,6 +40,7 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol
func CaptchaDeleteCacheKey(req requests.Request) {
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeInit))
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeShow))
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeImage))
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeSubmit))
}

View File

@@ -0,0 +1,71 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf
import (
"bytes"
"github.com/dchest/captcha"
"github.com/iwind/TeaGo/rands"
"io"
"time"
)
// CaptchaGenerator captcha generator
type CaptchaGenerator struct {
store captcha.Store
}
func NewCaptchaGenerator() *CaptchaGenerator {
return &CaptchaGenerator{
store: captcha.NewMemoryStore(100_000, 5*time.Minute),
}
}
// NewCaptcha create new captcha
func (this *CaptchaGenerator) NewCaptcha(length int) (captchaId string) {
captchaId = rands.HexString(16)
if length <= 0 || length > 20 {
length = 4
}
this.store.Set(captchaId, captcha.RandomDigits(length))
return
}
// WriteImage write image to front writer
func (this *CaptchaGenerator) WriteImage(w io.Writer, id string, width, height int) error {
var d = this.store.Get(id, false)
if d == nil {
return captcha.ErrNotFound
}
_, err := captcha.NewImage(id, d, width, height).WriteTo(w)
return err
}
// Verify user input
func (this *CaptchaGenerator) Verify(id string, digits string) bool {
var countDigits = len(digits)
if countDigits == 0 {
return false
}
var value = this.store.Get(id, true)
if len(value) != countDigits {
return false
}
var nb = make([]byte, countDigits)
for i := 0; i < countDigits; i++ {
var d = digits[i]
if d >= '0' && d <= '9' {
nb[i] = d - '0'
}
}
return bytes.Equal(nb, value)
}
// Get captcha data
func (this *CaptchaGenerator) Get(id string) []byte {
return this.store.Get(id, false)
}

View File

@@ -0,0 +1,87 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/types"
"runtime"
"strings"
"testing"
"time"
)
func TestCaptchaGenerator_NewCaptcha(t *testing.T) {
var a = assert.NewAssertion(t)
var generator = waf.NewCaptchaGenerator()
var captchaId = generator.NewCaptcha(6)
t.Log("captchaId:", captchaId)
var digits = generator.Get(captchaId)
var s []string
for _, digit := range digits {
s = append(s, types.String(digit))
}
t.Log(strings.Join(s, " "))
a.IsTrue(generator.Verify(captchaId, strings.Join(s, "")))
a.IsFalse(generator.Verify(captchaId, strings.Join(s, "")))
}
func TestCaptchaGenerator_NewCaptcha_UTF8(t *testing.T) {
var a = assert.NewAssertion(t)
var generator = waf.NewCaptchaGenerator()
var captchaId = generator.NewCaptcha(6)
t.Log("captchaId:", captchaId)
var digits = generator.Get(captchaId)
var s []string
for _, digit := range digits {
s = append(s, types.String(digit))
}
t.Log(strings.Join(s, " "))
a.IsFalse(generator.Verify(captchaId, "中文真的很长"))
}
func TestCaptchaGenerator_NewCaptcha_Memory(t *testing.T) {
runtime.GC()
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var generator = waf.NewCaptchaGenerator()
for i := 0; i < 1_000_000; i++ {
generator.NewCaptcha(6)
}
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Second)
}
runtime.GC()
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log((stat2.HeapInuse-stat1.HeapInuse)>>10, "KiB")
_ = generator
}
func BenchmarkNewCaptchaGenerator(b *testing.B) {
runtime.GOMAXPROCS(4)
var generator = waf.NewCaptchaGenerator()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
generator.NewCaptcha(6)
}
})
}

View File

@@ -0,0 +1,70 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf_test
import (
"bytes"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/dchest/captcha"
"runtime"
"testing"
"time"
)
func TestCaptchaMemory(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var count = 5_000
var before = time.Now()
for i := 0; i < count; i++ {
var id = captcha.NewLen(6)
var writer = &bytes.Buffer{}
err := captcha.WriteImage(writer, id, 200, 100)
if err != nil {
t.Fatal(err)
}
captcha.VerifyString(id, "abc")
}
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log((stat2.HeapInuse-stat1.HeapInuse)>>20, "MB", fmt.Sprintf("%.0f QPS", float64(count)/time.Since(before).Seconds()))
}
func BenchmarkCaptcha_VerifyCode_100_50(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var id = captcha.NewLen(6)
var writer = &bytes.Buffer{}
err := captcha.WriteImage(writer, id, 100, 50)
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkCaptcha_VerifyCode_200_100(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var id = captcha.NewLen(6)
var writer = &bytes.Buffer{}
err := captcha.WriteImage(writer, id, 200, 100)
if err != nil {
b.Fatal(err)
}
_ = id
}
})
}

File diff suppressed because one or more lines are too long

View File

@@ -12,11 +12,11 @@ type RequestAllCheckpoint struct {
}
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var valueBytes = []byte{}
var valueBytes = [][]byte{}
if len(req.WAFRaw().RequestURI) > 0 {
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
valueBytes = append(valueBytes, []byte(req.WAFRaw().RequestURI))
} else if req.WAFRaw().URL != nil {
valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
valueBytes = append(valueBytes, []byte(req.WAFRaw().URL.RequestURI()))
}
if this.RequestBodyIsEmpty(req) {
@@ -25,8 +25,6 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
}
if req.WAFRaw().Body != nil {
valueBytes = append(valueBytes, ' ')
var bodyData = req.WAFGetCacheBody()
hasRequestBody = true
if len(bodyData) == 0 {
@@ -39,7 +37,9 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
req.WAFSetCacheBody(data)
req.WAFRestoreBody(data)
}
valueBytes = append(valueBytes, bodyData...)
if len(bodyData) > 0 {
valueBytes = append(valueBytes, bodyData)
}
}
value = valueBytes

View File

@@ -25,8 +25,14 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
if userErr != nil {
t.Fatal(userErr)
}
t.Log(v)
t.Log(types.String(v))
if v != nil {
vv, ok := v.([][]byte)
if ok {
for _, v2 := range vv {
t.Log(string(v2), ":", v2)
}
}
}
body, err := io.ReadAll(req.Body)
if err != nil {

View File

@@ -0,0 +1,32 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/maps"
"strings"
)
type RequestHeaderNamesCheckpoint struct {
Checkpoint
}
func (this *RequestHeaderNamesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var headerNames = []string{}
for k := range req.WAFRaw().Header {
headerNames = append(headerNames, k)
}
value = strings.Join(headerNames, "\n")
return
}
func (this *RequestHeaderNamesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options, ruleId)
}
return
}
func (this *RequestHeaderNamesCheckpoint) CacheLife() utils.CacheLife {
return utils.CacheShortLife
}

View File

@@ -0,0 +1,23 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestHeaderNamesCheckpoint_RequestValue(t *testing.T) {
var checkpoint = &checkpoints.RequestHeaderNamesCheckpoint{}
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
if err != nil {
t.Fatal(err)
}
rawReq.Header.Set("Accept", "text/html")
rawReq.Header.Set("User-Agent", "Chrome")
rawReq.Header.Set("Accept-Encoding", "br, gzip")
var req = requests.NewTestRequest(rawReq)
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}

View File

@@ -23,5 +23,5 @@ func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *
}
func (this *RequestRefererCheckpoint) CacheLife() utils.CacheLife {
return utils.CacheShortLife
return utils.CacheMiddleLife
}

View File

@@ -0,0 +1,44 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/maps"
)
type RequestRefererOriginCheckpoint struct {
Checkpoint
}
func (this *RequestRefererOriginCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var s []string
var referer = req.WAFRaw().Referer()
if len(referer) > 0 {
s = append(s, referer)
}
var origin = req.WAFRaw().Header.Get("Origin")
if len(origin) > 0 {
s = append(s, origin)
}
if len(s) > 0 {
value = s
} else {
value = ""
}
return
}
func (this *RequestRefererOriginCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options, ruleId)
}
return
}
func (this *RequestRefererOriginCheckpoint) CacheLife() utils.CacheLife {
return utils.CacheMiddleLife
}

View File

@@ -0,0 +1,38 @@
package checkpoints_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestRefererOriginCheckpoint_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
if err != nil {
t.Fatal(err)
}
var req = requests.NewTestRequest(rawReq)
var checkpoint = &checkpoints.RequestRefererOriginCheckpoint{}
{
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
{
rawReq.Header.Set("Referer", "https://example.com/hello.yaml")
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
{
rawReq.Header.Set("Origin", "https://example.com/world.yaml")
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
{
rawReq.Header.Del("Referer")
rawReq.Header.Set("Origin", "https://example.com/world.yaml")
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
}

View File

@@ -163,7 +163,15 @@ var AllCheckpoints = []*CheckpointDefinition{
Priority: 100,
},
{
Name: "请求来源URL",
Name: "请求来源",
Prefix: "refererOrigin",
Description: "请求报头中的Referer或Origin值",
HasParams: false,
Instance: new(RequestRefererOriginCheckpoint),
Priority: 100,
},
{
Name: "请求来源Referer",
Prefix: "referer",
Description: "请求Header中的Referer值",
HasParams: false,
@@ -226,6 +234,14 @@ var AllCheckpoints = []*CheckpointDefinition{
Instance: new(RequestHeadersCheckpoint),
Priority: 100,
},
{
Name: "所有请求报头名称",
Prefix: "headerNames",
Description: "使用换行符(\\n隔开的报头名称字符串每行一个名称",
HasParams: false,
Instance: new(RequestHeaderNamesCheckpoint),
Priority: 100,
},
{
Name: "单个Header值",
Prefix: "header",

View File

@@ -0,0 +1,32 @@
Copyright (c) 2012-2016, Nick Galbreath
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
https://github.com/client9/libinjection
http://opensource.org/licenses/BSD-3-Clause

View File

@@ -0,0 +1 @@
copy from https://github.com/libinjection/libinjection

View File

@@ -0,0 +1,65 @@
/**
* Copyright 2012-2016 Nick Galbreath
* nickg@client9.com
* BSD License -- see COPYING.txt for details
*
* https://libinjection.client9.com/
*
*/
#ifndef LIBINJECTION_H
#define LIBINJECTION_H
#ifdef __cplusplus
# define LIBINJECTION_BEGIN_DECLS extern "C" {
# define LIBINJECTION_END_DECLS }
#else
# define LIBINJECTION_BEGIN_DECLS
# define LIBINJECTION_END_DECLS
#endif
LIBINJECTION_BEGIN_DECLS
/*
* Pull in size_t
*/
#include <string.h>
/*
* Version info.
*
* This is moved into a function to allow SWIG and other auto-generated
* binding to not be modified during minor release changes. We change
* change the version number in the c source file, and not regenerated
* the binding
*
* See python's normalized version
* http://www.python.org/dev/peps/pep-0386/#normalizedversion
*/
const char* libinjection_version(void);
/**
* Simple API for SQLi detection - returns a SQLi fingerprint or NULL
* is benign input
*
* \param[in] s input string, may contain nulls, does not need to be null-terminated
* \param[in] slen input string length
* \param[out] fingerprint buffer of 8+ characters. c-string,
* \return 1 if SQLi, 0 if benign. fingerprint will be set or set to empty string.
*/
int libinjection_sqli(const char* s, size_t slen, char fingerprint[]);
/** ALPHA version of xss detector.
*
* NOT DONE.
*
* \param[in] s input string, may contain nulls, does not need to be null-terminated
* \param[in] slen input string length
* \return 1 if XSS found, 0 if benign
*
*/
int libinjection_xss(const char* s, size_t slen);
LIBINJECTION_END_DECLS
#endif /* LIBINJECTION_H */

View File

@@ -0,0 +1,868 @@
#include "libinjection_html5.h"
#include <string.h>
#include <assert.h>
#ifdef DEBUG
#include <stdio.h>
#define TRACE() printf("%s:%d\n", __FUNCTION__, __LINE__)
#else
#define TRACE()
#endif
#define CHAR_EOF -1
#define CHAR_NULL 0
#define CHAR_BANG 33
#define CHAR_DOUBLE 34
#define CHAR_PERCENT 37
#define CHAR_SINGLE 39
#define CHAR_DASH 45
#define CHAR_SLASH 47
#define CHAR_LT 60
#define CHAR_EQUALS 61
#define CHAR_GT 62
#define CHAR_QUESTION 63
#define CHAR_RIGHTB 93
#define CHAR_TICK 96
/* prototypes */
static int h5_skip_white(h5_state_t* hs);
static int h5_is_white(char ch);
static int h5_state_eof(h5_state_t* hs);
static int h5_state_data(h5_state_t* hs);
static int h5_state_tag_open(h5_state_t* hs);
static int h5_state_tag_name(h5_state_t* hs);
static int h5_state_tag_name_close(h5_state_t* hs);
static int h5_state_end_tag_open(h5_state_t* hs);
static int h5_state_self_closing_start_tag(h5_state_t* hs);
static int h5_state_attribute_name(h5_state_t* hs);
static int h5_state_after_attribute_name(h5_state_t* hs);
static int h5_state_before_attribute_name(h5_state_t* hs);
static int h5_state_before_attribute_value(h5_state_t* hs);
static int h5_state_attribute_value_double_quote(h5_state_t* hs);
static int h5_state_attribute_value_single_quote(h5_state_t* hs);
static int h5_state_attribute_value_back_quote(h5_state_t* hs);
static int h5_state_attribute_value_no_quote(h5_state_t* hs);
static int h5_state_after_attribute_value_quoted_state(h5_state_t* hs);
static int h5_state_comment(h5_state_t* hs);
static int h5_state_cdata(h5_state_t* hs);
/* 12.2.4.44 */
static int h5_state_bogus_comment(h5_state_t* hs);
static int h5_state_bogus_comment2(h5_state_t* hs);
/* 12.2.4.45 */
static int h5_state_markup_declaration_open(h5_state_t* hs);
/* 8.2.4.52 */
static int h5_state_doctype(h5_state_t* hs);
/**
* public function
*/
void libinjection_h5_init(h5_state_t* hs, const char* s, size_t len, enum html5_flags flags)
{
memset(hs, 0, sizeof(h5_state_t));
hs->s = s;
hs->len = len;
switch (flags) {
case DATA_STATE:
hs->state = h5_state_data;
break;
case VALUE_NO_QUOTE:
hs->state = h5_state_before_attribute_name;
break;
case VALUE_SINGLE_QUOTE:
hs->state = h5_state_attribute_value_single_quote;
break;
case VALUE_DOUBLE_QUOTE:
hs->state = h5_state_attribute_value_double_quote;
break;
case VALUE_BACK_QUOTE:
hs->state = h5_state_attribute_value_back_quote;
break;
}
}
/**
* public function
*/
int libinjection_h5_next(h5_state_t* hs)
{
assert(hs->state != NULL);
return (*hs->state)(hs);
}
/**
* Everything below here is private
*
*/
static int h5_is_white(char ch)
{
/*
* \t = horizontal tab = 0x09
* \n = newline = 0x0A
* \v = vertical tab = 0x0B
* \f = form feed = 0x0C
* \r = cr = 0x0D
*/
return strchr(" \t\n\v\f\r", ch) != NULL;
}
static int h5_skip_white(h5_state_t* hs)
{
char ch;
while (hs->pos < hs->len) {
ch = hs->s[hs->pos];
switch (ch) {
case 0x00: /* IE only */
case 0x20:
case 0x09:
case 0x0A:
case 0x0B: /* IE only */
case 0x0C:
case 0x0D: /* IE only */
hs->pos += 1;
break;
default:
return ch;
}
}
return CHAR_EOF;
}
static int h5_state_eof(h5_state_t* hs)
{
/* eliminate unused function argument warning */
(void)hs;
return 0;
}
static int h5_state_data(h5_state_t* hs)
{
const char* idx;
TRACE();
assert(hs->len >= hs->pos);
idx = (const char*) memchr(hs->s + hs->pos, CHAR_LT, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = DATA_TEXT;
hs->state = h5_state_eof;
if (hs->token_len == 0) {
return 0;
}
} else {
hs->token_start = hs->s + hs->pos;
hs->token_type = DATA_TEXT;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
hs->state = h5_state_tag_open;
if (hs->token_len == 0) {
return h5_state_tag_open(hs);
}
}
return 1;
}
/**
* 12 2.4.8
*/
static int h5_state_tag_open(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_BANG) {
hs->pos += 1;
return h5_state_markup_declaration_open(hs);
} else if (ch == CHAR_SLASH) {
hs->pos += 1;
hs->is_close = 1;
return h5_state_end_tag_open(hs);
} else if (ch == CHAR_QUESTION) {
hs->pos += 1;
return h5_state_bogus_comment(hs);
} else if (ch == CHAR_PERCENT) {
/* this is not in spec.. alternative comment format used
by IE <= 9 and Safari < 4.0.3 */
hs->pos += 1;
return h5_state_bogus_comment2(hs);
} else if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
return h5_state_tag_name(hs);
} else if (ch == CHAR_NULL) {
/* IE-ism NULL characters are ignored */
return h5_state_tag_name(hs);
} else {
/* user input mistake in configuring state */
if (hs->pos == 0) {
return h5_state_data(hs);
}
hs->token_start = hs->s + hs->pos - 1;
hs->token_len = 1;
hs->token_type = DATA_TEXT;
hs->state = h5_state_data;
return 1;
}
}
/**
* 12.2.4.9
*/
static int h5_state_end_tag_open(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_GT) {
return h5_state_data(hs);
} else if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
return h5_state_tag_name(hs);
}
hs->is_close = 0;
return h5_state_bogus_comment(hs);
}
/*
*
*/
static int h5_state_tag_name_close(h5_state_t* hs)
{
TRACE();
hs->is_close = 0;
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
if (hs->pos < hs->len) {
hs->state = h5_state_data;
} else {
hs->state = h5_state_eof;
}
return 1;
}
/**
* 12.2.4.10
*/
static int h5_state_tag_name(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos;
while (pos < hs->len) {
ch = hs->s[pos];
if (ch == 0) {
/* special non-standard case */
/* allow nulls in tag name */
/* some old browsers apparently allow and ignore them */
pos += 1;
} else if (h5_is_white(ch)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->pos = pos + 1;
hs->state = h5_state_before_attribute_name;
return 1;
} else if (ch == CHAR_SLASH) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->pos = pos + 1;
hs->state = h5_state_self_closing_start_tag;
return 1;
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
if (hs->is_close) {
hs->pos = pos + 1;
hs->is_close = 0;
hs->token_type = TAG_CLOSE;
hs->state = h5_state_data;
} else {
hs->pos = pos;
hs->token_type = TAG_NAME_OPEN;
hs->state = h5_state_tag_name_close;
}
return 1;
} else {
pos += 1;
}
}
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->state = h5_state_eof;
return 1;
}
/**
* 12.2.4.34
*/
static int h5_state_before_attribute_name(h5_state_t* hs)
{
int ch;
TRACE();
/* for manual tail call optimization, see comment below */
tail_call:;
ch = h5_skip_white(hs);
switch (ch) {
case CHAR_EOF: {
return 0;
}
case CHAR_SLASH: {
hs->pos += 1;
/* Logically, We want to call h5_state_self_closing_start_tag(hs) here.
As this function may call us back and the compiler
might not implement automatic tail call optimization,
this might result in a deep recursion.
We detect this case here and start over with the current state.
*/
if (hs->pos < hs->len && hs->s[hs->pos] != CHAR_GT) {
goto tail_call;
}
return h5_state_self_closing_start_tag(hs);
}
case CHAR_GT: {
hs->state = h5_state_data;
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
return 1;
}
default: {
return h5_state_attribute_name(hs);
}
}
}
static int h5_state_attribute_name(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos + 1;
while (pos < hs->len) {
ch = hs->s[pos];
if (h5_is_white(ch)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_after_attribute_name;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_SLASH) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_self_closing_start_tag;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_EQUALS) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_before_attribute_value;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_tag_name_close;
hs->pos = pos;
return 1;
} else {
pos += 1;
}
}
/* EOF */
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_eof;
hs->pos = hs->len;
return 1;
}
/**
* 12.2.4.36
*/
static int h5_state_after_attribute_name(h5_state_t* hs)
{
int c;
TRACE();
c = h5_skip_white(hs);
switch (c) {
case CHAR_EOF: {
return 0;
}
case CHAR_SLASH: {
hs->pos += 1;
return h5_state_self_closing_start_tag(hs);
}
case CHAR_EQUALS: {
hs->pos += 1;
return h5_state_before_attribute_value(hs);
}
case CHAR_GT: {
return h5_state_tag_name_close(hs);
}
default: {
return h5_state_attribute_name(hs);
}
}
}
/**
* 12.2.4.37
*/
static int h5_state_before_attribute_value(h5_state_t* hs)
{
int c;
TRACE();
c = h5_skip_white(hs);
if (c == CHAR_EOF) {
hs->state = h5_state_eof;
return 0;
}
if (c == CHAR_DOUBLE) {
return h5_state_attribute_value_double_quote(hs);
} else if (c == CHAR_SINGLE) {
return h5_state_attribute_value_single_quote(hs);
} else if (c == CHAR_TICK) {
/* NON STANDARD IE */
return h5_state_attribute_value_back_quote(hs);
} else {
return h5_state_attribute_value_no_quote(hs);
}
}
static int h5_state_attribute_value_quote(h5_state_t* hs, char qchar)
{
const char* idx;
TRACE();
/* skip initial quote in normal case.
* don't do this "if (pos == 0)" since it means we have started
* in a non-data state. given an input of '><foo
* we want to make 0-length attribute name
*/
if (hs->pos > 0) {
hs->pos += 1;
}
idx = (const char*) memchr(hs->s + hs->pos, qchar, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_VALUE;
hs->state = h5_state_eof;
} else {
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->token_type = ATTR_VALUE;
hs->state = h5_state_after_attribute_value_quoted_state;
hs->pos += hs->token_len + 1;
}
return 1;
}
static
int h5_state_attribute_value_double_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_DOUBLE);
}
static
int h5_state_attribute_value_single_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_SINGLE);
}
static
int h5_state_attribute_value_back_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_TICK);
}
static int h5_state_attribute_value_no_quote(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos;
while (pos < hs->len) {
ch = hs->s[pos];
if (h5_is_white(ch)) {
hs->token_type = ATTR_VALUE;
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->pos = pos + 1;
hs->state = h5_state_before_attribute_name;
return 1;
} else if (ch == CHAR_GT) {
hs->token_type = ATTR_VALUE;
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->pos = pos;
hs->state = h5_state_tag_name_close;
return 1;
}
pos += 1;
}
TRACE();
/* EOF */
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_VALUE;
return 1;
}
/**
* 12.2.4.41
*/
static int h5_state_after_attribute_value_quoted_state(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (h5_is_white(ch)) {
hs->pos += 1;
return h5_state_before_attribute_name(hs);
} else if (ch == CHAR_SLASH) {
hs->pos += 1;
return h5_state_self_closing_start_tag(hs);
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
hs->state = h5_state_data;
return 1;
} else {
return h5_state_before_attribute_name(hs);
}
}
/**
* 12.2.4.43
*
* WARNING: This function is partially inlined into h5_state_before_attribute_name()
*/
static int h5_state_self_closing_start_tag(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_GT) {
assert(hs->pos > 0);
hs->token_start = hs->s + hs->pos -1;
hs->token_len = 2;
hs->token_type = TAG_NAME_SELFCLOSE;
hs->state = h5_state_data;
hs->pos += 1;
return 1;
} else {
return h5_state_before_attribute_name(hs);
}
}
/**
* 12.2.4.44
*/
static int h5_state_bogus_comment(h5_state_t* hs)
{
const char* idx;
TRACE();
idx = (const char*) memchr(hs->s + hs->pos, CHAR_GT, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->pos = hs->len;
hs->state = h5_state_eof;
} else {
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
hs->state = h5_state_data;
}
hs->token_type = TAG_COMMENT;
return 1;
}
/**
* 12.2.4.44 ALT
*/
static int h5_state_bogus_comment2(h5_state_t* hs)
{
const char* idx;
size_t pos;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_PERCENT, hs->len - pos);
if (idx == NULL || (idx + 1 >= hs->s + hs->len)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->pos = hs->len;
hs->token_type = TAG_COMMENT;
hs->state = h5_state_eof;
return 1;
}
if (*(idx +1) != CHAR_GT) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
/* ends in %> */
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 2;
hs->state = h5_state_data;
hs->token_type = TAG_COMMENT;
return 1;
}
}
/**
* 8.2.4.45
*/
static int h5_state_markup_declaration_open(h5_state_t* hs)
{
size_t remaining;
TRACE();
remaining = hs->len - hs->pos;
if (remaining >= 7 &&
/* case insensitive */
(hs->s[hs->pos + 0] == 'D' || hs->s[hs->pos + 0] == 'd') &&
(hs->s[hs->pos + 1] == 'O' || hs->s[hs->pos + 1] == 'o') &&
(hs->s[hs->pos + 2] == 'C' || hs->s[hs->pos + 2] == 'c') &&
(hs->s[hs->pos + 3] == 'T' || hs->s[hs->pos + 3] == 't') &&
(hs->s[hs->pos + 4] == 'Y' || hs->s[hs->pos + 4] == 'y') &&
(hs->s[hs->pos + 5] == 'P' || hs->s[hs->pos + 5] == 'p') &&
(hs->s[hs->pos + 6] == 'E' || hs->s[hs->pos + 6] == 'e')
) {
return h5_state_doctype(hs);
} else if (remaining >= 7 &&
/* upper case required */
hs->s[hs->pos + 0] == '[' &&
hs->s[hs->pos + 1] == 'C' &&
hs->s[hs->pos + 2] == 'D' &&
hs->s[hs->pos + 3] == 'A' &&
hs->s[hs->pos + 4] == 'T' &&
hs->s[hs->pos + 5] == 'A' &&
hs->s[hs->pos + 6] == '['
) {
hs->pos += 7;
return h5_state_cdata(hs);
} else if (remaining >= 2 &&
hs->s[hs->pos + 0] == '-' &&
hs->s[hs->pos + 1] == '-') {
hs->pos += 2;
return h5_state_comment(hs);
}
return h5_state_bogus_comment(hs);
}
/**
* 12.2.4.48
* 12.2.4.49
* 12.2.4.50
* 12.2.4.51
* state machine spec is confusing since it can only look
* at one character at a time but simply it's comments end by:
* 1) EOF
* 2) ending in -->
* 3) ending in -!>
*/
static int h5_state_comment(h5_state_t* hs)
{
char ch;
const char* idx;
size_t pos;
size_t offset;
const char* end = hs->s + hs->len;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_DASH, hs->len - pos);
/* did not find anything or has less than 3 chars left */
if (idx == NULL || idx > hs->s + hs->len - 3) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
offset = 1;
/* skip all nulls */
while (idx + offset < end && *(idx + offset) == 0) {
offset += 1;
}
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
ch = *(idx + offset);
if (ch != CHAR_DASH && ch != CHAR_BANG) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
/* need to test */
#if 0
/* skip all nulls */
while (idx + offset < end && *(idx + offset) == 0) {
offset += 1;
}
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
#endif
offset += 1;
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
ch = *(idx + offset);
if (ch != CHAR_GT) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
offset += 1;
/* ends in --> or -!> */
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx + offset - hs->s);
hs->state = h5_state_data;
hs->token_type = TAG_COMMENT;
return 1;
}
}
static int h5_state_cdata(h5_state_t* hs)
{
const char* idx;
size_t pos;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_RIGHTB, hs->len - pos);
/* did not find anything or has less than 3 chars left */
if (idx == NULL || idx > hs->s + hs->len - 3) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = DATA_TEXT;
return 1;
} else if ( *(idx+1) == CHAR_RIGHTB && *(idx+2) == CHAR_GT) {
hs->state = h5_state_data;
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 3;
hs->token_type = DATA_TEXT;
return 1;
} else {
pos = (size_t)(idx - hs->s) + 1;
}
}
}
/**
* 8.2.4.52
* http://www.w3.org/html/wg/drafts/html/master/syntax.html#doctype-state
*/
static int h5_state_doctype(h5_state_t* hs)
{
const char* idx;
TRACE();
hs->token_start = hs->s + hs->pos;
hs->token_type = DOCTYPE;
idx = (const char*) memchr(hs->s + hs->pos, CHAR_GT, hs->len - hs->pos);
if (idx == NULL) {
hs->state = h5_state_eof;
hs->token_len = hs->len - hs->pos;
} else {
hs->state = h5_state_data;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
}
return 1;
}

View File

@@ -0,0 +1,54 @@
#ifndef LIBINJECTION_HTML5
#define LIBINJECTION_HTML5
#ifdef __cplusplus
extern "C" {
#endif
/* pull in size_t */
#include <stddef.h>
enum html5_type {
DATA_TEXT
, TAG_NAME_OPEN
, TAG_NAME_CLOSE
, TAG_NAME_SELFCLOSE
, TAG_DATA
, TAG_CLOSE
, ATTR_NAME
, ATTR_VALUE
, TAG_COMMENT
, DOCTYPE
};
enum html5_flags {
DATA_STATE
, VALUE_NO_QUOTE
, VALUE_SINGLE_QUOTE
, VALUE_DOUBLE_QUOTE
, VALUE_BACK_QUOTE
};
struct h5_state;
typedef int (*ptr_html5_state)(struct h5_state*);
typedef struct h5_state {
const char* s;
size_t len;
size_t pos;
int is_close;
ptr_html5_state state;
const char* token_start;
size_t token_len;
enum html5_type token_type;
} h5_state_t;
void libinjection_h5_init(h5_state_t* hs, const char* s, size_t len, enum html5_flags);
int libinjection_h5_next(h5_state_t* hs);
#ifdef __cplusplus
}
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,294 @@
/**
* Copyright 2012-2016 Nick Galbreath
* nickg@client9.com
* BSD License -- see `COPYING.txt` for details
*
* https://libinjection.client9.com/
*
*/
#ifndef LIBINJECTION_SQLI_H
#define LIBINJECTION_SQLI_H
#ifdef __cplusplus
extern "C" {
#endif
/*
* Pull in size_t
*/
#include <string.h>
enum sqli_flags {
FLAG_NONE = 0
, FLAG_QUOTE_NONE = 1 /* 1 << 0 */
, FLAG_QUOTE_SINGLE = 2 /* 1 << 1 */
, FLAG_QUOTE_DOUBLE = 4 /* 1 << 2 */
, FLAG_SQL_ANSI = 8 /* 1 << 3 */
, FLAG_SQL_MYSQL = 16 /* 1 << 4 */
};
enum lookup_type {
LOOKUP_WORD = 1
, LOOKUP_TYPE = 2
, LOOKUP_OPERATOR = 3
, LOOKUP_FINGERPRINT = 4
};
struct libinjection_sqli_token {
#ifdef SWIG
%immutable;
#endif
/*
* position and length of token
* in original string
*/
size_t pos;
size_t len;
/* count:
* in type 'v', used for number of opening '@'
* but maybe used in other contexts
*/
int count;
char type;
char str_open;
char str_close;
char val[32];
};
typedef struct libinjection_sqli_token stoken_t;
/**
* Pointer to function, takes c-string input,
* returns '\0' for no match, else a char
*/
struct libinjection_sqli_state;
typedef char (*ptr_lookup_fn)(struct libinjection_sqli_state*, int lookuptype, const char* word, size_t len);
struct libinjection_sqli_state {
#ifdef SWIG
%immutable;
#endif
/*
* input, does not need to be null terminated.
* it is also not modified.
*/
const char *s;
/*
* input length
*/
size_t slen;
/*
* How to lookup a word or fingerprint
*/
ptr_lookup_fn lookup;
void* userdata;
/*
*
*/
int flags;
/*
* pos is the index in the string during tokenization
*/
size_t pos;
#ifndef SWIG
/* for SWIG.. don't use this.. use functional API instead */
/* MAX TOKENS + 1 since we use one extra token
* to determine the type of the previous token
*/
struct libinjection_sqli_token tokenvec[8];
#endif
/*
* Pointer to token position in tokenvec, above
*/
struct libinjection_sqli_token *current;
/*
* fingerprint pattern c-string
* +1 for ending null
* Minimum of 8 bytes to add gcc's -fstack-protector to work
*/
char fingerprint[8];
/*
* Line number of code that said decided if the input was SQLi or
* not. Most of the time it's line that said "it's not a matching
* fingerprint" but there is other logic that sometimes approves
* an input. This is only useful for debugging.
*
*/
int reason;
/* Number of ddw (dash-dash-white) comments
* These comments are in the form of
* '--[whitespace]' or '--[EOF]'
*
* All databases treat this as a comment.
*/
int stats_comment_ddw;
/* Number of ddx (dash-dash-[notwhite]) comments
*
* ANSI SQL treats these are comments, MySQL treats this as
* two unary operators '-' '-'
*
* If you are parsing result returns FALSE and
* stats_comment_dd > 0, you should reparse with
* COMMENT_MYSQL
*
*/
int stats_comment_ddx;
/*
* c-style comments found /x .. x/
*/
int stats_comment_c;
/* '#' operators or MySQL EOL comments found
*
*/
int stats_comment_hash;
/*
* number of tokens folded away
*/
int stats_folds;
/*
* total tokens processed
*/
int stats_tokens;
};
typedef struct libinjection_sqli_state sfilter;
struct libinjection_sqli_token* libinjection_sqli_get_token(
struct libinjection_sqli_state* sql_state, int i);
/*
* Version info.
*
* This is moved into a function to allow SWIG and other auto-generated
* binding to not be modified during minor release changes. We change
* change the version number in the c source file, and not regenerated
* the binding
*
* See python's normalized version
* http://www.python.org/dev/peps/pep-0386/#normalizedversion
*/
const char* libinjection_version(void);
/**
*
*/
void libinjection_sqli_init(struct libinjection_sqli_state *sf,
const char* s, size_t len,
int flags);
/**
* Main API: tests for SQLi in three possible contexts, no quotes,
* single quote and double quote
*
* \param sql_state core data structure
*
* \return 1 (true) if SQLi, 0 (false) if benign
*/
int libinjection_is_sqli(struct libinjection_sqli_state* sql_state);
/* FOR HACKERS ONLY
* provides deep hooks into the decision making process
*/
void libinjection_sqli_callback(struct libinjection_sqli_state *sf,
ptr_lookup_fn fn,
void* userdata);
/*
* Resets state, but keeps initial string and callbacks
*/
void libinjection_sqli_reset(struct libinjection_sqli_state *sf,
int flags);
/**
*
*/
/**
* This detects SQLi in a single context, mostly useful for custom
* logic and debugging.
*
* \param sql_state Main data structure
* \param flags flags to adjust parsing
*
* \returns a pointer to sfilter.fingerprint as convenience
* do not free!
*
*/
const char* libinjection_sqli_fingerprint(struct libinjection_sqli_state *sql_state,
int flags);
/**
* The default "word" to token-type or fingerprint function. This
* uses a ASCII case-insensitive binary tree.
*/
char libinjection_sqli_lookup_word(struct libinjection_sqli_state *sql_state,
int lookup_type,
const char* str,
size_t len);
/* Streaming tokenization interface.
*
* sql_state->current is updated with the current token.
*
* \returns 1, has a token, keep going, or 0 no tokens
*
*/
int libinjection_sqli_tokenize(struct libinjection_sqli_state *sf);
/**
* parses and folds input, up to 5 tokens
*
*/
int libinjection_sqli_fold(struct libinjection_sqli_state *sf);
/** The built-in default function to match fingerprints
* and do false negative/positive analysis. This calls the following
* two functions. With this, you over-ride one part or the other.
*
* return libinjection_sqli_blacklist(sql_state) &&
* libinjection_sqli_not_whitelist(sql_state);
*
* \param sql_state should be filled out after libinjection_sqli_fingerprint is called
*/
int libinjection_sqli_check_fingerprint(struct libinjection_sqli_state * sql_state);
/* Given a pattern determine if it's a SQLi pattern.
*
* \return TRUE if sqli, false otherwise
*/
int libinjection_sqli_blacklist(struct libinjection_sqli_state* sql_state);
/* Given a positive match for a pattern (i.e. pattern is SQLi), this function
* does additional analysis to reduce false positives.
*
* \return TRUE if SQLi, false otherwise
*/
int libinjection_sqli_not_whitelist(struct libinjection_sqli_state * sql_state);
#ifdef __cplusplus
}
#endif
#endif /* LIBINJECTION_SQLI_H */

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,857 @@
#include "libinjection.h"
#include "libinjection_xss.h"
#include "libinjection_html5.h"
#include <assert.h>
#include <stdio.h>
typedef enum attribute {
TYPE_NONE
, TYPE_BLACK /* ban always */
, TYPE_ATTR_URL /* attribute value takes a URL-like object */
, TYPE_STYLE
, TYPE_ATTR_INDIRECT /* attribute *name* is given in *value* */
} attribute_t;
static attribute_t is_black_attr(const char* s, size_t len);
static int is_black_tag(const char* s, size_t len);
static int is_black_url(const char* s, size_t len);
static int cstrcasecmp_with_null(const char *a, const char *b, size_t n);
static int html_decode_char_at(const char* src, size_t len, size_t* consumed);
static int htmlencode_startswith(const char *a/* prefix */, const char *b /* src */, size_t n);
typedef struct stringtype {
const char* name;
attribute_t atype;
} stringtype_t;
static const int gsHexDecodeMap[256] = {
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 256, 256,
256, 256, 256, 256, 256, 10, 11, 12, 13, 14, 15, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 10, 11, 12, 13, 14, 15, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256
};
static int html_decode_char_at(const char* src, size_t len, size_t* consumed)
{
int val = 0;
size_t i;
int ch;
if (len == 0 || src == NULL) {
*consumed = 0;
return -1;
}
*consumed = 1;
if (*src != '&' || len < 2) {
return (unsigned char)(*src);
}
if (*(src+1) != '#') {
/* normally this would be for named entities
* but for this case we don't actually care
*/
return '&';
}
if (*(src+2) == 'x' || *(src+2) == 'X') {
ch = (unsigned char) (*(src+3));
ch = gsHexDecodeMap[ch];
if (ch == 256) {
/* degenerate case '&#[?]' */
return '&';
}
val = ch;
i = 4;
while (i < len) {
ch = (unsigned char) src[i];
if (ch == ';') {
*consumed = i + 1;
return val;
}
ch = gsHexDecodeMap[ch];
if (ch == 256) {
*consumed = i;
return val;
}
val = (val * 16) + ch;
if (val > 0x1000FF) {
return '&';
}
++i;
}
*consumed = i;
return val;
} else {
i = 2;
ch = (unsigned char) src[i];
if (ch < '0' || ch > '9') {
return '&';
}
val = ch - '0';
i += 1;
while (i < len) {
ch = (unsigned char) src[i];
if (ch == ';') {
*consumed = i + 1;
return val;
}
if (ch < '0' || ch > '9') {
*consumed = i;
return val;
}
val = (val * 10) + (ch - '0');
if (val > 0x1000FF) {
return '&';
}
++i;
}
*consumed = i;
return val;
}
}
/*
* These were mostly extracted from: https://raw.githubusercontent.com/WebKit/WebKit/main/Source/WebCore/dom/EventNames.h
*
* view-source:
* data:
* javascript:
* events:
*/
static stringtype_t BLACKATTREVENT[] = {
{ "ABORT", TYPE_BLACK }
, { "ACTIVATE", TYPE_BLACK }
, { "ACTIVE", TYPE_BLACK }
, { "ADDSOURCEBUFFER", TYPE_BLACK }
, { "ADDSTREAM", TYPE_BLACK }
, { "ADDTRACK", TYPE_BLACK }
, { "AFTERPRINT", TYPE_BLACK }
, { "ANIMATIONCANCEL", TYPE_BLACK }
, { "ANIMATIONEND", TYPE_BLACK }
, { "ANIMATIONITERATION", TYPE_BLACK }
, { "ANIMATIONSTART", TYPE_BLACK }
, { "AUDIOEND", TYPE_BLACK }
, { "AUDIOPROCESS", TYPE_BLACK }
, { "AUDIOSTART", TYPE_BLACK }
, { "AUTOCOMPLETEERROR", TYPE_BLACK }
, { "AUTOCOMPLETE", TYPE_BLACK }
, { "BEFOREACTIVATE", TYPE_BLACK }
, { "BEFORECOPY", TYPE_BLACK }
, { "BEFORECUT", TYPE_BLACK }
, { "BEFOREINPUT", TYPE_BLACK }
, { "BEFORELOAD", TYPE_BLACK }
, { "BEFOREPASTE", TYPE_BLACK }
, { "BEFOREPRINT", TYPE_BLACK }
, { "BEFOREUNLOAD", TYPE_BLACK }
, { "BEGINEVENT", TYPE_BLACK }
, { "BLOCKED", TYPE_BLACK }
, { "BLUR", TYPE_BLACK }
, { "BOUNDARY", TYPE_BLACK }
, { "BUFFEREDAMOUNTLOW", TYPE_BLACK }
, { "CACHED", TYPE_BLACK }
, { "CANCEL", TYPE_BLACK }
, { "CANPLAYTHROUGH", TYPE_BLACK }
, { "CANPLAY", TYPE_BLACK }
, { "CHANGE", TYPE_BLACK }
, { "CHARGINGCHANGE", TYPE_BLACK }
, { "CHARGINGTIMECHANGE", TYPE_BLACK }
, { "CHECKING", TYPE_BLACK }
, { "CLICK", TYPE_BLACK }
, { "CLOSE", TYPE_BLACK }
, { "COMPLETE", TYPE_BLACK }
, { "COMPOSITIONEND", TYPE_BLACK }
, { "COMPOSITIONSTART", TYPE_BLACK }
, { "COMPOSITIONUPDATE", TYPE_BLACK }
, { "CONNECTING", TYPE_BLACK }
, { "CONNECTIONSTATECHANGE", TYPE_BLACK }
, { "CONNECT", TYPE_BLACK }
, { "CONTEXTMENU", TYPE_BLACK }
, { "CONTROLLERCHANGE", TYPE_BLACK }
, { "COPY", TYPE_BLACK }
, { "CUECHANGE", TYPE_BLACK }
, { "CUT", TYPE_BLACK }
, { "DATAAVAILABLE", TYPE_BLACK }
, { "DATACHANNEL", TYPE_BLACK }
, { "DBLCLICK", TYPE_BLACK }
, { "DEVICECHANGE", TYPE_BLACK }
, { "DEVICEMOTION", TYPE_BLACK }
, { "DEVICEORIENTATION", TYPE_BLACK }
, { "DISCHARGINGTIMECHANGE", TYPE_BLACK }
, { "DISCONNECT", TYPE_BLACK }
, { "DOMACTIVATE", TYPE_BLACK }
, { "DOMCHARACTERDATAMODIFIED", TYPE_BLACK }
, { "DOMCONTENTLOADED", TYPE_BLACK }
, { "DOMFOCUSIN", TYPE_BLACK }
, { "DOMFOCUSOUT", TYPE_BLACK }
, { "DOMNODEINSERTEDINTODOCUMENT", TYPE_BLACK }
, { "DOMNODEINSERTED", TYPE_BLACK }
, { "DOMNODEREMOVEDFROMDOCUMENT", TYPE_BLACK }
, { "DOMNODEREMOVED", TYPE_BLACK }
, { "DOMSUBTREEMODIFIED", TYPE_BLACK }
, { "DOWNLOADING", TYPE_BLACK }
, { "DRAGEND", TYPE_BLACK }
, { "DRAGENTER", TYPE_BLACK }
, { "DRAGLEAVE", TYPE_BLACK }
, { "DRAGOVER", TYPE_BLACK }
, { "DRAGSTART", TYPE_BLACK }
, { "DRAG", TYPE_BLACK }
, { "DROP", TYPE_BLACK }
, { "DURATIONCHANGE", TYPE_BLACK }
, { "EMPTIED", TYPE_BLACK }
, { "ENCRYPTED", TYPE_BLACK }
, { "ENDED", TYPE_BLACK }
, { "ENDEVENT", TYPE_BLACK }
, { "END", TYPE_BLACK }
, { "ENTERPICTUREINPICTURE", TYPE_BLACK }
, { "ENTER", TYPE_BLACK }
, { "ERROR", TYPE_BLACK }
, { "EXIT", TYPE_BLACK }
, { "FETCH", TYPE_BLACK }
, { "FINISH", TYPE_BLACK }
, { "FOCUSIN", TYPE_BLACK }
, { "FOCUSOUT", TYPE_BLACK }
, { "FOCUS", TYPE_BLACK }
, { "FORMCHANGE", TYPE_BLACK }
, { "FORMINPUT", TYPE_BLACK }
, { "GAMEPADCONNECTED", TYPE_BLACK }
, { "GAMEPADDISCONNECTED", TYPE_BLACK }
, { "GESTURECHANGE", TYPE_BLACK }
, { "GESTUREEND", TYPE_BLACK }
, { "GESTURESCROLLEND", TYPE_BLACK }
, { "GESTURESCROLLSTART", TYPE_BLACK }
, { "GESTURESCROLLUPDATE", TYPE_BLACK }
, { "GESTURESTART", TYPE_BLACK }
, { "GESTURETAPDOWN", TYPE_BLACK }
, { "GESTURETAP", TYPE_BLACK }
, { "GOTPOINTERCAPTURE", TYPE_BLACK }
, { "HASHCHANGE", TYPE_BLACK }
, { "ICECANDIDATEERROR", TYPE_BLACK }
, { "ICECANDIDATE", TYPE_BLACK }
, { "ICECONNECTIONSTATECHANGE", TYPE_BLACK }
, { "ICEGATHERINGSTATECHANGE", TYPE_BLACK }
, { "INACTIVE", TYPE_BLACK }
, { "INPUTSOURCESCHANGE", TYPE_BLACK }
, { "INPUT", TYPE_BLACK }
, { "INSTALL", TYPE_BLACK }
, { "INVALID", TYPE_BLACK }
, { "KEYDOWN", TYPE_BLACK }
, { "KEYPRESS", TYPE_BLACK }
, { "KEYSTATUSESCHANGE", TYPE_BLACK }
, { "KEYUP", TYPE_BLACK }
, { "LANGUAGECHANGE", TYPE_BLACK }
, { "LEAVEPICTUREINPICTURE", TYPE_BLACK }
, { "LEVELCHANGE", TYPE_BLACK }
, { "LOADEDDATA", TYPE_BLACK }
, { "LOADEDMETADATA", TYPE_BLACK }
, { "LOADEND", TYPE_BLACK }
, { "LOADINGDONE", TYPE_BLACK }
, { "LOADINGERROR", TYPE_BLACK }
, { "LOADING", TYPE_BLACK }
, { "LOADSTART", TYPE_BLACK }
, { "LOAD", TYPE_BLACK }
, { "LOSTPOINTERCAPTURE", TYPE_BLACK }
, { "MARK", TYPE_BLACK }
, { "MERCHANTVALIDATION", TYPE_BLACK }
, { "MESSAGEERROR", TYPE_BLACK }
, { "MESSAGE", TYPE_BLACK }
, { "MOUSEDOWN", TYPE_BLACK }
, { "MOUSEENTER", TYPE_BLACK }
, { "MOUSELEAVE", TYPE_BLACK }
, { "MOUSEMOVE", TYPE_BLACK }
, { "MOUSEOUT", TYPE_BLACK }
, { "MOUSEOVER", TYPE_BLACK }
, { "MOUSEUP", TYPE_BLACK }
, { "MOUSEWHEEL", TYPE_BLACK }
, { "MUTE", TYPE_BLACK }
, { "NEGOTIATIONNEEDED", TYPE_BLACK }
, { "NEXTTRACK", TYPE_BLACK }
, { "NOMATCH", TYPE_BLACK }
, { "NOUPDATE", TYPE_BLACK }
, { "OBSOLETE", TYPE_BLACK }
, { "OFFLINE", TYPE_BLACK }
, { "ONLINE", TYPE_BLACK }
, { "OPEN", TYPE_BLACK }
, { "ORIENTATIONCHANGE", TYPE_BLACK }
, { "OVERCONSTRAINED", TYPE_BLACK }
, { "OVERFLOWCHANGED", TYPE_BLACK }
, { "PAGEHIDE", TYPE_BLACK }
, { "PAGESHOW", TYPE_BLACK }
, { "PASTE", TYPE_BLACK }
, { "PAUSE", TYPE_BLACK }
, { "PAYERDETAILCHANGE", TYPE_BLACK }
, { "PAYMENTAUTHORIZED", TYPE_BLACK }
, { "PAYMENTMETHODCHANGE", TYPE_BLACK }
, { "PAYMENTMETHODSELECTED", TYPE_BLACK }
, { "PLAYING", TYPE_BLACK }
, { "PLAY", TYPE_BLACK }
, { "POINTERCANCEL", TYPE_BLACK }
, { "POINTERDOWN", TYPE_BLACK }
, { "POINTERENTER", TYPE_BLACK }
, { "POINTERLEAVE", TYPE_BLACK }
, { "POINTERLOCKCHANGE", TYPE_BLACK }
, { "POINTERLOCKERROR", TYPE_BLACK }
, { "POINTERMOVE", TYPE_BLACK }
, { "POINTEROUT", TYPE_BLACK }
, { "POINTEROVER", TYPE_BLACK }
, { "POINTERUP", TYPE_BLACK }
, { "POPSTATE", TYPE_BLACK }
, { "PREVIOUSTRACK", TYPE_BLACK }
, { "PROCESSORERROR", TYPE_BLACK }
, { "PROGRESS", TYPE_BLACK }
, { "PROPERTYCHANGE", TYPE_BLACK }
, { "RATECHANGE", TYPE_BLACK }
, { "READYSTATECHANGE", TYPE_BLACK }
, { "REJECTIONHANDLED", TYPE_BLACK }
, { "REMOVESOURCEBUFFER", TYPE_BLACK }
, { "REMOVESTREAM", TYPE_BLACK }
, { "REMOVETRACK", TYPE_BLACK }
, { "REMOVE", TYPE_BLACK }
, { "RESET", TYPE_BLACK }
, { "RESIZE", TYPE_BLACK }
, { "RESOURCETIMINGBUFFERFULL", TYPE_BLACK }
, { "RESULT", TYPE_BLACK }
, { "RESUME", TYPE_BLACK }
, { "SCROLL", TYPE_BLACK }
, { "SEARCH", TYPE_BLACK }
, { "SECURITYPOLICYVIOLATION", TYPE_BLACK }
, { "SEEKED", TYPE_BLACK }
, { "SEEKING", TYPE_BLACK }
, { "SELECTEND", TYPE_BLACK }
, { "SELECTIONCHANGE", TYPE_BLACK }
, { "SELECTSTART", TYPE_BLACK }
, { "SELECT", TYPE_BLACK }
, { "SHIPPINGADDRESSCHANGE", TYPE_BLACK }
, { "SHIPPINGCONTACTSELECTED", TYPE_BLACK }
, { "SHIPPINGMETHODSELECTED", TYPE_BLACK }
, { "SHIPPINGOPTIONCHANGE", TYPE_BLACK }
, { "SHOW", TYPE_BLACK }
, { "SIGNALINGSTATECHANGE", TYPE_BLACK }
, { "SLOTCHANGE", TYPE_BLACK }
, { "SOUNDEND", TYPE_BLACK }
, { "SOUNDSTART", TYPE_BLACK }
, { "SOURCECLOSE", TYPE_BLACK }
, { "SOURCEENDED", TYPE_BLACK }
, { "SOURCEOPEN", TYPE_BLACK }
, { "SPEECHEND", TYPE_BLACK }
, { "SPEECHSTART", TYPE_BLACK }
, { "SQUEEZEEND", TYPE_BLACK }
, { "SQUEEZESTART", TYPE_BLACK }
, { "SQUEEZE", TYPE_BLACK }
, { "STALLED", TYPE_BLACK }
, { "STARTED", TYPE_BLACK }
, { "START", TYPE_BLACK }
, { "STATECHANGE", TYPE_BLACK }
, { "STOP", TYPE_BLACK }
, { "STORAGE", TYPE_BLACK }
, { "SUBMIT", TYPE_BLACK }
, { "SUCCESS", TYPE_BLACK }
, { "SUSPEND", TYPE_BLACK }
, { "TEXTINPUT", TYPE_BLACK }
, { "TIMEOUT", TYPE_BLACK }
, { "TIMEUPDATE", TYPE_BLACK }
, { "TOGGLE", TYPE_BLACK }
, { "TOGGLE", TYPE_BLACK }
, { "TONECHANGE", TYPE_BLACK }
, { "TOUCHCANCEL", TYPE_BLACK }
, { "TOUCHEND", TYPE_BLACK }
, { "TOUCHFORCECHANGE", TYPE_BLACK }
, { "TOUCHMOVE", TYPE_BLACK }
, { "TOUCHSTART", TYPE_BLACK }
, { "TRACK", TYPE_BLACK }
, { "TRANSITIONCANCEL", TYPE_BLACK }
, { "TRANSITIONEND", TYPE_BLACK }
, { "TRANSITIONRUN", TYPE_BLACK }
, { "TRANSITIONSTART", TYPE_BLACK }
, { "UNCAPTUREDERROR", TYPE_BLACK }
, { "UNHANDLEDREJECTION", TYPE_BLACK }
, { "UNLOAD", TYPE_BLACK }
, { "UNMUTE", TYPE_BLACK }
, { "UPDATEEND", TYPE_BLACK }
, { "UPDATEFOUND", TYPE_BLACK }
, { "UPDATEREADY", TYPE_BLACK }
, { "UPDATESTART", TYPE_BLACK }
, { "UPDATE", TYPE_BLACK }
, { "UPGRADENEEDED", TYPE_BLACK }
, { "VALIDATEMERCHANT", TYPE_BLACK }
, { "VERSIONCHANGE", TYPE_BLACK }
, { "VISIBILITYCHANGE", TYPE_BLACK }
, { "VOLUMECHANGE", TYPE_BLACK }
, { "WAITINGFORKEY", TYPE_BLACK }
, { "WAITING", TYPE_BLACK }
, { "WEBGLCONTEXTCHANGED", TYPE_BLACK }
, { "WEBGLCONTEXTCREATIONERROR", TYPE_BLACK }
, { "WEBGLCONTEXTLOST", TYPE_BLACK }
, { "WEBGLCONTEXTRESTORED", TYPE_BLACK }
, { "WEBKITANIMATIONEND", TYPE_BLACK }
, { "WEBKITANIMATIONITERATION", TYPE_BLACK }
, { "WEBKITANIMATIONSTART", TYPE_BLACK }
, { "WEBKITBEFORETEXTINSERTED", TYPE_BLACK }
, { "WEBKITBEGINFULLSCREEN", TYPE_BLACK }
, { "WEBKITCURRENTPLAYBACKTARGETISWIRELESSCHANGED", TYPE_BLACK }
, { "WEBKITENDFULLSCREEN", TYPE_BLACK }
, { "WEBKITFULLSCREENCHANGE", TYPE_BLACK }
, { "WEBKITFULLSCREENERROR", TYPE_BLACK }
, { "WEBKITKEYADDED", TYPE_BLACK }
, { "WEBKITKEYERROR", TYPE_BLACK }
, { "WEBKITKEYMESSAGE", TYPE_BLACK }
, { "WEBKITMOUSEFORCECHANGED", TYPE_BLACK }
, { "WEBKITMOUSEFORCEDOWN", TYPE_BLACK }
, { "WEBKITMOUSEFORCEUP", TYPE_BLACK }
, { "WEBKITMOUSEFORCEWILLBEGIN", TYPE_BLACK }
, { "WEBKITNEEDKEY", TYPE_BLACK }
, { "WEBKITNETWORKINFOCHANGE", TYPE_BLACK }
, { "WEBKITPLAYBACKTARGETAVAILABILITYCHANGED", TYPE_BLACK }
, { "WEBKITPRESENTATIONMODECHANGED", TYPE_BLACK }
, { "WEBKITREGIONOVERSETCHANGE", TYPE_BLACK }
, { "WEBKITREMOVESOURCEBUFFER", TYPE_BLACK }
, { "WEBKITSOURCECLOSE", TYPE_BLACK }
, { "WEBKITSOURCEENDED", TYPE_BLACK }
, { "WEBKITSOURCEOPEN", TYPE_BLACK }
, { "WEBKITSPEECHCHANGE", TYPE_BLACK }
, { "WEBKITTRANSITIONEND", TYPE_BLACK }
, { "WEBKITWILLREVEALBOTTOM", TYPE_BLACK }
, { "WEBKITWILLREVEALLEFT", TYPE_BLACK }
, { "WEBKITWILLREVEALRIGHT", TYPE_BLACK }
, { "WEBKITWILLREVEALTOP", TYPE_BLACK }
, { "WHEEL", TYPE_BLACK }
, { "WRITEEND", TYPE_BLACK }
, { "WRITESTART", TYPE_BLACK }
, { "WRITE", TYPE_BLACK }
, { "ZOOM", TYPE_BLACK }
, { NULL, TYPE_NONE }
};
/*
* view-source:
* data:
* javascript:
*/
static stringtype_t BLACKATTR[] = {
{ "ACTION", TYPE_ATTR_URL } /* form */
, { "ATTRIBUTENAME", TYPE_ATTR_INDIRECT } /* SVG allow indirection of attribute names */
, { "BY", TYPE_ATTR_URL } /* SVG */
, { "BACKGROUND", TYPE_ATTR_URL } /* IE6, O11 */
, { "DATAFORMATAS", TYPE_BLACK } /* IE */
, { "DATASRC", TYPE_BLACK } /* IE */
, { "DYNSRC", TYPE_ATTR_URL } /* Obsolete img attribute */
, { "FILTER", TYPE_STYLE } /* Opera, SVG inline style */
, { "FORMACTION", TYPE_ATTR_URL } /* HTML 5 */
, { "FOLDER", TYPE_ATTR_URL } /* Only on A tags, IE-only */
, { "FROM", TYPE_ATTR_URL } /* SVG */
, { "HANDLER", TYPE_ATTR_URL } /* SVG Tiny, Opera */
, { "HREF", TYPE_ATTR_URL }
, { "LOWSRC", TYPE_ATTR_URL } /* Obsolete img attribute */
, { "POSTER", TYPE_ATTR_URL } /* Opera 10,11 */
, { "SRC", TYPE_ATTR_URL }
, { "STYLE", TYPE_STYLE }
, { "TO", TYPE_ATTR_URL } /* SVG */
, { "VALUES", TYPE_ATTR_URL } /* SVG */
, { "XLINK:HREF", TYPE_ATTR_URL }
, { NULL, TYPE_NONE }
};
/* xmlns */
/* `xml-stylesheet` > <eval>, <if expr=> */
/*
static const char* BLACKATTR[] = {
"ATTRIBUTENAME",
"BACKGROUND",
"DATAFORMATAS",
"HREF",
"SCROLL",
"SRC",
"STYLE",
"SRCDOC",
NULL
};
*/
static const char* BLACKTAG[] = {
"APPLET"
/* , "AUDIO" */
, "BASE"
, "COMMENT" /* IE http://html5sec.org/#38 */
, "EMBED"
/* , "FORM" */
, "FRAME"
, "FRAMESET"
, "HANDLER" /* Opera SVG, effectively a script tag */
, "IFRAME"
, "IMPORT"
, "ISINDEX"
, "LINK"
, "LISTENER"
/* , "MARQUEE" */
, "META"
, "NOSCRIPT"
, "OBJECT"
, "SCRIPT"
, "STYLE"
/* , "VIDEO" */
, "VMLFRAME"
, "XML"
, "XSS"
, NULL
};
static int cstrcasecmp_with_null(const char *a, const char *b, size_t n)
{
char ca;
char cb;
/* printf("Comparing to %s %.*s\n", a, (int)n, b); */
while (n-- > 0) {
cb = *b++;
if (cb == '\0') continue;
ca = *a++;
if (cb >= 'a' && cb <= 'z') {
cb -= 0x20;
}
/* printf("Comparing %c vs %c with %d left\n", ca, cb, (int)n); */
if (ca != cb) {
return 1;
}
}
if (*a == 0) {
/* printf(" MATCH \n"); */
return 0;
} else {
return 1;
}
}
/*
* Does an HTML encoded binary string (const char*, length) start with
* a all uppercase c-string (null terminated), case insensitive!
*
* also ignore any embedded nulls in the HTML string!
*
* return 1 if match / starts with
* return 0 if not
*/
static int htmlencode_startswith(const char *a, const char *b, size_t n)
{
size_t consumed;
int cb;
int first = 1;
/* printf("Comparing %s with %.*s\n", a,(int)n,b); */
while (n > 0) {
if (*a == 0) {
/* printf("Match EOL!\n"); */
return 1;
}
cb = html_decode_char_at(b, n, &consumed);
b += consumed;
n -= consumed;
if (first && cb <= 32) {
/* ignore all leading whitespace and control characters */
continue;
}
first = 0;
if (cb == 0) {
/* always ignore null characters in user input */
continue;
}
if (cb == 10) {
/* always ignore vertical tab characters in user input */
/* who allows this?? */
continue;
}
if (cb >= 'a' && cb <= 'z') {
/* upcase */
cb -= 0x20;
}
if (*a != (char) cb) {
/* printf(" %c != %c\n", *a, cb); */
/* mismatch */
return 0;
}
a++;
}
return (*a == 0) ? 1 : 0;
}
static int is_black_tag(const char* s, size_t len)
{
const char** black;
if (len < 3) {
return 0;
}
black = BLACKTAG;
while (*black != NULL) {
if (cstrcasecmp_with_null(*black, s, len) == 0) {
/* printf("Got black tag %s\n", *black); */
return 1;
}
black += 1;
}
/* anything SVG related */
if ((s[0] == 's' || s[0] == 'S') &&
(s[1] == 'v' || s[1] == 'V') &&
(s[2] == 'g' || s[2] == 'G')) {
/* printf("Got SVG tag \n"); */
return 1;
}
/* Anything XSL(t) related */
if ((s[0] == 'x' || s[0] == 'X') &&
(s[1] == 's' || s[1] == 'S') &&
(s[2] == 'l' || s[2] == 'L')) {
/* printf("Got XSL tag\n"); */
return 1;
}
return 0;
}
static attribute_t is_black_attr(const char* s, size_t len)
{
stringtype_t* black;
if (len < 2) {
return TYPE_NONE;
}
if (len >= 5) {
/* JavaScript on.* event handlers */
if ((s[0] == 'o' || s[0] == 'O') && (s[1] == 'n' || s[1] == 'N')) {
black = BLACKATTREVENT;
const char *s_without_on = &s[2]; // start comparing from the third char
while (black->name != NULL) {
if (cstrcasecmp_with_null(black->name, s_without_on, strlen(black->name)) == 0) {
/* printf("Got banned attribute name %s\n", black->name); */
return black->atype;
}
black += 1;
}
}
/* XMLNS can be used to create arbitrary tags */
// goedge: commented for photo uploading
//if (cstrcasecmp_with_null("XMLNS", s, 5) == 0 || cstrcasecmp_with_null("XLINK", s, 5) == 0) {
/* printf("Got XMLNS and XLINK tags\n"); */
// return TYPE_BLACK;
//}
}
black = BLACKATTR;
while (black->name != NULL) {
if (cstrcasecmp_with_null(black->name, s, len) == 0) {
/* printf("Got banned attribute name %s\n", black->name); */
return black->atype;
}
black += 1;
}
return TYPE_NONE;
}
static int is_black_url(const char* s, size_t len)
{
static const char* data_url = "DATA";
static const char* viewsource_url = "VIEW-SOURCE";
/* obsolete but interesting signal */
static const char* vbscript_url = "VBSCRIPT";
/* covers JAVA, JAVASCRIPT, + colon */
static const char* javascript_url = "JAVA";
/* skip whitespace */
while (len > 0 && (*s <= 32 || *s >= 127)) {
/*
* HEY: this is a signed character.
* We are intentionally skipping high-bit characters too
* since they are not ASCII, and Opera sometimes uses UTF-8 whitespace.
*
* Also in EUC-JP some of the high bytes are just ignored.
*/
++s;
--len;
}
if (htmlencode_startswith(data_url, s, len)) {
return 1;
}
if (htmlencode_startswith(viewsource_url, s, len)) {
return 1;
}
if (htmlencode_startswith(javascript_url, s, len)) {
return 1;
}
if (htmlencode_startswith(vbscript_url, s, len)) {
return 1;
}
return 0;
}
int libinjection_is_xss(const char* s, size_t len, int flags)
{
h5_state_t h5;
attribute_t attr = TYPE_NONE;
libinjection_h5_init(&h5, s, len, (enum html5_flags) flags);
while (libinjection_h5_next(&h5)) {
if (h5.token_type != ATTR_VALUE) {
attr = TYPE_NONE;
}
if (h5.token_type == DOCTYPE) {
return 1;
} else if (h5.token_type == TAG_NAME_OPEN) {
if (is_black_tag(h5.token_start, h5.token_len)) {
return 1;
}
} else if (h5.token_type == ATTR_NAME) {
attr = is_black_attr(h5.token_start, h5.token_len);
} else if (h5.token_type == ATTR_VALUE) {
/*
* IE6,7,8 parsing works a bit differently so
* a whole <script> or other black tag might be hiding
* inside an attribute value under HTML 5 parsing
* See http://html5sec.org/#102
* to avoid doing a full reparse of the value, just
* look for "<". This probably need adjusting to
* handle escaped characters
*/
/*
if (memchr(h5.token_start, '<', h5.token_len) != NULL) {
return 1;
}
*/
switch (attr) {
case TYPE_NONE:
break;
case TYPE_BLACK:
return 1;
case TYPE_ATTR_URL:
if (is_black_url(h5.token_start, h5.token_len)) {
return 1;
}
break;
case TYPE_STYLE:
return 1;
case TYPE_ATTR_INDIRECT:
/* an attribute name is specified in a _value_ */
if (is_black_attr(h5.token_start, h5.token_len)) {
return 1;
}
break;
/*
default:
assert(0);
*/
}
attr = TYPE_NONE;
} else if (h5.token_type == TAG_COMMENT) {
/* IE uses a "`" as a tag ending char */
// goedge: commented for photo uploading
/**if (memchr(h5.token_start, '`', h5.token_len) != NULL) {
return 1;
}**/
/* IE conditional comment */
if (h5.token_len > 3) {
if (h5.token_start[0] == '[' &&
(h5.token_start[1] == 'i' || h5.token_start[1] == 'I') &&
(h5.token_start[2] == 'f' || h5.token_start[2] == 'F')) {
return 1;
}
if ((h5.token_start[0] == 'x' || h5.token_start[0] == 'X') &&
(h5.token_start[1] == 'm' || h5.token_start[1] == 'M') &&
(h5.token_start[2] == 'l' || h5.token_start[2] == 'L')) {
return 1;
}
}
if (h5.token_len > 5) {
/* IE <?import pseudo-tag */
if (cstrcasecmp_with_null("IMPORT", h5.token_start, 6) == 0) {
return 1;
}
/* XML Entity definition */
if (cstrcasecmp_with_null("ENTITY", h5.token_start, 6) == 0) {
return 1;
}
}
}
}
return 0;
}
/*
* wrapper
*
*
* const char* s: input string, may contain nulls, does not need to be null-terminated.
* size_t len: input string length.
*
*
*/
int libinjection_xss(const char* s, size_t slen)
{
if (libinjection_is_xss(s, slen, DATA_STATE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_NO_QUOTE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_SINGLE_QUOTE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_DOUBLE_QUOTE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_BACK_QUOTE)) {
return 1;
}
return 0;
}

View File

@@ -0,0 +1,21 @@
#ifndef LIBINJECTION_XSS
#define LIBINJECTION_XSS
#ifdef __cplusplus
extern "C" {
#endif
/**
* HEY THIS ISN'T DONE
*/
/* pull in size_t */
#include <string.h>
int libinjection_is_xss(const char* s, size_t len, int flags);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,313 @@
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <assert.h>
#include "libinjection.h"
#include "libinjection_sqli.h"
#include "libinjection_xss.h"
#ifndef TRUE
#define TRUE 1
#endif
#ifndef FALSE
#define FALSE 0
#endif
static int g_test_ok = 0;
static int g_test_fail = 0;
typedef enum {
MODE_SQLI,
MODE_XSS
} detect_mode_t;
static void usage(const char* program_name);
size_t modp_rtrim(char* str, size_t len);
void modp_toprint(char* str, size_t len);
void test_positive(FILE * fd, const char *fname, detect_mode_t mode,
int flag_invert, int flag_true, int flag_quiet);
int urlcharmap(char ch);
size_t modp_url_decode(char* dest, const char* s, size_t len);
int urlcharmap(char ch) {
switch (ch) {
case '0': return 0;
case '1': return 1;
case '2': return 2;
case '3': return 3;
case '4': return 4;
case '5': return 5;
case '6': return 6;
case '7': return 7;
case '8': return 8;
case '9': return 9;
case 'a': case 'A': return 10;
case 'b': case 'B': return 11;
case 'c': case 'C': return 12;
case 'd': case 'D': return 13;
case 'e': case 'E': return 14;
case 'f': case 'F': return 15;
default:
return 256;
}
}
size_t modp_url_decode(char* dest, const char* s, size_t len)
{
const char* deststart = dest;
size_t i = 0;
int d = 0;
while (i < len) {
switch (s[i]) {
case '+':
*dest++ = ' ';
i += 1;
break;
case '%':
if (i+2 < len) {
d = (urlcharmap(s[i+1]) << 4) | urlcharmap(s[i+2]);
if ( d < 256) {
*dest = (char) d;
dest++;
i += 3; /* loop will increment one time */
} else {
*dest++ = '%';
i += 1;
}
} else {
*dest++ = '%';
i += 1;
}
break;
default:
*dest++ = s[i];
i += 1;
}
}
*dest = '\0';
return (size_t)(dest - deststart); /* compute "strlen" of dest */
}
void modp_toprint(char* str, size_t len)
{
size_t i;
for (i = 0; i < len; ++i) {
if (str[i] < 32 || str[i] > 126) {
str[i] = '?';
}
}
}
size_t modp_rtrim(char* str, size_t len)
{
while (len) {
char c = str[len -1];
if (c == ' ' || c == '\n' || c == '\t' || c == '\r') {
str[len -1] = '\0';
len -= 1;
} else {
break;
}
}
return len;
}
void test_positive(FILE * fd, const char *fname, detect_mode_t mode,
int flag_invert, int flag_true, int flag_quiet)
{
char linebuf[8192];
int issqli = 0;
int linenum = 0;
size_t len;
sfilter sf;
while (fgets(linebuf, sizeof(linebuf), fd)) {
linenum += 1;
len = modp_rtrim(linebuf, strlen(linebuf));
if (len == 0) {
continue;
}
if (linebuf[0] == '#') {
continue;
}
len = modp_url_decode(linebuf, linebuf, len);
switch (mode) {
case MODE_SQLI: {
libinjection_sqli_init(&sf, linebuf, len, 0);
issqli = libinjection_is_sqli(&sf);
break;
}
case MODE_XSS: {
issqli = libinjection_xss(linebuf, len);
break;
}
default:
assert(0);
}
if (issqli) {
g_test_ok += 1;
} else {
g_test_fail += 1;
}
if (!flag_quiet) {
if ((issqli && flag_true && ! flag_invert) ||
(!issqli && flag_true && flag_invert) ||
!flag_true) {
modp_toprint(linebuf, len);
switch (mode) {
case MODE_SQLI: {
/*
* if we didn't find a SQLi and fingerprint from
* sqlstats is is 'sns' or 'snsns' then redo using
* plain context
*/
if (!issqli && (strcmp(sf.fingerprint, "sns") == 0 ||
strcmp(sf.fingerprint, "snsns") == 0)) {
libinjection_sqli_fingerprint(&sf, 0);
}
fprintf(stdout, "%s\t%d\t%s\t%s\t%s\n",
fname, linenum,
(issqli ? "True" : "False"), sf.fingerprint, linebuf);
break;
}
case MODE_XSS: {
fprintf(stdout, "%s\t%d\t%s\t%s\n",
fname, linenum,
(issqli ? "True" : "False"), linebuf);
break;
}
default:
assert(0);
}
}
}
}
}
static void usage(const char* program_name)
{
fprintf(stdout, "usage: %s [flags] [files...]\n", program_name);
fprintf(stdout, "%s\n", "");
fprintf(stdout, "%s\n", "-q --quiet : quiet mode");
fprintf(stdout, "%s\n", "-m --max-fails : number of failed cases need to fail entire test");
fprintf(stdout, "%s\n", "-s INTEGER : repeat each test N time "
"(for performance testing)");
fprintf(stdout, "%s\n", "-t : only print positive matches");
fprintf(stdout, "%s\n", "-x --mode-xss : test input for XSS");
fprintf(stdout, "%s\n", "-i --invert : invert test logic "
"(input is tested for being safe)");
fprintf(stdout, "%s\n", "");
fprintf(stdout, "%s\n", "-? -h -help --help : this page");
fprintf(stdout, "%s\n", "");
}
int main(int argc, const char *argv[])
{
/*
* invert output, by
*/
int flag_invert = FALSE;
/*
* don't print anything.. useful for
* performance monitors, gprof.
*/
int flag_quiet = FALSE;
/*
* only print positive results
* with invert, only print negative results
*/
int flag_true = FALSE;
detect_mode_t mode = MODE_SQLI;
int flag_slow = 1;
int count = 0;
int max = -1;
int i, j;
int offset = 1;
while (offset < argc) {
if (strcmp(argv[offset], "-?") == 0 ||
strcmp(argv[offset], "-h") == 0 ||
strcmp(argv[offset], "-help") == 0 ||
strcmp(argv[offset], "--help") == 0) {
usage(argv[0]);
exit(0);
}
if (strcmp(argv[offset], "-i") == 0) {
offset += 1;
flag_invert = TRUE;
} else if (strcmp(argv[offset], "-q") == 0 ||
strcmp(argv[offset], "--quiet") == 0) {
offset += 1;
flag_quiet = TRUE;
} else if (strcmp(argv[offset], "-t") == 0) {
offset += 1;
flag_true = TRUE;
} else if (strcmp(argv[offset], "-s") == 0) {
offset += 1;
flag_slow = 100;
} else if (strcmp(argv[offset], "-m") == 0 ||
strcmp(argv[offset], "--max-fails") == 0) {
offset += 1;
max = atoi(argv[offset]);
offset += 1;
} else if (strcmp(argv[offset], "-x") == 0 ||
strcmp(argv[offset], "--mode-xss") == 0) {
mode = MODE_XSS;
offset += 1;
} else {
break;
}
}
if (offset == argc) {
test_positive(stdin, "stdin", mode, flag_invert, flag_true, flag_quiet);
} else {
for (j = 0; j < flag_slow; ++j) {
for (i = offset; i < argc; ++i) {
FILE* fd = fopen(argv[i], "r");
if (fd) {
test_positive(fd, argv[i], mode, flag_invert, flag_true, flag_quiet);
fclose(fd);
}
}
}
}
if (!flag_quiet) {
fprintf(stdout, "%s", "\n");
fprintf(stdout, "SQLI : %d\n", g_test_ok);
fprintf(stdout, "SAFE : %d\n", g_test_fail);
fprintf(stdout, "TOTAL : %d\n", g_test_ok + g_test_fail);
}
if (max == -1) {
return 0;
}
count = g_test_ok;
if (flag_invert) {
count = g_test_fail;
}
if (count > max) {
printf("\nThreshold is %d, got %d, failing.\n", max, count);
return 1;
} else {
printf("\nThreshold is %d, got %d, passing.\n", max, count);
return 0;
}
}

View File

@@ -0,0 +1,165 @@
/**
* Copyright 2012, 2013 Nick Galbreath
* nickg@client9.com
* BSD License -- see COPYING.txt for details
*
* This is for testing against files in ../data/ *.txt
* Reads from stdin or a list of files, and emits if a line
* is a SQLi attack or not, and does basic statistics
*
*/
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include "libinjection.h"
#include "libinjection_sqli.h"
void print_string(stoken_t* t);
void print_var(stoken_t* t);
void print_token(stoken_t *t);
void usage(void);
void print_string(stoken_t* t)
{
/* print opening quote */
if (t->str_open != '\0') {
printf("%c", t->str_open);
}
/* print content */
printf("%s", t->val);
/* print closing quote */
if (t->str_close != '\0') {
printf("%c", t->str_close);
}
}
void print_var(stoken_t* t)
{
if (t->count >= 1) {
printf("%c", '@');
}
if (t->count == 2) {
printf("%c", '@');
}
print_string(t);
}
void print_token(stoken_t *t) {
printf("%c ", t->type);
switch (t->type) {
case 's':
print_string(t);
break;
case 'v':
print_var(t);
break;
default:
printf("%s", t->val);
}
printf("%s", "\n");
}
void usage(void) {
printf("\n");
printf("libinjection sqli tester\n");
printf("\n");
printf(" -ca parse as ANSI SQL\n");
printf(" -cm parse as MySQL SQL\n");
printf(" -q0 parse as is\n");
printf(" -q1 parse in single-quote mode\n");
printf(" -q2 parse in doiuble-quote mode\n");
printf("\n");
printf(" -f --fold fold results\n");
printf("\n");
printf(" -d --detect detect SQLI. empty reply = not detected\n");
printf("\n");
}
int main(int argc, const char* argv[])
{
size_t slen;
char* copy;
int flags = 0;
int fold = 0;
int detect = 0;
int i;
int count;
int offset = 1;
int issqli;
sfilter sf;
if (argc < 2) {
usage();
return 1;
}
while (1) {
if (strcmp(argv[offset], "-h") == 0 || strcmp(argv[offset], "-?") == 0 || strcmp(argv[offset], "--help") == 0) {
usage();
return 1;
}
if (strcmp(argv[offset], "-m") == 0) {
flags |= FLAG_SQL_MYSQL;
offset += 1;
}
else if (strcmp(argv[offset], "-f") == 0 || strcmp(argv[offset], "--fold") == 0) {
fold = 1;
offset += 1;
} else if (strcmp(argv[offset], "-d") == 0 || strcmp(argv[offset], "--detect") == 0) {
detect = 1;
offset += 1;
} else if (strcmp(argv[offset], "-ca") == 0) {
flags |= FLAG_SQL_ANSI;
offset += 1;
} else if (strcmp(argv[offset], "-cm") == 0) {
flags |= FLAG_SQL_MYSQL;
offset += 1;
} else if (strcmp(argv[offset], "-q0") == 0) {
flags |= FLAG_QUOTE_NONE;
offset += 1;
} else if (strcmp(argv[offset], "-q1") == 0) {
flags |= FLAG_QUOTE_SINGLE;
offset += 1;
} else if (strcmp(argv[offset], "-q2") == 0) {
flags |= FLAG_QUOTE_DOUBLE;
offset += 1;
} else {
break;
}
}
/* ATTENTION: argv is a C-string, null terminated. We copy this
* to it's own location, WITHOUT null byte. This way, valgrind
* can see if we run past the buffer.
*/
slen = strlen(argv[offset]);
copy = (char* ) malloc(slen);
memcpy(copy, argv[offset], slen);
libinjection_sqli_init(&sf, copy, slen, flags);
if (detect == 1) {
issqli = libinjection_is_sqli(&sf);
if (issqli) {
printf("%s\n", sf.fingerprint);
}
} else if (fold == 1) {
count = libinjection_sqli_fold(&sf);
for (i = 0; i < count; ++i) {
print_token(&(sf.tokenvec[i]));
}
} else {
while (libinjection_sqli_tokenize(&sf)) {
print_token(sf.current);
}
}
free(copy);
return 0;
}

View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
#
# Copyright 2012, 2013 Nick Galbreath
# nickg@client9.com
# BSD License -- see COPYING.txt for details
#
"""
Converts a libinjection JSON data file to a C header (.h) file
"""
import sys
def toc(obj):
""" main routine """
print("""
#ifndef LIBINJECTION_SQLI_DATA_H
#define LIBINJECTION_SQLI_DATA_H
#include "libinjection.h"
#include "libinjection_sqli.h"
typedef struct {
const char *word;
char type;
} keyword_t;
static size_t parse_money(sfilter * sf);
static size_t parse_other(sfilter * sf);
static size_t parse_white(sfilter * sf);
static size_t parse_operator1(sfilter *sf);
static size_t parse_char(sfilter *sf);
static size_t parse_hash(sfilter *sf);
static size_t parse_dash(sfilter *sf);
static size_t parse_slash(sfilter *sf);
static size_t parse_backslash(sfilter * sf);
static size_t parse_operator2(sfilter *sf);
static size_t parse_string(sfilter *sf);
static size_t parse_word(sfilter * sf);
static size_t parse_var(sfilter * sf);
static size_t parse_number(sfilter * sf);
static size_t parse_tick(sfilter * sf);
static size_t parse_ustring(sfilter * sf);
static size_t parse_qstring(sfilter * sf);
static size_t parse_nqstring(sfilter * sf);
static size_t parse_xstring(sfilter * sf);
static size_t parse_bstring(sfilter * sf);
static size_t parse_estring(sfilter * sf);
static size_t parse_bword(sfilter * sf);
""")
#
# Mapping of character to function
#
fnmap = {
'CHAR_WORD' : 'parse_word',
'CHAR_WHITE': 'parse_white',
'CHAR_OP1' : 'parse_operator1',
'CHAR_UNARY': 'parse_operator1',
'CHAR_OP2' : 'parse_operator2',
'CHAR_BANG' : 'parse_operator2',
'CHAR_BACK' : 'parse_backslash',
'CHAR_DASH' : 'parse_dash',
'CHAR_STR' : 'parse_string',
'CHAR_HASH' : 'parse_hash',
'CHAR_NUM' : 'parse_number',
'CHAR_SLASH': 'parse_slash',
'CHAR_SEMICOLON' : 'parse_char',
'CHAR_COMMA': 'parse_char',
'CHAR_LEFTPARENS': 'parse_char',
'CHAR_RIGHTPARENS': 'parse_char',
'CHAR_LEFTBRACE': 'parse_char',
'CHAR_RIGHTBRACE': 'parse_char',
'CHAR_VAR' : 'parse_var',
'CHAR_OTHER': 'parse_other',
'CHAR_MONEY': 'parse_money',
'CHAR_TICK' : 'parse_tick',
'CHAR_UNDERSCORE': 'parse_underscore',
'CHAR_USTRING' : 'parse_ustring',
'CHAR_QSTRING' : 'parse_qstring',
'CHAR_NQSTRING' : 'parse_nqstring',
'CHAR_XSTRING' : 'parse_xstring',
'CHAR_BSTRING' : 'parse_bstring',
'CHAR_ESTRING' : 'parse_estring',
'CHAR_BWORD' : 'parse_bword'
}
print()
print("typedef size_t (*pt2Function)(sfilter *sf);")
print("static const pt2Function char_parse_map[] = {")
pos = 0
for character in obj['charmap']:
print(" &%s, /* %d */" % (fnmap[character], pos))
pos += 1
print("};")
print()
# keywords
# load them
keywords = obj['keywords']
for fingerprint in list(obj['fingerprints']):
fingerprint = '0' + fingerprint.upper()
keywords[fingerprint] = 'F'
needhelp = []
for key in keywords.keys():
if key != key.upper():
needhelp.append(key)
for key in needhelp:
tmpv = keywords[key]
del keywords[key]
keywords[key.upper()] = tmpv
print("static const keyword_t sql_keywords[] = {")
for k in sorted(keywords.keys()):
if len(k) > 31:
sys.stderr.write("ERROR: keyword greater than 32 chars\n")
sys.exit(1)
print(" {\"%s\", '%s'}," % (k, keywords[k]))
print("};")
print("static const size_t sql_keywords_sz = %d;" % (len(keywords), ))
print("#endif")
return 0
if __name__ == '__main__':
import json
sys.exit(toc(json.load(sys.stdin)))

View File

@@ -0,0 +1,3 @@
#define LIBINJECTION_VERSION "3.9.1"
#include "libinjection/src/libinjection_sqli.c"

View File

@@ -0,0 +1,6 @@
#define LIBINJECTION_VERSION "3.9.1"
#include "libinjection/src/libinjection_xss.c"
#include "libinjection/src/libinjection_html5.c"
#define GOEDGE_VERSION "23" // last version is for GoEdge change

View File

@@ -0,0 +1,93 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils
/*
#cgo CFLAGS: -I./libinjection/src
#include <libinjection.h>
#include <stdlib.h>
*/
import "C"
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/cespare/xxhash/v2"
"net/url"
"strconv"
"strings"
"unsafe"
)
// DetectSQLInjectionCache detect sql injection in string with cache
func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
var l = len(input)
if l == 0 {
return false
}
if cacheLife <= 0 || l < 128 || l > utils.MaxCacheDataSize {
return DetectSQLInjection(input)
}
var hash = xxhash.Sum64String(input)
var key = "WAF@SQLI@" + strconv.FormatUint(hash, 10)
var item = utils.SharedCache.Read(key)
if item != nil {
return item.Value == 1
}
var result = DetectSQLInjection(input)
if result {
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
return result
}
// DetectSQLInjection detect sql injection in string
func DetectSQLInjection(input string) bool {
if len(input) == 0 {
return false
}
if detectSQLInjectionOne(input) {
return true
}
// 兼容 /PATH?URI
if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 {
var argsIndex = strings.Index(input, "?")
if argsIndex > 0 {
var args = input[argsIndex+1:]
unescapeArgs, err := url.QueryUnescape(args)
if err == nil && args != unescapeArgs {
return detectSQLInjectionOne(args) || detectSQLInjectionOne(unescapeArgs)
} else {
return detectSQLInjectionOne(args)
}
}
} else {
unescapedInput, err := url.QueryUnescape(input)
if err == nil && input != unescapedInput {
return detectSQLInjectionOne(unescapedInput)
}
}
return false
}
func detectSQLInjectionOne(input string) bool {
if len(input) == 0 {
return false
}
var fingerprint [8]C.char
var fingerprintPtr = (*C.char)(unsafe.Pointer(&fingerprint[0]))
var cInput = C.CString(input)
defer C.free(unsafe.Pointer(cInput))
return C.libinjection_sqli(cInput, C.size_t(len(input)), fingerprintPtr) == 1
}

View File

@@ -0,0 +1,128 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"runtime"
"strings"
"testing"
)
func TestDetectSQLInjection(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(injectionutils.DetectSQLInjection("' UNION SELECT * FROM myTable"))
a.IsTrue(injectionutils.DetectSQLInjection("id=1 ' UNION select * from a"))
a.IsTrue(injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--"))
a.IsFalse(injectionutils.DetectSQLInjection("' UNION SELECT1 * FROM myTable"))
a.IsFalse(injectionutils.DetectSQLInjection("1234"))
a.IsFalse(injectionutils.DetectSQLInjection(""))
a.IsTrue(injectionutils.DetectSQLInjection("id=123 OR 1=1&b=2"))
a.IsTrue(injectionutils.DetectSQLInjection("id=123&b=456&c=1' or 2=2"))
a.IsFalse(injectionutils.DetectSQLInjection("?"))
a.IsFalse(injectionutils.DetectSQLInjection("/hello?age=22"))
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1"))
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/' or 1=1"))
}
func BenchmarkDetectSQLInjection(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--")
}
})
}
func BenchmarkDetectSQLInjection_URL(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1")
}
})
}
func BenchmarkDetectSQLInjection_Normal_Small(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=1234")
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Small(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=" + types.String(rands.Int64()%10000))
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Middle(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/search?q=libinjection+fingerprint&newwindow=1&sca_esv=589290862&sxsrf=AMwHvKnxuLoejn2XlNniffC12E_xc35M7Q%3A1702090118361&ei=htvzzebfFZfo1e8PvLGggAk&ved=0ahUKEwjTsYmnq4GDAxUWdPOHHbwkCJAQ4ddDCBA&uact=5&oq=libinjection+fingerprint&gs_lp=Egxnd3Mtd2l6LXNlcnAiGIxpYmluamVjdGlvbmBmaW5nKXJwcmludTIEEAAYHjIGVAAYCBgeSiEaUPkRWKFZcAJ4AZABAJgBHgGgAfoEqgwDMC40uAEGyAEA-AEBwgIKEAFYTxjWMuiwA-IDBBgAVteIBgGQBgI&sclient=gws-wiz-serp#ip=1")
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Small_Cache(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjectionCache("/sql/injection?id="+types.String(rands.Int64()%10000), utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectSQLInjection_Normal_Large(b *testing.B) {
runtime.GOMAXPROCS(4)
var s = strings.Repeat("A", 512)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=" + types.String(rands.Int64()%10000) + "&s=" + s + "&v=%20")
}
})
}
func BenchmarkDetectSQLInjection_Normal_Large_Cache(b *testing.B) {
runtime.GOMAXPROCS(4)
var s = strings.Repeat("A", 512)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjectionCache("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s, utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectSQLInjection_URL_Unescape(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1")
}
})
}

View File

@@ -0,0 +1,90 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils
/*
#cgo CFLAGS: -I./libinjection/src
#include <libinjection.h>
#include <stdlib.h>
*/
import "C"
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/cespare/xxhash/v2"
"net/url"
"strconv"
"strings"
"unsafe"
)
func DetectXSSCache(input string, cacheLife utils.CacheLife) bool {
var l = len(input)
if l == 0 {
return false
}
if cacheLife <= 0 || l < 512 || l > utils.MaxCacheDataSize {
return DetectXSS(input)
}
var hash = xxhash.Sum64String(input)
var key = "WAF@XSS@" + strconv.FormatUint(hash, 10)
var item = utils.SharedCache.Read(key)
if item != nil {
return item.Value == 1
}
var result = DetectXSS(input)
if result {
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
return result
}
// DetectXSS detect XSS in string
func DetectXSS(input string) bool {
if len(input) == 0 {
return false
}
if detectXSSOne(input) {
return true
}
// 兼容 /PATH?URI
if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 {
var argsIndex = strings.Index(input, "?")
if argsIndex > 0 {
var args = input[argsIndex+1:]
unescapeArgs, err := url.QueryUnescape(args)
if err == nil && args != unescapeArgs {
return detectXSSOne(args) || detectXSSOne(unescapeArgs)
} else {
return detectXSSOne(args)
}
}
} else {
unescapedInput, err := url.QueryUnescape(input)
if err == nil && input != unescapedInput {
return detectXSSOne(unescapedInput)
}
}
return false
}
func detectXSSOne(input string) bool {
if len(input) == 0 {
return false
}
var cInput = C.CString(input)
defer C.free(unsafe.Pointer(cInput))
return C.libinjection_xss(cInput, C.size_t(len(input))) == 1
}

View File

@@ -0,0 +1,80 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/assert"
"runtime"
"testing"
)
func TestDetectXSS(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(injectionutils.DetectXSS(""))
a.IsFalse(injectionutils.DetectXSS("abc"))
a.IsTrue(injectionutils.DetectXSS("<script>"))
a.IsTrue(injectionutils.DetectXSS("<link>"))
a.IsFalse(injectionutils.DetectXSS("<html><span>"))
a.IsFalse(injectionutils.DetectXSS("&lt;script&gt;"))
a.IsTrue(injectionutils.DetectXSS("/path?onmousedown=a"))
a.IsTrue(injectionutils.DetectXSS("/path?onkeyup=a"))
a.IsTrue(injectionutils.DetectXSS("onkeyup=a"))
a.IsTrue(injectionutils.DetectXSS("<iframe scrolling='no'>"))
a.IsFalse(injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>"))
a.IsTrue(injectionutils.DetectXSS("name=s&description=%3Cscript+src%3D%22a.js%22%3Edddd%3C%2Fscript%3E"))
a.IsFalse(injectionutils.DetectXSS(`<x:xmpmeta xmlns:x="adobe:ns:meta/" x:xmptk="XMP Core 6.0.0">
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<rdf:Description rdf:about=""
xmlns:tiff="http://ns.adobe.com/tiff/1.0/">
<tiff:Orientation>1</tiff:Orientation>
</rdf:Description>
</rdf:RDF>
</x:xmpmeta>`)) // included in some photo files
}
func BenchmarkDetectXSS_MISS(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
if result {
b.Fatal("'result' should not be 'true'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
}
})
}
func BenchmarkDetectXSS_MISS_Cache(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
if result {
b.Fatal("'result' should not be 'true'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSSCache("<html><body><span>RequestId: 1234567890</span></body></html>", utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectXSS_HIT(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
if !result {
b.Fatal("'result' should not be 'false'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
}
})
}

View File

@@ -3,12 +3,18 @@
package waf
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"os"
"sync"
"sync/atomic"
)
@@ -25,11 +31,30 @@ const (
const IPTypeAll = "*"
func init() {
if !teaconst.IsMain {
return
}
var cacheFile = Tea.Root + "/data/waf_white_list.cache"
// save
events.On(events.EventTerminated, func() {
_ = SharedIPWhiteList.Save(cacheFile)
})
// load
go func() {
_ = SharedIPWhiteList.Load(cacheFile)
_ = os.Remove(cacheFile)
}()
}
// IPList IP列表管理
type IPList struct {
expireList *expires.List
ipMap map[string]uint64 // ip => id
idMap map[uint64]string // id => ip
ipMap map[string]uint64 // ip info => id
idMap map[uint64]string // id => ip info
listType IPListType
id uint64
@@ -47,7 +72,7 @@ func NewIPList(listType IPListType) *IPList {
listType: listType,
}
e := expires.NewList()
var e = expires.NewList()
list.expireList = e
e.OnGC(func(itemId uint64) {
@@ -206,6 +231,85 @@ func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
}
}
// Save to local file
func (this *IPList) Save(path string) error {
var itemMaps = []maps.Map{} // [ {ip info, expiresAt }, ... ]
this.locker.Lock()
defer this.locker.Unlock()
// prevent too many items
if len(this.ipMap) > 100_000 {
return nil
}
for ipInfo, id := range this.ipMap {
var expiresAt = this.expireList.ExpiresAt(id)
if expiresAt <= 0 {
continue
}
itemMaps = append(itemMaps, maps.Map{
"ip": ipInfo,
"expiresAt": expiresAt,
})
}
itemMapsJSON, err := json.Marshal(itemMaps)
if err != nil {
return err
}
return os.WriteFile(path, itemMapsJSON, 0666)
}
// Load from local file
func (this *IPList) Load(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if len(data) == 0 {
return nil
}
var itemMaps = []maps.Map{}
err = json.Unmarshal(data, &itemMaps)
if err != nil {
return err
}
this.locker.Lock()
defer this.locker.Unlock()
for _, itemMap := range itemMaps {
var ip = itemMap.GetString("ip")
var expiresAt = itemMap.GetInt64("expiresAt")
if len(ip) == 0 || expiresAt < fasttime.Now().Unix()+10 /** seconds **/ {
continue
}
var id = this.nextId()
this.expireList.Add(id, expiresAt)
this.ipMap[ip] = id
this.idMap[id] = ip
}
return nil
}
// IPMap get ipMap
func (this *IPList) IPMap() map[string]uint64 {
this.locker.RLock()
defer this.locker.RUnlock()
return this.ipMap
}
// IdMap get idMap
func (this *IPList) IdMap() map[uint64]string {
this.locker.RLock()
defer this.locker.RUnlock()
return this.idMap
}
func (this *IPList) remove(id uint64) {
this.locker.Lock()
ip, ok := this.idMap[id]

View File

@@ -1,12 +1,16 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
package waf_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
timeutil "github.com/iwind/TeaGo/utils/time"
"os"
"runtime"
"strconv"
"testing"
@@ -14,35 +18,33 @@ import (
)
func TestNewIPList(t *testing.T) {
var list = NewIPList(IPListTypeDeny)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
var list = waf.NewIPList(waf.IPListTypeDeny)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
list.RemoveIP("127.0.0.1", 1, false)
logs.PrintAsJSON(list.ipMap, t)
logs.PrintAsJSON(list.idMap, t)
logs.PrintAsJSON(list.IPMap(), t)
logs.PrintAsJSON(list.IdMap(), t)
}
func TestIPList_Expire(t *testing.T) {
var list = NewIPList(IPListTypeDeny)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6)
var list = waf.NewIPList(waf.IPListTypeDeny)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6)
var ticker = time.NewTicker(1 * time.Second)
for range ticker.C {
t.Log("====")
list.locker.Lock()
logs.PrintAsJSON(list.ipMap, t)
logs.PrintAsJSON(list.idMap, t)
list.locker.Unlock()
if len(list.idMap) == 0 {
logs.PrintAsJSON(list.IPMap(), t)
logs.PrintAsJSON(list.IdMap(), t)
if len(list.IdMap()) == 0 {
break
}
}
@@ -51,54 +53,78 @@ func TestIPList_Expire(t *testing.T) {
func TestIPList_Contains(t *testing.T) {
var a = assert.NewAssertion(t)
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
//list.RemoveIP("192.168.1.100")
{
a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
a.IsTrue(list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
}
{
a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
a.IsFalse(list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
}
}
func TestIPList_ContainsExpires(t *testing.T) {
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
// list.RemoveIP("192.168.1.100", 1, false)
for _, ip := range []string{"192.168.1.100", "192.168.2.100"} {
expiresAt, ok := list.ContainsExpires(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip)
expiresAt, ok := list.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip)
t.Log(ok, expiresAt, timeutil.FormatTime("Y-m-d H:i:s", expiresAt))
}
}
func TestIPList_Save(t *testing.T) {
var a = assert.NewAssertion(t)
var list = waf.NewIPList(waf.IPListTypeAllow)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100", time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 2, "192.168.1.101", time.Now().Unix()+3600)
var file = Tea.Root + "/data/waf.iplist.json"
err := list.Save(file)
if err != nil {
t.Fatal(err)
}
var newList = waf.NewIPList(waf.IPListTypeAllow)
err = newList.Load(file)
if err != nil {
t.Fatal(err)
}
a.IsTrue(newList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
_ = os.Remove(file)
}
func BenchmarkIPList_Add(b *testing.B) {
runtime.GOMAXPROCS(1)
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
for i := 0; i < b.N; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
b.Log(len(list.ipMap))
b.Log(len(list.IPMap()))
}
func BenchmarkIPList_Has(b *testing.B) {
runtime.GOMAXPROCS(1)
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
b.ResetTimer()
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
for i := 0; i < b.N; i++ {
list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")
list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")
}
}

View File

@@ -9,17 +9,20 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/filterconfigs"
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/runes"
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/values"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/TeaGo/utils/string"
stringutil "github.com/iwind/TeaGo/utils/string"
"net"
"reflect"
"regexp"
"sort"
"strings"
)
@@ -29,14 +32,14 @@ var singleParamRegexp = regexp.MustCompile(`^\${[\w.-]+}$`)
type Rule struct {
Id int64
Description string `yaml:"description" json:"description"`
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"`
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
Priority int `yaml:"priority" json:"priority"`
Description string `yaml:"description" json:"description"`
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"`
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]any `yaml:"checkpointOptions" json:"checkpointOptions"`
Priority int `yaml:"priority" json:"priority"`
checkpointFinder func(prefix string) checkpoints.CheckpointInterface
@@ -50,12 +53,13 @@ type Rule struct {
ipRangeListValue *values.IPRangeList
stringValues []string
stringValueRunes [][]rune
ipList *values.StringList
floatValue float64
reg *re.Regexp
regCacheLife utils.CacheLife
reg *re.Regexp
cacheLife utils.CacheLife
}
func NewRule() *Rule {
@@ -77,7 +81,7 @@ func (this *Rule) Init() error {
this.floatValue = types.Float64(this.Value)
case RuleOperatorNeq:
this.floatValue = types.Float64(this.Value)
case RuleOperatorContainsAny, RuleOperatorContainsAll:
case RuleOperatorContainsAny, RuleOperatorContainsAll, RuleOperatorContainsAnyWord, RuleOperatorContainsAllWords, RuleOperatorNotContainsAnyWord:
this.stringValues = []string{}
if len(this.Value) > 0 {
var lines = strings.Split(this.Value, "\n")
@@ -91,9 +95,17 @@ func (this *Rule) Init() error {
}
}
}
if this.Operator == RuleOperatorContainsAnyWord || this.Operator == RuleOperatorContainsAllWords || this.Operator == RuleOperatorNotContainsAnyWord {
sort.Strings(this.stringValues)
}
this.stringValueRunes = [][]rune{}
for _, line := range this.stringValues {
this.stringValueRunes = append(this.stringValueRunes, []rune(line))
}
}
case RuleOperatorMatch:
v := this.Value
var v = this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v
}
@@ -106,7 +118,7 @@ func (this *Rule) Init() error {
}
this.reg = reg
case RuleOperatorNotMatch:
v := this.Value
var v = this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v
}
@@ -164,7 +176,7 @@ func (this *Rule) Init() error {
this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife()
this.cacheLife = checkpoint.CacheLife()
} else {
var checkpoint = checkpoints.FindCheckpoint(prefix)
if checkpoint == nil {
@@ -174,7 +186,7 @@ func (this *Rule) Init() error {
this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife()
this.cacheLife = checkpoint.CacheLife()
}
return nil
@@ -193,8 +205,8 @@ func (this *Rule) Init() error {
this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority()
if this.regCacheLife <= 0 || checkpoint.CacheLife() < this.regCacheLife {
this.regCacheLife = checkpoint.CacheLife()
if this.cacheLife <= 0 || checkpoint.CacheLife() < this.cacheLife {
this.cacheLife = checkpoint.CacheLife()
}
}
} else {
@@ -206,7 +218,7 @@ func (this *Rule) Init() error {
this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife()
this.cacheLife = checkpoint.CacheLife()
}
}
@@ -239,9 +251,9 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
return this.Test(value), hasRequestBody, nil
}
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
var value = configutils.ParseVariables(this.Param, func(varName string) (value string) {
var pieces = strings.SplitN(varName, ".", 2)
var prefix = pieces[0]
point, ok := this.multipleCheckpoints[prefix]
if !ok {
return ""
@@ -255,7 +267,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
}
value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions, this.Id)
@@ -265,7 +277,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
})
if err != nil {
@@ -312,9 +324,9 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
return this.Test(value), hasRequestBody, nil
}
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
var value = configutils.ParseVariables(this.Param, func(varName string) (value string) {
var pieces = strings.SplitN(varName, ".", 2)
var prefix = pieces[0]
point, ok := this.multipleCheckpoints[prefix]
if !ok {
return ""
@@ -329,7 +341,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
} else {
value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions, this.Id)
if hasCheckRequestBody {
@@ -338,7 +350,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
}
}
@@ -350,7 +362,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
} else {
value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions, this.Id)
if hasCheckRequestBody {
@@ -359,7 +371,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
}
})
@@ -387,26 +399,37 @@ func (this *Rule) Test(value any) bool {
return types.Float64(value) != this.floatValue
case RuleOperatorEqString:
if this.IsCaseInsensitive {
return strings.EqualFold(types.String(value), this.Value)
return strings.EqualFold(this.stringifyValue(value), this.Value)
} else {
return types.String(value) == this.Value
return this.stringifyValue(value) == this.Value
}
case RuleOperatorNeqString:
if this.IsCaseInsensitive {
return !strings.EqualFold(types.String(value), this.Value)
return !strings.EqualFold(this.stringifyValue(value), this.Value)
} else {
return types.String(value) != this.Value
return this.stringifyValue(value) != this.Value
}
case RuleOperatorMatch, RuleOperatorWildcardMatch:
if value == nil {
return false
value = ""
}
// strings
stringList, ok := value.([]string)
if ok {
for _, s := range stringList {
if utils.MatchStringCache(this.reg, s, this.regCacheLife) {
if utils.MatchStringCache(this.reg, s, this.cacheLife) {
return true
}
}
return false
}
// bytes list
byteSlices, ok := value.([][]byte)
if ok {
for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) {
return true
}
}
@@ -416,19 +439,30 @@ func (this *Rule) Test(value any) bool {
// bytes
byteSlice, ok := value.([]byte)
if ok {
return utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife)
return utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife)
}
// string
return utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife)
return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife)
case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch:
if value == nil {
return true
value = ""
}
stringList, ok := value.([]string)
if ok {
for _, s := range stringList {
if utils.MatchStringCache(this.reg, s, this.regCacheLife) {
if utils.MatchStringCache(this.reg, s, this.cacheLife) {
return false
}
}
return true
}
// bytes list
byteSlices, ok := value.([][]byte)
if ok {
for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) {
return false
}
}
@@ -438,17 +472,17 @@ func (this *Rule) Test(value any) bool {
// bytes
byteSlice, ok := value.([]byte)
if ok {
return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife)
return !utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife)
}
return !utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife)
return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife)
case RuleOperatorContains:
if types.IsSlice(value) {
_, isBytes := value.([]byte)
if !isBytes {
ok := false
var ok = false
lists.Each(value, func(k int, v any) {
if types.String(v) == this.Value {
if this.stringifyValue(v) == this.Value {
ok = true
}
})
@@ -456,17 +490,17 @@ func (this *Rule) Test(value any) bool {
}
}
if types.IsMap(value) {
lowerValue := ""
var lowerValue = ""
if this.IsCaseInsensitive {
lowerValue = strings.ToLower(this.Value)
}
for _, v := range maps.NewMap(value) {
if this.IsCaseInsensitive {
if strings.ToLower(types.String(v)) == lowerValue {
if strings.ToLower(this.stringifyValue(v)) == lowerValue {
return true
}
} else {
if types.String(v) == this.Value {
if this.stringifyValue(v) == this.Value {
return true
}
}
@@ -475,30 +509,44 @@ func (this *Rule) Test(value any) bool {
}
if this.IsCaseInsensitive {
return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
return strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else {
return strings.Contains(types.String(value), this.Value)
return strings.Contains(this.stringifyValue(value), this.Value)
}
case RuleOperatorNotContains:
if this.IsCaseInsensitive {
return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
return !strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else {
return !strings.Contains(types.String(value), this.Value)
return !strings.Contains(this.stringifyValue(value), this.Value)
}
case RuleOperatorPrefix:
if this.IsCaseInsensitive {
return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
var s = this.stringifyValue(value)
var sl = len(s)
var vl = len(this.Value)
if sl < vl {
return false
}
s = s[:vl]
return strings.HasPrefix(strings.ToLower(s), strings.ToLower(this.Value))
} else {
return strings.HasPrefix(types.String(value), this.Value)
return strings.HasPrefix(this.stringifyValue(value), this.Value)
}
case RuleOperatorSuffix:
if this.IsCaseInsensitive {
return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
var s = this.stringifyValue(value)
var sl = len(s)
var vl = len(this.Value)
if sl < vl {
return false
}
s = s[sl-vl:]
return strings.HasSuffix(strings.ToLower(s), strings.ToLower(this.Value))
} else {
return strings.HasSuffix(types.String(value), this.Value)
return strings.HasSuffix(this.stringifyValue(value), this.Value)
}
case RuleOperatorContainsAny:
var stringValue = types.String(value)
var stringValue = this.stringifyValue(value)
if this.IsCaseInsensitive {
stringValue = strings.ToLower(stringValue)
}
@@ -511,7 +559,7 @@ func (this *Rule) Test(value any) bool {
}
return false
case RuleOperatorContainsAll:
var stringValue = types.String(value)
var stringValue = this.stringifyValue(value)
if this.IsCaseInsensitive {
stringValue = strings.ToLower(stringValue)
}
@@ -524,31 +572,81 @@ func (this *Rule) Test(value any) bool {
return true
}
return false
case RuleOperatorContainsAnyWord:
return runes.ContainsAnyWordRunes(this.stringifyValue(value), this.stringValueRunes, this.IsCaseInsensitive)
case RuleOperatorContainsAllWords:
return runes.ContainsAllWords(this.stringifyValue(value), this.stringValues, this.IsCaseInsensitive)
case RuleOperatorNotContainsAnyWord:
return !runes.ContainsAnyWordRunes(this.stringifyValue(value), this.stringValueRunes, this.IsCaseInsensitive)
case RuleOperatorContainsSQLInjection:
if value == nil {
return false
}
switch xValue := value.(type) {
case []string:
for _, v := range xValue {
if injectionutils.DetectSQLInjectionCache(v, this.cacheLife) {
return true
}
}
return false
case [][]byte:
for _, v := range xValue {
if injectionutils.DetectSQLInjectionCache(string(v), this.cacheLife) {
return true
}
}
return false
default:
return injectionutils.DetectSQLInjectionCache(this.stringifyValue(value), this.cacheLife)
}
case RuleOperatorContainsXSS:
if value == nil {
return false
}
switch xValue := value.(type) {
case []string:
for _, v := range xValue {
if injectionutils.DetectXSSCache(v, this.cacheLife) {
return true
}
}
return false
case [][]byte:
for _, v := range xValue {
if injectionutils.DetectXSSCache(string(v), this.cacheLife) {
return true
}
}
return false
default:
return injectionutils.DetectXSSCache(this.stringifyValue(value), this.cacheLife)
}
case RuleOperatorContainsBinary:
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value))
data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
if this.IsCaseInsensitive {
return bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data))
return bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data))
} else {
return bytes.Contains([]byte(types.String(value)), data)
return bytes.Contains([]byte(this.stringifyValue(value)), data)
}
case RuleOperatorNotContainsBinary:
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value))
data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
if this.IsCaseInsensitive {
return !bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data))
return !bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data))
} else {
return !bytes.Contains([]byte(types.String(value)), data)
return !bytes.Contains([]byte(this.stringifyValue(value)), data)
}
case RuleOperatorHasKey:
if types.IsSlice(value) {
index := types.Int(this.Value)
var index = types.Int(this.Value)
if index < 0 {
return false
}
return reflect.ValueOf(value).Len() > index
} else if types.IsMap(value) {
m := maps.NewMap(value)
var m = maps.NewMap(value)
if this.IsCaseInsensitive {
lowerValue := strings.ToLower(this.Value)
var lowerValue = strings.ToLower(this.Value)
for k := range m {
if strings.ToLower(k) == lowerValue {
return true
@@ -567,9 +665,9 @@ func (this *Rule) Test(value any) bool {
return stringutil.VersionCompare(this.Value, types.String(value)) < 0
case RuleOperatorVersionRange:
if strings.Contains(this.Value, ",") {
versions := strings.SplitN(this.Value, ",", 2)
version1 := strings.TrimSpace(versions[0])
version2 := strings.TrimSpace(versions[1])
var versions = strings.SplitN(this.Value, ",", 2)
var version1 = strings.TrimSpace(versions[0])
var version2 = strings.TrimSpace(versions[1])
if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 {
return false
}
@@ -587,25 +685,25 @@ func (this *Rule) Test(value any) bool {
}
return this.isIP && ip.Equal(this.ipValue)
case RuleOperatorGtIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) > 0
case RuleOperatorGteIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
case RuleOperatorLtIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) < 0
case RuleOperatorLteIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
@@ -624,7 +722,7 @@ func (this *Rule) Test(value any) bool {
if div == 0 {
return false
}
rem := types.Int64(pieces[1])
var rem = types.Int64(pieces[1])
return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem
case RuleOperatorIPMod10:
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
@@ -737,3 +835,25 @@ func (this *Rule) execFilter(value any) any {
}
return value
}
func (this *Rule) stringifyValue(value any) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case []string:
return strings.Join(v, "")
case []byte:
return string(v)
case [][]byte:
var b = &bytes.Buffer{}
for _, vb := range v {
b.Write(vb)
}
return b.String()
default:
return types.String(v)
}
}

View File

@@ -4,34 +4,40 @@ type RuleOperator = string
type RuleCaseInsensitive = string
const (
RuleOperatorGt RuleOperator = "gt"
RuleOperatorGte RuleOperator = "gte"
RuleOperatorLt RuleOperator = "lt"
RuleOperatorLte RuleOperator = "lte"
RuleOperatorEq RuleOperator = "eq"
RuleOperatorNeq RuleOperator = "neq"
RuleOperatorEqString RuleOperator = "eq string"
RuleOperatorNeqString RuleOperator = "neq string"
RuleOperatorMatch RuleOperator = "match"
RuleOperatorNotMatch RuleOperator = "not match"
RuleOperatorWildcardMatch RuleOperator = "wildcard match"
RuleOperatorWildcardNotMatch RuleOperator = "wildcard not match"
RuleOperatorContains RuleOperator = "contains"
RuleOperatorNotContains RuleOperator = "not contains"
RuleOperatorPrefix RuleOperator = "prefix"
RuleOperatorSuffix RuleOperator = "suffix"
RuleOperatorContainsAny RuleOperator = "contains any"
RuleOperatorContainsAll RuleOperator = "contains all"
RuleOperatorInIPList RuleOperator = "in ip list"
RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map
RuleOperatorVersionGt RuleOperator = "version gt"
RuleOperatorVersionLt RuleOperator = "version lt"
RuleOperatorVersionRange RuleOperator = "version range"
RuleOperatorGt RuleOperator = "gt"
RuleOperatorGte RuleOperator = "gte"
RuleOperatorLt RuleOperator = "lt"
RuleOperatorLte RuleOperator = "lte"
RuleOperatorEq RuleOperator = "eq"
RuleOperatorNeq RuleOperator = "neq"
RuleOperatorEqString RuleOperator = "eq string"
RuleOperatorNeqString RuleOperator = "neq string"
RuleOperatorMatch RuleOperator = "match"
RuleOperatorNotMatch RuleOperator = "not match"
RuleOperatorWildcardMatch RuleOperator = "wildcard match"
RuleOperatorWildcardNotMatch RuleOperator = "wildcard not match"
RuleOperatorContains RuleOperator = "contains"
RuleOperatorNotContains RuleOperator = "not contains"
RuleOperatorPrefix RuleOperator = "prefix"
RuleOperatorSuffix RuleOperator = "suffix"
RuleOperatorContainsAny RuleOperator = "contains any"
RuleOperatorContainsAll RuleOperator = "contains all"
RuleOperatorContainsAnyWord RuleOperator = "contains any word"
RuleOperatorContainsAllWords RuleOperator = "contains all words"
RuleOperatorNotContainsAnyWord RuleOperator = "not contains any word"
RuleOperatorContainsSQLInjection RuleOperator = "contains sql injection"
RuleOperatorContainsXSS RuleOperator = "contains xss"
RuleOperatorInIPList RuleOperator = "in ip list"
RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map
RuleOperatorVersionGt RuleOperator = "version gt"
RuleOperatorVersionLt RuleOperator = "version lt"
RuleOperatorVersionRange RuleOperator = "version range"
RuleOperatorContainsBinary RuleOperator = "contains binary" // contains binary
RuleOperatorNotContainsBinary RuleOperator = "not contains binary" // not contains binary
// ip
RuleOperatorEqIP RuleOperator = "eq ip"
RuleOperatorGtIP RuleOperator = "gt ip"
RuleOperatorGteIP RuleOperator = "gte ip"
@@ -42,10 +48,6 @@ const (
RuleOperatorIPMod10 RuleOperator = "ip mod 10"
RuleOperatorIPMod100 RuleOperator = "ip mod 100"
RuleOperatorIPMod RuleOperator = "ip mod"
RuleCaseInsensitiveNone = "none"
RuleCaseInsensitiveYes = "yes"
RuleCaseInsensitiveNo = "no"
)
type RuleOperatorDefinition struct {
@@ -54,174 +56,3 @@ type RuleOperatorDefinition struct {
Description string
CaseInsensitive RuleCaseInsensitive // default caseInsensitive setting
}
var AllRuleOperators = []*RuleOperatorDefinition{
{
Name: "数值大于",
Code: RuleOperatorGt,
Description: "使用数值对比大于",
CaseInsensitive: RuleCaseInsensitiveNone,
},
{
Name: "数值大于等于",
Code: RuleOperatorGte,
Description: "使用数值对比大于等于",
CaseInsensitive: RuleCaseInsensitiveNone,
},
{
Name: "数值小于",
Code: RuleOperatorLt,
Description: "使用数值对比小于",
CaseInsensitive: RuleCaseInsensitiveNone,
},
{
Name: "数值小于等于",
Code: RuleOperatorLte,
Description: "使用数值对比小于等于",
CaseInsensitive: RuleCaseInsensitiveNone,
},
{
Name: "数值等于",
Code: RuleOperatorEq,
Description: "使用数值对比等于",
CaseInsensitive: RuleCaseInsensitiveNone,
},
{
Name: "数值不等于",
Code: RuleOperatorNeq,
Description: "使用数值对比不等于",
CaseInsensitive: RuleCaseInsensitiveNone,
},
{
Name: "字符串等于",
Code: RuleOperatorEqString,
Description: "使用字符串对比等于",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "字符串不等于",
Code: RuleOperatorNeqString,
Description: "使用字符串对比不等于",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "正则匹配",
Code: RuleOperatorMatch,
Description: "使用正则表达式匹配,在头部使用(?i)表示不区分大小写,<a href=\"http://teaos.cn/doc/regexp/Regexp.md\" target=\"_blank\">正则表达式语法 &raquo;</a>",
CaseInsensitive: RuleCaseInsensitiveYes,
},
{
Name: "正则不匹配",
Code: RuleOperatorNotMatch,
Description: "使用正则表达式不匹配,在头部使用(?i)表示不区分大小写,<a href=\"http://teaos.cn/doc/regexp/Regexp.md\" target=\"_blank\">正则表达式语法 &raquo;</a>",
CaseInsensitive: RuleCaseInsensitiveYes,
},
{
Name: "包含字符串",
Code: RuleOperatorContains,
Description: "包含某个字符串",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "不包含字符串",
Code: RuleOperatorNotContains,
Description: "不包含某个字符串",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "包含前缀",
Code: RuleOperatorPrefix,
Description: "包含某个前缀",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "包含后缀",
Code: RuleOperatorSuffix,
Description: "包含某个后缀",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "包含索引",
Code: RuleOperatorHasKey,
Description: "对于一组数据拥有某个键值或者索引",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "版本号大于",
Code: RuleOperatorVersionGt,
Description: "对比版本号大于",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "版本号小于",
Code: RuleOperatorVersionLt,
Description: "对比版本号小于",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "版本号范围",
Code: RuleOperatorVersionRange,
Description: "判断版本号在某个范围内格式为version1,version2",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP等于",
Code: RuleOperatorEqIP,
Description: "将参数转换为IP进行对比",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP大于",
Code: RuleOperatorGtIP,
Description: "将参数转换为IP进行对比",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP大于等于",
Code: RuleOperatorGteIP,
Description: "将参数转换为IP进行对比",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP小于",
Code: RuleOperatorLtIP,
Description: "将参数转换为IP进行对比",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP小于等于",
Code: RuleOperatorLteIP,
Description: "将参数转换为IP进行对比",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP范围",
Code: RuleOperatorIPRange,
Description: "IP在某个范围之内范围格式可以是英文逗号分隔的ip1,ip2或者CIDR格式的ip/bits",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "不在IP范围",
Code: RuleOperatorNotIPRange,
Description: "IP不在某个范围之内范围格式可以是英文逗号分隔的ip1,ip2或者CIDR格式的ip/bits",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP取模10",
Code: RuleOperatorIPMod10,
Description: "对IP参数值取模除数为10对比值为余数",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP取模100",
Code: RuleOperatorIPMod100,
Description: "对IP参数值取模除数为100对比值为余数",
CaseInsensitive: RuleCaseInsensitiveNo,
},
{
Name: "IP取模",
Code: RuleOperatorIPMod,
Description: "对IP参数值取模对比值格式为除数,余数比如10,1",
CaseInsensitive: RuleCaseInsensitiveNo,
},
}

View File

@@ -1,7 +1,8 @@
package waf
package waf_test
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/cespare/xxhash"
"github.com/iwind/TeaGo/assert"
@@ -12,18 +13,18 @@ import (
)
func TestRuleSet_MatchRequest(t *testing.T) {
set := NewRuleSet()
set.Connector = RuleConnectorAnd
var set = waf.NewRuleSet()
set.Connector = waf.RuleConnectorAnd
set.Rules = []*Rule{
set.Rules = []*waf.Rule{
{
Param: "${arg.name}",
Operator: RuleOperatorEqString,
Operator: waf.RuleOperatorEqString,
Value: "lu",
},
{
Param: "${arg.age}",
Operator: RuleOperatorEq,
Operator: waf.RuleOperatorEq,
Value: "20",
},
}
@@ -42,20 +43,20 @@ func TestRuleSet_MatchRequest(t *testing.T) {
}
func TestRuleSet_MatchRequest2(t *testing.T) {
a := assert.NewAssertion(t)
var a = assert.NewAssertion(t)
set := NewRuleSet()
set.Connector = RuleConnectorOr
var set = waf.NewRuleSet()
set.Connector = waf.RuleConnectorOr
set.Rules = []*Rule{
set.Rules = []*waf.Rule{
{
Param: "${arg.name}",
Operator: RuleOperatorEqString,
Operator: waf.RuleOperatorEqString,
Value: "lu",
},
{
Param: "${arg.age}",
Operator: RuleOperatorEq,
Operator: waf.RuleOperatorEq,
Value: "21",
},
}
@@ -76,28 +77,28 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
func BenchmarkRuleSet_MatchRequest(b *testing.B) {
runtime.GOMAXPROCS(1)
set := NewRuleSet()
set.Connector = RuleConnectorOr
var set = waf.NewRuleSet()
set.Connector = waf.RuleConnectorOr
set.Rules = []*Rule{
set.Rules = []*waf.Rule{
{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Operator: waf.RuleOperatorMatch,
Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`,
},
{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Operator: waf.RuleOperatorMatch,
Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`,
},
{
Param: "${arg.name}",
Operator: RuleOperatorEqString,
Operator: waf.RuleOperatorEqString,
Value: "lu",
},
{
Param: "${arg.age}",
Operator: RuleOperatorEq,
Operator: waf.RuleOperatorEq,
Value: "21",
},
}
@@ -120,13 +121,13 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
runtime.GOMAXPROCS(1)
set := NewRuleSet()
set.Connector = RuleConnectorOr
var set = waf.NewRuleSet()
set.Connector = waf.RuleConnectorOr
set.Rules = []*Rule{
set.Rules = []*waf.Rule{
{
Param: "${requestBody}",
Operator: RuleOperatorMatch,
Operator: waf.RuleOperatorMatch,
Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`,
IsCaseInsensitive: false,
},

View File

@@ -49,10 +49,10 @@ func TestRule_Init_Composite(t *testing.T) {
}
func TestRule_Test(t *testing.T) {
a := assert.NewAssertion(t)
var a = assert.NewAssertion(t)
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorGt
rule.Value = "123"
err := rule.Init()
@@ -66,7 +66,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorGte
rule.Value = "123"
err := rule.Init()
@@ -79,7 +79,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorLt
rule.Value = "123"
err := rule.Init()
@@ -92,7 +92,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorLte
rule.Value = "123"
err := rule.Init()
@@ -105,7 +105,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorEq
rule.Value = "123"
err := rule.Init()
@@ -118,7 +118,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNeq
rule.Value = "123"
err := rule.Init()
@@ -131,7 +131,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorEqString
rule.Value = "123"
err := rule.Init()
@@ -144,7 +144,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorEqString
rule.Value = "abc"
err := rule.Init()
@@ -156,7 +156,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorEqString
rule.IsCaseInsensitive = true
rule.Value = "abc"
@@ -169,7 +169,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNeqString
rule.Value = "abc"
err := rule.Init()
@@ -182,7 +182,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNeqString
rule.IsCaseInsensitive = true
rule.Value = "abc"
@@ -194,7 +194,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^\\d+"
err := rule.Init()
@@ -206,7 +206,31 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^\\d+"
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test([]byte("123")))
a.IsFalse(rule.Test([]byte("abc123")))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^\\d+"
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test([][]byte{[]byte("123"), []byte("456")}))
a.IsFalse(rule.Test([][]byte{[]byte("abc123")}))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "abc"
rule.IsCaseInsensitive = true
@@ -218,7 +242,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^\\d+"
err := rule.Init()
@@ -230,7 +254,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNotMatch
rule.Value = "\\d+"
err := rule.Init()
@@ -242,7 +266,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNotMatch
rule.Value = "abc"
rule.IsCaseInsensitive = true
@@ -254,7 +278,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNotMatch
rule.Value = "^\\d+"
err := rule.Init()
@@ -266,7 +290,20 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNotMatch
rule.Value = "^\\d+"
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsFalse(rule.Test([][]byte{[]byte("123"), []byte("456")}))
a.IsFalse(rule.Test([][]byte{[]byte("123"), []byte("abc")}))
a.IsTrue(rule.Test([][]byte{[]byte("abc123")}))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^(?i)[a-z]+$"
err := rule.Init()
@@ -277,7 +314,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorContains
rule.Value = "Hello"
err := rule.Init()
@@ -288,7 +325,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorContains
rule.Value = "hello"
rule.IsCaseInsensitive = true
@@ -300,7 +337,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorContains
rule.Value = "Hello"
err := rule.Init()
@@ -317,7 +354,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNotContains
rule.Value = "Hello"
err := rule.Init()
@@ -329,7 +366,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorNotContains
rule.Value = "hello"
rule.IsCaseInsensitive = true
@@ -342,7 +379,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorPrefix
rule.Value = "Hello"
err := rule.Init()
@@ -350,11 +387,12 @@ func TestRule_Test(t *testing.T) {
t.Fatal(err)
}
a.IsTrue(rule.Test("Hello, World"))
a.IsFalse(rule.Test("hello"))
a.IsFalse(rule.Test("World, Hello"))
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorPrefix
rule.Value = "hello"
rule.IsCaseInsensitive = true
@@ -363,11 +401,13 @@ func TestRule_Test(t *testing.T) {
t.Fatal(err)
}
a.IsTrue(rule.Test("Hello, World"))
a.IsTrue(rule.Test("hello, World"))
a.IsFalse(rule.Test("hell"))
a.IsFalse(rule.Test("World, Hello"))
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorSuffix
rule.Value = "Hello"
err := rule.Init()
@@ -379,7 +419,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorSuffix
rule.Value = "hello"
rule.IsCaseInsensitive = true
@@ -388,11 +428,13 @@ func TestRule_Test(t *testing.T) {
t.Fatal(err)
}
a.IsFalse(rule.Test("Hello, World"))
a.IsTrue(rule.Test("Hello"))
a.IsFalse(rule.Test("llo"))
a.IsTrue(rule.Test("World, Hello"))
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorHasKey
rule.Value = "Hello"
err := rule.Init()
@@ -409,7 +451,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorHasKey
rule.Value = "hello"
rule.IsCaseInsensitive = true
@@ -427,7 +469,7 @@ func TestRule_Test(t *testing.T) {
}
{
rule := NewRule()
var rule = NewRule()
rule.Operator = RuleOperatorHasKey
rule.Value = "3"
err := rule.Init()
@@ -440,6 +482,45 @@ func TestRule_Test(t *testing.T) {
}))
a.IsTrue(rule.Test([]int{1, 2, 3, 4}))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorContainsAnyWord
rule.Value = "How\nare\nyou"
rule.IsCaseInsensitive = true
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test("how"))
a.IsTrue(rule.Test("How doing"))
a.IsFalse(rule.Test("doing"))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorContainsAllWords
rule.Value = "How\nare\nyou"
rule.IsCaseInsensitive = true
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test("how are you"))
a.IsTrue(rule.Test("How are you doing"))
a.IsFalse(rule.Test("How are dare"))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorContainsSQLInjection
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test("id=123 OR 1=1"))
a.IsTrue(rule.Test("id=456 UNION SELECT"))
a.IsTrue(rule.Test("id=456 AND select load_file('') --"))
a.IsFalse(rule.Test("id=123"))
a.IsFalse(rule.Test("id=abc123 hello world '"))
}
}
func TestRule_MatchStar(t *testing.T) {

View File

@@ -1,434 +1,40 @@
package waf
func Template() *WAF {
waf := NewWAF()
waf.Id = 0
waf.IsOn = true
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
)
// xss
{
group := NewRuleGroup()
func Template() (*WAF, error) {
var config = firewallconfigs.HTTPFirewallTemplate()
if config.Inbound != nil {
config.Inbound.IsOn = true
}
for _, group := range config.AllRuleGroups() {
if group.Code == "cc" || group.Code == "cc2" {
continue
}
group.IsOn = true
group.IsInbound = true
group.Name = "XSS"
group.Code = "xss"
group.Description = "防跨站脚本攻击Cross Site Scripting"
{
set := NewRuleSet()
for _, set := range group.Sets {
set.IsOn = true
set.Name = "Javascript事件"
set.Code = "1001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "Javascript函数"
set.Code = "1002"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `(alert|eval|prompt|confirm)\s*\(`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "HTML标签"
set.Code = "1003"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `<(script|iframe|link)`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// upload
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "文件上传"
group.Code = "upload"
group.Description = "防止上传可执行脚本文件到服务器"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "上传文件扩展名"
set.Code = "2001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestUpload.ext}",
Operator: RuleOperatorMatch,
Value: `\.(php|jsp|aspx|asp|exe|asa|rb|py)\b`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
instance, err := SharedWAFManager.ConvertWAF(config)
if err != nil {
return nil, err
}
// web shell
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "Web Shell"
group.Code = "webShell"
group.Description = "防止远程执行服务器命令"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "Web Shell"
set.Code = "3001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
for _, group := range instance.Inbound {
for _, set := range group.RuleSets {
for _, rule := range set.Rules {
rule.cacheLife = utils.CacheDisabled // for performance test
_ = rule
}
}
waf.AddRuleGroup(group)
}
// command injection
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "命令注入"
group.Code = "commandInjection"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "命令注入"
set.Code = "4001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${requestBody}",
Operator: RuleOperatorMatch,
Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here
IsCaseInsensitive: false,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// path traversal
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "路径穿越"
group.Code = "pathTraversal"
group.Description = "防止读取网站目录之外的其他系统文件"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "路径穿越"
set.Code = "5001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `((\.+)(/+)){2,}`, // TODO more keywords here
IsCaseInsensitive: false,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// special dirs
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "特殊目录"
group.Code = "denyDirs"
group.Description = "防止通过Web访问到一些特殊目录"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "特殊目录"
set.Code = "6001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestPath}",
Operator: RuleOperatorMatch,
Value: `/\.(git|svn|htaccess|idea)\b`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// sql injection
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "SQL注入"
group.Code = "sqlInjection"
group.Description = "防止SQL注入漏洞"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "Union SQL Injection"
set.Code = "7001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `union[\s/\*]+select`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL注释"
set.Code = "7002"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `/\*(!|\x00)`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL条件"
set.Code = "7003"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\s(and|or|rlike)\s+(if|updatexml)\s*\(`,
IsCaseInsensitive: true,
})
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\s+(and|or|rlike)\s+(select|case)\s+`,
IsCaseInsensitive: true,
})
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\s+(and|or|procedure)\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+(\s|$|--|#)`,
IsCaseInsensitive: true,
})
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\(\s*case\s+when\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+\s+then\s+`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL函数"
set.Code = "7004"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `(updatexml|extractvalue|ascii|ord|char|chr|count|concat|rand|floor|substr|length|len|user|database|benchmark|analyse)\s*\(`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL附加语句"
set.Code = "7005"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `;\s*(declare|use|drop|create|exec|delete|update|insert)\s`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// bot
{
group := NewRuleGroup()
group.IsOn = false
group.IsInbound = true
group.Name = "网络爬虫"
group.Code = "bot"
group.Description = "禁止一些网络爬虫"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "常见网络爬虫"
set.Code = "20001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${userAgent}",
Operator: RuleOperatorMatch,
Value: `Googlebot|AdsBot|bingbot|BingPreview|facebookexternalhit|Slurp|Sogou|proximic|Baiduspider|yandex|twitterbot|spider|python`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// cc
{
group := NewRuleGroup()
group.IsOn = false
group.IsInbound = true
group.Name = "CC攻击"
group.Description = "Challenge Collapsar防止短时间大量请求涌入请谨慎开启和设置"
group.Code = "cc2"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "CC请求数"
set.Description = "限制单IP在一定时间内的请求数"
set.Code = "8001"
set.Connector = RuleConnectorAnd
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${cc2}",
Operator: RuleOperatorGt,
Value: "1000",
CheckpointOptions: map[string]interface{}{
"period": "60",
"threshold": 1000,
"keys": []string{"${remoteAddr}", "${requestPath}"},
},
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `127.0.0.1/8`,
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `192.168.0.1/16`,
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `10.0.0.1/8`,
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `172.16.0.1/12`,
IsCaseInsensitive: false,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// custom
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "自定义规则分组"
group.Description = "我的自定义规则分组,可以将自定义的规则放在这个分组下"
group.Code = "custom"
waf.AddRuleGroup(group)
}
return waf
return instance, nil
}

View File

@@ -1,56 +1,58 @@
package waf
package waf_test
import (
"bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
"math/rand"
"mime/multipart"
"net/http"
"net/url"
"runtime"
"strings"
"testing"
"time"
)
func Test_Template(t *testing.T) {
a := assert.NewAssertion(t)
const testUserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0"
template := Template()
err := template.Init()
func Test_Template(t *testing.T) {
var a = assert.NewAssertion(t)
wafInstance, err := waf.Template()
if err != nil {
t.Fatal(err)
}
testTemplate1001(a, t, template)
testTemplate1002(a, t, template)
testTemplate1003(a, t, template)
testTemplate2001(a, t, template)
testTemplate3001(a, t, template)
testTemplate4001(a, t, template)
testTemplate5001(a, t, template)
testTemplate6001(a, t, template)
testTemplate7001(a, t, template)
testTemplate20001(a, t, template)
testTemplate1010(a, t, wafInstance)
testTemplate2001(a, t, wafInstance)
testTemplate3001(a, t, wafInstance)
testTemplate4001(a, t, wafInstance)
testTemplate5001(a, t, wafInstance)
testTemplate6001(a, t, wafInstance)
testTemplate7010(a, t, wafInstance)
testTemplate20001(a, t, wafInstance)
}
func Test_Template2(t *testing.T) {
reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024)))
req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", reader)
req, err := http.NewRequest(http.MethodPost, "https://example.com/index.php?id=123", reader)
if err != nil {
t.Fatal(err)
}
waf := Template()
var errs = waf.Init()
if len(errs) > 0 {
t.Fatal(errs[0])
wafInstance, err := waf.Template()
if err != nil {
t.Fatal(err)
}
now := time.Now()
goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
@@ -65,69 +67,76 @@ func Test_Template2(t *testing.T) {
}
func BenchmarkTemplate(b *testing.B) {
waf := Template()
err := waf.Init()
runtime.GOMAXPROCS(4)
wafInstance, err := waf.Template()
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
reader := bytes.NewReader([]byte(strings.Repeat("Hello", 1024)))
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader)
if err != nil {
b.Fatal(err)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123"+types.String(rand.Int()%10000), nil)
if err != nil {
b.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
}
})
}
_, _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
for _, id := range []string{
"<script",
"<script src=\"123.js\">",
"<script>alert(123)</script>",
"<link",
"<link>",
"1 onfocus='alert(document.cookie)'",
} {
req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id="+id, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "1010")
} else {
t.Log("break at:", id)
}
}
for _, id := range []string{
"123",
"abc",
"<html></html>",
} {
req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id="+url.QueryEscape(id), nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNil(result)
if result != nil {
a.IsTrue(result.Code == "1010")
}
}
}
func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) {
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=onmousedown%3D123", nil)
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "1001")
}
}
func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) {
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=eval%28", nil)
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "1002")
}
}
func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) {
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=<script src=\"123.js\">", nil)
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "1003")
}
}
func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) {
func testTemplate2001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
body := bytes.NewBuffer([]byte{})
writer := multipart.NewWriter(body)
@@ -193,7 +202,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) {
}
}
func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) {
func testTemplate3001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodPost, "http://example.com/index.php?exec1+(", bytes.NewReader([]byte("exec('rm -rf /hello');")))
if err != nil {
t.Fatal(err)
@@ -208,7 +217,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) {
}
}
func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) {
func testTemplate4001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodPost, "http://example.com/index.php?whoami", nil)
if err != nil {
t.Fatal(err)
@@ -223,7 +232,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) {
}
}
func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
{
req, err := http.NewRequest(http.MethodPost, "http://example.com/.././..", nil)
if err != nil {
@@ -255,12 +264,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
}
}
func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
{
req, err := http.NewRequest(http.MethodPost, "http://example.com/.svn/123.txt", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
@@ -280,39 +290,116 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
a.IsNil(result)
a.IsNotNil(result)
}
}
func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) {
func testTemplate7010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
for _, id := range []string{
"union select",
" and if(",
"/*!",
" and select ",
" and id=123 ",
"(case when a=1 then ",
"updatexml (",
"; delete from table",
" union all select id from credits",
"' or 1=1",
"' or '1'='1",
"1' or '1'='1')) /*",
"OR 1/** this is comment **/=1",
"AND 1=2",
"; INSERT INTO users (...)",
"order by 10--",
"UNION SELECT 1,null,null--",
"' AND ASCII(SUBSTRING(username, 1, 1))=97 AND '1'='1",
"||UTL_INADDR.GET_HOST_NAME((SELECT user FROM dual) )--",
" AND IF(version() like '5%', sleep(10), 'false')",
"; update tablename set code='javascript code' where 1--",
"AND @@version like '5.0%', ",
"/*!40110 and 1=0*/",
"AND 1=0 UNION SELECT DATABASE()",
"load_file('filename')",
"limit 1 into outfile 'aaa'",
"OR IF(1, BENCHMARK(#ofcicies, action_to_be_performed), 'false')",
"AND 1=CONVERT(int, db_name())",
// PostgresSQL
"and 1::int=1",
} {
req, err := http.NewRequest(http.MethodPost, "http://example.com/?id="+url.QueryEscape(id), nil)
req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1 "+url.QueryEscape(id), nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(lists.ContainsAny([]string{"7001", "7002", "7003", "7004", "7005"}, result.Code))
a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Code))
} else {
t.Log("break:", id)
}
}
}
func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
func TestTemplateSQLInjection(t *testing.T) {
template, err := waf.Template()
if err != nil {
t.Fatal(err)
}
var group = template.FindRuleGroupWithCode("sqlInjection")
if group == nil {
t.Fatal("group not found")
return
}
//
//for _, set := range group.RuleSets {
// for _, rule := range set.Rules {
// t.Logf("%#v", rule.singleCheckpoint)
// }
//}
req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, result, err := group.MatchRequest(requests.NewTestRequest(req))
if err != nil {
t.Fatal(err)
}
if result != nil {
t.Log(result)
}
}
func BenchmarkTemplateSQLInjection(b *testing.B) {
template, err := waf.Template()
if err != nil {
b.Fatal(err)
}
var group = template.FindRuleGroupWithCode("sqlInjection")
if group == nil {
b.Fatal("group not found")
return
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234"+types.String(rand.Int()%10000), nil)
if err != nil {
b.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, result, err := group.MatchRequest(requests.NewTestRequest(req))
if err != nil {
b.Fatal(err)
}
_ = result
}
})
}
func testTemplate20001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
// enable bot rule set
for _, g := range template.Inbound {
if g.Code == "bot" {
@@ -323,7 +410,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
for _, bot := range []string{
"Googlebot",
"AdsBot",
"AdsBot-Google",
"bingbot",
"BingPreview",
"facebookexternalhit",
@@ -348,3 +435,67 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
}
}
}
func BenchmarkTemplatePathTraversal(b *testing.B) {
runtime.GOMAXPROCS(4)
template, err := waf.Template()
if err != nil {
b.Fatal(err)
}
var group = template.FindRuleGroupWithCode("pathTraversal")
if group == nil {
b.Fatal("group not found")
return
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234"+types.String(rand.Int()%10000)+"&name=lily&time=12345678910", nil)
if err != nil {
b.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, result, err := group.MatchRequest(requests.NewTestRequest(req))
if err != nil {
b.Fatal(err)
}
_ = result
}
})
}
func BenchmarkTemplateCC2(b *testing.B) {
runtime.GOMAXPROCS(4)
template, err := waf.Template()
if err != nil {
b.Fatal(err)
}
var group = template.FindRuleGroupWithCode("cc2")
if group == nil {
b.Fatal("group not found")
return
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234"+types.String(rand.Int()%10000)+"&name=lily&time=12345678910", nil)
if err != nil {
b.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, result, err := group.MatchRequest(requests.NewTestRequest(req))
if err != nil {
b.Fatal(err)
}
_ = result
}
})
}

View File

@@ -7,13 +7,13 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/cachehits"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/cespare/xxhash"
"github.com/cespare/xxhash/v2"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"strconv"
)
var cache = ttlcache.NewCache[int8]()
var SharedCache = ttlcache.NewCache[int8]()
var cacheHits *cachehits.Stat
func init() {
@@ -24,7 +24,7 @@ func init() {
}
const (
maxCacheDataSize = 1024
MaxCacheDataSize = 1024
)
type CacheLife = int64
@@ -45,22 +45,22 @@ func MatchStringCache(regex *re.Regexp, s string, cacheLife CacheLife) bool {
var regIdString = regex.IdString()
// 如果长度超过一定数量,大概率是不能重用的
if cacheLife <= 0 || len(s) > maxCacheDataSize || !cacheHits.IsGood(regIdString) {
if cacheLife <= 0 || len(s) > MaxCacheDataSize || !cacheHits.IsGood(regIdString) {
return regex.MatchString(s)
}
var hash = xxhash.Sum64String(s)
var key = regIdString + "@" + strconv.FormatUint(hash, 10)
var item = cache.Read(key)
var item = SharedCache.Read(key)
if item != nil {
cacheHits.IncreaseHit(regIdString)
return item.Value == 1
}
var b = regex.MatchString(s)
if b {
cache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
cache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
cacheHits.IncreaseCached(regIdString)
return b
@@ -75,25 +75,22 @@ func MatchBytesCache(regex *re.Regexp, byteSlice []byte, cacheLife CacheLife) bo
var regIdString = regex.IdString()
// 如果长度超过一定数量,大概率是不能重用的
if cacheLife <= 0 || len(byteSlice) > maxCacheDataSize || !cacheHits.IsGood(regIdString) {
if cacheLife <= 0 || len(byteSlice) > MaxCacheDataSize || !cacheHits.IsGood(regIdString) {
return regex.Match(byteSlice)
}
var hash = xxhash.Sum64(byteSlice)
var key = regIdString + "@" + strconv.FormatUint(hash, 10)
var item = cache.Read(key)
var item = SharedCache.Read(key)
if item != nil {
cacheHits.IncreaseHit(regIdString)
return item.Value == 1
}
if item != nil {
return item.Value == 1
}
var b = regex.Match(byteSlice)
if b {
cache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
cache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
cacheHits.IncreaseCached(regIdString)
return b

View File

@@ -402,7 +402,7 @@ func (this *WAF) Stop() {
}
// MergeTemplate merge with template
func (this *WAF) MergeTemplate() (changedItems []string) {
func (this *WAF) MergeTemplate() (changedItems []string, err error) {
changedItems = []string{}
// compare versions
@@ -411,7 +411,10 @@ func (this *WAF) MergeTemplate() (changedItems []string) {
}
this.CreatedVersion = teaconst.Version
template := Template()
template, err := Template()
if err != nil {
return nil, err
}
groups := []*RuleGroup{}
groups = append(groups, template.Inbound...)
groups = append(groups, template.Outbound...)

View File

@@ -202,6 +202,7 @@ func (this *WAFManager) ConvertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
UIFooter: policy.CaptchaOptions.UIFooter,
UIBody: policy.CaptchaOptions.UIBody,
Lang: policy.CaptchaOptions.Lang,
GeeTestConfig: &policy.CaptchaOptions.GeeTestConfig,
}
}

View File

@@ -1,7 +1,8 @@
package waf
package waf_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert"
"net/http"
@@ -9,32 +10,32 @@ import (
)
func TestWAF_MatchRequest(t *testing.T) {
a := assert.NewAssertion(t)
var a = assert.NewAssertion(t)
set := NewRuleSet()
var set = waf.NewRuleSet()
set.Name = "Name_Age"
set.Connector = RuleConnectorAnd
set.Rules = []*Rule{
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${arg.name}",
Operator: RuleOperatorEqString,
Operator: waf.RuleOperatorEqString,
Value: "lu",
},
{
Param: "${arg.age}",
Operator: RuleOperatorEq,
Operator: waf.RuleOperatorEq,
Value: "20",
},
}
set.AddAction(ActionBlock, nil)
set.AddAction(waf.ActionBlock, nil)
group := NewRuleGroup()
var group = waf.NewRuleGroup()
group.AddRuleSet(set)
group.IsInbound = true
waf := NewWAF()
waf.AddRuleGroup(group)
errs := waf.Init()
var wafInstance = waf.NewWAF()
wafInstance.AddRuleGroup(group)
errs := wafInstance.Init()
if len(errs) > 0 {
t.Fatal(errs[0])
}
@@ -43,7 +44,7 @@ func TestWAF_MatchRequest(t *testing.T) {
if err != nil {
t.Fatal(err)
}
goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}